@@ -420,30 +420,91 @@ def forward(self, x):
420420 f"MaxPool3d TRT outputs don't match with the original model." ,
421421 )
422422
423-
424- def test_lowering_select_scatter_module (self ):
425- class selectScatter (torch .nn .Module ):
423+ def test_lowering_slice_scatter_dimZero_module (self ):
424+ class sliceScatter (torch .nn .Module ):
426425 def __init__ (self , * args , ** kwargs ) -> None :
427426 super ().__init__ (* args , ** kwargs )
428427
429- def forward (self , x , src , dim , start ):
430- y = self .slice_scatter (x , src , dim , start )
428+ def forward (self , x , src , dim , start = None , end = None , step = 1 ):
429+ y = self .slice_scatter (x , src , dim , start , end , step )
431430 return y
432431
433432 # Operations expected to be removed in the traced graph after decompositions
434433 expected_ops = {
435- torch .ops .aten .lt .default ,
436- torch .ops .aten .lt .default ,
437- torch .ops .aten .expand .default ,
438- torch .ops .aten .eq .default ,
439- torch .ops .aten .where .default ,
434+ torch .ops .aten .slice .Tensor ,
435+ torch .ops .aten .squeeze .dim ,
436+ torch .ops .aten .cat .default ,
437+ torch .ops .aten .index .Tensor ,
438+ }
439+ unexpected_ops = {torch .ops .aten .select_scatter }
440+
441+ inputs = [torch .zeros (8 , 8 ).cuda (), torch .ones (2 , 8 ).cuda (), 0 , 6 ]
442+
443+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
444+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
445+ fx_graph ,
446+ inputs ,
447+ expected_ops = expected_ops ,
448+ unexpected_ops = unexpected_ops ,
449+ min_block_size = 1 ,
450+ )
451+
452+ self .assertEquals (
453+ len (unexpected_ops_seen ),
454+ 0 ,
455+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
456+ )
457+
458+ self .assertEquals (
459+ len (expected_ops_unseen ),
460+ 0 ,
461+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
462+ )
463+
464+ torch ._dynamo .reset ()
465+
466+ # Validate that the results between Torch and Torch-TRT are similar
467+ optimized_model = torch_tensorrt .compile (
468+ fx_graph ,
469+ "torch_compile" ,
470+ inputs ,
471+ min_block_size = 1 ,
472+ pass_through_build_failures = True ,
473+ )
474+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
475+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
440476
477+ max_diff = float (
478+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
479+ )
480+ self .assertAlmostEqual (
481+ max_diff ,
482+ 0 ,
483+ DECIMALS_OF_AGREEMENT ,
484+ f"Slice_scatter TRT outputs don't match with the original model." ,
485+ )
486+
487+ def test_lowering_slice_scatter_dimOne_module (self ):
488+ class sliceScatter (torch .nn .Module ):
489+ def __init__ (self , * args , ** kwargs ) -> None :
490+ super ().__init__ (* args , ** kwargs )
491+
492+ def forward (self , x , src , dim , start = None , end = None , step = 1 ):
493+ y = self .slice_scatter (x , src , dim , start , end , step )
494+ return y
495+
496+ # Operations expected to be removed in the traced graph after decompositions
497+ expected_ops = {
498+ torch .ops .aten .slice .Tensor ,
499+ torch .ops .aten .squeeze .dim ,
500+ torch .ops .aten .cat .default ,
501+ torch .ops .aten .index .Tensor ,
441502 }
442503 unexpected_ops = {torch .ops .aten .select_scatter }
443504
444- inputs = [torch .randn ( 2 , 2 ) , torch .ones (2 ) ]
505+ inputs = [torch .zeros ( 8 , 8 ). cuda () , torch .ones (2 , 8 ). cuda (), 0 , 6 ]
445506
446- fx_graph = torch .fx .symbolic_trace (selectScatter ())
507+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
447508 unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
448509 fx_graph ,
449510 inputs ,
@@ -484,8 +545,9 @@ def forward(self, x, src, dim, start):
484545 max_diff ,
485546 0 ,
486547 DECIMALS_OF_AGREEMENT ,
487- f"Select_scatter TRT outputs don't match with the original model." ,
548+ f"Slice_scatter TRT outputs don't match with the original model." ,
488549 )
489550
551+
490552if __name__ == "__main__" :
491553 run_tests ()
0 commit comments