pix2pix-zero / app.py
multimodalart's picture
Update app.py
7b80ee9
raw
history blame
9.1 kB
import os
import gradio as gr
import os, pdb
import argparse
import numpy as np
import torch
import requests
from PIL import Image
from transformers import AutoProcessor, BlipForConditionalGeneration
from diffusers import UNet2DConditionModel, DDIMScheduler
from src.utils.ddim_inv import DDIMInversion
from src.utils.scheduler import DDIMInverseScheduler
from src.utils.edit_directions import construct_direction, construct_direction_prompts
from src.utils.edit_pipeline import EditingPipeline
#from src.make_edit_direction import load_sentence_embeddings
torch_dtype = torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"
blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large.to(device)
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to(device)
pipe_inversion = DDIMInversion.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch_dtype, unet=unet).to(device)
pipe_inversion.scheduler = DDIMInverseScheduler.from_config(pipe_inversion.scheduler.config)
pipe_editing = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch_dtype, unet=unet).to(device)
pipe_editing.scheduler = DDIMScheduler.from_config(pipe_editing.scheduler.config)
def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
with torch.no_grad():
l_embeddings = []
for sent in l_sentences:
text_inputs = tokenizer(
sent,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
l_embeddings.append(prompt_embeds)
return torch.cat(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)
def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
inputs = processor(images=image, return_tensors="pt").to(device)
if use_float_16:
inputs = inputs.to(torch.float16)
generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
if tokenizer is not None:
generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_caption
def generate_inversion(prompt, image, num_ddim_steps=50):
image = image.resize((512,512), Image.Resampling.LANCZOS)
x_inv, x_inv_image, x_dec_img = pipe_inversion(
prompt,
guidance_scale=1,
num_inversion_steps=num_ddim_steps,
img=image,
torch_dtype=torch_dtype
)
return x_inv[0]
def run_captioning(image):
caption = generate_caption(blip_processor_large, blip_model_large, image).strip()
return caption
def run_editing(image, original_prompt, edit_prompt, ddim_steps=50, xa_guidance=0.1, negative_guidance_scale=5.0):
inverted_noise = generate_inversion(original_prompt, image)
source_prompt_embeddings = load_sentence_embeddings([original_prompt], pipe_editing.tokenizer, pipe_editing.text_encoder, device="cuda")
target_prompt_embeddings = load_sentence_embeddings([edit_prompt], pipe_editing.tokenizer, pipe_editing.text_encoder, device="cuda")
rec_pil, edit_pil = pipe_editing(
original_prompt,
num_inference_steps=ddim_steps,
x_in=inverted_noise.unsqueeze(0),
edit_dir=construct_direction_prompts(source_prompt_embeddings,target_prompt_embeddings),
guidance_amount=xa_guidance,
guidance_scale=negative_guidance_scale,
negative_prompt=original_prompt # use the unedited prompt for the negative prompt
)
return edit_pil[0]
def run_editing_quality(image, item_from, item_from_other, item_to, item_to_other, ddim_steps=50, xa_guidance=0.1, negative_guidance_scale=5.0):
caption = generate_caption(blip_processor_large, blip_model_large, image).strip()
item_from_selected = item_from if item_from_other == "" else item_from_other
item_to_selected = item_to if item_to_other == "" else item_to_other
inverted_noise = generate_inversion(caption, image)
emb_dir = f"assets/embeddings_sd_1.4"
embs_a = torch.load(os.path.join(emb_dir, f"{item_from_selected}.pt"))
embs_b = torch.load(os.path.join(emb_dir, f"{item_to_selected}.pt"))
edit_dir = (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
rec_pil, edit_pil = pipe_editing(
caption,
num_inference_steps=ddim_steps,
x_in=inverted_noise.unsqueeze(0),
edit_dir=edit_dir,
guidance_amount=xa_guidance,
guidance_scale=negative_guidance_scale,
negative_prompt=caption # use the unedited prompt for the negative prompt
)
return edit_pil[0]
css = '''
#generate_button{height: 100%}
#quality_description{text-align: center; margin-top: 1em}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown('''## Edit with Words - Pix2Pix Zero demo
Upload an image to edit it. You can try `Fast mode` with prompts, or `Quality mode` with custom edit directions.
''')
with gr.Row():
with gr.Column():
image = gr.Image(label="Upload your image", type="pil", shape=(512, 512))
with gr.Tabs():
with gr.TabItem("Fast mode"):
with gr.Row():
with gr.Column(scale=10):
original_prompt = gr.Textbox(label="Image description - type a caption for the image or generate it")
with gr.Column(scale=1, min_width=180):
btn_caption = gr.Button("Generate caption", elem_id="generate_button")
edit_prompt = gr.Textbox(label="Edit prompt - what would you like to edit in the image above. Change one thing at a time")
btn_edit_fast = gr.Button("Edit Image")
with gr.TabItem("Quality mode"):
gr.Markdown("Quality mode temporarely set to only 4 categories. Soon to be dynamic, so you can create your own edit directions.", elem_id="quality_description")
with gr.Row():
with gr.Column():
item_from = gr.Dropdown(label="What to identify in your image", choices=["cat", "dog", "horse", "zebra"], value="cat")
item_from_other = gr.Textbox(visible=False, label="Type what to identify on your image")
item_from.change(lambda choice: gr.Dropdown.update(visible=choice=="Other"), item_from, item_from_other)
with gr.Column():
item_to = gr.Dropdown(label="What to replace what you identified for", choices=["cat", "dog", "horse", "zebra"], value="dog")
item_to_other = gr.Textbox(visible=False, label="Type what to replace what you identified for")
item_to.change(lambda choice: gr.Dropdown.update(visible=choice=="Other"), item_to, item_to_other)
btn_edit_quality = gr.Button("Edit Image")
with gr.Accordion(label="Advanced settings", open=False):
steps = gr.Slider(minimum=2, maximum=50, step=1, value=50, label="Inference Steps")
xa_guidance =gr.Slider(minimum=0.0, maximum=10.0, step=0.05, value=0.1, label="xa guidance")
negative_scale = gr.Slider(minimum=0.0, maximum=20.0, step=0.1, value=5.0, label="Negative Guidance Scale")
with gr.Column():
image_output = gr.Image(label="Image with edits",type="pil",shape=(512, 512))
btn_caption.click(fn=run_captioning, inputs=image, outputs=original_prompt)
btn_edit_fast.click(fn=run_editing, inputs=[image, original_prompt, edit_prompt, steps, xa_guidance, negative_scale], outputs=[image_output])
btn_edit_quality.click(fn=run_editing_quality, inputs=[image, item_from, item_from_other, item_to, item_to_other, steps, xa_guidance, negative_scale], outputs=[image_output])
gr.Examples(
examples=[
[os.path.join(os.path.dirname(__file__), "assets/test_images/cats/cat_1.png"), "cat", "", "dog", ""],
[os.path.join(os.path.dirname(__file__), "assets/test_images/cats/cat_2.png"), "cat", "", "horse", ""],
[os.path.join(os.path.dirname(__file__), "assets/test_images/dogs/dog_1.png"), "dog", "", "horse", ""],
[os.path.join(os.path.dirname(__file__), "assets/test_images/dogs/dog_2.png"), "dog", "", "cat", ""],
],
inputs=[image, item_from, item_from_other, item_to, item_to_other],
outputs=image_output,
fn=run_editing_quality,
cache_examples=True,
)
demo.launch()