@@ -249,7 +249,7 @@ def run_test(
249249 input = inputs [num_input ]
250250 if input .dtype in (torch .int64 , torch .float64 ):
251251 dtype_32bit = (
252- torch .int32 if (input .dtype == torch .int64 ) else torch .int64
252+ torch .int32 if (input .dtype == torch .int64 ) else torch .float32
253253 )
254254 # should we modify graph here to insert clone nodes?
255255 # ideally not required
@@ -259,7 +259,7 @@ def run_test(
259259 input .to (dtype_32bit ),
260260 ]
261261 + list (trt_inputs [num_input + 1 :])
262- )
262+ )
263263
264264 trt_input_specs = [Input .from_tensor (i ) for i in trt_inputs ]
265265 input_specs = [Input .from_tensor (i ) for i in inputs ]
@@ -270,7 +270,7 @@ def run_test(
270270 mod ,
271271 input_specs ,
272272 compilation_settings .device ,
273- truncate_long_and_double = compilation_settings .truncate_long_and_double ,
273+ truncate_double = compilation_settings .truncate_double ,
274274 )
275275
276276 _LOGGER .debug (f"Compilation settings: { compilation_settings } " )
@@ -316,7 +316,7 @@ def run_test_compare_tensor_attributes_only(
316316 # We replicate this behavior here
317317 compilation_settings = CompilationSettings (
318318 enabled_precisions = {dtype ._from (precision )},
319- truncate_long_and_double = True ,
319+ truncate_and_double = True ,
320320 debug = True ,
321321 )
322322
0 commit comments