Skip to content

Commit 08f1636

Browse files
zewenli98peri044
authored andcommitted
fix: convert_module_to_trt_engine (#2728)
1 parent bec91fb commit 08f1636

File tree

4 files changed

+41
-37
lines changed

4 files changed

+41
-37
lines changed

docsrc/py_api/dynamo.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Functions
2222

2323
.. autofunction:: export
2424

25+
.. autofunction:: convert_module_to_trt_engine
26+
2527

2628

2729
Classes

py/torch_tensorrt/_compile.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import collections.abc
34
import logging
45
from enum import Enum
56
from typing import Any, Callable, List, Optional, Sequence, Set
@@ -237,8 +238,6 @@ def compile(
237238
return compiled_fx_module
238239
elif target_ir == _IRType.dynamo:
239240
# Prepare torch and torchtrt inputs
240-
import collections.abc
241-
242241
from torch_tensorrt.dynamo.utils import prepare_inputs
243242

244243
if not isinstance(input_list, collections.abc.Sequence):
@@ -342,10 +341,19 @@ def convert_method_to_trt_engine(
342341
"convert_method_to_trt_engine call is not supported for ir=fx"
343342
)
344343
elif target_ir == _IRType.dynamo:
344+
# Prepare torch and torchtrt inputs
345+
from torch_tensorrt.dynamo.utils import prepare_inputs
346+
347+
if not isinstance(inputs, collections.abc.Sequence):
348+
inputs = [inputs]
349+
350+
# Export the module
351+
torchtrt_inputs = prepare_inputs(inputs)
352+
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
353+
345354
return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return]
346-
module,
355+
exp_program,
347356
inputs=inputs,
348-
method_name=method_name,
349357
enabled_precisions=enabled_precisions_set,
350358
**kwargs,
351359
)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
422422

423423

424424
def convert_module_to_trt_engine(
425-
module: torch.fx.GraphModule,
426-
method_name: str = "forward",
425+
exported_program: ExportedProgram,
427426
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
428427
enabled_precisions: (
429428
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
@@ -453,15 +452,15 @@ def convert_module_to_trt_engine(
453452
calibrator: object = None,
454453
allow_shape_tensors: bool = False,
455454
) -> bytes:
456-
"""Convert a GraphModule module method to a serialized TensorRT engine
455+
"""Convert an ExportedProgram to a serialized TensorRT engine
457456
458-
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
457+
Converts an ExportedProgram to a serialized TensorRT engine given a dictionary of conversion settings
459458
460459
Arguments:
461-
module (torch.fx.GraphModule): Source module
460+
exported_program (torch.export.ExportedProgram): Source module
462461
463462
Keyword Args:
464-
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
463+
inputs (Optional[Sequence[torch_tensorrt.Input | torch.Tensor]]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
465464
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
466465
to select device type. ::
467466
@@ -476,30 +475,11 @@ def convert_module_to_trt_engine(
476475
), # Dynamic input shape for input #2
477476
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
478477
]
479-
480-
method_name (str): Name of method to convert
481-
input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
482-
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** ::
483-
484-
input_signature=([
485-
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
486-
torch_tensorrt.Input(
487-
min_shape=(1, 224, 224, 3),
488-
opt_shape=(1, 512, 512, 3),
489-
max_shape=(1, 1024, 1024, 3),
490-
dtype=torch.int32
491-
format=torch.channel_last
492-
), # Dynamic input shape for input #2
493-
], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3
494-
495-
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
496-
497-
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
498-
478+
enabled_precisions (Optional[Set[torch.dtype | _enums.dtype]]): The set of datatypes that TensorRT can use
499479
debug (bool): Whether to print out verbose debugging information
500480
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
501481
min_block_size (int): Minimum number of operators per TRT-Engine Block
502-
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
482+
torch_executed_ops (Set[str]): Set of operations to run in Torch, regardless of converter coverage
503483
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
504484
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
505485
version_compatible (bool): Provide version forward-compatibility for engine plan files
@@ -566,13 +546,25 @@ def convert_module_to_trt_engine(
566546
"dla_global_dram_size": dla_global_dram_size,
567547
}
568548

549+
# Decompose the exported program
550+
exported_program = exported_program.run_decompositions(
551+
get_decompositions(enable_experimental_decompositions)
552+
)
553+
gm = exported_program.module()
554+
logger.debug("Input graph: " + str(gm.graph))
555+
556+
# Apply lowering on the graph module
557+
torch_inputs = get_torch_inputs(input_list, device)
558+
gm = apply_lowering_passes(gm, torch_inputs)
559+
logger.debug("Lowered Input graph: " + str(gm.graph))
560+
569561
settings = CompilationSettings(**compilation_options)
570562
logger.info("Compilation Settings: %s\n", settings)
571563
try:
572-
interpreter_result = interpret_module_to_result(module, input_list, settings)
564+
interpreter_result = interpret_module_to_result(gm, input_list, settings)
573565
except UnsupportedOperatorException:
574566
logger.error(
575-
f"Conversion of module {module} not currently fully supported or convertible!",
567+
f"Conversion of module {gm} not currently fully supported or convertible!",
576568
exc_info=True,
577569
)
578570
except Exception as e:

tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py renamed to tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
88

99

10-
class TestConvertMethodToTrtEngine(unittest.TestCase):
10+
class TestConvertModuleToTrtEngine(unittest.TestCase):
1111
def test_convert_module(self):
1212
class Test(torch.nn.Module):
1313
def forward(self, a, b):
@@ -18,19 +18,21 @@ def forward(self, a, b):
1818

1919
# Create a model
2020
model = Test()
21-
symbolic_traced_gm = torch.fx.symbolic_trace(model)
21+
exp_program = torch.export.export(model, (input_data_0, input_data_1))
2222

2323
# Convert to TensorRT engine
2424
trt_engine_str = torch_tensorrt.dynamo.convert_module_to_trt_engine(
25-
symbolic_traced_gm, "forward", inputs=[input_data_0, input_data_1]
25+
exp_program, inputs=(input_data_0, input_data_1)
2626
)
2727

2828
# Deserialize the TensorRT engine
2929
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
3030
engine = runtime.deserialize_cuda_engine(trt_engine_str)
3131

3232
# Inference on TRT Engine
33-
py_trt_module = PythonTorchTensorRTModule(engine, ["a", "b"], ["output0"])
33+
py_trt_module = PythonTorchTensorRTModule(
34+
engine, ["arg0_1", "arg1_1"], ["output0"]
35+
)
3436
trt_output = py_trt_module(input_data_0, input_data_1).cpu()
3537

3638
# Inference on PyTorch model

0 commit comments

Comments
 (0)