-
Couldn't load subscription status.
- Fork 262
[Torch FX] Compress PT2E Support #3663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
190f9d5
c52fcca
4e56cb5
9651ceb
14daeb5
3746815
0815dc5
1b8d940
7d35374
427ebc2
24dbfb6
4bb8c1a
8902842
3842538
2e70c2e
b1c9aad
33fe01c
d8e1006
88a8472
7a8e51a
fed5052
2866473
7171d56
3e3b067
5b7b210
71a479f
b24a59c
d12225a
9870ee2
8015629
0804218
1f1fda3
623ce46
d14a6eb
e91b455
448bf84
8e23572
36ddf53
07b730b
d5dd422
076a76b
2ce9eec
1bebf3e
ea81cfd
e82920f
82cc10b
beae508
8bd95df
aac9d3f
4278cfd
6fd5216
118b611
e9f3cd4
a969e58
71d0597
bf671ff
5f1c2de
6f81879
eb0ff16
8afeb9d
f491c8d
b9f3eff
09dabf6
68316a5
58b8992
e7bae1f
4b0d8ea
d4da34f
2b91658
93c3f19
932b296
67ab135
a23acaf
6462284
cf7e8d3
a07dc07
0506bca
f8675ad
52a7d5a
9e02948
a578fce
8ae6a80
c7210b8
2f8b296
75ccdcb
75cc255
3cdfe74
e4f9286
009c587
6e379c8
e45f796
f2ece8c
387d69c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| # Copyright (c) 2025 Intel Corporation | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| # Copyright (c) 2025 Intel Corporation | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Iterable, Optional | ||
|
|
||
| import torch | ||
|
|
||
| from nncf import AdvancedCompressionParameters | ||
| from nncf import BackupMode | ||
| from nncf import CompressionFormat | ||
| from nncf import CompressWeightsMode | ||
| from nncf import Dataset | ||
| from nncf import IgnoredScope | ||
| from nncf import SensitivityMetric | ||
| from nncf.common.graph.graph import NNCFGraph | ||
| from nncf.common.graph.graph import NNCFNode | ||
| from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer | ||
| from nncf.common.utils.backend import BackendType | ||
| from nncf.experimental.quantization.quantizer import Quantizer | ||
| from nncf.quantization.algorithms.algorithm import Algorithm | ||
| from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression as OriginalWeightCompression | ||
|
|
||
|
|
||
| class WeightsCompression(Algorithm): | ||
| """ | ||
| Post-training Weight Compression algorithm implementation. | ||
|
|
||
| Compresses weights of Linear and Embedding layers to 8-bit integer or | ||
| to 4-bit integer/float depending on mode, ratio and group size. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| mode: CompressWeightsMode, | ||
| quantizer: Quantizer, | ||
| ratio: float, | ||
| group_size: int, | ||
| ignored_scope: IgnoredScope, | ||
| all_layers: bool, | ||
| subset_size: int, | ||
| awq: bool, | ||
| scale_estimation: bool, | ||
| gptq: bool, | ||
| lora_correction: bool, | ||
| backup_mode: BackupMode, | ||
| sensitivity_metric: SensitivityMetric, | ||
| compression_format: CompressionFormat, | ||
| advanced_parameters: AdvancedCompressionParameters, | ||
| ) -> torch.fx.GraphModule: | ||
| """ | ||
| :param mode: Defines a mode for weight compression. | ||
| INT8_SYM stands for 8-bit integer symmetric quantization of all weights. | ||
| Weights are quantized symmetrically without zero point. | ||
| INT8_ASYM is the same as INT8_SYM mode, but weights are quantized to a primary precision asymmetrically | ||
| with a typical non-fixed zero point. | ||
| INT4_SYM stands for a mixed-precision weights quantization with 4-bit integer as a primary precision. | ||
| Weights are quantized to a primary precision symmetrically without zero point. | ||
| All embeddings and the last layer are always compressed to a backup_mode, which is INT8_ASYM, | ||
| by default. All others are quantized whether to 4-bit integer or to a backup_mode depending on | ||
| criteria and the given ratio. | ||
| INT4_ASYM is the same as INT4_SYM mode, but weights are quantized to a primary precision asymmetrically | ||
| with a typical non-fixed zero point. | ||
| :param quantizer: Quantizer to use in WeightCompression algorithm. | ||
| :param ratio: the ratio between primary and backup precisions (e.g. 0.9 means 90% of layers quantized to INT4 | ||
| and the rest to backup_mode). | ||
| :param group_size: number of weights (e.g. 128) in the channel dimension | ||
| that share quantization parameters (scale). The value -1 means no grouping. | ||
| :param ignored_scope: An ignored scope that defined the list of model control | ||
| flow graph nodes to be ignored during quantization. | ||
| :param all_layers: Indicates whether embeddings and last MatMul layers should be compressed to a primary | ||
| precision. By default, the backup precision is assigned for the embeddings and last MatMul layers. | ||
| :param subset_size: Number of data samples to calculate activation statistics used for assigning different | ||
| quantization precision. | ||
| :param awq: determines whether to use or not modified AWQ algorithm. | ||
| :param scale_estimation: determines whether to use or not scale estimation for 4 bit layers. | ||
| :param gptq: determines whether to use or not GPTQ algorithm. | ||
| :param lora_correction: determines whether to use or not LoRA Correction algorithm. | ||
| :param backup_mode: Defines a backup mode for mixed-precision weight compression. | ||
| NONE stands for original floating-point precision of the model weights. | ||
| In this mode, weights are retained in their original precision without any quantization. | ||
| INT8_SYM stands for 8-bit integer symmetric quantization without zero point. | ||
| INT8_ASYM stands for 8-bit integer asymmetric quantization with a typical non-fixed zero point. | ||
| :param sensitivity_metric: The sensitivity metric for assigning quantization precision to layers. In order to | ||
| preserve the accuracy of the model, the more sensitive layers receives a higher precision. | ||
| :param compression_format: Describes the format in which the model is saved after weight compression. | ||
| :param advanced_parameters: advanced parameters for algorithms in compression pipeline. | ||
| """ | ||
| self._quantizer = quantizer | ||
|
|
||
| self._mode = mode | ||
| self._awq = awq | ||
| self._gptq = gptq | ||
| self._scale_estimation = scale_estimation | ||
| self._subset_size = subset_size | ||
| self._advanced_parameters = advanced_parameters | ||
| self._lora_correction = lora_correction | ||
| self._ratio = ratio | ||
| self._group_size = group_size | ||
| self._all_layers = all_layers | ||
| self._backup_mode = backup_mode | ||
| self._sensitivity_metric = sensitivity_metric | ||
| self._compression_format = compression_format | ||
|
|
||
| self._algo = OriginalWeightCompression( | ||
| mode=self._mode, | ||
| ratio=self._ratio, | ||
| group_size=self._group_size, | ||
| ignored_scope=ignored_scope, | ||
| all_layers=self._all_layers, | ||
| sensitivity_metric=self._sensitivity_metric, | ||
| awq=self._awq, | ||
| subset_size=self._subset_size, | ||
| scale_estimation=self._scale_estimation, | ||
| gptq=self._gptq, | ||
| lora_correction=self._lora_correction, | ||
| backup_mode=self._backup_mode, | ||
| compression_format=self._compression_format, | ||
| advanced_parameters=self._advanced_parameters, | ||
| ) | ||
|
|
||
| def available_backends(self) -> list[BackendType]: | ||
| return self._algo.available_backends() | ||
|
|
||
| def apply( | ||
| self, | ||
| model: torch.fx.GraphModule, | ||
| graph: NNCFGraph, | ||
| statistic_points: Optional[StatisticPointsContainer] = None, | ||
| dataset: Optional[Dataset] = None, | ||
| ) -> torch.fx.GraphModule: | ||
| self._algo.set_backend_entity(model) | ||
|
|
||
| all_weight_params, ratio_defining_params, skipped_weight_params = ( | ||
| self._quantizer.get_weight_compression_parameters(model, graph) | ||
| ) | ||
|
|
||
| return self._algo.apply_with_parameters( | ||
| model, | ||
| graph, | ||
| dataset, | ||
| statistic_points, | ||
| all_weight_params, | ||
| ratio_defining_params, | ||
| skipped_weight_params, | ||
| ) | ||
|
|
||
| def get_statistic_points( | ||
| self, | ||
| model: torch.fx.GraphModule, | ||
| graph: NNCFGraph, | ||
| nodes_and_port_ids: Iterable[tuple[NNCFNode, int]], | ||
| ) -> StatisticPointsContainer: | ||
| """ | ||
| Returns statistic points, for which StatisticsCollector should collect statistics. | ||
|
|
||
| :param model: Model for statistics collection. | ||
| :param graph: Model graph. | ||
| :param nodes_and_port_ids: Nodes and port ids for which statistics should be collected. | ||
| :return: Statistic points, for which StatisticsCollector should collect statistics. | ||
| """ | ||
| return self._algo.get_statistic_points(model, graph, nodes_and_port_ids) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,11 +22,14 @@ | |
| from torch.fx.passes.infra.pass_manager import PassManager | ||
|
|
||
| import nncf | ||
| from nncf import AdvancedCompressionParameters | ||
| from nncf import Dataset | ||
| from nncf import SensitivityMetric | ||
| from nncf.common.factory import NNCFGraphFactory | ||
| from nncf.common.logging import nncf_logger | ||
| from nncf.common.utils.api_marker import api | ||
| from nncf.experimental.quantization.algorithms.post_training.algorithm import ExperimentalPostTrainingQuantization | ||
| from nncf.experimental.quantization.algorithms.weight_compression.algorithm import WeightsCompression | ||
| from nncf.experimental.torch.fx.constant_folding import constant_fold | ||
| from nncf.experimental.torch.fx.quantization.quantizer.openvino_adapter import OpenVINOQuantizerAdapter | ||
| from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer | ||
|
|
@@ -36,6 +39,7 @@ | |
| from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation | ||
| from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters | ||
| from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters | ||
| from nncf.quantization.algorithms.weight_compression.algorithm import get_weight_compression_configuration | ||
| from nncf.quantization.range_estimator import RangeEstimatorParameters | ||
|
|
||
|
|
||
|
|
@@ -157,3 +161,94 @@ def _quant_node_constraint(n: torch.fx.Node) -> bool: | |
| related to quantization | ||
| """ | ||
| return n.op == "call_function" and n.target in QUANTIZE_NODE_TARGETS | ||
|
|
||
|
|
||
| @api(canonical_alias="nncf.experimental.torch.fx.compress_pt2e") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please check that API docs reflect the new API correctly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you mean to ask about the method docstring or is there another API doc? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://openvinotoolkit.github.io/nncf/autoapi/nncf/ |
||
| def compress_pt2e( | ||
| model: torch.fx.GraphModule, | ||
| quantizer: Quantizer, | ||
| dataset: Optional[nncf.Dataset] = None, | ||
| awq: bool = False, | ||
| scale_estimation: bool = False, | ||
| gptq: bool = False, | ||
| lora_correction: bool = False, | ||
| subset_size: int = 128, | ||
| ratio: int = 1, | ||
| sensitivity_metric: Optional[SensitivityMetric] = None, | ||
| advanced_parameters: Optional[AdvancedCompressionParameters] = None, | ||
| ) -> torch.fx.GraphModule: | ||
| """ | ||
| Applies Weight Compression to the torch.fx.GraphModule provided model | ||
| using provided torch.ao quantizer. | ||
|
|
||
| :param model: A torch.fx.GraphModule instance to be quantized. | ||
| :param quantizer: Torch ao quantizer to annotate nodes in the graph with quantization setups | ||
| to convey the desired way of quantization. | ||
| :param dataset: A representative dataset for the | ||
| calibration process. | ||
| :param awq: Determines whether to use or not the modified AWQ algorithm. | ||
| :param scale_estimation: Determines whether to use or not scale estimation for 4-bit layers. | ||
| :param gptq: Determines whether to use or not GPTQ algorithm. | ||
| :param lora_correction: Determines whether to use or not LoRA Correction algorithm. | ||
| :param subset_size: Number of data samples to calculate activation statistics used for assigning different | ||
| quantization precision. | ||
| :param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 | ||
| and the rest to INT8_ASYM). | ||
| :param sensitivity_metric: The sensitivity metric for assigning quantization precision to layers. In order to | ||
| preserve the accuracy of the model, the more sensitive layers receive a higher precision. | ||
| :param advanced_parameters: Advanced parameters for algorithms in the compression pipeline. | ||
| """ | ||
| if isinstance(quantizer, OpenVINOQuantizer) or hasattr(quantizer, "get_nncf_weight_compression_parameters"): | ||
| quantizer = OpenVINOQuantizerAdapter(quantizer) | ||
| compression_format = nncf.CompressionFormat.DQ | ||
| else: | ||
| # TODO Support Third party quantizers here. | ||
| msg = "Only OpenVINO Quantizer is supported currently." | ||
| raise nncf.InternalError(msg) | ||
|
|
||
| wc_config = quantizer.get_weight_compression_config() | ||
|
|
||
| mode = wc_config.get("mode", None) | ||
| awq = awq | ||
| gptq = gptq | ||
| scale_estimation = scale_estimation | ||
| subset_size = subset_size | ||
| advanced_parameters = advanced_parameters | ||
| lora_correction = lora_correction | ||
| ratio = ratio | ||
| group_size = wc_config.get("group_size", 128) | ||
| all_layers = wc_config.get("all_layers", False) | ||
| backup_mode = wc_config.get("backup_mode", nncf.BackupMode.INT8_ASYM) | ||
| sensitivity_metric = sensitivity_metric | ||
| compression_format = compression_format | ||
| ignored_scope = nncf.IgnoredScope() # This is already defined in the quantizer object | ||
|
|
||
| weight_compression_configuration = get_weight_compression_configuration( | ||
| mode, | ||
| dataset, | ||
| ratio, | ||
| group_size, | ||
| all_layers, | ||
| awq, | ||
| scale_estimation, | ||
| gptq, | ||
| lora_correction, | ||
| ignored_scope, | ||
| sensitivity_metric, | ||
| backup_mode, | ||
| advanced_parameters, | ||
| ) | ||
|
|
||
| quantization_algorithm = WeightsCompression( | ||
| quantizer=quantizer, | ||
| subset_size=subset_size, | ||
| compression_format=compression_format, | ||
| **weight_compression_configuration, | ||
| ) | ||
|
|
||
| # Here the model is annotated | ||
| transformed_model = quantizer.transform_prior_quantization(model) | ||
| nncf_graph = NNCFGraphFactory.create(transformed_model) | ||
| quantized_model = quantization_algorithm.apply(transformed_model, nncf_graph, dataset=dataset) | ||
| quantized_model = torch.fx.GraphModule(quantized_model, graph=quantized_model.graph) | ||
| return quantized_model | ||
Uh oh!
There was an error while loading. Please reload this page.