Skip to content

Commit a185e1a

Browse files
sayakpaulstevhliu
andauthored
[tests] add a test on torch compile for varied resolutions (#11776)
* add test for checking compile on different shapes. * update * update * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> --------- Co-authored-by: Steven Liu <[email protected]>
1 parent d93381c commit a185e1a

File tree

3 files changed

+58
-12
lines changed

3 files changed

+58
-12
lines changed

docs/source/en/optimization/fp16.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,28 @@ pipeline(prompt, num_inference_steps=30).images[0]
150150

151151
Compilation is slow the first time, but once compiled, it is significantly faster. Try to only use the compiled pipeline on the same type of inference operations. Calling the compiled pipeline on a different image size retriggers compilation which is slow and inefficient.
152152

153+
### Dynamic shape compilation
154+
155+
> [!TIP]
156+
> Make sure to always use the nightly version of PyTorch for better support.
157+
158+
`torch.compile` keeps track of input shapes and conditions, and if these are different, it recompiles the model. For example, if a model is compiled on a 1024x1024 resolution image and used on an image with a different resolution, it triggers recompilation.
159+
160+
To avoid recompilation, add `dynamic=True` to try and generate a more dynamic kernel to avoid recompilation when conditions change.
161+
162+
```diff
163+
+ torch.fx.experimental._config.use_duck_shape = False
164+
+ pipeline.unet = torch.compile(
165+
pipeline.unet, fullgraph=True, dynamic=True
166+
)
167+
```
168+
169+
Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
170+
171+
Not all models may benefit from dynamic compilation out of the box and may require changes. Refer to this [PR](https://github.com/huggingface/diffusers/pull/11297/) that improved the [`AuraFlowPipeline`] implementation to benefit from dynamic compilation.
172+
173+
Feel free to open an issue if dynamic compilation doesn't work as expected for a Diffusers model.
174+
153175
### Regional compilation
154176

155177

tests/models/test_modeling_common.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
require_torch_accelerator_with_training,
7777
require_torch_gpu,
7878
require_torch_multi_accelerator,
79+
require_torch_version_greater,
7980
run_test_in_subprocess,
8081
slow,
8182
torch_all_close,
@@ -1907,6 +1908,8 @@ def test_push_to_hub_library_name(self):
19071908
@is_torch_compile
19081909
@slow
19091910
class TorchCompileTesterMixin:
1911+
different_shapes_for_compilation = None
1912+
19101913
def setUp(self):
19111914
# clean up the VRAM before each test
19121915
super().setUp()
@@ -1957,14 +1960,14 @@ def test_torch_compile_repeated_blocks(self):
19571960
_ = model(**inputs_dict)
19581961

19591962
def test_compile_with_group_offloading(self):
1963+
if not self.model_class._supports_group_offloading:
1964+
pytest.skip("Model does not support group offloading.")
1965+
19601966
torch._dynamo.config.cache_size_limit = 10000
19611967

19621968
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
19631969
model = self.model_class(**init_dict)
19641970

1965-
if not getattr(model, "_supports_group_offloading", True):
1966-
return
1967-
19681971
model.eval()
19691972
# TODO: Can test for other group offloading kwargs later if needed.
19701973
group_offload_kwargs = {
@@ -1981,6 +1984,21 @@ def test_compile_with_group_offloading(self):
19811984
_ = model(**inputs_dict)
19821985
_ = model(**inputs_dict)
19831986

1987+
@require_torch_version_greater("2.7.1")
1988+
def test_compile_on_different_shapes(self):
1989+
if self.different_shapes_for_compilation is None:
1990+
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
1991+
torch.fx.experimental._config.use_duck_shape = False
1992+
1993+
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
1994+
model = self.model_class(**init_dict).to(torch_device)
1995+
model = torch.compile(model, fullgraph=True, dynamic=True)
1996+
1997+
for height, width in self.different_shapes_for_compilation:
1998+
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
1999+
inputs_dict = self.prepare_dummy_input(height=height, width=width)
2000+
_ = model(**inputs_dict)
2001+
19842002

19852003
@slow
19862004
@require_torch_2

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,20 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
9191

9292
@property
9393
def dummy_input(self):
94+
return self.prepare_dummy_input()
95+
96+
@property
97+
def input_shape(self):
98+
return (16, 4)
99+
100+
@property
101+
def output_shape(self):
102+
return (16, 4)
103+
104+
def prepare_dummy_input(self, height=4, width=4):
94105
batch_size = 1
95106
num_latent_channels = 4
96107
num_image_channels = 3
97-
height = width = 4
98108
sequence_length = 48
99109
embedding_dim = 32
100110

@@ -114,14 +124,6 @@ def dummy_input(self):
114124
"timestep": timestep,
115125
}
116126

117-
@property
118-
def input_shape(self):
119-
return (16, 4)
120-
121-
@property
122-
def output_shape(self):
123-
return (16, 4)
124-
125127
def prepare_init_args_and_inputs_for_common(self):
126128
init_dict = {
127129
"patch_size": 1,
@@ -173,10 +175,14 @@ def test_gradient_checkpointing_is_applied(self):
173175

174176
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
175177
model_class = FluxTransformer2DModel
178+
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
176179

177180
def prepare_init_args_and_inputs_for_common(self):
178181
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
179182

183+
def prepare_dummy_input(self, height, width):
184+
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
185+
180186

181187
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
182188
model_class = FluxTransformer2DModel

0 commit comments

Comments
 (0)