Skip to content

Commit e8418d4

Browse files
committed
Add 16A8W FCNode support with BMM dependency fix
Add 16A8W quantization support for FCNode operations with BMM dependency fix in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, sigmoid, tanh, slice, view/transpose, and cat operations, extending int16 support to FCNode operations. Changes: - Add INT16 dtype validation support in op_bmm.py - Add test_addmm_tensor_16a8w_tosa_INT test function - Enable test_addmm.py in test targets configuration - Fix BMM dependency for FCNode operations The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. Differential Revision: [D80512504](https://our.internmc.facebook.com/intern/diff/D80512504/) [ghstack-poisoned]
1 parent a237e06 commit e8418d4

File tree

3 files changed

+111
-3
lines changed

3 files changed

+111
-3
lines changed

backends/arm/operators/op_bmm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def define_node(
5555
validate_valid_dtype(
5656
self.target,
5757
[*inputs, output],
58-
[ts.DType.INT8, ts.DType.FP32],
58+
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP32],
5959
output.tosa_spec,
6060
)
6161

@@ -93,7 +93,8 @@ def define_node(
9393
if output.dtype == ts.DType.INT8:
9494
output_qparams = get_output_qparams(node)[0]
9595
final_output_scale = (
96-
input_qparams[0].get_scale_per_tensor() * input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61]
96+
input_qparams[0].get_scale_per_tensor()
97+
* input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61]
9798
) / output_qparams.get_scale_per_tensor()
9899

99100
build_rescale(

backends/arm/test/ops/test_addmm.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,23 @@
55

66
from typing import Tuple
77

8+
import pytest
89
import torch
10+
from executorch.backends.arm.quantizer.arm_quantizer import (
11+
get_symmetric_a16w8_quantization_config,
12+
TOSAQuantizer,
13+
)
914

10-
from executorch.backends.arm.test import common
15+
from executorch.backends.arm.test import common, conftest
1116
from executorch.backends.arm.test.tester.test_pipeline import (
1217
EthosU55PipelineINT,
1318
EthosU85PipelineINT,
1419
TosaPipelineFP,
1520
TosaPipelineINT,
1621
VgfPipeline,
1722
)
23+
from executorch.backends.arm.tosa_specification import TosaSpecification
24+
from executorch.backends.xnnpack.test.tester import Quantize
1825

1926
aten_op = "torch.ops.aten.addmm.default"
2027

@@ -182,3 +189,102 @@ def test_addmm_vgf_INT(test_data: input_t1):
182189
tosa_version="TOSA-1.0+INT",
183190
)
184191
pipeline.run()
192+
193+
194+
def get_symmetric_a16w8_addmm_quantizer(per_channel_quantization=False):
195+
tosa_version = conftest.get_option("tosa_version")
196+
tosa_profiles = {
197+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
198+
}
199+
200+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
201+
quantizer.set_global(
202+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
203+
)
204+
205+
return Quantize(
206+
quantizer,
207+
get_symmetric_a16w8_quantization_config(
208+
is_per_channel=per_channel_quantization
209+
),
210+
)
211+
212+
213+
@common.parametrize("test_data", test_data_suite)
214+
def test_addmm_16a8w_tosa_INT(test_data: input_t1):
215+
"""Test addmm (FC layer) operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
216+
per_channel_quantization = False
217+
218+
pipeline = TosaPipelineINT[input_t1](
219+
Addmm(),
220+
(*test_data,),
221+
aten_op=[],
222+
exir_op=[],
223+
per_channel_quantization=per_channel_quantization,
224+
use_to_edge_transform_and_lower=True,
225+
tosa_extensions=["int16"],
226+
)
227+
228+
pipeline.change_args(
229+
"quantize",
230+
get_symmetric_a16w8_addmm_quantizer(
231+
per_channel_quantization=per_channel_quantization
232+
),
233+
)
234+
pipeline.run()
235+
236+
237+
@common.parametrize("test_data", test_data_suite)
238+
@common.XfailIfNoCorstone300
239+
@pytest.mark.xfail(
240+
reason="Vela compilation fails with 'Invalid arguments' for int16 addmm operations"
241+
)
242+
def test_addmm_16a8w_u55_INT16(test_data: input_t1):
243+
"""Test addmm (FC layer) operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
244+
per_channel_quantization = False
245+
246+
pipeline = EthosU55PipelineINT[input_t1](
247+
Addmm(),
248+
(*test_data,),
249+
aten_ops=[],
250+
exir_ops=[],
251+
per_channel_quantization=per_channel_quantization,
252+
use_to_edge_transform_and_lower=True,
253+
run_on_fvp=True,
254+
)
255+
256+
pipeline.change_args(
257+
"quantize",
258+
get_symmetric_a16w8_addmm_quantizer(
259+
per_channel_quantization=per_channel_quantization
260+
),
261+
)
262+
pipeline.run()
263+
264+
265+
@common.parametrize("test_data", test_data_suite)
266+
@common.XfailIfNoCorstone320
267+
@pytest.mark.xfail(
268+
reason="Vela compilation fails with 'Invalid arguments' for int16 addmm operations"
269+
)
270+
def test_addmm_16a8w_u85_INT16(test_data: input_t1):
271+
"""Test addmm (FC layer) operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
272+
per_channel_quantization = False
273+
274+
pipeline = EthosU85PipelineINT[input_t1](
275+
Addmm(),
276+
(*test_data,),
277+
aten_ops=[],
278+
exir_ops=[],
279+
per_channel_quantization=per_channel_quantization,
280+
use_to_edge_transform_and_lower=True,
281+
run_on_fvp=True,
282+
)
283+
284+
pipeline.change_args(
285+
"quantize",
286+
get_symmetric_a16w8_addmm_quantizer(
287+
per_channel_quantization=per_channel_quantization
288+
),
289+
)
290+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def define_arm_tests():
1414
# Operators
1515
test_files += [
1616
"ops/test_add.py",
17+
"ops/test_addmm.py",
1718
"ops/test_avg_pool2d.py",
1819
"ops/test_cat.py",
1920
"ops/test_linear.py",

0 commit comments

Comments
 (0)