1111
1212# Modify import location of utilities based on Torch version
1313if version .parse (sanitized_torch_version ()) < version .parse ("2.1.1" ):
14- from torch ._inductor .freezing import ConstantFolder , replace_node_with_constant
14+ from torch ._inductor .freezing import ConstantFolder
1515else :
16- from torch ._inductor .constant_folding import (
17- ConstantFolder ,
18- replace_node_with_constant ,
19- )
16+ from torch ._inductor .constant_folding import ConstantFolder
2017
2118logger = logging .getLogger (__name__ )
2219
@@ -36,7 +33,7 @@ def constant_fold(
3633 cf .run ()
3734
3835 for node , constant in cf .node_replacements .items ():
39- replace_node_with_constant (gm , node , constant )
36+ replace_node_with_constant (gm , node , torch . nn . Parameter ( constant . cuda ()) )
4037
4138 erased_params = []
4239 for node in gm .graph .nodes :
@@ -60,6 +57,35 @@ def constant_fold(
6057 return gm
6158
6259
60+ def replace_node_with_constant (
61+ gm : torch .fx .GraphModule , node : torch .fx .Node , constant : torch .Tensor
62+ ) -> None :
63+ g = gm .graph
64+
65+ if not hasattr (gm , "_frozen_param_count" ):
66+ gm ._frozen_param_count = 0
67+
68+ i = gm ._frozen_param_count
69+
70+ while True :
71+ qualname = f"_frozen_param{ i } "
72+ if not hasattr (gm , qualname ):
73+ break
74+ i += 1
75+
76+ gm ._frozen_param_count = i + 1
77+
78+ with g .inserting_before (node ):
79+ new_input_node = g .create_node ("get_attr" , qualname , (), {})
80+ node .replace_all_uses_with (new_input_node )
81+ new_input_node .meta .update (node .meta )
82+ g .erase_node (node )
83+
84+ # Needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
85+ gm .register_parameter (qualname , constant )
86+ setattr (gm , qualname , constant )
87+
88+
6389# TODO: Delete this class when the following code is fixed in nightly:
6490# https://github.com/pytorch/pytorch/blob/4b881b0da390c1290bb12850ef9daad6f6eb2cb6/torch/_inductor/constant_folding.py#L53-L63
6591class _TorchTensorRTConstantFolder (ConstantFolder ): # type: ignore[misc]
0 commit comments