Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit b09c6d0

Browse files
bfineranmarkurtzrahul-tuli
authored
[cherry-pick] transformers refactor (#538)
* Refactor of Transformers SparseML CLI and integrations (#536) * Refactor of Transformers SparseML CLI and integrations * Refactor export.py to use new pathways, fix make quality * Update src/sparseml/optim/manager.py Co-authored-by: Rahul Tuli <[email protected]> * Update src/sparseml/transformers/utils/model.py Co-authored-by: Rahul Tuli <[email protected]> * fixes from review * fixes from review and testing * bug fixes and logging * bug fixes for export and distillation * review fixes, quality fixes, style fixes * fix dependency issue * fix distillation tests * fix distillation tests * fix distillation tests * fill in docs and update style * fix issue with distillation improperly updating students inputs * fix quality * Update src/sparseml/pytorch/optim/modifier_distillation.py * add in better logging for missing and unexpected keys in model reload for transformers trainer * fix logging for transformers export Co-authored-by: Rahul Tuli <[email protected]> * Fix model load bug and add logging to catch potential future issues (#537) * Fix model load bug and add logging to catch potential future issues * initial migration to generalize module sparsification information * propagate ModuleSparsificationInfo * report type of input tensors in export.py * minor bug fixes * ModuleSparsificationInfo docs * export onnx bugfix * bug fixes * make style * bug fix for quantization * revert to use ScheduledOptimizer due to bug with torch LambdaLR * remove language_modeling script * add end model sparsification log Co-authored-by: Benjamin <[email protected]> Co-authored-by: Mark Kurtz <[email protected]> Co-authored-by: Rahul Tuli <[email protected]>
1 parent d1b0622 commit b09c6d0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2059
-1721
lines changed

setup.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,22 @@
6666

6767
_dev_deps = [
6868
"beautifulsoup4==4.9.3",
69-
"black>=20.8b1",
70-
"flake8>=3.8.3",
71-
"isort>=5.7.0",
69+
"black==21.5b2",
70+
"flake8==3.9.2",
71+
"isort==5.8.0",
7272
"m2r2~=0.2.7",
7373
"myst-parser~=0.14.0",
74-
"rinohtype>=0.4.2",
75-
"sphinx>=3.4.0",
76-
"sphinx-copybutton>=0.3.0",
77-
"sphinx-markdown-tables>=0.0.15",
78-
"sphinx-multiversion==0.2.4",
79-
"sphinx-pydantic>=0.1.0",
80-
"sphinx-rtd-theme>=0.5.0",
74+
"rinohtype~=0.4.2",
75+
"sphinx~=3.5.0",
76+
"sphinx-copybutton~=0.3.0",
77+
"sphinx-markdown-tables~=0.0.15",
78+
"sphinx-multiversion~=0.2.4",
79+
"sphinx-pydantic~=0.1.0",
80+
"sphinx-rtd-theme~=0.5.0",
8181
"wheel>=0.36.2",
82-
"pytest>=6.0.0",
83-
"pytest-mock>=3.6.1",
84-
"flaky>=3.0.0",
82+
"pytest~=6.2.0",
83+
"pytest-mock~=3.6.0",
84+
"flaky~=3.7.0",
8585
"sphinx-rtd-theme",
8686
]
8787

@@ -112,25 +112,35 @@ def _setup_extras() -> Dict:
112112
}
113113

114114

115-
_transformers_entry_point_template = (
116-
"sparseml.transformers.train.{task}=sparseml.transformers.train.{task}:main"
117-
)
118-
119-
120115
def _setup_entry_points() -> Dict:
121-
return {
116+
entry_points = {
122117
"console_scripts": [
118+
# sparsification
123119
"sparseml.benchmark=sparseml.benchmark.info:_main",
124120
"sparseml.framework=sparseml.framework.info:_main",
125121
"sparseml.sparsification=sparseml.sparsification.info:_main",
126-
_transformers_entry_point_template.format(task="question_answering"),
127-
_transformers_entry_point_template.format(task="text_classification"),
128-
_transformers_entry_point_template.format(task="token_classification"),
129-
_transformers_entry_point_template.format(task="language_modeling"),
130-
"sparseml.transformers.export_onnx=sparseml.transformers.utils.export:main",
131122
]
132123
}
133124

125+
# transformers integration
126+
for task in [
127+
"question_answering",
128+
"text_classification",
129+
"token_classification",
130+
]:
131+
entry_points["console_scripts"].extend(
132+
[
133+
f"sparseml.transformers.{task}=sparseml.transformers.{task}:main",
134+
f"sparseml.transformers.train.{task}=sparseml.transformers.{task}:main",
135+
]
136+
)
137+
138+
entry_points["console_scripts"].append(
139+
"sparseml.transformers.export_onnx=sparseml.transformers.export:main"
140+
)
141+
142+
return entry_points
143+
134144

135145
def _setup_long_description() -> Tuple[str, str]:
136146
return open("README.md", "r", encoding="utf-8").read(), "text/markdown"

src/sparseml/keras/optim/manager.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
Also handles loading modifiers from yaml files
1919
"""
2020

21-
from typing import List, Union
21+
from typing import Any, Dict, List, Optional, Union
2222

2323
from tensorflow import Tensor
2424

2525
from sparseml.keras.optim.modifier import Modifier, ScheduledModifier
2626
from sparseml.keras.utils.compat import keras
2727
from sparseml.keras.utils.logger import KerasLogger
28-
from sparseml.optim import BaseManager, load_recipe_yaml_str
28+
from sparseml.optim import BaseManager, load_recipe_yaml_str, parse_recipe_variables
2929
from sparsezoo.objects import Recipe
3030

3131

@@ -41,7 +41,7 @@ class ScheduledModifierManager(BaseManager, Modifier):
4141
def from_yaml(
4242
file_path: Union[str, Recipe],
4343
add_modifiers: List[Modifier] = None,
44-
**recipe_variables,
44+
recipe_variables: Optional[Union[Dict[str, Any], str]] = None,
4545
):
4646
"""
4747
Convenience function used to create the manager of multiple modifiers from a
@@ -59,6 +59,7 @@ def from_yaml(
5959
with (i.e. num_epochs, init_lr)
6060
:return: ScheduledModifierManager() created from the recipe file
6161
"""
62+
recipe_variables = parse_recipe_variables(recipe_variables)
6263
yaml_str = load_recipe_yaml_str(file_path, **recipe_variables)
6364
modifiers = Modifier.load_list(yaml_str)
6465
if add_modifiers:

src/sparseml/optim/helpers.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
Helper functions for base Modifier and Manger utilities
1717
"""
1818

19+
import json
1920
import re
20-
from typing import Any, Dict, Tuple, Union
21+
from contextlib import suppress
22+
from typing import Any, Dict, Optional, Tuple, Union
2123

2224
import yaml
2325

@@ -32,6 +34,7 @@
3234
"rewrite_recipe_yaml_string_with_classes",
3335
"update_recipe_variables",
3436
"evaluate_recipe_yaml_str_equations",
37+
"parse_recipe_variables",
3538
]
3639

3740

@@ -137,6 +140,61 @@ def rewrite_recipe_yaml_string_with_classes(recipe_contianer: Any) -> str:
137140
return pattern.sub(r"!\g<class_name>", updated_yaml_str)
138141

139142

143+
def parse_recipe_variables(
144+
recipe_variables: Optional[Union[Dict[str, Any], str]] = None
145+
) -> Dict[str, Any]:
146+
"""
147+
Parse input recipe_variables into a dictionary that can be used to overload
148+
variables at the root of a recipe.
149+
Supports dictionaries as well as parsing a string in either json or
150+
csv key=value format
151+
152+
:param recipe_variables: the recipe_variables string or dictionary to parse
153+
for variables used with overloading recipes
154+
:return: the parsed recipe variables
155+
"""
156+
if not recipe_variables:
157+
return {}
158+
159+
if isinstance(recipe_variables, Dict):
160+
return recipe_variables
161+
162+
if not isinstance(recipe_variables, str):
163+
raise ValueError(
164+
f"recipe_args must be a string for parsing, given {recipe_variables}"
165+
)
166+
167+
# assume json first, try and parse
168+
with suppress(Exception):
169+
recipe_variables = json.loads(recipe_variables)
170+
return recipe_variables
171+
172+
# assume csv, and standardize to format key=val
173+
orig_recipe_variables = recipe_variables
174+
recipe_vars_str = recipe_variables.replace(":", "=")
175+
recipe_variables = {}
176+
for arg_val in recipe_vars_str.split(","):
177+
vals = arg_val.split("=")
178+
if len(vals) != 2:
179+
raise ValueError(
180+
"Improper key=val given in csv for recipe variables with value "
181+
f"{arg_val} in {orig_recipe_variables}"
182+
)
183+
key = vals[0].strip()
184+
if any(char in key for char in ["{", "!", "=", "}"]):
185+
raise ValueError(
186+
"Improper key given in csv for recipe variables with value "
187+
f"{key} in {orig_recipe_variables}"
188+
)
189+
val = vals[1].strip()
190+
with suppress(Exception):
191+
# check if val should be a number, otherwise fall back on string
192+
val = float(val)
193+
recipe_variables[key] = val
194+
195+
return recipe_variables
196+
197+
140198
def update_recipe_variables(recipe_yaml_str: str, variables: Dict[str, Any]) -> str:
141199
"""
142200
:param recipe_yaml_str: YAML string of a SparseML recipe

src/sparseml/optim/manager.py

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,8 @@
2222
from functools import cmp_to_key
2323
from typing import List
2424

25-
from sparseml.optim.modifier import (
26-
BaseModifier,
27-
BaseObject,
28-
BaseScheduled,
29-
ModifierProp,
30-
)
25+
from sparseml.optim.modifier import BaseModifier, BaseObject, ModifierProp
26+
from sparseml.sparsification.types import SparsificationTypes
3127
from sparseml.utils import clean_path, create_parent_dirs
3228

3329

@@ -42,7 +38,7 @@ class BaseManager(BaseObject):
4238
:param modifiers: the modifiers to wrap
4339
"""
4440

45-
def __init__(self, modifiers: List[BaseScheduled], **kwargs):
41+
def __init__(self, modifiers: List[BaseModifier], **kwargs):
4642
super().__init__(**kwargs)
4743
# sort modifiers by when they start and end so that later modifiers
4844
# can overwrite in a deterministic order such as when initializing
@@ -57,44 +53,88 @@ def __del__(self):
5753
def __str__(self) -> str:
5854
return "\n".join(self.to_string_lines())
5955

56+
def __eq__(self, compare: object) -> bool:
57+
return str(self) == str(compare)
58+
6059
@ModifierProp(serializable=False)
61-
def modifiers(self) -> List[BaseScheduled]:
60+
def modifiers(self) -> List[BaseModifier]:
6261
"""
6362
:return: list of all SparseML modifiers in the managed recipe
6463
"""
6564
return self._modifiers
6665

6766
@ModifierProp(serializable=False)
68-
def epoch_modifiers(self) -> List[BaseScheduled]:
67+
def epoch_modifiers(self) -> List[BaseModifier]:
6968
"""
7069
:return: list of all SparseML modifiers in the managed recipe that modify the
7170
epoch range
7271
"""
73-
return [mod for mod in self._modifiers if "Epoch" in str(type(mod))]
72+
return [
73+
mod
74+
for mod in self._modifiers
75+
if SparsificationTypes.epoch in mod.sparsification_types
76+
]
7477

7578
@ModifierProp(serializable=False)
76-
def learning_rate_modifiers(self) -> List[BaseScheduled]:
79+
def learning_rate_modifiers(self) -> List[BaseModifier]:
7780
"""
7881
:return: list of all SparseML modifiers in the managed recipe that modify the
7982
LearningRate schedule
8083
"""
81-
return [mod for mod in self._modifiers if "LearningRate" in str(type(mod))]
84+
return [
85+
mod
86+
for mod in self._modifiers
87+
if SparsificationTypes.learning_rate in mod.sparsification_types
88+
]
8289

8390
@ModifierProp(serializable=False)
84-
def pruning_modifiers(self) -> List[BaseScheduled]:
91+
def pruning_modifiers(self) -> List[BaseModifier]:
8592
"""
8693
:return: list of all SparseML modifiers in the managed recipe that manage
8794
model sparsity
8895
"""
89-
return [mod for mod in self._modifiers if "Pruning" in str(type(mod))]
96+
return [
97+
mod
98+
for mod in self._modifiers
99+
if SparsificationTypes.pruning in mod.sparsification_types
100+
]
90101

91102
@ModifierProp(serializable=False)
92-
def quantization_modifiers(self) -> List[BaseScheduled]:
103+
def quantization_modifiers(self) -> List[BaseModifier]:
93104
"""
94105
:return: list of all SparseML modifiers in the managed recipe that manage
95106
model quantization
96107
"""
97-
return [mod for mod in self._modifiers if "Quantization" in str(type(mod))]
108+
return [
109+
mod
110+
for mod in self._modifiers
111+
if SparsificationTypes.quantization in mod.sparsification_types
112+
]
113+
114+
@ModifierProp(serializable=False)
115+
def distillation_modifiers(self) -> List[BaseModifier]:
116+
"""
117+
:return: list of all SparseML modifiers in the managed recipe that manage
118+
Distillation
119+
"""
120+
return [
121+
mod
122+
for mod in self._modifiers
123+
if SparsificationTypes.distillation in mod.sparsification_types
124+
]
125+
126+
@ModifierProp(serializable=False)
127+
def structured_modifiers(self) -> List[BaseModifier]:
128+
"""
129+
:return: list of all SparseML modifiers in the managed recipe that manage
130+
structure changes to a model such as layer pruning, fitler pruning,
131+
and quantization
132+
"""
133+
return [
134+
mod
135+
for mod in self._modifiers
136+
if SparsificationTypes.structured in mod.sparsification_types
137+
]
98138

99139
@ModifierProp(serializable=False)
100140
def min_epochs(self) -> int:
@@ -154,7 +194,7 @@ def to_string_lines(self) -> List[str]:
154194

155195
return yaml_str_lines
156196

157-
def modifiers_to_string_lines(self, modifiers: List[BaseScheduled]) -> List[str]:
197+
def modifiers_to_string_lines(self, modifiers: List[BaseModifier]) -> List[str]:
158198
"""
159199
:param modifiers: the modifiers to convert into string / yaml representation
160200
for within the manage
@@ -176,3 +216,18 @@ def modifiers_to_string_lines(self, modifiers: List[BaseScheduled]) -> List[str]
176216
yaml_str_lines.append("")
177217

178218
return yaml_str_lines
219+
220+
def qat_active(self, epoch: float) -> bool:
221+
"""
222+
:param epoch: the epoch to check if quantization aware training will be
223+
active during
224+
:return: True if quantization aware training will be active at the start
225+
of or within the given epoch, False otherwise
226+
"""
227+
quant_modifiers = self.quantization_modifiers
228+
229+
return (
230+
min(mod.start_epoch for mod in quant_modifiers) < epoch + 1
231+
if quant_modifiers
232+
else False
233+
)

src/sparseml/optim/modifier.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import yaml
2626

2727
from sparseml.optim.helpers import evaluate_recipe_yaml_str_equations
28+
from sparseml.sparsification.types import SparsificationTypes
2829
from sparseml.utils import ALL_TOKEN, validate_str_iterable
2930

3031

@@ -466,6 +467,13 @@ def __repr__(self):
466467
self.props(only_serializable=False, format_repr=True),
467468
)
468469

470+
@ModifierProp(serializable=False)
471+
def sparsification_types(self) -> List[SparsificationTypes]:
472+
"""
473+
:return: the sparsification types this modifier instance will apply
474+
"""
475+
return []
476+
469477
@ModifierProp(serializable=True)
470478
def log_types(self) -> Union[None, str, List[str]]:
471479
"""

0 commit comments

Comments
 (0)