multimodalart HF staff commited on
Commit
98b9d34
·
1 Parent(s): d40f6d4

Demo first version

Browse files
Files changed (2) hide show
  1. app.py +142 -14
  2. src/utils/edit_directions.py +3 -0
app.py CHANGED
@@ -9,24 +9,152 @@ import torch
9
  import requests
10
  from PIL import Image
11
 
12
- from lavis.models import load_model_and_preprocess
13
 
 
14
  from src.utils.ddim_inv import DDIMInversion
15
  from src.utils.scheduler import DDIMInverseScheduler
 
 
 
16
 
17
- def update(name):
18
- os.system('''python src/inversion.py \
19
- --input_image "assets/test_images/cats/cat_1.png" \
20
- --results_folder "output/test_cat"
21
- ''')
22
- return f"Inverted!"
23
 
24
- with gr.Blocks() as demo:
25
- gr.Markdown("Start typing below and then click **Run** to see the output.")
26
- with gr.Row():
27
- inp = gr.Textbox(placeholder="Do you want to invert?")
28
- out = gr.Textbox()
29
- btn = gr.Button("Run")
30
- btn.click(fn=update, inputs=inp, outputs=out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  demo.launch()
 
9
  import requests
10
  from PIL import Image
11
 
12
+ from transformers import AutoProcessor, BlipForConditionalGeneration
13
 
14
+ from diffusers import UNet2DConditionModel, DDIMScheduler
15
  from src.utils.ddim_inv import DDIMInversion
16
  from src.utils.scheduler import DDIMInverseScheduler
17
+ from src.utils.edit_directions import construct_direction, construct_direction_prompts
18
+ from src.utils.edit_pipeline import EditingPipeline
19
+ #from src.make_edit_direction import load_sentence_embeddings
20
 
21
+ torch_dtype = torch.float16
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
23
 
24
+ blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
25
+ blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
26
+ blip_model_large.to(device)
27
+
28
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to(device)
29
+ pipe_inversion = DDIMInversion.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch_dtype, unet=unet).to(device)
30
+ pipe_inversion.scheduler = DDIMInverseScheduler.from_config(pipe_inversion.scheduler.config)
31
+
32
+ pipe_editing = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch_dtype, unet=unet).to(device)
33
+ pipe_editing.scheduler = DDIMScheduler.from_config(pipe_editing.scheduler.config)
34
+
35
+ def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
36
+ with torch.no_grad():
37
+ l_embeddings = []
38
+ for sent in l_sentences:
39
+ text_inputs = tokenizer(
40
+ sent,
41
+ padding="max_length",
42
+ max_length=tokenizer.model_max_length,
43
+ truncation=True,
44
+ return_tensors="pt",
45
+ )
46
+ text_input_ids = text_inputs.input_ids
47
+ prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
48
+ l_embeddings.append(prompt_embeds)
49
+ return torch.cat(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)
50
+
51
+ def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
52
+ inputs = processor(images=image, return_tensors="pt").to(device)
53
+
54
+ if use_float_16:
55
+ inputs = inputs.to(torch.float16)
56
+
57
+ generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
58
 
59
+ if tokenizer is not None:
60
+ generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
61
+ else:
62
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
63
+
64
+ return generated_caption
65
+
66
+ def generate_inversion(prompt, image, num_ddim_steps=50):
67
+ image = image.resize((512,512), Image.Resampling.LANCZOS)
68
+ x_inv, x_inv_image, x_dec_img = pipe_inversion(
69
+ prompt,
70
+ guidance_scale=1,
71
+ num_inversion_steps=num_ddim_steps,
72
+ img=image,
73
+ torch_dtype=torch_dtype
74
+ )
75
+ return x_inv[0]
76
+
77
+ def swap_blip_model_cpu_gpu(device_to):
78
+ if torch.cuda.is_available():
79
+ blip_model_large.to(device_to)
80
+
81
+ def run_captioning(image):
82
+ caption = generate_caption(blip_processor_large, blip_model_large, image).strip()
83
+ swap_blip_model_cpu_gpu("cpu")
84
+ return caption
85
+
86
+ def run_editing(image, original_prompt, edit_prompt, ddim_steps=50, xa_guidance=0.1, negative_guidance_scale=5.0):
87
+ inverted_noise = generate_inversion(original_prompt, image)
88
+ source_prompt_embeddings = load_sentence_embeddings([original_prompt], pipe_editing.tokenizer, pipe_editing.text_encoder, device="cuda")
89
+ target_prompt_embeddings = load_sentence_embeddings([edit_prompt], pipe_editing.tokenizer, pipe_editing.text_encoder, device="cuda")
90
+ rec_pil, edit_pil = pipe_editing(
91
+ original_prompt,
92
+ num_inference_steps=ddim_steps,
93
+ x_in=inverted_noise.unsqueeze(0),
94
+ edit_dir=construct_direction_prompts(source_prompt_embeddings,target_prompt_embeddings),
95
+ guidance_amount=xa_guidance,
96
+ guidance_scale=negative_guidance_scale,
97
+ negative_prompt=original_prompt # use the unedited prompt for the negative prompt
98
+ )
99
+ return edit_pil[0]
100
+
101
+ def run_editing_quality(image, item_from, item_from_other, item_to, item_to_other, ddim_steps=50, xa_guidance=0.1, negative_guidance_scale=5.0):
102
+ caption = generate_caption(blip_processor_large, blip_model_large, image).strip()
103
+ item_from_selected = item_from if item_from_other == "" else item_from_other
104
+ item_to_selected = item_to if item_to_other == "" else item_to_other
105
+ inverted_noise = generate_inversion(caption, image)
106
+ emb_dir = f"assets/embeddings_sd_1.4"
107
+ embs_a = torch.load(os.path.join(emb_dir, f"{item_from_selected}.pt"))
108
+ embs_b = torch.load(os.path.join(emb_dir, f"{item_to_selected}.pt"))
109
+ edit_dir = (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
110
+ rec_pil, edit_pil = pipe_editing(
111
+ original_prompt,
112
+ num_inference_steps=ddim_steps,
113
+ x_in=inverted_noise.unsqueeze(0),
114
+ edit_dir=edit_dir,
115
+ guidance_amount=xa_guidance,
116
+ guidance_scale=negative_guidance_scale,
117
+ negative_prompt=original_prompt # use the unedited prompt for the negative prompt
118
+ )
119
+ return edit_pil[0]
120
+ css = '''
121
+ #generate_button{height: 100%}
122
+ #quality_description{text-align: center; margin-top: 1em}
123
+ '''
124
+ with gr.Blocks(css=css) as demo:
125
+ gr.Markdown('''## Edit with Words - Pix2Pix Zero demo
126
+ Upload an image to edit it. You can try `Fast mode` with prompts, or `Quality mode` with custom edit directions.
127
+ ''')
128
+ with gr.Row():
129
+ with gr.Column():
130
+ image = gr.Image(label="Upload your image", type="pil", shape=(512, 512))
131
+ with gr.Tabs():
132
+ with gr.TabItem("Fast mode"):
133
+ with gr.Row():
134
+ with gr.Column(scale=10):
135
+ original_prompt = gr.Textbox(label="Image description - either type a caption for the image above or generate it")
136
+ with gr.Column(scale=1, min_width=180):
137
+ btn_caption = gr.Button("Generate caption", elem_id="generate_button")
138
+ edit_prompt = gr.Textbox(label="Edit prompt - what would you like to edit in the image above. Change one thing at a time")
139
+ btn_edit_fast = gr.Button("Edit Image")
140
+ with gr.TabItem("Quality mode"):
141
+ gr.Markdown("Quality mode temporarely set to only 4 categories. Soon to be dynamic, so you can create your own edit directions.", elem_id="quality_description")
142
+ with gr.Row():
143
+ with gr.Column():
144
+ item_from = gr.Dropdown(label="What to identify in your image", choices=["cat", "dog", "horse", "zebra"], value="cat")
145
+ item_from_other = gr.Textbox(visible=False, label="Type what to identify on your image")
146
+ item_from.change(lambda choice: gr.Dropdown.update(visible=choice=="Other"), item_from, item_from_other)
147
+ with gr.Column():
148
+ item_to = gr.Dropdown(label="What to replace what you identified for", choices=["cat", "dog", "horse", "zebra"], value="dog")
149
+ item_to_other = gr.Textbox(visible=False, label="Type what to replace what you identified for")
150
+ item_to.change(lambda choice: gr.Dropdown.update(visible=choice=="Other"), item_to, item_to_other)
151
+ btn_edit_quality = gr.Button("Edit Image")
152
+
153
+ with gr.Column():
154
+ image_output = gr.Image(label="Image with edits",type="pil",shape=(512, 512))
155
+
156
+ btn_caption.click(fn=run_captioning, inputs=image, outputs=original_prompt)
157
+
158
+ btn_edit_fast.click(fn=run_editing, inputs=[image, original_prompt, edit_prompt], outputs=[image_output])
159
+ btn_edit_quality.click(fn=run_editing_quality, inputs=[image, item_from, item_from_other, item_to, item_to_other], outputs=[image_output])
160
  demo.launch()
src/utils/edit_directions.py CHANGED
@@ -27,3 +27,6 @@ def construct_direction(task_name):
27
  return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
28
  else:
29
  raise NotImplementedError
 
 
 
 
27
  return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
28
  else:
29
  raise NotImplementedError
30
+
31
+ def construct_direction_prompts(source_prompt_embeddings, target_prompt_embeddings):
32
+ return((target_prompt_embeddings.mean(0)-source_prompt_embeddings.mean(0)).unsqueeze(0))