diff --git a/.github/workflows/call_precommit.yml b/.github/workflows/call_precommit.yml index 6c2f58ae9ce..ddabf24213a 100644 --- a/.github/workflows/call_precommit.yml +++ b/.github/workflows/call_precommit.yml @@ -136,6 +136,48 @@ jobs: env: NUM_WORKERS: 4 + executorch: + timeout-minutes: 40 + runs-on: ubuntu-latest-8-cores + defaults: + run: + shell: bash + env: + DEBIAN_FRONTEND: noninteractive + steps: + - name: Install dependencies + run : | + sudo apt-get update + sudo apt-get --assume-yes install gcc g++ build-essential ninja-build libgl1-mesa-dev libglib2.0-0 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + lfs: true + - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + with: + python-version: ${{ inputs.python_version }} + - name: Runner info + continue-on-error: true + run: | + cat /etc/*release + cat /proc/cpuinfo + - name: Override constraints + if: ${{ inputs.override_requirements != '' }} + run: python .github/scripts/override_constraints.py "${{ inputs.override_requirements }}" + shell: bash + - name: Install NNCF and test requirements + run: | + pip install . -r tests/executorch/requirements.txt + # Executorch + # Editable install due to https://github.com/pytorch/executorch/issues/6475 + pip install --no-build-isolation -e git+https://github.com/anzr299/executorch.git@an/quantizer_nncf_pt2e_support#egg=executorch + - name: Print installed modules + run: pip list + - name: Run PyTorch precommit test scope + run: | + make test-executorch + env: + NUM_WORKERS: 4 + pytorch-cuda: timeout-minutes: 40 runs-on: aks-linux-4-cores-28gb-gpu-tesla-t4 diff --git a/Makefile b/Makefile index d420c8ce4f0..0cf574a1848 100644 --- a/Makefile +++ b/Makefile @@ -141,6 +141,9 @@ test-torch-cpu: test-torch-cuda: pytest ${COVERAGE_ARGS} tests/torch -ra -m "cuda and not weekly and not nightly and not models_hub and not legacy" --junitxml ${JUNITXML_PATH} +test-executorch: + pytest ${COVERAGE_ARGS} tests/executorch --junitxml ${JUNITXML_PATH} + test-torch-nightly: pytest ${COVERAGE_ARGS} tests/torch -m "nightly or legacy" --junitxml ${JUNITXML_PATH} $(DATA_ARG) diff --git a/src/nncf/experimental/quantization/algorithms/weight_compression/__init__.py b/src/nncf/experimental/quantization/algorithms/weight_compression/__init__.py new file mode 100644 index 00000000000..e5a42efc0ef --- /dev/null +++ b/src/nncf/experimental/quantization/algorithms/weight_compression/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/nncf/experimental/quantization/algorithms/weight_compression/algorithm.py b/src/nncf/experimental/quantization/algorithms/weight_compression/algorithm.py new file mode 100644 index 00000000000..7d541421d2a --- /dev/null +++ b/src/nncf/experimental/quantization/algorithms/weight_compression/algorithm.py @@ -0,0 +1,169 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Optional + +import torch + +from nncf import AdvancedCompressionParameters +from nncf import BackupMode +from nncf import CompressionFormat +from nncf import CompressWeightsMode +from nncf import Dataset +from nncf import IgnoredScope +from nncf import SensitivityMetric +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer +from nncf.common.utils.backend import BackendType +from nncf.experimental.quantization.quantizer import Quantizer +from nncf.quantization.algorithms.algorithm import Algorithm +from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression as OriginalWeightCompression + + +class WeightsCompression(Algorithm): + """ + Post-training Weight Compression algorithm implementation. + + Compresses weights of Linear and Embedding layers to 8-bit integer or + to 4-bit integer/float depending on mode, ratio and group size. + """ + + def __init__( + self, + mode: CompressWeightsMode, + quantizer: Quantizer, + ratio: float, + group_size: int, + ignored_scope: IgnoredScope, + all_layers: bool, + subset_size: int, + awq: bool, + scale_estimation: bool, + gptq: bool, + lora_correction: bool, + backup_mode: BackupMode, + sensitivity_metric: SensitivityMetric, + compression_format: CompressionFormat, + advanced_parameters: AdvancedCompressionParameters, + ) -> torch.fx.GraphModule: + """ + :param mode: Defines a mode for weight compression. + INT8_SYM stands for 8-bit integer symmetric quantization of all weights. + Weights are quantized symmetrically without zero point. + INT8_ASYM is the same as INT8_SYM mode, but weights are quantized to a primary precision asymmetrically + with a typical non-fixed zero point. + INT4_SYM stands for a mixed-precision weights quantization with 4-bit integer as a primary precision. + Weights are quantized to a primary precision symmetrically without zero point. + All embeddings and the last layer are always compressed to a backup_mode, which is INT8_ASYM, + by default. All others are quantized whether to 4-bit integer or to a backup_mode depending on + criteria and the given ratio. + INT4_ASYM is the same as INT4_SYM mode, but weights are quantized to a primary precision asymmetrically + with a typical non-fixed zero point. + :param quantizer: Quantizer to use in WeightCompression algorithm. + :param ratio: the ratio between primary and backup precisions (e.g. 0.9 means 90% of layers quantized to INT4 + and the rest to backup_mode). + :param group_size: number of weights (e.g. 128) in the channel dimension + that share quantization parameters (scale). The value -1 means no grouping. + :param ignored_scope: An ignored scope that defined the list of model control + flow graph nodes to be ignored during quantization. + :param all_layers: Indicates whether embeddings and last MatMul layers should be compressed to a primary + precision. By default, the backup precision is assigned for the embeddings and last MatMul layers. + :param subset_size: Number of data samples to calculate activation statistics used for assigning different + quantization precision. + :param awq: determines whether to use or not modified AWQ algorithm. + :param scale_estimation: determines whether to use or not scale estimation for 4 bit layers. + :param gptq: determines whether to use or not GPTQ algorithm. + :param lora_correction: determines whether to use or not LoRA Correction algorithm. + :param backup_mode: Defines a backup mode for mixed-precision weight compression. + NONE stands for original floating-point precision of the model weights. + In this mode, weights are retained in their original precision without any quantization. + INT8_SYM stands for 8-bit integer symmetric quantization without zero point. + INT8_ASYM stands for 8-bit integer asymmetric quantization with a typical non-fixed zero point. + :param sensitivity_metric: The sensitivity metric for assigning quantization precision to layers. In order to + preserve the accuracy of the model, the more sensitive layers receives a higher precision. + :param compression_format: Describes the format in which the model is saved after weight compression. + :param advanced_parameters: advanced parameters for algorithms in compression pipeline. + """ + self._quantizer = quantizer + + self._mode = mode + self._awq = awq + self._gptq = gptq + self._scale_estimation = scale_estimation + self._subset_size = subset_size + self._advanced_parameters = advanced_parameters + self._lora_correction = lora_correction + self._ratio = ratio + self._group_size = group_size + self._all_layers = all_layers + self._backup_mode = backup_mode + self._sensitivity_metric = sensitivity_metric + self._compression_format = compression_format + + self._algo = OriginalWeightCompression( + mode=self._mode, + ratio=self._ratio, + group_size=self._group_size, + ignored_scope=ignored_scope, + all_layers=self._all_layers, + sensitivity_metric=self._sensitivity_metric, + awq=self._awq, + subset_size=self._subset_size, + scale_estimation=self._scale_estimation, + gptq=self._gptq, + lora_correction=self._lora_correction, + backup_mode=self._backup_mode, + compression_format=self._compression_format, + advanced_parameters=self._advanced_parameters, + ) + + def available_backends(self) -> list[BackendType]: + return self._algo.available_backends() + + def apply( + self, + model: torch.fx.GraphModule, + graph: NNCFGraph, + statistic_points: Optional[StatisticPointsContainer] = None, + dataset: Optional[Dataset] = None, + ) -> torch.fx.GraphModule: + self._algo.set_backend_entity(model) + + all_weight_params, ratio_defining_params, skipped_weight_params = ( + self._quantizer.get_weight_compression_parameters(model, graph) + ) + + return self._algo.apply_with_parameters( + model, + graph, + dataset, + statistic_points, + all_weight_params, + ratio_defining_params, + skipped_weight_params, + ) + + def get_statistic_points( + self, + model: torch.fx.GraphModule, + graph: NNCFGraph, + nodes_and_port_ids: Iterable[tuple[NNCFNode, int]], + ) -> StatisticPointsContainer: + """ + Returns statistic points, for which StatisticsCollector should collect statistics. + + :param model: Model for statistics collection. + :param graph: Model graph. + :param nodes_and_port_ids: Nodes and port ids for which statistics should be collected. + :return: Statistic points, for which StatisticsCollector should collect statistics. + """ + return self._algo.get_statistic_points(model, graph, nodes_and_port_ids) diff --git a/src/nncf/experimental/torch/fx/__init__.py b/src/nncf/experimental/torch/fx/__init__.py index 2ecdde60840..86cd9709f6b 100644 --- a/src/nncf/experimental/torch/fx/__init__.py +++ b/src/nncf/experimental/torch/fx/__init__.py @@ -9,5 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nncf.experimental.torch.fx.quantization.quantize_pt2e import compress_pt2e as compress_pt2e from nncf.experimental.torch.fx.quantization.quantize_pt2e import quantize_pt2e as quantize_pt2e from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer as OpenVINOQuantizer diff --git a/src/nncf/experimental/torch/fx/quantization/quantize_pt2e.py b/src/nncf/experimental/torch/fx/quantization/quantize_pt2e.py index 3f0b3186310..b493a9461af 100644 --- a/src/nncf/experimental/torch/fx/quantization/quantize_pt2e.py +++ b/src/nncf/experimental/torch/fx/quantization/quantize_pt2e.py @@ -22,11 +22,14 @@ from torch.fx.passes.infra.pass_manager import PassManager import nncf +from nncf import AdvancedCompressionParameters from nncf import Dataset +from nncf import SensitivityMetric from nncf.common.factory import NNCFGraphFactory from nncf.common.logging import nncf_logger from nncf.common.utils.api_marker import api from nncf.experimental.quantization.algorithms.post_training.algorithm import ExperimentalPostTrainingQuantization +from nncf.experimental.quantization.algorithms.weight_compression.algorithm import WeightsCompression from nncf.experimental.torch.fx.constant_folding import constant_fold from nncf.experimental.torch.fx.quantization.quantizer.openvino_adapter import OpenVINOQuantizerAdapter from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer @@ -36,6 +39,7 @@ from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters +from nncf.quantization.algorithms.weight_compression.algorithm import get_weight_compression_configuration from nncf.quantization.range_estimator import RangeEstimatorParameters @@ -157,3 +161,94 @@ def _quant_node_constraint(n: torch.fx.Node) -> bool: related to quantization """ return n.op == "call_function" and n.target in QUANTIZE_NODE_TARGETS + + +@api(canonical_alias="nncf.experimental.torch.fx.compress_pt2e") +def compress_pt2e( + model: torch.fx.GraphModule, + quantizer: Quantizer, + dataset: Optional[nncf.Dataset] = None, + awq: bool = False, + scale_estimation: bool = False, + gptq: bool = False, + lora_correction: bool = False, + subset_size: int = 128, + ratio: int = 1, + sensitivity_metric: Optional[SensitivityMetric] = None, + advanced_parameters: Optional[AdvancedCompressionParameters] = None, +) -> torch.fx.GraphModule: + """ + Applies Weight Compression to the torch.fx.GraphModule provided model + using provided torch.ao quantizer. + + :param model: A torch.fx.GraphModule instance to be quantized. + :param quantizer: Torch ao quantizer to annotate nodes in the graph with quantization setups + to convey the desired way of quantization. + :param dataset: A representative dataset for the + calibration process. + :param awq: Determines whether to use or not the modified AWQ algorithm. + :param scale_estimation: Determines whether to use or not scale estimation for 4-bit layers. + :param gptq: Determines whether to use or not GPTQ algorithm. + :param lora_correction: Determines whether to use or not LoRA Correction algorithm. + :param subset_size: Number of data samples to calculate activation statistics used for assigning different + quantization precision. + :param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4 + and the rest to INT8_ASYM). + :param sensitivity_metric: The sensitivity metric for assigning quantization precision to layers. In order to + preserve the accuracy of the model, the more sensitive layers receive a higher precision. + :param advanced_parameters: Advanced parameters for algorithms in the compression pipeline. + """ + if isinstance(quantizer, OpenVINOQuantizer) or hasattr(quantizer, "get_nncf_weight_compression_parameters"): + quantizer = OpenVINOQuantizerAdapter(quantizer) + compression_format = nncf.CompressionFormat.DQ + else: + # TODO Support Third party quantizers here. + msg = "Only OpenVINO Quantizer is supported currently." + raise nncf.InternalError(msg) + + wc_config = quantizer.get_weight_compression_config() + + mode = wc_config.get("mode", None) + awq = awq + gptq = gptq + scale_estimation = scale_estimation + subset_size = subset_size + advanced_parameters = advanced_parameters + lora_correction = lora_correction + ratio = ratio + group_size = wc_config.get("group_size", 128) + all_layers = wc_config.get("all_layers", False) + backup_mode = wc_config.get("backup_mode", nncf.BackupMode.INT8_ASYM) + sensitivity_metric = sensitivity_metric + compression_format = compression_format + ignored_scope = nncf.IgnoredScope() # This is already defined in the quantizer object + + weight_compression_configuration = get_weight_compression_configuration( + mode, + dataset, + ratio, + group_size, + all_layers, + awq, + scale_estimation, + gptq, + lora_correction, + ignored_scope, + sensitivity_metric, + backup_mode, + advanced_parameters, + ) + + quantization_algorithm = WeightsCompression( + quantizer=quantizer, + subset_size=subset_size, + compression_format=compression_format, + **weight_compression_configuration, + ) + + # Here the model is annotated + transformed_model = quantizer.transform_prior_quantization(model) + nncf_graph = NNCFGraphFactory.create(transformed_model) + quantized_model = quantization_algorithm.apply(transformed_model, nncf_graph, dataset=dataset) + quantized_model = torch.fx.GraphModule(quantized_model, graph=quantized_model.graph) + return quantized_model diff --git a/src/nncf/experimental/torch/fx/quantization/quantizer/openvino_adapter.py b/src/nncf/experimental/torch/fx/quantization/quantizer/openvino_adapter.py index 2283d9d9dbb..5b4bd321780 100644 --- a/src/nncf/experimental/torch/fx/quantization/quantizer/openvino_adapter.py +++ b/src/nncf/experimental/torch/fx/quantization/quantizer/openvino_adapter.py @@ -9,12 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import torch.fx from nncf.common.graph.graph import NNCFGraph from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup from nncf.experimental.quantization.quantizer import Quantizer from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer +from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters class OpenVINOQuantizerAdapter(Quantizer): @@ -30,3 +33,17 @@ def transform_prior_quantization(self, model: torch.fx.GraphModule) -> torch.fx. def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup: return self._quantizer.get_nncf_quantization_setup(model, nncf_graph) + + def get_weight_compression_parameters( + self, + model: torch.fx.GraphModule, + nncf_graph: NNCFGraph, + ) -> tuple[ + list[WeightCompressionParameters], + list[WeightCompressionParameters], + list[WeightCompressionParameters], + ]: + return self._quantizer.get_nncf_weight_compression_parameters(model, nncf_graph) + + def get_weight_compression_config(self) -> dict[str, Any]: + return self._quantizer.weight_compression_configuration diff --git a/src/nncf/quantization/algorithms/weight_compression/algorithm.py b/src/nncf/quantization/algorithms/weight_compression/algorithm.py index 50a5399a5c8..784f2c1ce91 100644 --- a/src/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/src/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -103,7 +103,7 @@ def get_weight_compression_configuration( group_size = 128 return { - "mode": mode, + "mode": mode if isinstance(mode, nncf.CompressWeightsMode) else nncf.CompressWeightsMode(mode), "ratio": ratio or 1, "group_size": group_size, "all_layers": all_layers or False, @@ -505,6 +505,23 @@ def _get_ratio_defining_params( return ratio_defining_params + def _get_backup_config(self, weight_dtype: TensorDataType) -> WeightCompressionConfig: + """ + Returns the backup weight compression configuration based on the algorithm's backup mode. + + :param weight_dtype: Data type of the weight tensor. + :return: A WeightCompressionConfig object for the backup precision, or None if backup is + disabled or unsupported. + """ + if self._backup_mode == BackupMode.NONE: + return None + mode = ( + CompressWeightsMode.INT8_ASYM if self._backup_mode == BackupMode.INT8_ASYM else CompressWeightsMode.INT8_SYM + ) + if not self.is_weight_compression_supported(weight_dtype, mode): + return None + return WeightCompressionConfig(mode=mode) + def _get_primary_config(self, group_size: int) -> WeightCompressionConfig: codebook_values = None @@ -525,7 +542,6 @@ def _set_weight_compression_config( model: TModel, graph: NNCFGraph, statistics_points: StatisticPointsContainer, - group_size_values: dict[str, int], ) -> None: """ Sets the appropriate compression configuration for weights based on some criteria. @@ -541,12 +557,12 @@ def _set_weight_compression_config( primary_precision_weight_params = self._mixed_precision_algo.apply( model, graph, statistics_points, weight_params=ratio_defining_params ) - else: - primary_precision_weight_params = ratio_defining_params - - for weight_param in primary_precision_weight_params: - weight_param.compression_config = self._get_primary_config(group_size_values[weight_param.weight_name]) - + # ratio_defining_params are all in primary precision. Update parameters + # which need to be set to backup precision + for weight_param in ratio_defining_params: + if weight_param in primary_precision_weight_params: + continue + weight_param.compression_config = self._get_backup_config(weight_param.weight_dtype) # Check if group size is valid for each weight in ratio_defining_params failed_nodes = [] for w_params in ratio_defining_params: @@ -783,27 +799,66 @@ def is_weight_compression_supported( return is_supported_dtype and not no_bit_reduction + def _collect_statistics_and_statistic_points( + self, + model: TModel, + graph: NNCFGraph, + statistic_points: StatisticPointsContainer, + dataset: Dataset, + ratio_defining_params: list[WeightCompressionParameters], + all_weight_params: list[WeightCompressionParameters], + ) -> tuple[dict[str, WCTensorStatistic], StatisticPointsContainer]: + """ + Collects and computes statistics required for weight compression. + + :param model: Backend-specific model instance. + :param graph: Corresponding NNCFGraph of the model. + :param Container with pre-collected statistics, if available.. + :param dataset: Dataset used for collecting statistics when not provided. + :param ratio_defining_params: List of parameters defining compression ratios. + :param all_weight_params: List of all weight compression parameters. + :return: A tuple containing collected statistics for weight compression and the updated statistic_points. + """ + if not dataset or not (self._data_aware_mixed_precision or self._data_aware_compression): + return None, statistic_points + weight_params = ratio_defining_params if self._backup_mode == BackupMode.NONE else all_weight_params + matmul_nodes_to_compress = [ + wp.node_with_weight + for wp in weight_params + if wp.node_with_weight.metatype in self._backend_entity.matmul_metatypes + ] + matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map(matmul_nodes_to_compress, graph) + if statistic_points is None: + statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys()) + statistic_points = self._collect_statistics(dataset, graph, model, statistic_points) + statistics = self._get_statistics_for_weights_compression(matmul_input_to_output_nodes_map, statistic_points) + return statistics, statistic_points + def get_weight_compression_parameters( self, model: TModel, graph: NNCFGraph, - statistic_points: Optional[StatisticPointsContainer] = None, - dataset: Optional[Dataset] = None, - ) -> tuple[list[WeightCompressionParameters], Optional[dict[str, WCTensorStatistic]]]: + ) -> tuple[ + list[WeightCompressionParameters], + list[WeightCompressionParameters], + list[WeightCompressionParameters], + ]: """ - Generates a list of weight compression parameters based on the Weight Compression algorithm - configuration. Determines the appropriate quantization parameters for each node eligible for - weight compression. Also, Generates a mapping of target node names to the collected statistics - based on the provided statistic_points. If statistic_points is None, collects required - compression statistics on the given dataset. + This Function does the following: + + * Generates a list of weight compression parameters based on the algorithm configuration. + * Determines the appropriate quantization parameters for each node eligible for weight compression. + * Generates a subset of parameters that can be compressed in both primary and backup precisions, + called ratio-defining parameters. All ratio-defining parameters are set to the primary precision. + * Generates a subset of parameters that will not be compressed, based on the ignored scope or + compression configuration restrictions. :param model: Backend-specific input model. :param graph: NNCFGraph instance. - :param statistic_points: Optional pre-collected statistic points. - :param dataset: Optional dataset for statistics collection. - :return: A tuple consisting of a list of weight compression parameters, based on the Weight - Compression algorithm configuration, and a mapping of target node names to the - collected statistics. + :return: A tuple consisting a list of weight compression parameters that can be compressed, + a list of ratio-defining parameters, which is a subset of compressible weight parameters + that are allowed to be set to mixed precisions, and a list of weight compression parameters + that can not be compressed. """ nodes_to_compress = self.get_nodes_to_compress(graph) @@ -828,8 +883,8 @@ def get_weight_compression_parameters( weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph) weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph) reduction_axes = self._backend_entity.get_reduction_axes(node, weight_port_id, graph) - wc_config = None + if is_target_node and self.is_weight_compression_supported(weight_dtype, self._mode): if ( self._group_size != -1 @@ -848,14 +903,7 @@ def get_weight_compression_parameters( f"node name: {node.node_name}. The node will be in {self._backup_mode} mode." ) - if self._backup_mode != BackupMode.NONE: - mode = ( - CompressWeightsMode.INT8_ASYM - if self._backup_mode == BackupMode.INT8_ASYM - else CompressWeightsMode.INT8_SYM - ) - if self.is_weight_compression_supported(weight_dtype, mode): - wc_config = WeightCompressionConfig(mode=mode) + wc_config = self._get_backup_config(weight_dtype) weight_params = WeightCompressionParameters( weight_name, node, weight_port_id, weight_dtype, weight_shape, reduction_axes, wc_config @@ -883,37 +931,11 @@ def get_weight_compression_parameters( else: group_size_values = {w_params.weight_name: self._group_size for w_params in ratio_defining_params} - # Collect statistics for the weights compression - statistics = None - if (self._data_aware_mixed_precision or self._data_aware_compression) and dataset: - weight_params = ratio_defining_params if self._backup_mode == BackupMode.NONE else all_weight_params - matmul_nodes_to_compress = [ - wp.node_with_weight - for wp in weight_params - if wp.node_with_weight.metatype in self._backend_entity.matmul_metatypes - ] - matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map( - matmul_nodes_to_compress, graph - ) - if statistic_points is None: - statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys()) - statistic_points = self._collect_statistics(dataset, graph, model, statistic_points) - statistics = self._get_statistics_for_weights_compression( - matmul_input_to_output_nodes_map, statistic_points - ) - - # Set weight compression configuration - self._set_weight_compression_config(ratio_defining_params, model, graph, statistic_points, group_size_values) - - # Print statistics - nncf_logger.info( - self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params, skipped_weight_params) - ) - - # Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision - all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params)) + # Set each ratio defining parameter to primary config + for weight_param in ratio_defining_params: + weight_param.compression_config = self._get_primary_config(group_size_values[weight_param.weight_name]) - return all_weight_params, statistics + return all_weight_params, ratio_defining_params, skipped_weight_params def apply( self, @@ -925,7 +947,59 @@ def apply( self.set_backend_entity(model) # Get processed weight compression parameters ready for compression - all_weight_params, statistics = self.get_weight_compression_parameters(model, graph, statistic_points, dataset) + all_weight_params, ratio_defining_params, skipped_weight_params = self.get_weight_compression_parameters( + model, graph + ) + return self.apply_with_parameters( + model, + graph, + dataset, + statistic_points, + all_weight_params, + ratio_defining_params, + skipped_weight_params, + ) + + def apply_with_parameters( + self, + model: TModel, + graph: NNCFGraph, + dataset: Dataset, + statistic_points: StatisticPointsContainer, + all_weight_params: list[WeightCompressionParameters], + ratio_defining_params: list[WeightCompressionParameters], + skipped_weight_params: list[WeightCompressionParameters], + ) -> TModel: + """ + Applies the Weight Compression algorithm using precomputed parameters and optional + algorithms (AWQ, GPTQ, scale estimation, LoRA correction). The method collects + statistics, configures the weight compression parameters for mixed precision algorithm, + and performs the model transformation with appropriate decompression operations + + :param model: Backend-specific model to be compressed. + :param graph: NNCFGraph instance. + :param dataset: Dataset to collect statistics. + :param statistic_points: Statistics points object. + :param all_weight_params: List of all weight parameters. + :param ratio_defining_params: Subset of all_weight_params that determine mixed-precision ratios. + :param skipped_weight_params: List of parameters corresponding to weights intentionally skipped + from compression (e.g., due to ignored scopes or group size adjustments). + :return: Transformed model with compressed weights and inserted backend-specific decompressor. + """ + # Collect statistics for the weights compression + statistics, statistic_points = self._collect_statistics_and_statistic_points( + model, graph, statistic_points, dataset, ratio_defining_params, all_weight_params + ) + # Set weight compression configuration + self._set_weight_compression_config(ratio_defining_params, model, graph, statistic_points) + + # Print statistics + nncf_logger.info( + self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params, skipped_weight_params) + ) + + # Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision + all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params)) if self._awq: model = self.awq_algo.apply(model, graph, all_weight_params, statistics, self._backend_entity) @@ -1135,7 +1209,7 @@ def _get_statistics_for_weights_compression( :param matmul_input_to_output_nodes_map: A mapping from activation node and a port id to corresponding matmul nodes which accept this activation as an input. - :param statistic_points: Statistic points object. + :param statistic_points: Statistic points. :return: Collected statistics. """ # For each node we store statistics in a WCTensorStatistics data-class. It contains the following fields: diff --git a/tests/executorch/__init__.py b/tests/executorch/__init__.py new file mode 100644 index 00000000000..e5a42efc0ef --- /dev/null +++ b/tests/executorch/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/executorch/conftest.py b/tests/executorch/conftest.py new file mode 100644 index 00000000000..ce9b7b42661 --- /dev/null +++ b/tests/executorch/conftest.py @@ -0,0 +1,11 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +pytest_plugins = ["tests.torch2.conftest"] diff --git a/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_False.dot b/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_False.dot new file mode 100644 index 00000000000..0a9a27fd85b --- /dev/null +++ b/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_False.dot @@ -0,0 +1,169 @@ +strict digraph { +"0 attn_norm_weight" [id=0, type="get_attr"]; +"1 mlp_norm_weight" [id=1, type="get_attr"]; +"2 q_proj_weight_updated_constant0" [id=2, type="get_attr"]; +"3 symmetric_weights_decompressor_q_proj_weight_0" [id=3, type="call_module"]; +"4 k_proj_weight_updated_constant0" [id=4, type="get_attr"]; +"5 symmetric_weights_decompressor_k_proj_weight_0" [id=5, type="call_module"]; +"6 v_proj_weight_updated_constant0" [id=6, type="get_attr"]; +"7 symmetric_weights_decompressor_v_proj_weight_0" [id=7, type="call_module"]; +"8 o_proj_weight_updated_constant0" [id=8, type="get_attr"]; +"9 symmetric_weights_decompressor_o_proj_weight_0" [id=9, type="call_module"]; +"10 mlp_gate_proj_weight_updated_constant0" [id=10, type="get_attr"]; +"11 symmetric_weights_decompressor_mlp_gate_proj_weight_0" [id=11, type="call_module"]; +"12 mlp_up_proj_weight_updated_constant0" [id=12, type="get_attr"]; +"13 symmetric_weights_decompressor_mlp_up_proj_weight_0" [id=13, type="call_module"]; +"14 mlp_down_proj_weight_updated_constant0" [id=14, type="get_attr"]; +"15 asymmetric_weights_decompressor_mlp_down_proj_weight_0" [id=15, type="call_module"]; +"16 rope_cos" [id=16, type="get_attr"]; +"17 rope_sin" [id=17, type="get_attr"]; +"18 x_embed" [id=18, type=input]; +"19 arange" [id=19, type=arange]; +"20 _assert_tensor_metadata_default" [id=20, type="_assert_tensor_metadata"]; +"21 to" [id=21, type=to]; +"22 pow_1" [id=22, type=pow]; +"23 mean" [id=23, type=mean]; +"24 add" [id=24, type=add]; +"25 rsqrt" [id=25, type=rsqrt]; +"26 mul" [id=26, type=mul]; +"27 _assert_tensor_metadata_default_1" [id=27, type="_assert_tensor_metadata"]; +"28 to_1" [id=28, type=to]; +"29 mul_1" [id=29, type=mul]; +"30 linear" [id=30, type=linear]; +"31 view" [id=31, type=view]; +"32 transpose" [id=32, type=transpose]; +"33 linear_1" [id=33, type=linear]; +"34 view_1" [id=34, type=view]; +"35 transpose_1" [id=35, type=transpose]; +"36 linear_2" [id=36, type=linear]; +"37 view_2" [id=37, type=view]; +"38 transpose_2" [id=38, type=transpose]; +"39 index" [id=39, type=index]; +"40 index_1" [id=40, type=index]; +"41 mul_2" [id=41, type=mul]; +"42 slice_1" [id=42, type=slice]; +"43 slice_2" [id=43, type=slice]; +"44 neg" [id=44, type=neg]; +"45 cat" [id=45, type=cat]; +"46 mul_3" [id=46, type=mul]; +"47 add_1" [id=47, type=add]; +"48 mul_4" [id=48, type=mul]; +"49 slice_3" [id=49, type=slice]; +"50 slice_4" [id=50, type=slice]; +"51 neg_1" [id=51, type=neg]; +"52 cat_1" [id=52, type=cat]; +"53 mul_5" [id=53, type=mul]; +"54 add_2" [id=54, type=add]; +"55 scaled_dot_product_attention" [id=55, type="scaled_dot_product_attention"]; +"56 transpose_3" [id=56, type=transpose]; +"57 view_3" [id=57, type=view]; +"58 linear_3" [id=58, type=linear]; +"59 add_3" [id=59, type=add]; +"60 _assert_tensor_metadata_default_2" [id=60, type="_assert_tensor_metadata"]; +"61 to_2" [id=61, type=to]; +"62 pow_2" [id=62, type=pow]; +"63 mean_1" [id=63, type=mean]; +"64 add_4" [id=64, type=add]; +"65 rsqrt_1" [id=65, type=rsqrt]; +"66 mul_6" [id=66, type=mul]; +"67 _assert_tensor_metadata_default_3" [id=67, type="_assert_tensor_metadata"]; +"68 to_3" [id=68, type=to]; +"69 mul_7" [id=69, type=mul]; +"70 linear_4" [id=70, type=linear]; +"71 silu" [id=71, type=silu]; +"72 linear_5" [id=72, type=linear]; +"73 mul_8" [id=73, type=mul]; +"74 linear_6" [id=74, type=linear]; +"75 add_5" [id=75, type=add]; +"76 output" [id=76, type=output]; +"0 attn_norm_weight" -> "29 mul_1" [style=solid, label="(64,)"]; +"1 mlp_norm_weight" -> "69 mul_7" [style=solid, label="(64,)"]; +"2 q_proj_weight_updated_constant0" -> "3 symmetric_weights_decompressor_q_proj_weight_0" [style=solid, label="(2048, 1)"]; +"3 symmetric_weights_decompressor_q_proj_weight_0" -> "30 linear" [style=solid, label="(64, 64)"]; +"4 k_proj_weight_updated_constant0" -> "5 symmetric_weights_decompressor_k_proj_weight_0" [style=solid, label="(2048, 1)"]; +"5 symmetric_weights_decompressor_k_proj_weight_0" -> "33 linear_1" [style=solid, label="(64, 64)"]; +"6 v_proj_weight_updated_constant0" -> "7 symmetric_weights_decompressor_v_proj_weight_0" [style=solid, label="(2048, 1)"]; +"7 symmetric_weights_decompressor_v_proj_weight_0" -> "36 linear_2" [style=solid, label="(64, 64)"]; +"8 o_proj_weight_updated_constant0" -> "9 symmetric_weights_decompressor_o_proj_weight_0" [style=solid, label="(2048, 1)"]; +"9 symmetric_weights_decompressor_o_proj_weight_0" -> "58 linear_3" [style=solid, label="(64, 64)"]; +"10 mlp_gate_proj_weight_updated_constant0" -> "11 symmetric_weights_decompressor_mlp_gate_proj_weight_0" [style=solid, label="(4096, 1)"]; +"11 symmetric_weights_decompressor_mlp_gate_proj_weight_0" -> "70 linear_4" [style=solid, label="(128, 64)"]; +"12 mlp_up_proj_weight_updated_constant0" -> "13 symmetric_weights_decompressor_mlp_up_proj_weight_0" [style=solid, label="(4096, 1)"]; +"13 symmetric_weights_decompressor_mlp_up_proj_weight_0" -> "72 linear_5" [style=solid, label="(128, 64)"]; +"14 mlp_down_proj_weight_updated_constant0" -> "15 asymmetric_weights_decompressor_mlp_down_proj_weight_0" [style=solid, label="(64, 128)"]; +"15 asymmetric_weights_decompressor_mlp_down_proj_weight_0" -> "74 linear_6" [style=solid, label="(64, 128)"]; +"16 rope_cos" -> "39 index" [style=solid, label="(1, 1, 128, 16)"]; +"17 rope_sin" -> "40 index_1" [style=solid, label="(1, 1, 128, 16)"]; +"18 x_embed" -> "20 _assert_tensor_metadata_default" [style=solid, label="(1, 3, 64)"]; +"18 x_embed" -> "21 to" [style=solid, label="(1, 3, 64)"]; +"18 x_embed" -> "59 add_3" [style=solid, label="(1, 3, 64)"]; +"19 arange" -> "39 index" [style=solid, label="(3,)"]; +"19 arange" -> "40 index_1" [style=solid, label="(3,)"]; +"21 to" -> "22 pow_1" [style=solid, label="(1, 3, 64)"]; +"21 to" -> "26 mul" [style=solid, label="(1, 3, 64)"]; +"22 pow_1" -> "23 mean" [style=solid, label="(1, 3, 64)"]; +"23 mean" -> "24 add" [style=solid, label="(1, 3, 1)"]; +"24 add" -> "25 rsqrt" [style=solid, label="(1, 3, 1)"]; +"25 rsqrt" -> "26 mul" [style=solid, label="(1, 3, 1)"]; +"26 mul" -> "27 _assert_tensor_metadata_default_1" [style=solid, label="(1, 3, 64)"]; +"26 mul" -> "28 to_1" [style=solid, label="(1, 3, 64)"]; +"28 to_1" -> "29 mul_1" [style=solid, label="(1, 3, 64)"]; +"29 mul_1" -> "30 linear" [style=solid, label="(1, 3, 64)"]; +"29 mul_1" -> "33 linear_1" [style=solid, label="(1, 3, 64)"]; +"29 mul_1" -> "36 linear_2" [style=solid, label="(1, 3, 64)"]; +"30 linear" -> "31 view" [style=solid, label="(1, 3, 64)"]; +"31 view" -> "32 transpose" [style=solid, label="(1, 3, 4, 16)"]; +"32 transpose" -> "41 mul_2" [style=solid, label="(1, 4, 3, 16)"]; +"32 transpose" -> "42 slice_1" [style=solid, label="(1, 4, 3, 16)"]; +"32 transpose" -> "43 slice_2" [style=solid, label="(1, 4, 3, 16)"]; +"33 linear_1" -> "34 view_1" [style=solid, label="(1, 3, 64)"]; +"34 view_1" -> "35 transpose_1" [style=solid, label="(1, 3, 4, 16)"]; +"35 transpose_1" -> "48 mul_4" [style=solid, label="(1, 4, 3, 16)"]; +"35 transpose_1" -> "49 slice_3" [style=solid, label="(1, 4, 3, 16)"]; +"35 transpose_1" -> "50 slice_4" [style=solid, label="(1, 4, 3, 16)"]; +"36 linear_2" -> "37 view_2" [style=solid, label="(1, 3, 64)"]; +"37 view_2" -> "38 transpose_2" [style=solid, label="(1, 3, 4, 16)"]; +"38 transpose_2" -> "55 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"39 index" -> "41 mul_2" [style=solid, label="(1, 1, 3, 16)"]; +"39 index" -> "48 mul_4" [style=solid, label="(1, 1, 3, 16)"]; +"40 index_1" -> "46 mul_3" [style=solid, label="(1, 1, 3, 16)"]; +"40 index_1" -> "53 mul_5" [style=solid, label="(1, 1, 3, 16)"]; +"41 mul_2" -> "47 add_1" [style=solid, label="(1, 4, 3, 16)"]; +"42 slice_1" -> "45 cat" [style=solid, label="(1, 4, 3, 8)"]; +"43 slice_2" -> "44 neg" [style=solid, label="(1, 4, 3, 8)"]; +"44 neg" -> "45 cat" [style=solid, label="(1, 4, 3, 8)"]; +"45 cat" -> "46 mul_3" [style=solid, label="(1, 4, 3, 16)"]; +"46 mul_3" -> "47 add_1" [style=solid, label="(1, 4, 3, 16)"]; +"47 add_1" -> "55 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"48 mul_4" -> "54 add_2" [style=solid, label="(1, 4, 3, 16)"]; +"49 slice_3" -> "52 cat_1" [style=solid, label="(1, 4, 3, 8)"]; +"50 slice_4" -> "51 neg_1" [style=solid, label="(1, 4, 3, 8)"]; +"51 neg_1" -> "52 cat_1" [style=solid, label="(1, 4, 3, 8)"]; +"52 cat_1" -> "53 mul_5" [style=solid, label="(1, 4, 3, 16)"]; +"53 mul_5" -> "54 add_2" [style=solid, label="(1, 4, 3, 16)"]; +"54 add_2" -> "55 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"55 scaled_dot_product_attention" -> "56 transpose_3" [style=solid, label="(1, 4, 3, 16)"]; +"56 transpose_3" -> "57 view_3" [style=solid, label="(1, 3, 4, 16)"]; +"57 view_3" -> "58 linear_3" [style=solid, label="(1, 3, 64)"]; +"58 linear_3" -> "59 add_3" [style=solid, label="(1, 3, 64)"]; +"59 add_3" -> "60 _assert_tensor_metadata_default_2" [style=solid, label="(1, 3, 64)"]; +"59 add_3" -> "61 to_2" [style=solid, label="(1, 3, 64)"]; +"61 to_2" -> "62 pow_2" [style=solid, label="(1, 3, 64)"]; +"61 to_2" -> "66 mul_6" [style=solid, label="(1, 3, 64)"]; +"61 to_2" -> "75 add_5" [style=solid, label="(1, 3, 64)"]; +"62 pow_2" -> "63 mean_1" [style=solid, label="(1, 3, 64)"]; +"63 mean_1" -> "64 add_4" [style=solid, label="(1, 3, 1)"]; +"64 add_4" -> "65 rsqrt_1" [style=solid, label="(1, 3, 1)"]; +"65 rsqrt_1" -> "66 mul_6" [style=solid, label="(1, 3, 1)"]; +"66 mul_6" -> "67 _assert_tensor_metadata_default_3" [style=solid, label="(1, 3, 64)"]; +"66 mul_6" -> "68 to_3" [style=solid, label="(1, 3, 64)"]; +"68 to_3" -> "69 mul_7" [style=solid, label="(1, 3, 64)"]; +"69 mul_7" -> "70 linear_4" [style=solid, label="(1, 3, 64)"]; +"69 mul_7" -> "72 linear_5" [style=solid, label="(1, 3, 64)"]; +"70 linear_4" -> "71 silu" [style=solid, label="(1, 3, 128)"]; +"71 silu" -> "73 mul_8" [style=solid, label="(1, 3, 128)"]; +"72 linear_5" -> "73 mul_8" [style=solid, label="(1, 3, 128)"]; +"73 mul_8" -> "74 linear_6" [style=solid, label="(1, 3, 128)"]; +"74 linear_6" -> "75 add_5" [style=solid, label="(1, 3, 64)"]; +"75 add_5" -> "76 output" [style=solid, label="(1, 3, 64)"]; +} diff --git a/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_True.dot b/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_True.dot new file mode 100644 index 00000000000..254abcb9dc0 --- /dev/null +++ b/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_True.dot @@ -0,0 +1,169 @@ +strict digraph { +"0 attn_norm_weight" [id=0, type="get_attr"]; +"1 mlp_norm_weight" [id=1, type="get_attr"]; +"2 q_proj_weight_updated_constant0" [id=2, type="get_attr"]; +"3 symmetric_weights_decompressor_q_proj_weight_0" [id=3, type="call_module"]; +"4 k_proj_weight_updated_constant0" [id=4, type="get_attr"]; +"5 symmetric_weights_decompressor_k_proj_weight_0" [id=5, type="call_module"]; +"6 v_proj_weight_updated_constant0" [id=6, type="get_attr"]; +"7 symmetric_weights_decompressor_v_proj_weight_0" [id=7, type="call_module"]; +"8 o_proj_weight_updated_constant0" [id=8, type="get_attr"]; +"9 symmetric_weights_decompressor_o_proj_weight_0" [id=9, type="call_module"]; +"10 mlp_gate_proj_weight_updated_constant0" [id=10, type="get_attr"]; +"11 symmetric_weights_decompressor_mlp_gate_proj_weight_0" [id=11, type="call_module"]; +"12 mlp_up_proj_weight_updated_constant0" [id=12, type="get_attr"]; +"13 symmetric_weights_decompressor_mlp_up_proj_weight_0" [id=13, type="call_module"]; +"14 mlp_down_proj_weight_updated_constant0" [id=14, type="get_attr"]; +"15 symmetric_weights_decompressor_mlp_down_proj_weight_0" [id=15, type="call_module"]; +"16 rope_cos" [id=16, type="get_attr"]; +"17 rope_sin" [id=17, type="get_attr"]; +"18 x_embed" [id=18, type=input]; +"19 arange" [id=19, type=arange]; +"20 _assert_tensor_metadata_default" [id=20, type="_assert_tensor_metadata"]; +"21 to" [id=21, type=to]; +"22 pow_1" [id=22, type=pow]; +"23 mean" [id=23, type=mean]; +"24 add" [id=24, type=add]; +"25 rsqrt" [id=25, type=rsqrt]; +"26 mul" [id=26, type=mul]; +"27 _assert_tensor_metadata_default_1" [id=27, type="_assert_tensor_metadata"]; +"28 to_1" [id=28, type=to]; +"29 mul_1" [id=29, type=mul]; +"30 linear" [id=30, type=linear]; +"31 view" [id=31, type=view]; +"32 transpose" [id=32, type=transpose]; +"33 linear_1" [id=33, type=linear]; +"34 view_1" [id=34, type=view]; +"35 transpose_1" [id=35, type=transpose]; +"36 linear_2" [id=36, type=linear]; +"37 view_2" [id=37, type=view]; +"38 transpose_2" [id=38, type=transpose]; +"39 index" [id=39, type=index]; +"40 index_1" [id=40, type=index]; +"41 mul_2" [id=41, type=mul]; +"42 slice_1" [id=42, type=slice]; +"43 slice_2" [id=43, type=slice]; +"44 neg" [id=44, type=neg]; +"45 cat" [id=45, type=cat]; +"46 mul_3" [id=46, type=mul]; +"47 add_1" [id=47, type=add]; +"48 mul_4" [id=48, type=mul]; +"49 slice_3" [id=49, type=slice]; +"50 slice_4" [id=50, type=slice]; +"51 neg_1" [id=51, type=neg]; +"52 cat_1" [id=52, type=cat]; +"53 mul_5" [id=53, type=mul]; +"54 add_2" [id=54, type=add]; +"55 scaled_dot_product_attention" [id=55, type="scaled_dot_product_attention"]; +"56 transpose_3" [id=56, type=transpose]; +"57 view_3" [id=57, type=view]; +"58 linear_3" [id=58, type=linear]; +"59 add_3" [id=59, type=add]; +"60 _assert_tensor_metadata_default_2" [id=60, type="_assert_tensor_metadata"]; +"61 to_2" [id=61, type=to]; +"62 pow_2" [id=62, type=pow]; +"63 mean_1" [id=63, type=mean]; +"64 add_4" [id=64, type=add]; +"65 rsqrt_1" [id=65, type=rsqrt]; +"66 mul_6" [id=66, type=mul]; +"67 _assert_tensor_metadata_default_3" [id=67, type="_assert_tensor_metadata"]; +"68 to_3" [id=68, type=to]; +"69 mul_7" [id=69, type=mul]; +"70 linear_4" [id=70, type=linear]; +"71 silu" [id=71, type=silu]; +"72 linear_5" [id=72, type=linear]; +"73 mul_8" [id=73, type=mul]; +"74 linear_6" [id=74, type=linear]; +"75 add_5" [id=75, type=add]; +"76 output" [id=76, type=output]; +"0 attn_norm_weight" -> "29 mul_1" [style=solid, label="(64,)"]; +"1 mlp_norm_weight" -> "69 mul_7" [style=solid, label="(64,)"]; +"2 q_proj_weight_updated_constant0" -> "3 symmetric_weights_decompressor_q_proj_weight_0" [style=solid, label="(2048, 1)"]; +"3 symmetric_weights_decompressor_q_proj_weight_0" -> "30 linear" [style=solid, label="(64, 64)"]; +"4 k_proj_weight_updated_constant0" -> "5 symmetric_weights_decompressor_k_proj_weight_0" [style=solid, label="(2048, 1)"]; +"5 symmetric_weights_decompressor_k_proj_weight_0" -> "33 linear_1" [style=solid, label="(64, 64)"]; +"6 v_proj_weight_updated_constant0" -> "7 symmetric_weights_decompressor_v_proj_weight_0" [style=solid, label="(2048, 1)"]; +"7 symmetric_weights_decompressor_v_proj_weight_0" -> "36 linear_2" [style=solid, label="(64, 64)"]; +"8 o_proj_weight_updated_constant0" -> "9 symmetric_weights_decompressor_o_proj_weight_0" [style=solid, label="(2048, 1)"]; +"9 symmetric_weights_decompressor_o_proj_weight_0" -> "58 linear_3" [style=solid, label="(64, 64)"]; +"10 mlp_gate_proj_weight_updated_constant0" -> "11 symmetric_weights_decompressor_mlp_gate_proj_weight_0" [style=solid, label="(4096, 1)"]; +"11 symmetric_weights_decompressor_mlp_gate_proj_weight_0" -> "70 linear_4" [style=solid, label="(128, 64)"]; +"12 mlp_up_proj_weight_updated_constant0" -> "13 symmetric_weights_decompressor_mlp_up_proj_weight_0" [style=solid, label="(4096, 1)"]; +"13 symmetric_weights_decompressor_mlp_up_proj_weight_0" -> "72 linear_5" [style=solid, label="(128, 64)"]; +"14 mlp_down_proj_weight_updated_constant0" -> "15 symmetric_weights_decompressor_mlp_down_proj_weight_0" [style=solid, label="(4096, 1)"]; +"15 symmetric_weights_decompressor_mlp_down_proj_weight_0" -> "74 linear_6" [style=solid, label="(64, 128)"]; +"16 rope_cos" -> "39 index" [style=solid, label="(1, 1, 128, 16)"]; +"17 rope_sin" -> "40 index_1" [style=solid, label="(1, 1, 128, 16)"]; +"18 x_embed" -> "20 _assert_tensor_metadata_default" [style=solid, label="(1, 3, 64)"]; +"18 x_embed" -> "21 to" [style=solid, label="(1, 3, 64)"]; +"18 x_embed" -> "59 add_3" [style=solid, label="(1, 3, 64)"]; +"19 arange" -> "39 index" [style=solid, label="(3,)"]; +"19 arange" -> "40 index_1" [style=solid, label="(3,)"]; +"21 to" -> "22 pow_1" [style=solid, label="(1, 3, 64)"]; +"21 to" -> "26 mul" [style=solid, label="(1, 3, 64)"]; +"22 pow_1" -> "23 mean" [style=solid, label="(1, 3, 64)"]; +"23 mean" -> "24 add" [style=solid, label="(1, 3, 1)"]; +"24 add" -> "25 rsqrt" [style=solid, label="(1, 3, 1)"]; +"25 rsqrt" -> "26 mul" [style=solid, label="(1, 3, 1)"]; +"26 mul" -> "27 _assert_tensor_metadata_default_1" [style=solid, label="(1, 3, 64)"]; +"26 mul" -> "28 to_1" [style=solid, label="(1, 3, 64)"]; +"28 to_1" -> "29 mul_1" [style=solid, label="(1, 3, 64)"]; +"29 mul_1" -> "30 linear" [style=solid, label="(1, 3, 64)"]; +"29 mul_1" -> "33 linear_1" [style=solid, label="(1, 3, 64)"]; +"29 mul_1" -> "36 linear_2" [style=solid, label="(1, 3, 64)"]; +"30 linear" -> "31 view" [style=solid, label="(1, 3, 64)"]; +"31 view" -> "32 transpose" [style=solid, label="(1, 3, 4, 16)"]; +"32 transpose" -> "41 mul_2" [style=solid, label="(1, 4, 3, 16)"]; +"32 transpose" -> "42 slice_1" [style=solid, label="(1, 4, 3, 16)"]; +"32 transpose" -> "43 slice_2" [style=solid, label="(1, 4, 3, 16)"]; +"33 linear_1" -> "34 view_1" [style=solid, label="(1, 3, 64)"]; +"34 view_1" -> "35 transpose_1" [style=solid, label="(1, 3, 4, 16)"]; +"35 transpose_1" -> "48 mul_4" [style=solid, label="(1, 4, 3, 16)"]; +"35 transpose_1" -> "49 slice_3" [style=solid, label="(1, 4, 3, 16)"]; +"35 transpose_1" -> "50 slice_4" [style=solid, label="(1, 4, 3, 16)"]; +"36 linear_2" -> "37 view_2" [style=solid, label="(1, 3, 64)"]; +"37 view_2" -> "38 transpose_2" [style=solid, label="(1, 3, 4, 16)"]; +"38 transpose_2" -> "55 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"39 index" -> "41 mul_2" [style=solid, label="(1, 1, 3, 16)"]; +"39 index" -> "48 mul_4" [style=solid, label="(1, 1, 3, 16)"]; +"40 index_1" -> "46 mul_3" [style=solid, label="(1, 1, 3, 16)"]; +"40 index_1" -> "53 mul_5" [style=solid, label="(1, 1, 3, 16)"]; +"41 mul_2" -> "47 add_1" [style=solid, label="(1, 4, 3, 16)"]; +"42 slice_1" -> "45 cat" [style=solid, label="(1, 4, 3, 8)"]; +"43 slice_2" -> "44 neg" [style=solid, label="(1, 4, 3, 8)"]; +"44 neg" -> "45 cat" [style=solid, label="(1, 4, 3, 8)"]; +"45 cat" -> "46 mul_3" [style=solid, label="(1, 4, 3, 16)"]; +"46 mul_3" -> "47 add_1" [style=solid, label="(1, 4, 3, 16)"]; +"47 add_1" -> "55 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"48 mul_4" -> "54 add_2" [style=solid, label="(1, 4, 3, 16)"]; +"49 slice_3" -> "52 cat_1" [style=solid, label="(1, 4, 3, 8)"]; +"50 slice_4" -> "51 neg_1" [style=solid, label="(1, 4, 3, 8)"]; +"51 neg_1" -> "52 cat_1" [style=solid, label="(1, 4, 3, 8)"]; +"52 cat_1" -> "53 mul_5" [style=solid, label="(1, 4, 3, 16)"]; +"53 mul_5" -> "54 add_2" [style=solid, label="(1, 4, 3, 16)"]; +"54 add_2" -> "55 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"55 scaled_dot_product_attention" -> "56 transpose_3" [style=solid, label="(1, 4, 3, 16)"]; +"56 transpose_3" -> "57 view_3" [style=solid, label="(1, 3, 4, 16)"]; +"57 view_3" -> "58 linear_3" [style=solid, label="(1, 3, 64)"]; +"58 linear_3" -> "59 add_3" [style=solid, label="(1, 3, 64)"]; +"59 add_3" -> "60 _assert_tensor_metadata_default_2" [style=solid, label="(1, 3, 64)"]; +"59 add_3" -> "61 to_2" [style=solid, label="(1, 3, 64)"]; +"61 to_2" -> "62 pow_2" [style=solid, label="(1, 3, 64)"]; +"61 to_2" -> "66 mul_6" [style=solid, label="(1, 3, 64)"]; +"61 to_2" -> "75 add_5" [style=solid, label="(1, 3, 64)"]; +"62 pow_2" -> "63 mean_1" [style=solid, label="(1, 3, 64)"]; +"63 mean_1" -> "64 add_4" [style=solid, label="(1, 3, 1)"]; +"64 add_4" -> "65 rsqrt_1" [style=solid, label="(1, 3, 1)"]; +"65 rsqrt_1" -> "66 mul_6" [style=solid, label="(1, 3, 1)"]; +"66 mul_6" -> "67 _assert_tensor_metadata_default_3" [style=solid, label="(1, 3, 64)"]; +"66 mul_6" -> "68 to_3" [style=solid, label="(1, 3, 64)"]; +"68 to_3" -> "69 mul_7" [style=solid, label="(1, 3, 64)"]; +"69 mul_7" -> "70 linear_4" [style=solid, label="(1, 3, 64)"]; +"69 mul_7" -> "72 linear_5" [style=solid, label="(1, 3, 64)"]; +"70 linear_4" -> "71 silu" [style=solid, label="(1, 3, 128)"]; +"71 silu" -> "73 mul_8" [style=solid, label="(1, 3, 128)"]; +"72 linear_5" -> "73 mul_8" [style=solid, label="(1, 3, 128)"]; +"73 mul_8" -> "74 linear_6" [style=solid, label="(1, 3, 128)"]; +"74 linear_6" -> "75 add_5" [style=solid, label="(1, 3, 64)"]; +"75 add_5" -> "76 output" [style=solid, label="(1, 3, 64)"]; +} diff --git a/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/LlamaDecoderOnly/int8wo_asym_gs-1_all_layers_False.dot b/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/LlamaDecoderOnly/int8wo_asym_gs-1_all_layers_False.dot new file mode 100644 index 00000000000..614e06a21ac --- /dev/null +++ b/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/LlamaDecoderOnly/int8wo_asym_gs-1_all_layers_False.dot @@ -0,0 +1,169 @@ +strict digraph { +"0 attn_norm_weight" [id=0, type="get_attr"]; +"1 mlp_norm_weight" [id=1, type="get_attr"]; +"2 q_proj_weight_updated_constant0" [id=2, type="get_attr"]; +"3 asymmetric_weights_decompressor_q_proj_weight_0" [id=3, type="call_module"]; +"4 k_proj_weight_updated_constant0" [id=4, type="get_attr"]; +"5 asymmetric_weights_decompressor_k_proj_weight_0" [id=5, type="call_module"]; +"6 v_proj_weight_updated_constant0" [id=6, type="get_attr"]; +"7 asymmetric_weights_decompressor_v_proj_weight_0" [id=7, type="call_module"]; +"8 o_proj_weight_updated_constant0" [id=8, type="get_attr"]; +"9 asymmetric_weights_decompressor_o_proj_weight_0" [id=9, type="call_module"]; +"10 mlp_gate_proj_weight_updated_constant0" [id=10, type="get_attr"]; +"11 asymmetric_weights_decompressor_mlp_gate_proj_weight_0" [id=11, type="call_module"]; +"12 mlp_up_proj_weight_updated_constant0" [id=12, type="get_attr"]; +"13 asymmetric_weights_decompressor_mlp_up_proj_weight_0" [id=13, type="call_module"]; +"14 mlp_down_proj_weight_updated_constant0" [id=14, type="get_attr"]; +"15 asymmetric_weights_decompressor_mlp_down_proj_weight_0" [id=15, type="call_module"]; +"16 rope_cos" [id=16, type="get_attr"]; +"17 rope_sin" [id=17, type="get_attr"]; +"18 x_embed" [id=18, type=input]; +"19 arange" [id=19, type=arange]; +"20 _assert_tensor_metadata_default" [id=20, type="_assert_tensor_metadata"]; +"21 to" [id=21, type=to]; +"22 pow_1" [id=22, type=pow]; +"23 mean" [id=23, type=mean]; +"24 add" [id=24, type=add]; +"25 rsqrt" [id=25, type=rsqrt]; +"26 mul" [id=26, type=mul]; +"27 _assert_tensor_metadata_default_1" [id=27, type="_assert_tensor_metadata"]; +"28 to_1" [id=28, type=to]; +"29 mul_1" [id=29, type=mul]; +"30 linear" [id=30, type=linear]; +"31 view" [id=31, type=view]; +"32 transpose" [id=32, type=transpose]; +"33 linear_1" [id=33, type=linear]; +"34 view_1" [id=34, type=view]; +"35 transpose_1" [id=35, type=transpose]; +"36 linear_2" [id=36, type=linear]; +"37 view_2" [id=37, type=view]; +"38 transpose_2" [id=38, type=transpose]; +"39 index" [id=39, type=index]; +"40 index_1" [id=40, type=index]; +"41 mul_2" [id=41, type=mul]; +"42 slice_1" [id=42, type=slice]; +"43 slice_2" [id=43, type=slice]; +"44 neg" [id=44, type=neg]; +"45 cat" [id=45, type=cat]; +"46 mul_3" [id=46, type=mul]; +"47 add_1" [id=47, type=add]; +"48 mul_4" [id=48, type=mul]; +"49 slice_3" [id=49, type=slice]; +"50 slice_4" [id=50, type=slice]; +"51 neg_1" [id=51, type=neg]; +"52 cat_1" [id=52, type=cat]; +"53 mul_5" [id=53, type=mul]; +"54 add_2" [id=54, type=add]; +"55 scaled_dot_product_attention" [id=55, type="scaled_dot_product_attention"]; +"56 transpose_3" [id=56, type=transpose]; +"57 view_3" [id=57, type=view]; +"58 linear_3" [id=58, type=linear]; +"59 add_3" [id=59, type=add]; +"60 _assert_tensor_metadata_default_2" [id=60, type="_assert_tensor_metadata"]; +"61 to_2" [id=61, type=to]; +"62 pow_2" [id=62, type=pow]; +"63 mean_1" [id=63, type=mean]; +"64 add_4" [id=64, type=add]; +"65 rsqrt_1" [id=65, type=rsqrt]; +"66 mul_6" [id=66, type=mul]; +"67 _assert_tensor_metadata_default_3" [id=67, type="_assert_tensor_metadata"]; +"68 to_3" [id=68, type=to]; +"69 mul_7" [id=69, type=mul]; +"70 linear_4" [id=70, type=linear]; +"71 silu" [id=71, type=silu]; +"72 linear_5" [id=72, type=linear]; +"73 mul_8" [id=73, type=mul]; +"74 linear_6" [id=74, type=linear]; +"75 add_5" [id=75, type=add]; +"76 output" [id=76, type=output]; +"0 attn_norm_weight" -> "29 mul_1" [style=solid, label="(64,)"]; +"1 mlp_norm_weight" -> "69 mul_7" [style=solid, label="(64,)"]; +"2 q_proj_weight_updated_constant0" -> "3 asymmetric_weights_decompressor_q_proj_weight_0" [style=solid, label="(64, 64)"]; +"3 asymmetric_weights_decompressor_q_proj_weight_0" -> "30 linear" [style=solid, label="(64, 64)"]; +"4 k_proj_weight_updated_constant0" -> "5 asymmetric_weights_decompressor_k_proj_weight_0" [style=solid, label="(64, 64)"]; +"5 asymmetric_weights_decompressor_k_proj_weight_0" -> "33 linear_1" [style=solid, label="(64, 64)"]; +"6 v_proj_weight_updated_constant0" -> "7 asymmetric_weights_decompressor_v_proj_weight_0" [style=solid, label="(64, 64)"]; +"7 asymmetric_weights_decompressor_v_proj_weight_0" -> "36 linear_2" [style=solid, label="(64, 64)"]; +"8 o_proj_weight_updated_constant0" -> "9 asymmetric_weights_decompressor_o_proj_weight_0" [style=solid, label="(64, 64)"]; +"9 asymmetric_weights_decompressor_o_proj_weight_0" -> "58 linear_3" [style=solid, label="(64, 64)"]; +"10 mlp_gate_proj_weight_updated_constant0" -> "11 asymmetric_weights_decompressor_mlp_gate_proj_weight_0" [style=solid, label="(128, 64)"]; +"11 asymmetric_weights_decompressor_mlp_gate_proj_weight_0" -> "70 linear_4" [style=solid, label="(128, 64)"]; +"12 mlp_up_proj_weight_updated_constant0" -> "13 asymmetric_weights_decompressor_mlp_up_proj_weight_0" [style=solid, label="(128, 64)"]; +"13 asymmetric_weights_decompressor_mlp_up_proj_weight_0" -> "72 linear_5" [style=solid, label="(128, 64)"]; +"14 mlp_down_proj_weight_updated_constant0" -> "15 asymmetric_weights_decompressor_mlp_down_proj_weight_0" [style=solid, label="(64, 128)"]; +"15 asymmetric_weights_decompressor_mlp_down_proj_weight_0" -> "74 linear_6" [style=solid, label="(64, 128)"]; +"16 rope_cos" -> "39 index" [style=solid, label="(1, 1, 128, 16)"]; +"17 rope_sin" -> "40 index_1" [style=solid, label="(1, 1, 128, 16)"]; +"18 x_embed" -> "20 _assert_tensor_metadata_default" [style=solid, label="(1, 3, 64)"]; +"18 x_embed" -> "21 to" [style=solid, label="(1, 3, 64)"]; +"18 x_embed" -> "59 add_3" [style=solid, label="(1, 3, 64)"]; +"19 arange" -> "39 index" [style=solid, label="(3,)"]; +"19 arange" -> "40 index_1" [style=solid, label="(3,)"]; +"21 to" -> "22 pow_1" [style=solid, label="(1, 3, 64)"]; +"21 to" -> "26 mul" [style=solid, label="(1, 3, 64)"]; +"22 pow_1" -> "23 mean" [style=solid, label="(1, 3, 64)"]; +"23 mean" -> "24 add" [style=solid, label="(1, 3, 1)"]; +"24 add" -> "25 rsqrt" [style=solid, label="(1, 3, 1)"]; +"25 rsqrt" -> "26 mul" [style=solid, label="(1, 3, 1)"]; +"26 mul" -> "27 _assert_tensor_metadata_default_1" [style=solid, label="(1, 3, 64)"]; +"26 mul" -> "28 to_1" [style=solid, label="(1, 3, 64)"]; +"28 to_1" -> "29 mul_1" [style=solid, label="(1, 3, 64)"]; +"29 mul_1" -> "30 linear" [style=solid, label="(1, 3, 64)"]; +"29 mul_1" -> "33 linear_1" [style=solid, label="(1, 3, 64)"]; +"29 mul_1" -> "36 linear_2" [style=solid, label="(1, 3, 64)"]; +"30 linear" -> "31 view" [style=solid, label="(1, 3, 64)"]; +"31 view" -> "32 transpose" [style=solid, label="(1, 3, 4, 16)"]; +"32 transpose" -> "41 mul_2" [style=solid, label="(1, 4, 3, 16)"]; +"32 transpose" -> "42 slice_1" [style=solid, label="(1, 4, 3, 16)"]; +"32 transpose" -> "43 slice_2" [style=solid, label="(1, 4, 3, 16)"]; +"33 linear_1" -> "34 view_1" [style=solid, label="(1, 3, 64)"]; +"34 view_1" -> "35 transpose_1" [style=solid, label="(1, 3, 4, 16)"]; +"35 transpose_1" -> "48 mul_4" [style=solid, label="(1, 4, 3, 16)"]; +"35 transpose_1" -> "49 slice_3" [style=solid, label="(1, 4, 3, 16)"]; +"35 transpose_1" -> "50 slice_4" [style=solid, label="(1, 4, 3, 16)"]; +"36 linear_2" -> "37 view_2" [style=solid, label="(1, 3, 64)"]; +"37 view_2" -> "38 transpose_2" [style=solid, label="(1, 3, 4, 16)"]; +"38 transpose_2" -> "55 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"39 index" -> "41 mul_2" [style=solid, label="(1, 1, 3, 16)"]; +"39 index" -> "48 mul_4" [style=solid, label="(1, 1, 3, 16)"]; +"40 index_1" -> "46 mul_3" [style=solid, label="(1, 1, 3, 16)"]; +"40 index_1" -> "53 mul_5" [style=solid, label="(1, 1, 3, 16)"]; +"41 mul_2" -> "47 add_1" [style=solid, label="(1, 4, 3, 16)"]; +"42 slice_1" -> "45 cat" [style=solid, label="(1, 4, 3, 8)"]; +"43 slice_2" -> "44 neg" [style=solid, label="(1, 4, 3, 8)"]; +"44 neg" -> "45 cat" [style=solid, label="(1, 4, 3, 8)"]; +"45 cat" -> "46 mul_3" [style=solid, label="(1, 4, 3, 16)"]; +"46 mul_3" -> "47 add_1" [style=solid, label="(1, 4, 3, 16)"]; +"47 add_1" -> "55 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"48 mul_4" -> "54 add_2" [style=solid, label="(1, 4, 3, 16)"]; +"49 slice_3" -> "52 cat_1" [style=solid, label="(1, 4, 3, 8)"]; +"50 slice_4" -> "51 neg_1" [style=solid, label="(1, 4, 3, 8)"]; +"51 neg_1" -> "52 cat_1" [style=solid, label="(1, 4, 3, 8)"]; +"52 cat_1" -> "53 mul_5" [style=solid, label="(1, 4, 3, 16)"]; +"53 mul_5" -> "54 add_2" [style=solid, label="(1, 4, 3, 16)"]; +"54 add_2" -> "55 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"55 scaled_dot_product_attention" -> "56 transpose_3" [style=solid, label="(1, 4, 3, 16)"]; +"56 transpose_3" -> "57 view_3" [style=solid, label="(1, 3, 4, 16)"]; +"57 view_3" -> "58 linear_3" [style=solid, label="(1, 3, 64)"]; +"58 linear_3" -> "59 add_3" [style=solid, label="(1, 3, 64)"]; +"59 add_3" -> "60 _assert_tensor_metadata_default_2" [style=solid, label="(1, 3, 64)"]; +"59 add_3" -> "61 to_2" [style=solid, label="(1, 3, 64)"]; +"61 to_2" -> "62 pow_2" [style=solid, label="(1, 3, 64)"]; +"61 to_2" -> "66 mul_6" [style=solid, label="(1, 3, 64)"]; +"61 to_2" -> "75 add_5" [style=solid, label="(1, 3, 64)"]; +"62 pow_2" -> "63 mean_1" [style=solid, label="(1, 3, 64)"]; +"63 mean_1" -> "64 add_4" [style=solid, label="(1, 3, 1)"]; +"64 add_4" -> "65 rsqrt_1" [style=solid, label="(1, 3, 1)"]; +"65 rsqrt_1" -> "66 mul_6" [style=solid, label="(1, 3, 1)"]; +"66 mul_6" -> "67 _assert_tensor_metadata_default_3" [style=solid, label="(1, 3, 64)"]; +"66 mul_6" -> "68 to_3" [style=solid, label="(1, 3, 64)"]; +"68 to_3" -> "69 mul_7" [style=solid, label="(1, 3, 64)"]; +"69 mul_7" -> "70 linear_4" [style=solid, label="(1, 3, 64)"]; +"69 mul_7" -> "72 linear_5" [style=solid, label="(1, 3, 64)"]; +"70 linear_4" -> "71 silu" [style=solid, label="(1, 3, 128)"]; +"71 silu" -> "73 mul_8" [style=solid, label="(1, 3, 128)"]; +"72 linear_5" -> "73 mul_8" [style=solid, label="(1, 3, 128)"]; +"73 mul_8" -> "74 linear_6" [style=solid, label="(1, 3, 128)"]; +"74 linear_6" -> "75 add_5" [style=solid, label="(1, 3, 64)"]; +"75 add_5" -> "76 output" [style=solid, label="(1, 3, 64)"]; +} diff --git a/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_False.dot b/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_False.dot new file mode 100644 index 00000000000..2841824b5a3 --- /dev/null +++ b/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_False.dot @@ -0,0 +1,24 @@ +strict digraph { +"0 linear_weight_updated_constant0" [id=0, type="get_attr"]; +"1 symmetric_weights_decompressor_linear_weight_0" [id=1, type="call_module"]; +"2 linear_bias" [id=2, type="get_attr"]; +"3 wte_weight_1_updated_constant0" [id=3, type="get_attr"]; +"4 asymmetric_weights_decompressor_wte_weight_1_0" [id=4, type="call_module"]; +"5 lm_head_bias" [id=5, type="get_attr"]; +"6 input_ids" [id=6, type=input]; +"7 embedding" [id=7, type=embedding]; +"8 linear" [id=8, type=linear]; +"9 linear_1" [id=9, type=linear]; +"10 output" [id=10, type=output]; +"0 linear_weight_updated_constant0" -> "1 symmetric_weights_decompressor_linear_weight_0" [style=solid, label="(2048, 1)"]; +"1 symmetric_weights_decompressor_linear_weight_0" -> "8 linear" [style=solid, label="(64, 64)"]; +"2 linear_bias" -> "8 linear" [style=solid, label="(64,)"]; +"3 wte_weight_1_updated_constant0" -> "4 asymmetric_weights_decompressor_wte_weight_1_0" [style=solid, label="(128, 64)"]; +"4 asymmetric_weights_decompressor_wte_weight_1_0" -> "7 embedding" [style=solid, label="(128, 64)"]; +"4 asymmetric_weights_decompressor_wte_weight_1_0" -> "9 linear_1" [style=solid, label="(128, 64)"]; +"5 lm_head_bias" -> "9 linear_1" [style=solid, label="(128,)"]; +"6 input_ids" -> "7 embedding" [style=solid, label="(5,)"]; +"7 embedding" -> "8 linear" [style=solid, label="(5, 64)"]; +"8 linear" -> "9 linear_1" [style=solid, label="(5, 64)"]; +"9 linear_1" -> "10 output" [style=solid, label="(5, 128)"]; +} diff --git a/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_True.dot b/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_True.dot new file mode 100644 index 00000000000..0382f7e5934 --- /dev/null +++ b/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_True.dot @@ -0,0 +1,24 @@ +strict digraph { +"0 linear_weight_updated_constant0" [id=0, type="get_attr"]; +"1 symmetric_weights_decompressor_linear_weight_0" [id=1, type="call_module"]; +"2 linear_bias" [id=2, type="get_attr"]; +"3 wte_weight_1_updated_constant0" [id=3, type="get_attr"]; +"4 symmetric_weights_decompressor_wte_weight_1_0" [id=4, type="call_module"]; +"5 lm_head_bias" [id=5, type="get_attr"]; +"6 input_ids" [id=6, type=input]; +"7 embedding" [id=7, type=embedding]; +"8 linear" [id=8, type=linear]; +"9 linear_1" [id=9, type=linear]; +"10 output" [id=10, type=output]; +"0 linear_weight_updated_constant0" -> "1 symmetric_weights_decompressor_linear_weight_0" [style=solid, label="(2048, 1)"]; +"1 symmetric_weights_decompressor_linear_weight_0" -> "8 linear" [style=solid, label="(64, 64)"]; +"2 linear_bias" -> "8 linear" [style=solid, label="(64,)"]; +"3 wte_weight_1_updated_constant0" -> "4 symmetric_weights_decompressor_wte_weight_1_0" [style=solid, label="(4096, 1)"]; +"4 symmetric_weights_decompressor_wte_weight_1_0" -> "7 embedding" [style=solid, label="(128, 64)"]; +"4 symmetric_weights_decompressor_wte_weight_1_0" -> "9 linear_1" [style=solid, label="(128, 64)"]; +"5 lm_head_bias" -> "9 linear_1" [style=solid, label="(128,)"]; +"6 input_ids" -> "7 embedding" [style=solid, label="(5,)"]; +"7 embedding" -> "8 linear" [style=solid, label="(5, 64)"]; +"8 linear" -> "9 linear_1" [style=solid, label="(5, 64)"]; +"9 linear_1" -> "10 output" [style=solid, label="(5, 128)"]; +} diff --git a/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/short_transformer_shared/int8wo_asym_gs-1_all_layers_False.dot b/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/short_transformer_shared/int8wo_asym_gs-1_all_layers_False.dot new file mode 100644 index 00000000000..03fc9e9c6a0 --- /dev/null +++ b/tests/executorch/data/fx/ao_export_compression_OpenVINOQuantizer/short_transformer_shared/int8wo_asym_gs-1_all_layers_False.dot @@ -0,0 +1,24 @@ +strict digraph { +"0 linear_weight_updated_constant0" [id=0, type="get_attr"]; +"1 asymmetric_weights_decompressor_linear_weight_0" [id=1, type="call_module"]; +"2 linear_bias" [id=2, type="get_attr"]; +"3 wte_weight_1_updated_constant0" [id=3, type="get_attr"]; +"4 asymmetric_weights_decompressor_wte_weight_1_0" [id=4, type="call_module"]; +"5 lm_head_bias" [id=5, type="get_attr"]; +"6 input_ids" [id=6, type=input]; +"7 embedding" [id=7, type=embedding]; +"8 linear" [id=8, type=linear]; +"9 linear_1" [id=9, type=linear]; +"10 output" [id=10, type=output]; +"0 linear_weight_updated_constant0" -> "1 asymmetric_weights_decompressor_linear_weight_0" [style=solid, label="(64, 64)"]; +"1 asymmetric_weights_decompressor_linear_weight_0" -> "8 linear" [style=solid, label="(64, 64)"]; +"2 linear_bias" -> "8 linear" [style=solid, label="(64,)"]; +"3 wte_weight_1_updated_constant0" -> "4 asymmetric_weights_decompressor_wte_weight_1_0" [style=solid, label="(128, 64)"]; +"4 asymmetric_weights_decompressor_wte_weight_1_0" -> "7 embedding" [style=solid, label="(128, 64)"]; +"4 asymmetric_weights_decompressor_wte_weight_1_0" -> "9 linear_1" [style=solid, label="(128, 64)"]; +"5 lm_head_bias" -> "9 linear_1" [style=solid, label="(128,)"]; +"6 input_ids" -> "7 embedding" [style=solid, label="(5,)"]; +"7 embedding" -> "8 linear" [style=solid, label="(5, 64)"]; +"8 linear" -> "9 linear_1" [style=solid, label="(5, 64)"]; +"9 linear_1" -> "10 output" [style=solid, label="(5, 128)"]; +} diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_False.dot b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_False.dot new file mode 100644 index 00000000000..076e46114eb --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_False.dot @@ -0,0 +1,169 @@ +strict digraph { +"0 attn_norm_weight" [id=0, type="get_attr"]; +"1 mlp_norm_weight" [id=1, type="get_attr"]; +"2 rope_cos" [id=2, type="get_attr"]; +"3 rope_sin" [id=3, type="get_attr"]; +"4 x_embed" [id=4, type=input]; +"5 arange" [id=5, type=arange]; +"6 _assert_tensor_metadata_default" [id=6, type="_assert_tensor_metadata"]; +"7 to" [id=7, type=to]; +"8 pow_1" [id=8, type=pow]; +"9 mean" [id=9, type=mean]; +"10 add" [id=10, type=add]; +"11 rsqrt" [id=11, type=rsqrt]; +"12 mul" [id=12, type=mul]; +"13 _assert_tensor_metadata_default_1" [id=13, type="_assert_tensor_metadata"]; +"14 to_1" [id=14, type=to]; +"15 mul_1" [id=15, type=mul]; +"16 q_proj_weight_updated_constant0" [id=16, type="get_attr"]; +"17 symmetric_weights_decompressor_q_proj_weight_0" [id=17, type="call_module"]; +"18 linear" [id=18, type=linear]; +"19 view" [id=19, type=view]; +"20 transpose" [id=20, type=transpose]; +"21 k_proj_weight_updated_constant0" [id=21, type="get_attr"]; +"22 symmetric_weights_decompressor_k_proj_weight_0" [id=22, type="call_module"]; +"23 linear_1" [id=23, type=linear]; +"24 view_1" [id=24, type=view]; +"25 transpose_1" [id=25, type=transpose]; +"26 v_proj_weight_updated_constant0" [id=26, type="get_attr"]; +"27 symmetric_weights_decompressor_v_proj_weight_0" [id=27, type="call_module"]; +"28 linear_2" [id=28, type=linear]; +"29 view_2" [id=29, type=view]; +"30 transpose_2" [id=30, type=transpose]; +"31 index" [id=31, type=index]; +"32 index_1" [id=32, type=index]; +"33 mul_2" [id=33, type=mul]; +"34 slice_1" [id=34, type=slice]; +"35 slice_2" [id=35, type=slice]; +"36 neg" [id=36, type=neg]; +"37 cat" [id=37, type=cat]; +"38 mul_3" [id=38, type=mul]; +"39 add_1" [id=39, type=add]; +"40 mul_4" [id=40, type=mul]; +"41 slice_3" [id=41, type=slice]; +"42 slice_4" [id=42, type=slice]; +"43 neg_1" [id=43, type=neg]; +"44 cat_1" [id=44, type=cat]; +"45 mul_5" [id=45, type=mul]; +"46 add_2" [id=46, type=add]; +"47 scaled_dot_product_attention" [id=47, type="scaled_dot_product_attention"]; +"48 transpose_3" [id=48, type=transpose]; +"49 view_3" [id=49, type=view]; +"50 o_proj_weight_updated_constant0" [id=50, type="get_attr"]; +"51 symmetric_weights_decompressor_o_proj_weight_0" [id=51, type="call_module"]; +"52 linear_3" [id=52, type=linear]; +"53 add_3" [id=53, type=add]; +"54 _assert_tensor_metadata_default_2" [id=54, type="_assert_tensor_metadata"]; +"55 to_2" [id=55, type=to]; +"56 pow_2" [id=56, type=pow]; +"57 mean_1" [id=57, type=mean]; +"58 add_4" [id=58, type=add]; +"59 rsqrt_1" [id=59, type=rsqrt]; +"60 mul_6" [id=60, type=mul]; +"61 _assert_tensor_metadata_default_3" [id=61, type="_assert_tensor_metadata"]; +"62 to_3" [id=62, type=to]; +"63 mul_7" [id=63, type=mul]; +"64 mlp_gate_proj_weight_updated_constant0" [id=64, type="get_attr"]; +"65 symmetric_weights_decompressor_mlp_gate_proj_weight_0" [id=65, type="call_module"]; +"66 linear_4" [id=66, type=linear]; +"67 silu" [id=67, type=silu]; +"68 mlp_up_proj_weight_updated_constant0" [id=68, type="get_attr"]; +"69 asymmetric_weights_decompressor_mlp_up_proj_weight_0" [id=69, type="call_module"]; +"70 linear_5" [id=70, type=linear]; +"71 mul_8" [id=71, type=mul]; +"72 mlp_down_proj_weight_updated_constant0" [id=72, type="get_attr"]; +"73 asymmetric_weights_decompressor_mlp_down_proj_weight_0" [id=73, type="call_module"]; +"74 linear_6" [id=74, type=linear]; +"75 add_5" [id=75, type=add]; +"76 output" [id=76, type=output]; +"0 attn_norm_weight" -> "15 mul_1" [style=solid, label="(64,)"]; +"1 mlp_norm_weight" -> "63 mul_7" [style=solid, label="(64,)"]; +"2 rope_cos" -> "31 index" [style=solid, label="(1, 1, 128, 16)"]; +"3 rope_sin" -> "32 index_1" [style=solid, label="(1, 1, 128, 16)"]; +"4 x_embed" -> "6 _assert_tensor_metadata_default" [style=solid, label="(1, 3, 64)"]; +"4 x_embed" -> "7 to" [style=solid, label="(1, 3, 64)"]; +"4 x_embed" -> "53 add_3" [style=solid, label="(1, 3, 64)"]; +"5 arange" -> "31 index" [style=solid, label="(3,)"]; +"5 arange" -> "32 index_1" [style=solid, label="(3,)"]; +"7 to" -> "8 pow_1" [style=solid, label="(1, 3, 64)"]; +"7 to" -> "12 mul" [style=solid, label="(1, 3, 64)"]; +"8 pow_1" -> "9 mean" [style=solid, label="(1, 3, 64)"]; +"9 mean" -> "10 add" [style=solid, label="(1, 3, 1)"]; +"10 add" -> "11 rsqrt" [style=solid, label="(1, 3, 1)"]; +"11 rsqrt" -> "12 mul" [style=solid, label="(1, 3, 1)"]; +"12 mul" -> "13 _assert_tensor_metadata_default_1" [style=solid, label="(1, 3, 64)"]; +"12 mul" -> "14 to_1" [style=solid, label="(1, 3, 64)"]; +"14 to_1" -> "15 mul_1" [style=solid, label="(1, 3, 64)"]; +"15 mul_1" -> "18 linear" [style=solid, label="(1, 3, 64)"]; +"15 mul_1" -> "23 linear_1" [style=solid, label="(1, 3, 64)"]; +"15 mul_1" -> "28 linear_2" [style=solid, label="(1, 3, 64)"]; +"16 q_proj_weight_updated_constant0" -> "17 symmetric_weights_decompressor_q_proj_weight_0" [style=solid, label="(2048, 1)"]; +"17 symmetric_weights_decompressor_q_proj_weight_0" -> "18 linear" [style=solid, label="(64, 64)"]; +"18 linear" -> "19 view" [style=solid, label="(1, 3, 64)"]; +"19 view" -> "20 transpose" [style=solid, label="(1, 3, 4, 16)"]; +"20 transpose" -> "33 mul_2" [style=solid, label="(1, 4, 3, 16)"]; +"20 transpose" -> "34 slice_1" [style=solid, label="(1, 4, 3, 16)"]; +"20 transpose" -> "35 slice_2" [style=solid, label="(1, 4, 3, 16)"]; +"21 k_proj_weight_updated_constant0" -> "22 symmetric_weights_decompressor_k_proj_weight_0" [style=solid, label="(2048, 1)"]; +"22 symmetric_weights_decompressor_k_proj_weight_0" -> "23 linear_1" [style=solid, label="(64, 64)"]; +"23 linear_1" -> "24 view_1" [style=solid, label="(1, 3, 64)"]; +"24 view_1" -> "25 transpose_1" [style=solid, label="(1, 3, 4, 16)"]; +"25 transpose_1" -> "40 mul_4" [style=solid, label="(1, 4, 3, 16)"]; +"25 transpose_1" -> "41 slice_3" [style=solid, label="(1, 4, 3, 16)"]; +"25 transpose_1" -> "42 slice_4" [style=solid, label="(1, 4, 3, 16)"]; +"26 v_proj_weight_updated_constant0" -> "27 symmetric_weights_decompressor_v_proj_weight_0" [style=solid, label="(2048, 1)"]; +"27 symmetric_weights_decompressor_v_proj_weight_0" -> "28 linear_2" [style=solid, label="(64, 64)"]; +"28 linear_2" -> "29 view_2" [style=solid, label="(1, 3, 64)"]; +"29 view_2" -> "30 transpose_2" [style=solid, label="(1, 3, 4, 16)"]; +"30 transpose_2" -> "47 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"31 index" -> "33 mul_2" [style=solid, label="(1, 1, 3, 16)"]; +"31 index" -> "40 mul_4" [style=solid, label="(1, 1, 3, 16)"]; +"32 index_1" -> "38 mul_3" [style=solid, label="(1, 1, 3, 16)"]; +"32 index_1" -> "45 mul_5" [style=solid, label="(1, 1, 3, 16)"]; +"33 mul_2" -> "39 add_1" [style=solid, label="(1, 4, 3, 16)"]; +"34 slice_1" -> "37 cat" [style=solid, label="(1, 4, 3, 8)"]; +"35 slice_2" -> "36 neg" [style=solid, label="(1, 4, 3, 8)"]; +"36 neg" -> "37 cat" [style=solid, label="(1, 4, 3, 8)"]; +"37 cat" -> "38 mul_3" [style=solid, label="(1, 4, 3, 16)"]; +"38 mul_3" -> "39 add_1" [style=solid, label="(1, 4, 3, 16)"]; +"39 add_1" -> "47 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"40 mul_4" -> "46 add_2" [style=solid, label="(1, 4, 3, 16)"]; +"41 slice_3" -> "44 cat_1" [style=solid, label="(1, 4, 3, 8)"]; +"42 slice_4" -> "43 neg_1" [style=solid, label="(1, 4, 3, 8)"]; +"43 neg_1" -> "44 cat_1" [style=solid, label="(1, 4, 3, 8)"]; +"44 cat_1" -> "45 mul_5" [style=solid, label="(1, 4, 3, 16)"]; +"45 mul_5" -> "46 add_2" [style=solid, label="(1, 4, 3, 16)"]; +"46 add_2" -> "47 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"47 scaled_dot_product_attention" -> "48 transpose_3" [style=solid, label="(1, 4, 3, 16)"]; +"48 transpose_3" -> "49 view_3" [style=solid, label="(1, 3, 4, 16)"]; +"49 view_3" -> "52 linear_3" [style=solid, label="(1, 3, 64)"]; +"50 o_proj_weight_updated_constant0" -> "51 symmetric_weights_decompressor_o_proj_weight_0" [style=solid, label="(2048, 1)"]; +"51 symmetric_weights_decompressor_o_proj_weight_0" -> "52 linear_3" [style=solid, label="(64, 64)"]; +"52 linear_3" -> "53 add_3" [style=solid, label="(1, 3, 64)"]; +"53 add_3" -> "54 _assert_tensor_metadata_default_2" [style=solid, label="(1, 3, 64)"]; +"53 add_3" -> "55 to_2" [style=solid, label="(1, 3, 64)"]; +"55 to_2" -> "56 pow_2" [style=solid, label="(1, 3, 64)"]; +"55 to_2" -> "60 mul_6" [style=solid, label="(1, 3, 64)"]; +"55 to_2" -> "75 add_5" [style=solid, label="(1, 3, 64)"]; +"56 pow_2" -> "57 mean_1" [style=solid, label="(1, 3, 64)"]; +"57 mean_1" -> "58 add_4" [style=solid, label="(1, 3, 1)"]; +"58 add_4" -> "59 rsqrt_1" [style=solid, label="(1, 3, 1)"]; +"59 rsqrt_1" -> "60 mul_6" [style=solid, label="(1, 3, 1)"]; +"60 mul_6" -> "61 _assert_tensor_metadata_default_3" [style=solid, label="(1, 3, 64)"]; +"60 mul_6" -> "62 to_3" [style=solid, label="(1, 3, 64)"]; +"62 to_3" -> "63 mul_7" [style=solid, label="(1, 3, 64)"]; +"63 mul_7" -> "66 linear_4" [style=solid, label="(1, 3, 64)"]; +"63 mul_7" -> "70 linear_5" [style=solid, label="(1, 3, 64)"]; +"64 mlp_gate_proj_weight_updated_constant0" -> "65 symmetric_weights_decompressor_mlp_gate_proj_weight_0" [style=solid, label="(4096, 1)"]; +"65 symmetric_weights_decompressor_mlp_gate_proj_weight_0" -> "66 linear_4" [style=solid, label="(128, 64)"]; +"66 linear_4" -> "67 silu" [style=solid, label="(1, 3, 128)"]; +"67 silu" -> "71 mul_8" [style=solid, label="(1, 3, 128)"]; +"68 mlp_up_proj_weight_updated_constant0" -> "69 asymmetric_weights_decompressor_mlp_up_proj_weight_0" [style=solid, label="(128, 64)"]; +"69 asymmetric_weights_decompressor_mlp_up_proj_weight_0" -> "70 linear_5" [style=solid, label="(128, 64)"]; +"70 linear_5" -> "71 mul_8" [style=solid, label="(1, 3, 128)"]; +"71 mul_8" -> "74 linear_6" [style=solid, label="(1, 3, 128)"]; +"72 mlp_down_proj_weight_updated_constant0" -> "73 asymmetric_weights_decompressor_mlp_down_proj_weight_0" [style=solid, label="(64, 128)"]; +"73 asymmetric_weights_decompressor_mlp_down_proj_weight_0" -> "74 linear_6" [style=solid, label="(64, 128)"]; +"74 linear_6" -> "75 add_5" [style=solid, label="(1, 3, 64)"]; +"75 add_5" -> "76 output" [style=solid, label="(1, 3, 64)"]; +} diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_False_awq_True_scale_estimation_True_ref_wc_scales.json b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_False_awq_True_scale_estimation_True_ref_wc_scales.json new file mode 100644 index 00000000000..364f78db4aa --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_False_awq_True_scale_estimation_True_ref_wc_scales.json @@ -0,0 +1,3664 @@ +{ + "symmetric_weights_decompressor_q_proj_weight_0": [ + [ + [ + -0.0135650634765625 + ], + [ + -0.014251708984375 + ] + ], + [ + [ + 0.015411376953125 + ], + [ + -0.01557159423828125 + ] + ], + [ + [ + -0.01427459716796875 + ], + [ + -0.01519012451171875 + ] + ], + [ + [ + 0.01224517822265625 + ], + [ + 0.01528167724609375 + ] + ], + [ + [ + 0.01348876953125 + ], + [ + -0.01556396484375 + ] + ], + [ + [ + 0.01401519775390625 + ], + [ + 0.0153350830078125 + ] + ], + [ + [ + 0.01560211181640625 + ], + [ + 0.01523590087890625 + ] + ], + [ + [ + 0.0160675048828125 + ], + [ + -0.0149993896484375 + ] + ], + [ + [ + -0.01445770263671875 + ], + [ + 0.01251220703125 + ] + ], + [ + [ + -0.0155487060546875 + ], + [ + 0.0126495361328125 + ] + ], + [ + [ + -0.01306915283203125 + ], + [ + -0.01422882080078125 + ] + ], + [ + [ + -0.0154571533203125 + ], + [ + -0.015594482421875 + ] + ], + [ + [ + -0.01308441162109375 + ], + [ + -0.01392364501953125 + ] + ], + [ + [ + -0.01512908935546875 + ], + [ + 0.01303863525390625 + ] + ], + [ + [ + -0.0134124755859375 + ], + [ + -0.0130462646484375 + ] + ], + [ + [ + 0.01467132568359375 + ], + [ + 0.01568603515625 + ] + ], + [ + [ + -0.01470184326171875 + ], + [ + 0.0147705078125 + ] + ], + [ + [ + -0.01477813720703125 + ], + [ + -0.01468658447265625 + ] + ], + [ + [ + -0.0128021240234375 + ], + [ + 0.0122833251953125 + ] + ], + [ + [ + -0.0158233642578125 + ], + [ + 0.0131378173828125 + ] + ], + [ + [ + -0.01537322998046875 + ], + [ + -0.01543426513671875 + ] + ], + [ + [ + -0.01548004150390625 + ], + [ + 0.01435089111328125 + ] + ], + [ + [ + -0.0130462646484375 + ], + [ + 0.01294708251953125 + ] + ], + [ + [ + 0.01348876953125 + ], + [ + 0.01446533203125 + ] + ], + [ + [ + -0.0157470703125 + ], + [ + -0.014892578125 + ] + ], + [ + [ + 0.01482391357421875 + ], + [ + -0.01473236083984375 + ] + ], + [ + [ + 0.0155029296875 + ], + [ + 0.0135040283203125 + ] + ], + [ + [ + 0.01456451416015625 + ], + [ + -0.016326904296875 + ] + ], + [ + [ + -0.01509857177734375 + ], + [ + -0.012847900390625 + ] + ], + [ + [ + -0.0152130126953125 + ], + [ + 0.01459503173828125 + ] + ], + [ + [ + -0.0153350830078125 + ], + [ + -0.01287078857421875 + ] + ], + [ + [ + 0.0136871337890625 + ], + [ + 0.014801025390625 + ] + ], + [ + [ + 0.01520538330078125 + ], + [ + -0.01514434814453125 + ] + ], + [ + [ + -0.01471710205078125 + ], + [ + 0.0155792236328125 + ] + ], + [ + [ + -0.01485443115234375 + ], + [ + 0.0147857666015625 + ] + ], + [ + [ + 0.01512908935546875 + ], + [ + -0.01381683349609375 + ] + ], + [ + [ + -0.015838623046875 + ], + [ + -0.01444244384765625 + ] + ], + [ + [ + -0.0146636962890625 + ], + [ + -0.01299285888671875 + ] + ], + [ + [ + -0.01495361328125 + ], + [ + -0.014801025390625 + ] + ], + [ + [ + -0.01396942138671875 + ], + [ + 0.0134124755859375 + ] + ], + [ + [ + -0.01490020751953125 + ], + [ + 0.015045166015625 + ] + ], + [ + [ + -0.01543426513671875 + ], + [ + 0.01514434814453125 + ] + ], + [ + [ + 0.01428985595703125 + ], + [ + 0.0141754150390625 + ] + ], + [ + [ + 0.014923095703125 + ], + [ + 0.01470947265625 + ] + ], + [ + [ + -0.01654052734375 + ], + [ + 0.01470947265625 + ] + ], + [ + [ + 0.0150299072265625 + ], + [ + 0.0132293701171875 + ] + ], + [ + [ + -0.0144500732421875 + ], + [ + -0.014556884765625 + ] + ], + [ + [ + -0.01354217529296875 + ], + [ + -0.01436614990234375 + ] + ], + [ + [ + 0.01250457763671875 + ], + [ + 0.014495849609375 + ] + ], + [ + [ + -0.01361846923828125 + ], + [ + -0.01445770263671875 + ] + ], + [ + [ + -0.0148162841796875 + ], + [ + 0.01213836669921875 + ] + ], + [ + [ + -0.0125274658203125 + ], + [ + -0.0152587890625 + ] + ], + [ + [ + -0.01308441162109375 + ], + [ + 0.01410675048828125 + ] + ], + [ + [ + -0.0150146484375 + ], + [ + 0.01324462890625 + ] + ], + [ + [ + -0.016021728515625 + ], + [ + 0.015289306640625 + ] + ], + [ + [ + -0.0143280029296875 + ], + [ + -0.0139617919921875 + ] + ], + [ + [ + -0.0147247314453125 + ], + [ + 0.0161590576171875 + ] + ], + [ + [ + -0.0119476318359375 + ], + [ + 0.0154571533203125 + ] + ], + [ + [ + -0.01476287841796875 + ], + [ + -0.0137176513671875 + ] + ], + [ + [ + 0.01558685302734375 + ], + [ + 0.013427734375 + ] + ], + [ + [ + -0.0167694091796875 + ], + [ + 0.01517486572265625 + ] + ], + [ + [ + 0.01235198974609375 + ], + [ + -0.01605224609375 + ] + ], + [ + [ + 0.015960693359375 + ], + [ + -0.015167236328125 + ] + ], + [ + [ + 0.01517486572265625 + ], + [ + 0.0162200927734375 + ] + ] + ], + "symmetric_weights_decompressor_k_proj_weight_0": [ + [ + [ + 0.0150604248046875 + ], + [ + -0.0138702392578125 + ] + ], + [ + [ + -0.01486968994140625 + ], + [ + -0.01424407958984375 + ] + ], + [ + [ + 0.01526641845703125 + ], + [ + -0.0126800537109375 + ] + ], + [ + [ + -0.01436614990234375 + ], + [ + -0.0157012939453125 + ] + ], + [ + [ + -0.01470947265625 + ], + [ + 0.013916015625 + ] + ], + [ + [ + -0.01371002197265625 + ], + [ + -0.01558685302734375 + ] + ], + [ + [ + 0.01265716552734375 + ], + [ + 0.01399993896484375 + ] + ], + [ + [ + -0.01520538330078125 + ], + [ + -0.01537322998046875 + ] + ], + [ + [ + 0.01538848876953125 + ], + [ + 0.0160064697265625 + ] + ], + [ + [ + -0.01537322998046875 + ], + [ + -0.01198577880859375 + ] + ], + [ + [ + -0.01551055908203125 + ], + [ + -0.01419830322265625 + ] + ], + [ + [ + -0.01544189453125 + ], + [ + -0.0127410888671875 + ] + ], + [ + [ + 0.014373779296875 + ], + [ + -0.01462554931640625 + ] + ], + [ + [ + 0.01326751708984375 + ], + [ + -0.015716552734375 + ] + ], + [ + [ + -0.01415252685546875 + ], + [ + -0.01483917236328125 + ] + ], + [ + [ + -0.01505279541015625 + ], + [ + 0.0154571533203125 + ] + ], + [ + [ + 0.01538848876953125 + ], + [ + -0.016021728515625 + ] + ], + [ + [ + -0.013916015625 + ], + [ + -0.01514434814453125 + ] + ], + [ + [ + 0.01401519775390625 + ], + [ + -0.01239776611328125 + ] + ], + [ + [ + -0.01540374755859375 + ], + [ + -0.0133209228515625 + ] + ], + [ + [ + 0.014617919921875 + ], + [ + 0.01727294921875 + ] + ], + [ + [ + 0.0156707763671875 + ], + [ + -0.0155792236328125 + ] + ], + [ + [ + 0.01384735107421875 + ], + [ + 0.01262664794921875 + ] + ], + [ + [ + -0.0143890380859375 + ], + [ + 0.015106201171875 + ] + ], + [ + [ + 0.0154571533203125 + ], + [ + -0.01403045654296875 + ] + ], + [ + [ + 0.0149993896484375 + ], + [ + 0.012847900390625 + ] + ], + [ + [ + 0.01552581787109375 + ], + [ + -0.01554107666015625 + ] + ], + [ + [ + 0.01503753662109375 + ], + [ + 0.01519775390625 + ] + ], + [ + [ + 0.0144195556640625 + ], + [ + -0.01325225830078125 + ] + ], + [ + [ + -0.0159454345703125 + ], + [ + -0.01555633544921875 + ] + ], + [ + [ + -0.01416015625 + ], + [ + -0.01580810546875 + ] + ], + [ + [ + -0.01446533203125 + ], + [ + -0.01375579833984375 + ] + ], + [ + [ + 0.01214599609375 + ], + [ + -0.0137786865234375 + ] + ], + [ + [ + 0.01497650146484375 + ], + [ + 0.0144805908203125 + ] + ], + [ + [ + -0.01474761962890625 + ], + [ + -0.0155181884765625 + ] + ], + [ + [ + -0.01508331298828125 + ], + [ + -0.01496124267578125 + ] + ], + [ + [ + -0.01544189453125 + ], + [ + 0.014678955078125 + ] + ], + [ + [ + -0.01329803466796875 + ], + [ + -0.0157012939453125 + ] + ], + [ + [ + 0.01535797119140625 + ], + [ + -0.0161590576171875 + ] + ], + [ + [ + 0.01480865478515625 + ], + [ + -0.01407623291015625 + ] + ], + [ + [ + 0.01212310791015625 + ], + [ + 0.01406097412109375 + ] + ], + [ + [ + 0.012939453125 + ], + [ + 0.01445770263671875 + ] + ], + [ + [ + 0.01476287841796875 + ], + [ + -0.01544189453125 + ] + ], + [ + [ + 0.0135650634765625 + ], + [ + 0.01358795166015625 + ] + ], + [ + [ + -0.0150299072265625 + ], + [ + -0.014190673828125 + ] + ], + [ + [ + 0.01522064208984375 + ], + [ + 0.01520538330078125 + ] + ], + [ + [ + 0.0146942138671875 + ], + [ + -0.01531982421875 + ] + ], + [ + [ + 0.01305389404296875 + ], + [ + 0.0139312744140625 + ] + ], + [ + [ + 0.01507568359375 + ], + [ + -0.01461029052734375 + ] + ], + [ + [ + -0.015899658203125 + ], + [ + 0.01421356201171875 + ] + ], + [ + [ + 0.01385498046875 + ], + [ + 0.01284027099609375 + ] + ], + [ + [ + 0.01535797119140625 + ], + [ + 0.0152740478515625 + ] + ], + [ + [ + -0.0144805908203125 + ], + [ + 0.01386260986328125 + ] + ], + [ + [ + 0.0132598876953125 + ], + [ + -0.0147705078125 + ] + ], + [ + [ + -0.01397705078125 + ], + [ + 0.01549530029296875 + ] + ], + [ + [ + 0.0145111083984375 + ], + [ + -0.0167694091796875 + ] + ], + [ + [ + -0.0148773193359375 + ], + [ + 0.01532745361328125 + ] + ], + [ + [ + -0.0145263671875 + ], + [ + -0.01387786865234375 + ] + ], + [ + [ + 0.01473236083984375 + ], + [ + 0.016326904296875 + ] + ], + [ + [ + -0.01299285888671875 + ], + [ + 0.0149993896484375 + ] + ], + [ + [ + 0.013214111328125 + ], + [ + -0.01541900634765625 + ] + ], + [ + [ + -0.01316070556640625 + ], + [ + 0.0142822265625 + ] + ], + [ + [ + 0.01425933837890625 + ], + [ + -0.01212310791015625 + ] + ], + [ + [ + 0.0168914794921875 + ], + [ + -0.01407623291015625 + ] + ] + ], + "symmetric_weights_decompressor_v_proj_weight_0": [ + [ + [ + -0.0145721435546875 + ], + [ + -0.01470184326171875 + ] + ], + [ + [ + -0.01517486572265625 + ], + [ + -0.01496124267578125 + ] + ], + [ + [ + 0.013580322265625 + ], + [ + -0.0135040283203125 + ] + ], + [ + [ + 0.0142669677734375 + ], + [ + 0.014251708984375 + ] + ], + [ + [ + 0.0146942138671875 + ], + [ + 0.0164337158203125 + ] + ], + [ + [ + -0.0142364501953125 + ], + [ + -0.0138397216796875 + ] + ], + [ + [ + -0.0160064697265625 + ], + [ + 0.01447296142578125 + ] + ], + [ + [ + -0.01551055908203125 + ], + [ + -0.013824462890625 + ] + ], + [ + [ + -0.0135650634765625 + ], + [ + 0.0128326416015625 + ] + ], + [ + [ + -0.01386260986328125 + ], + [ + -0.0139312744140625 + ] + ], + [ + [ + -0.0142059326171875 + ], + [ + 0.01422119140625 + ] + ], + [ + [ + -0.01546478271484375 + ], + [ + -0.0157318115234375 + ] + ], + [ + [ + -0.01416015625 + ], + [ + -0.01371002197265625 + ] + ], + [ + [ + -0.0151519775390625 + ], + [ + 0.0147857666015625 + ] + ], + [ + [ + -0.0164031982421875 + ], + [ + -0.01531982421875 + ] + ], + [ + [ + -0.01323699951171875 + ], + [ + -0.01331329345703125 + ] + ], + [ + [ + 0.0156097412109375 + ], + [ + 0.01561737060546875 + ] + ], + [ + [ + 0.0145721435546875 + ], + [ + 0.0152587890625 + ] + ], + [ + [ + 0.01342010498046875 + ], + [ + 0.013824462890625 + ] + ], + [ + [ + 0.01375579833984375 + ], + [ + -0.012847900390625 + ] + ], + [ + [ + 0.015960693359375 + ], + [ + 0.0157623291015625 + ] + ], + [ + [ + 0.01479339599609375 + ], + [ + 0.012969970703125 + ] + ], + [ + [ + 0.0158233642578125 + ], + [ + -0.0147552490234375 + ] + ], + [ + [ + 0.0137481689453125 + ], + [ + 0.01409912109375 + ] + ], + [ + [ + -0.01373291015625 + ], + [ + -0.01508331298828125 + ] + ], + [ + [ + -0.01456451416015625 + ], + [ + 0.0151824951171875 + ] + ], + [ + [ + -0.01549530029296875 + ], + [ + 0.0151519775390625 + ] + ], + [ + [ + 0.012725830078125 + ], + [ + -0.01461029052734375 + ] + ], + [ + [ + -0.01531982421875 + ], + [ + 0.0142974853515625 + ] + ], + [ + [ + 0.01558685302734375 + ], + [ + 0.01357269287109375 + ] + ], + [ + [ + -0.01500701904296875 + ], + [ + -0.0123291015625 + ] + ], + [ + [ + -0.01526641845703125 + ], + [ + 0.0153961181640625 + ] + ], + [ + [ + 0.01474761962890625 + ], + [ + 0.0154876708984375 + ] + ], + [ + [ + -0.01513671875 + ], + [ + 0.015350341796875 + ] + ], + [ + [ + 0.0153961181640625 + ], + [ + 0.01528167724609375 + ] + ], + [ + [ + 0.0152435302734375 + ], + [ + 0.0153656005859375 + ] + ], + [ + [ + 0.0149993896484375 + ], + [ + -0.01336669921875 + ] + ], + [ + [ + 0.01336669921875 + ], + [ + 0.0147857666015625 + ] + ], + [ + [ + 0.01328277587890625 + ], + [ + -0.0137176513671875 + ] + ], + [ + [ + -0.01544952392578125 + ], + [ + 0.01535797119140625 + ] + ], + [ + [ + 0.0138702392578125 + ], + [ + -0.01288604736328125 + ] + ], + [ + [ + 0.01401519775390625 + ], + [ + -0.0158843994140625 + ] + ], + [ + [ + 0.01477813720703125 + ], + [ + 0.01238250732421875 + ] + ], + [ + [ + 0.01261138916015625 + ], + [ + -0.01371002197265625 + ] + ], + [ + [ + 0.01448822021484375 + ], + [ + -0.0145416259765625 + ] + ], + [ + [ + 0.01453399658203125 + ], + [ + 0.0154571533203125 + ] + ], + [ + [ + 0.014251708984375 + ], + [ + -0.0150604248046875 + ] + ], + [ + [ + -0.0154266357421875 + ], + [ + -0.0140228271484375 + ] + ], + [ + [ + 0.0145721435546875 + ], + [ + 0.015472412109375 + ] + ], + [ + [ + 0.01425933837890625 + ], + [ + -0.01351165771484375 + ] + ], + [ + [ + -0.01450347900390625 + ], + [ + -0.0159759521484375 + ] + ], + [ + [ + -0.01361083984375 + ], + [ + 0.01483917236328125 + ] + ], + [ + [ + -0.01447296142578125 + ], + [ + 0.01418304443359375 + ] + ], + [ + [ + -0.015106201171875 + ], + [ + 0.0139923095703125 + ] + ], + [ + [ + -0.014068603515625 + ], + [ + 0.01320648193359375 + ] + ], + [ + [ + -0.0155181884765625 + ], + [ + 0.01560211181640625 + ] + ], + [ + [ + -0.0155792236328125 + ], + [ + -0.0147247314453125 + ] + ], + [ + [ + 0.0147247314453125 + ], + [ + 0.0133209228515625 + ] + ], + [ + [ + 0.01415252685546875 + ], + [ + 0.0130615234375 + ] + ], + [ + [ + -0.01419830322265625 + ], + [ + -0.014251708984375 + ] + ], + [ + [ + -0.0134124755859375 + ], + [ + 0.01519775390625 + ] + ], + [ + [ + 0.01476287841796875 + ], + [ + 0.0138092041015625 + ] + ], + [ + [ + -0.0151824951171875 + ], + [ + 0.01494598388671875 + ] + ], + [ + [ + 0.015106201171875 + ], + [ + 0.01279449462890625 + ] + ] + ], + "symmetric_weights_decompressor_o_proj_weight_0": [ + [ + [ + 0.015625 + ], + [ + 0.014495849609375 + ] + ], + [ + [ + -0.01404571533203125 + ], + [ + -0.0152130126953125 + ] + ], + [ + [ + -0.01512908935546875 + ], + [ + 0.0160369873046875 + ] + ], + [ + [ + 0.01451873779296875 + ], + [ + -0.0155181884765625 + ] + ], + [ + [ + -0.01464080810546875 + ], + [ + -0.0139007568359375 + ] + ], + [ + [ + -0.0123138427734375 + ], + [ + 0.01412200927734375 + ] + ], + [ + [ + -0.01317596435546875 + ], + [ + 0.0151824951171875 + ] + ], + [ + [ + -0.01235198974609375 + ], + [ + -0.0142059326171875 + ] + ], + [ + [ + -0.0145263671875 + ], + [ + -0.0148162841796875 + ] + ], + [ + [ + 0.01427459716796875 + ], + [ + -0.01490020751953125 + ] + ], + [ + [ + 0.01490020751953125 + ], + [ + 0.01303863525390625 + ] + ], + [ + [ + 0.0155029296875 + ], + [ + -0.013946533203125 + ] + ], + [ + [ + 0.01409149169921875 + ], + [ + -0.01322174072265625 + ] + ], + [ + [ + 0.013427734375 + ], + [ + 0.0127716064453125 + ] + ], + [ + [ + 0.0142669677734375 + ], + [ + 0.01432037353515625 + ] + ], + [ + [ + -0.01528167724609375 + ], + [ + 0.01529693603515625 + ] + ], + [ + [ + 0.01393890380859375 + ], + [ + -0.01446533203125 + ] + ], + [ + [ + -0.01214599609375 + ], + [ + -0.01450347900390625 + ] + ], + [ + [ + 0.013275146484375 + ], + [ + -0.01328277587890625 + ] + ], + [ + [ + -0.01528167724609375 + ], + [ + -0.01406097412109375 + ] + ], + [ + [ + -0.01247406005859375 + ], + [ + -0.0160064697265625 + ] + ], + [ + [ + -0.01490020751953125 + ], + [ + -0.01470184326171875 + ] + ], + [ + [ + -0.01491546630859375 + ], + [ + -0.013702392578125 + ] + ], + [ + [ + -0.0145721435546875 + ], + [ + 0.01506805419921875 + ] + ], + [ + [ + -0.0150146484375 + ], + [ + 0.015380859375 + ] + ], + [ + [ + -0.0146484375 + ], + [ + 0.013946533203125 + ] + ], + [ + [ + 0.0121917724609375 + ], + [ + 0.01367950439453125 + ] + ], + [ + [ + -0.01552581787109375 + ], + [ + -0.015228271484375 + ] + ], + [ + [ + 0.0135650634765625 + ], + [ + -0.01288604736328125 + ] + ], + [ + [ + -0.015869140625 + ], + [ + 0.01409912109375 + ] + ], + [ + [ + -0.013946533203125 + ], + [ + -0.0148162841796875 + ] + ], + [ + [ + 0.01346588134765625 + ], + [ + -0.015533447265625 + ] + ], + [ + [ + 0.01334381103515625 + ], + [ + -0.0154571533203125 + ] + ], + [ + [ + -0.01387786865234375 + ], + [ + -0.0156707763671875 + ] + ], + [ + [ + 0.0160675048828125 + ], + [ + -0.0134429931640625 + ] + ], + [ + [ + 0.0123748779296875 + ], + [ + -0.01427459716796875 + ] + ], + [ + [ + -0.0137939453125 + ], + [ + 0.01299285888671875 + ] + ], + [ + [ + -0.015289306640625 + ], + [ + -0.01548004150390625 + ] + ], + [ + [ + 0.0142059326171875 + ], + [ + 0.0158233642578125 + ] + ], + [ + [ + -0.01528167724609375 + ], + [ + -0.013824462890625 + ] + ], + [ + [ + -0.01453399658203125 + ], + [ + -0.0151519775390625 + ] + ], + [ + [ + -0.01526641845703125 + ], + [ + 0.0164337158203125 + ] + ], + [ + [ + 0.01546478271484375 + ], + [ + -0.01494598388671875 + ] + ], + [ + [ + -0.01458740234375 + ], + [ + -0.01313018798828125 + ] + ], + [ + [ + -0.0141448974609375 + ], + [ + -0.0145721435546875 + ] + ], + [ + [ + -0.0144500732421875 + ], + [ + -0.012664794921875 + ] + ], + [ + [ + 0.0151824951171875 + ], + [ + 0.0142822265625 + ] + ], + [ + [ + 0.01434326171875 + ], + [ + -0.0160675048828125 + ] + ], + [ + [ + 0.01505279541015625 + ], + [ + -0.0137939453125 + ] + ], + [ + [ + 0.01270294189453125 + ], + [ + -0.0133056640625 + ] + ], + [ + [ + -0.01343536376953125 + ], + [ + -0.01441192626953125 + ] + ], + [ + [ + 0.0150146484375 + ], + [ + 0.01453399658203125 + ] + ], + [ + [ + -0.016143798828125 + ], + [ + -0.01445770263671875 + ] + ], + [ + [ + -0.0134735107421875 + ], + [ + 0.01480865478515625 + ] + ], + [ + [ + -0.0162506103515625 + ], + [ + 0.0152130126953125 + ] + ], + [ + [ + -0.01522064208984375 + ], + [ + -0.01541900634765625 + ] + ], + [ + [ + -0.01448822021484375 + ], + [ + 0.01557159423828125 + ] + ], + [ + [ + -0.01395416259765625 + ], + [ + 0.01319122314453125 + ] + ], + [ + [ + -0.0153350830078125 + ], + [ + -0.01532745361328125 + ] + ], + [ + [ + 0.016265869140625 + ], + [ + -0.0161285400390625 + ] + ], + [ + [ + -0.0131988525390625 + ], + [ + 0.015350341796875 + ] + ], + [ + [ + 0.0146331787109375 + ], + [ + -0.01483917236328125 + ] + ], + [ + [ + -0.01554107666015625 + ], + [ + -0.01318359375 + ] + ], + [ + [ + 0.0138092041015625 + ], + [ + 0.01560211181640625 + ] + ] + ], + "symmetric_weights_decompressor_mlp_gate_proj_weight_0": [ + [ + [ + -0.0156097412109375 + ], + [ + 0.0138702392578125 + ] + ], + [ + [ + 0.01531219482421875 + ], + [ + -0.01438140869140625 + ] + ], + [ + [ + 0.01373291015625 + ], + [ + 0.0133514404296875 + ] + ], + [ + [ + 0.01351165771484375 + ], + [ + -0.01241302490234375 + ] + ], + [ + [ + 0.01239776611328125 + ], + [ + -0.01500701904296875 + ] + ], + [ + [ + -0.0160064697265625 + ], + [ + -0.01306915283203125 + ] + ], + [ + [ + -0.0152587890625 + ], + [ + -0.01387786865234375 + ] + ], + [ + [ + -0.0160369873046875 + ], + [ + -0.01507568359375 + ] + ], + [ + [ + -0.0150604248046875 + ], + [ + -0.0146942138671875 + ] + ], + [ + [ + -0.0153350830078125 + ], + [ + 0.0147247314453125 + ] + ], + [ + [ + 0.01427459716796875 + ], + [ + -0.01500701904296875 + ] + ], + [ + [ + -0.0140380859375 + ], + [ + 0.01541900634765625 + ] + ], + [ + [ + 0.01519775390625 + ], + [ + 0.01490020751953125 + ] + ], + [ + [ + 0.01526641845703125 + ], + [ + 0.01348114013671875 + ] + ], + [ + [ + 0.01519012451171875 + ], + [ + -0.0141448974609375 + ] + ], + [ + [ + 0.0132904052734375 + ], + [ + 0.013275146484375 + ] + ], + [ + [ + -0.0136566162109375 + ], + [ + -0.016143798828125 + ] + ], + [ + [ + 0.0150604248046875 + ], + [ + 0.01561737060546875 + ] + ], + [ + [ + -0.01538848876953125 + ], + [ + 0.01464080810546875 + ] + ], + [ + [ + 0.016021728515625 + ], + [ + 0.01496124267578125 + ] + ], + [ + [ + 0.01239776611328125 + ], + [ + -0.01406097412109375 + ] + ], + [ + [ + -0.01380157470703125 + ], + [ + 0.015533447265625 + ] + ], + [ + [ + -0.015472412109375 + ], + [ + -0.01557159423828125 + ] + ], + [ + [ + -0.014190673828125 + ], + [ + 0.01348114013671875 + ] + ], + [ + [ + -0.01543426513671875 + ], + [ + -0.0142669677734375 + ] + ], + [ + [ + 0.014923095703125 + ], + [ + 0.01528167724609375 + ] + ], + [ + [ + 0.01294708251953125 + ], + [ + -0.014862060546875 + ] + ], + [ + [ + 0.01442718505859375 + ], + [ + -0.01514434814453125 + ] + ], + [ + [ + 0.01561737060546875 + ], + [ + -0.0137481689453125 + ] + ], + [ + [ + 0.0157928466796875 + ], + [ + 0.015838623046875 + ] + ], + [ + [ + -0.01508331298828125 + ], + [ + 0.0143585205078125 + ] + ], + [ + [ + 0.01557159423828125 + ], + [ + 0.0131988525390625 + ] + ], + [ + [ + 0.01296234130859375 + ], + [ + -0.01441192626953125 + ] + ], + [ + [ + 0.014129638671875 + ], + [ + 0.0147552490234375 + ] + ], + [ + [ + -0.014892578125 + ], + [ + -0.01434326171875 + ] + ], + [ + [ + -0.0155487060546875 + ], + [ + 0.0153961181640625 + ] + ], + [ + [ + 0.01314544677734375 + ], + [ + 0.01385498046875 + ] + ], + [ + [ + -0.013671875 + ], + [ + 0.015106201171875 + ] + ], + [ + [ + 0.012725830078125 + ], + [ + 0.01401519775390625 + ] + ], + [ + [ + 0.0154876708984375 + ], + [ + -0.01436614990234375 + ] + ], + [ + [ + -0.0135955810546875 + ], + [ + 0.0159149169921875 + ] + ], + [ + [ + 0.01509857177734375 + ], + [ + 0.015533447265625 + ] + ], + [ + [ + 0.01290130615234375 + ], + [ + -0.012908935546875 + ] + ], + [ + [ + 0.01514434814453125 + ], + [ + 0.0147247314453125 + ] + ], + [ + [ + 0.0133056640625 + ], + [ + -0.0161590576171875 + ] + ], + [ + [ + 0.01409912109375 + ], + [ + -0.01456451416015625 + ] + ], + [ + [ + 0.0138092041015625 + ], + [ + -0.0165863037109375 + ] + ], + [ + [ + 0.01416015625 + ], + [ + 0.01491546630859375 + ] + ], + [ + [ + -0.01523590087890625 + ], + [ + 0.0150909423828125 + ] + ], + [ + [ + -0.0140533447265625 + ], + [ + -0.01312255859375 + ] + ], + [ + [ + -0.01364898681640625 + ], + [ + 0.01268768310546875 + ] + ], + [ + [ + -0.01406097412109375 + ], + [ + 0.01497650146484375 + ] + ], + [ + [ + 0.0128326416015625 + ], + [ + 0.01483917236328125 + ] + ], + [ + [ + -0.0146026611328125 + ], + [ + -0.01520538330078125 + ] + ], + [ + [ + 0.0151214599609375 + ], + [ + 0.0113372802734375 + ] + ], + [ + [ + -0.0147857666015625 + ], + [ + -0.015716552734375 + ] + ], + [ + [ + 0.01318359375 + ], + [ + -0.01543426513671875 + ] + ], + [ + [ + 0.01508331298828125 + ], + [ + -0.01529693603515625 + ] + ], + [ + [ + 0.01462554931640625 + ], + [ + -0.01311492919921875 + ] + ], + [ + [ + -0.0139007568359375 + ], + [ + 0.01496124267578125 + ] + ], + [ + [ + 0.0155792236328125 + ], + [ + 0.015899658203125 + ] + ], + [ + [ + 0.01395416259765625 + ], + [ + 0.0123291015625 + ] + ], + [ + [ + -0.01465606689453125 + ], + [ + -0.0148162841796875 + ] + ], + [ + [ + 0.01617431640625 + ], + [ + -0.0152130126953125 + ] + ], + [ + [ + -0.01348876953125 + ], + [ + -0.0154571533203125 + ] + ], + [ + [ + -0.01517486572265625 + ], + [ + -0.0145263671875 + ] + ], + [ + [ + 0.01546478271484375 + ], + [ + -0.01513671875 + ] + ], + [ + [ + 0.015594482421875 + ], + [ + -0.01428985595703125 + ] + ], + [ + [ + 0.0152435302734375 + ], + [ + -0.0138092041015625 + ] + ], + [ + [ + 0.0145263671875 + ], + [ + 0.01174163818359375 + ] + ], + [ + [ + -0.0145263671875 + ], + [ + 0.01326751708984375 + ] + ], + [ + [ + -0.01523590087890625 + ], + [ + 0.0143585205078125 + ] + ], + [ + [ + -0.01380157470703125 + ], + [ + -0.01544189453125 + ] + ], + [ + [ + 0.012054443359375 + ], + [ + -0.01401519775390625 + ] + ], + [ + [ + 0.01190185546875 + ], + [ + -0.016571044921875 + ] + ], + [ + [ + -0.01470184326171875 + ], + [ + -0.0139007568359375 + ] + ], + [ + [ + 0.013427734375 + ], + [ + -0.0148773193359375 + ] + ], + [ + [ + -0.01534271240234375 + ], + [ + 0.01479339599609375 + ] + ], + [ + [ + 0.01433563232421875 + ], + [ + 0.01558685302734375 + ] + ], + [ + [ + 0.0167999267578125 + ], + [ + -0.01342010498046875 + ] + ], + [ + [ + -0.0141754150390625 + ], + [ + -0.01506805419921875 + ] + ], + [ + [ + 0.01541900634765625 + ], + [ + -0.01486968994140625 + ] + ], + [ + [ + -0.01505279541015625 + ], + [ + 0.015533447265625 + ] + ], + [ + [ + 0.013519287109375 + ], + [ + 0.014434814453125 + ] + ], + [ + [ + 0.0151824951171875 + ], + [ + -0.01277923583984375 + ] + ], + [ + [ + 0.01611328125 + ], + [ + -0.0157470703125 + ] + ], + [ + [ + 0.01448822021484375 + ], + [ + 0.01453399658203125 + ] + ], + [ + [ + 0.0153350830078125 + ], + [ + 0.0142059326171875 + ] + ], + [ + [ + -0.014190673828125 + ], + [ + -0.013946533203125 + ] + ], + [ + [ + 0.014923095703125 + ], + [ + -0.01447296142578125 + ] + ], + [ + [ + 0.014495849609375 + ], + [ + 0.014404296875 + ] + ], + [ + [ + 0.016204833984375 + ], + [ + 0.015594482421875 + ] + ], + [ + [ + 0.01555633544921875 + ], + [ + -0.01470947265625 + ] + ], + [ + [ + -0.01280975341796875 + ], + [ + 0.0138092041015625 + ] + ], + [ + [ + -0.0149383544921875 + ], + [ + 0.0152587890625 + ] + ], + [ + [ + -0.0153961181640625 + ], + [ + -0.01477813720703125 + ] + ], + [ + [ + -0.01474761962890625 + ], + [ + -0.0145111083984375 + ] + ], + [ + [ + -0.01343536376953125 + ], + [ + 0.013824462890625 + ] + ], + [ + [ + 0.0166778564453125 + ], + [ + 0.014190673828125 + ] + ], + [ + [ + 0.01358795166015625 + ], + [ + 0.015838623046875 + ] + ], + [ + [ + -0.01520538330078125 + ], + [ + 0.01334381103515625 + ] + ], + [ + [ + -0.01416015625 + ], + [ + 0.013824462890625 + ] + ], + [ + [ + -0.01309967041015625 + ], + [ + 0.0156402587890625 + ] + ], + [ + [ + 0.0135955810546875 + ], + [ + 0.0158538818359375 + ] + ], + [ + [ + -0.01552581787109375 + ], + [ + 0.0140228271484375 + ] + ], + [ + [ + 0.01302337646484375 + ], + [ + -0.01416015625 + ] + ], + [ + [ + 0.013336181640625 + ], + [ + 0.01395416259765625 + ] + ], + [ + [ + 0.0152435302734375 + ], + [ + 0.01525115966796875 + ] + ], + [ + [ + 0.014373779296875 + ], + [ + 0.0148162841796875 + ] + ], + [ + [ + -0.01324462890625 + ], + [ + -0.01549530029296875 + ] + ], + [ + [ + 0.0152740478515625 + ], + [ + -0.01324462890625 + ] + ], + [ + [ + 0.015655517578125 + ], + [ + 0.01544952392578125 + ] + ], + [ + [ + 0.0155792236328125 + ], + [ + -0.01348114013671875 + ] + ], + [ + [ + -0.01526641845703125 + ], + [ + -0.0137176513671875 + ] + ], + [ + [ + -0.0145721435546875 + ], + [ + 0.01506805419921875 + ] + ], + [ + [ + -0.01479339599609375 + ], + [ + 0.0142059326171875 + ] + ], + [ + [ + 0.0159912109375 + ], + [ + 0.015106201171875 + ] + ], + [ + [ + 0.0155029296875 + ], + [ + -0.01354217529296875 + ] + ], + [ + [ + -0.01551055908203125 + ], + [ + 0.0157012939453125 + ] + ], + [ + [ + -0.0138397216796875 + ], + [ + 0.01361083984375 + ] + ], + [ + [ + 0.0142669677734375 + ], + [ + -0.01470947265625 + ] + ], + [ + [ + -0.0139312744140625 + ], + [ + -0.01308441162109375 + ] + ], + [ + [ + 0.01525115966796875 + ], + [ + -0.015869140625 + ] + ], + [ + [ + 0.01397705078125 + ], + [ + 0.01459503173828125 + ] + ], + [ + [ + -0.015838623046875 + ], + [ + -0.01488494873046875 + ] + ], + [ + [ + 0.01422882080078125 + ], + [ + -0.01251220703125 + ] + ], + [ + [ + -0.0144805908203125 + ], + [ + -0.013336181640625 + ] + ], + [ + [ + 0.01526641845703125 + ], + [ + -0.0143585205078125 + ] + ] + ], + "asymmetric_weights_decompressor_mlp_up_proj_weight_0": [ + [ + 0.0009670257568359375 + ], + [ + 0.0009627342224121094 + ], + [ + 0.0009636878967285156 + ], + [ + 0.0009675025939941406 + ], + [ + 0.0009679794311523438 + ], + [ + 0.00096893310546875 + ], + [ + 0.0009675025939941406 + ], + [ + 0.0009493827819824219 + ], + [ + 0.00096893310546875 + ], + [ + 0.0009484291076660156 + ], + [ + 0.000949859619140625 + ], + [ + 0.0009398460388183594 + ], + [ + 0.0009174346923828125 + ], + [ + 0.0009679794311523438 + ], + [ + 0.0009403228759765625 + ], + [ + 0.0009741783142089844 + ], + [ + 0.0008835792541503906 + ], + [ + 0.0009450912475585938 + ], + [ + 0.0009484291076660156 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009579658508300781 + ], + [ + 0.00091552734375 + ], + [ + 0.0009326934814453125 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009627342224121094 + ], + [ + 0.0009765625 + ], + [ + 0.0009531974792480469 + ], + [ + 0.0009503364562988281 + ], + [ + 0.0009765625 + ], + [ + 0.0009579658508300781 + ], + [ + 0.0009169578552246094 + ], + [ + 0.0009493827819824219 + ], + [ + 0.0009450912475585938 + ], + [ + 0.0009741783142089844 + ], + [ + 0.0009379386901855469 + ], + [ + 0.0009403228759765625 + ], + [ + 0.00092315673828125 + ], + [ + 0.0009717941284179688 + ], + [ + 0.000911712646484375 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009360313415527344 + ], + [ + 0.0009765625 + ], + [ + 0.000972747802734375 + ], + [ + 0.0009288787841796875 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009121894836425781 + ], + [ + 0.0009512901306152344 + ], + [ + 0.0009007453918457031 + ], + [ + 0.0009570121765136719 + ], + [ + 0.0009136199951171875 + ], + [ + 0.0009436607360839844 + ], + [ + 0.0009584426879882812 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009465217590332031 + ], + [ + 0.0009765625 + ], + [ + 0.0009484291076660156 + ], + [ + 0.0009670257568359375 + ], + [ + 0.0009775161743164062 + ], + [ + 0.0009627342224121094 + ], + [ + 0.0009360313415527344 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009737014770507812 + ], + [ + 0.0009717941284179688 + ], + [ + 0.0009551048278808594 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009765625 + ], + [ + 0.0008749961853027344 + ], + [ + 0.0009751319885253906 + ], + [ + 0.0009322166442871094 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009331703186035156 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009756088256835938 + ], + [ + 0.0009636878967285156 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009760856628417969 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009207725524902344 + ], + [ + 0.0009741783142089844 + ], + [ + 0.0009436607360839844 + ], + [ + 0.0009217262268066406 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009579658508300781 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009641647338867188 + ], + [ + 0.0009579658508300781 + ], + [ + 0.0009617805480957031 + ], + [ + 0.0009646415710449219 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009202957153320312 + ], + [ + 0.0009560585021972656 + ], + [ + 0.0009617805480957031 + ], + [ + 0.0009784698486328125 + ], + [ + 0.00089263916015625 + ], + [ + 0.0009698867797851562 + ], + [ + 0.0009479522705078125 + ], + [ + 0.0009436607360839844 + ], + [ + 0.0009250640869140625 + ], + [ + 0.000946044921875 + ], + [ + 0.0009484291076660156 + ], + [ + 0.0009794235229492188 + ], + [ + 0.0009698867797851562 + ], + [ + 0.0009255409240722656 + ], + [ + 0.0009746551513671875 + ], + [ + 0.0009427070617675781 + ], + [ + 0.0009365081787109375 + ], + [ + 0.0009627342224121094 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009694099426269531 + ], + [ + 0.0009379386901855469 + ], + [ + 0.0009407997131347656 + ], + [ + 0.0009579658508300781 + ], + [ + 0.0009579658508300781 + ], + [ + 0.0009360313415527344 + ], + [ + 0.0009374618530273438 + ], + [ + 0.000934600830078125 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009431838989257812 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009617805480957031 + ], + [ + 0.0009570121765136719 + ], + [ + 0.0009522438049316406 + ], + [ + 0.0009083747863769531 + ], + [ + 0.0009484291076660156 + ], + [ + 0.0009508132934570312 + ], + [ + 0.0009775161743164062 + ], + [ + 0.0009584426879882812 + ] + ], + "asymmetric_weights_decompressor_mlp_down_proj_weight_0": [ + [ + 0.0006890296936035156 + ], + [ + 0.0006866455078125 + ], + [ + 0.0006794929504394531 + ], + [ + 0.0006856918334960938 + ], + [ + 0.0006885528564453125 + ], + [ + 0.0006775856018066406 + ], + [ + 0.00067901611328125 + ], + [ + 0.0006818771362304688 + ], + [ + 0.0006885528564453125 + ], + [ + 0.0006928443908691406 + ], + [ + 0.0006814002990722656 + ], + [ + 0.000690460205078125 + ], + [ + 0.0006756782531738281 + ], + [ + 0.0006895065307617188 + ], + [ + 0.0006847381591796875 + ], + [ + 0.0006761550903320312 + ], + [ + 0.0006814002990722656 + ], + [ + 0.0006885528564453125 + ], + [ + 0.000682830810546875 + ], + [ + 0.0006794929504394531 + ], + [ + 0.0006899833679199219 + ], + [ + 0.0006809234619140625 + ], + [ + 0.0006785392761230469 + ], + [ + 0.0006670951843261719 + ], + [ + 0.0006914138793945312 + ], + [ + 0.0006780624389648438 + ], + [ + 0.0006856918334960938 + ], + [ + 0.0006742477416992188 + ], + [ + 0.000690460205078125 + ], + [ + 0.0006909370422363281 + ], + [ + 0.0006885528564453125 + ], + [ + 0.0006866455078125 + ], + [ + 0.0006842613220214844 + ], + [ + 0.0006880760192871094 + ], + [ + 0.0006861686706542969 + ], + [ + 0.0006861686706542969 + ], + [ + 0.0006804466247558594 + ], + [ + 0.0006866455078125 + ], + [ + 0.0006761550903320312 + ], + [ + 0.0006871223449707031 + ], + [ + 0.0006875991821289062 + ], + [ + 0.0006780624389648438 + ], + [ + 0.0006880760192871094 + ], + [ + 0.0006909370422363281 + ], + [ + 0.0006718635559082031 + ], + [ + 0.0006723403930664062 + ], + [ + 0.0006895065307617188 + ], + [ + 0.0006694793701171875 + ], + [ + 0.0006737709045410156 + ], + [ + 0.0006885528564453125 + ], + [ + 0.0006785392761230469 + ], + [ + 0.0006885528564453125 + ], + [ + 0.0006804466247558594 + ], + [ + 0.0006866455078125 + ], + [ + 0.0006666183471679688 + ], + [ + 0.0006909370422363281 + ], + [ + 0.0006833076477050781 + ], + [ + 0.0006875991821289062 + ], + [ + 0.0006818771362304688 + ], + [ + 0.0006794929504394531 + ], + [ + 0.0006918907165527344 + ], + [ + 0.0006780624389648438 + ], + [ + 0.0006914138793945312 + ], + [ + 0.0006756782531738281 + ] + ] +} \ No newline at end of file diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_False_ref_wc_param.json b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_False_ref_wc_param.json new file mode 100644 index 00000000000..7cfdf2719df --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_False_ref_wc_param.json @@ -0,0 +1,128 @@ +[ + { + "weight_name": "q_proj_weight", + "node_with_weight": "linear", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + }, + { + "weight_name": "k_proj_weight", + "node_with_weight": "linear_1", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + }, + { + "weight_name": "v_proj_weight", + "node_with_weight": "linear_2", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + }, + { + "weight_name": "o_proj_weight", + "node_with_weight": "linear_3", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + }, + { + "weight_name": "mlp_gate_proj_weight", + "node_with_weight": "linear_4", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 128, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + }, + { + "weight_name": "mlp_up_proj_weight", + "node_with_weight": "linear_5", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 128, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + }, + { + "weight_name": "mlp_down_proj_weight", + "node_with_weight": "linear_6", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 128 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int8_asym", + "group_size": -1, + "codebook_values": null + } + } +] \ No newline at end of file diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_True.dot b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_True.dot new file mode 100644 index 00000000000..31fb9463c88 --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_True.dot @@ -0,0 +1,169 @@ +strict digraph { +"0 attn_norm_weight" [id=0, type="get_attr"]; +"1 mlp_norm_weight" [id=1, type="get_attr"]; +"2 rope_cos" [id=2, type="get_attr"]; +"3 rope_sin" [id=3, type="get_attr"]; +"4 x_embed" [id=4, type=input]; +"5 arange" [id=5, type=arange]; +"6 _assert_tensor_metadata_default" [id=6, type="_assert_tensor_metadata"]; +"7 to" [id=7, type=to]; +"8 pow_1" [id=8, type=pow]; +"9 mean" [id=9, type=mean]; +"10 add" [id=10, type=add]; +"11 rsqrt" [id=11, type=rsqrt]; +"12 mul" [id=12, type=mul]; +"13 _assert_tensor_metadata_default_1" [id=13, type="_assert_tensor_metadata"]; +"14 to_1" [id=14, type=to]; +"15 mul_1" [id=15, type=mul]; +"16 q_proj_weight_updated_constant0" [id=16, type="get_attr"]; +"17 symmetric_weights_decompressor_q_proj_weight_0" [id=17, type="call_module"]; +"18 linear" [id=18, type=linear]; +"19 view" [id=19, type=view]; +"20 transpose" [id=20, type=transpose]; +"21 k_proj_weight_updated_constant0" [id=21, type="get_attr"]; +"22 symmetric_weights_decompressor_k_proj_weight_0" [id=22, type="call_module"]; +"23 linear_1" [id=23, type=linear]; +"24 view_1" [id=24, type=view]; +"25 transpose_1" [id=25, type=transpose]; +"26 v_proj_weight_updated_constant0" [id=26, type="get_attr"]; +"27 symmetric_weights_decompressor_v_proj_weight_0" [id=27, type="call_module"]; +"28 linear_2" [id=28, type=linear]; +"29 view_2" [id=29, type=view]; +"30 transpose_2" [id=30, type=transpose]; +"31 index" [id=31, type=index]; +"32 index_1" [id=32, type=index]; +"33 mul_2" [id=33, type=mul]; +"34 slice_1" [id=34, type=slice]; +"35 slice_2" [id=35, type=slice]; +"36 neg" [id=36, type=neg]; +"37 cat" [id=37, type=cat]; +"38 mul_3" [id=38, type=mul]; +"39 add_1" [id=39, type=add]; +"40 mul_4" [id=40, type=mul]; +"41 slice_3" [id=41, type=slice]; +"42 slice_4" [id=42, type=slice]; +"43 neg_1" [id=43, type=neg]; +"44 cat_1" [id=44, type=cat]; +"45 mul_5" [id=45, type=mul]; +"46 add_2" [id=46, type=add]; +"47 scaled_dot_product_attention" [id=47, type="scaled_dot_product_attention"]; +"48 transpose_3" [id=48, type=transpose]; +"49 view_3" [id=49, type=view]; +"50 o_proj_weight_updated_constant0" [id=50, type="get_attr"]; +"51 symmetric_weights_decompressor_o_proj_weight_0" [id=51, type="call_module"]; +"52 linear_3" [id=52, type=linear]; +"53 add_3" [id=53, type=add]; +"54 _assert_tensor_metadata_default_2" [id=54, type="_assert_tensor_metadata"]; +"55 to_2" [id=55, type=to]; +"56 pow_2" [id=56, type=pow]; +"57 mean_1" [id=57, type=mean]; +"58 add_4" [id=58, type=add]; +"59 rsqrt_1" [id=59, type=rsqrt]; +"60 mul_6" [id=60, type=mul]; +"61 _assert_tensor_metadata_default_3" [id=61, type="_assert_tensor_metadata"]; +"62 to_3" [id=62, type=to]; +"63 mul_7" [id=63, type=mul]; +"64 mlp_gate_proj_weight_updated_constant0" [id=64, type="get_attr"]; +"65 asymmetric_weights_decompressor_mlp_gate_proj_weight_0" [id=65, type="call_module"]; +"66 linear_4" [id=66, type=linear]; +"67 silu" [id=67, type=silu]; +"68 mlp_up_proj_weight_updated_constant0" [id=68, type="get_attr"]; +"69 asymmetric_weights_decompressor_mlp_up_proj_weight_0" [id=69, type="call_module"]; +"70 linear_5" [id=70, type=linear]; +"71 mul_8" [id=71, type=mul]; +"72 mlp_down_proj_weight_updated_constant0" [id=72, type="get_attr"]; +"73 symmetric_weights_decompressor_mlp_down_proj_weight_0" [id=73, type="call_module"]; +"74 linear_6" [id=74, type=linear]; +"75 add_5" [id=75, type=add]; +"76 output" [id=76, type=output]; +"0 attn_norm_weight" -> "15 mul_1" [style=solid, label="(64,)"]; +"1 mlp_norm_weight" -> "63 mul_7" [style=solid, label="(64,)"]; +"2 rope_cos" -> "31 index" [style=solid, label="(1, 1, 128, 16)"]; +"3 rope_sin" -> "32 index_1" [style=solid, label="(1, 1, 128, 16)"]; +"4 x_embed" -> "6 _assert_tensor_metadata_default" [style=solid, label="(1, 3, 64)"]; +"4 x_embed" -> "7 to" [style=solid, label="(1, 3, 64)"]; +"4 x_embed" -> "53 add_3" [style=solid, label="(1, 3, 64)"]; +"5 arange" -> "31 index" [style=solid, label="(3,)"]; +"5 arange" -> "32 index_1" [style=solid, label="(3,)"]; +"7 to" -> "8 pow_1" [style=solid, label="(1, 3, 64)"]; +"7 to" -> "12 mul" [style=solid, label="(1, 3, 64)"]; +"8 pow_1" -> "9 mean" [style=solid, label="(1, 3, 64)"]; +"9 mean" -> "10 add" [style=solid, label="(1, 3, 1)"]; +"10 add" -> "11 rsqrt" [style=solid, label="(1, 3, 1)"]; +"11 rsqrt" -> "12 mul" [style=solid, label="(1, 3, 1)"]; +"12 mul" -> "13 _assert_tensor_metadata_default_1" [style=solid, label="(1, 3, 64)"]; +"12 mul" -> "14 to_1" [style=solid, label="(1, 3, 64)"]; +"14 to_1" -> "15 mul_1" [style=solid, label="(1, 3, 64)"]; +"15 mul_1" -> "18 linear" [style=solid, label="(1, 3, 64)"]; +"15 mul_1" -> "23 linear_1" [style=solid, label="(1, 3, 64)"]; +"15 mul_1" -> "28 linear_2" [style=solid, label="(1, 3, 64)"]; +"16 q_proj_weight_updated_constant0" -> "17 symmetric_weights_decompressor_q_proj_weight_0" [style=solid, label="(2048, 1)"]; +"17 symmetric_weights_decompressor_q_proj_weight_0" -> "18 linear" [style=solid, label="(64, 64)"]; +"18 linear" -> "19 view" [style=solid, label="(1, 3, 64)"]; +"19 view" -> "20 transpose" [style=solid, label="(1, 3, 4, 16)"]; +"20 transpose" -> "33 mul_2" [style=solid, label="(1, 4, 3, 16)"]; +"20 transpose" -> "34 slice_1" [style=solid, label="(1, 4, 3, 16)"]; +"20 transpose" -> "35 slice_2" [style=solid, label="(1, 4, 3, 16)"]; +"21 k_proj_weight_updated_constant0" -> "22 symmetric_weights_decompressor_k_proj_weight_0" [style=solid, label="(2048, 1)"]; +"22 symmetric_weights_decompressor_k_proj_weight_0" -> "23 linear_1" [style=solid, label="(64, 64)"]; +"23 linear_1" -> "24 view_1" [style=solid, label="(1, 3, 64)"]; +"24 view_1" -> "25 transpose_1" [style=solid, label="(1, 3, 4, 16)"]; +"25 transpose_1" -> "40 mul_4" [style=solid, label="(1, 4, 3, 16)"]; +"25 transpose_1" -> "41 slice_3" [style=solid, label="(1, 4, 3, 16)"]; +"25 transpose_1" -> "42 slice_4" [style=solid, label="(1, 4, 3, 16)"]; +"26 v_proj_weight_updated_constant0" -> "27 symmetric_weights_decompressor_v_proj_weight_0" [style=solid, label="(2048, 1)"]; +"27 symmetric_weights_decompressor_v_proj_weight_0" -> "28 linear_2" [style=solid, label="(64, 64)"]; +"28 linear_2" -> "29 view_2" [style=solid, label="(1, 3, 64)"]; +"29 view_2" -> "30 transpose_2" [style=solid, label="(1, 3, 4, 16)"]; +"30 transpose_2" -> "47 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"31 index" -> "33 mul_2" [style=solid, label="(1, 1, 3, 16)"]; +"31 index" -> "40 mul_4" [style=solid, label="(1, 1, 3, 16)"]; +"32 index_1" -> "38 mul_3" [style=solid, label="(1, 1, 3, 16)"]; +"32 index_1" -> "45 mul_5" [style=solid, label="(1, 1, 3, 16)"]; +"33 mul_2" -> "39 add_1" [style=solid, label="(1, 4, 3, 16)"]; +"34 slice_1" -> "37 cat" [style=solid, label="(1, 4, 3, 8)"]; +"35 slice_2" -> "36 neg" [style=solid, label="(1, 4, 3, 8)"]; +"36 neg" -> "37 cat" [style=solid, label="(1, 4, 3, 8)"]; +"37 cat" -> "38 mul_3" [style=solid, label="(1, 4, 3, 16)"]; +"38 mul_3" -> "39 add_1" [style=solid, label="(1, 4, 3, 16)"]; +"39 add_1" -> "47 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"40 mul_4" -> "46 add_2" [style=solid, label="(1, 4, 3, 16)"]; +"41 slice_3" -> "44 cat_1" [style=solid, label="(1, 4, 3, 8)"]; +"42 slice_4" -> "43 neg_1" [style=solid, label="(1, 4, 3, 8)"]; +"43 neg_1" -> "44 cat_1" [style=solid, label="(1, 4, 3, 8)"]; +"44 cat_1" -> "45 mul_5" [style=solid, label="(1, 4, 3, 16)"]; +"45 mul_5" -> "46 add_2" [style=solid, label="(1, 4, 3, 16)"]; +"46 add_2" -> "47 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"47 scaled_dot_product_attention" -> "48 transpose_3" [style=solid, label="(1, 4, 3, 16)"]; +"48 transpose_3" -> "49 view_3" [style=solid, label="(1, 3, 4, 16)"]; +"49 view_3" -> "52 linear_3" [style=solid, label="(1, 3, 64)"]; +"50 o_proj_weight_updated_constant0" -> "51 symmetric_weights_decompressor_o_proj_weight_0" [style=solid, label="(2048, 1)"]; +"51 symmetric_weights_decompressor_o_proj_weight_0" -> "52 linear_3" [style=solid, label="(64, 64)"]; +"52 linear_3" -> "53 add_3" [style=solid, label="(1, 3, 64)"]; +"53 add_3" -> "54 _assert_tensor_metadata_default_2" [style=solid, label="(1, 3, 64)"]; +"53 add_3" -> "55 to_2" [style=solid, label="(1, 3, 64)"]; +"55 to_2" -> "56 pow_2" [style=solid, label="(1, 3, 64)"]; +"55 to_2" -> "60 mul_6" [style=solid, label="(1, 3, 64)"]; +"55 to_2" -> "75 add_5" [style=solid, label="(1, 3, 64)"]; +"56 pow_2" -> "57 mean_1" [style=solid, label="(1, 3, 64)"]; +"57 mean_1" -> "58 add_4" [style=solid, label="(1, 3, 1)"]; +"58 add_4" -> "59 rsqrt_1" [style=solid, label="(1, 3, 1)"]; +"59 rsqrt_1" -> "60 mul_6" [style=solid, label="(1, 3, 1)"]; +"60 mul_6" -> "61 _assert_tensor_metadata_default_3" [style=solid, label="(1, 3, 64)"]; +"60 mul_6" -> "62 to_3" [style=solid, label="(1, 3, 64)"]; +"62 to_3" -> "63 mul_7" [style=solid, label="(1, 3, 64)"]; +"63 mul_7" -> "66 linear_4" [style=solid, label="(1, 3, 64)"]; +"63 mul_7" -> "70 linear_5" [style=solid, label="(1, 3, 64)"]; +"64 mlp_gate_proj_weight_updated_constant0" -> "65 asymmetric_weights_decompressor_mlp_gate_proj_weight_0" [style=solid, label="(128, 64)"]; +"65 asymmetric_weights_decompressor_mlp_gate_proj_weight_0" -> "66 linear_4" [style=solid, label="(128, 64)"]; +"66 linear_4" -> "67 silu" [style=solid, label="(1, 3, 128)"]; +"67 silu" -> "71 mul_8" [style=solid, label="(1, 3, 128)"]; +"68 mlp_up_proj_weight_updated_constant0" -> "69 asymmetric_weights_decompressor_mlp_up_proj_weight_0" [style=solid, label="(128, 64)"]; +"69 asymmetric_weights_decompressor_mlp_up_proj_weight_0" -> "70 linear_5" [style=solid, label="(128, 64)"]; +"70 linear_5" -> "71 mul_8" [style=solid, label="(1, 3, 128)"]; +"71 mul_8" -> "74 linear_6" [style=solid, label="(1, 3, 128)"]; +"72 mlp_down_proj_weight_updated_constant0" -> "73 symmetric_weights_decompressor_mlp_down_proj_weight_0" [style=solid, label="(4096, 1)"]; +"73 symmetric_weights_decompressor_mlp_down_proj_weight_0" -> "74 linear_6" [style=solid, label="(64, 128)"]; +"74 linear_6" -> "75 add_5" [style=solid, label="(1, 3, 64)"]; +"75 add_5" -> "76 output" [style=solid, label="(1, 3, 64)"]; +} diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_True_awq_True_scale_estimation_True_ref_wc_scales.json b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_True_awq_True_scale_estimation_True_ref_wc_scales.json new file mode 100644 index 00000000000..74b808d1245 --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_True_awq_True_scale_estimation_True_ref_wc_scales.json @@ -0,0 +1,3728 @@ +{ + "symmetric_weights_decompressor_q_proj_weight_0": [ + [ + [ + -0.0135650634765625 + ], + [ + -0.014251708984375 + ] + ], + [ + [ + 0.015411376953125 + ], + [ + -0.01557159423828125 + ] + ], + [ + [ + -0.01427459716796875 + ], + [ + -0.01519012451171875 + ] + ], + [ + [ + 0.01224517822265625 + ], + [ + 0.01528167724609375 + ] + ], + [ + [ + 0.01348876953125 + ], + [ + -0.01556396484375 + ] + ], + [ + [ + 0.01401519775390625 + ], + [ + 0.0153350830078125 + ] + ], + [ + [ + 0.01560211181640625 + ], + [ + 0.01523590087890625 + ] + ], + [ + [ + 0.0160675048828125 + ], + [ + -0.0149993896484375 + ] + ], + [ + [ + -0.01445770263671875 + ], + [ + 0.01251220703125 + ] + ], + [ + [ + -0.0155487060546875 + ], + [ + 0.0126495361328125 + ] + ], + [ + [ + -0.01306915283203125 + ], + [ + -0.01422882080078125 + ] + ], + [ + [ + -0.0154571533203125 + ], + [ + -0.015594482421875 + ] + ], + [ + [ + -0.01308441162109375 + ], + [ + -0.01392364501953125 + ] + ], + [ + [ + -0.01512908935546875 + ], + [ + 0.01303863525390625 + ] + ], + [ + [ + -0.0134124755859375 + ], + [ + -0.0130462646484375 + ] + ], + [ + [ + 0.01467132568359375 + ], + [ + 0.01568603515625 + ] + ], + [ + [ + -0.01470184326171875 + ], + [ + 0.0147705078125 + ] + ], + [ + [ + -0.01477813720703125 + ], + [ + -0.01468658447265625 + ] + ], + [ + [ + -0.0128021240234375 + ], + [ + 0.0122833251953125 + ] + ], + [ + [ + -0.0158233642578125 + ], + [ + 0.0131378173828125 + ] + ], + [ + [ + -0.01537322998046875 + ], + [ + -0.01543426513671875 + ] + ], + [ + [ + -0.01548004150390625 + ], + [ + 0.01435089111328125 + ] + ], + [ + [ + -0.0130462646484375 + ], + [ + 0.01294708251953125 + ] + ], + [ + [ + 0.01348876953125 + ], + [ + 0.01446533203125 + ] + ], + [ + [ + -0.0157470703125 + ], + [ + -0.014892578125 + ] + ], + [ + [ + 0.01482391357421875 + ], + [ + -0.01473236083984375 + ] + ], + [ + [ + 0.0155029296875 + ], + [ + 0.0135040283203125 + ] + ], + [ + [ + 0.01456451416015625 + ], + [ + -0.016326904296875 + ] + ], + [ + [ + -0.01509857177734375 + ], + [ + -0.012847900390625 + ] + ], + [ + [ + -0.0152130126953125 + ], + [ + 0.01459503173828125 + ] + ], + [ + [ + -0.0153350830078125 + ], + [ + -0.01287078857421875 + ] + ], + [ + [ + 0.0136871337890625 + ], + [ + 0.014801025390625 + ] + ], + [ + [ + 0.01520538330078125 + ], + [ + -0.01514434814453125 + ] + ], + [ + [ + -0.01471710205078125 + ], + [ + 0.0155792236328125 + ] + ], + [ + [ + -0.01485443115234375 + ], + [ + 0.0147857666015625 + ] + ], + [ + [ + 0.01512908935546875 + ], + [ + -0.01381683349609375 + ] + ], + [ + [ + -0.015838623046875 + ], + [ + -0.01444244384765625 + ] + ], + [ + [ + -0.0146636962890625 + ], + [ + -0.01299285888671875 + ] + ], + [ + [ + -0.01495361328125 + ], + [ + -0.014801025390625 + ] + ], + [ + [ + -0.01396942138671875 + ], + [ + 0.0134124755859375 + ] + ], + [ + [ + -0.01490020751953125 + ], + [ + 0.015045166015625 + ] + ], + [ + [ + -0.01543426513671875 + ], + [ + 0.01514434814453125 + ] + ], + [ + [ + 0.01428985595703125 + ], + [ + 0.0141754150390625 + ] + ], + [ + [ + 0.014923095703125 + ], + [ + 0.01470947265625 + ] + ], + [ + [ + -0.01654052734375 + ], + [ + 0.01470947265625 + ] + ], + [ + [ + 0.0150299072265625 + ], + [ + 0.0132293701171875 + ] + ], + [ + [ + -0.0144500732421875 + ], + [ + -0.014556884765625 + ] + ], + [ + [ + -0.01354217529296875 + ], + [ + -0.01436614990234375 + ] + ], + [ + [ + 0.01250457763671875 + ], + [ + 0.014495849609375 + ] + ], + [ + [ + -0.01361846923828125 + ], + [ + -0.01445770263671875 + ] + ], + [ + [ + -0.0148162841796875 + ], + [ + 0.01213836669921875 + ] + ], + [ + [ + -0.0125274658203125 + ], + [ + -0.0152587890625 + ] + ], + [ + [ + -0.01308441162109375 + ], + [ + 0.01410675048828125 + ] + ], + [ + [ + -0.0150146484375 + ], + [ + 0.01324462890625 + ] + ], + [ + [ + -0.016021728515625 + ], + [ + 0.015289306640625 + ] + ], + [ + [ + -0.0143280029296875 + ], + [ + -0.0139617919921875 + ] + ], + [ + [ + -0.0147247314453125 + ], + [ + 0.0161590576171875 + ] + ], + [ + [ + -0.0119476318359375 + ], + [ + 0.0154571533203125 + ] + ], + [ + [ + -0.01476287841796875 + ], + [ + -0.0137176513671875 + ] + ], + [ + [ + 0.01558685302734375 + ], + [ + 0.013427734375 + ] + ], + [ + [ + -0.0167694091796875 + ], + [ + 0.01517486572265625 + ] + ], + [ + [ + 0.01235198974609375 + ], + [ + -0.01605224609375 + ] + ], + [ + [ + 0.015960693359375 + ], + [ + -0.015167236328125 + ] + ], + [ + [ + 0.01517486572265625 + ], + [ + 0.0162200927734375 + ] + ] + ], + "symmetric_weights_decompressor_k_proj_weight_0": [ + [ + [ + 0.0150604248046875 + ], + [ + -0.0138702392578125 + ] + ], + [ + [ + -0.01486968994140625 + ], + [ + -0.01424407958984375 + ] + ], + [ + [ + 0.01526641845703125 + ], + [ + -0.0126800537109375 + ] + ], + [ + [ + -0.01436614990234375 + ], + [ + -0.0157012939453125 + ] + ], + [ + [ + -0.01470947265625 + ], + [ + 0.013916015625 + ] + ], + [ + [ + -0.01371002197265625 + ], + [ + -0.01558685302734375 + ] + ], + [ + [ + 0.01265716552734375 + ], + [ + 0.01399993896484375 + ] + ], + [ + [ + -0.01520538330078125 + ], + [ + -0.01537322998046875 + ] + ], + [ + [ + 0.01538848876953125 + ], + [ + 0.0160064697265625 + ] + ], + [ + [ + -0.01537322998046875 + ], + [ + -0.01198577880859375 + ] + ], + [ + [ + -0.01551055908203125 + ], + [ + -0.01419830322265625 + ] + ], + [ + [ + -0.01544189453125 + ], + [ + -0.0127410888671875 + ] + ], + [ + [ + 0.014373779296875 + ], + [ + -0.01462554931640625 + ] + ], + [ + [ + 0.01326751708984375 + ], + [ + -0.015716552734375 + ] + ], + [ + [ + -0.01415252685546875 + ], + [ + -0.01483917236328125 + ] + ], + [ + [ + -0.01505279541015625 + ], + [ + 0.0154571533203125 + ] + ], + [ + [ + 0.01538848876953125 + ], + [ + -0.016021728515625 + ] + ], + [ + [ + -0.013916015625 + ], + [ + -0.01514434814453125 + ] + ], + [ + [ + 0.01401519775390625 + ], + [ + -0.01239776611328125 + ] + ], + [ + [ + -0.01540374755859375 + ], + [ + -0.0133209228515625 + ] + ], + [ + [ + 0.014617919921875 + ], + [ + 0.01727294921875 + ] + ], + [ + [ + 0.0156707763671875 + ], + [ + -0.0155792236328125 + ] + ], + [ + [ + 0.01384735107421875 + ], + [ + 0.01262664794921875 + ] + ], + [ + [ + -0.0143890380859375 + ], + [ + 0.015106201171875 + ] + ], + [ + [ + 0.0154571533203125 + ], + [ + -0.01403045654296875 + ] + ], + [ + [ + 0.0149993896484375 + ], + [ + 0.012847900390625 + ] + ], + [ + [ + 0.01552581787109375 + ], + [ + -0.01554107666015625 + ] + ], + [ + [ + 0.01503753662109375 + ], + [ + 0.01519775390625 + ] + ], + [ + [ + 0.0144195556640625 + ], + [ + -0.01325225830078125 + ] + ], + [ + [ + -0.0159454345703125 + ], + [ + -0.01555633544921875 + ] + ], + [ + [ + -0.01416015625 + ], + [ + -0.01580810546875 + ] + ], + [ + [ + -0.01446533203125 + ], + [ + -0.01375579833984375 + ] + ], + [ + [ + 0.01214599609375 + ], + [ + -0.0137786865234375 + ] + ], + [ + [ + 0.01497650146484375 + ], + [ + 0.0144805908203125 + ] + ], + [ + [ + -0.01474761962890625 + ], + [ + -0.0155181884765625 + ] + ], + [ + [ + -0.01508331298828125 + ], + [ + -0.01496124267578125 + ] + ], + [ + [ + -0.01544189453125 + ], + [ + 0.014678955078125 + ] + ], + [ + [ + -0.01329803466796875 + ], + [ + -0.0157012939453125 + ] + ], + [ + [ + 0.01535797119140625 + ], + [ + -0.0161590576171875 + ] + ], + [ + [ + 0.01480865478515625 + ], + [ + -0.01407623291015625 + ] + ], + [ + [ + 0.01212310791015625 + ], + [ + 0.01406097412109375 + ] + ], + [ + [ + 0.012939453125 + ], + [ + 0.01445770263671875 + ] + ], + [ + [ + 0.01476287841796875 + ], + [ + -0.01544189453125 + ] + ], + [ + [ + 0.0135650634765625 + ], + [ + 0.01358795166015625 + ] + ], + [ + [ + -0.0150299072265625 + ], + [ + -0.014190673828125 + ] + ], + [ + [ + 0.01522064208984375 + ], + [ + 0.01520538330078125 + ] + ], + [ + [ + 0.0146942138671875 + ], + [ + -0.01531982421875 + ] + ], + [ + [ + 0.01305389404296875 + ], + [ + 0.0139312744140625 + ] + ], + [ + [ + 0.01507568359375 + ], + [ + -0.01461029052734375 + ] + ], + [ + [ + -0.015899658203125 + ], + [ + 0.01421356201171875 + ] + ], + [ + [ + 0.01385498046875 + ], + [ + 0.01284027099609375 + ] + ], + [ + [ + 0.01535797119140625 + ], + [ + 0.0152740478515625 + ] + ], + [ + [ + -0.0144805908203125 + ], + [ + 0.01386260986328125 + ] + ], + [ + [ + 0.0132598876953125 + ], + [ + -0.0147705078125 + ] + ], + [ + [ + -0.01397705078125 + ], + [ + 0.01549530029296875 + ] + ], + [ + [ + 0.0145111083984375 + ], + [ + -0.0167694091796875 + ] + ], + [ + [ + -0.0148773193359375 + ], + [ + 0.01532745361328125 + ] + ], + [ + [ + -0.0145263671875 + ], + [ + -0.01387786865234375 + ] + ], + [ + [ + 0.01473236083984375 + ], + [ + 0.016326904296875 + ] + ], + [ + [ + -0.01299285888671875 + ], + [ + 0.0149993896484375 + ] + ], + [ + [ + 0.013214111328125 + ], + [ + -0.01541900634765625 + ] + ], + [ + [ + -0.01316070556640625 + ], + [ + 0.0142822265625 + ] + ], + [ + [ + 0.01425933837890625 + ], + [ + -0.01212310791015625 + ] + ], + [ + [ + 0.0168914794921875 + ], + [ + -0.01407623291015625 + ] + ] + ], + "symmetric_weights_decompressor_v_proj_weight_0": [ + [ + [ + -0.0145721435546875 + ], + [ + -0.01470184326171875 + ] + ], + [ + [ + -0.01517486572265625 + ], + [ + -0.01496124267578125 + ] + ], + [ + [ + 0.013580322265625 + ], + [ + -0.0135040283203125 + ] + ], + [ + [ + 0.0142669677734375 + ], + [ + 0.014251708984375 + ] + ], + [ + [ + 0.0146942138671875 + ], + [ + 0.0164337158203125 + ] + ], + [ + [ + -0.0142364501953125 + ], + [ + -0.0138397216796875 + ] + ], + [ + [ + -0.0160064697265625 + ], + [ + 0.01447296142578125 + ] + ], + [ + [ + -0.01551055908203125 + ], + [ + -0.013824462890625 + ] + ], + [ + [ + -0.0135650634765625 + ], + [ + 0.0128326416015625 + ] + ], + [ + [ + -0.01386260986328125 + ], + [ + -0.0139312744140625 + ] + ], + [ + [ + -0.0142059326171875 + ], + [ + 0.01422119140625 + ] + ], + [ + [ + -0.01546478271484375 + ], + [ + -0.0157318115234375 + ] + ], + [ + [ + -0.01416015625 + ], + [ + -0.01371002197265625 + ] + ], + [ + [ + -0.0151519775390625 + ], + [ + 0.0147857666015625 + ] + ], + [ + [ + -0.0164031982421875 + ], + [ + -0.01531982421875 + ] + ], + [ + [ + -0.01323699951171875 + ], + [ + -0.01331329345703125 + ] + ], + [ + [ + 0.0156097412109375 + ], + [ + 0.01561737060546875 + ] + ], + [ + [ + 0.0145721435546875 + ], + [ + 0.0152587890625 + ] + ], + [ + [ + 0.01342010498046875 + ], + [ + 0.013824462890625 + ] + ], + [ + [ + 0.01375579833984375 + ], + [ + -0.012847900390625 + ] + ], + [ + [ + 0.015960693359375 + ], + [ + 0.0157623291015625 + ] + ], + [ + [ + 0.01479339599609375 + ], + [ + 0.012969970703125 + ] + ], + [ + [ + 0.0158233642578125 + ], + [ + -0.0147552490234375 + ] + ], + [ + [ + 0.0137481689453125 + ], + [ + 0.01409912109375 + ] + ], + [ + [ + -0.01373291015625 + ], + [ + -0.01508331298828125 + ] + ], + [ + [ + -0.01456451416015625 + ], + [ + 0.0151824951171875 + ] + ], + [ + [ + -0.01549530029296875 + ], + [ + 0.0151519775390625 + ] + ], + [ + [ + 0.012725830078125 + ], + [ + -0.01461029052734375 + ] + ], + [ + [ + -0.01531982421875 + ], + [ + 0.0142974853515625 + ] + ], + [ + [ + 0.01558685302734375 + ], + [ + 0.01357269287109375 + ] + ], + [ + [ + -0.01500701904296875 + ], + [ + -0.0123291015625 + ] + ], + [ + [ + -0.01526641845703125 + ], + [ + 0.0153961181640625 + ] + ], + [ + [ + 0.01474761962890625 + ], + [ + 0.0154876708984375 + ] + ], + [ + [ + -0.01513671875 + ], + [ + 0.015350341796875 + ] + ], + [ + [ + 0.0153961181640625 + ], + [ + 0.01528167724609375 + ] + ], + [ + [ + 0.0152435302734375 + ], + [ + 0.0153656005859375 + ] + ], + [ + [ + 0.0149993896484375 + ], + [ + -0.01336669921875 + ] + ], + [ + [ + 0.01336669921875 + ], + [ + 0.0147857666015625 + ] + ], + [ + [ + 0.01328277587890625 + ], + [ + -0.0137176513671875 + ] + ], + [ + [ + -0.01544952392578125 + ], + [ + 0.01535797119140625 + ] + ], + [ + [ + 0.0138702392578125 + ], + [ + -0.01288604736328125 + ] + ], + [ + [ + 0.01401519775390625 + ], + [ + -0.0158843994140625 + ] + ], + [ + [ + 0.01477813720703125 + ], + [ + 0.01238250732421875 + ] + ], + [ + [ + 0.01261138916015625 + ], + [ + -0.01371002197265625 + ] + ], + [ + [ + 0.01448822021484375 + ], + [ + -0.0145416259765625 + ] + ], + [ + [ + 0.01453399658203125 + ], + [ + 0.0154571533203125 + ] + ], + [ + [ + 0.014251708984375 + ], + [ + -0.0150604248046875 + ] + ], + [ + [ + -0.0154266357421875 + ], + [ + -0.0140228271484375 + ] + ], + [ + [ + 0.0145721435546875 + ], + [ + 0.015472412109375 + ] + ], + [ + [ + 0.01425933837890625 + ], + [ + -0.01351165771484375 + ] + ], + [ + [ + -0.01450347900390625 + ], + [ + -0.0159759521484375 + ] + ], + [ + [ + -0.01361083984375 + ], + [ + 0.01483917236328125 + ] + ], + [ + [ + -0.01447296142578125 + ], + [ + 0.01418304443359375 + ] + ], + [ + [ + -0.015106201171875 + ], + [ + 0.0139923095703125 + ] + ], + [ + [ + -0.014068603515625 + ], + [ + 0.01320648193359375 + ] + ], + [ + [ + -0.0155181884765625 + ], + [ + 0.01560211181640625 + ] + ], + [ + [ + -0.0155792236328125 + ], + [ + -0.0147247314453125 + ] + ], + [ + [ + 0.0147247314453125 + ], + [ + 0.0133209228515625 + ] + ], + [ + [ + 0.01415252685546875 + ], + [ + 0.0130615234375 + ] + ], + [ + [ + -0.01419830322265625 + ], + [ + -0.014251708984375 + ] + ], + [ + [ + -0.0134124755859375 + ], + [ + 0.01519775390625 + ] + ], + [ + [ + 0.01476287841796875 + ], + [ + 0.0138092041015625 + ] + ], + [ + [ + -0.0151824951171875 + ], + [ + 0.01494598388671875 + ] + ], + [ + [ + 0.015106201171875 + ], + [ + 0.01279449462890625 + ] + ] + ], + "symmetric_weights_decompressor_o_proj_weight_0": [ + [ + [ + 0.015625 + ], + [ + 0.014495849609375 + ] + ], + [ + [ + -0.01404571533203125 + ], + [ + -0.0152130126953125 + ] + ], + [ + [ + -0.01512908935546875 + ], + [ + 0.0160369873046875 + ] + ], + [ + [ + 0.01451873779296875 + ], + [ + -0.0155181884765625 + ] + ], + [ + [ + -0.01464080810546875 + ], + [ + -0.0139007568359375 + ] + ], + [ + [ + -0.0123138427734375 + ], + [ + 0.01412200927734375 + ] + ], + [ + [ + -0.01317596435546875 + ], + [ + 0.0151824951171875 + ] + ], + [ + [ + -0.01235198974609375 + ], + [ + -0.0142059326171875 + ] + ], + [ + [ + -0.0145263671875 + ], + [ + -0.0148162841796875 + ] + ], + [ + [ + 0.01427459716796875 + ], + [ + -0.01490020751953125 + ] + ], + [ + [ + 0.01490020751953125 + ], + [ + 0.01303863525390625 + ] + ], + [ + [ + 0.0155029296875 + ], + [ + -0.013946533203125 + ] + ], + [ + [ + 0.01409149169921875 + ], + [ + -0.01322174072265625 + ] + ], + [ + [ + 0.013427734375 + ], + [ + 0.0127716064453125 + ] + ], + [ + [ + 0.0142669677734375 + ], + [ + 0.01432037353515625 + ] + ], + [ + [ + -0.01528167724609375 + ], + [ + 0.01529693603515625 + ] + ], + [ + [ + 0.01393890380859375 + ], + [ + -0.01446533203125 + ] + ], + [ + [ + -0.01214599609375 + ], + [ + -0.01450347900390625 + ] + ], + [ + [ + 0.013275146484375 + ], + [ + -0.01328277587890625 + ] + ], + [ + [ + -0.01528167724609375 + ], + [ + -0.01406097412109375 + ] + ], + [ + [ + -0.01247406005859375 + ], + [ + -0.0160064697265625 + ] + ], + [ + [ + -0.01490020751953125 + ], + [ + -0.01470184326171875 + ] + ], + [ + [ + -0.01491546630859375 + ], + [ + -0.013702392578125 + ] + ], + [ + [ + -0.0145721435546875 + ], + [ + 0.01506805419921875 + ] + ], + [ + [ + -0.0150146484375 + ], + [ + 0.015380859375 + ] + ], + [ + [ + -0.0146484375 + ], + [ + 0.013946533203125 + ] + ], + [ + [ + 0.0121917724609375 + ], + [ + 0.01367950439453125 + ] + ], + [ + [ + -0.01552581787109375 + ], + [ + -0.015228271484375 + ] + ], + [ + [ + 0.0135650634765625 + ], + [ + -0.01288604736328125 + ] + ], + [ + [ + -0.015869140625 + ], + [ + 0.01409912109375 + ] + ], + [ + [ + -0.013946533203125 + ], + [ + -0.0148162841796875 + ] + ], + [ + [ + 0.01346588134765625 + ], + [ + -0.015533447265625 + ] + ], + [ + [ + 0.01334381103515625 + ], + [ + -0.0154571533203125 + ] + ], + [ + [ + -0.01387786865234375 + ], + [ + -0.0156707763671875 + ] + ], + [ + [ + 0.0160675048828125 + ], + [ + -0.0134429931640625 + ] + ], + [ + [ + 0.0123748779296875 + ], + [ + -0.01427459716796875 + ] + ], + [ + [ + -0.0137939453125 + ], + [ + 0.01299285888671875 + ] + ], + [ + [ + -0.015289306640625 + ], + [ + -0.01548004150390625 + ] + ], + [ + [ + 0.0142059326171875 + ], + [ + 0.0158233642578125 + ] + ], + [ + [ + -0.01528167724609375 + ], + [ + -0.013824462890625 + ] + ], + [ + [ + -0.01453399658203125 + ], + [ + -0.0151519775390625 + ] + ], + [ + [ + -0.01526641845703125 + ], + [ + 0.0164337158203125 + ] + ], + [ + [ + 0.01546478271484375 + ], + [ + -0.01494598388671875 + ] + ], + [ + [ + -0.01458740234375 + ], + [ + -0.01313018798828125 + ] + ], + [ + [ + -0.0141448974609375 + ], + [ + -0.0145721435546875 + ] + ], + [ + [ + -0.0144500732421875 + ], + [ + -0.012664794921875 + ] + ], + [ + [ + 0.0151824951171875 + ], + [ + 0.0142822265625 + ] + ], + [ + [ + 0.01434326171875 + ], + [ + -0.0160675048828125 + ] + ], + [ + [ + 0.01505279541015625 + ], + [ + -0.0137939453125 + ] + ], + [ + [ + 0.01270294189453125 + ], + [ + -0.0133056640625 + ] + ], + [ + [ + -0.01343536376953125 + ], + [ + -0.01441192626953125 + ] + ], + [ + [ + 0.0150146484375 + ], + [ + 0.01453399658203125 + ] + ], + [ + [ + -0.016143798828125 + ], + [ + -0.01445770263671875 + ] + ], + [ + [ + -0.0134735107421875 + ], + [ + 0.01480865478515625 + ] + ], + [ + [ + -0.0162506103515625 + ], + [ + 0.0152130126953125 + ] + ], + [ + [ + -0.01522064208984375 + ], + [ + -0.01541900634765625 + ] + ], + [ + [ + -0.01448822021484375 + ], + [ + 0.01557159423828125 + ] + ], + [ + [ + -0.01395416259765625 + ], + [ + 0.01319122314453125 + ] + ], + [ + [ + -0.0153350830078125 + ], + [ + -0.01532745361328125 + ] + ], + [ + [ + 0.016265869140625 + ], + [ + -0.0161285400390625 + ] + ], + [ + [ + -0.0131988525390625 + ], + [ + 0.015350341796875 + ] + ], + [ + [ + 0.0146331787109375 + ], + [ + -0.01483917236328125 + ] + ], + [ + [ + -0.01554107666015625 + ], + [ + -0.01318359375 + ] + ], + [ + [ + 0.0138092041015625 + ], + [ + 0.01560211181640625 + ] + ] + ], + "asymmetric_weights_decompressor_mlp_gate_proj_weight_0": [ + [ + 0.0009555816650390625 + ], + [ + 0.0009636878967285156 + ], + [ + 0.00091552734375 + ], + [ + 0.0009512901306152344 + ], + [ + 0.000949859619140625 + ], + [ + 0.0009474754333496094 + ], + [ + 0.0009431838989257812 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009593963623046875 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009398460388183594 + ], + [ + 0.0009694099426269531 + ], + [ + 0.0009374618530273438 + ], + [ + 0.0009660720825195312 + ], + [ + 0.0009570121765136719 + ], + [ + 0.0009589195251464844 + ], + [ + 0.0009474754333496094 + ], + [ + 0.0009517669677734375 + ], + [ + 0.000942230224609375 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009679794311523438 + ], + [ + 0.0009641647338867188 + ], + [ + 0.0009756088256835938 + ], + [ + 0.0009708404541015625 + ], + [ + 0.0009584426879882812 + ], + [ + 0.0009531974792480469 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009632110595703125 + ], + [ + 0.0009751319885253906 + ], + [ + 0.0009646415710449219 + ], + [ + 0.0009503364562988281 + ], + [ + 0.0009660720825195312 + ], + [ + 0.0009450912475585938 + ], + [ + 0.0009551048278808594 + ], + [ + 0.0009632110595703125 + ], + [ + 0.00096893310546875 + ], + [ + 0.00087738037109375 + ], + [ + 0.0009603500366210938 + ], + [ + 0.0009393692016601562 + ], + [ + 0.0009708404541015625 + ], + [ + 0.0009555816650390625 + ], + [ + 0.0009717941284179688 + ], + [ + 0.0009641647338867188 + ], + [ + 0.0009632110595703125 + ], + [ + 0.000972747802734375 + ], + [ + 0.0009322166442871094 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009617805480957031 + ], + [ + 0.0009284019470214844 + ], + [ + 0.0009598731994628906 + ], + [ + 0.0009293556213378906 + ], + [ + 0.00092315673828125 + ], + [ + 0.0008797645568847656 + ], + [ + 0.0009746551513671875 + ], + [ + 0.0009598731994628906 + ], + [ + 0.0009226799011230469 + ], + [ + 0.0009679794311523438 + ], + [ + 0.0009741783142089844 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009245872497558594 + ], + [ + 0.0009694099426269531 + ], + [ + 0.0009393692016601562 + ], + [ + 0.0009298324584960938 + ], + [ + 0.0009636878967285156 + ], + [ + 0.0009665489196777344 + ], + [ + 0.0009307861328125 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009760856628417969 + ], + [ + 0.0009551048278808594 + ], + [ + 0.0008802413940429688 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009632110595703125 + ], + [ + 0.0009469985961914062 + ], + [ + 0.00095367431640625 + ], + [ + 0.0009508132934570312 + ], + [ + 0.0009350776672363281 + ], + [ + 0.0009427070617675781 + ], + [ + 0.0009551048278808594 + ], + [ + 0.0009765625 + ], + [ + 0.0009560585021972656 + ], + [ + 0.0009598731994628906 + ], + [ + 0.0009636878967285156 + ], + [ + 0.0009670257568359375 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009646415710449219 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009274482727050781 + ], + [ + 0.0009093284606933594 + ], + [ + 0.0009398460388183594 + ], + [ + 0.0009493827819824219 + ], + [ + 0.0009708404541015625 + ], + [ + 0.0009756088256835938 + ], + [ + 0.0009412765502929688 + ], + [ + 0.0009632110595703125 + ], + [ + 0.0009260177612304688 + ], + [ + 0.0009589195251464844 + ], + [ + 0.0009484291076660156 + ], + [ + 0.000957489013671875 + ], + [ + 0.0009765625 + ], + [ + 0.0009641647338867188 + ], + [ + 0.000965118408203125 + ], + [ + 0.00095367431640625 + ], + [ + 0.0009660720825195312 + ], + [ + 0.0009703636169433594 + ], + [ + 0.000949859619140625 + ], + [ + 0.0009622573852539062 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009379386901855469 + ], + [ + 0.0009436607360839844 + ], + [ + 0.0009646415710449219 + ], + [ + 0.0009708404541015625 + ], + [ + 0.0009355545043945312 + ], + [ + 0.0009694099426269531 + ], + [ + 0.000942230224609375 + ], + [ + 0.0009441375732421875 + ], + [ + 0.0009751319885253906 + ], + [ + 0.000957489013671875 + ], + [ + 0.0009441375732421875 + ], + [ + 0.0009717941284179688 + ], + [ + 0.0009450912475585938 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009417533874511719 + ], + [ + 0.0009288787841796875 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009632110595703125 + ] + ], + "asymmetric_weights_decompressor_mlp_up_proj_weight_updated_constant0_0": [ + [ + 0.0009670257568359375 + ], + [ + 0.0009627342224121094 + ], + [ + 0.0009636878967285156 + ], + [ + 0.0009675025939941406 + ], + [ + 0.0009679794311523438 + ], + [ + 0.00096893310546875 + ], + [ + 0.0009675025939941406 + ], + [ + 0.0009493827819824219 + ], + [ + 0.00096893310546875 + ], + [ + 0.0009484291076660156 + ], + [ + 0.000949859619140625 + ], + [ + 0.0009398460388183594 + ], + [ + 0.0009174346923828125 + ], + [ + 0.0009679794311523438 + ], + [ + 0.0009403228759765625 + ], + [ + 0.0009741783142089844 + ], + [ + 0.0008835792541503906 + ], + [ + 0.0009450912475585938 + ], + [ + 0.0009484291076660156 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009579658508300781 + ], + [ + 0.00091552734375 + ], + [ + 0.0009326934814453125 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009627342224121094 + ], + [ + 0.0009765625 + ], + [ + 0.0009531974792480469 + ], + [ + 0.0009503364562988281 + ], + [ + 0.0009765625 + ], + [ + 0.0009579658508300781 + ], + [ + 0.0009169578552246094 + ], + [ + 0.0009493827819824219 + ], + [ + 0.0009450912475585938 + ], + [ + 0.0009741783142089844 + ], + [ + 0.0009379386901855469 + ], + [ + 0.0009403228759765625 + ], + [ + 0.00092315673828125 + ], + [ + 0.0009717941284179688 + ], + [ + 0.000911712646484375 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009360313415527344 + ], + [ + 0.0009765625 + ], + [ + 0.000972747802734375 + ], + [ + 0.0009288787841796875 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009121894836425781 + ], + [ + 0.0009512901306152344 + ], + [ + 0.0009007453918457031 + ], + [ + 0.0009570121765136719 + ], + [ + 0.0009136199951171875 + ], + [ + 0.0009436607360839844 + ], + [ + 0.0009584426879882812 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009465217590332031 + ], + [ + 0.0009765625 + ], + [ + 0.0009484291076660156 + ], + [ + 0.0009670257568359375 + ], + [ + 0.0009775161743164062 + ], + [ + 0.0009627342224121094 + ], + [ + 0.0009360313415527344 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009737014770507812 + ], + [ + 0.0009717941284179688 + ], + [ + 0.0009551048278808594 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009765625 + ], + [ + 0.0008749961853027344 + ], + [ + 0.0009751319885253906 + ], + [ + 0.0009322166442871094 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009331703186035156 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009756088256835938 + ], + [ + 0.0009636878967285156 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009760856628417969 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009207725524902344 + ], + [ + 0.0009741783142089844 + ], + [ + 0.0009436607360839844 + ], + [ + 0.0009217262268066406 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009579658508300781 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009641647338867188 + ], + [ + 0.0009579658508300781 + ], + [ + 0.0009617805480957031 + ], + [ + 0.0009646415710449219 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009202957153320312 + ], + [ + 0.0009560585021972656 + ], + [ + 0.0009617805480957031 + ], + [ + 0.0009784698486328125 + ], + [ + 0.00089263916015625 + ], + [ + 0.0009698867797851562 + ], + [ + 0.0013704299926757812 + ], + [ + 0.00206756591796875 + ], + [ + 0.0013513565063476562 + ], + [ + 0.00225830078125 + ], + [ + 0.0021381378173828125 + ], + [ + 0.0024662017822265625 + ], + [ + 0.0025844573974609375 + ], + [ + 0.0015583038330078125 + ], + [ + 0.0023174285888671875 + ], + [ + 0.0025119781494140625 + ], + [ + 0.002399444580078125 + ], + [ + 0.0020198822021484375 + ], + [ + 0.00145721435546875 + ], + [ + 0.0021514892578125 + ], + [ + 0.0019207000732421875 + ], + [ + 0.0019245147705078125 + ], + [ + 0.0016422271728515625 + ], + [ + 0.00133514404296875 + ], + [ + 0.0024929046630859375 + ], + [ + 0.0015106201171875 + ], + [ + 0.0017309188842773438 + ], + [ + 0.0017538070678710938 + ], + [ + 0.00246429443359375 + ], + [ + 0.0012035369873046875 + ], + [ + 0.002346038818359375 + ], + [ + 0.0008511543273925781 + ], + [ + 0.001300811767578125 + ], + [ + 0.0024204254150390625 + ], + [ + 0.002277374267578125 + ], + [ + 0.00124359130859375 + ], + [ + 0.0018281936645507812 + ], + [ + 0.0013427734375 + ] + ], + "symmetric_weights_decompressor_mlp_down_proj_weight_updated_constant0_0": [ + [ + [ + -0.01129150390625 + ], + [ + 0.01030731201171875 + ], + [ + -0.0095977783203125 + ], + [ + -0.00717926025390625 + ] + ], + [ + [ + -0.0103607177734375 + ], + [ + 0.010406494140625 + ], + [ + 0.010711669921875 + ], + [ + -0.00701904296875 + ] + ], + [ + [ + -0.0098419189453125 + ], + [ + 0.01096343994140625 + ], + [ + -0.01012420654296875 + ], + [ + 0.00954437255859375 + ] + ], + [ + [ + 0.0105133056640625 + ], + [ + 0.01090240478515625 + ], + [ + -0.010833740234375 + ], + [ + -0.00818634033203125 + ] + ], + [ + [ + 0.01097869873046875 + ], + [ + 0.0105438232421875 + ], + [ + -0.01099395751953125 + ], + [ + -0.00856781005859375 + ] + ], + [ + [ + 0.0101318359375 + ], + [ + 0.0116119384765625 + ], + [ + 0.00989532470703125 + ], + [ + 0.01172637939453125 + ] + ], + [ + [ + -0.00850677490234375 + ], + [ + 0.0114288330078125 + ], + [ + 0.01036834716796875 + ], + [ + 0.0076446533203125 + ] + ], + [ + [ + -0.0112457275390625 + ], + [ + -0.0092315673828125 + ], + [ + 0.00942230224609375 + ], + [ + 0.007320404052734375 + ] + ], + [ + [ + 0.0104522705078125 + ], + [ + -0.00957489013671875 + ], + [ + -0.01071929931640625 + ], + [ + -0.00634002685546875 + ] + ], + [ + [ + -0.0102996826171875 + ], + [ + 0.01103973388671875 + ], + [ + -0.009124755859375 + ], + [ + -0.00803375244140625 + ] + ], + [ + [ + -0.0095367431640625 + ], + [ + 0.00888824462890625 + ], + [ + 0.01154327392578125 + ], + [ + 0.00800323486328125 + ] + ], + [ + [ + 0.009674072265625 + ], + [ + -0.0116119384765625 + ], + [ + 0.0104522705078125 + ], + [ + 0.00786590576171875 + ] + ], + [ + [ + 0.0091705322265625 + ], + [ + 0.00913238525390625 + ], + [ + -0.01096343994140625 + ], + [ + -0.007678985595703125 + ] + ], + [ + [ + 0.01239013671875 + ], + [ + 0.009857177734375 + ], + [ + 0.01012420654296875 + ], + [ + -0.007171630859375 + ] + ], + [ + [ + 0.01021575927734375 + ], + [ + 0.00972747802734375 + ], + [ + -0.01096343994140625 + ], + [ + 0.00801849365234375 + ] + ], + [ + [ + -0.01032257080078125 + ], + [ + -0.01013946533203125 + ], + [ + -0.01071929931640625 + ], + [ + -0.0102691650390625 + ] + ], + [ + [ + -0.0106964111328125 + ], + [ + 0.00943756103515625 + ], + [ + 0.01076507568359375 + ], + [ + 0.00707244873046875 + ] + ], + [ + [ + -0.0108795166015625 + ], + [ + -0.010406494140625 + ], + [ + 0.0109710693359375 + ], + [ + 0.00952911376953125 + ] + ], + [ + [ + -0.009552001953125 + ], + [ + 0.01085662841796875 + ], + [ + -0.00939178466796875 + ], + [ + -0.01177215576171875 + ] + ], + [ + [ + -0.0090179443359375 + ], + [ + 0.00785064697265625 + ], + [ + 0.00989532470703125 + ], + [ + 0.0099334716796875 + ] + ], + [ + [ + 0.0109100341796875 + ], + [ + -0.01056671142578125 + ], + [ + 0.0117950439453125 + ], + [ + 0.0103607177734375 + ] + ], + [ + [ + 0.01050567626953125 + ], + [ + -0.0103912353515625 + ], + [ + 0.01074981689453125 + ], + [ + 0.007213592529296875 + ] + ], + [ + [ + -0.009979248046875 + ], + [ + -0.01123046875 + ], + [ + -0.0108489990234375 + ], + [ + -0.00695037841796875 + ] + ], + [ + [ + -0.01064300537109375 + ], + [ + -0.01023101806640625 + ], + [ + 0.00847625732421875 + ], + [ + 0.00609588623046875 + ] + ], + [ + [ + -0.0100250244140625 + ], + [ + 0.0110015869140625 + ], + [ + -0.009124755859375 + ], + [ + -0.007610321044921875 + ] + ], + [ + [ + 0.01087188720703125 + ], + [ + 0.01104736328125 + ], + [ + 0.01092529296875 + ], + [ + 0.008697509765625 + ] + ], + [ + [ + -0.0101470947265625 + ], + [ + 0.0101318359375 + ], + [ + -0.01070404052734375 + ], + [ + -0.007740020751953125 + ] + ], + [ + [ + 0.010467529296875 + ], + [ + -0.01071929931640625 + ], + [ + -0.01088714599609375 + ], + [ + 0.00823974609375 + ] + ], + [ + [ + 0.0109710693359375 + ], + [ + 0.01070404052734375 + ], + [ + -0.0088653564453125 + ], + [ + -0.0120391845703125 + ] + ], + [ + [ + -0.0101470947265625 + ], + [ + 0.01103973388671875 + ], + [ + -0.01092529296875 + ], + [ + 0.00841522216796875 + ] + ], + [ + [ + -0.00984954833984375 + ], + [ + 0.00902557373046875 + ], + [ + 0.01081085205078125 + ], + [ + -0.0115203857421875 + ] + ], + [ + [ + 0.01021575927734375 + ], + [ + -0.0107421875 + ], + [ + 0.01123809814453125 + ], + [ + 0.00835418701171875 + ] + ], + [ + [ + 0.01099395751953125 + ], + [ + -0.0100250244140625 + ], + [ + 0.01085662841796875 + ], + [ + 0.006694793701171875 + ] + ], + [ + [ + 0.00936126708984375 + ], + [ + 0.01097869873046875 + ], + [ + 0.01055908203125 + ], + [ + 0.00826263427734375 + ] + ], + [ + [ + 0.01218414306640625 + ], + [ + -0.01041412353515625 + ], + [ + -0.01038360595703125 + ], + [ + 0.00843048095703125 + ] + ], + [ + [ + -0.01076507568359375 + ], + [ + -0.0114593505859375 + ], + [ + 0.00991058349609375 + ], + [ + -0.0055389404296875 + ] + ], + [ + [ + -0.0111236572265625 + ], + [ + -0.0110015869140625 + ], + [ + 0.0101776123046875 + ], + [ + 0.00720977783203125 + ] + ], + [ + [ + -0.00986480712890625 + ], + [ + 0.01038360595703125 + ], + [ + -0.01102447509765625 + ], + [ + 0.00872802734375 + ] + ], + [ + [ + -0.01039886474609375 + ], + [ + -0.00897216796875 + ], + [ + 0.01068115234375 + ], + [ + -0.006473541259765625 + ] + ], + [ + [ + -0.01056671142578125 + ], + [ + 0.0096588134765625 + ], + [ + 0.0109100341796875 + ], + [ + 0.00579071044921875 + ] + ], + [ + [ + -0.009613037109375 + ], + [ + -0.0108489990234375 + ], + [ + 0.0097198486328125 + ], + [ + -0.006763458251953125 + ] + ], + [ + [ + 0.0100555419921875 + ], + [ + -0.00954437255859375 + ], + [ + -0.009185791015625 + ], + [ + 0.006927490234375 + ] + ], + [ + [ + -0.01076507568359375 + ], + [ + -0.010528564453125 + ], + [ + 0.0106048583984375 + ], + [ + -0.007671356201171875 + ] + ], + [ + [ + -0.01036834716796875 + ], + [ + -0.01068115234375 + ], + [ + -0.01056671142578125 + ], + [ + -0.009033203125 + ] + ], + [ + [ + -0.01070404052734375 + ], + [ + 0.01039886474609375 + ], + [ + -0.00970458984375 + ], + [ + -0.005916595458984375 + ] + ], + [ + [ + 0.00968170166015625 + ], + [ + -0.010589599609375 + ], + [ + 0.00940704345703125 + ], + [ + -0.00543212890625 + ] + ], + [ + [ + -0.01090240478515625 + ], + [ + -0.010345458984375 + ], + [ + 0.01006317138671875 + ], + [ + 0.00695037841796875 + ] + ], + [ + [ + 0.00974273681640625 + ], + [ + -0.0087432861328125 + ], + [ + -0.009857177734375 + ], + [ + -0.006603240966796875 + ] + ], + [ + [ + 0.00936126708984375 + ], + [ + 0.010711669921875 + ], + [ + 0.0103912353515625 + ], + [ + 0.006847381591796875 + ] + ], + [ + [ + 0.00975799560546875 + ], + [ + -0.01107025146484375 + ], + [ + -0.01073455810546875 + ], + [ + -0.0077362060546875 + ] + ], + [ + [ + 0.009185791015625 + ], + [ + -0.01050567626953125 + ], + [ + 0.0096588134765625 + ], + [ + 0.00775909423828125 + ] + ], + [ + [ + 0.00922393798828125 + ], + [ + 0.00940704345703125 + ], + [ + 0.00949859619140625 + ], + [ + 0.005893707275390625 + ] + ], + [ + [ + 0.01010894775390625 + ], + [ + -0.0094757080078125 + ], + [ + 0.00902557373046875 + ], + [ + -0.00682830810546875 + ] + ], + [ + [ + -0.00994873046875 + ], + [ + 0.010467529296875 + ], + [ + -0.00936126708984375 + ], + [ + -0.006427764892578125 + ] + ], + [ + [ + -0.01031494140625 + ], + [ + 0.01061248779296875 + ], + [ + -0.0096893310546875 + ], + [ + -0.0110626220703125 + ] + ], + [ + [ + -0.00983428955078125 + ], + [ + -0.01062774658203125 + ], + [ + 0.0115203857421875 + ], + [ + 0.006656646728515625 + ] + ], + [ + [ + -0.0110015869140625 + ], + [ + -0.00907135009765625 + ], + [ + -0.0113067626953125 + ], + [ + 0.0066680908203125 + ] + ], + [ + [ + 0.00891876220703125 + ], + [ + 0.01165008544921875 + ], + [ + 0.00977325439453125 + ], + [ + 0.00894927978515625 + ] + ], + [ + [ + -0.01136016845703125 + ], + [ + 0.01145172119140625 + ], + [ + -0.010955810546875 + ], + [ + 0.010772705078125 + ] + ], + [ + [ + -0.010711669921875 + ], + [ + 0.010772705078125 + ], + [ + 0.01076507568359375 + ], + [ + -0.006195068359375 + ] + ], + [ + [ + 0.00940704345703125 + ], + [ + -0.01081085205078125 + ], + [ + 0.00977325439453125 + ], + [ + -0.00846099853515625 + ] + ], + [ + [ + -0.01116943359375 + ], + [ + 0.00909423828125 + ], + [ + -0.01003265380859375 + ], + [ + -0.0112762451171875 + ] + ], + [ + [ + -0.010650634765625 + ], + [ + -0.0108489990234375 + ], + [ + 0.0111541748046875 + ], + [ + 0.01007843017578125 + ] + ], + [ + [ + -0.00910186767578125 + ], + [ + -0.010528564453125 + ], + [ + 0.0102691650390625 + ], + [ + -0.005126953125 + ] + ] + ] +} \ No newline at end of file diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_True_ref_wc_param.json b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_True_ref_wc_param.json new file mode 100644 index 00000000000..e1baa81d0dc --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int4wo_sym_gs32_all_layers_True_ref_wc_param.json @@ -0,0 +1,128 @@ +[ + { + "weight_name": "q_proj_weight", + "node_with_weight": "linear", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + }, + { + "weight_name": "k_proj_weight", + "node_with_weight": "linear_1", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + }, + { + "weight_name": "v_proj_weight", + "node_with_weight": "linear_2", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + }, + { + "weight_name": "o_proj_weight", + "node_with_weight": "linear_3", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + }, + { + "weight_name": "mlp_gate_proj_weight", + "node_with_weight": "linear_4", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 128, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + }, + { + "weight_name": "mlp_up_proj_weight", + "node_with_weight": "linear_5", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 128, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + }, + { + "weight_name": "mlp_down_proj_weight", + "node_with_weight": "linear_6", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 128 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + } +] \ No newline at end of file diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int8wo_asym_gs-1_all_layers_False.dot b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int8wo_asym_gs-1_all_layers_False.dot new file mode 100644 index 00000000000..29de7b02841 --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int8wo_asym_gs-1_all_layers_False.dot @@ -0,0 +1,169 @@ +strict digraph { +"0 attn_norm_weight" [id=0, type="get_attr"]; +"1 mlp_norm_weight" [id=1, type="get_attr"]; +"2 rope_cos" [id=2, type="get_attr"]; +"3 rope_sin" [id=3, type="get_attr"]; +"4 x_embed" [id=4, type=input]; +"5 arange" [id=5, type=arange]; +"6 _assert_tensor_metadata_default" [id=6, type="_assert_tensor_metadata"]; +"7 to" [id=7, type=to]; +"8 pow_1" [id=8, type=pow]; +"9 mean" [id=9, type=mean]; +"10 add" [id=10, type=add]; +"11 rsqrt" [id=11, type=rsqrt]; +"12 mul" [id=12, type=mul]; +"13 _assert_tensor_metadata_default_1" [id=13, type="_assert_tensor_metadata"]; +"14 to_1" [id=14, type=to]; +"15 mul_1" [id=15, type=mul]; +"16 q_proj_weight_updated_constant0" [id=16, type="get_attr"]; +"17 asymmetric_weights_decompressor_q_proj_weight_0" [id=17, type="call_module"]; +"18 linear" [id=18, type=linear]; +"19 view" [id=19, type=view]; +"20 transpose" [id=20, type=transpose]; +"21 k_proj_weight_updated_constant0" [id=21, type="get_attr"]; +"22 asymmetric_weights_decompressor_k_proj_weight_0" [id=22, type="call_module"]; +"23 linear_1" [id=23, type=linear]; +"24 view_1" [id=24, type=view]; +"25 transpose_1" [id=25, type=transpose]; +"26 v_proj_weight_updated_constant0" [id=26, type="get_attr"]; +"27 asymmetric_weights_decompressor_v_proj_weight_0" [id=27, type="call_module"]; +"28 linear_2" [id=28, type=linear]; +"29 view_2" [id=29, type=view]; +"30 transpose_2" [id=30, type=transpose]; +"31 index" [id=31, type=index]; +"32 index_1" [id=32, type=index]; +"33 mul_2" [id=33, type=mul]; +"34 slice_1" [id=34, type=slice]; +"35 slice_2" [id=35, type=slice]; +"36 neg" [id=36, type=neg]; +"37 cat" [id=37, type=cat]; +"38 mul_3" [id=38, type=mul]; +"39 add_1" [id=39, type=add]; +"40 mul_4" [id=40, type=mul]; +"41 slice_3" [id=41, type=slice]; +"42 slice_4" [id=42, type=slice]; +"43 neg_1" [id=43, type=neg]; +"44 cat_1" [id=44, type=cat]; +"45 mul_5" [id=45, type=mul]; +"46 add_2" [id=46, type=add]; +"47 scaled_dot_product_attention" [id=47, type="scaled_dot_product_attention"]; +"48 transpose_3" [id=48, type=transpose]; +"49 view_3" [id=49, type=view]; +"50 o_proj_weight_updated_constant0" [id=50, type="get_attr"]; +"51 asymmetric_weights_decompressor_o_proj_weight_0" [id=51, type="call_module"]; +"52 linear_3" [id=52, type=linear]; +"53 add_3" [id=53, type=add]; +"54 _assert_tensor_metadata_default_2" [id=54, type="_assert_tensor_metadata"]; +"55 to_2" [id=55, type=to]; +"56 pow_2" [id=56, type=pow]; +"57 mean_1" [id=57, type=mean]; +"58 add_4" [id=58, type=add]; +"59 rsqrt_1" [id=59, type=rsqrt]; +"60 mul_6" [id=60, type=mul]; +"61 _assert_tensor_metadata_default_3" [id=61, type="_assert_tensor_metadata"]; +"62 to_3" [id=62, type=to]; +"63 mul_7" [id=63, type=mul]; +"64 mlp_gate_proj_weight_updated_constant0" [id=64, type="get_attr"]; +"65 asymmetric_weights_decompressor_mlp_gate_proj_weight_0" [id=65, type="call_module"]; +"66 linear_4" [id=66, type=linear]; +"67 silu" [id=67, type=silu]; +"68 mlp_up_proj_weight_updated_constant0" [id=68, type="get_attr"]; +"69 asymmetric_weights_decompressor_mlp_up_proj_weight_0" [id=69, type="call_module"]; +"70 linear_5" [id=70, type=linear]; +"71 mul_8" [id=71, type=mul]; +"72 mlp_down_proj_weight_updated_constant0" [id=72, type="get_attr"]; +"73 asymmetric_weights_decompressor_mlp_down_proj_weight_0" [id=73, type="call_module"]; +"74 linear_6" [id=74, type=linear]; +"75 add_5" [id=75, type=add]; +"76 output" [id=76, type=output]; +"0 attn_norm_weight" -> "15 mul_1" [style=solid, label="(64,)"]; +"1 mlp_norm_weight" -> "63 mul_7" [style=solid, label="(64,)"]; +"2 rope_cos" -> "31 index" [style=solid, label="(1, 1, 128, 16)"]; +"3 rope_sin" -> "32 index_1" [style=solid, label="(1, 1, 128, 16)"]; +"4 x_embed" -> "6 _assert_tensor_metadata_default" [style=solid, label="(1, 3, 64)"]; +"4 x_embed" -> "7 to" [style=solid, label="(1, 3, 64)"]; +"4 x_embed" -> "53 add_3" [style=solid, label="(1, 3, 64)"]; +"5 arange" -> "31 index" [style=solid, label="(3,)"]; +"5 arange" -> "32 index_1" [style=solid, label="(3,)"]; +"7 to" -> "8 pow_1" [style=solid, label="(1, 3, 64)"]; +"7 to" -> "12 mul" [style=solid, label="(1, 3, 64)"]; +"8 pow_1" -> "9 mean" [style=solid, label="(1, 3, 64)"]; +"9 mean" -> "10 add" [style=solid, label="(1, 3, 1)"]; +"10 add" -> "11 rsqrt" [style=solid, label="(1, 3, 1)"]; +"11 rsqrt" -> "12 mul" [style=solid, label="(1, 3, 1)"]; +"12 mul" -> "13 _assert_tensor_metadata_default_1" [style=solid, label="(1, 3, 64)"]; +"12 mul" -> "14 to_1" [style=solid, label="(1, 3, 64)"]; +"14 to_1" -> "15 mul_1" [style=solid, label="(1, 3, 64)"]; +"15 mul_1" -> "18 linear" [style=solid, label="(1, 3, 64)"]; +"15 mul_1" -> "23 linear_1" [style=solid, label="(1, 3, 64)"]; +"15 mul_1" -> "28 linear_2" [style=solid, label="(1, 3, 64)"]; +"16 q_proj_weight_updated_constant0" -> "17 asymmetric_weights_decompressor_q_proj_weight_0" [style=solid, label="(64, 64)"]; +"17 asymmetric_weights_decompressor_q_proj_weight_0" -> "18 linear" [style=solid, label="(64, 64)"]; +"18 linear" -> "19 view" [style=solid, label="(1, 3, 64)"]; +"19 view" -> "20 transpose" [style=solid, label="(1, 3, 4, 16)"]; +"20 transpose" -> "33 mul_2" [style=solid, label="(1, 4, 3, 16)"]; +"20 transpose" -> "34 slice_1" [style=solid, label="(1, 4, 3, 16)"]; +"20 transpose" -> "35 slice_2" [style=solid, label="(1, 4, 3, 16)"]; +"21 k_proj_weight_updated_constant0" -> "22 asymmetric_weights_decompressor_k_proj_weight_0" [style=solid, label="(64, 64)"]; +"22 asymmetric_weights_decompressor_k_proj_weight_0" -> "23 linear_1" [style=solid, label="(64, 64)"]; +"23 linear_1" -> "24 view_1" [style=solid, label="(1, 3, 64)"]; +"24 view_1" -> "25 transpose_1" [style=solid, label="(1, 3, 4, 16)"]; +"25 transpose_1" -> "40 mul_4" [style=solid, label="(1, 4, 3, 16)"]; +"25 transpose_1" -> "41 slice_3" [style=solid, label="(1, 4, 3, 16)"]; +"25 transpose_1" -> "42 slice_4" [style=solid, label="(1, 4, 3, 16)"]; +"26 v_proj_weight_updated_constant0" -> "27 asymmetric_weights_decompressor_v_proj_weight_0" [style=solid, label="(64, 64)"]; +"27 asymmetric_weights_decompressor_v_proj_weight_0" -> "28 linear_2" [style=solid, label="(64, 64)"]; +"28 linear_2" -> "29 view_2" [style=solid, label="(1, 3, 64)"]; +"29 view_2" -> "30 transpose_2" [style=solid, label="(1, 3, 4, 16)"]; +"30 transpose_2" -> "47 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"31 index" -> "33 mul_2" [style=solid, label="(1, 1, 3, 16)"]; +"31 index" -> "40 mul_4" [style=solid, label="(1, 1, 3, 16)"]; +"32 index_1" -> "38 mul_3" [style=solid, label="(1, 1, 3, 16)"]; +"32 index_1" -> "45 mul_5" [style=solid, label="(1, 1, 3, 16)"]; +"33 mul_2" -> "39 add_1" [style=solid, label="(1, 4, 3, 16)"]; +"34 slice_1" -> "37 cat" [style=solid, label="(1, 4, 3, 8)"]; +"35 slice_2" -> "36 neg" [style=solid, label="(1, 4, 3, 8)"]; +"36 neg" -> "37 cat" [style=solid, label="(1, 4, 3, 8)"]; +"37 cat" -> "38 mul_3" [style=solid, label="(1, 4, 3, 16)"]; +"38 mul_3" -> "39 add_1" [style=solid, label="(1, 4, 3, 16)"]; +"39 add_1" -> "47 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"40 mul_4" -> "46 add_2" [style=solid, label="(1, 4, 3, 16)"]; +"41 slice_3" -> "44 cat_1" [style=solid, label="(1, 4, 3, 8)"]; +"42 slice_4" -> "43 neg_1" [style=solid, label="(1, 4, 3, 8)"]; +"43 neg_1" -> "44 cat_1" [style=solid, label="(1, 4, 3, 8)"]; +"44 cat_1" -> "45 mul_5" [style=solid, label="(1, 4, 3, 16)"]; +"45 mul_5" -> "46 add_2" [style=solid, label="(1, 4, 3, 16)"]; +"46 add_2" -> "47 scaled_dot_product_attention" [style=solid, label="(1, 4, 3, 16)"]; +"47 scaled_dot_product_attention" -> "48 transpose_3" [style=solid, label="(1, 4, 3, 16)"]; +"48 transpose_3" -> "49 view_3" [style=solid, label="(1, 3, 4, 16)"]; +"49 view_3" -> "52 linear_3" [style=solid, label="(1, 3, 64)"]; +"50 o_proj_weight_updated_constant0" -> "51 asymmetric_weights_decompressor_o_proj_weight_0" [style=solid, label="(64, 64)"]; +"51 asymmetric_weights_decompressor_o_proj_weight_0" -> "52 linear_3" [style=solid, label="(64, 64)"]; +"52 linear_3" -> "53 add_3" [style=solid, label="(1, 3, 64)"]; +"53 add_3" -> "54 _assert_tensor_metadata_default_2" [style=solid, label="(1, 3, 64)"]; +"53 add_3" -> "55 to_2" [style=solid, label="(1, 3, 64)"]; +"55 to_2" -> "56 pow_2" [style=solid, label="(1, 3, 64)"]; +"55 to_2" -> "60 mul_6" [style=solid, label="(1, 3, 64)"]; +"55 to_2" -> "75 add_5" [style=solid, label="(1, 3, 64)"]; +"56 pow_2" -> "57 mean_1" [style=solid, label="(1, 3, 64)"]; +"57 mean_1" -> "58 add_4" [style=solid, label="(1, 3, 1)"]; +"58 add_4" -> "59 rsqrt_1" [style=solid, label="(1, 3, 1)"]; +"59 rsqrt_1" -> "60 mul_6" [style=solid, label="(1, 3, 1)"]; +"60 mul_6" -> "61 _assert_tensor_metadata_default_3" [style=solid, label="(1, 3, 64)"]; +"60 mul_6" -> "62 to_3" [style=solid, label="(1, 3, 64)"]; +"62 to_3" -> "63 mul_7" [style=solid, label="(1, 3, 64)"]; +"63 mul_7" -> "66 linear_4" [style=solid, label="(1, 3, 64)"]; +"63 mul_7" -> "70 linear_5" [style=solid, label="(1, 3, 64)"]; +"64 mlp_gate_proj_weight_updated_constant0" -> "65 asymmetric_weights_decompressor_mlp_gate_proj_weight_0" [style=solid, label="(128, 64)"]; +"65 asymmetric_weights_decompressor_mlp_gate_proj_weight_0" -> "66 linear_4" [style=solid, label="(128, 64)"]; +"66 linear_4" -> "67 silu" [style=solid, label="(1, 3, 128)"]; +"67 silu" -> "71 mul_8" [style=solid, label="(1, 3, 128)"]; +"68 mlp_up_proj_weight_updated_constant0" -> "69 asymmetric_weights_decompressor_mlp_up_proj_weight_0" [style=solid, label="(128, 64)"]; +"69 asymmetric_weights_decompressor_mlp_up_proj_weight_0" -> "70 linear_5" [style=solid, label="(128, 64)"]; +"70 linear_5" -> "71 mul_8" [style=solid, label="(1, 3, 128)"]; +"71 mul_8" -> "74 linear_6" [style=solid, label="(1, 3, 128)"]; +"72 mlp_down_proj_weight_updated_constant0" -> "73 asymmetric_weights_decompressor_mlp_down_proj_weight_0" [style=solid, label="(64, 128)"]; +"73 asymmetric_weights_decompressor_mlp_down_proj_weight_0" -> "74 linear_6" [style=solid, label="(64, 128)"]; +"74 linear_6" -> "75 add_5" [style=solid, label="(1, 3, 64)"]; +"75 add_5" -> "76 output" [style=solid, label="(1, 3, 64)"]; +} diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int8wo_asym_gs-1_all_layers_False_awq_False_scale_estimation_False_ref_wc_scales.json b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int8wo_asym_gs-1_all_layers_False_awq_False_scale_estimation_False_ref_wc_scales.json new file mode 100644 index 00000000000..40b1cc6c44e --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int8wo_asym_gs-1_all_layers_False_awq_False_scale_estimation_False_ref_wc_scales.json @@ -0,0 +1,1744 @@ +{ + "asymmetric_weights_decompressor_q_proj_weight_0": [ + [ + 0.0009570121765136719 + ], + [ + 0.0009775161743164062 + ], + [ + 0.00090789794921875 + ], + [ + 0.0009617805480957031 + ], + [ + 0.0009713172912597656 + ], + [ + 0.0009593963623046875 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009412765502929688 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009531974792480469 + ], + [ + 0.0009756088256835938 + ], + [ + 0.0009655952453613281 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009293556213378906 + ], + [ + 0.0009431838989257812 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009675025939941406 + ], + [ + 0.0009603500366210938 + ], + [ + 0.0009112358093261719 + ], + [ + 0.0009660720825195312 + ], + [ + 0.0009403228759765625 + ], + [ + 0.0009670257568359375 + ], + [ + 0.0009279251098632812 + ], + [ + 0.0009522438049316406 + ], + [ + 0.0008702278137207031 + ], + [ + 0.0009508132934570312 + ], + [ + 0.0009407997131347656 + ], + [ + 0.0009713172912597656 + ], + [ + 0.0009570121765136719 + ], + [ + 0.0009112358093261719 + ], + [ + 0.0009579658508300781 + ], + [ + 0.000926971435546875 + ], + [ + 0.0009055137634277344 + ], + [ + 0.0009469985961914062 + ], + [ + 0.0009598731994628906 + ], + [ + 0.0009484291076660156 + ], + [ + 0.0009455680847167969 + ], + [ + 0.0009098052978515625 + ], + [ + 0.0009479522705078125 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009589195251464844 + ], + [ + 0.0009512901306152344 + ], + [ + 0.0009641647338867188 + ], + [ + 0.0009608268737792969 + ], + [ + 0.0009350776672363281 + ], + [ + 0.000911712646484375 + ], + [ + 0.0009655952453613281 + ], + [ + 0.000949859619140625 + ], + [ + 0.00092315673828125 + ], + [ + 0.000957489013671875 + ], + [ + 0.0009412765502929688 + ], + [ + 0.0009593963623046875 + ], + [ + 0.0009217262268066406 + ], + [ + 0.0009026527404785156 + ], + [ + 0.00095367431640625 + ], + [ + 0.0009531974792480469 + ], + [ + 0.0009198188781738281 + ], + [ + 0.0009183883666992188 + ], + [ + 0.000957489013671875 + ], + [ + 0.0009531974792480469 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009760856628417969 + ], + [ + 0.0009632110595703125 + ] + ], + "asymmetric_weights_decompressor_k_proj_weight_0": [ + [ + 0.0009379386901855469 + ], + [ + 0.0009522438049316406 + ], + [ + 0.0009541511535644531 + ], + [ + 0.0009598731994628906 + ], + [ + 0.0009517669677734375 + ], + [ + 0.000934600830078125 + ], + [ + 0.0009455680847167969 + ], + [ + 0.000926971435546875 + ], + [ + 0.000957489013671875 + ], + [ + 0.0009679794311523438 + ], + [ + 0.0009698867797851562 + ], + [ + 0.0009417533874511719 + ], + [ + 0.0009417533874511719 + ], + [ + 0.0009469985961914062 + ], + [ + 0.000965118408203125 + ], + [ + 0.000946044921875 + ], + [ + 0.0009469985961914062 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009484291076660156 + ], + [ + 0.0009698867797851562 + ], + [ + 0.0009603500366210938 + ], + [ + 0.0009522438049316406 + ], + [ + 0.0009694099426269531 + ], + [ + 0.0009531974792480469 + ], + [ + 0.0009427070617675781 + ], + [ + 0.0009622573852539062 + ], + [ + 0.0009732246398925781 + ], + [ + 0.0009245872497558594 + ], + [ + 0.0009450912475585938 + ], + [ + 0.0009560585021972656 + ], + [ + 0.000949859619140625 + ], + [ + 0.0009469985961914062 + ], + [ + 0.0009622573852539062 + ], + [ + 0.00093841552734375 + ], + [ + 0.0009455680847167969 + ], + [ + 0.0009469985961914062 + ], + [ + 0.0009493827819824219 + ], + [ + 0.0009741783142089844 + ], + [ + 0.0009331703186035156 + ], + [ + 0.0009412765502929688 + ], + [ + 0.0009660720825195312 + ], + [ + 0.0009593963623046875 + ], + [ + 0.0009484291076660156 + ], + [ + 0.0009756088256835938 + ], + [ + 0.0009617805480957031 + ], + [ + 0.0009279251098632812 + ], + [ + 0.0009350776672363281 + ], + [ + 0.00093841552734375 + ], + [ + 0.0009512901306152344 + ], + [ + 0.0009531974792480469 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009493827819824219 + ], + [ + 0.0009551048278808594 + ], + [ + 0.0009503364562988281 + ], + [ + 0.0009622573852539062 + ], + [ + 0.0009250640869140625 + ], + [ + 0.0009508132934570312 + ], + [ + 0.0009503364562988281 + ], + [ + 0.0009627342224121094 + ], + [ + 0.0009059906005859375 + ], + [ + 0.0009679794311523438 + ], + [ + 0.0009765625 + ], + [ + 0.0009584426879882812 + ], + [ + 0.0009436607360839844 + ] + ], + "asymmetric_weights_decompressor_v_proj_weight_0": [ + [ + 0.0009713172912597656 + ], + [ + 0.0009407997131347656 + ], + [ + 0.000934600830078125 + ], + [ + 0.0009560585021972656 + ], + [ + 0.00090789794921875 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009274482727050781 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009093284606933594 + ], + [ + 0.0009679794311523438 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009374618530273438 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009140968322753906 + ], + [ + 0.0009531974792480469 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009226799011230469 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009160041809082031 + ], + [ + 0.0009560585021972656 + ], + [ + 0.0009713172912597656 + ], + [ + 0.000942230224609375 + ], + [ + 0.0009527206420898438 + ], + [ + 0.0009603500366210938 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009450912475585938 + ], + [ + 0.0009388923645019531 + ], + [ + 0.0009670257568359375 + ], + [ + 0.00095367431640625 + ], + [ + 0.0009665489196777344 + ], + [ + 0.00096893310546875 + ], + [ + 0.0009512901306152344 + ], + [ + 0.0009665489196777344 + ], + [ + 0.000934600830078125 + ], + [ + 0.0009379386901855469 + ], + [ + 0.0009522438049316406 + ], + [ + 0.0009431838989257812 + ], + [ + 0.0009765625 + ], + [ + 0.0009579658508300781 + ], + [ + 0.0009608268737792969 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009694099426269531 + ], + [ + 0.0009603500366210938 + ], + [ + 0.0009665489196777344 + ], + [ + 0.0009260177612304688 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009188652038574219 + ], + [ + 0.0009365081787109375 + ], + [ + 0.00095367431640625 + ], + [ + 0.0009708404541015625 + ], + [ + 0.0009398460388183594 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009446144104003906 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009717941284179688 + ], + [ + 0.0009593963623046875 + ], + [ + 0.0009503364562988281 + ], + [ + 0.0009560585021972656 + ], + [ + 0.0009608268737792969 + ], + [ + 0.0009760856628417969 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0008978843688964844 + ] + ], + "asymmetric_weights_decompressor_o_proj_weight_0": [ + [ + 0.0009369850158691406 + ], + [ + 0.0009579658508300781 + ], + [ + 0.0009160041809082031 + ], + [ + 0.0009713172912597656 + ], + [ + 0.0009646415710449219 + ], + [ + 0.0009450912475585938 + ], + [ + 0.0009622573852539062 + ], + [ + 0.0009102821350097656 + ], + [ + 0.000942230224609375 + ], + [ + 0.0009608268737792969 + ], + [ + 0.0009150505065917969 + ], + [ + 0.0009412765502929688 + ], + [ + 0.0009427070617675781 + ], + [ + 0.0009264945983886719 + ], + [ + 0.0009660720825195312 + ], + [ + 0.0009541511535644531 + ], + [ + 0.0009675025939941406 + ], + [ + 0.0009431838989257812 + ], + [ + 0.0009207725524902344 + ], + [ + 0.0009479522705078125 + ], + [ + 0.0009660720825195312 + ], + [ + 0.0009217262268066406 + ], + [ + 0.0009613037109375 + ], + [ + 0.0009675025939941406 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009465217590332031 + ], + [ + 0.0009403228759765625 + ], + [ + 0.00093841552734375 + ], + [ + 0.0009326934814453125 + ], + [ + 0.0009765625 + ], + [ + 0.0009708404541015625 + ], + [ + 0.0009565353393554688 + ], + [ + 0.0009765625 + ], + [ + 0.0009293556213378906 + ], + [ + 0.0009379386901855469 + ], + [ + 0.0009794235229492188 + ], + [ + 0.0009675025939941406 + ], + [ + 0.0009517669677734375 + ], + [ + 0.0009436607360839844 + ], + [ + 0.0009622573852539062 + ], + [ + 0.000942230224609375 + ], + [ + 0.0009531974792480469 + ], + [ + 0.0009503364562988281 + ], + [ + 0.0009775161743164062 + ], + [ + 0.000972747802734375 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009307861328125 + ], + [ + 0.0009579658508300781 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009584426879882812 + ], + [ + 0.0009207725524902344 + ], + [ + 0.0008878707885742188 + ], + [ + 0.0009522438049316406 + ], + [ + 0.0009560585021972656 + ], + [ + 0.00093841552734375 + ], + [ + 0.0009641647338867188 + ], + [ + 0.000957489013671875 + ], + [ + 0.0009393692016601562 + ], + [ + 0.0009016990661621094 + ], + [ + 0.0009589195251464844 + ], + [ + 0.0009589195251464844 + ], + [ + 0.0009746551513671875 + ], + [ + 0.0009665489196777344 + ], + [ + 0.0009617805480957031 + ] + ], + "asymmetric_weights_decompressor_mlp_gate_proj_weight_0": [ + [ + 0.0009775161743164062 + ], + [ + 0.0009708404541015625 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009775161743164062 + ], + [ + 0.0009636878967285156 + ], + [ + 0.0009717941284179688 + ], + [ + 0.000934600830078125 + ], + [ + 0.0009636878967285156 + ], + [ + 0.00091552734375 + ], + [ + 0.0009522438049316406 + ], + [ + 0.0009431838989257812 + ], + [ + 0.0009474754333496094 + ], + [ + 0.0009431838989257812 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009593963623046875 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009598731994628906 + ], + [ + 0.0009617805480957031 + ], + [ + 0.0009374618530273438 + ], + [ + 0.0009660720825195312 + ], + [ + 0.0009632110595703125 + ], + [ + 0.0009503364562988281 + ], + [ + 0.0009474754333496094 + ], + [ + 0.0009636878967285156 + ], + [ + 0.0009646415710449219 + ], + [ + 0.0009646415710449219 + ], + [ + 0.0009679794311523438 + ], + [ + 0.00096893310546875 + ], + [ + 0.0009756088256835938 + ], + [ + 0.0009708404541015625 + ], + [ + 0.0009517669677734375 + ], + [ + 0.0009531974792480469 + ], + [ + 0.0009489059448242188 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009031295776367188 + ], + [ + 0.0009646415710449219 + ], + [ + 0.00096893310546875 + ], + [ + 0.0009598731994628906 + ], + [ + 0.0009245872497558594 + ], + [ + 0.0009675025939941406 + ], + [ + 0.0009632110595703125 + ], + [ + 0.00096893310546875 + ], + [ + 0.0009126663208007812 + ], + [ + 0.0009603500366210938 + ], + [ + 0.0009427070617675781 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009555816650390625 + ], + [ + 0.0009717941284179688 + ], + [ + 0.0009641647338867188 + ], + [ + 0.0009713172912597656 + ], + [ + 0.0009698867797851562 + ], + [ + 0.0009388923645019531 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009617805480957031 + ], + [ + 0.0008859634399414062 + ], + [ + 0.0009598731994628906 + ], + [ + 0.0009293556213378906 + ], + [ + 0.00092315673828125 + ], + [ + 0.0008797645568847656 + ], + [ + 0.0009746551513671875 + ], + [ + 0.0009541511535644531 + ], + [ + 0.0009226799011230469 + ], + [ + 0.0009679794311523438 + ], + [ + 0.0009741783142089844 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009245872497558594 + ], + [ + 0.0009694099426269531 + ], + [ + 0.0009512901306152344 + ], + [ + 0.0009183883666992188 + ], + [ + 0.0009636878967285156 + ], + [ + 0.0009665489196777344 + ], + [ + 0.0009307861328125 + ], + [ + 0.0009603500366210938 + ], + [ + 0.0009760856628417969 + ], + [ + 0.0009551048278808594 + ], + [ + 0.0008802413940429688 + ], + [ + 0.0009717941284179688 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009469985961914062 + ], + [ + 0.00095367431640625 + ], + [ + 0.0009508132934570312 + ], + [ + 0.0009632110595703125 + ], + [ + 0.0009436607360839844 + ], + [ + 0.0009322166442871094 + ], + [ + 0.0009765625 + ], + [ + 0.0009469985961914062 + ], + [ + 0.0009636878967285156 + ], + [ + 0.0009603500366210938 + ], + [ + 0.0009670257568359375 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009646415710449219 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009298324584960938 + ], + [ + 0.0008978843688964844 + ], + [ + 0.000949859619140625 + ], + [ + 0.0009593963623046875 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009756088256835938 + ], + [ + 0.0009417533874511719 + ], + [ + 0.0009555816650390625 + ], + [ + 0.0009527206420898438 + ], + [ + 0.000926971435546875 + ], + [ + 0.0009565353393554688 + ], + [ + 0.000957489013671875 + ], + [ + 0.0009765625 + ], + [ + 0.0009641647338867188 + ], + [ + 0.0009622573852539062 + ], + [ + 0.00095367431640625 + ], + [ + 0.0009660720825195312 + ], + [ + 0.0009703636169433594 + ], + [ + 0.000949859619140625 + ], + [ + 0.0009436607360839844 + ], + [ + 0.0009570121765136719 + ], + [ + 0.0009655952453613281 + ], + [ + 0.00096893310546875 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009608268737792969 + ], + [ + 0.0009360313415527344 + ], + [ + 0.0008497238159179688 + ], + [ + 0.0009436607360839844 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009412765502929688 + ], + [ + 0.0009751319885253906 + ], + [ + 0.000957489013671875 + ], + [ + 0.0009441375732421875 + ], + [ + 0.0009717941284179688 + ], + [ + 0.0009450912475585938 + ] + ], + "asymmetric_weights_decompressor_mlp_up_proj_weight_0": [ + [ + 0.0009703636169433594 + ], + [ + 0.0009417533874511719 + ], + [ + 0.0009379386901855469 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009632110595703125 + ], + [ + 0.0009670257568359375 + ], + [ + 0.0009627342224121094 + ], + [ + 0.0009636878967285156 + ], + [ + 0.0009698867797851562 + ], + [ + 0.0009479522705078125 + ], + [ + 0.0009713172912597656 + ], + [ + 0.0009675025939941406 + ], + [ + 0.0009493827819824219 + ], + [ + 0.00096893310546875 + ], + [ + 0.0009484291076660156 + ], + [ + 0.000949859619140625 + ], + [ + 0.0009565353393554688 + ], + [ + 0.0009121894836425781 + ], + [ + 0.0009679794311523438 + ], + [ + 0.000957489013671875 + ], + [ + 0.0009665489196777344 + ], + [ + 0.0009403228759765625 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009388923645019531 + ], + [ + 0.0009760856628417969 + ], + [ + 0.0009570121765136719 + ], + [ + 0.0009322166442871094 + ], + [ + 0.0009326934814453125 + ], + [ + 0.0009603500366210938 + ], + [ + 0.0009555816650390625 + ], + [ + 0.0009765625 + ], + [ + 0.0009188652038574219 + ], + [ + 0.0009503364562988281 + ], + [ + 0.0009765625 + ], + [ + 0.0009579658508300781 + ], + [ + 0.0009031295776367188 + ], + [ + 0.0009493827819824219 + ], + [ + 0.0009274482727050781 + ], + [ + 0.0009741783142089844 + ], + [ + 0.0009379386901855469 + ], + [ + 0.0009403228759765625 + ], + [ + 0.00092315673828125 + ], + [ + 0.0009717941284179688 + ], + [ + 0.0009055137634277344 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009589195251464844 + ], + [ + 0.000972747802734375 + ], + [ + 0.0009694099426269531 + ], + [ + 0.0009641647338867188 + ], + [ + 0.0009164810180664062 + ], + [ + 0.0009512901306152344 + ], + [ + 0.0009007453918457031 + ], + [ + 0.0009570121765136719 + ], + [ + 0.0009188652038574219 + ], + [ + 0.0009560585021972656 + ], + [ + 0.0009551048278808594 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009465217590332031 + ], + [ + 0.0009765625 + ], + [ + 0.0009541511535644531 + ], + [ + 0.0009717941284179688 + ], + [ + 0.0009760856628417969 + ], + [ + 0.000957489013671875 + ], + [ + 0.0009493827819824219 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009737014770507812 + ], + [ + 0.0009717941284179688 + ], + [ + 0.0009608268737792969 + ], + [ + 0.0009217262268066406 + ], + [ + 0.0009765625 + ], + [ + 0.0008749961853027344 + ], + [ + 0.0009751319885253906 + ], + [ + 0.0009322166442871094 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009331703186035156 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009756088256835938 + ], + [ + 0.0009646415710449219 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009760856628417969 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009431838989257812 + ], + [ + 0.0009765625 + ], + [ + 0.0009403228759765625 + ], + [ + 0.000946044921875 + ], + [ + 0.0009775161743164062 + ], + [ + 0.0009350776672363281 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009641647338867188 + ], + [ + 0.0009522438049316406 + ], + [ + 0.0009617805480957031 + ], + [ + 0.0009646415710449219 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009484291076660156 + ], + [ + 0.0009412765502929688 + ], + [ + 0.0009450912475585938 + ], + [ + 0.0009784698486328125 + ], + [ + 0.0009140968322753906 + ], + [ + 0.0009698867797851562 + ], + [ + 0.0009546279907226562 + ], + [ + 0.0009436607360839844 + ], + [ + 0.0008835792541503906 + ], + [ + 0.0009217262268066406 + ], + [ + 0.0009484291076660156 + ], + [ + 0.0009794235229492188 + ], + [ + 0.0009698867797851562 + ], + [ + 0.0008611679077148438 + ], + [ + 0.0009746551513671875 + ], + [ + 0.0009427070617675781 + ], + [ + 0.0008997917175292969 + ], + [ + 0.0009713172912597656 + ], + [ + 0.0009679794311523438 + ], + [ + 0.0009393692016601562 + ], + [ + 0.0009617805480957031 + ], + [ + 0.0009407997131347656 + ], + [ + 0.0009593963623046875 + ], + [ + 0.0009493827819824219 + ], + [ + 0.0008873939514160156 + ], + [ + 0.0009212493896484375 + ], + [ + 0.0009560585021972656 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009517669677734375 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009617805480957031 + ], + [ + 0.0009570121765136719 + ] + ], + "asymmetric_weights_decompressor_mlp_down_proj_weight_0": [ + [ + 0.0006866455078125 + ], + [ + 0.0006728172302246094 + ], + [ + 0.0006909370422363281 + ], + [ + 0.000690460205078125 + ], + [ + 0.0006866455078125 + ], + [ + 0.0006794929504394531 + ], + [ + 0.0006856918334960938 + ], + [ + 0.0006885528564453125 + ], + [ + 0.0006775856018066406 + ], + [ + 0.00067901611328125 + ], + [ + 0.0006818771362304688 + ], + [ + 0.0006895065307617188 + ], + [ + 0.0006823539733886719 + ], + [ + 0.0006814002990722656 + ], + [ + 0.000690460205078125 + ], + [ + 0.0006890296936035156 + ], + [ + 0.0006895065307617188 + ], + [ + 0.0006847381591796875 + ], + [ + 0.0006761550903320312 + ], + [ + 0.0006785392761230469 + ], + [ + 0.0006885528564453125 + ], + [ + 0.000682830810546875 + ], + [ + 0.0006794929504394531 + ], + [ + 0.0006899833679199219 + ], + [ + 0.0006804466247558594 + ], + [ + 0.0006785392761230469 + ], + [ + 0.00066375732421875 + ], + [ + 0.0006914138793945312 + ], + [ + 0.0006780624389648438 + ], + [ + 0.0006856918334960938 + ], + [ + 0.0006890296936035156 + ], + [ + 0.0006837844848632812 + ], + [ + 0.0006890296936035156 + ], + [ + 0.0006885528564453125 + ], + [ + 0.0006866455078125 + ], + [ + 0.0006885528564453125 + ], + [ + 0.0006880760192871094 + ], + [ + 0.0006861686706542969 + ], + [ + 0.0006861686706542969 + ], + [ + 0.0006804466247558594 + ], + [ + 0.0006866455078125 + ], + [ + 0.0006761550903320312 + ], + [ + 0.0006871223449707031 + ], + [ + 0.0006875991821289062 + ], + [ + 0.0006780624389648438 + ], + [ + 0.0006880760192871094 + ], + [ + 0.0006909370422363281 + ], + [ + 0.0006718635559082031 + ], + [ + 0.0006723403930664062 + ], + [ + 0.0006895065307617188 + ], + [ + 0.0006694793701171875 + ], + [ + 0.0006737709045410156 + ], + [ + 0.0006885528564453125 + ], + [ + 0.0006785392761230469 + ], + [ + 0.0006885528564453125 + ], + [ + 0.0006804466247558594 + ], + [ + 0.0006866455078125 + ], + [ + 0.0006666183471679688 + ], + [ + 0.0006909370422363281 + ], + [ + 0.0006833076477050781 + ], + [ + 0.0006875991821289062 + ], + [ + 0.0006818771362304688 + ], + [ + 0.0006794929504394531 + ], + [ + 0.0006918907165527344 + ] + ] +} \ No newline at end of file diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int8wo_asym_gs-1_all_layers_False_ref_wc_param.json b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int8wo_asym_gs-1_all_layers_False_ref_wc_param.json new file mode 100644 index 00000000000..69d4cf0f6a8 --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/LlamaDecoderOnly/int8wo_asym_gs-1_all_layers_False_ref_wc_param.json @@ -0,0 +1,128 @@ +[ + { + "weight_name": "q_proj_weight", + "node_with_weight": "linear", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int8_asym", + "group_size": -1, + "codebook_values": null + } + }, + { + "weight_name": "k_proj_weight", + "node_with_weight": "linear_1", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int8_asym", + "group_size": -1, + "codebook_values": null + } + }, + { + "weight_name": "v_proj_weight", + "node_with_weight": "linear_2", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int8_asym", + "group_size": -1, + "codebook_values": null + } + }, + { + "weight_name": "o_proj_weight", + "node_with_weight": "linear_3", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int8_asym", + "group_size": -1, + "codebook_values": null + } + }, + { + "weight_name": "mlp_gate_proj_weight", + "node_with_weight": "linear_4", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 128, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int8_asym", + "group_size": -1, + "codebook_values": null + } + }, + { + "weight_name": "mlp_up_proj_weight", + "node_with_weight": "linear_5", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 128, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int8_asym", + "group_size": -1, + "codebook_values": null + } + }, + { + "weight_name": "mlp_down_proj_weight", + "node_with_weight": "linear_6", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 128 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int8_asym", + "group_size": -1, + "codebook_values": null + } + } +] \ No newline at end of file diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_False.dot b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_False.dot new file mode 100644 index 00000000000..b249fdf7ce3 --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_False.dot @@ -0,0 +1,24 @@ +strict digraph { +"0 linear_bias" [id=0, type="get_attr"]; +"1 lm_head_bias" [id=1, type="get_attr"]; +"2 input_ids" [id=2, type=input]; +"3 wte_weight_1_updated_constant0" [id=3, type="get_attr"]; +"4 asymmetric_weights_decompressor_wte_weight_1_0" [id=4, type="call_module"]; +"5 embedding" [id=5, type=embedding]; +"6 linear_weight_updated_constant0" [id=6, type="get_attr"]; +"7 asymmetric_weights_decompressor_linear_weight_0" [id=7, type="call_module"]; +"8 linear" [id=8, type=linear]; +"9 linear_1" [id=9, type=linear]; +"10 output" [id=10, type=output]; +"0 linear_bias" -> "8 linear" [style=solid, label="(64,)"]; +"1 lm_head_bias" -> "9 linear_1" [style=solid, label="(128,)"]; +"2 input_ids" -> "5 embedding" [style=solid, label="(5,)"]; +"3 wte_weight_1_updated_constant0" -> "4 asymmetric_weights_decompressor_wte_weight_1_0" [style=solid, label="(128, 64)"]; +"4 asymmetric_weights_decompressor_wte_weight_1_0" -> "5 embedding" [style=solid, label="(128, 64)"]; +"4 asymmetric_weights_decompressor_wte_weight_1_0" -> "9 linear_1" [style=solid, label="(128, 64)"]; +"5 embedding" -> "8 linear" [style=solid, label="(5, 64)"]; +"6 linear_weight_updated_constant0" -> "7 asymmetric_weights_decompressor_linear_weight_0" [style=solid, label="(64, 64)"]; +"7 asymmetric_weights_decompressor_linear_weight_0" -> "8 linear" [style=solid, label="(64, 64)"]; +"8 linear" -> "9 linear_1" [style=solid, label="(5, 64)"]; +"9 linear_1" -> "10 output" [style=solid, label="(5, 128)"]; +} diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_False_awq_True_scale_estimation_True_ref_wc_scales.json b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_False_awq_True_scale_estimation_True_ref_wc_scales.json new file mode 100644 index 00000000000..38bc9f7b51c --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_False_awq_True_scale_estimation_True_ref_wc_scales.json @@ -0,0 +1,582 @@ +{ + "asymmetric_weights_decompressor_wte_weight_1_0": [ + [ + 0.0181884765625 + ], + [ + 0.025665283203125 + ], + [ + 0.01727294921875 + ], + [ + 0.015869140625 + ], + [ + 0.01837158203125 + ], + [ + 0.0209503173828125 + ], + [ + 0.0199127197265625 + ], + [ + 0.01641845703125 + ], + [ + 0.0213775634765625 + ], + [ + 0.01971435546875 + ], + [ + 0.020751953125 + ], + [ + 0.0206756591796875 + ], + [ + 0.018585205078125 + ], + [ + 0.017120361328125 + ], + [ + 0.016693115234375 + ], + [ + 0.01551055908203125 + ], + [ + 0.019378662109375 + ], + [ + 0.0218353271484375 + ], + [ + 0.018707275390625 + ], + [ + 0.018524169921875 + ], + [ + 0.0207672119140625 + ], + [ + 0.0210113525390625 + ], + [ + 0.017608642578125 + ], + [ + 0.016937255859375 + ], + [ + 0.0146331787109375 + ], + [ + 0.016754150390625 + ], + [ + 0.02288818359375 + ], + [ + 0.0201873779296875 + ], + [ + 0.0160675048828125 + ], + [ + 0.0161285400390625 + ], + [ + 0.0251617431640625 + ], + [ + 0.015899658203125 + ], + [ + 0.016143798828125 + ], + [ + 0.0206756591796875 + ], + [ + 0.0192718505859375 + ], + [ + 0.01537322998046875 + ], + [ + 0.017669677734375 + ], + [ + 0.0156402587890625 + ], + [ + 0.0193023681640625 + ], + [ + 0.021484375 + ], + [ + 0.018341064453125 + ], + [ + 0.017730712890625 + ], + [ + 0.0257110595703125 + ], + [ + 0.0167388916015625 + ], + [ + 0.017822265625 + ], + [ + 0.016204833984375 + ], + [ + 0.0133209228515625 + ], + [ + 0.0187835693359375 + ], + [ + 0.015716552734375 + ], + [ + 0.0193939208984375 + ], + [ + 0.018707275390625 + ], + [ + 0.0181427001953125 + ], + [ + 0.017822265625 + ], + [ + 0.018035888671875 + ], + [ + 0.01763916015625 + ], + [ + 0.0210418701171875 + ], + [ + 0.018341064453125 + ], + [ + 0.0186614990234375 + ], + [ + 0.0202789306640625 + ], + [ + 0.01519775390625 + ], + [ + 0.020172119140625 + ], + [ + 0.02069091796875 + ], + [ + 0.0180816650390625 + ], + [ + 0.0163726806640625 + ], + [ + 0.0164337158203125 + ], + [ + 0.017852783203125 + ], + [ + 0.018646240234375 + ], + [ + 0.0186614990234375 + ], + [ + 0.0171356201171875 + ], + [ + 0.0163116455078125 + ], + [ + 0.01611328125 + ], + [ + 0.0183868408203125 + ], + [ + 0.016571044921875 + ], + [ + 0.024322509765625 + ], + [ + 0.017547607421875 + ], + [ + 0.01885986328125 + ], + [ + 0.0171051025390625 + ], + [ + 0.0189971923828125 + ], + [ + 0.019134521484375 + ], + [ + 0.0159759521484375 + ], + [ + 0.020416259765625 + ], + [ + 0.0206756591796875 + ], + [ + 0.0185089111328125 + ], + [ + 0.0176544189453125 + ], + [ + 0.01861572265625 + ], + [ + 0.0157623291015625 + ], + [ + 0.0214691162109375 + ], + [ + 0.0176239013671875 + ], + [ + 0.0150299072265625 + ], + [ + 0.0193939208984375 + ], + [ + 0.02099609375 + ], + [ + 0.0237274169921875 + ], + [ + 0.0191802978515625 + ], + [ + 0.0176849365234375 + ], + [ + 0.01983642578125 + ], + [ + 0.0178070068359375 + ], + [ + 0.020050048828125 + ], + [ + 0.01355743408203125 + ], + [ + 0.01800537109375 + ], + [ + 0.019195556640625 + ], + [ + 0.0178375244140625 + ], + [ + 0.0227813720703125 + ], + [ + 0.01983642578125 + ], + [ + 0.019744873046875 + ], + [ + 0.0207977294921875 + ], + [ + 0.0200958251953125 + ], + [ + 0.0193939208984375 + ], + [ + 0.018280029296875 + ], + [ + 0.0204620361328125 + ], + [ + 0.0170745849609375 + ], + [ + 0.0171661376953125 + ], + [ + 0.0176849365234375 + ], + [ + 0.015625 + ], + [ + 0.01715087890625 + ], + [ + 0.01885986328125 + ], + [ + 0.015869140625 + ], + [ + 0.0142364501953125 + ], + [ + 0.01629638671875 + ], + [ + 0.017852783203125 + ], + [ + 0.01678466796875 + ], + [ + 0.0186920166015625 + ], + [ + 0.0174560546875 + ], + [ + 0.016754150390625 + ], + [ + 0.0172119140625 + ], + [ + 0.0206756591796875 + ], + [ + 0.02178955078125 + ], + [ + 0.02001953125 + ], + [ + 0.0166473388671875 + ] + ], + "asymmetric_weights_decompressor_linear_weight_0": [ + [ + 0.0009713172912597656 + ], + [ + 0.0009407997131347656 + ], + [ + 0.000934600830078125 + ], + [ + 0.0009560585021972656 + ], + [ + 0.00090789794921875 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009274482727050781 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009093284606933594 + ], + [ + 0.0009679794311523438 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009374618530273438 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009140968322753906 + ], + [ + 0.0009531974792480469 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009226799011230469 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009160041809082031 + ], + [ + 0.0009560585021972656 + ], + [ + 0.0009713172912597656 + ], + [ + 0.000942230224609375 + ], + [ + 0.0009527206420898438 + ], + [ + 0.0009603500366210938 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009450912475585938 + ], + [ + 0.0009388923645019531 + ], + [ + 0.0009670257568359375 + ], + [ + 0.00095367431640625 + ], + [ + 0.0009665489196777344 + ], + [ + 0.00096893310546875 + ], + [ + 0.0009512901306152344 + ], + [ + 0.0009665489196777344 + ], + [ + 0.000934600830078125 + ], + [ + 0.0009379386901855469 + ], + [ + 0.0009522438049316406 + ], + [ + 0.0009431838989257812 + ], + [ + 0.0009765625 + ], + [ + 0.0009579658508300781 + ], + [ + 0.0009608268737792969 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009694099426269531 + ], + [ + 0.0009603500366210938 + ], + [ + 0.0009665489196777344 + ], + [ + 0.0009260177612304688 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009188652038574219 + ], + [ + 0.0009365081787109375 + ], + [ + 0.00095367431640625 + ], + [ + 0.0009708404541015625 + ], + [ + 0.0009398460388183594 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009446144104003906 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009717941284179688 + ], + [ + 0.0009593963623046875 + ], + [ + 0.0009503364562988281 + ], + [ + 0.0009560585021972656 + ], + [ + 0.0009608268737792969 + ], + [ + 0.0009760856628417969 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0008978843688964844 + ] + ] +} \ No newline at end of file diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_False_ref_wc_param.json b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_False_ref_wc_param.json new file mode 100644 index 00000000000..fd8fbda6f54 --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_False_ref_wc_param.json @@ -0,0 +1,38 @@ +[ + { + "weight_name": "wte_weight_1", + "node_with_weight": "embedding", + "weight_port_id": 0, + "weight_dtype": "float32", + "weight_shape": [ + 128, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int8_asym", + "group_size": -1, + "codebook_values": null + } + }, + { + "weight_name": "linear_weight", + "node_with_weight": "linear", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + } +] \ No newline at end of file diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_True.dot b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_True.dot new file mode 100644 index 00000000000..0a7bb5fe8f8 --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_True.dot @@ -0,0 +1,24 @@ +strict digraph { +"0 linear_bias" [id=0, type="get_attr"]; +"1 lm_head_bias" [id=1, type="get_attr"]; +"2 input_ids" [id=2, type=input]; +"3 wte_weight_1_updated_constant0" [id=3, type="get_attr"]; +"4 symmetric_weights_decompressor_wte_weight_1_0" [id=4, type="call_module"]; +"5 embedding" [id=5, type=embedding]; +"6 linear_weight_updated_constant0" [id=6, type="get_attr"]; +"7 asymmetric_weights_decompressor_linear_weight_0" [id=7, type="call_module"]; +"8 linear" [id=8, type=linear]; +"9 linear_1" [id=9, type=linear]; +"10 output" [id=10, type=output]; +"0 linear_bias" -> "8 linear" [style=solid, label="(64,)"]; +"1 lm_head_bias" -> "9 linear_1" [style=solid, label="(128,)"]; +"2 input_ids" -> "5 embedding" [style=solid, label="(5,)"]; +"3 wte_weight_1_updated_constant0" -> "4 symmetric_weights_decompressor_wte_weight_1_0" [style=solid, label="(4096, 1)"]; +"4 symmetric_weights_decompressor_wte_weight_1_0" -> "5 embedding" [style=solid, label="(128, 64)"]; +"4 symmetric_weights_decompressor_wte_weight_1_0" -> "9 linear_1" [style=solid, label="(128, 64)"]; +"5 embedding" -> "8 linear" [style=solid, label="(5, 64)"]; +"6 linear_weight_updated_constant0" -> "7 asymmetric_weights_decompressor_linear_weight_0" [style=solid, label="(64, 64)"]; +"7 asymmetric_weights_decompressor_linear_weight_0" -> "8 linear" [style=solid, label="(64, 64)"]; +"8 linear" -> "9 linear_1" [style=solid, label="(5, 64)"]; +"9 linear_1" -> "10 output" [style=solid, label="(5, 128)"]; +} diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_True_awq_True_scale_estimation_True_ref_wc_scales.json b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_True_awq_True_scale_estimation_True_ref_wc_scales.json new file mode 100644 index 00000000000..24d93dde47d --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_True_awq_True_scale_estimation_True_ref_wc_scales.json @@ -0,0 +1,1222 @@ +{ + "symmetric_weights_decompressor_wte_weight_1_0": [ + [ + [ + -0.290283203125 + ], + [ + -0.283447265625 + ] + ], + [ + [ + 0.313720703125 + ], + [ + -0.50439453125 + ] + ], + [ + [ + 0.28662109375 + ], + [ + -0.263916015625 + ] + ], + [ + [ + 0.255859375 + ], + [ + 0.1998291015625 + ] + ], + [ + [ + -0.271240234375 + ], + [ + -0.319580078125 + ] + ], + [ + [ + -0.3564453125 + ], + [ + 0.31103515625 + ] + ], + [ + [ + 0.3388671875 + ], + [ + -0.2958984375 + ] + ], + [ + [ + -0.2027587890625 + ], + [ + 0.32080078125 + ] + ], + [ + [ + -0.2509765625 + ], + [ + -0.423095703125 + ] + ], + [ + [ + -0.290771484375 + ], + [ + 0.337646484375 + ] + ], + [ + [ + 0.34912109375 + ], + [ + -0.3125 + ] + ], + [ + [ + 0.286865234375 + ], + [ + 0.3876953125 + ] + ], + [ + [ + 0.373046875 + ], + [ + 0.2486572265625 + ] + ], + [ + [ + 0.295166015625 + ], + [ + -0.238525390625 + ] + ], + [ + [ + -0.28857421875 + ], + [ + 0.2283935546875 + ] + ], + [ + [ + 0.28564453125 + ], + [ + 0.2127685546875 + ] + ], + [ + [ + -0.35009765625 + ], + [ + -0.275634765625 + ] + ], + [ + [ + -0.3984375 + ], + [ + -0.3095703125 + ] + ], + [ + [ + 0.28564453125 + ], + [ + 0.31591796875 + ] + ], + [ + [ + -0.351318359375 + ], + [ + -0.304931640625 + ] + ], + [ + [ + -0.316650390625 + ], + [ + -0.35595703125 + ] + ], + [ + [ + -0.3603515625 + ], + [ + -0.2188720703125 + ] + ], + [ + [ + -0.233642578125 + ], + [ + 0.303466796875 + ] + ], + [ + [ + 0.27685546875 + ], + [ + -0.26318359375 + ] + ], + [ + [ + 0.271240234375 + ], + [ + 0.2080078125 + ] + ], + [ + [ + 0.30126953125 + ], + [ + 0.26171875 + ] + ], + [ + [ + 0.3359375 + ], + [ + -0.393798828125 + ] + ], + [ + [ + 0.326904296875 + ], + [ + -0.316650390625 + ] + ], + [ + [ + -0.2403564453125 + ], + [ + -0.2607421875 + ] + ], + [ + [ + 0.2646484375 + ], + [ + -0.249267578125 + ] + ], + [ + [ + 0.47900390625 + ], + [ + 0.36181640625 + ] + ], + [ + [ + 0.25048828125 + ], + [ + 0.3310546875 + ] + ], + [ + [ + -0.25830078125 + ], + [ + 0.25634765625 + ] + ], + [ + [ + -0.349365234375 + ], + [ + -0.290283203125 + ] + ], + [ + [ + 0.269287109375 + ], + [ + 0.38134765625 + ] + ], + [ + [ + -0.274169921875 + ], + [ + -0.253662109375 + ] + ], + [ + [ + -0.251953125 + ], + [ + 0.302490234375 + ] + ], + [ + [ + 0.2353515625 + ], + [ + -0.262939453125 + ] + ], + [ + [ + -0.268310546875 + ], + [ + 0.3466796875 + ] + ], + [ + [ + -0.39892578125 + ], + [ + -0.27734375 + ] + ], + [ + [ + 0.2763671875 + ], + [ + -0.308349609375 + ] + ], + [ + [ + -0.254638671875 + ], + [ + -0.31689453125 + ] + ], + [ + [ + -0.36572265625 + ], + [ + 0.453857421875 + ] + ], + [ + [ + -0.2156982421875 + ], + [ + 0.290283203125 + ] + ], + [ + [ + 0.34716796875 + ], + [ + 0.346923828125 + ] + ], + [ + [ + 0.235595703125 + ], + [ + -0.266357421875 + ] + ], + [ + [ + -0.2052001953125 + ], + [ + -0.253173828125 + ] + ], + [ + [ + -0.259765625 + ], + [ + 0.339111328125 + ] + ], + [ + [ + 0.259521484375 + ], + [ + 0.25537109375 + ] + ], + [ + [ + -0.3212890625 + ], + [ + -0.283935546875 + ] + ], + [ + [ + 0.2685546875 + ], + [ + -0.314453125 + ] + ], + [ + [ + 0.2138671875 + ], + [ + 0.378662109375 + ] + ], + [ + [ + 0.256591796875 + ], + [ + -0.311279296875 + ] + ], + [ + [ + -0.272216796875 + ], + [ + 0.302490234375 + ] + ], + [ + [ + -0.330078125 + ], + [ + 0.22216796875 + ] + ], + [ + [ + -0.241943359375 + ], + [ + -0.442626953125 + ] + ], + [ + [ + -0.33740234375 + ], + [ + 0.235107421875 + ] + ], + [ + [ + 0.320068359375 + ], + [ + -0.275146484375 + ] + ], + [ + [ + -0.338623046875 + ], + [ + 0.3076171875 + ] + ], + [ + [ + 0.256591796875 + ], + [ + 0.2252197265625 + ] + ], + [ + [ + 0.34765625 + ], + [ + -0.29541015625 + ] + ], + [ + [ + -0.306396484375 + ], + [ + 0.353271484375 + ] + ], + [ + [ + 0.309326171875 + ], + [ + -0.231201171875 + ] + ], + [ + [ + 0.290283203125 + ], + [ + -0.2315673828125 + ] + ], + [ + [ + 0.263671875 + ], + [ + -0.26025390625 + ] + ], + [ + [ + 0.320556640625 + ], + [ + 0.212890625 + ] + ], + [ + [ + 0.341796875 + ], + [ + 0.2548828125 + ] + ], + [ + [ + -0.302001953125 + ], + [ + -0.3212890625 + ] + ], + [ + [ + -0.1842041015625 + ], + [ + 0.28271484375 + ] + ], + [ + [ + -0.2568359375 + ], + [ + 0.26318359375 + ] + ], + [ + [ + -0.2091064453125 + ], + [ + 0.304443359375 + ] + ], + [ + [ + -0.2381591796875 + ], + [ + 0.319580078125 + ] + ], + [ + [ + 0.264892578125 + ], + [ + 0.252197265625 + ] + ], + [ + [ + -0.271728515625 + ], + [ + 0.42333984375 + ] + ], + [ + [ + -0.2264404296875 + ], + [ + -0.36376953125 + ] + ], + [ + [ + -0.296875 + ], + [ + 0.30419921875 + ] + ], + [ + [ + 0.268798828125 + ], + [ + -0.276123046875 + ] + ], + [ + [ + 0.2763671875 + ], + [ + 0.330810546875 + ] + ], + [ + [ + 0.314453125 + ], + [ + -0.29541015625 + ] + ], + [ + [ + -0.19921875 + ], + [ + 0.2685546875 + ] + ], + [ + [ + 0.32763671875 + ], + [ + -0.323486328125 + ] + ], + [ + [ + -0.3369140625 + ], + [ + 0.322265625 + ] + ], + [ + [ + 0.32763671875 + ], + [ + 0.293212890625 + ] + ], + [ + [ + 0.2374267578125 + ], + [ + 0.3037109375 + ] + ], + [ + [ + -0.295166015625 + ], + [ + 0.2978515625 + ] + ], + [ + [ + -0.306396484375 + ], + [ + -0.25634765625 + ] + ], + [ + [ + 0.314697265625 + ], + [ + -0.36962890625 + ] + ], + [ + [ + 0.2705078125 + ], + [ + 0.290283203125 + ] + ], + [ + [ + -0.2493896484375 + ], + [ + -0.263427734375 + ] + ], + [ + [ + -0.359619140625 + ], + [ + 0.2587890625 + ] + ], + [ + [ + -0.23583984375 + ], + [ + -0.348388671875 + ] + ], + [ + [ + -0.2421875 + ], + [ + -0.436767578125 + ] + ], + [ + [ + -0.349365234375 + ], + [ + 0.26220703125 + ] + ], + [ + [ + 0.25732421875 + ], + [ + 0.314697265625 + ] + ], + [ + [ + 0.21728515625 + ], + [ + 0.338134765625 + ] + ], + [ + [ + 0.301025390625 + ], + [ + -0.2666015625 + ] + ], + [ + [ + 0.314697265625 + ], + [ + -0.32421875 + ] + ], + [ + [ + 0.260986328125 + ], + [ + 0.262939453125 + ] + ], + [ + [ + 0.2110595703125 + ], + [ + -0.28759765625 + ] + ], + [ + [ + 0.339599609375 + ], + [ + 0.359375 + ] + ], + [ + [ + -0.289306640625 + ], + [ + 0.279296875 + ] + ], + [ + [ + -0.330810546875 + ], + [ + 0.395263671875 + ] + ], + [ + [ + 0.271240234375 + ], + [ + -0.32373046875 + ] + ], + [ + [ + 0.32568359375 + ], + [ + -0.3037109375 + ] + ], + [ + [ + 0.387939453125 + ], + [ + 0.3095703125 + ] + ], + [ + [ + 0.325927734375 + ], + [ + -0.314697265625 + ] + ], + [ + [ + -0.30615234375 + ], + [ + -0.346923828125 + ] + ], + [ + [ + -0.330322265625 + ], + [ + -0.28369140625 + ] + ], + [ + [ + 0.373046875 + ], + [ + 0.251708984375 + ] + ], + [ + [ + 0.259033203125 + ], + [ + -0.284912109375 + ] + ], + [ + [ + 0.2054443359375 + ], + [ + 0.29931640625 + ] + ], + [ + [ + -0.259033203125 + ], + [ + 0.304443359375 + ] + ], + [ + [ + -0.268310546875 + ], + [ + 0.2294921875 + ] + ], + [ + [ + 0.27392578125 + ], + [ + 0.312744140625 + ] + ], + [ + [ + 0.2205810546875 + ], + [ + -0.31884765625 + ] + ], + [ + [ + 0.286865234375 + ], + [ + -0.21923828125 + ] + ], + [ + [ + 0.2305908203125 + ], + [ + -0.2103271484375 + ] + ], + [ + [ + -0.292236328125 + ], + [ + -0.2314453125 + ] + ], + [ + [ + -0.300537109375 + ], + [ + -0.32080078125 + ] + ], + [ + [ + 0.292236328125 + ], + [ + 0.2412109375 + ] + ], + [ + [ + -0.30126953125 + ], + [ + 0.29443359375 + ] + ], + [ + [ + 0.277587890625 + ], + [ + -0.27880859375 + ] + ], + [ + [ + 0.26318359375 + ], + [ + 0.280029296875 + ] + ], + [ + [ + -0.32763671875 + ], + [ + -0.272216796875 + ] + ], + [ + [ + 0.272705078125 + ], + [ + -0.38623046875 + ] + ], + [ + [ + -0.249267578125 + ], + [ + 0.4208984375 + ] + ], + [ + [ + 0.290283203125 + ], + [ + 0.34765625 + ] + ], + [ + [ + -0.284912109375 + ], + [ + -0.34326171875 + ] + ] + ], + "asymmetric_weights_decompressor_linear_weight_0": [ + [ + 0.0009713172912597656 + ], + [ + 0.0009407997131347656 + ], + [ + 0.000934600830078125 + ], + [ + 0.0009560585021972656 + ], + [ + 0.00090789794921875 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009274482727050781 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009093284606933594 + ], + [ + 0.0009679794311523438 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009374618530273438 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009140968322753906 + ], + [ + 0.0009531974792480469 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009226799011230469 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009160041809082031 + ], + [ + 0.0009560585021972656 + ], + [ + 0.0009713172912597656 + ], + [ + 0.000942230224609375 + ], + [ + 0.0009527206420898438 + ], + [ + 0.0009603500366210938 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009450912475585938 + ], + [ + 0.0009388923645019531 + ], + [ + 0.0009670257568359375 + ], + [ + 0.00095367431640625 + ], + [ + 0.0009665489196777344 + ], + [ + 0.00096893310546875 + ], + [ + 0.0009512901306152344 + ], + [ + 0.0009665489196777344 + ], + [ + 0.000934600830078125 + ], + [ + 0.0009379386901855469 + ], + [ + 0.0009522438049316406 + ], + [ + 0.0009431838989257812 + ], + [ + 0.0009765625 + ], + [ + 0.0009579658508300781 + ], + [ + 0.0009608268737792969 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009694099426269531 + ], + [ + 0.0009603500366210938 + ], + [ + 0.0009665489196777344 + ], + [ + 0.0009260177612304688 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009188652038574219 + ], + [ + 0.0009365081787109375 + ], + [ + 0.00095367431640625 + ], + [ + 0.0009708404541015625 + ], + [ + 0.0009398460388183594 + ], + [ + 0.0009655952453613281 + ], + [ + 0.0009446144104003906 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009717941284179688 + ], + [ + 0.0009593963623046875 + ], + [ + 0.0009503364562988281 + ], + [ + 0.0009560585021972656 + ], + [ + 0.0009608268737792969 + ], + [ + 0.0009760856628417969 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0008978843688964844 + ] + ] +} \ No newline at end of file diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_True_ref_wc_param.json b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_True_ref_wc_param.json new file mode 100644 index 00000000000..81205ac2ca8 --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int4wo_sym_gs32_all_layers_True_ref_wc_param.json @@ -0,0 +1,38 @@ +[ + { + "weight_name": "wte_weight_1", + "node_with_weight": "embedding", + "weight_port_id": 0, + "weight_dtype": "float32", + "weight_shape": [ + 128, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + }, + { + "weight_name": "linear_weight", + "node_with_weight": "linear", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int4_sym", + "group_size": 32, + "codebook_values": null + } + } +] \ No newline at end of file diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int8wo_asym_gs-1_all_layers_False.dot b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int8wo_asym_gs-1_all_layers_False.dot new file mode 100644 index 00000000000..b249fdf7ce3 --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int8wo_asym_gs-1_all_layers_False.dot @@ -0,0 +1,24 @@ +strict digraph { +"0 linear_bias" [id=0, type="get_attr"]; +"1 lm_head_bias" [id=1, type="get_attr"]; +"2 input_ids" [id=2, type=input]; +"3 wte_weight_1_updated_constant0" [id=3, type="get_attr"]; +"4 asymmetric_weights_decompressor_wte_weight_1_0" [id=4, type="call_module"]; +"5 embedding" [id=5, type=embedding]; +"6 linear_weight_updated_constant0" [id=6, type="get_attr"]; +"7 asymmetric_weights_decompressor_linear_weight_0" [id=7, type="call_module"]; +"8 linear" [id=8, type=linear]; +"9 linear_1" [id=9, type=linear]; +"10 output" [id=10, type=output]; +"0 linear_bias" -> "8 linear" [style=solid, label="(64,)"]; +"1 lm_head_bias" -> "9 linear_1" [style=solid, label="(128,)"]; +"2 input_ids" -> "5 embedding" [style=solid, label="(5,)"]; +"3 wte_weight_1_updated_constant0" -> "4 asymmetric_weights_decompressor_wte_weight_1_0" [style=solid, label="(128, 64)"]; +"4 asymmetric_weights_decompressor_wte_weight_1_0" -> "5 embedding" [style=solid, label="(128, 64)"]; +"4 asymmetric_weights_decompressor_wte_weight_1_0" -> "9 linear_1" [style=solid, label="(128, 64)"]; +"5 embedding" -> "8 linear" [style=solid, label="(5, 64)"]; +"6 linear_weight_updated_constant0" -> "7 asymmetric_weights_decompressor_linear_weight_0" [style=solid, label="(64, 64)"]; +"7 asymmetric_weights_decompressor_linear_weight_0" -> "8 linear" [style=solid, label="(64, 64)"]; +"8 linear" -> "9 linear_1" [style=solid, label="(5, 64)"]; +"9 linear_1" -> "10 output" [style=solid, label="(5, 128)"]; +} diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int8wo_asym_gs-1_all_layers_False_awq_False_scale_estimation_False_ref_wc_scales.json b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int8wo_asym_gs-1_all_layers_False_awq_False_scale_estimation_False_ref_wc_scales.json new file mode 100644 index 00000000000..edb48f38b53 --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int8wo_asym_gs-1_all_layers_False_awq_False_scale_estimation_False_ref_wc_scales.json @@ -0,0 +1,582 @@ +{ + "asymmetric_weights_decompressor_wte_weight_1_0": [ + [ + 0.01552581787109375 + ], + [ + 0.02117919921875 + ], + [ + 0.020660400390625 + ], + [ + 0.0182037353515625 + ], + [ + 0.0228118896484375 + ], + [ + 0.0202178955078125 + ], + [ + 0.0205841064453125 + ], + [ + 0.0165252685546875 + ], + [ + 0.0195770263671875 + ], + [ + 0.0199127197265625 + ], + [ + 0.018218994140625 + ], + [ + 0.0209197998046875 + ], + [ + 0.0209503173828125 + ], + [ + 0.0175628662109375 + ], + [ + 0.0171966552734375 + ], + [ + 0.0175628662109375 + ], + [ + 0.01415252685546875 + ], + [ + 0.0167388916015625 + ], + [ + 0.0225982666015625 + ], + [ + 0.018218994140625 + ], + [ + 0.0182342529296875 + ], + [ + 0.02288818359375 + ], + [ + 0.0174713134765625 + ], + [ + 0.01513671875 + ], + [ + 0.0233917236328125 + ], + [ + 0.0154266357421875 + ], + [ + 0.016845703125 + ], + [ + 0.0186767578125 + ], + [ + 0.018646240234375 + ], + [ + 0.0182952880859375 + ], + [ + 0.0182647705078125 + ], + [ + 0.017181396484375 + ], + [ + 0.0171966552734375 + ], + [ + 0.0185089111328125 + ], + [ + 0.0191497802734375 + ], + [ + 0.0159454345703125 + ], + [ + 0.02313232421875 + ], + [ + 0.0196075439453125 + ], + [ + 0.0168304443359375 + ], + [ + 0.015594482421875 + ], + [ + 0.01898193359375 + ], + [ + 0.021270751953125 + ], + [ + 0.015869140625 + ], + [ + 0.0191192626953125 + ], + [ + 0.0183563232421875 + ], + [ + 0.01557159423828125 + ], + [ + 0.02337646484375 + ], + [ + 0.01558685302734375 + ], + [ + 0.0152740478515625 + ], + [ + 0.0184783935546875 + ], + [ + 0.016021728515625 + ], + [ + 0.0166473388671875 + ], + [ + 0.0171051025390625 + ], + [ + 0.0184326171875 + ], + [ + 0.0150909423828125 + ], + [ + 0.023773193359375 + ], + [ + 0.0170745849609375 + ], + [ + 0.0181121826171875 + ], + [ + 0.01715087890625 + ], + [ + 0.020843505859375 + ], + [ + 0.018280029296875 + ], + [ + 0.0178375244140625 + ], + [ + 0.01375579833984375 + ], + [ + 0.0179290771484375 + ], + [ + 0.0196075439453125 + ], + [ + 0.01708984375 + ], + [ + 0.0186920166015625 + ], + [ + 0.0255584716796875 + ], + [ + 0.02203369140625 + ], + [ + 0.0218505859375 + ], + [ + 0.0159759521484375 + ], + [ + 0.017852783203125 + ], + [ + 0.01922607421875 + ], + [ + 0.0218658447265625 + ], + [ + 0.0211029052734375 + ], + [ + 0.017547607421875 + ], + [ + 0.016937255859375 + ], + [ + 0.020782470703125 + ], + [ + 0.0189056396484375 + ], + [ + 0.01519775390625 + ], + [ + 0.01806640625 + ], + [ + 0.021728515625 + ], + [ + 0.0183868408203125 + ], + [ + 0.019927978515625 + ], + [ + 0.018463134765625 + ], + [ + 0.0167999267578125 + ], + [ + 0.017059326171875 + ], + [ + 0.01708984375 + ], + [ + 0.016143798828125 + ], + [ + 0.0185699462890625 + ], + [ + 0.018341064453125 + ], + [ + 0.01262664794921875 + ], + [ + 0.01849365234375 + ], + [ + 0.0159759521484375 + ], + [ + 0.019012451171875 + ], + [ + 0.01947021484375 + ], + [ + 0.0208282470703125 + ], + [ + 0.0182342529296875 + ], + [ + 0.0167999267578125 + ], + [ + 0.01523590087890625 + ], + [ + 0.021331787109375 + ], + [ + 0.0187225341796875 + ], + [ + 0.0179443359375 + ], + [ + 0.017608642578125 + ], + [ + 0.01416778564453125 + ], + [ + 0.0186614990234375 + ], + [ + 0.01302337646484375 + ], + [ + 0.018463134765625 + ], + [ + 0.0204010009765625 + ], + [ + 0.018463134765625 + ], + [ + 0.0205078125 + ], + [ + 0.0153350830078125 + ], + [ + 0.01751708984375 + ], + [ + 0.01922607421875 + ], + [ + 0.0174560546875 + ], + [ + 0.0154571533203125 + ], + [ + 0.01812744140625 + ], + [ + 0.019073486328125 + ], + [ + 0.017852783203125 + ], + [ + 0.0158538818359375 + ], + [ + 0.0195465087890625 + ], + [ + 0.0213470458984375 + ], + [ + 0.01995849609375 + ], + [ + 0.0166168212890625 + ], + [ + 0.019561767578125 + ], + [ + 0.0184478759765625 + ], + [ + 0.0162353515625 + ], + [ + 0.021270751953125 + ] + ], + "asymmetric_weights_decompressor_linear_weight_0": [ + [ + 0.0009274482727050781 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009093284606933594 + ], + [ + 0.0009679794311523438 + ], + [ + 0.0009703636169433594 + ], + [ + 0.0009570121765136719 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009617805480957031 + ], + [ + 0.0009212493896484375 + ], + [ + 0.0009541511535644531 + ], + [ + 0.000965118408203125 + ], + [ + 0.0009226799011230469 + ], + [ + 0.0009541511535644531 + ], + [ + 0.0009613037109375 + ], + [ + 0.0009560585021972656 + ], + [ + 0.0009713172912597656 + ], + [ + 0.000942230224609375 + ], + [ + 0.0009274482727050781 + ], + [ + 0.0009565353393554688 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009312629699707031 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0009579658508300781 + ], + [ + 0.0009436607360839844 + ], + [ + 0.0009632110595703125 + ], + [ + 0.0009717941284179688 + ], + [ + 0.0009670257568359375 + ], + [ + 0.0009665489196777344 + ], + [ + 0.000949859619140625 + ], + [ + 0.0009379386901855469 + ], + [ + 0.0009522438049316406 + ], + [ + 0.0009431838989257812 + ], + [ + 0.0009694099426269531 + ], + [ + 0.0009660720825195312 + ], + [ + 0.0009608268737792969 + ], + [ + 0.0009489059448242188 + ], + [ + 0.0009322166442871094 + ], + [ + 0.0009675025939941406 + ], + [ + 0.0009665489196777344 + ], + [ + 0.0009264945983886719 + ], + [ + 0.0009684562683105469 + ], + [ + 0.0009188652038574219 + ], + [ + 0.0009365081787109375 + ], + [ + 0.00095367431640625 + ], + [ + 0.0009708404541015625 + ], + [ + 0.0009508132934570312 + ], + [ + 0.0009217262268066406 + ], + [ + 0.0009446144104003906 + ], + [ + 0.0009517669677734375 + ], + [ + 0.0009717941284179688 + ], + [ + 0.0009593963623046875 + ], + [ + 0.0009174346923828125 + ], + [ + 0.0009484291076660156 + ], + [ + 0.0009756088256835938 + ], + [ + 0.0009760856628417969 + ], + [ + 0.0009722709655761719 + ], + [ + 0.0008978843688964844 + ], + [ + 0.0009188652038574219 + ], + [ + 0.0009579658508300781 + ], + [ + 0.000946044921875 + ], + [ + 0.0009713172912597656 + ], + [ + 0.0009474754333496094 + ], + [ + 0.0009555816650390625 + ] + ] +} \ No newline at end of file diff --git a/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int8wo_asym_gs-1_all_layers_False_ref_wc_param.json b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int8wo_asym_gs-1_all_layers_False_ref_wc_param.json new file mode 100644 index 00000000000..49d45c1fffb --- /dev/null +++ b/tests/executorch/data/fx/compress_pt2e/OpenVINOQuantizer/short_transformer_shared/int8wo_asym_gs-1_all_layers_False_ref_wc_param.json @@ -0,0 +1,38 @@ +[ + { + "weight_name": "wte_weight_1", + "node_with_weight": "embedding", + "weight_port_id": 0, + "weight_dtype": "float32", + "weight_shape": [ + 128, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int8_asym", + "group_size": -1, + "codebook_values": null + } + }, + { + "weight_name": "linear_weight", + "node_with_weight": "linear", + "weight_port_id": 1, + "weight_dtype": "float32", + "weight_shape": [ + 64, + 64 + ], + "reduction_axes": [ + 1 + ], + "compression_config": { + "mode": "int8_asym", + "group_size": -1, + "codebook_values": null + } + } +] \ No newline at end of file diff --git a/tests/executorch/requirements.txt b/tests/executorch/requirements.txt new file mode 100644 index 00000000000..49faa6dac6a --- /dev/null +++ b/tests/executorch/requirements.txt @@ -0,0 +1,47 @@ +--pre +--extra-index-url https://download.pytorch.org/whl/nightly/cpu + +# Pytorch +torch==2.10.0.dev20250922+cpu +torchvision==0.25.0.dev20250922+cpu +torchao==0.14.0.dev20250922+cpu + +# Openvino +openvino==2025.3.0 + +# ONNX +onnx==1.17.0; python_version < '3.13' +onnx==1.18.0; python_version >= '3.13' +onnxruntime==1.21.1 + +# Copied from https://github.com/anzr299/executorch/blob/an/quantizer_nncf_pt2e_support/requirements-dev.txt +cmake>=3.29, <4.0.0 # For building binary targets in the wheel. +packaging>=24.2 # Lower bound required by setuptools +pip>=23 # For building the pip package. +pyyaml # Imported by the kernel codegen tools. +setuptools>=77.0.3 # For building the pip package contents. +wheel # For building the pip package archive. +zstd # Imported by resolve_buck.py. +certifi # Imported by resolve_buck.py. + +# Copied from tests/torch2/requirements.txt +addict>=2.4.0 +efficientnet_pytorch==0.7.1 +transformers==4.52.1 + +sentence-transformers==4.1.0 +optimum-intel==1.24.0 +optimum==1.26.0 +accelerate==1.9.0 +fastdownload==0.0.7 + + +# Tests and examples +pytest==8.0.2 +pytest-cov==4.1.0 +pytest-mock==3.12.0 +pytest-dependency==0.6.0 +pytest-ordering==0.6 +pytest-xdist==3.5.0 +pytest-forked==1.6.0 +pytest-split==0.9.0 diff --git a/tests/executorch/test_quantizer_compression.py b/tests/executorch/test_quantizer_compression.py new file mode 100644 index 00000000000..464d0a4764a --- /dev/null +++ b/tests/executorch/test_quantizer_compression.py @@ -0,0 +1,351 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import dataclasses +import json +from dataclasses import dataclass +from enum import Enum +from functools import partial +from typing import Any, Callable, Optional + +import pytest +import torch +import torch.fx +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e +from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e + +import nncf +from executorch.backends.openvino.quantizer.quantizer import OpenVINOQuantizer +from executorch.backends.openvino.quantizer.quantizer import QuantizationMode +from nncf.common.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.utils.os import safe_open +from nncf.experimental.torch.fx import compress_pt2e +from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter +from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor +from tests.cross_fw.shared.nx_graph import compare_nx_graph_with_reference +from tests.cross_fw.shared.paths import TEST_ROOT +from tests.torch.test_models.llama import LlamaDecoderOnly +from tests.torch.test_models.synthetic import ShortTransformer +from tests.torch2.fx.helpers import get_torch_fx_model + +FX_PT2E_DIR = TEST_ROOT / "executorch" / "data" / "fx" / "compress_pt2e" +FX_AO_DIR = TEST_ROOT / "executorch" / "data" / "fx" / "ao_export_compression_OpenVINOQuantizer" + + +@dataclass +class ModelCase: + model_builder: Callable[[], torch.nn.Module] + model_id: str + input_shape: tuple[int, ...] + + +def get_dot_filename(model_name: str) -> str: + return model_name + ".dot" + + +def get_wc_param_filename(model_name: str) -> str: + return model_name + "_ref_wc_param.json" + + +def get_wc_scales_filename(model_name: str) -> str: + return model_name + "_ref_wc_scales.json" + + +def build_torch_fx_model(model_case: ModelCase) -> tuple[torch.fx.GraphModule, torch.Tensor]: + model = model_case.model_builder() + # ShortTransformer takes token ids; match prior synthetic tests (int32) + example_input = torch.ones(model_case.input_shape, dtype=torch.int32) + fx_model = get_torch_fx_model(model, example_input) + return fx_model, example_input + + +def _get_calibration_dataset(example_input: torch.Tensor) -> nncf.Dataset: + torch.manual_seed(42) + + def transform_fn(x): + return x.to("cpu") + + sample_1 = torch.randint_like(example_input, 0, 10) + sample_2 = torch.randint_like(example_input, 0, 10) + return nncf.Dataset([example_input, sample_1, sample_2], transform_fn) + + +def get_openvino_quantizer(*args, **kwargs) -> OpenVINOQuantizer: + return OpenVINOQuantizer(*args, **kwargs) + + +def _string_from_quantizer_params(qparams: dict[str, Any], pt2e_param: Optional[dict[str, Any]] = None) -> str: + mode = qparams.get("mode") + gs = qparams.get("group_size", "-1") + all_layers = qparams.get("all_layers", "False") + if pt2e_param is None: + return f"{mode.value}_gs{gs}_all_layers_{all_layers}" + awq = pt2e_param.get("awq", "False") + scale_estimation = pt2e_param.get("scale_estimation", "False") + return f"{mode.value}_gs{gs}_all_layers_{all_layers}_awq_{awq}_scale_estimation_{scale_estimation}" + + +def check_multiple_isinstance(object_to_check: Any, objects: list[Any]): + if not object_to_check: + return False + for obj in objects: + if isinstance(object_to_check, obj): + return True + return False + + +def get_scale_values_from_model(model: torch.fx.GraphModule): + node_to_scale_mapping = {} + decompressor_modules = [ + INT4AsymmetricWeightsDecompressor, + INT4SymmetricWeightsDecompressor, + INT8AsymmetricWeightsDecompressor, + INT8SymmetricWeightsDecompressor, + ] + for node in model.graph.nodes: + # print(node.name, node.target, node.meta) + node_module = getattr(model, node.target) if node.op == "call_module" else None + if not check_multiple_isinstance(node_module, decompressor_modules): + continue + state_dict_scale_name = f"{node.target}._scale" + node_to_scale_mapping[node.name] = model.state_dict()[state_dict_scale_name] + + return node_to_scale_mapping + + +def get_test_cases(): + test_cases = [] + for model in BASE_MODELS: + for qparam in QUANTIZER_PARAMS: + pt2e_params = PT2E_PARAMS + if qparam.get("mode") in {QuantizationMode.INT8WO_ASYM, QuantizationMode.INT8WO_SYM}: + pt2e_params = [{}] + for pt2e_param in pt2e_params: + test_cases.append( + ( + model, + qparam, + pt2e_param, + ) + ) + return test_cases + + +BASE_MODELS = ( + ModelCase(LlamaDecoderOnly, "LlamaDecoderOnly", [1, 3, 64]), + ModelCase(partial(ShortTransformer, 64, 128, True), "short_transformer_shared", [5]), +) + +QUANTIZER_PARAMS = ( + {"mode": QuantizationMode.INT8WO_ASYM}, + {"mode": QuantizationMode.INT4WO_SYM, "group_size": 32}, + {"mode": QuantizationMode.INT4WO_SYM, "group_size": 32, "all_layers": True}, +) + +PT2E_PARAMS = ({"awq": True, "scale_estimation": True},) + + +TEST_MODELS = get_test_cases() + + +TEST_MODEL_IDS = [ + f"{m.model_id}__{_string_from_quantizer_params(qparams, pt2e_param)}" for (m, qparams, pt2e_param) in TEST_MODELS +] + +INT8_COMPRESSION_MODES = [QuantizationMode.INT8WO_ASYM, QuantizationMode.INT8WO_SYM] + + +@pytest.mark.parametrize( + ("model_case", "quantizer_params", "pt2e_params"), + TEST_MODELS, + ids=TEST_MODEL_IDS, +) +@pytest.mark.parametrize( + "quantizer_builder", + [get_openvino_quantizer], + ids=["OpenVINOQuantizer"], +) +def test_compress_pt2e( + quantizer_builder: Callable[..., OpenVINOQuantizer], + model_case: ModelCase, + quantizer_params, + pt2e_params, +): + fx_model, example_input = build_torch_fx_model(model_case) + with torch.no_grad(): + ref_out = fx_model(example_input) + + calibration_dataset = _get_calibration_dataset(example_input) + + # Build quantizer directly from quantizer_params (already includes mode/group_size) + quantizer = quantizer_builder(**quantizer_params) + mode = quantizer_params.get("mode") + ratio = 1 if mode in INT8_COMPRESSION_MODES else 0.8 + + quantized_model = compress_pt2e(fx_model, quantizer=quantizer, ratio=ratio, dataset=calibration_dataset) + + with torch.no_grad(): + out = quantized_model(example_input) + assert out.shape == ref_out.shape, "Compressed model output shape mismatch." + + nncf_graph: NNCFGraph = GraphConverter.create_nncf_graph(quantized_model) + nx_graph = nncf_graph.get_graph_for_structure_analysis(extended=True) + param_string = _string_from_quantizer_params(quantizer_params) + path_to_dot = ( + FX_PT2E_DIR / quantizer.__class__.__name__ / model_case.model_id / get_dot_filename(param_string) + ).as_posix() + compare_nx_graph_with_reference(nx_graph, path_to_dot) + + +@pytest.mark.parametrize( + ("model_case", "quantizer_params", "pt2e_params"), + TEST_MODELS, + ids=TEST_MODEL_IDS, +) +@pytest.mark.parametrize( + "quantizer_builder", + [get_openvino_quantizer], + ids=["OpenVINOQuantizer"], +) +def test_compress_pt2e_scales( + quantizer_builder: Callable[..., OpenVINOQuantizer], + model_case: ModelCase, + quantizer_params, + pt2e_params, + regen_ref_data, +): + fx_model, example_input = build_torch_fx_model(model_case) + with torch.no_grad(): + ref_out = fx_model(example_input) + + calibration_dataset = _get_calibration_dataset(example_input) + + # Build quantizer directly from quantizer_params (already includes mode/group_size) + quantizer = quantizer_builder(**quantizer_params) + mode = quantizer_params.get("mode") + ratio = 1 if mode in INT8_COMPRESSION_MODES else 0.8 + quantized_model = compress_pt2e( + fx_model, quantizer=quantizer, ratio=ratio, dataset=calibration_dataset, **pt2e_params + ) + + with torch.no_grad(): + out = quantized_model(example_input) + assert out.shape == ref_out.shape, "Compressed model output shape mismatch." + + param_string = _string_from_quantizer_params(quantizer_params, pt2e_params) + ref_json_path = ( + FX_PT2E_DIR / quantizer.__class__.__name__ / model_case.model_id / get_wc_scales_filename(param_string) + ) + + scales_list = get_scale_values_from_model(quantized_model) + scales_list = to_json_serializable(scales_list) + + if regen_ref_data: + with safe_open(ref_json_path, "w") as file: + json.dump(scales_list, file, indent=4) + + with safe_open(ref_json_path, "r") as f: + json.load(f) + + +@pytest.mark.parametrize( + ("model_case", "quantizer_params", "pt2e_params"), + TEST_MODELS, + ids=TEST_MODEL_IDS, +) +@pytest.mark.parametrize( + "quantizer_builder", + [get_openvino_quantizer], + ids=["OpenVINOQuantizer"], +) +def test_openvino_quantizer( + model_case: ModelCase, + quantizer_params, + quantizer_builder: Callable[..., OpenVINOQuantizer], + pt2e_params, +): + fx_model, example_input = build_torch_fx_model(model_case) + quantizer = quantizer_builder(**quantizer_params) + + prepared = prepare_pt2e(fx_model, quantizer) + prepared(example_input) + ao_quantized_model = convert_pt2e(prepared) + + nncf_graph = GraphConverter.create_nncf_graph(ao_quantized_model) + nx_graph = nncf_graph.get_graph_for_structure_analysis(extended=True) + + param_string = _string_from_quantizer_params(quantizer_params) + path_to_dot = (FX_AO_DIR / model_case.model_id / get_dot_filename(param_string)).as_posix() + compare_nx_graph_with_reference(nx_graph, path_to_dot) + + +def to_json_serializable(obj: Any) -> dict[Any, Any]: + if dataclasses.is_dataclass(obj): + return {k: to_json_serializable(v) for k, v in dataclasses.asdict(obj).items()} + if isinstance(obj, Enum): + return obj.value + if isinstance(obj, (list, tuple)): + return [to_json_serializable(x) for x in obj] + if isinstance(obj, torch.Tensor): + return obj.detach().cpu().tolist() + if isinstance(obj, dict): + return {k: to_json_serializable(v) for k, v in obj.items()} + if isinstance(obj, NNCFNode): + return obj.node_name + return obj + + +@pytest.mark.parametrize( + ("model_case", "quantizer_params", "pt2e_params"), + TEST_MODELS, + ids=TEST_MODEL_IDS, +) +@pytest.mark.parametrize( + "quantizer_builder", + [get_openvino_quantizer], + ids=["OpenVINOQuantizer"], +) +def test_openvino_wc_params( + quantizer_builder: Callable[..., OpenVINOQuantizer], + model_case: ModelCase, + quantizer_params, + pt2e_params, + regen_ref_data, +): + fx_model, _ = build_torch_fx_model(model_case) + nncf_graph: NNCFGraph = GraphConverter.create_nncf_graph(fx_model) + + param_string = _string_from_quantizer_params(quantizer_params) + quantizer = quantizer_builder(**quantizer_params) + + all_weight_params, *_ = quantizer.get_nncf_weight_compression_parameters(fx_model, nncf_graph) + + wc_params = to_json_serializable(all_weight_params) + + ref_json_path = ( + FX_PT2E_DIR / quantizer.__class__.__name__ / model_case.model_id / get_wc_param_filename(param_string) + ) + + if regen_ref_data: + with safe_open(ref_json_path, "w") as file: + json.dump(wc_params, file, indent=4) + + with safe_open(ref_json_path, "r") as f: + ref_data = json.load(f) + + assert wc_params == ref_data, ( + f"Weight compression parameters JSON mismatch for {model_case.model_id} ({param_string}).\nRef: {ref_json_path}" + ) diff --git a/tests/torch/test_models/__init__.py b/tests/torch/test_models/__init__.py index 95cba87cc98..8dcdf092213 100644 --- a/tests/torch/test_models/__init__.py +++ b/tests/torch/test_models/__init__.py @@ -15,6 +15,7 @@ from .googlenet import * from .inceptionv3 import * from .lenet import * +from .llama import * from .pnasnet import * from .preact_resnet import * from .resnet import * diff --git a/tests/torch/test_models/llama.py b/tests/torch/test_models/llama.py new file mode 100644 index 00000000000..fae6b6be9de --- /dev/null +++ b/tests/torch/test_models/llama.py @@ -0,0 +1,168 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +EMBED_DIM = 64 +N_HEADS = 4 +HEAD_DIM = EMBED_DIM // N_HEADS +# Same as Llama 3.2 config +ROPE_THETA = 500000.0 +MAX_SEQ = 128 +BIAS = False + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Copied from src/transformers/models/llama/modeling_llama.py + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +def _rotate_half(x): + """ + Copied from src/transformers/models/llama/modeling_llama.py + Rotates half the hidden dims of the input. + """ + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class Rotary(nn.Module): + """ + Precompute cos/sin for RoPE and apply to q,k. + Copied from src/transformers/models/llama/modeling_llama.py + Initialize the cos and sin value once in init method + """ + + # Llama applies rotary to q,k before attention; see modeling_llama + def __init__(self, head_dim: int, max_seq_len: int = MAX_SEQ, theta: float = ROPE_THETA, device=None): + super().__init__() + dtype = torch.float32 + inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=dtype, device=device) / head_dim)) + t = torch.arange(max_seq_len, dtype=dtype, device=device) + freqs = torch.einsum("t,f->tf", t, inv_freq) # (T, Hd/2) + emb = torch.cat((freqs, freqs), dim=-1) # (T, Hd) + self.register_buffer("cos", emb.cos()[None, None, ...], persistent=False) # (1,1,T,Hd) + self.register_buffer("sin", emb.sin()[None, None, ...], persistent=False) + + def forward(self, q: torch.Tensor, k: torch.Tensor, pos: torch.Tensor): + cos = self.cos[..., pos, :] + sin = self.sin[..., pos, :] + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + """ + Copied from src/transformers/models/llama/modeling_llama.py + """ + + def __init__(self, dim: int, mult: int = 2): + super().__init__() + # mult is used as a scaling factor of sorts. This is to define the hidden/intermediate layer size + hidden = mult * dim + self.gate_proj = nn.Linear(dim, hidden, bias=BIAS) + self.up_proj = nn.Linear(dim, hidden, bias=BIAS) + self.down_proj = nn.Linear(hidden, dim, bias=BIAS) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class LlamaDecoderOnly(nn.Module): + """ + One Llama-style transformer block (pre-norm attn + MLP) with RoPE and KV cache. + Forward takes embeddings only. + """ + + # KV caching + past_key_values flow mirrors HF implementations. :contentReference[oaicite:4]{index=4} + def __init__(self, dim: int = EMBED_DIM, n_heads: int = N_HEADS): + super().__init__() + assert dim % n_heads == 0 + self.n_heads = n_heads + self.head_dim = dim // n_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + + self.attn_norm = LlamaRMSNorm(dim) + self.q_proj = nn.Linear(dim, dim, bias=BIAS) + self.k_proj = nn.Linear(dim, dim, bias=BIAS) + self.v_proj = nn.Linear(dim, dim, bias=BIAS) + self.o_proj = nn.Linear(dim, dim, bias=BIAS) + self.rope = Rotary(self.head_dim, MAX_SEQ, theta=ROPE_THETA) + + self.mlp_norm = LlamaRMSNorm(dim) + self.mlp = LlamaMLP(dim) + + def _attn(self, x: torch.Tensor, pos: torch.Tensor, past_kv: Optional[tuple[torch.Tensor, torch.Tensor]]): + """ + Code from LlamaAttention forward method. SDPA implementation similar to model.config._attn_implementation="SDPA" + """ + B, T, C = x.shape + H, Hd = self.n_heads, self.head_dim + + # QKV projections from hidden state x + q = self.q_proj(x).view(B, T, H, Hd).transpose(1, 2) + k = self.k_proj(x).view(B, T, H, Hd).transpose(1, 2) + v = self.v_proj(x).view(B, T, H, Hd).transpose(1, 2) + + # RoPE + q, k = self.rope(q, k, pos) + + # KV cache + if past_kv is not None: + pk, pv = past_kv # (B,H,Tpast,Hd) + k = torch.cat([pk, k], dim=2) + v = torch.cat([pv, v], dim=2) + + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(B, T, C) + y = self.o_proj(y) + return y, (k, v) + + def forward( + self, + x_embed: torch.Tensor, # (B, T_new, C) embeddings only + past_kv: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # (B,H,Tpast,Hd) + ): + # positions for the *new* tokens only + past_len = 0 if past_kv is None else past_kv[0].size(2) + T_new = x_embed.size(1) + pos = torch.arange(past_len, past_len + T_new, device=x_embed.device) + + # pre-norm attention + residual + y, _kv = self._attn(self.attn_norm(x_embed), pos, past_kv) + x = x_embed + y + + # pre-norm MLP + residual + x = x + self.mlp(self.mlp_norm(x)) + return x