diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index bd0848a2d63f..ee58b8ac244d 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -241,97 +241,85 @@ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: freqs_cis = [] - # Use float32 for MPS compatibility - dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): - emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=dtype) + emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype) freqs_cis.append(emb) return freqs_cis def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: + device = ids.device + if ids.device.type == "mps": + ids = ids.to("cpu") + result = [] for i in range(len(self.axes_dim)): freqs = self.freqs_cis[i].to(ids.device) index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) - return torch.cat(result, dim=-1) + return torch.cat(result, dim=-1).to(device) def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): - batch_size = len(hidden_states) - p_h = p_w = self.patch_size - device = hidden_states[0].device + batch_size, channels, height, width = hidden_states.shape + p = self.patch_size + post_patch_height, post_patch_width = height // p, width // p + image_seq_len = post_patch_height * post_patch_width + device = hidden_states.device + encoder_seq_len = attention_mask.shape[1] l_effective_cap_len = attention_mask.sum(dim=1).tolist() - # TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape - img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] - l_effective_img_len = [(H // p_h) * (W // p_w) for (H, W) in img_sizes] - - max_seq_len = max((cap_len + img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))) - max_img_len = max(l_effective_img_len) + seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len] + max_seq_len = max(seq_lengths) + # Create position IDs position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) - for i in range(batch_size): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - H, W = img_sizes[i] - H_tokens, W_tokens = H // p_h, W // p_w - assert H_tokens * W_tokens == img_len + for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + # add caption position ids + position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device) + position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len - position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) - position_ids[i, cap_len : cap_len + img_len, 0] = cap_len + # add image position ids row_ids = ( - torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() + torch.arange(post_patch_height, dtype=torch.int32, device=device) + .view(-1, 1) + .repeat(1, post_patch_width) + .flatten() ) col_ids = ( - torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() + torch.arange(post_patch_width, dtype=torch.int32, device=device) + .view(1, -1) + .repeat(post_patch_height, 1) + .flatten() ) - position_ids[i, cap_len : cap_len + img_len, 1] = row_ids - position_ids[i, cap_len : cap_len + img_len, 2] = col_ids + position_ids[i, cap_seq_len:seq_len, 1] = row_ids + position_ids[i, cap_seq_len:seq_len, 2] = col_ids + # Get combined rotary embeddings freqs_cis = self._get_freqs_cis(position_ids) - cap_freqs_cis_shape = list(freqs_cis.shape) - cap_freqs_cis_shape[1] = attention_mask.shape[1] - cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - img_freqs_cis_shape = list(freqs_cis.shape) - img_freqs_cis_shape[1] = max_img_len - img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - for i in range(batch_size): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len] - - flat_hidden_states = [] - for i in range(batch_size): - img = hidden_states[i] - C, H, W = img.size() - img = img.view(C, H // p_h, p_h, W // p_w, p_w).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) - flat_hidden_states.append(img) - hidden_states = flat_hidden_states - padded_img_embed = torch.zeros( - batch_size, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype + # create separate rotary embeddings for captions and images + cap_freqs_cis = torch.zeros( + batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype ) - padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device) - for i in range(batch_size): - padded_img_embed[i, : l_effective_img_len[i]] = hidden_states[i] - padded_img_mask[i, : l_effective_img_len[i]] = True - - return ( - padded_img_embed, - padded_img_mask, - img_sizes, - l_effective_cap_len, - l_effective_img_len, - freqs_cis, - cap_freqs_cis, - img_freqs_cis, - max_seq_len, + img_freqs_cis = torch.zeros( + batch_size, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + + for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] + img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len] + + # image patch embeddings + hidden_states = ( + hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p) + .permute(0, 2, 4, 3, 5, 1) + .flatten(3) + .flatten(1, 2) ) + return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths + class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): r""" @@ -471,75 +459,63 @@ def forward( hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - use_mask_in_transformer: bool = True, + encoder_attention_mask: torch.Tensor, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: - batch_size = hidden_states.size(0) - # 1. Condition, positional & patch embedding + batch_size, _, height, width = hidden_states.shape + temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states) ( hidden_states, - hidden_mask, - hidden_sizes, - encoder_hidden_len, - hidden_len, - joint_rotary_emb, - encoder_rotary_emb, - hidden_rotary_emb, - max_seq_len, - ) = self.rope_embedder(hidden_states, attention_mask) + context_rotary_emb, + noise_rotary_emb, + rotary_emb, + encoder_seq_lengths, + seq_lengths, + ) = self.rope_embedder(hidden_states, encoder_attention_mask) hidden_states = self.x_embedder(hidden_states) # 2. Context & noise refinement for layer in self.context_refiner: - # NOTE: mask not used for performance - encoder_hidden_states = layer( - encoder_hidden_states, attention_mask if use_mask_in_transformer else None, encoder_rotary_emb - ) + encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb) for layer in self.noise_refiner: - # NOTE: mask not used for performance - hidden_states = layer( - hidden_states, hidden_mask if use_mask_in_transformer else None, hidden_rotary_emb, temb - ) + hidden_states = layer(hidden_states, None, noise_rotary_emb, temb) + + # 3. Joint Transformer blocks + max_seq_len = max(seq_lengths) + use_mask = len(set(seq_lengths)) > 1 + + attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + attention_mask[i, :seq_len] = True + joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len] + joint_hidden_states[i, encoder_seq_len:seq_len] = hidden_states[i] + + hidden_states = joint_hidden_states - # 3. Attention mask preparation - mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) - padded_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) - for i in range(batch_size): - cap_len = encoder_hidden_len[i] - img_len = hidden_len[i] - mask[i, : cap_len + img_len] = True - padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len] - padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len] - hidden_states = padded_hidden_states - - # 4. Transformer blocks for layer in self.layers: - # NOTE: mask not used for performance if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( - layer, hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb + layer, hidden_states, attention_mask if use_mask else None, rotary_emb, temb ) else: - hidden_states = layer(hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb) + hidden_states = layer(hidden_states, attention_mask if use_mask else None, rotary_emb, temb) - # 5. Output norm & projection & unpatchify + # 4. Output norm & projection hidden_states = self.norm_out(hidden_states, temb) - height_tokens = width_tokens = self.config.patch_size + # 5. Unpatchify + p = self.config.patch_size output = [] - for i in range(len(hidden_sizes)): - height, width = hidden_sizes[i] - begin = encoder_hidden_len[i] - end = begin + (height // height_tokens) * (width // width_tokens) + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): output.append( - hidden_states[i][begin:end] - .view(height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels) + hidden_states[i][encoder_seq_len:seq_len] + .view(height // p, width // p, p, p, self.out_channels) .permute(4, 0, 2, 1, 3) .flatten(3, 4) .flatten(1, 2) diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 801ed25093a3..3478a6140e8e 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -24,8 +24,6 @@ from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - is_bs4_available, - is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring, @@ -44,12 +42,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -if is_bs4_available(): - pass - -if is_ftfy_available(): - pass - EXAMPLE_DOC_STRING = """ Examples: ```py @@ -525,7 +517,6 @@ def __call__( system_prompt: Optional[str] = None, cfg_trunc_ratio: float = 1.0, cfg_normalization: bool = True, - use_mask_in_transformer: bool = True, max_sequence_length: int = 256, ) -> Union[ImagePipelineOutput, Tuple]: """ @@ -597,8 +588,6 @@ def __call__( The ratio of the timestep interval to apply normalization-based guidance scale. cfg_normalization (`bool`, *optional*, defaults to `True`): Whether to apply normalization-based guidance scale. - use_mask_in_transformer (`bool`, *optional*, defaults to `True`): - Whether to use attention mask in `Lumina2Transformer2DModel`. Set `False` for performance gain. max_sequence_length (`int`, defaults to `256`): Maximum sequence length to use with the `prompt`. @@ -704,8 +693,7 @@ def __call__( hidden_states=latents, timestep=current_timestep, encoder_hidden_states=prompt_embeds, - attention_mask=prompt_attention_mask, - use_mask_in_transformer=use_mask_in_transformer, + encoder_attention_mask=prompt_attention_mask, return_dict=False, )[0] @@ -715,8 +703,7 @@ def __call__( hidden_states=latents, timestep=current_timestep, encoder_hidden_states=negative_prompt_embeds, - attention_mask=negative_prompt_attention_mask, - use_mask_in_transformer=use_mask_in_transformer, + encoder_attention_mask=negative_prompt_attention_mask, return_dict=False, )[0] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) diff --git a/tests/models/transformers/test_models_transformer_lumina2.py b/tests/models/transformers/test_models_transformer_lumina2.py index e89f160433bd..4db3ae68aa94 100644 --- a/tests/models/transformers/test_models_transformer_lumina2.py +++ b/tests/models/transformers/test_models_transformer_lumina2.py @@ -51,7 +51,7 @@ def dummy_input(self): "hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep, - "attention_mask": attention_mask, + "encoder_attention_mask": attention_mask, } @property