@@ -420,6 +420,264 @@ def forward(self, x):
420420 f"MaxPool3d TRT outputs don't match with the original model." ,
421421 )
422422
423+ def test_lowering_slice_scatter_dimZero_module (self ):
424+ class sliceScatter (torch .nn .Module ):
425+ def __init__ (self , * args , ** kwargs ) -> None :
426+ super ().__init__ (* args , ** kwargs )
427+
428+ def forward (self , x , src , dim , start , end , step ):
429+ y = torch .ops .aten .slice_scatter .default (x , src , dim , start , end , step )
430+ return y
431+
432+ # Operations expected to be removed in the traced graph after decompositions
433+ expected_ops = {
434+ torch .ops .aten .scatter .src ,
435+ }
436+ unexpected_ops = {torch .ops .aten .slice_scatter }
437+
438+ inputs = [torch .zeros (8 , 8 ).cuda (), torch .ones (2 , 8 ).cuda (), 0 , 6 , None , 1 ]
439+
440+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
441+
442+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
443+ fx_graph ,
444+ inputs ,
445+ expected_ops = expected_ops ,
446+ unexpected_ops = unexpected_ops ,
447+ min_block_size = 1 ,
448+ )
449+
450+ self .assertEquals (
451+ len (unexpected_ops_seen ),
452+ 0 ,
453+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
454+ )
455+
456+ self .assertEquals (
457+ len (expected_ops_unseen ),
458+ 0 ,
459+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
460+ )
461+
462+ torch ._dynamo .reset ()
463+
464+ # Validate that the results between Torch and Torch-TRT are similar
465+ optimized_model = torch_tensorrt .compile (
466+ fx_graph ,
467+ "torch_compile" ,
468+ inputs ,
469+ min_block_size = 1 ,
470+ truncate_long_and_double = True ,
471+ pass_through_build_failures = True ,
472+ )
473+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
474+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
475+
476+ max_diff = float (
477+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
478+ )
479+ self .assertAlmostEqual (
480+ max_diff ,
481+ 0 ,
482+ DECIMALS_OF_AGREEMENT ,
483+ f"Slice_scatter TRT outputs don't match with the original model." ,
484+ )
485+
486+ def test_lowering_slice_scatter_dimOne_module (self ):
487+ class sliceScatter (torch .nn .Module ):
488+ def __init__ (self , * args , ** kwargs ) -> None :
489+ super ().__init__ (* args , ** kwargs )
490+
491+ def forward (self , x , src , dim , start = None , end = None , step = 1 ):
492+ y = torch .ops .aten .slice_scatter (x , src , dim , start , end , step )
493+ return y
494+
495+ # Operations expected to be removed in the traced graph after decompositions
496+ expected_ops = {
497+ torch .ops .aten .scatter .src ,
498+ }
499+ unexpected_ops = {torch .ops .aten .select_scatter }
500+
501+ inputs = [torch .zeros (8 , 8 ).cuda (), torch .ones (8 , 2 ).cuda (), 1 , 6 , None , 1 ]
502+
503+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
504+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
505+ fx_graph ,
506+ inputs ,
507+ expected_ops = expected_ops ,
508+ unexpected_ops = unexpected_ops ,
509+ min_block_size = 1 ,
510+ )
511+
512+ self .assertEquals (
513+ len (unexpected_ops_seen ),
514+ 0 ,
515+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
516+ )
517+
518+ self .assertEquals (
519+ len (expected_ops_unseen ),
520+ 0 ,
521+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
522+ )
523+
524+ torch ._dynamo .reset ()
525+
526+ # Validate that the results between Torch and Torch-TRT are similar
527+ optimized_model = torch_tensorrt .compile (
528+ fx_graph ,
529+ "torch_compile" ,
530+ inputs ,
531+ min_block_size = 1 ,
532+ truncate_long_and_double = True ,
533+ pass_through_build_failures = True ,
534+ )
535+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
536+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
537+
538+ max_diff = float (
539+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
540+ )
541+ self .assertAlmostEqual (
542+ max_diff ,
543+ 0 ,
544+ DECIMALS_OF_AGREEMENT ,
545+ f"Slice_scatter TRT outputs don't match with the original model." ,
546+ )
547+
548+ def test_lowering_slice_scatter_dimZero_StepTwo_module (self ):
549+ class sliceScatter (torch .nn .Module ):
550+ def __init__ (self , * args , ** kwargs ) -> None :
551+ super ().__init__ (* args , ** kwargs )
552+
553+ def forward (self , x , src , dim , start , end , step ):
554+ y = torch .ops .aten .slice_scatter .default (x , src , dim , start , end , step )
555+ return y
556+
557+ # Operations expected to be removed in the traced graph after decompositions
558+ expected_ops = {
559+ torch .ops .aten .scatter .src ,
560+ }
561+ unexpected_ops = {torch .ops .aten .slice_scatter }
562+
563+ inputs = [torch .zeros (8 , 8 ).cuda (), torch .ones (2 , 8 ).cuda (), 0 , 2 , 6 , 2 ]
564+
565+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
566+
567+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
568+ fx_graph ,
569+ inputs ,
570+ expected_ops = expected_ops ,
571+ unexpected_ops = unexpected_ops ,
572+ min_block_size = 1 ,
573+ )
574+
575+ self .assertEquals (
576+ len (unexpected_ops_seen ),
577+ 0 ,
578+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
579+ )
580+
581+ self .assertEquals (
582+ len (expected_ops_unseen ),
583+ 0 ,
584+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
585+ )
586+
587+ torch ._dynamo .reset ()
588+
589+ # Validate that the results between Torch and Torch-TRT are similar
590+ optimized_model = torch_tensorrt .compile (
591+ fx_graph ,
592+ "torch_compile" ,
593+ inputs ,
594+ min_block_size = 1 ,
595+ truncate_long_and_double = True ,
596+ pass_through_build_failures = True ,
597+ )
598+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
599+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
600+
601+ max_diff = float (
602+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
603+ )
604+ self .assertAlmostEqual (
605+ max_diff ,
606+ 0 ,
607+ DECIMALS_OF_AGREEMENT ,
608+ f"Slice_scatter TRT outputs don't match with the original model." ,
609+ )
610+
611+ def test_lowering_slice_scatter_dimOne_3d_module (self ):
612+ class sliceScatter (torch .nn .Module ):
613+ def __init__ (self , * args , ** kwargs ) -> None :
614+ super ().__init__ (* args , ** kwargs )
615+
616+ def forward (self , x , src , dim , start , end , step ):
617+ y = torch .ops .aten .slice_scatter .default (x , src , dim , start , end , step )
618+ return y
619+
620+ # Operations expected to be removed in the traced graph after decompositions
621+ expected_ops = {
622+ torch .ops .aten .scatter .src ,
623+ }
624+ unexpected_ops = {torch .ops .aten .slice_scatter }
625+
626+ inputs = [
627+ torch .zeros (8 , 8 , 8 ).cuda (),
628+ torch .ones (8 , 2 , 8 ).cuda (),
629+ 1 ,
630+ 6 ,
631+ None ,
632+ 1 ,
633+ ]
634+
635+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
636+
637+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
638+ fx_graph ,
639+ inputs ,
640+ expected_ops = expected_ops ,
641+ unexpected_ops = unexpected_ops ,
642+ min_block_size = 1 ,
643+ )
644+
645+ self .assertEquals (
646+ len (unexpected_ops_seen ),
647+ 0 ,
648+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
649+ )
650+
651+ self .assertEquals (
652+ len (expected_ops_unseen ),
653+ 0 ,
654+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
655+ )
656+
657+ torch ._dynamo .reset ()
658+
659+ # Validate that the results between Torch and Torch-TRT are similar
660+ optimized_model = torch_tensorrt .compile (
661+ fx_graph ,
662+ "torch_compile" ,
663+ inputs ,
664+ min_block_size = 1 ,
665+ truncate_long_and_double = True ,
666+ pass_through_build_failures = True ,
667+ )
668+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
669+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
670+
671+ max_diff = float (
672+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
673+ )
674+ self .assertAlmostEqual (
675+ max_diff ,
676+ 0 ,
677+ DECIMALS_OF_AGREEMENT ,
678+ f"Slice_scatter TRT outputs don't match with the original model." ,
679+ )
680+
423681
424682if __name__ == "__main__" :
425683 run_tests ()
0 commit comments