Spaces:
Runtime error
Runtime error
File size: 8,614 Bytes
7e0bf18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
# Copyright 2023 Adobe Research. All rights reserved.
# To view a copy of the license, visit LICENSE.md.
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
from PIL import Image
import torch
import gradio as gr
from lavis.models import load_model_and_preprocess
from diffusers import DDIMScheduler
from src.utils.ddim_inv import DDIMInversion
from src.utils.edit_directions import construct_direction
from src.utils.scheduler import DDIMInverseScheduler
from src.utils.edit_pipeline import EditingPipeline
def main():
NUM_DDIM_STEPS = 50
TORCH_DTYPE = torch.float16
XA_GUIDANCE = 0.1
DIR_SCALE = 1.0
MODEL_NAME = 'CompVis/stable-diffusion-v1-4'
NEGATIVE_GUIDANCE_SCALE = 5.0
DEVICE = "cuda"
# if torch.cuda.is_available():
# DEVICE = "cuda"
# else:
# DEVICE = "cpu"
# print(f"Using {DEVICE}")
model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=DEVICE)
pipe = EditingPipeline.from_pretrained(MODEL_NAME, torch_dtype=TORCH_DTYPE, safety_checker=None).to(DEVICE)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
inv_pipe = DDIMInversion.from_pretrained(MODEL_NAME, torch_dtype=TORCH_DTYPE, safety_checker=None).to("cuda")
inv_pipe.scheduler = DDIMInverseScheduler.from_config(inv_pipe.scheduler.config)
TASKS = ["dog2cat","cat2dog","horse2zebra","zebra2horse","horse2llama","dog2capy"]
TASK_OPTIONS = ["Dog to Cat", "Cat to Dog", "Horse to Zebra", "Zebra to Horse", "Horse to Llama", "Dog to Capy"]
def edit_real_image(
og_img,
task,
seed,
xa_guidance,
num_ddim_steps,
dir_scale
):
torch.cuda.manual_seed(seed)
# do inversion first, get inversion and generated prompt
curr_img = og_img.resize((512,512), Image.Resampling.LANCZOS)
_image = vis_processors["eval"](curr_img).unsqueeze(0).to(DEVICE)
prompt_str = model_blip.generate({"image": _image})[0]
x_inv, _, _ = inv_pipe(
prompt_str,
guidance_scale=1,
num_inversion_steps=NUM_DDIM_STEPS,
img=curr_img,
torch_dtype=TORCH_DTYPE
)
task_str = TASKS[task]
rec_pil, edit_pil = pipe(
prompt_str,
num_inference_steps=num_ddim_steps,
x_in=x_inv[0].unsqueeze(0),
edit_dir=construct_direction(task_str)*dir_scale,
guidance_amount=xa_guidance,
guidance_scale=NEGATIVE_GUIDANCE_SCALE,
negative_prompt=prompt_str # use the unedited prompt for the negative prompt
)
return prompt_str, edit_pil[0]
def edit_real_image_example():
test_img = Image.open("./assets/test_images/cats/cat_4.png")
seed = 42
task = 1
prompt_str, edited_img = edit_real_image(test_img, task, seed, XA_GUIDANCE, NUM_DDIM_STEPS, DIR_SCALE)
return test_img, seed, "Cat to Dog", prompt_str, edited_img, XA_GUIDANCE, NUM_DDIM_STEPS, DIR_SCALE
def edit_synthetic_image(seed, task, prompt_str, xa_guidance, num_ddim_steps):
torch.cuda.manual_seed(seed)
x = torch.randn((1,4,64,64), device="cuda")
task_str = TASKS[task]
rec_pil, edit_pil = pipe(
prompt_str,
num_inference_steps=num_ddim_steps,
x_in=x,
edit_dir=construct_direction(task_str),
guidance_amount=xa_guidance,
guidance_scale=NEGATIVE_GUIDANCE_SCALE,
negative_prompt="" # use the empty string for the negative prompt
)
return rec_pil[0], edit_pil[0]
def edit_synth_image_example():
seed = 42
task = 1
xa_guidance = XA_GUIDANCE
num_ddim_steps = NUM_DDIM_STEPS
prompt_str = "A cute white cat sitting on top of the fridge"
recon_img, edited_img = edit_synthetic_image(seed, task, prompt_str, xa_guidance, num_ddim_steps)
return seed, "Cat to Dog", xa_guidance, num_ddim_steps, prompt_str, recon_img, edited_img
with gr.Blocks() as demo:
gr.Markdown("""
### Zero-shot Image-to-Image Translation (https://github.com/pix2pixzero/pix2pix-zero)
Gaurav Parmar, Krishna Kumar Singh, Richard Zhang, Yijun Li, Jingwan Lu, Jun-Yan Zhu <br/>
- For real images:
- Upload an image of a dog, cat or horse,
- Choose one of the task options to turn it into another animal!
- Changing Parameters:
- Increase direction scale is it is not cat (or another animal) enough.
- If the quality is not high enough, increase num ddim steps.
- Increase cross attention guidance to preserve original image structures. <br/>
- For synthetic images:
- Enter a prompt about dogs/cats/horses
- Choose a task option
""")
with gr.Tab("Real Image"):
with gr.Row():
seed = gr.Number(value=42, precision=1, label="Seed", interactive=True)
real_xa_guidance = gr.Number(value=XA_GUIDANCE, label="Cross Attention Guidance", interactive=True)
real_num_ddim_steps = gr.Number(value=NUM_DDIM_STEPS, precision=1, label="Num DDIM steps", interactive=True)
real_edit_dir_scale = gr.Number(value=DIR_SCALE, label="Edit Direction Scale", interactive=True)
real_generate_button = gr.Button("Generate")
real_load_sample_button = gr.Button("Load Example")
with gr.Row():
task_name = gr.Radio(
label='Task Name',
choices=TASK_OPTIONS,
value=TASK_OPTIONS[0],
type="index",
show_label=True,
interactive=True,
)
with gr.Row():
recon_text = gr.Textbox(lines=1, label="Reconstructed Text", interactive=False)
with gr.Row():
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
output_image = gr.Image(label="Output Image", type="pil", interactive=False)
with gr.Tab("Synthetic Images"):
with gr.Row():
synth_seed = gr.Number(value=42, precision=1, label="Seed", interactive=True)
synth_prompt = gr.Textbox(lines=1, label="Prompt", interactive=True)
synth_generate_button = gr.Button("Generate")
synth_load_sample_button = gr.Button("Load Example")
with gr.Row():
synth_task_name = gr.Radio(
label='Task Name',
choices=TASK_OPTIONS,
value=TASK_OPTIONS[0],
type="index",
show_label=True,
interactive=True,
)
synth_xa_guidance = gr.Number(value=XA_GUIDANCE, label="Cross Attention Guidance", interactive=True)
synth_num_ddim_steps = gr.Number(value=NUM_DDIM_STEPS, precision=1, label="Num DDIM steps", interactive=True)
with gr.Row():
synth_input_image = gr.Image(label="Input Image", type="pil", interactive=False)
synth_output_image = gr.Image(label="Output Image", type="pil", interactive=False)
real_generate_button.click(
fn=edit_real_image,
inputs=[
input_image, task_name, seed, real_xa_guidance, real_num_ddim_steps, real_edit_dir_scale
],
outputs=[recon_text, output_image]
)
real_load_sample_button.click(
fn=edit_real_image_example,
inputs=[],
outputs=[input_image, seed, task_name, recon_text, output_image, real_xa_guidance, real_num_ddim_steps, real_edit_dir_scale]
)
synth_generate_button.click(
fn=edit_synthetic_image,
inputs=[synth_seed, synth_task_name, synth_prompt, synth_xa_guidance, synth_num_ddim_steps],
outputs=[synth_input_image, synth_output_image]
)
synth_load_sample_button.click(
fn=edit_synth_image_example,
inputs=[],
outputs=[seed, synth_task_name, synth_xa_guidance, synth_num_ddim_steps, synth_prompt, synth_input_image, synth_output_image]
)
demo.queue(concurrency_count=1)
demo.launch(share=False, server_name="0.0.0.0")
if __name__ == "__main__":
main()
|