Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
TRTInterpreterResult,
)
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device
from torch_tensorrt.dynamo.utils import get_torch_inputs


def interpret_module_to_result(
Expand All @@ -29,7 +29,6 @@ def interpret_module_to_result(
TRTInterpreterResult
"""
torch_inputs = get_torch_inputs(inputs, settings.device)
module.to(to_torch_device(settings.device))
module_outputs = module(*torch_inputs)

if not isinstance(module_outputs, (list, tuple)):
Expand Down
45 changes: 39 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@

# Modify import location of utilities based on Torch version
if version.parse(sanitized_torch_version()) < version.parse("2.1.1"):
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
from torch._inductor.freezing import ConstantFolder
else:
from torch._inductor.constant_folding import (
ConstantFolder,
replace_node_with_constant,
)
from torch._inductor.constant_folding import ConstantFolder

logger = logging.getLogger(__name__)

Expand All @@ -36,7 +33,9 @@ def constant_fold(
cf.run()

for node, constant in cf.node_replacements.items():
replace_node_with_constant(gm, node, constant)
replace_node_with_constant(
gm, node, torch.nn.Parameter(constant.cuda(), requires_grad=False)
)

erased_params = []
for node in gm.graph.nodes:
Expand All @@ -60,6 +59,40 @@ def constant_fold(
return gm


def replace_node_with_constant(
gm: torch.fx.GraphModule, node: torch.fx.Node, constant: torch.Tensor
) -> None:
"""Adapted from:
https://github.com/pytorch/pytorch/blob/bcf35c6ae62bb6560befa3550e37a8283944e5f4/torch/_inductor/constant_folding.py#L17-L43

Modified to register parameters, instead of buffers for frozen constants
"""
g = gm.graph

if not hasattr(gm, "_frozen_param_count"):
gm._frozen_param_count = 0

i = gm._frozen_param_count

while True:
qualname = f"_frozen_param{i}"
if not hasattr(gm, qualname):
break
i += 1

gm._frozen_param_count = i + 1

with g.inserting_before(node):
new_input_node = g.create_node("get_attr", qualname, (), {})
node.replace_all_uses_with(new_input_node)
new_input_node.meta.update(node.meta)
g.erase_node(node)

# Needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
gm.register_parameter(qualname, constant)
setattr(gm, qualname, constant)


# TODO: Delete this class when the following code is fixed in nightly:
# https://github.com/pytorch/pytorch/blob/4b881b0da390c1290bb12850ef9daad6f6eb2cb6/torch/_inductor/constant_folding.py#L53-L63
class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]
Expand Down