Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from PIL import Image | |
import os | |
from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast | |
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler | |
from flux.transformer_flux import FluxTransformer2DModel | |
from flux.pipeline_flux_chameleon import FluxPipeline | |
import torch.nn as nn | |
MODEL_ID = "Djrango/Qwen2vl-Flux" | |
class Qwen2Connector(nn.Module): | |
def __init__(self, input_dim=3584, output_dim=4096): | |
super().__init__() | |
self.linear = nn.Linear(input_dim, output_dim) | |
def forward(self, x): | |
return self.linear(x) | |
class FluxInterface: | |
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"): | |
self.device = device | |
self.dtype = torch.bfloat16 | |
self.models = None | |
self.MODEL_ID = "Djrango/Qwen2vl-Flux" | |
def load_models(self): | |
if self.models is not None: | |
return | |
# Load FLUX components | |
tokenizer = CLIPTokenizer.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder") | |
text_encoder_two = T5EncoderModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder_2") | |
tokenizer_two = T5TokenizerFast.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer_2") | |
# Load VAE and transformer from flux folder | |
vae = AutoencoderKL.from_pretrained(self.MODEL_ID, subfolder="flux") | |
transformer = FluxTransformer2DModel.from_pretrained(self.MODEL_ID, subfolder="flux") | |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(self.MODEL_ID, subfolder="flux/scheduler", shift=1) | |
# Load Qwen2VL components from qwen2-vl folder | |
qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(self.MODEL_ID, subfolder="qwen2-vl") | |
# Load connector and t5 embedder from qwen2-vl folder | |
connector = Qwen2Connector() | |
connector_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/connector.pt" | |
connector_state = torch.hub.load_state_dict_from_url(connector_path, map_location=self.device) | |
connector.load_state_dict(connector_state) | |
# Load T5 embedder | |
self.t5_context_embedder = nn.Linear(4096, 3072) | |
t5_embedder_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/t5_embedder.pt" | |
t5_embedder_state = torch.hub.load_state_dict_from_url(t5_embedder_path, map_location=self.device) | |
self.t5_context_embedder.load_state_dict(t5_embedder_state) | |
# Move models to device and set dtype | |
models = [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, self.t5_context_embedder] | |
for model in models: | |
model.to(self.device).to(self.dtype) | |
model.eval() | |
self.models = { | |
'tokenizer': tokenizer, | |
'text_encoder': text_encoder, | |
'text_encoder_two': text_encoder_two, | |
'tokenizer_two': tokenizer_two, | |
'vae': vae, | |
'transformer': transformer, | |
'scheduler': scheduler, | |
'qwen2vl': qwen2vl, | |
'connector': connector | |
} | |
# Initialize processor and pipeline | |
self.qwen2vl_processor = AutoProcessor.from_pretrained( | |
self.MODEL_ID, | |
subfolder="qwen2-vl", | |
min_pixels=256*28*28, | |
max_pixels=256*28*28 | |
) | |
self.pipeline = FluxPipeline( | |
transformer=transformer, | |
scheduler=scheduler, | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
) | |
def resize_image(self, img, max_pixels=1050000): | |
if not isinstance(img, Image.Image): | |
img = Image.fromarray(img) | |
width, height = img.size | |
num_pixels = width * height | |
if num_pixels > max_pixels: | |
scale = math.sqrt(max_pixels / num_pixels) | |
new_width = int(width * scale) | |
new_height = int(height * scale) | |
new_width = new_width - (new_width % 8) | |
new_height = new_height - (new_height % 8) | |
img = img.resize((new_width, new_height), Image.LANCZOS) | |
return img | |
def process_image(self, image): | |
message = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": "Describe this image."}, | |
] | |
} | |
] | |
text = self.qwen2vl_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) | |
with torch.no_grad(): | |
inputs = self.qwen2vl_processor(text=[text], images=[image], padding=True, return_tensors="pt").to(self.device) | |
output_hidden_state, image_token_mask, image_grid_thw = self.models['qwen2vl'](**inputs) | |
image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1)) | |
image_hidden_state = self.models['connector'](image_hidden_state) | |
return image_hidden_state, image_grid_thw | |
def compute_t5_text_embeddings(self, prompt): | |
"""Compute T5 embeddings for text prompt""" | |
if prompt == "": | |
return None | |
text_inputs = self.models['tokenizer_two']( | |
prompt, | |
padding="max_length", | |
max_length=256, | |
truncation=True, | |
return_tensors="pt" | |
).to(self.device) | |
prompt_embeds = self.models['text_encoder_two'](text_inputs.input_ids)[0] | |
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device) | |
prompt_embeds = self.t5_context_embedder(prompt_embeds) | |
return prompt_embeds | |
def compute_text_embeddings(self, prompt=""): | |
with torch.no_grad(): | |
text_inputs = self.models['tokenizer']( | |
prompt, | |
padding="max_length", | |
max_length=77, | |
truncation=True, | |
return_tensors="pt" | |
).to(self.device) | |
prompt_embeds = self.models['text_encoder']( | |
text_inputs.input_ids, | |
output_hidden_states=False | |
) | |
pooled_prompt_embeds = prompt_embeds.pooler_output.to(self.dtype) | |
return pooled_prompt_embeds | |
def generate(self, input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None): | |
try: | |
if seed is not None: | |
torch.manual_seed(seed) | |
self.load_models() | |
# Process input image | |
input_image = self.resize_image(input_image) | |
qwen2_hidden_state, image_grid_thw = self.process_image(input_image) | |
pooled_prompt_embeds = self.compute_text_embeddings("") | |
# Get T5 embeddings if prompt is provided | |
t5_prompt_embeds = self.compute_t5_text_embeddings(prompt) | |
# Generate images | |
output_images = self.pipeline( | |
prompt_embeds=qwen2_hidden_state.repeat(num_images, 1, 1), | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
).images | |
return output_images | |
except Exception as e: | |
print(f"Error during generation: {str(e)}") | |
raise gr.Error(f"Generation failed: {str(e)}") | |
# Initialize the interface | |
interface = FluxInterface() | |
# Create Gradio interface | |
with gr.Blocks(title="Qwen2vl-Flux Demo") as demo: | |
gr.Markdown(""" | |
# 🎨 Qwen2vl-Flux Image Variation Demo | |
Upload an image and get AI-generated variations. You can optionally add a text prompt to guide the generation. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Upload Image", type="pil") | |
prompt = gr.Textbox(label="Optional Text Prompt(should be as long as possible)", placeholder="Enter text prompt here (optional)") | |
with gr.Row(): | |
guidance = gr.Slider(minimum=1, maximum=10, value=3.5, label="Guidance Scale") | |
steps = gr.Slider(minimum=1, maximum=50, value=28, label="Number of Steps") | |
num_images = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Number of Images") | |
seed = gr.Number(label="Random Seed (optional)", precision=0) | |
submit_btn = gr.Button("Generate Variations", variant="primary") | |
with gr.Column(): | |
output_gallery = gr.Gallery(label="Generated Variations", columns=2, show_label=True) | |
# Set up the generation function | |
submit_btn.click( | |
fn=interface.generate, | |
inputs=[input_image, prompt, guidance, steps, num_images, seed], | |
outputs=output_gallery, | |
) | |
gr.Markdown(""" | |
### Notes: | |
- Higher guidance scale values result in outputs that more closely follow the prompt | |
- More steps generally produce better quality but take longer | |
- Set a seed for reproducible results | |
""") | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |