muneebable's picture
Update app.py
8dd5003 verified
import gradio as gr
import numpy as np
import torch
from diffusers import DDPMPipeline, DDIMScheduler
import open_clip
import torchvision
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
# Initialize device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load CLIP model
clip_model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
clip_model.to(device)
# Transform to preprocess images
tfms = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
])
# CLIP Loss function
def clip_loss(image, text_features):
# Ensure image is in the correct format (B, C, H, W)
if image.dim() == 3:
image = image.unsqueeze(0)
# Apply transforms
image = tfms(image)
# Encode image
image_features = clip_model.encode_image(image)
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
loss = (1 - torch.cosine_similarity(image_features, text_features)).mean()
return loss
# Load Diffusion model
model_repo_id = "muneebable/ddpm-celebahq-finetuned-anime-art"
image_pipe = DDPMPipeline.from_pretrained(model_repo_id)
image_pipe.to(device)
# Load scheduler
scheduler = DDIMScheduler.from_pretrained(model_repo_id)
def generate_image(prompt, guidance_scale, num_steps):
scheduler.set_timesteps(num_inference_steps=num_steps)
# We embed a prompt with CLIP as our target
text = open_clip.tokenize([prompt]).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = clip_model.encode_text(text)
x = torch.randn(1, 3, 256, 256).to(device)
n_cuts = 4
for i, t in tqdm(enumerate(scheduler.timesteps)):
model_input = scheduler.scale_model_input(x, t)
# predict the noise residual
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
cond_grad = 0
for cut in range(n_cuts):
# Set requires grad on x
x = x.detach().requires_grad_()
# Get the predicted x0:
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# Calculate loss
loss = clip_loss(x0, text_features) * guidance_scale
# Get gradient (scale by n_cuts since we want the average)
cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts
# Modify x based on this gradient
alpha_bar = scheduler.alphas_cumprod[i]
x = x.detach() + cond_grad * alpha_bar.sqrt()
# Now step with scheduler
x = scheduler.step(noise_pred, t, x).prev_sample
# Convert the tensor to a PIL Image
x = x.squeeze(0).permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
x = (x * 255).byte().numpy()
return Image.fromarray(x)
# Gradio interface
def gradio_interface(prompt, guidance_scale, num_steps):
return generate_image(prompt, guidance_scale, num_steps)
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(label="Prompt", value="Red Rose (still life), red flower painting"),
gr.Slider(minimum=1, maximum=20, step=1, label="Guidance Scale", value=8),
gr.Slider(minimum=10, maximum=100, step=10, label="Number of Steps", value=50)
],
outputs=gr.Image(type="pil", label="Generated Image"),
title="CLIP-Guided Diffusion Image Generation",
description="Generate images using CLIP-guided diffusion. Enter a prompt, adjust the guidance scale, and set the number of steps.",
examples=[
["A serene landscape with mountains and a lake", 10, 2],
["A futuristic cityscape at night", 15, 5],
["Red Rose (still life), red flower painting", 5, 5]
]
)
iface.launch()