Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main():
from torch.utils.data import DataLoader
from torchvision import transforms

import sparseml.core.session as sml
import sparseml.core.session as session_manager
from sparseml.core.event import EventType
from sparseml.core.framework import Framework
from sparseml.pytorch.utils import (
Expand All @@ -40,8 +40,8 @@ def main():
device = "cuda:0"

# set up SparseML session
sml.create_session()
session = sml.active_session()
session_manager.create_session()
session = session_manager.active_session()

# download model
model = torchvision.models.mobilenet_v2(
Expand Down
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,11 @@ def _setup_entry_points() -> Dict:
]
)

entry_points["console_scripts"].append(
"sparseml.transformers.export_onnx=sparseml.transformers.export:main"
entry_points["console_scripts"].extend(
[
"sparseml.transformers.export_onnx=sparseml.transformers.export:main",
"sparseml.transformers.export_onnx_refactor=sparseml.transformers.sparsification.obcq.export:main", # noqa 501
]
)

# image classification integration
Expand Down
5 changes: 1 addition & 4 deletions src/sparseml/modifiers/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import Any, Dict, List, Optional

from sparseml.core import Event, Modifier, State
from sparseml.core import Event, Modifier


__all__ = ["QuantizationModifier"]
Expand Down Expand Up @@ -136,6 +136,3 @@ def check_should_disable_observer(self, event: Event) -> bool:
if event.current_index >= disable_epoch:
return True
return False

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier
47 changes: 27 additions & 20 deletions src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def __init__(self, **kwargs):
self.scheme_overrides, self.scheme
)

def on_initialize_structure(self, state: State, **kwargs):
module = state.model.model
self._enable_module_qat(module)
state.model.model.apply(torch.quantization.disable_observer)

def on_initialize(self, state: State, **kwargs) -> bool:
raise_if_torch_quantization_not_available()
if self.end and self.end != -1:
Expand All @@ -84,6 +89,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:

if self.calculate_start() == -1: # one-shot
self._enable_module_qat(module)
self._calibrate_if_possible(module)
self._disable_quantization_observer(module)

return True
Expand Down Expand Up @@ -122,30 +128,31 @@ def _disable_quantization_observer(self, model: Module):
self.quantization_observer_disabled_ = True

def _enable_module_qat(self, module: Module):
# fuse conv-bn-relu blocks prior to quantization emulation
self._fuse(module)

# add quantization_schemes to target submodules
set_quantization_schemes(
module,
scheme=self.scheme,
scheme_overrides=self.scheme_overrides,
ignore=self.ignore,
strict=self.strict,
)
module.apply(torch.quantization.enable_observer)

# fix for freezing batchnorm statistics when not fusing BN with convs.
# pytorch only supports freezing batchnorm statistics for fused modules.
# this fix wraps BN modules adding with a new module class that supports
# methods related to freezing/unfreezing BN statistics.
configure_module_bn_wrappers(module)
if not self.qat_enabled_:
# fuse conv-bn-relu blocks prior to quantization emulation
self._fuse(module)

# add quantization_schemes to target submodules
set_quantization_schemes(
module,
scheme=self.scheme,
scheme_overrides=self.scheme_overrides,
ignore=self.ignore,
strict=self.strict,
)

# convert target qconfig layers to QAT modules with FakeQuantize
convert_module_qat_from_schemes(module)
# fix for freezing batchnorm statistics when not fusing BN with convs.
# pytorch only supports freezing batchnorm statistics for fused modules.
# this fix wraps BN modules adding with a new module class that supports
# methods related to freezing/unfreezing BN statistics.
configure_module_bn_wrappers(module)

self.qat_enabled_ = True
# convert target qconfig layers to QAT modules with FakeQuantize
convert_module_qat_from_schemes(module)

self._calibrate_if_possible(module)
self.qat_enabled_ = True

def _fuse(self, module: Module):
if self.model_fuse_fn_name in [None, "conv_bn_relus"]:
Expand Down
Loading