Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/sparseml/modifiers/pruning/constant/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def on_update(self, state: State, event: Event, **kwargs):
def apply_masks(module):
mask_name = param_mask_name()
if hasattr(module, mask_name):
mask = getattr(module, mask_name)
if mask.device != module.weight.device:
setattr(module, mask_name, mask.to(module.weight.device))
module.weight *= getattr(module, mask_name)

state.model.model.apply(apply_masks)
Expand Down
8 changes: 0 additions & 8 deletions src/sparseml/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@
)
from sparseml.transformers.finetune.model_args import ModelArguments
from sparseml.transformers.finetune.training_args import TrainingArguments
from sparseml.utils.fsdp.context import summon_full_params_context
from sparseml.utils.fsdp.helpers import is_fsdp_model, unwrap_and_export_model
from sparseml.utils.pytorch import qat_active


_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -287,12 +285,6 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
session = session_manager.active_session()
session.reset_stage()

# log model sparsity
with summon_full_params_context(self.trainer.model):
if self.trainer.accelerator.is_main_process:
if not qat_active(self.trainer.model):
self.trainer.log_model_sparsification()

# synchronize and clean up memory
self.trainer.accelerator.wait_for_everyone()
self.trainer.model = get_session_model()
Expand Down
19 changes: 15 additions & 4 deletions src/sparseml/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from sparseml.utils.fsdp.context import summon_full_params_context
from sparseml.utils.fsdp.helpers import is_fsdp_model, save_pretrained_fsdp
from sparseml.utils.pytorch import qat_active


__all__ = [
Expand Down Expand Up @@ -137,7 +138,7 @@ def initialize_session(
train_data = self.get_train_dataloader()

self.accelerator.wait_for_everyone()
with summon_full_params_context(self.model):
with summon_full_params_context(self.model, offload_to_cpu=True):
session_manager.initialize(
model=self.model,
teacher_model=self.teacher, # TODO: what about for self/disable?
Expand Down Expand Up @@ -370,9 +371,13 @@ def train(self, *args, stage: Optional[str] = None, **kwargs):

self.accelerator.wait_for_everyone()

# Need to gather parameters across the GPUs before accessing layer weights
with summon_full_params_context(self.model):
self.log_model_sparsification()
# log model sparsity
with summon_full_params_context(self.model, offload_to_cpu=True):
if self.accelerator.is_main_process:
if not qat_active(self.model):
self.log_model_sparsification()

self.accelerator.wait_for_everyone()

return output

Expand Down Expand Up @@ -434,6 +439,12 @@ def one_shot(self, calib_data: DataLoader, stage: Optional[str] = None):
accelerator=self.accelerator,
)

# log model sparsity
with summon_full_params_context(self.model, offload_to_cpu=True):
if self.accelerator.is_main_process:
if not qat_active(self.model):
self.log_model_sparsification()

self.accelerator.wait_for_everyone()

def save_model(
Expand Down