Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 81e5084

Browse files
dbogunowiczbogunowicz@arrival.combfineranKSGulin
committed
[Fix] composed_staged failing with modifiers with dynamically inferred end_epoch (#682)
* initial commit * added incrementing of epoch_end in base_stages plus a unit test * Correct the formating of the test input * . * Merge tests together Co-authored-by: [email protected] <[email protected]> Co-authored-by: Benjamin Fineran <[email protected]> Co-authored-by: Konstantin Gulin <[email protected]>
1 parent da964e9 commit 81e5084

File tree

2 files changed

+112
-1
lines changed

2 files changed

+112
-1
lines changed

src/sparseml/optim/manager.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,28 @@ def compose_staged(
223223
if not keep_original_epochs:
224224
# update additional modifier epochs
225225
base_end_epoch = base_recipe.max_epochs
226+
227+
# make sure that for the modifiers in base_stages
228+
# with the initial attribute `end_epoch` = -1,
229+
# this attribute value is replaced with `base_end_epoch`
230+
for base_modifiers in base_stages.values():
231+
for base_modifier in base_modifiers:
232+
if (
233+
hasattr(base_modifier, "end_epoch")
234+
and base_modifier.end_epoch == -1
235+
):
236+
base_modifier._init_end = base_end_epoch
237+
base_modifier.end_epoch = base_end_epoch
238+
226239
for additional_modifiers in additional_stages.values():
227240
for additional_modifier in additional_modifiers:
228-
if hasattr(additional_modifier, "end_epoch"):
241+
if (
242+
hasattr(additional_modifier, "end_epoch")
243+
and additional_modifier.end_epoch != -1
244+
):
245+
# if end_epoch == -1, the .end_epoch is being
246+
# assumed implicitly and does not need to be
247+
# incremented
229248
additional_modifier.end_epoch += base_end_epoch
230249
if hasattr(additional_modifier, "start_epoch"):
231250
additional_modifier.start_epoch += base_end_epoch

tests/sparseml/pytorch/optim/test_manager.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,88 @@
417417
418418
""" # noqa: W293
419419

420+
RECIPE_END_EPOCH_IMPLICIT = """
421+
training_modifiers:
422+
- !EpochRangeModifier
423+
start_epoch: 0.0
424+
end_epoch: 52
425+
426+
- !SetLearningRateModifier
427+
start_epoch: 50
428+
learning_rate: 0.000002
429+
430+
pruning_modifiers:
431+
- !ConstantPruningModifier
432+
start_epoch: 0.0
433+
params: __ALL_PRUNABLE__
434+
435+
quantization_modifiers:
436+
- !QuantizationModifier
437+
start_epoch: 50
438+
submodules: ['model.0']
439+
"""
440+
441+
COMPOSED_RECIPE_END_EPOCH_IMPLICIT = """version: 1.1.0
442+
443+
stage_0:
444+
__metadata__: None
445+
446+
stage_0_modifiers:
447+
- !ConstantPruningModifier
448+
end_epoch: 52
449+
params: __ALL_PRUNABLE__
450+
start_epoch: 0.0
451+
update_frequency: -1
452+
453+
- !EpochRangeModifier
454+
end_epoch: 52
455+
start_epoch: 0.0
456+
457+
- !QuantizationModifier
458+
end_epoch: 52
459+
quantize_embeddings: True
460+
quantize_linear_activations: True
461+
reduce_range: False
462+
start_epoch: 50
463+
submodules: ['model.0']
464+
465+
- !SetLearningRateModifier
466+
constant_logging: False
467+
end_epoch: 52
468+
learning_rate: 2e-06
469+
start_epoch: 50
470+
471+
472+
stage_1:
473+
__metadata__: None
474+
475+
stage_1_modifiers:
476+
- !EpochRangeModifier
477+
end_epoch: 104
478+
start_epoch: 52.0
479+
480+
- !ConstantPruningModifier
481+
end_epoch: -1.0
482+
params: __ALL_PRUNABLE__
483+
start_epoch: 52.0
484+
update_frequency: -1
485+
486+
- !QuantizationModifier
487+
end_epoch: -1.0
488+
quantize_embeddings: True
489+
quantize_linear_activations: True
490+
reduce_range: False
491+
start_epoch: 102
492+
submodules: ['model.0']
493+
494+
- !SetLearningRateModifier
495+
constant_logging: False
496+
end_epoch: -1.0
497+
learning_rate: 2e-06
498+
start_epoch: 102
499+
500+
""" # noqa: W293
501+
420502

421503
def _generate_fake_metadata(item1=("metadata", None), item2=("level", 1)):
422504
return {k: v for (k, v) in (item1, item2)}
@@ -558,6 +640,16 @@ def _generate_fake_metadata(item1=("metadata", None), item2=("level", 1)):
558640
False,
559641
True,
560642
),
643+
# Testing composing two recipes with modifiers containing
644+
# implicit `end_epoch` attribution (i.e. `end_epoch = -1`)
645+
(
646+
RECIPE_END_EPOCH_IMPLICIT,
647+
RECIPE_END_EPOCH_IMPLICIT,
648+
None,
649+
COMPOSED_RECIPE_END_EPOCH_IMPLICIT,
650+
False,
651+
False,
652+
),
561653
],
562654
)
563655
def test_lifecycle_manager_staged(

0 commit comments

Comments
 (0)