import os import gradio as gr import os, pdb import argparse import numpy as np import torch import requests from PIL import Image from transformers import AutoProcessor, BlipForConditionalGeneration from diffusers import UNet2DConditionModel, DDIMScheduler from src.utils.ddim_inv import DDIMInversion from src.utils.scheduler import DDIMInverseScheduler from src.utils.edit_directions import construct_direction, construct_direction_prompts from src.utils.edit_pipeline import EditingPipeline #from src.make_edit_direction import load_sentence_embeddings torch_dtype = torch.float16 device = "cuda" if torch.cuda.is_available() else "cpu" blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large") blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") blip_model_large.to(device) unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to(device) pipe_inversion = DDIMInversion.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch_dtype, unet=unet).to(device) pipe_inversion.scheduler = DDIMInverseScheduler.from_config(pipe_inversion.scheduler.config) pipe_editing = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch_dtype, unet=unet).to(device) pipe_editing.scheduler = DDIMScheduler.from_config(pipe_editing.scheduler.config) def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"): with torch.no_grad(): l_embeddings = [] for sent in l_sentences: text_inputs = tokenizer( sent, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0] l_embeddings.append(prompt_embeds) return torch.cat(l_embeddings, dim=0).mean(dim=0).unsqueeze(0) def generate_caption(processor, model, image, tokenizer=None, use_float_16=False): inputs = processor(images=image, return_tensors="pt").to(device) if use_float_16: inputs = inputs.to(torch.float16) generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50) if tokenizer is not None: generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] else: generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_caption def generate_inversion(prompt, image, num_ddim_steps=50): image = image.resize((512,512), Image.Resampling.LANCZOS) x_inv, x_inv_image, x_dec_img = pipe_inversion( prompt, guidance_scale=1, num_inversion_steps=num_ddim_steps, img=image, torch_dtype=torch_dtype ) return x_inv[0] def run_captioning(image): caption = generate_caption(blip_processor_large, blip_model_large, image).strip() return caption def run_editing(image, original_prompt, edit_prompt, ddim_steps=50, xa_guidance=0.1, negative_guidance_scale=5.0): inverted_noise = generate_inversion(original_prompt, image) source_prompt_embeddings = load_sentence_embeddings([original_prompt], pipe_editing.tokenizer, pipe_editing.text_encoder, device="cuda") target_prompt_embeddings = load_sentence_embeddings([edit_prompt], pipe_editing.tokenizer, pipe_editing.text_encoder, device="cuda") rec_pil, edit_pil = pipe_editing( original_prompt, num_inference_steps=ddim_steps, x_in=inverted_noise.unsqueeze(0), edit_dir=construct_direction_prompts(source_prompt_embeddings,target_prompt_embeddings), guidance_amount=xa_guidance, guidance_scale=negative_guidance_scale, negative_prompt=original_prompt # use the unedited prompt for the negative prompt ) return edit_pil[0] def run_editing_quality(image, item_from, item_from_other, item_to, item_to_other, ddim_steps=50, xa_guidance=0.1, negative_guidance_scale=5.0): caption = generate_caption(blip_processor_large, blip_model_large, image).strip() item_from_selected = item_from if item_from_other == "" else item_from_other item_to_selected = item_to if item_to_other == "" else item_to_other inverted_noise = generate_inversion(caption, image) emb_dir = f"assets/embeddings_sd_1.4" embs_a = torch.load(os.path.join(emb_dir, f"{item_from_selected}.pt")) embs_b = torch.load(os.path.join(emb_dir, f"{item_to_selected}.pt")) edit_dir = (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0) rec_pil, edit_pil = pipe_editing( caption, num_inference_steps=ddim_steps, x_in=inverted_noise.unsqueeze(0), edit_dir=edit_dir, guidance_amount=xa_guidance, guidance_scale=negative_guidance_scale, negative_prompt=caption # use the unedited prompt for the negative prompt ) return edit_pil[0] css = ''' #generate_button{height: 100%} #quality_description{text-align: center; margin-top: 1em} ''' with gr.Blocks(css=css) as demo: gr.Markdown('''## Edit with Words - Pix2Pix Zero demo Upload an image to edit it. You can try `Fast mode` with prompts, or `Quality mode` with custom edit directions. ''') with gr.Row(): with gr.Column(): image = gr.Image(label="Upload your image", type="pil", shape=(512, 512)) with gr.Tabs(): with gr.TabItem("Fast mode"): with gr.Row(): with gr.Column(scale=10): original_prompt = gr.Textbox(label="Image description - type a caption for the image or generate it") with gr.Column(scale=1, min_width=180): btn_caption = gr.Button("Generate caption", elem_id="generate_button") edit_prompt = gr.Textbox(label="Edit prompt - what would you like to edit in the image above. Change one thing at a time") btn_edit_fast = gr.Button("Edit Image") with gr.TabItem("Quality mode"): gr.Markdown("Quality mode temporarely set to only 4 categories. Soon to be dynamic, so you can create your own edit directions.", elem_id="quality_description") with gr.Row(): with gr.Column(): item_from = gr.Dropdown(label="What to identify in your image", choices=["cat", "dog", "horse", "zebra"], value="cat") item_from_other = gr.Textbox(visible=False, label="Type what to identify on your image") item_from.change(lambda choice: gr.Dropdown.update(visible=choice=="Other"), item_from, item_from_other) with gr.Column(): item_to = gr.Dropdown(label="What to replace what you identified for", choices=["cat", "dog", "horse", "zebra"], value="dog") item_to_other = gr.Textbox(visible=False, label="Type what to replace what you identified for") item_to.change(lambda choice: gr.Dropdown.update(visible=choice=="Other"), item_to, item_to_other) btn_edit_quality = gr.Button("Edit Image") with gr.Accordion(label="Advanced settings", open=False): steps = gr.Slider(minimum=2, maximum=50, step=1, value=50, label="Inference Steps") xa_guidance =gr.Slider(minimum=0.0, maximum=10.0, step=0.05, value=0.1, label="xa guidance") negative_scale = gr.Slider(minimum=0.0, maximum=20.0, step=0.1, value=5.0, label="Negative Guidance Scale") with gr.Column(): image_output = gr.Image(label="Image with edits",type="pil",shape=(512, 512)) btn_caption.click(fn=run_captioning, inputs=image, outputs=original_prompt) btn_edit_fast.click(fn=run_editing, inputs=[image, original_prompt, edit_prompt, steps, xa_guidance, negative_scale], outputs=[image_output]) btn_edit_quality.click(fn=run_editing_quality, inputs=[image, item_from, item_from_other, item_to, item_to_other, steps, xa_guidance, negative_scale], outputs=[image_output]) gr.Examples( examples=[ [os.path.join(os.path.dirname(__file__), "assets/test_images/cats/cat_1.png"), "cat", "", "dog", ""], [os.path.join(os.path.dirname(__file__), "assets/test_images/cats/cat_2.png"), "cat", "", "horse", ""], [os.path.join(os.path.dirname(__file__), "assets/test_images/dogs/dog_1.png"), "dog", "", "horse", ""], [os.path.join(os.path.dirname(__file__), "assets/test_images/dogs/dog_2.png"), "dog", "", "cat", ""], ], inputs=[image, item_from, item_from_other, item_to, item_to_other], outputs=image_output, fn=run_editing_quality, cache_examples=True, ) demo.launch()