1111
1212# Modify import location of utilities based on Torch version
1313if version .parse (sanitized_torch_version ()) < version .parse ("2.1.1" ):
14- from torch ._inductor .freezing import ConstantFolder , replace_node_with_constant
14+ from torch ._inductor .freezing import ConstantFolder
1515else :
16- from torch ._inductor .constant_folding import (
17- ConstantFolder ,
18- replace_node_with_constant ,
19- )
16+ from torch ._inductor .constant_folding import ConstantFolder
2017
2118logger = logging .getLogger (__name__ )
2219
@@ -36,7 +33,9 @@ def constant_fold(
3633 cf .run ()
3734
3835 for node , constant in cf .node_replacements .items ():
39- replace_node_with_constant (gm , node , constant )
36+ replace_node_with_constant (
37+ gm , node , torch .nn .Parameter (constant .cuda (), requires_grad = False )
38+ )
4039
4140 erased_params = []
4241 for node in gm .graph .nodes :
@@ -55,6 +54,40 @@ def constant_fold(
5554 return gm
5655
5756
57+ def replace_node_with_constant (
58+ gm : torch .fx .GraphModule , node : torch .fx .Node , constant : torch .Tensor
59+ ) -> None :
60+ """Adapted from:
61+ https://github.com/pytorch/pytorch/blob/bcf35c6ae62bb6560befa3550e37a8283944e5f4/torch/_inductor/constant_folding.py#L17-L43
62+
63+ Modified to register parameters, instead of buffers for frozen constants
64+ """
65+ g = gm .graph
66+
67+ if not hasattr (gm , "_frozen_param_count" ):
68+ gm ._frozen_param_count = 0
69+
70+ i = gm ._frozen_param_count
71+
72+ while True :
73+ qualname = f"_frozen_param{ i } "
74+ if not hasattr (gm , qualname ):
75+ break
76+ i += 1
77+
78+ gm ._frozen_param_count = i + 1
79+
80+ with g .inserting_before (node ):
81+ new_input_node = g .create_node ("get_attr" , qualname , (), {})
82+ node .replace_all_uses_with (new_input_node )
83+ new_input_node .meta .update (node .meta )
84+ g .erase_node (node )
85+
86+ # Needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
87+ gm .register_parameter (qualname , constant )
88+ setattr (gm , qualname , constant )
89+
90+
5891# TODO: Delete this class when the following code is fixed in nightly:
5992# https://github.com/pytorch/pytorch/blob/4b881b0da390c1290bb12850ef9daad6f6eb2cb6/torch/_inductor/constant_folding.py#L53-L63
6093class _TorchTensorRTConstantFolder (ConstantFolder ): # type: ignore[misc]
0 commit comments