Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
from diffusers import DDPMPipeline, DDIMScheduler | |
import open_clip | |
import torchvision | |
from PIL import Image | |
from tqdm import tqdm | |
import torch.nn.functional as F | |
# Initialize device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load CLIP model | |
clip_model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai") | |
clip_model.to(device) | |
# Transform to preprocess images | |
tfms = torchvision.transforms.Compose([ | |
torchvision.transforms.Resize((224, 224)), | |
torchvision.transforms.Normalize( | |
mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711), | |
), | |
]) | |
# CLIP Loss function | |
def clip_loss(image, text_features): | |
# Ensure image is in the correct format (B, C, H, W) | |
if image.dim() == 3: | |
image = image.unsqueeze(0) | |
# Apply transforms | |
image = tfms(image) | |
# Encode image | |
image_features = clip_model.encode_image(image) | |
image_features = F.normalize(image_features, dim=-1) | |
text_features = F.normalize(text_features, dim=-1) | |
loss = (1 - torch.cosine_similarity(image_features, text_features)).mean() | |
return loss | |
# Load Diffusion model | |
model_repo_id = "muneebable/ddpm-celebahq-finetuned-anime-art" | |
image_pipe = DDPMPipeline.from_pretrained(model_repo_id) | |
image_pipe.to(device) | |
# Load scheduler | |
scheduler = DDIMScheduler.from_pretrained(model_repo_id) | |
def generate_image(prompt, guidance_scale, num_steps): | |
scheduler.set_timesteps(num_inference_steps=num_steps) | |
# We embed a prompt with CLIP as our target | |
text = open_clip.tokenize([prompt]).to(device) | |
with torch.no_grad(), torch.cuda.amp.autocast(): | |
text_features = clip_model.encode_text(text) | |
x = torch.randn(1, 3, 256, 256).to(device) | |
n_cuts = 4 | |
for i, t in tqdm(enumerate(scheduler.timesteps)): | |
model_input = scheduler.scale_model_input(x, t) | |
# predict the noise residual | |
with torch.no_grad(): | |
noise_pred = image_pipe.unet(model_input, t)["sample"] | |
cond_grad = 0 | |
for cut in range(n_cuts): | |
# Set requires grad on x | |
x = x.detach().requires_grad_() | |
# Get the predicted x0: | |
x0 = scheduler.step(noise_pred, t, x).pred_original_sample | |
# Calculate loss | |
loss = clip_loss(x0, text_features) * guidance_scale | |
# Get gradient (scale by n_cuts since we want the average) | |
cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts | |
# Modify x based on this gradient | |
alpha_bar = scheduler.alphas_cumprod[i] | |
x = x.detach() + cond_grad * alpha_bar.sqrt() | |
# Now step with scheduler | |
x = scheduler.step(noise_pred, t, x).prev_sample | |
# Convert the tensor to a PIL Image | |
x = x.squeeze(0).permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5 | |
x = (x * 255).byte().numpy() | |
return Image.fromarray(x) | |
# Gradio interface | |
def gradio_interface(prompt, guidance_scale, num_steps): | |
return generate_image(prompt, guidance_scale, num_steps) | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(label="Prompt", value="Red Rose (still life), red flower painting"), | |
gr.Slider(minimum=1, maximum=20, step=1, label="Guidance Scale", value=8), | |
gr.Slider(minimum=10, maximum=100, step=10, label="Number of Steps", value=50) | |
], | |
outputs=gr.Image(type="pil", label="Generated Image"), | |
title="CLIP-Guided Diffusion Image Generation", | |
description="Generate images using CLIP-guided diffusion. Enter a prompt, adjust the guidance scale, and set the number of steps.", | |
examples=[ | |
["A serene landscape with mountains and a lake", 10, 2], | |
["A futuristic cityscape at night", 15, 5], | |
["Red Rose (still life), red flower painting", 5, 5] | |
] | |
) | |
iface.launch() |