imsuperkong commited on
Commit
d3bdeec
·
1 Parent(s): dc47947

Upload 6 files

Browse files
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision
3
+ timm==0.6.12
4
+ gradio==3.40.1
5
+ diffusers==0.17.1
6
+ numpy==1.20.3
7
+ wget
sd/core.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from diffusers import StableDiffusionPipeline
5
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
6
+ from typing import Any, Callable, Dict, List, Optional, Union
7
+ from sd.pnp_utils import register_time, register_attention_control_efficient_kv_w_mask, register_conv_control_efficient_w_mask
8
+ import torch.nn as nn
9
+ from sd.dift_sd import MyUNet2DConditionModel, OneStepSDPipeline
10
+ import ipdb
11
+ from tqdm import tqdm
12
+ from lib.midas import MiDas
13
+
14
+ class DDIMBackward(StableDiffusionPipeline):
15
+ def __init__(
16
+ self, vae, text_encoder, tokenizer, unet, scheduler,
17
+ safety_checker, feature_extractor,
18
+ requires_safety_checker: bool = True,
19
+ device='cuda', model_id='ckpt/stable-diffusion-2-1-base',depth_model='dpt_swin2_large_384'
20
+ ):
21
+ super().__init__(
22
+ vae, text_encoder, tokenizer, unet, scheduler,
23
+ safety_checker, feature_extractor, requires_safety_checker,
24
+ )
25
+
26
+ self.dift_unet = MyUNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=torch.float16 if 'cuda' in device else torch.float32)
27
+ self.onestep_pipe = OneStepSDPipeline.from_pretrained(model_id, unet=self.dift_unet, safety_checker=None, torch_dtype=torch.float16 if 'cuda' in device else torch.float32)
28
+ self.onestep_pipe = self.onestep_pipe.to(device)
29
+
30
+ if 'cuda' in device:
31
+ self.onestep_pipe.enable_attention_slicing()
32
+ self.onestep_pipe.enable_xformers_memory_efficient_attention()
33
+ self.ensemble_size = 4
34
+ self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
35
+
36
+ self.midas_model = MiDas(device,model_type=depth_model)
37
+
38
+ self.torch_dtype=torch.float16 if 'cuda' in device else torch.float32
39
+
40
+
41
+ @torch.no_grad()
42
+ def __call__(
43
+ self,
44
+ prompt: Union[str, List[str]] = None,
45
+ height: Optional[int] = None,
46
+ width: Optional[int] = None,
47
+ num_inference_steps: int = 50,
48
+ guidance_scale: float = 7.5,
49
+ negative_prompt: Optional[Union[str, List[str]]] = None,
50
+ num_images_per_prompt: Optional[int] = 1,
51
+ eta: float = 0.0,
52
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
53
+ latents: Optional[torch.FloatTensor] = None,
54
+ prompt_embeds: Optional[torch.FloatTensor] = None,
55
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
56
+ output_type: Optional[str] = "pil",
57
+ return_dict: bool = True,
58
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
59
+ callback_steps: int = 1,
60
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
61
+ t_start=None,
62
+ ):
63
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
64
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
65
+ self.check_inputs(
66
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
67
+ )
68
+
69
+ if prompt is not None and isinstance(prompt, str):
70
+ batch_size = 1
71
+ elif prompt is not None and isinstance(prompt, list):
72
+ batch_size = len(prompt)
73
+ else:
74
+ batch_size = prompt_embeds.shape[0]
75
+
76
+ device = self._execution_device
77
+ do_classifier_free_guidance = guidance_scale > 1.0
78
+ prompt_embeds = self._encode_prompt(
79
+ prompt,
80
+ device,
81
+ num_images_per_prompt,
82
+ do_classifier_free_guidance,
83
+ negative_prompt,
84
+ prompt_embeds=prompt_embeds,
85
+ negative_prompt_embeds=negative_prompt_embeds,
86
+ )
87
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
88
+ timesteps = self.scheduler.timesteps
89
+ num_channels_latents = self.unet.in_channels
90
+ latents = self.prepare_latents(
91
+ batch_size * num_images_per_prompt,
92
+ num_channels_latents,
93
+ height,
94
+ width,
95
+ prompt_embeds.dtype,
96
+ device,
97
+ generator,
98
+ latents,
99
+ )
100
+
101
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
102
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
103
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
104
+ for i, t in enumerate(timesteps):
105
+ if t_start and t >= t_start:
106
+ progress_bar.update()
107
+ continue
108
+
109
+ # expand the latents if we are doing classifier free guidance
110
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
111
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
112
+
113
+ # predict the noise residual
114
+ noise_pred = self.unet(
115
+ latent_model_input,
116
+ t,
117
+ encoder_hidden_states=prompt_embeds,
118
+ cross_attention_kwargs=cross_attention_kwargs,
119
+ ).sample
120
+
121
+ # perform guidance
122
+ if do_classifier_free_guidance:
123
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
124
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
125
+
126
+ # compute the previous noisy sample x_t -> x_t-1
127
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
128
+
129
+ # call the callback, if provided
130
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
131
+ progress_bar.update()
132
+ if callback is not None and i % callback_steps == 0:
133
+ callback(i, t, latents)
134
+
135
+ if output_type == "latent":
136
+ image = latents
137
+ has_nsfw_concept = None
138
+ elif output_type == "pil":
139
+ image = self.decode_latents(latents)
140
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
141
+ image = self.numpy_to_pil(image)
142
+ else:
143
+ image = self.decode_latents(latents)
144
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
145
+
146
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
147
+ self.final_offload_hook.offload()
148
+
149
+ if not return_dict:
150
+ return (image, has_nsfw_concept)
151
+
152
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
153
+
154
+ def denoise_w_injection(
155
+ self,
156
+ prompt: Union[str, List[str]] = None,
157
+ height: Optional[int] = None,
158
+ width: Optional[int] = None,
159
+ num_inference_steps: int = 50,
160
+ guidance_scale: float = 7.5,
161
+ negative_prompt: Optional[Union[str, List[str]]] = None,
162
+ num_images_per_prompt: Optional[int] = 1,
163
+ eta: float = 0.0,
164
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
165
+ latents: Optional[torch.FloatTensor] = None,
166
+ prompt_embeds: Optional[torch.FloatTensor] = None,
167
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
168
+ output_type: Optional[str] = "pil",
169
+ return_dict: bool = True,
170
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
171
+ callback_steps: int = 1,
172
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
173
+ t_start=None,
174
+ attn=0.8,
175
+ f=0.5,
176
+ latent_mask=None,
177
+ guidance_loss_scale=0,
178
+ cfg_decay=False,
179
+ cfg_norm=False,
180
+ lr=1.0,
181
+ up_ft_indexes=[1,2],
182
+ img_tensor=None,
183
+ early_stop=50,
184
+ intrinsic=None, extrinsic=None, threshold=20,depth=None,
185
+ ):
186
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
187
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
188
+ self.check_inputs(
189
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
190
+ )
191
+
192
+ if prompt is not None and isinstance(prompt, str):
193
+ batch_size = 1
194
+ elif prompt is not None and isinstance(prompt, list):
195
+ batch_size = len(prompt)
196
+ else:
197
+ batch_size = prompt_embeds.shape[0]
198
+
199
+ device = self._execution_device
200
+ do_classifier_free_guidance = guidance_scale > 1.0
201
+ prompt_embeds = self._encode_prompt(
202
+ prompt,
203
+ device,
204
+ num_images_per_prompt,
205
+ do_classifier_free_guidance,
206
+ negative_prompt,
207
+ prompt_embeds=prompt_embeds,
208
+ negative_prompt_embeds=negative_prompt_embeds,
209
+ )
210
+ if do_classifier_free_guidance:
211
+ prompt_embeds = torch.cat((prompt_embeds[1:], prompt_embeds[1:], prompt_embeds[:1]), dim=0)
212
+ else:
213
+ prompt_embeds = torch.cat([prompt_embeds]*2, dim=0)
214
+
215
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
216
+ timesteps = self.scheduler.timesteps
217
+ num_channels_latents = self.unet.in_channels
218
+ latents = self.prepare_latents(
219
+ batch_size * num_images_per_prompt,
220
+ num_channels_latents,
221
+ height,
222
+ width,
223
+ prompt_embeds.dtype,
224
+ device,
225
+ generator,
226
+ latents,
227
+ )
228
+
229
+ kv_injection_timesteps = self.scheduler.timesteps[:int(len(self.scheduler.timesteps) * attn)]
230
+ f_injection_timesteps = self.scheduler.timesteps[:int(len(self.scheduler.timesteps) * f)]
231
+ register_attention_control_efficient_kv_w_mask(self, kv_injection_timesteps, mask=latent_mask, do_classifier_free_guidance=do_classifier_free_guidance)
232
+ register_conv_control_efficient_w_mask(self, f_injection_timesteps, mask=latent_mask)
233
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
234
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
235
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
236
+ for i, t in enumerate(timesteps):
237
+ if t_start and t >= t_start:
238
+ progress_bar.update()
239
+ continue
240
+ if i > early_stop: guidance_loss_scale = 0 # Early stop (optional)
241
+ # if t > 300: guidance_loss_scale = 0 # Early stop (optional)
242
+ register_time(self, t.item())
243
+ # Set requires grad
244
+ if guidance_loss_scale != 0:
245
+ latents = latents.detach().requires_grad_()
246
+
247
+ # expand the latents if we are doing classifier free guidance
248
+ latent_model_input = latents # latents: ori_z + wrap_z
249
+ if do_classifier_free_guidance:
250
+ latent_model_input = torch.cat([latent_model_input, latent_model_input[1:]], dim=0)
251
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
252
+
253
+ # predict the noise residual
254
+ if guidance_loss_scale != 0:
255
+ with torch.no_grad():
256
+ noise_pred = self.unet(
257
+ latent_model_input,
258
+ t,
259
+ encoder_hidden_states=prompt_embeds,
260
+ cross_attention_kwargs=cross_attention_kwargs,
261
+ ).sample
262
+ else:
263
+ with torch.no_grad():
264
+ noise_pred = self.unet(
265
+ latent_model_input,
266
+ t,
267
+ encoder_hidden_states=prompt_embeds,
268
+ cross_attention_kwargs=cross_attention_kwargs,
269
+ ).sample
270
+
271
+ # perform guidance
272
+ if do_classifier_free_guidance:
273
+ cfg_scale = guidance_scale
274
+ if cfg_decay: cfg_scale = 1 + guidance_scale * (1-i/num_inference_steps)
275
+ noise_pred_text, wrap_noise_pred_text, wrap_noise_pred_uncond = noise_pred.chunk(3)
276
+ noise_pred = wrap_noise_pred_text + cfg_scale * (wrap_noise_pred_text - wrap_noise_pred_uncond)
277
+ else:
278
+ noise_pred_text, wrap_noise_pred_text = noise_pred.chunk(3)
279
+ noise_pred = wrap_noise_pred_text
280
+
281
+ if cfg_norm:
282
+ noise_pred = noise_pred * (torch.linalg.norm(wrap_noise_pred_uncond) / torch.linalg.norm(noise_pred))
283
+
284
+ if guidance_loss_scale != 0:
285
+ for up_ft_index in up_ft_indexes:
286
+
287
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
288
+ alpha_prod_t_prev = (
289
+ self.scheduler.alphas_cumprod[timesteps[i - 0]]
290
+ if i > 0 else self.scheduler.final_alpha_cumprod
291
+ )
292
+
293
+ mu = alpha_prod_t ** 0.5
294
+ mu_prev = alpha_prod_t_prev ** 0.5
295
+ sigma = (1 - alpha_prod_t) ** 0.5
296
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
297
+
298
+ pred_x0 = (latents - sigma_prev * noise_pred[:latents.shape[0]]) / mu_prev
299
+
300
+ unet_ft_all = self.onestep_pipe(
301
+ latents=pred_x0[:1].repeat(self.ensemble_size, 1, 1, 1),
302
+ t=t,
303
+ up_ft_indices=[up_ft_index],
304
+ prompt_embeds=prompt_embeds[:1].repeat(self.ensemble_size, 1, 1)
305
+ )
306
+ unet_ft1 = unet_ft_all['up_ft'][up_ft_index].mean(0, keepdim=True) # 1,c,h,w
307
+ unet_ft1_norm = unet_ft1 / torch.norm(unet_ft1, dim=1, keepdim=True)
308
+
309
+ unet_ft1_norm = self.midas_model.wrap_img_tensor_w_fft_ext(
310
+ unet_ft1_norm.to(self.torch_dtype),
311
+ torch.from_numpy(depth).to(device).to(self.torch_dtype),
312
+ intrinsic,
313
+ extrinsic[:3,:3], extrinsic[:3,3], threshold=threshold).to(self.torch_dtype)
314
+
315
+ unet_ft_all = self.onestep_pipe(
316
+ latents=pred_x0[1:2].repeat(self.ensemble_size, 1, 1, 1),
317
+ t=t,
318
+ up_ft_indices=[up_ft_index],
319
+ prompt_embeds=prompt_embeds[:1].repeat(self.ensemble_size, 1, 1)
320
+ )
321
+ unet_ft2 = unet_ft_all['up_ft'][up_ft_index].mean(0, keepdim=True) # 1,c,h,w
322
+ unet_ft2_norm = unet_ft2 / torch.norm(unet_ft2, dim=1, keepdim=True)
323
+ c = unet_ft2.shape[1]
324
+ loss = (-self.cos(unet_ft1_norm.squeeze().view(c, -1).T, unet_ft2_norm.squeeze().view(c, -1).T).mean() + 1) / 2.
325
+ # Get gradient
326
+ cond_grad = torch.autograd.grad(loss * guidance_loss_scale, latents)[0][1:2]
327
+
328
+ # compute the previous noisy sample x_t -> x_t-1
329
+ noise_pred_ = noise_pred - sigma_prev * cond_grad*lr
330
+ noise_pred_ = torch.cat([noise_pred_text, noise_pred_], dim=0)
331
+
332
+ # compute the previous noisy sample x_t -> x_t-1
333
+ with torch.no_grad():
334
+ latents = self.scheduler.step(noise_pred_, t, latents, **extra_step_kwargs).prev_sample
335
+ # call the callback, if provided
336
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
337
+ progress_bar.update()
338
+ if callback is not None and i % callback_steps == 0:
339
+ callback(i, t, latents)
340
+
341
+ if output_type == "latent":
342
+ image = latents
343
+ has_nsfw_concept = None
344
+ elif output_type == "pil":
345
+ with torch.no_grad():
346
+ image = self.decode_latents(latents)
347
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
348
+ image = self.numpy_to_pil(image)
349
+ else:
350
+ image = self.decode_latents(latents)
351
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
352
+
353
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
354
+ self.final_offload_hook.offload()
355
+
356
+ if not return_dict:
357
+ return (image, has_nsfw_concept)
358
+
359
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
360
+ @torch.no_grad()
361
+ def decoder(self, latents):
362
+ with torch.autocast(device_type=self.device, dtype=torch.float32):
363
+ latents = 1 / 0.18215 * latents
364
+ imgs = self.vae.decode(latents).sample
365
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
366
+ return imgs
367
+
368
+
369
+ def ddim_inversion_w_grad(self, latent, cond, stop_t, guidance_loss_scale=1.0, lr=1.0):
370
+ timesteps = reversed(self.scheduler.timesteps)
371
+ with torch.autocast(device_type=self.device, dtype=torch.float32):
372
+
373
+ for i, t in enumerate(tqdm(timesteps)):
374
+ if t >= stop_t:
375
+ break
376
+
377
+ if guidance_loss_scale != 0:
378
+ latent = latent.detach().requires_grad_()
379
+ cond_batch = cond.repeat(latent.shape[0], 1, 1)
380
+
381
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
382
+ alpha_prod_t_prev = (
383
+ self.scheduler.alphas_cumprod[timesteps[i - 1]]
384
+ if i > 0 else self.scheduler.final_alpha_cumprod
385
+ )
386
+
387
+ mu = alpha_prod_t ** 0.5
388
+ mu_prev = alpha_prod_t_prev ** 0.5
389
+ sigma = (1 - alpha_prod_t) ** 0.5
390
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
391
+
392
+ eps = self.onestep_pipe.unet(latent, t, encoder_hidden_states=cond_batch, up_ft_indices=[3], output_eps=True)['eps']
393
+ pred_x0 = (latent - sigma_prev * eps) / mu_prev
394
+
395
+ unet_ft_all = self.onestep_pipe(
396
+ latents=pred_x0[:1].repeat(self.ensemble_size, 1, 1, 1),
397
+ t=t,
398
+ up_ft_indices=[1],
399
+ prompt_embeds=cond_batch[:1].repeat(self.ensemble_size, 1, 1)
400
+ )
401
+ unet_ft1 = unet_ft_all['up_ft'][1].mean(0, keepdim=True) # 1,c,h,w
402
+ unet_ft1_norm = unet_ft1 / torch.norm(unet_ft1, dim=1, keepdim=True)
403
+
404
+ unet_ft_all = self.onestep_pipe(
405
+ latents=pred_x0[1:2].repeat(self.ensemble_size, 1, 1, 1),
406
+ t=t,
407
+ up_ft_indices=[1],
408
+ prompt_embeds=cond_batch[:1].repeat(self.ensemble_size, 1, 1)
409
+ )
410
+ unet_ft2 = unet_ft_all['up_ft'][1].mean(0, keepdim=True) # 1,c,h,w
411
+ unet_ft2_norm = unet_ft2 / torch.norm(unet_ft2, dim=1, keepdim=True)
412
+ c = unet_ft2.shape[1]
413
+ loss = (-self.cos(unet_ft1_norm.squeeze().view(c, -1).T.detach(), unet_ft2_norm.squeeze().view(c, -1).T).mean() + 1) / 2.
414
+ print(f'loss: {loss.item()}')
415
+ # Get gradient
416
+ cond_grad = torch.autograd.grad(loss * guidance_loss_scale, latent)[0]
417
+
418
+ # latent = latent.detach() - cond_grad * lr
419
+ latent = mu * pred_x0 + sigma * eps - cond_grad * lr
420
+
421
+ return latent
422
+
423
+ @torch.no_grad()
424
+ def DDPM_forward(x_t_dot, t_start, delta_t, ddpm_scheduler, generator):
425
+ # just simple implementation, this should have an analytical expression
426
+ # TODO: implementation analytical form
427
+ for delta in range(1, delta_t):
428
+ # noise = torch.randn_like(x_t_dot, generator=generator)
429
+ noise = torch.empty_like(x_t_dot).normal_(generator=generator)
430
+
431
+ beta = ddpm_scheduler.betas[t_start+delta]
432
+ std_ = beta ** 0.5
433
+ mu_ = ((1 - beta) ** 0.5) * x_t_dot
434
+ x_t_dot = mu_ + std_ * noise
435
+ return x_t_dot
sd/dift_sd.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline
2
+ import torch
3
+ import torch.nn as nn
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from typing import Any, Callable, Dict, List, Optional, Union
7
+ from diffusers.models.unet_2d_condition import UNet2DConditionModel
8
+ from diffusers import DDIMScheduler
9
+ import gc
10
+ from PIL import Image
11
+
12
+ class MyUNet2DConditionModel(UNet2DConditionModel):
13
+ def forward(
14
+ self,
15
+ sample: torch.FloatTensor,
16
+ timestep: Union[torch.Tensor, float, int],
17
+ up_ft_indices,
18
+ encoder_hidden_states: torch.Tensor,
19
+ class_labels: Optional[torch.Tensor] = None,
20
+ timestep_cond: Optional[torch.Tensor] = None,
21
+ attention_mask: Optional[torch.Tensor] = None,
22
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
23
+ output_eps=False):
24
+ r"""
25
+ Args:
26
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
27
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
28
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
29
+ cross_attention_kwargs (`dict`, *optional*):
30
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
31
+ `self.processor` in
32
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
33
+ """
34
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
35
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
36
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
37
+ # on the fly if necessary.
38
+ default_overall_up_factor = 2**self.num_upsamplers
39
+
40
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
41
+ forward_upsample_size = False
42
+ upsample_size = None
43
+
44
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
45
+ # logger.info("Forward upsample size to force interpolation output size.")
46
+ forward_upsample_size = True
47
+
48
+ # prepare attention_mask
49
+ if attention_mask is not None:
50
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
51
+ attention_mask = attention_mask.unsqueeze(1)
52
+
53
+ # 0. center input if necessary
54
+ if self.config.center_input_sample:
55
+ sample = 2 * sample - 1.0
56
+
57
+ # 1. time
58
+ timesteps = timestep
59
+ if not torch.is_tensor(timesteps):
60
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
61
+ # This would be a good case for the `match` statement (Python 3.10+)
62
+ is_mps = sample.device.type == "mps"
63
+ if isinstance(timestep, float):
64
+ dtype = torch.float32 if is_mps else torch.float64
65
+ else:
66
+ dtype = torch.int32 if is_mps else torch.int64
67
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
68
+ elif len(timesteps.shape) == 0:
69
+ timesteps = timesteps[None].to(sample.device)
70
+
71
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
72
+ timesteps = timesteps.expand(sample.shape[0])
73
+
74
+ t_emb = self.time_proj(timesteps)
75
+
76
+ # timesteps does not contain any weights and will always return f32 tensors
77
+ # but time_embedding might actually be running in fp16. so we need to cast here.
78
+ # there might be better ways to encapsulate this.
79
+ t_emb = t_emb.to(dtype=self.dtype)
80
+
81
+ emb = self.time_embedding(t_emb, timestep_cond)
82
+
83
+ if self.class_embedding is not None:
84
+ if class_labels is None:
85
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
86
+
87
+ if self.config.class_embed_type == "timestep":
88
+ class_labels = self.time_proj(class_labels)
89
+
90
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
91
+ emb = emb + class_emb
92
+
93
+ # 2. pre-process
94
+ sample = self.conv_in(sample)
95
+
96
+ # 3. down
97
+ down_block_res_samples = (sample,)
98
+ for downsample_block in self.down_blocks:
99
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
100
+ sample, res_samples = downsample_block(
101
+ hidden_states=sample,
102
+ temb=emb,
103
+ encoder_hidden_states=encoder_hidden_states,
104
+ attention_mask=attention_mask,
105
+ cross_attention_kwargs=cross_attention_kwargs,
106
+ )
107
+ else:
108
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
109
+
110
+ down_block_res_samples += res_samples
111
+
112
+ # 4. mid
113
+ if self.mid_block is not None:
114
+ sample = self.mid_block(
115
+ sample,
116
+ emb,
117
+ encoder_hidden_states=encoder_hidden_states,
118
+ attention_mask=attention_mask,
119
+ cross_attention_kwargs=cross_attention_kwargs,
120
+ )
121
+
122
+ # 5. up
123
+ up_ft = {}
124
+ for i, upsample_block in enumerate(self.up_blocks):
125
+
126
+ if i > np.max(up_ft_indices):
127
+ break
128
+
129
+ is_final_block = i == len(self.up_blocks) - 1
130
+
131
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
132
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
133
+
134
+ # if we have not reached the final block and need to forward the
135
+ # upsample size, we do it here
136
+ if not is_final_block and forward_upsample_size:
137
+ upsample_size = down_block_res_samples[-1].shape[2:]
138
+
139
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
140
+ sample = upsample_block(
141
+ hidden_states=sample,
142
+ temb=emb,
143
+ res_hidden_states_tuple=res_samples,
144
+ encoder_hidden_states=encoder_hidden_states,
145
+ cross_attention_kwargs=cross_attention_kwargs,
146
+ upsample_size=upsample_size,
147
+ attention_mask=attention_mask,
148
+ )
149
+ else:
150
+ sample = upsample_block(
151
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
152
+ )
153
+
154
+ if i in up_ft_indices:
155
+ up_ft[i] = sample
156
+
157
+ output = {}
158
+ output['up_ft'] = up_ft
159
+ if output_eps:
160
+ sample = self.conv_norm_out(sample)
161
+ sample = self.conv_act(sample)
162
+ sample = self.conv_out(sample)
163
+ output['eps'] = sample
164
+ return output
165
+
166
+ class OneStepSDPipeline(StableDiffusionPipeline):
167
+ # @torch.no_grad()
168
+ def __call__(
169
+ self,
170
+
171
+ t,
172
+ up_ft_indices,
173
+ negative_prompt: Optional[Union[str, List[str]]] = None,
174
+ img_tensor=None,
175
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
176
+ prompt_embeds: Optional[torch.FloatTensor] = None,
177
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
178
+ callback_steps: int = 1,
179
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
180
+ latents=None
181
+ ):
182
+
183
+ device = self._execution_device
184
+ if latents is None:
185
+ latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor
186
+ t = torch.tensor(t.clone().detach(), dtype=torch.long, device=device)
187
+ noise = torch.randn_like(latents).to(device)
188
+ latents_noisy = self.scheduler.add_noise(latents, noise, t)
189
+ unet_output = self.unet(latents_noisy,
190
+ t,
191
+ up_ft_indices,
192
+ encoder_hidden_states=prompt_embeds,
193
+ cross_attention_kwargs=cross_attention_kwargs)
194
+ return unet_output
195
+
196
+
197
+ class SDFeaturizer:
198
+ def __init__(self, sd_id='ckpt/stable-diffusion-2-1-base'):
199
+ unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder="unet")
200
+ onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None)
201
+ onestep_pipe.vae.decoder = None
202
+ onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler")
203
+ gc.collect()
204
+ onestep_pipe = onestep_pipe.to("cuda")
205
+ onestep_pipe.enable_attention_slicing()
206
+ onestep_pipe.enable_xformers_memory_efficient_attention()
207
+ self.pipe = onestep_pipe
208
+
209
+ @torch.no_grad()
210
+ def forward(self,
211
+ img_tensor,
212
+ prompt,
213
+ t=261,
214
+ up_ft_index=1,
215
+ ensemble_size=8):
216
+ '''
217
+ Args:
218
+ img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W]
219
+ prompt: the prompt to use, a string
220
+ t: the time step to use, should be an int in the range of [0, 1000]
221
+ up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3]
222
+ ensemble_size: the number of repeated images used in the batch to extract features
223
+ Return:
224
+ unet_ft: a torch tensor in the shape of [1, c, h, w]
225
+ '''
226
+ img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w
227
+ prompt_embeds = self.pipe._encode_prompt(
228
+ prompt=prompt,
229
+ device='cuda',
230
+ num_images_per_prompt=1,
231
+ do_classifier_free_guidance=False) # [1, 77, dim]
232
+ prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1)
233
+ unet_ft_all = self.pipe(
234
+ img_tensor=img_tensor,
235
+ t=t,
236
+ up_ft_indices=[up_ft_index],
237
+ prompt_embeds=prompt_embeds)
238
+ unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w
239
+ unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w
240
+ return unet_ft
sd/gradio_utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import copy
3
+ import math
4
+ import os
5
+ import urllib.request
6
+ from typing import List, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import PIL
10
+ import PIL.Image
11
+ import PIL.ImageDraw
12
+ import torch
13
+ import torch.optim
14
+ from tqdm import tqdm
15
+ import ipdb
16
+
17
+ def tensor_to_PIL(img: torch.Tensor) -> PIL.Image.Image:
18
+ """
19
+ Converts a tensor image to a PIL Image.
20
+
21
+ Args:
22
+ img (torch.Tensor): The tensor image of shape [batch_size, num_channels, height, width].
23
+
24
+ Returns:
25
+ A PIL Image object.
26
+ """
27
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
28
+ return PIL.Image.fromarray(img[0].cpu().numpy(), "RGB")
29
+
30
+
31
+ def get_ellipse_coords(
32
+ point: Tuple[int, int], radius: int = 5
33
+ ) -> Tuple[int, int, int, int]:
34
+ """
35
+ Returns the coordinates of an ellipse centered at the given point.
36
+
37
+ Args:
38
+ point (Tuple[int, int]): The center point of the ellipse.
39
+ radius (int): The radius of the ellipse.
40
+
41
+ Returns:
42
+ A tuple containing the coordinates of the ellipse in the format (x_min, y_min, x_max, y_max).
43
+ """
44
+ center = point
45
+ return (
46
+ center[0] - radius,
47
+ center[1] - radius,
48
+ center[0] + radius,
49
+ center[1] + radius,
50
+ )
51
+
52
+
53
+
54
+ def draw_handle_target_points(
55
+ img: PIL.Image.Image,
56
+ # handle_points: List[Tuple[int, int]],
57
+ target_points: List[Tuple[int, int]],
58
+ radius: int = 5):
59
+ """
60
+ Draws handle and target points with arrow pointing towards the target point.
61
+
62
+ Args:
63
+ img (PIL.Image.Image): The image to draw on.
64
+ handle_points (List[Tuple[int, int]]): A list of handle [x,y] points.
65
+ target_points (List[Tuple[int, int]]): A list of target [x,y] points.
66
+ radius (int): The radius of the handle and target points.
67
+ """
68
+ if not isinstance(img, PIL.Image.Image):
69
+ img = PIL.Image.fromarray(img)
70
+
71
+ # if len(handle_points) == len(target_points) + 1:
72
+ # target_points = copy.deepcopy(target_points) + [None]
73
+
74
+ draw = PIL.ImageDraw.Draw(img)
75
+ for handle_point, target_point in zip(target_points, target_points):
76
+ # handle_point = [handle_point[1], handle_point[0]]
77
+ # Draw the handle point
78
+ # ipdb.set_trace()
79
+
80
+ target_coords = get_ellipse_coords(target_point, radius)
81
+ draw.ellipse((target_coords), fill="red")
82
+
83
+ return np.array(img)
84
+
85
+
sd/pnp_utils.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ import ipdb
6
+ import torch.nn.functional as F
7
+
8
+ def seed_everything(seed):
9
+ torch.manual_seed(seed)
10
+ torch.cuda.manual_seed(seed)
11
+ random.seed(seed)
12
+ np.random.seed(seed)
13
+
14
+ def register_time(model, t):
15
+ conv_module = model.unet.up_blocks[1].resnets[1]
16
+ setattr(conv_module, 't', t)
17
+ down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]}
18
+ up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
19
+ for res in up_res_dict:
20
+ for block in up_res_dict[res]:
21
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
22
+ setattr(module, 't', t)
23
+ for res in down_res_dict:
24
+ for block in down_res_dict[res]:
25
+ module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1
26
+ setattr(module, 't', t)
27
+ module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn1
28
+ setattr(module, 't', t)
29
+
30
+
31
+ def load_source_latents_t(t, latents_path):
32
+ latents_t_path = os.path.join(latents_path, f'noisy_latents_{t}.pt')
33
+ assert os.path.exists(latents_t_path), f'Missing latents at t {t} path {latents_t_path}'
34
+ latents = torch.load(latents_t_path)
35
+ return latents
36
+
37
+ def register_attention_control_efficient(model, injection_schedule):
38
+ def sa_forward(self):
39
+ to_out = self.to_out
40
+ if type(to_out) is torch.nn.modules.container.ModuleList:
41
+ to_out = self.to_out[0]
42
+ else:
43
+ to_out = self.to_out
44
+
45
+ def forward(x, encoder_hidden_states=None, attention_mask=None):
46
+ batch_size, sequence_length, dim = x.shape
47
+ h = self.heads
48
+
49
+ is_cross = encoder_hidden_states is not None
50
+ encoder_hidden_states = encoder_hidden_states if is_cross else x
51
+ if not is_cross and self.injection_schedule is not None and (
52
+ self.t in self.injection_schedule or self.t == 1000):
53
+ q = self.to_q(x)
54
+ k = self.to_k(encoder_hidden_states)
55
+
56
+ source_batch_size = int(q.shape[0] // 3)
57
+ # inject unconditional
58
+ q[source_batch_size:2 * source_batch_size] = q[:source_batch_size]
59
+ k[source_batch_size:2 * source_batch_size] = k[:source_batch_size]
60
+ # inject conditional
61
+ q[2 * source_batch_size:] = q[:source_batch_size]
62
+ k[2 * source_batch_size:] = k[:source_batch_size]
63
+
64
+ q = self.head_to_batch_dim(q)
65
+ k = self.head_to_batch_dim(k)
66
+ else:
67
+ q = self.to_q(x)
68
+ k = self.to_k(encoder_hidden_states)
69
+ q = self.head_to_batch_dim(q)
70
+ k = self.head_to_batch_dim(k)
71
+
72
+ v = self.to_v(encoder_hidden_states)
73
+ v = self.head_to_batch_dim(v)
74
+
75
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
76
+
77
+ if attention_mask is not None:
78
+ attention_mask = attention_mask.reshape(batch_size, -1)
79
+ max_neg_value = -torch.finfo(sim.dtype).max
80
+ attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
81
+ sim.masked_fill_(~attention_mask, max_neg_value)
82
+
83
+ # attention, what we cannot get enough of
84
+ attn = sim.softmax(dim=-1)
85
+ out = torch.einsum("b i j, b j d -> b i d", attn, v)
86
+ out = self.batch_to_head_dim(out)
87
+
88
+ return to_out(out)
89
+
90
+ return forward
91
+ res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
92
+ for res in res_dict:
93
+ for block in res_dict[res]:
94
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
95
+ module.forward = sa_forward(module)
96
+ setattr(module, 'injection_schedule', injection_schedule)
97
+
98
+ def register_attention_control_efficient_kv(model, injection_schedule):
99
+ def sa_forward(self):
100
+ to_out = self.to_out
101
+ if type(to_out) is torch.nn.modules.container.ModuleList:
102
+ to_out = self.to_out[0]
103
+ else:
104
+ to_out = self.to_out
105
+
106
+ def forward(x, encoder_hidden_states=None, attention_mask=None):
107
+ batch_size, sequence_length, dim = x.shape
108
+ h = self.heads
109
+ # if encoder_hidden_states is None:
110
+ # ipdb.set_trace()
111
+
112
+ is_cross = encoder_hidden_states is not None
113
+ encoder_hidden_states = encoder_hidden_states if is_cross else x
114
+
115
+ q = self.to_q(x)
116
+ q = self.head_to_batch_dim(q)
117
+
118
+ if not is_cross and self.injection_schedule is not None and (
119
+ self.t in self.injection_schedule or self.t == 1000):
120
+ # q = self.to_q(x)
121
+ k = self.to_k(encoder_hidden_states)
122
+ v = self.to_v(encoder_hidden_states)
123
+
124
+ source_batch_size = int(v.shape[0] // 3)
125
+ # inject unconditional
126
+ k[source_batch_size:2 * source_batch_size] = k[:source_batch_size]
127
+ v[source_batch_size:2 * source_batch_size] = v[:source_batch_size]
128
+
129
+ # inject conditional
130
+ k[2 * source_batch_size:] = k[:source_batch_size]
131
+ v[2 * source_batch_size:] = v[:source_batch_size]
132
+
133
+ # q = self.head_to_batch_dim(q)
134
+ k = self.head_to_batch_dim(k)
135
+ v = self.head_to_batch_dim(v)
136
+ else:
137
+ # q = self.to_q(x)
138
+ k = self.to_k(encoder_hidden_states)
139
+ # q = self.head_to_batch_dim(q)
140
+ k = self.head_to_batch_dim(k)
141
+
142
+ v = self.to_v(encoder_hidden_states)
143
+ v = self.head_to_batch_dim(v)
144
+
145
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
146
+
147
+ if attention_mask is not None:
148
+ attention_mask = attention_mask.reshape(batch_size, -1)
149
+ max_neg_value = -torch.finfo(sim.dtype).max
150
+ attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
151
+ sim.masked_fill_(~attention_mask, max_neg_value)
152
+
153
+ # attention, what we cannot get enough of
154
+ attn = sim.softmax(dim=-1)
155
+ out = torch.einsum("b i j, b j d -> b i d", attn, v)
156
+ out = self.batch_to_head_dim(out)
157
+
158
+ return to_out(out)
159
+
160
+ return forward
161
+
162
+ res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
163
+ for res in res_dict:
164
+ for block in res_dict[res]:
165
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
166
+ module.forward = sa_forward(module)
167
+ setattr(module, 'injection_schedule', injection_schedule)
168
+
169
+
170
+ def register_conv_control_efficient(model, injection_schedule):
171
+ def conv_forward(self):
172
+ def forward(input_tensor, temb):
173
+ hidden_states = input_tensor
174
+
175
+ hidden_states = self.norm1(hidden_states)
176
+ hidden_states = self.nonlinearity(hidden_states)
177
+
178
+ if self.upsample is not None:
179
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
180
+ if hidden_states.shape[0] >= 64:
181
+ input_tensor = input_tensor.contiguous()
182
+ hidden_states = hidden_states.contiguous()
183
+ input_tensor = self.upsample(input_tensor)
184
+ hidden_states = self.upsample(hidden_states)
185
+ elif self.downsample is not None:
186
+ input_tensor = self.downsample(input_tensor)
187
+ hidden_states = self.downsample(hidden_states)
188
+
189
+ hidden_states = self.conv1(hidden_states)
190
+
191
+ if temb is not None:
192
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
193
+
194
+ if temb is not None and self.time_embedding_norm == "default":
195
+ hidden_states = hidden_states + temb
196
+
197
+ hidden_states = self.norm2(hidden_states)
198
+
199
+ if temb is not None and self.time_embedding_norm == "scale_shift":
200
+ scale, shift = torch.chunk(temb, 2, dim=1)
201
+ hidden_states = hidden_states * (1 + scale) + shift
202
+
203
+ hidden_states = self.nonlinearity(hidden_states)
204
+
205
+ hidden_states = self.dropout(hidden_states)
206
+ hidden_states = self.conv2(hidden_states)
207
+ if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
208
+ source_batch_size = int(hidden_states.shape[0] // 3)
209
+ # inject unconditional
210
+ hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
211
+ # inject conditional
212
+ hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size]
213
+
214
+ if self.conv_shortcut is not None:
215
+ input_tensor = self.conv_shortcut(input_tensor)
216
+
217
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
218
+
219
+ return output_tensor
220
+
221
+ return forward
222
+
223
+ conv_module = model.unet.up_blocks[1].resnets[1]
224
+ conv_module.forward = conv_forward(conv_module)
225
+ setattr(conv_module, 'injection_schedule', injection_schedule)
226
+
227
+
228
+ def register_attention_control_efficient_kv_2nd_to_1st(model, injection_schedule, mask=None):
229
+ def sa_forward(self):
230
+ to_out = self.to_out
231
+ if type(to_out) is torch.nn.modules.container.ModuleList:
232
+ to_out = self.to_out[0]
233
+ else:
234
+ to_out = self.to_out
235
+
236
+ def forward(x, mask=mask, encoder_hidden_states=None, attention_mask=None):
237
+ batch_size, sequence_length, dim = x.shape
238
+ h = self.heads
239
+ # if encoder_hidden_states is None:
240
+ # ipdb.set_trace()
241
+ is_cross = encoder_hidden_states is not None
242
+ encoder_hidden_states = encoder_hidden_states if is_cross else x
243
+
244
+ q = self.to_q(x)
245
+ q = self.head_to_batch_dim(q)
246
+
247
+ if not is_cross and self.injection_schedule is not None and (
248
+ self.t in self.injection_schedule or self.t == 1000):
249
+ # q = self.to_q(x)
250
+ target_size = int(np.sqrt(encoder_hidden_states.shape[1]))
251
+ target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
252
+ target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
253
+ k = self.to_k(encoder_hidden_states) # k: bx256x1280
254
+ v = self.to_v(encoder_hidden_states)
255
+
256
+ source_batch_size = int(v.shape[0] // 2)
257
+ # inject
258
+ k[:source_batch_size] = k[source_batch_size:2 * source_batch_size] * (1-target_mask) + k[:source_batch_size] * target_mask
259
+ v[:source_batch_size] = v[source_batch_size:2 * source_batch_size] * (1-target_mask) + v[:source_batch_size] * target_mask
260
+
261
+ # q = self.head_to_batch_dim(q)
262
+ k = self.head_to_batch_dim(k)
263
+ v = self.head_to_batch_dim(v)
264
+ else:
265
+ # q = self.to_q(x)
266
+ k = self.to_k(encoder_hidden_states)
267
+ # q = self.head_to_batch_dim(q)
268
+ k = self.head_to_batch_dim(k)
269
+
270
+ v = self.to_v(encoder_hidden_states)
271
+ v = self.head_to_batch_dim(v)
272
+
273
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
274
+
275
+ if attention_mask is not None:
276
+ attention_mask = attention_mask.reshape(batch_size, -1)
277
+ max_neg_value = -torch.finfo(sim.dtype).max
278
+ attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
279
+ sim.masked_fill_(~attention_mask, max_neg_value)
280
+
281
+ # attention, what we cannot get enough of
282
+ attn = sim.softmax(dim=-1)
283
+ out = torch.einsum("b i j, b j d -> b i d", attn, v)
284
+ out = self.batch_to_head_dim(out)
285
+
286
+ return to_out(out)
287
+
288
+ return forward
289
+
290
+ # res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
291
+ res_dict = {1: [1, 2], 2: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
292
+
293
+ for res in res_dict:
294
+ for block in res_dict[res]:
295
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
296
+ module.forward = sa_forward(module)
297
+ setattr(module, 'injection_schedule', injection_schedule)
298
+
299
+ def register_conv_control_efficient_2nd_to_1st(model, injection_schedule, mask=None):
300
+ def conv_forward(self):
301
+ def forward(input_tensor, temb):
302
+ hidden_states = input_tensor
303
+
304
+ hidden_states = self.norm1(hidden_states)
305
+ hidden_states = self.nonlinearity(hidden_states)
306
+
307
+ if self.upsample is not None:
308
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
309
+ if hidden_states.shape[0] >= 64:
310
+ input_tensor = input_tensor.contiguous()
311
+ hidden_states = hidden_states.contiguous()
312
+ input_tensor = self.upsample(input_tensor)
313
+ hidden_states = self.upsample(hidden_states)
314
+ elif self.downsample is not None:
315
+ input_tensor = self.downsample(input_tensor)
316
+ hidden_states = self.downsample(hidden_states)
317
+
318
+ hidden_states = self.conv1(hidden_states)
319
+
320
+ if temb is not None:
321
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
322
+
323
+ if temb is not None and self.time_embedding_norm == "default":
324
+ hidden_states = hidden_states + temb
325
+
326
+ hidden_states = self.norm2(hidden_states)
327
+
328
+ if temb is not None and self.time_embedding_norm == "scale_shift":
329
+ scale, shift = torch.chunk(temb, 2, dim=1)
330
+ hidden_states = hidden_states * (1 + scale) + shift
331
+
332
+ hidden_states = self.nonlinearity(hidden_states)
333
+
334
+ hidden_states = self.dropout(hidden_states)
335
+ hidden_states = self.conv2(hidden_states)
336
+ if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
337
+ source_batch_size = int(hidden_states.shape[0] // 2)
338
+ # inject unconditional
339
+ # hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
340
+ # inject conditional
341
+ target_size = int(np.sqrt(hidden_states.shape[-1]))
342
+ target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
343
+ target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
344
+
345
+ hidden_states[:source_batch_size] = hidden_states[source_batch_size:] * (1-target_mask) + hidden_states[:source_batch_size] * target_mask
346
+
347
+ if self.conv_shortcut is not None:
348
+ input_tensor = self.conv_shortcut(input_tensor)
349
+
350
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
351
+
352
+ return output_tensor
353
+
354
+ return forward
355
+
356
+ conv_module = model.unet.up_blocks[1].resnets[1]
357
+ conv_module.forward = conv_forward(conv_module)
358
+ setattr(conv_module, 'injection_schedule', injection_schedule)
359
+
360
+
361
+ def register_attention_control_efficient_qk_w_mask(model, injection_schedule, mask):
362
+ def sa_forward(self):
363
+ to_out = self.to_out
364
+ if type(to_out) is torch.nn.modules.container.ModuleList:
365
+ to_out = self.to_out[0]
366
+ else:
367
+ to_out = self.to_out
368
+
369
+ def forward(x, encoder_hidden_states=None, attention_mask=None):
370
+ batch_size, sequence_length, dim = x.shape
371
+ h = self.heads
372
+
373
+ is_cross = encoder_hidden_states is not None
374
+ encoder_hidden_states = encoder_hidden_states if is_cross else x
375
+ if not is_cross and self.injection_schedule is not None and (
376
+ self.t in self.injection_schedule or self.t == 1000):
377
+ q = self.to_q(x)
378
+ k = self.to_k(encoder_hidden_states)
379
+
380
+ target_size = int(np.sqrt(encoder_hidden_states.shape[1]))
381
+ target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
382
+ target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
383
+
384
+ source_batch_size = int(q.shape[0] // 3)
385
+ # inject unconditional
386
+ q[source_batch_size:2 * source_batch_size] = q[:source_batch_size] * target_mask + q[source_batch_size:2 * source_batch_size] * (1 - target_mask)
387
+ k[source_batch_size:2 * source_batch_size] = k[:source_batch_size] * target_mask + k[source_batch_size:2 * source_batch_size] * (1 - target_mask)
388
+ # inject conditional
389
+ q[2 * source_batch_size:] = q[:source_batch_size] * target_mask + q[2 * source_batch_size:] * (1 - target_mask)
390
+ k[2 * source_batch_size:] = k[:source_batch_size] * target_mask + k[2 * source_batch_size:] * (1 - target_mask)
391
+
392
+ q = self.head_to_batch_dim(q)
393
+ k = self.head_to_batch_dim(k)
394
+ else:
395
+ q = self.to_q(x)
396
+ k = self.to_k(encoder_hidden_states)
397
+ q = self.head_to_batch_dim(q)
398
+ k = self.head_to_batch_dim(k)
399
+
400
+ v = self.to_v(encoder_hidden_states)
401
+ v = self.head_to_batch_dim(v)
402
+
403
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
404
+
405
+ if attention_mask is not None:
406
+ attention_mask = attention_mask.reshape(batch_size, -1)
407
+ max_neg_value = -torch.finfo(sim.dtype).max
408
+ attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
409
+ sim.masked_fill_(~attention_mask, max_neg_value)
410
+
411
+ # attention, what we cannot get enough of
412
+ attn = sim.softmax(dim=-1)
413
+ out = torch.einsum("b i j, b j d -> b i d", attn, v)
414
+ out = self.batch_to_head_dim(out)
415
+
416
+ return to_out(out)
417
+
418
+ return forward
419
+ res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
420
+
421
+ for res in res_dict:
422
+ for block in res_dict[res]:
423
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
424
+ module.forward = sa_forward(module)
425
+ setattr(module, 'injection_schedule', injection_schedule)
426
+
427
+ def register_attention_control_efficient_kv_w_mask(model, injection_schedule, mask, do_classifier_free_guidance):
428
+ def sa_forward(self):
429
+ to_out = self.to_out
430
+ if type(to_out) is torch.nn.modules.container.ModuleList:
431
+ to_out = self.to_out[0]
432
+ else:
433
+ to_out = self.to_out
434
+
435
+ def forward(x, encoder_hidden_states=None, attention_mask=None):
436
+ batch_size, sequence_length, dim = x.shape
437
+ h = self.heads
438
+
439
+ is_cross = encoder_hidden_states is not None
440
+ encoder_hidden_states = encoder_hidden_states if is_cross else x
441
+
442
+ q = self.to_q(x)
443
+ q = self.head_to_batch_dim(q)
444
+
445
+ if not is_cross and self.injection_schedule is not None and (
446
+ self.t in self.injection_schedule or self.t == 1000):
447
+ # if False:
448
+ k = self.to_k(encoder_hidden_states) # k: bx256x1280
449
+ v = self.to_v(encoder_hidden_states)
450
+
451
+ target_size = int(np.sqrt(encoder_hidden_states.shape[1]))
452
+ target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
453
+ target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
454
+
455
+ source_batch_size = int(v.shape[0] // 3)
456
+ if do_classifier_free_guidance:
457
+ # inject unconditional
458
+ v[source_batch_size:2 * source_batch_size] = v[:source_batch_size] * target_mask + v[source_batch_size:2 * source_batch_size] * (1 - target_mask)
459
+ k[source_batch_size:2 * source_batch_size] = k[:source_batch_size] * target_mask + k[source_batch_size:2 * source_batch_size] * (1 - target_mask)
460
+ # inject conditional
461
+ v[2 * source_batch_size:] = v[:source_batch_size] * target_mask + v[2 * source_batch_size:] * (1 - target_mask)
462
+ k[2 * source_batch_size:] = k[:source_batch_size] * target_mask + k[2 * source_batch_size:] * (1 - target_mask)
463
+ else:
464
+ v[source_batch_size:2 * source_batch_size] = v[:source_batch_size] * target_mask + v[source_batch_size:2 * source_batch_size] * (1 - target_mask)
465
+ k[source_batch_size:2 * source_batch_size] = k[:source_batch_size] * target_mask + k[source_batch_size:2 * source_batch_size] * (1 - target_mask)
466
+
467
+ k = self.head_to_batch_dim(k)
468
+ v = self.head_to_batch_dim(v)
469
+ else:
470
+ # q = self.to_q(x)
471
+ k = self.to_k(encoder_hidden_states)
472
+ # q = self.head_to_batch_dim(q)
473
+ k = self.head_to_batch_dim(k)
474
+
475
+ v = self.to_v(encoder_hidden_states)
476
+ v = self.head_to_batch_dim(v)
477
+
478
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
479
+
480
+ if attention_mask is not None:
481
+ attention_mask = attention_mask.reshape(batch_size, -1)
482
+ max_neg_value = -torch.finfo(sim.dtype).max
483
+ attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
484
+ sim.masked_fill_(~attention_mask, max_neg_value)
485
+
486
+ # attention, what we cannot get enough of
487
+ attn = sim.softmax(dim=-1)
488
+ out = torch.einsum("b i j, b j d -> b i d", attn, v)
489
+ out = self.batch_to_head_dim(out)
490
+
491
+ return to_out(out)
492
+
493
+ return forward
494
+ res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
495
+ # res_dict = {1: [2], 2: [2], 3: [2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
496
+
497
+ for res in res_dict:
498
+ for block in res_dict[res]:
499
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
500
+ module.forward = sa_forward(module)
501
+ setattr(module, 'injection_schedule', injection_schedule)
502
+ # down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]}
503
+ # for res in down_res_dict:
504
+ # for block in down_res_dict[res]:
505
+ # module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1
506
+ # module.forward = sa_forward(module)
507
+ # setattr(module, 'injection_schedule', injection_schedule)
508
+
509
+ def register_conv_control_efficient_w_mask(model, injection_schedule, mask):
510
+ def conv_forward(self):
511
+ def forward(input_tensor, temb):
512
+ hidden_states = input_tensor
513
+
514
+ hidden_states = self.norm1(hidden_states)
515
+ hidden_states = self.nonlinearity(hidden_states)
516
+
517
+ if self.upsample is not None:
518
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
519
+ if hidden_states.shape[0] >= 64:
520
+ input_tensor = input_tensor.contiguous()
521
+ hidden_states = hidden_states.contiguous()
522
+ input_tensor = self.upsample(input_tensor)
523
+ hidden_states = self.upsample(hidden_states)
524
+ elif self.downsample is not None:
525
+ input_tensor = self.downsample(input_tensor)
526
+ hidden_states = self.downsample(hidden_states)
527
+
528
+ hidden_states = self.conv1(hidden_states)
529
+
530
+ if temb is not None:
531
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
532
+
533
+ if temb is not None and self.time_embedding_norm == "default":
534
+ hidden_states = hidden_states + temb
535
+
536
+ hidden_states = self.norm2(hidden_states)
537
+
538
+ if temb is not None and self.time_embedding_norm == "scale_shift":
539
+ scale, shift = torch.chunk(temb, 2, dim=1)
540
+ hidden_states = hidden_states * (1 + scale) + shift
541
+
542
+ hidden_states = self.nonlinearity(hidden_states)
543
+
544
+ hidden_states = self.dropout(hidden_states)
545
+ hidden_states = self.conv2(hidden_states)
546
+ if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
547
+ # if False:
548
+ source_batch_size = int(hidden_states.shape[0] // 3)
549
+ target_size = int(np.sqrt(hidden_states.shape[-1]))
550
+ target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
551
+ target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
552
+
553
+ # inject unconditional
554
+ hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size] * target_mask + hidden_states[source_batch_size:2 * source_batch_size] * (1-target_mask)
555
+ # inject conditional
556
+ hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size] * target_mask + hidden_states[2 * source_batch_size:] * (1-target_mask)
557
+
558
+ if self.conv_shortcut is not None:
559
+ input_tensor = self.conv_shortcut(input_tensor)
560
+
561
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
562
+
563
+ return output_tensor
564
+
565
+ return forward
566
+
567
+ conv_module = model.unet.up_blocks[1].resnets[1]
568
+ conv_module.forward = conv_forward(conv_module)
569
+ setattr(conv_module, 'injection_schedule', injection_schedule)
weights/dpt_beit_large_512.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e9e900747e9e8b3112df716979219836a27716277b3d0dc53889cbba8b82328
3
+ size 1581966003