Vijish commited on
Commit
21e2504
Β·
verified Β·
1 Parent(s): a1e7a71

Upload stable_diffusion_xl_reference.py

Browse files
Files changed (1) hide show
  1. stable_diffusion_xl_reference.py +818 -0
stable_diffusion_xl_reference.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on stable_diffusion_reference.py
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import PIL.Image
7
+ import torch
8
+
9
+ from diffusers import StableDiffusionXLPipeline
10
+ from diffusers.models.attention import BasicTransformerBlock
11
+ from diffusers.models.unets.unet_2d_blocks import (
12
+ CrossAttnDownBlock2D,
13
+ CrossAttnUpBlock2D,
14
+ DownBlock2D,
15
+ UpBlock2D,
16
+ )
17
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
18
+ from diffusers.utils import PIL_INTERPOLATION, logging
19
+ from diffusers.utils.torch_utils import randn_tensor
20
+
21
+
22
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23
+
24
+ EXAMPLE_DOC_STRING = """
25
+ Examples:
26
+ ```py
27
+ >>> import torch
28
+ >>> from diffusers import UniPCMultistepScheduler
29
+ >>> from diffusers.utils import load_image
30
+
31
+ >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
32
+
33
+ >>> pipe = StableDiffusionXLReferencePipeline.from_pretrained(
34
+ "stabilityai/stable-diffusion-xl-base-1.0",
35
+ torch_dtype=torch.float16,
36
+ use_safetensors=True,
37
+ variant="fp16").to('cuda:0')
38
+
39
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
40
+ >>> result_img = pipe(ref_image=input_image,
41
+ prompt="1girl",
42
+ num_inference_steps=20,
43
+ reference_attn=True,
44
+ reference_adain=True).images[0]
45
+
46
+ >>> result_img.show()
47
+ ```
48
+ """
49
+
50
+
51
+ def torch_dfs(model: torch.nn.Module):
52
+ result = [model]
53
+ for child in model.children():
54
+ result += torch_dfs(child)
55
+ return result
56
+
57
+
58
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
59
+
60
+
61
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
62
+ """
63
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
64
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
65
+ """
66
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
67
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
68
+ # rescale the results from guidance (fixes overexposure)
69
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
70
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
71
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
72
+ return noise_cfg
73
+
74
+
75
+ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
76
+ def _default_height_width(self, height, width, image):
77
+ # NOTE: It is possible that a list of images have different
78
+ # dimensions for each image, so just checking the first image
79
+ # is not _exactly_ correct, but it is simple.
80
+ while isinstance(image, list):
81
+ image = image[0]
82
+
83
+ if height is None:
84
+ if isinstance(image, PIL.Image.Image):
85
+ height = image.height
86
+ elif isinstance(image, torch.Tensor):
87
+ height = image.shape[2]
88
+
89
+ height = (height // 8) * 8 # round down to nearest multiple of 8
90
+
91
+ if width is None:
92
+ if isinstance(image, PIL.Image.Image):
93
+ width = image.width
94
+ elif isinstance(image, torch.Tensor):
95
+ width = image.shape[3]
96
+
97
+ width = (width // 8) * 8
98
+
99
+ return height, width
100
+
101
+ def prepare_image(
102
+ self,
103
+ image,
104
+ width,
105
+ height,
106
+ batch_size,
107
+ num_images_per_prompt,
108
+ device,
109
+ dtype,
110
+ do_classifier_free_guidance=False,
111
+ guess_mode=False,
112
+ ):
113
+ if not isinstance(image, torch.Tensor):
114
+ if isinstance(image, PIL.Image.Image):
115
+ image = [image]
116
+
117
+ if isinstance(image[0], PIL.Image.Image):
118
+ images = []
119
+
120
+ for image_ in image:
121
+ image_ = image_.convert("RGB")
122
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
123
+ image_ = np.array(image_)
124
+ image_ = image_[None, :]
125
+ images.append(image_)
126
+
127
+ image = images
128
+
129
+ image = np.concatenate(image, axis=0)
130
+ image = np.array(image).astype(np.float32) / 255.0
131
+ image = (image - 0.5) / 0.5
132
+ image = image.transpose(0, 3, 1, 2)
133
+ image = torch.from_numpy(image)
134
+
135
+ elif isinstance(image[0], torch.Tensor):
136
+ image = torch.stack(image, dim=0)
137
+
138
+ image_batch_size = image.shape[0]
139
+
140
+ if image_batch_size == 1:
141
+ repeat_by = batch_size
142
+ else:
143
+ repeat_by = num_images_per_prompt
144
+
145
+ image = image.repeat_interleave(repeat_by, dim=0)
146
+
147
+ image = image.to(device=device, dtype=dtype)
148
+
149
+ if do_classifier_free_guidance and not guess_mode:
150
+ image = torch.cat([image] * 2)
151
+
152
+ return image
153
+
154
+ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
155
+ refimage = refimage.to(device=device)
156
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
157
+ self.upcast_vae()
158
+ refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
159
+ if refimage.dtype != self.vae.dtype:
160
+ refimage = refimage.to(dtype=self.vae.dtype)
161
+ # encode the mask image into latents space so we can concatenate it to the latents
162
+ if isinstance(generator, list):
163
+ ref_image_latents = [
164
+ self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
165
+ for i in range(batch_size)
166
+ ]
167
+ ref_image_latents = torch.cat(ref_image_latents, dim=0)
168
+ else:
169
+ ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
170
+ ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
171
+
172
+ # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
173
+ if ref_image_latents.shape[0] < batch_size:
174
+ if not batch_size % ref_image_latents.shape[0] == 0:
175
+ raise ValueError(
176
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
177
+ f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
178
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
179
+ )
180
+ ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
181
+
182
+ ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
183
+
184
+ # aligning device to prevent device errors when concating it with the latent model input
185
+ ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
186
+ return ref_image_latents
187
+
188
+ @torch.no_grad()
189
+ def __call__(
190
+ self,
191
+ prompt: Union[str, List[str]] = None,
192
+ prompt_2: Optional[Union[str, List[str]]] = None,
193
+ ref_image: Union[torch.Tensor, PIL.Image.Image] = None,
194
+ height: Optional[int] = None,
195
+ width: Optional[int] = None,
196
+ num_inference_steps: int = 50,
197
+ denoising_end: Optional[float] = None,
198
+ guidance_scale: float = 5.0,
199
+ negative_prompt: Optional[Union[str, List[str]]] = None,
200
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
201
+ num_images_per_prompt: Optional[int] = 1,
202
+ eta: float = 0.0,
203
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
204
+ latents: Optional[torch.Tensor] = None,
205
+ prompt_embeds: Optional[torch.Tensor] = None,
206
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
207
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
208
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
209
+ output_type: Optional[str] = "pil",
210
+ return_dict: bool = True,
211
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
212
+ callback_steps: int = 1,
213
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
214
+ guidance_rescale: float = 0.0,
215
+ original_size: Optional[Tuple[int, int]] = None,
216
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
217
+ target_size: Optional[Tuple[int, int]] = None,
218
+ attention_auto_machine_weight: float = 1.0,
219
+ gn_auto_machine_weight: float = 1.0,
220
+ style_fidelity: float = 0.5,
221
+ reference_attn: bool = True,
222
+ reference_adain: bool = True,
223
+ ):
224
+ assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
225
+
226
+ # 0. Default height and width to unet
227
+ # height, width = self._default_height_width(height, width, ref_image)
228
+
229
+ height = height or self.default_sample_size * self.vae_scale_factor
230
+ width = width or self.default_sample_size * self.vae_scale_factor
231
+ original_size = original_size or (height, width)
232
+ target_size = target_size or (height, width)
233
+
234
+ # 1. Check inputs. Raise error if not correct
235
+ self.check_inputs(
236
+ prompt,
237
+ prompt_2,
238
+ height,
239
+ width,
240
+ callback_steps,
241
+ negative_prompt,
242
+ negative_prompt_2,
243
+ prompt_embeds,
244
+ negative_prompt_embeds,
245
+ pooled_prompt_embeds,
246
+ negative_pooled_prompt_embeds,
247
+ )
248
+
249
+ # 2. Define call parameters
250
+ if prompt is not None and isinstance(prompt, str):
251
+ batch_size = 1
252
+ elif prompt is not None and isinstance(prompt, list):
253
+ batch_size = len(prompt)
254
+ else:
255
+ batch_size = prompt_embeds.shape[0]
256
+
257
+ device = self._execution_device
258
+
259
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
260
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
261
+ # corresponds to doing no classifier free guidance.
262
+ do_classifier_free_guidance = guidance_scale > 1.0
263
+
264
+ # 3. Encode input prompt
265
+ text_encoder_lora_scale = (
266
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
267
+ )
268
+ (
269
+ prompt_embeds,
270
+ negative_prompt_embeds,
271
+ pooled_prompt_embeds,
272
+ negative_pooled_prompt_embeds,
273
+ ) = self.encode_prompt(
274
+ prompt=prompt,
275
+ prompt_2=prompt_2,
276
+ device=device,
277
+ num_images_per_prompt=num_images_per_prompt,
278
+ do_classifier_free_guidance=do_classifier_free_guidance,
279
+ negative_prompt=negative_prompt,
280
+ negative_prompt_2=negative_prompt_2,
281
+ prompt_embeds=prompt_embeds,
282
+ negative_prompt_embeds=negative_prompt_embeds,
283
+ pooled_prompt_embeds=pooled_prompt_embeds,
284
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
285
+ lora_scale=text_encoder_lora_scale,
286
+ )
287
+ # 4. Preprocess reference image
288
+ ref_image = self.prepare_image(
289
+ image=ref_image,
290
+ width=width,
291
+ height=height,
292
+ batch_size=batch_size * num_images_per_prompt,
293
+ num_images_per_prompt=num_images_per_prompt,
294
+ device=device,
295
+ dtype=prompt_embeds.dtype,
296
+ )
297
+
298
+ # 5. Prepare timesteps
299
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
300
+
301
+ timesteps = self.scheduler.timesteps
302
+
303
+ # 6. Prepare latent variables
304
+ num_channels_latents = self.unet.config.in_channels
305
+ latents = self.prepare_latents(
306
+ batch_size * num_images_per_prompt,
307
+ num_channels_latents,
308
+ height,
309
+ width,
310
+ prompt_embeds.dtype,
311
+ device,
312
+ generator,
313
+ latents,
314
+ )
315
+ # 7. Prepare reference latent variables
316
+ ref_image_latents = self.prepare_ref_latents(
317
+ ref_image,
318
+ batch_size * num_images_per_prompt,
319
+ prompt_embeds.dtype,
320
+ device,
321
+ generator,
322
+ do_classifier_free_guidance,
323
+ )
324
+
325
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
326
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
327
+
328
+ # 9. Modify self attebtion and group norm
329
+ MODE = "write"
330
+ uc_mask = (
331
+ torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
332
+ .type_as(ref_image_latents)
333
+ .bool()
334
+ )
335
+
336
+ def hacked_basic_transformer_inner_forward(
337
+ self,
338
+ hidden_states: torch.Tensor,
339
+ attention_mask: Optional[torch.Tensor] = None,
340
+ encoder_hidden_states: Optional[torch.Tensor] = None,
341
+ encoder_attention_mask: Optional[torch.Tensor] = None,
342
+ timestep: Optional[torch.LongTensor] = None,
343
+ cross_attention_kwargs: Dict[str, Any] = None,
344
+ class_labels: Optional[torch.LongTensor] = None,
345
+ ):
346
+ if self.use_ada_layer_norm:
347
+ norm_hidden_states = self.norm1(hidden_states, timestep)
348
+ elif self.use_ada_layer_norm_zero:
349
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
350
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
351
+ )
352
+ else:
353
+ norm_hidden_states = self.norm1(hidden_states)
354
+
355
+ # 1. Self-Attention
356
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
357
+ if self.only_cross_attention:
358
+ attn_output = self.attn1(
359
+ norm_hidden_states,
360
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
361
+ attention_mask=attention_mask,
362
+ **cross_attention_kwargs,
363
+ )
364
+ else:
365
+ if MODE == "write":
366
+ self.bank.append(norm_hidden_states.detach().clone())
367
+ attn_output = self.attn1(
368
+ norm_hidden_states,
369
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
370
+ attention_mask=attention_mask,
371
+ **cross_attention_kwargs,
372
+ )
373
+ if MODE == "read":
374
+ if attention_auto_machine_weight > self.attn_weight:
375
+ attn_output_uc = self.attn1(
376
+ norm_hidden_states,
377
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
378
+ # attention_mask=attention_mask,
379
+ **cross_attention_kwargs,
380
+ )
381
+ attn_output_c = attn_output_uc.clone()
382
+ if do_classifier_free_guidance and style_fidelity > 0:
383
+ attn_output_c[uc_mask] = self.attn1(
384
+ norm_hidden_states[uc_mask],
385
+ encoder_hidden_states=norm_hidden_states[uc_mask],
386
+ **cross_attention_kwargs,
387
+ )
388
+ attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
389
+ self.bank.clear()
390
+ else:
391
+ attn_output = self.attn1(
392
+ norm_hidden_states,
393
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
394
+ attention_mask=attention_mask,
395
+ **cross_attention_kwargs,
396
+ )
397
+ if self.use_ada_layer_norm_zero:
398
+ attn_output = gate_msa.unsqueeze(1) * attn_output
399
+ hidden_states = attn_output + hidden_states
400
+
401
+ if self.attn2 is not None:
402
+ norm_hidden_states = (
403
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
404
+ )
405
+
406
+ # 2. Cross-Attention
407
+ attn_output = self.attn2(
408
+ norm_hidden_states,
409
+ encoder_hidden_states=encoder_hidden_states,
410
+ attention_mask=encoder_attention_mask,
411
+ **cross_attention_kwargs,
412
+ )
413
+ hidden_states = attn_output + hidden_states
414
+
415
+ # 3. Feed-forward
416
+ norm_hidden_states = self.norm3(hidden_states)
417
+
418
+ if self.use_ada_layer_norm_zero:
419
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
420
+
421
+ ff_output = self.ff(norm_hidden_states)
422
+
423
+ if self.use_ada_layer_norm_zero:
424
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
425
+
426
+ hidden_states = ff_output + hidden_states
427
+
428
+ return hidden_states
429
+
430
+ def hacked_mid_forward(self, *args, **kwargs):
431
+ eps = 1e-6
432
+ x = self.original_forward(*args, **kwargs)
433
+ if MODE == "write":
434
+ if gn_auto_machine_weight >= self.gn_weight:
435
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
436
+ self.mean_bank.append(mean)
437
+ self.var_bank.append(var)
438
+ if MODE == "read":
439
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
440
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
441
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
442
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
443
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
444
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
445
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
446
+ x_c = x_uc.clone()
447
+ if do_classifier_free_guidance and style_fidelity > 0:
448
+ x_c[uc_mask] = x[uc_mask]
449
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
450
+ self.mean_bank = []
451
+ self.var_bank = []
452
+ return x
453
+
454
+ def hack_CrossAttnDownBlock2D_forward(
455
+ self,
456
+ hidden_states: torch.Tensor,
457
+ temb: Optional[torch.Tensor] = None,
458
+ encoder_hidden_states: Optional[torch.Tensor] = None,
459
+ attention_mask: Optional[torch.Tensor] = None,
460
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
461
+ encoder_attention_mask: Optional[torch.Tensor] = None,
462
+ ):
463
+ eps = 1e-6
464
+
465
+ # TODO(Patrick, William) - attention mask is not used
466
+ output_states = ()
467
+
468
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
469
+ hidden_states = resnet(hidden_states, temb)
470
+ hidden_states = attn(
471
+ hidden_states,
472
+ encoder_hidden_states=encoder_hidden_states,
473
+ cross_attention_kwargs=cross_attention_kwargs,
474
+ attention_mask=attention_mask,
475
+ encoder_attention_mask=encoder_attention_mask,
476
+ return_dict=False,
477
+ )[0]
478
+ if MODE == "write":
479
+ if gn_auto_machine_weight >= self.gn_weight:
480
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
481
+ self.mean_bank.append([mean])
482
+ self.var_bank.append([var])
483
+ if MODE == "read":
484
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
485
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
486
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
487
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
488
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
489
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
490
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
491
+ hidden_states_c = hidden_states_uc.clone()
492
+ if do_classifier_free_guidance and style_fidelity > 0:
493
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
494
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
495
+
496
+ output_states = output_states + (hidden_states,)
497
+
498
+ if MODE == "read":
499
+ self.mean_bank = []
500
+ self.var_bank = []
501
+
502
+ if self.downsamplers is not None:
503
+ for downsampler in self.downsamplers:
504
+ hidden_states = downsampler(hidden_states)
505
+
506
+ output_states = output_states + (hidden_states,)
507
+
508
+ return hidden_states, output_states
509
+
510
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None, *args, **kwargs):
511
+ eps = 1e-6
512
+
513
+ output_states = ()
514
+
515
+ for i, resnet in enumerate(self.resnets):
516
+ hidden_states = resnet(hidden_states, temb)
517
+
518
+ if MODE == "write":
519
+ if gn_auto_machine_weight >= self.gn_weight:
520
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
521
+ self.mean_bank.append([mean])
522
+ self.var_bank.append([var])
523
+ if MODE == "read":
524
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
525
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
526
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
527
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
528
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
529
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
530
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
531
+ hidden_states_c = hidden_states_uc.clone()
532
+ if do_classifier_free_guidance and style_fidelity > 0:
533
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
534
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
535
+
536
+ output_states = output_states + (hidden_states,)
537
+
538
+ if MODE == "read":
539
+ self.mean_bank = []
540
+ self.var_bank = []
541
+
542
+ if self.downsamplers is not None:
543
+ for downsampler in self.downsamplers:
544
+ hidden_states = downsampler(hidden_states)
545
+
546
+ output_states = output_states + (hidden_states,)
547
+
548
+ return hidden_states, output_states
549
+
550
+ def hacked_CrossAttnUpBlock2D_forward(
551
+ self,
552
+ hidden_states: torch.Tensor,
553
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
554
+ temb: Optional[torch.Tensor] = None,
555
+ encoder_hidden_states: Optional[torch.Tensor] = None,
556
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
557
+ upsample_size: Optional[int] = None,
558
+ attention_mask: Optional[torch.Tensor] = None,
559
+ encoder_attention_mask: Optional[torch.Tensor] = None,
560
+ ):
561
+ eps = 1e-6
562
+ # TODO(Patrick, William) - attention mask is not used
563
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
564
+ # pop res hidden states
565
+ res_hidden_states = res_hidden_states_tuple[-1]
566
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
567
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
568
+ hidden_states = resnet(hidden_states, temb)
569
+ hidden_states = attn(
570
+ hidden_states,
571
+ encoder_hidden_states=encoder_hidden_states,
572
+ cross_attention_kwargs=cross_attention_kwargs,
573
+ attention_mask=attention_mask,
574
+ encoder_attention_mask=encoder_attention_mask,
575
+ return_dict=False,
576
+ )[0]
577
+
578
+ if MODE == "write":
579
+ if gn_auto_machine_weight >= self.gn_weight:
580
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
581
+ self.mean_bank.append([mean])
582
+ self.var_bank.append([var])
583
+ if MODE == "read":
584
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
585
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
586
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
587
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
588
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
589
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
590
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
591
+ hidden_states_c = hidden_states_uc.clone()
592
+ if do_classifier_free_guidance and style_fidelity > 0:
593
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
594
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
595
+
596
+ if MODE == "read":
597
+ self.mean_bank = []
598
+ self.var_bank = []
599
+
600
+ if self.upsamplers is not None:
601
+ for upsampler in self.upsamplers:
602
+ hidden_states = upsampler(hidden_states, upsample_size)
603
+
604
+ return hidden_states
605
+
606
+ def hacked_UpBlock2D_forward(
607
+ self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
608
+ ):
609
+ eps = 1e-6
610
+ for i, resnet in enumerate(self.resnets):
611
+ # pop res hidden states
612
+ res_hidden_states = res_hidden_states_tuple[-1]
613
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
614
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
615
+ hidden_states = resnet(hidden_states, temb)
616
+
617
+ if MODE == "write":
618
+ if gn_auto_machine_weight >= self.gn_weight:
619
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
620
+ self.mean_bank.append([mean])
621
+ self.var_bank.append([var])
622
+ if MODE == "read":
623
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
624
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
625
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
626
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
627
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
628
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
629
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
630
+ hidden_states_c = hidden_states_uc.clone()
631
+ if do_classifier_free_guidance and style_fidelity > 0:
632
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
633
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
634
+
635
+ if MODE == "read":
636
+ self.mean_bank = []
637
+ self.var_bank = []
638
+
639
+ if self.upsamplers is not None:
640
+ for upsampler in self.upsamplers:
641
+ hidden_states = upsampler(hidden_states, upsample_size)
642
+
643
+ return hidden_states
644
+
645
+ if reference_attn:
646
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
647
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
648
+
649
+ for i, module in enumerate(attn_modules):
650
+ module._original_inner_forward = module.forward
651
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
652
+ module.bank = []
653
+ module.attn_weight = float(i) / float(len(attn_modules))
654
+
655
+ if reference_adain:
656
+ gn_modules = [self.unet.mid_block]
657
+ self.unet.mid_block.gn_weight = 0
658
+
659
+ down_blocks = self.unet.down_blocks
660
+ for w, module in enumerate(down_blocks):
661
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
662
+ gn_modules.append(module)
663
+
664
+ up_blocks = self.unet.up_blocks
665
+ for w, module in enumerate(up_blocks):
666
+ module.gn_weight = float(w) / float(len(up_blocks))
667
+ gn_modules.append(module)
668
+
669
+ for i, module in enumerate(gn_modules):
670
+ if getattr(module, "original_forward", None) is None:
671
+ module.original_forward = module.forward
672
+ if i == 0:
673
+ # mid_block
674
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
675
+ elif isinstance(module, CrossAttnDownBlock2D):
676
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
677
+ elif isinstance(module, DownBlock2D):
678
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
679
+ elif isinstance(module, CrossAttnUpBlock2D):
680
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
681
+ elif isinstance(module, UpBlock2D):
682
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
683
+ module.mean_bank = []
684
+ module.var_bank = []
685
+ module.gn_weight *= 2
686
+
687
+ # 10. Prepare added time ids & embeddings
688
+ add_text_embeds = pooled_prompt_embeds
689
+ if self.text_encoder_2 is None:
690
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
691
+ else:
692
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
693
+
694
+ add_time_ids = self._get_add_time_ids(
695
+ original_size,
696
+ crops_coords_top_left,
697
+ target_size,
698
+ dtype=prompt_embeds.dtype,
699
+ text_encoder_projection_dim=text_encoder_projection_dim,
700
+ )
701
+
702
+ if do_classifier_free_guidance:
703
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
704
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
705
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
706
+
707
+ prompt_embeds = prompt_embeds.to(device)
708
+ add_text_embeds = add_text_embeds.to(device)
709
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
710
+
711
+ # 11. Denoising loop
712
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
713
+
714
+ # 10.1 Apply denoising_end
715
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
716
+ discrete_timestep_cutoff = int(
717
+ round(
718
+ self.scheduler.config.num_train_timesteps
719
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
720
+ )
721
+ )
722
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
723
+ timesteps = timesteps[:num_inference_steps]
724
+
725
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
726
+ for i, t in enumerate(timesteps):
727
+ # expand the latents if we are doing classifier free guidance
728
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
729
+
730
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
731
+
732
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
733
+
734
+ # ref only part
735
+ noise = randn_tensor(
736
+ ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
737
+ )
738
+ ref_xt = self.scheduler.add_noise(
739
+ ref_image_latents,
740
+ noise,
741
+ t.reshape(
742
+ 1,
743
+ ),
744
+ )
745
+ ref_xt = self.scheduler.scale_model_input(ref_xt, t)
746
+
747
+ MODE = "write"
748
+
749
+ self.unet(
750
+ ref_xt,
751
+ t,
752
+ encoder_hidden_states=prompt_embeds,
753
+ cross_attention_kwargs=cross_attention_kwargs,
754
+ added_cond_kwargs=added_cond_kwargs,
755
+ return_dict=False,
756
+ )
757
+
758
+ # predict the noise residual
759
+ MODE = "read"
760
+ noise_pred = self.unet(
761
+ latent_model_input,
762
+ t,
763
+ encoder_hidden_states=prompt_embeds,
764
+ cross_attention_kwargs=cross_attention_kwargs,
765
+ added_cond_kwargs=added_cond_kwargs,
766
+ return_dict=False,
767
+ )[0]
768
+
769
+ # perform guidance
770
+ if do_classifier_free_guidance:
771
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
772
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
773
+
774
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
775
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
776
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
777
+
778
+ # compute the previous noisy sample x_t -> x_t-1
779
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
780
+
781
+ # call the callback, if provided
782
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
783
+ progress_bar.update()
784
+ if callback is not None and i % callback_steps == 0:
785
+ step_idx = i // getattr(self.scheduler, "order", 1)
786
+ callback(step_idx, t, latents)
787
+
788
+ if not output_type == "latent":
789
+ # make sure the VAE is in float32 mode, as it overflows in float16
790
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
791
+
792
+ if needs_upcasting:
793
+ self.upcast_vae()
794
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
795
+
796
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
797
+
798
+ # cast back to fp16 if needed
799
+ if needs_upcasting:
800
+ self.vae.to(dtype=torch.float16)
801
+ else:
802
+ image = latents
803
+ return StableDiffusionXLPipelineOutput(images=image)
804
+
805
+ # apply watermark if available
806
+ if self.watermark is not None:
807
+ image = self.watermark.apply_watermark(image)
808
+
809
+ image = self.image_processor.postprocess(image, output_type=output_type)
810
+
811
+ # Offload last model to CPU
812
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
813
+ self.final_offload_hook.offload()
814
+
815
+ if not return_dict:
816
+ return (image,)
817
+
818
+ return StableDiffusionXLPipelineOutput(images=image)