boomcheng commited on
Commit
ab89d18
·
verified ·
1 Parent(s): 83a2683

Upload hico_pipeline.py

Browse files
Files changed (1) hide show
  1. hico_pipeline.py +1277 -0
hico_pipeline.py ADDED
@@ -0,0 +1,1277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
24
+
25
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
26
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
+ from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
28
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
29
+ from diffusers.schedulers import KarrasDiffusionSchedulers
30
+ from diffusers.utils import (
31
+ deprecate,
32
+ is_accelerate_available,
33
+ is_accelerate_version,
34
+ is_compiled_module,
35
+ logging,
36
+ randn_tensor,
37
+ replace_example_docstring,
38
+ )
39
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
40
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
41
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
42
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
43
+ import pdb
44
+ import time
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+
49
+ EXAMPLE_DOC_STRING = """
50
+ Examples:
51
+ ```py
52
+ >>> # !pip install opencv-python transformers accelerate
53
+ >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
54
+ >>> from diffusers.utils import load_image
55
+ >>> import numpy as np
56
+ >>> import torch
57
+
58
+ >>> import cv2
59
+ >>> from PIL import Image
60
+
61
+ >>> # download an image
62
+ >>> image = load_image(
63
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
64
+ ... )
65
+ >>> image = np.array(image)
66
+
67
+ >>> # get canny image
68
+ >>> image = cv2.Canny(image, 100, 200)
69
+ >>> image = image[:, :, None]
70
+ >>> image = np.concatenate([image, image, image], axis=2)
71
+ >>> canny_image = Image.fromarray(image)
72
+
73
+ >>> # load control net and stable diffusion v1-5
74
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
75
+ >>> pipe = StableDiffusionControlNetPipeline.from_pretrained(
76
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
77
+ ... )
78
+
79
+ >>> # speed up diffusion process with faster scheduler and memory optimization
80
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
81
+ >>> # remove following line if xformers is not installed
82
+ >>> pipe.enable_xformers_memory_efficient_attention()
83
+
84
+ >>> pipe.enable_model_cpu_offload()
85
+
86
+ >>> # generate image
87
+ >>> generator = torch.manual_seed(0)
88
+ >>> image = pipe(
89
+ ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
90
+ ... ).images[0]
91
+ ```
92
+ """
93
+
94
+
95
+ class StableDiffusionControlNetMultiLayoutPipeline(
96
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
97
+ ):
98
+ r"""
99
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
100
+
101
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
102
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
103
+
104
+ The pipeline also inherits the following loading methods:
105
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
106
+
107
+ Args:
108
+ vae ([`AutoencoderKL`]):
109
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
110
+ text_encoder ([`~transformers.CLIPTextModel`]):
111
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
112
+ tokenizer ([`~transformers.CLIPTokenizer`]):
113
+ A `CLIPTokenizer` to tokenize text.
114
+ unet ([`UNet2DConditionModel`]):
115
+ A `UNet2DConditionModel` to denoise the encoded image latents.
116
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
117
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
118
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
119
+ additional conditioning.
120
+ scheduler ([`SchedulerMixin`]):
121
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
122
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
123
+ safety_checker ([`StableDiffusionSafetyChecker`]):
124
+ Classification module that estimates whether generated images could be considered offensive or harmful.
125
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
126
+ about a model's potential harms.
127
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
128
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
129
+ """
130
+ _optional_components = ["safety_checker", "feature_extractor"]
131
+
132
+ def __init__(
133
+ self,
134
+ vae: AutoencoderKL,
135
+ text_encoder: CLIPTextModel,
136
+ tokenizer: CLIPTokenizer,
137
+ unet: UNet2DConditionModel,
138
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
139
+ scheduler: KarrasDiffusionSchedulers,
140
+ safety_checker: StableDiffusionSafetyChecker,
141
+ feature_extractor: CLIPImageProcessor,
142
+ requires_safety_checker: bool = True,
143
+ ):
144
+ super().__init__()
145
+
146
+ if safety_checker is None and requires_safety_checker:
147
+ logger.warning(
148
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
149
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
150
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
151
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
152
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
153
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
154
+ )
155
+
156
+ if safety_checker is not None and feature_extractor is None:
157
+ raise ValueError(
158
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
159
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
160
+ )
161
+
162
+ if isinstance(controlnet, (list, tuple)):
163
+ controlnet = MultiControlNetModel(controlnet)
164
+
165
+ self.register_modules(
166
+ vae=vae,
167
+ text_encoder=text_encoder,
168
+ tokenizer=tokenizer,
169
+ unet=unet,
170
+ controlnet=controlnet,
171
+ scheduler=scheduler,
172
+ safety_checker=safety_checker,
173
+ feature_extractor=feature_extractor,
174
+ )
175
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
176
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
177
+ self.control_image_processor = VaeImageProcessor(
178
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
179
+ )
180
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
181
+
182
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
183
+ def enable_vae_slicing(self):
184
+ r"""
185
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
186
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
187
+ """
188
+ self.vae.enable_slicing()
189
+
190
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
191
+ def disable_vae_slicing(self):
192
+ r"""
193
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
194
+ computing decoding in one step.
195
+ """
196
+ self.vae.disable_slicing()
197
+
198
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
199
+ def enable_vae_tiling(self):
200
+ r"""
201
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
202
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
203
+ processing larger images.
204
+ """
205
+ self.vae.enable_tiling()
206
+
207
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
208
+ def disable_vae_tiling(self):
209
+ r"""
210
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
211
+ computing decoding in one step.
212
+ """
213
+ self.vae.disable_tiling()
214
+
215
+ def enable_model_cpu_offload(self, gpu_id=0):
216
+ r"""
217
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
218
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
219
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
220
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
221
+ """
222
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
223
+ from accelerate import cpu_offload_with_hook
224
+ else:
225
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
226
+
227
+ device = torch.device(f"cuda:{gpu_id}")
228
+
229
+ hook = None
230
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
231
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
232
+
233
+ if self.safety_checker is not None:
234
+ # the safety checker can offload the vae again
235
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
236
+
237
+ # control net hook has be manually offloaded as it alternates with unet
238
+ cpu_offload_with_hook(self.controlnet, device)
239
+
240
+ # We'll offload the last model manually.
241
+ self.final_offload_hook = hook
242
+
243
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
244
+ def _encode_prompt(
245
+ self,
246
+ prompt,
247
+ device,
248
+ num_images_per_prompt,
249
+ do_classifier_free_guidance,
250
+ negative_prompt=None,
251
+ prompt_embeds: Optional[torch.FloatTensor] = None,
252
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
253
+ lora_scale: Optional[float] = None,
254
+ ):
255
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
256
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
257
+
258
+ prompt_embeds_tuple = self.encode_prompt(
259
+ prompt=prompt,
260
+ device=device,
261
+ num_images_per_prompt=num_images_per_prompt,
262
+ do_classifier_free_guidance=do_classifier_free_guidance,
263
+ negative_prompt=negative_prompt,
264
+ prompt_embeds=prompt_embeds,
265
+ negative_prompt_embeds=negative_prompt_embeds,
266
+ lora_scale=lora_scale,
267
+ )
268
+
269
+ # concatenate for backwards comp
270
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
271
+
272
+ return prompt_embeds
273
+
274
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
275
+ def encode_prompt(
276
+ self,
277
+ prompt,
278
+ device,
279
+ num_images_per_prompt,
280
+ do_classifier_free_guidance,
281
+ negative_prompt=None,
282
+ prompt_embeds: Optional[torch.FloatTensor] = None,
283
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
284
+ lora_scale: Optional[float] = None,
285
+ ):
286
+ r"""
287
+ Encodes the prompt into text encoder hidden states.
288
+
289
+ Args:
290
+ prompt (`str` or `List[str]`, *optional*):
291
+ prompt to be encoded
292
+ device: (`torch.device`):
293
+ torch device
294
+ num_images_per_prompt (`int`):
295
+ number of images that should be generated per prompt
296
+ do_classifier_free_guidance (`bool`):
297
+ whether to use classifier free guidance or not
298
+ negative_prompt (`str` or `List[str]`, *optional*):
299
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
300
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
301
+ less than `1`).
302
+ prompt_embeds (`torch.FloatTensor`, *optional*):
303
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
304
+ provided, text embeddings will be generated from `prompt` input argument.
305
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
306
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
307
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
308
+ argument.
309
+ lora_scale (`float`, *optional*):
310
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
311
+ """
312
+ # set lora scale so that monkey patched LoRA
313
+ # function of text encoder can correctly access it
314
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
315
+ self._lora_scale = lora_scale
316
+
317
+ # dynamically adjust the LoRA scale
318
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
319
+
320
+ if prompt is not None and isinstance(prompt, str):
321
+ batch_size = 1
322
+ elif prompt is not None and isinstance(prompt, list):
323
+ batch_size = len(prompt)
324
+ else:
325
+ batch_size = prompt_embeds.shape[0]
326
+
327
+ if prompt_embeds is None:
328
+ # textual inversion: procecss multi-vector tokens if necessary
329
+ if isinstance(self, TextualInversionLoaderMixin):
330
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
331
+
332
+ text_inputs = self.tokenizer(
333
+ prompt,
334
+ padding="max_length",
335
+ max_length=self.tokenizer.model_max_length,
336
+ truncation=True,
337
+ return_tensors="pt",
338
+ )
339
+ text_input_ids = text_inputs.input_ids
340
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
341
+
342
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
343
+ text_input_ids, untruncated_ids
344
+ ):
345
+ removed_text = self.tokenizer.batch_decode(
346
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
347
+ )
348
+ logger.warning(
349
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
350
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
351
+ )
352
+
353
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
354
+ attention_mask = text_inputs.attention_mask.to(device)
355
+ else:
356
+ attention_mask = None
357
+
358
+ prompt_embeds = self.text_encoder(
359
+ text_input_ids.to(device),
360
+ attention_mask=attention_mask,
361
+ )
362
+ prompt_embeds = prompt_embeds[0]
363
+
364
+ if self.text_encoder is not None:
365
+ prompt_embeds_dtype = self.text_encoder.dtype
366
+ elif self.unet is not None:
367
+ prompt_embeds_dtype = self.unet.dtype
368
+ else:
369
+ prompt_embeds_dtype = prompt_embeds.dtype
370
+
371
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
372
+
373
+ bs_embed, seq_len, _ = prompt_embeds.shape
374
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
375
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
376
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
377
+
378
+ # get unconditional embeddings for classifier free guidance
379
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
380
+ uncond_tokens: List[str]
381
+ if negative_prompt is None:
382
+ uncond_tokens = [""] * batch_size
383
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
384
+ raise TypeError(
385
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
386
+ f" {type(prompt)}."
387
+ )
388
+ elif isinstance(negative_prompt, str):
389
+ uncond_tokens = [negative_prompt]
390
+ elif batch_size != len(negative_prompt):
391
+ raise ValueError(
392
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
393
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
394
+ " the batch size of `prompt`."
395
+ )
396
+ else:
397
+ uncond_tokens = negative_prompt
398
+
399
+ # textual inversion: procecss multi-vector tokens if necessary
400
+ if isinstance(self, TextualInversionLoaderMixin):
401
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
402
+
403
+ max_length = prompt_embeds.shape[1]
404
+ uncond_input = self.tokenizer(
405
+ uncond_tokens,
406
+ padding="max_length",
407
+ max_length=max_length,
408
+ truncation=True,
409
+ return_tensors="pt",
410
+ )
411
+
412
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
413
+ attention_mask = uncond_input.attention_mask.to(device)
414
+ else:
415
+ attention_mask = None
416
+
417
+ negative_prompt_embeds = self.text_encoder(
418
+ uncond_input.input_ids.to(device),
419
+ attention_mask=attention_mask,
420
+ )
421
+ negative_prompt_embeds = negative_prompt_embeds[0]
422
+
423
+ if do_classifier_free_guidance:
424
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
425
+ seq_len = negative_prompt_embeds.shape[1]
426
+
427
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
428
+
429
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
430
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
431
+
432
+ return prompt_embeds, negative_prompt_embeds
433
+
434
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
435
+ def run_safety_checker(self, image, device, dtype):
436
+ if self.safety_checker is None:
437
+ has_nsfw_concept = None
438
+ else:
439
+ if torch.is_tensor(image):
440
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
441
+ else:
442
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
443
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
444
+ image, has_nsfw_concept = self.safety_checker(
445
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
446
+ )
447
+ return image, has_nsfw_concept
448
+
449
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
450
+ def decode_latents(self, latents):
451
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
452
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
453
+
454
+ latents = 1 / self.vae.config.scaling_factor * latents
455
+ image = self.vae.decode(latents, return_dict=False)[0]
456
+ image = (image / 2 + 0.5).clamp(0, 1)
457
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
458
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
459
+ return image
460
+
461
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
462
+ def prepare_extra_step_kwargs(self, generator, eta):
463
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
464
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
465
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
466
+ # and should be between [0, 1]
467
+
468
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
469
+ extra_step_kwargs = {}
470
+ if accepts_eta:
471
+ extra_step_kwargs["eta"] = eta
472
+
473
+ # check if the scheduler accepts generator
474
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
475
+ if accepts_generator:
476
+ extra_step_kwargs["generator"] = generator
477
+ return extra_step_kwargs
478
+
479
+ def check_inputs(
480
+ self,
481
+ prompt,
482
+ image,
483
+ callback_steps,
484
+ negative_prompt=None,
485
+ prompt_embeds=None,
486
+ negative_prompt_embeds=None,
487
+ controlnet_conditioning_scale=1.0,
488
+ control_guidance_start=0.0,
489
+ control_guidance_end=1.0,
490
+ ):
491
+ if (callback_steps is None) or (
492
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
493
+ ):
494
+ raise ValueError(
495
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
496
+ f" {type(callback_steps)}."
497
+ )
498
+
499
+ if prompt is not None and prompt_embeds is not None:
500
+ raise ValueError(
501
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
502
+ " only forward one of the two."
503
+ )
504
+ elif prompt is None and prompt_embeds is None:
505
+ raise ValueError(
506
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
507
+ )
508
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
509
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
510
+
511
+ if negative_prompt is not None and negative_prompt_embeds is not None:
512
+ raise ValueError(
513
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
514
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
515
+ )
516
+
517
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
518
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
519
+ raise ValueError(
520
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
521
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
522
+ f" {negative_prompt_embeds.shape}."
523
+ )
524
+
525
+ # `prompt` needs more sophisticated handling when there are multiple
526
+ # conditionings.
527
+ if isinstance(self.controlnet, MultiControlNetModel):
528
+ if isinstance(prompt, list):
529
+ logger.warning(
530
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
531
+ " prompts. The conditionings will be fixed across the prompts."
532
+ )
533
+
534
+ # Check `image`
535
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
536
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
537
+ )
538
+ if (
539
+ isinstance(self.controlnet, ControlNetModel)
540
+ or is_compiled
541
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
542
+ ):
543
+ self.check_image(image, prompt, prompt_embeds)
544
+ elif (
545
+ isinstance(self.controlnet, MultiControlNetModel)
546
+ or is_compiled
547
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
548
+ ):
549
+ if not isinstance(image, list):
550
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
551
+
552
+ # When `image` is a nested list:
553
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
554
+ elif any(isinstance(i, list) for i in image):
555
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
556
+ #elif len(image) != len(self.controlnet.nets):
557
+ # raise ValueError(
558
+ # f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
559
+ # )
560
+
561
+ for image_ in image:
562
+ self.check_image(image_, prompt, prompt_embeds)
563
+ else:
564
+ assert False
565
+
566
+ # Check `controlnet_conditioning_scale`
567
+ if (
568
+ isinstance(self.controlnet, ControlNetModel)
569
+ or is_compiled
570
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
571
+ ):
572
+ if not isinstance(controlnet_conditioning_scale, float):
573
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
574
+ elif (
575
+ isinstance(self.controlnet, MultiControlNetModel)
576
+ or is_compiled
577
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
578
+ ):
579
+ if isinstance(controlnet_conditioning_scale, list):
580
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
581
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
582
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
583
+ self.controlnet.nets
584
+ ):
585
+ raise ValueError(
586
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
587
+ " the same length as the number of controlnets"
588
+ )
589
+ else:
590
+ assert False
591
+
592
+ if not isinstance(control_guidance_start, (tuple, list)):
593
+ control_guidance_start = [control_guidance_start]
594
+
595
+ if not isinstance(control_guidance_end, (tuple, list)):
596
+ control_guidance_end = [control_guidance_end]
597
+
598
+ if len(control_guidance_start) != len(control_guidance_end):
599
+ raise ValueError(
600
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
601
+ )
602
+
603
+ if isinstance(self.controlnet, MultiControlNetModel):
604
+ if len(control_guidance_start) != len(self.controlnet.nets):
605
+ raise ValueError(
606
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
607
+ )
608
+
609
+ for start, end in zip(control_guidance_start, control_guidance_end):
610
+ if start >= end:
611
+ raise ValueError(
612
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
613
+ )
614
+ if start < 0.0:
615
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
616
+ if end > 1.0:
617
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
618
+
619
+ def check_image(self, image, prompt, prompt_embeds):
620
+ image_is_pil = isinstance(image, PIL.Image.Image)
621
+ image_is_tensor = isinstance(image, torch.Tensor)
622
+ image_is_np = isinstance(image, np.ndarray)
623
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
624
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
625
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
626
+
627
+ if (
628
+ not image_is_pil
629
+ and not image_is_tensor
630
+ and not image_is_np
631
+ and not image_is_pil_list
632
+ and not image_is_tensor_list
633
+ and not image_is_np_list
634
+ ):
635
+ raise TypeError(
636
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
637
+ )
638
+
639
+ if image_is_pil:
640
+ image_batch_size = 1
641
+ else:
642
+ image_batch_size = len(image)
643
+
644
+ if prompt is not None and isinstance(prompt, str):
645
+ prompt_batch_size = 1
646
+ elif prompt is not None and isinstance(prompt, list):
647
+ prompt_batch_size = len(prompt)
648
+ elif prompt_embeds is not None:
649
+ prompt_batch_size = prompt_embeds.shape[0]
650
+
651
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
652
+ raise ValueError(
653
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
654
+ )
655
+
656
+ def prepare_image(
657
+ self,
658
+ image,
659
+ width,
660
+ height,
661
+ batch_size,
662
+ num_images_per_prompt,
663
+ device,
664
+ dtype,
665
+ do_classifier_free_guidance=False,
666
+ guess_mode=False,
667
+ ):
668
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
669
+ image_batch_size = image.shape[0]
670
+
671
+ if image_batch_size == 1:
672
+ repeat_by = batch_size
673
+ else:
674
+ # image batch size is the same as prompt batch size
675
+ repeat_by = num_images_per_prompt
676
+
677
+ image = image.repeat_interleave(repeat_by, dim=0)
678
+
679
+ image = image.to(device=device, dtype=dtype)
680
+
681
+ if do_classifier_free_guidance and not guess_mode:
682
+ image = torch.cat([image] * 2)
683
+
684
+ return image
685
+
686
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
687
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
688
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
689
+ if isinstance(generator, list) and len(generator) != batch_size:
690
+ raise ValueError(
691
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
692
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
693
+ )
694
+
695
+ if latents is None:
696
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
697
+ else:
698
+ latents = latents.to(device)
699
+ #torch.save(latents, '/home/jovyan/myh-data-ceph-0/code/layout_chengbo/diffusers-layout/examples/controlnet/nogen_or.pt')
700
+ # scale the initial noise by the standard deviation required by the scheduler
701
+ latents = latents * self.scheduler.init_noise_sigma
702
+ #torch.save(latents, '/home/jovyan/myh-data-ceph-0/code/layout_chengbo/diffusers-layout/examples/controlnet/nogen_or_scale.pt')
703
+ return latents
704
+
705
+ @torch.no_grad()
706
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
707
+ def __call__(
708
+ self,
709
+ prompt: Union[str, List[str]] = None,
710
+ layo_prompt: Union[str, List[str]] = None,
711
+ #layo_cond: Union[torch.FloatTensor] = None,
712
+ fuse_type:str = 'sum',
713
+ image: PipelineImageInput = None,
714
+ height: Optional[int] = None,
715
+ width: Optional[int] = None,
716
+ num_inference_steps: int = 50,
717
+ guidance_scale: float = 7.5,
718
+ negative_prompt: Optional[Union[str, List[str]]] = None,
719
+ num_images_per_prompt: Optional[int] = 1,
720
+ eta: float = 0.0,
721
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
722
+ latents: Optional[torch.FloatTensor] = None,
723
+ prompt_embeds: Optional[torch.FloatTensor] = None,
724
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
725
+ output_type: Optional[str] = "pil",
726
+ return_dict: bool = True,
727
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
728
+ callback_steps: int = 1,
729
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
730
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
731
+ guess_mode: bool = False,
732
+ control_guidance_start: Union[float, List[float]] = 0.0,
733
+ control_guidance_end: Union[float, List[float]] = 1.0,
734
+ ):
735
+ r"""
736
+ The call function to the pipeline for generation.
737
+
738
+ Args:
739
+ prompt (`str` or `List[str]`, *optional*):
740
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
741
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
742
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
743
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
744
+ specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
745
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
746
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
747
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
748
+ input to a single ControlNet.
749
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
750
+ The height in pixels of the generated image.
751
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
752
+ The width in pixels of the generated image.
753
+ num_inference_steps (`int`, *optional*, defaults to 50):
754
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
755
+ expense of slower inference.
756
+ guidance_scale (`float`, *optional*, defaults to 7.5):
757
+ A higher guidance scale value encourages the model to generate images closely linked to the text
758
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
759
+ negative_prompt (`str` or `List[str]`, *optional*):
760
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
761
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
762
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
763
+ The number of images to generate per prompt.
764
+ eta (`float`, *optional*, defaults to 0.0):
765
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
766
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
767
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
768
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
769
+ generation deterministic.
770
+ latents (`torch.FloatTensor`, *optional*):
771
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
772
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
773
+ tensor is generated by sampling using the supplied random `generator`.
774
+ prompt_embeds (`torch.FloatTensor`, *optional*):
775
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
776
+ provided, text embeddings are generated from the `prompt` input argument.
777
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
778
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
779
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
780
+ output_type (`str`, *optional*, defaults to `"pil"`):
781
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
782
+ return_dict (`bool`, *optional*, defaults to `True`):
783
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
784
+ plain tuple.
785
+ callback (`Callable`, *optional*):
786
+ A function that calls every `callback_steps` steps during inference. The function is called with the
787
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
788
+ callback_steps (`int`, *optional*, defaults to 1):
789
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
790
+ every step.
791
+ cross_attention_kwargs (`dict`, *optional*):
792
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
793
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
794
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
795
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
796
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
797
+ the corresponding scale as a list.
798
+ guess_mode (`bool`, *optional*, defaults to `False`):
799
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
800
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
801
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
802
+ The percentage of total steps at which the ControlNet starts applying.
803
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
804
+ The percentage of total steps at which the ControlNet stops applying.
805
+
806
+ Examples:
807
+
808
+ Returns:
809
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
810
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
811
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
812
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
813
+ "not-safe-for-work" (nsfw) content.
814
+ """
815
+ #pdb.set_trace()
816
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
817
+
818
+ # align format for control guidance
819
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
820
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
821
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
822
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
823
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
824
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
825
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
826
+ control_guidance_end
827
+ ]
828
+
829
+ # 1. Check inputs. Raise error if not correct
830
+ self.check_inputs(
831
+ prompt,
832
+ image,
833
+ callback_steps,
834
+ negative_prompt,
835
+ prompt_embeds,
836
+ negative_prompt_embeds,
837
+ controlnet_conditioning_scale,
838
+ control_guidance_start,
839
+ control_guidance_end,
840
+ )
841
+ #pdb.set_trace()
842
+
843
+ # 2. Define call parameters
844
+ if prompt is not None and isinstance(prompt, str):
845
+ batch_size = 1
846
+ elif prompt is not None and isinstance(prompt, list):
847
+ batch_size = len(prompt)
848
+ else:
849
+ batch_size = prompt_embeds.shape[0]
850
+
851
+ device = self._execution_device
852
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
853
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
854
+ # corresponds to doing no classifier free guidance.
855
+ do_classifier_free_guidance = guidance_scale >= 1.0
856
+
857
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
858
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
859
+
860
+ global_pool_conditions = (
861
+ controlnet.config.global_pool_conditions
862
+ if isinstance(controlnet, ControlNetModel)
863
+ else controlnet.nets[0].config.global_pool_conditions
864
+ )
865
+ guess_mode = guess_mode or global_pool_conditions
866
+
867
+ # 3. Encode input prompt
868
+ text_encoder_lora_scale = (
869
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
870
+ )
871
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
872
+ prompt,
873
+ device,
874
+ num_images_per_prompt,
875
+ do_classifier_free_guidance,
876
+ negative_prompt,
877
+ prompt_embeds=prompt_embeds,
878
+ negative_prompt_embeds=negative_prompt_embeds,
879
+ lora_scale=text_encoder_lora_scale,
880
+ )
881
+ # For classifier free guidance, we need to do two forward passes.
882
+ # Here we concatenate the unconditional and text embeddings into a single batch
883
+ # to avoid doing two forward passes
884
+
885
+ ################################# modify boom ##############################
886
+ #pdb.set_trace()
887
+ # 3-1. Encoder sub prompt
888
+ list_prompt_embeds = []
889
+ for dot_prompt in layo_prompt:
890
+ text_inputs = self.tokenizer(
891
+ dot_prompt,
892
+ padding="max_length",
893
+ max_length=self.tokenizer.model_max_length,
894
+ truncation=True,
895
+ return_tensors="pt",
896
+ )
897
+ text_input_ids = text_inputs.input_ids
898
+
899
+ dot_prompt_embeds = self.text_encoder(
900
+ text_input_ids.to(device),
901
+ )
902
+ dot_prompt_embeds = dot_prompt_embeds[0]
903
+ list_prompt_embeds.append(dot_prompt_embeds)
904
+ bs_prompt_embeds = torch.stack(list_prompt_embeds).squeeze() # bs, 77, 768
905
+ # t1 = time.time()
906
+ # text_inputs = self.tokenizer(
907
+ # layo_prompt,
908
+ # padding="max_length",
909
+ # max_length=self.tokenizer.model_max_length,
910
+ # truncation=True,
911
+ # return_tensors="pt",
912
+ # )
913
+ # text_input_ids = text_inputs.input_ids.to(device)
914
+ # bs_prompt_embeds = self.text_encoder(text_input_ids)[0]
915
+ # t2 = time.time()
916
+
917
+
918
+ if do_classifier_free_guidance:
919
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
920
+
921
+ # 4. Prepare image
922
+ if isinstance(controlnet, ControlNetModel):
923
+ image = self.prepare_image(
924
+ image=image,
925
+ width=width,
926
+ height=height,
927
+ batch_size=batch_size * num_images_per_prompt,
928
+ num_images_per_prompt=num_images_per_prompt,
929
+ device=device,
930
+ dtype=controlnet.dtype,
931
+ do_classifier_free_guidance=do_classifier_free_guidance,
932
+ guess_mode=guess_mode,
933
+ )
934
+ height, width = image.shape[-2:]
935
+ elif isinstance(controlnet, MultiControlNetModel):
936
+ images = []
937
+
938
+ for image_ in image:
939
+ image_ = self.prepare_image(
940
+ image=image_,
941
+ width=width,
942
+ height=height,
943
+ batch_size=batch_size * num_images_per_prompt,
944
+ num_images_per_prompt=num_images_per_prompt,
945
+ device=device,
946
+ dtype=controlnet.dtype,
947
+ do_classifier_free_guidance=do_classifier_free_guidance,
948
+ guess_mode=guess_mode,
949
+ )
950
+
951
+ images.append(image_)
952
+
953
+ image = images
954
+ height, width = image[0].shape[-2:]
955
+ else:
956
+ assert False
957
+
958
+ # 5. Prepare timesteps
959
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
960
+ timesteps = self.scheduler.timesteps
961
+
962
+ # 6. Prepare latent variables
963
+ num_channels_latents = self.unet.config.in_channels
964
+ latents = self.prepare_latents(
965
+ batch_size * num_images_per_prompt,
966
+ num_channels_latents,
967
+ height,
968
+ width,
969
+ prompt_embeds.dtype,
970
+ device,
971
+ generator,
972
+ latents,
973
+ )
974
+
975
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
976
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
977
+
978
+ # 7.1 Create tensor stating which controlnets to keep
979
+ controlnet_keep = []
980
+ for i in range(len(timesteps)):
981
+ keeps = [
982
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
983
+ for s, e in zip(control_guidance_start, control_guidance_end)
984
+ ]
985
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
986
+
987
+ def fuse_mask_single_block(f_mask, f_mid_block_res_sample):
988
+ # [4, 8, 8], [4, 1280, 8, 8]
989
+ fus_feat = []
990
+ for mii in range(len(f_mask)):
991
+ mask_block = torch.masked_fill(f_mid_block_res_sample[mii], ~f_mask[mii], 0)
992
+ fus_feat.append(mask_block)
993
+ mask_fus = torch.sum(torch.stack(fus_feat), dim=0) # [1280, 8, 8]
994
+ return mask_fus
995
+
996
+ def fuse_mask_down(f_mask, f_down_block_res_samples):
997
+ # 12, [10, 320, 64, 64] -> 12, [320, 64, 64]
998
+ fus_feat = []
999
+ size_mask = f_mask.shape[-1]
1000
+ for ii in range(len(f_down_block_res_samples)):
1001
+ dot_down_block_res_samples = f_down_block_res_samples[ii]
1002
+ size_dot = dot_down_block_res_samples.shape[-1]
1003
+ bins = int(size_mask / size_dot)
1004
+ dot_mask = f_mask[:,::bins,::bins]
1005
+ dot_fuse_block = fuse_mask_single_block(dot_mask, dot_down_block_res_samples)
1006
+ fus_feat.append(dot_fuse_block)
1007
+ return fus_feat
1008
+
1009
+ #pdb.set_trace()
1010
+ # 8. Denoising loop
1011
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1012
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1013
+ for i, t in enumerate(timesteps):
1014
+ # expand the latents if we are doing classifier free guidance
1015
+
1016
+
1017
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1018
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1019
+
1020
+ # controlnet(s) inference
1021
+ if guess_mode and do_classifier_free_guidance:
1022
+ # Infer ControlNet only for the conditional batch.
1023
+ control_model_input = latents # [1, 4, 64, 64]
1024
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1025
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1026
+
1027
+ else:
1028
+ control_model_input = latent_model_input # [2, 4, 64, 64]
1029
+ controlnet_prompt_embeds = prompt_embeds # [2, 77, 768]
1030
+
1031
+
1032
+
1033
+ if isinstance(controlnet_keep[i], list):
1034
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1035
+ else:
1036
+ controlnet_cond_scale = controlnet_conditioning_scale
1037
+ if isinstance(controlnet_cond_scale, list):
1038
+ controlnet_cond_scale = controlnet_cond_scale[0]
1039
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1040
+
1041
+ """
1042
+ down_samples, mid_sample = self.controlnet(
1043
+ control_model_input, # [2, 4, 64, 64]
1044
+ t,
1045
+ encoder_hidden_states=controlnet_prompt_embeds, # [2, 77, 768]
1046
+ controlnet_cond=cond_image, # [2, 3, 512, 512]
1047
+ conditioning_scale=cond_scale,
1048
+ guess_mode=guess_mode,
1049
+ return_dict=False,
1050
+ )
1051
+ ############# save ############
1052
+ #torch.save(image[jj], "dm_image_%s_%d.pt" % (i, jj))
1053
+ #torch.save(mid_sample, "dm_mid_%s_%d.pt" % (i, jj))
1054
+ #torch.save(down_samples, "dm_down_%s_%d.pt" % (i, jj))
1055
+
1056
+ if fuse_type == "mask":
1057
+ fuse_down_samples.append(down_samples)
1058
+ fuse_mid_samples.append(mid_sample)
1059
+
1060
+ if jj == 0:
1061
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
1062
+ else:
1063
+ down_block_res_samples = [
1064
+ samples_prev + samples_curr
1065
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
1066
+ ]
1067
+ mid_block_res_sample += mid_sample
1068
+ """
1069
+
1070
+ #pdb.set_trace()
1071
+ if True:
1072
+ # infernce Time Single
1073
+ fuse_down_samples = []
1074
+ fuse_mid_samples = []
1075
+ for jj in range(len(image)):
1076
+ dot_prompt_embeds = list_prompt_embeds[jj]
1077
+ dot_prompt_embeds = torch.cat([negative_prompt_embeds, dot_prompt_embeds])
1078
+ controlnet_prompt_embeds = dot_prompt_embeds
1079
+
1080
+ cond_image = image[jj]
1081
+ #down_block_res_samples, mid_block_res_sample = self.controlnet(
1082
+ down_samples, mid_sample = self.controlnet(
1083
+ control_model_input, # [2, 4, 64, 64]
1084
+ t,
1085
+ encoder_hidden_states=controlnet_prompt_embeds, # [2, 77, 768]
1086
+ controlnet_cond=cond_image, # [2, 3, 512, 512]
1087
+ conditioning_scale=cond_scale,
1088
+ guess_mode=guess_mode,
1089
+ return_dict=False,
1090
+ )
1091
+ # ############# save ############
1092
+ # #torch.save(image[jj], "dm_image_%s_%d.pt" % (i, jj))
1093
+ # #torch.save(mid_sample, "dm_mid_%s_%d.pt" % (i, jj))
1094
+ # #torch.save(down_samples, "dm_down_%s_%d.pt" % (i, jj))
1095
+
1096
+ if fuse_type == "mask":
1097
+ fuse_down_samples.append(down_samples)
1098
+ fuse_mid_samples.append(mid_sample)
1099
+
1100
+ if jj == 0:
1101
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
1102
+ else:
1103
+ down_block_res_samples = [
1104
+ samples_prev + samples_curr
1105
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
1106
+ ]
1107
+ mid_block_res_sample += mid_sample
1108
+ if False:
1109
+ # inference Time Batch
1110
+ # BNS = len(image)
1111
+ # BNS_control_model_input = torch.repeat_interleave(control_model_input, repeats=BNS, dim=0)
1112
+ # BNS_controlnet_prompt_embeds = torch.repeat_interleave(controlnet_prompt_embeds, repeats=BNS, dim=0)
1113
+ # BNS_cond_image = torch.repeat_interleave(cond_image, repeats=BNS, dim=0)
1114
+
1115
+ # down_samples, mid_sample = self.controlnet(
1116
+ # BNS_control_model_input, # [2, 4, 64, 64]
1117
+ # t,
1118
+ # encoder_hidden_states=BNS_controlnet_prompt_embeds, # [2, 77, 768]
1119
+ # controlnet_cond=BNS_cond_image, # [2, 3, 512, 512]
1120
+ # conditioning_scale=cond_scale,
1121
+ # guess_mode=guess_mode,
1122
+ # return_dict=False,
1123
+ # )
1124
+
1125
+
1126
+ # negative_mid_sample = torch.sum(mid_sample[::2], dim=0, keepdim=True)
1127
+ # positive_mid_sample = torch.sum(mid_sample[1::2], dim=0, keepdim=True)
1128
+ # negative_down_samples = tuple(torch.sum(x[::2], dim=0, keepdim=True) for x in down_samples)
1129
+ # positive_down_samples = tuple(torch.sum(x[1::2], dim=0, keepdim=True) for x in down_samples)
1130
+ # mid_block_res_sample = torch.cat((negative_mid_sample, positive_mid_sample), dim=0)
1131
+
1132
+ # down_block_res_samples = tuple(torch.cat((neg, pos), dim=0) for neg, pos in zip(negative_down_samples, positive_down_samples))
1133
+ BNS = len(image)
1134
+
1135
+ dot_prompt_embeds_batch = []
1136
+ cond_images_batch = []
1137
+
1138
+ # cond_images_batch_n = []
1139
+ # cond_images_batch_p = []
1140
+ for jj in range(BNS):
1141
+ dot_prompt_embeds =list_prompt_embeds[jj]
1142
+ dot_prompt_embeds = torch.cat([negative_prompt_embeds, dot_prompt_embeds], dim=0)
1143
+ dot_prompt_embeds_batch.append(dot_prompt_embeds)
1144
+
1145
+ cond_image = image[jj]
1146
+ cond_images_batch.append(cond_image)
1147
+
1148
+ # cond_image_n = cond_image[0].unsqueeze(0)
1149
+ # cond_image_p = cond_image[1].unsqueeze(0)
1150
+ # cond_images_batch_n.append(cond_image_n)
1151
+ # cond_images_batch_p.append(cond_image_p)
1152
+
1153
+ # cond_images_batch_n = torch.cat(cond_images_batch_n,dim=0)
1154
+ # cond_images_batch_p = torch.cat(cond_images_batch_p,dim=0)
1155
+ # cond_images_batch= torch.cat((cond_images_batch_n,cond_images_batch_p),dim=0)
1156
+
1157
+
1158
+ dot_prompt_embeds_batch = torch.cat(dot_prompt_embeds_batch,dim=0) # [21*2, 77, 768]
1159
+ cond_images_batch = torch.cat(cond_images_batch,dim=0) # [21*2, 3, 512, 512]
1160
+ control_model_input = torch.repeat_interleave(control_model_input, repeats=BNS, dim=0)
1161
+
1162
+ # negative_prompt_embeds_test = torch.repeat_interleave(negative_prompt_embeds, repeats=BNS, dim=0)
1163
+ # dot_prompt_embeds_batch = torch.cat((negative_prompt_embeds_test,bs_prompt_embeds),dim=0)
1164
+
1165
+ down_samples, mid_sample = self.controlnet(
1166
+ control_model_input,
1167
+ t,
1168
+ encoder_hidden_states=dot_prompt_embeds_batch,
1169
+ controlnet_cond=cond_images_batch,
1170
+ conditioning_scale=cond_scale,
1171
+ guess_mode=guess_mode,
1172
+ return_dict=False,
1173
+ )
1174
+
1175
+
1176
+ #第一种
1177
+
1178
+ mid_block_res_sample = sum(torch.split(mid_sample, 2, dim=0))
1179
+ down_block_res_samples = [sum(torch.split(sample, 2, dim=0)) for sample in down_samples]
1180
+
1181
+ # negative_mid_sample = torch.sum(mid_sample[::2], dim=0, keepdim=True)
1182
+ # positive_mid_sample = torch.sum(mid_sample[1::2], dim=0, keepdim=True)
1183
+ # negative_down_samples = tuple(torch.sum(x[::2], dim=0, keepdim=True) for x in down_samples)
1184
+ # positive_down_samples = tuple(torch.sum(x[1::2], dim=0, keepdim=True) for x in down_samples)
1185
+ # mid_block_res_sample = torch.cat((negative_mid_sample, positive_mid_sample), dim=0)
1186
+ # down_block_res_samples = tuple(torch.cat((neg, pos), dim=0) for neg, pos in zip(negative_down_samples, positive_down_samples))
1187
+
1188
+ # 获取 mid_sample 的前半部分和后半部分,分别求和 第二种
1189
+ # half_size = mid_sample.size(0) // 2
1190
+ # negative_mid_sample = torch.sum(mid_sample[:half_size], dim=0, keepdim=True)
1191
+ # positive_mid_sample = torch.sum(mid_sample[half_size:], dim=0, keepdim=True)
1192
+
1193
+ # # 获取 down_samples 的每个张量的前半部分和后半部分,分别求和
1194
+ # negative_down_samples = tuple(torch.sum(x[:half_size], dim=0, keepdim=True) for x in down_samples)
1195
+ # positive_down_samples = tuple(torch.sum(x[half_size:], dim=0, keepdim=True) for x in down_samples)
1196
+
1197
+ # # 将 negative 和 positive 的结果沿第 0 维拼接
1198
+ # mid_block_res_sample = torch.cat((negative_mid_sample, positive_mid_sample), dim=0)
1199
+ # down_block_res_samples = tuple(torch.cat((neg, pos), dim=0) for neg, pos in zip(negative_down_samples, positive_down_samples))
1200
+
1201
+
1202
+
1203
+
1204
+ if fuse_type == "avg":
1205
+ mid_block_res_sample = mid_block_res_sample / len(image) # [2, 1280, 8, 8]
1206
+ down_block_res_samples = [d/len(image) for d in down_block_res_samples] # 12, [[2, 320, 64, 64], ...]
1207
+ else:
1208
+ # sum
1209
+ pass
1210
+
1211
+ # t4 = time.time()
1212
+ if guess_mode and do_classifier_free_guidance:
1213
+
1214
+ # Infered ControlNet only for the conditional batch.
1215
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1216
+ # add 0 to the unconditional batch to keep it unchanged.
1217
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1218
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1219
+
1220
+ # t5 = time.time()
1221
+ # predict the noise residual
1222
+
1223
+ noise_pred = self.unet(
1224
+ latent_model_input,
1225
+ t,
1226
+ encoder_hidden_states=prompt_embeds,
1227
+ cross_attention_kwargs=cross_attention_kwargs,
1228
+ down_block_additional_residuals=down_block_res_samples,
1229
+ mid_block_additional_residual=mid_block_res_sample,
1230
+ return_dict=False,
1231
+ )[0]
1232
+
1233
+ # perform guidance
1234
+ if do_classifier_free_guidance:
1235
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1236
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1237
+
1238
+ # compute the previous noisy sample x_t -> x_t-1
1239
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1240
+
1241
+ #torch.save(latents, "dm_latents_%s.pt" % i)
1242
+
1243
+ # call the callback, if provided
1244
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1245
+ progress_bar.update()
1246
+ if callback is not None and i % callback_steps == 0:
1247
+ callback(i, t, latents)
1248
+
1249
+ # If we do sequential model offloading, let's offload unet and controlnet
1250
+ # manually for max memory savings
1251
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1252
+ self.unet.to("cpu")
1253
+ self.controlnet.to("cpu")
1254
+ torch.cuda.empty_cache()
1255
+
1256
+ if not output_type == "latent":
1257
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1258
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1259
+ else:
1260
+ image = latents
1261
+ has_nsfw_concept = None
1262
+
1263
+ if has_nsfw_concept is None:
1264
+ do_denormalize = [True] * image.shape[0]
1265
+ else:
1266
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1267
+
1268
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1269
+
1270
+ # Offload last model to CPU
1271
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1272
+ self.final_offload_hook.offload()
1273
+
1274
+ if not return_dict:
1275
+ return (image, has_nsfw_concept)
1276
+
1277
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)