Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 0eaf565

Browse files
author
Sara Adkins
authored
Support for Stacking Recipes (#1897)
* initial recipe re-loading * loading for input recipe * persist structure across recipe loads * clean up fn names * clean up duplicated code * delete extra file * unit tests * fix failing test * quantization edge cases * quant tests * fixes for stage name clashes * clean up documentation * add apply flag during finalization as well * clarity comments * fix unit test
1 parent f088321 commit 0eaf565

File tree

8 files changed

+339
-12
lines changed

8 files changed

+339
-12
lines changed

src/sparseml/core/lifecycle/session.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,11 @@ def pre_initialize_structure(
7878
if data is not None:
7979
mod_data.append(data)
8080

81+
# mark which modifiers have already had their structures initialized
82+
# so when we consolidate the next recipe this info isn't lost
8183
self.initialized_structure = True
84+
applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
85+
self.recipe_container.update_applied_stages(applied_stage_names)
8286

8387
return mod_data
8488

@@ -113,6 +117,8 @@ def finalize(self, **kwargs) -> List[Any]:
113117
mod_data.append(data)
114118

115119
self.finalized = True
120+
applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
121+
self.recipe_container.update_applied_stages(applied_stage_names)
116122

117123
return mod_data
118124

@@ -169,6 +175,9 @@ def _check_compile_recipe(self):
169175
self.modifiers = self.recipe_container.compiled_recipe.create_modifier(
170176
self.state.framework
171177
)
178+
for mod in self.modifiers:
179+
if mod.unique_id in self.recipe_container.applied_stages:
180+
mod.applied = True
172181

173182
def _check_setup_event_lifecycle(self, event_type: EventType):
174183
if self.event_lifecycle is not None:

src/sparseml/core/modifier/stage.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,14 @@ class StageModifiers(ModifierInterface, BaseModel):
3636
:param modifiers: The modifiers to apply as a stage
3737
:param index: The index of the stage, if applicable
3838
:param group: The group name of the stage, if applicable
39+
:param applied: Flag for indicating if this stage has has already been
40+
applied to the model, through structure initialization or finalization
3941
"""
4042

4143
modifiers: List["Modifier"] = Field(default_factory=list)
4244
index: Optional[int] = None
4345
group: Optional[str] = None
46+
applied: bool = False
4447

4548
@property
4649
def initialized_structure(self) -> bool:
@@ -66,6 +69,13 @@ def finalized(self) -> bool:
6669
"""
6770
return all(mod.finalized for mod in self.modifiers)
6871

72+
@property
73+
def unique_id(self) -> str:
74+
"""
75+
:return: ID for stage containing the name and index
76+
"""
77+
return self.group + "_" + str(self.index)
78+
6979
def check_initialized(self):
7080
"""
7181
Check if all of the stage modifiers have been initialized, and log a warning
@@ -103,7 +113,7 @@ def calculate_end(self) -> float:
103113

104114
def pre_initialize_structure(self, state: "State", **kwargs):
105115
"""
106-
Pre initialize the structure for all stage modifiers
116+
Pre initialize the structure for all stage modifiers mark the stage applied
107117
108118
:param state: The current state of the training
109119
:param kwargs: Additional kwargs to pass to the modifier(s)
@@ -112,6 +122,8 @@ def pre_initialize_structure(self, state: "State", **kwargs):
112122
for modifier in self.modifiers:
113123
modifier.pre_initialize_structure(state, **kwargs)
114124

125+
self.applied = True
126+
115127
def initialize(self, state: "State", **kwargs):
116128
"""
117129
Initialize all the stage modifiers
@@ -120,20 +132,30 @@ def initialize(self, state: "State", **kwargs):
120132
:param kwargs: Additional kwargs to pass to the modifier(s)
121133
initialize method
122134
"""
135+
136+
if self.applied:
137+
return
138+
123139
for modifier in self.modifiers:
124140
modifier.initialize(state, **kwargs)
125141

126142
def finalize(self, state: "State", **kwargs):
127143
"""
128-
Finalize all the stage modifiers
144+
Finalize all the stage modifiers and mark the stage as applied
129145
130146
:param state: The state of current session
131147
:param kwargs: Additional kwargs to pass to the modifier(s)
132148
finalize method
133149
"""
150+
151+
if self.applied:
152+
return
153+
134154
for modifier in self.modifiers:
135155
modifier.finalize(state, **kwargs)
136156

157+
self.applied = True
158+
137159
def update_event(self, state: "State", event: "Event", **kwargs):
138160
"""
139161
Propagate the event to all the stage modifiers
@@ -143,5 +165,9 @@ def update_event(self, state: "State", event: "Event", **kwargs):
143165
:param kwargs: Additional kwargs to pass to the modifier(s)
144166
update_event method
145167
"""
168+
169+
if self.applied:
170+
return
171+
146172
for modifier in self.modifiers:
147173
modifier.update_event(state, event, **kwargs)

src/sparseml/core/recipe/container.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,12 @@ class RecipeContainer:
4444
4545
:param compiled_recipe: the compiled recipe from the recipes list
4646
:param recipes: the list of RecipeTuple instances to be compiled
47+
:param applied_stages: list of recipe stages that have already been applied
4748
"""
4849

4950
compiled_recipe: Optional[Recipe] = None
5051
recipes: List[RecipeTuple] = field(default_factory=list)
52+
applied_stages: List[str] = field(default_factory=list)
5153

5254
def update(
5355
self,
@@ -118,6 +120,17 @@ def update(
118120

119121
return kwargs
120122

123+
def update_applied_stages(self, new_stages: List[str]):
124+
"""
125+
Updates the applied_stages list with new stages, indicating their structure
126+
has already been applied
127+
128+
:param new_stages: new stage names to add
129+
"""
130+
for stage in new_stages:
131+
if stage not in self.applied_stages:
132+
self.applied_stages.append(stage)
133+
121134
def check_compile_recipe(self) -> bool:
122135
"""
123136
Check if the recipes need to be compiled into a single recipe and

src/sparseml/core/recipe/recipe.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -494,20 +494,22 @@ def _modifier_group_to_dict(modifier_group: List[Dict[str, Any]]):
494494
for key, value in modifier.items()
495495
}
496496

497-
def _stage_to_dict(stage: List[Dict[str, Any]]):
498-
# convert a list of stages to a dict of modifiers
497+
def _stage_to_dict(stage: Dict[str, Any]):
498+
# convert a stage to a dict of modifiers
499499
return {
500500
modifier_group_name: _modifier_group_to_dict(modifier_group)
501-
for stage_modifiers in stage
502-
for modifier_group_name, modifier_group in stage_modifiers[
503-
"modifiers"
504-
].items()
501+
for modifier_group_name, modifier_group in stage["modifiers"].items()
505502
}
506503

507-
return {
508-
stage_name: _stage_to_dict(stage=stage)
509-
for stage_name, stage in self.dict()["stages"].items()
510-
}
504+
final_dict = {}
505+
for stage_name, stages in self.dict()["stages"].items():
506+
if len(stages) == 1:
507+
final_dict[stage_name] = _stage_to_dict(stages[0])
508+
else:
509+
for idx, stage in enumerate(stages):
510+
final_dict[stage_name + "_" + str(idx)] = _stage_to_dict(stage)
511+
512+
return final_dict
511513

512514

513515
@dataclass

src/sparseml/modifiers/quantization/utils/quantize.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
"add_input_activation_quant_wrappers",
5353
"add_output_activation_observers",
5454
"raise_if_torch_quantization_not_available",
55+
"raise_if_already_quantized",
56+
"is_module_quantized",
5557
]
5658

5759

@@ -148,6 +150,18 @@ def set_quantization_schemes(
148150
# submodule type or graph section set to ignore, skip
149151
continue
150152

153+
if isinstance(submodule, torch_quantization.QuantWrapper):
154+
# special case to catch QuantizableMatMul children
155+
if ignore and _match_submodule_name_or_type(
156+
submodule.module, submodule_name, ignore
157+
):
158+
continue
159+
160+
if is_qat_helper_module(submodule):
161+
# ignore children of an already quantized module, if there is a clash it
162+
# will have been caught in the parent
163+
continue
164+
151165
# override default scheme if necessary
152166
override_key = _match_submodule_name_or_type(
153167
submodule, submodule_name, scheme_overrides
@@ -162,6 +176,7 @@ def set_quantization_schemes(
162176
wrap_qat_targets[submodule_name] = submodule_scheme
163177
elif is_module_type_override or is_quantizable_module(submodule):
164178
# is base quantizable module or user specifically targeted module type
179+
raise_if_already_quantized(submodule_name, submodule)
165180
submodule.quantization_scheme = submodule_scheme
166181

167182
# inject any targeted QATWrappers
@@ -351,6 +366,34 @@ def raise_if_torch_quantization_not_available():
351366
)
352367

353368

369+
def raise_if_already_quantized(module_name: str, module: Module):
370+
"""
371+
:param module_name: name of module to check for quantization
372+
:param module: module to check for quantization
373+
:raises: RuntimeError if module is already quantized, it cannot be re-quantized
374+
"""
375+
if is_module_quantized(module):
376+
raise RuntimeError(
377+
f"Unable to quantize module {module_name}, as it has already been "
378+
"quantized. Ensure your input recipe does not contain multiple "
379+
"QuantizationModifiers that act on the same module. "
380+
)
381+
382+
383+
def is_module_quantized(module: Module) -> bool:
384+
"""
385+
:param module: module to check for quantization
386+
:return: True if the module is quantized, False otherwise
387+
"""
388+
if hasattr(module, "quantization_scheme") and isinstance(
389+
module.quantization_scheme, QuantizationScheme
390+
):
391+
return True
392+
if isinstance(module, torch_quantization.QuantWrapper):
393+
return True
394+
return False
395+
396+
354397
def _match_submodule_name_or_type(
355398
submodule: Module, submodule_name: str, names_or_types: List[str]
356399
) -> Optional[str]:

tests/sparseml/core/lifecycle/test_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def test_session_initialize_propagates_layer_prefix_to_model(
8080

8181
class ModifierMock(ModifierInterface):
8282
initialized_ = False
83+
applied = False
84+
group = "test"
85+
unique_id = "test_0"
8386

8487
def __init__(self, *args, **kwargs) -> None:
8588
super().__init__()
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
test_stage:
2+
obcq_modifiers:
3+
SparseGPTModifier:
4+
sparsity: 0.7
5+
block_size: 128
6+
sequential_update: True
7+
quantize: False
8+
percdamp: 0.01
9+
prunen: 0
10+
prunem: 0
11+
targets: [
12+
"model.layers.0"
13+
]
14+
target_ids: ["attention_mask", "position_ids"]

0 commit comments

Comments
 (0)