Spaces:
Running
on
Zero
Running
on
Zero
import unittest | |
import torch | |
from transformers import AutoTokenizer, Gemma2Config, Gemma2Model | |
from diffusers import ( | |
AutoencoderKL, | |
FlowMatchEulerDiscreteScheduler, | |
Lumina2Pipeline, | |
Lumina2Transformer2DModel, | |
) | |
from ..test_pipelines_common import PipelineTesterMixin | |
class Lumina2PipelineFastTests(unittest.TestCase, PipelineTesterMixin): | |
pipeline_class = Lumina2Pipeline | |
params = frozenset( | |
[ | |
"prompt", | |
"height", | |
"width", | |
"guidance_scale", | |
"negative_prompt", | |
"prompt_embeds", | |
"negative_prompt_embeds", | |
] | |
) | |
batch_params = frozenset(["prompt", "negative_prompt"]) | |
required_optional_params = frozenset( | |
[ | |
"num_inference_steps", | |
"generator", | |
"latents", | |
"return_dict", | |
"callback_on_step_end", | |
"callback_on_step_end_tensor_inputs", | |
] | |
) | |
supports_dduf = False | |
test_xformers_attention = False | |
test_layerwise_casting = True | |
def get_dummy_components(self): | |
torch.manual_seed(0) | |
transformer = Lumina2Transformer2DModel( | |
sample_size=4, | |
patch_size=2, | |
in_channels=4, | |
hidden_size=8, | |
num_layers=2, | |
num_attention_heads=1, | |
num_kv_heads=1, | |
multiple_of=16, | |
ffn_dim_multiplier=None, | |
norm_eps=1e-5, | |
scaling_factor=1.0, | |
axes_dim_rope=[4, 2, 2], | |
cap_feat_dim=8, | |
) | |
torch.manual_seed(0) | |
vae = AutoencoderKL( | |
sample_size=32, | |
in_channels=3, | |
out_channels=3, | |
block_out_channels=(4,), | |
layers_per_block=1, | |
latent_channels=4, | |
norm_num_groups=1, | |
use_quant_conv=False, | |
use_post_quant_conv=False, | |
shift_factor=0.0609, | |
scaling_factor=1.5035, | |
) | |
scheduler = FlowMatchEulerDiscreteScheduler() | |
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") | |
torch.manual_seed(0) | |
config = Gemma2Config( | |
head_dim=4, | |
hidden_size=8, | |
intermediate_size=8, | |
num_attention_heads=2, | |
num_hidden_layers=2, | |
num_key_value_heads=2, | |
sliding_window=2, | |
) | |
text_encoder = Gemma2Model(config) | |
components = { | |
"transformer": transformer, | |
"vae": vae.eval(), | |
"scheduler": scheduler, | |
"text_encoder": text_encoder, | |
"tokenizer": tokenizer, | |
} | |
return components | |
def get_dummy_inputs(self, device, seed=0): | |
if str(device).startswith("mps"): | |
generator = torch.manual_seed(seed) | |
else: | |
generator = torch.Generator(device="cpu").manual_seed(seed) | |
inputs = { | |
"prompt": "A painting of a squirrel eating a burger", | |
"generator": generator, | |
"num_inference_steps": 2, | |
"guidance_scale": 5.0, | |
"height": 32, | |
"width": 32, | |
"output_type": "np", | |
} | |
return inputs | |