ljp commited on
Commit
9bc8f3c
1 Parent(s): 5838577

Create controlnet_sd3.py

Browse files
Files changed (1) hide show
  1. controlnet_sd3.py +552 -0
controlnet_sd3.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ import diffusers
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from diffusers.models.attention import JointTransformerBlock
26
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from diffusers.utils import (
29
+ USE_PEFT_BACKEND,
30
+ is_torch_version,
31
+ logging,
32
+ scale_lora_layers,
33
+ unscale_lora_layers,
34
+ )
35
+ from diffusers.models.controlnet import BaseOutput, zero_module
36
+ from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
37
+ from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
38
+ from torch.nn import functional as F
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+ from packaging import version
42
+
43
+ class ControlNetConditioningEmbedding(nn.Module):
44
+ """
45
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
46
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
47
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
48
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
49
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
50
+ model) to encode image-space conditions ... into feature maps ..."
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ conditioning_embedding_channels: int,
56
+ conditioning_channels: int = 3,
57
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
58
+ ):
59
+ super().__init__()
60
+
61
+ self.conv_in = nn.Conv2d(
62
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
63
+ )
64
+
65
+ self.blocks = nn.ModuleList([])
66
+
67
+ for i in range(len(block_out_channels) - 1):
68
+ channel_in = block_out_channels[i]
69
+ channel_out = block_out_channels[i + 1]
70
+ self.blocks.append(
71
+ nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)
72
+ )
73
+ self.blocks.append(
74
+ nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)
75
+ )
76
+
77
+ self.conv_out = zero_module(
78
+ nn.Conv2d(
79
+ block_out_channels[-1],
80
+ conditioning_embedding_channels,
81
+ kernel_size=3,
82
+ padding=1,
83
+ )
84
+ )
85
+
86
+ def forward(self, conditioning):
87
+ embedding = self.conv_in(conditioning)
88
+ embedding = F.silu(embedding)
89
+
90
+ for block in self.blocks:
91
+ embedding = block(embedding)
92
+ embedding = F.silu(embedding)
93
+
94
+ embedding = self.conv_out(embedding)
95
+
96
+ return embedding
97
+
98
+
99
+ @dataclass
100
+ class SD3ControlNetOutput(BaseOutput):
101
+ controlnet_block_samples: Tuple[torch.Tensor]
102
+
103
+
104
+ class SD3ControlNetModel(
105
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
106
+ ):
107
+ _supports_gradient_checkpointing = True
108
+
109
+ @register_to_config
110
+ def __init__(
111
+ self,
112
+ sample_size: int = 128,
113
+ patch_size: int = 2,
114
+ in_channels: int = 16,
115
+ num_layers: int = 18,
116
+ attention_head_dim: int = 64,
117
+ num_attention_heads: int = 18,
118
+ joint_attention_dim: int = 4096,
119
+ caption_projection_dim: int = 1152,
120
+ pooled_projection_dim: int = 2048,
121
+ out_channels: int = 16,
122
+ pos_embed_max_size: int = 96,
123
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
124
+ 16,
125
+ 32,
126
+ 96,
127
+ 256,
128
+ ),
129
+ conditioning_channels: int = 3,
130
+ ):
131
+ """
132
+ conditioning_channels: condition image pixel space channels
133
+ conditioning_embedding_out_channels: intermediate channels
134
+
135
+ """
136
+ super().__init__()
137
+ default_out_channels = in_channels
138
+ self.out_channels = (
139
+ out_channels if out_channels is not None else default_out_channels
140
+ )
141
+ self.inner_dim = num_attention_heads * attention_head_dim
142
+
143
+ self.pos_embed = PatchEmbed(
144
+ height=sample_size,
145
+ width=sample_size,
146
+ patch_size=patch_size,
147
+ in_channels=in_channels,
148
+ embed_dim=self.inner_dim,
149
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
150
+ )
151
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
152
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
153
+ )
154
+ self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
155
+
156
+ # control net conditioning embedding
157
+ # self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
158
+ # conditioning_embedding_channels=default_out_channels,
159
+ # block_out_channels=conditioning_embedding_out_channels,
160
+ # conditioning_channels=conditioning_channels,
161
+ # )
162
+
163
+ # `attention_head_dim` is doubled to account for the mixing.
164
+ # It needs to crafted when we get the actual checkpoints.
165
+ self.transformer_blocks = nn.ModuleList(
166
+ [
167
+ JointTransformerBlock(
168
+ dim=self.inner_dim,
169
+ num_attention_heads=num_attention_heads,
170
+ attention_head_dim=attention_head_dim if version.parse(diffusers.__version__) >= version.parse('0.30.0.dev0') else self.inner_dim,
171
+ context_pre_only=False,
172
+ )
173
+ for _ in range(num_layers)
174
+ ]
175
+ )
176
+
177
+ # controlnet_blocks
178
+ self.controlnet_blocks = nn.ModuleList([])
179
+ for _ in range(len(self.transformer_blocks)):
180
+ controlnet_block = zero_module(nn.Linear(self.inner_dim, self.inner_dim))
181
+ self.controlnet_blocks.append(controlnet_block)
182
+
183
+ # control condition embedding
184
+ pos_embed_cond = PatchEmbed(
185
+ height=sample_size,
186
+ width=sample_size,
187
+ patch_size=patch_size,
188
+ in_channels=in_channels + 1,
189
+ embed_dim=self.inner_dim,
190
+ pos_embed_type=None,
191
+ )
192
+ # pos_embed_cond = nn.Linear(in_channels + 1, self.inner_dim)
193
+ self.pos_embed_cond = zero_module(pos_embed_cond)
194
+
195
+ self.gradient_checkpointing = False
196
+
197
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
198
+ def enable_forward_chunking(
199
+ self, chunk_size: Optional[int] = None, dim: int = 0
200
+ ) -> None:
201
+ """
202
+ Sets the attention processor to use [feed forward
203
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
204
+
205
+ Parameters:
206
+ chunk_size (`int`, *optional*):
207
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
208
+ over each tensor of dim=`dim`.
209
+ dim (`int`, *optional*, defaults to `0`):
210
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
211
+ or dim=1 (sequence length).
212
+ """
213
+ if dim not in [0, 1]:
214
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
215
+
216
+ # By default chunk size is 1
217
+ chunk_size = chunk_size or 1
218
+
219
+ def fn_recursive_feed_forward(
220
+ module: torch.nn.Module, chunk_size: int, dim: int
221
+ ):
222
+ if hasattr(module, "set_chunk_feed_forward"):
223
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
224
+
225
+ for child in module.children():
226
+ fn_recursive_feed_forward(child, chunk_size, dim)
227
+
228
+ for module in self.children():
229
+ fn_recursive_feed_forward(module, chunk_size, dim)
230
+
231
+ @property
232
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
233
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
234
+ r"""
235
+ Returns:
236
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
237
+ indexed by its weight name.
238
+ """
239
+ # set recursively
240
+ processors = {}
241
+
242
+ def fn_recursive_add_processors(
243
+ name: str,
244
+ module: torch.nn.Module,
245
+ processors: Dict[str, AttentionProcessor],
246
+ ):
247
+ if hasattr(module, "get_processor"):
248
+ processors[f"{name}.processor"] = module.get_processor(
249
+ return_deprecated_lora=True
250
+ )
251
+
252
+ for sub_name, child in module.named_children():
253
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
254
+
255
+ return processors
256
+
257
+ for name, module in self.named_children():
258
+ fn_recursive_add_processors(name, module, processors)
259
+
260
+ return processors
261
+
262
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
263
+ def set_attn_processor(
264
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
265
+ ):
266
+ r"""
267
+ Sets the attention processor to use to compute attention.
268
+
269
+ Parameters:
270
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
271
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
272
+ for **all** `Attention` layers.
273
+
274
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
275
+ processor. This is strongly recommended when setting trainable attention processors.
276
+
277
+ """
278
+ count = len(self.attn_processors.keys())
279
+
280
+ if isinstance(processor, dict) and len(processor) != count:
281
+ raise ValueError(
282
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
283
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
284
+ )
285
+
286
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
287
+ if hasattr(module, "set_processor"):
288
+ if not isinstance(processor, dict):
289
+ module.set_processor(processor)
290
+ else:
291
+ module.set_processor(processor.pop(f"{name}.processor"))
292
+
293
+ for sub_name, child in module.named_children():
294
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
295
+
296
+ for name, module in self.named_children():
297
+ fn_recursive_attn_processor(name, module, processor)
298
+
299
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
300
+ def fuse_qkv_projections(self):
301
+ """
302
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
303
+ are fused. For cross-attention modules, key and value projection matrices are fused.
304
+
305
+ <Tip warning={true}>
306
+
307
+ This API is 🧪 experimental.
308
+
309
+ </Tip>
310
+ """
311
+ self.original_attn_processors = None
312
+
313
+ for _, attn_processor in self.attn_processors.items():
314
+ if "Added" in str(attn_processor.__class__.__name__):
315
+ raise ValueError(
316
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
317
+ )
318
+
319
+ self.original_attn_processors = self.attn_processors
320
+
321
+ for module in self.modules():
322
+ if isinstance(module, Attention):
323
+ module.fuse_projections(fuse=True)
324
+
325
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
326
+ def unfuse_qkv_projections(self):
327
+ """Disables the fused QKV projection if enabled.
328
+
329
+ <Tip warning={true}>
330
+
331
+ This API is 🧪 experimental.
332
+
333
+ </Tip>
334
+
335
+ """
336
+ if self.original_attn_processors is not None:
337
+ self.set_attn_processor(self.original_attn_processors)
338
+
339
+ def _set_gradient_checkpointing(self, module, value=False):
340
+ if hasattr(module, "gradient_checkpointing"):
341
+ module.gradient_checkpointing = value
342
+
343
+ @classmethod
344
+ def from_transformer(
345
+ cls, transformer, num_layers=None, load_weights_from_transformer=True
346
+ ):
347
+ config = transformer.config
348
+ config["num_layers"] = num_layers or config.num_layers
349
+ controlnet = cls(**config)
350
+
351
+ if load_weights_from_transformer:
352
+ controlnet.pos_embed.load_state_dict(
353
+ transformer.pos_embed.state_dict(), strict=False
354
+ )
355
+ controlnet.time_text_embed.load_state_dict(
356
+ transformer.time_text_embed.state_dict(), strict=False
357
+ )
358
+ controlnet.context_embedder.load_state_dict(
359
+ transformer.context_embedder.state_dict(), strict=False
360
+ )
361
+ controlnet.transformer_blocks.load_state_dict(
362
+ transformer.transformer_blocks.state_dict(), strict=False
363
+ )
364
+
365
+ return controlnet
366
+
367
+ def forward(
368
+ self,
369
+ hidden_states: torch.FloatTensor,
370
+ controlnet_cond: torch.Tensor,
371
+ conditioning_scale: float = 1.0,
372
+ encoder_hidden_states: torch.FloatTensor = None,
373
+ pooled_projections: torch.FloatTensor = None,
374
+ timestep: torch.LongTensor = None,
375
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
376
+ return_dict: bool = True,
377
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
378
+ """
379
+ The [`SD3Transformer2DModel`] forward method.
380
+
381
+ Args:
382
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
383
+ Input `hidden_states`.
384
+ controlnet_cond (`torch.Tensor`):
385
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
386
+ conditioning_scale (`float`, defaults to `1.0`):
387
+ The scale factor for ControlNet outputs.
388
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
389
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
390
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
391
+ from the embeddings of input conditions.
392
+ timestep ( `torch.LongTensor`):
393
+ Used to indicate denoising step.
394
+ joint_attention_kwargs (`dict`, *optional*):
395
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
396
+ `self.processor` in
397
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
398
+ return_dict (`bool`, *optional*, defaults to `True`):
399
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
400
+ tuple.
401
+
402
+ Returns:
403
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
404
+ `tuple` where the first element is the sample tensor.
405
+ """
406
+ if joint_attention_kwargs is not None:
407
+ joint_attention_kwargs = joint_attention_kwargs.copy()
408
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
409
+ else:
410
+ lora_scale = 1.0
411
+
412
+ if USE_PEFT_BACKEND:
413
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
414
+ scale_lora_layers(self, lora_scale)
415
+ else:
416
+ if (
417
+ joint_attention_kwargs is not None
418
+ and joint_attention_kwargs.get("scale", None) is not None
419
+ ):
420
+ logger.warning(
421
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
422
+ )
423
+
424
+ height, width = hidden_states.shape[-2:]
425
+
426
+ hidden_states = self.pos_embed(
427
+ hidden_states
428
+ ) # takes care of adding positional embeddings too. b,c,H,W -> b, N, C
429
+ temb = self.time_text_embed(timestep, pooled_projections)
430
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
431
+
432
+ # add condition
433
+ hidden_states = hidden_states + self.pos_embed_cond(controlnet_cond)
434
+
435
+ block_res_samples = ()
436
+
437
+ for block in self.transformer_blocks:
438
+ if self.training and self.gradient_checkpointing:
439
+
440
+ def create_custom_forward(module, return_dict=None):
441
+ def custom_forward(*inputs):
442
+ if return_dict is not None:
443
+ return module(*inputs, return_dict=return_dict)
444
+ else:
445
+ return module(*inputs)
446
+
447
+ return custom_forward
448
+
449
+ ckpt_kwargs: Dict[str, Any] = (
450
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
451
+ )
452
+ hidden_states = torch.utils.checkpoint.checkpoint(
453
+ create_custom_forward(block),
454
+ hidden_states,
455
+ encoder_hidden_states,
456
+ temb,
457
+ **ckpt_kwargs,
458
+ )
459
+
460
+ else:
461
+ encoder_hidden_states, hidden_states = block(
462
+ hidden_states=hidden_states,
463
+ encoder_hidden_states=encoder_hidden_states,
464
+ temb=temb,
465
+ )
466
+
467
+ block_res_samples = block_res_samples + (hidden_states,)
468
+
469
+ controlnet_block_res_samples = ()
470
+ for block_res_sample, controlnet_block in zip(
471
+ block_res_samples, self.controlnet_blocks
472
+ ):
473
+ block_res_sample = controlnet_block(block_res_sample)
474
+ controlnet_block_res_samples = controlnet_block_res_samples + (
475
+ block_res_sample,
476
+ )
477
+
478
+ # 6. scaling
479
+ controlnet_block_res_samples = [
480
+ sample * conditioning_scale for sample in controlnet_block_res_samples
481
+ ]
482
+
483
+ if USE_PEFT_BACKEND:
484
+ # remove `lora_scale` from each PEFT layer
485
+ unscale_lora_layers(self, lora_scale)
486
+
487
+ if not return_dict:
488
+ return (controlnet_block_res_samples,)
489
+
490
+ return SD3ControlNetOutput(
491
+ controlnet_block_samples=controlnet_block_res_samples
492
+ )
493
+
494
+ def invert_copy_paste(self, controlnet_block_samples):
495
+ controlnet_block_samples = controlnet_block_samples + controlnet_block_samples[::-1]
496
+ return controlnet_block_samples
497
+
498
+ class SD3MultiControlNetModel(ModelMixin):
499
+ r"""
500
+ `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
501
+
502
+ This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
503
+ compatible with `SD3ControlNetModel`.
504
+
505
+ Args:
506
+ controlnets (`List[SD3ControlNetModel]`):
507
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
508
+ `SD3ControlNetModel` as a list.
509
+ """
510
+
511
+ def __init__(self, controlnets):
512
+ super().__init__()
513
+ self.nets = nn.ModuleList(controlnets)
514
+
515
+ def forward(
516
+ self,
517
+ hidden_states: torch.FloatTensor,
518
+ controlnet_cond: List[torch.tensor],
519
+ conditioning_scale: List[float],
520
+ pooled_projections: torch.FloatTensor,
521
+ encoder_hidden_states: torch.FloatTensor = None,
522
+ timestep: torch.LongTensor = None,
523
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
524
+ return_dict: bool = True,
525
+ ) -> Union[SD3ControlNetOutput, Tuple]:
526
+ for i, (image, scale, controlnet) in enumerate(
527
+ zip(controlnet_cond, conditioning_scale, self.nets)
528
+ ):
529
+ block_samples = controlnet(
530
+ hidden_states=hidden_states,
531
+ timestep=timestep,
532
+ encoder_hidden_states=encoder_hidden_states,
533
+ pooled_projections=pooled_projections,
534
+ controlnet_cond=image,
535
+ conditioning_scale=scale,
536
+ joint_attention_kwargs=joint_attention_kwargs,
537
+ return_dict=return_dict,
538
+ )
539
+
540
+ # merge samples
541
+ if i == 0:
542
+ control_block_samples = block_samples
543
+ else:
544
+ control_block_samples = [
545
+ control_block_sample + block_sample
546
+ for control_block_sample, block_sample in zip(
547
+ control_block_samples[0], block_samples[0]
548
+ )
549
+ ]
550
+ control_block_samples = (tuple(control_block_samples),)
551
+
552
+ return control_block_samples