Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,19 @@ def slice_scatter_decomposition(
return output_tensor


@register_torch_trt_decomposition(
torch.ops.aten.select_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS
)
def select_scatter_decomposition(
input_tensor: torch.Tensor,
src_tensor: torch.Tensor,
dim: int,
index: int,
) -> torch.Tensor:
src_tensor = torch.unsqueeze(src_tensor, dim)
return torch.slice_scatter(input_tensor, src_tensor, dim, index, index + 1, 1)


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
195 changes: 192 additions & 3 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def forward(self, x, src, dim, start=None, end=None, step=1):
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
Expand Down Expand Up @@ -593,7 +593,7 @@ def forward(self, x, src, dim, start, end, step):
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
Expand Down Expand Up @@ -663,7 +663,7 @@ def forward(self, x, src, dim, start, end, step):
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
Expand All @@ -679,6 +679,195 @@ def forward(self, x, src, dim, start, end, step):
f"Slice_scatter TRT outputs don't match with the original model.",
)

def test_lowering_select_scatter_dimZero_module(self):
class selectScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, index):
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
unexpected_ops = {
torch.ops.aten.select_scatter.default,
torch.ops.aten.slice_scatter.default,
}

inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 0, 0]

fx_graph = torch.fx.symbolic_trace(selectScatter())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)

def test_lowering_select_scatter_dimOne_module(self):
class selectScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, index):
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
unexpected_ops = {
torch.ops.aten.select_scatter.default,
torch.ops.aten.slice_scatter.default,
}

inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 1, 0]

fx_graph = torch.fx.symbolic_trace(selectScatter())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)

def test_lowering_select_scatter_multidimension_module(self):
class selectScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, index):
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
unexpected_ops = {
torch.ops.aten.select_scatter.default,
torch.ops.aten.slice_scatter.default,
}

inputs = [torch.zeros(2, 3, 4).cuda(), torch.ones(2, 4).cuda(), 1, 0]

fx_graph = torch.fx.symbolic_trace(selectScatter())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()