@@ -497,6 +497,59 @@ TEST(Converters, ATenConvTransposeConvertsCorrectly) {
497497 ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
498498}
499499
500+ TEST (Converters, ATenConvTranspose2dWithWeightsAsTensorsConvertsCorrectly) {
501+ const auto graph = R"IR(
502+ graph(%0 : Tensor,
503+ %1 : Float(48, 56, 3, 3, strides=[504, 9, 3, 1])):
504+ %2 : int = prim::Constant[value=-128]()
505+ %3 : float = prim::Constant[value=3.5]()
506+ %4 : int = prim::Constant[value=0]()
507+ %5 : int = prim::Constant[value=127]()
508+ %quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5)
509+ %6 : int = prim::Constant[value=6]()
510+ %7 : int = prim::Constant[value=56]()
511+ %8 : Device = prim::Constant[value="cuda:0"]()
512+ %9 : None = prim::Constant()
513+ %10 : int[] = prim::ListConstruct(%7)
514+ %11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9)
515+ %12 : int[] = prim::ListConstruct(%7)
516+ %13 : int = prim::Constant[value=1]()
517+ %14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9)
518+ %quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5)
519+ %15 : None = prim::Constant()
520+ %16 : bool = prim::Constant[value=1]()
521+ %17 : int = prim::Constant[value=1]() # Adjusted padding
522+ %17.1: int = prim::Constant[value=0]() # Adjusted out_padding
523+ %18 : int = prim::Constant[value=1]() # Adjusted dilation
524+ %19 : int = prim::Constant[value=2]() # Adjusted stride
525+ %20 : int = prim::Constant[value=1]()
526+ %21 : int[] = prim::ListConstruct(%17)
527+ %22 : int[] = prim::ListConstruct(%17, %17)
528+ %23 : int[] = prim::ListConstruct(%18, %18)
529+ %23.1: int[] = prim::ListConstruct(%17.1, %17.1)
530+ %24 : int[] = prim::ListConstruct(%19, %19)
531+ %25 : Tensor = aten::_convolution(%quant_input, %quant_wts, %15, %24, %22, %23, %16, %23.1, %17, %16, %16, %16, %16)
532+ return (%25))IR" ;
533+
534+ auto g = std::make_shared<torch::jit::Graph>();
535+ torch::jit::parseIR (graph, g.get ());
536+
537+ auto in = at::randint (1 , 10 , {1 , 48 , 2 , 200 }, {at::kCUDA });
538+ auto w = at::randint (1 , 2 , {48 , 56 , 3 , 3 }, {at::kCUDA });
539+
540+ auto jit_in = at::clone (in);
541+ auto jit_w = at::clone (w);
542+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
543+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in, jit_w});
544+
545+ auto trt_in = at::clone (in);
546+ auto trt_w = at::clone (w);
547+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
548+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in, trt_w}, nvinfer1::DataType::kINT8 );
549+
550+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
551+ }
552+
500553TEST (Converters, ATenConvTransposeNoBiasConvertsCorrectly) {
501554 const auto graph = R"IR(
502555 graph(%0 : Tensor,
0 commit comments