File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change 1414import copy
1515from typing import TYPE_CHECKING , Dict , List , Union
1616
17+ from torch import nn
18+
1719from ..utils import logging
1820
1921
@@ -52,7 +54,7 @@ def _maybe_expand_lora_scales(
5254 weight_for_adapter ,
5355 blocks_with_transformer ,
5456 transformer_per_block ,
55- unet . state_dict () ,
57+ model = unet ,
5658 default_scale = default_scale ,
5759 )
5860 for weight_for_adapter in weight_scales
@@ -65,7 +67,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
6567 scales : Union [float , Dict ],
6668 blocks_with_transformer : Dict [str , int ],
6769 transformer_per_block : Dict [str , int ],
68- state_dict : None ,
70+ model : nn . Module ,
6971 default_scale : float = 1.0 ,
7072):
7173 """
@@ -154,6 +156,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
154156
155157 del scales [updown ]
156158
159+ state_dict = model .state_dict ()
157160 for layer in scales .keys ():
158161 if not any (_translate_into_actual_layer_name (layer ) in module for module in state_dict .keys ()):
159162 raise ValueError (
You can’t perform that action at this time.
0 commit comments