Spaces:
Runtime error
Runtime error
imsuperkong
commited on
Commit
·
d3bdeec
1
Parent(s):
dc47947
Upload 6 files
Browse files- requirements.txt +7 -0
- sd/core.py +435 -0
- sd/dift_sd.py +240 -0
- sd/gradio_utils.py +85 -0
- sd/pnp_utils.py +569 -0
- weights/dpt_beit_large_512.pt +3 -0
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
torchvision
|
3 |
+
timm==0.6.12
|
4 |
+
gradio==3.40.1
|
5 |
+
diffusers==0.17.1
|
6 |
+
numpy==1.20.3
|
7 |
+
wget
|
sd/core.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from diffusers import StableDiffusionPipeline
|
5 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
6 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
7 |
+
from sd.pnp_utils import register_time, register_attention_control_efficient_kv_w_mask, register_conv_control_efficient_w_mask
|
8 |
+
import torch.nn as nn
|
9 |
+
from sd.dift_sd import MyUNet2DConditionModel, OneStepSDPipeline
|
10 |
+
import ipdb
|
11 |
+
from tqdm import tqdm
|
12 |
+
from lib.midas import MiDas
|
13 |
+
|
14 |
+
class DDIMBackward(StableDiffusionPipeline):
|
15 |
+
def __init__(
|
16 |
+
self, vae, text_encoder, tokenizer, unet, scheduler,
|
17 |
+
safety_checker, feature_extractor,
|
18 |
+
requires_safety_checker: bool = True,
|
19 |
+
device='cuda', model_id='ckpt/stable-diffusion-2-1-base',depth_model='dpt_swin2_large_384'
|
20 |
+
):
|
21 |
+
super().__init__(
|
22 |
+
vae, text_encoder, tokenizer, unet, scheduler,
|
23 |
+
safety_checker, feature_extractor, requires_safety_checker,
|
24 |
+
)
|
25 |
+
|
26 |
+
self.dift_unet = MyUNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=torch.float16 if 'cuda' in device else torch.float32)
|
27 |
+
self.onestep_pipe = OneStepSDPipeline.from_pretrained(model_id, unet=self.dift_unet, safety_checker=None, torch_dtype=torch.float16 if 'cuda' in device else torch.float32)
|
28 |
+
self.onestep_pipe = self.onestep_pipe.to(device)
|
29 |
+
|
30 |
+
if 'cuda' in device:
|
31 |
+
self.onestep_pipe.enable_attention_slicing()
|
32 |
+
self.onestep_pipe.enable_xformers_memory_efficient_attention()
|
33 |
+
self.ensemble_size = 4
|
34 |
+
self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
35 |
+
|
36 |
+
self.midas_model = MiDas(device,model_type=depth_model)
|
37 |
+
|
38 |
+
self.torch_dtype=torch.float16 if 'cuda' in device else torch.float32
|
39 |
+
|
40 |
+
|
41 |
+
@torch.no_grad()
|
42 |
+
def __call__(
|
43 |
+
self,
|
44 |
+
prompt: Union[str, List[str]] = None,
|
45 |
+
height: Optional[int] = None,
|
46 |
+
width: Optional[int] = None,
|
47 |
+
num_inference_steps: int = 50,
|
48 |
+
guidance_scale: float = 7.5,
|
49 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
50 |
+
num_images_per_prompt: Optional[int] = 1,
|
51 |
+
eta: float = 0.0,
|
52 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
53 |
+
latents: Optional[torch.FloatTensor] = None,
|
54 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
55 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
56 |
+
output_type: Optional[str] = "pil",
|
57 |
+
return_dict: bool = True,
|
58 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
59 |
+
callback_steps: int = 1,
|
60 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
61 |
+
t_start=None,
|
62 |
+
):
|
63 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
64 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
65 |
+
self.check_inputs(
|
66 |
+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
67 |
+
)
|
68 |
+
|
69 |
+
if prompt is not None and isinstance(prompt, str):
|
70 |
+
batch_size = 1
|
71 |
+
elif prompt is not None and isinstance(prompt, list):
|
72 |
+
batch_size = len(prompt)
|
73 |
+
else:
|
74 |
+
batch_size = prompt_embeds.shape[0]
|
75 |
+
|
76 |
+
device = self._execution_device
|
77 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
78 |
+
prompt_embeds = self._encode_prompt(
|
79 |
+
prompt,
|
80 |
+
device,
|
81 |
+
num_images_per_prompt,
|
82 |
+
do_classifier_free_guidance,
|
83 |
+
negative_prompt,
|
84 |
+
prompt_embeds=prompt_embeds,
|
85 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
86 |
+
)
|
87 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
88 |
+
timesteps = self.scheduler.timesteps
|
89 |
+
num_channels_latents = self.unet.in_channels
|
90 |
+
latents = self.prepare_latents(
|
91 |
+
batch_size * num_images_per_prompt,
|
92 |
+
num_channels_latents,
|
93 |
+
height,
|
94 |
+
width,
|
95 |
+
prompt_embeds.dtype,
|
96 |
+
device,
|
97 |
+
generator,
|
98 |
+
latents,
|
99 |
+
)
|
100 |
+
|
101 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
102 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
103 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
104 |
+
for i, t in enumerate(timesteps):
|
105 |
+
if t_start and t >= t_start:
|
106 |
+
progress_bar.update()
|
107 |
+
continue
|
108 |
+
|
109 |
+
# expand the latents if we are doing classifier free guidance
|
110 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
111 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
112 |
+
|
113 |
+
# predict the noise residual
|
114 |
+
noise_pred = self.unet(
|
115 |
+
latent_model_input,
|
116 |
+
t,
|
117 |
+
encoder_hidden_states=prompt_embeds,
|
118 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
119 |
+
).sample
|
120 |
+
|
121 |
+
# perform guidance
|
122 |
+
if do_classifier_free_guidance:
|
123 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
124 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
125 |
+
|
126 |
+
# compute the previous noisy sample x_t -> x_t-1
|
127 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
128 |
+
|
129 |
+
# call the callback, if provided
|
130 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
131 |
+
progress_bar.update()
|
132 |
+
if callback is not None and i % callback_steps == 0:
|
133 |
+
callback(i, t, latents)
|
134 |
+
|
135 |
+
if output_type == "latent":
|
136 |
+
image = latents
|
137 |
+
has_nsfw_concept = None
|
138 |
+
elif output_type == "pil":
|
139 |
+
image = self.decode_latents(latents)
|
140 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
141 |
+
image = self.numpy_to_pil(image)
|
142 |
+
else:
|
143 |
+
image = self.decode_latents(latents)
|
144 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
145 |
+
|
146 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
147 |
+
self.final_offload_hook.offload()
|
148 |
+
|
149 |
+
if not return_dict:
|
150 |
+
return (image, has_nsfw_concept)
|
151 |
+
|
152 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
153 |
+
|
154 |
+
def denoise_w_injection(
|
155 |
+
self,
|
156 |
+
prompt: Union[str, List[str]] = None,
|
157 |
+
height: Optional[int] = None,
|
158 |
+
width: Optional[int] = None,
|
159 |
+
num_inference_steps: int = 50,
|
160 |
+
guidance_scale: float = 7.5,
|
161 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
162 |
+
num_images_per_prompt: Optional[int] = 1,
|
163 |
+
eta: float = 0.0,
|
164 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
165 |
+
latents: Optional[torch.FloatTensor] = None,
|
166 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
167 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
168 |
+
output_type: Optional[str] = "pil",
|
169 |
+
return_dict: bool = True,
|
170 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
171 |
+
callback_steps: int = 1,
|
172 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
173 |
+
t_start=None,
|
174 |
+
attn=0.8,
|
175 |
+
f=0.5,
|
176 |
+
latent_mask=None,
|
177 |
+
guidance_loss_scale=0,
|
178 |
+
cfg_decay=False,
|
179 |
+
cfg_norm=False,
|
180 |
+
lr=1.0,
|
181 |
+
up_ft_indexes=[1,2],
|
182 |
+
img_tensor=None,
|
183 |
+
early_stop=50,
|
184 |
+
intrinsic=None, extrinsic=None, threshold=20,depth=None,
|
185 |
+
):
|
186 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
187 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
188 |
+
self.check_inputs(
|
189 |
+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
190 |
+
)
|
191 |
+
|
192 |
+
if prompt is not None and isinstance(prompt, str):
|
193 |
+
batch_size = 1
|
194 |
+
elif prompt is not None and isinstance(prompt, list):
|
195 |
+
batch_size = len(prompt)
|
196 |
+
else:
|
197 |
+
batch_size = prompt_embeds.shape[0]
|
198 |
+
|
199 |
+
device = self._execution_device
|
200 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
201 |
+
prompt_embeds = self._encode_prompt(
|
202 |
+
prompt,
|
203 |
+
device,
|
204 |
+
num_images_per_prompt,
|
205 |
+
do_classifier_free_guidance,
|
206 |
+
negative_prompt,
|
207 |
+
prompt_embeds=prompt_embeds,
|
208 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
209 |
+
)
|
210 |
+
if do_classifier_free_guidance:
|
211 |
+
prompt_embeds = torch.cat((prompt_embeds[1:], prompt_embeds[1:], prompt_embeds[:1]), dim=0)
|
212 |
+
else:
|
213 |
+
prompt_embeds = torch.cat([prompt_embeds]*2, dim=0)
|
214 |
+
|
215 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
216 |
+
timesteps = self.scheduler.timesteps
|
217 |
+
num_channels_latents = self.unet.in_channels
|
218 |
+
latents = self.prepare_latents(
|
219 |
+
batch_size * num_images_per_prompt,
|
220 |
+
num_channels_latents,
|
221 |
+
height,
|
222 |
+
width,
|
223 |
+
prompt_embeds.dtype,
|
224 |
+
device,
|
225 |
+
generator,
|
226 |
+
latents,
|
227 |
+
)
|
228 |
+
|
229 |
+
kv_injection_timesteps = self.scheduler.timesteps[:int(len(self.scheduler.timesteps) * attn)]
|
230 |
+
f_injection_timesteps = self.scheduler.timesteps[:int(len(self.scheduler.timesteps) * f)]
|
231 |
+
register_attention_control_efficient_kv_w_mask(self, kv_injection_timesteps, mask=latent_mask, do_classifier_free_guidance=do_classifier_free_guidance)
|
232 |
+
register_conv_control_efficient_w_mask(self, f_injection_timesteps, mask=latent_mask)
|
233 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
234 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
235 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
236 |
+
for i, t in enumerate(timesteps):
|
237 |
+
if t_start and t >= t_start:
|
238 |
+
progress_bar.update()
|
239 |
+
continue
|
240 |
+
if i > early_stop: guidance_loss_scale = 0 # Early stop (optional)
|
241 |
+
# if t > 300: guidance_loss_scale = 0 # Early stop (optional)
|
242 |
+
register_time(self, t.item())
|
243 |
+
# Set requires grad
|
244 |
+
if guidance_loss_scale != 0:
|
245 |
+
latents = latents.detach().requires_grad_()
|
246 |
+
|
247 |
+
# expand the latents if we are doing classifier free guidance
|
248 |
+
latent_model_input = latents # latents: ori_z + wrap_z
|
249 |
+
if do_classifier_free_guidance:
|
250 |
+
latent_model_input = torch.cat([latent_model_input, latent_model_input[1:]], dim=0)
|
251 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
252 |
+
|
253 |
+
# predict the noise residual
|
254 |
+
if guidance_loss_scale != 0:
|
255 |
+
with torch.no_grad():
|
256 |
+
noise_pred = self.unet(
|
257 |
+
latent_model_input,
|
258 |
+
t,
|
259 |
+
encoder_hidden_states=prompt_embeds,
|
260 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
261 |
+
).sample
|
262 |
+
else:
|
263 |
+
with torch.no_grad():
|
264 |
+
noise_pred = self.unet(
|
265 |
+
latent_model_input,
|
266 |
+
t,
|
267 |
+
encoder_hidden_states=prompt_embeds,
|
268 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
269 |
+
).sample
|
270 |
+
|
271 |
+
# perform guidance
|
272 |
+
if do_classifier_free_guidance:
|
273 |
+
cfg_scale = guidance_scale
|
274 |
+
if cfg_decay: cfg_scale = 1 + guidance_scale * (1-i/num_inference_steps)
|
275 |
+
noise_pred_text, wrap_noise_pred_text, wrap_noise_pred_uncond = noise_pred.chunk(3)
|
276 |
+
noise_pred = wrap_noise_pred_text + cfg_scale * (wrap_noise_pred_text - wrap_noise_pred_uncond)
|
277 |
+
else:
|
278 |
+
noise_pred_text, wrap_noise_pred_text = noise_pred.chunk(3)
|
279 |
+
noise_pred = wrap_noise_pred_text
|
280 |
+
|
281 |
+
if cfg_norm:
|
282 |
+
noise_pred = noise_pred * (torch.linalg.norm(wrap_noise_pred_uncond) / torch.linalg.norm(noise_pred))
|
283 |
+
|
284 |
+
if guidance_loss_scale != 0:
|
285 |
+
for up_ft_index in up_ft_indexes:
|
286 |
+
|
287 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
288 |
+
alpha_prod_t_prev = (
|
289 |
+
self.scheduler.alphas_cumprod[timesteps[i - 0]]
|
290 |
+
if i > 0 else self.scheduler.final_alpha_cumprod
|
291 |
+
)
|
292 |
+
|
293 |
+
mu = alpha_prod_t ** 0.5
|
294 |
+
mu_prev = alpha_prod_t_prev ** 0.5
|
295 |
+
sigma = (1 - alpha_prod_t) ** 0.5
|
296 |
+
sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
|
297 |
+
|
298 |
+
pred_x0 = (latents - sigma_prev * noise_pred[:latents.shape[0]]) / mu_prev
|
299 |
+
|
300 |
+
unet_ft_all = self.onestep_pipe(
|
301 |
+
latents=pred_x0[:1].repeat(self.ensemble_size, 1, 1, 1),
|
302 |
+
t=t,
|
303 |
+
up_ft_indices=[up_ft_index],
|
304 |
+
prompt_embeds=prompt_embeds[:1].repeat(self.ensemble_size, 1, 1)
|
305 |
+
)
|
306 |
+
unet_ft1 = unet_ft_all['up_ft'][up_ft_index].mean(0, keepdim=True) # 1,c,h,w
|
307 |
+
unet_ft1_norm = unet_ft1 / torch.norm(unet_ft1, dim=1, keepdim=True)
|
308 |
+
|
309 |
+
unet_ft1_norm = self.midas_model.wrap_img_tensor_w_fft_ext(
|
310 |
+
unet_ft1_norm.to(self.torch_dtype),
|
311 |
+
torch.from_numpy(depth).to(device).to(self.torch_dtype),
|
312 |
+
intrinsic,
|
313 |
+
extrinsic[:3,:3], extrinsic[:3,3], threshold=threshold).to(self.torch_dtype)
|
314 |
+
|
315 |
+
unet_ft_all = self.onestep_pipe(
|
316 |
+
latents=pred_x0[1:2].repeat(self.ensemble_size, 1, 1, 1),
|
317 |
+
t=t,
|
318 |
+
up_ft_indices=[up_ft_index],
|
319 |
+
prompt_embeds=prompt_embeds[:1].repeat(self.ensemble_size, 1, 1)
|
320 |
+
)
|
321 |
+
unet_ft2 = unet_ft_all['up_ft'][up_ft_index].mean(0, keepdim=True) # 1,c,h,w
|
322 |
+
unet_ft2_norm = unet_ft2 / torch.norm(unet_ft2, dim=1, keepdim=True)
|
323 |
+
c = unet_ft2.shape[1]
|
324 |
+
loss = (-self.cos(unet_ft1_norm.squeeze().view(c, -1).T, unet_ft2_norm.squeeze().view(c, -1).T).mean() + 1) / 2.
|
325 |
+
# Get gradient
|
326 |
+
cond_grad = torch.autograd.grad(loss * guidance_loss_scale, latents)[0][1:2]
|
327 |
+
|
328 |
+
# compute the previous noisy sample x_t -> x_t-1
|
329 |
+
noise_pred_ = noise_pred - sigma_prev * cond_grad*lr
|
330 |
+
noise_pred_ = torch.cat([noise_pred_text, noise_pred_], dim=0)
|
331 |
+
|
332 |
+
# compute the previous noisy sample x_t -> x_t-1
|
333 |
+
with torch.no_grad():
|
334 |
+
latents = self.scheduler.step(noise_pred_, t, latents, **extra_step_kwargs).prev_sample
|
335 |
+
# call the callback, if provided
|
336 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
337 |
+
progress_bar.update()
|
338 |
+
if callback is not None and i % callback_steps == 0:
|
339 |
+
callback(i, t, latents)
|
340 |
+
|
341 |
+
if output_type == "latent":
|
342 |
+
image = latents
|
343 |
+
has_nsfw_concept = None
|
344 |
+
elif output_type == "pil":
|
345 |
+
with torch.no_grad():
|
346 |
+
image = self.decode_latents(latents)
|
347 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
348 |
+
image = self.numpy_to_pil(image)
|
349 |
+
else:
|
350 |
+
image = self.decode_latents(latents)
|
351 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
352 |
+
|
353 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
354 |
+
self.final_offload_hook.offload()
|
355 |
+
|
356 |
+
if not return_dict:
|
357 |
+
return (image, has_nsfw_concept)
|
358 |
+
|
359 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
360 |
+
@torch.no_grad()
|
361 |
+
def decoder(self, latents):
|
362 |
+
with torch.autocast(device_type=self.device, dtype=torch.float32):
|
363 |
+
latents = 1 / 0.18215 * latents
|
364 |
+
imgs = self.vae.decode(latents).sample
|
365 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
366 |
+
return imgs
|
367 |
+
|
368 |
+
|
369 |
+
def ddim_inversion_w_grad(self, latent, cond, stop_t, guidance_loss_scale=1.0, lr=1.0):
|
370 |
+
timesteps = reversed(self.scheduler.timesteps)
|
371 |
+
with torch.autocast(device_type=self.device, dtype=torch.float32):
|
372 |
+
|
373 |
+
for i, t in enumerate(tqdm(timesteps)):
|
374 |
+
if t >= stop_t:
|
375 |
+
break
|
376 |
+
|
377 |
+
if guidance_loss_scale != 0:
|
378 |
+
latent = latent.detach().requires_grad_()
|
379 |
+
cond_batch = cond.repeat(latent.shape[0], 1, 1)
|
380 |
+
|
381 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
382 |
+
alpha_prod_t_prev = (
|
383 |
+
self.scheduler.alphas_cumprod[timesteps[i - 1]]
|
384 |
+
if i > 0 else self.scheduler.final_alpha_cumprod
|
385 |
+
)
|
386 |
+
|
387 |
+
mu = alpha_prod_t ** 0.5
|
388 |
+
mu_prev = alpha_prod_t_prev ** 0.5
|
389 |
+
sigma = (1 - alpha_prod_t) ** 0.5
|
390 |
+
sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
|
391 |
+
|
392 |
+
eps = self.onestep_pipe.unet(latent, t, encoder_hidden_states=cond_batch, up_ft_indices=[3], output_eps=True)['eps']
|
393 |
+
pred_x0 = (latent - sigma_prev * eps) / mu_prev
|
394 |
+
|
395 |
+
unet_ft_all = self.onestep_pipe(
|
396 |
+
latents=pred_x0[:1].repeat(self.ensemble_size, 1, 1, 1),
|
397 |
+
t=t,
|
398 |
+
up_ft_indices=[1],
|
399 |
+
prompt_embeds=cond_batch[:1].repeat(self.ensemble_size, 1, 1)
|
400 |
+
)
|
401 |
+
unet_ft1 = unet_ft_all['up_ft'][1].mean(0, keepdim=True) # 1,c,h,w
|
402 |
+
unet_ft1_norm = unet_ft1 / torch.norm(unet_ft1, dim=1, keepdim=True)
|
403 |
+
|
404 |
+
unet_ft_all = self.onestep_pipe(
|
405 |
+
latents=pred_x0[1:2].repeat(self.ensemble_size, 1, 1, 1),
|
406 |
+
t=t,
|
407 |
+
up_ft_indices=[1],
|
408 |
+
prompt_embeds=cond_batch[:1].repeat(self.ensemble_size, 1, 1)
|
409 |
+
)
|
410 |
+
unet_ft2 = unet_ft_all['up_ft'][1].mean(0, keepdim=True) # 1,c,h,w
|
411 |
+
unet_ft2_norm = unet_ft2 / torch.norm(unet_ft2, dim=1, keepdim=True)
|
412 |
+
c = unet_ft2.shape[1]
|
413 |
+
loss = (-self.cos(unet_ft1_norm.squeeze().view(c, -1).T.detach(), unet_ft2_norm.squeeze().view(c, -1).T).mean() + 1) / 2.
|
414 |
+
print(f'loss: {loss.item()}')
|
415 |
+
# Get gradient
|
416 |
+
cond_grad = torch.autograd.grad(loss * guidance_loss_scale, latent)[0]
|
417 |
+
|
418 |
+
# latent = latent.detach() - cond_grad * lr
|
419 |
+
latent = mu * pred_x0 + sigma * eps - cond_grad * lr
|
420 |
+
|
421 |
+
return latent
|
422 |
+
|
423 |
+
@torch.no_grad()
|
424 |
+
def DDPM_forward(x_t_dot, t_start, delta_t, ddpm_scheduler, generator):
|
425 |
+
# just simple implementation, this should have an analytical expression
|
426 |
+
# TODO: implementation analytical form
|
427 |
+
for delta in range(1, delta_t):
|
428 |
+
# noise = torch.randn_like(x_t_dot, generator=generator)
|
429 |
+
noise = torch.empty_like(x_t_dot).normal_(generator=generator)
|
430 |
+
|
431 |
+
beta = ddpm_scheduler.betas[t_start+delta]
|
432 |
+
std_ = beta ** 0.5
|
433 |
+
mu_ = ((1 - beta) ** 0.5) * x_t_dot
|
434 |
+
x_t_dot = mu_ + std_ * noise
|
435 |
+
return x_t_dot
|
sd/dift_sd.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import StableDiffusionPipeline
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
7 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
8 |
+
from diffusers import DDIMScheduler
|
9 |
+
import gc
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
class MyUNet2DConditionModel(UNet2DConditionModel):
|
13 |
+
def forward(
|
14 |
+
self,
|
15 |
+
sample: torch.FloatTensor,
|
16 |
+
timestep: Union[torch.Tensor, float, int],
|
17 |
+
up_ft_indices,
|
18 |
+
encoder_hidden_states: torch.Tensor,
|
19 |
+
class_labels: Optional[torch.Tensor] = None,
|
20 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
21 |
+
attention_mask: Optional[torch.Tensor] = None,
|
22 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
23 |
+
output_eps=False):
|
24 |
+
r"""
|
25 |
+
Args:
|
26 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
27 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
28 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
29 |
+
cross_attention_kwargs (`dict`, *optional*):
|
30 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
31 |
+
`self.processor` in
|
32 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
33 |
+
"""
|
34 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
35 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
36 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
37 |
+
# on the fly if necessary.
|
38 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
39 |
+
|
40 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
41 |
+
forward_upsample_size = False
|
42 |
+
upsample_size = None
|
43 |
+
|
44 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
45 |
+
# logger.info("Forward upsample size to force interpolation output size.")
|
46 |
+
forward_upsample_size = True
|
47 |
+
|
48 |
+
# prepare attention_mask
|
49 |
+
if attention_mask is not None:
|
50 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
51 |
+
attention_mask = attention_mask.unsqueeze(1)
|
52 |
+
|
53 |
+
# 0. center input if necessary
|
54 |
+
if self.config.center_input_sample:
|
55 |
+
sample = 2 * sample - 1.0
|
56 |
+
|
57 |
+
# 1. time
|
58 |
+
timesteps = timestep
|
59 |
+
if not torch.is_tensor(timesteps):
|
60 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
61 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
62 |
+
is_mps = sample.device.type == "mps"
|
63 |
+
if isinstance(timestep, float):
|
64 |
+
dtype = torch.float32 if is_mps else torch.float64
|
65 |
+
else:
|
66 |
+
dtype = torch.int32 if is_mps else torch.int64
|
67 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
68 |
+
elif len(timesteps.shape) == 0:
|
69 |
+
timesteps = timesteps[None].to(sample.device)
|
70 |
+
|
71 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
72 |
+
timesteps = timesteps.expand(sample.shape[0])
|
73 |
+
|
74 |
+
t_emb = self.time_proj(timesteps)
|
75 |
+
|
76 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
77 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
78 |
+
# there might be better ways to encapsulate this.
|
79 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
80 |
+
|
81 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
82 |
+
|
83 |
+
if self.class_embedding is not None:
|
84 |
+
if class_labels is None:
|
85 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
86 |
+
|
87 |
+
if self.config.class_embed_type == "timestep":
|
88 |
+
class_labels = self.time_proj(class_labels)
|
89 |
+
|
90 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
91 |
+
emb = emb + class_emb
|
92 |
+
|
93 |
+
# 2. pre-process
|
94 |
+
sample = self.conv_in(sample)
|
95 |
+
|
96 |
+
# 3. down
|
97 |
+
down_block_res_samples = (sample,)
|
98 |
+
for downsample_block in self.down_blocks:
|
99 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
100 |
+
sample, res_samples = downsample_block(
|
101 |
+
hidden_states=sample,
|
102 |
+
temb=emb,
|
103 |
+
encoder_hidden_states=encoder_hidden_states,
|
104 |
+
attention_mask=attention_mask,
|
105 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
109 |
+
|
110 |
+
down_block_res_samples += res_samples
|
111 |
+
|
112 |
+
# 4. mid
|
113 |
+
if self.mid_block is not None:
|
114 |
+
sample = self.mid_block(
|
115 |
+
sample,
|
116 |
+
emb,
|
117 |
+
encoder_hidden_states=encoder_hidden_states,
|
118 |
+
attention_mask=attention_mask,
|
119 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
120 |
+
)
|
121 |
+
|
122 |
+
# 5. up
|
123 |
+
up_ft = {}
|
124 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
125 |
+
|
126 |
+
if i > np.max(up_ft_indices):
|
127 |
+
break
|
128 |
+
|
129 |
+
is_final_block = i == len(self.up_blocks) - 1
|
130 |
+
|
131 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
132 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
133 |
+
|
134 |
+
# if we have not reached the final block and need to forward the
|
135 |
+
# upsample size, we do it here
|
136 |
+
if not is_final_block and forward_upsample_size:
|
137 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
138 |
+
|
139 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
140 |
+
sample = upsample_block(
|
141 |
+
hidden_states=sample,
|
142 |
+
temb=emb,
|
143 |
+
res_hidden_states_tuple=res_samples,
|
144 |
+
encoder_hidden_states=encoder_hidden_states,
|
145 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
146 |
+
upsample_size=upsample_size,
|
147 |
+
attention_mask=attention_mask,
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
sample = upsample_block(
|
151 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
152 |
+
)
|
153 |
+
|
154 |
+
if i in up_ft_indices:
|
155 |
+
up_ft[i] = sample
|
156 |
+
|
157 |
+
output = {}
|
158 |
+
output['up_ft'] = up_ft
|
159 |
+
if output_eps:
|
160 |
+
sample = self.conv_norm_out(sample)
|
161 |
+
sample = self.conv_act(sample)
|
162 |
+
sample = self.conv_out(sample)
|
163 |
+
output['eps'] = sample
|
164 |
+
return output
|
165 |
+
|
166 |
+
class OneStepSDPipeline(StableDiffusionPipeline):
|
167 |
+
# @torch.no_grad()
|
168 |
+
def __call__(
|
169 |
+
self,
|
170 |
+
|
171 |
+
t,
|
172 |
+
up_ft_indices,
|
173 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
174 |
+
img_tensor=None,
|
175 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
176 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
177 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
178 |
+
callback_steps: int = 1,
|
179 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
180 |
+
latents=None
|
181 |
+
):
|
182 |
+
|
183 |
+
device = self._execution_device
|
184 |
+
if latents is None:
|
185 |
+
latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor
|
186 |
+
t = torch.tensor(t.clone().detach(), dtype=torch.long, device=device)
|
187 |
+
noise = torch.randn_like(latents).to(device)
|
188 |
+
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
189 |
+
unet_output = self.unet(latents_noisy,
|
190 |
+
t,
|
191 |
+
up_ft_indices,
|
192 |
+
encoder_hidden_states=prompt_embeds,
|
193 |
+
cross_attention_kwargs=cross_attention_kwargs)
|
194 |
+
return unet_output
|
195 |
+
|
196 |
+
|
197 |
+
class SDFeaturizer:
|
198 |
+
def __init__(self, sd_id='ckpt/stable-diffusion-2-1-base'):
|
199 |
+
unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder="unet")
|
200 |
+
onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None)
|
201 |
+
onestep_pipe.vae.decoder = None
|
202 |
+
onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler")
|
203 |
+
gc.collect()
|
204 |
+
onestep_pipe = onestep_pipe.to("cuda")
|
205 |
+
onestep_pipe.enable_attention_slicing()
|
206 |
+
onestep_pipe.enable_xformers_memory_efficient_attention()
|
207 |
+
self.pipe = onestep_pipe
|
208 |
+
|
209 |
+
@torch.no_grad()
|
210 |
+
def forward(self,
|
211 |
+
img_tensor,
|
212 |
+
prompt,
|
213 |
+
t=261,
|
214 |
+
up_ft_index=1,
|
215 |
+
ensemble_size=8):
|
216 |
+
'''
|
217 |
+
Args:
|
218 |
+
img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W]
|
219 |
+
prompt: the prompt to use, a string
|
220 |
+
t: the time step to use, should be an int in the range of [0, 1000]
|
221 |
+
up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3]
|
222 |
+
ensemble_size: the number of repeated images used in the batch to extract features
|
223 |
+
Return:
|
224 |
+
unet_ft: a torch tensor in the shape of [1, c, h, w]
|
225 |
+
'''
|
226 |
+
img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w
|
227 |
+
prompt_embeds = self.pipe._encode_prompt(
|
228 |
+
prompt=prompt,
|
229 |
+
device='cuda',
|
230 |
+
num_images_per_prompt=1,
|
231 |
+
do_classifier_free_guidance=False) # [1, 77, dim]
|
232 |
+
prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1)
|
233 |
+
unet_ft_all = self.pipe(
|
234 |
+
img_tensor=img_tensor,
|
235 |
+
t=t,
|
236 |
+
up_ft_indices=[up_ft_index],
|
237 |
+
prompt_embeds=prompt_embeds)
|
238 |
+
unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w
|
239 |
+
unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w
|
240 |
+
return unet_ft
|
sd/gradio_utils.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import urllib.request
|
6 |
+
from typing import List, Optional, Tuple
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import PIL
|
10 |
+
import PIL.Image
|
11 |
+
import PIL.ImageDraw
|
12 |
+
import torch
|
13 |
+
import torch.optim
|
14 |
+
from tqdm import tqdm
|
15 |
+
import ipdb
|
16 |
+
|
17 |
+
def tensor_to_PIL(img: torch.Tensor) -> PIL.Image.Image:
|
18 |
+
"""
|
19 |
+
Converts a tensor image to a PIL Image.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
img (torch.Tensor): The tensor image of shape [batch_size, num_channels, height, width].
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
A PIL Image object.
|
26 |
+
"""
|
27 |
+
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
28 |
+
return PIL.Image.fromarray(img[0].cpu().numpy(), "RGB")
|
29 |
+
|
30 |
+
|
31 |
+
def get_ellipse_coords(
|
32 |
+
point: Tuple[int, int], radius: int = 5
|
33 |
+
) -> Tuple[int, int, int, int]:
|
34 |
+
"""
|
35 |
+
Returns the coordinates of an ellipse centered at the given point.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
point (Tuple[int, int]): The center point of the ellipse.
|
39 |
+
radius (int): The radius of the ellipse.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
A tuple containing the coordinates of the ellipse in the format (x_min, y_min, x_max, y_max).
|
43 |
+
"""
|
44 |
+
center = point
|
45 |
+
return (
|
46 |
+
center[0] - radius,
|
47 |
+
center[1] - radius,
|
48 |
+
center[0] + radius,
|
49 |
+
center[1] + radius,
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
def draw_handle_target_points(
|
55 |
+
img: PIL.Image.Image,
|
56 |
+
# handle_points: List[Tuple[int, int]],
|
57 |
+
target_points: List[Tuple[int, int]],
|
58 |
+
radius: int = 5):
|
59 |
+
"""
|
60 |
+
Draws handle and target points with arrow pointing towards the target point.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
img (PIL.Image.Image): The image to draw on.
|
64 |
+
handle_points (List[Tuple[int, int]]): A list of handle [x,y] points.
|
65 |
+
target_points (List[Tuple[int, int]]): A list of target [x,y] points.
|
66 |
+
radius (int): The radius of the handle and target points.
|
67 |
+
"""
|
68 |
+
if not isinstance(img, PIL.Image.Image):
|
69 |
+
img = PIL.Image.fromarray(img)
|
70 |
+
|
71 |
+
# if len(handle_points) == len(target_points) + 1:
|
72 |
+
# target_points = copy.deepcopy(target_points) + [None]
|
73 |
+
|
74 |
+
draw = PIL.ImageDraw.Draw(img)
|
75 |
+
for handle_point, target_point in zip(target_points, target_points):
|
76 |
+
# handle_point = [handle_point[1], handle_point[0]]
|
77 |
+
# Draw the handle point
|
78 |
+
# ipdb.set_trace()
|
79 |
+
|
80 |
+
target_coords = get_ellipse_coords(target_point, radius)
|
81 |
+
draw.ellipse((target_coords), fill="red")
|
82 |
+
|
83 |
+
return np.array(img)
|
84 |
+
|
85 |
+
|
sd/pnp_utils.py
ADDED
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import ipdb
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
def seed_everything(seed):
|
9 |
+
torch.manual_seed(seed)
|
10 |
+
torch.cuda.manual_seed(seed)
|
11 |
+
random.seed(seed)
|
12 |
+
np.random.seed(seed)
|
13 |
+
|
14 |
+
def register_time(model, t):
|
15 |
+
conv_module = model.unet.up_blocks[1].resnets[1]
|
16 |
+
setattr(conv_module, 't', t)
|
17 |
+
down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]}
|
18 |
+
up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
|
19 |
+
for res in up_res_dict:
|
20 |
+
for block in up_res_dict[res]:
|
21 |
+
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
|
22 |
+
setattr(module, 't', t)
|
23 |
+
for res in down_res_dict:
|
24 |
+
for block in down_res_dict[res]:
|
25 |
+
module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1
|
26 |
+
setattr(module, 't', t)
|
27 |
+
module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn1
|
28 |
+
setattr(module, 't', t)
|
29 |
+
|
30 |
+
|
31 |
+
def load_source_latents_t(t, latents_path):
|
32 |
+
latents_t_path = os.path.join(latents_path, f'noisy_latents_{t}.pt')
|
33 |
+
assert os.path.exists(latents_t_path), f'Missing latents at t {t} path {latents_t_path}'
|
34 |
+
latents = torch.load(latents_t_path)
|
35 |
+
return latents
|
36 |
+
|
37 |
+
def register_attention_control_efficient(model, injection_schedule):
|
38 |
+
def sa_forward(self):
|
39 |
+
to_out = self.to_out
|
40 |
+
if type(to_out) is torch.nn.modules.container.ModuleList:
|
41 |
+
to_out = self.to_out[0]
|
42 |
+
else:
|
43 |
+
to_out = self.to_out
|
44 |
+
|
45 |
+
def forward(x, encoder_hidden_states=None, attention_mask=None):
|
46 |
+
batch_size, sequence_length, dim = x.shape
|
47 |
+
h = self.heads
|
48 |
+
|
49 |
+
is_cross = encoder_hidden_states is not None
|
50 |
+
encoder_hidden_states = encoder_hidden_states if is_cross else x
|
51 |
+
if not is_cross and self.injection_schedule is not None and (
|
52 |
+
self.t in self.injection_schedule or self.t == 1000):
|
53 |
+
q = self.to_q(x)
|
54 |
+
k = self.to_k(encoder_hidden_states)
|
55 |
+
|
56 |
+
source_batch_size = int(q.shape[0] // 3)
|
57 |
+
# inject unconditional
|
58 |
+
q[source_batch_size:2 * source_batch_size] = q[:source_batch_size]
|
59 |
+
k[source_batch_size:2 * source_batch_size] = k[:source_batch_size]
|
60 |
+
# inject conditional
|
61 |
+
q[2 * source_batch_size:] = q[:source_batch_size]
|
62 |
+
k[2 * source_batch_size:] = k[:source_batch_size]
|
63 |
+
|
64 |
+
q = self.head_to_batch_dim(q)
|
65 |
+
k = self.head_to_batch_dim(k)
|
66 |
+
else:
|
67 |
+
q = self.to_q(x)
|
68 |
+
k = self.to_k(encoder_hidden_states)
|
69 |
+
q = self.head_to_batch_dim(q)
|
70 |
+
k = self.head_to_batch_dim(k)
|
71 |
+
|
72 |
+
v = self.to_v(encoder_hidden_states)
|
73 |
+
v = self.head_to_batch_dim(v)
|
74 |
+
|
75 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
76 |
+
|
77 |
+
if attention_mask is not None:
|
78 |
+
attention_mask = attention_mask.reshape(batch_size, -1)
|
79 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
80 |
+
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
|
81 |
+
sim.masked_fill_(~attention_mask, max_neg_value)
|
82 |
+
|
83 |
+
# attention, what we cannot get enough of
|
84 |
+
attn = sim.softmax(dim=-1)
|
85 |
+
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
86 |
+
out = self.batch_to_head_dim(out)
|
87 |
+
|
88 |
+
return to_out(out)
|
89 |
+
|
90 |
+
return forward
|
91 |
+
res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
|
92 |
+
for res in res_dict:
|
93 |
+
for block in res_dict[res]:
|
94 |
+
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
|
95 |
+
module.forward = sa_forward(module)
|
96 |
+
setattr(module, 'injection_schedule', injection_schedule)
|
97 |
+
|
98 |
+
def register_attention_control_efficient_kv(model, injection_schedule):
|
99 |
+
def sa_forward(self):
|
100 |
+
to_out = self.to_out
|
101 |
+
if type(to_out) is torch.nn.modules.container.ModuleList:
|
102 |
+
to_out = self.to_out[0]
|
103 |
+
else:
|
104 |
+
to_out = self.to_out
|
105 |
+
|
106 |
+
def forward(x, encoder_hidden_states=None, attention_mask=None):
|
107 |
+
batch_size, sequence_length, dim = x.shape
|
108 |
+
h = self.heads
|
109 |
+
# if encoder_hidden_states is None:
|
110 |
+
# ipdb.set_trace()
|
111 |
+
|
112 |
+
is_cross = encoder_hidden_states is not None
|
113 |
+
encoder_hidden_states = encoder_hidden_states if is_cross else x
|
114 |
+
|
115 |
+
q = self.to_q(x)
|
116 |
+
q = self.head_to_batch_dim(q)
|
117 |
+
|
118 |
+
if not is_cross and self.injection_schedule is not None and (
|
119 |
+
self.t in self.injection_schedule or self.t == 1000):
|
120 |
+
# q = self.to_q(x)
|
121 |
+
k = self.to_k(encoder_hidden_states)
|
122 |
+
v = self.to_v(encoder_hidden_states)
|
123 |
+
|
124 |
+
source_batch_size = int(v.shape[0] // 3)
|
125 |
+
# inject unconditional
|
126 |
+
k[source_batch_size:2 * source_batch_size] = k[:source_batch_size]
|
127 |
+
v[source_batch_size:2 * source_batch_size] = v[:source_batch_size]
|
128 |
+
|
129 |
+
# inject conditional
|
130 |
+
k[2 * source_batch_size:] = k[:source_batch_size]
|
131 |
+
v[2 * source_batch_size:] = v[:source_batch_size]
|
132 |
+
|
133 |
+
# q = self.head_to_batch_dim(q)
|
134 |
+
k = self.head_to_batch_dim(k)
|
135 |
+
v = self.head_to_batch_dim(v)
|
136 |
+
else:
|
137 |
+
# q = self.to_q(x)
|
138 |
+
k = self.to_k(encoder_hidden_states)
|
139 |
+
# q = self.head_to_batch_dim(q)
|
140 |
+
k = self.head_to_batch_dim(k)
|
141 |
+
|
142 |
+
v = self.to_v(encoder_hidden_states)
|
143 |
+
v = self.head_to_batch_dim(v)
|
144 |
+
|
145 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
146 |
+
|
147 |
+
if attention_mask is not None:
|
148 |
+
attention_mask = attention_mask.reshape(batch_size, -1)
|
149 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
150 |
+
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
|
151 |
+
sim.masked_fill_(~attention_mask, max_neg_value)
|
152 |
+
|
153 |
+
# attention, what we cannot get enough of
|
154 |
+
attn = sim.softmax(dim=-1)
|
155 |
+
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
156 |
+
out = self.batch_to_head_dim(out)
|
157 |
+
|
158 |
+
return to_out(out)
|
159 |
+
|
160 |
+
return forward
|
161 |
+
|
162 |
+
res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
|
163 |
+
for res in res_dict:
|
164 |
+
for block in res_dict[res]:
|
165 |
+
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
|
166 |
+
module.forward = sa_forward(module)
|
167 |
+
setattr(module, 'injection_schedule', injection_schedule)
|
168 |
+
|
169 |
+
|
170 |
+
def register_conv_control_efficient(model, injection_schedule):
|
171 |
+
def conv_forward(self):
|
172 |
+
def forward(input_tensor, temb):
|
173 |
+
hidden_states = input_tensor
|
174 |
+
|
175 |
+
hidden_states = self.norm1(hidden_states)
|
176 |
+
hidden_states = self.nonlinearity(hidden_states)
|
177 |
+
|
178 |
+
if self.upsample is not None:
|
179 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
180 |
+
if hidden_states.shape[0] >= 64:
|
181 |
+
input_tensor = input_tensor.contiguous()
|
182 |
+
hidden_states = hidden_states.contiguous()
|
183 |
+
input_tensor = self.upsample(input_tensor)
|
184 |
+
hidden_states = self.upsample(hidden_states)
|
185 |
+
elif self.downsample is not None:
|
186 |
+
input_tensor = self.downsample(input_tensor)
|
187 |
+
hidden_states = self.downsample(hidden_states)
|
188 |
+
|
189 |
+
hidden_states = self.conv1(hidden_states)
|
190 |
+
|
191 |
+
if temb is not None:
|
192 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
193 |
+
|
194 |
+
if temb is not None and self.time_embedding_norm == "default":
|
195 |
+
hidden_states = hidden_states + temb
|
196 |
+
|
197 |
+
hidden_states = self.norm2(hidden_states)
|
198 |
+
|
199 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
200 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
201 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
202 |
+
|
203 |
+
hidden_states = self.nonlinearity(hidden_states)
|
204 |
+
|
205 |
+
hidden_states = self.dropout(hidden_states)
|
206 |
+
hidden_states = self.conv2(hidden_states)
|
207 |
+
if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
|
208 |
+
source_batch_size = int(hidden_states.shape[0] // 3)
|
209 |
+
# inject unconditional
|
210 |
+
hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
|
211 |
+
# inject conditional
|
212 |
+
hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size]
|
213 |
+
|
214 |
+
if self.conv_shortcut is not None:
|
215 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
216 |
+
|
217 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
218 |
+
|
219 |
+
return output_tensor
|
220 |
+
|
221 |
+
return forward
|
222 |
+
|
223 |
+
conv_module = model.unet.up_blocks[1].resnets[1]
|
224 |
+
conv_module.forward = conv_forward(conv_module)
|
225 |
+
setattr(conv_module, 'injection_schedule', injection_schedule)
|
226 |
+
|
227 |
+
|
228 |
+
def register_attention_control_efficient_kv_2nd_to_1st(model, injection_schedule, mask=None):
|
229 |
+
def sa_forward(self):
|
230 |
+
to_out = self.to_out
|
231 |
+
if type(to_out) is torch.nn.modules.container.ModuleList:
|
232 |
+
to_out = self.to_out[0]
|
233 |
+
else:
|
234 |
+
to_out = self.to_out
|
235 |
+
|
236 |
+
def forward(x, mask=mask, encoder_hidden_states=None, attention_mask=None):
|
237 |
+
batch_size, sequence_length, dim = x.shape
|
238 |
+
h = self.heads
|
239 |
+
# if encoder_hidden_states is None:
|
240 |
+
# ipdb.set_trace()
|
241 |
+
is_cross = encoder_hidden_states is not None
|
242 |
+
encoder_hidden_states = encoder_hidden_states if is_cross else x
|
243 |
+
|
244 |
+
q = self.to_q(x)
|
245 |
+
q = self.head_to_batch_dim(q)
|
246 |
+
|
247 |
+
if not is_cross and self.injection_schedule is not None and (
|
248 |
+
self.t in self.injection_schedule or self.t == 1000):
|
249 |
+
# q = self.to_q(x)
|
250 |
+
target_size = int(np.sqrt(encoder_hidden_states.shape[1]))
|
251 |
+
target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
|
252 |
+
target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
|
253 |
+
k = self.to_k(encoder_hidden_states) # k: bx256x1280
|
254 |
+
v = self.to_v(encoder_hidden_states)
|
255 |
+
|
256 |
+
source_batch_size = int(v.shape[0] // 2)
|
257 |
+
# inject
|
258 |
+
k[:source_batch_size] = k[source_batch_size:2 * source_batch_size] * (1-target_mask) + k[:source_batch_size] * target_mask
|
259 |
+
v[:source_batch_size] = v[source_batch_size:2 * source_batch_size] * (1-target_mask) + v[:source_batch_size] * target_mask
|
260 |
+
|
261 |
+
# q = self.head_to_batch_dim(q)
|
262 |
+
k = self.head_to_batch_dim(k)
|
263 |
+
v = self.head_to_batch_dim(v)
|
264 |
+
else:
|
265 |
+
# q = self.to_q(x)
|
266 |
+
k = self.to_k(encoder_hidden_states)
|
267 |
+
# q = self.head_to_batch_dim(q)
|
268 |
+
k = self.head_to_batch_dim(k)
|
269 |
+
|
270 |
+
v = self.to_v(encoder_hidden_states)
|
271 |
+
v = self.head_to_batch_dim(v)
|
272 |
+
|
273 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
274 |
+
|
275 |
+
if attention_mask is not None:
|
276 |
+
attention_mask = attention_mask.reshape(batch_size, -1)
|
277 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
278 |
+
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
|
279 |
+
sim.masked_fill_(~attention_mask, max_neg_value)
|
280 |
+
|
281 |
+
# attention, what we cannot get enough of
|
282 |
+
attn = sim.softmax(dim=-1)
|
283 |
+
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
284 |
+
out = self.batch_to_head_dim(out)
|
285 |
+
|
286 |
+
return to_out(out)
|
287 |
+
|
288 |
+
return forward
|
289 |
+
|
290 |
+
# res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
|
291 |
+
res_dict = {1: [1, 2], 2: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
|
292 |
+
|
293 |
+
for res in res_dict:
|
294 |
+
for block in res_dict[res]:
|
295 |
+
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
|
296 |
+
module.forward = sa_forward(module)
|
297 |
+
setattr(module, 'injection_schedule', injection_schedule)
|
298 |
+
|
299 |
+
def register_conv_control_efficient_2nd_to_1st(model, injection_schedule, mask=None):
|
300 |
+
def conv_forward(self):
|
301 |
+
def forward(input_tensor, temb):
|
302 |
+
hidden_states = input_tensor
|
303 |
+
|
304 |
+
hidden_states = self.norm1(hidden_states)
|
305 |
+
hidden_states = self.nonlinearity(hidden_states)
|
306 |
+
|
307 |
+
if self.upsample is not None:
|
308 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
309 |
+
if hidden_states.shape[0] >= 64:
|
310 |
+
input_tensor = input_tensor.contiguous()
|
311 |
+
hidden_states = hidden_states.contiguous()
|
312 |
+
input_tensor = self.upsample(input_tensor)
|
313 |
+
hidden_states = self.upsample(hidden_states)
|
314 |
+
elif self.downsample is not None:
|
315 |
+
input_tensor = self.downsample(input_tensor)
|
316 |
+
hidden_states = self.downsample(hidden_states)
|
317 |
+
|
318 |
+
hidden_states = self.conv1(hidden_states)
|
319 |
+
|
320 |
+
if temb is not None:
|
321 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
322 |
+
|
323 |
+
if temb is not None and self.time_embedding_norm == "default":
|
324 |
+
hidden_states = hidden_states + temb
|
325 |
+
|
326 |
+
hidden_states = self.norm2(hidden_states)
|
327 |
+
|
328 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
329 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
330 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
331 |
+
|
332 |
+
hidden_states = self.nonlinearity(hidden_states)
|
333 |
+
|
334 |
+
hidden_states = self.dropout(hidden_states)
|
335 |
+
hidden_states = self.conv2(hidden_states)
|
336 |
+
if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
|
337 |
+
source_batch_size = int(hidden_states.shape[0] // 2)
|
338 |
+
# inject unconditional
|
339 |
+
# hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
|
340 |
+
# inject conditional
|
341 |
+
target_size = int(np.sqrt(hidden_states.shape[-1]))
|
342 |
+
target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
|
343 |
+
target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
|
344 |
+
|
345 |
+
hidden_states[:source_batch_size] = hidden_states[source_batch_size:] * (1-target_mask) + hidden_states[:source_batch_size] * target_mask
|
346 |
+
|
347 |
+
if self.conv_shortcut is not None:
|
348 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
349 |
+
|
350 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
351 |
+
|
352 |
+
return output_tensor
|
353 |
+
|
354 |
+
return forward
|
355 |
+
|
356 |
+
conv_module = model.unet.up_blocks[1].resnets[1]
|
357 |
+
conv_module.forward = conv_forward(conv_module)
|
358 |
+
setattr(conv_module, 'injection_schedule', injection_schedule)
|
359 |
+
|
360 |
+
|
361 |
+
def register_attention_control_efficient_qk_w_mask(model, injection_schedule, mask):
|
362 |
+
def sa_forward(self):
|
363 |
+
to_out = self.to_out
|
364 |
+
if type(to_out) is torch.nn.modules.container.ModuleList:
|
365 |
+
to_out = self.to_out[0]
|
366 |
+
else:
|
367 |
+
to_out = self.to_out
|
368 |
+
|
369 |
+
def forward(x, encoder_hidden_states=None, attention_mask=None):
|
370 |
+
batch_size, sequence_length, dim = x.shape
|
371 |
+
h = self.heads
|
372 |
+
|
373 |
+
is_cross = encoder_hidden_states is not None
|
374 |
+
encoder_hidden_states = encoder_hidden_states if is_cross else x
|
375 |
+
if not is_cross and self.injection_schedule is not None and (
|
376 |
+
self.t in self.injection_schedule or self.t == 1000):
|
377 |
+
q = self.to_q(x)
|
378 |
+
k = self.to_k(encoder_hidden_states)
|
379 |
+
|
380 |
+
target_size = int(np.sqrt(encoder_hidden_states.shape[1]))
|
381 |
+
target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
|
382 |
+
target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
|
383 |
+
|
384 |
+
source_batch_size = int(q.shape[0] // 3)
|
385 |
+
# inject unconditional
|
386 |
+
q[source_batch_size:2 * source_batch_size] = q[:source_batch_size] * target_mask + q[source_batch_size:2 * source_batch_size] * (1 - target_mask)
|
387 |
+
k[source_batch_size:2 * source_batch_size] = k[:source_batch_size] * target_mask + k[source_batch_size:2 * source_batch_size] * (1 - target_mask)
|
388 |
+
# inject conditional
|
389 |
+
q[2 * source_batch_size:] = q[:source_batch_size] * target_mask + q[2 * source_batch_size:] * (1 - target_mask)
|
390 |
+
k[2 * source_batch_size:] = k[:source_batch_size] * target_mask + k[2 * source_batch_size:] * (1 - target_mask)
|
391 |
+
|
392 |
+
q = self.head_to_batch_dim(q)
|
393 |
+
k = self.head_to_batch_dim(k)
|
394 |
+
else:
|
395 |
+
q = self.to_q(x)
|
396 |
+
k = self.to_k(encoder_hidden_states)
|
397 |
+
q = self.head_to_batch_dim(q)
|
398 |
+
k = self.head_to_batch_dim(k)
|
399 |
+
|
400 |
+
v = self.to_v(encoder_hidden_states)
|
401 |
+
v = self.head_to_batch_dim(v)
|
402 |
+
|
403 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
404 |
+
|
405 |
+
if attention_mask is not None:
|
406 |
+
attention_mask = attention_mask.reshape(batch_size, -1)
|
407 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
408 |
+
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
|
409 |
+
sim.masked_fill_(~attention_mask, max_neg_value)
|
410 |
+
|
411 |
+
# attention, what we cannot get enough of
|
412 |
+
attn = sim.softmax(dim=-1)
|
413 |
+
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
414 |
+
out = self.batch_to_head_dim(out)
|
415 |
+
|
416 |
+
return to_out(out)
|
417 |
+
|
418 |
+
return forward
|
419 |
+
res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
|
420 |
+
|
421 |
+
for res in res_dict:
|
422 |
+
for block in res_dict[res]:
|
423 |
+
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
|
424 |
+
module.forward = sa_forward(module)
|
425 |
+
setattr(module, 'injection_schedule', injection_schedule)
|
426 |
+
|
427 |
+
def register_attention_control_efficient_kv_w_mask(model, injection_schedule, mask, do_classifier_free_guidance):
|
428 |
+
def sa_forward(self):
|
429 |
+
to_out = self.to_out
|
430 |
+
if type(to_out) is torch.nn.modules.container.ModuleList:
|
431 |
+
to_out = self.to_out[0]
|
432 |
+
else:
|
433 |
+
to_out = self.to_out
|
434 |
+
|
435 |
+
def forward(x, encoder_hidden_states=None, attention_mask=None):
|
436 |
+
batch_size, sequence_length, dim = x.shape
|
437 |
+
h = self.heads
|
438 |
+
|
439 |
+
is_cross = encoder_hidden_states is not None
|
440 |
+
encoder_hidden_states = encoder_hidden_states if is_cross else x
|
441 |
+
|
442 |
+
q = self.to_q(x)
|
443 |
+
q = self.head_to_batch_dim(q)
|
444 |
+
|
445 |
+
if not is_cross and self.injection_schedule is not None and (
|
446 |
+
self.t in self.injection_schedule or self.t == 1000):
|
447 |
+
# if False:
|
448 |
+
k = self.to_k(encoder_hidden_states) # k: bx256x1280
|
449 |
+
v = self.to_v(encoder_hidden_states)
|
450 |
+
|
451 |
+
target_size = int(np.sqrt(encoder_hidden_states.shape[1]))
|
452 |
+
target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
|
453 |
+
target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
|
454 |
+
|
455 |
+
source_batch_size = int(v.shape[0] // 3)
|
456 |
+
if do_classifier_free_guidance:
|
457 |
+
# inject unconditional
|
458 |
+
v[source_batch_size:2 * source_batch_size] = v[:source_batch_size] * target_mask + v[source_batch_size:2 * source_batch_size] * (1 - target_mask)
|
459 |
+
k[source_batch_size:2 * source_batch_size] = k[:source_batch_size] * target_mask + k[source_batch_size:2 * source_batch_size] * (1 - target_mask)
|
460 |
+
# inject conditional
|
461 |
+
v[2 * source_batch_size:] = v[:source_batch_size] * target_mask + v[2 * source_batch_size:] * (1 - target_mask)
|
462 |
+
k[2 * source_batch_size:] = k[:source_batch_size] * target_mask + k[2 * source_batch_size:] * (1 - target_mask)
|
463 |
+
else:
|
464 |
+
v[source_batch_size:2 * source_batch_size] = v[:source_batch_size] * target_mask + v[source_batch_size:2 * source_batch_size] * (1 - target_mask)
|
465 |
+
k[source_batch_size:2 * source_batch_size] = k[:source_batch_size] * target_mask + k[source_batch_size:2 * source_batch_size] * (1 - target_mask)
|
466 |
+
|
467 |
+
k = self.head_to_batch_dim(k)
|
468 |
+
v = self.head_to_batch_dim(v)
|
469 |
+
else:
|
470 |
+
# q = self.to_q(x)
|
471 |
+
k = self.to_k(encoder_hidden_states)
|
472 |
+
# q = self.head_to_batch_dim(q)
|
473 |
+
k = self.head_to_batch_dim(k)
|
474 |
+
|
475 |
+
v = self.to_v(encoder_hidden_states)
|
476 |
+
v = self.head_to_batch_dim(v)
|
477 |
+
|
478 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
479 |
+
|
480 |
+
if attention_mask is not None:
|
481 |
+
attention_mask = attention_mask.reshape(batch_size, -1)
|
482 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
483 |
+
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
|
484 |
+
sim.masked_fill_(~attention_mask, max_neg_value)
|
485 |
+
|
486 |
+
# attention, what we cannot get enough of
|
487 |
+
attn = sim.softmax(dim=-1)
|
488 |
+
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
489 |
+
out = self.batch_to_head_dim(out)
|
490 |
+
|
491 |
+
return to_out(out)
|
492 |
+
|
493 |
+
return forward
|
494 |
+
res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
|
495 |
+
# res_dict = {1: [2], 2: [2], 3: [2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
|
496 |
+
|
497 |
+
for res in res_dict:
|
498 |
+
for block in res_dict[res]:
|
499 |
+
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
|
500 |
+
module.forward = sa_forward(module)
|
501 |
+
setattr(module, 'injection_schedule', injection_schedule)
|
502 |
+
# down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]}
|
503 |
+
# for res in down_res_dict:
|
504 |
+
# for block in down_res_dict[res]:
|
505 |
+
# module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1
|
506 |
+
# module.forward = sa_forward(module)
|
507 |
+
# setattr(module, 'injection_schedule', injection_schedule)
|
508 |
+
|
509 |
+
def register_conv_control_efficient_w_mask(model, injection_schedule, mask):
|
510 |
+
def conv_forward(self):
|
511 |
+
def forward(input_tensor, temb):
|
512 |
+
hidden_states = input_tensor
|
513 |
+
|
514 |
+
hidden_states = self.norm1(hidden_states)
|
515 |
+
hidden_states = self.nonlinearity(hidden_states)
|
516 |
+
|
517 |
+
if self.upsample is not None:
|
518 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
519 |
+
if hidden_states.shape[0] >= 64:
|
520 |
+
input_tensor = input_tensor.contiguous()
|
521 |
+
hidden_states = hidden_states.contiguous()
|
522 |
+
input_tensor = self.upsample(input_tensor)
|
523 |
+
hidden_states = self.upsample(hidden_states)
|
524 |
+
elif self.downsample is not None:
|
525 |
+
input_tensor = self.downsample(input_tensor)
|
526 |
+
hidden_states = self.downsample(hidden_states)
|
527 |
+
|
528 |
+
hidden_states = self.conv1(hidden_states)
|
529 |
+
|
530 |
+
if temb is not None:
|
531 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
532 |
+
|
533 |
+
if temb is not None and self.time_embedding_norm == "default":
|
534 |
+
hidden_states = hidden_states + temb
|
535 |
+
|
536 |
+
hidden_states = self.norm2(hidden_states)
|
537 |
+
|
538 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
539 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
540 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
541 |
+
|
542 |
+
hidden_states = self.nonlinearity(hidden_states)
|
543 |
+
|
544 |
+
hidden_states = self.dropout(hidden_states)
|
545 |
+
hidden_states = self.conv2(hidden_states)
|
546 |
+
if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
|
547 |
+
# if False:
|
548 |
+
source_batch_size = int(hidden_states.shape[0] // 3)
|
549 |
+
target_size = int(np.sqrt(hidden_states.shape[-1]))
|
550 |
+
target_mask = F.interpolate(mask.unsqueeze(1),size=(target_size, target_size))[:,0,:,:]
|
551 |
+
target_mask = target_mask.view(target_mask.shape[0], -1).unsqueeze(-1)
|
552 |
+
|
553 |
+
# inject unconditional
|
554 |
+
hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size] * target_mask + hidden_states[source_batch_size:2 * source_batch_size] * (1-target_mask)
|
555 |
+
# inject conditional
|
556 |
+
hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size] * target_mask + hidden_states[2 * source_batch_size:] * (1-target_mask)
|
557 |
+
|
558 |
+
if self.conv_shortcut is not None:
|
559 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
560 |
+
|
561 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
562 |
+
|
563 |
+
return output_tensor
|
564 |
+
|
565 |
+
return forward
|
566 |
+
|
567 |
+
conv_module = model.unet.up_blocks[1].resnets[1]
|
568 |
+
conv_module.forward = conv_forward(conv_module)
|
569 |
+
setattr(conv_module, 'injection_schedule', injection_schedule)
|
weights/dpt_beit_large_512.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e9e900747e9e8b3112df716979219836a27716277b3d0dc53889cbba8b82328
|
3 |
+
size 1581966003
|