66from torch .testing ._internal .common_utils import run_tests
77from torch_tensorrt import Input
88
9+ grid_sampler_aten_ops = {
10+ "torch.ops.aten.grid_sampler" : torch .ops .aten .grid_sampler ,
11+ "torch.ops.aten.grid_sampler_2d" : torch .ops .aten .grid_sampler_2d ,
12+ "torch.ops.aten.grid_sampler.default" : torch .ops .aten .grid_sampler .default ,
13+ "torch.ops.aten.grid_sampler_2d.default" : torch .ops .aten .grid_sampler_2d .default ,
14+ }
15+
916grid_sampler_ops = [
1017 (
1118 "input_grid_interpolation_nearest_sample_fill" ,
12- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 0 , True )),
19+ "torch.ops.aten.grid_sampler" ,
20+ (lambda x , grid , op : op (x , grid , 0 , 0 , True )),
1321 [1 , 1 , 5 , 5 ],
1422 [1 , 5 , 2 , 2 ],
1523 ),
1624 (
1725 "input_grid_interpolation_nearest_sample_clamp" ,
18- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 1 , True )),
26+ "torch.ops.aten.grid_sampler" ,
27+ (lambda x , grid , op : op (x , grid , 0 , 1 , True )),
1928 [1 , 1 , 5 , 5 ],
2029 [1 , 5 , 2 , 2 ],
2130 ),
2231 (
2332 "input_grid_interpolation_nearest_sample_reflect" ,
24- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 2 , True )),
33+ "torch.ops.aten.grid_sampler" ,
34+ (lambda x , grid , op : op (x , grid , 0 , 2 , True )),
2535 [1 , 1 , 5 , 5 ],
2636 [1 , 5 , 2 , 2 ],
2737 ),
2838 (
2939 "input_grid_interpolation_linear_sample_fill" ,
30- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 0 , True )),
40+ "torch.ops.aten.grid_sampler" ,
41+ (lambda x , grid , op : op (x , grid , 1 , 0 , True )),
3142 [1 , 1 , 5 , 5 ],
3243 [1 , 5 , 2 , 2 ],
3344 ),
3445 (
3546 "input_grid_interpolation_linear_sample_clamp" ,
36- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 1 , True )),
47+ "torch.ops.aten.grid_sampler" ,
48+ (lambda x , grid , op : op (x , grid , 1 , 1 , True )),
3749 [1 , 1 , 5 , 5 ],
3850 [1 , 5 , 2 , 2 ],
3951 ),
4052 (
4153 "input_grid_interpolation_linear_sample_reflect" ,
42- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 2 , True )),
54+ "torch.ops.aten.grid_sampler" ,
55+ (lambda x , grid , op : op (x , grid , 1 , 2 , True )),
4356 [1 , 1 , 5 , 5 ],
4457 [1 , 5 , 2 , 2 ],
4558 ),
4659 (
4760 "input_grid_interpolation_cubic_sample_fill" ,
48- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 0 , True )),
61+ "torch.ops.aten.grid_sampler" ,
62+ (lambda x , grid , op : op (x , grid , 2 , 0 , True )),
4963 [1 , 1 , 5 , 5 ],
5064 [1 , 5 , 2 , 2 ],
5165 ),
5266 (
5367 "input_grid_interpolation_cubic_sample_clamp" ,
54- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 1 , True )),
68+ "torch.ops.aten.grid_sampler" ,
69+ (lambda x , grid , op : op (x , grid , 2 , 1 , True )),
5570 [1 , 1 , 5 , 5 ],
5671 [1 , 5 , 2 , 2 ],
5772 ),
5873 (
5974 "input_grid_interpolation_cubic_sample_reflect" ,
60- (lambda x , grid : torch .ops .aten .grid_sampler (x , grid , 0 , 2 , True )),
61- [1 , 1 , 5 , 5 ],
62- [1 , 5 , 2 , 2 ],
63- ),
64- (
65- "input_grid_interpolation_nearest_sample_fill_2d" ,
66- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 0 , True )),
67- [1 , 1 , 5 , 5 ],
68- [1 , 5 , 2 , 2 ],
69- ),
70- (
71- "input_grid_interpolation_nearest_sample_clamp_2d" ,
72- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 1 , True )),
73- [1 , 1 , 5 , 5 ],
74- [1 , 5 , 2 , 2 ],
75- ),
76- (
77- "input_grid_interpolation_nearest_sample_reflect_2d" ,
78- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 2 , True )),
79- [1 , 1 , 5 , 5 ],
80- [1 , 5 , 2 , 2 ],
81- ),
82- (
83- "input_grid_interpolation_linear_sample_fill_2d" ,
84- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 0 , True )),
85- [1 , 1 , 5 , 5 ],
86- [1 , 5 , 2 , 2 ],
87- ),
88- (
89- "input_grid_interpolation_linear_sample_clamp_2d" ,
90- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 1 , True )),
91- [1 , 1 , 5 , 5 ],
92- [1 , 5 , 2 , 2 ],
93- ),
94- (
95- "input_grid_interpolation_linear_sample_reflect_2d" ,
96- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 2 , True )),
97- [1 , 1 , 5 , 5 ],
98- [1 , 5 , 2 , 2 ],
99- ),
100- (
101- "input_grid_interpolation_cubic_sample_fill_2d" ,
102- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 0 , True )),
103- [1 , 1 , 5 , 5 ],
104- [1 , 5 , 2 , 2 ],
105- ),
106- (
107- "input_grid_interpolation_cubic_sample_clamp_2d" ,
108- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 1 , True )),
109- [1 , 1 , 5 , 5 ],
110- [1 , 5 , 2 , 2 ],
111- ),
112- (
113- "input_grid_interpolation_cubic_sample_reflect_2d" ,
114- (lambda x , grid : torch .ops .aten .grid_sampler_2d (x , grid , 0 , 2 , True )),
75+ "torch.ops.aten.grid_sampler" ,
76+ (lambda x , grid , op : op (x , grid , 2 , 2 , True )),
11577 [1 , 1 , 5 , 5 ],
11678 [1 , 5 , 2 , 2 ],
11779 ),
@@ -126,19 +88,98 @@ class TestGridConverter(DispatchTestCase):
12688 grid_sampler_op [1 ],
12789 grid_sampler_op [2 ],
12890 grid_sampler_op [3 ],
91+ grid_sampler_op [4 ],
92+ )
93+ for grid_sampler_op in grid_sampler_ops
94+ ]
95+ )
96+ def test_grid (self , _ , op_name , op , input_shape , dim_shape ):
97+ class TestModule (nn .Module ):
98+ def __init__ (self , grid_sampler_op ):
99+ super ().__init__ ()
100+ self .grid_sampler_op = grid_sampler_op
101+
102+ def forward (self , x ):
103+ grid = torch .randint (- 1 , 1 , dim_shape , dtype = torch .float32 )
104+ return self .grid_sampler_op (x , grid , grid_sampler_aten_ops [op_name ])
105+
106+ inputs = [torch .randn (input_shape , dtype = torch .float32 )]
107+ grid_model = TestModule (op )
108+ self .run_test (grid_model , inputs )
109+
110+ @parameterized .expand (
111+ [
112+ (
113+ grid_sampler_op [0 ],
114+ grid_sampler_op [1 ] + "_2d" ,
115+ grid_sampler_op [2 ],
116+ grid_sampler_op [3 ],
117+ grid_sampler_op [4 ],
118+ )
119+ for grid_sampler_op in grid_sampler_ops
120+ ]
121+ )
122+ def test_grid_2d (self , _ , op_name , op , input_shape , dim_shape ):
123+ class TestModule (nn .Module ):
124+ def __init__ (self , grid_sampler_op ):
125+ super ().__init__ ()
126+ self .grid_sampler_op = grid_sampler_op
127+
128+ def forward (self , x ):
129+ grid = torch .randint (- 1 , 1 , dim_shape , dtype = torch .float32 )
130+ return self .grid_sampler_op (x , grid , grid_sampler_aten_ops [op_name ])
131+
132+ inputs = [torch .randn (input_shape , dtype = torch .float32 )]
133+ grid_model = TestModule (op )
134+ self .run_test (grid_model , inputs )
135+
136+ @parameterized .expand (
137+ [
138+ (
139+ grid_sampler_op [0 ],
140+ grid_sampler_op [1 ] + ".default" ,
141+ grid_sampler_op [2 ],
142+ grid_sampler_op [3 ],
143+ grid_sampler_op [4 ],
144+ )
145+ for grid_sampler_op in grid_sampler_ops
146+ ]
147+ )
148+ def test_grid_default (self , _ , op_name , op , input_shape , dim_shape ):
149+ class TestModule (nn .Module ):
150+ def __init__ (self , grid_sampler_op ):
151+ super ().__init__ ()
152+ self .grid_sampler_op = grid_sampler_op
153+
154+ def forward (self , x ):
155+ grid = torch .randint (- 1 , 1 , dim_shape , dtype = torch .float32 )
156+ return self .grid_sampler_op (x , grid , grid_sampler_aten_ops [op_name ])
157+
158+ inputs = [torch .randn (input_shape , dtype = torch .float32 )]
159+ grid_model = TestModule (op )
160+ self .run_test (grid_model , inputs )
161+
162+ @parameterized .expand (
163+ [
164+ (
165+ grid_sampler_op [0 ],
166+ grid_sampler_op [1 ] + "_2d.default" ,
167+ grid_sampler_op [2 ],
168+ grid_sampler_op [3 ],
169+ grid_sampler_op [4 ],
129170 )
130171 for grid_sampler_op in grid_sampler_ops
131172 ]
132173 )
133- def test_grid (self , _ , op , input_shape , dim_shape ):
174+ def test_grid_2d_default (self , _ , op_name , op , input_shape , dim_shape ):
134175 class TestModule (nn .Module ):
135176 def __init__ (self , grid_sampler_op ):
136177 super ().__init__ ()
137178 self .grid_sampler_op = grid_sampler_op
138179
139180 def forward (self , x ):
140181 grid = torch .randint (- 1 , 1 , dim_shape , dtype = torch .float32 )
141- return self .grid_sampler_op (x , grid )
182+ return self .grid_sampler_op (x , grid , grid_sampler_aten_ops [ op_name ] )
142183
143184 inputs = [torch .randn (input_shape , dtype = torch .float32 )]
144185 grid_model = TestModule (op )
0 commit comments