Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 800d08b

Browse files
author
Sara Adkins
committed
Merge branch 'main' into prod_smooth_quant
2 parents 788c16b + b622bba commit 800d08b

File tree

34 files changed

+2192
-114
lines changed

34 files changed

+2192
-114
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,3 +795,6 @@ fabric.properties
795795
*.resources
796796
test-results/
797797
integrations/pytorch/pytorch_vision*
798+
799+
# local log files
800+
nm_temp_test_logs/*

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ More information on installation such as optional dependencies and requirements
128128

129129
### Recipes
130130

131-
To enable flexibility, ease of use, and repeatability, SparseML uses a declarative interface called `recipes` for specifying the sparsity-related algorithms and hyperparamters that should be applied by SparseML.
131+
To enable flexibility, ease of use, and repeatability, SparseML uses a declarative interface called `recipes` for specifying the sparsity-related algorithms and hyperparameters that should be applied by SparseML.
132132

133133
`Recipes` are YAML-files formatted as a list of `modifiers`, which encode the instructions for SparseML. Example `modifiers` can be anything from setting the learning rate to encoding the hyperparameters of the gradual magnitude pruning algorithm. The SparseML system parses the `recipes` into a native format for each framework and applies the modifications to the model and training pipeline.
134134

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
_transformers_deps = _pytorch_deps + [
8080
f"{'nm-transformers' if is_release else 'nm-transformers-nightly'}"
8181
f"~={version_nm_deps}",
82-
"datasets<=2.11",
82+
"datasets<=2.14.6",
8383
"scikit-learn",
8484
"seqeval",
8585
"einops",

src/sparseml/core/lifecycle/event.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def optim_pre_step_events(self) -> List[Event]:
198198
and self.type_ is not None
199199
and self.type_ != EventType.OPTIM_POST_STEP
200200
):
201-
raise ValueError("optim pre step must be called after optim post step")
201+
raise ValueError("optim pre step must be called before optim post step")
202202

203203
if (
204204
self.type_first == EventType.LOSS_CALCULATED

src/sparseml/core/lifecycle/session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def reset(self):
5050
except Exception:
5151
pass
5252

53+
if self.state and self.state.data:
54+
# reset data if it exists
55+
self.state.data.reset()
5356
self.state = None
5457
self.recipe_container = RecipeContainer()
5558
self.modifiers = []

src/sparseml/core/model/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,11 @@ def get_matching_layer(
126126
:param model: model to search for targets
127127
"""
128128
raise NotImplementedError()
129+
130+
def qat_active(self) -> bool:
131+
"""
132+
Checks if quantization aware training is set up in the model
133+
134+
:return: True if QAT is active in any layer, False otherwise
135+
"""
136+
raise NotImplementedError()

src/sparseml/core/model/pytorch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
get_matching_layer,
2626
get_param,
2727
get_params,
28+
qat_active,
2829
set_layer,
2930
set_param,
3031
)
@@ -105,3 +106,11 @@ def get_matching_layer(
105106
:param model: model to search for targets
106107
"""
107108
return get_matching_layer(target, name_to_match, model)
109+
110+
def qat_active(self) -> bool:
111+
"""
112+
Checks if quantization aware training is set up in the model
113+
114+
:return: True if QAT is active in any layer, False otherwise
115+
"""
116+
return qat_active(self.model)

src/sparseml/core/recipe/modifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,4 +113,4 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]:
113113
"""
114114
:return: the dictionary representation of the modifier
115115
"""
116-
return {self.type: self.args}
116+
return {self.type: self.args, "group": f"{self.group}_modifiers"}

src/sparseml/core/recipe/recipe.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -399,18 +399,20 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]:
399399
... targets: ['re:.*weight']
400400
... '''
401401
>>> recipe = Recipe.create_instance(recipe_str)
402-
>>> recipe.dict()
403-
Traceback (most recent call last):
404-
...
405-
KeyError: 'group'
402+
>>> recipe_dict = recipe.dict()
403+
>>> stage = recipe_dict["stages"]["test"]
404+
>>> pruning_mods = stage[0]['modifiers']['pruning']
405+
>>> modifier_args = pruning_mods[0]['ConstantPruningModifier']
406+
>>> modifier_args == {'start': 0.0, 'end': 2.0, 'targets': ['re:.*weight']}
407+
True
406408
407409
:return: A dictionary representation of the recipe
408410
"""
409411
dict_ = super().dict(*args, **kwargs)
410412
stages = {}
411413

412414
for stage in dict_["stages"]:
413-
name = stage["group"]
415+
name = f"{stage['group']}_stage"
414416
del stage["group"]
415417

416418
if name not in stages:
@@ -422,6 +424,58 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]:
422424

423425
return dict_
424426

427+
def yaml(self, file_path: Optional[str] = None) -> str:
428+
"""
429+
Return a yaml string representation of the recipe.
430+
431+
:param file_path: optional file path to save yaml to
432+
:return: The yaml string representation of the recipe
433+
"""
434+
file_stream = None if file_path is None else open(file_path, "w")
435+
yaml_dict = self._get_yaml_dict()
436+
437+
ret = yaml.dump(
438+
yaml_dict, stream=file_stream, allow_unicode=True, sort_keys=False
439+
)
440+
441+
if file_stream is not None:
442+
file_stream.close()
443+
444+
return ret
445+
446+
def _get_yaml_dict(self) -> Dict[str, Any]:
447+
"""
448+
Get a dictionary representation of the recipe for yaml serialization
449+
The returned dict will only contain information necessary for yaml
450+
serialization (ignores metadata, version, etc), and must not be used
451+
in place of the dict method
452+
453+
:return: A dictionary representation of the recipe for yaml serialization
454+
"""
455+
456+
def _modifier_group_to_dict(modifier_group: List[Dict[str, Any]]):
457+
# convert a list of modifiers to a dict of modifiers
458+
return {
459+
key: value
460+
for modifier in modifier_group
461+
for key, value in modifier.items()
462+
}
463+
464+
def _stage_to_dict(stage: List[Dict[str, Any]]):
465+
# convert a list of stages to a dict of modifiers
466+
return {
467+
modifier_group_name: _modifier_group_to_dict(modifier_group)
468+
for stage_modifiers in stage
469+
for modifier_group_name, modifier_group in stage_modifiers[
470+
"modifiers"
471+
].items()
472+
}
473+
474+
return {
475+
stage_name: _stage_to_dict(stage=stage)
476+
for stage_name, stage in self.dict()["stages"].items()
477+
}
478+
425479

426480
@dataclass
427481
class RecipeTuple:

src/sparseml/core/state.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ class Data:
4545
test: Optional[ModifiableData] = None
4646
calib: Optional[ModifiableData] = None
4747

48+
def reset(self):
49+
"""
50+
Reset self to initial state
51+
"""
52+
attribs = Data().__dict__
53+
for attrib_name, attrib_value in attribs.items():
54+
setattr(self, attrib_name, attrib_value)
55+
4856

4957
@dataclass
5058
class Hardware:

0 commit comments

Comments
 (0)