import inspect import unittest import numpy as np import torch from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3PAGPipeline, StableDiffusion3Pipeline, ) from diffusers.utils.testing_utils import ( torch_device, ) from ..test_pipelines_common import ( PipelineTesterMixin, check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_processors_exist, ) class StableDiffusion3PAGPipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = StableDiffusion3PAGPipeline params = frozenset( [ "prompt", "height", "width", "guidance_scale", "negative_prompt", "prompt_embeds", "negative_prompt_embeds", ] ) batch_params = frozenset(["prompt", "negative_prompt"]) def get_dummy_components(self): torch.manual_seed(0) transformer = SD3Transformer2DModel( sample_size=32, patch_size=1, in_channels=4, num_layers=2, attention_head_dim=8, num_attention_heads=4, caption_projection_dim=32, joint_attention_dim=32, pooled_projection_dim=64, out_channels=4, ) clip_text_encoder_config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, hidden_size=32, intermediate_size=37, layer_norm_eps=1e-05, num_attention_heads=4, num_hidden_layers=5, pad_token_id=1, vocab_size=1000, hidden_act="gelu", projection_dim=32, ) torch.manual_seed(0) text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config) torch.manual_seed(0) text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config) text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") torch.manual_seed(0) vae = AutoencoderKL( sample_size=32, in_channels=3, out_channels=3, block_out_channels=(4,), layers_per_block=1, latent_channels=4, norm_num_groups=1, use_quant_conv=False, use_post_quant_conv=False, shift_factor=0.0609, scaling_factor=1.5035, ) scheduler = FlowMatchEulerDiscreteScheduler() return { "scheduler": scheduler, "text_encoder": text_encoder, "text_encoder_2": text_encoder_2, "text_encoder_3": text_encoder_3, "tokenizer": tokenizer, "tokenizer_2": tokenizer_2, "tokenizer_3": tokenizer_3, "transformer": transformer, "vae": vae, } def get_dummy_inputs(self, device, seed=0): if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: generator = torch.Generator(device="cpu").manual_seed(seed) inputs = { "prompt": "A painting of a squirrel eating a burger", "generator": generator, "num_inference_steps": 2, "guidance_scale": 5.0, "output_type": "np", "pag_scale": 0.0, } return inputs def test_stable_diffusion_3_different_prompts(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) output_same_prompt = pipe(**inputs).images[0] inputs = self.get_dummy_inputs(torch_device) inputs["prompt_2"] = "a different prompt" inputs["prompt_3"] = "another different prompt" output_different_prompts = pipe(**inputs).images[0] max_diff = np.abs(output_same_prompt - output_different_prompts).max() # Outputs should be different here assert max_diff > 1e-2 def test_stable_diffusion_3_different_negative_prompts(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) output_same_prompt = pipe(**inputs).images[0] inputs = self.get_dummy_inputs(torch_device) inputs["negative_prompt_2"] = "deformed" inputs["negative_prompt_3"] = "blurry" output_different_prompts = pipe(**inputs).images[0] max_diff = np.abs(output_same_prompt - output_different_prompts).max() # Outputs should be different here assert max_diff > 1e-2 def test_stable_diffusion_3_prompt_embeds(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) output_with_prompt = pipe(**inputs).images[0] inputs = self.get_dummy_inputs(torch_device) prompt = inputs.pop("prompt") do_classifier_free_guidance = inputs["guidance_scale"] > 1 ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt, prompt_2=None, prompt_3=None, do_classifier_free_guidance=do_classifier_free_guidance, device=torch_device, ) output_with_embeds = pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, **inputs, ).images[0] max_diff = np.abs(output_with_prompt - output_with_embeds).max() assert max_diff < 1e-4 def test_fused_qkv_projections(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images original_image_slice = image[0, -3:, -3:, -1] # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() assert check_qkv_fusion_processors_exist( pipe.transformer ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." assert check_qkv_fusion_matches_attn_procs_length( pipe.transformer, pipe.transformer.original_attn_processors ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images image_slice_fused = image[0, -3:, -3:, -1] pipe.transformer.unfuse_qkv_projections() inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images image_slice_disabled = image[0, -3:, -3:, -1] assert np.allclose( original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 ), "Fusion of QKV projections shouldn't affect the outputs." assert np.allclose( image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." assert np.allclose( original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 ), "Original outputs should match when fused QKV projections are disabled." def test_pag_disable_enable(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() # base pipeline (expect same output when pag is disabled) pipe_sd = StableDiffusion3Pipeline(**components) pipe_sd = pipe_sd.to(device) pipe_sd.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) del inputs["pag_scale"] assert ( "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}." out = pipe_sd(**inputs).images[0, -3:, -3:, -1] components = self.get_dummy_components() # pag disabled with pag_scale=0.0 pipe_pag = self.pipeline_class(**components) pipe_pag = pipe_pag.to(device) pipe_pag.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) inputs["pag_scale"] = 0.0 out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1] assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3 def test_pag_applied_layers(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() # base pipeline pipe = self.pipeline_class(**components) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn" in k] original_attn_procs = pipe.transformer.attn_processors pag_layers = ["blocks.0", "blocks.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) # blocks.0 block_0_self_attn = ["transformer_blocks.0.attn.processor"] pipe.transformer.set_attn_processor(original_attn_procs.copy()) pag_layers = ["blocks.0"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(block_0_self_attn) pipe.transformer.set_attn_processor(original_attn_procs.copy()) pag_layers = ["blocks.0.attn"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert set(pipe.pag_attn_processors) == set(block_0_self_attn) pipe.transformer.set_attn_processor(original_attn_procs.copy()) pag_layers = ["blocks.(0|1)"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert (len(pipe.pag_attn_processors)) == 2 pipe.transformer.set_attn_processor(original_attn_procs.copy()) pag_layers = ["blocks.0", r"blocks\.1"] pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) assert len(pipe.pag_attn_processors) == 2