| 
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
13 | 13 | # See the License for the specific language governing permissions and  | 
14 | 14 | # limitations under the License.  | 
 | 15 | +import gc  | 
15 | 16 | import os  | 
16 | 17 | import sys  | 
17 | 18 | import tempfile  | 
 | 
23 | 24 | from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel  | 
24 | 25 | 
 
  | 
25 | 26 | from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel  | 
26 |  | -from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend, torch_device  | 
 | 27 | +from diffusers.utils.testing_utils import (  | 
 | 28 | +    floats_tensor,  | 
 | 29 | +    is_peft_available,  | 
 | 30 | +    require_peft_backend,  | 
 | 31 | +    require_torch_gpu,  | 
 | 32 | +    slow,  | 
 | 33 | +    torch_device,  | 
 | 34 | +)  | 
27 | 35 | 
 
  | 
28 | 36 | 
 
  | 
29 | 37 | if is_peft_available():  | 
@@ -145,3 +153,89 @@ def test_with_alpha_in_state_dict(self):  | 
145 | 153 |             "Loading from saved checkpoints should give same results.",  | 
146 | 154 |         )  | 
147 | 155 |         self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))  | 
 | 156 | + | 
 | 157 | + | 
 | 158 | +@slow  | 
 | 159 | +@require_torch_gpu  | 
 | 160 | +@require_peft_backend  | 
 | 161 | +@unittest.skip("We cannot run inference on this model with the current CI hardware")  | 
 | 162 | +# TODO (DN6, sayakpaul): move these tests to a beefier GPU  | 
 | 163 | +class FluxLoRAIntegrationTests(unittest.TestCase):  | 
 | 164 | +    """internal note: The integration slices were obtained on audace."""  | 
 | 165 | + | 
 | 166 | +    num_inference_steps = 10  | 
 | 167 | +    seed = 0  | 
 | 168 | + | 
 | 169 | +    def setUp(self):  | 
 | 170 | +        super().setUp()  | 
 | 171 | + | 
 | 172 | +        gc.collect()  | 
 | 173 | +        torch.cuda.empty_cache()  | 
 | 174 | + | 
 | 175 | +        self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)  | 
 | 176 | + | 
 | 177 | +    def tearDown(self):  | 
 | 178 | +        super().tearDown()  | 
 | 179 | + | 
 | 180 | +        gc.collect()  | 
 | 181 | +        torch.cuda.empty_cache()  | 
 | 182 | + | 
 | 183 | +    def test_flux_the_last_ben(self):  | 
 | 184 | +        self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")  | 
 | 185 | +        self.pipeline.fuse_lora()  | 
 | 186 | +        self.pipeline.unload_lora_weights()  | 
 | 187 | +        self.pipeline.enable_model_cpu_offload()  | 
 | 188 | + | 
 | 189 | +        prompt = "jon snow eating pizza with ketchup"  | 
 | 190 | + | 
 | 191 | +        out = self.pipeline(  | 
 | 192 | +            prompt,  | 
 | 193 | +            num_inference_steps=self.num_inference_steps,  | 
 | 194 | +            guidance_scale=4.0,  | 
 | 195 | +            output_type="np",  | 
 | 196 | +            generator=torch.manual_seed(self.seed),  | 
 | 197 | +        ).images  | 
 | 198 | +        out_slice = out[0, -3:, -3:, -1].flatten()  | 
 | 199 | +        expected_slice = np.array([0.1719, 0.1719, 0.1699, 0.1719, 0.1719, 0.1738, 0.1641, 0.1621, 0.2090])  | 
 | 200 | + | 
 | 201 | +        assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)  | 
 | 202 | + | 
 | 203 | +    def test_flux_kohya(self):  | 
 | 204 | +        self.pipeline.load_lora_weights("Norod78/brain-slug-flux")  | 
 | 205 | +        self.pipeline.fuse_lora()  | 
 | 206 | +        self.pipeline.unload_lora_weights()  | 
 | 207 | +        self.pipeline.enable_model_cpu_offload()  | 
 | 208 | + | 
 | 209 | +        prompt = "The cat with a brain slug earring"  | 
 | 210 | +        out = self.pipeline(  | 
 | 211 | +            prompt,  | 
 | 212 | +            num_inference_steps=self.num_inference_steps,  | 
 | 213 | +            guidance_scale=4.5,  | 
 | 214 | +            output_type="np",  | 
 | 215 | +            generator=torch.manual_seed(self.seed),  | 
 | 216 | +        ).images  | 
 | 217 | + | 
 | 218 | +        out_slice = out[0, -3:, -3:, -1].flatten()  | 
 | 219 | +        expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484])  | 
 | 220 | + | 
 | 221 | +        assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)  | 
 | 222 | + | 
 | 223 | +    def test_flux_xlabs(self):  | 
 | 224 | +        self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")  | 
 | 225 | +        self.pipeline.fuse_lora()  | 
 | 226 | +        self.pipeline.unload_lora_weights()  | 
 | 227 | +        self.pipeline.enable_model_cpu_offload()  | 
 | 228 | + | 
 | 229 | +        prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"  | 
 | 230 | + | 
 | 231 | +        out = self.pipeline(  | 
 | 232 | +            prompt,  | 
 | 233 | +            num_inference_steps=self.num_inference_steps,  | 
 | 234 | +            guidance_scale=3.5,  | 
 | 235 | +            output_type="np",  | 
 | 236 | +            generator=torch.manual_seed(self.seed),  | 
 | 237 | +        ).images  | 
 | 238 | +        out_slice = out[0, -3:, -3:, -1].flatten()  | 
 | 239 | +        expected_slice = np.array([0.3984, 0.4199, 0.4453, 0.4102, 0.4375, 0.4590, 0.4141, 0.4355, 0.4980])  | 
 | 240 | + | 
 | 241 | +        assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)  | 
0 commit comments