|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
| 15 | +import gc |
15 | 16 | import os |
16 | 17 | import sys |
17 | 18 | import tempfile |
|
23 | 24 | from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel |
24 | 25 |
|
25 | 26 | from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel |
26 | | -from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend, torch_device |
| 27 | +from diffusers.utils.testing_utils import ( |
| 28 | + floats_tensor, |
| 29 | + is_peft_available, |
| 30 | + require_peft_backend, |
| 31 | + require_torch_gpu, |
| 32 | + slow, |
| 33 | + torch_device, |
| 34 | +) |
27 | 35 |
|
28 | 36 |
|
29 | 37 | if is_peft_available(): |
@@ -145,3 +153,89 @@ def test_with_alpha_in_state_dict(self): |
145 | 153 | "Loading from saved checkpoints should give same results.", |
146 | 154 | ) |
147 | 155 | self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) |
| 156 | + |
| 157 | + |
| 158 | +@slow |
| 159 | +@require_torch_gpu |
| 160 | +@require_peft_backend |
| 161 | +@unittest.skip("We cannot run inference on this model with the current CI hardware") |
| 162 | +# TODO (DN6, sayakpaul): move these tests to a beefier GPU |
| 163 | +class FluxLoRAIntegrationTests(unittest.TestCase): |
| 164 | + """internal note: The integration slices were obtained on audace.""" |
| 165 | + |
| 166 | + num_inference_steps = 10 |
| 167 | + seed = 0 |
| 168 | + |
| 169 | + def setUp(self): |
| 170 | + super().setUp() |
| 171 | + |
| 172 | + gc.collect() |
| 173 | + torch.cuda.empty_cache() |
| 174 | + |
| 175 | + self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) |
| 176 | + |
| 177 | + def tearDown(self): |
| 178 | + super().tearDown() |
| 179 | + |
| 180 | + gc.collect() |
| 181 | + torch.cuda.empty_cache() |
| 182 | + |
| 183 | + def test_flux_the_last_ben(self): |
| 184 | + self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors") |
| 185 | + self.pipeline.fuse_lora() |
| 186 | + self.pipeline.unload_lora_weights() |
| 187 | + self.pipeline.enable_model_cpu_offload() |
| 188 | + |
| 189 | + prompt = "jon snow eating pizza with ketchup" |
| 190 | + |
| 191 | + out = self.pipeline( |
| 192 | + prompt, |
| 193 | + num_inference_steps=self.num_inference_steps, |
| 194 | + guidance_scale=4.0, |
| 195 | + output_type="np", |
| 196 | + generator=torch.manual_seed(self.seed), |
| 197 | + ).images |
| 198 | + out_slice = out[0, -3:, -3:, -1].flatten() |
| 199 | + expected_slice = np.array([0.1719, 0.1719, 0.1699, 0.1719, 0.1719, 0.1738, 0.1641, 0.1621, 0.2090]) |
| 200 | + |
| 201 | + assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) |
| 202 | + |
| 203 | + def test_flux_kohya(self): |
| 204 | + self.pipeline.load_lora_weights("Norod78/brain-slug-flux") |
| 205 | + self.pipeline.fuse_lora() |
| 206 | + self.pipeline.unload_lora_weights() |
| 207 | + self.pipeline.enable_model_cpu_offload() |
| 208 | + |
| 209 | + prompt = "The cat with a brain slug earring" |
| 210 | + out = self.pipeline( |
| 211 | + prompt, |
| 212 | + num_inference_steps=self.num_inference_steps, |
| 213 | + guidance_scale=4.5, |
| 214 | + output_type="np", |
| 215 | + generator=torch.manual_seed(self.seed), |
| 216 | + ).images |
| 217 | + |
| 218 | + out_slice = out[0, -3:, -3:, -1].flatten() |
| 219 | + expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484]) |
| 220 | + |
| 221 | + assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) |
| 222 | + |
| 223 | + def test_flux_xlabs(self): |
| 224 | + self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") |
| 225 | + self.pipeline.fuse_lora() |
| 226 | + self.pipeline.unload_lora_weights() |
| 227 | + self.pipeline.enable_model_cpu_offload() |
| 228 | + |
| 229 | + prompt = "A blue jay standing on a large basket of rainbow macarons, disney style" |
| 230 | + |
| 231 | + out = self.pipeline( |
| 232 | + prompt, |
| 233 | + num_inference_steps=self.num_inference_steps, |
| 234 | + guidance_scale=3.5, |
| 235 | + output_type="np", |
| 236 | + generator=torch.manual_seed(self.seed), |
| 237 | + ).images |
| 238 | + out_slice = out[0, -3:, -3:, -1].flatten() |
| 239 | + expected_slice = np.array([0.3984, 0.4199, 0.4453, 0.4102, 0.4375, 0.4590, 0.4141, 0.4355, 0.4980]) |
| 240 | + |
| 241 | + assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) |
0 commit comments