import gc import unittest import numpy as np import torch from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel from diffusers.utils.testing_utils import ( numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device, ) from ..test_pipelines_common import PipelineTesterMixin @unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = FluxPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) def get_dummy_components(self): torch.manual_seed(0) transformer = FluxTransformer2DModel( patch_size=1, in_channels=4, num_layers=1, num_single_layers=1, attention_head_dim=16, num_attention_heads=2, joint_attention_dim=32, pooled_projection_dim=32, axes_dims_rope=[4, 4, 8], ) clip_text_encoder_config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, hidden_size=32, intermediate_size=37, layer_norm_eps=1e-05, num_attention_heads=4, num_hidden_layers=5, pad_token_id=1, vocab_size=1000, hidden_act="gelu", projection_dim=32, ) torch.manual_seed(0) text_encoder = CLIPTextModel(clip_text_encoder_config) torch.manual_seed(0) text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") torch.manual_seed(0) vae = AutoencoderKL( sample_size=32, in_channels=3, out_channels=3, block_out_channels=(4,), layers_per_block=1, latent_channels=1, norm_num_groups=1, use_quant_conv=False, use_post_quant_conv=False, shift_factor=0.0609, scaling_factor=1.5035, ) scheduler = FlowMatchEulerDiscreteScheduler() return { "scheduler": scheduler, "text_encoder": text_encoder, "text_encoder_2": text_encoder_2, "tokenizer": tokenizer, "tokenizer_2": tokenizer_2, "transformer": transformer, "vae": vae, } def get_dummy_inputs(self, device, seed=0): if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: generator = torch.Generator(device="cpu").manual_seed(seed) inputs = { "prompt": "A painting of a squirrel eating a burger", "generator": generator, "num_inference_steps": 2, "guidance_scale": 5.0, "height": 8, "width": 8, "max_sequence_length": 48, "output_type": "np", } return inputs def test_flux_different_prompts(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) output_same_prompt = pipe(**inputs).images[0] inputs = self.get_dummy_inputs(torch_device) inputs["prompt_2"] = "a different prompt" output_different_prompts = pipe(**inputs).images[0] max_diff = np.abs(output_same_prompt - output_different_prompts).max() # Outputs should be different here # For some reasons, they don't show large differences assert max_diff > 1e-6 def test_flux_prompt_embeds(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) output_with_prompt = pipe(**inputs).images[0] inputs = self.get_dummy_inputs(torch_device) prompt = inputs.pop("prompt") (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( prompt, prompt_2=None, device=torch_device, max_sequence_length=inputs["max_sequence_length"], ) output_with_embeds = pipe( prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, **inputs, ).images[0] max_diff = np.abs(output_with_prompt - output_with_embeds).max() assert max_diff < 1e-4 @slow @require_torch_gpu class FluxPipelineSlowTests(unittest.TestCase): pipeline_class = FluxPipeline repo_id = "black-forest-labs/FLUX.1-schnell" def setUp(self): super().setUp() gc.collect() torch.cuda.empty_cache() def tearDown(self): super().tearDown() gc.collect() torch.cuda.empty_cache() def get_inputs(self, device, seed=0): if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: generator = torch.Generator(device="cpu").manual_seed(seed) return { "prompt": "A photo of a cat", "num_inference_steps": 2, "guidance_scale": 5.0, "output_type": "np", "generator": generator, } # TODO: Dhruv. Move large model tests to a dedicated runner) @unittest.skip("We cannot run inference on this model with the current CI hardware") def test_flux_inference(self): pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() inputs = self.get_inputs(torch_device) image = pipe(**inputs).images[0] image_slice = image[0, :10, :10] expected_slice = np.array( [ [0.36132812, 0.30004883, 0.25830078], [0.36669922, 0.31103516, 0.23754883], [0.34814453, 0.29248047, 0.23583984], [0.35791016, 0.30981445, 0.23999023], [0.36328125, 0.31274414, 0.2607422], [0.37304688, 0.32177734, 0.26171875], [0.3671875, 0.31933594, 0.25756836], [0.36035156, 0.31103516, 0.2578125], [0.3857422, 0.33789062, 0.27563477], [0.3701172, 0.31982422, 0.265625], ], dtype=np.float32, ) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) assert max_diff < 1e-4