@@ -546,6 +546,68 @@ def forward(self, x, src, dim, index):
546546 f"Select_scatter TRT outputs don't match with the original model." ,
547547 )
548548
549+ def test_lowering_select_scatter_multidimension_module (self ):
550+ class selectScatter (torch .nn .Module ):
551+ def __init__ (self , * args , ** kwargs ) -> None :
552+ super ().__init__ (* args , ** kwargs )
553+
554+ def forward (self , x , src , dim , index ):
555+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
556+ return y
557+
558+ # Operations expected to be removed in the traced graph after decompositions
559+ expected_ops = {torch .ops .aten .scatter .src , torch .ops .aten .unsqueeze .default }
560+ unexpected_ops = {
561+ torch .ops .aten .select_scatter .default ,
562+ torch .ops .aten .slice_scatter .default ,
563+ }
564+
565+ inputs = [torch .zeros (2 , 3 , 4 ).cuda (), torch .ones (2 , 4 ).cuda (), 1 , 0 ]
566+
567+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
568+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
569+ fx_graph ,
570+ inputs ,
571+ expected_ops = expected_ops ,
572+ unexpected_ops = unexpected_ops ,
573+ min_block_size = 1 ,
574+ )
575+
576+ self .assertEquals (
577+ len (unexpected_ops_seen ),
578+ 0 ,
579+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
580+ )
581+
582+ self .assertEquals (
583+ len (expected_ops_unseen ),
584+ 0 ,
585+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
586+ )
587+
588+ torch ._dynamo .reset ()
589+
590+ # Validate that the results between Torch and Torch-TRT are similar
591+ optimized_model = torch_tensorrt .compile (
592+ fx_graph ,
593+ "torch_compile" ,
594+ inputs ,
595+ min_block_size = 1 ,
596+ truncate_long_and_double = True ,
597+ pass_through_build_failures = True ,
598+ )
599+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
600+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
601+
602+ max_diff = float (
603+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
604+ )
605+ self .assertAlmostEqual (
606+ max_diff ,
607+ 0 ,
608+ DECIMALS_OF_AGREEMENT ,
609+ f"Select_scatter TRT outputs don't match with the original model." ,
610+ )
549611
550612if __name__ == "__main__" :
551613 run_tests ()
0 commit comments