Tonic commited on
Commit
7de6816
Β·
1 Parent(s): 9942bf6

Create pipeline_calls.py

Browse files
Files changed (1) hide show
  1. pipeline_calls.py +552 -0
pipeline_calls.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+ from typing import Any
18
+ import torch
19
+ import numpy as np
20
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
21
+ from diffusers.image_processor import PipelineImageInput
22
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
23
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
24
+ from diffusers import StableDiffusionPanoramaPipeline
25
+ from PIL import Image
26
+ import copy
27
+
28
+ T = torch.Tensor
29
+ TN = T | None
30
+
31
+
32
+ def get_depth_map(image: Image, feature_processor: DPTImageProcessor, depth_estimator: DPTForDepthEstimation) -> Image:
33
+ image = feature_processor(images=image, return_tensors="pt").pixel_values.to("cuda")
34
+ with torch.no_grad(), torch.autocast("cuda"):
35
+ depth_map = depth_estimator(image).predicted_depth
36
+
37
+ depth_map = torch.nn.functional.interpolate(
38
+ depth_map.unsqueeze(1),
39
+ size=(1024, 1024),
40
+ mode="bicubic",
41
+ align_corners=False,
42
+ )
43
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
44
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
45
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
46
+ image = torch.cat([depth_map] * 3, dim=1)
47
+
48
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
49
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
50
+ return image
51
+
52
+
53
+ def concat_zero_control(control_reisduel: T) -> T:
54
+ b = control_reisduel.shape[0] // 2
55
+ zerso_reisduel = torch.zeros_like(control_reisduel[0:1])
56
+ return torch.cat((zerso_reisduel, control_reisduel[:b], zerso_reisduel, control_reisduel[b::]))
57
+
58
+
59
+ @torch.no_grad()
60
+ def controlnet_call(
61
+ pipeline: StableDiffusionXLControlNetPipeline,
62
+ prompt: str | list[str] = None,
63
+ prompt_2: str | list[str] | None = None,
64
+ image: PipelineImageInput = None,
65
+ height: int | None = None,
66
+ width: int | None = None,
67
+ num_inference_steps: int = 50,
68
+ guidance_scale: float = 5.0,
69
+ negative_prompt: str | list[str] | None = None,
70
+ negative_prompt_2: str | list[str] | None = None,
71
+ num_images_per_prompt: int = 1,
72
+ eta: float = 0.0,
73
+ generator: torch.Generator | None = None,
74
+ latents: TN = None,
75
+ prompt_embeds: TN = None,
76
+ negative_prompt_embeds: TN = None,
77
+ pooled_prompt_embeds: TN = None,
78
+ negative_pooled_prompt_embeds: TN = None,
79
+ cross_attention_kwargs: dict[str, Any] | None = None,
80
+ controlnet_conditioning_scale: float | list[float] = 1.0,
81
+ control_guidance_start: float | list[float] = 0.0,
82
+ control_guidance_end: float | list[float] = 1.0,
83
+ original_size: tuple[int, int] = None,
84
+ crops_coords_top_left: tuple[int, int] = (0, 0),
85
+ target_size: tuple[int, int] | None = None,
86
+ negative_original_size: tuple[int, int] | None = None,
87
+ negative_crops_coords_top_left: tuple[int, int] = (0, 0),
88
+ negative_target_size:tuple[int, int] | None = None,
89
+ clip_skip: int | None = None,
90
+ ) -> list[Image]:
91
+ controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet
92
+
93
+ # align format for control guidance
94
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
95
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
96
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
97
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
98
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
99
+ mult = 1
100
+ control_guidance_start, control_guidance_end = (
101
+ mult * [control_guidance_start],
102
+ mult * [control_guidance_end],
103
+ )
104
+
105
+ # 1. Check inputs. Raise error if not correct
106
+ pipeline.check_inputs(
107
+ prompt,
108
+ prompt_2,
109
+ image,
110
+ 1,
111
+ negative_prompt,
112
+ negative_prompt_2,
113
+ prompt_embeds,
114
+ negative_prompt_embeds,
115
+ pooled_prompt_embeds,
116
+ negative_pooled_prompt_embeds,
117
+ controlnet_conditioning_scale,
118
+ control_guidance_start,
119
+ control_guidance_end,
120
+ )
121
+
122
+ pipeline._guidance_scale = guidance_scale
123
+
124
+ # 2. Define call parameters
125
+ if prompt is not None and isinstance(prompt, str):
126
+ batch_size = 1
127
+ elif prompt is not None and isinstance(prompt, list):
128
+ batch_size = len(prompt)
129
+ else:
130
+ batch_size = prompt_embeds.shape[0]
131
+
132
+ device = pipeline._execution_device
133
+
134
+ # 3. Encode input prompt
135
+ text_encoder_lora_scale = (
136
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
137
+ )
138
+ (
139
+ prompt_embeds,
140
+ negative_prompt_embeds,
141
+ pooled_prompt_embeds,
142
+ negative_pooled_prompt_embeds,
143
+ ) = pipeline.encode_prompt(
144
+ prompt,
145
+ prompt_2,
146
+ device,
147
+ 1,
148
+ True,
149
+ negative_prompt,
150
+ negative_prompt_2,
151
+ prompt_embeds=prompt_embeds,
152
+ negative_prompt_embeds=negative_prompt_embeds,
153
+ pooled_prompt_embeds=pooled_prompt_embeds,
154
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
155
+ lora_scale=text_encoder_lora_scale,
156
+ clip_skip=clip_skip,
157
+ )
158
+
159
+ # 4. Prepare image
160
+ if isinstance(controlnet, ControlNetModel):
161
+ image = pipeline.prepare_image(
162
+ image=image,
163
+ width=width,
164
+ height=height,
165
+ batch_size=1,
166
+ num_images_per_prompt=1,
167
+ device=device,
168
+ dtype=controlnet.dtype,
169
+ do_classifier_free_guidance=True,
170
+ guess_mode=False,
171
+ )
172
+ height, width = image.shape[-2:]
173
+ image = torch.stack([image[0]] * num_images_per_prompt + [image[1]] * num_images_per_prompt)
174
+ else:
175
+ assert False
176
+ # 5. Prepare timesteps
177
+ pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
178
+ timesteps = pipeline.scheduler.timesteps
179
+
180
+ # 6. Prepare latent variables
181
+ num_channels_latents = pipeline.unet.config.in_channels
182
+ latents = pipeline.prepare_latents(
183
+ 1 + num_images_per_prompt,
184
+ num_channels_latents,
185
+ height,
186
+ width,
187
+ prompt_embeds.dtype,
188
+ device,
189
+ generator,
190
+ latents,
191
+ )
192
+
193
+ # 6.5 Optionally get Guidance Scale Embedding
194
+ timestep_cond = None
195
+
196
+ # 7. Prepare extra step kwargs.
197
+ extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
198
+
199
+ # 7.1 Create tensor stating which controlnets to keep
200
+ controlnet_keep = []
201
+ for i in range(len(timesteps)):
202
+ keeps = [
203
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
204
+ for s, e in zip(control_guidance_start, control_guidance_end)
205
+ ]
206
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
207
+
208
+ # 7.2 Prepare added time ids & embeddings
209
+ if isinstance(image, list):
210
+ original_size = original_size or image[0].shape[-2:]
211
+ else:
212
+ original_size = original_size or image.shape[-2:]
213
+ target_size = target_size or (height, width)
214
+
215
+ add_text_embeds = pooled_prompt_embeds
216
+ if pipeline.text_encoder_2 is None:
217
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
218
+ else:
219
+ text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim
220
+
221
+ add_time_ids = pipeline._get_add_time_ids(
222
+ original_size,
223
+ crops_coords_top_left,
224
+ target_size,
225
+ dtype=prompt_embeds.dtype,
226
+ text_encoder_projection_dim=text_encoder_projection_dim,
227
+ )
228
+
229
+ if negative_original_size is not None and negative_target_size is not None:
230
+ negative_add_time_ids = pipeline._get_add_time_ids(
231
+ negative_original_size,
232
+ negative_crops_coords_top_left,
233
+ negative_target_size,
234
+ dtype=prompt_embeds.dtype,
235
+ text_encoder_projection_dim=text_encoder_projection_dim,
236
+ )
237
+ else:
238
+ negative_add_time_ids = add_time_ids
239
+
240
+ prompt_embeds = torch.stack([prompt_embeds[0]] + [prompt_embeds[1]] * num_images_per_prompt)
241
+ negative_prompt_embeds = torch.stack([negative_prompt_embeds[0]] + [negative_prompt_embeds[1]] * num_images_per_prompt)
242
+ negative_pooled_prompt_embeds = torch.stack([negative_pooled_prompt_embeds[0]] + [negative_pooled_prompt_embeds[1]] * num_images_per_prompt)
243
+ add_text_embeds = torch.stack([add_text_embeds[0]] + [add_text_embeds[1]] * num_images_per_prompt)
244
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
245
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
246
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
247
+
248
+ prompt_embeds = prompt_embeds.to(device)
249
+ add_text_embeds = add_text_embeds.to(device)
250
+ add_time_ids = add_time_ids.to(device).repeat(1 + num_images_per_prompt, 1)
251
+ batch_size = num_images_per_prompt + 1
252
+ # 8. Denoising loop
253
+ num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
254
+ is_unet_compiled = is_compiled_module(pipeline.unet)
255
+ is_controlnet_compiled = is_compiled_module(pipeline.controlnet)
256
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
257
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
258
+ controlnet_prompt_embeds = torch.cat((prompt_embeds[1:batch_size], prompt_embeds[1:batch_size]))
259
+ controlnet_added_cond_kwargs = {key: torch.cat((item[1:batch_size,], item[1:batch_size])) for key, item in added_cond_kwargs.items()}
260
+ with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
261
+ for i, t in enumerate(timesteps):
262
+ # Relevant thread:
263
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
264
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
265
+ torch._inductor.cudagraph_mark_step_begin()
266
+ # expand the latents if we are doing classifier free guidance
267
+ latent_model_input = torch.cat([latents] * 2)
268
+ latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
269
+
270
+ # controlnet(s) inference
271
+ control_model_input = torch.cat((latent_model_input[1:batch_size,], latent_model_input[batch_size+1:]))
272
+
273
+ if isinstance(controlnet_keep[i], list):
274
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
275
+ else:
276
+ controlnet_cond_scale = controlnet_conditioning_scale
277
+ if isinstance(controlnet_cond_scale, list):
278
+ controlnet_cond_scale = controlnet_cond_scale[0]
279
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
280
+ if cond_scale > 0:
281
+ down_block_res_samples, mid_block_res_sample = pipeline.controlnet(
282
+ control_model_input,
283
+ t,
284
+ encoder_hidden_states=controlnet_prompt_embeds,
285
+ controlnet_cond=image,
286
+ conditioning_scale=cond_scale,
287
+ guess_mode=False,
288
+ added_cond_kwargs=controlnet_added_cond_kwargs,
289
+ return_dict=False,
290
+ )
291
+
292
+ mid_block_res_sample = concat_zero_control(mid_block_res_sample)
293
+ down_block_res_samples = [concat_zero_control(down_block_res_sample) for down_block_res_sample in down_block_res_samples]
294
+ else:
295
+ mid_block_res_sample = down_block_res_samples = None
296
+ # predict the noise residual
297
+ noise_pred = pipeline.unet(
298
+ latent_model_input,
299
+ t,
300
+ encoder_hidden_states=prompt_embeds,
301
+ timestep_cond=timestep_cond,
302
+ cross_attention_kwargs=cross_attention_kwargs,
303
+ down_block_additional_residuals=down_block_res_samples,
304
+ mid_block_additional_residual=mid_block_res_sample,
305
+ added_cond_kwargs=added_cond_kwargs,
306
+ return_dict=False,
307
+ )[0]
308
+
309
+ # perform guidance
310
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
311
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
312
+
313
+ # compute the previous noisy sample x_t -> x_t-1
314
+ latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
315
+
316
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
317
+ progress_bar.update()
318
+
319
+ # manually for max memory savings
320
+ if pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast:
321
+ pipeline.upcast_vae()
322
+ latents = latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype)
323
+
324
+ # make sure the VAE is in float32 mode, as it overflows in float16
325
+ needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast
326
+
327
+ if needs_upcasting:
328
+ pipeline.upcast_vae()
329
+ latents = latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype)
330
+
331
+ image = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor, return_dict=False)[0]
332
+
333
+ # cast back to fp16 if needed
334
+ if needs_upcasting:
335
+ pipeline.vae.to(dtype=torch.float16)
336
+
337
+ if pipeline.watermark is not None:
338
+ image = pipeline.watermark.apply_watermark(image)
339
+
340
+ image = pipeline.image_processor.postprocess(image, output_type='pil')
341
+
342
+ # Offload all models
343
+ pipeline.maybe_free_model_hooks()
344
+ return image
345
+
346
+
347
+ @torch.no_grad()
348
+ def panorama_call(
349
+ pipeline: StableDiffusionPanoramaPipeline,
350
+ prompt: list[str],
351
+ height: int | None = 512,
352
+ width: int | None = 2048,
353
+ num_inference_steps: int = 50,
354
+ guidance_scale: float = 7.5,
355
+ view_batch_size: int = 1,
356
+ negative_prompt: str | list[str] | None = None,
357
+ num_images_per_prompt: int | None = 1,
358
+ eta: float = 0.0,
359
+ generator: torch.Generator | None = None,
360
+ reference_latent: TN = None,
361
+ latents: TN = None,
362
+ prompt_embeds: TN = None,
363
+ negative_prompt_embeds: TN = None,
364
+ cross_attention_kwargs: dict[str, Any] | None = None,
365
+ circular_padding: bool = False,
366
+ clip_skip: int | None = None,
367
+ stride=8
368
+ ) -> list[Image]:
369
+ # 0. Default height and width to unet
370
+ height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
371
+ width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
372
+
373
+ # 1. Check inputs. Raise error if not correct
374
+ pipeline.check_inputs(
375
+ prompt, height, width, 1, negative_prompt, prompt_embeds, negative_prompt_embeds
376
+ )
377
+
378
+ device = pipeline._execution_device
379
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
380
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
381
+ # corresponds to doing no classifier free guidance.
382
+ do_classifier_free_guidance = guidance_scale > 1.0
383
+
384
+ # 3. Encode input prompt
385
+ text_encoder_lora_scale = (
386
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
387
+ )
388
+ prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
389
+ prompt,
390
+ device,
391
+ num_images_per_prompt,
392
+ do_classifier_free_guidance,
393
+ negative_prompt,
394
+ prompt_embeds=prompt_embeds,
395
+ negative_prompt_embeds=negative_prompt_embeds,
396
+ lora_scale=text_encoder_lora_scale,
397
+ clip_skip=clip_skip,
398
+ )
399
+ # For classifier free guidance, we need to do two forward passes.
400
+ # Here we concatenate the unconditional and text embeddings into a single batch
401
+ # to avoid doing two forward passes
402
+
403
+ # 4. Prepare timesteps
404
+ pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
405
+ timesteps = pipeline.scheduler.timesteps
406
+
407
+ # 5. Prepare latent variables
408
+ num_channels_latents = pipeline.unet.config.in_channels
409
+ latents = pipeline.prepare_latents(
410
+ 1,
411
+ num_channels_latents,
412
+ height,
413
+ width,
414
+ prompt_embeds.dtype,
415
+ device,
416
+ generator,
417
+ latents,
418
+ )
419
+ if reference_latent is None:
420
+ reference_latent = torch.randn(1, 4, pipeline.unet.config.sample_size, pipeline.unet.config.sample_size,
421
+ generator=generator)
422
+ reference_latent = reference_latent.to(device=device, dtype=pipeline.unet.dtype)
423
+ # 6. Define panorama grid and initialize views for synthesis.
424
+ # prepare batch grid
425
+ views = pipeline.get_views(height, width, circular_padding=circular_padding, stride=stride)
426
+ views_batch = [views[i: i + view_batch_size] for i in range(0, len(views), view_batch_size)]
427
+ views_scheduler_status = [copy.deepcopy(pipeline.scheduler.__dict__)] * len(views_batch)
428
+ count = torch.zeros_like(latents)
429
+ value = torch.zeros_like(latents)
430
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
431
+ extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
432
+
433
+ # 8. Denoising loop
434
+ # Each denoising step also includes refinement of the latents with respect to the
435
+ # views.
436
+ num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
437
+
438
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds[:1],
439
+ *[negative_prompt_embeds[1:]] * view_batch_size]
440
+ )
441
+ prompt_embeds = torch.cat([prompt_embeds[:1],
442
+ *[prompt_embeds[1:]] * view_batch_size]
443
+ )
444
+
445
+ with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
446
+ for i, t in enumerate(timesteps):
447
+ count.zero_()
448
+ value.zero_()
449
+
450
+ # generate views
451
+ # Here, we iterate through different spatial crops of the latents and denoise them. These
452
+ # denoised (latent) crops are then averaged to produce the final latent
453
+ # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
454
+ # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
455
+ # Batch views denoise
456
+ for j, batch_view in enumerate(views_batch):
457
+ vb_size = len(batch_view)
458
+ # get the latents corresponding to the current view coordinates
459
+ if circular_padding:
460
+ latents_for_view = []
461
+ for h_start, h_end, w_start, w_end in batch_view:
462
+ if w_end > latents.shape[3]:
463
+ # Add circular horizontal padding
464
+ latent_view = torch.cat(
465
+ (
466
+ latents[:, :, h_start:h_end, w_start:],
467
+ latents[:, :, h_start:h_end, : w_end - latents.shape[3]],
468
+ ),
469
+ dim=-1,
470
+ )
471
+ else:
472
+ latent_view = latents[:, :, h_start:h_end, w_start:w_end]
473
+ latents_for_view.append(latent_view)
474
+ latents_for_view = torch.cat(latents_for_view)
475
+ else:
476
+ latents_for_view = torch.cat(
477
+ [
478
+ latents[:, :, h_start:h_end, w_start:w_end]
479
+ for h_start, h_end, w_start, w_end in batch_view
480
+ ]
481
+ )
482
+ # rematch block's scheduler status
483
+ pipeline.scheduler.__dict__.update(views_scheduler_status[j])
484
+
485
+ # expand the latents if we are doing classifier free guidance
486
+ latent_reference_plus_view = torch.cat((reference_latent, latents_for_view))
487
+ latent_model_input = latent_reference_plus_view.repeat(2, 1, 1, 1)
488
+ prompt_embeds_input = torch.cat([negative_prompt_embeds[: 1 + vb_size],
489
+ prompt_embeds[: 1 + vb_size]]
490
+ )
491
+ latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
492
+ # predict the noise residual
493
+ # return
494
+ noise_pred = pipeline.unet(
495
+ latent_model_input,
496
+ t,
497
+ encoder_hidden_states=prompt_embeds_input,
498
+ cross_attention_kwargs=cross_attention_kwargs,
499
+ ).sample
500
+
501
+ # perform guidance
502
+
503
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
504
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
505
+ # compute the previous noisy sample x_t -> x_t-1
506
+ latent_reference_plus_view = pipeline.scheduler.step(
507
+ noise_pred, t, latent_reference_plus_view, **extra_step_kwargs
508
+ ).prev_sample
509
+ if j == len(views_batch) - 1:
510
+ reference_latent = latent_reference_plus_view[:1]
511
+ latents_denoised_batch = latent_reference_plus_view[1:]
512
+ # save views scheduler status after sample
513
+ views_scheduler_status[j] = copy.deepcopy(pipeline.scheduler.__dict__)
514
+
515
+ # extract value from batch
516
+ for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
517
+ latents_denoised_batch.chunk(vb_size), batch_view
518
+ ):
519
+ if circular_padding and w_end > latents.shape[3]:
520
+ # Case for circular padding
521
+ value[:, :, h_start:h_end, w_start:] += latents_view_denoised[
522
+ :, :, h_start:h_end, : latents.shape[3] - w_start
523
+ ]
524
+ value[:, :, h_start:h_end, : w_end - latents.shape[3]] += latents_view_denoised[
525
+ :, :, h_start:h_end,
526
+ latents.shape[3] - w_start:
527
+ ]
528
+ count[:, :, h_start:h_end, w_start:] += 1
529
+ count[:, :, h_start:h_end, : w_end - latents.shape[3]] += 1
530
+ else:
531
+ value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
532
+ count[:, :, h_start:h_end, w_start:w_end] += 1
533
+
534
+ # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
535
+ latents = torch.where(count > 0, value / count, value)
536
+
537
+ # call the callback, if provided
538
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
539
+ progress_bar.update()
540
+
541
+ if circular_padding:
542
+ image = pipeline.decode_latents_with_padding(latents)
543
+ else:
544
+ image = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor, return_dict=False)[0]
545
+ reference_image = pipeline.vae.decode(reference_latent / pipeline.vae.config.scaling_factor, return_dict=False)[0]
546
+ # image, has_nsfw_concept = pipeline.run_safety_checker(image, device, prompt_embeds.dtype)
547
+ # reference_image, _ = pipeline.run_safety_checker(reference_image, device, prompt_embeds.dtype)
548
+
549
+ image = pipeline.image_processor.postprocess(image, output_type='pil', do_denormalize=[True])
550
+ reference_image = pipeline.image_processor.postprocess(reference_image, output_type='pil', do_denormalize=[True])
551
+ pipeline.maybe_free_model_hooks()
552
+ return reference_image + image