Skip to content

Commit 6239a6a

Browse files
authored
Merge branch 'main' into mochi
2 parents 4b966d2 + 0d1d267 commit 6239a6a

23 files changed

+3300
-5
lines changed

docs/source/en/_toctree.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@
252252
title: SparseControlNetModel
253253
title: ControlNets
254254
- sections:
255+
- local: api/models/allegro_transformer3d
256+
title: AllegroTransformer3DModel
255257
- local: api/models/aura_flow_transformer2d
256258
title: AuraFlowTransformer2DModel
257259
- local: api/models/cogvideox_transformer3d
@@ -302,6 +304,8 @@
302304
- sections:
303305
- local: api/models/autoencoderkl
304306
title: AutoencoderKL
307+
- local: api/models/autoencoderkl_allegro
308+
title: AutoencoderKLAllegro
305309
- local: api/models/autoencoderkl_cogvideox
306310
title: AutoencoderKLCogVideoX
307311
- local: api/models/autoencoderkl_mochi
@@ -322,6 +326,8 @@
322326
sections:
323327
- local: api/pipelines/overview
324328
title: Overview
329+
- local: api/pipelines/allegro
330+
title: Allegro
325331
- local: api/pipelines/amused
326332
title: aMUSEd
327333
- local: api/pipelines/animatediff
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# AllegroTransformer3DModel
13+
14+
A Diffusion Transformer model for 3D data from [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import AllegroTransformer3DModel
20+
21+
vae = AllegroTransformer3DModel.from_pretrained("rhymes-ai/Allegro", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
22+
```
23+
24+
## AllegroTransformer3DModel
25+
26+
[[autodoc]] AllegroTransformer3DModel
27+
28+
## Transformer2DModelOutput
29+
30+
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# AutoencoderKLAllegro
13+
14+
The 3D variational autoencoder (VAE) model with KL loss used in [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import AutoencoderKLAllegro
20+
21+
vae = AutoencoderKLCogVideoX.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32).to("cuda")
22+
```
23+
24+
## AutoencoderKLAllegro
25+
26+
[[autodoc]] AutoencoderKLAllegro
27+
- decode
28+
- encode
29+
- all
30+
31+
## AutoencoderKLOutput
32+
33+
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
34+
35+
## DecoderOutput
36+
37+
[[autodoc]] models.autoencoders.vae.DecoderOutput
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# Allegro
13+
14+
[Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) from RhymesAI, by Yuan Zhou, Qiuyue Wang, Yuxuan Cai, Huan Yang.
15+
16+
The abstract from the paper is:
17+
18+
*Significant advancements have been made in the field of video generation, with the open-source community contributing a wealth of research papers and tools for training high-quality models. However, despite these efforts, the available information and resources remain insufficient for achieving commercial-level performance. In this report, we open the black box and introduce Allegro, an advanced video generation model that excels in both quality and temporal consistency. We also highlight the current limitations in the field and present a comprehensive methodology for training high-performance, commercial-level video generation models, addressing key aspects such as data, model architecture, training pipeline, and evaluation. Our user study shows that Allegro surpasses existing open-source models and most commercial models, ranking just behind Hailuo and Kling. Code: https://github.com/rhymes-ai/Allegro , Model: https://huggingface.co/rhymes-ai/Allegro , Gallery: https://rhymes.ai/allegro_gallery .*
19+
20+
<Tip>
21+
22+
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
23+
24+
</Tip>
25+
26+
## AllegroPipeline
27+
28+
[[autodoc]] AllegroPipeline
29+
- all
30+
- __call__
31+
32+
## AllegroPipelineOutput
33+
34+
[[autodoc]] pipelines.allegro.pipeline_output.AllegroPipelineOutput

src/diffusers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@
7777
else:
7878
_import_structure["models"].extend(
7979
[
80+
"AllegroTransformer3DModel",
8081
"AsymmetricAutoencoderKL",
8182
"AuraFlowTransformer2DModel",
8283
"AutoencoderKL",
84+
"AutoencoderKLAllegro",
8385
"AutoencoderKLCogVideoX",
8486
"AutoencoderKLMochi",
8587
"AutoencoderKLTemporalDecoder",
@@ -239,6 +241,7 @@
239241
else:
240242
_import_structure["pipelines"].extend(
241243
[
244+
"AllegroPipeline",
242245
"AltDiffusionImg2ImgPipeline",
243246
"AltDiffusionPipeline",
244247
"AmusedImg2ImgPipeline",
@@ -559,9 +562,11 @@
559562
from .utils.dummy_pt_objects import * # noqa F403
560563
else:
561564
from .models import (
565+
AllegroTransformer3DModel,
562566
AsymmetricAutoencoderKL,
563567
AuraFlowTransformer2DModel,
564568
AutoencoderKL,
569+
AutoencoderKLAllegro,
565570
AutoencoderKLCogVideoX,
566571
AutoencoderKLMochi,
567572
AutoencoderKLTemporalDecoder,
@@ -702,6 +707,7 @@
702707
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
703708
else:
704709
from .pipelines import (
710+
AllegroPipeline,
705711
AltDiffusionImg2ImgPipeline,
706712
AltDiffusionPipeline,
707713
AmusedImg2ImgPipeline,

src/diffusers/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
2929
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
3030
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
31+
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
3132
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
3233
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
3334
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
@@ -55,6 +56,7 @@
5556
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
5657
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
5758
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
59+
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
5860
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
5961
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
6062
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
@@ -83,6 +85,7 @@
8385
from .autoencoders import (
8486
AsymmetricAutoencoderKL,
8587
AutoencoderKL,
88+
AutoencoderKLAllegro,
8689
AutoencoderKLCogVideoX,
8790
AutoencoderKLMochi,
8891
AutoencoderKLTemporalDecoder,
@@ -100,6 +103,7 @@
100103
from .embeddings import ImageProjection
101104
from .modeling_utils import ModelMixin
102105
from .transformers import (
106+
AllegroTransformer3DModel,
103107
AuraFlowTransformer2DModel,
104108
CogVideoXTransformer3DModel,
105109
CogView3PlusTransformer2DModel,

src/diffusers/models/attention_processor.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,6 +1523,100 @@ def __call__(
15231523
return hidden_states, encoder_hidden_states
15241524

15251525

1526+
class AllegroAttnProcessor2_0:
1527+
r"""
1528+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
1529+
used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector.
1530+
"""
1531+
1532+
def __init__(self):
1533+
if not hasattr(F, "scaled_dot_product_attention"):
1534+
raise ImportError(
1535+
"AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1536+
)
1537+
1538+
def __call__(
1539+
self,
1540+
attn: Attention,
1541+
hidden_states: torch.Tensor,
1542+
encoder_hidden_states: Optional[torch.Tensor] = None,
1543+
attention_mask: Optional[torch.Tensor] = None,
1544+
temb: Optional[torch.Tensor] = None,
1545+
image_rotary_emb: Optional[torch.Tensor] = None,
1546+
) -> torch.Tensor:
1547+
residual = hidden_states
1548+
1549+
if attn.spatial_norm is not None:
1550+
hidden_states = attn.spatial_norm(hidden_states, temb)
1551+
1552+
input_ndim = hidden_states.ndim
1553+
1554+
if input_ndim == 4:
1555+
batch_size, channel, height, width = hidden_states.shape
1556+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1557+
1558+
batch_size, sequence_length, _ = (
1559+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1560+
)
1561+
1562+
if attention_mask is not None:
1563+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1564+
# scaled_dot_product_attention expects attention_mask shape to be
1565+
# (batch, heads, source_length, target_length)
1566+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1567+
1568+
if attn.group_norm is not None:
1569+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1570+
1571+
query = attn.to_q(hidden_states)
1572+
1573+
if encoder_hidden_states is None:
1574+
encoder_hidden_states = hidden_states
1575+
elif attn.norm_cross:
1576+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1577+
1578+
key = attn.to_k(encoder_hidden_states)
1579+
value = attn.to_v(encoder_hidden_states)
1580+
1581+
inner_dim = key.shape[-1]
1582+
head_dim = inner_dim // attn.heads
1583+
1584+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1585+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1586+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1587+
1588+
# Apply RoPE if needed
1589+
if image_rotary_emb is not None and not attn.is_cross_attention:
1590+
from .embeddings import apply_rotary_emb_allegro
1591+
1592+
query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
1593+
key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])
1594+
1595+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
1596+
# TODO: add support for attn.scale when we move to Torch 2.1
1597+
hidden_states = F.scaled_dot_product_attention(
1598+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1599+
)
1600+
1601+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1602+
hidden_states = hidden_states.to(query.dtype)
1603+
1604+
# linear proj
1605+
hidden_states = attn.to_out[0](hidden_states)
1606+
# dropout
1607+
hidden_states = attn.to_out[1](hidden_states)
1608+
1609+
if input_ndim == 4:
1610+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1611+
1612+
if attn.residual_connection:
1613+
hidden_states = hidden_states + residual
1614+
1615+
hidden_states = hidden_states / attn.rescale_output_factor
1616+
1617+
return hidden_states
1618+
1619+
15261620
class AuraFlowAttnProcessor2_0:
15271621
"""Attention processor used typically in processing Aura Flow."""
15281622

src/diffusers/models/autoencoders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
22
from .autoencoder_kl import AutoencoderKL
3+
from .autoencoder_kl_allegro import AutoencoderKLAllegro
34
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
45
from .autoencoder_kl_mochi import AutoencoderKLMochi
56
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder

0 commit comments

Comments
 (0)