atlury commited on
Commit
16a860b
Β·
verified Β·
1 Parent(s): fc33955

Delete src

Browse files
src/attentionhacked_garmnet.py DELETED
@@ -1,670 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Any, Dict, Optional
15
-
16
- import torch
17
- import torch.nn.functional as F
18
- from torch import nn
19
-
20
- from diffusers.utils import USE_PEFT_BACKEND
21
- from diffusers.utils.torch_utils import maybe_allow_in_graph
22
- from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
23
- from diffusers.models.attention_processor import Attention
24
- from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
- from diffusers.models.lora import LoRACompatibleLinear
26
- from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
27
-
28
-
29
- def _chunked_feed_forward(
30
- ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
31
- ):
32
- # "feed_forward_chunk_size" can be used to save memory
33
- if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
- raise ValueError(
35
- f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
- )
37
-
38
- num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
- if lora_scale is None:
40
- ff_output = torch.cat(
41
- [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
42
- dim=chunk_dim,
43
- )
44
- else:
45
- # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
46
- ff_output = torch.cat(
47
- [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
48
- dim=chunk_dim,
49
- )
50
-
51
- return ff_output
52
-
53
-
54
- @maybe_allow_in_graph
55
- class GatedSelfAttentionDense(nn.Module):
56
- r"""
57
- A gated self-attention dense layer that combines visual features and object features.
58
-
59
- Parameters:
60
- query_dim (`int`): The number of channels in the query.
61
- context_dim (`int`): The number of channels in the context.
62
- n_heads (`int`): The number of heads to use for attention.
63
- d_head (`int`): The number of channels in each head.
64
- """
65
-
66
- def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
67
- super().__init__()
68
-
69
- # we need a linear projection since we need cat visual feature and obj feature
70
- self.linear = nn.Linear(context_dim, query_dim)
71
-
72
- self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
73
- self.ff = FeedForward(query_dim, activation_fn="geglu")
74
-
75
- self.norm1 = nn.LayerNorm(query_dim)
76
- self.norm2 = nn.LayerNorm(query_dim)
77
-
78
- self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
79
- self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
80
-
81
- self.enabled = True
82
-
83
- def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
84
- if not self.enabled:
85
- return x
86
-
87
- n_visual = x.shape[1]
88
- objs = self.linear(objs)
89
-
90
- x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
91
- x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
92
-
93
- return x
94
-
95
-
96
- @maybe_allow_in_graph
97
- class BasicTransformerBlock(nn.Module):
98
- r"""
99
- A basic Transformer block.
100
-
101
- Parameters:
102
- dim (`int`): The number of channels in the input and output.
103
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
104
- attention_head_dim (`int`): The number of channels in each head.
105
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
106
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
107
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
108
- num_embeds_ada_norm (:
109
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
110
- attention_bias (:
111
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
112
- only_cross_attention (`bool`, *optional*):
113
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
114
- double_self_attention (`bool`, *optional*):
115
- Whether to use two self-attention layers. In this case no cross attention layers are used.
116
- upcast_attention (`bool`, *optional*):
117
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
118
- norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
119
- Whether to use learnable elementwise affine parameters for normalization.
120
- norm_type (`str`, *optional*, defaults to `"layer_norm"`):
121
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
122
- final_dropout (`bool` *optional*, defaults to False):
123
- Whether to apply a final dropout after the last feed-forward layer.
124
- attention_type (`str`, *optional*, defaults to `"default"`):
125
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
126
- positional_embeddings (`str`, *optional*, defaults to `None`):
127
- The type of positional embeddings to apply to.
128
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
129
- The maximum number of positional embeddings to apply.
130
- """
131
-
132
- def __init__(
133
- self,
134
- dim: int,
135
- num_attention_heads: int,
136
- attention_head_dim: int,
137
- dropout=0.0,
138
- cross_attention_dim: Optional[int] = None,
139
- activation_fn: str = "geglu",
140
- num_embeds_ada_norm: Optional[int] = None,
141
- attention_bias: bool = False,
142
- only_cross_attention: bool = False,
143
- double_self_attention: bool = False,
144
- upcast_attention: bool = False,
145
- norm_elementwise_affine: bool = True,
146
- norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
147
- norm_eps: float = 1e-5,
148
- final_dropout: bool = False,
149
- attention_type: str = "default",
150
- positional_embeddings: Optional[str] = None,
151
- num_positional_embeddings: Optional[int] = None,
152
- ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
153
- ada_norm_bias: Optional[int] = None,
154
- ff_inner_dim: Optional[int] = None,
155
- ff_bias: bool = True,
156
- attention_out_bias: bool = True,
157
- ):
158
- super().__init__()
159
- self.only_cross_attention = only_cross_attention
160
-
161
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
162
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
163
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
164
- self.use_layer_norm = norm_type == "layer_norm"
165
- self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
166
-
167
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
168
- raise ValueError(
169
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
170
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
171
- )
172
-
173
- if positional_embeddings and (num_positional_embeddings is None):
174
- raise ValueError(
175
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
176
- )
177
-
178
- if positional_embeddings == "sinusoidal":
179
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
180
- else:
181
- self.pos_embed = None
182
-
183
- # Define 3 blocks. Each block has its own normalization layer.
184
- # 1. Self-Attn
185
- if self.use_ada_layer_norm:
186
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
187
- elif self.use_ada_layer_norm_zero:
188
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
189
- elif self.use_ada_layer_norm_continuous:
190
- self.norm1 = AdaLayerNormContinuous(
191
- dim,
192
- ada_norm_continous_conditioning_embedding_dim,
193
- norm_elementwise_affine,
194
- norm_eps,
195
- ada_norm_bias,
196
- "rms_norm",
197
- )
198
- else:
199
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
200
-
201
- self.attn1 = Attention(
202
- query_dim=dim,
203
- heads=num_attention_heads,
204
- dim_head=attention_head_dim,
205
- dropout=dropout,
206
- bias=attention_bias,
207
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
208
- upcast_attention=upcast_attention,
209
- out_bias=attention_out_bias,
210
- )
211
-
212
- # 2. Cross-Attn
213
- if cross_attention_dim is not None or double_self_attention:
214
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
215
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
216
- # the second cross attention block.
217
- if self.use_ada_layer_norm:
218
- self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
219
- elif self.use_ada_layer_norm_continuous:
220
- self.norm2 = AdaLayerNormContinuous(
221
- dim,
222
- ada_norm_continous_conditioning_embedding_dim,
223
- norm_elementwise_affine,
224
- norm_eps,
225
- ada_norm_bias,
226
- "rms_norm",
227
- )
228
- else:
229
- self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
230
-
231
- self.attn2 = Attention(
232
- query_dim=dim,
233
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
234
- heads=num_attention_heads,
235
- dim_head=attention_head_dim,
236
- dropout=dropout,
237
- bias=attention_bias,
238
- upcast_attention=upcast_attention,
239
- out_bias=attention_out_bias,
240
- ) # is self-attn if encoder_hidden_states is none
241
- else:
242
- self.norm2 = None
243
- self.attn2 = None
244
-
245
- # 3. Feed-forward
246
- if self.use_ada_layer_norm_continuous:
247
- self.norm3 = AdaLayerNormContinuous(
248
- dim,
249
- ada_norm_continous_conditioning_embedding_dim,
250
- norm_elementwise_affine,
251
- norm_eps,
252
- ada_norm_bias,
253
- "layer_norm",
254
- )
255
- elif not self.use_ada_layer_norm_single:
256
- self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
257
-
258
- self.ff = FeedForward(
259
- dim,
260
- dropout=dropout,
261
- activation_fn=activation_fn,
262
- final_dropout=final_dropout,
263
- inner_dim=ff_inner_dim,
264
- bias=ff_bias,
265
- )
266
-
267
- # 4. Fuser
268
- if attention_type == "gated" or attention_type == "gated-text-image":
269
- self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
270
-
271
- # 5. Scale-shift for PixArt-Alpha.
272
- if self.use_ada_layer_norm_single:
273
- self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
274
-
275
- # let chunk size default to None
276
- self._chunk_size = None
277
- self._chunk_dim = 0
278
-
279
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
280
- # Sets chunk feed-forward
281
- self._chunk_size = chunk_size
282
- self._chunk_dim = dim
283
-
284
- def forward(
285
- self,
286
- hidden_states: torch.FloatTensor,
287
- attention_mask: Optional[torch.FloatTensor] = None,
288
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
289
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
290
- timestep: Optional[torch.LongTensor] = None,
291
- cross_attention_kwargs: Dict[str, Any] = None,
292
- class_labels: Optional[torch.LongTensor] = None,
293
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
294
- ) -> torch.FloatTensor:
295
- # Notice that normalization is always applied before the real computation in the following blocks.
296
- # 0. Self-Attention
297
- batch_size = hidden_states.shape[0]
298
- if self.use_ada_layer_norm:
299
- norm_hidden_states = self.norm1(hidden_states, timestep)
300
- elif self.use_ada_layer_norm_zero:
301
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
302
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
303
- )
304
- elif self.use_layer_norm:
305
- norm_hidden_states = self.norm1(hidden_states)
306
- elif self.use_ada_layer_norm_continuous:
307
- norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
308
- elif self.use_ada_layer_norm_single:
309
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
310
- self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
311
- ).chunk(6, dim=1)
312
- norm_hidden_states = self.norm1(hidden_states)
313
- norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
314
- norm_hidden_states = norm_hidden_states.squeeze(1)
315
- else:
316
- raise ValueError("Incorrect norm used")
317
-
318
- if self.pos_embed is not None:
319
- norm_hidden_states = self.pos_embed(norm_hidden_states)
320
-
321
- garment_features = []
322
- garment_features.append(norm_hidden_states)
323
-
324
- # 1. Retrieve lora scale.
325
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
326
-
327
- # 2. Prepare GLIGEN inputs
328
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
329
- gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
330
-
331
- attn_output = self.attn1(
332
- norm_hidden_states,
333
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
334
- attention_mask=attention_mask,
335
- **cross_attention_kwargs,
336
- )
337
- if self.use_ada_layer_norm_zero:
338
- attn_output = gate_msa.unsqueeze(1) * attn_output
339
- elif self.use_ada_layer_norm_single:
340
- attn_output = gate_msa * attn_output
341
-
342
- hidden_states = attn_output + hidden_states
343
- if hidden_states.ndim == 4:
344
- hidden_states = hidden_states.squeeze(1)
345
-
346
- # 2.5 GLIGEN Control
347
- if gligen_kwargs is not None:
348
- hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
349
-
350
- # 3. Cross-Attention
351
- if self.attn2 is not None:
352
- if self.use_ada_layer_norm:
353
- norm_hidden_states = self.norm2(hidden_states, timestep)
354
- elif self.use_ada_layer_norm_zero or self.use_layer_norm:
355
- norm_hidden_states = self.norm2(hidden_states)
356
- elif self.use_ada_layer_norm_single:
357
- # For PixArt norm2 isn't applied here:
358
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
359
- norm_hidden_states = hidden_states
360
- elif self.use_ada_layer_norm_continuous:
361
- norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
362
- else:
363
- raise ValueError("Incorrect norm")
364
-
365
- if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
366
- norm_hidden_states = self.pos_embed(norm_hidden_states)
367
-
368
- attn_output = self.attn2(
369
- norm_hidden_states,
370
- encoder_hidden_states=encoder_hidden_states,
371
- attention_mask=encoder_attention_mask,
372
- **cross_attention_kwargs,
373
- )
374
- hidden_states = attn_output + hidden_states
375
-
376
- # 4. Feed-forward
377
- if self.use_ada_layer_norm_continuous:
378
- norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
379
- elif not self.use_ada_layer_norm_single:
380
- norm_hidden_states = self.norm3(hidden_states)
381
-
382
- if self.use_ada_layer_norm_zero:
383
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
384
-
385
- if self.use_ada_layer_norm_single:
386
- norm_hidden_states = self.norm2(hidden_states)
387
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
388
-
389
- if self._chunk_size is not None:
390
- # "feed_forward_chunk_size" can be used to save memory
391
- ff_output = _chunked_feed_forward(
392
- self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
393
- )
394
- else:
395
- ff_output = self.ff(norm_hidden_states, scale=lora_scale)
396
-
397
- if self.use_ada_layer_norm_zero:
398
- ff_output = gate_mlp.unsqueeze(1) * ff_output
399
- elif self.use_ada_layer_norm_single:
400
- ff_output = gate_mlp * ff_output
401
-
402
- hidden_states = ff_output + hidden_states
403
- if hidden_states.ndim == 4:
404
- hidden_states = hidden_states.squeeze(1)
405
-
406
- return hidden_states, garment_features
407
-
408
-
409
- @maybe_allow_in_graph
410
- class TemporalBasicTransformerBlock(nn.Module):
411
- r"""
412
- A basic Transformer block for video like data.
413
-
414
- Parameters:
415
- dim (`int`): The number of channels in the input and output.
416
- time_mix_inner_dim (`int`): The number of channels for temporal attention.
417
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
418
- attention_head_dim (`int`): The number of channels in each head.
419
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
420
- """
421
-
422
- def __init__(
423
- self,
424
- dim: int,
425
- time_mix_inner_dim: int,
426
- num_attention_heads: int,
427
- attention_head_dim: int,
428
- cross_attention_dim: Optional[int] = None,
429
- ):
430
- super().__init__()
431
- self.is_res = dim == time_mix_inner_dim
432
-
433
- self.norm_in = nn.LayerNorm(dim)
434
-
435
- # Define 3 blocks. Each block has its own normalization layer.
436
- # 1. Self-Attn
437
- self.norm_in = nn.LayerNorm(dim)
438
- self.ff_in = FeedForward(
439
- dim,
440
- dim_out=time_mix_inner_dim,
441
- activation_fn="geglu",
442
- )
443
-
444
- self.norm1 = nn.LayerNorm(time_mix_inner_dim)
445
- self.attn1 = Attention(
446
- query_dim=time_mix_inner_dim,
447
- heads=num_attention_heads,
448
- dim_head=attention_head_dim,
449
- cross_attention_dim=None,
450
- )
451
-
452
- # 2. Cross-Attn
453
- if cross_attention_dim is not None:
454
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
455
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
456
- # the second cross attention block.
457
- self.norm2 = nn.LayerNorm(time_mix_inner_dim)
458
- self.attn2 = Attention(
459
- query_dim=time_mix_inner_dim,
460
- cross_attention_dim=cross_attention_dim,
461
- heads=num_attention_heads,
462
- dim_head=attention_head_dim,
463
- ) # is self-attn if encoder_hidden_states is none
464
- else:
465
- self.norm2 = None
466
- self.attn2 = None
467
-
468
- # 3. Feed-forward
469
- self.norm3 = nn.LayerNorm(time_mix_inner_dim)
470
- self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
471
-
472
- # let chunk size default to None
473
- self._chunk_size = None
474
- self._chunk_dim = None
475
-
476
- def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
477
- # Sets chunk feed-forward
478
- self._chunk_size = chunk_size
479
- # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
480
- self._chunk_dim = 1
481
-
482
- def forward(
483
- self,
484
- hidden_states: torch.FloatTensor,
485
- num_frames: int,
486
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
487
- ) -> torch.FloatTensor:
488
- # Notice that normalization is always applied before the real computation in the following blocks.
489
- # 0. Self-Attention
490
- batch_size = hidden_states.shape[0]
491
-
492
- batch_frames, seq_length, channels = hidden_states.shape
493
- batch_size = batch_frames // num_frames
494
-
495
- hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
496
- hidden_states = hidden_states.permute(0, 2, 1, 3)
497
- hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
498
-
499
- residual = hidden_states
500
- hidden_states = self.norm_in(hidden_states)
501
-
502
- if self._chunk_size is not None:
503
- hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
504
- else:
505
- hidden_states = self.ff_in(hidden_states)
506
-
507
- if self.is_res:
508
- hidden_states = hidden_states + residual
509
-
510
- norm_hidden_states = self.norm1(hidden_states)
511
- attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
512
- hidden_states = attn_output + hidden_states
513
-
514
- # 3. Cross-Attention
515
- if self.attn2 is not None:
516
- norm_hidden_states = self.norm2(hidden_states)
517
- attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
518
- hidden_states = attn_output + hidden_states
519
-
520
- # 4. Feed-forward
521
- norm_hidden_states = self.norm3(hidden_states)
522
-
523
- if self._chunk_size is not None:
524
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
525
- else:
526
- ff_output = self.ff(norm_hidden_states)
527
-
528
- if self.is_res:
529
- hidden_states = ff_output + hidden_states
530
- else:
531
- hidden_states = ff_output
532
-
533
- hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
534
- hidden_states = hidden_states.permute(0, 2, 1, 3)
535
- hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
536
-
537
- return hidden_states
538
-
539
-
540
- class SkipFFTransformerBlock(nn.Module):
541
- def __init__(
542
- self,
543
- dim: int,
544
- num_attention_heads: int,
545
- attention_head_dim: int,
546
- kv_input_dim: int,
547
- kv_input_dim_proj_use_bias: bool,
548
- dropout=0.0,
549
- cross_attention_dim: Optional[int] = None,
550
- attention_bias: bool = False,
551
- attention_out_bias: bool = True,
552
- ):
553
- super().__init__()
554
- if kv_input_dim != dim:
555
- self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
556
- else:
557
- self.kv_mapper = None
558
-
559
- self.norm1 = RMSNorm(dim, 1e-06)
560
-
561
- self.attn1 = Attention(
562
- query_dim=dim,
563
- heads=num_attention_heads,
564
- dim_head=attention_head_dim,
565
- dropout=dropout,
566
- bias=attention_bias,
567
- cross_attention_dim=cross_attention_dim,
568
- out_bias=attention_out_bias,
569
- )
570
-
571
- self.norm2 = RMSNorm(dim, 1e-06)
572
-
573
- self.attn2 = Attention(
574
- query_dim=dim,
575
- cross_attention_dim=cross_attention_dim,
576
- heads=num_attention_heads,
577
- dim_head=attention_head_dim,
578
- dropout=dropout,
579
- bias=attention_bias,
580
- out_bias=attention_out_bias,
581
- )
582
-
583
- def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
584
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
585
-
586
- if self.kv_mapper is not None:
587
- encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
588
-
589
- norm_hidden_states = self.norm1(hidden_states)
590
-
591
- attn_output = self.attn1(
592
- norm_hidden_states,
593
- encoder_hidden_states=encoder_hidden_states,
594
- **cross_attention_kwargs,
595
- )
596
-
597
- hidden_states = attn_output + hidden_states
598
-
599
- norm_hidden_states = self.norm2(hidden_states)
600
-
601
- attn_output = self.attn2(
602
- norm_hidden_states,
603
- encoder_hidden_states=encoder_hidden_states,
604
- **cross_attention_kwargs,
605
- )
606
-
607
- hidden_states = attn_output + hidden_states
608
-
609
- return hidden_states
610
-
611
-
612
- class FeedForward(nn.Module):
613
- r"""
614
- A feed-forward layer.
615
-
616
- Parameters:
617
- dim (`int`): The number of channels in the input.
618
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
619
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
620
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
621
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
622
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
623
- bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
624
- """
625
-
626
- def __init__(
627
- self,
628
- dim: int,
629
- dim_out: Optional[int] = None,
630
- mult: int = 4,
631
- dropout: float = 0.0,
632
- activation_fn: str = "geglu",
633
- final_dropout: bool = False,
634
- inner_dim=None,
635
- bias: bool = True,
636
- ):
637
- super().__init__()
638
- if inner_dim is None:
639
- inner_dim = int(dim * mult)
640
- dim_out = dim_out if dim_out is not None else dim
641
- linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
642
-
643
- if activation_fn == "gelu":
644
- act_fn = GELU(dim, inner_dim, bias=bias)
645
- if activation_fn == "gelu-approximate":
646
- act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
647
- elif activation_fn == "geglu":
648
- act_fn = GEGLU(dim, inner_dim, bias=bias)
649
- elif activation_fn == "geglu-approximate":
650
- act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
651
-
652
- self.net = nn.ModuleList([])
653
- # project in
654
- self.net.append(act_fn)
655
- # project dropout
656
- self.net.append(nn.Dropout(dropout))
657
- # project out
658
- self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
659
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
660
- if final_dropout:
661
- self.net.append(nn.Dropout(dropout))
662
-
663
- def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
664
- compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
665
- for module in self.net:
666
- if isinstance(module, compatible_cls):
667
- hidden_states = module(hidden_states, scale)
668
- else:
669
- hidden_states = module(hidden_states)
670
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/attentionhacked_tryon.py DELETED
@@ -1,679 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Any, Dict, Optional
15
-
16
- import torch
17
- import torch.nn.functional as F
18
- from torch import nn
19
-
20
- from diffusers.utils import USE_PEFT_BACKEND
21
- from diffusers.utils.torch_utils import maybe_allow_in_graph
22
- from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
23
- from diffusers.models.attention_processor import Attention
24
- from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
- from diffusers.models.lora import LoRACompatibleLinear
26
- from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
27
-
28
-
29
- def _chunked_feed_forward(
30
- ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
31
- ):
32
- # "feed_forward_chunk_size" can be used to save memory
33
- if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
- raise ValueError(
35
- f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
- )
37
-
38
- num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
- if lora_scale is None:
40
- ff_output = torch.cat(
41
- [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
42
- dim=chunk_dim,
43
- )
44
- else:
45
- # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
46
- ff_output = torch.cat(
47
- [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
48
- dim=chunk_dim,
49
- )
50
-
51
- return ff_output
52
-
53
-
54
- @maybe_allow_in_graph
55
- class GatedSelfAttentionDense(nn.Module):
56
- r"""
57
- A gated self-attention dense layer that combines visual features and object features.
58
-
59
- Parameters:
60
- query_dim (`int`): The number of channels in the query.
61
- context_dim (`int`): The number of channels in the context.
62
- n_heads (`int`): The number of heads to use for attention.
63
- d_head (`int`): The number of channels in each head.
64
- """
65
-
66
- def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
67
- super().__init__()
68
-
69
- # we need a linear projection since we need cat visual feature and obj feature
70
- self.linear = nn.Linear(context_dim, query_dim)
71
-
72
- self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
73
- self.ff = FeedForward(query_dim, activation_fn="geglu")
74
-
75
- self.norm1 = nn.LayerNorm(query_dim)
76
- self.norm2 = nn.LayerNorm(query_dim)
77
-
78
- self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
79
- self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
80
-
81
- self.enabled = True
82
-
83
- def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
84
- if not self.enabled:
85
- return x
86
-
87
- n_visual = x.shape[1]
88
- objs = self.linear(objs)
89
-
90
- x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
91
- x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
92
-
93
- return x
94
-
95
-
96
- @maybe_allow_in_graph
97
- class BasicTransformerBlock(nn.Module):
98
- r"""
99
- A basic Transformer block.
100
-
101
- Parameters:
102
- dim (`int`): The number of channels in the input and output.
103
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
104
- attention_head_dim (`int`): The number of channels in each head.
105
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
106
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
107
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
108
- num_embeds_ada_norm (:
109
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
110
- attention_bias (:
111
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
112
- only_cross_attention (`bool`, *optional*):
113
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
114
- double_self_attention (`bool`, *optional*):
115
- Whether to use two self-attention layers. In this case no cross attention layers are used.
116
- upcast_attention (`bool`, *optional*):
117
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
118
- norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
119
- Whether to use learnable elementwise affine parameters for normalization.
120
- norm_type (`str`, *optional*, defaults to `"layer_norm"`):
121
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
122
- final_dropout (`bool` *optional*, defaults to False):
123
- Whether to apply a final dropout after the last feed-forward layer.
124
- attention_type (`str`, *optional*, defaults to `"default"`):
125
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
126
- positional_embeddings (`str`, *optional*, defaults to `None`):
127
- The type of positional embeddings to apply to.
128
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
129
- The maximum number of positional embeddings to apply.
130
- """
131
-
132
- def __init__(
133
- self,
134
- dim: int,
135
- num_attention_heads: int,
136
- attention_head_dim: int,
137
- dropout=0.0,
138
- cross_attention_dim: Optional[int] = None,
139
- activation_fn: str = "geglu",
140
- num_embeds_ada_norm: Optional[int] = None,
141
- attention_bias: bool = False,
142
- only_cross_attention: bool = False,
143
- double_self_attention: bool = False,
144
- upcast_attention: bool = False,
145
- norm_elementwise_affine: bool = True,
146
- norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
147
- norm_eps: float = 1e-5,
148
- final_dropout: bool = False,
149
- attention_type: str = "default",
150
- positional_embeddings: Optional[str] = None,
151
- num_positional_embeddings: Optional[int] = None,
152
- ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
153
- ada_norm_bias: Optional[int] = None,
154
- ff_inner_dim: Optional[int] = None,
155
- ff_bias: bool = True,
156
- attention_out_bias: bool = True,
157
- ):
158
- super().__init__()
159
- self.only_cross_attention = only_cross_attention
160
-
161
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
162
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
163
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
164
- self.use_layer_norm = norm_type == "layer_norm"
165
- self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
166
-
167
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
168
- raise ValueError(
169
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
170
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
171
- )
172
-
173
- if positional_embeddings and (num_positional_embeddings is None):
174
- raise ValueError(
175
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
176
- )
177
-
178
- if positional_embeddings == "sinusoidal":
179
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
180
- else:
181
- self.pos_embed = None
182
-
183
- # Define 3 blocks. Each block has its own normalization layer.
184
- # 1. Self-Attn
185
- if self.use_ada_layer_norm:
186
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
187
- elif self.use_ada_layer_norm_zero:
188
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
189
- elif self.use_ada_layer_norm_continuous:
190
- self.norm1 = AdaLayerNormContinuous(
191
- dim,
192
- ada_norm_continous_conditioning_embedding_dim,
193
- norm_elementwise_affine,
194
- norm_eps,
195
- ada_norm_bias,
196
- "rms_norm",
197
- )
198
- else:
199
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
200
-
201
- self.attn1 = Attention(
202
- query_dim=dim,
203
- heads=num_attention_heads,
204
- dim_head=attention_head_dim,
205
- dropout=dropout,
206
- bias=attention_bias,
207
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
208
- upcast_attention=upcast_attention,
209
- out_bias=attention_out_bias,
210
- )
211
-
212
- # 2. Cross-Attn
213
- if cross_attention_dim is not None or double_self_attention:
214
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
215
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
216
- # the second cross attention block.
217
- if self.use_ada_layer_norm:
218
- self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
219
- elif self.use_ada_layer_norm_continuous:
220
- self.norm2 = AdaLayerNormContinuous(
221
- dim,
222
- ada_norm_continous_conditioning_embedding_dim,
223
- norm_elementwise_affine,
224
- norm_eps,
225
- ada_norm_bias,
226
- "rms_norm",
227
- )
228
- else:
229
- self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
230
-
231
- self.attn2 = Attention(
232
- query_dim=dim,
233
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
234
- heads=num_attention_heads,
235
- dim_head=attention_head_dim,
236
- dropout=dropout,
237
- bias=attention_bias,
238
- upcast_attention=upcast_attention,
239
- out_bias=attention_out_bias,
240
- ) # is self-attn if encoder_hidden_states is none
241
- else:
242
- self.norm2 = None
243
- self.attn2 = None
244
-
245
- # 3. Feed-forward
246
- if self.use_ada_layer_norm_continuous:
247
- self.norm3 = AdaLayerNormContinuous(
248
- dim,
249
- ada_norm_continous_conditioning_embedding_dim,
250
- norm_elementwise_affine,
251
- norm_eps,
252
- ada_norm_bias,
253
- "layer_norm",
254
- )
255
- elif not self.use_ada_layer_norm_single:
256
- self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
257
-
258
- self.ff = FeedForward(
259
- dim,
260
- dropout=dropout,
261
- activation_fn=activation_fn,
262
- final_dropout=final_dropout,
263
- inner_dim=ff_inner_dim,
264
- bias=ff_bias,
265
- )
266
-
267
- # 4. Fuser
268
- if attention_type == "gated" or attention_type == "gated-text-image":
269
- self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
270
-
271
- # 5. Scale-shift for PixArt-Alpha.
272
- if self.use_ada_layer_norm_single:
273
- self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
274
-
275
- # let chunk size default to None
276
- self._chunk_size = None
277
- self._chunk_dim = 0
278
-
279
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
280
- # Sets chunk feed-forward
281
- self._chunk_size = chunk_size
282
- self._chunk_dim = dim
283
-
284
- def forward(
285
- self,
286
- hidden_states: torch.FloatTensor,
287
- attention_mask: Optional[torch.FloatTensor] = None,
288
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
289
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
290
- timestep: Optional[torch.LongTensor] = None,
291
- cross_attention_kwargs: Dict[str, Any] = None,
292
- class_labels: Optional[torch.LongTensor] = None,
293
- garment_features=None,
294
- curr_garment_feat_idx=0,
295
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
296
- ) -> torch.FloatTensor:
297
- # Notice that normalization is always applied before the real computation in the following blocks.
298
- # 0. Self-Attention
299
- batch_size = hidden_states.shape[0]
300
-
301
-
302
-
303
- if self.use_ada_layer_norm:
304
- norm_hidden_states = self.norm1(hidden_states, timestep)
305
- elif self.use_ada_layer_norm_zero:
306
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
307
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
308
- )
309
- elif self.use_layer_norm:
310
- norm_hidden_states = self.norm1(hidden_states)
311
- elif self.use_ada_layer_norm_continuous:
312
- norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
313
- elif self.use_ada_layer_norm_single:
314
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
315
- self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
316
- ).chunk(6, dim=1)
317
- norm_hidden_states = self.norm1(hidden_states)
318
- norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
319
- norm_hidden_states = norm_hidden_states.squeeze(1)
320
- else:
321
- raise ValueError("Incorrect norm used")
322
-
323
- if self.pos_embed is not None:
324
- norm_hidden_states = self.pos_embed(norm_hidden_states)
325
-
326
- # 1. Retrieve lora scale.
327
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
328
-
329
- # 2. Prepare GLIGEN inputs
330
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
331
- gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
332
-
333
-
334
- modify_norm_hidden_states = torch.cat([norm_hidden_states,garment_features[curr_garment_feat_idx]], dim=1)
335
- curr_garment_feat_idx +=1
336
- attn_output = self.attn1(
337
- #norm_hidden_states,
338
- modify_norm_hidden_states,
339
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
340
- attention_mask=attention_mask,
341
- **cross_attention_kwargs,
342
- )
343
- if self.use_ada_layer_norm_zero:
344
- attn_output = gate_msa.unsqueeze(1) * attn_output
345
- elif self.use_ada_layer_norm_single:
346
- attn_output = gate_msa * attn_output
347
-
348
- hidden_states = attn_output[:,:hidden_states.shape[-2],:] + hidden_states
349
-
350
-
351
-
352
-
353
- if hidden_states.ndim == 4:
354
- hidden_states = hidden_states.squeeze(1)
355
-
356
- # 2.5 GLIGEN Control
357
- if gligen_kwargs is not None:
358
- hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
359
-
360
- # 3. Cross-Attention
361
- if self.attn2 is not None:
362
- if self.use_ada_layer_norm:
363
- norm_hidden_states = self.norm2(hidden_states, timestep)
364
- elif self.use_ada_layer_norm_zero or self.use_layer_norm:
365
- norm_hidden_states = self.norm2(hidden_states)
366
- elif self.use_ada_layer_norm_single:
367
- # For PixArt norm2 isn't applied here:
368
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
369
- norm_hidden_states = hidden_states
370
- elif self.use_ada_layer_norm_continuous:
371
- norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
372
- else:
373
- raise ValueError("Incorrect norm")
374
-
375
- if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
376
- norm_hidden_states = self.pos_embed(norm_hidden_states)
377
-
378
- attn_output = self.attn2(
379
- norm_hidden_states,
380
- encoder_hidden_states=encoder_hidden_states,
381
- attention_mask=encoder_attention_mask,
382
- **cross_attention_kwargs,
383
- )
384
- hidden_states = attn_output + hidden_states
385
-
386
- # 4. Feed-forward
387
- if self.use_ada_layer_norm_continuous:
388
- norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
389
- elif not self.use_ada_layer_norm_single:
390
- norm_hidden_states = self.norm3(hidden_states)
391
-
392
- if self.use_ada_layer_norm_zero:
393
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
394
-
395
- if self.use_ada_layer_norm_single:
396
- norm_hidden_states = self.norm2(hidden_states)
397
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
398
-
399
- if self._chunk_size is not None:
400
- # "feed_forward_chunk_size" can be used to save memory
401
- ff_output = _chunked_feed_forward(
402
- self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
403
- )
404
- else:
405
- ff_output = self.ff(norm_hidden_states, scale=lora_scale)
406
-
407
- if self.use_ada_layer_norm_zero:
408
- ff_output = gate_mlp.unsqueeze(1) * ff_output
409
- elif self.use_ada_layer_norm_single:
410
- ff_output = gate_mlp * ff_output
411
-
412
- hidden_states = ff_output + hidden_states
413
- if hidden_states.ndim == 4:
414
- hidden_states = hidden_states.squeeze(1)
415
- return hidden_states,curr_garment_feat_idx
416
-
417
-
418
- @maybe_allow_in_graph
419
- class TemporalBasicTransformerBlock(nn.Module):
420
- r"""
421
- A basic Transformer block for video like data.
422
-
423
- Parameters:
424
- dim (`int`): The number of channels in the input and output.
425
- time_mix_inner_dim (`int`): The number of channels for temporal attention.
426
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
427
- attention_head_dim (`int`): The number of channels in each head.
428
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
429
- """
430
-
431
- def __init__(
432
- self,
433
- dim: int,
434
- time_mix_inner_dim: int,
435
- num_attention_heads: int,
436
- attention_head_dim: int,
437
- cross_attention_dim: Optional[int] = None,
438
- ):
439
- super().__init__()
440
- self.is_res = dim == time_mix_inner_dim
441
-
442
- self.norm_in = nn.LayerNorm(dim)
443
-
444
- # Define 3 blocks. Each block has its own normalization layer.
445
- # 1. Self-Attn
446
- self.norm_in = nn.LayerNorm(dim)
447
- self.ff_in = FeedForward(
448
- dim,
449
- dim_out=time_mix_inner_dim,
450
- activation_fn="geglu",
451
- )
452
-
453
- self.norm1 = nn.LayerNorm(time_mix_inner_dim)
454
- self.attn1 = Attention(
455
- query_dim=time_mix_inner_dim,
456
- heads=num_attention_heads,
457
- dim_head=attention_head_dim,
458
- cross_attention_dim=None,
459
- )
460
-
461
- # 2. Cross-Attn
462
- if cross_attention_dim is not None:
463
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
464
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
465
- # the second cross attention block.
466
- self.norm2 = nn.LayerNorm(time_mix_inner_dim)
467
- self.attn2 = Attention(
468
- query_dim=time_mix_inner_dim,
469
- cross_attention_dim=cross_attention_dim,
470
- heads=num_attention_heads,
471
- dim_head=attention_head_dim,
472
- ) # is self-attn if encoder_hidden_states is none
473
- else:
474
- self.norm2 = None
475
- self.attn2 = None
476
-
477
- # 3. Feed-forward
478
- self.norm3 = nn.LayerNorm(time_mix_inner_dim)
479
- self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
480
-
481
- # let chunk size default to None
482
- self._chunk_size = None
483
- self._chunk_dim = None
484
-
485
- def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
486
- # Sets chunk feed-forward
487
- self._chunk_size = chunk_size
488
- # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
489
- self._chunk_dim = 1
490
-
491
- def forward(
492
- self,
493
- hidden_states: torch.FloatTensor,
494
- num_frames: int,
495
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
496
- ) -> torch.FloatTensor:
497
- # Notice that normalization is always applied before the real computation in the following blocks.
498
- # 0. Self-Attention
499
- batch_size = hidden_states.shape[0]
500
-
501
- batch_frames, seq_length, channels = hidden_states.shape
502
- batch_size = batch_frames // num_frames
503
-
504
- hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
505
- hidden_states = hidden_states.permute(0, 2, 1, 3)
506
- hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
507
-
508
- residual = hidden_states
509
- hidden_states = self.norm_in(hidden_states)
510
-
511
- if self._chunk_size is not None:
512
- hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
513
- else:
514
- hidden_states = self.ff_in(hidden_states)
515
-
516
- if self.is_res:
517
- hidden_states = hidden_states + residual
518
-
519
- norm_hidden_states = self.norm1(hidden_states)
520
- attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
521
- hidden_states = attn_output + hidden_states
522
-
523
- # 3. Cross-Attention
524
- if self.attn2 is not None:
525
- norm_hidden_states = self.norm2(hidden_states)
526
- attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
527
- hidden_states = attn_output + hidden_states
528
-
529
- # 4. Feed-forward
530
- norm_hidden_states = self.norm3(hidden_states)
531
-
532
- if self._chunk_size is not None:
533
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
534
- else:
535
- ff_output = self.ff(norm_hidden_states)
536
-
537
- if self.is_res:
538
- hidden_states = ff_output + hidden_states
539
- else:
540
- hidden_states = ff_output
541
-
542
- hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
543
- hidden_states = hidden_states.permute(0, 2, 1, 3)
544
- hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
545
-
546
- return hidden_states
547
-
548
-
549
- class SkipFFTransformerBlock(nn.Module):
550
- def __init__(
551
- self,
552
- dim: int,
553
- num_attention_heads: int,
554
- attention_head_dim: int,
555
- kv_input_dim: int,
556
- kv_input_dim_proj_use_bias: bool,
557
- dropout=0.0,
558
- cross_attention_dim: Optional[int] = None,
559
- attention_bias: bool = False,
560
- attention_out_bias: bool = True,
561
- ):
562
- super().__init__()
563
- if kv_input_dim != dim:
564
- self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
565
- else:
566
- self.kv_mapper = None
567
-
568
- self.norm1 = RMSNorm(dim, 1e-06)
569
-
570
- self.attn1 = Attention(
571
- query_dim=dim,
572
- heads=num_attention_heads,
573
- dim_head=attention_head_dim,
574
- dropout=dropout,
575
- bias=attention_bias,
576
- cross_attention_dim=cross_attention_dim,
577
- out_bias=attention_out_bias,
578
- )
579
-
580
- self.norm2 = RMSNorm(dim, 1e-06)
581
-
582
- self.attn2 = Attention(
583
- query_dim=dim,
584
- cross_attention_dim=cross_attention_dim,
585
- heads=num_attention_heads,
586
- dim_head=attention_head_dim,
587
- dropout=dropout,
588
- bias=attention_bias,
589
- out_bias=attention_out_bias,
590
- )
591
-
592
- def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
593
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
594
-
595
- if self.kv_mapper is not None:
596
- encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
597
-
598
- norm_hidden_states = self.norm1(hidden_states)
599
-
600
- attn_output = self.attn1(
601
- norm_hidden_states,
602
- encoder_hidden_states=encoder_hidden_states,
603
- **cross_attention_kwargs,
604
- )
605
-
606
- hidden_states = attn_output + hidden_states
607
-
608
- norm_hidden_states = self.norm2(hidden_states)
609
-
610
- attn_output = self.attn2(
611
- norm_hidden_states,
612
- encoder_hidden_states=encoder_hidden_states,
613
- **cross_attention_kwargs,
614
- )
615
-
616
- hidden_states = attn_output + hidden_states
617
-
618
- return hidden_states
619
-
620
-
621
- class FeedForward(nn.Module):
622
- r"""
623
- A feed-forward layer.
624
-
625
- Parameters:
626
- dim (`int`): The number of channels in the input.
627
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
628
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
629
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
630
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
631
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
632
- bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
633
- """
634
-
635
- def __init__(
636
- self,
637
- dim: int,
638
- dim_out: Optional[int] = None,
639
- mult: int = 4,
640
- dropout: float = 0.0,
641
- activation_fn: str = "geglu",
642
- final_dropout: bool = False,
643
- inner_dim=None,
644
- bias: bool = True,
645
- ):
646
- super().__init__()
647
- if inner_dim is None:
648
- inner_dim = int(dim * mult)
649
- dim_out = dim_out if dim_out is not None else dim
650
- linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
651
-
652
- if activation_fn == "gelu":
653
- act_fn = GELU(dim, inner_dim, bias=bias)
654
- if activation_fn == "gelu-approximate":
655
- act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
656
- elif activation_fn == "geglu":
657
- act_fn = GEGLU(dim, inner_dim, bias=bias)
658
- elif activation_fn == "geglu-approximate":
659
- act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
660
-
661
- self.net = nn.ModuleList([])
662
- # project in
663
- self.net.append(act_fn)
664
- # project dropout
665
- self.net.append(nn.Dropout(dropout))
666
- # project out
667
- self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
668
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
669
- if final_dropout:
670
- self.net.append(nn.Dropout(dropout))
671
-
672
- def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
673
- compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
674
- for module in self.net:
675
- if isinstance(module, compatible_cls):
676
- hidden_states = module(hidden_states, scale)
677
- else:
678
- hidden_states = module(hidden_states)
679
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/transformerhacked_garmnet.py DELETED
@@ -1,460 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, Optional
16
-
17
- import torch
18
- import torch.nn.functional as F
19
- from torch import nn
20
-
21
- from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.models.embeddings import ImagePositionalEmbeddings
23
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
24
- from src.attentionhacked_garmnet import BasicTransformerBlock
25
- from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection
26
- from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
27
- from diffusers.models.modeling_utils import ModelMixin
28
- from diffusers.models.normalization import AdaLayerNormSingle
29
-
30
-
31
- @dataclass
32
- class Transformer2DModelOutput(BaseOutput):
33
- """
34
- The output of [`Transformer2DModel`].
35
-
36
- Args:
37
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
38
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
39
- distributions for the unnoised latent pixels.
40
- """
41
-
42
- sample: torch.FloatTensor
43
-
44
-
45
- class Transformer2DModel(ModelMixin, ConfigMixin):
46
- """
47
- A 2D Transformer model for image-like data.
48
-
49
- Parameters:
50
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
51
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
52
- in_channels (`int`, *optional*):
53
- The number of channels in the input and output (specify if the input is **continuous**).
54
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
55
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
56
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
57
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
58
- This is fixed during training since it is used to learn a number of position embeddings.
59
- num_vector_embeds (`int`, *optional*):
60
- The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
61
- Includes the class for the masked latent pixel.
62
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
63
- num_embeds_ada_norm ( `int`, *optional*):
64
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
65
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
66
- added to the hidden states.
67
-
68
- During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
69
- attention_bias (`bool`, *optional*):
70
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
71
- """
72
-
73
- _supports_gradient_checkpointing = True
74
-
75
- @register_to_config
76
- def __init__(
77
- self,
78
- num_attention_heads: int = 16,
79
- attention_head_dim: int = 88,
80
- in_channels: Optional[int] = None,
81
- out_channels: Optional[int] = None,
82
- num_layers: int = 1,
83
- dropout: float = 0.0,
84
- norm_num_groups: int = 32,
85
- cross_attention_dim: Optional[int] = None,
86
- attention_bias: bool = False,
87
- sample_size: Optional[int] = None,
88
- num_vector_embeds: Optional[int] = None,
89
- patch_size: Optional[int] = None,
90
- activation_fn: str = "geglu",
91
- num_embeds_ada_norm: Optional[int] = None,
92
- use_linear_projection: bool = False,
93
- only_cross_attention: bool = False,
94
- double_self_attention: bool = False,
95
- upcast_attention: bool = False,
96
- norm_type: str = "layer_norm",
97
- norm_elementwise_affine: bool = True,
98
- norm_eps: float = 1e-5,
99
- attention_type: str = "default",
100
- caption_channels: int = None,
101
- ):
102
- super().__init__()
103
- self.use_linear_projection = use_linear_projection
104
- self.num_attention_heads = num_attention_heads
105
- self.attention_head_dim = attention_head_dim
106
- inner_dim = num_attention_heads * attention_head_dim
107
-
108
- conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
109
- linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
110
-
111
- # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
112
- # Define whether input is continuous or discrete depending on configuration
113
- self.is_input_continuous = (in_channels is not None) and (patch_size is None)
114
- self.is_input_vectorized = num_vector_embeds is not None
115
- self.is_input_patches = in_channels is not None and patch_size is not None
116
-
117
- if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
118
- deprecation_message = (
119
- f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
120
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
121
- " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
122
- " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
123
- " would be very nice if you could open a Pull request for the `transformer/config.json` file"
124
- )
125
- deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
126
- norm_type = "ada_norm"
127
-
128
- if self.is_input_continuous and self.is_input_vectorized:
129
- raise ValueError(
130
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
131
- " sure that either `in_channels` or `num_vector_embeds` is None."
132
- )
133
- elif self.is_input_vectorized and self.is_input_patches:
134
- raise ValueError(
135
- f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
136
- " sure that either `num_vector_embeds` or `num_patches` is None."
137
- )
138
- elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
139
- raise ValueError(
140
- f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
141
- f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
142
- )
143
-
144
- # 2. Define input layers
145
- if self.is_input_continuous:
146
- self.in_channels = in_channels
147
-
148
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
149
- if use_linear_projection:
150
- self.proj_in = linear_cls(in_channels, inner_dim)
151
- else:
152
- self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
153
- elif self.is_input_vectorized:
154
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
155
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
156
-
157
- self.height = sample_size
158
- self.width = sample_size
159
- self.num_vector_embeds = num_vector_embeds
160
- self.num_latent_pixels = self.height * self.width
161
-
162
- self.latent_image_embedding = ImagePositionalEmbeddings(
163
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
164
- )
165
- elif self.is_input_patches:
166
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
167
-
168
- self.height = sample_size
169
- self.width = sample_size
170
-
171
- self.patch_size = patch_size
172
- interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
173
- interpolation_scale = max(interpolation_scale, 1)
174
- self.pos_embed = PatchEmbed(
175
- height=sample_size,
176
- width=sample_size,
177
- patch_size=patch_size,
178
- in_channels=in_channels,
179
- embed_dim=inner_dim,
180
- interpolation_scale=interpolation_scale,
181
- )
182
-
183
- # 3. Define transformers blocks
184
- self.transformer_blocks = nn.ModuleList(
185
- [
186
- BasicTransformerBlock(
187
- inner_dim,
188
- num_attention_heads,
189
- attention_head_dim,
190
- dropout=dropout,
191
- cross_attention_dim=cross_attention_dim,
192
- activation_fn=activation_fn,
193
- num_embeds_ada_norm=num_embeds_ada_norm,
194
- attention_bias=attention_bias,
195
- only_cross_attention=only_cross_attention,
196
- double_self_attention=double_self_attention,
197
- upcast_attention=upcast_attention,
198
- norm_type=norm_type,
199
- norm_elementwise_affine=norm_elementwise_affine,
200
- norm_eps=norm_eps,
201
- attention_type=attention_type,
202
- )
203
- for d in range(num_layers)
204
- ]
205
- )
206
-
207
- # 4. Define output layers
208
- self.out_channels = in_channels if out_channels is None else out_channels
209
- if self.is_input_continuous:
210
- # TODO: should use out_channels for continuous projections
211
- if use_linear_projection:
212
- self.proj_out = linear_cls(inner_dim, in_channels)
213
- else:
214
- self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
215
- elif self.is_input_vectorized:
216
- self.norm_out = nn.LayerNorm(inner_dim)
217
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
218
- elif self.is_input_patches and norm_type != "ada_norm_single":
219
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
220
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
221
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
222
- elif self.is_input_patches and norm_type == "ada_norm_single":
223
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
224
- self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
225
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
226
-
227
- # 5. PixArt-Alpha blocks.
228
- self.adaln_single = None
229
- self.use_additional_conditions = False
230
- if norm_type == "ada_norm_single":
231
- self.use_additional_conditions = self.config.sample_size == 128
232
- # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
233
- # additional conditions until we find better name
234
- self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
235
-
236
- self.caption_projection = None
237
- if caption_channels is not None:
238
- self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
239
-
240
- self.gradient_checkpointing = False
241
-
242
- def _set_gradient_checkpointing(self, module, value=False):
243
- if hasattr(module, "gradient_checkpointing"):
244
- module.gradient_checkpointing = value
245
-
246
- def forward(
247
- self,
248
- hidden_states: torch.Tensor,
249
- encoder_hidden_states: Optional[torch.Tensor] = None,
250
- timestep: Optional[torch.LongTensor] = None,
251
- added_cond_kwargs: Dict[str, torch.Tensor] = None,
252
- class_labels: Optional[torch.LongTensor] = None,
253
- cross_attention_kwargs: Dict[str, Any] = None,
254
- attention_mask: Optional[torch.Tensor] = None,
255
- encoder_attention_mask: Optional[torch.Tensor] = None,
256
- return_dict: bool = True,
257
- ):
258
- """
259
- The [`Transformer2DModel`] forward method.
260
-
261
- Args:
262
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
263
- Input `hidden_states`.
264
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
265
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
266
- self-attention.
267
- timestep ( `torch.LongTensor`, *optional*):
268
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
269
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
270
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
271
- `AdaLayerZeroNorm`.
272
- cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
273
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
274
- `self.processor` in
275
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
276
- attention_mask ( `torch.Tensor`, *optional*):
277
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
278
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
279
- negative values to the attention scores corresponding to "discard" tokens.
280
- encoder_attention_mask ( `torch.Tensor`, *optional*):
281
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
282
-
283
- * Mask `(batch, sequence_length)` True = keep, False = discard.
284
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
285
-
286
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
287
- above. This bias will be added to the cross-attention scores.
288
- return_dict (`bool`, *optional*, defaults to `True`):
289
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
290
- tuple.
291
-
292
- Returns:
293
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
294
- `tuple` where the first element is the sample tensor.
295
- """
296
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
297
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
298
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
299
- # expects mask of shape:
300
- # [batch, key_tokens]
301
- # adds singleton query_tokens dimension:
302
- # [batch, 1, key_tokens]
303
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
304
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
305
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
306
- if attention_mask is not None and attention_mask.ndim == 2:
307
- # assume that mask is expressed as:
308
- # (1 = keep, 0 = discard)
309
- # convert mask into a bias that can be added to attention scores:
310
- # (keep = +0, discard = -10000.0)
311
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
312
- attention_mask = attention_mask.unsqueeze(1)
313
-
314
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
315
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
316
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
317
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
318
-
319
- # Retrieve lora scale.
320
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
321
-
322
- # 1. Input
323
- if self.is_input_continuous:
324
- batch, _, height, width = hidden_states.shape
325
- residual = hidden_states
326
-
327
- hidden_states = self.norm(hidden_states)
328
- if not self.use_linear_projection:
329
- hidden_states = (
330
- self.proj_in(hidden_states, scale=lora_scale)
331
- if not USE_PEFT_BACKEND
332
- else self.proj_in(hidden_states)
333
- )
334
- inner_dim = hidden_states.shape[1]
335
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
336
- else:
337
- inner_dim = hidden_states.shape[1]
338
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
339
- hidden_states = (
340
- self.proj_in(hidden_states, scale=lora_scale)
341
- if not USE_PEFT_BACKEND
342
- else self.proj_in(hidden_states)
343
- )
344
-
345
- elif self.is_input_vectorized:
346
- hidden_states = self.latent_image_embedding(hidden_states)
347
- elif self.is_input_patches:
348
- height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
349
- hidden_states = self.pos_embed(hidden_states)
350
-
351
- if self.adaln_single is not None:
352
- if self.use_additional_conditions and added_cond_kwargs is None:
353
- raise ValueError(
354
- "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
355
- )
356
- batch_size = hidden_states.shape[0]
357
- timestep, embedded_timestep = self.adaln_single(
358
- timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
359
- )
360
-
361
- # 2. Blocks
362
- if self.caption_projection is not None:
363
- batch_size = hidden_states.shape[0]
364
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
365
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
366
-
367
- garment_features = []
368
- for block in self.transformer_blocks:
369
- if self.training and self.gradient_checkpointing:
370
-
371
- def create_custom_forward(module, return_dict=None):
372
- def custom_forward(*inputs):
373
- if return_dict is not None:
374
- return module(*inputs, return_dict=return_dict)
375
- else:
376
- return module(*inputs)
377
-
378
- return custom_forward
379
-
380
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
381
- hidden_states,out_garment_feat = torch.utils.checkpoint.checkpoint(
382
- create_custom_forward(block),
383
- hidden_states,
384
- attention_mask,
385
- encoder_hidden_states,
386
- encoder_attention_mask,
387
- timestep,
388
- cross_attention_kwargs,
389
- class_labels,
390
- **ckpt_kwargs,
391
- )
392
- else:
393
- hidden_states,out_garment_feat = block(
394
- hidden_states,
395
- attention_mask=attention_mask,
396
- encoder_hidden_states=encoder_hidden_states,
397
- encoder_attention_mask=encoder_attention_mask,
398
- timestep=timestep,
399
- cross_attention_kwargs=cross_attention_kwargs,
400
- class_labels=class_labels,
401
- )
402
- garment_features += out_garment_feat
403
- # 3. Output
404
- if self.is_input_continuous:
405
- if not self.use_linear_projection:
406
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
407
- hidden_states = (
408
- self.proj_out(hidden_states, scale=lora_scale)
409
- if not USE_PEFT_BACKEND
410
- else self.proj_out(hidden_states)
411
- )
412
- else:
413
- hidden_states = (
414
- self.proj_out(hidden_states, scale=lora_scale)
415
- if not USE_PEFT_BACKEND
416
- else self.proj_out(hidden_states)
417
- )
418
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
419
-
420
- output = hidden_states + residual
421
- elif self.is_input_vectorized:
422
- hidden_states = self.norm_out(hidden_states)
423
- logits = self.out(hidden_states)
424
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
425
- logits = logits.permute(0, 2, 1)
426
-
427
- # log(p(x_0))
428
- output = F.log_softmax(logits.double(), dim=1).float()
429
-
430
- if self.is_input_patches:
431
- if self.config.norm_type != "ada_norm_single":
432
- conditioning = self.transformer_blocks[0].norm1.emb(
433
- timestep, class_labels, hidden_dtype=hidden_states.dtype
434
- )
435
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
436
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
437
- hidden_states = self.proj_out_2(hidden_states)
438
- elif self.config.norm_type == "ada_norm_single":
439
- shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
440
- hidden_states = self.norm_out(hidden_states)
441
- # Modulation
442
- hidden_states = hidden_states * (1 + scale) + shift
443
- hidden_states = self.proj_out(hidden_states)
444
- hidden_states = hidden_states.squeeze(1)
445
-
446
- # unpatchify
447
- if self.adaln_single is None:
448
- height = width = int(hidden_states.shape[1] ** 0.5)
449
- hidden_states = hidden_states.reshape(
450
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
451
- )
452
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
453
- output = hidden_states.reshape(
454
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
455
- )
456
-
457
- if not return_dict:
458
- return (output,) ,garment_features
459
-
460
- return Transformer2DModelOutput(sample=output),garment_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/transformerhacked_tryon.py DELETED
@@ -1,467 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, Optional
16
-
17
- import torch
18
- import torch.nn.functional as F
19
- from torch import nn
20
-
21
- from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.models.embeddings import ImagePositionalEmbeddings
23
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
24
- from src.attentionhacked_tryon import BasicTransformerBlock
25
- from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection
26
- from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
27
- from diffusers.models.modeling_utils import ModelMixin
28
- from diffusers.models.normalization import AdaLayerNormSingle
29
-
30
-
31
- @dataclass
32
- class Transformer2DModelOutput(BaseOutput):
33
- """
34
- The output of [`Transformer2DModel`].
35
-
36
- Args:
37
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
38
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
39
- distributions for the unnoised latent pixels.
40
- """
41
-
42
- sample: torch.FloatTensor
43
-
44
-
45
- class Transformer2DModel(ModelMixin, ConfigMixin):
46
- """
47
- A 2D Transformer model for image-like data.
48
-
49
- Parameters:
50
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
51
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
52
- in_channels (`int`, *optional*):
53
- The number of channels in the input and output (specify if the input is **continuous**).
54
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
55
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
56
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
57
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
58
- This is fixed during training since it is used to learn a number of position embeddings.
59
- num_vector_embeds (`int`, *optional*):
60
- The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
61
- Includes the class for the masked latent pixel.
62
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
63
- num_embeds_ada_norm ( `int`, *optional*):
64
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
65
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
66
- added to the hidden states.
67
-
68
- During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
69
- attention_bias (`bool`, *optional*):
70
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
71
- """
72
-
73
- _supports_gradient_checkpointing = True
74
-
75
- @register_to_config
76
- def __init__(
77
- self,
78
- num_attention_heads: int = 16,
79
- attention_head_dim: int = 88,
80
- in_channels: Optional[int] = None,
81
- out_channels: Optional[int] = None,
82
- num_layers: int = 1,
83
- dropout: float = 0.0,
84
- norm_num_groups: int = 32,
85
- cross_attention_dim: Optional[int] = None,
86
- attention_bias: bool = False,
87
- sample_size: Optional[int] = None,
88
- num_vector_embeds: Optional[int] = None,
89
- patch_size: Optional[int] = None,
90
- activation_fn: str = "geglu",
91
- num_embeds_ada_norm: Optional[int] = None,
92
- use_linear_projection: bool = False,
93
- only_cross_attention: bool = False,
94
- double_self_attention: bool = False,
95
- upcast_attention: bool = False,
96
- norm_type: str = "layer_norm",
97
- norm_elementwise_affine: bool = True,
98
- norm_eps: float = 1e-5,
99
- attention_type: str = "default",
100
- caption_channels: int = None,
101
- ):
102
- super().__init__()
103
- self.use_linear_projection = use_linear_projection
104
- self.num_attention_heads = num_attention_heads
105
- self.attention_head_dim = attention_head_dim
106
- inner_dim = num_attention_heads * attention_head_dim
107
-
108
- conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
109
- linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
110
-
111
- # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
112
- # Define whether input is continuous or discrete depending on configuration
113
- self.is_input_continuous = (in_channels is not None) and (patch_size is None)
114
- self.is_input_vectorized = num_vector_embeds is not None
115
- self.is_input_patches = in_channels is not None and patch_size is not None
116
-
117
- if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
118
- deprecation_message = (
119
- f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
120
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
121
- " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
122
- " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
123
- " would be very nice if you could open a Pull request for the `transformer/config.json` file"
124
- )
125
- deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
126
- norm_type = "ada_norm"
127
-
128
- if self.is_input_continuous and self.is_input_vectorized:
129
- raise ValueError(
130
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
131
- " sure that either `in_channels` or `num_vector_embeds` is None."
132
- )
133
- elif self.is_input_vectorized and self.is_input_patches:
134
- raise ValueError(
135
- f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
136
- " sure that either `num_vector_embeds` or `num_patches` is None."
137
- )
138
- elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
139
- raise ValueError(
140
- f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
141
- f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
142
- )
143
-
144
- # 2. Define input layers
145
- if self.is_input_continuous:
146
- self.in_channels = in_channels
147
-
148
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
149
- if use_linear_projection:
150
- self.proj_in = linear_cls(in_channels, inner_dim)
151
- else:
152
- self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
153
- elif self.is_input_vectorized:
154
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
155
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
156
-
157
- self.height = sample_size
158
- self.width = sample_size
159
- self.num_vector_embeds = num_vector_embeds
160
- self.num_latent_pixels = self.height * self.width
161
-
162
- self.latent_image_embedding = ImagePositionalEmbeddings(
163
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
164
- )
165
- elif self.is_input_patches:
166
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
167
-
168
- self.height = sample_size
169
- self.width = sample_size
170
-
171
- self.patch_size = patch_size
172
- interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
173
- interpolation_scale = max(interpolation_scale, 1)
174
- self.pos_embed = PatchEmbed(
175
- height=sample_size,
176
- width=sample_size,
177
- patch_size=patch_size,
178
- in_channels=in_channels,
179
- embed_dim=inner_dim,
180
- interpolation_scale=interpolation_scale,
181
- )
182
-
183
- # 3. Define transformers blocks
184
- self.transformer_blocks = nn.ModuleList(
185
- [
186
- BasicTransformerBlock(
187
- inner_dim,
188
- num_attention_heads,
189
- attention_head_dim,
190
- dropout=dropout,
191
- cross_attention_dim=cross_attention_dim,
192
- activation_fn=activation_fn,
193
- num_embeds_ada_norm=num_embeds_ada_norm,
194
- attention_bias=attention_bias,
195
- only_cross_attention=only_cross_attention,
196
- double_self_attention=double_self_attention,
197
- upcast_attention=upcast_attention,
198
- norm_type=norm_type,
199
- norm_elementwise_affine=norm_elementwise_affine,
200
- norm_eps=norm_eps,
201
- attention_type=attention_type,
202
- )
203
- for d in range(num_layers)
204
- ]
205
- )
206
-
207
- # 4. Define output layers
208
- self.out_channels = in_channels if out_channels is None else out_channels
209
- if self.is_input_continuous:
210
- # TODO: should use out_channels for continuous projections
211
- if use_linear_projection:
212
- self.proj_out = linear_cls(inner_dim, in_channels)
213
- else:
214
- self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
215
- elif self.is_input_vectorized:
216
- self.norm_out = nn.LayerNorm(inner_dim)
217
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
218
- elif self.is_input_patches and norm_type != "ada_norm_single":
219
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
220
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
221
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
222
- elif self.is_input_patches and norm_type == "ada_norm_single":
223
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
224
- self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
225
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
226
-
227
- # 5. PixArt-Alpha blocks.
228
- self.adaln_single = None
229
- self.use_additional_conditions = False
230
- if norm_type == "ada_norm_single":
231
- self.use_additional_conditions = self.config.sample_size == 128
232
- # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
233
- # additional conditions until we find better name
234
- self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
235
-
236
- self.caption_projection = None
237
- if caption_channels is not None:
238
- self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
239
-
240
- self.gradient_checkpointing = False
241
-
242
- def _set_gradient_checkpointing(self, module, value=False):
243
- if hasattr(module, "gradient_checkpointing"):
244
- module.gradient_checkpointing = value
245
-
246
- def forward(
247
- self,
248
- hidden_states: torch.Tensor,
249
- encoder_hidden_states: Optional[torch.Tensor] = None,
250
- timestep: Optional[torch.LongTensor] = None,
251
- added_cond_kwargs: Dict[str, torch.Tensor] = None,
252
- class_labels: Optional[torch.LongTensor] = None,
253
- cross_attention_kwargs: Dict[str, Any] = None,
254
- attention_mask: Optional[torch.Tensor] = None,
255
- encoder_attention_mask: Optional[torch.Tensor] = None,
256
- garment_features=None,
257
- curr_garment_feat_idx=0,
258
- return_dict: bool = True,
259
- ):
260
- """
261
- The [`Transformer2DModel`] forward method.
262
-
263
- Args:
264
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
265
- Input `hidden_states`.
266
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
267
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
268
- self-attention.
269
- timestep ( `torch.LongTensor`, *optional*):
270
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
271
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
272
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
273
- `AdaLayerZeroNorm`.
274
- cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
275
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
276
- `self.processor` in
277
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
278
- attention_mask ( `torch.Tensor`, *optional*):
279
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
280
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
281
- negative values to the attention scores corresponding to "discard" tokens.
282
- encoder_attention_mask ( `torch.Tensor`, *optional*):
283
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
284
-
285
- * Mask `(batch, sequence_length)` True = keep, False = discard.
286
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
287
-
288
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
289
- above. This bias will be added to the cross-attention scores.
290
- return_dict (`bool`, *optional*, defaults to `True`):
291
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
292
- tuple.
293
-
294
- Returns:
295
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
296
- `tuple` where the first element is the sample tensor.
297
- """
298
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
299
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
300
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
301
- # expects mask of shape:
302
- # [batch, key_tokens]
303
- # adds singleton query_tokens dimension:
304
- # [batch, 1, key_tokens]
305
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
306
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
307
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
308
- if attention_mask is not None and attention_mask.ndim == 2:
309
- # assume that mask is expressed as:
310
- # (1 = keep, 0 = discard)
311
- # convert mask into a bias that can be added to attention scores:
312
- # (keep = +0, discard = -10000.0)
313
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
314
- attention_mask = attention_mask.unsqueeze(1)
315
-
316
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
317
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
318
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
319
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
320
-
321
- # Retrieve lora scale.
322
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
323
-
324
- # 1. Input
325
- if self.is_input_continuous:
326
- batch, _, height, width = hidden_states.shape
327
- residual = hidden_states
328
-
329
- hidden_states = self.norm(hidden_states)
330
- if not self.use_linear_projection:
331
- hidden_states = (
332
- self.proj_in(hidden_states, scale=lora_scale)
333
- if not USE_PEFT_BACKEND
334
- else self.proj_in(hidden_states)
335
- )
336
- inner_dim = hidden_states.shape[1]
337
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
338
- else:
339
- inner_dim = hidden_states.shape[1]
340
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
341
- hidden_states = (
342
- self.proj_in(hidden_states, scale=lora_scale)
343
- if not USE_PEFT_BACKEND
344
- else self.proj_in(hidden_states)
345
- )
346
-
347
- elif self.is_input_vectorized:
348
- hidden_states = self.latent_image_embedding(hidden_states)
349
- elif self.is_input_patches:
350
- height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
351
- hidden_states = self.pos_embed(hidden_states)
352
-
353
- if self.adaln_single is not None:
354
- if self.use_additional_conditions and added_cond_kwargs is None:
355
- raise ValueError(
356
- "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
357
- )
358
- batch_size = hidden_states.shape[0]
359
- timestep, embedded_timestep = self.adaln_single(
360
- timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
361
- )
362
-
363
- # 2. Blocks
364
- if self.caption_projection is not None:
365
- batch_size = hidden_states.shape[0]
366
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
367
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
368
-
369
-
370
- for block in self.transformer_blocks:
371
- if self.training and self.gradient_checkpointing:
372
-
373
- def create_custom_forward(module, return_dict=None):
374
- def custom_forward(*inputs):
375
- if return_dict is not None:
376
- return module(*inputs, return_dict=return_dict)
377
- else:
378
- return module(*inputs)
379
-
380
- return custom_forward
381
-
382
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
383
- hidden_states,curr_garment_feat_idx = torch.utils.checkpoint.checkpoint(
384
- create_custom_forward(block),
385
- hidden_states,
386
- attention_mask,
387
- encoder_hidden_states,
388
- encoder_attention_mask,
389
- timestep,
390
- cross_attention_kwargs,
391
- class_labels,
392
- garment_features,
393
- curr_garment_feat_idx,
394
- **ckpt_kwargs,
395
- )
396
- else:
397
- hidden_states,curr_garment_feat_idx = block(
398
- hidden_states,
399
- attention_mask=attention_mask,
400
- encoder_hidden_states=encoder_hidden_states,
401
- encoder_attention_mask=encoder_attention_mask,
402
- timestep=timestep,
403
- cross_attention_kwargs=cross_attention_kwargs,
404
- class_labels=class_labels,
405
- garment_features=garment_features,
406
- curr_garment_feat_idx=curr_garment_feat_idx,
407
- )
408
-
409
-
410
- # 3. Output
411
- if self.is_input_continuous:
412
- if not self.use_linear_projection:
413
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
414
- hidden_states = (
415
- self.proj_out(hidden_states, scale=lora_scale)
416
- if not USE_PEFT_BACKEND
417
- else self.proj_out(hidden_states)
418
- )
419
- else:
420
- hidden_states = (
421
- self.proj_out(hidden_states, scale=lora_scale)
422
- if not USE_PEFT_BACKEND
423
- else self.proj_out(hidden_states)
424
- )
425
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
426
-
427
- output = hidden_states + residual
428
- elif self.is_input_vectorized:
429
- hidden_states = self.norm_out(hidden_states)
430
- logits = self.out(hidden_states)
431
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
432
- logits = logits.permute(0, 2, 1)
433
-
434
- # log(p(x_0))
435
- output = F.log_softmax(logits.double(), dim=1).float()
436
-
437
- if self.is_input_patches:
438
- if self.config.norm_type != "ada_norm_single":
439
- conditioning = self.transformer_blocks[0].norm1.emb(
440
- timestep, class_labels, hidden_dtype=hidden_states.dtype
441
- )
442
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
443
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
444
- hidden_states = self.proj_out_2(hidden_states)
445
- elif self.config.norm_type == "ada_norm_single":
446
- shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
447
- hidden_states = self.norm_out(hidden_states)
448
- # Modulation
449
- hidden_states = hidden_states * (1 + scale) + shift
450
- hidden_states = self.proj_out(hidden_states)
451
- hidden_states = hidden_states.squeeze(1)
452
-
453
- # unpatchify
454
- if self.adaln_single is None:
455
- height = width = int(hidden_states.shape[1] ** 0.5)
456
- hidden_states = hidden_states.reshape(
457
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
458
- )
459
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
460
- output = hidden_states.reshape(
461
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
462
- )
463
-
464
- if not return_dict:
465
- return (output,),curr_garment_feat_idx
466
-
467
- return Transformer2DModelOutput(sample=output),curr_garment_feat_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/tryon_pipeline.py DELETED
@@ -1,1896 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import inspect
16
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
-
18
- import numpy as np
19
- import PIL.Image
20
- import torch
21
- from transformers import (
22
- CLIPImageProcessor,
23
- CLIPTextModel,
24
- CLIPTextModelWithProjection,
25
- CLIPTokenizer,
26
- CLIPVisionModelWithProjection,
27
- )
28
-
29
- from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
- from diffusers.loaders import (
31
- FromSingleFileMixin,
32
- IPAdapterMixin,
33
- StableDiffusionXLLoraLoaderMixin,
34
- TextualInversionLoaderMixin,
35
- )
36
- from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
37
- from diffusers.models.attention_processor import (
38
- AttnProcessor2_0,
39
- FusedAttnProcessor2_0,
40
- LoRAAttnProcessor2_0,
41
- LoRAXFormersAttnProcessor,
42
- XFormersAttnProcessor,
43
- )
44
- from diffusers.models.lora import adjust_lora_scale_text_encoder
45
- from diffusers.schedulers import KarrasDiffusionSchedulers
46
- from diffusers.utils import (
47
- USE_PEFT_BACKEND,
48
- deprecate,
49
- is_invisible_watermark_available,
50
- is_torch_xla_available,
51
- logging,
52
- replace_example_docstring,
53
- scale_lora_layers,
54
- unscale_lora_layers,
55
- )
56
- from diffusers.utils.torch_utils import randn_tensor
57
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
58
-
59
-
60
-
61
- if is_torch_xla_available():
62
- import torch_xla.core.xla_model as xm
63
-
64
- XLA_AVAILABLE = True
65
- else:
66
- XLA_AVAILABLE = False
67
-
68
-
69
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
70
-
71
-
72
- EXAMPLE_DOC_STRING = """
73
- Examples:
74
- ```py
75
- >>> import torch
76
- >>> from diffusers import StableDiffusionXLInpaintPipeline
77
- >>> from diffusers.utils import load_image
78
-
79
- >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
80
- ... "stabilityai/stable-diffusion-xl-base-1.0",
81
- ... torch_dtype=torch.float16,
82
- ... variant="fp16",
83
- ... use_safetensors=True,
84
- ... )
85
- >>> pipe.to("cuda")
86
-
87
- >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
88
- >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
89
-
90
- >>> init_image = load_image(img_url).convert("RGB")
91
- >>> mask_image = load_image(mask_url).convert("RGB")
92
-
93
- >>> prompt = "A majestic tiger sitting on a bench"
94
- >>> image = pipe(
95
- ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
96
- ... ).images[0]
97
- ```
98
- """
99
-
100
-
101
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
102
- def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
103
- """
104
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
105
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
106
- """
107
- std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
108
- std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
109
- # rescale the results from guidance (fixes overexposure)
110
- noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
111
- # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
112
- noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
113
- return noise_cfg
114
-
115
-
116
- def mask_pil_to_torch(mask, height, width):
117
- # preprocess mask
118
- if isinstance(mask, (PIL.Image.Image, np.ndarray)):
119
- mask = [mask]
120
-
121
- if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
122
- mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
123
- mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
124
- mask = mask.astype(np.float32) / 255.0
125
- elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
126
- mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
127
-
128
- mask = torch.from_numpy(mask)
129
- return mask
130
-
131
-
132
- def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
133
- """
134
- Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
135
- converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
136
- ``image`` and ``1`` for the ``mask``.
137
-
138
- The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
139
- binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
140
-
141
- Args:
142
- image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
143
- It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
144
- ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
145
- mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
146
- It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
147
- ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
148
-
149
-
150
- Raises:
151
- ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
152
- should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
153
- TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
154
- (ot the other way around).
155
-
156
- Returns:
157
- tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
158
- dimensions: ``batch x channels x height x width``.
159
- """
160
-
161
- # checkpoint. TOD(Yiyi) - need to clean this up later
162
- deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
163
- deprecate(
164
- "prepare_mask_and_masked_image",
165
- "0.30.0",
166
- deprecation_message,
167
- )
168
- if image is None:
169
- raise ValueError("`image` input cannot be undefined.")
170
-
171
- if mask is None:
172
- raise ValueError("`mask_image` input cannot be undefined.")
173
-
174
- if isinstance(image, torch.Tensor):
175
- if not isinstance(mask, torch.Tensor):
176
- mask = mask_pil_to_torch(mask, height, width)
177
-
178
- if image.ndim == 3:
179
- image = image.unsqueeze(0)
180
-
181
- # Batch and add channel dim for single mask
182
- if mask.ndim == 2:
183
- mask = mask.unsqueeze(0).unsqueeze(0)
184
-
185
- # Batch single mask or add channel dim
186
- if mask.ndim == 3:
187
- # Single batched mask, no channel dim or single mask not batched but channel dim
188
- if mask.shape[0] == 1:
189
- mask = mask.unsqueeze(0)
190
-
191
- # Batched masks no channel dim
192
- else:
193
- mask = mask.unsqueeze(1)
194
-
195
- assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
196
- # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
197
- assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
198
-
199
- # Check image is in [-1, 1]
200
- # if image.min() < -1 or image.max() > 1:
201
- # raise ValueError("Image should be in [-1, 1] range")
202
-
203
- # Check mask is in [0, 1]
204
- if mask.min() < 0 or mask.max() > 1:
205
- raise ValueError("Mask should be in [0, 1] range")
206
-
207
- # Binarize mask
208
- mask[mask < 0.5] = 0
209
- mask[mask >= 0.5] = 1
210
-
211
- # Image as float32
212
- image = image.to(dtype=torch.float32)
213
- elif isinstance(mask, torch.Tensor):
214
- raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
215
- else:
216
- # preprocess image
217
- if isinstance(image, (PIL.Image.Image, np.ndarray)):
218
- image = [image]
219
- if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
220
- # resize all images w.r.t passed height an width
221
- image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
222
- image = [np.array(i.convert("RGB"))[None, :] for i in image]
223
- image = np.concatenate(image, axis=0)
224
- elif isinstance(image, list) and isinstance(image[0], np.ndarray):
225
- image = np.concatenate([i[None, :] for i in image], axis=0)
226
-
227
- image = image.transpose(0, 3, 1, 2)
228
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
229
-
230
- mask = mask_pil_to_torch(mask, height, width)
231
- mask[mask < 0.5] = 0
232
- mask[mask >= 0.5] = 1
233
-
234
- if image.shape[1] == 4:
235
- # images are in latent space and thus can't
236
- # be masked set masked_image to None
237
- # we assume that the checkpoint is not an inpainting
238
- # checkpoint. TOD(Yiyi) - need to clean this up later
239
- masked_image = None
240
- else:
241
- masked_image = image * (mask < 0.5)
242
-
243
- # n.b. ensure backwards compatibility as old function does not return image
244
- if return_image:
245
- return mask, masked_image, image
246
-
247
- return mask, masked_image
248
-
249
-
250
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
251
- def retrieve_latents(
252
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
253
- ):
254
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
255
- return encoder_output.latent_dist.sample(generator)
256
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
257
- return encoder_output.latent_dist.mode()
258
- elif hasattr(encoder_output, "latents"):
259
- return encoder_output.latents
260
- else:
261
- raise AttributeError("Could not access latents of provided encoder_output")
262
-
263
-
264
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
265
- def retrieve_timesteps(
266
- scheduler,
267
- num_inference_steps: Optional[int] = None,
268
- device: Optional[Union[str, torch.device]] = None,
269
- timesteps: Optional[List[int]] = None,
270
- **kwargs,
271
- ):
272
- """
273
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
274
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
275
-
276
- Args:
277
- scheduler (`SchedulerMixin`):
278
- The scheduler to get timesteps from.
279
- num_inference_steps (`int`):
280
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
281
- `timesteps` must be `None`.
282
- device (`str` or `torch.device`, *optional*):
283
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
284
- timesteps (`List[int]`, *optional*):
285
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
286
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
287
- must be `None`.
288
-
289
- Returns:
290
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
291
- second element is the number of inference steps.
292
- """
293
- if timesteps is not None:
294
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
295
- if not accepts_timesteps:
296
- raise ValueError(
297
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
298
- f" timestep schedules. Please check whether you are using the correct scheduler."
299
- )
300
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
301
- timesteps = scheduler.timesteps
302
- num_inference_steps = len(timesteps)
303
- else:
304
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
305
- timesteps = scheduler.timesteps
306
- return timesteps, num_inference_steps
307
-
308
-
309
- class StableDiffusionXLInpaintPipeline(
310
- DiffusionPipeline,
311
- TextualInversionLoaderMixin,
312
- StableDiffusionXLLoraLoaderMixin,
313
- FromSingleFileMixin,
314
- IPAdapterMixin,
315
- ):
316
- r"""
317
- Pipeline for text-to-image generation using Stable Diffusion XL.
318
-
319
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
320
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
321
-
322
- The pipeline also inherits the following loading methods:
323
- - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
324
- - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
325
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
326
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
327
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
328
-
329
- Args:
330
- vae ([`AutoencoderKL`]):
331
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
332
- text_encoder ([`CLIPTextModel`]):
333
- Frozen text-encoder. Stable Diffusion XL uses the text portion of
334
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
335
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
336
- text_encoder_2 ([` CLIPTextModelWithProjection`]):
337
- Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
338
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
339
- specifically the
340
- [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
341
- variant.
342
- tokenizer (`CLIPTokenizer`):
343
- Tokenizer of class
344
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
345
- tokenizer_2 (`CLIPTokenizer`):
346
- Second Tokenizer of class
347
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
348
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
349
- scheduler ([`SchedulerMixin`]):
350
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
351
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
352
- requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
353
- Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
354
- of `stabilityai/stable-diffusion-xl-refiner-1-0`.
355
- force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
356
- Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
357
- `stabilityai/stable-diffusion-xl-base-1-0`.
358
- add_watermarker (`bool`, *optional*):
359
- Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
360
- watermark output images. If not defined, it will default to True if the package is installed, otherwise no
361
- watermarker will be used.
362
- """
363
-
364
- model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
365
-
366
- _optional_components = [
367
- "tokenizer",
368
- "tokenizer_2",
369
- "text_encoder",
370
- "text_encoder_2",
371
- "image_encoder",
372
- "feature_extractor",
373
- "unet_encoder",
374
- ]
375
- _callback_tensor_inputs = [
376
- "latents",
377
- "prompt_embeds",
378
- "negative_prompt_embeds",
379
- "add_text_embeds",
380
- "add_time_ids",
381
- "negative_pooled_prompt_embeds",
382
- "add_neg_time_ids",
383
- "mask",
384
- "masked_image_latents",
385
- ]
386
-
387
- def __init__(
388
- self,
389
- vae: AutoencoderKL,
390
- text_encoder: CLIPTextModel,
391
- text_encoder_2: CLIPTextModelWithProjection,
392
- tokenizer: CLIPTokenizer,
393
- tokenizer_2: CLIPTokenizer,
394
- unet: UNet2DConditionModel,
395
- unet_encoder: UNet2DConditionModel,
396
- scheduler: KarrasDiffusionSchedulers,
397
- image_encoder: CLIPVisionModelWithProjection = None,
398
- feature_extractor: CLIPImageProcessor = None,
399
- requires_aesthetics_score: bool = False,
400
- force_zeros_for_empty_prompt: bool = True,
401
- ):
402
- super().__init__()
403
-
404
- self.register_modules(
405
- vae=vae,
406
- text_encoder=text_encoder,
407
- text_encoder_2=text_encoder_2,
408
- tokenizer=tokenizer,
409
- tokenizer_2=tokenizer_2,
410
- unet=unet,
411
- image_encoder=image_encoder,
412
- unet_encoder=unet_encoder,
413
- feature_extractor=feature_extractor,
414
- scheduler=scheduler,
415
- )
416
- self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
417
- self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
418
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
419
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
420
- self.mask_processor = VaeImageProcessor(
421
- vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
422
- )
423
-
424
-
425
-
426
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
427
- def enable_vae_slicing(self):
428
- r"""
429
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
430
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
431
- """
432
- self.vae.enable_slicing()
433
-
434
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
435
- def disable_vae_slicing(self):
436
- r"""
437
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
438
- computing decoding in one step.
439
- """
440
- self.vae.disable_slicing()
441
-
442
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
443
- def enable_vae_tiling(self):
444
- r"""
445
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
446
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
447
- processing larger images.
448
- """
449
- self.vae.enable_tiling()
450
-
451
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
452
- def disable_vae_tiling(self):
453
- r"""
454
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
455
- computing decoding in one step.
456
- """
457
- self.vae.disable_tiling()
458
-
459
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
460
- def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
461
- dtype = next(self.image_encoder.parameters()).dtype
462
- # print(image.shape)
463
- if not isinstance(image, torch.Tensor):
464
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
465
-
466
- image = image.to(device=device, dtype=dtype)
467
- if output_hidden_states:
468
- image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
469
- image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
470
- uncond_image_enc_hidden_states = self.image_encoder(
471
- torch.zeros_like(image), output_hidden_states=True
472
- ).hidden_states[-2]
473
- uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
474
- num_images_per_prompt, dim=0
475
- )
476
- return image_enc_hidden_states, uncond_image_enc_hidden_states
477
- else:
478
- image_embeds = self.image_encoder(image).image_embeds
479
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
480
- uncond_image_embeds = torch.zeros_like(image_embeds)
481
-
482
- return image_embeds, uncond_image_embeds
483
-
484
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
485
- def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
486
- # if not isinstance(ip_adapter_image, list):
487
- # ip_adapter_image = [ip_adapter_image]
488
-
489
- # if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
490
- # raise ValueError(
491
- # f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
492
- # )
493
- output_hidden_state = not isinstance(self.unet.encoder_hid_proj, ImageProjection)
494
- # print(output_hidden_state)
495
- image_embeds, negative_image_embeds = self.encode_image(
496
- ip_adapter_image, device, 1, output_hidden_state
497
- )
498
- # print(single_image_embeds.shape)
499
- # single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
500
- # single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
501
- # print(single_image_embeds.shape)
502
- if self.do_classifier_free_guidance:
503
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
504
- image_embeds = image_embeds.to(device)
505
-
506
-
507
- return image_embeds
508
-
509
-
510
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
511
- def encode_prompt(
512
- self,
513
- prompt: str,
514
- prompt_2: Optional[str] = None,
515
- device: Optional[torch.device] = None,
516
- num_images_per_prompt: int = 1,
517
- do_classifier_free_guidance: bool = True,
518
- negative_prompt: Optional[str] = None,
519
- negative_prompt_2: Optional[str] = None,
520
- prompt_embeds: Optional[torch.FloatTensor] = None,
521
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
522
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
523
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
524
- lora_scale: Optional[float] = None,
525
- clip_skip: Optional[int] = None,
526
- ):
527
- r"""
528
- Encodes the prompt into text encoder hidden states.
529
-
530
- Args:
531
- prompt (`str` or `List[str]`, *optional*):
532
- prompt to be encoded
533
- prompt_2 (`str` or `List[str]`, *optional*):
534
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
535
- used in both text-encoders
536
- device: (`torch.device`):
537
- torch device
538
- num_images_per_prompt (`int`):
539
- number of images that should be generated per prompt
540
- do_classifier_free_guidance (`bool`):
541
- whether to use classifier free guidance or not
542
- negative_prompt (`str` or `List[str]`, *optional*):
543
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
544
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
545
- less than `1`).
546
- negative_prompt_2 (`str` or `List[str]`, *optional*):
547
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
548
- `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
549
- prompt_embeds (`torch.FloatTensor`, *optional*):
550
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
551
- provided, text embeddings will be generated from `prompt` input argument.
552
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
553
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
554
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
555
- argument.
556
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
557
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
558
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
559
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
560
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
561
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
562
- input argument.
563
- lora_scale (`float`, *optional*):
564
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
565
- clip_skip (`int`, *optional*):
566
- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
567
- the output of the pre-final layer will be used for computing the prompt embeddings.
568
- """
569
- device = device or self._execution_device
570
-
571
- # set lora scale so that monkey patched LoRA
572
- # function of text encoder can correctly access it
573
- if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
574
- self._lora_scale = lora_scale
575
-
576
- # dynamically adjust the LoRA scale
577
- if self.text_encoder is not None:
578
- if not USE_PEFT_BACKEND:
579
- adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
580
- else:
581
- scale_lora_layers(self.text_encoder, lora_scale)
582
-
583
- if self.text_encoder_2 is not None:
584
- if not USE_PEFT_BACKEND:
585
- adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
586
- else:
587
- scale_lora_layers(self.text_encoder_2, lora_scale)
588
-
589
- prompt = [prompt] if isinstance(prompt, str) else prompt
590
-
591
- if prompt is not None:
592
- batch_size = len(prompt)
593
- else:
594
- batch_size = prompt_embeds.shape[0]
595
-
596
- # Define tokenizers and text encoders
597
- tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
598
- text_encoders = (
599
- [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
600
- )
601
-
602
- if prompt_embeds is None:
603
- prompt_2 = prompt_2 or prompt
604
- prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
605
-
606
- # textual inversion: procecss multi-vector tokens if necessary
607
- prompt_embeds_list = []
608
- prompts = [prompt, prompt_2]
609
- for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
610
- if isinstance(self, TextualInversionLoaderMixin):
611
- prompt = self.maybe_convert_prompt(prompt, tokenizer)
612
-
613
- text_inputs = tokenizer(
614
- prompt,
615
- padding="max_length",
616
- max_length=tokenizer.model_max_length,
617
- truncation=True,
618
- return_tensors="pt",
619
- )
620
-
621
- text_input_ids = text_inputs.input_ids
622
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
623
-
624
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
625
- text_input_ids, untruncated_ids
626
- ):
627
- removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
628
- logger.warning(
629
- "The following part of your input was truncated because CLIP can only handle sequences up to"
630
- f" {tokenizer.model_max_length} tokens: {removed_text}"
631
- )
632
-
633
- prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
634
-
635
- # We are only ALWAYS interested in the pooled output of the final text encoder
636
- pooled_prompt_embeds = prompt_embeds[0]
637
- if clip_skip is None:
638
- prompt_embeds = prompt_embeds.hidden_states[-2]
639
- else:
640
- # "2" because SDXL always indexes from the penultimate layer.
641
- prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
642
-
643
- prompt_embeds_list.append(prompt_embeds)
644
-
645
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
646
-
647
- # get unconditional embeddings for classifier free guidance
648
- zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
649
- if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
650
- negative_prompt_embeds = torch.zeros_like(prompt_embeds)
651
- negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
652
- elif do_classifier_free_guidance and negative_prompt_embeds is None:
653
- negative_prompt = negative_prompt or ""
654
- negative_prompt_2 = negative_prompt_2 or negative_prompt
655
-
656
- # normalize str to list
657
- negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
658
- negative_prompt_2 = (
659
- batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
660
- )
661
-
662
- uncond_tokens: List[str]
663
- if prompt is not None and type(prompt) is not type(negative_prompt):
664
- raise TypeError(
665
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
666
- f" {type(prompt)}."
667
- )
668
- elif batch_size != len(negative_prompt):
669
- raise ValueError(
670
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
671
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
672
- " the batch size of `prompt`."
673
- )
674
- else:
675
- uncond_tokens = [negative_prompt, negative_prompt_2]
676
-
677
- negative_prompt_embeds_list = []
678
- for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
679
- if isinstance(self, TextualInversionLoaderMixin):
680
- negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
681
-
682
- max_length = prompt_embeds.shape[1]
683
- uncond_input = tokenizer(
684
- negative_prompt,
685
- padding="max_length",
686
- max_length=max_length,
687
- truncation=True,
688
- return_tensors="pt",
689
- )
690
-
691
- negative_prompt_embeds = text_encoder(
692
- uncond_input.input_ids.to(device),
693
- output_hidden_states=True,
694
- )
695
- # We are only ALWAYS interested in the pooled output of the final text encoder
696
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
697
- negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
698
-
699
- negative_prompt_embeds_list.append(negative_prompt_embeds)
700
-
701
- negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
702
-
703
- if self.text_encoder_2 is not None:
704
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
705
- else:
706
- prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
707
-
708
- bs_embed, seq_len, _ = prompt_embeds.shape
709
- # duplicate text embeddings for each generation per prompt, using mps friendly method
710
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
711
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
712
-
713
- if do_classifier_free_guidance:
714
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
715
- seq_len = negative_prompt_embeds.shape[1]
716
-
717
- if self.text_encoder_2 is not None:
718
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
719
- else:
720
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
721
-
722
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
723
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
724
-
725
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
726
- bs_embed * num_images_per_prompt, -1
727
- )
728
- if do_classifier_free_guidance:
729
- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
730
- bs_embed * num_images_per_prompt, -1
731
- )
732
-
733
- if self.text_encoder is not None:
734
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
735
- # Retrieve the original scale by scaling back the LoRA layers
736
- unscale_lora_layers(self.text_encoder, lora_scale)
737
-
738
- if self.text_encoder_2 is not None:
739
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
740
- # Retrieve the original scale by scaling back the LoRA layers
741
- unscale_lora_layers(self.text_encoder_2, lora_scale)
742
-
743
- return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
744
-
745
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
746
- def prepare_extra_step_kwargs(self, generator, eta):
747
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
748
- # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
749
- # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
750
- # and should be between [0, 1]
751
-
752
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
753
- extra_step_kwargs = {}
754
- if accepts_eta:
755
- extra_step_kwargs["eta"] = eta
756
-
757
- # check if the scheduler accepts generator
758
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
759
- if accepts_generator:
760
- extra_step_kwargs["generator"] = generator
761
- return extra_step_kwargs
762
-
763
- def check_inputs(
764
- self,
765
- prompt,
766
- prompt_2,
767
- image,
768
- mask_image,
769
- height,
770
- width,
771
- strength,
772
- callback_steps,
773
- output_type,
774
- negative_prompt=None,
775
- negative_prompt_2=None,
776
- prompt_embeds=None,
777
- negative_prompt_embeds=None,
778
- callback_on_step_end_tensor_inputs=None,
779
- padding_mask_crop=None,
780
- ):
781
- if strength < 0 or strength > 1:
782
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
783
-
784
- if height % 8 != 0 or width % 8 != 0:
785
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
786
-
787
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
788
- raise ValueError(
789
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
790
- f" {type(callback_steps)}."
791
- )
792
-
793
- if callback_on_step_end_tensor_inputs is not None and not all(
794
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
795
- ):
796
- raise ValueError(
797
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
798
- )
799
-
800
- if prompt is not None and prompt_embeds is not None:
801
- raise ValueError(
802
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
803
- " only forward one of the two."
804
- )
805
- elif prompt_2 is not None and prompt_embeds is not None:
806
- raise ValueError(
807
- f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
808
- " only forward one of the two."
809
- )
810
- elif prompt is None and prompt_embeds is None:
811
- raise ValueError(
812
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
813
- )
814
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
815
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
816
- elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
817
- raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
818
-
819
- if negative_prompt is not None and negative_prompt_embeds is not None:
820
- raise ValueError(
821
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
822
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
823
- )
824
- elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
825
- raise ValueError(
826
- f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
827
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
828
- )
829
-
830
- if prompt_embeds is not None and negative_prompt_embeds is not None:
831
- if prompt_embeds.shape != negative_prompt_embeds.shape:
832
- raise ValueError(
833
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
834
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
835
- f" {negative_prompt_embeds.shape}."
836
- )
837
- if padding_mask_crop is not None:
838
- if not isinstance(image, PIL.Image.Image):
839
- raise ValueError(
840
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
841
- )
842
- if not isinstance(mask_image, PIL.Image.Image):
843
- raise ValueError(
844
- f"The mask image should be a PIL image when inpainting mask crop, but is of type"
845
- f" {type(mask_image)}."
846
- )
847
- if output_type != "pil":
848
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
849
-
850
- def prepare_latents(
851
- self,
852
- batch_size,
853
- num_channels_latents,
854
- height,
855
- width,
856
- dtype,
857
- device,
858
- generator,
859
- latents=None,
860
- image=None,
861
- timestep=None,
862
- is_strength_max=True,
863
- add_noise=True,
864
- return_noise=False,
865
- return_image_latents=False,
866
- ):
867
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
868
- if isinstance(generator, list) and len(generator) != batch_size:
869
- raise ValueError(
870
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
871
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
872
- )
873
-
874
- if (image is None or timestep is None) and not is_strength_max:
875
- raise ValueError(
876
- "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
877
- "However, either the image or the noise timestep has not been provided."
878
- )
879
-
880
- if image.shape[1] == 4:
881
- image_latents = image.to(device=device, dtype=dtype)
882
- image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
883
- elif return_image_latents or (latents is None and not is_strength_max):
884
- image = image.to(device=device, dtype=dtype)
885
- image_latents = self._encode_vae_image(image=image, generator=generator)
886
- image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
887
-
888
- if latents is None and add_noise:
889
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
890
- # if strength is 1. then initialise the latents to noise, else initial to image + noise
891
- latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
892
- # if pure noise then scale the initial latents by the Scheduler's init sigma
893
- latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
894
- elif add_noise:
895
- noise = latents.to(device)
896
- latents = noise * self.scheduler.init_noise_sigma
897
- else:
898
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
899
- latents = image_latents.to(device)
900
-
901
- outputs = (latents,)
902
-
903
- if return_noise:
904
- outputs += (noise,)
905
-
906
- if return_image_latents:
907
- outputs += (image_latents,)
908
-
909
- return outputs
910
-
911
- def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
912
- dtype = image.dtype
913
- if self.vae.config.force_upcast:
914
- image = image.float()
915
- self.vae.to(dtype=torch.float32)
916
-
917
- if isinstance(generator, list):
918
- image_latents = [
919
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
920
- for i in range(image.shape[0])
921
- ]
922
- image_latents = torch.cat(image_latents, dim=0)
923
- else:
924
- image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
925
-
926
- if self.vae.config.force_upcast:
927
- self.vae.to(dtype)
928
-
929
- image_latents = image_latents.to(dtype)
930
- image_latents = self.vae.config.scaling_factor * image_latents
931
-
932
- return image_latents
933
-
934
- def prepare_mask_latents(
935
- self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
936
- ):
937
- # resize the mask to latents shape as we concatenate the mask to the latents
938
- # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
939
- # and half precision
940
- mask = torch.nn.functional.interpolate(
941
- mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
942
- )
943
- mask = mask.to(device=device, dtype=dtype)
944
-
945
- # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
946
- if mask.shape[0] < batch_size:
947
- if not batch_size % mask.shape[0] == 0:
948
- raise ValueError(
949
- "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
950
- f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
951
- " of masks that you pass is divisible by the total requested batch size."
952
- )
953
- mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
954
-
955
- mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
956
- if masked_image is not None and masked_image.shape[1] == 4:
957
- masked_image_latents = masked_image
958
- else:
959
- masked_image_latents = None
960
-
961
- if masked_image is not None:
962
- if masked_image_latents is None:
963
- masked_image = masked_image.to(device=device, dtype=dtype)
964
- masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
965
-
966
- if masked_image_latents.shape[0] < batch_size:
967
- if not batch_size % masked_image_latents.shape[0] == 0:
968
- raise ValueError(
969
- "The passed images and the required batch size don't match. Images are supposed to be duplicated"
970
- f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
971
- " Make sure the number of images that you pass is divisible by the total requested batch size."
972
- )
973
- masked_image_latents = masked_image_latents.repeat(
974
- batch_size // masked_image_latents.shape[0], 1, 1, 1
975
- )
976
-
977
- masked_image_latents = (
978
- torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
979
- )
980
-
981
- # aligning device to prevent device errors when concating it with the latent model input
982
- masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
983
-
984
- return mask, masked_image_latents
985
-
986
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
987
- def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
988
- # get the original timestep using init_timestep
989
- if denoising_start is None:
990
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
991
- t_start = max(num_inference_steps - init_timestep, 0)
992
- else:
993
- t_start = 0
994
-
995
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
996
-
997
- # Strength is irrelevant if we directly request a timestep to start at;
998
- # that is, strength is determined by the denoising_start instead.
999
- if denoising_start is not None:
1000
- discrete_timestep_cutoff = int(
1001
- round(
1002
- self.scheduler.config.num_train_timesteps
1003
- - (denoising_start * self.scheduler.config.num_train_timesteps)
1004
- )
1005
- )
1006
-
1007
- num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
1008
- if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
1009
- # if the scheduler is a 2nd order scheduler we might have to do +1
1010
- # because `num_inference_steps` might be even given that every timestep
1011
- # (except the highest one) is duplicated. If `num_inference_steps` is even it would
1012
- # mean that we cut the timesteps in the middle of the denoising step
1013
- # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
1014
- # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
1015
- num_inference_steps = num_inference_steps + 1
1016
-
1017
- # because t_n+1 >= t_n, we slice the timesteps starting from the end
1018
- timesteps = timesteps[-num_inference_steps:]
1019
- return timesteps, num_inference_steps
1020
-
1021
- return timesteps, num_inference_steps - t_start
1022
-
1023
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
1024
- def _get_add_time_ids(
1025
- self,
1026
- original_size,
1027
- crops_coords_top_left,
1028
- target_size,
1029
- aesthetic_score,
1030
- negative_aesthetic_score,
1031
- negative_original_size,
1032
- negative_crops_coords_top_left,
1033
- negative_target_size,
1034
- dtype,
1035
- text_encoder_projection_dim=None,
1036
- ):
1037
- if self.config.requires_aesthetics_score:
1038
- add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
1039
- add_neg_time_ids = list(
1040
- negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
1041
- )
1042
- else:
1043
- add_time_ids = list(original_size + crops_coords_top_left + target_size)
1044
- add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
1045
-
1046
- passed_add_embed_dim = (
1047
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
1048
- )
1049
- expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
1050
-
1051
- if (
1052
- expected_add_embed_dim > passed_add_embed_dim
1053
- and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
1054
- ):
1055
- raise ValueError(
1056
- f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
1057
- )
1058
- elif (
1059
- expected_add_embed_dim < passed_add_embed_dim
1060
- and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
1061
- ):
1062
- raise ValueError(
1063
- f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
1064
- )
1065
- elif expected_add_embed_dim != passed_add_embed_dim:
1066
- raise ValueError(
1067
- f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
1068
- )
1069
-
1070
- add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
1071
- add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
1072
-
1073
- return add_time_ids, add_neg_time_ids
1074
-
1075
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
1076
- def upcast_vae(self):
1077
- dtype = self.vae.dtype
1078
- self.vae.to(dtype=torch.float32)
1079
- use_torch_2_0_or_xformers = isinstance(
1080
- self.vae.decoder.mid_block.attentions[0].processor,
1081
- (
1082
- AttnProcessor2_0,
1083
- XFormersAttnProcessor,
1084
- LoRAXFormersAttnProcessor,
1085
- LoRAAttnProcessor2_0,
1086
- ),
1087
- )
1088
- # if xformers or torch_2_0 is used attention block does not need
1089
- # to be in float32 which can save lots of memory
1090
- if use_torch_2_0_or_xformers:
1091
- self.vae.post_quant_conv.to(dtype)
1092
- self.vae.decoder.conv_in.to(dtype)
1093
- self.vae.decoder.mid_block.to(dtype)
1094
-
1095
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
1096
- def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
1097
- r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
1098
-
1099
- The suffixes after the scaling factors represent the stages where they are being applied.
1100
-
1101
- Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
1102
- that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
1103
-
1104
- Args:
1105
- s1 (`float`):
1106
- Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
1107
- mitigate "oversmoothing effect" in the enhanced denoising process.
1108
- s2 (`float`):
1109
- Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
1110
- mitigate "oversmoothing effect" in the enhanced denoising process.
1111
- b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
1112
- b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
1113
- """
1114
- if not hasattr(self, "unet"):
1115
- raise ValueError("The pipeline must have `unet` for using FreeU.")
1116
- self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
1117
-
1118
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
1119
- def disable_freeu(self):
1120
- """Disables the FreeU mechanism if enabled."""
1121
- self.unet.disable_freeu()
1122
-
1123
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
1124
- def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
1125
- """
1126
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
1127
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
1128
-
1129
- <Tip warning={true}>
1130
-
1131
- This API is πŸ§ͺ experimental.
1132
-
1133
- </Tip>
1134
-
1135
- Args:
1136
- unet (`bool`, defaults to `True`): To apply fusion on the UNet.
1137
- vae (`bool`, defaults to `True`): To apply fusion on the VAE.
1138
- """
1139
- self.fusing_unet = False
1140
- self.fusing_vae = False
1141
-
1142
- if unet:
1143
- self.fusing_unet = True
1144
- self.unet.fuse_qkv_projections()
1145
- self.unet.set_attn_processor(FusedAttnProcessor2_0())
1146
-
1147
- if vae:
1148
- if not isinstance(self.vae, AutoencoderKL):
1149
- raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
1150
-
1151
- self.fusing_vae = True
1152
- self.vae.fuse_qkv_projections()
1153
- self.vae.set_attn_processor(FusedAttnProcessor2_0())
1154
-
1155
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
1156
- def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
1157
- """Disable QKV projection fusion if enabled.
1158
-
1159
- <Tip warning={true}>
1160
-
1161
- This API is πŸ§ͺ experimental.
1162
-
1163
- </Tip>
1164
-
1165
- Args:
1166
- unet (`bool`, defaults to `True`): To apply fusion on the UNet.
1167
- vae (`bool`, defaults to `True`): To apply fusion on the VAE.
1168
-
1169
- """
1170
- if unet:
1171
- if not self.fusing_unet:
1172
- logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
1173
- else:
1174
- self.unet.unfuse_qkv_projections()
1175
- self.fusing_unet = False
1176
-
1177
- if vae:
1178
- if not self.fusing_vae:
1179
- logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
1180
- else:
1181
- self.vae.unfuse_qkv_projections()
1182
- self.fusing_vae = False
1183
-
1184
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
1185
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
1186
- """
1187
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
1188
-
1189
- Args:
1190
- timesteps (`torch.Tensor`):
1191
- generate embedding vectors at these timesteps
1192
- embedding_dim (`int`, *optional*, defaults to 512):
1193
- dimension of the embeddings to generate
1194
- dtype:
1195
- data type of the generated embeddings
1196
-
1197
- Returns:
1198
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
1199
- """
1200
- assert len(w.shape) == 1
1201
- w = w * 1000.0
1202
-
1203
- half_dim = embedding_dim // 2
1204
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
1205
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
1206
- emb = w.to(dtype)[:, None] * emb[None, :]
1207
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
1208
- if embedding_dim % 2 == 1: # zero pad
1209
- emb = torch.nn.functional.pad(emb, (0, 1))
1210
- assert emb.shape == (w.shape[0], embedding_dim)
1211
- return emb
1212
-
1213
- @property
1214
- def guidance_scale(self):
1215
- return self._guidance_scale
1216
-
1217
- @property
1218
- def guidance_rescale(self):
1219
- return self._guidance_rescale
1220
-
1221
- @property
1222
- def clip_skip(self):
1223
- return self._clip_skip
1224
-
1225
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1226
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1227
- # corresponds to doing no classifier free guidance.
1228
- @property
1229
- def do_classifier_free_guidance(self):
1230
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
1231
-
1232
- @property
1233
- def cross_attention_kwargs(self):
1234
- return self._cross_attention_kwargs
1235
-
1236
- @property
1237
- def denoising_end(self):
1238
- return self._denoising_end
1239
-
1240
- @property
1241
- def denoising_start(self):
1242
- return self._denoising_start
1243
-
1244
- @property
1245
- def num_timesteps(self):
1246
- return self._num_timesteps
1247
-
1248
- @property
1249
- def interrupt(self):
1250
- return self._interrupt
1251
-
1252
- @torch.no_grad()
1253
- @replace_example_docstring(EXAMPLE_DOC_STRING)
1254
- def __call__(
1255
- self,
1256
- prompt: Union[str, List[str]] = None,
1257
- prompt_2: Optional[Union[str, List[str]]] = None,
1258
- image: PipelineImageInput = None,
1259
- mask_image: PipelineImageInput = None,
1260
- masked_image_latents: torch.FloatTensor = None,
1261
- height: Optional[int] = None,
1262
- width: Optional[int] = None,
1263
- padding_mask_crop: Optional[int] = None,
1264
- strength: float = 0.9999,
1265
- num_inference_steps: int = 50,
1266
- timesteps: List[int] = None,
1267
- denoising_start: Optional[float] = None,
1268
- denoising_end: Optional[float] = None,
1269
- guidance_scale: float = 7.5,
1270
- negative_prompt: Optional[Union[str, List[str]]] = None,
1271
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
1272
- num_images_per_prompt: Optional[int] = 1,
1273
- eta: float = 0.0,
1274
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1275
- latents: Optional[torch.FloatTensor] = None,
1276
- prompt_embeds: Optional[torch.FloatTensor] = None,
1277
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1278
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1279
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1280
- ip_adapter_image: Optional[PipelineImageInput] = None,
1281
- output_type: Optional[str] = "pil",
1282
- cloth =None,
1283
- pose_img = None,
1284
- text_embeds_cloth=None,
1285
- return_dict: bool = True,
1286
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1287
- guidance_rescale: float = 0.0,
1288
- original_size: Tuple[int, int] = None,
1289
- crops_coords_top_left: Tuple[int, int] = (0, 0),
1290
- target_size: Tuple[int, int] = None,
1291
- negative_original_size: Optional[Tuple[int, int]] = None,
1292
- negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
1293
- negative_target_size: Optional[Tuple[int, int]] = None,
1294
- aesthetic_score: float = 6.0,
1295
- negative_aesthetic_score: float = 2.5,
1296
- clip_skip: Optional[int] = None,
1297
- pooled_prompt_embeds_c=None,
1298
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1299
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1300
- **kwargs,
1301
- ):
1302
- r"""
1303
- Function invoked when calling the pipeline for generation.
1304
-
1305
- Args:
1306
- prompt (`str` or `List[str]`, *optional*):
1307
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1308
- instead.
1309
- prompt_2 (`str` or `List[str]`, *optional*):
1310
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1311
- used in both text-encoders
1312
- image (`PIL.Image.Image`):
1313
- `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
1314
- be masked out with `mask_image` and repainted according to `prompt`.
1315
- mask_image (`PIL.Image.Image`):
1316
- `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1317
- repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
1318
- to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
1319
- instead of 3, so the expected shape would be `(B, H, W, 1)`.
1320
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1321
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
1322
- Anything below 512 pixels won't work well for
1323
- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1324
- and checkpoints that are not specifically fine-tuned on low resolutions.
1325
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1326
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
1327
- Anything below 512 pixels won't work well for
1328
- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1329
- and checkpoints that are not specifically fine-tuned on low resolutions.
1330
- padding_mask_crop (`int`, *optional*, defaults to `None`):
1331
- The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
1332
- `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
1333
- contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
1334
- the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
1335
- and contain information inreleant for inpainging, such as background.
1336
- strength (`float`, *optional*, defaults to 0.9999):
1337
- Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
1338
- between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
1339
- `strength`. The number of denoising steps depends on the amount of noise initially added. When
1340
- `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
1341
- iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
1342
- portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
1343
- integer, the value of `strength` will be ignored.
1344
- num_inference_steps (`int`, *optional*, defaults to 50):
1345
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1346
- expense of slower inference.
1347
- timesteps (`List[int]`, *optional*):
1348
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1349
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1350
- passed will be used. Must be in descending order.
1351
- denoising_start (`float`, *optional*):
1352
- When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
1353
- bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
1354
- it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
1355
- strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
1356
- is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
1357
- Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
1358
- denoising_end (`float`, *optional*):
1359
- When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1360
- completed before it is intentionally prematurely terminated. As a result, the returned sample will
1361
- still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
1362
- denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
1363
- final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
1364
- forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1365
- Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
1366
- guidance_scale (`float`, *optional*, defaults to 7.5):
1367
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1368
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
1369
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1370
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1371
- usually at the expense of lower image quality.
1372
- negative_prompt (`str` or `List[str]`, *optional*):
1373
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
1374
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1375
- less than `1`).
1376
- negative_prompt_2 (`str` or `List[str]`, *optional*):
1377
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1378
- `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1379
- prompt_embeds (`torch.FloatTensor`, *optional*):
1380
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1381
- provided, text embeddings will be generated from `prompt` input argument.
1382
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1383
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1384
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1385
- argument.
1386
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1387
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1388
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
1389
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1390
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1391
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1392
- input argument.
1393
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1394
- num_images_per_prompt (`int`, *optional*, defaults to 1):
1395
- The number of images to generate per prompt.
1396
- eta (`float`, *optional*, defaults to 0.0):
1397
- Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1398
- [`schedulers.DDIMScheduler`], will be ignored for others.
1399
- generator (`torch.Generator`, *optional*):
1400
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1401
- to make generation deterministic.
1402
- latents (`torch.FloatTensor`, *optional*):
1403
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1404
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1405
- tensor will ge generated by sampling using the supplied random `generator`.
1406
- output_type (`str`, *optional*, defaults to `"pil"`):
1407
- The output format of the generate image. Choose between
1408
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1409
- return_dict (`bool`, *optional*, defaults to `True`):
1410
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1411
- plain tuple.
1412
- cross_attention_kwargs (`dict`, *optional*):
1413
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1414
- `self.processor` in
1415
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1416
- original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1417
- If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1418
- `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1419
- explained in section 2.2 of
1420
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1421
- crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1422
- `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1423
- `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1424
- `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1425
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1426
- target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1427
- For most cases, `target_size` should be set to the desired height and width of the generated image. If
1428
- not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1429
- section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1430
- negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1431
- To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1432
- micro-conditioning as explained in section 2.2 of
1433
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1434
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1435
- negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1436
- To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1437
- micro-conditioning as explained in section 2.2 of
1438
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1439
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1440
- negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1441
- To negatively condition the generation process based on a target image resolution. It should be as same
1442
- as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1443
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1444
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1445
- aesthetic_score (`float`, *optional*, defaults to 6.0):
1446
- Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
1447
- Part of SDXL's micro-conditioning as explained in section 2.2 of
1448
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1449
- negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
1450
- Part of SDXL's micro-conditioning as explained in section 2.2 of
1451
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
1452
- simulate an aesthetic score of the generated image by influencing the negative text condition.
1453
- clip_skip (`int`, *optional*):
1454
- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1455
- the output of the pre-final layer will be used for computing the prompt embeddings.
1456
- callback_on_step_end (`Callable`, *optional*):
1457
- A function that calls at the end of each denoising steps during the inference. The function is called
1458
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1459
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1460
- `callback_on_step_end_tensor_inputs`.
1461
- callback_on_step_end_tensor_inputs (`List`, *optional*):
1462
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1463
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1464
- `._callback_tensor_inputs` attribute of your pipeline class.
1465
-
1466
- Examples:
1467
-
1468
- Returns:
1469
- [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
1470
- [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1471
- `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
1472
- """
1473
-
1474
- callback = kwargs.pop("callback", None)
1475
- callback_steps = kwargs.pop("callback_steps", None)
1476
-
1477
- if callback is not None:
1478
- deprecate(
1479
- "callback",
1480
- "1.0.0",
1481
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1482
- )
1483
- if callback_steps is not None:
1484
- deprecate(
1485
- "callback_steps",
1486
- "1.0.0",
1487
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1488
- )
1489
-
1490
- # 0. Default height and width to unet
1491
- height = height or self.unet.config.sample_size * self.vae_scale_factor
1492
- width = width or self.unet.config.sample_size * self.vae_scale_factor
1493
-
1494
- # 1. Check inputs
1495
- self.check_inputs(
1496
- prompt,
1497
- prompt_2,
1498
- image,
1499
- mask_image,
1500
- height,
1501
- width,
1502
- strength,
1503
- callback_steps,
1504
- output_type,
1505
- negative_prompt,
1506
- negative_prompt_2,
1507
- prompt_embeds,
1508
- negative_prompt_embeds,
1509
- callback_on_step_end_tensor_inputs,
1510
- padding_mask_crop,
1511
- )
1512
-
1513
- self._guidance_scale = guidance_scale
1514
- self._guidance_rescale = guidance_rescale
1515
- self._clip_skip = clip_skip
1516
- self._cross_attention_kwargs = cross_attention_kwargs
1517
- self._denoising_end = denoising_end
1518
- self._denoising_start = denoising_start
1519
- self._interrupt = False
1520
-
1521
- # 2. Define call parameters
1522
- if prompt is not None and isinstance(prompt, str):
1523
- batch_size = 1
1524
- elif prompt is not None and isinstance(prompt, list):
1525
- batch_size = len(prompt)
1526
- else:
1527
- batch_size = prompt_embeds.shape[0]
1528
-
1529
- device = self._execution_device
1530
-
1531
- # 3. Encode input prompt
1532
- text_encoder_lora_scale = (
1533
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1534
- )
1535
-
1536
- (
1537
- prompt_embeds,
1538
- negative_prompt_embeds,
1539
- pooled_prompt_embeds,
1540
- negative_pooled_prompt_embeds,
1541
- ) = self.encode_prompt(
1542
- prompt=prompt,
1543
- prompt_2=prompt_2,
1544
- device=device,
1545
- num_images_per_prompt=num_images_per_prompt,
1546
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1547
- negative_prompt=negative_prompt,
1548
- negative_prompt_2=negative_prompt_2,
1549
- prompt_embeds=prompt_embeds,
1550
- negative_prompt_embeds=negative_prompt_embeds,
1551
- pooled_prompt_embeds=pooled_prompt_embeds,
1552
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1553
- lora_scale=text_encoder_lora_scale,
1554
- clip_skip=self.clip_skip,
1555
- )
1556
-
1557
- # 4. set timesteps
1558
- def denoising_value_valid(dnv):
1559
- return isinstance(self.denoising_end, float) and 0 < dnv < 1
1560
-
1561
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1562
- timesteps, num_inference_steps = self.get_timesteps(
1563
- num_inference_steps,
1564
- strength,
1565
- device,
1566
- denoising_start=self.denoising_start if denoising_value_valid else None,
1567
- )
1568
- # check that number of inference steps is not < 1 - as this doesn't make sense
1569
- if num_inference_steps < 1:
1570
- raise ValueError(
1571
- f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
1572
- f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
1573
- )
1574
- # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
1575
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1576
- # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1577
- is_strength_max = strength == 1.0
1578
-
1579
- # 5. Preprocess mask and image
1580
- if padding_mask_crop is not None:
1581
- crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
1582
- resize_mode = "fill"
1583
- else:
1584
- crops_coords = None
1585
- resize_mode = "default"
1586
-
1587
- original_image = image
1588
- init_image = self.image_processor.preprocess(
1589
- image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
1590
- )
1591
- init_image = init_image.to(dtype=torch.float32)
1592
-
1593
- mask = self.mask_processor.preprocess(
1594
- mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1595
- )
1596
- if masked_image_latents is not None:
1597
- masked_image = masked_image_latents
1598
- elif init_image.shape[1] == 4:
1599
- # if images are in latent space, we can't mask it
1600
- masked_image = None
1601
- else:
1602
- masked_image = init_image * (mask < 0.5)
1603
-
1604
- # 6. Prepare latent variables
1605
- num_channels_latents = self.vae.config.latent_channels
1606
- num_channels_unet = self.unet.config.in_channels
1607
- return_image_latents = num_channels_unet == 4
1608
-
1609
- add_noise = True if self.denoising_start is None else False
1610
- latents_outputs = self.prepare_latents(
1611
- batch_size * num_images_per_prompt,
1612
- num_channels_latents,
1613
- height,
1614
- width,
1615
- prompt_embeds.dtype,
1616
- device,
1617
- generator,
1618
- latents,
1619
- image=init_image,
1620
- timestep=latent_timestep,
1621
- is_strength_max=is_strength_max,
1622
- add_noise=add_noise,
1623
- return_noise=True,
1624
- return_image_latents=return_image_latents,
1625
- )
1626
-
1627
- if return_image_latents:
1628
- latents, noise, image_latents = latents_outputs
1629
- else:
1630
- latents, noise = latents_outputs
1631
-
1632
- # 7. Prepare mask latent variables
1633
- mask, masked_image_latents = self.prepare_mask_latents(
1634
- mask,
1635
- masked_image,
1636
- batch_size * num_images_per_prompt,
1637
- height,
1638
- width,
1639
- prompt_embeds.dtype,
1640
- device,
1641
- generator,
1642
- self.do_classifier_free_guidance,
1643
- )
1644
- pose_img = pose_img.to(device=device, dtype=prompt_embeds.dtype)
1645
-
1646
- pose_img = self.vae.encode(pose_img).latent_dist.sample()
1647
- pose_img = pose_img * self.vae.config.scaling_factor
1648
-
1649
- # pose_img = self._encode_vae_image(pose_img, generator=generator)
1650
-
1651
- pose_img = (
1652
- torch.cat([pose_img] * 2) if self.do_classifier_free_guidance else pose_img
1653
- )
1654
- cloth = self._encode_vae_image(cloth, generator=generator)
1655
-
1656
- # # 8. Check that sizes of mask, masked image and latents match
1657
- # if num_channels_unet == 9:
1658
- # # default case for runwayml/stable-diffusion-inpainting
1659
- # num_channels_mask = mask.shape[1]
1660
- # num_channels_masked_image = masked_image_latents.shape[1]
1661
- # if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
1662
- # raise ValueError(
1663
- # f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
1664
- # f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1665
- # f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1666
- # f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1667
- # " `pipeline.unet` or your `mask_image` or `image` input."
1668
- # )
1669
- # elif num_channels_unet != 4:
1670
- # raise ValueError(
1671
- # f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
1672
- # )
1673
- # 8.1 Prepare extra step kwargs.
1674
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1675
-
1676
- # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1677
- height, width = latents.shape[-2:]
1678
- height = height * self.vae_scale_factor
1679
- width = width * self.vae_scale_factor
1680
-
1681
- original_size = original_size or (height, width)
1682
- target_size = target_size or (height, width)
1683
-
1684
- # 10. Prepare added time ids & embeddings
1685
- if negative_original_size is None:
1686
- negative_original_size = original_size
1687
- if negative_target_size is None:
1688
- negative_target_size = target_size
1689
-
1690
- add_text_embeds = pooled_prompt_embeds
1691
- if self.text_encoder_2 is None:
1692
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1693
- else:
1694
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1695
-
1696
- add_time_ids, add_neg_time_ids = self._get_add_time_ids(
1697
- original_size,
1698
- crops_coords_top_left,
1699
- target_size,
1700
- aesthetic_score,
1701
- negative_aesthetic_score,
1702
- negative_original_size,
1703
- negative_crops_coords_top_left,
1704
- negative_target_size,
1705
- dtype=prompt_embeds.dtype,
1706
- text_encoder_projection_dim=text_encoder_projection_dim,
1707
- )
1708
- add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1709
-
1710
- if self.do_classifier_free_guidance:
1711
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1712
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1713
- add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1714
- add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
1715
-
1716
- prompt_embeds = prompt_embeds.to(device)
1717
- add_text_embeds = add_text_embeds.to(device)
1718
- add_time_ids = add_time_ids.to(device)
1719
-
1720
- if ip_adapter_image is not None:
1721
- image_embeds = self.prepare_ip_adapter_image_embeds(
1722
- ip_adapter_image, device, batch_size * num_images_per_prompt
1723
- )
1724
-
1725
- #project outside for loop
1726
- image_embeds = self.unet.encoder_hid_proj(image_embeds).to(prompt_embeds.dtype)
1727
-
1728
-
1729
- # 11. Denoising loop
1730
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1731
-
1732
- if (
1733
- self.denoising_end is not None
1734
- and self.denoising_start is not None
1735
- and denoising_value_valid(self.denoising_end)
1736
- and denoising_value_valid(self.denoising_start)
1737
- and self.denoising_start >= self.denoising_end
1738
- ):
1739
- raise ValueError(
1740
- f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
1741
- + f" {self.denoising_end} when using type float."
1742
- )
1743
- elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
1744
- discrete_timestep_cutoff = int(
1745
- round(
1746
- self.scheduler.config.num_train_timesteps
1747
- - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1748
- )
1749
- )
1750
- num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1751
- timesteps = timesteps[:num_inference_steps]
1752
-
1753
- # 11.1 Optionally get Guidance Scale Embedding
1754
- timestep_cond = None
1755
- if self.unet.config.time_cond_proj_dim is not None:
1756
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1757
- timestep_cond = self.get_guidance_scale_embedding(
1758
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1759
- ).to(device=device, dtype=latents.dtype)
1760
-
1761
-
1762
-
1763
- self._num_timesteps = len(timesteps)
1764
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1765
- for i, t in enumerate(timesteps):
1766
- if self.interrupt:
1767
- continue
1768
- # expand the latents if we are doing classifier free guidance
1769
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1770
-
1771
- # concat latents, mask, masked_image_latents in the channel dimension
1772
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1773
-
1774
-
1775
- # bsz = mask.shape[0]
1776
- if num_channels_unet == 13:
1777
- latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents,pose_img], dim=1)
1778
-
1779
- # if num_channels_unet == 9:
1780
- # latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1781
-
1782
- # predict the noise residual
1783
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1784
- if ip_adapter_image is not None:
1785
- added_cond_kwargs["image_embeds"] = image_embeds
1786
- # down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
1787
- down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
1788
- # print(type(reference_features))
1789
- # print(reference_features)
1790
- reference_features = list(reference_features)
1791
- # print(len(reference_features))
1792
- # for elem in reference_features:
1793
- # print(elem.shape)
1794
- # exit(1)
1795
- if self.do_classifier_free_guidance:
1796
- reference_features = [torch.cat([torch.zeros_like(d), d]) for d in reference_features]
1797
-
1798
-
1799
- noise_pred = self.unet(
1800
- latent_model_input,
1801
- t,
1802
- encoder_hidden_states=prompt_embeds,
1803
- timestep_cond=timestep_cond,
1804
- cross_attention_kwargs=self.cross_attention_kwargs,
1805
- added_cond_kwargs=added_cond_kwargs,
1806
- return_dict=False,
1807
- garment_features=reference_features,
1808
- )[0]
1809
- # noise_pred = self.unet(latent_model_input, t,
1810
- # prompt_embeds,timestep_cond=timestep_cond,cross_attention_kwargs=self.cross_attention_kwargs,added_cond_kwargs=added_cond_kwargs,down_block_additional_attn=down ).sample
1811
-
1812
-
1813
- # perform guidance
1814
- if self.do_classifier_free_guidance:
1815
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1816
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1817
-
1818
- if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1819
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1820
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1821
-
1822
- # compute the previous noisy sample x_t -> x_t-1
1823
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1824
-
1825
- if num_channels_unet == 4:
1826
- init_latents_proper = image_latents
1827
- if self.do_classifier_free_guidance:
1828
- init_mask, _ = mask.chunk(2)
1829
- else:
1830
- init_mask = mask
1831
-
1832
- if i < len(timesteps) - 1:
1833
- noise_timestep = timesteps[i + 1]
1834
- init_latents_proper = self.scheduler.add_noise(
1835
- init_latents_proper, noise, torch.tensor([noise_timestep])
1836
- )
1837
-
1838
- latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1839
-
1840
- if callback_on_step_end is not None:
1841
- callback_kwargs = {}
1842
- for k in callback_on_step_end_tensor_inputs:
1843
- callback_kwargs[k] = locals()[k]
1844
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1845
-
1846
- latents = callback_outputs.pop("latents", latents)
1847
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1848
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1849
- add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1850
- negative_pooled_prompt_embeds = callback_outputs.pop(
1851
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1852
- )
1853
- add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1854
- add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
1855
- mask = callback_outputs.pop("mask", mask)
1856
- masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
1857
-
1858
- # call the callback, if provided
1859
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1860
- progress_bar.update()
1861
- if callback is not None and i % callback_steps == 0:
1862
- step_idx = i // getattr(self.scheduler, "order", 1)
1863
- callback(step_idx, t, latents)
1864
-
1865
- if XLA_AVAILABLE:
1866
- xm.mark_step()
1867
-
1868
- if not output_type == "latent":
1869
- # make sure the VAE is in float32 mode, as it overflows in float16
1870
- needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1871
-
1872
- if needs_upcasting:
1873
- self.upcast_vae()
1874
- latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1875
-
1876
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1877
-
1878
- # cast back to fp16 if needed
1879
- if needs_upcasting:
1880
- self.vae.to(dtype=torch.float16)
1881
- # else:
1882
- # return StableDiffusionXLPipelineOutput(images=latents)
1883
-
1884
-
1885
- image = self.image_processor.postprocess(image, output_type=output_type)
1886
-
1887
- if padding_mask_crop is not None:
1888
- image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
1889
-
1890
- # Offload all models
1891
- self.maybe_free_model_hooks()
1892
-
1893
- # if not return_dict:
1894
- return (image,)
1895
-
1896
- # return StableDiffusionXLPipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/unet_block_hacked_garmnet.py DELETED
The diff for this file is too large to render. See raw diff
 
src/unet_block_hacked_tryon.py DELETED
The diff for this file is too large to render. See raw diff
 
src/unet_hacked_garmnet.py DELETED
@@ -1,1284 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, List, Optional, Tuple, Union
16
-
17
- import torch
18
- import torch.nn as nn
19
- import torch.utils.checkpoint
20
-
21
- from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.loaders import UNet2DConditionLoadersMixin
23
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
24
- from diffusers.models.activations import get_activation
25
- from diffusers.models.attention_processor import (
26
- ADDED_KV_ATTENTION_PROCESSORS,
27
- CROSS_ATTENTION_PROCESSORS,
28
- Attention,
29
- AttentionProcessor,
30
- AttnAddedKVProcessor,
31
- AttnProcessor,
32
- )
33
- from einops import rearrange
34
-
35
- from diffusers.models.embeddings import (
36
- GaussianFourierProjection,
37
- ImageHintTimeEmbedding,
38
- ImageProjection,
39
- ImageTimeEmbedding,
40
- PositionNet,
41
- TextImageProjection,
42
- TextImageTimeEmbedding,
43
- TextTimeEmbedding,
44
- TimestepEmbedding,
45
- Timesteps,
46
- )
47
- from diffusers.models.modeling_utils import ModelMixin
48
- from src.unet_block_hacked_garmnet import (
49
- UNetMidBlock2D,
50
- UNetMidBlock2DCrossAttn,
51
- UNetMidBlock2DSimpleCrossAttn,
52
- get_down_block,
53
- get_up_block,
54
- )
55
- from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
56
- from diffusers.models.transformer_2d import Transformer2DModel
57
-
58
-
59
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
60
-
61
-
62
- def zero_module(module):
63
- for p in module.parameters():
64
- nn.init.zeros_(p)
65
- return module
66
-
67
- @dataclass
68
- class UNet2DConditionOutput(BaseOutput):
69
- """
70
- The output of [`UNet2DConditionModel`].
71
-
72
- Args:
73
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
74
- The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
75
- """
76
-
77
- sample: torch.FloatTensor = None
78
-
79
-
80
- class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
81
- r"""
82
- A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
83
- shaped output.
84
-
85
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
86
- for all models (such as downloading or saving).
87
-
88
- Parameters:
89
- sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
90
- Height and width of input/output sample.
91
- in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
92
- out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
93
- center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
94
- flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
95
- Whether to flip the sin to cos in the time embedding.
96
- freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
97
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
98
- The tuple of downsample blocks to use.
99
- mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
100
- Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
101
- `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
102
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
103
- The tuple of upsample blocks to use.
104
- only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
105
- Whether to include self-attention in the basic transformer blocks, see
106
- [`~models.attention.BasicTransformerBlock`].
107
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
108
- The tuple of output channels for each block.
109
- layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
110
- downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
111
- mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
112
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
113
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
114
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
115
- If `None`, normalization and activation layers is skipped in post-processing.
116
- norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
117
- cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
118
- The dimension of the cross attention features.
119
- transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
120
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
121
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
122
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
123
- reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
124
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
125
- blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
126
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
127
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
128
- encoder_hid_dim (`int`, *optional*, defaults to None):
129
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
130
- dimension to `cross_attention_dim`.
131
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
132
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
133
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
134
- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
135
- num_attention_heads (`int`, *optional*):
136
- The number of attention heads. If not defined, defaults to `attention_head_dim`
137
- resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
138
- for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
139
- class_embed_type (`str`, *optional*, defaults to `None`):
140
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
141
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
142
- addition_embed_type (`str`, *optional*, defaults to `None`):
143
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
144
- "text". "text" will use the `TextTimeEmbedding` layer.
145
- addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
146
- Dimension for the timestep embeddings.
147
- num_class_embeds (`int`, *optional*, defaults to `None`):
148
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
149
- class conditioning with `class_embed_type` equal to `None`.
150
- time_embedding_type (`str`, *optional*, defaults to `positional`):
151
- The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
152
- time_embedding_dim (`int`, *optional*, defaults to `None`):
153
- An optional override for the dimension of the projected time embedding.
154
- time_embedding_act_fn (`str`, *optional*, defaults to `None`):
155
- Optional activation function to use only once on the time embeddings before they are passed to the rest of
156
- the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
157
- timestep_post_act (`str`, *optional*, defaults to `None`):
158
- The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
159
- time_cond_proj_dim (`int`, *optional*, defaults to `None`):
160
- The dimension of `cond_proj` layer in the timestep embedding.
161
- conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
162
- *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
163
- *optional*): The dimension of the `class_labels` input when
164
- `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
165
- class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
166
- embeddings with the class embeddings.
167
- mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
168
- Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
169
- `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
170
- `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
171
- otherwise.
172
- """
173
-
174
- _supports_gradient_checkpointing = True
175
-
176
- @register_to_config
177
- def __init__(
178
- self,
179
- sample_size: Optional[int] = None,
180
- in_channels: int = 4,
181
- out_channels: int = 4,
182
- center_input_sample: bool = False,
183
- flip_sin_to_cos: bool = True,
184
- freq_shift: int = 0,
185
- down_block_types: Tuple[str] = (
186
- "CrossAttnDownBlock2D",
187
- "CrossAttnDownBlock2D",
188
- "CrossAttnDownBlock2D",
189
- "DownBlock2D",
190
- ),
191
- mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
192
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
193
- only_cross_attention: Union[bool, Tuple[bool]] = False,
194
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
195
- layers_per_block: Union[int, Tuple[int]] = 2,
196
- downsample_padding: int = 1,
197
- mid_block_scale_factor: float = 1,
198
- dropout: float = 0.0,
199
- act_fn: str = "silu",
200
- norm_num_groups: Optional[int] = 32,
201
- norm_eps: float = 1e-5,
202
- cross_attention_dim: Union[int, Tuple[int]] = 1280,
203
- transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
204
- reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
205
- encoder_hid_dim: Optional[int] = None,
206
- encoder_hid_dim_type: Optional[str] = None,
207
- attention_head_dim: Union[int, Tuple[int]] = 8,
208
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
209
- dual_cross_attention: bool = False,
210
- use_linear_projection: bool = False,
211
- class_embed_type: Optional[str] = None,
212
- addition_embed_type: Optional[str] = None,
213
- addition_time_embed_dim: Optional[int] = None,
214
- num_class_embeds: Optional[int] = None,
215
- upcast_attention: bool = False,
216
- resnet_time_scale_shift: str = "default",
217
- resnet_skip_time_act: bool = False,
218
- resnet_out_scale_factor: int = 1.0,
219
- time_embedding_type: str = "positional",
220
- time_embedding_dim: Optional[int] = None,
221
- time_embedding_act_fn: Optional[str] = None,
222
- timestep_post_act: Optional[str] = None,
223
- time_cond_proj_dim: Optional[int] = None,
224
- conv_in_kernel: int = 3,
225
- conv_out_kernel: int = 3,
226
- projection_class_embeddings_input_dim: Optional[int] = None,
227
- attention_type: str = "default",
228
- class_embeddings_concat: bool = False,
229
- mid_block_only_cross_attention: Optional[bool] = None,
230
- cross_attention_norm: Optional[str] = None,
231
- addition_embed_type_num_heads=64,
232
- ):
233
- super().__init__()
234
-
235
- self.sample_size = sample_size
236
-
237
- if num_attention_heads is not None:
238
- raise ValueError(
239
- "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
240
- )
241
-
242
- # If `num_attention_heads` is not defined (which is the case for most models)
243
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
244
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
245
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
246
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
247
- # which is why we correct for the naming here.
248
- num_attention_heads = num_attention_heads or attention_head_dim
249
-
250
- # Check inputs
251
- if len(down_block_types) != len(up_block_types):
252
- raise ValueError(
253
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
254
- )
255
-
256
- if len(block_out_channels) != len(down_block_types):
257
- raise ValueError(
258
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
259
- )
260
-
261
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
262
- raise ValueError(
263
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
264
- )
265
-
266
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
267
- raise ValueError(
268
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
269
- )
270
-
271
- if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
272
- raise ValueError(
273
- f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
274
- )
275
-
276
- if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
277
- raise ValueError(
278
- f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
279
- )
280
-
281
- if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
282
- raise ValueError(
283
- f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
284
- )
285
- if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
286
- for layer_number_per_block in transformer_layers_per_block:
287
- if isinstance(layer_number_per_block, list):
288
- raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
289
-
290
- # input
291
- conv_in_padding = (conv_in_kernel - 1) // 2
292
- self.conv_in = nn.Conv2d(
293
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
294
- )
295
-
296
- # time
297
- if time_embedding_type == "fourier":
298
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
299
- if time_embed_dim % 2 != 0:
300
- raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
301
- self.time_proj = GaussianFourierProjection(
302
- time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
303
- )
304
- timestep_input_dim = time_embed_dim
305
- elif time_embedding_type == "positional":
306
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
307
-
308
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
309
- timestep_input_dim = block_out_channels[0]
310
- else:
311
- raise ValueError(
312
- f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
313
- )
314
-
315
- self.time_embedding = TimestepEmbedding(
316
- timestep_input_dim,
317
- time_embed_dim,
318
- act_fn=act_fn,
319
- post_act_fn=timestep_post_act,
320
- cond_proj_dim=time_cond_proj_dim,
321
- )
322
-
323
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
324
- encoder_hid_dim_type = "text_proj"
325
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
326
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
327
-
328
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
329
- raise ValueError(
330
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
331
- )
332
-
333
- if encoder_hid_dim_type == "text_proj":
334
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
335
- elif encoder_hid_dim_type == "text_image_proj":
336
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
337
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
338
- # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
339
- self.encoder_hid_proj = TextImageProjection(
340
- text_embed_dim=encoder_hid_dim,
341
- image_embed_dim=cross_attention_dim,
342
- cross_attention_dim=cross_attention_dim,
343
- )
344
- elif encoder_hid_dim_type == "image_proj":
345
- # Kandinsky 2.2
346
- self.encoder_hid_proj = ImageProjection(
347
- image_embed_dim=encoder_hid_dim,
348
- cross_attention_dim=cross_attention_dim,
349
- )
350
- elif encoder_hid_dim_type is not None:
351
- raise ValueError(
352
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
353
- )
354
- else:
355
- self.encoder_hid_proj = None
356
-
357
- # class embedding
358
- if class_embed_type is None and num_class_embeds is not None:
359
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
360
- elif class_embed_type == "timestep":
361
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
362
- elif class_embed_type == "identity":
363
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
364
- elif class_embed_type == "projection":
365
- if projection_class_embeddings_input_dim is None:
366
- raise ValueError(
367
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
368
- )
369
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
370
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
371
- # 2. it projects from an arbitrary input dimension.
372
- #
373
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
374
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
375
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
376
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
377
- elif class_embed_type == "simple_projection":
378
- if projection_class_embeddings_input_dim is None:
379
- raise ValueError(
380
- "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
381
- )
382
- self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
383
- else:
384
- self.class_embedding = None
385
-
386
- if addition_embed_type == "text":
387
- if encoder_hid_dim is not None:
388
- text_time_embedding_from_dim = encoder_hid_dim
389
- else:
390
- text_time_embedding_from_dim = cross_attention_dim
391
-
392
- self.add_embedding = TextTimeEmbedding(
393
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
394
- )
395
- elif addition_embed_type == "text_image":
396
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
397
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
398
- # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
399
- self.add_embedding = TextImageTimeEmbedding(
400
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
401
- )
402
- elif addition_embed_type == "text_time":
403
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
404
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
405
- elif addition_embed_type == "image":
406
- # Kandinsky 2.2
407
- self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
408
- elif addition_embed_type == "image_hint":
409
- # Kandinsky 2.2 ControlNet
410
- self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
411
- elif addition_embed_type is not None:
412
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
413
-
414
- if time_embedding_act_fn is None:
415
- self.time_embed_act = None
416
- else:
417
- self.time_embed_act = get_activation(time_embedding_act_fn)
418
-
419
- self.down_blocks = nn.ModuleList([])
420
- self.up_blocks = nn.ModuleList([])
421
-
422
- if isinstance(only_cross_attention, bool):
423
- if mid_block_only_cross_attention is None:
424
- mid_block_only_cross_attention = only_cross_attention
425
-
426
- only_cross_attention = [only_cross_attention] * len(down_block_types)
427
-
428
- if mid_block_only_cross_attention is None:
429
- mid_block_only_cross_attention = False
430
-
431
- if isinstance(num_attention_heads, int):
432
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
433
-
434
- if isinstance(attention_head_dim, int):
435
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
436
-
437
- if isinstance(cross_attention_dim, int):
438
- cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
439
-
440
- if isinstance(layers_per_block, int):
441
- layers_per_block = [layers_per_block] * len(down_block_types)
442
-
443
- if isinstance(transformer_layers_per_block, int):
444
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
445
- if class_embeddings_concat:
446
- # The time embeddings are concatenated with the class embeddings. The dimension of the
447
- # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
448
- # regular time embeddings
449
- blocks_time_embed_dim = time_embed_dim * 2
450
- else:
451
- blocks_time_embed_dim = time_embed_dim
452
-
453
- # down
454
- output_channel = block_out_channels[0]
455
- for i, down_block_type in enumerate(down_block_types):
456
- input_channel = output_channel
457
- output_channel = block_out_channels[i]
458
- is_final_block = i == len(block_out_channels) - 1
459
-
460
- down_block = get_down_block(
461
- down_block_type,
462
- num_layers=layers_per_block[i],
463
- transformer_layers_per_block=transformer_layers_per_block[i],
464
- in_channels=input_channel,
465
- out_channels=output_channel,
466
- temb_channels=blocks_time_embed_dim,
467
- add_downsample=not is_final_block,
468
- resnet_eps=norm_eps,
469
- resnet_act_fn=act_fn,
470
- resnet_groups=norm_num_groups,
471
- cross_attention_dim=cross_attention_dim[i],
472
- num_attention_heads=num_attention_heads[i],
473
- downsample_padding=downsample_padding,
474
- dual_cross_attention=dual_cross_attention,
475
- use_linear_projection=use_linear_projection,
476
- only_cross_attention=only_cross_attention[i],
477
- upcast_attention=upcast_attention,
478
- resnet_time_scale_shift=resnet_time_scale_shift,
479
- attention_type=attention_type,
480
- resnet_skip_time_act=resnet_skip_time_act,
481
- resnet_out_scale_factor=resnet_out_scale_factor,
482
- cross_attention_norm=cross_attention_norm,
483
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
484
- dropout=dropout,
485
- )
486
- self.down_blocks.append(down_block)
487
-
488
- # mid
489
- if mid_block_type == "UNetMidBlock2DCrossAttn":
490
- self.mid_block = UNetMidBlock2DCrossAttn(
491
- transformer_layers_per_block=transformer_layers_per_block[-1],
492
- in_channels=block_out_channels[-1],
493
- temb_channels=blocks_time_embed_dim,
494
- dropout=dropout,
495
- resnet_eps=norm_eps,
496
- resnet_act_fn=act_fn,
497
- output_scale_factor=mid_block_scale_factor,
498
- resnet_time_scale_shift=resnet_time_scale_shift,
499
- cross_attention_dim=cross_attention_dim[-1],
500
- num_attention_heads=num_attention_heads[-1],
501
- resnet_groups=norm_num_groups,
502
- dual_cross_attention=dual_cross_attention,
503
- use_linear_projection=use_linear_projection,
504
- upcast_attention=upcast_attention,
505
- attention_type=attention_type,
506
- )
507
- elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
508
- self.mid_block = UNetMidBlock2DSimpleCrossAttn(
509
- in_channels=block_out_channels[-1],
510
- temb_channels=blocks_time_embed_dim,
511
- dropout=dropout,
512
- resnet_eps=norm_eps,
513
- resnet_act_fn=act_fn,
514
- output_scale_factor=mid_block_scale_factor,
515
- cross_attention_dim=cross_attention_dim[-1],
516
- attention_head_dim=attention_head_dim[-1],
517
- resnet_groups=norm_num_groups,
518
- resnet_time_scale_shift=resnet_time_scale_shift,
519
- skip_time_act=resnet_skip_time_act,
520
- only_cross_attention=mid_block_only_cross_attention,
521
- cross_attention_norm=cross_attention_norm,
522
- )
523
- elif mid_block_type == "UNetMidBlock2D":
524
- self.mid_block = UNetMidBlock2D(
525
- in_channels=block_out_channels[-1],
526
- temb_channels=blocks_time_embed_dim,
527
- dropout=dropout,
528
- num_layers=0,
529
- resnet_eps=norm_eps,
530
- resnet_act_fn=act_fn,
531
- output_scale_factor=mid_block_scale_factor,
532
- resnet_groups=norm_num_groups,
533
- resnet_time_scale_shift=resnet_time_scale_shift,
534
- add_attention=False,
535
- )
536
- elif mid_block_type is None:
537
- self.mid_block = None
538
- else:
539
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
540
-
541
- # count how many layers upsample the images
542
- self.num_upsamplers = 0
543
-
544
- # up
545
- reversed_block_out_channels = list(reversed(block_out_channels))
546
- reversed_num_attention_heads = list(reversed(num_attention_heads))
547
- reversed_layers_per_block = list(reversed(layers_per_block))
548
- reversed_cross_attention_dim = list(reversed(cross_attention_dim))
549
- reversed_transformer_layers_per_block = (
550
- list(reversed(transformer_layers_per_block))
551
- if reverse_transformer_layers_per_block is None
552
- else reverse_transformer_layers_per_block
553
- )
554
- only_cross_attention = list(reversed(only_cross_attention))
555
-
556
- output_channel = reversed_block_out_channels[0]
557
- for i, up_block_type in enumerate(up_block_types):
558
- is_final_block = i == len(block_out_channels) - 1
559
-
560
- prev_output_channel = output_channel
561
- output_channel = reversed_block_out_channels[i]
562
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
563
-
564
- # add upsample block for all BUT final layer
565
- if not is_final_block:
566
- add_upsample = True
567
- self.num_upsamplers += 1
568
- else:
569
- add_upsample = False
570
- up_block = get_up_block(
571
- up_block_type,
572
- num_layers=reversed_layers_per_block[i] + 1,
573
- transformer_layers_per_block=reversed_transformer_layers_per_block[i],
574
- in_channels=input_channel,
575
- out_channels=output_channel,
576
- prev_output_channel=prev_output_channel,
577
- temb_channels=blocks_time_embed_dim,
578
- add_upsample=add_upsample,
579
- resnet_eps=norm_eps,
580
- resnet_act_fn=act_fn,
581
- resolution_idx=i,
582
- resnet_groups=norm_num_groups,
583
- cross_attention_dim=reversed_cross_attention_dim[i],
584
- num_attention_heads=reversed_num_attention_heads[i],
585
- dual_cross_attention=dual_cross_attention,
586
- use_linear_projection=use_linear_projection,
587
- only_cross_attention=only_cross_attention[i],
588
- upcast_attention=upcast_attention,
589
- resnet_time_scale_shift=resnet_time_scale_shift,
590
- attention_type=attention_type,
591
- resnet_skip_time_act=resnet_skip_time_act,
592
- resnet_out_scale_factor=resnet_out_scale_factor,
593
- cross_attention_norm=cross_attention_norm,
594
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
595
- dropout=dropout,
596
- )
597
-
598
- self.up_blocks.append(up_block)
599
- prev_output_channel = output_channel
600
-
601
-
602
-
603
-
604
- # encode_output_chs = [
605
- # # 320,
606
- # # 320,
607
- # # 320,
608
- # 1280,
609
- # 1280,
610
- # 1280,
611
- # 1280,
612
- # 640,
613
- # 640
614
- # ]
615
-
616
- # encode_output_chs2 = [
617
- # # 320,
618
- # # 320,
619
- # # 320,
620
- # 1280,
621
- # 1280,
622
- # 640,
623
- # 640,
624
- # 640,
625
- # 320
626
- # ]
627
-
628
- # encode_num_head_chs3 = [
629
- # # 5,
630
- # # 5,
631
- # # 10,
632
- # 20,
633
- # 20,
634
- # 20,
635
- # 10,
636
- # 10,
637
- # 10
638
- # ]
639
-
640
-
641
- # encode_num_layers_chs4 = [
642
- # # 1,
643
- # # 1,
644
- # # 2,
645
- # 10,
646
- # 10,
647
- # 10,
648
- # 2,
649
- # 2,
650
- # 2
651
- # ]
652
-
653
-
654
- # self.warp_blks = nn.ModuleList([])
655
- # self.warp_zeros = nn.ModuleList([])
656
-
657
- # for in_ch, cont_ch,num_head,num_layers in zip(encode_output_chs, encode_output_chs2,encode_num_head_chs3,encode_num_layers_chs4):
658
- # # dim_head = in_ch // self.num_heads
659
- # # dim_head = dim_head // dim_head_denorm
660
-
661
- # self.warp_blks.append(Transformer2DModel(
662
- # num_attention_heads=num_head,
663
- # attention_head_dim=64,
664
- # in_channels=in_ch,
665
- # num_layers = num_layers,
666
- # cross_attention_dim = cont_ch,
667
- # ))
668
-
669
- # self.warp_zeros.append(zero_module(nn.Conv2d(in_ch, in_ch, 1, padding=0)))
670
-
671
-
672
-
673
- # out
674
- if norm_num_groups is not None:
675
- self.conv_norm_out = nn.GroupNorm(
676
- num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
677
- )
678
-
679
- self.conv_act = get_activation(act_fn)
680
-
681
- else:
682
- self.conv_norm_out = None
683
- self.conv_act = None
684
-
685
- conv_out_padding = (conv_out_kernel - 1) // 2
686
- self.conv_out = nn.Conv2d(
687
- block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
688
- )
689
-
690
- if attention_type in ["gated", "gated-text-image"]:
691
- positive_len = 768
692
- if isinstance(cross_attention_dim, int):
693
- positive_len = cross_attention_dim
694
- elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
695
- positive_len = cross_attention_dim[0]
696
-
697
- feature_type = "text-only" if attention_type == "gated" else "text-image"
698
- self.position_net = PositionNet(
699
- positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
700
- )
701
-
702
-
703
-
704
-
705
- @property
706
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
707
- r"""
708
- Returns:
709
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
710
- indexed by its weight name.
711
- """
712
- # set recursively
713
- processors = {}
714
-
715
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
716
- if hasattr(module, "get_processor"):
717
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
718
-
719
- for sub_name, child in module.named_children():
720
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
721
-
722
- return processors
723
-
724
- for name, module in self.named_children():
725
- fn_recursive_add_processors(name, module, processors)
726
-
727
- return processors
728
-
729
- def set_attn_processor(
730
- self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
731
- ):
732
- r"""
733
- Sets the attention processor to use to compute attention.
734
-
735
- Parameters:
736
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
737
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
738
- for **all** `Attention` layers.
739
-
740
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
741
- processor. This is strongly recommended when setting trainable attention processors.
742
-
743
- """
744
- count = len(self.attn_processors.keys())
745
-
746
- if isinstance(processor, dict) and len(processor) != count:
747
- raise ValueError(
748
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
749
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
750
- )
751
-
752
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
753
- if hasattr(module, "set_processor"):
754
- if not isinstance(processor, dict):
755
- module.set_processor(processor, _remove_lora=_remove_lora)
756
- else:
757
- module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
758
-
759
- for sub_name, child in module.named_children():
760
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
761
-
762
- for name, module in self.named_children():
763
- fn_recursive_attn_processor(name, module, processor)
764
-
765
- def set_default_attn_processor(self):
766
- """
767
- Disables custom attention processors and sets the default attention implementation.
768
- """
769
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
770
- processor = AttnAddedKVProcessor()
771
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
772
- processor = AttnProcessor()
773
- else:
774
- raise ValueError(
775
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
776
- )
777
-
778
- self.set_attn_processor(processor, _remove_lora=True)
779
-
780
- def set_attention_slice(self, slice_size):
781
- r"""
782
- Enable sliced attention computation.
783
-
784
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
785
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
786
-
787
- Args:
788
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
789
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
790
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
791
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
792
- must be a multiple of `slice_size`.
793
- """
794
- sliceable_head_dims = []
795
-
796
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
797
- if hasattr(module, "set_attention_slice"):
798
- sliceable_head_dims.append(module.sliceable_head_dim)
799
-
800
- for child in module.children():
801
- fn_recursive_retrieve_sliceable_dims(child)
802
-
803
- # retrieve number of attention layers
804
- for module in self.children():
805
- fn_recursive_retrieve_sliceable_dims(module)
806
-
807
- num_sliceable_layers = len(sliceable_head_dims)
808
-
809
- if slice_size == "auto":
810
- # half the attention head size is usually a good trade-off between
811
- # speed and memory
812
- slice_size = [dim // 2 for dim in sliceable_head_dims]
813
- elif slice_size == "max":
814
- # make smallest slice possible
815
- slice_size = num_sliceable_layers * [1]
816
-
817
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
818
-
819
- if len(slice_size) != len(sliceable_head_dims):
820
- raise ValueError(
821
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
822
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
823
- )
824
-
825
- for i in range(len(slice_size)):
826
- size = slice_size[i]
827
- dim = sliceable_head_dims[i]
828
- if size is not None and size > dim:
829
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
830
-
831
- # Recursively walk through all the children.
832
- # Any children which exposes the set_attention_slice method
833
- # gets the message
834
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
835
- if hasattr(module, "set_attention_slice"):
836
- module.set_attention_slice(slice_size.pop())
837
-
838
- for child in module.children():
839
- fn_recursive_set_attention_slice(child, slice_size)
840
-
841
- reversed_slice_size = list(reversed(slice_size))
842
- for module in self.children():
843
- fn_recursive_set_attention_slice(module, reversed_slice_size)
844
-
845
- def _set_gradient_checkpointing(self, module, value=False):
846
- if hasattr(module, "gradient_checkpointing"):
847
- module.gradient_checkpointing = value
848
-
849
- def enable_freeu(self, s1, s2, b1, b2):
850
- r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
851
-
852
- The suffixes after the scaling factors represent the stage blocks where they are being applied.
853
-
854
- Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
855
- are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
856
-
857
- Args:
858
- s1 (`float`):
859
- Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
860
- mitigate the "oversmoothing effect" in the enhanced denoising process.
861
- s2 (`float`):
862
- Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
863
- mitigate the "oversmoothing effect" in the enhanced denoising process.
864
- b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
865
- b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
866
- """
867
- for i, upsample_block in enumerate(self.up_blocks):
868
- setattr(upsample_block, "s1", s1)
869
- setattr(upsample_block, "s2", s2)
870
- setattr(upsample_block, "b1", b1)
871
- setattr(upsample_block, "b2", b2)
872
-
873
- def disable_freeu(self):
874
- """Disables the FreeU mechanism."""
875
- freeu_keys = {"s1", "s2", "b1", "b2"}
876
- for i, upsample_block in enumerate(self.up_blocks):
877
- for k in freeu_keys:
878
- if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
879
- setattr(upsample_block, k, None)
880
-
881
- def fuse_qkv_projections(self):
882
- """
883
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
884
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
885
-
886
- <Tip warning={true}>
887
-
888
- This API is πŸ§ͺ experimental.
889
-
890
- </Tip>
891
- """
892
- self.original_attn_processors = None
893
-
894
- for _, attn_processor in self.attn_processors.items():
895
- if "Added" in str(attn_processor.__class__.__name__):
896
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
897
-
898
- self.original_attn_processors = self.attn_processors
899
-
900
- for module in self.modules():
901
- if isinstance(module, Attention):
902
- module.fuse_projections(fuse=True)
903
-
904
- def unfuse_qkv_projections(self):
905
- """Disables the fused QKV projection if enabled.
906
-
907
- <Tip warning={true}>
908
-
909
- This API is πŸ§ͺ experimental.
910
-
911
- </Tip>
912
-
913
- """
914
- if self.original_attn_processors is not None:
915
- self.set_attn_processor(self.original_attn_processors)
916
-
917
- def forward(
918
- self,
919
- sample: torch.FloatTensor,
920
- timestep: Union[torch.Tensor, float, int],
921
- encoder_hidden_states: torch.Tensor,
922
- class_labels: Optional[torch.Tensor] = None,
923
- timestep_cond: Optional[torch.Tensor] = None,
924
- attention_mask: Optional[torch.Tensor] = None,
925
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
926
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
927
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
928
- mid_block_additional_residual: Optional[torch.Tensor] = None,
929
- down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
930
- encoder_attention_mask: Optional[torch.Tensor] = None,
931
- return_dict: bool = True,
932
- ) -> Union[UNet2DConditionOutput, Tuple]:
933
- r"""
934
- The [`UNet2DConditionModel`] forward method.
935
-
936
- Args:
937
- sample (`torch.FloatTensor`):
938
- The noisy input tensor with the following shape `(batch, channel, height, width)`.
939
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
940
- encoder_hidden_states (`torch.FloatTensor`):
941
- The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
942
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
943
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
944
- timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
945
- Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
946
- through the `self.time_embedding` layer to obtain the timestep embeddings.
947
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
948
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
949
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
950
- negative values to the attention scores corresponding to "discard" tokens.
951
- cross_attention_kwargs (`dict`, *optional*):
952
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
953
- `self.processor` in
954
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
955
- added_cond_kwargs: (`dict`, *optional*):
956
- A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
957
- are passed along to the UNet blocks.
958
- down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
959
- A tuple of tensors that if specified are added to the residuals of down unet blocks.
960
- mid_block_additional_residual: (`torch.Tensor`, *optional*):
961
- A tensor that if specified is added to the residual of the middle unet block.
962
- encoder_attention_mask (`torch.Tensor`):
963
- A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
964
- `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
965
- which adds large negative values to the attention scores corresponding to "discard" tokens.
966
- return_dict (`bool`, *optional*, defaults to `True`):
967
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
968
- tuple.
969
- cross_attention_kwargs (`dict`, *optional*):
970
- A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
971
- added_cond_kwargs: (`dict`, *optional*):
972
- A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
973
- are passed along to the UNet blocks.
974
- down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
975
- additional residuals to be added to UNet long skip connections from down blocks to up blocks for
976
- example from ControlNet side model(s)
977
- mid_block_additional_residual (`torch.Tensor`, *optional*):
978
- additional residual to be added to UNet mid block output, for example from ControlNet side model
979
- down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
980
- additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
981
-
982
- Returns:
983
- [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
984
- If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
985
- a `tuple` is returned where the first element is the sample tensor.
986
- """
987
- # By default samples have to be AT least a multiple of the overall upsampling factor.
988
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
989
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
990
- # on the fly if necessary.
991
- default_overall_up_factor = 2**self.num_upsamplers
992
-
993
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
994
- forward_upsample_size = False
995
- upsample_size = None
996
-
997
- for dim in sample.shape[-2:]:
998
- if dim % default_overall_up_factor != 0:
999
- # Forward upsample size to force interpolation output size.
1000
- forward_upsample_size = True
1001
- break
1002
-
1003
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1004
- # expects mask of shape:
1005
- # [batch, key_tokens]
1006
- # adds singleton query_tokens dimension:
1007
- # [batch, 1, key_tokens]
1008
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1009
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1010
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1011
- if attention_mask is not None:
1012
- # assume that mask is expressed as:
1013
- # (1 = keep, 0 = discard)
1014
- # convert mask into a bias that can be added to attention scores:
1015
- # (keep = +0, discard = -10000.0)
1016
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1017
- attention_mask = attention_mask.unsqueeze(1)
1018
-
1019
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
1020
- if encoder_attention_mask is not None:
1021
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1022
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1023
-
1024
- # 0. center input if necessary
1025
- if self.config.center_input_sample:
1026
- sample = 2 * sample - 1.0
1027
-
1028
- # 1. time
1029
- timesteps = timestep
1030
- if not torch.is_tensor(timesteps):
1031
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1032
- # This would be a good case for the `match` statement (Python 3.10+)
1033
- is_mps = sample.device.type == "mps"
1034
- if isinstance(timestep, float):
1035
- dtype = torch.float32 if is_mps else torch.float64
1036
- else:
1037
- dtype = torch.int32 if is_mps else torch.int64
1038
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1039
- elif len(timesteps.shape) == 0:
1040
- timesteps = timesteps[None].to(sample.device)
1041
-
1042
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1043
- timesteps = timesteps.expand(sample.shape[0])
1044
-
1045
- t_emb = self.time_proj(timesteps)
1046
-
1047
- # `Timesteps` does not contain any weights and will always return f32 tensors
1048
- # but time_embedding might actually be running in fp16. so we need to cast here.
1049
- # there might be better ways to encapsulate this.
1050
- t_emb = t_emb.to(dtype=sample.dtype)
1051
-
1052
- emb = self.time_embedding(t_emb, timestep_cond)
1053
- aug_emb = None
1054
-
1055
- if self.class_embedding is not None:
1056
- if class_labels is None:
1057
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
1058
-
1059
- if self.config.class_embed_type == "timestep":
1060
- class_labels = self.time_proj(class_labels)
1061
-
1062
- # `Timesteps` does not contain any weights and will always return f32 tensors
1063
- # there might be better ways to encapsulate this.
1064
- class_labels = class_labels.to(dtype=sample.dtype)
1065
-
1066
- class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1067
-
1068
- if self.config.class_embeddings_concat:
1069
- emb = torch.cat([emb, class_emb], dim=-1)
1070
- else:
1071
- emb = emb + class_emb
1072
-
1073
- if self.config.addition_embed_type == "text":
1074
- aug_emb = self.add_embedding(encoder_hidden_states)
1075
- elif self.config.addition_embed_type == "text_image":
1076
- # Kandinsky 2.1 - style
1077
- if "image_embeds" not in added_cond_kwargs:
1078
- raise ValueError(
1079
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1080
- )
1081
-
1082
- image_embs = added_cond_kwargs.get("image_embeds")
1083
- text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1084
- aug_emb = self.add_embedding(text_embs, image_embs)
1085
- elif self.config.addition_embed_type == "text_time":
1086
- # SDXL - style
1087
- if "text_embeds" not in added_cond_kwargs:
1088
- raise ValueError(
1089
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1090
- )
1091
- text_embeds = added_cond_kwargs.get("text_embeds")
1092
- if "time_ids" not in added_cond_kwargs:
1093
- raise ValueError(
1094
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1095
- )
1096
- time_ids = added_cond_kwargs.get("time_ids")
1097
- time_embeds = self.add_time_proj(time_ids.flatten())
1098
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1099
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1100
- add_embeds = add_embeds.to(emb.dtype)
1101
- aug_emb = self.add_embedding(add_embeds)
1102
- elif self.config.addition_embed_type == "image":
1103
- # Kandinsky 2.2 - style
1104
- if "image_embeds" not in added_cond_kwargs:
1105
- raise ValueError(
1106
- f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1107
- )
1108
- image_embs = added_cond_kwargs.get("image_embeds")
1109
- aug_emb = self.add_embedding(image_embs)
1110
- elif self.config.addition_embed_type == "image_hint":
1111
- # Kandinsky 2.2 - style
1112
- if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1113
- raise ValueError(
1114
- f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1115
- )
1116
- image_embs = added_cond_kwargs.get("image_embeds")
1117
- hint = added_cond_kwargs.get("hint")
1118
- aug_emb, hint = self.add_embedding(image_embs, hint)
1119
- sample = torch.cat([sample, hint], dim=1)
1120
-
1121
- emb = emb + aug_emb if aug_emb is not None else emb
1122
-
1123
- if self.time_embed_act is not None:
1124
- emb = self.time_embed_act(emb)
1125
-
1126
- if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1127
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1128
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1129
- # Kadinsky 2.1 - style
1130
- if "image_embeds" not in added_cond_kwargs:
1131
- raise ValueError(
1132
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1133
- )
1134
-
1135
- image_embeds = added_cond_kwargs.get("image_embeds")
1136
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1137
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1138
- # Kandinsky 2.2 - style
1139
- if "image_embeds" not in added_cond_kwargs:
1140
- raise ValueError(
1141
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1142
- )
1143
- image_embeds = added_cond_kwargs.get("image_embeds")
1144
- encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1145
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1146
- if "image_embeds" not in added_cond_kwargs:
1147
- raise ValueError(
1148
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1149
- )
1150
- image_embeds = added_cond_kwargs.get("image_embeds")
1151
- image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
1152
- encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
1153
-
1154
- # 2. pre-process
1155
- sample = self.conv_in(sample)
1156
- garment_features=[]
1157
-
1158
- # 2.5 GLIGEN position net
1159
- if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1160
- cross_attention_kwargs = cross_attention_kwargs.copy()
1161
- gligen_args = cross_attention_kwargs.pop("gligen")
1162
- cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1163
-
1164
-
1165
- # 3. down
1166
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1167
- if USE_PEFT_BACKEND:
1168
- # weight the lora layers by setting `lora_scale` for each PEFT layer
1169
- scale_lora_layers(self, lora_scale)
1170
-
1171
- is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1172
- # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1173
- is_adapter = down_intrablock_additional_residuals is not None
1174
- # maintain backward compatibility for legacy usage, where
1175
- # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1176
- # but can only use one or the other
1177
- if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1178
- deprecate(
1179
- "T2I should not use down_block_additional_residuals",
1180
- "1.3.0",
1181
- "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1182
- and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1183
- for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1184
- standard_warn=False,
1185
- )
1186
- down_intrablock_additional_residuals = down_block_additional_residuals
1187
- is_adapter = True
1188
-
1189
- down_block_res_samples = (sample,)
1190
- for downsample_block in self.down_blocks:
1191
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1192
- # For t2i-adapter CrossAttnDownBlock2D
1193
- additional_residuals = {}
1194
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
1195
- additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1196
-
1197
- sample, res_samples,out_garment_feat = downsample_block(
1198
- hidden_states=sample,
1199
- temb=emb,
1200
- encoder_hidden_states=encoder_hidden_states,
1201
- attention_mask=attention_mask,
1202
- cross_attention_kwargs=cross_attention_kwargs,
1203
- encoder_attention_mask=encoder_attention_mask,
1204
- **additional_residuals,
1205
- )
1206
- garment_features += out_garment_feat
1207
- else:
1208
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
1209
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
1210
- sample += down_intrablock_additional_residuals.pop(0)
1211
-
1212
- down_block_res_samples += res_samples
1213
-
1214
-
1215
- if is_controlnet:
1216
- new_down_block_res_samples = ()
1217
-
1218
- for down_block_res_sample, down_block_additional_residual in zip(
1219
- down_block_res_samples, down_block_additional_residuals
1220
- ):
1221
- down_block_res_sample = down_block_res_sample + down_block_additional_residual
1222
- new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1223
-
1224
- down_block_res_samples = new_down_block_res_samples
1225
-
1226
- # 4. mid
1227
- if self.mid_block is not None:
1228
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1229
- sample,out_garment_feat = self.mid_block(
1230
- sample,
1231
- emb,
1232
- encoder_hidden_states=encoder_hidden_states,
1233
- attention_mask=attention_mask,
1234
- cross_attention_kwargs=cross_attention_kwargs,
1235
- encoder_attention_mask=encoder_attention_mask,
1236
- )
1237
- garment_features += out_garment_feat
1238
-
1239
- else:
1240
- sample = self.mid_block(sample, emb)
1241
-
1242
- # To support T2I-Adapter-XL
1243
- if (
1244
- is_adapter
1245
- and len(down_intrablock_additional_residuals) > 0
1246
- and sample.shape == down_intrablock_additional_residuals[0].shape
1247
- ):
1248
- sample += down_intrablock_additional_residuals.pop(0)
1249
-
1250
- if is_controlnet:
1251
- sample = sample + mid_block_additional_residual
1252
-
1253
-
1254
-
1255
- # 5. up
1256
- for i, upsample_block in enumerate(self.up_blocks):
1257
- is_final_block = i == len(self.up_blocks) - 1
1258
-
1259
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1260
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1261
-
1262
- # if we have not reached the final block and need to forward the
1263
- # upsample size, we do it here
1264
- if not is_final_block and forward_upsample_size:
1265
- upsample_size = down_block_res_samples[-1].shape[2:]
1266
-
1267
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1268
- sample,out_garment_feat = upsample_block(
1269
- hidden_states=sample,
1270
- temb=emb,
1271
- res_hidden_states_tuple=res_samples,
1272
- encoder_hidden_states=encoder_hidden_states,
1273
- cross_attention_kwargs=cross_attention_kwargs,
1274
- upsample_size=upsample_size,
1275
- attention_mask=attention_mask,
1276
- encoder_attention_mask=encoder_attention_mask,
1277
- )
1278
- garment_features += out_garment_feat
1279
-
1280
-
1281
- if not return_dict:
1282
- return (sample,),garment_features
1283
-
1284
- return UNet2DConditionOutput(sample=sample),garment_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/unet_hacked_tryon.py DELETED
@@ -1,1395 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, List, Optional, Tuple, Union
16
-
17
- import torch
18
- import torch.nn as nn
19
- import torch.utils.checkpoint
20
-
21
- from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.loaders import UNet2DConditionLoadersMixin
23
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
24
- from diffusers.models.activations import get_activation
25
- from diffusers.models.attention_processor import (
26
- ADDED_KV_ATTENTION_PROCESSORS,
27
- CROSS_ATTENTION_PROCESSORS,
28
- Attention,
29
- AttentionProcessor,
30
- AttnAddedKVProcessor,
31
- AttnProcessor,
32
- )
33
- from einops import rearrange
34
-
35
- from diffusers.models.embeddings import (
36
- GaussianFourierProjection,
37
- ImageHintTimeEmbedding,
38
- ImageProjection,
39
- ImageTimeEmbedding,
40
- PositionNet,
41
- TextImageProjection,
42
- TextImageTimeEmbedding,
43
- TextTimeEmbedding,
44
- TimestepEmbedding,
45
- Timesteps,
46
- )
47
-
48
-
49
- from diffusers.models.modeling_utils import ModelMixin
50
- from src.unet_block_hacked_tryon import (
51
- UNetMidBlock2D,
52
- UNetMidBlock2DCrossAttn,
53
- UNetMidBlock2DSimpleCrossAttn,
54
- get_down_block,
55
- get_up_block,
56
- )
57
- from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
58
- from diffusers.models.transformer_2d import Transformer2DModel
59
- import math
60
-
61
- from ip_adapter.ip_adapter import Resampler
62
-
63
-
64
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
65
-
66
-
67
- # def FeedForward(dim, mult=4):
68
- # inner_dim = int(dim * mult)
69
- # return nn.Sequential(
70
- # nn.LayerNorm(dim),
71
- # nn.Linear(dim, inner_dim, bias=False),
72
- # nn.GELU(),
73
- # nn.Linear(inner_dim, dim, bias=False),
74
- # )
75
-
76
-
77
-
78
- # def reshape_tensor(x, heads):
79
- # bs, length, width = x.shape
80
- # # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
81
- # x = x.view(bs, length, heads, -1)
82
- # # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
83
- # x = x.transpose(1, 2)
84
- # # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
85
- # x = x.reshape(bs, heads, length, -1)
86
- # return x
87
-
88
-
89
- # class PerceiverAttention(nn.Module):
90
- # def __init__(self, *, dim, dim_head=64, heads=8):
91
- # super().__init__()
92
- # self.scale = dim_head**-0.5
93
- # self.dim_head = dim_head
94
- # self.heads = heads
95
- # inner_dim = dim_head * heads
96
-
97
- # self.norm1 = nn.LayerNorm(dim)
98
- # self.norm2 = nn.LayerNorm(dim)
99
-
100
- # self.to_q = nn.Linear(dim, inner_dim, bias=False)
101
- # self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
102
- # self.to_out = nn.Linear(inner_dim, dim, bias=False)
103
-
104
- # def forward(self, x, latents):
105
- # """
106
- # Args:
107
- # x (torch.Tensor): image features
108
- # shape (b, n1, D)
109
- # latent (torch.Tensor): latent features
110
- # shape (b, n2, D)
111
- # """
112
- # x = self.norm1(x)
113
- # latents = self.norm2(latents)
114
-
115
- # b, l, _ = latents.shape
116
-
117
- # q = self.to_q(latents)
118
- # kv_input = torch.cat((x, latents), dim=-2)
119
- # k, v = self.to_kv(kv_input).chunk(2, dim=-1)
120
-
121
- # q = reshape_tensor(q, self.heads)
122
- # k = reshape_tensor(k, self.heads)
123
- # v = reshape_tensor(v, self.heads)
124
-
125
- # # attention
126
- # scale = 1 / math.sqrt(math.sqrt(self.dim_head))
127
- # weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
128
- # weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
129
- # out = weight @ v
130
-
131
- # out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
132
-
133
- # return self.to_out(out)
134
-
135
-
136
- # class Resampler(nn.Module):
137
- # def __init__(
138
- # self,
139
- # dim=1024,
140
- # depth=8,
141
- # dim_head=64,
142
- # heads=16,
143
- # num_queries=8,
144
- # embedding_dim=768,
145
- # output_dim=1024,
146
- # ff_mult=4,
147
- # max_seq_len: int = 257, # CLIP tokens + CLS token
148
- # apply_pos_emb: bool = False,
149
- # num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
150
- # ):
151
- # super().__init__()
152
-
153
- # self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
154
-
155
- # self.proj_in = nn.Linear(embedding_dim, dim)
156
-
157
- # self.proj_out = nn.Linear(dim, output_dim)
158
- # self.norm_out = nn.LayerNorm(output_dim)
159
-
160
- # self.layers = nn.ModuleList([])
161
- # for _ in range(depth):
162
- # self.layers.append(
163
- # nn.ModuleList(
164
- # [
165
- # PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
166
- # FeedForward(dim=dim, mult=ff_mult),
167
- # ]
168
- # )
169
- # )
170
-
171
- # def forward(self, x):
172
-
173
- # latents = self.latents.repeat(x.size(0), 1, 1)
174
-
175
- # x = self.proj_in(x)
176
-
177
-
178
- # for attn, ff in self.layers:
179
- # latents = attn(x, latents) + latents
180
- # latents = ff(latents) + latents
181
-
182
- # latents = self.proj_out(latents)
183
- # return self.norm_out(latents)
184
-
185
-
186
- def zero_module(module):
187
- for p in module.parameters():
188
- nn.init.zeros_(p)
189
- return module
190
-
191
- @dataclass
192
- class UNet2DConditionOutput(BaseOutput):
193
- """
194
- The output of [`UNet2DConditionModel`].
195
-
196
- Args:
197
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
198
- The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
199
- """
200
-
201
- sample: torch.FloatTensor = None
202
-
203
-
204
- class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
205
- r"""
206
- A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
207
- shaped output.
208
-
209
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
210
- for all models (such as downloading or saving).
211
-
212
- Parameters:
213
- sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
214
- Height and width of input/output sample.
215
- in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
216
- out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
217
- center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
218
- flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
219
- Whether to flip the sin to cos in the time embedding.
220
- freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
221
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
222
- The tuple of downsample blocks to use.
223
- mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
224
- Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
225
- `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
226
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
227
- The tuple of upsample blocks to use.
228
- only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
229
- Whether to include self-attention in the basic transformer blocks, see
230
- [`~models.attention.BasicTransformerBlock`].
231
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
232
- The tuple of output channels for each block.
233
- layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
234
- downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
235
- mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
236
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
237
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
238
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
239
- If `None`, normalization and activation layers is skipped in post-processing.
240
- norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
241
- cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
242
- The dimension of the cross attention features.
243
- transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
244
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
245
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
246
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
247
- reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
248
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
249
- blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
250
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
251
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
252
- encoder_hid_dim (`int`, *optional*, defaults to None):
253
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
254
- dimension to `cross_attention_dim`.
255
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
256
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
257
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
258
- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
259
- num_attention_heads (`int`, *optional*):
260
- The number of attention heads. If not defined, defaults to `attention_head_dim`
261
- resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
262
- for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
263
- class_embed_type (`str`, *optional*, defaults to `None`):
264
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
265
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
266
- addition_embed_type (`str`, *optional*, defaults to `None`):
267
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
268
- "text". "text" will use the `TextTimeEmbedding` layer.
269
- addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
270
- Dimension for the timestep embeddings.
271
- num_class_embeds (`int`, *optional*, defaults to `None`):
272
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
273
- class conditioning with `class_embed_type` equal to `None`.
274
- time_embedding_type (`str`, *optional*, defaults to `positional`):
275
- The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
276
- time_embedding_dim (`int`, *optional*, defaults to `None`):
277
- An optional override for the dimension of the projected time embedding.
278
- time_embedding_act_fn (`str`, *optional*, defaults to `None`):
279
- Optional activation function to use only once on the time embeddings before they are passed to the rest of
280
- the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
281
- timestep_post_act (`str`, *optional*, defaults to `None`):
282
- The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
283
- time_cond_proj_dim (`int`, *optional*, defaults to `None`):
284
- The dimension of `cond_proj` layer in the timestep embedding.
285
- conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
286
- *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
287
- *optional*): The dimension of the `class_labels` input when
288
- `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
289
- class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
290
- embeddings with the class embeddings.
291
- mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
292
- Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
293
- `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
294
- `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
295
- otherwise.
296
- """
297
-
298
- _supports_gradient_checkpointing = True
299
-
300
- @register_to_config
301
- def __init__(
302
- self,
303
- sample_size: Optional[int] = None,
304
- in_channels: int = 4,
305
- out_channels: int = 4,
306
- center_input_sample: bool = False,
307
- flip_sin_to_cos: bool = True,
308
- freq_shift: int = 0,
309
- down_block_types: Tuple[str] = (
310
- "CrossAttnDownBlock2D",
311
- "CrossAttnDownBlock2D",
312
- "CrossAttnDownBlock2D",
313
- "DownBlock2D",
314
- ),
315
- mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
316
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
317
- only_cross_attention: Union[bool, Tuple[bool]] = False,
318
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
319
- layers_per_block: Union[int, Tuple[int]] = 2,
320
- downsample_padding: int = 1,
321
- mid_block_scale_factor: float = 1,
322
- dropout: float = 0.0,
323
- act_fn: str = "silu",
324
- norm_num_groups: Optional[int] = 32,
325
- norm_eps: float = 1e-5,
326
- cross_attention_dim: Union[int, Tuple[int]] = 1280,
327
- transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
328
- reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
329
- encoder_hid_dim: Optional[int] = None,
330
- encoder_hid_dim_type: Optional[str] = None,
331
- attention_head_dim: Union[int, Tuple[int]] = 8,
332
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
333
- dual_cross_attention: bool = False,
334
- use_linear_projection: bool = False,
335
- class_embed_type: Optional[str] = None,
336
- addition_embed_type: Optional[str] = None,
337
- addition_time_embed_dim: Optional[int] = None,
338
- num_class_embeds: Optional[int] = None,
339
- upcast_attention: bool = False,
340
- resnet_time_scale_shift: str = "default",
341
- resnet_skip_time_act: bool = False,
342
- resnet_out_scale_factor: int = 1.0,
343
- time_embedding_type: str = "positional",
344
- time_embedding_dim: Optional[int] = None,
345
- time_embedding_act_fn: Optional[str] = None,
346
- timestep_post_act: Optional[str] = None,
347
- time_cond_proj_dim: Optional[int] = None,
348
- conv_in_kernel: int = 3,
349
- conv_out_kernel: int = 3,
350
- projection_class_embeddings_input_dim: Optional[int] = None,
351
- attention_type: str = "default",
352
- class_embeddings_concat: bool = False,
353
- mid_block_only_cross_attention: Optional[bool] = None,
354
- cross_attention_norm: Optional[str] = None,
355
- addition_embed_type_num_heads=64,
356
- ):
357
- super().__init__()
358
-
359
- self.sample_size = sample_size
360
-
361
- if num_attention_heads is not None:
362
- raise ValueError(
363
- "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
364
- )
365
-
366
- # If `num_attention_heads` is not defined (which is the case for most models)
367
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
368
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
369
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
370
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
371
- # which is why we correct for the naming here.
372
- num_attention_heads = num_attention_heads or attention_head_dim
373
-
374
- # Check inputs
375
- if len(down_block_types) != len(up_block_types):
376
- raise ValueError(
377
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
378
- )
379
-
380
- if len(block_out_channels) != len(down_block_types):
381
- raise ValueError(
382
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
383
- )
384
-
385
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
386
- raise ValueError(
387
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
388
- )
389
-
390
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
391
- raise ValueError(
392
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
393
- )
394
-
395
- if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
396
- raise ValueError(
397
- f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
398
- )
399
-
400
- if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
401
- raise ValueError(
402
- f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
403
- )
404
-
405
- if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
406
- raise ValueError(
407
- f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
408
- )
409
- if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
410
- for layer_number_per_block in transformer_layers_per_block:
411
- if isinstance(layer_number_per_block, list):
412
- raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
413
-
414
- # input
415
- conv_in_padding = (conv_in_kernel - 1) // 2
416
- self.conv_in = nn.Conv2d(
417
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
418
- )
419
-
420
- # time
421
- if time_embedding_type == "fourier":
422
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
423
- if time_embed_dim % 2 != 0:
424
- raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
425
- self.time_proj = GaussianFourierProjection(
426
- time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
427
- )
428
- timestep_input_dim = time_embed_dim
429
- elif time_embedding_type == "positional":
430
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
431
-
432
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
433
- timestep_input_dim = block_out_channels[0]
434
- else:
435
- raise ValueError(
436
- f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
437
- )
438
-
439
- self.time_embedding = TimestepEmbedding(
440
- timestep_input_dim,
441
- time_embed_dim,
442
- act_fn=act_fn,
443
- post_act_fn=timestep_post_act,
444
- cond_proj_dim=time_cond_proj_dim,
445
- )
446
-
447
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
448
- encoder_hid_dim_type = "text_proj"
449
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
450
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
451
-
452
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
453
- raise ValueError(
454
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
455
- )
456
-
457
- if encoder_hid_dim_type == "text_proj":
458
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
459
- elif encoder_hid_dim_type == "text_image_proj":
460
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
461
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
462
- # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
463
- self.encoder_hid_proj = TextImageProjection(
464
- text_embed_dim=encoder_hid_dim,
465
- image_embed_dim=cross_attention_dim,
466
- cross_attention_dim=cross_attention_dim,
467
- )
468
- elif encoder_hid_dim_type == "image_proj":
469
- # Kandinsky 2.2
470
- self.encoder_hid_proj = ImageProjection(
471
- image_embed_dim=encoder_hid_dim,
472
- cross_attention_dim=cross_attention_dim,
473
- )
474
- elif encoder_hid_dim_type == "ip_image_proj":
475
- # Kandinsky 2.2
476
- self.encoder_hid_proj = Resampler(
477
- dim=1280,
478
- depth=4,
479
- dim_head=64,
480
- heads=20,
481
- num_queries=16,
482
- embedding_dim=encoder_hid_dim,
483
- output_dim=self.config.cross_attention_dim,
484
- ff_mult=4,
485
- )
486
-
487
-
488
- elif encoder_hid_dim_type is not None:
489
- raise ValueError(
490
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
491
- )
492
- else:
493
- self.encoder_hid_proj = None
494
-
495
- # class embedding
496
- if class_embed_type is None and num_class_embeds is not None:
497
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
498
- elif class_embed_type == "timestep":
499
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
500
- elif class_embed_type == "identity":
501
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
502
- elif class_embed_type == "projection":
503
- if projection_class_embeddings_input_dim is None:
504
- raise ValueError(
505
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
506
- )
507
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
508
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
509
- # 2. it projects from an arbitrary input dimension.
510
- #
511
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
512
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
513
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
514
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
515
- elif class_embed_type == "simple_projection":
516
- if projection_class_embeddings_input_dim is None:
517
- raise ValueError(
518
- "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
519
- )
520
- self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
521
- else:
522
- self.class_embedding = None
523
-
524
- if addition_embed_type == "text":
525
- if encoder_hid_dim is not None:
526
- text_time_embedding_from_dim = encoder_hid_dim
527
- else:
528
- text_time_embedding_from_dim = cross_attention_dim
529
-
530
- self.add_embedding = TextTimeEmbedding(
531
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
532
- )
533
- elif addition_embed_type == "text_image":
534
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
535
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
536
- # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
537
- self.add_embedding = TextImageTimeEmbedding(
538
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
539
- )
540
- elif addition_embed_type == "text_time":
541
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
542
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
543
- elif addition_embed_type == "image":
544
- # Kandinsky 2.2
545
- self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
546
- elif addition_embed_type == "image_hint":
547
- # Kandinsky 2.2 ControlNet
548
- self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
549
- elif addition_embed_type is not None:
550
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
551
-
552
- if time_embedding_act_fn is None:
553
- self.time_embed_act = None
554
- else:
555
- self.time_embed_act = get_activation(time_embedding_act_fn)
556
-
557
- self.down_blocks = nn.ModuleList([])
558
- self.up_blocks = nn.ModuleList([])
559
-
560
- if isinstance(only_cross_attention, bool):
561
- if mid_block_only_cross_attention is None:
562
- mid_block_only_cross_attention = only_cross_attention
563
-
564
- only_cross_attention = [only_cross_attention] * len(down_block_types)
565
-
566
- if mid_block_only_cross_attention is None:
567
- mid_block_only_cross_attention = False
568
-
569
- if isinstance(num_attention_heads, int):
570
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
571
-
572
- if isinstance(attention_head_dim, int):
573
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
574
-
575
- if isinstance(cross_attention_dim, int):
576
- cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
577
-
578
- if isinstance(layers_per_block, int):
579
- layers_per_block = [layers_per_block] * len(down_block_types)
580
-
581
- if isinstance(transformer_layers_per_block, int):
582
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
583
- if class_embeddings_concat:
584
- # The time embeddings are concatenated with the class embeddings. The dimension of the
585
- # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
586
- # regular time embeddings
587
- blocks_time_embed_dim = time_embed_dim * 2
588
- else:
589
- blocks_time_embed_dim = time_embed_dim
590
-
591
- # down
592
- output_channel = block_out_channels[0]
593
- for i, down_block_type in enumerate(down_block_types):
594
- input_channel = output_channel
595
- output_channel = block_out_channels[i]
596
- is_final_block = i == len(block_out_channels) - 1
597
-
598
- down_block = get_down_block(
599
- down_block_type,
600
- num_layers=layers_per_block[i],
601
- transformer_layers_per_block=transformer_layers_per_block[i],
602
- in_channels=input_channel,
603
- out_channels=output_channel,
604
- temb_channels=blocks_time_embed_dim,
605
- add_downsample=not is_final_block,
606
- resnet_eps=norm_eps,
607
- resnet_act_fn=act_fn,
608
- resnet_groups=norm_num_groups,
609
- cross_attention_dim=cross_attention_dim[i],
610
- num_attention_heads=num_attention_heads[i],
611
- downsample_padding=downsample_padding,
612
- dual_cross_attention=dual_cross_attention,
613
- use_linear_projection=use_linear_projection,
614
- only_cross_attention=only_cross_attention[i],
615
- upcast_attention=upcast_attention,
616
- resnet_time_scale_shift=resnet_time_scale_shift,
617
- attention_type=attention_type,
618
- resnet_skip_time_act=resnet_skip_time_act,
619
- resnet_out_scale_factor=resnet_out_scale_factor,
620
- cross_attention_norm=cross_attention_norm,
621
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
622
- dropout=dropout,
623
- )
624
- self.down_blocks.append(down_block)
625
-
626
- # mid
627
- if mid_block_type == "UNetMidBlock2DCrossAttn":
628
- self.mid_block = UNetMidBlock2DCrossAttn(
629
- transformer_layers_per_block=transformer_layers_per_block[-1],
630
- in_channels=block_out_channels[-1],
631
- temb_channels=blocks_time_embed_dim,
632
- dropout=dropout,
633
- resnet_eps=norm_eps,
634
- resnet_act_fn=act_fn,
635
- output_scale_factor=mid_block_scale_factor,
636
- resnet_time_scale_shift=resnet_time_scale_shift,
637
- cross_attention_dim=cross_attention_dim[-1],
638
- num_attention_heads=num_attention_heads[-1],
639
- resnet_groups=norm_num_groups,
640
- dual_cross_attention=dual_cross_attention,
641
- use_linear_projection=use_linear_projection,
642
- upcast_attention=upcast_attention,
643
- attention_type=attention_type,
644
- )
645
- elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
646
- self.mid_block = UNetMidBlock2DSimpleCrossAttn(
647
- in_channels=block_out_channels[-1],
648
- temb_channels=blocks_time_embed_dim,
649
- dropout=dropout,
650
- resnet_eps=norm_eps,
651
- resnet_act_fn=act_fn,
652
- output_scale_factor=mid_block_scale_factor,
653
- cross_attention_dim=cross_attention_dim[-1],
654
- attention_head_dim=attention_head_dim[-1],
655
- resnet_groups=norm_num_groups,
656
- resnet_time_scale_shift=resnet_time_scale_shift,
657
- skip_time_act=resnet_skip_time_act,
658
- only_cross_attention=mid_block_only_cross_attention,
659
- cross_attention_norm=cross_attention_norm,
660
- )
661
- elif mid_block_type == "UNetMidBlock2D":
662
- self.mid_block = UNetMidBlock2D(
663
- in_channels=block_out_channels[-1],
664
- temb_channels=blocks_time_embed_dim,
665
- dropout=dropout,
666
- num_layers=0,
667
- resnet_eps=norm_eps,
668
- resnet_act_fn=act_fn,
669
- output_scale_factor=mid_block_scale_factor,
670
- resnet_groups=norm_num_groups,
671
- resnet_time_scale_shift=resnet_time_scale_shift,
672
- add_attention=False,
673
- )
674
- elif mid_block_type is None:
675
- self.mid_block = None
676
- else:
677
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
678
-
679
- # count how many layers upsample the images
680
- self.num_upsamplers = 0
681
-
682
- # up
683
- reversed_block_out_channels = list(reversed(block_out_channels))
684
- reversed_num_attention_heads = list(reversed(num_attention_heads))
685
- reversed_layers_per_block = list(reversed(layers_per_block))
686
- reversed_cross_attention_dim = list(reversed(cross_attention_dim))
687
- reversed_transformer_layers_per_block = (
688
- list(reversed(transformer_layers_per_block))
689
- if reverse_transformer_layers_per_block is None
690
- else reverse_transformer_layers_per_block
691
- )
692
- only_cross_attention = list(reversed(only_cross_attention))
693
-
694
- output_channel = reversed_block_out_channels[0]
695
- for i, up_block_type in enumerate(up_block_types):
696
- is_final_block = i == len(block_out_channels) - 1
697
-
698
- prev_output_channel = output_channel
699
- output_channel = reversed_block_out_channels[i]
700
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
701
-
702
- # add upsample block for all BUT final layer
703
- if not is_final_block:
704
- add_upsample = True
705
- self.num_upsamplers += 1
706
- else:
707
- add_upsample = False
708
- up_block = get_up_block(
709
- up_block_type,
710
- num_layers=reversed_layers_per_block[i] + 1,
711
- transformer_layers_per_block=reversed_transformer_layers_per_block[i],
712
- in_channels=input_channel,
713
- out_channels=output_channel,
714
- prev_output_channel=prev_output_channel,
715
- temb_channels=blocks_time_embed_dim,
716
- add_upsample=add_upsample,
717
- resnet_eps=norm_eps,
718
- resnet_act_fn=act_fn,
719
- resolution_idx=i,
720
- resnet_groups=norm_num_groups,
721
- cross_attention_dim=reversed_cross_attention_dim[i],
722
- num_attention_heads=reversed_num_attention_heads[i],
723
- dual_cross_attention=dual_cross_attention,
724
- use_linear_projection=use_linear_projection,
725
- only_cross_attention=only_cross_attention[i],
726
- upcast_attention=upcast_attention,
727
- resnet_time_scale_shift=resnet_time_scale_shift,
728
- attention_type=attention_type,
729
- resnet_skip_time_act=resnet_skip_time_act,
730
- resnet_out_scale_factor=resnet_out_scale_factor,
731
- cross_attention_norm=cross_attention_norm,
732
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
733
- dropout=dropout,
734
- )
735
-
736
- self.up_blocks.append(up_block)
737
- prev_output_channel = output_channel
738
-
739
-
740
-
741
-
742
- # out
743
- if norm_num_groups is not None:
744
- self.conv_norm_out = nn.GroupNorm(
745
- num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
746
- )
747
-
748
- self.conv_act = get_activation(act_fn)
749
-
750
- else:
751
- self.conv_norm_out = None
752
- self.conv_act = None
753
-
754
- conv_out_padding = (conv_out_kernel - 1) // 2
755
- self.conv_out = nn.Conv2d(
756
- block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
757
- )
758
-
759
- if attention_type in ["gated", "gated-text-image"]:
760
- positive_len = 768
761
- if isinstance(cross_attention_dim, int):
762
- positive_len = cross_attention_dim
763
- elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
764
- positive_len = cross_attention_dim[0]
765
-
766
- feature_type = "text-only" if attention_type == "gated" else "text-image"
767
- self.position_net = PositionNet(
768
- positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
769
- )
770
-
771
-
772
-
773
- from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
774
-
775
- attn_procs = {}
776
- for name in self.attn_processors.keys():
777
- cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
778
- if name.startswith("mid_block"):
779
- hidden_size = self.config.block_out_channels[-1]
780
- elif name.startswith("up_blocks"):
781
- block_id = int(name[len("up_blocks.")])
782
- hidden_size = list(reversed(self.config.block_out_channels))[block_id]
783
- elif name.startswith("down_blocks"):
784
- block_id = int(name[len("down_blocks.")])
785
- hidden_size = self.config.block_out_channels[block_id]
786
- if cross_attention_dim is None:
787
- attn_procs[name] = AttnProcessor()
788
- else:
789
- layer_name = name.split(".processor")[0]
790
- attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=16)
791
- self.set_attn_processor(attn_procs)
792
-
793
-
794
- @property
795
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
796
- r"""
797
- Returns:
798
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
799
- indexed by its weight name.
800
- """
801
- # set recursively
802
- processors = {}
803
-
804
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
805
- if hasattr(module, "get_processor"):
806
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
807
-
808
- for sub_name, child in module.named_children():
809
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
810
-
811
- return processors
812
-
813
- for name, module in self.named_children():
814
- fn_recursive_add_processors(name, module, processors)
815
-
816
- return processors
817
-
818
- def set_attn_processor(
819
- self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
820
- ):
821
- r"""
822
- Sets the attention processor to use to compute attention.
823
-
824
- Parameters:
825
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
826
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
827
- for **all** `Attention` layers.
828
-
829
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
830
- processor. This is strongly recommended when setting trainable attention processors.
831
-
832
- """
833
- count = len(self.attn_processors.keys())
834
-
835
- if isinstance(processor, dict) and len(processor) != count:
836
- raise ValueError(
837
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
838
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
839
- )
840
-
841
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
842
- if hasattr(module, "set_processor"):
843
- if not isinstance(processor, dict):
844
- module.set_processor(processor, _remove_lora=_remove_lora)
845
- else:
846
- module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
847
-
848
- for sub_name, child in module.named_children():
849
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
850
-
851
- for name, module in self.named_children():
852
- fn_recursive_attn_processor(name, module, processor)
853
-
854
- def set_default_attn_processor(self):
855
- """
856
- Disables custom attention processors and sets the default attention implementation.
857
- """
858
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
859
- processor = AttnAddedKVProcessor()
860
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
861
- processor = AttnProcessor()
862
- else:
863
- raise ValueError(
864
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
865
- )
866
-
867
- self.set_attn_processor(processor, _remove_lora=True)
868
-
869
- def set_attention_slice(self, slice_size):
870
- r"""
871
- Enable sliced attention computation.
872
-
873
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
874
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
875
-
876
- Args:
877
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
878
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
879
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
880
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
881
- must be a multiple of `slice_size`.
882
- """
883
- sliceable_head_dims = []
884
-
885
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
886
- if hasattr(module, "set_attention_slice"):
887
- sliceable_head_dims.append(module.sliceable_head_dim)
888
-
889
- for child in module.children():
890
- fn_recursive_retrieve_sliceable_dims(child)
891
-
892
- # retrieve number of attention layers
893
- for module in self.children():
894
- fn_recursive_retrieve_sliceable_dims(module)
895
-
896
- num_sliceable_layers = len(sliceable_head_dims)
897
-
898
- if slice_size == "auto":
899
- # half the attention head size is usually a good trade-off between
900
- # speed and memory
901
- slice_size = [dim // 2 for dim in sliceable_head_dims]
902
- elif slice_size == "max":
903
- # make smallest slice possible
904
- slice_size = num_sliceable_layers * [1]
905
-
906
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
907
-
908
- if len(slice_size) != len(sliceable_head_dims):
909
- raise ValueError(
910
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
911
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
912
- )
913
-
914
- for i in range(len(slice_size)):
915
- size = slice_size[i]
916
- dim = sliceable_head_dims[i]
917
- if size is not None and size > dim:
918
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
919
-
920
- # Recursively walk through all the children.
921
- # Any children which exposes the set_attention_slice method
922
- # gets the message
923
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
924
- if hasattr(module, "set_attention_slice"):
925
- module.set_attention_slice(slice_size.pop())
926
-
927
- for child in module.children():
928
- fn_recursive_set_attention_slice(child, slice_size)
929
-
930
- reversed_slice_size = list(reversed(slice_size))
931
- for module in self.children():
932
- fn_recursive_set_attention_slice(module, reversed_slice_size)
933
-
934
- def _set_gradient_checkpointing(self, module, value=False):
935
- if hasattr(module, "gradient_checkpointing"):
936
- module.gradient_checkpointing = value
937
-
938
- def enable_freeu(self, s1, s2, b1, b2):
939
- r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
940
-
941
- The suffixes after the scaling factors represent the stage blocks where they are being applied.
942
-
943
- Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
944
- are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
945
-
946
- Args:
947
- s1 (`float`):
948
- Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
949
- mitigate the "oversmoothing effect" in the enhanced denoising process.
950
- s2 (`float`):
951
- Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
952
- mitigate the "oversmoothing effect" in the enhanced denoising process.
953
- b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
954
- b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
955
- """
956
- for i, upsample_block in enumerate(self.up_blocks):
957
- setattr(upsample_block, "s1", s1)
958
- setattr(upsample_block, "s2", s2)
959
- setattr(upsample_block, "b1", b1)
960
- setattr(upsample_block, "b2", b2)
961
-
962
- def disable_freeu(self):
963
- """Disables the FreeU mechanism."""
964
- freeu_keys = {"s1", "s2", "b1", "b2"}
965
- for i, upsample_block in enumerate(self.up_blocks):
966
- for k in freeu_keys:
967
- if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
968
- setattr(upsample_block, k, None)
969
-
970
- def fuse_qkv_projections(self):
971
- """
972
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
973
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
974
-
975
- <Tip warning={true}>
976
-
977
- This API is πŸ§ͺ experimental.
978
-
979
- </Tip>
980
- """
981
- self.original_attn_processors = None
982
-
983
- for _, attn_processor in self.attn_processors.items():
984
- if "Added" in str(attn_processor.__class__.__name__):
985
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
986
-
987
- self.original_attn_processors = self.attn_processors
988
-
989
- for module in self.modules():
990
- if isinstance(module, Attention):
991
- module.fuse_projections(fuse=True)
992
-
993
- def unfuse_qkv_projections(self):
994
- """Disables the fused QKV projection if enabled.
995
-
996
- <Tip warning={true}>
997
-
998
- This API is πŸ§ͺ experimental.
999
-
1000
- </Tip>
1001
-
1002
- """
1003
- if self.original_attn_processors is not None:
1004
- self.set_attn_processor(self.original_attn_processors)
1005
-
1006
- def forward(
1007
- self,
1008
- sample: torch.FloatTensor,
1009
- timestep: Union[torch.Tensor, float, int],
1010
- encoder_hidden_states: torch.Tensor,
1011
- class_labels: Optional[torch.Tensor] = None,
1012
- timestep_cond: Optional[torch.Tensor] = None,
1013
- attention_mask: Optional[torch.Tensor] = None,
1014
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1015
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1016
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1017
- mid_block_additional_residual: Optional[torch.Tensor] = None,
1018
- down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1019
- encoder_attention_mask: Optional[torch.Tensor] = None,
1020
- return_dict: bool = True,
1021
- garment_features: Optional[Tuple[torch.Tensor]] = None,
1022
- ) -> Union[UNet2DConditionOutput, Tuple]:
1023
- r"""
1024
- The [`UNet2DConditionModel`] forward method.
1025
-
1026
- Args:
1027
- sample (`torch.FloatTensor`):
1028
- The noisy input tensor with the following shape `(batch, channel, height, width)`.
1029
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
1030
- encoder_hidden_states (`torch.FloatTensor`):
1031
- The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1032
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1033
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1034
- timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1035
- Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1036
- through the `self.time_embedding` layer to obtain the timestep embeddings.
1037
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1038
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1039
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1040
- negative values to the attention scores corresponding to "discard" tokens.
1041
- cross_attention_kwargs (`dict`, *optional*):
1042
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1043
- `self.processor` in
1044
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1045
- added_cond_kwargs: (`dict`, *optional*):
1046
- A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1047
- are passed along to the UNet blocks.
1048
- down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1049
- A tuple of tensors that if specified are added to the residuals of down unet blocks.
1050
- mid_block_additional_residual: (`torch.Tensor`, *optional*):
1051
- A tensor that if specified is added to the residual of the middle unet block.
1052
- encoder_attention_mask (`torch.Tensor`):
1053
- A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1054
- `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1055
- which adds large negative values to the attention scores corresponding to "discard" tokens.
1056
- return_dict (`bool`, *optional*, defaults to `True`):
1057
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1058
- tuple.
1059
- cross_attention_kwargs (`dict`, *optional*):
1060
- A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
1061
- added_cond_kwargs: (`dict`, *optional*):
1062
- A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
1063
- are passed along to the UNet blocks.
1064
- down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1065
- additional residuals to be added to UNet long skip connections from down blocks to up blocks for
1066
- example from ControlNet side model(s)
1067
- mid_block_additional_residual (`torch.Tensor`, *optional*):
1068
- additional residual to be added to UNet mid block output, for example from ControlNet side model
1069
- down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1070
- additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1071
-
1072
- Returns:
1073
- [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1074
- If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
1075
- a `tuple` is returned where the first element is the sample tensor.
1076
- """
1077
- # By default samples have to be AT least a multiple of the overall upsampling factor.
1078
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1079
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
1080
- # on the fly if necessary.
1081
- default_overall_up_factor = 2**self.num_upsamplers
1082
-
1083
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1084
- forward_upsample_size = False
1085
- upsample_size = None
1086
-
1087
- for dim in sample.shape[-2:]:
1088
- if dim % default_overall_up_factor != 0:
1089
- # Forward upsample size to force interpolation output size.
1090
- forward_upsample_size = True
1091
- break
1092
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1093
- # expects mask of shape:
1094
- # [batch, key_tokens]
1095
- # adds singleton query_tokens dimension:
1096
- # [batch, 1, key_tokens]
1097
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1098
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1099
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1100
- if attention_mask is not None:
1101
- # assume that mask is expressed as:
1102
- # (1 = keep, 0 = discard)
1103
- # convert mask into a bias that can be added to attention scores:
1104
- # (keep = +0, discard = -10000.0)
1105
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1106
- attention_mask = attention_mask.unsqueeze(1)
1107
-
1108
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
1109
- if encoder_attention_mask is not None:
1110
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1111
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1112
-
1113
- # 0. center input if necessary
1114
- if self.config.center_input_sample:
1115
- sample = 2 * sample - 1.0
1116
-
1117
- # 1. time
1118
- timesteps = timestep
1119
- if not torch.is_tensor(timesteps):
1120
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1121
- # This would be a good case for the `match` statement (Python 3.10+)
1122
- is_mps = sample.device.type == "mps"
1123
- if isinstance(timestep, float):
1124
- dtype = torch.float32 if is_mps else torch.float64
1125
- else:
1126
- dtype = torch.int32 if is_mps else torch.int64
1127
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1128
- elif len(timesteps.shape) == 0:
1129
- timesteps = timesteps[None].to(sample.device)
1130
-
1131
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1132
- timesteps = timesteps.expand(sample.shape[0])
1133
-
1134
- t_emb = self.time_proj(timesteps)
1135
-
1136
- # `Timesteps` does not contain any weights and will always return f32 tensors
1137
- # but time_embedding might actually be running in fp16. so we need to cast here.
1138
- # there might be better ways to encapsulate this.
1139
- t_emb = t_emb.to(dtype=sample.dtype)
1140
-
1141
- emb = self.time_embedding(t_emb, timestep_cond)
1142
- aug_emb = None
1143
-
1144
- if self.class_embedding is not None:
1145
- if class_labels is None:
1146
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
1147
-
1148
- if self.config.class_embed_type == "timestep":
1149
- class_labels = self.time_proj(class_labels)
1150
-
1151
- # `Timesteps` does not contain any weights and will always return f32 tensors
1152
- # there might be better ways to encapsulate this.
1153
- class_labels = class_labels.to(dtype=sample.dtype)
1154
-
1155
- class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1156
-
1157
- if self.config.class_embeddings_concat:
1158
- emb = torch.cat([emb, class_emb], dim=-1)
1159
- else:
1160
- emb = emb + class_emb
1161
-
1162
- if self.config.addition_embed_type == "text":
1163
- aug_emb = self.add_embedding(encoder_hidden_states)
1164
- elif self.config.addition_embed_type == "text_image":
1165
- # Kandinsky 2.1 - style
1166
- if "image_embeds" not in added_cond_kwargs:
1167
- raise ValueError(
1168
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1169
- )
1170
-
1171
- image_embs = added_cond_kwargs.get("image_embeds")
1172
- text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1173
- aug_emb = self.add_embedding(text_embs, image_embs)
1174
- elif self.config.addition_embed_type == "text_time":
1175
- # SDXL - style
1176
- if "text_embeds" not in added_cond_kwargs:
1177
- raise ValueError(
1178
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1179
- )
1180
- text_embeds = added_cond_kwargs.get("text_embeds")
1181
- if "time_ids" not in added_cond_kwargs:
1182
- raise ValueError(
1183
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1184
- )
1185
- time_ids = added_cond_kwargs.get("time_ids")
1186
- time_embeds = self.add_time_proj(time_ids.flatten())
1187
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1188
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1189
- add_embeds = add_embeds.to(emb.dtype)
1190
- aug_emb = self.add_embedding(add_embeds)
1191
- elif self.config.addition_embed_type == "image":
1192
- # Kandinsky 2.2 - style
1193
- if "image_embeds" not in added_cond_kwargs:
1194
- raise ValueError(
1195
- f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1196
- )
1197
- image_embs = added_cond_kwargs.get("image_embeds")
1198
- aug_emb = self.add_embedding(image_embs)
1199
- elif self.config.addition_embed_type == "image_hint":
1200
- # Kandinsky 2.2 - style
1201
- if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1202
- raise ValueError(
1203
- f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1204
- )
1205
- image_embs = added_cond_kwargs.get("image_embeds")
1206
- hint = added_cond_kwargs.get("hint")
1207
- aug_emb, hint = self.add_embedding(image_embs, hint)
1208
- sample = torch.cat([sample, hint], dim=1)
1209
-
1210
- emb = emb + aug_emb if aug_emb is not None else emb
1211
-
1212
- if self.time_embed_act is not None:
1213
- emb = self.time_embed_act(emb)
1214
-
1215
- if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1216
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1217
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1218
- # Kadinsky 2.1 - style
1219
- if "image_embeds" not in added_cond_kwargs:
1220
- raise ValueError(
1221
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1222
- )
1223
-
1224
- image_embeds = added_cond_kwargs.get("image_embeds")
1225
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1226
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1227
- # Kandinsky 2.2 - style
1228
- if "image_embeds" not in added_cond_kwargs:
1229
- raise ValueError(
1230
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1231
- )
1232
- image_embeds = added_cond_kwargs.get("image_embeds")
1233
- encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1234
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1235
- if "image_embeds" not in added_cond_kwargs:
1236
- raise ValueError(
1237
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1238
- )
1239
- image_embeds = added_cond_kwargs.get("image_embeds")
1240
- # print(image_embeds.shape)
1241
- # image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
1242
- encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
1243
-
1244
- # 2. pre-process
1245
- sample = self.conv_in(sample)
1246
-
1247
- # 2.5 GLIGEN position net
1248
- if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1249
- cross_attention_kwargs = cross_attention_kwargs.copy()
1250
- gligen_args = cross_attention_kwargs.pop("gligen")
1251
- cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1252
-
1253
-
1254
- curr_garment_feat_idx = 0
1255
-
1256
-
1257
- # 3. down
1258
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1259
- if USE_PEFT_BACKEND:
1260
- # weight the lora layers by setting `lora_scale` for each PEFT layer
1261
- scale_lora_layers(self, lora_scale)
1262
-
1263
- is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1264
- # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1265
- is_adapter = down_intrablock_additional_residuals is not None
1266
- # maintain backward compatibility for legacy usage, where
1267
- # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1268
- # but can only use one or the other
1269
- if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1270
- deprecate(
1271
- "T2I should not use down_block_additional_residuals",
1272
- "1.3.0",
1273
- "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1274
- and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1275
- for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1276
- standard_warn=False,
1277
- )
1278
- down_intrablock_additional_residuals = down_block_additional_residuals
1279
- is_adapter = True
1280
-
1281
- down_block_res_samples = (sample,)
1282
- for downsample_block in self.down_blocks:
1283
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1284
- # For t2i-adapter CrossAttnDownBlock2D
1285
- additional_residuals = {}
1286
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
1287
- additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1288
-
1289
- sample, res_samples,curr_garment_feat_idx = downsample_block(
1290
- hidden_states=sample,
1291
- temb=emb,
1292
- encoder_hidden_states=encoder_hidden_states,
1293
- attention_mask=attention_mask,
1294
- cross_attention_kwargs=cross_attention_kwargs,
1295
- encoder_attention_mask=encoder_attention_mask,
1296
- garment_features=garment_features,
1297
- curr_garment_feat_idx=curr_garment_feat_idx,
1298
- **additional_residuals,
1299
- )
1300
- else:
1301
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
1302
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
1303
- sample += down_intrablock_additional_residuals.pop(0)
1304
-
1305
- down_block_res_samples += res_samples
1306
-
1307
-
1308
- if is_controlnet:
1309
- new_down_block_res_samples = ()
1310
-
1311
- for down_block_res_sample, down_block_additional_residual in zip(
1312
- down_block_res_samples, down_block_additional_residuals
1313
- ):
1314
- down_block_res_sample = down_block_res_sample + down_block_additional_residual
1315
- new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1316
-
1317
- down_block_res_samples = new_down_block_res_samples
1318
-
1319
- # 4. mid
1320
- if self.mid_block is not None:
1321
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1322
- sample ,curr_garment_feat_idx= self.mid_block(
1323
- sample,
1324
- emb,
1325
- encoder_hidden_states=encoder_hidden_states,
1326
- attention_mask=attention_mask,
1327
- cross_attention_kwargs=cross_attention_kwargs,
1328
- encoder_attention_mask=encoder_attention_mask,
1329
- garment_features=garment_features,
1330
- curr_garment_feat_idx=curr_garment_feat_idx,
1331
- )
1332
- else:
1333
- sample = self.mid_block(sample, emb)
1334
-
1335
- # To support T2I-Adapter-XL
1336
- if (
1337
- is_adapter
1338
- and len(down_intrablock_additional_residuals) > 0
1339
- and sample.shape == down_intrablock_additional_residuals[0].shape
1340
- ):
1341
- sample += down_intrablock_additional_residuals.pop(0)
1342
-
1343
- if is_controlnet:
1344
- sample = sample + mid_block_additional_residual
1345
-
1346
-
1347
-
1348
- # 5. up
1349
- for i, upsample_block in enumerate(self.up_blocks):
1350
- is_final_block = i == len(self.up_blocks) - 1
1351
-
1352
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1353
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1354
-
1355
- # if we have not reached the final block and need to forward the
1356
- # upsample size, we do it here
1357
- if not is_final_block and forward_upsample_size:
1358
- upsample_size = down_block_res_samples[-1].shape[2:]
1359
-
1360
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1361
- sample ,curr_garment_feat_idx= upsample_block(
1362
- hidden_states=sample,
1363
- temb=emb,
1364
- res_hidden_states_tuple=res_samples,
1365
- encoder_hidden_states=encoder_hidden_states,
1366
- cross_attention_kwargs=cross_attention_kwargs,
1367
- upsample_size=upsample_size,
1368
- attention_mask=attention_mask,
1369
- encoder_attention_mask=encoder_attention_mask,
1370
- garment_features=garment_features,
1371
- curr_garment_feat_idx=curr_garment_feat_idx,
1372
- )
1373
-
1374
- else:
1375
- sample = upsample_block(
1376
- hidden_states=sample,
1377
- temb=emb,
1378
- res_hidden_states_tuple=res_samples,
1379
- upsample_size=upsample_size,
1380
- scale=lora_scale,
1381
- )
1382
- # 6. post-process
1383
- if self.conv_norm_out:
1384
- sample = self.conv_norm_out(sample)
1385
- sample = self.conv_act(sample)
1386
- sample = self.conv_out(sample)
1387
-
1388
- if USE_PEFT_BACKEND:
1389
- # remove `lora_scale` from each PEFT layer
1390
- unscale_lora_layers(self, lora_scale)
1391
-
1392
- if not return_dict:
1393
- return (sample,)
1394
-
1395
- return UNet2DConditionOutput(sample=sample)