vidit98 commited on
Commit
0d2dd65
·
1 Parent(s): 9541d96

update code

Browse files
.gitattributes CHANGED
@@ -32,3 +32,37 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ assets/examples/ian.jpeg filter=lfs diff=lfs merge=lfs -text
36
+ assets/examples/resized_anm_38.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/examples/anm_8.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/examples/house.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ assets/examples/door2.jpeg filter=lfs diff=lfs merge=lfs -text
40
+ assets/examples/door.jpeg filter=lfs diff=lfs merge=lfs -text
41
+ assets/examples/frn_38.jpg filter=lfs diff=lfs merge=lfs -text
42
+ assets/examples/park.webp filter=lfs diff=lfs merge=lfs -text
43
+ assets/examples/car1.webp filter=lfs diff=lfs merge=lfs -text
44
+ assets/examples/car.jpeg filter=lfs diff=lfs merge=lfs -text
45
+ assets/examples/house2.jpeg filter=lfs diff=lfs merge=lfs -text
46
+ assets/examples/Lancia.webp filter=lfs diff=lfs merge=lfs -text
47
+ assets/examples/obj_11.jpg filter=lfs diff=lfs merge=lfs -text
48
+ assets/examples/resized_anm_8.jpg filter=lfs diff=lfs merge=lfs -text
49
+ assets/examples/resized_frn_38.jpg filter=lfs diff=lfs merge=lfs -text
50
+ assets/examples/resized_obj_11.jpg filter=lfs diff=lfs merge=lfs -text
51
+ assets/examples/dog.jpeg filter=lfs diff=lfs merge=lfs -text
52
+ assets/examples/grasslands-national-park.jpeg filter=lfs diff=lfs merge=lfs -text
53
+ assets/examples/resized_obj_38.jpg filter=lfs diff=lfs merge=lfs -text
54
+ assets/examples/chair1.jpeg filter=lfs diff=lfs merge=lfs -text
55
+ assets/examples/chair.jpeg filter=lfs diff=lfs merge=lfs -text
56
+ assets/examples/obj_38.jpg filter=lfs diff=lfs merge=lfs -text
57
+ assets/examples/ran.webp filter=lfs diff=lfs merge=lfs -text
58
+ assets/examples/anm_38.jpg filter=lfs diff=lfs merge=lfs -text
59
+ assets/examples/carpet2.webp filter=lfs diff=lfs merge=lfs -text
60
+ assets/ironman.webp filter=lfs diff=lfs merge=lfs -text
61
+ assets/truck2.jpeg filter=lfs diff=lfs merge=lfs -text
62
+ assets/truck.png filter=lfs diff=lfs merge=lfs -text
63
+ assets/ski.jpg filter=lfs diff=lfs merge=lfs -text
64
+ assets/Teaser_Small.png filter=lfs diff=lfs merge=lfs -text
65
+ assets/examples filter=lfs diff=lfs merge=lfs -text
66
+ assets/GIF.gif filter=lfs diff=lfs merge=lfs -text
67
+ assets/hulk.jpeg filter=lfs diff=lfs merge=lfs -text
68
+ assets/lava.jpg filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,429 +1,484 @@
1
-
2
-
3
- import einops
4
  import gradio as gr
5
- import numpy as np
6
- import torch
7
- import random
8
- import os
9
- import subprocess
10
- import shlex
11
-
12
- from huggingface_hub import hf_hub_url, hf_hub_download
13
- from share import *
14
-
15
- from pytorch_lightning import seed_everything
16
- from annotator.util import resize_image, HWC3
17
- from annotator.OneFormer import OneformerSegmenter
18
- from cldm.model import create_model, load_state_dict
19
- from cldm.ddim_hacked import DDIMSamplerSpaCFG
20
- from ldm.models.autoencoder import DiagonalGaussianDistribution
21
-
22
- urls = {
23
- 'shi-labs/oneformer_coco_swin_large': ['150_16_swin_l_oneformer_coco_100ep.pth'],
24
- 'PAIR/PAIR-diffusion-sdv15-coco-finetune': ['pair_diffusion_epoch62.ckpt']
25
- }
26
-
27
- WTS_DICT = {
28
-
29
- }
30
-
31
- if os.path.exists('checkpoints') == False:
32
- os.mkdir('checkpoints')
33
- for repo in urls:
34
- files = urls[repo]
35
- for file in files:
36
- url = hf_hub_url(repo, file)
37
- name_ckp = url.split('/')[-1]
38
- WTS_DICT[repo] = hf_hub_download(repo_id=repo, filename=file, token=os.environ.get("ACCESS_TOKEN"))
39
-
40
- print(WTS_DICT)
41
- apply_segmentor = OneformerSegmenter(WTS_DICT['shi-labs/oneformer_coco_swin_large'])
42
-
43
- model = create_model('./configs/sap_fixed_hintnet_v15.yaml').cpu()
44
- model.load_state_dict(load_state_dict(WTS_DICT['PAIR/PAIR-diffusion-sdv15-coco-finetune'], location='cuda'))
45
- model = model.cuda()
46
- ddim_sampler = DDIMSamplerSpaCFG(model)
47
- _COLORS = []
48
- save_memory = False
49
-
50
- def gen_color():
51
- color = tuple(np.round(np.random.choice(range(256), size=3), 3))
52
- if color not in _COLORS and np.mean(color) != 0.0:
53
- _COLORS.append(color)
54
- else:
55
- gen_color()
56
-
57
-
58
- for _ in range(300):
59
- gen_color()
60
-
61
 
62
- class ImageComp:
63
- def __init__(self, edit_operation):
64
- self.input_img = None
65
- self.input_pmask = None
66
- self.input_segmask = None
67
 
68
- self.ref_img = None
69
- self.ref_pmask = None
70
- self.ref_segmask = None
71
-
72
- self.H = None
73
- self.W = None
74
- self.baseoutput = None
75
- self.kernel = np.ones((5, 5), np.uint8)
76
- self.edit_operation = edit_operation
77
-
78
- def init_input_canvas(self, img):
79
- img = HWC3(img)
80
- img = resize_image(img, 512)
81
- detected_mask = apply_segmentor(img, 'panoptic')[0]
82
- detected_seg = apply_segmentor(img, 'semantic')
83
 
84
- self.input_img = img
85
- self.input_pmask = detected_mask
86
- self.input_segmask = detected_seg
87
- self.H = img.shape[0]
88
- self.W = img.shape[1]
89
-
90
- detected_mask = detected_mask.cpu().numpy()
91
 
92
- uni = np.unique(detected_mask)
93
- color_mask = np.zeros((detected_mask.shape[0], detected_mask.shape[1], 3))
94
- for i in uni:
95
- color_mask[detected_mask == i] = _COLORS[i]
96
 
97
- output = color_mask*0.8 + img * 0.2
98
- self.baseoutput = output.astype(np.uint8)
99
- return self.baseoutput
100
-
101
- def init_ref_canvas(self, img):
102
- img = HWC3(img)
103
- img = resize_image(img, 512)
104
- detected_mask = apply_segmentor(img, 'panoptic')[0]
105
- detected_seg = apply_segmentor(img, 'semantic')
106
-
107
- self.ref_img = img
108
- self.ref_pmask = detected_mask
109
- self.ref_segmask = detected_seg
110
-
111
- detected_mask = detected_mask.cpu().numpy()
112
-
113
- uni = np.unique(detected_mask)
114
- color_mask = np.zeros((detected_mask.shape[0], detected_mask.shape[1], 3))
115
- for i in uni:
116
- color_mask[detected_mask == i] = _COLORS[i]
117
-
118
- output = color_mask*0.8 + img * 0.2
119
- self.baseoutput = output.astype(np.uint8)
120
- return self.baseoutput
121
-
122
- def _process_mask(self, mask, panoptic_mask, segmask):
123
- panoptic_mask_ = panoptic_mask + 1
124
- mask_ = resize_image(mask['mask'][:, :, 0], min(panoptic_mask.shape))
125
- mask_ = torch.tensor(mask_)
126
- maski = torch.zeros_like(mask_).cuda()
127
- maski[mask_ > 127] = 1
128
- mask = maski * panoptic_mask_
129
- unique_ids, counts = torch.unique(mask, return_counts=True)
130
- mask_id = unique_ids[torch.argmax(counts[1:]) + 1]
131
- final_mask = torch.zeros(mask.shape).cuda()
132
- final_mask[panoptic_mask_ == mask_id] = 1
133
-
134
- obj_class = maski * (segmask + 1)
135
- unique_ids, counts = torch.unique(obj_class, return_counts=True)
136
- obj_class = unique_ids[torch.argmax(counts[1:]) + 1] - 1
137
- return final_mask, obj_class
138
-
139
-
140
- def _edit_app(self, input_mask, ref_mask, whole_ref):
141
- input_pmask = self.input_pmask
142
- input_segmask = self.input_segmask
143
 
144
- if whole_ref:
145
- reference_mask = torch.ones(self.ref_pmask.shape).cuda()
146
- else:
147
- reference_mask, _ = self._process_mask(ref_mask, self.ref_pmask, self.ref_segmask)
148
 
149
- edit_mask, _ = self._process_mask(input_mask, self.input_pmask, self.input_segmask)
150
- ma = torch.max(input_pmask)
151
- input_pmask[edit_mask == 1] = ma + 1
152
- return reference_mask, input_pmask, input_segmask, edit_mask, ma
153
 
154
-
155
- def _edit(self, input_mask, ref_mask, whole_ref=False, inter=1):
156
- input_img = (self.input_img/127.5 - 1)
157
- input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
158
 
159
- reference_img = (self.ref_img/127.5 - 1)
160
- reference_img = torch.from_numpy(reference_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
161
 
162
- reference_mask, input_pmask, input_segmask, region_mask, ma = self._edit_app(input_mask, ref_mask, whole_ref)
 
163
 
164
- input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1)
165
- _, mean_feat_inpt, one_hot_inpt, empty_mask_flag_inpt = model.get_appearance(input_img, input_pmask, return_all=True)
 
 
 
166
 
167
- reference_mask = reference_mask.float().cuda().unsqueeze(0).unsqueeze(1)
168
- _, mean_feat_ref, _, _ = model.get_appearance(reference_img, reference_mask, return_all=True)
 
 
 
 
169
 
170
- if mean_feat_ref.shape[1] > 1:
171
- mean_feat_inpt[:, ma + 1] = (1 - inter) * mean_feat_inpt[:, ma + 1] + inter*mean_feat_ref[:, 1]
172
 
173
- splatted_feat = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt, one_hot_inpt)
174
- appearance = torch.nn.functional.normalize(splatted_feat) #l2 normaliz
175
- input_segmask = ((input_segmask+1)/ 127.5 - 1.0).cuda().unsqueeze(0).unsqueeze(1)
176
- structure = torch.nn.functional.interpolate(input_segmask, (self.H, self.W))
177
- appearance = torch.nn.functional.interpolate(appearance, (self.H, self.W))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
 
179
 
180
- return structure, appearance, region_mask, input_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- def process(self, input_mask, ref_mask, prompt, a_prompt, n_prompt,
183
- num_samples, ddim_steps, guess_mode, strength,
184
- scale_s, scale_f, scale_t, seed, eta, masking=True,whole_ref=False,inter=1):
185
- structure, appearance, mask, img = self._edit(input_mask, ref_mask,
186
- whole_ref=whole_ref, inter=inter)
187
-
188
- null_structure = torch.zeros(structure.shape).cuda() - 1
189
- null_appearance = torch.zeros(appearance.shape).cuda()
190
-
191
- null_control = torch.cat([null_structure, null_appearance], dim=1)
192
- structure_control = torch.cat([structure, null_appearance], dim=1)
193
- full_control = torch.cat([structure, appearance], dim=1)
194
-
195
- null_control = torch.cat([null_control for _ in range(num_samples)], dim=0)
196
- structure_control = torch.cat([structure_control for _ in range(num_samples)], dim=0)
197
- full_control = torch.cat([full_control for _ in range(num_samples)], dim=0)
198
-
199
- #Masking for local edit
200
- if not masking:
201
- mask, x0 = None, None
202
- else:
203
- x0 = model.encode_first_stage(img)
204
- x0 = x0.sample() if isinstance(x0, DiagonalGaussianDistribution) else x0 # todo: check if we can set random number
205
- x0 = x0 * model.scale_factor
206
- mask = 1 - torch.tensor(mask).unsqueeze(0).unsqueeze(1).cuda()
207
- mask = torch.nn.functional.interpolate(mask, x0.shape[2:]).float()
208
-
209
- if seed == -1:
210
- seed = random.randint(0, 65535)
211
- seed_everything(seed)
212
 
213
- scale = [scale_s, scale_f, scale_t]
214
- print(scale)
215
- if save_memory:
216
- model.low_vram_shift(is_diffusing=False)
217
- # uc_cross = model.get_unconditional_conditioning(num_samples)
218
- uc_cross = model.get_learned_conditioning([n_prompt] * num_samples)
219
- cond = {"c_concat": [full_control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
220
- un_cond = {"c_concat": None if guess_mode else [null_control], "c_crossattn": [uc_cross]}
221
- un_cond_struct = {"c_concat": None if guess_mode else [structure_control], "c_crossattn": [uc_cross]}
222
- un_cond_struct_app = {"c_concat": None if guess_mode else [full_control], "c_crossattn": [uc_cross]}
223
 
224
- shape = (4, self.H // 8, self.W // 8)
 
225
 
226
- if save_memory:
227
- model.low_vram_shift(is_diffusing=True)
228
 
229
- model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
230
- samples, _ = ddim_sampler.sample(ddim_steps, num_samples,
231
- shape, cond, verbose=False, eta=eta,
232
- unconditional_guidance_scale=scale, mask=mask, x0=x0,
233
- unconditional_conditioning=[un_cond, un_cond_struct, un_cond_struct_app ])
234
 
235
- if save_memory:
236
- model.low_vram_shift(is_diffusing=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
- x_samples = (model.decode_first_stage(samples) + 1) * 127.5
239
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')).cpu().numpy().clip(0, 255).astype(np.uint8)
240
 
241
- results = [x_samples[i] for i in range(num_samples)]
242
- return [] + results
 
 
 
 
243
 
244
 
245
- def init_input_canvas_wrapper(obj, *args):
246
- return obj.init_input_canvas(*args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- def init_ref_canvas_wrapper(obj, *args):
249
- return obj.init_ref_canvas(*args)
 
 
 
 
 
250
 
251
- def process_wrapper(obj, *args):
252
- return obj.process(*args)
 
253
 
 
 
 
254
 
 
255
 
256
- css = """
257
- h1 {
258
- text-align: center;
259
- }
260
- .container {
261
- display: flex;
262
- justify-content: space-between
263
- }
264
-
265
- img {
266
- max-width: 100%
267
- padding-right: 100px;
268
- }
269
-
270
- .image {
271
- flex-basis: 40%
272
-
273
- }
274
-
275
- .text {
276
- font-size: 15px;
277
- padding-right: 20px;
278
- padding-left: 0px;
279
- }
280
- """
281
 
282
- def create_app_demo():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
 
284
  with gr.Row():
285
- gr.Markdown("## Object Level Appearance Editing")
286
  with gr.Row():
287
  gr.HTML(
288
  """
289
- <div class="container">
290
- <div class="text">
291
- <h4> Instructions </h4>
292
- <ol>
293
- <li>Upload an Input Image.</li>
294
- <li>Mark one of segmented objects in the <i>Select Object to Edit</i> tab.</li>
295
- <li>Upload an Reference Image.</li>
296
- <li>Mark one of segmented objects in the <i>Select Reference Object</i> tab, for the reference appearance.</li>
297
- <li>Enter a prompt and press <i>Run</i> button. (A very simple would also work) </li>
298
- </ol>
299
- </div>
300
- <div class="image">
301
- <img src="file/assets/GIF.gif" width="400"">
302
- </div>
303
- </div>
304
- """)
305
  with gr.Column():
306
  with gr.Row():
307
  img_edit = gr.State(ImageComp('edit_app'))
308
  with gr.Column():
309
- btn1 = gr.Button("Input Image")
310
  input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
311
  with gr.Column():
312
- btn2 = gr.Button("Select Object to Edit")
313
- input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy", tool="sketch")
314
- input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_mask], queue=False)
315
-
316
- # with gr.Row():
317
- with gr.Column():
318
- btn3 = gr.Button("Reference Image")
319
- ref_img = gr.Image(source='upload', label='Reference Image', type="numpy")
320
- with gr.Column():
321
- btn4 = gr.Button("Select Reference Object")
322
- reference_mask = gr.Image(source="upload", label='Select Object in Refernce Image', type="numpy", tool="sketch")
323
 
324
- ref_img.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, ref_img], outputs=[reference_mask], queue=False)
325
-
326
  with gr.Row():
327
- prompt = gr.Textbox(label="Prompt", value='A picture of truck')
328
- with gr.Column():
329
- interpolation = gr.Slider(label="Mixing ratio of appearance from reference object", minimum=0.1, maximum=1, value=1.0, step=0.1)
330
- whole_ref = gr.Checkbox(label='Use whole reference Image for appearance (Only useful for style transfers)', value=False)
 
 
 
 
 
 
 
331
  with gr.Row():
332
  run_button = gr.Button(label="Run")
 
333
 
334
  with gr.Row():
335
  result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
336
 
337
  with gr.Accordion("Advanced options", open=False):
338
- num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
 
339
  strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
340
  guess_mode = gr.Checkbox(label='Guess Mode', value=False)
341
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
342
- scale_t = gr.Slider(label="Guidance Scale Text", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
343
- scale_f = gr.Slider(label="Guidance Scale Appearance", minimum=0.1, maximum=30.0, value=8.0, step=0.1)
344
- scale_s = gr.Slider(label="Guidance Scale Structure", minimum=0.1, maximum=30.0, value=5.0, step=0.1)
 
345
  seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
346
  eta = gr.Number(label="eta (DDIM)", value=0.0)
347
  masking = gr.Checkbox(label='Only edit the local region', value=True)
348
  a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
349
  n_prompt = gr.Textbox(label="Negative Prompt",
350
  value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
351
-
 
 
 
352
  with gr.Column():
353
  gr.Examples(
354
- examples=[['A picture of a truck', 'assets/truck.png','assets/truck2.jpeg', 892905419, 9, 7.6, 4.3],
355
- ['A picture of a ironman', 'assets/ironman.webp','assets/hulk.jpeg', 709736989, 9, 7.7, 8.1],
356
- ['A person skiing', 'assets/ski.jpg','assets/lava.jpg', 917723061, 9, 7.5, 4.4]],
357
- inputs=[prompt, input_image, ref_img, seed, scale_t, scale_f, scale_s],
 
 
 
 
 
 
 
358
  outputs=None,
359
  fn=None,
360
  cache_examples=False,
361
  )
362
- ips = [input_mask, reference_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
363
- scale_s, scale_f, scale_t, seed, eta, masking, whole_ref, interpolation]
 
 
364
  run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
 
365
 
366
 
367
-
368
- def create_struct_demo():
369
  with gr.Row():
370
- gr.Markdown("## Edit Structure (Comming soon!)")
371
-
372
- def create_both_demo():
373
  with gr.Row():
374
- gr.Markdown("## Edit Structure and Appearance Together (Comming soon!)")
 
 
 
 
 
 
 
 
 
 
 
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
 
378
- block = gr.Blocks(css=css).queue()
379
  with block:
380
  gr.HTML(
381
  """
382
  <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
383
  <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
384
- PAIR Diffusion
385
  </h1>
386
- <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem">
387
- <a href="https://vidit98.github.io/" style="color:blue;">Vidit Goel</a><sup>1*</sup>,
388
- <a href="https://helia95.github.io/" style="color:blue;">Elia Peruzzo</a><sup>1,2*</sup>,
389
- <a href="https://yifanjiang19.github.io/" style="color:blue;">Yifan Jiang</a><sup>3</sup>,
390
- <a href="https://ir1d.github.io/" style="color:blue;">Dejia Xu</a><sup>3</sup>,
391
- <a href="http://disi.unitn.it/~sebe/" style="color:blue;">Nicu Sebe</a><sup>2</sup>, <br>
392
- <a href=" https://people.eecs.berkeley.edu/~trevor/" style="color:blue;">Trevor Darrell</a><sup>4</sup>,
393
- <a href="https://vita-group.github.io/" style="color:blue;">Zhangyang Wang</a><sup>1,3</sup>
394
- and <a href="https://www.humphreyshi.com/home" style="color:blue;">Humphrey Shi</a> <sup>1,5,6</sup> <br>
395
- [<a href="https://arxiv.org/abs/2303.17546" style="color:red;">arXiv</a>]
396
- [<a href="https://github.com/Picsart-AI-Research/PAIR-Diffusion" style="color:red;">GitHub</a>]
397
- </h2>
398
- <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
399
- <sup>1</sup>Picsart AI Resarch (PAIR), <sup>2</sup>UTrenton, <sup>3</sup>UT Austin, <sup>4</sup>UC Berkeley, <sup>5</sup>UOregon, <sup>6</sup>UIUC
400
- </h3>
401
  <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
402
- We built Structure and Appearance Paired (PAIR) Diffusion that allows reference image-guided appearance manipulation and
403
- structure editing of an image at an object level. PAIR diffusion models an image as composition of multiple objects and enables control
404
- over structure and appearance properties of the object. Describing object appearances using text can be challenging and ambiguous, PAIR Diffusion
405
- enables a user to control the appearance of an object using images. User can further use text as another degree of control for appearance.
406
- Having fine-grained control over appearance and structure at object level can be beneficial for future works in video and 3D beside image editing,
407
- where we need to have consistent appearance across time in case of video or across various viewing positions in case of 3D.
408
  </h2>
409
-
410
  </div>
411
  """)
412
 
413
- gr.HTML("""
414
- <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
415
- <br/>
416
- <a href="https://huggingface.co/spaces/PAIR/PAIR-Diffusion?duplicate=true">
417
- <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
418
- </p>""")
419
-
420
  with gr.Tab('Edit Appearance'):
421
  create_app_demo()
422
- with gr.Tab('Edit Structure'):
423
- create_struct_demo()
424
- with gr.Tab('Edit Both'):
425
- create_both_demo()
426
-
 
427
 
428
  block.queue(max_size=20)
429
- block.launch(debug=True)
 
 
 
 
 
1
  import gradio as gr
2
+ from pair_diff_demo import ImageComp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ # torch.cuda.set_per_process_memory_fraction(0.6)
 
 
 
 
5
 
6
+ def init_input_canvas_wrapper(obj, *args):
7
+ return obj.init_input_canvas(*args)
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ def init_ref_canvas_wrapper(obj, *args):
10
+ return obj.init_ref_canvas(*args)
 
 
 
 
 
11
 
12
+ def select_input_object_wrapper(obj, evt: gr.SelectData):
13
+ return obj.select_input_object(evt)
 
 
14
 
15
+ def select_ref_object_wrapper(obj, evt: gr.SelectData):
16
+ return obj.select_ref_object(evt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ def process_wrapper(obj, *args):
19
+ return obj.process(*args)
 
 
20
 
21
+ def set_multi_modal_wrapper(obj, *args):
22
+ return obj.set_multi_modal(*args)
 
 
23
 
24
+ def save_result_wrapper(obj, *args):
25
+ return obj.save_result(*args)
 
 
26
 
27
+ def return_input_img_wrapper(obj):
28
+ return obj.return_input_img()
29
 
30
+ def get_caption_wrapper(obj, *args):
31
+ return obj.get_caption(*args)
32
 
33
+ def multimodal_params(b):
34
+ if b:
35
+ return 10, 3, 6
36
+ else:
37
+ return 6, 8, 9
38
 
39
+ theme = gr.themes.Soft(
40
+ primary_hue="purple",
41
+ font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "Consolas", 'monospace'],
42
+ ).set(
43
+ block_label_background_fill_dark='*neutral_800'
44
+ )
45
 
 
 
46
 
47
+ css = """
48
+ #customized_imbox {
49
+ min-height: 450px;
50
+ }
51
+ #customized_imbox>div[data-testid="image"] {
52
+ min-height: 450px;
53
+ }
54
+ #customized_imbox>div[data-testid="image"]>div {
55
+ min-height: 450px;
56
+ }
57
+ #customized_imbox>div[data-testid="image"]>iframe {
58
+ min-height: 450px;
59
+ }
60
+ #customized_imbox>div.unpadded_box {
61
+ min-height: 450px;
62
+ }
63
+ #myinst {
64
+ font-size: 0.8rem;
65
+ margin: 0rem;
66
+ color: #6B7280;
67
+ }
68
+ #maskinst {
69
+ text-align: justify;
70
+ min-width: 1200px;
71
+ }
72
+ #maskinst>img {
73
+ min-width:399px;
74
+ max-width:450px;
75
+ vertical-align: top;
76
+ display: inline-block;
77
+ }
78
+ #maskinst:after {
79
+ content: "";
80
+ width: 100%;
81
+ display: inline-block;
82
+ }
83
+ """
84
 
85
+ def create_app_demo():
86
 
87
+ with gr.Row():
88
+ gr.Markdown("## Object Level Appearance Editing")
89
+ with gr.Row():
90
+ gr.HTML(
91
+ """
92
+ <div style="text-align: left; max-width: 1200px;">
93
+ <h3 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
94
+ Instructions </h3>
95
+ <ol>
96
+ <li>Upload an Input Image.</li>
97
+ <li>Mark one of segmented objects in the <i>Select Object to Edit</i> tab.</li>
98
+ <li>Upload an Reference Image.</li>
99
+ <li>Mark one of segmented objects in the <i>Select Reference Object</i> tab, whose appearance needs to used in the selected input object.</li>
100
+ <li>Enter a prompt and press <i>Run</i> button. (A very simple would also work) </li>
101
+ </ol>
102
+ </ol>
103
+ </div>""")
104
+ with gr.Column():
105
+ with gr.Row():
106
+ img_edit = gr.State(ImageComp('edit_app'))
107
+ with gr.Column():
108
+ input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
109
+ with gr.Column():
110
+ input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy",)
111
+
112
+ with gr.Column():
113
+ ref_img = gr.Image(source='upload', label='Reference Image', type="numpy")
114
+ with gr.Column():
115
+ reference_mask = gr.Image(source="upload", label='Select Object in Refernce Image', type="numpy")
116
 
117
+ with gr.Row():
118
+ with gr.Column():
119
+ prompt = gr.Textbox(label="Prompt", value='A picture of truck')
120
+ mulitmod = gr.Checkbox(label='Multi-Modal', value=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ mulitmod.change(fn=set_multi_modal_wrapper, inputs=[img_edit, mulitmod])
 
 
 
 
 
 
 
 
 
123
 
124
+ input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_image], show_progress=True)
125
+ input_image.select(fn=select_input_object_wrapper, inputs=[img_edit], outputs=[input_mask, prompt])
126
 
127
+ ref_img.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, ref_img], outputs=[ref_img], show_progress=True)
128
+ ref_img.select(fn=select_ref_object_wrapper, inputs=[img_edit], outputs=[reference_mask])
129
 
130
+ with gr.Column():
131
+ interpolation = gr.Slider(label="Mixing ratio of appearance from reference object", minimum=0.1, maximum=1, value=1.0, step=0.1)
132
+ whole_ref = gr.Checkbox(label='Use whole reference Image for appearance (Only useful for style transfers)', visible=False)
133
+
134
+ # clear_button.click(fn=img_edit.clear_points, inputs=[], outputs=[input_mask, reference_mask])
135
 
136
+ with gr.Row():
137
+ run_button = gr.Button(label="Run")
138
+ save_button = gr.Button("Save")
139
+
140
+ with gr.Row():
141
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
142
+
143
+ with gr.Accordion("Advanced options", open=False):
144
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=1)
145
+ image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
146
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
147
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
148
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
149
+ scale_t = gr.Slider(label="Guidance Scale Text", minimum=0., maximum=30.0, value=6.0, step=0.1)
150
+ scale_f = gr.Slider(label="Guidance Scale Appearance", minimum=0., maximum=30.0, value=8.0, step=0.1)
151
+ scale_s = gr.Slider(label="Guidance Scale Structure", minimum=0., maximum=30.0, value=9.0, step=0.1)
152
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
153
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
154
+ masking = gr.Checkbox(label='Only edit the local region', value=True)
155
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
156
+ n_prompt = gr.Textbox(label="Negative Prompt",
157
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
158
+ dil = gr.Slider(label="Merging region around Edge", minimum=0, maximum=0, value=0, step=0)
159
+
160
+ with gr.Column():
161
+ gr.Examples(
162
+ examples=[['assets/examples/car.jpeg','assets/examples/ian.jpeg', '', 709736989, 6, 8, 9],
163
+ ['assets/examples/ian.jpeg','assets/examples/car.jpeg', '', 709736989, 6, 8, 9],
164
+ ['assets/examples/car.jpeg','assets/examples/ran.webp', '', 709736989, 6, 8, 9],
165
+ ['assets/examples/car.jpeg','assets/examples/car1.webp', '', 709736989, 6, 8, 9],
166
+ ['assets/examples/car1.webp','assets/examples/car.jpeg', '', 709736989, 6, 8, 9],
167
+ ['assets/examples/chair.jpeg','assets/examples/chair1.jpeg', '', 1106204668, 6, 8, 9],
168
+ ['assets/examples/house.jpeg','assets/examples/house2.jpeg', '', 1106204668, 6, 8, 9],
169
+ ['assets/examples/house2.jpeg','assets/examples/house.jpeg', '', 1106204668, 6, 8, 9],
170
+ ['assets/examples/park.webp','assets/examples/grasslands-national-park.jpeg', '', 1106204668, 6, 8, 9],
171
+ ['assets/examples/door.jpeg','assets/examples/door2.jpeg', '', 709736989, 6, 8, 9]],
172
+ inputs=[input_image, ref_img, prompt, seed, scale_t, scale_f, scale_s],
173
+ cache_examples=False,
174
+ )
175
 
176
+ mulitmod.change(fn=multimodal_params, inputs=[mulitmod], outputs=[scale_t, scale_f, scale_s])
 
177
 
178
+ ips = [input_mask, reference_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
179
+ scale_s, scale_f, scale_t, seed, eta, dil, masking, whole_ref, interpolation]
180
+ ips_save = [input_mask, prompt, a_prompt, n_prompt, ddim_steps,
181
+ scale_s, scale_f, scale_t, seed, dil, interpolation]
182
+ run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
183
+ save_button.click(fn=save_result_wrapper, inputs=[img_edit, *ips_save])
184
 
185
 
186
+ def create_add_obj_demo():
187
+ with gr.Row():
188
+ gr.Markdown("## Add Objects to Image")
189
+ with gr.Row():
190
+ gr.HTML(
191
+ """
192
+ <div style="text-align: left; max-width: 1200px;">
193
+ <h3 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
194
+ Instructions </h3>
195
+ <ol>
196
+ <li> Upload an Input Image.</li>
197
+ <li>Draw the precise shape of object in the image where you want to add object in <i>Draw Object</i> tab.</li>
198
+ <li>Upload an Reference Image.</li>
199
+ <li>Click on the object in the Reference Image tab that you want to add in the Input Image.</li>
200
+ <li>Enter a prompt and press <i>Run</i> button. (A very simple would also work) </li>
201
+ </ol>
202
+ </ol>
203
+ </div>""")
204
+ with gr.Column():
205
+ with gr.Row():
206
+ img_edit = gr.State(ImageComp('add_obj'))
207
+ with gr.Column():
208
+ input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
209
+ with gr.Column():
210
+ input_mask = gr.Image(source="upload", label='Draw the desired Object', type="numpy", tool="sketch")
211
 
212
+ input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_image])
213
+ input_image.change(fn=return_input_img_wrapper, inputs=[img_edit], outputs=[input_mask], queue=False)
214
+
215
+ with gr.Column():
216
+ ref_img = gr.Image(source='upload', label='Reference Image', type="numpy")
217
+ with gr.Column():
218
+ reference_mask = gr.Image(source="upload", label='Selected Object in Refernce Image', type="numpy")
219
 
220
+ ref_img.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, ref_img], outputs=[ref_img], queue=False)
221
+ # ref_img.upload(fn=img_edit.init_ref_canvas, inputs=[ref_img], outputs=[ref_img])
222
+ ref_img.select(fn=select_ref_object_wrapper, inputs=[img_edit], outputs=[reference_mask])
223
 
224
+ with gr.Row():
225
+ prompt = gr.Textbox(label="Prompt", value='A picture of truck')
226
+ mulitmod = gr.Checkbox(label='Multi-Modal', value=False, visible=False)
227
 
228
+ mulitmod.change(fn=set_multi_modal_wrapper, inputs=[img_edit, mulitmod])
229
 
230
+ with gr.Row():
231
+ run_button = gr.Button(label="Run")
232
+ save_button = gr.Button("Save")
233
+
234
+ with gr.Row():
235
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
236
+
237
+ with gr.Accordion("Advanced options", open=False):
238
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=1)
239
+ # image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
240
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
241
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
242
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
243
+ dil = gr.Slider(label="Merging region around Edge", minimum=0, maximum=5, value=2, step=1)
244
+ scale_t = gr.Slider(label="Guidance Scale Text", minimum=0., maximum=30.0, value=6.0, step=0.1)
245
+ scale_f = gr.Slider(label="Guidance Scale Appearance", minimum=0., maximum=30.0, value=8.0, step=0.1)
246
+ scale_s = gr.Slider(label="Guidance Scale Structure", minimum=0., maximum=30.0, value=9.0, step=0.1)
247
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
248
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
249
+ masking = gr.Checkbox(label='Only edit the local region', value=True)
250
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
251
+ n_prompt = gr.Textbox(label="Negative Prompt",
252
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
253
+
254
+ mulitmod.change(fn=multimodal_params, inputs=[mulitmod], outputs=[scale_t, scale_f, scale_s])
255
 
256
+ with gr.Column():
257
+ gr.Examples(
258
+ examples=[['assets/examples/chair.jpeg','assets/examples/carpet2.webp', 'A picture of living room with carpet', 892905419, 6, 8, 9],
259
+ ['assets/examples/chair.jpeg','assets/examples/chair1.jpeg', 'A picture of living room with a orange and white sofa', 892905419, 6, 8, 9],
260
+ ['assets/examples/park.webp','assets/examples/dog.jpeg', 'A picture of dog in the park', 892905419, 6, 8, 9]],
261
+ inputs=[input_image, ref_img, prompt, seed, scale_t, scale_f, scale_s],
262
+ outputs=None,
263
+ fn=None,
264
+ cache_examples=False,
265
+ )
266
+ ips = [input_mask, reference_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
267
+ scale_s, scale_f, scale_t, seed, eta, dil, masking]
268
+ ips_save = [input_mask, prompt, a_prompt, n_prompt, ddim_steps,
269
+ scale_s, scale_f, scale_t, seed, dil]
270
+ run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
271
+ save_button.click(fn=save_result_wrapper, inputs=[img_edit, *ips_save])
272
 
273
+ def create_obj_variation_demo():
274
  with gr.Row():
275
+ gr.Markdown("## Objects Variation")
276
  with gr.Row():
277
  gr.HTML(
278
  """
279
+ <div style="text-align: left; max-width: 1200px;">
280
+ <h3 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
281
+ Instructions </h3>
282
+ <ol>
283
+ <li> Upload an Input Image.</li>
284
+ <li>Click on object to have variations</li>
285
+ <li>Press <i>Run</i> button</li>
286
+ </ol>
287
+ </ol>
288
+ </div>""")
289
+
 
 
 
 
 
290
  with gr.Column():
291
  with gr.Row():
292
  img_edit = gr.State(ImageComp('edit_app'))
293
  with gr.Column():
 
294
  input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
295
  with gr.Column():
296
+ input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy",)
 
 
 
 
 
 
 
 
 
 
297
 
 
 
298
  with gr.Row():
299
+ prompt = gr.Textbox(label="Prompt", value='')
300
+ mulitmod = gr.Checkbox(label='Multi-Modal', value=False)
301
+
302
+
303
+ mulitmod.change(fn=set_multi_modal_wrapper, inputs=[img_edit, mulitmod])
304
+
305
+ input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_image])
306
+ input_image.select(fn=select_input_object_wrapper, inputs=[img_edit], outputs=[input_mask, prompt])
307
+ input_image.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, input_image], outputs=[], queue=False)
308
+ input_image.select(fn=select_ref_object_wrapper, inputs=[img_edit], outputs=[])
309
+
310
  with gr.Row():
311
  run_button = gr.Button(label="Run")
312
+ save_button = gr.Button("Save")
313
 
314
  with gr.Row():
315
  result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
316
 
317
  with gr.Accordion("Advanced options", open=False):
318
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=2)
319
+ # image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
320
  strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
321
  guess_mode = gr.Checkbox(label='Guess Mode', value=False)
322
  ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
323
+ dil = gr.Slider(label="Merging region around Edge", minimum=0, maximum=5, value=2, step=1)
324
+ scale_t = gr.Slider(label="Guidance Scale Text", minimum=0.0, maximum=30.0, value=6.0, step=0.1)
325
+ scale_f = gr.Slider(label="Guidance Scale Appearance", minimum=0.0, maximum=30.0, value=8.0, step=0.1)
326
+ scale_s = gr.Slider(label="Guidance Scale Structure", minimum=0.0, maximum=30.0, value=9.0, step=0.1)
327
  seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
328
  eta = gr.Number(label="eta (DDIM)", value=0.0)
329
  masking = gr.Checkbox(label='Only edit the local region', value=True)
330
  a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
331
  n_prompt = gr.Textbox(label="Negative Prompt",
332
  value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
333
+
334
+
335
+ mulitmod.change(fn=multimodal_params, inputs=[mulitmod], outputs=[scale_t, scale_f, scale_s])
336
+
337
  with gr.Column():
338
  gr.Examples(
339
+ examples=[['assets/examples/chair.jpeg' , 892905419, 6, 8, 9],
340
+ ['assets/examples/chair1.jpeg', 892905419, 6, 8, 9],
341
+ ['assets/examples/park.webp', 892905419, 6, 8, 9],
342
+ ['assets/examples/car.jpeg', 709736989, 6, 8, 9],
343
+ ['assets/examples/ian.jpeg', 709736989, 6, 8, 9],
344
+ ['assets/examples/chair.jpeg', 1106204668, 6, 8, 9],
345
+ ['assets/examples/door.jpeg', 709736989, 6, 8, 9],
346
+ ['assets/examples/carpet2.webp', 892905419, 6, 8, 9],
347
+ ['assets/examples/house.jpeg', 709736989, 6, 8, 9],
348
+ ['assets/examples/house2.jpeg', 709736989, 6, 8, 9],],
349
+ inputs=[input_image, seed, scale_t, scale_f, scale_s],
350
  outputs=None,
351
  fn=None,
352
  cache_examples=False,
353
  )
354
+ ips = [input_mask, input_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
355
+ scale_s, scale_f, scale_t, seed, eta, dil, masking]
356
+ ips_save = [input_mask, prompt, a_prompt, n_prompt, ddim_steps,
357
+ scale_s, scale_f, scale_t, seed, dil]
358
  run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
359
+ save_button.click(fn=save_result_wrapper, inputs=[img_edit, *ips_save])
360
 
361
 
362
+ def create_free_form_obj_variation_demo():
 
363
  with gr.Row():
364
+ gr.Markdown("## Objects Variation")
 
 
365
  with gr.Row():
366
+ gr.HTML(
367
+ """
368
+ <div style="text-align: left; max-width: 1200px;">
369
+ <h3 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
370
+ Instructions </h3>
371
+ <ol>
372
+ <li> Upload an Input Image.</li>
373
+ <li>Mask the region that you want to have variation</li>
374
+ <li>Press <i>Run</i> button</li>
375
+ </ol>
376
+ </ol>
377
+ </div>""")
378
 
379
+ with gr.Column():
380
+ with gr.Row():
381
+ img_edit = gr.State(ImageComp('edit_app'))
382
+ with gr.Column():
383
+ input_image = gr.Image(source='upload', label='Input Image', type="numpy", )
384
+ with gr.Column():
385
+ input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy", tool="sketch")
386
+
387
+ with gr.Row():
388
+ prompt = gr.Textbox(label="Prompt", value='')
389
+ ignore_structure = gr.Checkbox(label='Ignore Structure (Please provide a good caption)', visible=False)
390
+ mulitmod = gr.Checkbox(label='Multi-Modal', value=False)
391
+
392
+ mulitmod.change(fn=set_multi_modal_wrapper, inputs=[img_edit, mulitmod])
393
+
394
+ input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_mask])
395
+ input_mask.edit(fn=get_caption_wrapper, inputs=[img_edit, input_mask], outputs=[prompt])
396
+ input_image.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, input_image], outputs=[], queue=False)
397
+ # input_image.select(fn=select_ref_object_wrapper, inputs=[img_edit], outputs=[])
398
+
399
+ # input_image.edit(fn=img_edit.vis_mask, inputs=[input_image], outputs=[input_mask])
400
+
401
+ with gr.Row():
402
+ run_button = gr.Button(label="Run")
403
+ save_button = gr.Button("Save")
404
+
405
+ with gr.Row():
406
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
407
+
408
+ with gr.Accordion("Advanced options", open=False):
409
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=2)
410
+ # image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
411
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
412
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
413
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
414
+ dil = gr.Slider(label="Merging region around Edge", minimum=0, maximum=5, value=2, step=1)
415
+ scale_t = gr.Slider(label="Guidance Scale Text", minimum=0.0, maximum=30.0, value=6.0, step=0.1)
416
+ scale_f = gr.Slider(label="Guidance Scale Appearance", minimum=0.0, maximum=30.0, value=8.0, step=0.1)
417
+ scale_s = gr.Slider(label="Guidance Scale Structure", minimum=0.0, maximum=30.0, value=9.0, step=0.1)
418
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
419
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
420
+ masking = gr.Checkbox(label='Only edit the local region', value=True)
421
+ free_form_obj_var = gr.Checkbox(label='', value=True)
422
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
423
+ n_prompt = gr.Textbox(label="Negative Prompt",
424
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
425
+ interpolation = gr.Slider(label="Mixing ratio of appearance from reference object", minimum=0.0, maximum=0.1, step=0.1)
426
+
427
+ mulitmod.change(fn=multimodal_params, inputs=[mulitmod], outputs=[scale_t, scale_f, scale_s])
428
+
429
+ with gr.Column():
430
+ gr.Examples(
431
+ examples=[['assets/examples/chair.jpeg' , 892905419, 6, 8, 9],
432
+ ['assets/examples/chair1.jpeg', 892905419, 6, 8, 9],
433
+ ['assets/examples/park.webp', 892905419, 6, 8, 9],
434
+ ['assets/examples/car.jpeg', 709736989, 6, 8, 9],
435
+ ['assets/examples/ian.jpeg', 709736989, 6, 8, 9],
436
+ ['assets/examples/chair.jpeg', 1106204668, 6, 8, 9],
437
+ ['assets/examples/door.jpeg', 709736989, 6, 8, 9],
438
+ ['assets/examples/carpet2.webp', 892905419, 6, 8, 9],
439
+ ['assets/examples/house.jpeg', 709736989, 6, 8, 9],
440
+ ['assets/examples/house2.jpeg', 709736989, 6, 8, 9],],
441
+ inputs=[input_image, seed, scale_t, scale_f, scale_s],
442
+ outputs=None,
443
+ fn=None,
444
+ cache_examples=False,
445
+ )
446
+ ips = [input_mask, input_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
447
+ scale_s, scale_f, scale_t, seed, eta, dil, masking, free_form_obj_var, dil, free_form_obj_var, ignore_structure]
448
+ ips_save = [input_mask, prompt, a_prompt, n_prompt, ddim_steps,
449
+ scale_s, scale_f, scale_t, seed, dil, interpolation, free_form_obj_var]
450
+ run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
451
+ save_button.click(fn=save_result_wrapper, inputs=[img_edit, *ips_save])
452
 
453
 
454
+ block = gr.Blocks(css=css, theme=theme).queue()
455
  with block:
456
  gr.HTML(
457
  """
458
  <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
459
  <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
460
+ PAIR Diffusion: A Comprehensive Multimodal Object-Level Image Editor
461
  </h1>
462
+ <h3 style="margin-top: 0.6rem; margin-bottom: 1rem">Picsart AI Research</h3>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
464
+ PAIR diffusion provides comprehensive multi-modal editing capabilities to edit real images without the need of inverting them. The current suite contains
465
+ <span style="color: #01feee;">Object Variation</span>, <span style="color: #4f82d9;">Edit Appearance of any object using a reference image and text</span>,
466
+ <span style="color: #d402bf;">Add any object from a reference image in the input image</span>. This operations can be mixed with each other to
467
+ develop new editing operations in future.
468
+ </ul>
 
469
  </h2>
 
470
  </div>
471
  """)
472
 
 
 
 
 
 
 
 
473
  with gr.Tab('Edit Appearance'):
474
  create_app_demo()
475
+ with gr.Tab('Object Variation Free Form Mask'):
476
+ create_free_form_obj_variation_demo()
477
+ with gr.Tab('Object Variation'):
478
+ create_obj_variation_demo()
479
+ with gr.Tab('Add Objects'):
480
+ create_add_obj_demo()
481
 
482
  block.queue(max_size=20)
483
+ block.launch(share=True)
484
+
assets/GIF.gif CHANGED

Git LFS Details

  • SHA256: e720b8c82526a982014b3eee781ba5d1a42c104e380444c536e4bbee21101a65
  • Pointer size: 131 Bytes
  • Size of remote file: 370 kB
assets/Teaser_Small.png ADDED

Git LFS Details

  • SHA256: dc29a44a9ddd8ec91b114b09b1229b1eb8d0740874a93e1a7d9ff92d7327b0b1
  • Pointer size: 131 Bytes
  • Size of remote file: 862 kB
assets/examples/Lancia.webp ADDED

Git LFS Details

  • SHA256: 628010d440fafc6d5e61691b543e7dd59bc11c76ec0d48b36890a96c22abc8a4
  • Pointer size: 131 Bytes
  • Size of remote file: 148 kB
assets/examples/car.jpeg ADDED

Git LFS Details

  • SHA256: 71a73a4ec6eab9e075eaa59879a884e0f663ad28d548ce8cf2e604166346874e
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
assets/examples/car1.webp ADDED

Git LFS Details

  • SHA256: c66b53e2f266d68f964574ba9d51dd70dbab478b63905a993f3784beb67bd3b7
  • Pointer size: 131 Bytes
  • Size of remote file: 122 kB
assets/examples/carpet2.webp ADDED

Git LFS Details

  • SHA256: 0bc055513cfcbfae7320e829fb84697c5a9d649edecd17652ddf54a77522af3c
  • Pointer size: 130 Bytes
  • Size of remote file: 69.4 kB
assets/examples/chair.jpeg ADDED

Git LFS Details

  • SHA256: d9d0040bcd7275bea283432c43abb69cdbe1fc32ff04c3994153d780057791b8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.03 MB
assets/examples/chair1.jpeg ADDED

Git LFS Details

  • SHA256: 55af4d3b00a2ec95638bc4703adbca8409cf58ae19342326d0c0eac191179dcb
  • Pointer size: 129 Bytes
  • Size of remote file: 6.49 kB
assets/examples/dog.jpeg ADDED

Git LFS Details

  • SHA256: aa2d7acb2a06243b753d56306f45e22aae6e5b02bdf966ee7466ba517153cc11
  • Pointer size: 131 Bytes
  • Size of remote file: 278 kB
assets/examples/door.jpeg ADDED

Git LFS Details

  • SHA256: 424cc31f30b29060b8869c2ccb62c2f60010088b0fc3e9d8ea53d04fc21dbfbe
  • Pointer size: 130 Bytes
  • Size of remote file: 46.9 kB
assets/examples/door2.jpeg ADDED

Git LFS Details

  • SHA256: 32022ee30272376935e44df622d478c13d68686071ada7fe60ae36ffe44167da
  • Pointer size: 131 Bytes
  • Size of remote file: 540 kB
assets/examples/grasslands-national-park.jpeg ADDED

Git LFS Details

  • SHA256: 26690739225241d173d04b21661809dda464e59ac2af8da73178883094e508b6
  • Pointer size: 130 Bytes
  • Size of remote file: 66.1 kB
assets/examples/house.jpeg ADDED

Git LFS Details

  • SHA256: 89268b6097908e97cc8a56df3824ac6c589d86e415d2b1954e65255f1eddb595
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB
assets/examples/house2.jpeg ADDED

Git LFS Details

  • SHA256: 753bbfa58d471be54f23b029d6764bd8682f77ff5c7f375b71ac8a10cb28342b
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
assets/examples/ian.jpeg ADDED

Git LFS Details

  • SHA256: b59d82f9b8cd2cc5a7864d85a1b0b51ad381c13370c6b70210b4ad1a267a9478
  • Pointer size: 131 Bytes
  • Size of remote file: 386 kB
assets/examples/park.webp ADDED

Git LFS Details

  • SHA256: d65c17257c64a793fa92311c91e2d9f31ec5d704ca644bbf3fb942de9769526e
  • Pointer size: 131 Bytes
  • Size of remote file: 731 kB
assets/examples/ran.webp ADDED

Git LFS Details

  • SHA256: 8f92f72d0a1286ff77cce5cdbd78fd1455fb5041c91ca5729f00055441bae13a
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB
assets/hulk.jpeg CHANGED

Git LFS Details

  • SHA256: e7b2163b45349d71b40ac92b24e5dfa8559dcce5449c41740bf344d1a445e287
  • Pointer size: 130 Bytes
  • Size of remote file: 76.5 kB
assets/ironman.webp CHANGED

Git LFS Details

  • SHA256: 005c4adf045975ec4a328664e65258078ed90efe3e30bceba863f3c187404cc4
  • Pointer size: 130 Bytes
  • Size of remote file: 94.9 kB
assets/lava.jpg CHANGED

Git LFS Details

  • SHA256: 16cd431ad032a8058f6d6142e2e24d6cc7848837c44df50465085be875c931b3
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB
assets/ski.jpg CHANGED

Git LFS Details

  • SHA256: c8f11aa5fdcbf78a3647a56a59c0cb4eb0d6da1dc83fc4ad247ad75e347a7476
  • Pointer size: 131 Bytes
  • Size of remote file: 200 kB
assets/truck.png CHANGED

Git LFS Details

  • SHA256: 86a0fa5c1d24bddd54db9631e717bcf56cac4f083e16d399c2495f4c766e4a9c
  • Pointer size: 130 Bytes
  • Size of remote file: 71.1 kB
assets/truck2.jpeg CHANGED

Git LFS Details

  • SHA256: 79ee52c4ab698b0702d34b7f6216db4899dcf08f5402d3e5f61b8e0f6408821a
  • Pointer size: 131 Bytes
  • Size of remote file: 199 kB
cldm/appearance_networks.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Neighborhood Attention Transformer.
3
+ https://arxiv.org/abs/2204.07143
4
+
5
+ This source code is licensed under the license found in the
6
+ LICENSE file in the root directory of this source tree.
7
+ """
8
+ import torch
9
+ import torchvision
10
+ import torch.nn as nn
11
+ from timm.models.layers import trunc_normal_, DropPath
12
+ from timm.models.registry import register_model
13
+
14
+
15
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
16
+ IMAGENET_STD = [0.229, 0.224, 0.225]
17
+
18
+ class VGGPerceptualLoss(torch.nn.Module):
19
+ def __init__(self, resize=True):
20
+ super(VGGPerceptualLoss, self).__init__()
21
+ blocks = []
22
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
23
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
24
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
25
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
26
+ for bl in blocks:
27
+ for p in bl.parameters():
28
+ p.requires_grad = False
29
+ self.blocks = torch.nn.ModuleList(blocks)
30
+ self.transform = torch.nn.functional.interpolate
31
+ self.resize = resize
32
+ self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
33
+ self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
34
+
35
+ def forward(self, input, appearance_layers=[0,1,2,3]):
36
+ if input.shape[1] != 3:
37
+ input = input.repeat(1, 3, 1, 1)
38
+ target = target.repeat(1, 3, 1, 1)
39
+ input = (input-self.mean) / self.std
40
+ if self.resize:
41
+ input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
42
+ x = input
43
+ feats = []
44
+ for i, block in enumerate(self.blocks):
45
+ x = block(x)
46
+ if i in appearance_layers:
47
+ feats.append(x)
48
+
49
+ return feats
50
+
51
+
52
+ class DINOv2(torch.nn.Module):
53
+ def __init__(self, resize=True, size=224, model_type='dinov2_vitl14'):
54
+ super(DINOv2, self).__init__()
55
+ self.size=size
56
+ self.resize = resize
57
+ self.transform = torch.nn.functional.interpolate
58
+ self.model = torch.hub.load('facebookresearch/dinov2', model_type)
59
+ self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
60
+ self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
61
+
62
+ def forward(self, input, appearance_layers=[1,2]):
63
+ if input.shape[1] != 3:
64
+ input = input.repeat(1, 3, 1, 1)
65
+ target = target.repeat(1, 3, 1, 1)
66
+
67
+ if self.resize:
68
+ input = self.transform(input, mode='bicubic', size=(self.size, self.size), align_corners=False)
69
+ # mean = torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1).to(input.device)
70
+ # std = torch.tensor(IMAGENET_STD).view(1, 3, 1, 1).to(input.device)
71
+ input = (input-self.mean) / self.std
72
+ feats = self.model.get_intermediate_layers(input, self.model.n_blocks, reshape=True)
73
+ feats = [f.detach() for f in feats]
74
+
75
+ return feats
cldm/cldm.py CHANGED
@@ -10,7 +10,6 @@ from ldm.modules.diffusionmodules.util import (
10
  zero_module,
11
  timestep_embedding,
12
  )
13
- import torchvision
14
  from einops import rearrange, repeat
15
  from torchvision.utils import make_grid
16
  from ldm.modules.attention import SpatialTransformer
@@ -18,46 +17,9 @@ from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSeq
18
  from ldm.models.diffusion.ddpm import LatentDiffusion
19
  from ldm.util import log_txt_as_img, exists, instantiate_from_config
20
  from ldm.models.diffusion.ddim import DDIMSampler
 
21
 
22
 
23
- class VGGPerceptualLoss(torch.nn.Module):
24
- def __init__(self, resize=True):
25
- super(VGGPerceptualLoss, self).__init__()
26
- blocks = []
27
- vgg_model = torchvision.models.vgg16(pretrained=True)
28
- print('Loaded VGG weights')
29
- blocks.append(vgg_model.features[:4].eval())
30
- blocks.append(vgg_model.features[4:9].eval())
31
- blocks.append(vgg_model.features[9:16].eval())
32
- blocks.append(vgg_model.features[16:23].eval())
33
-
34
- for bl in blocks:
35
- for p in bl.parameters():
36
- p.requires_grad = False
37
- self.blocks = torch.nn.ModuleList(blocks)
38
- self.transform = torch.nn.functional.interpolate
39
- self.resize = resize
40
- self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
41
- self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
42
- print('Initialized VGG model')
43
-
44
- def forward(self, input, feature_layers=[0, 1, 2, 3], style_layers=[1,]):
45
- if input.shape[1] != 3:
46
- input = input.repeat(1, 3, 1, 1)
47
- target = target.repeat(1, 3, 1, 1)
48
- input = (input-self.mean) / self.std
49
- if self.resize:
50
- input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
51
- x = input
52
- gram_matrices_all = []
53
- feats = []
54
- for i, block in enumerate(self.blocks):
55
- x = block(x)
56
- if i in style_layers:
57
- feats.append(x)
58
-
59
- return feats
60
-
61
 
62
 
63
  class ControlledUnetModel(UNetModel):
@@ -325,6 +287,7 @@ class ControlNet(nn.Module):
325
  def forward(self, x, hint, timesteps, context, **kwargs):
326
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
327
  emb = self.time_embed(t_emb)
 
328
  guided_hint = self.input_hint_block(hint, emb, context, x.shape)
329
 
330
  outs = []
@@ -343,57 +306,6 @@ class ControlNet(nn.Module):
343
  outs.append(self.middle_block_out(h, emb, context))
344
 
345
  return outs
346
-
347
- class Interpolate(nn.Module):
348
- def __init__(self, size, mode):
349
- super(Interpolate, self).__init__()
350
- self.interp = torch.nn.functional.interpolate
351
- self.size = size
352
- self.mode = mode
353
- self.factor = 8
354
-
355
- def forward(self, x):
356
- h,w = x.shape[2]//self.factor, x.shape[3]//self.factor
357
- x = self.interp(x, size=(h,w), mode=self.mode)
358
- return x
359
-
360
- class ControlNetSAP(ControlNet):
361
- def __init__(
362
- self,
363
- hint_channels,
364
- model_channels,
365
- input_hint_block='fixed',
366
- size = 64,
367
- mode='nearest',
368
- *args,
369
- **kwargs
370
- ):
371
- super().__init__( hint_channels=hint_channels, model_channels=model_channels, *args, **kwargs)
372
- #hint channels are atleast 128 dims
373
-
374
- if input_hint_block == 'learnable':
375
- ch = 2 ** (int(math.log2(hint_channels)))
376
- self.input_hint_block = TimestepEmbedSequential(
377
- conv_nd(self.dims, hint_channels, hint_channels, 3, padding=1),
378
- nn.SiLU(),
379
- conv_nd(self.dims, hint_channels, 2*ch, 3, padding=1, stride=2),
380
- nn.SiLU(),
381
- conv_nd(self.dims, 2*ch, 2*ch, 3, padding=1),
382
- nn.SiLU(),
383
- conv_nd(self.dims, 2*ch, 2*ch, 3, padding=1, stride=2),
384
- nn.SiLU(),
385
- conv_nd(self.dims, 2*ch, 2*ch, 3, padding=1),
386
- nn.SiLU(),
387
- conv_nd(self.dims, 2*ch, model_channels, 3, padding=1, stride=2),
388
- nn.SiLU(),
389
- zero_module(conv_nd(self.dims, model_channels, model_channels, 3, padding=1))
390
- )
391
- else:
392
- print("Only interpolation")
393
- self.input_hint_block = TimestepEmbedSequential(
394
- Interpolate(size, mode),
395
- zero_module(conv_nd(self.dims, hint_channels, model_channels, 3, padding=1)))
396
-
397
 
398
  class ControlLDM(LatentDiffusion):
399
 
@@ -420,11 +332,11 @@ class ControlLDM(LatentDiffusion):
420
  diffusion_model = self.model.diffusion_model
421
 
422
  cond_txt = torch.cat(cond['c_crossattn'], 1)
423
-
424
  if cond['c_concat'] is None:
425
  eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
426
  else:
427
- control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
 
428
  control = [c * scale for c, scale in zip(control, self.control_scales)]
429
  eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
430
 
@@ -443,7 +355,7 @@ class ControlLDM(LatentDiffusion):
443
  use_ddim = ddim_steps is not None
444
 
445
  log = dict()
446
- z, c = self.get_input(batch, self.first_stage_key, bs=N)
447
  c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
448
  N = min(z.shape[0], N)
449
  n_row = min(z.shape[0], n_row)
@@ -498,8 +410,9 @@ class ControlLDM(LatentDiffusion):
498
  @torch.no_grad()
499
  def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
500
  ddim_sampler = DDIMSampler(self)
501
- b, c, h, w = cond["c_concat"][0].shape
502
- shape = (self.channels, h // 8, w // 8)
 
503
  samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
504
  return samples, intermediates
505
 
@@ -525,24 +438,54 @@ class ControlLDM(LatentDiffusion):
525
  self.cond_stage_model = self.cond_stage_model.cuda()
526
 
527
 
528
- class SAP(ControlLDM):
 
529
  @torch.no_grad()
530
- def __init__(self,control_stage_config, control_key, only_mid_control, *args, **kwargs):
 
531
  super().__init__(control_stage_config=control_stage_config,
532
  control_key=control_key,
533
  only_mid_control=only_mid_control,
534
  *args, **kwargs)
535
- self.appearance_net = VGGPerceptualLoss().to(self.device)
536
- print("Loaded VGG model")
537
 
538
- def get_appearance(self, img, mask, return_all=False):
 
 
 
 
 
 
 
539
  img = (img + 1) * 0.5
540
- feat = self.appearance_net(img)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  empty_mask_flag = torch.sum(mask, dim=(1,2,3)) == 0
542
 
543
 
544
  empty_appearance = torch.zeros(feat.shape).to(self.device)
545
- mask = torch.nn.functional.interpolate(mask.float(), (feat.shape[2:])).long()
546
  one_hot = torch.nn.functional.one_hot(mask[:,0]).permute(0,3,1,2).float()
547
 
548
  feat = torch.einsum('nchw, nmhw->nmchw', feat, one_hot)
@@ -552,32 +495,68 @@ class SAP(ControlLDM):
552
  mean_feat[:, 0] = torch.zeros(mean_feat[:,0].shape).to(self.device) #set edges in panopitc mask to empty appearance feature
553
 
554
  splatted_feat = torch.einsum('nmc, nmhw->nchw', mean_feat, one_hot)
555
- splatted_feat[empty_mask_flag] = empty_appearance[empty_mask_flag]
556
  splatted_feat = torch.nn.functional.normalize(splatted_feat) #l2 normalize on c dim
557
 
558
  if return_all:
559
  return splatted_feat, mean_feat, one_hot, empty_mask_flag
560
-
561
  return splatted_feat
562
-
 
563
  def get_input(self, batch, k, bs=None, *args, **kwargs):
564
  z, c, x_orig, x_recon = super(ControlLDM, self).get_input(batch, self.first_stage_key, return_first_stage_outputs=True , *args, **kwargs)
565
  structure = batch['seg'].unsqueeze(1)
566
  mask = batch['mask'].unsqueeze(1).to(self.device)
567
- appearance = self.get_appearance(x_orig, mask)
 
 
 
568
  if bs is not None:
569
  structure = structure[:bs]
570
- appearance = appearance[:bs]
571
-
572
  structure = structure.to(self.device)
573
- appearance = appearance.to(self.device)
574
  structure = structure.to(memory_format=torch.contiguous_format).float()
575
- appearance = appearance.to(memory_format=torch.contiguous_format).float()
576
- structure = torch.nn.functional.interpolate(structure, x_orig.shape[2:])
577
- appearance = torch.nn.functional.interpolate(appearance, x_orig.shape[2:])
578
- control = torch.cat([structure, appearance], dim=1)
579
- return z, dict(c_crossattn=[c], c_concat=[control])
580
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  @torch.no_grad()
582
  def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
583
  quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=False,
@@ -588,11 +567,14 @@ class SAP(ControlLDM):
588
 
589
  log = dict()
590
  z, c = self.get_input(batch, self.first_stage_key, bs=N)
591
- c_cat, c = c["c_concat"][0][:N,], c["c_crossattn"][0][:N]
592
  N = min(z.shape[0], N)
593
  n_row = min(z.shape[0], n_row)
594
  log["reconstruction"] = self.decode_first_stage(z)
595
- log["control"] = c_cat[:, :1]
 
 
 
596
  log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
597
 
598
  if plot_diffusion_rows:
@@ -634,7 +616,7 @@ class SAP(ControlLDM):
634
 
635
  if unconditional_guidance_scale > 1.0:
636
  uc_cross = self.get_unconditional_conditioning(N)
637
- uc_cat = c_cat # torch.zeros_like(c_cat)
638
  uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
639
  samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
640
  batch_size=N, ddim=use_ddim,
@@ -646,3 +628,18 @@ class SAP(ControlLDM):
646
  log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
647
 
648
  return log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  zero_module,
11
  timestep_embedding,
12
  )
 
13
  from einops import rearrange, repeat
14
  from torchvision.utils import make_grid
15
  from ldm.modules.attention import SpatialTransformer
 
17
  from ldm.models.diffusion.ddpm import LatentDiffusion
18
  from ldm.util import log_txt_as_img, exists, instantiate_from_config
19
  from ldm.models.diffusion.ddim import DDIMSampler
20
+ from cldm.appearance_networks import VGGPerceptualLoss, DINOv2
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  class ControlledUnetModel(UNetModel):
 
287
  def forward(self, x, hint, timesteps, context, **kwargs):
288
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
289
  emb = self.time_embed(t_emb)
290
+ # hint = hint[:,:-1]
291
  guided_hint = self.input_hint_block(hint, emb, context, x.shape)
292
 
293
  outs = []
 
306
  outs.append(self.middle_block_out(h, emb, context))
307
 
308
  return outs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  class ControlLDM(LatentDiffusion):
311
 
 
332
  diffusion_model = self.model.diffusion_model
333
 
334
  cond_txt = torch.cat(cond['c_crossattn'], 1)
 
335
  if cond['c_concat'] is None:
336
  eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
337
  else:
338
+ # control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
339
+ control = self.control_model(x=x_noisy, hint=cond['c_concat'][0], timesteps=t, context=cond_txt)
340
  control = [c * scale for c, scale in zip(control, self.control_scales)]
341
  eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
342
 
 
355
  use_ddim = ddim_steps is not None
356
 
357
  log = dict()
358
+ z, c = self.get_input(batch, self.first_stage_key, bs=N, logging=True)
359
  c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
360
  N = min(z.shape[0], N)
361
  n_row = min(z.shape[0], n_row)
 
410
  @torch.no_grad()
411
  def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
412
  ddim_sampler = DDIMSampler(self)
413
+ b, c, h, w = cond["c_concat"][0][0].shape if isinstance(cond["c_concat"][0], list) else cond["c_concat"][0].shape
414
+ # shape = (self.channels, h // 8, w // 8)
415
+ shape = (self.channels, h, w)
416
  samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
417
  return samples, intermediates
418
 
 
438
  self.cond_stage_model = self.cond_stage_model.cuda()
439
 
440
 
441
+
442
+ class PAIRDiffusion(ControlLDM):
443
  @torch.no_grad()
444
+ def __init__(self,control_stage_config, control_key, only_mid_control, app_net='vgg', app_layer_conc=(1,), app_layer_ca=(6,6,18,18),
445
+ appearance_net_locked=True, concat_multi_app=False, train_structure_variation_only=False, instruct=False, *args, **kwargs):
446
  super().__init__(control_stage_config=control_stage_config,
447
  control_key=control_key,
448
  only_mid_control=only_mid_control,
449
  *args, **kwargs)
 
 
450
 
451
+ self.appearance_net_conc = VGGPerceptualLoss().to(self.device)
452
+ self.appearance_net_ca = DINOv2().to(self.device)
453
+ self.appearance_net = VGGPerceptualLoss().to(self.device) #need to be removed no use
454
+ self.app_layer_conc = app_layer_conc
455
+ self.app_layer_ca = app_layer_ca
456
+
457
+
458
+ def get_appearance(self, net, layer, img, mask, return_all=False):
459
  img = (img + 1) * 0.5
460
+ feat = net(img)
461
+ splatted_feat = []
462
+ mean_feat = []
463
+ for fe_i in layer:
464
+ v = self.get_appearance_single(feat[fe_i], mask, return_all=return_all)
465
+ if return_all:
466
+ spl, me_f, one_hot, empty_mask = v
467
+ splatted_feat.append(spl)
468
+ mean_feat.append(me_f)
469
+ else:
470
+ splatted_feat.append(v)
471
+
472
+ if len(layer) == 1:
473
+ splatted_feat = splatted_feat[0]
474
+ # mean_feat = mean_feat[0]
475
+
476
+ del feat
477
+
478
+ if return_all:
479
+ return splatted_feat, mean_feat, one_hot, empty_mask
480
+
481
+ return splatted_feat
482
+
483
+ def get_appearance_single(self, feat, mask, return_all):
484
  empty_mask_flag = torch.sum(mask, dim=(1,2,3)) == 0
485
 
486
 
487
  empty_appearance = torch.zeros(feat.shape).to(self.device)
488
+ mask = torch.nn.functional.interpolate(mask.float(), size=(feat.shape[2], feat.shape[3])).long()
489
  one_hot = torch.nn.functional.one_hot(mask[:,0]).permute(0,3,1,2).float()
490
 
491
  feat = torch.einsum('nchw, nmhw->nmchw', feat, one_hot)
 
495
  mean_feat[:, 0] = torch.zeros(mean_feat[:,0].shape).to(self.device) #set edges in panopitc mask to empty appearance feature
496
 
497
  splatted_feat = torch.einsum('nmc, nmhw->nchw', mean_feat, one_hot)
498
+ splatted_feat[empty_mask_flag] = empty_appearance[empty_mask_flag]
499
  splatted_feat = torch.nn.functional.normalize(splatted_feat) #l2 normalize on c dim
500
 
501
  if return_all:
502
  return splatted_feat, mean_feat, one_hot, empty_mask_flag
 
503
  return splatted_feat
504
+
505
+
506
  def get_input(self, batch, k, bs=None, *args, **kwargs):
507
  z, c, x_orig, x_recon = super(ControlLDM, self).get_input(batch, self.first_stage_key, return_first_stage_outputs=True , *args, **kwargs)
508
  structure = batch['seg'].unsqueeze(1)
509
  mask = batch['mask'].unsqueeze(1).to(self.device)
510
+
511
+ appearance_conc = self.get_appearance(self.appearance_net_conc, self.app_layer_conc, x_orig, mask)
512
+ appearance_ca = self.get_appearance(self.appearance_net_ca, self.app_layer_ca, x_orig, mask)
513
+
514
  if bs is not None:
515
  structure = structure[:bs]
 
 
516
  structure = structure.to(self.device)
 
517
  structure = structure.to(memory_format=torch.contiguous_format).float()
518
+ structure = torch.nn.functional.interpolate(structure, z.shape[2:])
519
+
520
+ mask = torch.nn.functional.interpolate(mask.float(), z.shape[2:])
521
+
522
+ def format_appearance(appearance):
523
+ if isinstance(appearance, list):
524
+ if bs is not None:
525
+ appearance = [ap[:bs] for ap in appearance]
526
+ appearance = [ap.to(self.device) for ap in appearance]
527
+ appearance = [ap.to(memory_format=torch.contiguous_format).float() for ap in appearance]
528
+ appearance = [torch.nn.functional.interpolate(ap, z.shape[2:]) for ap in appearance]
529
+
530
+ else:
531
+ if bs is not None:
532
+ appearance = appearance[:bs]
533
+ appearance = appearance.to(self.device)
534
+ appearance = appearance.to(memory_format=torch.contiguous_format).float()
535
+ appearance = torch.nn.functional.interpolate(appearance, z.shape[2:])
536
+
537
+ return appearance
538
+
539
+ appearance_conc = format_appearance(appearance_conc)
540
+ appearance_ca = format_appearance(appearance_ca)
541
+
542
+ if isinstance(appearance_conc, list):
543
+ concat_control = torch.cat(appearance_conc, dim=1)
544
+ concat_control = torch.cat([structure, concat_control, mask], dim=1)
545
+ else:
546
+ concat_control = torch.cat([structure, appearance_conc, mask], dim=1)
547
+
548
+
549
+ if isinstance(appearance_ca, list):
550
+ control = []
551
+ for ap in appearance_ca:
552
+ control.append(torch.cat([structure, ap, mask], dim=1))
553
+ control.append(concat_control)
554
+ return z, dict(c_crossattn=[c], c_concat=[control])
555
+ else:
556
+ control = torch.cat([structure, appearance_ca, mask], dim=1)
557
+ control.append(concat_control)
558
+ return z, dict(c_crossattn=[c], c_concat=[control])
559
+
560
  @torch.no_grad()
561
  def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
562
  quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=False,
 
567
 
568
  log = dict()
569
  z, c = self.get_input(batch, self.first_stage_key, bs=N)
570
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
571
  N = min(z.shape[0], N)
572
  n_row = min(z.shape[0], n_row)
573
  log["reconstruction"] = self.decode_first_stage(z)
574
+ log["control"] = batch['mask'].unsqueeze(1)
575
+ if 'aug_mask' in batch:
576
+ log['aug_mask'] = batch['aug_mask'].unsqueeze(1)
577
+
578
  log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
579
 
580
  if plot_diffusion_rows:
 
616
 
617
  if unconditional_guidance_scale > 1.0:
618
  uc_cross = self.get_unconditional_conditioning(N)
619
+ uc_cat = list(c_cat) # torch.zeros_like(c_cat)
620
  uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
621
  samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
622
  batch_size=N, ddim=use_ddim,
 
628
  log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
629
 
630
  return log
631
+
632
+
633
+ def configure_optimizers(self):
634
+ lr = self.learning_rate
635
+
636
+ params = list(self.control_model.parameters())
637
+ if not self.sd_locked:
638
+ params += list(self.model.diffusion_model.output_blocks.parameters())
639
+ params += list(self.model.diffusion_model.out.parameters())
640
+
641
+ opt = torch.optim.AdamW(params, lr=lr)
642
+ return opt
643
+
644
+
645
+
cldm/controlnet.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch as th
3
+ import torch.nn as nn
4
+
5
+ from ldm.modules.diffusionmodules.util import (
6
+ conv_nd,
7
+ linear,
8
+ zero_module,
9
+ timestep_embedding,
10
+ )
11
+
12
+ from ldm.modules.attention import SpatialTransformer
13
+ from ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
14
+ from ldm.util import exists
15
+
16
+ torch.autograd.set_detect_anomaly(True)
17
+
18
+ class Interpolate(nn.Module):
19
+ def __init__(self, mode):
20
+ super(Interpolate, self).__init__()
21
+ self.interp = torch.nn.functional.interpolate
22
+ self.mode = mode
23
+ self.factor = 8
24
+
25
+ def forward(self, x):
26
+ return x
27
+
28
+ class ControlNetPAIR(nn.Module):
29
+ def __init__(
30
+ self,
31
+ image_size,
32
+ in_channels,
33
+ model_channels,
34
+ hint_channels,
35
+ concat_indices,
36
+ num_res_blocks,
37
+ attention_resolutions,
38
+ concat_channels=130,
39
+ dropout=0,
40
+ channel_mult=(1, 2, 4, 8),
41
+ mode='nearest',
42
+ conv_resample=True,
43
+ dims=2,
44
+ use_checkpoint=False,
45
+ use_fp16=False,
46
+ num_heads=-1,
47
+ num_head_channels=-1,
48
+ num_heads_upsample=-1,
49
+ use_scale_shift_norm=False,
50
+ resblock_updown=False,
51
+ use_new_attention_order=False,
52
+ use_spatial_transformer=False, # custom transformer support
53
+ transformer_depth=1, # custom transformer support
54
+ context_dim=None, # custom transformer support
55
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
56
+ legacy=True,
57
+ disable_self_attentions=None,
58
+ num_attention_blocks=None,
59
+ disable_middle_self_attn=False,
60
+ use_linear_in_transformer=False,
61
+ attn_class=['softmax', 'softmax', 'softmax', 'softmax'],
62
+ ):
63
+ super().__init__()
64
+ if use_spatial_transformer:
65
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
66
+
67
+ if context_dim is not None:
68
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
69
+ from omegaconf.listconfig import ListConfig
70
+ if type(context_dim) == ListConfig:
71
+ context_dim = list(context_dim)
72
+
73
+ if num_heads_upsample == -1:
74
+ num_heads_upsample = num_heads
75
+
76
+ if num_heads == -1:
77
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
78
+
79
+ if num_head_channels == -1:
80
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
81
+
82
+ self.dims = dims
83
+ self.image_size = image_size
84
+ self.in_channels = in_channels
85
+ self.model_channels = model_channels
86
+ if isinstance(num_res_blocks, int):
87
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
88
+ else:
89
+ if len(num_res_blocks) != len(channel_mult):
90
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
91
+ "as a list/tuple (per-level) with the same length as channel_mult")
92
+ self.num_res_blocks = num_res_blocks
93
+ if disable_self_attentions is not None:
94
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
95
+ assert len(disable_self_attentions) == len(channel_mult)
96
+ if num_attention_blocks is not None:
97
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
98
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
99
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
100
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
101
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
102
+ f"attention will still not be set.")
103
+
104
+ self.attention_resolutions = attention_resolutions
105
+ self.dropout = dropout
106
+ self.channel_mult = channel_mult
107
+ self.conv_resample = conv_resample
108
+ self.use_checkpoint = use_checkpoint
109
+ self.dtype = th.float16 if use_fp16 else th.float32
110
+ self.num_heads = num_heads
111
+ self.num_head_channels = num_head_channels
112
+ self.num_heads_upsample = num_heads_upsample
113
+ self.predict_codebook_ids = n_embed is not None
114
+
115
+ time_embed_dim = model_channels * 4
116
+ self.time_embed = nn.Sequential(
117
+ linear(model_channels, time_embed_dim),
118
+ nn.SiLU(),
119
+ linear(time_embed_dim, time_embed_dim),
120
+ )
121
+
122
+ self.input_blocks = nn.ModuleList(
123
+ [
124
+ TimestepEmbedSequential(
125
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
126
+ )
127
+ ]
128
+ )
129
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
130
+ self.concat_indices = concat_indices
131
+ self.hint_channels = hint_channels
132
+ h_ch = sum([hint_channels[i] for i in concat_indices ])
133
+
134
+ self.input_hint_block = TimestepEmbedSequential(
135
+ Interpolate('nearest'),
136
+ conv_nd(self.dims, concat_channels, self.model_channels, 3, padding=1),
137
+ nn.SiLU(),
138
+ zero_module(conv_nd(self.dims, self.model_channels, self.model_channels, 3, padding=1)))
139
+
140
+ self._feature_size = model_channels
141
+ input_block_chans = [model_channels]
142
+ ch = model_channels
143
+ ds = 1
144
+ for level, mult in enumerate(channel_mult):
145
+ for nr in range(self.num_res_blocks[level]):
146
+ layers = [
147
+ ResBlock(
148
+ ch,
149
+ time_embed_dim,
150
+ dropout,
151
+ out_channels=mult * model_channels,
152
+ dims=dims,
153
+ use_checkpoint=use_checkpoint,
154
+ use_scale_shift_norm=use_scale_shift_norm,
155
+ )
156
+ ]
157
+ ch = mult * model_channels
158
+ if ds in attention_resolutions:
159
+ if num_head_channels == -1:
160
+ dim_head = ch // num_heads
161
+ else:
162
+ num_heads = ch // num_head_channels
163
+ dim_head = num_head_channels
164
+ if legacy:
165
+ # num_heads = 1
166
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
167
+ if exists(disable_self_attentions):
168
+ disabled_sa = disable_self_attentions[level]
169
+ else:
170
+ disabled_sa = False
171
+
172
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
173
+ layers.append(
174
+ AttentionBlock(
175
+ ch,
176
+ use_checkpoint=use_checkpoint,
177
+ num_heads=num_heads,
178
+ num_head_channels=dim_head,
179
+ use_new_attention_order=use_new_attention_order,
180
+ ) if not use_spatial_transformer else SpatialTransformer(
181
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
182
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
183
+ use_checkpoint=use_checkpoint, attn1_mode=attn_class[level], obj_feat_dim=hint_channels[level]
184
+ )
185
+ )
186
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
187
+ self.zero_convs.append(self.make_zero_conv(ch))
188
+ self._feature_size += ch
189
+ input_block_chans.append(ch)
190
+ if level != len(channel_mult) - 1:
191
+ out_ch = ch
192
+ self.input_blocks.append(
193
+ TimestepEmbedSequential(
194
+ ResBlock(
195
+ ch,
196
+ time_embed_dim,
197
+ dropout,
198
+ out_channels=out_ch,
199
+ dims=dims,
200
+ use_checkpoint=use_checkpoint,
201
+ use_scale_shift_norm=use_scale_shift_norm,
202
+ down=True,
203
+ )
204
+ if resblock_updown
205
+ else Downsample(
206
+ ch, conv_resample, dims=dims, out_channels=out_ch
207
+ )
208
+ )
209
+ )
210
+ ch = out_ch
211
+ input_block_chans.append(ch)
212
+ self.zero_convs.append(self.make_zero_conv(ch))
213
+ ds *= 2
214
+ self._feature_size += ch
215
+
216
+ if num_head_channels == -1:
217
+ dim_head = ch // num_heads
218
+ else:
219
+ num_heads = ch // num_head_channels
220
+ dim_head = num_head_channels
221
+ if legacy:
222
+ # num_heads = 1
223
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
224
+ self.middle_block = TimestepEmbedSequential(
225
+ ResBlock(
226
+ ch,
227
+ time_embed_dim,
228
+ # hint_channels[-1],
229
+ dropout,
230
+ dims=dims,
231
+ use_checkpoint=use_checkpoint,
232
+ use_scale_shift_norm=use_scale_shift_norm,
233
+ ),
234
+ AttentionBlock(
235
+ ch,
236
+ use_checkpoint=use_checkpoint,
237
+ num_heads=num_heads,
238
+ num_head_channels=dim_head,
239
+ use_new_attention_order=use_new_attention_order,
240
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
241
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
242
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
243
+ use_checkpoint=use_checkpoint
244
+ ),
245
+ ResBlock(
246
+ ch,
247
+ time_embed_dim,
248
+ # hint_channels[-1],
249
+ dropout,
250
+ dims=dims,
251
+ use_checkpoint=use_checkpoint,
252
+ use_scale_shift_norm=use_scale_shift_norm,
253
+ ),
254
+ )
255
+ self.middle_block_out = self.make_zero_conv(ch)
256
+ self._feature_size += ch
257
+
258
+ def make_zero_conv(self, channels):
259
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
260
+
261
+ def forward(self, x, hint, timesteps, context, **kwargs):
262
+ hint_list = []
263
+ concat_hint = hint[-1]
264
+ hint_c = hint[:-1]
265
+
266
+ if not isinstance(hint_c, list):
267
+ for _ in range(len(self.channel_mult)):
268
+ hint_list.append(hint_c)
269
+ else:
270
+ hint_list = hint_c
271
+ while len(hint_list) < 4:
272
+ hint_list.append(hint_c[-1])
273
+
274
+ mask = hint_c[0][:,-1].unsqueeze(1) #panoptic
275
+
276
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
277
+ emb = self.time_embed(t_emb)
278
+
279
+ guided_hint = self.input_hint_block(concat_hint, emb, context, x.shape)
280
+ outs = []
281
+
282
+ h = x.type(self.dtype)
283
+
284
+ cnt = self.num_res_blocks[0] + 1
285
+ i = 0
286
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
287
+ if guided_hint is not None:
288
+ h = module(h, emb, context, hint_list[i], mask)
289
+ h += guided_hint
290
+ guided_hint = None
291
+ else:
292
+ h = module(h, emb, context, hint_list[i], mask)
293
+ outs.append(zero_conv(h, emb, context))
294
+
295
+ cnt -= 1
296
+ if cnt == 0:
297
+ if i<len(self.num_res_blocks):
298
+ cnt = self.num_res_blocks[i] + 1
299
+ else:
300
+ if (i+1)<len(self.num_res_blocks):
301
+ i += 1
302
+
303
+ h = self.middle_block(h, emb, context, hint_list[-1], mask)
304
+ outs.append(self.middle_block_out(h, emb, context))
305
+
306
+ return outs
cldm/ddim_hacked.py CHANGED
@@ -316,7 +316,6 @@ class DDIMSampler(object):
316
  return x_dec
317
 
318
 
319
-
320
  class DDIMSamplerSpaCFG(DDIMSampler):
321
  @torch.no_grad()
322
  def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
@@ -332,8 +331,8 @@ class DDIMSamplerSpaCFG(DDIMSampler):
332
  model_uncond = self.model.apply_model(x, t, unconditional_conditioning[0])
333
  model_struct = self.model.apply_model(x, t, unconditional_conditioning[1])
334
  model_struct_app = self.model.apply_model(x, t, unconditional_conditioning[2])
335
- sT, sS, sF = unconditional_guidance_scale
336
- model_output = model_uncond + sS * (model_struct - model_uncond) + sF * (model_struct_app - model_struct) + sT * (model_t - model_struct_app)
337
 
338
  if self.model.parameterization == "v":
339
  e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
 
316
  return x_dec
317
 
318
 
 
319
  class DDIMSamplerSpaCFG(DDIMSampler):
320
  @torch.no_grad()
321
  def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
 
331
  model_uncond = self.model.apply_model(x, t, unconditional_conditioning[0])
332
  model_struct = self.model.apply_model(x, t, unconditional_conditioning[1])
333
  model_struct_app = self.model.apply_model(x, t, unconditional_conditioning[2])
334
+ sS, sF, sT = unconditional_guidance_scale
335
+ model_output = model_uncond + sS * (model_struct - model_uncond) + sF * (model_struct_app - model_struct) + sT * (model_t - model_uncond)
336
 
337
  if self.model.parameterization == "v":
338
  e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
cldm/logger.py CHANGED
@@ -114,16 +114,16 @@ class SetupCallback(Callback):
114
  OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
115
  os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
116
 
117
- else:
118
- # ModelCheckpoint callback created log directory --- remove it
119
- if not self.resume and os.path.exists(self.logdir):
120
- dst, name = os.path.split(self.logdir)
121
- dst = os.path.join(dst, "child_runs", name)
122
- os.makedirs(os.path.split(dst)[0], exist_ok=True)
123
- try:
124
- os.rename(self.logdir, dst)
125
- except FileNotFoundError:
126
- pass
127
 
128
 
129
  class ImageLogger(Callback):
 
114
  OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
115
  os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
116
 
117
+ # else:
118
+ # # ModelCheckpoint callback created log directory --- remove it
119
+ # if not self.resume and os.path.exists(self.logdir):
120
+ # dst, name = os.path.split(self.logdir)
121
+ # dst = os.path.join(dst, "child_runs", name)
122
+ # os.makedirs(os.path.split(dst)[0], exist_ok=True)
123
+ # try:
124
+ # os.rename(self.logdir, dst)
125
+ # except FileNotFoundError:
126
+ # pass
127
 
128
 
129
  class ImageLogger(Callback):
configs/{sap_fixed_hintnet_v15.yaml → pair_diff.yaml} RENAMED
@@ -1,9 +1,9 @@
1
  model:
2
- target: cldm.cldm.SAP
3
  learning_rate: 1.5e-05
4
  sd_locked: True
5
  only_mid_control: False
6
- init_ckpt: './models/sap_sd15_ini_fixed.ckpt'
7
  params:
8
  linear_start: 0.00085
9
  linear_end: 0.0120
@@ -21,14 +21,17 @@ model:
21
  scale_factor: 0.18215
22
  use_ema: False
23
  only_mid_control: False
 
 
24
 
25
  control_stage_config:
26
- target: cldm.cldm.ControlNetSAP
27
  params:
28
- input_hint_block: 'fixed'
29
  image_size: 32 # unused
30
  in_channels: 4
31
- hint_channels: 129 #(128 + 1)
 
 
32
  model_channels: 320
33
  attention_resolutions: [ 4, 2, 1 ]
34
  num_res_blocks: 2
@@ -39,6 +42,7 @@ model:
39
  context_dim: 768
40
  use_checkpoint: True
41
  legacy: False
 
42
 
43
  unet_config:
44
  target: cldm.cldm.ControlledUnetModel
@@ -87,16 +91,25 @@ model:
87
  data:
88
  target: cldm.data.DataModuleFromConfig
89
  params:
90
- batch_size: 4
91
  wrap: True
 
92
  train:
93
  target: dataset.txtseg.COCOTrain
94
  params:
 
 
 
 
95
  size: 512
96
  validation:
97
  target: dataset.txtseg.COCOValidation
98
  params:
99
  size: 512
 
 
 
 
100
 
101
 
102
  lightning:
@@ -111,4 +124,4 @@ lightning:
111
 
112
  trainer:
113
  benchmark: True
114
- accumulate_grad_batches: 4
 
1
  model:
2
+ target: cldm.cldm.PAIRDiffusion
3
  learning_rate: 1.5e-05
4
  sd_locked: True
5
  only_mid_control: False
6
+ init_ckpt: './models/pair_diff_init.ckpt'
7
  params:
8
  linear_start: 0.00085
9
  linear_end: 0.0120
 
21
  scale_factor: 0.18215
22
  use_ema: False
23
  only_mid_control: False
24
+ appearance_net_locked: True
25
+ app_net: 'DINO'
26
 
27
  control_stage_config:
28
+ target: cldm.controlnet.ControlNetPAIR
29
  params:
 
30
  image_size: 32 # unused
31
  in_channels: 4
32
+ concat_indices: [0,1]
33
+ concat_channels: 130
34
+ hint_channels: [1026, 1026, -1, -1] #(1024 + 2)
35
  model_channels: 320
36
  attention_resolutions: [ 4, 2, 1 ]
37
  num_res_blocks: 2
 
42
  context_dim: 768
43
  use_checkpoint: True
44
  legacy: False
45
+ attn_class: ['maskguided', 'maskguided', 'softmax', 'softmax']
46
 
47
  unet_config:
48
  target: cldm.cldm.ControlledUnetModel
 
91
  data:
92
  target: cldm.data.DataModuleFromConfig
93
  params:
94
+ batch_size: 2
95
  wrap: True
96
+ num_workers: 4
97
  train:
98
  target: dataset.txtseg.COCOTrain
99
  params:
100
+ image_dir:
101
+ caption_file:
102
+ panoptic_mask_dir:
103
+ seg_dir:
104
  size: 512
105
  validation:
106
  target: dataset.txtseg.COCOValidation
107
  params:
108
  size: 512
109
+ image_dir:
110
+ caption_file:
111
+ panoptic_mask_dir:
112
+ seg_dir:
113
 
114
 
115
  lightning:
 
124
 
125
  trainer:
126
  benchmark: True
127
+ accumulate_grad_batches: 2
ldm/ldm/util.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import torch
4
+ from torch import optim
5
+ import numpy as np
6
+
7
+ from inspect import isfunction
8
+ from PIL import Image, ImageDraw, ImageFont
9
+
10
+
11
+ def log_txt_as_img(wh, xc, size=10):
12
+ # wh a tuple of (width, height)
13
+ # xc a list of captions to plot
14
+ b = len(xc)
15
+ txts = list()
16
+ for bi in range(b):
17
+ txt = Image.new("RGB", wh, color="white")
18
+ draw = ImageDraw.Draw(txt)
19
+ font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
20
+ nc = int(40 * (wh[0] / 256))
21
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
22
+
23
+ try:
24
+ draw.text((0, 0), lines, fill="black", font=font)
25
+ except UnicodeEncodeError:
26
+ print("Cant encode string for logging. Skipping.")
27
+
28
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
29
+ txts.append(txt)
30
+ txts = np.stack(txts)
31
+ txts = torch.tensor(txts)
32
+ return txts
33
+
34
+
35
+ def ismap(x):
36
+ if not isinstance(x, torch.Tensor):
37
+ return False
38
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
39
+
40
+
41
+ def isimage(x):
42
+ if not isinstance(x,torch.Tensor):
43
+ return False
44
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
45
+
46
+
47
+ def exists(x):
48
+ return x is not None
49
+
50
+
51
+ def default(val, d):
52
+ if exists(val):
53
+ return val
54
+ return d() if isfunction(d) else d
55
+
56
+
57
+ def mean_flat(tensor):
58
+ """
59
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
60
+ Take the mean over all non-batch dimensions.
61
+ """
62
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
63
+
64
+
65
+ def count_params(model, verbose=False):
66
+ total_params = sum(p.numel() for p in model.parameters())
67
+ if verbose:
68
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
69
+ return total_params
70
+
71
+
72
+ def instantiate_from_config(config):
73
+ if not "target" in config:
74
+ if config == '__is_first_stage__':
75
+ return None
76
+ elif config == "__is_unconditional__":
77
+ return None
78
+ raise KeyError("Expected key `target` to instantiate.")
79
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
80
+
81
+
82
+ def get_obj_from_str(string, reload=False):
83
+ module, cls = string.rsplit(".", 1)
84
+ if reload:
85
+ module_imp = importlib.import_module(module)
86
+ importlib.reload(module_imp)
87
+ return getattr(importlib.import_module(module, package=None), cls)
88
+
89
+
90
+ class AdamWwithEMAandWings(optim.Optimizer):
91
+ # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
92
+ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
93
+ weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
94
+ ema_power=1., param_names=()):
95
+ """AdamW that saves EMA versions of the parameters."""
96
+ if not 0.0 <= lr:
97
+ raise ValueError("Invalid learning rate: {}".format(lr))
98
+ if not 0.0 <= eps:
99
+ raise ValueError("Invalid epsilon value: {}".format(eps))
100
+ if not 0.0 <= betas[0] < 1.0:
101
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
102
+ if not 0.0 <= betas[1] < 1.0:
103
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
104
+ if not 0.0 <= weight_decay:
105
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
106
+ if not 0.0 <= ema_decay <= 1.0:
107
+ raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
108
+ defaults = dict(lr=lr, betas=betas, eps=eps,
109
+ weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
110
+ ema_power=ema_power, param_names=param_names)
111
+ super().__init__(params, defaults)
112
+
113
+ def __setstate__(self, state):
114
+ super().__setstate__(state)
115
+ for group in self.param_groups:
116
+ group.setdefault('amsgrad', False)
117
+
118
+ @torch.no_grad()
119
+ def step(self, closure=None):
120
+ """Performs a single optimization step.
121
+ Args:
122
+ closure (callable, optional): A closure that reevaluates the model
123
+ and returns the loss.
124
+ """
125
+ loss = None
126
+ if closure is not None:
127
+ with torch.enable_grad():
128
+ loss = closure()
129
+
130
+ for group in self.param_groups:
131
+ params_with_grad = []
132
+ grads = []
133
+ exp_avgs = []
134
+ exp_avg_sqs = []
135
+ ema_params_with_grad = []
136
+ state_sums = []
137
+ max_exp_avg_sqs = []
138
+ state_steps = []
139
+ amsgrad = group['amsgrad']
140
+ beta1, beta2 = group['betas']
141
+ ema_decay = group['ema_decay']
142
+ ema_power = group['ema_power']
143
+
144
+ for p in group['params']:
145
+ if p.grad is None:
146
+ continue
147
+ params_with_grad.append(p)
148
+ if p.grad.is_sparse:
149
+ raise RuntimeError('AdamW does not support sparse gradients')
150
+ grads.append(p.grad)
151
+
152
+ state = self.state[p]
153
+
154
+ # State initialization
155
+ if len(state) == 0:
156
+ state['step'] = 0
157
+ # Exponential moving average of gradient values
158
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
159
+ # Exponential moving average of squared gradient values
160
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
161
+ if amsgrad:
162
+ # Maintains max of all exp. moving avg. of sq. grad. values
163
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
164
+ # Exponential moving average of parameter values
165
+ state['param_exp_avg'] = p.detach().float().clone()
166
+
167
+ exp_avgs.append(state['exp_avg'])
168
+ exp_avg_sqs.append(state['exp_avg_sq'])
169
+ ema_params_with_grad.append(state['param_exp_avg'])
170
+
171
+ if amsgrad:
172
+ max_exp_avg_sqs.append(state['max_exp_avg_sq'])
173
+
174
+ # update the steps for each param group update
175
+ state['step'] += 1
176
+ # record the step after step update
177
+ state_steps.append(state['step'])
178
+
179
+ optim._functional.adamw(params_with_grad,
180
+ grads,
181
+ exp_avgs,
182
+ exp_avg_sqs,
183
+ max_exp_avg_sqs,
184
+ state_steps,
185
+ amsgrad=amsgrad,
186
+ beta1=beta1,
187
+ beta2=beta2,
188
+ lr=group['lr'],
189
+ weight_decay=group['weight_decay'],
190
+ eps=group['eps'],
191
+ maximize=False)
192
+
193
+ cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
194
+ for param, ema_param in zip(params_with_grad, ema_params_with_grad):
195
+ ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
196
+
197
+ return loss
ldm/models/diffusion/ddim.py CHANGED
@@ -194,9 +194,19 @@ class DDIMSampler(object):
194
  c_in = dict()
195
  for k in c:
196
  if isinstance(c[k], list):
197
- c_in[k] = [torch.cat([
198
- unconditional_conditioning[k][i],
199
- c[k][i]]) for i in range(len(c[k]))]
 
 
 
 
 
 
 
 
 
 
200
  else:
201
  c_in[k] = torch.cat([
202
  unconditional_conditioning[k],
@@ -333,4 +343,5 @@ class DDIMSampler(object):
333
  unconditional_guidance_scale=unconditional_guidance_scale,
334
  unconditional_conditioning=unconditional_conditioning)
335
  if callback: callback(i)
336
- return x_dec
 
 
194
  c_in = dict()
195
  for k in c:
196
  if isinstance(c[k], list):
197
+ c_in[k] = []
198
+ if isinstance(c[k][0], list):
199
+ for i in range(len(c[k])):
200
+ c_ = []
201
+ for j in range(len(c[k][i])):
202
+ c_.append(torch.cat([
203
+ unconditional_conditioning[k][i][j],
204
+ c[k][i][j]]) )
205
+ c_in[k].append(c_)
206
+ else:
207
+ c_in[k] = [torch.cat([
208
+ unconditional_conditioning[k][i],
209
+ c[k][i]]) for i in range(len(c[k]))]
210
  else:
211
  c_in[k] = torch.cat([
212
  unconditional_conditioning[k],
 
343
  unconditional_guidance_scale=unconditional_guidance_scale,
344
  unconditional_conditioning=unconditional_conditioning)
345
  if callback: callback(i)
346
+ return x_dec
347
+
ldm/modules/attention.py CHANGED
@@ -42,7 +42,7 @@ def init_(tensor):
42
  dim = tensor.shape[-1]
43
  std = 1 / math.sqrt(dim)
44
  tensor.uniform_(-std, std)
45
- return tensor
46
 
47
 
48
  # feedforward
@@ -143,7 +143,7 @@ class SpatialSelfAttention(nn.Module):
143
 
144
 
145
  class CrossAttention(nn.Module):
146
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
147
  super().__init__()
148
  inner_dim = dim_head * heads
149
  context_dim = default(context_dim, query_dim)
@@ -160,7 +160,7 @@ class CrossAttention(nn.Module):
160
  nn.Dropout(dropout)
161
  )
162
 
163
- def forward(self, x, context=None, mask=None):
164
  h = self.heads
165
 
166
  q = self.to_q(x)
@@ -194,6 +194,34 @@ class CrossAttention(nn.Module):
194
  return self.to_out(out)
195
 
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  class MemoryEfficientCrossAttention(nn.Module):
198
  # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
199
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
@@ -246,17 +274,19 @@ class MemoryEfficientCrossAttention(nn.Module):
246
  class BasicTransformerBlock(nn.Module):
247
  ATTENTION_MODES = {
248
  "softmax": CrossAttention, # vanilla attention
249
- "softmax-xformers": MemoryEfficientCrossAttention
 
250
  }
251
  def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
252
- disable_self_attn=False):
253
  super().__init__()
254
  attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
255
  assert attn_mode in self.ATTENTION_MODES
256
  attn_cls = self.ATTENTION_MODES[attn_mode]
 
257
  self.disable_self_attn = disable_self_attn
258
- self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
259
- context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
260
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
261
  self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
262
  heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
@@ -265,11 +295,17 @@ class BasicTransformerBlock(nn.Module):
265
  self.norm3 = nn.LayerNorm(dim)
266
  self.checkpoint = checkpoint
267
 
268
- def forward(self, x, context=None):
269
- return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
 
 
 
 
 
270
 
271
- def _forward(self, x, context=None):
272
- x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
 
273
  x = self.attn2(self.norm2(x), context=context) + x
274
  x = self.ff(self.norm3(x)) + x
275
  return x
@@ -287,7 +323,7 @@ class SpatialTransformer(nn.Module):
287
  def __init__(self, in_channels, n_heads, d_head,
288
  depth=1, dropout=0., context_dim=None,
289
  disable_self_attn=False, use_linear=False,
290
- use_checkpoint=True):
291
  super().__init__()
292
  if exists(context_dim) and not isinstance(context_dim, list):
293
  context_dim = [context_dim]
@@ -305,7 +341,8 @@ class SpatialTransformer(nn.Module):
305
 
306
  self.transformer_blocks = nn.ModuleList(
307
  [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
308
- disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
 
309
  for d in range(depth)]
310
  )
311
  if not use_linear:
@@ -318,11 +355,20 @@ class SpatialTransformer(nn.Module):
318
  self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
319
  self.use_linear = use_linear
320
 
321
- def forward(self, x, context=None):
322
  # note: if no context is given, cross-attention defaults to self-attention
323
  if not isinstance(context, list):
324
  context = [context]
 
 
 
 
 
325
  b, c, h, w = x.shape
 
 
 
 
326
  x_in = x
327
  x = self.norm(x)
328
  if not self.use_linear:
@@ -331,7 +377,7 @@ class SpatialTransformer(nn.Module):
331
  if self.use_linear:
332
  x = self.proj_in(x)
333
  for i, block in enumerate(self.transformer_blocks):
334
- x = block(x, context=context[i])
335
  if self.use_linear:
336
  x = self.proj_out(x)
337
  x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
 
42
  dim = tensor.shape[-1]
43
  std = 1 / math.sqrt(dim)
44
  tensor.uniform_(-std, std)
45
+ return tensor
46
 
47
 
48
  # feedforward
 
143
 
144
 
145
  class CrossAttention(nn.Module):
146
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., **kargs):
147
  super().__init__()
148
  inner_dim = dim_head * heads
149
  context_dim = default(context_dim, query_dim)
 
160
  nn.Dropout(dropout)
161
  )
162
 
163
+ def forward(self, x, context=None, mask=None, **kargs):
164
  h = self.heads
165
 
166
  q = self.to_q(x)
 
194
  return self.to_out(out)
195
 
196
 
197
+ class MaskGuidedSelfAttention(nn.Module):
198
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., obj_feat_dim=1024):
199
+ super().__init__()
200
+ #here context dim is for object features coming from image encoder
201
+ inner_dim = dim_head * heads
202
+ self.heads = heads
203
+
204
+ self.obj_feats_map = nn.Linear(obj_feat_dim, inner_dim)
205
+ self.to_v = nn.Linear(inner_dim, inner_dim, bias=False)
206
+
207
+ self.to_out = nn.Sequential(
208
+ nn.Linear(inner_dim, query_dim),
209
+ nn.Dropout(dropout)
210
+ )
211
+
212
+ self.scale = dim_head ** -0.5
213
+
214
+ def forward(self, x, context=None, mask=None, obj_mask=None, obj_feat=None):
215
+ _, _, ht, wd = obj_feat.shape
216
+ obj_feat = rearrange(obj_feat, 'b c h w -> b (h w) c').contiguous()
217
+ obj_feat = self.obj_feats_map(obj_feat)
218
+ v = self.to_v(obj_feat)
219
+ return self.to_out(v)
220
+
221
+
222
+
223
+
224
+
225
  class MemoryEfficientCrossAttention(nn.Module):
226
  # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
227
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
 
274
  class BasicTransformerBlock(nn.Module):
275
  ATTENTION_MODES = {
276
  "softmax": CrossAttention, # vanilla attention
277
+ "softmax-xformers": MemoryEfficientCrossAttention,
278
+ "maskguided": MaskGuidedSelfAttention
279
  }
280
  def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
281
+ disable_self_attn=False, attn1_mode="softmax", obj_feat_dim=1024):
282
  super().__init__()
283
  attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
284
  assert attn_mode in self.ATTENTION_MODES
285
  attn_cls = self.ATTENTION_MODES[attn_mode]
286
+ attn1_cls = self.ATTENTION_MODES[attn1_mode]
287
  self.disable_self_attn = disable_self_attn
288
+ self.attn1 = attn1_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
289
+ context_dim=context_dim if self.disable_self_attn else None, obj_feat_dim=obj_feat_dim) # is a self-attention if not self.disable_self_attn
290
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
291
  self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
292
  heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
 
295
  self.norm3 = nn.LayerNorm(dim)
296
  self.checkpoint = checkpoint
297
 
298
+ # self.ff_text_obj_feat = FeedForward(context_dim, dim_out=dim, mult=1, dropout=dropout, glu=gated_ff)
299
+
300
+ def forward(self, x, context=None, obj_mask=None, obj_feat=None):
301
+ if obj_mask is None:
302
+ # return self._forward(x, context, obj_mask, obj_feat)
303
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
304
+ return checkpoint(self._forward, (x, context, obj_mask, obj_feat), self.parameters(), self.checkpoint)
305
 
306
+ def _forward(self, x, context=None, obj_mask=None, obj_feat=None):
307
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None,
308
+ obj_mask=obj_mask, obj_feat=obj_feat) + x
309
  x = self.attn2(self.norm2(x), context=context) + x
310
  x = self.ff(self.norm3(x)) + x
311
  return x
 
323
  def __init__(self, in_channels, n_heads, d_head,
324
  depth=1, dropout=0., context_dim=None,
325
  disable_self_attn=False, use_linear=False,
326
+ use_checkpoint=True,attn1_mode='softmax',obj_feat_dim=None):
327
  super().__init__()
328
  if exists(context_dim) and not isinstance(context_dim, list):
329
  context_dim = [context_dim]
 
341
 
342
  self.transformer_blocks = nn.ModuleList(
343
  [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
344
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn1_mode=attn1_mode,
345
+ obj_feat_dim=obj_feat_dim)
346
  for d in range(depth)]
347
  )
348
  if not use_linear:
 
355
  self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
356
  self.use_linear = use_linear
357
 
358
+ def forward(self, x, context=None, obj_masks=None, obj_feats=None):
359
  # note: if no context is given, cross-attention defaults to self-attention
360
  if not isinstance(context, list):
361
  context = [context]
362
+ if not isinstance(obj_masks, list):
363
+ obj_masks = [obj_masks]
364
+ if not isinstance(obj_feats, list):
365
+ obj_feats = [obj_feats]
366
+
367
  b, c, h, w = x.shape
368
+ if obj_feats[0] is not None:
369
+ obj_feats = [torch.nn.functional.interpolate(ofe, [h,w]) for ofe in obj_feats]
370
+ obj_masks = [torch.nn.functional.interpolate(om, [h,w]) for om in obj_masks]
371
+
372
  x_in = x
373
  x = self.norm(x)
374
  if not self.use_linear:
 
377
  if self.use_linear:
378
  x = self.proj_in(x)
379
  for i, block in enumerate(self.transformer_blocks):
380
+ x = block(x, context=context[i], obj_mask=obj_masks[i], obj_feat=obj_feats[i])
381
  if self.use_linear:
382
  x = self.proj_out(x)
383
  x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
ldm/modules/diffusionmodules/openaimodel.py CHANGED
@@ -69,19 +69,31 @@ class TimestepBlock(nn.Module):
69
  Apply the module to `x` given `emb` timestep embeddings.
70
  """
71
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
74
  """
75
  A sequential module that passes timestep embeddings to the children that
76
  support it as an extra input.
77
  """
78
 
79
- def forward(self, x, emb, context=None, *args):
80
  for layer in self:
81
  if isinstance(layer, TimestepBlock):
82
  x = layer(x, emb)
83
  elif isinstance(layer, SpatialTransformer):
84
- x = layer(x, context)
 
 
85
  else:
86
  x = layer(x)
87
  return x
@@ -783,4 +795,4 @@ class UNetModel(nn.Module):
783
  if self.predict_codebook_ids:
784
  return self.id_predictor(h)
785
  else:
786
- return self.out(h)
 
69
  Apply the module to `x` given `emb` timestep embeddings.
70
  """
71
 
72
+ class TimestepBlockSpa(nn.Module):
73
+ """
74
+ Any module where forward() takes timestep embeddings as a second argument.
75
+ """
76
+
77
+ @abstractmethod
78
+ def forward(self, x, emb, obj_feat):
79
+ """
80
+ Apply the module to `x` given `emb` timestep embeddings.
81
+ """
82
 
83
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock, TimestepBlockSpa):
84
  """
85
  A sequential module that passes timestep embeddings to the children that
86
  support it as an extra input.
87
  """
88
 
89
+ def forward(self, x, emb, context=None, obj_feat=None,obj_masks=None, *args):
90
  for layer in self:
91
  if isinstance(layer, TimestepBlock):
92
  x = layer(x, emb)
93
  elif isinstance(layer, SpatialTransformer):
94
+ x = layer(x, context, obj_masks=obj_masks, obj_feats=obj_feat)
95
+ elif isinstance(layer, TimestepBlockSpa):
96
+ x = layer(x, emb, obj_feat)
97
  else:
98
  x = layer(x)
99
  return x
 
795
  if self.predict_codebook_ids:
796
  return self.id_predictor(h)
797
  else:
798
+ return self.out(h)
ldm/modules/diffusionmodules/util.py CHANGED
@@ -215,9 +215,10 @@ class SiLU(nn.Module):
215
 
216
 
217
  class GroupNorm32(nn.GroupNorm):
218
- def forward(self, x):
219
  return super().forward(x.float()).type(x.dtype)
220
 
 
221
  def conv_nd(dims, *args, **kwargs):
222
  """
223
  Create a 1D, 2D, or 3D convolution module.
 
215
 
216
 
217
  class GroupNorm32(nn.GroupNorm):
218
+ def forward(self, x, *args):
219
  return super().forward(x.float()).type(x.dtype)
220
 
221
+
222
  def conv_nd(dims, *args, **kwargs):
223
  """
224
  Create a 1D, 2D, or 3D convolution module.
ldm/modules/encoders/modules.py CHANGED
@@ -114,14 +114,14 @@ class FrozenCLIPEmbedder(AbstractEncoder):
114
  for param in self.parameters():
115
  param.requires_grad = False
116
 
117
- def forward(self, text):
118
  batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
119
  return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
120
  tokens = batch_encoding["input_ids"].to(self.device)
121
  outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
122
- if self.layer == "last":
123
  z = outputs.last_hidden_state
124
- elif self.layer == "pooled":
125
  z = outputs.pooler_output[:, None, :]
126
  else:
127
  z = outputs.hidden_states[self.layer_idx]
 
114
  for param in self.parameters():
115
  param.requires_grad = False
116
 
117
+ def forward(self, text, layer='last'):
118
  batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
119
  return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
120
  tokens = batch_encoding["input_ids"].to(self.device)
121
  outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
122
+ if layer == "last":
123
  z = outputs.last_hidden_state
124
+ elif layer == "pooled":
125
  z = outputs.pooler_output[:, None, :]
126
  else:
127
  z = outputs.hidden_states[self.layer_idx]
pair_diff_demo.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import einops
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ import random
7
+ import os
8
+ import json
9
+ import datetime
10
+ from huggingface_hub import hf_hub_url, hf_hub_download
11
+
12
+ from pytorch_lightning import seed_everything
13
+ from annotator.util import resize_image, HWC3
14
+ from annotator.OneFormer import OneformerSegmenter
15
+ from cldm.model import create_model, load_state_dict
16
+ from cldm.ddim_hacked import DDIMSamplerSpaCFG
17
+ from ldm.models.autoencoder import DiagonalGaussianDistribution
18
+
19
+
20
+ SEGMENT_MODEL_DICT = {
21
+ 'Oneformer': OneformerSegmenter,
22
+ }
23
+
24
+ MASK_MODEL_DICT = {
25
+ 'Oneformer': OneformerSegmenter,
26
+ }
27
+
28
+ urls = {
29
+ 'shi-labs/oneformer_coco_swin_large': ['150_16_swin_l_oneformer_coco_100ep.pth'],
30
+ 'PAIR/PAIR-diffusion-sdv15-coco-finetune': ['model_e91.ckpt']
31
+ }
32
+
33
+ WTS_DICT = {
34
+
35
+ }
36
+
37
+ if os.path.exists('checkpoints') == False:
38
+ os.mkdir('checkpoints')
39
+ for repo in urls:
40
+ files = urls[repo]
41
+ for file in files:
42
+ url = hf_hub_url(repo, file)
43
+ name_ckp = url.split('/')[-1]
44
+
45
+ WTS_DICT[repo] = hf_hub_download(repo_id=repo, filename=file)
46
+
47
+
48
+ #main model
49
+ model = create_model('configs/pair_diff.yaml').cpu()
50
+ model.load_state_dict(load_state_dict(WTS_DICT['PAIR/PAIR-diffusion-sdv15-coco-finetune'], location='cuda'))
51
+
52
+ save_dir = 'results/'
53
+
54
+ model = model.cuda()
55
+ ddim_sampler = DDIMSamplerSpaCFG(model)
56
+ save_memory = False
57
+
58
+
59
+ class ImageComp:
60
+ def __init__(self, edit_operation):
61
+ self.input_img = None
62
+ self.input_pmask = None
63
+ self.input_segmask = None
64
+ self.input_mask = None
65
+ self.input_points = []
66
+ self.input_scale = 1
67
+
68
+ self.ref_img = None
69
+ self.ref_pmask = None
70
+ self.ref_segmask = None
71
+ self.ref_mask = None
72
+ self.ref_points = []
73
+ self.ref_scale = 1
74
+
75
+ self.multi_modal = False
76
+
77
+ self.H = None
78
+ self.W = None
79
+ self.kernel = np.ones((5, 5), np.uint8)
80
+ self.edit_operation = edit_operation
81
+ self.init_segmentation_model()
82
+ os.makedirs(save_dir, exist_ok=True)
83
+
84
+ self.base_prompt = 'A picture of {}'
85
+
86
+ def init_segmentation_model(self, mask_model='Oneformer', segment_model='Oneformer'):
87
+ self.segment_model_name = segment_model
88
+ self.mask_model_name = mask_model
89
+
90
+ self.segment_model = SEGMENT_MODEL_DICT[segment_model](WTS_DICT['shi-labs/oneformer_coco_swin_large'])
91
+
92
+ if mask_model == 'Oneformer' and segment_model == 'Oneformer':
93
+ self.mask_model_inp = self.segment_model
94
+ self.mask_model_ref = self.segment_model
95
+ else:
96
+ self.mask_model_inp = MASK_MODEL_DICT[mask_model]()
97
+ self.mask_model_ref = MASK_MODEL_DICT[mask_model]()
98
+
99
+ print(f"Segmentation Models initialized with {mask_model} as mask and {segment_model} as segment")
100
+
101
+ def init_input_canvas(self, img):
102
+
103
+ img = HWC3(img)
104
+ img = resize_image(img, 512)
105
+ if self.segment_model_name == 'Oneformer':
106
+ detected_seg = self.segment_model(img, 'semantic')
107
+ elif self.segment_model_name == 'SAM':
108
+ raise NotImplementedError
109
+
110
+ if self.mask_model_name == 'Oneformer':
111
+ detected_mask = self.mask_model_inp(img, 'panoptic')[0]
112
+ elif self.mask_model_name == 'SAM':
113
+ detected_mask = self.mask_model_inp(img)
114
+
115
+ self.input_points = []
116
+ self.input_img = img
117
+ self.input_pmask = detected_mask
118
+ self.input_segmask = detected_seg
119
+ self.H = img.shape[0]
120
+ self.W = img.shape[1]
121
+
122
+ return img
123
+
124
+ def init_ref_canvas(self, img):
125
+
126
+ img = HWC3(img)
127
+ img = resize_image(img, 512)
128
+ if self.segment_model_name == 'Oneformer':
129
+ detected_seg = self.segment_model(img, 'semantic')
130
+ elif self.segment_model_name == 'SAM':
131
+ raise NotImplementedError
132
+
133
+ if self.mask_model_name == 'Oneformer':
134
+ detected_mask = self.mask_model_ref(img, 'panoptic')[0]
135
+ elif self.mask_model_name == 'SAM':
136
+ detected_mask = self.mask_model_ref(img)
137
+
138
+ self.ref_points = []
139
+ print("Initialized ref", img.shape)
140
+ self.ref_img = img
141
+ self.ref_pmask = detected_mask
142
+ self.ref_segmask = detected_seg
143
+
144
+ return img
145
+
146
+ def select_input_object(self, evt: gr.SelectData):
147
+ idx = list(np.array(evt.index) * self.input_scale)
148
+ self.input_points.append(idx)
149
+ if self.mask_model_name == 'Oneformer':
150
+ mask = self._get_mask_from_panoptic(np.array(self.input_points), self.input_pmask)
151
+ else:
152
+ mask = self.mask_model_inp(self.input_img, self.input_points)
153
+
154
+ c_ids = self.input_segmask[np.array(self.input_points)[:,1], np.array(self.input_points)[:,0]]
155
+ unique_ids, counts = torch.unique(c_ids, return_counts=True)
156
+ c_id = int(unique_ids[torch.argmax(counts)].cpu().detach().numpy())
157
+ category = self.segment_model.metadata.stuff_classes[c_id]
158
+ # print(self.segment_model.metadata.stuff_classes)
159
+
160
+ self.input_mask = mask
161
+ mask = mask.cpu().numpy()
162
+ output = mask[:,:,None] * self.input_img + (1 - mask[:,:,None]) * self.input_img * 0.2
163
+ return output.astype(np.uint8), self.base_prompt.format(category)
164
+
165
+ def select_ref_object(self, evt: gr.SelectData):
166
+ idx = list(np.array(evt.index) * self.ref_scale)
167
+ self.ref_points.append(idx)
168
+ if self.mask_model_name == 'Oneformer':
169
+ mask = self._get_mask_from_panoptic(np.array(self.ref_points), self.ref_pmask)
170
+ else:
171
+ mask = self.mask_model_ref(self.ref_img, self.ref_points)
172
+ c_ids = self.ref_segmask[np.array(self.ref_points)[:,1], np.array(self.ref_points)[:,0]]
173
+ unique_ids, counts = torch.unique(c_ids, return_counts=True)
174
+ c_id = int(unique_ids[torch.argmax(counts)].cpu().detach().numpy())
175
+ category = self.segment_model.metadata.stuff_classes[c_id]
176
+ print("Category of reference object is:", category)
177
+
178
+ self.ref_mask = mask
179
+ mask = mask.cpu().numpy()
180
+ output = mask[:,:,None] * self.ref_img + (1 - mask[:,:,None]) * self.ref_img * 0.2
181
+ return output.astype(np.uint8)
182
+
183
+ def clear_points(self):
184
+ self.input_points = []
185
+ self.ref_points = []
186
+ zeros_inp = np.zeros(self.input_img.shape)
187
+ zeros_ref = np.zeros(self.ref_img.shape)
188
+ return zeros_inp, zeros_ref
189
+
190
+ def return_input_img(self):
191
+ return self.input_img
192
+
193
+
194
+ def _get_mask_from_panoptic(self, points, panoptic_mask):
195
+ panoptic_mask_ = panoptic_mask + 1
196
+ ids = panoptic_mask_[points[:,1], points[:,0]]
197
+ unique_ids, counts = torch.unique(ids, return_counts=True)
198
+ mask_id = unique_ids[torch.argmax(counts)]
199
+ final_mask = torch.zeros(panoptic_mask.shape).cuda()
200
+ final_mask[panoptic_mask_ == mask_id] = 1
201
+
202
+ return final_mask
203
+
204
+
205
+ def _process_mask(self, mask, panoptic_mask, segmask):
206
+ obj_class = mask * (segmask + 1)
207
+ unique_ids, counts = torch.unique(obj_class, return_counts=True)
208
+ obj_class = unique_ids[torch.argmax(counts[1:]) + 1] - 1
209
+ return mask, obj_class
210
+
211
+
212
+ def _edit_app(self, whole_ref):
213
+ """
214
+ Manipulates the panoptic mask of input image to change appearance
215
+ """
216
+ input_pmask = self.input_pmask
217
+ input_segmask = self.input_segmask
218
+
219
+ if whole_ref:
220
+ reference_mask = torch.ones(self.ref_pmask.shape).cuda()
221
+ else:
222
+ reference_mask, _ = self._process_mask(self.ref_mask, self.ref_pmask, self.ref_segmask)
223
+
224
+ edit_mask, _ = self._process_mask(self.input_mask, self.input_pmask, self.input_segmask)
225
+ # tmp = cv2.dilate(edit_mask.squeeze().cpu().numpy(), self.kernel, iterations = 2)
226
+ # region_mask = torch.tensor(tmp).cuda()
227
+ region_mask = edit_mask
228
+ ma = torch.max(input_pmask)
229
+
230
+ input_pmask[edit_mask == 1] = ma + 1
231
+ return reference_mask, input_pmask, input_segmask, region_mask, ma
232
+
233
+ def _add_object(self, input_mask, dilation_fac):
234
+ """
235
+ Manipulates the panooptic mask of input image for adding objects
236
+ Args:
237
+ input_mask (numpy array): Region where new objects needs to be added
238
+ dilation factor (float): Controls edge merging region for adding objects
239
+
240
+ """
241
+ input_pmask = self.input_pmask
242
+ input_segmask = self.input_segmask
243
+ reference_mask, obj_class = self._process_mask(self.ref_mask, self.ref_pmask, self.ref_segmask)
244
+
245
+ tmp = cv2.dilate(input_mask['mask'][:, :, 0], self.kernel, iterations = int(dilation_fac))
246
+ region = torch.tensor(tmp)
247
+ region_mask = torch.zeros_like(region).cuda()
248
+ region_mask[region > 127] = 1
249
+
250
+ mask_ = torch.tensor(input_mask['mask'][:, :, 0])
251
+ edit_mask = torch.zeros_like(mask_).cuda()
252
+ edit_mask[mask_ > 127] = 1
253
+ ma = torch.max(input_pmask)
254
+ input_pmask[edit_mask == 1] = ma + 1
255
+ print(obj_class)
256
+ input_segmask[edit_mask == 1] = obj_class.long()
257
+
258
+ return reference_mask, input_pmask, input_segmask, region_mask, ma
259
+
260
+ def _edit(self, input_mask, ref_mask, dilation_fac=1, whole_ref=False, inter=1):
261
+ """
262
+ Entry point for all the appearance editing and add objects operations. The function manipulates the
263
+ appearance vectors and structure based on user input
264
+ Args:
265
+ input mask (numpy array): Region in input image which needs to be edited
266
+ dilation factor (float): Controls edge merging region for adding objects
267
+ whole_ref (bool): Flag for specifying if complete reference image should be used
268
+ inter (float): Interpolation of appearance between the reference appearance and the input appearance.
269
+ """
270
+ input_img = (self.input_img/127.5 - 1)
271
+ input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
272
+
273
+ reference_img = (self.ref_img/127.5 - 1)
274
+ reference_img = torch.from_numpy(reference_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
275
+
276
+ if self.edit_operation == 'add_obj':
277
+ reference_mask, input_pmask, input_segmask, region_mask, ma = self._add_object(input_mask, dilation_fac)
278
+ elif self.edit_operation == 'edit_app':
279
+ reference_mask, input_pmask, input_segmask, region_mask, ma = self._edit_app(whole_ref)
280
+
281
+ #concat featurees
282
+ input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1)
283
+ _, mean_feat_inpt_conc, one_hot_inpt_conc, _ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, input_img, input_pmask, return_all=True)
284
+
285
+ reference_mask = reference_mask.float().cuda().unsqueeze(0).unsqueeze(1)
286
+ _, mean_feat_ref_conc, _, _ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, reference_img, reference_mask, return_all=True)
287
+
288
+ # if mean_feat_ref.shape[1] > 1:
289
+ if isinstance(mean_feat_inpt_conc, list):
290
+ appearance_conc = []
291
+ for i in range(len(mean_feat_inpt_conc)):
292
+ mean_feat_inpt_conc[i][:, ma + 1] = (1 - inter) * mean_feat_inpt_conc[i][:, ma + 1] + inter*mean_feat_ref_conc[i][:, 1]
293
+ splatted_feat_conc = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_conc[i], one_hot_inpt_conc)
294
+ splatted_feat_conc = torch.nn.functional.normalize(splatted_feat_conc)
295
+ splatted_feat_conc = torch.nn.functional.interpolate(splatted_feat_conc, (self.H//8, self.W//8))
296
+ appearance_conc.append(splatted_feat_conc)
297
+ appearance_conc = torch.cat(appearance_conc, dim=1)
298
+ else:
299
+ print("manipulating")
300
+ mean_feat_inpt_conc[:, ma + 1] = (1 - inter) * mean_feat_inpt_conc[:, ma + 1] + inter*mean_feat_ref_conc[:, 1]
301
+
302
+ splatted_feat_conc = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_conc, one_hot_inpt_conc)
303
+ appearance_conc = torch.nn.functional.normalize(splatted_feat_conc) #l2 normaliz
304
+ appearance_conc = torch.nn.functional.interpolate(appearance_conc, (self.H//8, self.W//8))
305
+
306
+ #cross attention features
307
+ _, mean_feat_inpt_ca, one_hot_inpt_ca, _ = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, input_img, input_pmask, return_all=True)
308
+
309
+ _, mean_feat_ref_ca, _, _ = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, reference_img, reference_mask, return_all=True)
310
+
311
+ # if mean_feat_ref.shape[1] > 1:
312
+ if isinstance(mean_feat_inpt_ca, list):
313
+ appearance_ca = []
314
+ for i in range(len(mean_feat_inpt_ca)):
315
+ mean_feat_inpt_ca[i][:, ma + 1] = (1 - inter) * mean_feat_inpt_ca[i][:, ma + 1] + inter*mean_feat_ref_ca[i][:, 1]
316
+ splatted_feat_ca = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_ca[i], one_hot_inpt_ca)
317
+ splatted_feat_ca = torch.nn.functional.normalize(splatted_feat_ca)
318
+ splatted_feat_ca = torch.nn.functional.interpolate(splatted_feat_ca, (self.H//8, self.W//8))
319
+ appearance_ca.append(splatted_feat_ca)
320
+ else:
321
+ print("manipulating")
322
+ mean_feat_inpt_ca[:, ma + 1] = (1 - inter) * mean_feat_inpt_ca[:, ma + 1] + inter*mean_feat_ref_ca[:, 1]
323
+
324
+ splatted_feat_ca = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_ca, one_hot_inpt_ca)
325
+ appearance_ca = torch.nn.functional.normalize(splatted_feat_ca) #l2 normaliz
326
+ appearance_ca = torch.nn.functional.interpolate(appearance_ca, (self.H//8, self.W//8))
327
+
328
+
329
+
330
+ input_segmask = ((input_segmask+1)/ 127.5 - 1.0).cuda().unsqueeze(0).unsqueeze(1)
331
+ structure = torch.nn.functional.interpolate(input_segmask, (self.H//8, self.W//8))
332
+
333
+
334
+ return structure, appearance_conc, appearance_ca, region_mask, input_img
335
+
336
+ def _edit_obj_var(self, input_mask, ignore_structure):
337
+ input_img = (self.input_img/127.5 - 1)
338
+ input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
339
+
340
+
341
+ input_pmask = self.input_pmask
342
+ input_segmask = self.input_segmask
343
+
344
+ ma = torch.max(input_pmask)
345
+ mask_ = torch.tensor(input_mask['mask'][:, :, 0])
346
+ edit_mask = torch.zeros_like(mask_).cuda()
347
+ edit_mask[mask_ > 127] = 1
348
+ tmp = edit_mask * (input_pmask + ma + 1)
349
+ if ignore_structure:
350
+ tmp = edit_mask
351
+
352
+ input_pmask = tmp * edit_mask + (1 - edit_mask) * input_pmask
353
+
354
+ input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1)
355
+
356
+ mask_ca_feat = self.input_pmask.float().cuda().unsqueeze(0).unsqueeze(1) if ignore_structure else input_pmask
357
+ print(torch.unique(mask_ca_feat))
358
+
359
+ appearance_conc,_,_,_ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, input_img, input_pmask, return_all=True)
360
+ appearance_ca = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, input_img, mask_ca_feat)
361
+
362
+ appearance_conc = torch.nn.functional.interpolate(appearance_conc, (self.H//8, self.W//8))
363
+ appearance_ca = [torch.nn.functional.interpolate(ap, (self.H//8, self.W//8)) for ap in appearance_ca]
364
+
365
+ input_segmask = ((input_segmask+1)/ 127.5 - 1.0).cuda().unsqueeze(0).unsqueeze(1)
366
+ structure = torch.nn.functional.interpolate(input_segmask, (self.H//8, self.W//8))
367
+
368
+
369
+ tmp = input_mask['mask'][:, :, 0]
370
+ region = torch.tensor(tmp)
371
+ mask = torch.zeros_like(region).cuda()
372
+ mask[region > 127] = 1
373
+
374
+ return structure, appearance_conc, appearance_ca, mask, input_img
375
+
376
+ def get_caption(self, mask):
377
+ """
378
+ Generates the captions based on a set template
379
+ Args:
380
+ mask (numpy array): Region of image based on which caption needs to be generated
381
+ """
382
+ mask = mask['mask'][:, :, 0]
383
+ region = torch.tensor(mask).cuda()
384
+ mask = torch.zeros_like(region)
385
+ mask[region > 127] = 1
386
+
387
+ if torch.sum(mask) == 0:
388
+ return ""
389
+
390
+ c_ids = self.input_segmask * mask
391
+ unique_ids, counts = torch.unique(c_ids, return_counts=True)
392
+ c_id = int(unique_ids[torch.argmax(counts[1:]) + 1].cpu().detach().numpy())
393
+ category = self.segment_model.metadata.stuff_classes[c_id]
394
+
395
+ return self.base_prompt.format(category)
396
+
397
+ def save_result(self, input_mask, prompt, a_prompt, n_prompt,
398
+ ddim_steps, scale_s, scale_f, scale_t, seed, dilation_fac=1,inter=1,
399
+ free_form_obj_var=False, ignore_structure=False):
400
+ """
401
+ Saves the current results with all the meta data
402
+ """
403
+
404
+ meta_data = {}
405
+ meta_data['prompt'] = prompt
406
+ meta_data['a_prompt'] = a_prompt
407
+ meta_data['n_prompt'] = n_prompt
408
+ meta_data['seed'] = seed
409
+ meta_data['ddim_steps'] = ddim_steps
410
+ meta_data['scale_s'] = scale_s
411
+ meta_data['scale_f'] = scale_f
412
+ meta_data['scale_t'] = scale_t
413
+ meta_data['inter'] = inter
414
+ meta_data['dilation_fac'] = dilation_fac
415
+ meta_data['edit_operation'] = self.edit_operation
416
+
417
+ uuid = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
418
+ os.makedirs(f'{save_dir}/{uuid}')
419
+
420
+ with open(f'{save_dir}/{uuid}/meta.json', "w") as outfile:
421
+ json.dump(meta_data, outfile)
422
+ cv2.imwrite(f'{save_dir}/{uuid}/input.png', self.input_img[:,:,::-1])
423
+ cv2.imwrite(f'{save_dir}/{uuid}/ref.png', self.ref_img[:,:,::-1])
424
+ if self.ref_mask is not None:
425
+ cv2.imwrite(f'{save_dir}/{uuid}/ref_mask.png', self.ref_mask.cpu().squeeze().numpy() * 200)
426
+ for i in range(len(self.results)):
427
+ cv2.imwrite(f'{save_dir}/{uuid}/edit{i}.png', self.results[i][:,:,::-1])
428
+
429
+ if self.edit_operation == 'add_obj' or free_form_obj_var:
430
+ cv2.imwrite(f'{save_dir}/{uuid}/input_mask.png', input_mask['mask'] * 200)
431
+ else:
432
+ cv2.imwrite(f'{save_dir}/{uuid}/input_mask.png', self.input_mask.cpu().squeeze().numpy() * 200)
433
+
434
+ print("Saved results at", f'{save_dir}/{uuid}')
435
+
436
+
437
+ def process(self, input_mask, ref_mask, prompt, a_prompt, n_prompt,
438
+ num_samples, ddim_steps, guess_mode, strength,
439
+ scale_s, scale_f, scale_t, seed, eta, dilation_fac=1,masking=True,whole_ref=False,inter=1,
440
+ free_form_obj_var=False, ignore_structure=False):
441
+
442
+ print(prompt)
443
+ if free_form_obj_var:
444
+ print("Free form")
445
+ structure, appearance_conc, appearance_ca, mask, img = self._edit_obj_var(input_mask, ignore_structure)
446
+ else:
447
+ structure, appearance_conc, appearance_ca, mask, img = self._edit(input_mask, ref_mask, dilation_fac=dilation_fac,
448
+ whole_ref=whole_ref, inter=inter)
449
+
450
+ input_pmask = torch.nn.functional.interpolate(self.input_pmask.cuda().unsqueeze(0).unsqueeze(1).float(), (self.H//8, self.W//8))
451
+ input_pmask = input_pmask.to(memory_format=torch.contiguous_format)
452
+
453
+
454
+ if isinstance(appearance_ca, list):
455
+ null_appearance_ca = [torch.zeros(a.shape).cuda() for a in appearance_ca]
456
+ null_appearance_conc = torch.zeros(appearance_conc.shape).cuda()
457
+ null_structure = torch.zeros(structure.shape).cuda() - 1
458
+
459
+ null_control = [torch.cat([null_structure, napp, input_pmask * 0], dim=1) for napp in null_appearance_ca]
460
+ structure_control = [torch.cat([structure, napp, input_pmask], dim=1) for napp in null_appearance_ca]
461
+ full_control = [torch.cat([structure, napp, input_pmask], dim=1) for napp in appearance_ca]
462
+
463
+ null_control.append(torch.cat([null_structure, null_appearance_conc, null_structure * 0], dim=1))
464
+ structure_control.append(torch.cat([structure, null_appearance_conc, null_structure], dim=1))
465
+ full_control.append(torch.cat([structure, appearance_conc, input_pmask], dim=1))
466
+
467
+ null_control = [torch.cat([nc for _ in range(num_samples)], dim=0) for nc in null_control]
468
+ structure_control = [torch.cat([sc for _ in range(num_samples)], dim=0) for sc in structure_control]
469
+ full_control = [torch.cat([fc for _ in range(num_samples)], dim=0) for fc in full_control]
470
+
471
+ #Masking for local edit
472
+ if not masking:
473
+ mask, x0 = None, None
474
+ else:
475
+ x0 = model.encode_first_stage(img)
476
+ x0 = x0.sample() if isinstance(x0, DiagonalGaussianDistribution) else x0 # todo: check if we can set random number
477
+ x0 = x0 * model.scale_factor
478
+ mask = 1 - torch.tensor(mask).unsqueeze(0).unsqueeze(1).cuda()
479
+ mask = torch.nn.functional.interpolate(mask.float(), x0.shape[2:]).float()
480
+
481
+ if seed == -1:
482
+ seed = random.randint(0, 65535)
483
+ seed_everything(seed)
484
+
485
+ scale = [scale_s, scale_f, scale_t]
486
+ print(scale)
487
+ if save_memory:
488
+ model.low_vram_shift(is_diffusing=False)
489
+
490
+ uc_cross = model.get_learned_conditioning([n_prompt] * num_samples)
491
+ c_cross = model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)
492
+ cond = {"c_concat": [null_control], "c_crossattn": [c_cross]}
493
+ un_cond = {"c_concat": None if guess_mode else [null_control], "c_crossattn": [uc_cross]}
494
+ un_cond_struct = {"c_concat": None if guess_mode else [structure_control], "c_crossattn": [uc_cross]}
495
+ un_cond_struct_app = {"c_concat": None if guess_mode else [full_control], "c_crossattn": [uc_cross]}
496
+
497
+ shape = (4, self.H // 8, self.W // 8)
498
+
499
+ if save_memory:
500
+ model.low_vram_shift(is_diffusing=True)
501
+
502
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
503
+ samples, _ = ddim_sampler.sample(ddim_steps, num_samples,
504
+ shape, cond, verbose=False, eta=eta,
505
+ unconditional_guidance_scale=scale, mask=mask, x0=x0,
506
+ unconditional_conditioning=[un_cond, un_cond_struct, un_cond_struct_app ])
507
+
508
+ if save_memory:
509
+ model.low_vram_shift(is_diffusing=False)
510
+
511
+ x_samples = (model.decode_first_stage(samples) + 1) * 127.5
512
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')).cpu().numpy().clip(0, 255).astype(np.uint8)
513
+
514
+ results = [x_samples[i] for i in range(num_samples)]
515
+ self.results = results
516
+ return [] + results
requirements.txt CHANGED
@@ -9,6 +9,7 @@ omegaconf==2.3.0
9
  open-clip-torch==2.0.2
10
  opencv-contrib-python==4.3.0.36
11
  opencv-python-headless==4.7.0.72
 
12
  prettytable==3.6.0
13
  pytorch-lightning==1.5.0
14
  safetensors==0.2.7
@@ -44,4 +45,4 @@ diffdist
44
  gdown
45
  huggingface_hub
46
  tqdm
47
- wget
 
9
  open-clip-torch==2.0.2
10
  opencv-contrib-python==4.3.0.36
11
  opencv-python-headless==4.7.0.72
12
+ pillow==9.4.0
13
  prettytable==3.6.0
14
  pytorch-lightning==1.5.0
15
  safetensors==0.2.7
 
45
  gdown
46
  huggingface_hub
47
  tqdm
48
+ wget