diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 844cb6789a..d7cfe15694 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -12,7 +12,7 @@ TRTInterpreterResult, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device +from torch_tensorrt.dynamo.utils import get_torch_inputs def interpret_module_to_result( @@ -29,7 +29,6 @@ def interpret_module_to_result( TRTInterpreterResult """ torch_inputs = get_torch_inputs(inputs, settings.device) - module.to(to_torch_device(settings.device)) module_outputs = module(*torch_inputs) if not isinstance(module_outputs, (list, tuple)): diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 2443e33d50..49e6dd9d3e 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -11,12 +11,9 @@ # Modify import location of utilities based on Torch version if version.parse(sanitized_torch_version()) < version.parse("2.1.1"): - from torch._inductor.freezing import ConstantFolder, replace_node_with_constant + from torch._inductor.freezing import ConstantFolder else: - from torch._inductor.constant_folding import ( - ConstantFolder, - replace_node_with_constant, - ) + from torch._inductor.constant_folding import ConstantFolder logger = logging.getLogger(__name__) @@ -36,7 +33,9 @@ def constant_fold( cf.run() for node, constant in cf.node_replacements.items(): - replace_node_with_constant(gm, node, constant) + replace_node_with_constant( + gm, node, torch.nn.Parameter(constant.cuda(), requires_grad=False) + ) erased_params = [] for node in gm.graph.nodes: @@ -60,6 +59,40 @@ def constant_fold( return gm +def replace_node_with_constant( + gm: torch.fx.GraphModule, node: torch.fx.Node, constant: torch.Tensor +) -> None: + """Adapted from: + https://github.com/pytorch/pytorch/blob/bcf35c6ae62bb6560befa3550e37a8283944e5f4/torch/_inductor/constant_folding.py#L17-L43 + + Modified to register parameters, instead of buffers for frozen constants + """ + g = gm.graph + + if not hasattr(gm, "_frozen_param_count"): + gm._frozen_param_count = 0 + + i = gm._frozen_param_count + + while True: + qualname = f"_frozen_param{i}" + if not hasattr(gm, qualname): + break + i += 1 + + gm._frozen_param_count = i + 1 + + with g.inserting_before(node): + new_input_node = g.create_node("get_attr", qualname, (), {}) + node.replace_all_uses_with(new_input_node) + new_input_node.meta.update(node.meta) + g.erase_node(node) + + # Needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning + gm.register_parameter(qualname, constant) + setattr(gm, qualname, constant) + + # TODO: Delete this class when the following code is fixed in nightly: # https://github.com/pytorch/pytorch/blob/4b881b0da390c1290bb12850ef9daad6f6eb2cb6/torch/_inductor/constant_folding.py#L53-L63 class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]