toshas's picture
initial commit
a45988a
import gc
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
def get_dummy_components(self):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=4,
num_layers=1,
num_single_layers=1,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[4, 4, 8],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 8,
"width": 8,
"max_sequence_length": 48,
"output_type": "np",
}
return inputs
def test_flux_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
# For some reasons, they don't show large differences
assert max_diff > 1e-6
def test_flux_prompt_embeds(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_with_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
(prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
prompt,
prompt_2=None,
device=torch_device,
max_sequence_length=inputs["max_sequence_length"],
)
output_with_embeds = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
**inputs,
).images[0]
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4
@slow
@require_torch_gpu
class FluxPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-schnell"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
return {
"prompt": "A photo of a cat",
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
"generator": generator,
}
# TODO: Dhruv. Move large model tests to a dedicated runner)
@unittest.skip("We cannot run inference on this model with the current CI hardware")
def test_flux_inference(self):
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
[0.36132812, 0.30004883, 0.25830078],
[0.36669922, 0.31103516, 0.23754883],
[0.34814453, 0.29248047, 0.23583984],
[0.35791016, 0.30981445, 0.23999023],
[0.36328125, 0.31274414, 0.2607422],
[0.37304688, 0.32177734, 0.26171875],
[0.3671875, 0.31933594, 0.25756836],
[0.36035156, 0.31103516, 0.2578125],
[0.3857422, 0.33789062, 0.27563477],
[0.3701172, 0.31982422, 0.265625],
],
dtype=np.float32,
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4