File tree Expand file tree Collapse file tree 1 file changed +7
-8
lines changed Expand file tree Collapse file tree 1 file changed +7
-8
lines changed Original file line number Diff line number Diff line change @@ -475,14 +475,13 @@ def forward(self, x):
475475        optimized_model_results  =  optimized_model (* inputs ).detach ().cpu ()
476476        torch_model_results  =  fx_graph (* inputs ).detach ().cpu ()
477477
478-         max_diff  =  float (
479-             torch .max (torch .abs (optimized_model_results  -  torch_model_results ))
480-         )
481-         self .assertAlmostEqual (
482-             max_diff ,
483-             0 ,
484-             DECIMALS_OF_AGREEMENT ,
485-             f"Select_scatter TRT outputs don't match with the original model." ,
478+         optimized_model_results_shape  =  optimized_model_results .size ()
479+         torch_model_results_shape  =  torch_model_results .size ()
480+ 
481+         self .assertEquals (
482+             optimized_model_results_shape ,
483+             torch_model_results_shape ,
484+             f"The optimized model results shape and torch model results shape should be equal in empty_like" ,
486485        )
487486
488487
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments