Skip to content

Commit 50dbeab

Browse files
committed
slice_scatter decomposition
1 parent 7578fa8 commit 50dbeab

File tree

2 files changed

+104
-0
lines changed

2 files changed

+104
-0
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,44 @@ def var_decomposition(
162162
return variance
163163

164164

165+
@register_torch_trt_decomposition(
166+
torch.ops.slice_scatter, registry=TORCH_TRT_DECOMPOSITIONS
167+
)
168+
def slice_scatter_decomposition(
169+
input_tensor: torch.Tensor,
170+
src_tensor: torch.Tensor,
171+
dim: int,
172+
start: Optional[int],
173+
end: Optional[int],
174+
step: int,
175+
):
176+
dim_size = input_tensor.shape[dim]
177+
input_tensor_shape = input_tensor.shape
178+
if start is not None and start < 0:
179+
start = start + dim_size
180+
if end is not None and end < 0:
181+
end = end + dim_size
182+
if start is None:
183+
start = 0
184+
if end is None:
185+
end = dim_size
186+
187+
src_dim = list(src_tensor.shape())
188+
src_dim[dim] = torch.floor_divide(end - start, step)
189+
src = torch.expand(src, src_dim)
190+
191+
if (start == 0 and end == dim_size and step == 0):
192+
return input_tensor
193+
mask = []
194+
if start != 0:
195+
mask.append(torch.ge(input_tensor_shape, start))
196+
if end != dim_size:
197+
mask.append(torch.ge(input_tensor_shape, end))
198+
if step != 1:
199+
mask.append(torch.eq(src_dim, 0))
200+
src_val = torch.masked(mask, src_dim, 0)
201+
return torch.where(mask, src_val,input_tensor)
202+
165203
def get_decompositions(
166204
enable_experimental_decompositions: bool = False,
167205
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,5 +421,71 @@ def forward(self, x):
421421
)
422422

423423

424+
def test_lowering_select_scatter_module(self):
425+
class selectScatter(torch.nn.Module):
426+
def __init__(self, *args, **kwargs) -> None:
427+
super().__init__(*args, **kwargs)
428+
429+
def forward(self, x, src, dim, start):
430+
y = self.slice_scatter(x, src, dim, start)
431+
return y
432+
433+
# Operations expected to be removed in the traced graph after decompositions
434+
expected_ops = {
435+
torch.ops.aten.lt.default,
436+
torch.ops.aten.lt.default,
437+
torch.ops.aten.expand.default,
438+
torch.ops.aten.eq.default,
439+
torch.ops.aten.where.default,
440+
441+
}
442+
unexpected_ops = {torch.ops.aten.select_scatter}
443+
444+
inputs = [torch.randn(2, 2), torch.ones(2)]
445+
446+
fx_graph = torch.fx.symbolic_trace(selectScatter())
447+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
448+
fx_graph,
449+
inputs,
450+
expected_ops=expected_ops,
451+
unexpected_ops=unexpected_ops,
452+
min_block_size=1,
453+
)
454+
455+
self.assertEquals(
456+
len(unexpected_ops_seen),
457+
0,
458+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
459+
)
460+
461+
self.assertEquals(
462+
len(expected_ops_unseen),
463+
0,
464+
f"The following expected ops were not encountered: {expected_ops_unseen}",
465+
)
466+
467+
torch._dynamo.reset()
468+
469+
# Validate that the results between Torch and Torch-TRT are similar
470+
optimized_model = torch_tensorrt.compile(
471+
fx_graph,
472+
"torch_compile",
473+
inputs,
474+
min_block_size=1,
475+
pass_through_build_failures=True,
476+
)
477+
optimized_model_results = optimized_model(*inputs).detach().cpu()
478+
torch_model_results = fx_graph(*inputs).detach().cpu()
479+
480+
max_diff = float(
481+
torch.max(torch.abs(optimized_model_results - torch_model_results))
482+
)
483+
self.assertAlmostEqual(
484+
max_diff,
485+
0,
486+
DECIMALS_OF_AGREEMENT,
487+
f"Select_scatter TRT outputs don't match with the original model.",
488+
)
489+
424490
if __name__ == "__main__":
425491
run_tests()

0 commit comments

Comments
 (0)