import gradio as gr import numpy as np import random import torch from diffusers import DDPMPipeline, DDIMScheduler import open_clip import torchvision from PIL import Image from tqdm import tqdm import torch.nn.functional as F # Initialize device device = "cuda" if torch.cuda.is_available() else "cpu" # Load CLIP model clip_model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai") clip_model.to(device) # Transform to preprocess images tfms = torchvision.transforms.Compose( [ torchvision.transforms.Resize((224, 224)), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) # CLIP Loss function def clip_loss(image, text_features): image_features = clip_model.encode_image(tfms(image).unsqueeze(0).to(device)) image_features = F.normalize(image_features, dim=-1) text_features = F.normalize(text_features, dim=-1) loss = (1 - torch.cosine_similarity(image_features, text_features)).mean() return loss # Load Diffusion model model_repo_id = "muneebable/ddpm-celebahq-finetuned-anime-art" # Replace with desired model repo image_pipe = DDPMPipeline.from_pretrained(model_repo_id) image_pipe.to(device) # Load scheduler scheduler = DDIMScheduler.from_pretrained(model_repo_id) scheduler.set_timesteps(num_inference_steps=40) # Gradio Inference Function def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)): if randomize_seed: seed = random.randint(0, np.iinfo(np.int32).max) generator = torch.manual_seed(seed) # Embed prompt with CLIP text = open_clip.tokenize([prompt]).to(device) with torch.no_grad(): text_features = clip_model.encode_text(text) x = torch.randn(4, 3, 256, 256).to(device) for i, t in tqdm(enumerate(scheduler.timesteps)): model_input = scheduler.scale_model_input(x, t) with torch.no_grad(): noise_pred = image_pipe.unet(model_input, t)["sample"] cond_grad = 0 for cut in range(4): x = x.detach().requires_grad_() x0 = scheduler.step(noise_pred, t, x).pred_original_sample loss = clip_loss(x0, text_features) * guidance_scale cond_grad -= torch.autograd.grad(loss, x)[0] / 4 alpha_bar = scheduler.alphas_cumprod[i] x = x.detach() + cond_grad * alpha_bar.sqrt() x = scheduler.step(noise_pred, t, x).prev_sample # Convert output to an image grid = torchvision.utils.make_grid(x.detach(), nrow=4) im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5 result_image = Image.fromarray((im.numpy() * 255).astype(np.uint8)) return result_image, seed # Gradio App with gr.Blocks() as demo: prompt = gr.Textbox(placeholder="Enter your prompt", label="Prompt") run_button = gr.Button("Generate") result = gr.Image(label="Generated Image") with gr.Accordion("Advanced Settings"): negative_prompt = gr.Textbox(label="Negative Prompt") seed = gr.Slider(0, np.iinfo(np.int32).max, value=0, label="Seed") randomize_seed = gr.Checkbox(True, label="Randomize Seed") width = gr.Slider(256, 1024, value=512, label="Width") height = gr.Slider(256, 1024, value=512, label="Height") guidance_scale = gr.Slider(0.0, 10.0, value=7.5, label="Guidance Scale") num_inference_steps = gr.Slider(1, 50, value=50, label="Steps") run_button.click(infer, inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], outputs=[result, seed]) demo.queue().launch()