@@ -341,6 +341,7 @@ def forward(
341341 block_controlnet_hidden_states : List = None ,
342342 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
343343 return_dict : bool = True ,
344+ skip_layers : Optional [List [int ]] = None ,
344345 ) -> Union [torch .FloatTensor , Transformer2DModelOutput ]:
345346 """
346347 The [`SD3Transformer2DModel`] forward method.
@@ -363,6 +364,8 @@ def forward(
363364 return_dict (`bool`, *optional*, defaults to `True`):
364365 Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
365366 tuple.
367+ skip_layers (`list` of `int`, *optional*):
368+ A list of layer indices to skip during the forward pass.
366369
367370 Returns:
368371 If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
@@ -390,7 +393,10 @@ def forward(
390393 encoder_hidden_states = self .context_embedder (encoder_hidden_states )
391394
392395 for index_block , block in enumerate (self .transformer_blocks ):
393- if self .training and self .gradient_checkpointing :
396+ # Skip specified layers
397+ is_skip = True if skip_layers is not None and index_block in skip_layers else False
398+
399+ if torch .is_grad_enabled () and self .gradient_checkpointing and not is_skip :
394400
395401 def create_custom_forward (module , return_dict = None ):
396402 def custom_forward (* inputs ):
@@ -410,8 +416,7 @@ def custom_forward(*inputs):
410416 joint_attention_kwargs ,
411417 ** ckpt_kwargs ,
412418 )
413-
414- else :
419+ elif not is_skip :
415420 encoder_hidden_states , hidden_states = block (
416421 hidden_states = hidden_states , encoder_hidden_states = encoder_hidden_states , temb = temb ,
417422 joint_attention_kwargs = joint_attention_kwargs ,
0 commit comments