prithivMLmods commited on
Commit
72a0c9a
·
verified ·
1 Parent(s): 942592d

Create pipeline_fill_sd_xl.py

Browse files
Files changed (1) hide show
  1. pipeline_fill_sd_xl.py +545 -0
pipeline_fill_sd_xl.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import cv2
4
+ import PIL.Image
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
8
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
9
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
10
+ from diffusers.schedulers import KarrasDiffusionSchedulers
11
+ from diffusers.utils.torch_utils import randn_tensor
12
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
13
+
14
+ from controlnet_union import ControlNetModel_Union
15
+
16
+
17
+ def latents_to_rgb(latents):
18
+ weights = ((60, -60, 25, -70), (60, -5, 15, -50), (60, 10, -5, -35))
19
+
20
+ weights_tensor = torch.t(
21
+ torch.tensor(weights, dtype=latents.dtype).to(latents.device)
22
+ )
23
+ biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(
24
+ latents.device
25
+ )
26
+ rgb_tensor = torch.einsum(
27
+ "...lxy,lr -> ...rxy", latents, weights_tensor
28
+ ) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
29
+ image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
30
+ image_array = image_array.transpose(1, 2, 0) # Change the order of dimensions
31
+
32
+ denoised_image = cv2.fastNlMeansDenoisingColored(image_array, None, 10, 10, 7, 21)
33
+ blurred_image = cv2.GaussianBlur(denoised_image, (5, 5), 0)
34
+ final_image = PIL.Image.fromarray(blurred_image)
35
+
36
+ width, height = final_image.size
37
+ final_image = final_image.resize(
38
+ (width * 8, height * 8), PIL.Image.Resampling.LANCZOS
39
+ )
40
+
41
+ return final_image
42
+
43
+
44
+ def retrieve_timesteps(
45
+ scheduler,
46
+ num_inference_steps: Optional[int] = None,
47
+ device: Optional[Union[str, torch.device]] = None,
48
+ **kwargs,
49
+ ):
50
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
51
+ timesteps = scheduler.timesteps
52
+
53
+ return timesteps, num_inference_steps
54
+
55
+
56
+ class StableDiffusionXLFillPipeline(DiffusionPipeline, StableDiffusionMixin):
57
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
58
+ _optional_components = [
59
+ "tokenizer",
60
+ "tokenizer_2",
61
+ "text_encoder",
62
+ "text_encoder_2",
63
+ ]
64
+
65
+ def __init__(
66
+ self,
67
+ vae: AutoencoderKL,
68
+ text_encoder: CLIPTextModel,
69
+ text_encoder_2: CLIPTextModelWithProjection,
70
+ tokenizer: CLIPTokenizer,
71
+ tokenizer_2: CLIPTokenizer,
72
+ unet: UNet2DConditionModel,
73
+ controlnet: ControlNetModel_Union,
74
+ scheduler: KarrasDiffusionSchedulers,
75
+ force_zeros_for_empty_prompt: bool = True,
76
+ ):
77
+ super().__init__()
78
+
79
+ self.register_modules(
80
+ vae=vae,
81
+ text_encoder=text_encoder,
82
+ text_encoder_2=text_encoder_2,
83
+ tokenizer=tokenizer,
84
+ tokenizer_2=tokenizer_2,
85
+ unet=unet,
86
+ controlnet=controlnet,
87
+ scheduler=scheduler,
88
+ )
89
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
90
+ self.image_processor = VaeImageProcessor(
91
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
92
+ )
93
+ self.control_image_processor = VaeImageProcessor(
94
+ vae_scale_factor=self.vae_scale_factor,
95
+ do_convert_rgb=True,
96
+ do_normalize=False,
97
+ )
98
+
99
+ self.register_to_config(
100
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt
101
+ )
102
+
103
+ def encode_prompt(
104
+ self,
105
+ prompt: str,
106
+ device: Optional[torch.device] = None,
107
+ do_classifier_free_guidance: bool = True,
108
+ ):
109
+ device = device or self._execution_device
110
+ prompt = [prompt] if isinstance(prompt, str) else prompt
111
+
112
+ if prompt is not None:
113
+ batch_size = len(prompt)
114
+
115
+ # Define tokenizers and text encoders
116
+ tokenizers = (
117
+ [self.tokenizer, self.tokenizer_2]
118
+ if self.tokenizer is not None
119
+ else [self.tokenizer_2]
120
+ )
121
+ text_encoders = (
122
+ [self.text_encoder, self.text_encoder_2]
123
+ if self.text_encoder is not None
124
+ else [self.text_encoder_2]
125
+ )
126
+
127
+ prompt_2 = prompt
128
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
129
+
130
+ # textual inversion: process multi-vector tokens if necessary
131
+ prompt_embeds_list = []
132
+ prompts = [prompt, prompt_2]
133
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
134
+ text_inputs = tokenizer(
135
+ prompt,
136
+ padding="max_length",
137
+ max_length=tokenizer.model_max_length,
138
+ truncation=True,
139
+ return_tensors="pt",
140
+ )
141
+
142
+ text_input_ids = text_inputs.input_ids
143
+
144
+ prompt_embeds = text_encoder(
145
+ text_input_ids.to(device), output_hidden_states=True
146
+ )
147
+
148
+ # We are only ALWAYS interested in the pooled output of the final text encoder
149
+ pooled_prompt_embeds = prompt_embeds[0]
150
+ prompt_embeds = prompt_embeds.hidden_states[-2]
151
+ prompt_embeds_list.append(prompt_embeds)
152
+
153
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
154
+
155
+ # get unconditional embeddings for classifier free guidance
156
+ zero_out_negative_prompt = True
157
+ negative_prompt_embeds = None
158
+ negative_pooled_prompt_embeds = None
159
+
160
+ if do_classifier_free_guidance and zero_out_negative_prompt:
161
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
162
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
163
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
164
+ negative_prompt = ""
165
+ negative_prompt_2 = negative_prompt
166
+
167
+ # normalize str to list
168
+ negative_prompt = (
169
+ batch_size * [negative_prompt]
170
+ if isinstance(negative_prompt, str)
171
+ else negative_prompt
172
+ )
173
+ negative_prompt_2 = (
174
+ batch_size * [negative_prompt_2]
175
+ if isinstance(negative_prompt_2, str)
176
+ else negative_prompt_2
177
+ )
178
+
179
+ uncond_tokens: List[str]
180
+ if prompt is not None and type(prompt) is not type(negative_prompt):
181
+ raise TypeError(
182
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
183
+ f" {type(prompt)}."
184
+ )
185
+ elif batch_size != len(negative_prompt):
186
+ raise ValueError(
187
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
188
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
189
+ " the batch size of `prompt`."
190
+ )
191
+ else:
192
+ uncond_tokens = [negative_prompt, negative_prompt_2]
193
+
194
+ negative_prompt_embeds_list = []
195
+ for negative_prompt, tokenizer, text_encoder in zip(
196
+ uncond_tokens, tokenizers, text_encoders
197
+ ):
198
+ max_length = prompt_embeds.shape[1]
199
+ uncond_input = tokenizer(
200
+ negative_prompt,
201
+ padding="max_length",
202
+ max_length=max_length,
203
+ truncation=True,
204
+ return_tensors="pt",
205
+ )
206
+
207
+ negative_prompt_embeds = text_encoder(
208
+ uncond_input.input_ids.to(device),
209
+ output_hidden_states=True,
210
+ )
211
+ # We are only ALWAYS interested in the pooled output of the final text encoder
212
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
213
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
214
+
215
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
216
+
217
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
218
+
219
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
220
+
221
+ bs_embed, seq_len, _ = prompt_embeds.shape
222
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
223
+ prompt_embeds = prompt_embeds.repeat(1, 1, 1)
224
+ prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1)
225
+
226
+ if do_classifier_free_guidance:
227
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
228
+ seq_len = negative_prompt_embeds.shape[1]
229
+
230
+ if self.text_encoder_2 is not None:
231
+ negative_prompt_embeds = negative_prompt_embeds.to(
232
+ dtype=self.text_encoder_2.dtype, device=device
233
+ )
234
+ else:
235
+ negative_prompt_embeds = negative_prompt_embeds.to(
236
+ dtype=self.unet.dtype, device=device
237
+ )
238
+
239
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, 1, 1)
240
+ negative_prompt_embeds = negative_prompt_embeds.view(
241
+ batch_size * 1, seq_len, -1
242
+ )
243
+
244
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1)
245
+ if do_classifier_free_guidance:
246
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
247
+ 1, 1
248
+ ).view(bs_embed * 1, -1)
249
+
250
+ return (
251
+ prompt_embeds,
252
+ negative_prompt_embeds,
253
+ pooled_prompt_embeds,
254
+ negative_pooled_prompt_embeds,
255
+ )
256
+
257
+ def check_inputs(
258
+ self,
259
+ prompt_embeds,
260
+ negative_prompt_embeds,
261
+ pooled_prompt_embeds,
262
+ negative_pooled_prompt_embeds,
263
+ image,
264
+ controlnet_conditioning_scale=1.0,
265
+ ):
266
+ if prompt_embeds is None:
267
+ raise ValueError(
268
+ "Provide `prompt_embeds`. Cannot leave `prompt_embeds` undefined."
269
+ )
270
+
271
+ if negative_prompt_embeds is None:
272
+ raise ValueError(
273
+ "Provide `negative_prompt_embeds`. Cannot leave `negative_prompt_embeds` undefined."
274
+ )
275
+
276
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
277
+ raise ValueError(
278
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
279
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
280
+ f" {negative_prompt_embeds.shape}."
281
+ )
282
+
283
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
284
+ raise ValueError(
285
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
286
+ )
287
+
288
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
289
+ raise ValueError(
290
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
291
+ )
292
+
293
+ # Check `image`
294
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
295
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
296
+ )
297
+ if (
298
+ isinstance(self.controlnet, ControlNetModel_Union)
299
+ or is_compiled
300
+ and isinstance(self.controlnet._orig_mod, ControlNetModel_Union)
301
+ ):
302
+ if not isinstance(image, PIL.Image.Image):
303
+ raise TypeError(
304
+ f"image must be passed and has to be a PIL image, but is {type(image)}"
305
+ )
306
+
307
+ else:
308
+ assert False
309
+
310
+ # Check `controlnet_conditioning_scale`
311
+ if (
312
+ isinstance(self.controlnet, ControlNetModel_Union)
313
+ or is_compiled
314
+ and isinstance(self.controlnet._orig_mod, ControlNetModel_Union)
315
+ ):
316
+ if not isinstance(controlnet_conditioning_scale, float):
317
+ raise TypeError(
318
+ "For single controlnet: `controlnet_conditioning_scale` must be type `float`."
319
+ )
320
+ else:
321
+ assert False
322
+
323
+ def prepare_image(self, image, device, dtype, do_classifier_free_guidance=False):
324
+ image = self.control_image_processor.preprocess(image).to(dtype=torch.float32)
325
+
326
+ image_batch_size = image.shape[0]
327
+
328
+ image = image.repeat_interleave(image_batch_size, dim=0)
329
+ image = image.to(device=device, dtype=dtype)
330
+
331
+ if do_classifier_free_guidance:
332
+ image = torch.cat([image] * 2)
333
+
334
+ return image
335
+
336
+ def prepare_latents(
337
+ self, batch_size, num_channels_latents, height, width, dtype, device
338
+ ):
339
+ shape = (
340
+ batch_size,
341
+ num_channels_latents,
342
+ int(height) // self.vae_scale_factor,
343
+ int(width) // self.vae_scale_factor,
344
+ )
345
+
346
+ latents = randn_tensor(shape, device=device, dtype=dtype)
347
+
348
+ # scale the initial noise by the standard deviation required by the scheduler
349
+ latents = latents * self.scheduler.init_noise_sigma
350
+ return latents
351
+
352
+ @property
353
+ def guidance_scale(self):
354
+ return self._guidance_scale
355
+
356
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
357
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
358
+ # corresponds to doing no classifier free guidance.
359
+ @property
360
+ def do_classifier_free_guidance(self):
361
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
362
+
363
+ @property
364
+ def num_timesteps(self):
365
+ return self._num_timesteps
366
+
367
+ @torch.no_grad()
368
+ def __call__(
369
+ self,
370
+ prompt_embeds: torch.Tensor,
371
+ negative_prompt_embeds: torch.Tensor,
372
+ pooled_prompt_embeds: torch.Tensor,
373
+ negative_pooled_prompt_embeds: torch.Tensor,
374
+ image: PipelineImageInput = None,
375
+ num_inference_steps: int = 8,
376
+ guidance_scale: float = 1.5,
377
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
378
+ ):
379
+ # 1. Check inputs. Raise error if not correct
380
+ self.check_inputs(
381
+ prompt_embeds,
382
+ negative_prompt_embeds,
383
+ pooled_prompt_embeds,
384
+ negative_pooled_prompt_embeds,
385
+ image,
386
+ controlnet_conditioning_scale,
387
+ )
388
+
389
+ self._guidance_scale = guidance_scale
390
+
391
+ # 2. Define call parameters
392
+ batch_size = 1
393
+ device = self._execution_device
394
+
395
+ # 4. Prepare image
396
+ if isinstance(self.controlnet, ControlNetModel_Union):
397
+ image = self.prepare_image(
398
+ image=image,
399
+ device=device,
400
+ dtype=self.controlnet.dtype,
401
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
402
+ )
403
+ height, width = image.shape[-2:]
404
+ else:
405
+ assert False
406
+
407
+ # 5. Prepare timesteps
408
+ timesteps, num_inference_steps = retrieve_timesteps(
409
+ self.scheduler, num_inference_steps, device
410
+ )
411
+ self._num_timesteps = len(timesteps)
412
+
413
+ # 6. Prepare latent variables
414
+ num_channels_latents = self.unet.config.in_channels
415
+ latents = self.prepare_latents(
416
+ batch_size,
417
+ num_channels_latents,
418
+ height,
419
+ width,
420
+ prompt_embeds.dtype,
421
+ device,
422
+ )
423
+
424
+ # 7 Prepare added time ids & embeddings
425
+ add_text_embeds = pooled_prompt_embeds
426
+
427
+ add_time_ids = negative_add_time_ids = torch.tensor(
428
+ image.shape[-2:] + torch.Size([0, 0]) + image.shape[-2:]
429
+ ).unsqueeze(0)
430
+
431
+ if self.do_classifier_free_guidance:
432
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
433
+ add_text_embeds = torch.cat(
434
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
435
+ )
436
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
437
+
438
+ prompt_embeds = prompt_embeds.to(device)
439
+ add_text_embeds = add_text_embeds.to(device)
440
+ add_time_ids = add_time_ids.to(device).repeat(batch_size, 1)
441
+
442
+ controlnet_image_list = [0, 0, 0, 0, 0, 0, image, 0]
443
+ union_control_type = (
444
+ torch.Tensor([0, 0, 0, 0, 0, 0, 1, 0])
445
+ .to(device, dtype=prompt_embeds.dtype)
446
+ .repeat(batch_size * 2, 1)
447
+ )
448
+
449
+ added_cond_kwargs = {
450
+ "text_embeds": add_text_embeds,
451
+ "time_ids": add_time_ids,
452
+ "control_type": union_control_type,
453
+ }
454
+
455
+ controlnet_prompt_embeds = prompt_embeds
456
+ controlnet_added_cond_kwargs = added_cond_kwargs
457
+
458
+ # 8. Denoising loop
459
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
460
+
461
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
462
+ for i, t in enumerate(timesteps):
463
+ # expand the latents if we are doing classifier free guidance
464
+ latent_model_input = (
465
+ torch.cat([latents] * 2)
466
+ if self.do_classifier_free_guidance
467
+ else latents
468
+ )
469
+ latent_model_input = self.scheduler.scale_model_input(
470
+ latent_model_input, t
471
+ )
472
+
473
+ # controlnet(s) inference
474
+ control_model_input = latent_model_input
475
+
476
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
477
+ control_model_input,
478
+ t,
479
+ encoder_hidden_states=controlnet_prompt_embeds,
480
+ controlnet_cond_list=controlnet_image_list,
481
+ conditioning_scale=controlnet_conditioning_scale,
482
+ guess_mode=False,
483
+ added_cond_kwargs=controlnet_added_cond_kwargs,
484
+ return_dict=False,
485
+ )
486
+
487
+ # predict the noise residual
488
+ noise_pred = self.unet(
489
+ latent_model_input,
490
+ t,
491
+ encoder_hidden_states=prompt_embeds,
492
+ timestep_cond=None,
493
+ cross_attention_kwargs={},
494
+ down_block_additional_residuals=down_block_res_samples,
495
+ mid_block_additional_residual=mid_block_res_sample,
496
+ added_cond_kwargs=added_cond_kwargs,
497
+ return_dict=False,
498
+ )[0]
499
+
500
+ # perform guidance
501
+ if self.do_classifier_free_guidance:
502
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
503
+ noise_pred = noise_pred_uncond + guidance_scale * (
504
+ noise_pred_text - noise_pred_uncond
505
+ )
506
+
507
+ # compute the previous noisy sample x_t -> x_t-1
508
+ latents = self.scheduler.step(
509
+ noise_pred, t, latents, return_dict=False
510
+ )[0]
511
+
512
+ if i == 2:
513
+ prompt_embeds = prompt_embeds[-1:]
514
+ add_text_embeds = add_text_embeds[-1:]
515
+ add_time_ids = add_time_ids[-1:]
516
+ union_control_type = union_control_type[-1:]
517
+
518
+ added_cond_kwargs = {
519
+ "text_embeds": add_text_embeds,
520
+ "time_ids": add_time_ids,
521
+ "control_type": union_control_type,
522
+ }
523
+
524
+ controlnet_prompt_embeds = prompt_embeds
525
+ controlnet_added_cond_kwargs = added_cond_kwargs
526
+
527
+ image = image[-1:]
528
+ controlnet_image_list = [0, 0, 0, 0, 0, 0, image, 0]
529
+
530
+ self._guidance_scale = 0.0
531
+
532
+ if i == len(timesteps) - 1 or (
533
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
534
+ ):
535
+ progress_bar.update()
536
+ yield latents_to_rgb(latents)
537
+
538
+ latents = latents / self.vae.config.scaling_factor
539
+ image = self.vae.decode(latents, return_dict=False)[0]
540
+ image = self.image_processor.postprocess(image)[0]
541
+
542
+ # Offload all models
543
+ self.maybe_free_model_hooks()
544
+
545
+ yield image