PhanDuy commited on
Commit
899139e
·
verified ·
1 Parent(s): a63a2f3

Upload sdxl_unet2_trt.py

Browse files
Files changed (1) hide show
  1. sdxl_unet2_trt.py +1386 -0
sdxl_unet2_trt.py ADDED
@@ -0,0 +1,1386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from transformers import (
20
+ CLIPImageProcessor,
21
+ CLIPTextModel,
22
+ CLIPTextModelWithProjection,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ )
26
+
27
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
28
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
29
+ from diffusers.loaders import (
30
+ FromSingleFileMixin,
31
+ IPAdapterMixin,
32
+ StableDiffusionXLLoraLoaderMixin,
33
+ TextualInversionLoaderMixin,
34
+ )
35
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
36
+ from diffusers.models.attention_processor import (
37
+ AttnProcessor2_0,
38
+ FusedAttnProcessor2_0,
39
+ XFormersAttnProcessor,
40
+ )
41
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
42
+ from diffusers.schedulers import KarrasDiffusionSchedulers
43
+ from diffusers.utils import (
44
+ USE_PEFT_BACKEND,
45
+ deprecate,
46
+ is_invisible_watermark_available,
47
+ is_torch_xla_available,
48
+ logging,
49
+ replace_example_docstring,
50
+ scale_lora_layers,
51
+ unscale_lora_layers,
52
+ )
53
+ from diffusers.utils.torch_utils import randn_tensor
54
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
55
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
56
+
57
+ import numpy as np
58
+ import torch
59
+ import onnxruntime as ort
60
+
61
+ # from tritonclient.http import InferenceServerClient, InferInput
62
+ # import tritonclient.http as httpclient
63
+
64
+ # triton_client = InferenceServerClient(url='localhost:1234')
65
+
66
+ if is_invisible_watermark_available():
67
+ from .watermark import StableDiffusionXLWatermarker
68
+
69
+ if is_torch_xla_available():
70
+ import torch_xla.core.xla_model as xm
71
+
72
+ XLA_AVAILABLE = True
73
+ else:
74
+ XLA_AVAILABLE = False
75
+
76
+
77
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
78
+
79
+ EXAMPLE_DOC_STRING = """
80
+ Examples:
81
+ ```py
82
+ >>> import torch
83
+ >>> from diffusers import StableDiffusionXLPipeline
84
+
85
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
86
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
87
+ ... )
88
+ >>> pipe = pipe.to("cuda")
89
+
90
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
91
+ >>> image = pipe(prompt).images[0]
92
+ ```
93
+ """
94
+
95
+
96
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
97
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
98
+ r"""
99
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
100
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
101
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
102
+
103
+ Args:
104
+ noise_cfg (`torch.Tensor`):
105
+ The predicted noise tensor for the guided diffusion process.
106
+ noise_pred_text (`torch.Tensor`):
107
+ The predicted noise tensor for the text-guided diffusion process.
108
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
109
+ A rescale factor applied to the noise predictions.
110
+
111
+ Returns:
112
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
113
+ """
114
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
115
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
116
+ # rescale the results from guidance (fixes overexposure)
117
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
118
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
119
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
120
+ return noise_cfg
121
+
122
+
123
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
124
+ def retrieve_timesteps(
125
+ scheduler,
126
+ num_inference_steps: Optional[int] = None,
127
+ device: Optional[Union[str, torch.device]] = None,
128
+ timesteps: Optional[List[int]] = None,
129
+ sigmas: Optional[List[float]] = None,
130
+ **kwargs,
131
+ ):
132
+ r"""
133
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
134
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
135
+
136
+ Args:
137
+ scheduler (`SchedulerMixin`):
138
+ The scheduler to get timesteps from.
139
+ num_inference_steps (`int`):
140
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
141
+ must be `None`.
142
+ device (`str` or `torch.device`, *optional*):
143
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
144
+ timesteps (`List[int]`, *optional*):
145
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
146
+ `num_inference_steps` and `sigmas` must be `None`.
147
+ sigmas (`List[float]`, *optional*):
148
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
149
+ `num_inference_steps` and `timesteps` must be `None`.
150
+
151
+ Returns:
152
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
153
+ second element is the number of inference steps.
154
+ """
155
+ if timesteps is not None and sigmas is not None:
156
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
157
+ if timesteps is not None:
158
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
159
+ if not accepts_timesteps:
160
+ raise ValueError(
161
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
162
+ f" timestep schedules. Please check whether you are using the correct scheduler."
163
+ )
164
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
165
+ timesteps = scheduler.timesteps
166
+ num_inference_steps = len(timesteps)
167
+ elif sigmas is not None:
168
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
169
+ if not accept_sigmas:
170
+ raise ValueError(
171
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
172
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
173
+ )
174
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
175
+ timesteps = scheduler.timesteps
176
+ num_inference_steps = len(timesteps)
177
+ else:
178
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
179
+ timesteps = scheduler.timesteps
180
+ return timesteps, num_inference_steps
181
+
182
+
183
+ class MyStableDiffusionXLPipeline(
184
+ DiffusionPipeline,
185
+ StableDiffusionMixin,
186
+ FromSingleFileMixin,
187
+ StableDiffusionXLLoraLoaderMixin,
188
+ TextualInversionLoaderMixin,
189
+ IPAdapterMixin,
190
+ ):
191
+ r"""
192
+ Pipeline for text-to-image generation using Stable Diffusion XL.
193
+
194
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
195
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
196
+
197
+ The pipeline also inherits the following loading methods:
198
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
199
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
200
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
201
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
202
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
203
+
204
+ Args:
205
+ vae ([`AutoencoderKL`]):
206
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
207
+ text_encoder ([`CLIPTextModel`]):
208
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
209
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
210
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
211
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
212
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
213
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
214
+ specifically the
215
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
216
+ variant.
217
+ tokenizer (`CLIPTokenizer`):
218
+ Tokenizer of class
219
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
220
+ tokenizer_2 (`CLIPTokenizer`):
221
+ Second Tokenizer of class
222
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
223
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
224
+ scheduler ([`SchedulerMixin`]):
225
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
226
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
227
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
228
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
229
+ `stabilityai/stable-diffusion-xl-base-1-0`.
230
+ add_watermarker (`bool`, *optional*):
231
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
232
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
233
+ watermarker will be used.
234
+ """
235
+
236
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
237
+ _optional_components = [
238
+ "tokenizer",
239
+ "tokenizer_2",
240
+ "text_encoder",
241
+ "text_encoder_2",
242
+ "image_encoder",
243
+ "feature_extractor",
244
+ ]
245
+ _callback_tensor_inputs = [
246
+ "latents",
247
+ "prompt_embeds",
248
+ "add_text_embeds",
249
+ "add_time_ids",
250
+ ]
251
+
252
+ def __init__(
253
+ self,
254
+ vae: AutoencoderKL,
255
+ text_encoder: CLIPTextModel,
256
+ text_encoder_2: CLIPTextModelWithProjection,
257
+ tokenizer: CLIPTokenizer,
258
+ tokenizer_2: CLIPTokenizer,
259
+ unet: UNet2DConditionModel,
260
+ scheduler: KarrasDiffusionSchedulers,
261
+ image_encoder: CLIPVisionModelWithProjection = None,
262
+ feature_extractor: CLIPImageProcessor = None,
263
+ force_zeros_for_empty_prompt: bool = True,
264
+ add_watermarker: Optional[bool] = None,
265
+ ):
266
+ super().__init__()
267
+
268
+ self.register_modules(
269
+ vae=vae,
270
+ text_encoder=text_encoder,
271
+ text_encoder_2=text_encoder_2,
272
+ tokenizer=tokenizer,
273
+ tokenizer_2=tokenizer_2,
274
+ unet=unet,
275
+ scheduler=scheduler,
276
+ image_encoder=image_encoder,
277
+ feature_extractor=feature_extractor,
278
+ )
279
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
280
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
281
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
282
+
283
+ self.default_sample_size = self.unet.config.sample_size
284
+
285
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
286
+
287
+ if add_watermarker:
288
+ self.watermark = StableDiffusionXLWatermarker()
289
+ else:
290
+ self.watermark = None
291
+
292
+ def encode_prompt(
293
+ self,
294
+ prompt: str,
295
+ prompt_2: Optional[str] = None,
296
+ device: Optional[torch.device] = None,
297
+ num_images_per_prompt: int = 1,
298
+ do_classifier_free_guidance: bool = True,
299
+ negative_prompt: Optional[str] = None,
300
+ negative_prompt_2: Optional[str] = None,
301
+ prompt_embeds: Optional[torch.Tensor] = None,
302
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
303
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
304
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
305
+ lora_scale: Optional[float] = None,
306
+ clip_skip: Optional[int] = None,
307
+ ):
308
+ r"""
309
+ Encodes the prompt into text encoder hidden states.
310
+
311
+ Args:
312
+ prompt (`str` or `List[str]`, *optional*):
313
+ prompt to be encoded
314
+ prompt_2 (`str` or `List[str]`, *optional*):
315
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
316
+ used in both text-encoders
317
+ device: (`torch.device`):
318
+ torch device
319
+ num_images_per_prompt (`int`):
320
+ number of images that should be generated per prompt
321
+ do_classifier_free_guidance (`bool`):
322
+ whether to use classifier free guidance or not
323
+ negative_prompt (`str` or `List[str]`, *optional*):
324
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
325
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
326
+ less than `1`).
327
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
328
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
329
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
330
+ prompt_embeds (`torch.Tensor`, *optional*):
331
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
332
+ provided, text embeddings will be generated from `prompt` input argument.
333
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
334
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
335
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
336
+ argument.
337
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
338
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
339
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
340
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
341
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
342
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
343
+ input argument.
344
+ lora_scale (`float`, *optional*):
345
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
346
+ clip_skip (`int`, *optional*):
347
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
348
+ the output of the pre-final layer will be used for computing the prompt embeddings.
349
+ """
350
+ device = device or self._execution_device
351
+
352
+ # set lora scale so that monkey patched LoRA
353
+ # function of text encoder can correctly access it
354
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
355
+ self._lora_scale = lora_scale
356
+
357
+ # dynamically adjust the LoRA scale
358
+ if self.text_encoder is not None:
359
+ if not USE_PEFT_BACKEND:
360
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
361
+ else:
362
+ scale_lora_layers(self.text_encoder, lora_scale)
363
+
364
+ if self.text_encoder_2 is not None:
365
+ if not USE_PEFT_BACKEND:
366
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
367
+ else:
368
+ scale_lora_layers(self.text_encoder_2, lora_scale)
369
+
370
+ prompt = [prompt] if isinstance(prompt, str) else prompt
371
+
372
+ if prompt is not None:
373
+ batch_size = len(prompt)
374
+ else:
375
+ batch_size = prompt_embeds.shape[0]
376
+
377
+ # Define tokenizers and text encoders
378
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
379
+ text_encoders = (
380
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
381
+ )
382
+
383
+ if prompt_embeds is None:
384
+ prompt_2 = prompt_2 or prompt
385
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
386
+
387
+ # textual inversion: process multi-vector tokens if necessary
388
+ prompt_embeds_list = []
389
+ prompts = [prompt, prompt_2]
390
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
391
+ if isinstance(self, TextualInversionLoaderMixin):
392
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
393
+
394
+ text_inputs = tokenizer(
395
+ prompt,
396
+ padding="max_length",
397
+ max_length=tokenizer.model_max_length,
398
+ truncation=True,
399
+ return_tensors="pt",
400
+ )
401
+
402
+ text_input_ids = text_inputs.input_ids
403
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
404
+
405
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
406
+ text_input_ids, untruncated_ids
407
+ ):
408
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
409
+ logger.warning(
410
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
411
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
412
+ )
413
+
414
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
415
+
416
+ # We are only ALWAYS interested in the pooled output of the final text encoder
417
+ pooled_prompt_embeds = prompt_embeds[0]
418
+ if clip_skip is None:
419
+ prompt_embeds = prompt_embeds.hidden_states[-2]
420
+ else:
421
+ # "2" because SDXL always indexes from the penultimate layer.
422
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
423
+
424
+ prompt_embeds_list.append(prompt_embeds)
425
+
426
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
427
+
428
+ # get unconditional embeddings for classifier free guidance
429
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
430
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
431
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
432
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
433
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
434
+ negative_prompt = negative_prompt or ""
435
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
436
+
437
+ # normalize str to list
438
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
439
+ negative_prompt_2 = (
440
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
441
+ )
442
+
443
+ uncond_tokens: List[str]
444
+ if prompt is not None and type(prompt) is not type(negative_prompt):
445
+ raise TypeError(
446
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
447
+ f" {type(prompt)}."
448
+ )
449
+ elif batch_size != len(negative_prompt):
450
+ raise ValueError(
451
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
452
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
453
+ " the batch size of `prompt`."
454
+ )
455
+ else:
456
+ uncond_tokens = [negative_prompt, negative_prompt_2]
457
+
458
+ negative_prompt_embeds_list = []
459
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
460
+ if isinstance(self, TextualInversionLoaderMixin):
461
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
462
+
463
+ max_length = prompt_embeds.shape[1]
464
+ uncond_input = tokenizer(
465
+ negative_prompt,
466
+ padding="max_length",
467
+ max_length=max_length,
468
+ truncation=True,
469
+ return_tensors="pt",
470
+ )
471
+
472
+ negative_prompt_embeds = text_encoder(
473
+ uncond_input.input_ids.to(device),
474
+ output_hidden_states=True,
475
+ )
476
+ # We are only ALWAYS interested in the pooled output of the final text encoder
477
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
478
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
479
+
480
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
481
+
482
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
483
+
484
+ if self.text_encoder_2 is not None:
485
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
486
+ else:
487
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
488
+
489
+ bs_embed, seq_len, _ = prompt_embeds.shape
490
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
491
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
492
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
493
+
494
+ if do_classifier_free_guidance:
495
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
496
+ seq_len = negative_prompt_embeds.shape[1]
497
+
498
+ if self.text_encoder_2 is not None:
499
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
500
+ else:
501
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
502
+
503
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
504
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
505
+
506
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
507
+ bs_embed * num_images_per_prompt, -1
508
+ )
509
+ if do_classifier_free_guidance:
510
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
511
+ bs_embed * num_images_per_prompt, -1
512
+ )
513
+
514
+ if self.text_encoder is not None:
515
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
516
+ # Retrieve the original scale by scaling back the LoRA layers
517
+ unscale_lora_layers(self.text_encoder, lora_scale)
518
+
519
+ if self.text_encoder_2 is not None:
520
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
521
+ # Retrieve the original scale by scaling back the LoRA layers
522
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
523
+
524
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
525
+
526
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
527
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
528
+ dtype = next(self.image_encoder.parameters()).dtype
529
+
530
+ if not isinstance(image, torch.Tensor):
531
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
532
+
533
+ image = image.to(device=device, dtype=dtype)
534
+ if output_hidden_states:
535
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
536
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
537
+ uncond_image_enc_hidden_states = self.image_encoder(
538
+ torch.zeros_like(image), output_hidden_states=True
539
+ ).hidden_states[-2]
540
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
541
+ num_images_per_prompt, dim=0
542
+ )
543
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
544
+ else:
545
+ image_embeds = self.image_encoder(image).image_embeds
546
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
547
+ uncond_image_embeds = torch.zeros_like(image_embeds)
548
+
549
+ return image_embeds, uncond_image_embeds
550
+
551
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
552
+ def prepare_ip_adapter_image_embeds(
553
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
554
+ ):
555
+ image_embeds = []
556
+ if do_classifier_free_guidance:
557
+ negative_image_embeds = []
558
+ if ip_adapter_image_embeds is None:
559
+ if not isinstance(ip_adapter_image, list):
560
+ ip_adapter_image = [ip_adapter_image]
561
+
562
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
563
+ raise ValueError(
564
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
565
+ )
566
+
567
+ for single_ip_adapter_image, image_proj_layer in zip(
568
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
569
+ ):
570
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
571
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
572
+ single_ip_adapter_image, device, 1, output_hidden_state
573
+ )
574
+
575
+ image_embeds.append(single_image_embeds[None, :])
576
+ if do_classifier_free_guidance:
577
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
578
+ else:
579
+ for single_image_embeds in ip_adapter_image_embeds:
580
+ if do_classifier_free_guidance:
581
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
582
+ negative_image_embeds.append(single_negative_image_embeds)
583
+ image_embeds.append(single_image_embeds)
584
+
585
+ ip_adapter_image_embeds = []
586
+ for i, single_image_embeds in enumerate(image_embeds):
587
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
588
+ if do_classifier_free_guidance:
589
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
590
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
591
+
592
+ single_image_embeds = single_image_embeds.to(device=device)
593
+ ip_adapter_image_embeds.append(single_image_embeds)
594
+
595
+ return ip_adapter_image_embeds
596
+
597
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
598
+ def prepare_extra_step_kwargs(self, generator, eta):
599
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
600
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
601
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
602
+ # and should be between [0, 1]
603
+
604
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
605
+ extra_step_kwargs = {}
606
+ if accepts_eta:
607
+ extra_step_kwargs["eta"] = eta
608
+
609
+ # check if the scheduler accepts generator
610
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
611
+ if accepts_generator:
612
+ extra_step_kwargs["generator"] = generator
613
+ return extra_step_kwargs
614
+
615
+ def check_inputs(
616
+ self,
617
+ prompt,
618
+ prompt_2,
619
+ height,
620
+ width,
621
+ callback_steps,
622
+ negative_prompt=None,
623
+ negative_prompt_2=None,
624
+ prompt_embeds=None,
625
+ negative_prompt_embeds=None,
626
+ pooled_prompt_embeds=None,
627
+ negative_pooled_prompt_embeds=None,
628
+ ip_adapter_image=None,
629
+ ip_adapter_image_embeds=None,
630
+ callback_on_step_end_tensor_inputs=None,
631
+ ):
632
+ if height % 8 != 0 or width % 8 != 0:
633
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
634
+
635
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
636
+ raise ValueError(
637
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
638
+ f" {type(callback_steps)}."
639
+ )
640
+
641
+ if callback_on_step_end_tensor_inputs is not None and not all(
642
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
643
+ ):
644
+ raise ValueError(
645
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
646
+ )
647
+
648
+ if prompt is not None and prompt_embeds is not None:
649
+ raise ValueError(
650
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
651
+ " only forward one of the two."
652
+ )
653
+ elif prompt_2 is not None and prompt_embeds is not None:
654
+ raise ValueError(
655
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
656
+ " only forward one of the two."
657
+ )
658
+ elif prompt is None and prompt_embeds is None:
659
+ raise ValueError(
660
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
661
+ )
662
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
663
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
664
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
665
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
666
+
667
+ if negative_prompt is not None and negative_prompt_embeds is not None:
668
+ raise ValueError(
669
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
670
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
671
+ )
672
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
673
+ raise ValueError(
674
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
675
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
676
+ )
677
+
678
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
679
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
680
+ raise ValueError(
681
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
682
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
683
+ f" {negative_prompt_embeds.shape}."
684
+ )
685
+
686
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
687
+ raise ValueError(
688
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
689
+ )
690
+
691
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
692
+ raise ValueError(
693
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
694
+ )
695
+
696
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
697
+ raise ValueError(
698
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
699
+ )
700
+
701
+ if ip_adapter_image_embeds is not None:
702
+ if not isinstance(ip_adapter_image_embeds, list):
703
+ raise ValueError(
704
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
705
+ )
706
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
707
+ raise ValueError(
708
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
709
+ )
710
+
711
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
712
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
713
+ shape = (
714
+ batch_size,
715
+ num_channels_latents,
716
+ int(height) // self.vae_scale_factor,
717
+ int(width) // self.vae_scale_factor,
718
+ )
719
+ if isinstance(generator, list) and len(generator) != batch_size:
720
+ raise ValueError(
721
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
722
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
723
+ )
724
+
725
+ if latents is None:
726
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
727
+ else:
728
+ latents = latents.to(device)
729
+
730
+ # scale the initial noise by the standard deviation required by the scheduler
731
+ latents = latents * self.scheduler.init_noise_sigma
732
+ return latents
733
+
734
+ def _get_add_time_ids(
735
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
736
+ ):
737
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
738
+
739
+ passed_add_embed_dim = (
740
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
741
+ )
742
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
743
+
744
+ if expected_add_embed_dim != passed_add_embed_dim:
745
+ raise ValueError(
746
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
747
+ )
748
+
749
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
750
+ return add_time_ids
751
+
752
+ def upcast_vae(self):
753
+ dtype = self.vae.dtype
754
+ self.vae.to(dtype=torch.float32)
755
+ use_torch_2_0_or_xformers = isinstance(
756
+ self.vae.decoder.mid_block.attentions[0].processor,
757
+ (
758
+ AttnProcessor2_0,
759
+ XFormersAttnProcessor,
760
+ FusedAttnProcessor2_0,
761
+ ),
762
+ )
763
+ # if xformers or torch_2_0 is used attention block does not need
764
+ # to be in float32 which can save lots of memory
765
+ if use_torch_2_0_or_xformers:
766
+ self.vae.post_quant_conv.to(dtype)
767
+ self.vae.decoder.conv_in.to(dtype)
768
+ self.vae.decoder.mid_block.to(dtype)
769
+
770
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
771
+ def get_guidance_scale_embedding(
772
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
773
+ ) -> torch.Tensor:
774
+ """
775
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
776
+
777
+ Args:
778
+ w (`torch.Tensor`):
779
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
780
+ embedding_dim (`int`, *optional*, defaults to 512):
781
+ Dimension of the embeddings to generate.
782
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
783
+ Data type of the generated embeddings.
784
+
785
+ Returns:
786
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
787
+ """
788
+ assert len(w.shape) == 1
789
+ w = w * 1000.0
790
+
791
+ half_dim = embedding_dim // 2
792
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
793
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
794
+ emb = w.to(dtype)[:, None] * emb[None, :]
795
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
796
+ if embedding_dim % 2 == 1: # zero pad
797
+ emb = torch.nn.functional.pad(emb, (0, 1))
798
+ assert emb.shape == (w.shape[0], embedding_dim)
799
+ return emb
800
+
801
+ @property
802
+ def guidance_scale(self):
803
+ return self._guidance_scale
804
+
805
+ @property
806
+ def guidance_rescale(self):
807
+ return self._guidance_rescale
808
+
809
+ @property
810
+ def clip_skip(self):
811
+ return self._clip_skip
812
+
813
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
814
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
815
+ # corresponds to doing no classifier free guidance.
816
+ @property
817
+ def do_classifier_free_guidance(self):
818
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
819
+
820
+ @property
821
+ def cross_attention_kwargs(self):
822
+ return self._cross_attention_kwargs
823
+
824
+ @property
825
+ def denoising_end(self):
826
+ return self._denoising_end
827
+
828
+ @property
829
+ def num_timesteps(self):
830
+ return self._num_timesteps
831
+
832
+ @property
833
+ def interrupt(self):
834
+ return self._interrupt
835
+
836
+ @torch.no_grad()
837
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
838
+ def __call__(
839
+ self,
840
+ prompt: Union[str, List[str]] = None,
841
+ prompt_2: Optional[Union[str, List[str]]] = None,
842
+ height: Optional[int] = None,
843
+ width: Optional[int] = None,
844
+ num_inference_steps: int = 1,
845
+ timesteps: List[int] = None,
846
+ sigmas: List[float] = None,
847
+ denoising_end: Optional[float] = None,
848
+ guidance_scale: float = 5.0,
849
+ negative_prompt: Optional[Union[str, List[str]]] = None,
850
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
851
+ num_images_per_prompt: Optional[int] = 1,
852
+ eta: float = 0.0,
853
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
854
+ latents: Optional[torch.Tensor] = None,
855
+ prompt_embeds: Optional[torch.Tensor] = None,
856
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
857
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
858
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
859
+ ip_adapter_image: Optional[PipelineImageInput] = None,
860
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
861
+ output_type: Optional[str] = "pil",
862
+ return_dict: bool = True,
863
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
864
+ guidance_rescale: float = 0.0,
865
+ original_size: Optional[Tuple[int, int]] = None,
866
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
867
+ target_size: Optional[Tuple[int, int]] = None,
868
+ negative_original_size: Optional[Tuple[int, int]] = None,
869
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
870
+ negative_target_size: Optional[Tuple[int, int]] = None,
871
+ clip_skip: Optional[int] = None,
872
+ callback_on_step_end: Optional[
873
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
874
+ ] = None,
875
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
876
+ **kwargs,
877
+ ):
878
+ r"""
879
+ Function invoked when calling the pipeline for generation.
880
+
881
+ Args:
882
+ prompt (`str` or `List[str]`, *optional*):
883
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
884
+ instead.
885
+ prompt_2 (`str` or `List[str]`, *optional*):
886
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
887
+ used in both text-encoders
888
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
889
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
890
+ Anything below 512 pixels won't work well for
891
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
892
+ and checkpoints that are not specifically fine-tuned on low resolutions.
893
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
894
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
895
+ Anything below 512 pixels won't work well for
896
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
897
+ and checkpoints that are not specifically fine-tuned on low resolutions.
898
+ num_inference_steps (`int`, *optional*, defaults to 50):
899
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
900
+ expense of slower inference.
901
+ timesteps (`List[int]`, *optional*):
902
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
903
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
904
+ passed will be used. Must be in descending order.
905
+ sigmas (`List[float]`, *optional*):
906
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
907
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
908
+ will be used.
909
+ denoising_end (`float`, *optional*):
910
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
911
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
912
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
913
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
914
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
915
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
916
+ guidance_scale (`float`, *optional*, defaults to 5.0):
917
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
918
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
919
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
920
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
921
+ usually at the expense of lower image quality.
922
+ negative_prompt (`str` or `List[str]`, *optional*):
923
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
924
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
925
+ less than `1`).
926
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
927
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
928
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
929
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
930
+ The number of images to generate per prompt.
931
+ eta (`float`, *optional*, defaults to 0.0):
932
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
933
+ [`schedulers.DDIMScheduler`], will be ignored for others.
934
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
935
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
936
+ to make generation deterministic.
937
+ latents (`torch.Tensor`, *optional*):
938
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
939
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
940
+ tensor will ge generated by sampling using the supplied random `generator`.
941
+ prompt_embeds (`torch.Tensor`, *optional*):
942
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
943
+ provided, text embeddings will be generated from `prompt` input argument.
944
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
945
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
946
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
947
+ argument.
948
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
949
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
950
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
951
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
952
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
953
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
954
+ input argument.
955
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
956
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
957
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
958
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
959
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
960
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
961
+ output_type (`str`, *optional*, defaults to `"pil"`):
962
+ The output format of the generate image. Choose between
963
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
964
+ return_dict (`bool`, *optional*, defaults to `True`):
965
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
966
+ of a plain tuple.
967
+ cross_attention_kwargs (`dict`, *optional*):
968
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
969
+ `self.processor` in
970
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
971
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
972
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
973
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
974
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
975
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
976
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
977
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
978
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
979
+ explained in section 2.2 of
980
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
981
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
982
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
983
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
984
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
985
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
986
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
987
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
988
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
989
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
990
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
991
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
992
+ micro-conditioning as explained in section 2.2 of
993
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
994
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
995
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
996
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
997
+ micro-conditioning as explained in section 2.2 of
998
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
999
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1000
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1001
+ To negatively condition the generation process based on a target image resolution. It should be as same
1002
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1003
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1004
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1005
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1006
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1007
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1008
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1009
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1010
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1011
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1012
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1013
+ `._callback_tensor_inputs` attribute of your pipeline class.
1014
+
1015
+ Examples:
1016
+
1017
+ Returns:
1018
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
1019
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1020
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1021
+ """
1022
+
1023
+ callback = kwargs.pop("callback", None)
1024
+ callback_steps = kwargs.pop("callback_steps", None)
1025
+
1026
+ if callback is not None:
1027
+ deprecate(
1028
+ "callback",
1029
+ "1.0.0",
1030
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1031
+ )
1032
+ if callback_steps is not None:
1033
+ deprecate(
1034
+ "callback_steps",
1035
+ "1.0.0",
1036
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1037
+ )
1038
+
1039
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1040
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1041
+
1042
+ # 0. Default height and width to unet
1043
+ height = height or self.default_sample_size * self.vae_scale_factor
1044
+ width = width or self.default_sample_size * self.vae_scale_factor
1045
+
1046
+ original_size = original_size or (height, width)
1047
+ target_size = target_size or (height, width)
1048
+
1049
+ # 1. Check inputs. Raise error if not correct
1050
+ self.check_inputs(
1051
+ prompt,
1052
+ prompt_2,
1053
+ height,
1054
+ width,
1055
+ callback_steps,
1056
+ negative_prompt,
1057
+ negative_prompt_2,
1058
+ prompt_embeds,
1059
+ negative_prompt_embeds,
1060
+ pooled_prompt_embeds,
1061
+ negative_pooled_prompt_embeds,
1062
+ ip_adapter_image,
1063
+ ip_adapter_image_embeds,
1064
+ callback_on_step_end_tensor_inputs,
1065
+ )
1066
+
1067
+ self._guidance_scale = guidance_scale
1068
+ self._guidance_rescale = guidance_rescale
1069
+ self._clip_skip = clip_skip
1070
+ self._cross_attention_kwargs = cross_attention_kwargs
1071
+ self._denoising_end = denoising_end
1072
+ self._interrupt = False
1073
+
1074
+ # 2. Define call parameters
1075
+ if prompt is not None and isinstance(prompt, str):
1076
+ batch_size = 1
1077
+ elif prompt is not None and isinstance(prompt, list):
1078
+ batch_size = len(prompt)
1079
+ else:
1080
+ batch_size = prompt_embeds.shape[0]
1081
+
1082
+ device = self._execution_device
1083
+
1084
+ # 3. Encode input prompt
1085
+ lora_scale = (
1086
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1087
+ )
1088
+
1089
+ (
1090
+ prompt_embeds,
1091
+ negative_prompt_embeds,
1092
+ pooled_prompt_embeds,
1093
+ negative_pooled_prompt_embeds,
1094
+ ) = self.encode_prompt(
1095
+ prompt=prompt,
1096
+ prompt_2=prompt_2,
1097
+ device=device,
1098
+ num_images_per_prompt=num_images_per_prompt,
1099
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1100
+ negative_prompt=negative_prompt,
1101
+ negative_prompt_2=negative_prompt_2,
1102
+ prompt_embeds=prompt_embeds,
1103
+ negative_prompt_embeds=negative_prompt_embeds,
1104
+ pooled_prompt_embeds=pooled_prompt_embeds,
1105
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1106
+ lora_scale=lora_scale,
1107
+ clip_skip=self.clip_skip,
1108
+ )
1109
+
1110
+ # 4. Prepare timesteps
1111
+ timesteps, num_inference_steps = retrieve_timesteps(
1112
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1113
+ )
1114
+
1115
+ # 5. Prepare latent variables
1116
+ num_channels_latents = self.unet.config.in_channels
1117
+ latents = self.prepare_latents(
1118
+ batch_size * num_images_per_prompt,
1119
+ num_channels_latents,
1120
+ height,
1121
+ width,
1122
+ prompt_embeds.dtype,
1123
+ device,
1124
+ generator,
1125
+ latents,
1126
+ )
1127
+
1128
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1129
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1130
+
1131
+ # 7. Prepare added time ids & embeddings
1132
+ add_text_embeds = pooled_prompt_embeds
1133
+ if self.text_encoder_2 is None:
1134
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1135
+ else:
1136
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1137
+
1138
+ add_time_ids = self._get_add_time_ids(
1139
+ original_size,
1140
+ crops_coords_top_left,
1141
+ target_size,
1142
+ dtype=prompt_embeds.dtype,
1143
+ text_encoder_projection_dim=text_encoder_projection_dim,
1144
+ )
1145
+ if negative_original_size is not None and negative_target_size is not None:
1146
+ negative_add_time_ids = self._get_add_time_ids(
1147
+ negative_original_size,
1148
+ negative_crops_coords_top_left,
1149
+ negative_target_size,
1150
+ dtype=prompt_embeds.dtype,
1151
+ text_encoder_projection_dim=text_encoder_projection_dim,
1152
+ )
1153
+ else:
1154
+ negative_add_time_ids = add_time_ids
1155
+
1156
+ if self.do_classifier_free_guidance:
1157
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1158
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1159
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1160
+
1161
+ prompt_embeds = prompt_embeds.to(device)
1162
+ add_text_embeds = add_text_embeds.to(device)
1163
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1164
+
1165
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1166
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1167
+ ip_adapter_image,
1168
+ ip_adapter_image_embeds,
1169
+ device,
1170
+ batch_size * num_images_per_prompt,
1171
+ self.do_classifier_free_guidance,
1172
+ )
1173
+
1174
+ # 8. Denoising loop
1175
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1176
+
1177
+ # 8.1 Apply denoising_end
1178
+ if (
1179
+ self.denoising_end is not None
1180
+ and isinstance(self.denoising_end, float)
1181
+ and self.denoising_end > 0
1182
+ and self.denoising_end < 1
1183
+ ):
1184
+ discrete_timestep_cutoff = int(
1185
+ round(
1186
+ self.scheduler.config.num_train_timesteps
1187
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1188
+ )
1189
+ )
1190
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1191
+ timesteps = timesteps[:num_inference_steps]
1192
+
1193
+ # 9. Optionally get Guidance Scale Embedding
1194
+ timestep_cond = None
1195
+ if self.unet.config.time_cond_proj_dim is not None:
1196
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1197
+ timestep_cond = self.get_guidance_scale_embedding(
1198
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1199
+ ).to(device=device, dtype=latents.dtype)
1200
+
1201
+ self._num_timesteps = len(timesteps)
1202
+ with self.progress_bar(total=1) as progress_bar:
1203
+ for i, t in enumerate(timesteps):
1204
+ if self.interrupt:
1205
+ continue
1206
+
1207
+ # expand the latents if we are doing classifier free guidance
1208
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1209
+ # latent_model_input = latents
1210
+
1211
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1212
+
1213
+ # predict the noise residual
1214
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1215
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1216
+ added_cond_kwargs["image_embeds"] = image_embeds
1217
+
1218
+ noise_pred = self.unet(
1219
+ latent_model_input,
1220
+ t,
1221
+ encoder_hidden_states=prompt_embeds,
1222
+ timestep_cond=timestep_cond,
1223
+ cross_attention_kwargs=self.cross_attention_kwargs,
1224
+ added_cond_kwargs=added_cond_kwargs,
1225
+ return_dict=False,
1226
+ )[0]
1227
+
1228
+ del self.unet
1229
+ # ONNX ONNX ONNX ONNX ONNX ONNX ONNX ONNX ONNX ONNX
1230
+ session = ort.InferenceSession('/home/tiennv/trang/model_unet_quantize_full_onnx_2/unet_onnx_2/model.onnx')
1231
+
1232
+ if prompt_embeds.dtype == torch.float16:
1233
+ prompt_embeds = prompt_embeds.float()
1234
+ if latent_model_input.dtype == torch.float16:
1235
+ latent_model_input = latent_model_input.float()
1236
+ if t.dtype == torch.float16:
1237
+ t = t.float()
1238
+ # print('prompt_embeds',prompt_embeds.dtype)
1239
+ # print('latent_model_input',latent_model_input.dtype)
1240
+ # print('t',t.dtype)
1241
+
1242
+ onnx_inputs = {
1243
+ "encoder_hidden_state": prompt_embeds.cpu().numpy(),
1244
+ "sample": latent_model_input.cpu().numpy(),
1245
+ "timestep": torch.tensor([t]).cpu().numpy(),
1246
+ }
1247
+ onnx_output = session.run(None, onnx_inputs)
1248
+ # print('onnx_output', onnx_output)
1249
+ # print('onnx_output', type(onnx_output))
1250
+ onnx_output = onnx_output[0]
1251
+ # onnx_output = onnx_output.cpu()
1252
+ # print('onnx_output', onnx_output)
1253
+ # print('onnx_output', type(onnx_output))
1254
+ noise_pred = noise_pred.cpu()
1255
+ is_close = np.allclose(noise_pred.numpy(), onnx_output, atol=1e-3)
1256
+ is_close_1 = np.allclose(noise_pred.numpy(), onnx_output, atol=1e-1)
1257
+
1258
+ print("Outputs are close:", is_close)
1259
+ print("Outputs are close:", is_close_1)
1260
+ # ONNX ONNX ONNX ONNX ONNX ONNX ONNX ONNX ONNX ONNX
1261
+
1262
+ # Just for compare only
1263
+ return
1264
+
1265
+ # TRITON TRITON TRITON TRITON TRITON TRITON TRITON
1266
+ # inputs = []
1267
+
1268
+ # # Ensure `prompt_embeds` is FP32
1269
+ # prompt_embeds = prompt_embeds.to(dtype=torch.float32)
1270
+ # encoder_hidden_state_input = InferInput('encoder_hidden_state', prompt_embeds.shape, "FP32")
1271
+ # encoder_hidden_state_input.set_data_from_numpy(prompt_embeds.cpu().numpy().astype(np.float32))
1272
+ # inputs.append(encoder_hidden_state_input)
1273
+
1274
+ # # Ensure `latent_model_input` is FP32
1275
+ # latent_model_input = latent_model_input.to(dtype=torch.float32)
1276
+ # sample_input = InferInput('sample', latent_model_input.shape, "FP32")
1277
+ # sample_input.set_data_from_numpy(latent_model_input.cpu().numpy().astype(np.float32))
1278
+ # inputs.append(sample_input)
1279
+
1280
+ # # Ensure `t` is FP32
1281
+ # t = t.to(dtype=torch.float32)
1282
+ # timestep_input = InferInput('timestep', [1], "FP32")
1283
+ # timestep_input.set_data_from_numpy(t.unsqueeze(0).cpu().numpy().astype(np.float32))
1284
+ # inputs.append(timestep_input)
1285
+
1286
+ # outputs = [httpclient.InferRequestedOutput('predict_noise')]
1287
+
1288
+ # response = triton_client.infer(model_name='unet', inputs=inputs, outputs=outputs)
1289
+
1290
+ # noise_pred = response.as_numpy('predict_noise')
1291
+ # noise_pred = torch.from_numpy(noise_pred)
1292
+ # noise_pred = noise_pred.to(self.device)
1293
+ # TRITON TRITON TRITON TRITON TRITON TRITON TRITON
1294
+
1295
+ # perform guidance
1296
+ if self.do_classifier_free_guidance:
1297
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1298
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1299
+
1300
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1301
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1302
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1303
+
1304
+ # compute the previous noisy sample x_t -> x_t-1
1305
+ latents_dtype = latents.dtype
1306
+ try:
1307
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1308
+ except Exception as e:
1309
+ print("Scheduler step error:", e)
1310
+ raise
1311
+
1312
+ if latents.dtype != latents_dtype:
1313
+ if torch.backends.mps.is_available():
1314
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1315
+ latents = latents.to(latents_dtype)
1316
+
1317
+ if callback_on_step_end is not None:
1318
+ callback_kwargs = {}
1319
+ for k in callback_on_step_end_tensor_inputs:
1320
+ callback_kwargs[k] = locals()[k]
1321
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1322
+
1323
+ latents = callback_outputs.pop("latents", latents)
1324
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1325
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1326
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1327
+
1328
+ # call the callback, if provided
1329
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1330
+ progress_bar.update()
1331
+ if callback is not None and i % callback_steps == 0:
1332
+ step_idx = i // getattr(self.scheduler, "order", 1)
1333
+ callback(step_idx, t, latents)
1334
+
1335
+ if XLA_AVAILABLE:
1336
+ xm.mark_step()
1337
+
1338
+ if not output_type == "latent":
1339
+ # make sure the VAE is in float32 mode, as it overflows in float16
1340
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1341
+
1342
+ if needs_upcasting:
1343
+ self.upcast_vae()
1344
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1345
+ elif latents.dtype != self.vae.dtype:
1346
+ if torch.backends.mps.is_available():
1347
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1348
+ self.vae = self.vae.to(latents.dtype)
1349
+
1350
+ # unscale/denormalize the latents
1351
+ # denormalize with the mean and std if available and not None
1352
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1353
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1354
+ if has_latents_mean and has_latents_std:
1355
+ latents_mean = (
1356
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1357
+ )
1358
+ latents_std = (
1359
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1360
+ )
1361
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1362
+ else:
1363
+ latents = latents / self.vae.config.scaling_factor
1364
+
1365
+ image = self.vae.decode(latents, return_dict=False)[0]
1366
+
1367
+ # cast back to fp16 if needed
1368
+ if needs_upcasting:
1369
+ self.vae.to(dtype=torch.float16)
1370
+ else:
1371
+ image = latents
1372
+
1373
+ if not output_type == "latent":
1374
+ # apply watermark if available
1375
+ if self.watermark is not None:
1376
+ image = self.watermark.apply_watermark(image)
1377
+
1378
+ image = self.image_processor.postprocess(image, output_type=output_type)
1379
+
1380
+ # Offload all models
1381
+ self.maybe_free_model_hooks()
1382
+
1383
+ if not return_dict:
1384
+ return (image,)
1385
+
1386
+ return StableDiffusionXLPipelineOutput(images=image)