@@ -2157,3 +2157,94 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
21572157
21582158 pipe_float8_e4m3_bf16 = initialize_pipeline (storage_dtype = torch .float8_e4m3fn , compute_dtype = torch .bfloat16 )
21592159 pipe_float8_e4m3_bf16 (** inputs , generator = torch .manual_seed (0 ))[0 ]
2160+
2161+ @require_peft_version_greater ("0.14.0" )
2162+ def test_layerwise_casting_peft_input_autocast_denoiser (self ):
2163+ r"""
2164+ A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This
2165+ is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise
2166+ cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`).
2167+ In this test, we enable the layerwise casting on the PEFT layers as well. If run with PEFT version <= 0.14.0,
2168+ this test will fail with the following error:
2169+
2170+ ```
2171+ RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != float
2172+ ```
2173+
2174+ See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
2175+ """
2176+
2177+ from diffusers .hooks .layerwise_casting import (
2178+ _PEFT_AUTOCAST_DISABLE_HOOK ,
2179+ DEFAULT_SKIP_MODULES_PATTERN ,
2180+ SUPPORTED_PYTORCH_LAYERS ,
2181+ apply_layerwise_casting ,
2182+ )
2183+
2184+ storage_dtype = torch .float8_e4m3fn
2185+ compute_dtype = torch .float32
2186+
2187+ def check_module (denoiser ):
2188+ # This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
2189+ for name , module in denoiser .named_modules ():
2190+ if not isinstance (module , SUPPORTED_PYTORCH_LAYERS ):
2191+ continue
2192+ dtype_to_check = storage_dtype
2193+ if any (re .search (pattern , name ) for pattern in patterns_to_check ):
2194+ dtype_to_check = compute_dtype
2195+ if getattr (module , "weight" , None ) is not None :
2196+ self .assertEqual (module .weight .dtype , dtype_to_check )
2197+ if getattr (module , "bias" , None ) is not None :
2198+ self .assertEqual (module .bias .dtype , dtype_to_check )
2199+ if isinstance (module , BaseTunerLayer ):
2200+ self .assertTrue (getattr (module , "_diffusers_hook" , None ) is not None )
2201+ self .assertTrue (module ._diffusers_hook .get_hook (_PEFT_AUTOCAST_DISABLE_HOOK ) is not None )
2202+
2203+ # 1. Test forward with add_adapter
2204+ components , _ , denoiser_lora_config = self .get_dummy_components (self .scheduler_classes [0 ])
2205+ pipe = self .pipeline_class (** components )
2206+ pipe = pipe .to (torch_device , dtype = compute_dtype )
2207+ pipe .set_progress_bar_config (disable = None )
2208+
2209+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
2210+ denoiser .add_adapter (denoiser_lora_config )
2211+ self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
2212+
2213+ patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
2214+ if getattr (denoiser , "_skip_layerwise_casting_patterns" , None ) is not None :
2215+ patterns_to_check += tuple (denoiser ._skip_layerwise_casting_patterns )
2216+
2217+ apply_layerwise_casting (
2218+ denoiser , storage_dtype = storage_dtype , compute_dtype = compute_dtype , skip_modules_pattern = patterns_to_check
2219+ )
2220+ check_module (denoiser )
2221+
2222+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2223+ pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2224+
2225+ # 2. Test forward with load_lora_weights
2226+ with tempfile .TemporaryDirectory () as tmpdirname :
2227+ modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
2228+ lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
2229+ self .pipeline_class .save_lora_weights (
2230+ save_directory = tmpdirname , safe_serialization = True , ** lora_state_dicts
2231+ )
2232+
2233+ self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" )))
2234+ components , _ , _ = self .get_dummy_components (self .scheduler_classes [0 ])
2235+ pipe = self .pipeline_class (** components )
2236+ pipe = pipe .to (torch_device , dtype = compute_dtype )
2237+ pipe .set_progress_bar_config (disable = None )
2238+ pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.safetensors" ))
2239+
2240+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
2241+ apply_layerwise_casting (
2242+ denoiser ,
2243+ storage_dtype = storage_dtype ,
2244+ compute_dtype = compute_dtype ,
2245+ skip_modules_pattern = patterns_to_check ,
2246+ )
2247+ check_module (denoiser )
2248+
2249+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2250+ pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
0 commit comments