@@ -421,5 +421,71 @@ def forward(self, x):
421421 )
422422
423423
424+ def test_lowering_select_scatter_module (self ):
425+ class selectScatter (torch .nn .Module ):
426+ def __init__ (self , * args , ** kwargs ) -> None :
427+ super ().__init__ (* args , ** kwargs )
428+
429+ def forward (self , x , src , dim , start ):
430+ y = self .slice_scatter (x , src , dim , start )
431+ return y
432+
433+ # Operations expected to be removed in the traced graph after decompositions
434+ expected_ops = {
435+ torch .ops .aten .lt .default ,
436+ torch .ops .aten .lt .default ,
437+ torch .ops .aten .expand .default ,
438+ torch .ops .aten .eq .default ,
439+ torch .ops .aten .where .default ,
440+
441+ }
442+ unexpected_ops = {torch .ops .aten .select_scatter }
443+
444+ inputs = [torch .randn (2 , 2 ), torch .ones (2 )]
445+
446+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
447+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
448+ fx_graph ,
449+ inputs ,
450+ expected_ops = expected_ops ,
451+ unexpected_ops = unexpected_ops ,
452+ min_block_size = 1 ,
453+ )
454+
455+ self .assertEquals (
456+ len (unexpected_ops_seen ),
457+ 0 ,
458+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
459+ )
460+
461+ self .assertEquals (
462+ len (expected_ops_unseen ),
463+ 0 ,
464+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
465+ )
466+
467+ torch ._dynamo .reset ()
468+
469+ # Validate that the results between Torch and Torch-TRT are similar
470+ optimized_model = torch_tensorrt .compile (
471+ fx_graph ,
472+ "torch_compile" ,
473+ inputs ,
474+ min_block_size = 1 ,
475+ pass_through_build_failures = True ,
476+ )
477+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
478+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
479+
480+ max_diff = float (
481+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
482+ )
483+ self .assertAlmostEqual (
484+ max_diff ,
485+ 0 ,
486+ DECIMALS_OF_AGREEMENT ,
487+ f"Select_scatter TRT outputs don't match with the original model." ,
488+ )
489+
424490if __name__ == "__main__" :
425491 run_tests ()
0 commit comments