prithivMLmods commited on
Commit
dd5b015
·
verified ·
1 Parent(s): f5e2b63

Create controlnet_union.py

Browse files
Files changed (1) hide show
  1. controlnet_union.py +1108 -0
controlnet_union.py ADDED
@@ -0,0 +1,1108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Hugging Face's logo Hugging Face
3
+
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Posts
8
+ Docs
9
+ Enterprise
10
+ Pricing
11
+
12
+ Spaces:
13
+ fffiloni
14
+ /
15
+ diffusers-image-outpaint
16
+ App
17
+ Files
18
+ Community
19
+ 34
20
+ diffusers-image-outpaint
21
+ / controlnet_union.py
22
+ OzzyGT's picture
23
+ OzzyGT
24
+ HF Staff
25
+ initial app
26
+ 6405936
27
+ 7 months ago
28
+ raw
29
+ history
30
+ blame
31
+ contribute
32
+ delete
33
+ 48.9 kB
34
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
35
+ #
36
+ # Licensed under the Apache License, Version 2.0 (the "License");
37
+ # you may not use this file except in compliance with the License.
38
+ # You may obtain a copy of the License at
39
+ #
40
+ # http://www.apache.org/licenses/LICENSE-2.0
41
+ #
42
+ # Unless required by applicable law or agreed to in writing, software
43
+ # distributed under the License is distributed on an "AS IS" BASIS,
44
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
45
+ # See the License for the specific language governing permissions and
46
+ # limitations under the License.
47
+ from collections import OrderedDict
48
+ from dataclasses import dataclass
49
+ from typing import Any, Dict, List, Optional, Tuple, Union
50
+
51
+ import torch
52
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
53
+ from diffusers.loaders import FromOriginalModelMixin
54
+ from diffusers.models.attention_processor import (
55
+ ADDED_KV_ATTENTION_PROCESSORS,
56
+ CROSS_ATTENTION_PROCESSORS,
57
+ AttentionProcessor,
58
+ AttnAddedKVProcessor,
59
+ AttnProcessor,
60
+ )
61
+ from diffusers.models.embeddings import (
62
+ TextImageProjection,
63
+ TextImageTimeEmbedding,
64
+ TextTimeEmbedding,
65
+ TimestepEmbedding,
66
+ Timesteps,
67
+ )
68
+ from diffusers.models.modeling_utils import ModelMixin
69
+ from diffusers.models.unets.unet_2d_blocks import (
70
+ CrossAttnDownBlock2D,
71
+ DownBlock2D,
72
+ UNetMidBlock2DCrossAttn,
73
+ get_down_block,
74
+ )
75
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
76
+ from diffusers.utils import BaseOutput, logging
77
+ from torch import nn
78
+ from torch.nn import functional as F
79
+
80
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
81
+
82
+
83
+ # Transformer Block
84
+ # Used to exchange info between different conditions and input image
85
+ # With reference to https://github.com/TencentARC/T2I-Adapter/blob/SD/ldm/modules/encoders/adapter.py#L147
86
+ class QuickGELU(nn.Module):
87
+ def forward(self, x: torch.Tensor):
88
+ return x * torch.sigmoid(1.702 * x)
89
+
90
+
91
+ class LayerNorm(nn.LayerNorm):
92
+ """Subclass torch's LayerNorm to handle fp16."""
93
+
94
+ def forward(self, x: torch.Tensor):
95
+ orig_type = x.dtype
96
+ ret = super().forward(x)
97
+ return ret.type(orig_type)
98
+
99
+
100
+ class ResidualAttentionBlock(nn.Module):
101
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
102
+ super().__init__()
103
+
104
+ self.attn = nn.MultiheadAttention(d_model, n_head)
105
+ self.ln_1 = LayerNorm(d_model)
106
+ self.mlp = nn.Sequential(
107
+ OrderedDict(
108
+ [
109
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
110
+ ("gelu", QuickGELU()),
111
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
112
+ ]
113
+ )
114
+ )
115
+ self.ln_2 = LayerNorm(d_model)
116
+ self.attn_mask = attn_mask
117
+
118
+ def attention(self, x: torch.Tensor):
119
+ self.attn_mask = (
120
+ self.attn_mask.to(dtype=x.dtype, device=x.device)
121
+ if self.attn_mask is not None
122
+ else None
123
+ )
124
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
125
+
126
+ def forward(self, x: torch.Tensor):
127
+ x = x + self.attention(self.ln_1(x))
128
+ x = x + self.mlp(self.ln_2(x))
129
+ return x
130
+
131
+
132
+ # -----------------------------------------------------------------------------------------------------
133
+
134
+
135
+ @dataclass
136
+ class ControlNetOutput(BaseOutput):
137
+ """
138
+ The output of [`ControlNetModel`].
139
+ Args:
140
+ down_block_res_samples (`tuple[torch.Tensor]`):
141
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
142
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
143
+ used to condition the original UNet's downsampling activations.
144
+ mid_down_block_re_sample (`torch.Tensor`):
145
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
146
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
147
+ Output can be used to condition the original UNet's middle block activation.
148
+ """
149
+
150
+ down_block_res_samples: Tuple[torch.Tensor]
151
+ mid_block_res_sample: torch.Tensor
152
+
153
+
154
+ class ControlNetConditioningEmbedding(nn.Module):
155
+ """
156
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
157
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
158
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
159
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
160
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
161
+ model) to encode image-space conditions ... into feature maps ..."
162
+ """
163
+
164
+ # original setting is (16, 32, 96, 256)
165
+ def __init__(
166
+ self,
167
+ conditioning_embedding_channels: int,
168
+ conditioning_channels: int = 3,
169
+ block_out_channels: Tuple[int] = (48, 96, 192, 384),
170
+ ):
171
+ super().__init__()
172
+
173
+ self.conv_in = nn.Conv2d(
174
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
175
+ )
176
+
177
+ self.blocks = nn.ModuleList([])
178
+
179
+ for i in range(len(block_out_channels) - 1):
180
+ channel_in = block_out_channels[i]
181
+ channel_out = block_out_channels[i + 1]
182
+ self.blocks.append(
183
+ nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)
184
+ )
185
+ self.blocks.append(
186
+ nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)
187
+ )
188
+
189
+ self.conv_out = zero_module(
190
+ nn.Conv2d(
191
+ block_out_channels[-1],
192
+ conditioning_embedding_channels,
193
+ kernel_size=3,
194
+ padding=1,
195
+ )
196
+ )
197
+
198
+ def forward(self, conditioning):
199
+ embedding = self.conv_in(conditioning)
200
+ embedding = F.silu(embedding)
201
+
202
+ for block in self.blocks:
203
+ embedding = block(embedding)
204
+ embedding = F.silu(embedding)
205
+
206
+ embedding = self.conv_out(embedding)
207
+
208
+ return embedding
209
+
210
+
211
+ class ControlNetModel_Union(ModelMixin, ConfigMixin, FromOriginalModelMixin):
212
+ """
213
+ A ControlNet model.
214
+ Args:
215
+ in_channels (`int`, defaults to 4):
216
+ The number of channels in the input sample.
217
+ flip_sin_to_cos (`bool`, defaults to `True`):
218
+ Whether to flip the sin to cos in the time embedding.
219
+ freq_shift (`int`, defaults to 0):
220
+ The frequency shift to apply to the time embedding.
221
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
222
+ The tuple of downsample blocks to use.
223
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
224
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
225
+ The tuple of output channels for each block.
226
+ layers_per_block (`int`, defaults to 2):
227
+ The number of layers per block.
228
+ downsample_padding (`int`, defaults to 1):
229
+ The padding to use for the downsampling convolution.
230
+ mid_block_scale_factor (`float`, defaults to 1):
231
+ The scale factor to use for the mid block.
232
+ act_fn (`str`, defaults to "silu"):
233
+ The activation function to use.
234
+ norm_num_groups (`int`, *optional*, defaults to 32):
235
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
236
+ in post-processing.
237
+ norm_eps (`float`, defaults to 1e-5):
238
+ The epsilon to use for the normalization.
239
+ cross_attention_dim (`int`, defaults to 1280):
240
+ The dimension of the cross attention features.
241
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
242
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
243
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
244
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
245
+ encoder_hid_dim (`int`, *optional*, defaults to None):
246
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
247
+ dimension to `cross_attention_dim`.
248
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
249
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
250
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
251
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
252
+ The dimension of the attention heads.
253
+ use_linear_projection (`bool`, defaults to `False`):
254
+ class_embed_type (`str`, *optional*, defaults to `None`):
255
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
256
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
257
+ addition_embed_type (`str`, *optional*, defaults to `None`):
258
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
259
+ "text". "text" will use the `TextTimeEmbedding` layer.
260
+ num_class_embeds (`int`, *optional*, defaults to 0):
261
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
262
+ class conditioning with `class_embed_type` equal to `None`.
263
+ upcast_attention (`bool`, defaults to `False`):
264
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
265
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
266
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
267
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
268
+ `class_embed_type="projection"`.
269
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
270
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
271
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
272
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
273
+ global_pool_conditions (`bool`, defaults to `False`):
274
+ """
275
+
276
+ _supports_gradient_checkpointing = True
277
+
278
+ @register_to_config
279
+ def __init__(
280
+ self,
281
+ in_channels: int = 4,
282
+ conditioning_channels: int = 3,
283
+ flip_sin_to_cos: bool = True,
284
+ freq_shift: int = 0,
285
+ down_block_types: Tuple[str] = (
286
+ "CrossAttnDownBlock2D",
287
+ "CrossAttnDownBlock2D",
288
+ "CrossAttnDownBlock2D",
289
+ "DownBlock2D",
290
+ ),
291
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
292
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
293
+ layers_per_block: int = 2,
294
+ downsample_padding: int = 1,
295
+ mid_block_scale_factor: float = 1,
296
+ act_fn: str = "silu",
297
+ norm_num_groups: Optional[int] = 32,
298
+ norm_eps: float = 1e-5,
299
+ cross_attention_dim: int = 1280,
300
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
301
+ encoder_hid_dim: Optional[int] = None,
302
+ encoder_hid_dim_type: Optional[str] = None,
303
+ attention_head_dim: Union[int, Tuple[int]] = 8,
304
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
305
+ use_linear_projection: bool = False,
306
+ class_embed_type: Optional[str] = None,
307
+ addition_embed_type: Optional[str] = None,
308
+ addition_time_embed_dim: Optional[int] = None,
309
+ num_class_embeds: Optional[int] = None,
310
+ upcast_attention: bool = False,
311
+ resnet_time_scale_shift: str = "default",
312
+ projection_class_embeddings_input_dim: Optional[int] = None,
313
+ controlnet_conditioning_channel_order: str = "rgb",
314
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
315
+ global_pool_conditions: bool = False,
316
+ addition_embed_type_num_heads=64,
317
+ num_control_type=6,
318
+ ):
319
+ super().__init__()
320
+
321
+ # If `num_attention_heads` is not defined (which is the case for most models)
322
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
323
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
324
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
325
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
326
+ # which is why we correct for the naming here.
327
+ num_attention_heads = num_attention_heads or attention_head_dim
328
+
329
+ # Check inputs
330
+ if len(block_out_channels) != len(down_block_types):
331
+ raise ValueError(
332
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
333
+ )
334
+
335
+ if not isinstance(only_cross_attention, bool) and len(
336
+ only_cross_attention
337
+ ) != len(down_block_types):
338
+ raise ValueError(
339
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
340
+ )
341
+
342
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
343
+ down_block_types
344
+ ):
345
+ raise ValueError(
346
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
347
+ )
348
+
349
+ if isinstance(transformer_layers_per_block, int):
350
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
351
+ down_block_types
352
+ )
353
+
354
+ # input
355
+ conv_in_kernel = 3
356
+ conv_in_padding = (conv_in_kernel - 1) // 2
357
+ self.conv_in = nn.Conv2d(
358
+ in_channels,
359
+ block_out_channels[0],
360
+ kernel_size=conv_in_kernel,
361
+ padding=conv_in_padding,
362
+ )
363
+
364
+ # time
365
+ time_embed_dim = block_out_channels[0] * 4
366
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
367
+ timestep_input_dim = block_out_channels[0]
368
+ self.time_embedding = TimestepEmbedding(
369
+ timestep_input_dim,
370
+ time_embed_dim,
371
+ act_fn=act_fn,
372
+ )
373
+
374
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
375
+ encoder_hid_dim_type = "text_proj"
376
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
377
+ logger.info(
378
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
379
+ )
380
+
381
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
382
+ raise ValueError(
383
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
384
+ )
385
+
386
+ if encoder_hid_dim_type == "text_proj":
387
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
388
+ elif encoder_hid_dim_type == "text_image_proj":
389
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
390
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
391
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
392
+ self.encoder_hid_proj = TextImageProjection(
393
+ text_embed_dim=encoder_hid_dim,
394
+ image_embed_dim=cross_attention_dim,
395
+ cross_attention_dim=cross_attention_dim,
396
+ )
397
+
398
+ elif encoder_hid_dim_type is not None:
399
+ raise ValueError(
400
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
401
+ )
402
+ else:
403
+ self.encoder_hid_proj = None
404
+
405
+ # class embedding
406
+ if class_embed_type is None and num_class_embeds is not None:
407
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
408
+ elif class_embed_type == "timestep":
409
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
410
+ elif class_embed_type == "identity":
411
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
412
+ elif class_embed_type == "projection":
413
+ if projection_class_embeddings_input_dim is None:
414
+ raise ValueError(
415
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
416
+ )
417
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
418
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
419
+ # 2. it projects from an arbitrary input dimension.
420
+ #
421
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
422
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
423
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
424
+ self.class_embedding = TimestepEmbedding(
425
+ projection_class_embeddings_input_dim, time_embed_dim
426
+ )
427
+ else:
428
+ self.class_embedding = None
429
+
430
+ if addition_embed_type == "text":
431
+ if encoder_hid_dim is not None:
432
+ text_time_embedding_from_dim = encoder_hid_dim
433
+ else:
434
+ text_time_embedding_from_dim = cross_attention_dim
435
+
436
+ self.add_embedding = TextTimeEmbedding(
437
+ text_time_embedding_from_dim,
438
+ time_embed_dim,
439
+ num_heads=addition_embed_type_num_heads,
440
+ )
441
+ elif addition_embed_type == "text_image":
442
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
443
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
444
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
445
+ self.add_embedding = TextImageTimeEmbedding(
446
+ text_embed_dim=cross_attention_dim,
447
+ image_embed_dim=cross_attention_dim,
448
+ time_embed_dim=time_embed_dim,
449
+ )
450
+ elif addition_embed_type == "text_time":
451
+ self.add_time_proj = Timesteps(
452
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
453
+ )
454
+ self.add_embedding = TimestepEmbedding(
455
+ projection_class_embeddings_input_dim, time_embed_dim
456
+ )
457
+
458
+ elif addition_embed_type is not None:
459
+ raise ValueError(
460
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
461
+ )
462
+
463
+ # control net conditioning embedding
464
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
465
+ conditioning_embedding_channels=block_out_channels[0],
466
+ block_out_channels=conditioning_embedding_out_channels,
467
+ conditioning_channels=conditioning_channels,
468
+ )
469
+
470
+ # Copyright by Qi Xin(2024/07/06)
471
+ # Condition Transformer(fuse single/multi conditions with input image)
472
+ # The Condition Transformer augment the feature representation of conditions
473
+ # The overall design is somewhat like resnet. The output of Condition Transformer is used to predict a condition bias adding to the original condition feature.
474
+ # num_control_type = 6
475
+ num_trans_channel = 320
476
+ num_trans_head = 8
477
+ num_trans_layer = 1
478
+ num_proj_channel = 320
479
+ task_scale_factor = num_trans_channel**0.5
480
+
481
+ self.task_embedding = nn.Parameter(
482
+ task_scale_factor * torch.randn(num_control_type, num_trans_channel)
483
+ )
484
+ self.transformer_layes = nn.Sequential(
485
+ *[
486
+ ResidualAttentionBlock(num_trans_channel, num_trans_head)
487
+ for _ in range(num_trans_layer)
488
+ ]
489
+ )
490
+ self.spatial_ch_projs = zero_module(
491
+ nn.Linear(num_trans_channel, num_proj_channel)
492
+ )
493
+ # -----------------------------------------------------------------------------------------------------
494
+
495
+ # Copyright by Qi Xin(2024/07/06)
496
+ # Control Encoder to distinguish different control conditions
497
+ # A simple but effective module, consists of an embedding layer and a linear layer, to inject the control info to time embedding.
498
+ self.control_type_proj = Timesteps(
499
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
500
+ )
501
+ self.control_add_embedding = TimestepEmbedding(
502
+ addition_time_embed_dim * num_control_type, time_embed_dim
503
+ )
504
+ # -----------------------------------------------------------------------------------------------------
505
+
506
+ self.down_blocks = nn.ModuleList([])
507
+ self.controlnet_down_blocks = nn.ModuleList([])
508
+
509
+ if isinstance(only_cross_attention, bool):
510
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
511
+
512
+ if isinstance(attention_head_dim, int):
513
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
514
+
515
+ if isinstance(num_attention_heads, int):
516
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
517
+
518
+ # down
519
+ output_channel = block_out_channels[0]
520
+
521
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
522
+ controlnet_block = zero_module(controlnet_block)
523
+ self.controlnet_down_blocks.append(controlnet_block)
524
+
525
+ for i, down_block_type in enumerate(down_block_types):
526
+ input_channel = output_channel
527
+ output_channel = block_out_channels[i]
528
+ is_final_block = i == len(block_out_channels) - 1
529
+
530
+ down_block = get_down_block(
531
+ down_block_type,
532
+ num_layers=layers_per_block,
533
+ transformer_layers_per_block=transformer_layers_per_block[i],
534
+ in_channels=input_channel,
535
+ out_channels=output_channel,
536
+ temb_channels=time_embed_dim,
537
+ add_downsample=not is_final_block,
538
+ resnet_eps=norm_eps,
539
+ resnet_act_fn=act_fn,
540
+ resnet_groups=norm_num_groups,
541
+ cross_attention_dim=cross_attention_dim,
542
+ num_attention_heads=num_attention_heads[i],
543
+ attention_head_dim=attention_head_dim[i]
544
+ if attention_head_dim[i] is not None
545
+ else output_channel,
546
+ downsample_padding=downsample_padding,
547
+ use_linear_projection=use_linear_projection,
548
+ only_cross_attention=only_cross_attention[i],
549
+ upcast_attention=upcast_attention,
550
+ resnet_time_scale_shift=resnet_time_scale_shift,
551
+ )
552
+ self.down_blocks.append(down_block)
553
+
554
+ for _ in range(layers_per_block):
555
+ controlnet_block = nn.Conv2d(
556
+ output_channel, output_channel, kernel_size=1
557
+ )
558
+ controlnet_block = zero_module(controlnet_block)
559
+ self.controlnet_down_blocks.append(controlnet_block)
560
+
561
+ if not is_final_block:
562
+ controlnet_block = nn.Conv2d(
563
+ output_channel, output_channel, kernel_size=1
564
+ )
565
+ controlnet_block = zero_module(controlnet_block)
566
+ self.controlnet_down_blocks.append(controlnet_block)
567
+
568
+ # mid
569
+ mid_block_channel = block_out_channels[-1]
570
+
571
+ controlnet_block = nn.Conv2d(
572
+ mid_block_channel, mid_block_channel, kernel_size=1
573
+ )
574
+ controlnet_block = zero_module(controlnet_block)
575
+ self.controlnet_mid_block = controlnet_block
576
+
577
+ self.mid_block = UNetMidBlock2DCrossAttn(
578
+ transformer_layers_per_block=transformer_layers_per_block[-1],
579
+ in_channels=mid_block_channel,
580
+ temb_channels=time_embed_dim,
581
+ resnet_eps=norm_eps,
582
+ resnet_act_fn=act_fn,
583
+ output_scale_factor=mid_block_scale_factor,
584
+ resnet_time_scale_shift=resnet_time_scale_shift,
585
+ cross_attention_dim=cross_attention_dim,
586
+ num_attention_heads=num_attention_heads[-1],
587
+ resnet_groups=norm_num_groups,
588
+ use_linear_projection=use_linear_projection,
589
+ upcast_attention=upcast_attention,
590
+ )
591
+
592
+ @classmethod
593
+ def from_unet(
594
+ cls,
595
+ unet: UNet2DConditionModel,
596
+ controlnet_conditioning_channel_order: str = "rgb",
597
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
598
+ load_weights_from_unet: bool = True,
599
+ ):
600
+ r"""
601
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
602
+ Parameters:
603
+ unet (`UNet2DConditionModel`):
604
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
605
+ where applicable.
606
+ """
607
+ transformer_layers_per_block = (
608
+ unet.config.transformer_layers_per_block
609
+ if "transformer_layers_per_block" in unet.config
610
+ else 1
611
+ )
612
+ encoder_hid_dim = (
613
+ unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
614
+ )
615
+ encoder_hid_dim_type = (
616
+ unet.config.encoder_hid_dim_type
617
+ if "encoder_hid_dim_type" in unet.config
618
+ else None
619
+ )
620
+ addition_embed_type = (
621
+ unet.config.addition_embed_type
622
+ if "addition_embed_type" in unet.config
623
+ else None
624
+ )
625
+ addition_time_embed_dim = (
626
+ unet.config.addition_time_embed_dim
627
+ if "addition_time_embed_dim" in unet.config
628
+ else None
629
+ )
630
+
631
+ controlnet = cls(
632
+ encoder_hid_dim=encoder_hid_dim,
633
+ encoder_hid_dim_type=encoder_hid_dim_type,
634
+ addition_embed_type=addition_embed_type,
635
+ addition_time_embed_dim=addition_time_embed_dim,
636
+ transformer_layers_per_block=transformer_layers_per_block,
637
+ # transformer_layers_per_block=[1, 2, 5],
638
+ in_channels=unet.config.in_channels,
639
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
640
+ freq_shift=unet.config.freq_shift,
641
+ down_block_types=unet.config.down_block_types,
642
+ only_cross_attention=unet.config.only_cross_attention,
643
+ block_out_channels=unet.config.block_out_channels,
644
+ layers_per_block=unet.config.layers_per_block,
645
+ downsample_padding=unet.config.downsample_padding,
646
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
647
+ act_fn=unet.config.act_fn,
648
+ norm_num_groups=unet.config.norm_num_groups,
649
+ norm_eps=unet.config.norm_eps,
650
+ cross_attention_dim=unet.config.cross_attention_dim,
651
+ attention_head_dim=unet.config.attention_head_dim,
652
+ num_attention_heads=unet.config.num_attention_heads,
653
+ use_linear_projection=unet.config.use_linear_projection,
654
+ class_embed_type=unet.config.class_embed_type,
655
+ num_class_embeds=unet.config.num_class_embeds,
656
+ upcast_attention=unet.config.upcast_attention,
657
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
658
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
659
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
660
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
661
+ )
662
+
663
+ if load_weights_from_unet:
664
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
665
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
666
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
667
+
668
+ if controlnet.class_embedding:
669
+ controlnet.class_embedding.load_state_dict(
670
+ unet.class_embedding.state_dict()
671
+ )
672
+
673
+ controlnet.down_blocks.load_state_dict(
674
+ unet.down_blocks.state_dict(), strict=False
675
+ )
676
+ controlnet.mid_block.load_state_dict(
677
+ unet.mid_block.state_dict(), strict=False
678
+ )
679
+
680
+ return controlnet
681
+
682
+ @property
683
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
684
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
685
+ r"""
686
+ Returns:
687
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
688
+ indexed by its weight name.
689
+ """
690
+ # set recursively
691
+ processors = {}
692
+
693
+ def fn_recursive_add_processors(
694
+ name: str,
695
+ module: torch.nn.Module,
696
+ processors: Dict[str, AttentionProcessor],
697
+ ):
698
+ if hasattr(module, "get_processor"):
699
+ processors[f"{name}.processor"] = module.get_processor(
700
+ return_deprecated_lora=True
701
+ )
702
+
703
+ for sub_name, child in module.named_children():
704
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
705
+
706
+ return processors
707
+
708
+ for name, module in self.named_children():
709
+ fn_recursive_add_processors(name, module, processors)
710
+
711
+ return processors
712
+
713
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
714
+ def set_attn_processor(
715
+ self,
716
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
717
+ _remove_lora=False,
718
+ ):
719
+ r"""
720
+ Sets the attention processor to use to compute attention.
721
+ Parameters:
722
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
723
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
724
+ for **all** `Attention` layers.
725
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
726
+ processor. This is strongly recommended when setting trainable attention processors.
727
+ """
728
+ count = len(self.attn_processors.keys())
729
+
730
+ if isinstance(processor, dict) and len(processor) != count:
731
+ raise ValueError(
732
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
733
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
734
+ )
735
+
736
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
737
+ if hasattr(module, "set_processor"):
738
+ if not isinstance(processor, dict):
739
+ module.set_processor(processor, _remove_lora=_remove_lora)
740
+ else:
741
+ module.set_processor(
742
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
743
+ )
744
+
745
+ for sub_name, child in module.named_children():
746
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
747
+
748
+ for name, module in self.named_children():
749
+ fn_recursive_attn_processor(name, module, processor)
750
+
751
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
752
+ def set_default_attn_processor(self):
753
+ """
754
+ Disables custom attention processors and sets the default attention implementation.
755
+ """
756
+ if all(
757
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
758
+ for proc in self.attn_processors.values()
759
+ ):
760
+ processor = AttnAddedKVProcessor()
761
+ elif all(
762
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
763
+ for proc in self.attn_processors.values()
764
+ ):
765
+ processor = AttnProcessor()
766
+ else:
767
+ raise ValueError(
768
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
769
+ )
770
+
771
+ self.set_attn_processor(processor, _remove_lora=True)
772
+
773
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
774
+ def set_attention_slice(self, slice_size):
775
+ r"""
776
+ Enable sliced attention computation.
777
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
778
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
779
+ Args:
780
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
781
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
782
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
783
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
784
+ must be a multiple of `slice_size`.
785
+ """
786
+ sliceable_head_dims = []
787
+
788
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
789
+ if hasattr(module, "set_attention_slice"):
790
+ sliceable_head_dims.append(module.sliceable_head_dim)
791
+
792
+ for child in module.children():
793
+ fn_recursive_retrieve_sliceable_dims(child)
794
+
795
+ # retrieve number of attention layers
796
+ for module in self.children():
797
+ fn_recursive_retrieve_sliceable_dims(module)
798
+
799
+ num_sliceable_layers = len(sliceable_head_dims)
800
+
801
+ if slice_size == "auto":
802
+ # half the attention head size is usually a good trade-off between
803
+ # speed and memory
804
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
805
+ elif slice_size == "max":
806
+ # make smallest slice possible
807
+ slice_size = num_sliceable_layers * [1]
808
+
809
+ slice_size = (
810
+ num_sliceable_layers * [slice_size]
811
+ if not isinstance(slice_size, list)
812
+ else slice_size
813
+ )
814
+
815
+ if len(slice_size) != len(sliceable_head_dims):
816
+ raise ValueError(
817
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
818
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
819
+ )
820
+
821
+ for i in range(len(slice_size)):
822
+ size = slice_size[i]
823
+ dim = sliceable_head_dims[i]
824
+ if size is not None and size > dim:
825
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
826
+
827
+ # Recursively walk through all the children.
828
+ # Any children which exposes the set_attention_slice method
829
+ # gets the message
830
+ def fn_recursive_set_attention_slice(
831
+ module: torch.nn.Module, slice_size: List[int]
832
+ ):
833
+ if hasattr(module, "set_attention_slice"):
834
+ module.set_attention_slice(slice_size.pop())
835
+
836
+ for child in module.children():
837
+ fn_recursive_set_attention_slice(child, slice_size)
838
+
839
+ reversed_slice_size = list(reversed(slice_size))
840
+ for module in self.children():
841
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
842
+
843
+ def _set_gradient_checkpointing(self, module, value=False):
844
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
845
+ module.gradient_checkpointing = value
846
+
847
+ def forward(
848
+ self,
849
+ sample: torch.FloatTensor,
850
+ timestep: Union[torch.Tensor, float, int],
851
+ encoder_hidden_states: torch.Tensor,
852
+ controlnet_cond_list: torch.FloatTensor,
853
+ conditioning_scale: float = 1.0,
854
+ class_labels: Optional[torch.Tensor] = None,
855
+ timestep_cond: Optional[torch.Tensor] = None,
856
+ attention_mask: Optional[torch.Tensor] = None,
857
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
858
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
859
+ guess_mode: bool = False,
860
+ return_dict: bool = True,
861
+ ) -> Union[ControlNetOutput, Tuple]:
862
+ """
863
+ The [`ControlNetModel`] forward method.
864
+ Args:
865
+ sample (`torch.FloatTensor`):
866
+ The noisy input tensor.
867
+ timestep (`Union[torch.Tensor, float, int]`):
868
+ The number of timesteps to denoise an input.
869
+ encoder_hidden_states (`torch.Tensor`):
870
+ The encoder hidden states.
871
+ controlnet_cond (`torch.FloatTensor`):
872
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
873
+ conditioning_scale (`float`, defaults to `1.0`):
874
+ The scale factor for ControlNet outputs.
875
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
876
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
877
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
878
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
879
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
880
+ embeddings.
881
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
882
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
883
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
884
+ negative values to the attention scores corresponding to "discard" tokens.
885
+ added_cond_kwargs (`dict`):
886
+ Additional conditions for the Stable Diffusion XL UNet.
887
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
888
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
889
+ guess_mode (`bool`, defaults to `False`):
890
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
891
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
892
+ return_dict (`bool`, defaults to `True`):
893
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
894
+ Returns:
895
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
896
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
897
+ returned where the first element is the sample tensor.
898
+ """
899
+ # check channel order
900
+ channel_order = self.config.controlnet_conditioning_channel_order
901
+
902
+ if channel_order == "rgb":
903
+ # in rgb order by default
904
+ ...
905
+ # elif channel_order == "bgr":
906
+ # controlnet_cond = torch.flip(controlnet_cond, dims=[1])
907
+ else:
908
+ raise ValueError(
909
+ f"unknown `controlnet_conditioning_channel_order`: {channel_order}"
910
+ )
911
+
912
+ # prepare attention_mask
913
+ if attention_mask is not None:
914
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
915
+ attention_mask = attention_mask.unsqueeze(1)
916
+
917
+ # 1. time
918
+ timesteps = timestep
919
+ if not torch.is_tensor(timesteps):
920
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
921
+ # This would be a good case for the `match` statement (Python 3.10+)
922
+ is_mps = sample.device.type == "mps"
923
+ if isinstance(timestep, float):
924
+ dtype = torch.float32 if is_mps else torch.float64
925
+ else:
926
+ dtype = torch.int32 if is_mps else torch.int64
927
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
928
+ elif len(timesteps.shape) == 0:
929
+ timesteps = timesteps[None].to(sample.device)
930
+
931
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
932
+ timesteps = timesteps.expand(sample.shape[0])
933
+
934
+ t_emb = self.time_proj(timesteps)
935
+
936
+ # timesteps does not contain any weights and will always return f32 tensors
937
+ # but time_embedding might actually be running in fp16. so we need to cast here.
938
+ # there might be better ways to encapsulate this.
939
+ t_emb = t_emb.to(dtype=sample.dtype)
940
+
941
+ emb = self.time_embedding(t_emb, timestep_cond)
942
+ aug_emb = None
943
+
944
+ if self.class_embedding is not None:
945
+ if class_labels is None:
946
+ raise ValueError(
947
+ "class_labels should be provided when num_class_embeds > 0"
948
+ )
949
+
950
+ if self.config.class_embed_type == "timestep":
951
+ class_labels = self.time_proj(class_labels)
952
+
953
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
954
+ emb = emb + class_emb
955
+
956
+ if self.config.addition_embed_type is not None:
957
+ if self.config.addition_embed_type == "text":
958
+ aug_emb = self.add_embedding(encoder_hidden_states)
959
+
960
+ elif self.config.addition_embed_type == "text_time":
961
+ if "text_embeds" not in added_cond_kwargs:
962
+ raise ValueError(
963
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
964
+ )
965
+ text_embeds = added_cond_kwargs.get("text_embeds")
966
+ if "time_ids" not in added_cond_kwargs:
967
+ raise ValueError(
968
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
969
+ )
970
+ time_ids = added_cond_kwargs.get("time_ids")
971
+ time_embeds = self.add_time_proj(time_ids.flatten())
972
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
973
+
974
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
975
+ add_embeds = add_embeds.to(emb.dtype)
976
+ aug_emb = self.add_embedding(add_embeds)
977
+
978
+ # Copyright by Qi Xin(2024/07/06)
979
+ # inject control type info to time embedding to distinguish different control conditions
980
+ control_type = added_cond_kwargs.get("control_type")
981
+ control_embeds = self.control_type_proj(control_type.flatten())
982
+ control_embeds = control_embeds.reshape((t_emb.shape[0], -1))
983
+ control_embeds = control_embeds.to(emb.dtype)
984
+ control_emb = self.control_add_embedding(control_embeds)
985
+ emb = emb + control_emb
986
+ # ---------------------------------------------------------------------------------
987
+
988
+ emb = emb + aug_emb if aug_emb is not None else emb
989
+
990
+ # 2. pre-process
991
+ sample = self.conv_in(sample)
992
+ indices = torch.nonzero(control_type[0])
993
+
994
+ # Copyright by Qi Xin(2024/07/06)
995
+ # add single/multi conditons to input image.
996
+ # Condition Transformer provides an easy and effective way to fuse different features naturally
997
+ inputs = []
998
+ condition_list = []
999
+
1000
+ for idx in range(indices.shape[0] + 1):
1001
+ if idx == indices.shape[0]:
1002
+ controlnet_cond = sample
1003
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C
1004
+ else:
1005
+ controlnet_cond = self.controlnet_cond_embedding(
1006
+ controlnet_cond_list[indices[idx][0]]
1007
+ )
1008
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) # N * C
1009
+ feat_seq = feat_seq + self.task_embedding[indices[idx][0]]
1010
+
1011
+ inputs.append(feat_seq.unsqueeze(1))
1012
+ condition_list.append(controlnet_cond)
1013
+
1014
+ x = torch.cat(inputs, dim=1) # NxLxC
1015
+ x = self.transformer_layes(x)
1016
+
1017
+ controlnet_cond_fuser = sample * 0.0
1018
+ for idx in range(indices.shape[0]):
1019
+ alpha = self.spatial_ch_projs(x[:, idx])
1020
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
1021
+ controlnet_cond_fuser += condition_list[idx] + alpha
1022
+
1023
+ sample = sample + controlnet_cond_fuser
1024
+ # -------------------------------------------------------------------------------------------
1025
+
1026
+ # 3. down
1027
+ down_block_res_samples = (sample,)
1028
+ for downsample_block in self.down_blocks:
1029
+ if (
1030
+ hasattr(downsample_block, "has_cross_attention")
1031
+ and downsample_block.has_cross_attention
1032
+ ):
1033
+ sample, res_samples = downsample_block(
1034
+ hidden_states=sample,
1035
+ temb=emb,
1036
+ encoder_hidden_states=encoder_hidden_states,
1037
+ attention_mask=attention_mask,
1038
+ cross_attention_kwargs=cross_attention_kwargs,
1039
+ )
1040
+ else:
1041
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1042
+
1043
+ down_block_res_samples += res_samples
1044
+
1045
+ # 4. mid
1046
+ if self.mid_block is not None:
1047
+ sample = self.mid_block(
1048
+ sample,
1049
+ emb,
1050
+ encoder_hidden_states=encoder_hidden_states,
1051
+ attention_mask=attention_mask,
1052
+ cross_attention_kwargs=cross_attention_kwargs,
1053
+ )
1054
+
1055
+ # 5. Control net blocks
1056
+
1057
+ controlnet_down_block_res_samples = ()
1058
+
1059
+ for down_block_res_sample, controlnet_block in zip(
1060
+ down_block_res_samples, self.controlnet_down_blocks
1061
+ ):
1062
+ down_block_res_sample = controlnet_block(down_block_res_sample)
1063
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (
1064
+ down_block_res_sample,
1065
+ )
1066
+
1067
+ down_block_res_samples = controlnet_down_block_res_samples
1068
+
1069
+ mid_block_res_sample = self.controlnet_mid_block(sample)
1070
+
1071
+ # 6. scaling
1072
+ if guess_mode and not self.config.global_pool_conditions:
1073
+ scales = torch.logspace(
1074
+ -1, 0, len(down_block_res_samples) + 1, device=sample.device
1075
+ ) # 0.1 to 1.0
1076
+ scales = scales * conditioning_scale
1077
+ down_block_res_samples = [
1078
+ sample * scale for sample, scale in zip(down_block_res_samples, scales)
1079
+ ]
1080
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
1081
+ else:
1082
+ down_block_res_samples = [
1083
+ sample * conditioning_scale for sample in down_block_res_samples
1084
+ ]
1085
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
1086
+
1087
+ if self.config.global_pool_conditions:
1088
+ down_block_res_samples = [
1089
+ torch.mean(sample, dim=(2, 3), keepdim=True)
1090
+ for sample in down_block_res_samples
1091
+ ]
1092
+ mid_block_res_sample = torch.mean(
1093
+ mid_block_res_sample, dim=(2, 3), keepdim=True
1094
+ )
1095
+
1096
+ if not return_dict:
1097
+ return (down_block_res_samples, mid_block_res_sample)
1098
+
1099
+ return ControlNetOutput(
1100
+ down_block_res_samples=down_block_res_samples,
1101
+ mid_block_res_sample=mid_block_res_sample,
1102
+ )
1103
+
1104
+
1105
+ def zero_module(module):
1106
+ for p in module.parameters():
1107
+ nn.init.zeros_(p)
1108
+ return module