From acbc6a5832920c68854b5b28750a7f7d88e3c709 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Feb 2025 03:41:25 +0100 Subject: [PATCH 1/9] up --- .../models/transformers/transformer_lumina2.py | 10 ++++++---- src/diffusers/pipelines/lumina2/pipeline_lumina2.py | 8 -------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index bd0848a2d63f..720cdc4e5a46 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -241,20 +241,22 @@ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: freqs_cis = [] - # Use float32 for MPS compatibility - dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): - emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=dtype) + emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=torch.float64) freqs_cis.append(emb) return freqs_cis def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: + device = ids.device + if ids.device.type == "mps": + ids = ids.to("cpu") + result = [] for i in range(len(self.axes_dim)): freqs = self.freqs_cis[i].to(ids.device) index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) - return torch.cat(result, dim=-1) + return torch.cat(result, dim=-1).to(device) def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): batch_size = len(hidden_states) diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 801ed25093a3..7a1ec906d7ce 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -24,8 +24,6 @@ from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - is_bs4_available, - is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring, @@ -44,12 +42,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -if is_bs4_available(): - pass - -if is_ftfy_available(): - pass - EXAMPLE_DOC_STRING = """ Examples: ```py From 62c4a9fa60158e7fed6fdf51eb9326799ff5a0f8 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Feb 2025 08:31:55 +0100 Subject: [PATCH 2/9] up --- .../transformers/transformer_lumina2.py | 119 ++++++------------ 1 file changed, 40 insertions(+), 79 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 720cdc4e5a46..ea5eb5e102b0 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -259,80 +259,45 @@ def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: return torch.cat(result, dim=-1).to(device) def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): - batch_size = len(hidden_states) - p_h = p_w = self.patch_size - device = hidden_states[0].device + # Get batch info and dimensions + batch_size, _, height, width = hidden_states.shape + patch_height, patch_width = height // self.patch_size, width // self.patch_size + num_patches = patch_height * patch_width + device = hidden_states.device + # Get caption lengths and calculate max sequence length l_effective_cap_len = attention_mask.sum(dim=1).tolist() - # TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape - img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] - l_effective_img_len = [(H // p_h) * (W // p_w) for (H, W) in img_sizes] - - max_seq_len = max((cap_len + img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))) - max_img_len = max(l_effective_img_len) + max_seq_len = max(l_effective_cap_len) + num_patches + # Create position IDs position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) - + for i in range(batch_size): cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - H, W = img_sizes[i] - H_tokens, W_tokens = H // p_h, W // p_w - assert H_tokens * W_tokens == img_len - + + # Set caption positions position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) - position_ids[i, cap_len : cap_len + img_len, 0] = cap_len - row_ids = ( - torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() - ) - col_ids = ( - torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() - ) - position_ids[i, cap_len : cap_len + img_len, 1] = row_ids - position_ids[i, cap_len : cap_len + img_len, 2] = col_ids - + position_ids[i, cap_len : cap_len + num_patches, 0] = cap_len + + # Set image patch positions + row_ids = torch.arange(patch_height, dtype=torch.int32, device=device).view(-1, 1).repeat(1, patch_width).flatten() + col_ids = torch.arange(patch_width, dtype=torch.int32, device=device).view(1, -1).repeat(patch_height, 1).flatten() + position_ids[i, cap_len : cap_len + num_patches, 1] = row_ids + position_ids[i, cap_len : cap_len + num_patches, 2] = col_ids + + # Get frequencies freqs_cis = self._get_freqs_cis(position_ids) - cap_freqs_cis_shape = list(freqs_cis.shape) - cap_freqs_cis_shape[1] = attention_mask.shape[1] - cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - img_freqs_cis_shape = list(freqs_cis.shape) - img_freqs_cis_shape[1] = max_img_len - img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) + # Split frequencies for captions and images + cap_freqs_cis = torch.zeros(batch_size, attention_mask.shape[1], freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) + img_freqs_cis = torch.zeros(batch_size, num_patches, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) for i in range(batch_size): cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len] + img_freqs_cis[i, :num_patches] = freqs_cis[i, cap_len : cap_len + num_patches] - flat_hidden_states = [] - for i in range(batch_size): - img = hidden_states[i] - C, H, W = img.size() - img = img.view(C, H // p_h, p_h, W // p_w, p_w).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) - flat_hidden_states.append(img) - hidden_states = flat_hidden_states - padded_img_embed = torch.zeros( - batch_size, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype - ) - padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device) - for i in range(batch_size): - padded_img_embed[i, : l_effective_img_len[i]] = hidden_states[i] - padded_img_mask[i, : l_effective_img_len[i]] = True - - return ( - padded_img_embed, - padded_img_mask, - img_sizes, - l_effective_cap_len, - l_effective_img_len, - freqs_cis, - cap_freqs_cis, - img_freqs_cis, - max_seq_len, - ) + return cap_freqs_cis, img_freqs_cis class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): @@ -477,22 +442,19 @@ def forward( use_mask_in_transformer: bool = True, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: - batch_size = hidden_states.size(0) + + batch_size, _, height, width = hidden_states.shape + image_seq_len = (height // self.config.patch_size) * (width // self.config.patch_size) + + text_seq_len = encoder_hidden_states.shape[1] + + l_effective_text_seq_len = attention_mask.sum(dim=1).tolist() + max_seq_len = max(l_effective_text_seq_len) + image_seq_len # 1. Condition, positional & patch embedding temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states) - ( - hidden_states, - hidden_mask, - hidden_sizes, - encoder_hidden_len, - hidden_len, - joint_rotary_emb, - encoder_rotary_emb, - hidden_rotary_emb, - max_seq_len, - ) = self.rope_embedder(hidden_states, attention_mask) + encoder_rotary_emb, hidden_rotary_emb = self.rope_embedder(hidden_states, attention_mask) hidden_states = self.x_embedder(hidden_states) @@ -506,15 +468,15 @@ def forward( for layer in self.noise_refiner: # NOTE: mask not used for performance hidden_states = layer( - hidden_states, hidden_mask if use_mask_in_transformer else None, hidden_rotary_emb, temb + hidden_states, None, hidden_rotary_emb, temb ) # 3. Attention mask preparation mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) padded_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) for i in range(batch_size): - cap_len = encoder_hidden_len[i] - img_len = hidden_len[i] + cap_len = l_effective_text_seq_len[i] + img_len = image_seq_len mask[i, : cap_len + img_len] = True padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len] padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len] @@ -535,10 +497,9 @@ def forward( height_tokens = width_tokens = self.config.patch_size output = [] - for i in range(len(hidden_sizes)): - height, width = hidden_sizes[i] - begin = encoder_hidden_len[i] - end = begin + (height // height_tokens) * (width // width_tokens) + for i in range(batch_size): + begin = l_effective_text_seq_len[i] + end = begin + image_seq_len output.append( hidden_states[i][begin:end] .view(height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels) From c5412b9a2991b20494b4c826508bdbe67547e370 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Feb 2025 18:55:47 +0100 Subject: [PATCH 3/9] up --- .../transformers/transformer_lumina2.py | 115 +++++++++++------- .../pipelines/lumina2/pipeline_lumina2.py | 4 +- 2 files changed, 70 insertions(+), 49 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index ea5eb5e102b0..71ed05060c8e 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -260,9 +260,10 @@ def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): # Get batch info and dimensions - batch_size, _, height, width = hidden_states.shape - patch_height, patch_width = height // self.patch_size, width // self.patch_size - num_patches = patch_height * patch_width + batch_size, channels, height, width = hidden_states.shape + p = self.patch_size + post_patch_height, post_patch_width = height // p, width // p + num_patches = post_patch_height * post_patch_width device = hidden_states.device # Get caption lengths and calculate max sequence length @@ -271,17 +272,27 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): # Create position IDs position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) - + for i in range(batch_size): cap_len = l_effective_cap_len[i] - + # Set caption positions position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) position_ids[i, cap_len : cap_len + num_patches, 0] = cap_len - + # Set image patch positions - row_ids = torch.arange(patch_height, dtype=torch.int32, device=device).view(-1, 1).repeat(1, patch_width).flatten() - col_ids = torch.arange(patch_width, dtype=torch.int32, device=device).view(1, -1).repeat(patch_height, 1).flatten() + row_ids = ( + torch.arange(post_patch_height, dtype=torch.int32, device=device) + .view(-1, 1) + .repeat(1, post_patch_width) + .flatten() + ) + col_ids = ( + torch.arange(post_patch_width, dtype=torch.int32, device=device) + .view(1, -1) + .repeat(post_patch_height, 1) + .flatten() + ) position_ids[i, cap_len : cap_len + num_patches, 1] = row_ids position_ids[i, cap_len : cap_len + num_patches, 2] = col_ids @@ -289,7 +300,9 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): freqs_cis = self._get_freqs_cis(position_ids) # Split frequencies for captions and images - cap_freqs_cis = torch.zeros(batch_size, attention_mask.shape[1], freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) + cap_freqs_cis = torch.zeros( + batch_size, attention_mask.shape[1], freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) img_freqs_cis = torch.zeros(batch_size, num_patches, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) for i in range(batch_size): @@ -297,7 +310,17 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] img_freqs_cis[i, :num_patches] = freqs_cis[i, cap_len : cap_len + num_patches] - return cap_freqs_cis, img_freqs_cis + # patch embeddings + hidden_states = ( + hidden_states.view( + batch_size, channels, post_patch_height, self.patch_size, post_patch_width, self.patch_size + ) + .permute(0, 2, 4, 3, 5, 1) + .flatten(3) + .flatten(1, 2) + ) + + return hidden_states, freqs_cis, cap_freqs_cis, img_freqs_cis class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): @@ -438,71 +461,69 @@ def forward( hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, - attention_mask: torch.Tensor, + encoder_attention_mask: torch.Tensor, use_mask_in_transformer: bool = True, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: - + # 1. Condition, positional & patch embedding batch_size, _, height, width = hidden_states.shape - image_seq_len = (height // self.config.patch_size) * (width // self.config.patch_size) + p = self.config.patch_size + post_patch_height, post_patch_width = height // p, width // p + num_patches = post_patch_height * post_patch_width - text_seq_len = encoder_hidden_states.shape[1] - - l_effective_text_seq_len = attention_mask.sum(dim=1).tolist() - max_seq_len = max(l_effective_text_seq_len) + image_seq_len + # effective_text_seq_lengths is based on actual caption length, so it's different for each prompt in a batch + effective_encoder_seq_lengths = encoder_attention_mask.sum(dim=1).tolist() + seq_lengths = [ + encoder_seq_len + num_patches for encoder_seq_len in effective_encoder_seq_lengths + ] # Add num_patches to each length + max_seq_len = max(seq_lengths) - # 1. Condition, positional & patch embedding temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states) - encoder_rotary_emb, hidden_rotary_emb = self.rope_embedder(hidden_states, attention_mask) + hidden_states, rotary_emb, context_rotary_emb, noise_rotary_emb = self.rope_embedder( + hidden_states, encoder_attention_mask + ) hidden_states = self.x_embedder(hidden_states) # 2. Context & noise refinement for layer in self.context_refiner: - # NOTE: mask not used for performance encoder_hidden_states = layer( - encoder_hidden_states, attention_mask if use_mask_in_transformer else None, encoder_rotary_emb + encoder_hidden_states, encoder_attention_mask if use_mask_in_transformer else None, context_rotary_emb ) for layer in self.noise_refiner: - # NOTE: mask not used for performance - hidden_states = layer( - hidden_states, None, hidden_rotary_emb, temb - ) + hidden_states = layer(hidden_states, None, noise_rotary_emb, temb) + + # 3. Joint Transformer blocks + attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) + for i, (effective_encoder_seq_len, seq_len) in enumerate(zip(effective_encoder_seq_lengths, seq_lengths)): + attention_mask[i, :seq_len] = True + joint_hidden_states[i, :effective_encoder_seq_len] = encoder_hidden_states[i, :effective_encoder_seq_len] + joint_hidden_states[i, effective_encoder_seq_len:seq_len] = hidden_states[i] + + hidden_states = joint_hidden_states - # 3. Attention mask preparation - mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) - padded_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) - for i in range(batch_size): - cap_len = l_effective_text_seq_len[i] - img_len = image_seq_len - mask[i, : cap_len + img_len] = True - padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len] - padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len] - hidden_states = padded_hidden_states - - # 4. Transformer blocks for layer in self.layers: - # NOTE: mask not used for performance if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( - layer, hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb + layer, hidden_states, attention_mask if use_mask_in_transformer else None, rotary_emb, temb ) else: - hidden_states = layer(hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb) + hidden_states = layer( + hidden_states, attention_mask if use_mask_in_transformer else None, rotary_emb, temb + ) - # 5. Output norm & projection & unpatchify + # 4. Output norm & projection hidden_states = self.norm_out(hidden_states, temb) - height_tokens = width_tokens = self.config.patch_size + # 5. Unpatchify output = [] - for i in range(batch_size): - begin = l_effective_text_seq_len[i] - end = begin + image_seq_len + for i, (effective_encoder_seq_len, seq_len) in enumerate(zip(effective_encoder_seq_lengths, seq_lengths)): output.append( - hidden_states[i][begin:end] - .view(height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels) + hidden_states[i][effective_encoder_seq_len:seq_len] + .view(post_patch_height, post_patch_width, p, p, self.out_channels) .permute(4, 0, 2, 1, 3) .flatten(3, 4) .flatten(1, 2) diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 7a1ec906d7ce..54ec5295e81c 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -696,7 +696,7 @@ def __call__( hidden_states=latents, timestep=current_timestep, encoder_hidden_states=prompt_embeds, - attention_mask=prompt_attention_mask, + encoder_attention_mask=prompt_attention_mask, use_mask_in_transformer=use_mask_in_transformer, return_dict=False, )[0] @@ -707,7 +707,7 @@ def __call__( hidden_states=latents, timestep=current_timestep, encoder_hidden_states=negative_prompt_embeds, - attention_mask=negative_prompt_attention_mask, + encoder_attention_mask=negative_prompt_attention_mask, use_mask_in_transformer=use_mask_in_transformer, return_dict=False, )[0] From bde26d3af83af6c6ce3399fbf3fa74249d37688e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 13 Feb 2025 01:03:50 +0100 Subject: [PATCH 4/9] fix for mps --- src/diffusers/models/transformers/transformer_lumina2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 71ed05060c8e..b3512c952c49 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -241,8 +241,9 @@ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: freqs_cis = [] + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): - emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=torch.float64) + emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype) freqs_cis.append(emb) return freqs_cis From 30a80085462efa3d94e8b716f5ccd291b8f26961 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 13 Feb 2025 05:11:17 +0100 Subject: [PATCH 5/9] up --- .../transformers/transformer_lumina2.py | 73 +++++++++---------- 1 file changed, 34 insertions(+), 39 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index b3512c952c49..f7ac35c92302 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -264,22 +264,21 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): batch_size, channels, height, width = hidden_states.shape p = self.patch_size post_patch_height, post_patch_width = height // p, width // p - num_patches = post_patch_height * post_patch_width + image_seq_len = post_patch_height * post_patch_width device = hidden_states.device - # Get caption lengths and calculate max sequence length + encoder_seq_len = attention_mask.shape[1] l_effective_cap_len = attention_mask.sum(dim=1).tolist() - max_seq_len = max(l_effective_cap_len) + num_patches + seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len] + max_seq_len = max(seq_lengths) # Create position IDs position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) - for i in range(batch_size): - cap_len = l_effective_cap_len[i] - + for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): # Set caption positions - position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) - position_ids[i, cap_len : cap_len + num_patches, 0] = cap_len + position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device) + position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len # Set image patch positions row_ids = ( @@ -294,34 +293,33 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): .repeat(post_patch_height, 1) .flatten() ) - position_ids[i, cap_len : cap_len + num_patches, 1] = row_ids - position_ids[i, cap_len : cap_len + num_patches, 2] = col_ids + position_ids[i, cap_seq_len:seq_len, 1] = row_ids + position_ids[i, cap_seq_len:seq_len, 2] = col_ids # Get frequencies freqs_cis = self._get_freqs_cis(position_ids) # Split frequencies for captions and images cap_freqs_cis = torch.zeros( - batch_size, attention_mask.shape[1], freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + img_freqs_cis = torch.zeros( + batch_size, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype ) - img_freqs_cis = torch.zeros(batch_size, num_patches, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype) - for i in range(batch_size): - cap_len = l_effective_cap_len[i] - cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :num_patches] = freqs_cis[i, cap_len : cap_len + num_patches] + for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] + img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len] # patch embeddings hidden_states = ( - hidden_states.view( - batch_size, channels, post_patch_height, self.patch_size, post_patch_width, self.patch_size - ) + hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p) .permute(0, 2, 4, 3, 5, 1) .flatten(3) .flatten(1, 2) ) - return hidden_states, freqs_cis, cap_freqs_cis, img_freqs_cis + return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): @@ -468,22 +466,17 @@ def forward( ) -> Union[torch.Tensor, Transformer2DModelOutput]: # 1. Condition, positional & patch embedding batch_size, _, height, width = hidden_states.shape - p = self.config.patch_size - post_patch_height, post_patch_width = height // p, width // p - num_patches = post_patch_height * post_patch_width - - # effective_text_seq_lengths is based on actual caption length, so it's different for each prompt in a batch - effective_encoder_seq_lengths = encoder_attention_mask.sum(dim=1).tolist() - seq_lengths = [ - encoder_seq_len + num_patches for encoder_seq_len in effective_encoder_seq_lengths - ] # Add num_patches to each length - max_seq_len = max(seq_lengths) temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states) - hidden_states, rotary_emb, context_rotary_emb, noise_rotary_emb = self.rope_embedder( - hidden_states, encoder_attention_mask - ) + ( + hidden_states, + context_rotary_emb, + noise_rotary_emb, + rotary_emb, + encoder_seq_lengths, + seq_lengths, + ) = self.rope_embedder(hidden_states, encoder_attention_mask) hidden_states = self.x_embedder(hidden_states) @@ -497,12 +490,13 @@ def forward( hidden_states = layer(hidden_states, None, noise_rotary_emb, temb) # 3. Joint Transformer blocks + max_seq_len = max(seq_lengths) attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) - for i, (effective_encoder_seq_len, seq_len) in enumerate(zip(effective_encoder_seq_lengths, seq_lengths)): + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): attention_mask[i, :seq_len] = True - joint_hidden_states[i, :effective_encoder_seq_len] = encoder_hidden_states[i, :effective_encoder_seq_len] - joint_hidden_states[i, effective_encoder_seq_len:seq_len] = hidden_states[i] + joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len] + joint_hidden_states[i, encoder_seq_len:seq_len] = hidden_states[i] hidden_states = joint_hidden_states @@ -520,11 +514,12 @@ def forward( hidden_states = self.norm_out(hidden_states, temb) # 5. Unpatchify + p = self.config.patch_size output = [] - for i, (effective_encoder_seq_len, seq_len) in enumerate(zip(effective_encoder_seq_lengths, seq_lengths)): + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): output.append( - hidden_states[i][effective_encoder_seq_len:seq_len] - .view(post_patch_height, post_patch_width, p, p, self.out_channels) + hidden_states[i][encoder_seq_len:seq_len] + .view(height // p, width // p, p, p, self.out_channels) .permute(4, 0, 2, 1, 3) .flatten(3, 4) .flatten(1, 2) From 79ed8c1688dacb2daa3ceff771c658d71e8ef877 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 13 Feb 2025 05:20:44 +0100 Subject: [PATCH 6/9] up --- .../models/transformers/transformer_lumina2.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index f7ac35c92302..8359d35e18d2 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -260,7 +260,6 @@ def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: return torch.cat(result, dim=-1).to(device) def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): - # Get batch info and dimensions batch_size, channels, height, width = hidden_states.shape p = self.patch_size post_patch_height, post_patch_width = height // p, width // p @@ -276,11 +275,11 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): - # Set caption positions + # add caption position ids position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device) position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len - # Set image patch positions + # add image position ids row_ids = ( torch.arange(post_patch_height, dtype=torch.int32, device=device) .view(-1, 1) @@ -296,10 +295,10 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): position_ids[i, cap_seq_len:seq_len, 1] = row_ids position_ids[i, cap_seq_len:seq_len, 2] = col_ids - # Get frequencies + # Get combined rotary embeddings freqs_cis = self._get_freqs_cis(position_ids) - # Split frequencies for captions and images + # create separate rotary embeddings for captions and images cap_freqs_cis = torch.zeros( batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype ) @@ -311,7 +310,7 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len] - # patch embeddings + # image patch embeddings hidden_states = ( hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p) .permute(0, 2, 4, 3, 5, 1) From fc3aa8c6e6b90ab0a142cccad84505a32c5fca44 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 13 Feb 2025 05:36:30 +0100 Subject: [PATCH 7/9] flip the default + always use mask in context refiner --- .../models/transformers/transformer_lumina2.py | 12 ++++-------- src/diffusers/pipelines/lumina2/pipeline_lumina2.py | 9 +++++---- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 8359d35e18d2..007765aa018f 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -460,7 +460,7 @@ def forward( timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, - use_mask_in_transformer: bool = True, + use_mask: bool = True, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: # 1. Condition, positional & patch embedding @@ -481,9 +481,7 @@ def forward( # 2. Context & noise refinement for layer in self.context_refiner: - encoder_hidden_states = layer( - encoder_hidden_states, encoder_attention_mask if use_mask_in_transformer else None, context_rotary_emb - ) + encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb) for layer in self.noise_refiner: hidden_states = layer(hidden_states, None, noise_rotary_emb, temb) @@ -502,12 +500,10 @@ def forward( for layer in self.layers: if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( - layer, hidden_states, attention_mask if use_mask_in_transformer else None, rotary_emb, temb + layer, hidden_states, attention_mask if use_mask else None, rotary_emb, temb ) else: - hidden_states = layer( - hidden_states, attention_mask if use_mask_in_transformer else None, rotary_emb, temb - ) + hidden_states = layer(hidden_states, attention_mask if use_mask else None, rotary_emb, temb) # 4. Output norm & projection hidden_states = self.norm_out(hidden_states, temb) diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 54ec5295e81c..a1306842f043 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -517,7 +517,7 @@ def __call__( system_prompt: Optional[str] = None, cfg_trunc_ratio: float = 1.0, cfg_normalization: bool = True, - use_mask_in_transformer: bool = True, + use_mask_in_transformer: bool = False, max_sequence_length: int = 256, ) -> Union[ImagePipelineOutput, Tuple]: """ @@ -590,7 +590,8 @@ def __call__( cfg_normalization (`bool`, *optional*, defaults to `True`): Whether to apply normalization-based guidance scale. use_mask_in_transformer (`bool`, *optional*, defaults to `True`): - Whether to use attention mask in `Lumina2Transformer2DModel`. Set `False` for performance gain. + Whether to use attention mask in `Lumina2Transformer2DModel` for the transformer blocks. Only need to + set `True` when you pass a list of prompts with different lengths. max_sequence_length (`int`, defaults to `256`): Maximum sequence length to use with the `prompt`. @@ -697,7 +698,7 @@ def __call__( timestep=current_timestep, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, - use_mask_in_transformer=use_mask_in_transformer, + use_mask=use_mask_in_transformer, return_dict=False, )[0] @@ -708,7 +709,7 @@ def __call__( timestep=current_timestep, encoder_hidden_states=negative_prompt_embeds, encoder_attention_mask=negative_prompt_attention_mask, - use_mask_in_transformer=use_mask_in_transformer, + use_mask=use_mask_in_transformer, return_dict=False, )[0] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) From ef7a14bea63c13e5268defdfe95315b811b6dad4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 13 Feb 2025 06:15:22 +0100 Subject: [PATCH 8/9] test --- tests/models/transformers/test_models_transformer_lumina2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/transformers/test_models_transformer_lumina2.py b/tests/models/transformers/test_models_transformer_lumina2.py index e89f160433bd..4db3ae68aa94 100644 --- a/tests/models/transformers/test_models_transformer_lumina2.py +++ b/tests/models/transformers/test_models_transformer_lumina2.py @@ -51,7 +51,7 @@ def dummy_input(self): "hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep, - "attention_mask": attention_mask, + "encoder_attention_mask": attention_mask, } @property From eb8333885889ad9817ebe1c1a382dbf08a9daf74 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 15 Feb 2025 01:36:15 +0100 Subject: [PATCH 9/9] remove use_mask_in_transformer --- src/diffusers/models/transformers/transformer_lumina2.py | 3 ++- src/diffusers/pipelines/lumina2/pipeline_lumina2.py | 6 ------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 007765aa018f..ee58b8ac244d 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -460,7 +460,6 @@ def forward( timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, - use_mask: bool = True, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: # 1. Condition, positional & patch embedding @@ -488,6 +487,8 @@ def forward( # 3. Joint Transformer blocks max_seq_len = max(seq_lengths) + use_mask = len(set(seq_lengths)) > 1 + attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index a1306842f043..3478a6140e8e 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -517,7 +517,6 @@ def __call__( system_prompt: Optional[str] = None, cfg_trunc_ratio: float = 1.0, cfg_normalization: bool = True, - use_mask_in_transformer: bool = False, max_sequence_length: int = 256, ) -> Union[ImagePipelineOutput, Tuple]: """ @@ -589,9 +588,6 @@ def __call__( The ratio of the timestep interval to apply normalization-based guidance scale. cfg_normalization (`bool`, *optional*, defaults to `True`): Whether to apply normalization-based guidance scale. - use_mask_in_transformer (`bool`, *optional*, defaults to `True`): - Whether to use attention mask in `Lumina2Transformer2DModel` for the transformer blocks. Only need to - set `True` when you pass a list of prompts with different lengths. max_sequence_length (`int`, defaults to `256`): Maximum sequence length to use with the `prompt`. @@ -698,7 +694,6 @@ def __call__( timestep=current_timestep, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, - use_mask=use_mask_in_transformer, return_dict=False, )[0] @@ -709,7 +704,6 @@ def __call__( timestep=current_timestep, encoder_hidden_states=negative_prompt_embeds, encoder_attention_mask=negative_prompt_attention_mask, - use_mask=use_mask_in_transformer, return_dict=False, )[0] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)