levihsu commited on
Commit
7ab1dcb
1 Parent(s): d1f7251

Delete ootd

Browse files
ootd/inference_ootd.py DELETED
@@ -1,133 +0,0 @@
1
- import pdb
2
- from pathlib import Path
3
- import sys
4
- PROJECT_ROOT = Path(__file__).absolute().parents[0].absolute()
5
- sys.path.insert(0, str(PROJECT_ROOT))
6
- import os
7
-
8
- import torch
9
- import numpy as np
10
- from PIL import Image
11
- import cv2
12
-
13
- import random
14
- import time
15
- import pdb
16
-
17
- from pipelines_ootd.pipeline_ootd import OotdPipeline
18
- from pipelines_ootd.unet_garm_2d_condition import UNetGarm2DConditionModel
19
- from pipelines_ootd.unet_vton_2d_condition import UNetVton2DConditionModel
20
- from diffusers import UniPCMultistepScheduler
21
- from diffusers import AutoencoderKL
22
-
23
- import torch.nn as nn
24
- import torch.nn.functional as F
25
- from transformers import AutoProcessor, CLIPVisionModelWithProjection
26
- from transformers import CLIPTextModel, CLIPTokenizer
27
-
28
- VIT_PATH = "openai/clip-vit-large-patch14"
29
- VAE_PATH = "levihsu/ootd"
30
- UNET_PATH = "levihsu/ootd"
31
- MODEL_PATH = "levihsu/ootd"
32
-
33
- class OOTDiffusion:
34
-
35
- def __init__(self, gpu_id):
36
- self.gpu_id = 'cuda:' + str(gpu_id)
37
-
38
- vae = AutoencoderKL.from_pretrained(
39
- VAE_PATH,
40
- subfolder="vae",
41
- torch_dtype=torch.float16,
42
- )
43
-
44
- unet_garm = UNetGarm2DConditionModel.from_pretrained(
45
- UNET_PATH,
46
- subfolder="unet_garm",
47
- torch_dtype=torch.float16,
48
- use_safetensors=True,
49
- )
50
- unet_vton = UNetVton2DConditionModel.from_pretrained(
51
- UNET_PATH,
52
- subfolder="unet_vton",
53
- torch_dtype=torch.float16,
54
- use_safetensors=True,
55
- )
56
-
57
- self.pipe = OotdPipeline.from_pretrained(
58
- MODEL_PATH,
59
- unet_garm=unet_garm,
60
- unet_vton=unet_vton,
61
- vae=vae,
62
- torch_dtype=torch.float16,
63
- variant="fp16",
64
- use_safetensors=True,
65
- safety_checker=None,
66
- requires_safety_checker=False,
67
- ).to(self.gpu_id)
68
-
69
- self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
70
-
71
- self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
72
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH).to(self.gpu_id)
73
-
74
- self.tokenizer = CLIPTokenizer.from_pretrained(
75
- MODEL_PATH,
76
- subfolder="tokenizer",
77
- )
78
- self.text_encoder = CLIPTextModel.from_pretrained(
79
- MODEL_PATH,
80
- subfolder="text_encoder",
81
- ).to(self.gpu_id)
82
-
83
-
84
- def tokenize_captions(self, captions, max_length):
85
- inputs = self.tokenizer(
86
- captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
87
- )
88
- return inputs.input_ids
89
-
90
-
91
- def __call__(self,
92
- model_type='hd',
93
- category='upperbody',
94
- image_garm=None,
95
- image_vton=None,
96
- mask=None,
97
- image_ori=None,
98
- num_samples=1,
99
- num_steps=20,
100
- image_scale=1.0,
101
- seed=-1,
102
- ):
103
- if seed == -1:
104
- random.seed(time.time())
105
- seed = random.randint(0, 2147483647)
106
- print('Initial seed: ' + str(seed))
107
- generator = torch.manual_seed(seed)
108
-
109
- with torch.no_grad():
110
- prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to(self.gpu_id)
111
- prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
112
- prompt_image = prompt_image.unsqueeze(1)
113
- if model_type == 'hd':
114
- prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to(self.gpu_id))[0]
115
- prompt_embeds[:, 1:] = prompt_image[:]
116
- elif model_type == 'dc':
117
- prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to(self.gpu_id))[0]
118
- prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
119
- else:
120
- raise ValueError("model_type must be \'hd\' or \'dc\'!")
121
-
122
- images = self.pipe(prompt_embeds=prompt_embeds,
123
- image_garm=image_garm,
124
- image_vton=image_vton,
125
- mask=mask,
126
- image_ori=image_ori,
127
- num_inference_steps=num_steps,
128
- image_guidance_scale=image_scale,
129
- num_images_per_prompt=num_samples,
130
- generator=generator,
131
- ).images
132
-
133
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ootd/inference_ootd_dc.py DELETED
@@ -1,132 +0,0 @@
1
- import pdb
2
- from pathlib import Path
3
- import sys
4
- PROJECT_ROOT = Path(__file__).absolute().parents[0].absolute()
5
- sys.path.insert(0, str(PROJECT_ROOT))
6
- import os
7
- import torch
8
- import numpy as np
9
- from PIL import Image
10
- import cv2
11
-
12
- import random
13
- import time
14
- import pdb
15
-
16
- from pipelines_ootd.pipeline_ootd import OotdPipeline
17
- from pipelines_ootd.unet_garm_2d_condition import UNetGarm2DConditionModel
18
- from pipelines_ootd.unet_vton_2d_condition import UNetVton2DConditionModel
19
- from diffusers import UniPCMultistepScheduler
20
- from diffusers import AutoencoderKL
21
-
22
- import torch.nn as nn
23
- import torch.nn.functional as F
24
- from transformers import AutoProcessor, CLIPVisionModelWithProjection
25
- from transformers import CLIPTextModel, CLIPTokenizer
26
-
27
- VIT_PATH = "openai/clip-vit-large-patch14"
28
- VAE_PATH = "levihsu/ootd"
29
- UNET_PATH = "levihsu/ootd"
30
- MODEL_PATH = "levihsu/ootd"
31
-
32
- class OOTDiffusionDC:
33
-
34
- def __init__(self, gpu_id):
35
- self.gpu_id = 'cuda:' + str(gpu_id)
36
-
37
- vae = AutoencoderKL.from_pretrained(
38
- VAE_PATH,
39
- subfolder="vae",
40
- torch_dtype=torch.float16,
41
- )
42
-
43
- unet_garm = UNetGarm2DConditionModel.from_pretrained(
44
- UNET_PATH,
45
- subfolder="ootd_dc/checkpoint-36000/unet_garm",
46
- torch_dtype=torch.float16,
47
- use_safetensors=True,
48
- )
49
- unet_vton = UNetVton2DConditionModel.from_pretrained(
50
- UNET_PATH,
51
- subfolder="ootd_dc/checkpoint-36000/unet_vton",
52
- torch_dtype=torch.float16,
53
- use_safetensors=True,
54
- )
55
-
56
- self.pipe = OotdPipeline.from_pretrained(
57
- MODEL_PATH,
58
- unet_garm=unet_garm,
59
- unet_vton=unet_vton,
60
- vae=vae,
61
- torch_dtype=torch.float16,
62
- variant="fp16",
63
- use_safetensors=True,
64
- safety_checker=None,
65
- requires_safety_checker=False,
66
- ).to(self.gpu_id)
67
-
68
- self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
69
-
70
- self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
71
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH).to(self.gpu_id)
72
-
73
- self.tokenizer = CLIPTokenizer.from_pretrained(
74
- MODEL_PATH,
75
- subfolder="tokenizer",
76
- )
77
- self.text_encoder = CLIPTextModel.from_pretrained(
78
- MODEL_PATH,
79
- subfolder="text_encoder",
80
- ).to(self.gpu_id)
81
-
82
-
83
- def tokenize_captions(self, captions, max_length):
84
- inputs = self.tokenizer(
85
- captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
86
- )
87
- return inputs.input_ids
88
-
89
-
90
- def __call__(self,
91
- model_type='hd',
92
- category='upperbody',
93
- image_garm=None,
94
- image_vton=None,
95
- mask=None,
96
- image_ori=None,
97
- num_samples=1,
98
- num_steps=20,
99
- image_scale=1.0,
100
- seed=-1,
101
- ):
102
- if seed == -1:
103
- random.seed(time.time())
104
- seed = random.randint(0, 2147483647)
105
- print('Initial seed: ' + str(seed))
106
- generator = torch.manual_seed(seed)
107
-
108
- with torch.no_grad():
109
- prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to(self.gpu_id)
110
- prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
111
- prompt_image = prompt_image.unsqueeze(1)
112
- if model_type == 'hd':
113
- prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to(self.gpu_id))[0]
114
- prompt_embeds[:, 1:] = prompt_image[:]
115
- elif model_type == 'dc':
116
- prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to(self.gpu_id))[0]
117
- prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
118
- else:
119
- raise ValueError("model_type must be \'hd\' or \'dc\'!")
120
-
121
- images = self.pipe(prompt_embeds=prompt_embeds,
122
- image_garm=image_garm,
123
- image_vton=image_vton,
124
- mask=mask,
125
- image_ori=image_ori,
126
- num_inference_steps=num_steps,
127
- image_guidance_scale=image_scale,
128
- num_images_per_prompt=num_samples,
129
- generator=generator,
130
- ).images
131
-
132
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ootd/inference_ootd_hd.py DELETED
@@ -1,132 +0,0 @@
1
- import pdb
2
- from pathlib import Path
3
- import sys
4
- PROJECT_ROOT = Path(__file__).absolute().parents[0].absolute()
5
- sys.path.insert(0, str(PROJECT_ROOT))
6
- import os
7
- import torch
8
- import numpy as np
9
- from PIL import Image
10
- import cv2
11
-
12
- import random
13
- import time
14
- import pdb
15
-
16
- from pipelines_ootd.pipeline_ootd import OotdPipeline
17
- from pipelines_ootd.unet_garm_2d_condition import UNetGarm2DConditionModel
18
- from pipelines_ootd.unet_vton_2d_condition import UNetVton2DConditionModel
19
- from diffusers import UniPCMultistepScheduler
20
- from diffusers import AutoencoderKL
21
-
22
- import torch.nn as nn
23
- import torch.nn.functional as F
24
- from transformers import AutoProcessor, CLIPVisionModelWithProjection
25
- from transformers import CLIPTextModel, CLIPTokenizer
26
-
27
- VIT_PATH = "openai/clip-vit-large-patch14"
28
- VAE_PATH = "levihsu/ootd"
29
- UNET_PATH = "levihsu/ootd"
30
- MODEL_PATH = "levihsu/ootd"
31
-
32
- class OOTDiffusionHD:
33
-
34
- def __init__(self, gpu_id):
35
- self.gpu_id = 'cuda:' + str(gpu_id)
36
-
37
- vae = AutoencoderKL.from_pretrained(
38
- VAE_PATH,
39
- subfolder="vae",
40
- torch_dtype=torch.float16,
41
- )
42
-
43
- unet_garm = UNetGarm2DConditionModel.from_pretrained(
44
- UNET_PATH,
45
- subfolder="ootd_hd/checkpoint-36000/unet_garm",
46
- torch_dtype=torch.float16,
47
- use_safetensors=True,
48
- )
49
- unet_vton = UNetVton2DConditionModel.from_pretrained(
50
- UNET_PATH,
51
- subfolder="ootd_hd/checkpoint-36000/unet_vton",
52
- torch_dtype=torch.float16,
53
- use_safetensors=True,
54
- )
55
-
56
- self.pipe = OotdPipeline.from_pretrained(
57
- MODEL_PATH,
58
- vae=vae,
59
- unet_garm=unet_garm,
60
- unet_vton=unet_vton,
61
- torch_dtype=torch.float16,
62
- variant="fp16",
63
- use_safetensors=True,
64
- safety_checker=None,
65
- requires_safety_checker=False,
66
- ).to(self.gpu_id)
67
-
68
- self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
69
-
70
- self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
71
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH).to(self.gpu_id)
72
-
73
- self.tokenizer = CLIPTokenizer.from_pretrained(
74
- MODEL_PATH,
75
- subfolder="tokenizer",
76
- )
77
- self.text_encoder = CLIPTextModel.from_pretrained(
78
- MODEL_PATH,
79
- subfolder="text_encoder",
80
- ).to(self.gpu_id)
81
-
82
-
83
- def tokenize_captions(self, captions, max_length):
84
- inputs = self.tokenizer(
85
- captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
86
- )
87
- return inputs.input_ids
88
-
89
-
90
- def __call__(self,
91
- model_type='hd',
92
- category='upperbody',
93
- image_garm=None,
94
- image_vton=None,
95
- mask=None,
96
- image_ori=None,
97
- num_samples=1,
98
- num_steps=20,
99
- image_scale=1.0,
100
- seed=-1,
101
- ):
102
- if seed == -1:
103
- random.seed(time.time())
104
- seed = random.randint(0, 2147483647)
105
- print('Initial seed: ' + str(seed))
106
- generator = torch.manual_seed(seed)
107
-
108
- with torch.no_grad():
109
- prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to(self.gpu_id)
110
- prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
111
- prompt_image = prompt_image.unsqueeze(1)
112
- if model_type == 'hd':
113
- prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to(self.gpu_id))[0]
114
- prompt_embeds[:, 1:] = prompt_image[:]
115
- elif model_type == 'dc':
116
- prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to(self.gpu_id))[0]
117
- prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
118
- else:
119
- raise ValueError("model_type must be \'hd\' or \'dc\'!")
120
-
121
- images = self.pipe(prompt_embeds=prompt_embeds,
122
- image_garm=image_garm,
123
- image_vton=image_vton,
124
- mask=mask,
125
- image_ori=image_ori,
126
- num_inference_steps=num_steps,
127
- image_guidance_scale=image_scale,
128
- num_images_per_prompt=num_samples,
129
- generator=generator,
130
- ).images
131
-
132
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ootd/pipelines_ootd/attention_garm.py DELETED
@@ -1,402 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
- from typing import Any, Dict, Optional
17
-
18
- import torch
19
- from torch import nn
20
-
21
- from diffusers.utils import USE_PEFT_BACKEND
22
- from diffusers.utils.torch_utils import maybe_allow_in_graph
23
- from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
24
- from diffusers.models.attention_processor import Attention
25
- from diffusers.models.embeddings import SinusoidalPositionalEmbedding
26
- from diffusers.models.lora import LoRACompatibleLinear
27
- from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
28
-
29
-
30
- @maybe_allow_in_graph
31
- class GatedSelfAttentionDense(nn.Module):
32
- r"""
33
- A gated self-attention dense layer that combines visual features and object features.
34
-
35
- Parameters:
36
- query_dim (`int`): The number of channels in the query.
37
- context_dim (`int`): The number of channels in the context.
38
- n_heads (`int`): The number of heads to use for attention.
39
- d_head (`int`): The number of channels in each head.
40
- """
41
-
42
- def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
43
- super().__init__()
44
-
45
- # we need a linear projection since we need cat visual feature and obj feature
46
- self.linear = nn.Linear(context_dim, query_dim)
47
-
48
- self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
49
- self.ff = FeedForward(query_dim, activation_fn="geglu")
50
-
51
- self.norm1 = nn.LayerNorm(query_dim)
52
- self.norm2 = nn.LayerNorm(query_dim)
53
-
54
- self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
55
- self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
56
-
57
- self.enabled = True
58
-
59
- def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
60
- if not self.enabled:
61
- return x
62
-
63
- n_visual = x.shape[1]
64
- objs = self.linear(objs)
65
-
66
- x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
67
- x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
68
-
69
- return x
70
-
71
-
72
- @maybe_allow_in_graph
73
- class BasicTransformerBlock(nn.Module):
74
- r"""
75
- A basic Transformer block.
76
-
77
- Parameters:
78
- dim (`int`): The number of channels in the input and output.
79
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
80
- attention_head_dim (`int`): The number of channels in each head.
81
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
82
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
83
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
84
- num_embeds_ada_norm (:
85
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
86
- attention_bias (:
87
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
88
- only_cross_attention (`bool`, *optional*):
89
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
90
- double_self_attention (`bool`, *optional*):
91
- Whether to use two self-attention layers. In this case no cross attention layers are used.
92
- upcast_attention (`bool`, *optional*):
93
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
94
- norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
95
- Whether to use learnable elementwise affine parameters for normalization.
96
- norm_type (`str`, *optional*, defaults to `"layer_norm"`):
97
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
98
- final_dropout (`bool` *optional*, defaults to False):
99
- Whether to apply a final dropout after the last feed-forward layer.
100
- attention_type (`str`, *optional*, defaults to `"default"`):
101
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
102
- positional_embeddings (`str`, *optional*, defaults to `None`):
103
- The type of positional embeddings to apply to.
104
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
105
- The maximum number of positional embeddings to apply.
106
- """
107
-
108
- def __init__(
109
- self,
110
- dim: int,
111
- num_attention_heads: int,
112
- attention_head_dim: int,
113
- dropout=0.0,
114
- cross_attention_dim: Optional[int] = None,
115
- activation_fn: str = "geglu",
116
- num_embeds_ada_norm: Optional[int] = None,
117
- attention_bias: bool = False,
118
- only_cross_attention: bool = False,
119
- double_self_attention: bool = False,
120
- upcast_attention: bool = False,
121
- norm_elementwise_affine: bool = True,
122
- norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
123
- norm_eps: float = 1e-5,
124
- final_dropout: bool = False,
125
- attention_type: str = "default",
126
- positional_embeddings: Optional[str] = None,
127
- num_positional_embeddings: Optional[int] = None,
128
- ):
129
- super().__init__()
130
- self.only_cross_attention = only_cross_attention
131
-
132
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
133
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
134
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
135
- self.use_layer_norm = norm_type == "layer_norm"
136
-
137
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
138
- raise ValueError(
139
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
140
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
141
- )
142
-
143
- if positional_embeddings and (num_positional_embeddings is None):
144
- raise ValueError(
145
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
146
- )
147
-
148
- if positional_embeddings == "sinusoidal":
149
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
150
- else:
151
- self.pos_embed = None
152
-
153
- # Define 3 blocks. Each block has its own normalization layer.
154
- # 1. Self-Attn
155
- if self.use_ada_layer_norm:
156
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
157
- elif self.use_ada_layer_norm_zero:
158
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
159
- else:
160
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
161
-
162
- self.attn1 = Attention(
163
- query_dim=dim,
164
- heads=num_attention_heads,
165
- dim_head=attention_head_dim,
166
- dropout=dropout,
167
- bias=attention_bias,
168
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
169
- upcast_attention=upcast_attention,
170
- )
171
-
172
- # 2. Cross-Attn
173
- if cross_attention_dim is not None or double_self_attention:
174
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
175
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
176
- # the second cross attention block.
177
- self.norm2 = (
178
- AdaLayerNorm(dim, num_embeds_ada_norm)
179
- if self.use_ada_layer_norm
180
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
181
- )
182
- self.attn2 = Attention(
183
- query_dim=dim,
184
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
185
- heads=num_attention_heads,
186
- dim_head=attention_head_dim,
187
- dropout=dropout,
188
- bias=attention_bias,
189
- upcast_attention=upcast_attention,
190
- ) # is self-attn if encoder_hidden_states is none
191
- else:
192
- self.norm2 = None
193
- self.attn2 = None
194
-
195
- # 3. Feed-forward
196
- if not self.use_ada_layer_norm_single:
197
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
198
-
199
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
200
-
201
- # 4. Fuser
202
- if attention_type == "gated" or attention_type == "gated-text-image":
203
- self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
204
-
205
- # 5. Scale-shift for PixArt-Alpha.
206
- if self.use_ada_layer_norm_single:
207
- self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
208
-
209
- # let chunk size default to None
210
- self._chunk_size = None
211
- self._chunk_dim = 0
212
-
213
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
214
- # Sets chunk feed-forward
215
- self._chunk_size = chunk_size
216
- self._chunk_dim = dim
217
-
218
- def forward(
219
- self,
220
- hidden_states: torch.FloatTensor,
221
- spatial_attn_inputs = [],
222
- attention_mask: Optional[torch.FloatTensor] = None,
223
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
224
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
225
- timestep: Optional[torch.LongTensor] = None,
226
- cross_attention_kwargs: Dict[str, Any] = None,
227
- class_labels: Optional[torch.LongTensor] = None,
228
- ) -> torch.FloatTensor:
229
- # Notice that normalization is always applied before the real computation in the following blocks.
230
- # 0. Self-Attention
231
- batch_size = hidden_states.shape[0]
232
-
233
- spatial_attn_input = hidden_states
234
- spatial_attn_inputs.append(spatial_attn_input)
235
-
236
- if self.use_ada_layer_norm:
237
- norm_hidden_states = self.norm1(hidden_states, timestep)
238
- elif self.use_ada_layer_norm_zero:
239
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
240
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
241
- )
242
- elif self.use_layer_norm:
243
- norm_hidden_states = self.norm1(hidden_states)
244
- elif self.use_ada_layer_norm_single:
245
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
246
- self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
247
- ).chunk(6, dim=1)
248
- norm_hidden_states = self.norm1(hidden_states)
249
- norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
250
- norm_hidden_states = norm_hidden_states.squeeze(1)
251
- else:
252
- raise ValueError("Incorrect norm used")
253
-
254
- if self.pos_embed is not None:
255
- norm_hidden_states = self.pos_embed(norm_hidden_states)
256
-
257
- # 1. Retrieve lora scale.
258
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
259
-
260
- # 2. Prepare GLIGEN inputs
261
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
262
- gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
263
-
264
- attn_output = self.attn1(
265
- norm_hidden_states,
266
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
267
- attention_mask=attention_mask,
268
- **cross_attention_kwargs,
269
- )
270
- if self.use_ada_layer_norm_zero:
271
- attn_output = gate_msa.unsqueeze(1) * attn_output
272
- elif self.use_ada_layer_norm_single:
273
- attn_output = gate_msa * attn_output
274
-
275
- hidden_states = attn_output + hidden_states
276
- if hidden_states.ndim == 4:
277
- hidden_states = hidden_states.squeeze(1)
278
-
279
- # 2.5 GLIGEN Control
280
- if gligen_kwargs is not None:
281
- hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
282
-
283
- # 3. Cross-Attention
284
- if self.attn2 is not None:
285
- if self.use_ada_layer_norm:
286
- norm_hidden_states = self.norm2(hidden_states, timestep)
287
- elif self.use_ada_layer_norm_zero or self.use_layer_norm:
288
- norm_hidden_states = self.norm2(hidden_states)
289
- elif self.use_ada_layer_norm_single:
290
- # For PixArt norm2 isn't applied here:
291
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
292
- norm_hidden_states = hidden_states
293
- else:
294
- raise ValueError("Incorrect norm")
295
-
296
- if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
297
- norm_hidden_states = self.pos_embed(norm_hidden_states)
298
-
299
- attn_output = self.attn2(
300
- norm_hidden_states,
301
- encoder_hidden_states=encoder_hidden_states,
302
- attention_mask=encoder_attention_mask,
303
- **cross_attention_kwargs,
304
- )
305
- hidden_states = attn_output + hidden_states
306
-
307
- # 4. Feed-forward
308
- if not self.use_ada_layer_norm_single:
309
- norm_hidden_states = self.norm3(hidden_states)
310
-
311
- if self.use_ada_layer_norm_zero:
312
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
313
-
314
- if self.use_ada_layer_norm_single:
315
- norm_hidden_states = self.norm2(hidden_states)
316
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
317
-
318
- if self._chunk_size is not None:
319
- # "feed_forward_chunk_size" can be used to save memory
320
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
321
- raise ValueError(
322
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
323
- )
324
-
325
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
326
- ff_output = torch.cat(
327
- [
328
- self.ff(hid_slice, scale=lora_scale)
329
- for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
330
- ],
331
- dim=self._chunk_dim,
332
- )
333
- else:
334
- ff_output = self.ff(norm_hidden_states, scale=lora_scale)
335
-
336
- if self.use_ada_layer_norm_zero:
337
- ff_output = gate_mlp.unsqueeze(1) * ff_output
338
- elif self.use_ada_layer_norm_single:
339
- ff_output = gate_mlp * ff_output
340
-
341
- hidden_states = ff_output + hidden_states
342
- if hidden_states.ndim == 4:
343
- hidden_states = hidden_states.squeeze(1)
344
-
345
- return hidden_states, spatial_attn_inputs
346
-
347
-
348
- class FeedForward(nn.Module):
349
- r"""
350
- A feed-forward layer.
351
-
352
- Parameters:
353
- dim (`int`): The number of channels in the input.
354
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
355
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
356
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
357
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
358
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
359
- """
360
-
361
- def __init__(
362
- self,
363
- dim: int,
364
- dim_out: Optional[int] = None,
365
- mult: int = 4,
366
- dropout: float = 0.0,
367
- activation_fn: str = "geglu",
368
- final_dropout: bool = False,
369
- ):
370
- super().__init__()
371
- inner_dim = int(dim * mult)
372
- dim_out = dim_out if dim_out is not None else dim
373
- linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
374
-
375
- if activation_fn == "gelu":
376
- act_fn = GELU(dim, inner_dim)
377
- if activation_fn == "gelu-approximate":
378
- act_fn = GELU(dim, inner_dim, approximate="tanh")
379
- elif activation_fn == "geglu":
380
- act_fn = GEGLU(dim, inner_dim)
381
- elif activation_fn == "geglu-approximate":
382
- act_fn = ApproximateGELU(dim, inner_dim)
383
-
384
- self.net = nn.ModuleList([])
385
- # project in
386
- self.net.append(act_fn)
387
- # project dropout
388
- self.net.append(nn.Dropout(dropout))
389
- # project out
390
- self.net.append(linear_cls(inner_dim, dim_out))
391
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
392
- if final_dropout:
393
- self.net.append(nn.Dropout(dropout))
394
-
395
- def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
396
- compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
397
- for module in self.net:
398
- if isinstance(module, compatible_cls):
399
- hidden_states = module(hidden_states, scale)
400
- else:
401
- hidden_states = module(hidden_states)
402
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ootd/pipelines_ootd/attention_vton.py DELETED
@@ -1,407 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
- from typing import Any, Dict, Optional
17
-
18
- import torch
19
- from torch import nn
20
-
21
- from diffusers.utils import USE_PEFT_BACKEND
22
- from diffusers.utils.torch_utils import maybe_allow_in_graph
23
- from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
24
- from diffusers.models.attention_processor import Attention
25
- from diffusers.models.embeddings import SinusoidalPositionalEmbedding
26
- from diffusers.models.lora import LoRACompatibleLinear
27
- from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
28
-
29
-
30
- @maybe_allow_in_graph
31
- class GatedSelfAttentionDense(nn.Module):
32
- r"""
33
- A gated self-attention dense layer that combines visual features and object features.
34
-
35
- Parameters:
36
- query_dim (`int`): The number of channels in the query.
37
- context_dim (`int`): The number of channels in the context.
38
- n_heads (`int`): The number of heads to use for attention.
39
- d_head (`int`): The number of channels in each head.
40
- """
41
-
42
- def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
43
- super().__init__()
44
-
45
- # we need a linear projection since we need cat visual feature and obj feature
46
- self.linear = nn.Linear(context_dim, query_dim)
47
-
48
- self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
49
- self.ff = FeedForward(query_dim, activation_fn="geglu")
50
-
51
- self.norm1 = nn.LayerNorm(query_dim)
52
- self.norm2 = nn.LayerNorm(query_dim)
53
-
54
- self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
55
- self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
56
-
57
- self.enabled = True
58
-
59
- def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
60
- if not self.enabled:
61
- return x
62
-
63
- n_visual = x.shape[1]
64
- objs = self.linear(objs)
65
-
66
- x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
67
- x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
68
-
69
- return x
70
-
71
-
72
- @maybe_allow_in_graph
73
- class BasicTransformerBlock(nn.Module):
74
- r"""
75
- A basic Transformer block.
76
-
77
- Parameters:
78
- dim (`int`): The number of channels in the input and output.
79
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
80
- attention_head_dim (`int`): The number of channels in each head.
81
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
82
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
83
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
84
- num_embeds_ada_norm (:
85
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
86
- attention_bias (:
87
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
88
- only_cross_attention (`bool`, *optional*):
89
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
90
- double_self_attention (`bool`, *optional*):
91
- Whether to use two self-attention layers. In this case no cross attention layers are used.
92
- upcast_attention (`bool`, *optional*):
93
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
94
- norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
95
- Whether to use learnable elementwise affine parameters for normalization.
96
- norm_type (`str`, *optional*, defaults to `"layer_norm"`):
97
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
98
- final_dropout (`bool` *optional*, defaults to False):
99
- Whether to apply a final dropout after the last feed-forward layer.
100
- attention_type (`str`, *optional*, defaults to `"default"`):
101
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
102
- positional_embeddings (`str`, *optional*, defaults to `None`):
103
- The type of positional embeddings to apply to.
104
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
105
- The maximum number of positional embeddings to apply.
106
- """
107
-
108
- def __init__(
109
- self,
110
- dim: int,
111
- num_attention_heads: int,
112
- attention_head_dim: int,
113
- dropout=0.0,
114
- cross_attention_dim: Optional[int] = None,
115
- activation_fn: str = "geglu",
116
- num_embeds_ada_norm: Optional[int] = None,
117
- attention_bias: bool = False,
118
- only_cross_attention: bool = False,
119
- double_self_attention: bool = False,
120
- upcast_attention: bool = False,
121
- norm_elementwise_affine: bool = True,
122
- norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
123
- norm_eps: float = 1e-5,
124
- final_dropout: bool = False,
125
- attention_type: str = "default",
126
- positional_embeddings: Optional[str] = None,
127
- num_positional_embeddings: Optional[int] = None,
128
- ):
129
- super().__init__()
130
- self.only_cross_attention = only_cross_attention
131
-
132
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
133
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
134
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
135
- self.use_layer_norm = norm_type == "layer_norm"
136
-
137
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
138
- raise ValueError(
139
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
140
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
141
- )
142
-
143
- if positional_embeddings and (num_positional_embeddings is None):
144
- raise ValueError(
145
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
146
- )
147
-
148
- if positional_embeddings == "sinusoidal":
149
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
150
- else:
151
- self.pos_embed = None
152
-
153
- # Define 3 blocks. Each block has its own normalization layer.
154
- # 1. Self-Attn
155
- if self.use_ada_layer_norm:
156
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
157
- elif self.use_ada_layer_norm_zero:
158
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
159
- else:
160
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
161
-
162
- self.attn1 = Attention(
163
- query_dim=dim,
164
- heads=num_attention_heads,
165
- dim_head=attention_head_dim,
166
- dropout=dropout,
167
- bias=attention_bias,
168
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
169
- upcast_attention=upcast_attention,
170
- )
171
-
172
- # 2. Cross-Attn
173
- if cross_attention_dim is not None or double_self_attention:
174
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
175
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
176
- # the second cross attention block.
177
- self.norm2 = (
178
- AdaLayerNorm(dim, num_embeds_ada_norm)
179
- if self.use_ada_layer_norm
180
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
181
- )
182
- self.attn2 = Attention(
183
- query_dim=dim,
184
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
185
- heads=num_attention_heads,
186
- dim_head=attention_head_dim,
187
- dropout=dropout,
188
- bias=attention_bias,
189
- upcast_attention=upcast_attention,
190
- ) # is self-attn if encoder_hidden_states is none
191
- else:
192
- self.norm2 = None
193
- self.attn2 = None
194
-
195
- # 3. Feed-forward
196
- if not self.use_ada_layer_norm_single:
197
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
198
-
199
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
200
-
201
- # 4. Fuser
202
- if attention_type == "gated" or attention_type == "gated-text-image":
203
- self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
204
-
205
- # 5. Scale-shift for PixArt-Alpha.
206
- if self.use_ada_layer_norm_single:
207
- self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
208
-
209
- # let chunk size default to None
210
- self._chunk_size = None
211
- self._chunk_dim = 0
212
-
213
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
214
- # Sets chunk feed-forward
215
- self._chunk_size = chunk_size
216
- self._chunk_dim = dim
217
-
218
- def forward(
219
- self,
220
- hidden_states: torch.FloatTensor,
221
- spatial_attn_inputs = [],
222
- spatial_attn_idx = 0,
223
- attention_mask: Optional[torch.FloatTensor] = None,
224
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
225
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
226
- timestep: Optional[torch.LongTensor] = None,
227
- cross_attention_kwargs: Dict[str, Any] = None,
228
- class_labels: Optional[torch.LongTensor] = None,
229
- ) -> torch.FloatTensor:
230
- # Notice that normalization is always applied before the real computation in the following blocks.
231
- # 0. Self-Attention
232
- batch_size = hidden_states.shape[0]
233
-
234
- spatial_attn_input = spatial_attn_inputs[spatial_attn_idx]
235
- spatial_attn_idx += 1
236
- hidden_states = torch.cat((hidden_states, spatial_attn_input), dim=1)
237
-
238
- if self.use_ada_layer_norm:
239
- norm_hidden_states = self.norm1(hidden_states, timestep)
240
- elif self.use_ada_layer_norm_zero:
241
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
242
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
243
- )
244
- elif self.use_layer_norm:
245
- norm_hidden_states = self.norm1(hidden_states)
246
- elif self.use_ada_layer_norm_single:
247
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
248
- self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
249
- ).chunk(6, dim=1)
250
- norm_hidden_states = self.norm1(hidden_states)
251
- norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
252
- norm_hidden_states = norm_hidden_states.squeeze(1)
253
- else:
254
- raise ValueError("Incorrect norm used")
255
-
256
- if self.pos_embed is not None:
257
- norm_hidden_states = self.pos_embed(norm_hidden_states)
258
-
259
- # 1. Retrieve lora scale.
260
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
261
-
262
- # 2. Prepare GLIGEN inputs
263
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
264
- gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
265
-
266
- attn_output = self.attn1(
267
- norm_hidden_states,
268
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
269
- attention_mask=attention_mask,
270
- **cross_attention_kwargs,
271
- )
272
- if self.use_ada_layer_norm_zero:
273
- attn_output = gate_msa.unsqueeze(1) * attn_output
274
- elif self.use_ada_layer_norm_single:
275
- attn_output = gate_msa * attn_output
276
-
277
-
278
- hidden_states = attn_output + hidden_states
279
- hidden_states, _ = hidden_states.chunk(2, dim=1)
280
-
281
- if hidden_states.ndim == 4:
282
- hidden_states = hidden_states.squeeze(1)
283
-
284
- # 2.5 GLIGEN Control
285
- if gligen_kwargs is not None:
286
- hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
287
-
288
- # 3. Cross-Attention
289
- if self.attn2 is not None:
290
- if self.use_ada_layer_norm:
291
- norm_hidden_states = self.norm2(hidden_states, timestep)
292
- elif self.use_ada_layer_norm_zero or self.use_layer_norm:
293
- norm_hidden_states = self.norm2(hidden_states)
294
- elif self.use_ada_layer_norm_single:
295
- # For PixArt norm2 isn't applied here:
296
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
297
- norm_hidden_states = hidden_states
298
- else:
299
- raise ValueError("Incorrect norm")
300
-
301
- if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
302
- norm_hidden_states = self.pos_embed(norm_hidden_states)
303
-
304
- attn_output = self.attn2(
305
- norm_hidden_states,
306
- encoder_hidden_states=encoder_hidden_states,
307
- attention_mask=encoder_attention_mask,
308
- **cross_attention_kwargs,
309
- )
310
- hidden_states = attn_output + hidden_states
311
-
312
- # 4. Feed-forward
313
- if not self.use_ada_layer_norm_single:
314
- norm_hidden_states = self.norm3(hidden_states)
315
-
316
- if self.use_ada_layer_norm_zero:
317
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
318
-
319
- if self.use_ada_layer_norm_single:
320
- norm_hidden_states = self.norm2(hidden_states)
321
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
322
-
323
- if self._chunk_size is not None:
324
- # "feed_forward_chunk_size" can be used to save memory
325
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
326
- raise ValueError(
327
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
328
- )
329
-
330
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
331
- ff_output = torch.cat(
332
- [
333
- self.ff(hid_slice, scale=lora_scale)
334
- for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
335
- ],
336
- dim=self._chunk_dim,
337
- )
338
- else:
339
- ff_output = self.ff(norm_hidden_states, scale=lora_scale)
340
-
341
- if self.use_ada_layer_norm_zero:
342
- ff_output = gate_mlp.unsqueeze(1) * ff_output
343
- elif self.use_ada_layer_norm_single:
344
- ff_output = gate_mlp * ff_output
345
-
346
- hidden_states = ff_output + hidden_states
347
- if hidden_states.ndim == 4:
348
- hidden_states = hidden_states.squeeze(1)
349
-
350
- return hidden_states, spatial_attn_inputs, spatial_attn_idx
351
-
352
-
353
- class FeedForward(nn.Module):
354
- r"""
355
- A feed-forward layer.
356
-
357
- Parameters:
358
- dim (`int`): The number of channels in the input.
359
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
360
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
361
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
362
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
363
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
364
- """
365
-
366
- def __init__(
367
- self,
368
- dim: int,
369
- dim_out: Optional[int] = None,
370
- mult: int = 4,
371
- dropout: float = 0.0,
372
- activation_fn: str = "geglu",
373
- final_dropout: bool = False,
374
- ):
375
- super().__init__()
376
- inner_dim = int(dim * mult)
377
- dim_out = dim_out if dim_out is not None else dim
378
- linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
379
-
380
- if activation_fn == "gelu":
381
- act_fn = GELU(dim, inner_dim)
382
- if activation_fn == "gelu-approximate":
383
- act_fn = GELU(dim, inner_dim, approximate="tanh")
384
- elif activation_fn == "geglu":
385
- act_fn = GEGLU(dim, inner_dim)
386
- elif activation_fn == "geglu-approximate":
387
- act_fn = ApproximateGELU(dim, inner_dim)
388
-
389
- self.net = nn.ModuleList([])
390
- # project in
391
- self.net.append(act_fn)
392
- # project dropout
393
- self.net.append(nn.Dropout(dropout))
394
- # project out
395
- self.net.append(linear_cls(inner_dim, dim_out))
396
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
397
- if final_dropout:
398
- self.net.append(nn.Dropout(dropout))
399
-
400
- def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
401
- compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
402
- for module in self.net:
403
- if isinstance(module, compatible_cls):
404
- hidden_states = module(hidden_states, scale)
405
- else:
406
- hidden_states = module(hidden_states)
407
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ootd/pipelines_ootd/pipeline_ootd.py DELETED
@@ -1,846 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
- import inspect
17
- from typing import Any, Callable, Dict, List, Optional, Union
18
-
19
- import numpy as np
20
- import PIL.Image
21
- import torch
22
- from packaging import version
23
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
24
-
25
- from transformers import AutoProcessor, CLIPVisionModelWithProjection
26
-
27
- from .unet_vton_2d_condition import UNetVton2DConditionModel
28
- from .unet_garm_2d_condition import UNetGarm2DConditionModel
29
-
30
- from diffusers.configuration_utils import FrozenDict
31
- from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
32
- from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
33
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
34
- from diffusers.models.lora import adjust_lora_scale_text_encoder
35
- from diffusers.schedulers import KarrasDiffusionSchedulers
36
- from diffusers.utils import (
37
- PIL_INTERPOLATION,
38
- USE_PEFT_BACKEND,
39
- deprecate,
40
- logging,
41
- replace_example_docstring,
42
- scale_lora_layers,
43
- unscale_lora_layers,
44
- )
45
- from diffusers.utils.torch_utils import randn_tensor
46
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
47
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
48
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
49
-
50
-
51
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
-
53
-
54
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
55
- def preprocess(image):
56
- deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
57
- deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
58
- if isinstance(image, torch.Tensor):
59
- return image
60
- elif isinstance(image, PIL.Image.Image):
61
- image = [image]
62
-
63
- if isinstance(image[0], PIL.Image.Image):
64
- w, h = image[0].size
65
- w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
66
-
67
- image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
68
- image = np.concatenate(image, axis=0)
69
- image = np.array(image).astype(np.float32) / 255.0
70
- image = image.transpose(0, 3, 1, 2)
71
- image = 2.0 * image - 1.0
72
- image = torch.from_numpy(image)
73
- elif isinstance(image[0], torch.Tensor):
74
- image = torch.cat(image, dim=0)
75
- return image
76
-
77
-
78
- class OotdPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
79
- r"""
80
- Args:
81
- vae ([`AutoencoderKL`]):
82
- Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
83
- text_encoder ([`~transformers.CLIPTextModel`]):
84
- Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
85
- tokenizer ([`~transformers.CLIPTokenizer`]):
86
- A `CLIPTokenizer` to tokenize text.
87
- unet ([`UNet2DConditionModel`]):
88
- A `UNet2DConditionModel` to denoise the encoded image latents.
89
- scheduler ([`SchedulerMixin`]):
90
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
91
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
92
- safety_checker ([`StableDiffusionSafetyChecker`]):
93
- Classification module that estimates whether generated images could be considered offensive or harmful.
94
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
95
- about a model's potential harms.
96
- feature_extractor ([`~transformers.CLIPImageProcessor`]):
97
- A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
98
- """
99
- model_cpu_offload_seq = "text_encoder->unet->vae"
100
- _optional_components = ["safety_checker", "feature_extractor"]
101
- _exclude_from_cpu_offload = ["safety_checker"]
102
- _callback_tensor_inputs = ["latents", "prompt_embeds", "vton_latents"]
103
-
104
- def __init__(
105
- self,
106
- vae: AutoencoderKL,
107
- text_encoder: CLIPTextModel,
108
- tokenizer: CLIPTokenizer,
109
- unet_garm: UNetGarm2DConditionModel,
110
- unet_vton: UNetVton2DConditionModel,
111
- scheduler: KarrasDiffusionSchedulers,
112
- safety_checker: StableDiffusionSafetyChecker,
113
- feature_extractor: CLIPImageProcessor,
114
- requires_safety_checker: bool = True,
115
- ):
116
- super().__init__()
117
-
118
- if safety_checker is None and requires_safety_checker:
119
- logger.warning(
120
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
121
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
122
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
123
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
124
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
125
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
126
- )
127
-
128
- if safety_checker is not None and feature_extractor is None:
129
- raise ValueError(
130
- "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
131
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
132
- )
133
-
134
- self.register_modules(
135
- vae=vae,
136
- text_encoder=text_encoder,
137
- tokenizer=tokenizer,
138
- unet_garm=unet_garm,
139
- unet_vton=unet_vton,
140
- scheduler=scheduler,
141
- safety_checker=safety_checker,
142
- feature_extractor=feature_extractor,
143
- )
144
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
145
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
146
- self.register_to_config(requires_safety_checker=requires_safety_checker)
147
-
148
- @torch.no_grad()
149
- def __call__(
150
- self,
151
- prompt: Union[str, List[str]] = None,
152
- image_garm: PipelineImageInput = None,
153
- image_vton: PipelineImageInput = None,
154
- mask: PipelineImageInput = None,
155
- image_ori: PipelineImageInput = None,
156
- num_inference_steps: int = 100,
157
- guidance_scale: float = 7.5,
158
- image_guidance_scale: float = 1.5,
159
- negative_prompt: Optional[Union[str, List[str]]] = None,
160
- num_images_per_prompt: Optional[int] = 1,
161
- eta: float = 0.0,
162
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
163
- latents: Optional[torch.FloatTensor] = None,
164
- prompt_embeds: Optional[torch.FloatTensor] = None,
165
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
166
- output_type: Optional[str] = "pil",
167
- return_dict: bool = True,
168
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
169
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
170
- **kwargs,
171
- ):
172
- r"""
173
- The call function to the pipeline for generation.
174
-
175
- Args:
176
- prompt (`str` or `List[str]`, *optional*):
177
- The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
178
- image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
179
- `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
180
- image latents as `image`, but if passing latents directly it is not encoded again.
181
- num_inference_steps (`int`, *optional*, defaults to 100):
182
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
183
- expense of slower inference.
184
- guidance_scale (`float`, *optional*, defaults to 7.5):
185
- A higher guidance scale value encourages the model to generate images closely linked to the text
186
- `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
187
- image_guidance_scale (`float`, *optional*, defaults to 1.5):
188
- Push the generated image towards the initial `image`. Image guidance scale is enabled by setting
189
- `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
190
- linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
191
- value of at least `1`.
192
- negative_prompt (`str` or `List[str]`, *optional*):
193
- The prompt or prompts to guide what to not include in image generation. If not defined, you need to
194
- pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
195
- num_images_per_prompt (`int`, *optional*, defaults to 1):
196
- The number of images to generate per prompt.
197
- eta (`float`, *optional*, defaults to 0.0):
198
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
199
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
200
- generator (`torch.Generator`, *optional*):
201
- A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
202
- generation deterministic.
203
- latents (`torch.FloatTensor`, *optional*):
204
- Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
205
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
206
- tensor is generated by sampling using the supplied random `generator`.
207
- prompt_embeds (`torch.FloatTensor`, *optional*):
208
- Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
209
- provided, text embeddings are generated from the `prompt` input argument.
210
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
211
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
212
- not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
213
- output_type (`str`, *optional*, defaults to `"pil"`):
214
- The output format of the generated image. Choose between `PIL.Image` or `np.array`.
215
- return_dict (`bool`, *optional*, defaults to `True`):
216
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
217
- plain tuple.
218
- callback_on_step_end (`Callable`, *optional*):
219
- A function that calls at the end of each denoising steps during the inference. The function is called
220
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
221
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
222
- `callback_on_step_end_tensor_inputs`.
223
- callback_on_step_end_tensor_inputs (`List`, *optional*):
224
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
225
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
226
- `._callback_tensor_inputs` attribute of your pipeline class.
227
-
228
- Returns:
229
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
230
- If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
231
- otherwise a `tuple` is returned where the first element is a list with the generated images and the
232
- second element is a list of `bool`s indicating whether the corresponding generated image contains
233
- "not-safe-for-work" (nsfw) content.
234
- """
235
-
236
- callback = kwargs.pop("callback", None)
237
- callback_steps = kwargs.pop("callback_steps", None)
238
-
239
- if callback is not None:
240
- deprecate(
241
- "callback",
242
- "1.0.0",
243
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
244
- )
245
- if callback_steps is not None:
246
- deprecate(
247
- "callback_steps",
248
- "1.0.0",
249
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
250
- )
251
-
252
- # 0. Check inputs
253
- self.check_inputs(
254
- prompt,
255
- callback_steps,
256
- negative_prompt,
257
- prompt_embeds,
258
- negative_prompt_embeds,
259
- callback_on_step_end_tensor_inputs,
260
- )
261
- self._guidance_scale = guidance_scale
262
- self._image_guidance_scale = image_guidance_scale
263
-
264
- if (image_vton is None) or (image_garm is None):
265
- raise ValueError("`image` input cannot be undefined.")
266
-
267
- # 1. Define call parameters
268
- if prompt is not None and isinstance(prompt, str):
269
- batch_size = 1
270
- elif prompt is not None and isinstance(prompt, list):
271
- batch_size = len(prompt)
272
- else:
273
- batch_size = prompt_embeds.shape[0]
274
-
275
- device = self._execution_device
276
- # check if scheduler is in sigmas space
277
- scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
278
-
279
- # 2. Encode input prompt
280
- prompt_embeds = self._encode_prompt(
281
- prompt,
282
- device,
283
- num_images_per_prompt,
284
- self.do_classifier_free_guidance,
285
- negative_prompt,
286
- prompt_embeds=prompt_embeds,
287
- negative_prompt_embeds=negative_prompt_embeds,
288
- )
289
-
290
- # 3. Preprocess image
291
- image_garm = self.image_processor.preprocess(image_garm)
292
- image_vton = self.image_processor.preprocess(image_vton)
293
- image_ori = self.image_processor.preprocess(image_ori)
294
- mask = np.array(mask)
295
- mask[mask < 127] = 0
296
- mask[mask >= 127] = 255
297
- mask = torch.tensor(mask)
298
- mask = mask / 255
299
- mask = mask.reshape(-1, 1, mask.size(-2), mask.size(-1))
300
-
301
- # 4. set timesteps
302
- self.scheduler.set_timesteps(num_inference_steps, device=device)
303
- timesteps = self.scheduler.timesteps
304
-
305
- # 5. Prepare Image latents
306
- garm_latents = self.prepare_garm_latents(
307
- image_garm,
308
- batch_size,
309
- num_images_per_prompt,
310
- prompt_embeds.dtype,
311
- device,
312
- self.do_classifier_free_guidance,
313
- generator,
314
- )
315
-
316
- vton_latents, mask_latents, image_ori_latents = self.prepare_vton_latents(
317
- image_vton,
318
- mask,
319
- image_ori,
320
- batch_size,
321
- num_images_per_prompt,
322
- prompt_embeds.dtype,
323
- device,
324
- self.do_classifier_free_guidance,
325
- generator,
326
- )
327
-
328
- height, width = vton_latents.shape[-2:]
329
- height = height * self.vae_scale_factor
330
- width = width * self.vae_scale_factor
331
-
332
- # 6. Prepare latent variables
333
- num_channels_latents = self.vae.config.latent_channels
334
- latents = self.prepare_latents(
335
- batch_size * num_images_per_prompt,
336
- num_channels_latents,
337
- height,
338
- width,
339
- prompt_embeds.dtype,
340
- device,
341
- generator,
342
- latents,
343
- )
344
-
345
- noise = latents.clone()
346
-
347
- # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
348
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
349
-
350
- # 9. Denoising loop
351
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
352
- self._num_timesteps = len(timesteps)
353
-
354
- _, spatial_attn_outputs = self.unet_garm(
355
- garm_latents,
356
- 0,
357
- encoder_hidden_states=prompt_embeds,
358
- return_dict=False,
359
- )
360
-
361
- with self.progress_bar(total=num_inference_steps) as progress_bar:
362
- for i, t in enumerate(timesteps):
363
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
364
-
365
- # concat latents, image_latents in the channel dimension
366
- scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
367
- latent_vton_model_input = torch.cat([scaled_latent_model_input, vton_latents], dim=1)
368
- # latent_vton_model_input = scaled_latent_model_input + vton_latents
369
-
370
- spatial_attn_inputs = spatial_attn_outputs.copy()
371
-
372
- # predict the noise residual
373
- noise_pred = self.unet_vton(
374
- latent_vton_model_input,
375
- spatial_attn_inputs,
376
- t,
377
- encoder_hidden_states=prompt_embeds,
378
- return_dict=False,
379
- )[0]
380
-
381
- # Hack:
382
- # For karras style schedulers the model does classifer free guidance using the
383
- # predicted_original_sample instead of the noise_pred. So we need to compute the
384
- # predicted_original_sample here if we are using a karras style scheduler.
385
- if scheduler_is_in_sigma_space:
386
- step_index = (self.scheduler.timesteps == t).nonzero()[0].item()
387
- sigma = self.scheduler.sigmas[step_index]
388
- noise_pred = latent_model_input - sigma * noise_pred
389
-
390
- # perform guidance
391
- if self.do_classifier_free_guidance:
392
- noise_pred_text_image, noise_pred_text = noise_pred.chunk(2)
393
- noise_pred = (
394
- noise_pred_text
395
- + self.image_guidance_scale * (noise_pred_text_image - noise_pred_text)
396
- )
397
-
398
- # Hack:
399
- # For karras style schedulers the model does classifer free guidance using the
400
- # predicted_original_sample instead of the noise_pred. But the scheduler.step function
401
- # expects the noise_pred and computes the predicted_original_sample internally. So we
402
- # need to overwrite the noise_pred here such that the value of the computed
403
- # predicted_original_sample is correct.
404
- if scheduler_is_in_sigma_space:
405
- noise_pred = (noise_pred - latents) / (-sigma)
406
-
407
- # compute the previous noisy sample x_t -> x_t-1
408
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
409
-
410
- init_latents_proper = image_ori_latents * self.vae.config.scaling_factor
411
-
412
- # repainting
413
- if i < len(timesteps) - 1:
414
- noise_timestep = timesteps[i + 1]
415
- init_latents_proper = self.scheduler.add_noise(
416
- init_latents_proper, noise, torch.tensor([noise_timestep])
417
- )
418
-
419
- latents = (1 - mask_latents) * init_latents_proper + mask_latents * latents
420
-
421
- if callback_on_step_end is not None:
422
- callback_kwargs = {}
423
- for k in callback_on_step_end_tensor_inputs:
424
- callback_kwargs[k] = locals()[k]
425
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
426
-
427
- latents = callback_outputs.pop("latents", latents)
428
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
429
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
430
- vton_latents = callback_outputs.pop("vton_latents", vton_latents)
431
-
432
- # call the callback, if provided
433
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
434
- progress_bar.update()
435
- if callback is not None and i % callback_steps == 0:
436
- step_idx = i // getattr(self.scheduler, "order", 1)
437
- callback(step_idx, t, latents)
438
-
439
- if not output_type == "latent":
440
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
441
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
442
- else:
443
- image = latents
444
- has_nsfw_concept = None
445
-
446
- if has_nsfw_concept is None:
447
- do_denormalize = [True] * image.shape[0]
448
- else:
449
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
450
-
451
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
452
-
453
- # Offload all models
454
- self.maybe_free_model_hooks()
455
-
456
- if not return_dict:
457
- return (image, has_nsfw_concept)
458
-
459
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
460
-
461
- def _encode_prompt(
462
- self,
463
- prompt,
464
- device,
465
- num_images_per_prompt,
466
- do_classifier_free_guidance,
467
- negative_prompt=None,
468
- prompt_embeds: Optional[torch.FloatTensor] = None,
469
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
470
- ):
471
- r"""
472
- Encodes the prompt into text encoder hidden states.
473
-
474
- Args:
475
- prompt (`str` or `List[str]`, *optional*):
476
- prompt to be encoded
477
- device: (`torch.device`):
478
- torch device
479
- num_images_per_prompt (`int`):
480
- number of images that should be generated per prompt
481
- do_classifier_free_guidance (`bool`):
482
- whether to use classifier free guidance or not
483
- negative_ prompt (`str` or `List[str]`, *optional*):
484
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
485
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
486
- less than `1`).
487
- prompt_embeds (`torch.FloatTensor`, *optional*):
488
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
489
- provided, text embeddings will be generated from `prompt` input argument.
490
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
491
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
492
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
493
- argument.
494
- """
495
- if prompt is not None and isinstance(prompt, str):
496
- batch_size = 1
497
- elif prompt is not None and isinstance(prompt, list):
498
- batch_size = len(prompt)
499
- else:
500
- batch_size = prompt_embeds.shape[0]
501
-
502
- if prompt_embeds is None:
503
- # textual inversion: procecss multi-vector tokens if necessary
504
- if isinstance(self, TextualInversionLoaderMixin):
505
- prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
506
-
507
- text_inputs = self.tokenizer(
508
- prompt,
509
- padding="max_length",
510
- max_length=self.tokenizer.model_max_length,
511
- truncation=True,
512
- return_tensors="pt",
513
- )
514
- text_input_ids = text_inputs.input_ids
515
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
516
-
517
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
518
- text_input_ids, untruncated_ids
519
- ):
520
- removed_text = self.tokenizer.batch_decode(
521
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
522
- )
523
- logger.warning(
524
- "The following part of your input was truncated because CLIP can only handle sequences up to"
525
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
526
- )
527
-
528
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
529
- attention_mask = text_inputs.attention_mask.to(device)
530
- else:
531
- attention_mask = None
532
-
533
- prompt_embeds = self.text_encoder(
534
- text_input_ids.to(device),
535
- attention_mask=attention_mask,
536
- )
537
- prompt_embeds = prompt_embeds[0]
538
-
539
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
540
-
541
- bs_embed, seq_len, _ = prompt_embeds.shape
542
- # duplicate text embeddings for each generation per prompt, using mps friendly method
543
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
544
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
545
-
546
- # get unconditional embeddings for classifier free guidance
547
- if do_classifier_free_guidance and negative_prompt_embeds is None:
548
- uncond_tokens: List[str]
549
- if negative_prompt is None:
550
- uncond_tokens = [""] * batch_size
551
- elif type(prompt) is not type(negative_prompt):
552
- raise TypeError(
553
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
554
- f" {type(prompt)}."
555
- )
556
- elif isinstance(negative_prompt, str):
557
- uncond_tokens = [negative_prompt]
558
- elif batch_size != len(negative_prompt):
559
- raise ValueError(
560
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
561
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
562
- " the batch size of `prompt`."
563
- )
564
- else:
565
- uncond_tokens = negative_prompt
566
-
567
- # textual inversion: procecss multi-vector tokens if necessary
568
- if isinstance(self, TextualInversionLoaderMixin):
569
- uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
570
-
571
- max_length = prompt_embeds.shape[1]
572
- uncond_input = self.tokenizer(
573
- uncond_tokens,
574
- padding="max_length",
575
- max_length=max_length,
576
- truncation=True,
577
- return_tensors="pt",
578
- )
579
-
580
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
581
- attention_mask = uncond_input.attention_mask.to(device)
582
- else:
583
- attention_mask = None
584
-
585
- if do_classifier_free_guidance:
586
- prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
587
-
588
- return prompt_embeds
589
-
590
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
591
- def run_safety_checker(self, image, device, dtype):
592
- if self.safety_checker is None:
593
- has_nsfw_concept = None
594
- else:
595
- if torch.is_tensor(image):
596
- feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
597
- else:
598
- feature_extractor_input = self.image_processor.numpy_to_pil(image)
599
- safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
600
- image, has_nsfw_concept = self.safety_checker(
601
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
602
- )
603
- return image, has_nsfw_concept
604
-
605
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
606
- def prepare_extra_step_kwargs(self, generator, eta):
607
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
608
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
609
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
610
- # and should be between [0, 1]
611
-
612
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
613
- extra_step_kwargs = {}
614
- if accepts_eta:
615
- extra_step_kwargs["eta"] = eta
616
-
617
- # check if the scheduler accepts generator
618
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
619
- if accepts_generator:
620
- extra_step_kwargs["generator"] = generator
621
- return extra_step_kwargs
622
-
623
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
624
- def decode_latents(self, latents):
625
- deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
626
- deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
627
-
628
- latents = 1 / self.vae.config.scaling_factor * latents
629
- image = self.vae.decode(latents, return_dict=False)[0]
630
- image = (image / 2 + 0.5).clamp(0, 1)
631
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
632
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
633
- return image
634
-
635
- def check_inputs(
636
- self,
637
- prompt,
638
- callback_steps,
639
- negative_prompt=None,
640
- prompt_embeds=None,
641
- negative_prompt_embeds=None,
642
- callback_on_step_end_tensor_inputs=None,
643
- ):
644
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
645
- raise ValueError(
646
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
647
- f" {type(callback_steps)}."
648
- )
649
-
650
- if callback_on_step_end_tensor_inputs is not None and not all(
651
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
652
- ):
653
- raise ValueError(
654
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
655
- )
656
-
657
- if prompt is not None and prompt_embeds is not None:
658
- raise ValueError(
659
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
660
- " only forward one of the two."
661
- )
662
- elif prompt is None and prompt_embeds is None:
663
- raise ValueError(
664
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
665
- )
666
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
667
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
668
-
669
- if negative_prompt is not None and negative_prompt_embeds is not None:
670
- raise ValueError(
671
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
672
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
673
- )
674
-
675
- if prompt_embeds is not None and negative_prompt_embeds is not None:
676
- if prompt_embeds.shape != negative_prompt_embeds.shape:
677
- raise ValueError(
678
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
679
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
680
- f" {negative_prompt_embeds.shape}."
681
- )
682
-
683
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
684
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
685
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
686
- if isinstance(generator, list) and len(generator) != batch_size:
687
- raise ValueError(
688
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
689
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
690
- )
691
-
692
- if latents is None:
693
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
694
- else:
695
- latents = latents.to(device)
696
-
697
- # scale the initial noise by the standard deviation required by the scheduler
698
- latents = latents * self.scheduler.init_noise_sigma
699
- return latents
700
-
701
- def prepare_garm_latents(
702
- self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
703
- ):
704
- if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
705
- raise ValueError(
706
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
707
- )
708
-
709
- image = image.to(device=device, dtype=dtype)
710
-
711
- batch_size = batch_size * num_images_per_prompt
712
-
713
- if image.shape[1] == 4:
714
- image_latents = image
715
- else:
716
- if isinstance(generator, list) and len(generator) != batch_size:
717
- raise ValueError(
718
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
719
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
720
- )
721
-
722
- if isinstance(generator, list):
723
- image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
724
- image_latents = torch.cat(image_latents, dim=0)
725
- else:
726
- image_latents = self.vae.encode(image).latent_dist.mode()
727
-
728
- if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
729
- additional_image_per_prompt = batch_size // image_latents.shape[0]
730
- image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
731
- elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
732
- raise ValueError(
733
- f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
734
- )
735
- else:
736
- image_latents = torch.cat([image_latents], dim=0)
737
-
738
- if do_classifier_free_guidance:
739
- uncond_image_latents = torch.zeros_like(image_latents)
740
- image_latents = torch.cat([image_latents, uncond_image_latents], dim=0)
741
-
742
- return image_latents
743
-
744
- def prepare_vton_latents(
745
- self, image, mask, image_ori, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
746
- ):
747
- if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
748
- raise ValueError(
749
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
750
- )
751
-
752
- image = image.to(device=device, dtype=dtype)
753
- image_ori = image_ori.to(device=device, dtype=dtype)
754
-
755
- batch_size = batch_size * num_images_per_prompt
756
-
757
- if image.shape[1] == 4:
758
- image_latents = image
759
- image_ori_latents = image_ori
760
- else:
761
- if isinstance(generator, list) and len(generator) != batch_size:
762
- raise ValueError(
763
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
764
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
765
- )
766
-
767
- if isinstance(generator, list):
768
- image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
769
- image_latents = torch.cat(image_latents, dim=0)
770
- image_ori_latents = [self.vae.encode(image_ori[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
771
- image_ori_latents = torch.cat(image_ori_latents, dim=0)
772
- else:
773
- image_latents = self.vae.encode(image).latent_dist.mode()
774
- image_ori_latents = self.vae.encode(image_ori).latent_dist.mode()
775
-
776
- mask = torch.nn.functional.interpolate(
777
- mask, size=(image_latents.size(-2), image_latents.size(-1))
778
- )
779
- mask = mask.to(device=device, dtype=dtype)
780
-
781
- if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
782
- additional_image_per_prompt = batch_size // image_latents.shape[0]
783
- image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
784
- mask = torch.cat([mask] * additional_image_per_prompt, dim=0)
785
- image_ori_latents = torch.cat([image_ori_latents] * additional_image_per_prompt, dim=0)
786
- elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
787
- raise ValueError(
788
- f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
789
- )
790
- else:
791
- image_latents = torch.cat([image_latents], dim=0)
792
- mask = torch.cat([mask], dim=0)
793
- image_ori_latents = torch.cat([image_ori_latents], dim=0)
794
-
795
- if do_classifier_free_guidance:
796
- # uncond_image_latents = torch.zeros_like(image_latents)
797
- image_latents = torch.cat([image_latents] * 2, dim=0)
798
-
799
- return image_latents, mask, image_ori_latents
800
-
801
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
802
- def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
803
- r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
804
-
805
- The suffixes after the scaling factors represent the stages where they are being applied.
806
-
807
- Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
808
- that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
809
-
810
- Args:
811
- s1 (`float`):
812
- Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
813
- mitigate "oversmoothing effect" in the enhanced denoising process.
814
- s2 (`float`):
815
- Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
816
- mitigate "oversmoothing effect" in the enhanced denoising process.
817
- b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
818
- b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
819
- """
820
- if not hasattr(self, "unet"):
821
- raise ValueError("The pipeline must have `unet` for using FreeU.")
822
- self.unet_vton.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
823
-
824
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
825
- def disable_freeu(self):
826
- """Disables the FreeU mechanism if enabled."""
827
- self.unet_vton.disable_freeu()
828
-
829
- @property
830
- def guidance_scale(self):
831
- return self._guidance_scale
832
-
833
- @property
834
- def image_guidance_scale(self):
835
- return self._image_guidance_scale
836
-
837
- @property
838
- def num_timesteps(self):
839
- return self._num_timesteps
840
-
841
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
842
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
843
- # corresponds to doing no classifier free guidance.
844
- @property
845
- def do_classifier_free_guidance(self):
846
- return self.image_guidance_scale >= 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ootd/pipelines_ootd/transformer_garm_2d.py DELETED
@@ -1,449 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
- from dataclasses import dataclass
17
- from typing import Any, Dict, Optional
18
-
19
- import torch
20
- import torch.nn.functional as F
21
- from torch import nn
22
-
23
- from .attention_garm import BasicTransformerBlock
24
-
25
- from diffusers.configuration_utils import ConfigMixin, register_to_config
26
- from diffusers.models.embeddings import ImagePositionalEmbeddings
27
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate
28
- # from diffusers.models.attention import BasicTransformerBlock
29
- from diffusers.models.embeddings import CaptionProjection, PatchEmbed
30
- from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
31
- from diffusers.models.modeling_utils import ModelMixin
32
- from diffusers.models.normalization import AdaLayerNormSingle
33
-
34
-
35
- @dataclass
36
- class Transformer2DModelOutput(BaseOutput):
37
- """
38
- The output of [`Transformer2DModel`].
39
-
40
- Args:
41
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
42
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
43
- distributions for the unnoised latent pixels.
44
- """
45
-
46
- sample: torch.FloatTensor
47
-
48
-
49
- class Transformer2DModel(ModelMixin, ConfigMixin):
50
- """
51
- A 2D Transformer model for image-like data.
52
-
53
- Parameters:
54
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
55
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
56
- in_channels (`int`, *optional*):
57
- The number of channels in the input and output (specify if the input is **continuous**).
58
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
59
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
60
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
61
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
62
- This is fixed during training since it is used to learn a number of position embeddings.
63
- num_vector_embeds (`int`, *optional*):
64
- The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
65
- Includes the class for the masked latent pixel.
66
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
67
- num_embeds_ada_norm ( `int`, *optional*):
68
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
69
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
70
- added to the hidden states.
71
-
72
- During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
73
- attention_bias (`bool`, *optional*):
74
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
75
- """
76
-
77
- @register_to_config
78
- def __init__(
79
- self,
80
- num_attention_heads: int = 16,
81
- attention_head_dim: int = 88,
82
- in_channels: Optional[int] = None,
83
- out_channels: Optional[int] = None,
84
- num_layers: int = 1,
85
- dropout: float = 0.0,
86
- norm_num_groups: int = 32,
87
- cross_attention_dim: Optional[int] = None,
88
- attention_bias: bool = False,
89
- sample_size: Optional[int] = None,
90
- num_vector_embeds: Optional[int] = None,
91
- patch_size: Optional[int] = None,
92
- activation_fn: str = "geglu",
93
- num_embeds_ada_norm: Optional[int] = None,
94
- use_linear_projection: bool = False,
95
- only_cross_attention: bool = False,
96
- double_self_attention: bool = False,
97
- upcast_attention: bool = False,
98
- norm_type: str = "layer_norm",
99
- norm_elementwise_affine: bool = True,
100
- norm_eps: float = 1e-5,
101
- attention_type: str = "default",
102
- caption_channels: int = None,
103
- ):
104
- super().__init__()
105
- self.use_linear_projection = use_linear_projection
106
- self.num_attention_heads = num_attention_heads
107
- self.attention_head_dim = attention_head_dim
108
- inner_dim = num_attention_heads * attention_head_dim
109
-
110
- conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
111
- linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
112
-
113
- # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
114
- # Define whether input is continuous or discrete depending on configuration
115
- self.is_input_continuous = (in_channels is not None) and (patch_size is None)
116
- self.is_input_vectorized = num_vector_embeds is not None
117
- self.is_input_patches = in_channels is not None and patch_size is not None
118
-
119
- if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
120
- deprecation_message = (
121
- f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
122
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
123
- " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
124
- " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
125
- " would be very nice if you could open a Pull request for the `transformer/config.json` file"
126
- )
127
- deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
128
- norm_type = "ada_norm"
129
-
130
- if self.is_input_continuous and self.is_input_vectorized:
131
- raise ValueError(
132
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
133
- " sure that either `in_channels` or `num_vector_embeds` is None."
134
- )
135
- elif self.is_input_vectorized and self.is_input_patches:
136
- raise ValueError(
137
- f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
138
- " sure that either `num_vector_embeds` or `num_patches` is None."
139
- )
140
- elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
141
- raise ValueError(
142
- f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
143
- f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
144
- )
145
-
146
- # 2. Define input layers
147
- if self.is_input_continuous:
148
- self.in_channels = in_channels
149
-
150
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
151
- if use_linear_projection:
152
- self.proj_in = linear_cls(in_channels, inner_dim)
153
- else:
154
- self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
155
- elif self.is_input_vectorized:
156
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
157
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
158
-
159
- self.height = sample_size
160
- self.width = sample_size
161
- self.num_vector_embeds = num_vector_embeds
162
- self.num_latent_pixels = self.height * self.width
163
-
164
- self.latent_image_embedding = ImagePositionalEmbeddings(
165
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
166
- )
167
- elif self.is_input_patches:
168
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
169
-
170
- self.height = sample_size
171
- self.width = sample_size
172
-
173
- self.patch_size = patch_size
174
- interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
175
- interpolation_scale = max(interpolation_scale, 1)
176
- self.pos_embed = PatchEmbed(
177
- height=sample_size,
178
- width=sample_size,
179
- patch_size=patch_size,
180
- in_channels=in_channels,
181
- embed_dim=inner_dim,
182
- interpolation_scale=interpolation_scale,
183
- )
184
-
185
- # 3. Define transformers blocks
186
- self.transformer_blocks = nn.ModuleList(
187
- [
188
- BasicTransformerBlock(
189
- inner_dim,
190
- num_attention_heads,
191
- attention_head_dim,
192
- dropout=dropout,
193
- cross_attention_dim=cross_attention_dim,
194
- activation_fn=activation_fn,
195
- num_embeds_ada_norm=num_embeds_ada_norm,
196
- attention_bias=attention_bias,
197
- only_cross_attention=only_cross_attention,
198
- double_self_attention=double_self_attention,
199
- upcast_attention=upcast_attention,
200
- norm_type=norm_type,
201
- norm_elementwise_affine=norm_elementwise_affine,
202
- norm_eps=norm_eps,
203
- attention_type=attention_type,
204
- )
205
- for d in range(num_layers)
206
- ]
207
- )
208
-
209
- # 4. Define output layers
210
- self.out_channels = in_channels if out_channels is None else out_channels
211
- if self.is_input_continuous:
212
- # TODO: should use out_channels for continuous projections
213
- if use_linear_projection:
214
- self.proj_out = linear_cls(inner_dim, in_channels)
215
- else:
216
- self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
217
- elif self.is_input_vectorized:
218
- self.norm_out = nn.LayerNorm(inner_dim)
219
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
220
- elif self.is_input_patches and norm_type != "ada_norm_single":
221
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
222
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
223
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
224
- elif self.is_input_patches and norm_type == "ada_norm_single":
225
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
226
- self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
227
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
228
-
229
- # 5. PixArt-Alpha blocks.
230
- self.adaln_single = None
231
- self.use_additional_conditions = False
232
- if norm_type == "ada_norm_single":
233
- self.use_additional_conditions = self.config.sample_size == 128
234
- # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
235
- # additional conditions until we find better name
236
- self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
237
-
238
- self.caption_projection = None
239
- if caption_channels is not None:
240
- self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
241
-
242
- self.gradient_checkpointing = False
243
-
244
- def forward(
245
- self,
246
- hidden_states: torch.Tensor,
247
- spatial_attn_inputs = [],
248
- encoder_hidden_states: Optional[torch.Tensor] = None,
249
- timestep: Optional[torch.LongTensor] = None,
250
- added_cond_kwargs: Dict[str, torch.Tensor] = None,
251
- class_labels: Optional[torch.LongTensor] = None,
252
- cross_attention_kwargs: Dict[str, Any] = None,
253
- attention_mask: Optional[torch.Tensor] = None,
254
- encoder_attention_mask: Optional[torch.Tensor] = None,
255
- return_dict: bool = True,
256
- ):
257
- """
258
- The [`Transformer2DModel`] forward method.
259
-
260
- Args:
261
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
262
- Input `hidden_states`.
263
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
264
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
265
- self-attention.
266
- timestep ( `torch.LongTensor`, *optional*):
267
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
268
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
269
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
270
- `AdaLayerZeroNorm`.
271
- cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
272
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
273
- `self.processor` in
274
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
275
- attention_mask ( `torch.Tensor`, *optional*):
276
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
277
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
278
- negative values to the attention scores corresponding to "discard" tokens.
279
- encoder_attention_mask ( `torch.Tensor`, *optional*):
280
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
281
-
282
- * Mask `(batch, sequence_length)` True = keep, False = discard.
283
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
284
-
285
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
286
- above. This bias will be added to the cross-attention scores.
287
- return_dict (`bool`, *optional*, defaults to `True`):
288
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
289
- tuple.
290
-
291
- Returns:
292
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
293
- `tuple` where the first element is the sample tensor.
294
- """
295
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
296
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
297
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
298
- # expects mask of shape:
299
- # [batch, key_tokens]
300
- # adds singleton query_tokens dimension:
301
- # [batch, 1, key_tokens]
302
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
303
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
304
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
305
- if attention_mask is not None and attention_mask.ndim == 2:
306
- # assume that mask is expressed as:
307
- # (1 = keep, 0 = discard)
308
- # convert mask into a bias that can be added to attention scores:
309
- # (keep = +0, discard = -10000.0)
310
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
311
- attention_mask = attention_mask.unsqueeze(1)
312
-
313
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
314
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
315
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
316
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
317
-
318
- # Retrieve lora scale.
319
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
320
-
321
- # 1. Input
322
- if self.is_input_continuous:
323
- batch, _, height, width = hidden_states.shape
324
- residual = hidden_states
325
-
326
- hidden_states = self.norm(hidden_states)
327
- if not self.use_linear_projection:
328
- hidden_states = (
329
- self.proj_in(hidden_states, scale=lora_scale)
330
- if not USE_PEFT_BACKEND
331
- else self.proj_in(hidden_states)
332
- )
333
- inner_dim = hidden_states.shape[1]
334
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
335
- else:
336
- inner_dim = hidden_states.shape[1]
337
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
338
- hidden_states = (
339
- self.proj_in(hidden_states, scale=lora_scale)
340
- if not USE_PEFT_BACKEND
341
- else self.proj_in(hidden_states)
342
- )
343
-
344
- elif self.is_input_vectorized:
345
- hidden_states = self.latent_image_embedding(hidden_states)
346
- elif self.is_input_patches:
347
- height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
348
- hidden_states = self.pos_embed(hidden_states)
349
-
350
- if self.adaln_single is not None:
351
- if self.use_additional_conditions and added_cond_kwargs is None:
352
- raise ValueError(
353
- "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
354
- )
355
- batch_size = hidden_states.shape[0]
356
- timestep, embedded_timestep = self.adaln_single(
357
- timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
358
- )
359
-
360
- # 2. Blocks
361
- if self.caption_projection is not None:
362
- batch_size = hidden_states.shape[0]
363
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
364
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
365
-
366
- for block in self.transformer_blocks:
367
- if self.training and self.gradient_checkpointing:
368
- hidden_states, spatial_attn_inputs = torch.utils.checkpoint.checkpoint(
369
- block,
370
- hidden_states,
371
- spatial_attn_inputs,
372
- attention_mask,
373
- encoder_hidden_states,
374
- encoder_attention_mask,
375
- timestep,
376
- cross_attention_kwargs,
377
- class_labels,
378
- use_reentrant=False,
379
- )
380
- else:
381
- hidden_states, spatial_attn_inputs = block(
382
- hidden_states,
383
- spatial_attn_inputs,
384
- attention_mask=attention_mask,
385
- encoder_hidden_states=encoder_hidden_states,
386
- encoder_attention_mask=encoder_attention_mask,
387
- timestep=timestep,
388
- cross_attention_kwargs=cross_attention_kwargs,
389
- class_labels=class_labels,
390
- )
391
-
392
- # 3. Output
393
- if self.is_input_continuous:
394
- if not self.use_linear_projection:
395
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
396
- hidden_states = (
397
- self.proj_out(hidden_states, scale=lora_scale)
398
- if not USE_PEFT_BACKEND
399
- else self.proj_out(hidden_states)
400
- )
401
- else:
402
- hidden_states = (
403
- self.proj_out(hidden_states, scale=lora_scale)
404
- if not USE_PEFT_BACKEND
405
- else self.proj_out(hidden_states)
406
- )
407
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
408
-
409
- output = hidden_states + residual
410
- elif self.is_input_vectorized:
411
- hidden_states = self.norm_out(hidden_states)
412
- logits = self.out(hidden_states)
413
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
414
- logits = logits.permute(0, 2, 1)
415
-
416
- # log(p(x_0))
417
- output = F.log_softmax(logits.double(), dim=1).float()
418
-
419
- if self.is_input_patches:
420
- if self.config.norm_type != "ada_norm_single":
421
- conditioning = self.transformer_blocks[0].norm1.emb(
422
- timestep, class_labels, hidden_dtype=hidden_states.dtype
423
- )
424
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
425
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
426
- hidden_states = self.proj_out_2(hidden_states)
427
- elif self.config.norm_type == "ada_norm_single":
428
- shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
429
- hidden_states = self.norm_out(hidden_states)
430
- # Modulation
431
- hidden_states = hidden_states * (1 + scale) + shift
432
- hidden_states = self.proj_out(hidden_states)
433
- hidden_states = hidden_states.squeeze(1)
434
-
435
- # unpatchify
436
- if self.adaln_single is None:
437
- height = width = int(hidden_states.shape[1] ** 0.5)
438
- hidden_states = hidden_states.reshape(
439
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
440
- )
441
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
442
- output = hidden_states.reshape(
443
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
444
- )
445
-
446
- if not return_dict:
447
- return (output,), spatial_attn_inputs
448
-
449
- return Transformer2DModelOutput(sample=output), spatial_attn_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ootd/pipelines_ootd/transformer_vton_2d.py DELETED
@@ -1,452 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
- from dataclasses import dataclass
17
- from typing import Any, Dict, Optional
18
-
19
- import torch
20
- import torch.nn.functional as F
21
- from torch import nn
22
-
23
- from .attention_vton import BasicTransformerBlock
24
-
25
- from diffusers.configuration_utils import ConfigMixin, register_to_config
26
- from diffusers.models.embeddings import ImagePositionalEmbeddings
27
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate
28
- # from diffusers.models.attention import BasicTransformerBlock
29
- from diffusers.models.embeddings import CaptionProjection, PatchEmbed
30
- from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
31
- from diffusers.models.modeling_utils import ModelMixin
32
- from diffusers.models.normalization import AdaLayerNormSingle
33
-
34
-
35
- @dataclass
36
- class Transformer2DModelOutput(BaseOutput):
37
- """
38
- The output of [`Transformer2DModel`].
39
-
40
- Args:
41
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
42
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
43
- distributions for the unnoised latent pixels.
44
- """
45
-
46
- sample: torch.FloatTensor
47
-
48
-
49
- class Transformer2DModel(ModelMixin, ConfigMixin):
50
- """
51
- A 2D Transformer model for image-like data.
52
-
53
- Parameters:
54
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
55
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
56
- in_channels (`int`, *optional*):
57
- The number of channels in the input and output (specify if the input is **continuous**).
58
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
59
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
60
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
61
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
62
- This is fixed during training since it is used to learn a number of position embeddings.
63
- num_vector_embeds (`int`, *optional*):
64
- The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
65
- Includes the class for the masked latent pixel.
66
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
67
- num_embeds_ada_norm ( `int`, *optional*):
68
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
69
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
70
- added to the hidden states.
71
-
72
- During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
73
- attention_bias (`bool`, *optional*):
74
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
75
- """
76
-
77
- @register_to_config
78
- def __init__(
79
- self,
80
- num_attention_heads: int = 16,
81
- attention_head_dim: int = 88,
82
- in_channels: Optional[int] = None,
83
- out_channels: Optional[int] = None,
84
- num_layers: int = 1,
85
- dropout: float = 0.0,
86
- norm_num_groups: int = 32,
87
- cross_attention_dim: Optional[int] = None,
88
- attention_bias: bool = False,
89
- sample_size: Optional[int] = None,
90
- num_vector_embeds: Optional[int] = None,
91
- patch_size: Optional[int] = None,
92
- activation_fn: str = "geglu",
93
- num_embeds_ada_norm: Optional[int] = None,
94
- use_linear_projection: bool = False,
95
- only_cross_attention: bool = False,
96
- double_self_attention: bool = False,
97
- upcast_attention: bool = False,
98
- norm_type: str = "layer_norm",
99
- norm_elementwise_affine: bool = True,
100
- norm_eps: float = 1e-5,
101
- attention_type: str = "default",
102
- caption_channels: int = None,
103
- ):
104
- super().__init__()
105
- self.use_linear_projection = use_linear_projection
106
- self.num_attention_heads = num_attention_heads
107
- self.attention_head_dim = attention_head_dim
108
- inner_dim = num_attention_heads * attention_head_dim
109
-
110
- conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
111
- linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
112
-
113
- # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
114
- # Define whether input is continuous or discrete depending on configuration
115
- self.is_input_continuous = (in_channels is not None) and (patch_size is None)
116
- self.is_input_vectorized = num_vector_embeds is not None
117
- self.is_input_patches = in_channels is not None and patch_size is not None
118
-
119
- if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
120
- deprecation_message = (
121
- f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
122
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
123
- " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
124
- " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
125
- " would be very nice if you could open a Pull request for the `transformer/config.json` file"
126
- )
127
- deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
128
- norm_type = "ada_norm"
129
-
130
- if self.is_input_continuous and self.is_input_vectorized:
131
- raise ValueError(
132
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
133
- " sure that either `in_channels` or `num_vector_embeds` is None."
134
- )
135
- elif self.is_input_vectorized and self.is_input_patches:
136
- raise ValueError(
137
- f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
138
- " sure that either `num_vector_embeds` or `num_patches` is None."
139
- )
140
- elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
141
- raise ValueError(
142
- f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
143
- f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
144
- )
145
-
146
- # 2. Define input layers
147
- if self.is_input_continuous:
148
- self.in_channels = in_channels
149
-
150
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
151
- if use_linear_projection:
152
- self.proj_in = linear_cls(in_channels, inner_dim)
153
- else:
154
- self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
155
- elif self.is_input_vectorized:
156
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
157
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
158
-
159
- self.height = sample_size
160
- self.width = sample_size
161
- self.num_vector_embeds = num_vector_embeds
162
- self.num_latent_pixels = self.height * self.width
163
-
164
- self.latent_image_embedding = ImagePositionalEmbeddings(
165
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
166
- )
167
- elif self.is_input_patches:
168
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
169
-
170
- self.height = sample_size
171
- self.width = sample_size
172
-
173
- self.patch_size = patch_size
174
- interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
175
- interpolation_scale = max(interpolation_scale, 1)
176
- self.pos_embed = PatchEmbed(
177
- height=sample_size,
178
- width=sample_size,
179
- patch_size=patch_size,
180
- in_channels=in_channels,
181
- embed_dim=inner_dim,
182
- interpolation_scale=interpolation_scale,
183
- )
184
-
185
- # 3. Define transformers blocks
186
- self.transformer_blocks = nn.ModuleList(
187
- [
188
- BasicTransformerBlock(
189
- inner_dim,
190
- num_attention_heads,
191
- attention_head_dim,
192
- dropout=dropout,
193
- cross_attention_dim=cross_attention_dim,
194
- activation_fn=activation_fn,
195
- num_embeds_ada_norm=num_embeds_ada_norm,
196
- attention_bias=attention_bias,
197
- only_cross_attention=only_cross_attention,
198
- double_self_attention=double_self_attention,
199
- upcast_attention=upcast_attention,
200
- norm_type=norm_type,
201
- norm_elementwise_affine=norm_elementwise_affine,
202
- norm_eps=norm_eps,
203
- attention_type=attention_type,
204
- )
205
- for d in range(num_layers)
206
- ]
207
- )
208
-
209
- # 4. Define output layers
210
- self.out_channels = in_channels if out_channels is None else out_channels
211
- if self.is_input_continuous:
212
- # TODO: should use out_channels for continuous projections
213
- if use_linear_projection:
214
- self.proj_out = linear_cls(inner_dim, in_channels)
215
- else:
216
- self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
217
- elif self.is_input_vectorized:
218
- self.norm_out = nn.LayerNorm(inner_dim)
219
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
220
- elif self.is_input_patches and norm_type != "ada_norm_single":
221
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
222
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
223
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
224
- elif self.is_input_patches and norm_type == "ada_norm_single":
225
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
226
- self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
227
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
228
-
229
- # 5. PixArt-Alpha blocks.
230
- self.adaln_single = None
231
- self.use_additional_conditions = False
232
- if norm_type == "ada_norm_single":
233
- self.use_additional_conditions = self.config.sample_size == 128
234
- # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
235
- # additional conditions until we find better name
236
- self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
237
-
238
- self.caption_projection = None
239
- if caption_channels is not None:
240
- self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
241
-
242
- self.gradient_checkpointing = False
243
-
244
- def forward(
245
- self,
246
- hidden_states: torch.Tensor,
247
- spatial_attn_inputs = [],
248
- spatial_attn_idx = 0,
249
- encoder_hidden_states: Optional[torch.Tensor] = None,
250
- timestep: Optional[torch.LongTensor] = None,
251
- added_cond_kwargs: Dict[str, torch.Tensor] = None,
252
- class_labels: Optional[torch.LongTensor] = None,
253
- cross_attention_kwargs: Dict[str, Any] = None,
254
- attention_mask: Optional[torch.Tensor] = None,
255
- encoder_attention_mask: Optional[torch.Tensor] = None,
256
- return_dict: bool = True,
257
- ):
258
- """
259
- The [`Transformer2DModel`] forward method.
260
-
261
- Args:
262
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
263
- Input `hidden_states`.
264
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
265
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
266
- self-attention.
267
- timestep ( `torch.LongTensor`, *optional*):
268
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
269
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
270
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
271
- `AdaLayerZeroNorm`.
272
- cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
273
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
274
- `self.processor` in
275
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
276
- attention_mask ( `torch.Tensor`, *optional*):
277
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
278
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
279
- negative values to the attention scores corresponding to "discard" tokens.
280
- encoder_attention_mask ( `torch.Tensor`, *optional*):
281
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
282
-
283
- * Mask `(batch, sequence_length)` True = keep, False = discard.
284
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
285
-
286
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
287
- above. This bias will be added to the cross-attention scores.
288
- return_dict (`bool`, *optional*, defaults to `True`):
289
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
290
- tuple.
291
-
292
- Returns:
293
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
294
- `tuple` where the first element is the sample tensor.
295
- """
296
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
297
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
298
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
299
- # expects mask of shape:
300
- # [batch, key_tokens]
301
- # adds singleton query_tokens dimension:
302
- # [batch, 1, key_tokens]
303
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
304
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
305
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
306
- if attention_mask is not None and attention_mask.ndim == 2:
307
- # assume that mask is expressed as:
308
- # (1 = keep, 0 = discard)
309
- # convert mask into a bias that can be added to attention scores:
310
- # (keep = +0, discard = -10000.0)
311
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
312
- attention_mask = attention_mask.unsqueeze(1)
313
-
314
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
315
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
316
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
317
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
318
-
319
- # Retrieve lora scale.
320
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
321
-
322
- # 1. Input
323
- if self.is_input_continuous:
324
- batch, _, height, width = hidden_states.shape
325
- residual = hidden_states
326
-
327
- hidden_states = self.norm(hidden_states)
328
- if not self.use_linear_projection:
329
- hidden_states = (
330
- self.proj_in(hidden_states, scale=lora_scale)
331
- if not USE_PEFT_BACKEND
332
- else self.proj_in(hidden_states)
333
- )
334
- inner_dim = hidden_states.shape[1]
335
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
336
- else:
337
- inner_dim = hidden_states.shape[1]
338
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
339
- hidden_states = (
340
- self.proj_in(hidden_states, scale=lora_scale)
341
- if not USE_PEFT_BACKEND
342
- else self.proj_in(hidden_states)
343
- )
344
-
345
- elif self.is_input_vectorized:
346
- hidden_states = self.latent_image_embedding(hidden_states)
347
- elif self.is_input_patches:
348
- height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
349
- hidden_states = self.pos_embed(hidden_states)
350
-
351
- if self.adaln_single is not None:
352
- if self.use_additional_conditions and added_cond_kwargs is None:
353
- raise ValueError(
354
- "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
355
- )
356
- batch_size = hidden_states.shape[0]
357
- timestep, embedded_timestep = self.adaln_single(
358
- timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
359
- )
360
-
361
- # 2. Blocks
362
- if self.caption_projection is not None:
363
- batch_size = hidden_states.shape[0]
364
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
365
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
366
-
367
- for block in self.transformer_blocks:
368
- if self.training and self.gradient_checkpointing:
369
- hidden_states, spatial_attn_inputs, spatial_attn_idx = torch.utils.checkpoint.checkpoint(
370
- block,
371
- hidden_states,
372
- spatial_attn_inputs,
373
- spatial_attn_idx,
374
- attention_mask,
375
- encoder_hidden_states,
376
- encoder_attention_mask,
377
- timestep,
378
- cross_attention_kwargs,
379
- class_labels,
380
- use_reentrant=False,
381
- )
382
- else:
383
- hidden_states, spatial_attn_inputs, spatial_attn_idx = block(
384
- hidden_states,
385
- spatial_attn_inputs,
386
- spatial_attn_idx,
387
- attention_mask=attention_mask,
388
- encoder_hidden_states=encoder_hidden_states,
389
- encoder_attention_mask=encoder_attention_mask,
390
- timestep=timestep,
391
- cross_attention_kwargs=cross_attention_kwargs,
392
- class_labels=class_labels,
393
- )
394
-
395
- # 3. Output
396
- if self.is_input_continuous:
397
- if not self.use_linear_projection:
398
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
399
- hidden_states = (
400
- self.proj_out(hidden_states, scale=lora_scale)
401
- if not USE_PEFT_BACKEND
402
- else self.proj_out(hidden_states)
403
- )
404
- else:
405
- hidden_states = (
406
- self.proj_out(hidden_states, scale=lora_scale)
407
- if not USE_PEFT_BACKEND
408
- else self.proj_out(hidden_states)
409
- )
410
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
411
-
412
- output = hidden_states + residual
413
- elif self.is_input_vectorized:
414
- hidden_states = self.norm_out(hidden_states)
415
- logits = self.out(hidden_states)
416
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
417
- logits = logits.permute(0, 2, 1)
418
-
419
- # log(p(x_0))
420
- output = F.log_softmax(logits.double(), dim=1).float()
421
-
422
- if self.is_input_patches:
423
- if self.config.norm_type != "ada_norm_single":
424
- conditioning = self.transformer_blocks[0].norm1.emb(
425
- timestep, class_labels, hidden_dtype=hidden_states.dtype
426
- )
427
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
428
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
429
- hidden_states = self.proj_out_2(hidden_states)
430
- elif self.config.norm_type == "ada_norm_single":
431
- shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
432
- hidden_states = self.norm_out(hidden_states)
433
- # Modulation
434
- hidden_states = hidden_states * (1 + scale) + shift
435
- hidden_states = self.proj_out(hidden_states)
436
- hidden_states = hidden_states.squeeze(1)
437
-
438
- # unpatchify
439
- if self.adaln_single is None:
440
- height = width = int(hidden_states.shape[1] ** 0.5)
441
- hidden_states = hidden_states.reshape(
442
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
443
- )
444
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
445
- output = hidden_states.reshape(
446
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
447
- )
448
-
449
- if not return_dict:
450
- return (output,), spatial_attn_inputs, spatial_attn_idx
451
-
452
- return Transformer2DModelOutput(sample=output), spatial_attn_inputs, spatial_attn_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ootd/pipelines_ootd/unet_garm_2d_blocks.py DELETED
The diff for this file is too large to render. See raw diff
 
ootd/pipelines_ootd/unet_garm_2d_condition.py DELETED
@@ -1,1183 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
- from dataclasses import dataclass
17
- from typing import Any, Dict, List, Optional, Tuple, Union
18
-
19
- import torch
20
- import torch.nn as nn
21
- import torch.utils.checkpoint
22
-
23
- from .unet_garm_2d_blocks import (
24
- UNetMidBlock2D,
25
- UNetMidBlock2DCrossAttn,
26
- UNetMidBlock2DSimpleCrossAttn,
27
- get_down_block,
28
- get_up_block,
29
- )
30
-
31
- from diffusers.configuration_utils import ConfigMixin, register_to_config
32
- from diffusers.loaders import UNet2DConditionLoadersMixin
33
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
34
- from diffusers.models.activations import get_activation
35
- from diffusers.models.attention_processor import (
36
- ADDED_KV_ATTENTION_PROCESSORS,
37
- CROSS_ATTENTION_PROCESSORS,
38
- AttentionProcessor,
39
- AttnAddedKVProcessor,
40
- AttnProcessor,
41
- )
42
- from diffusers.models.embeddings import (
43
- GaussianFourierProjection,
44
- ImageHintTimeEmbedding,
45
- ImageProjection,
46
- ImageTimeEmbedding,
47
- PositionNet,
48
- TextImageProjection,
49
- TextImageTimeEmbedding,
50
- TextTimeEmbedding,
51
- TimestepEmbedding,
52
- Timesteps,
53
- )
54
- from diffusers.models.modeling_utils import ModelMixin
55
- # from diffusers.models.unet_2d_blocks import (
56
- # UNetMidBlock2D,
57
- # UNetMidBlock2DCrossAttn,
58
- # UNetMidBlock2DSimpleCrossAttn,
59
- # get_down_block,
60
- # get_up_block,
61
- # )
62
-
63
-
64
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
65
-
66
-
67
- @dataclass
68
- class UNet2DConditionOutput(BaseOutput):
69
- """
70
- The output of [`UNet2DConditionModel`].
71
-
72
- Args:
73
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
74
- The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
75
- """
76
-
77
- sample: torch.FloatTensor = None
78
-
79
-
80
- class UNetGarm2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
81
- r"""
82
- A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
83
- shaped output.
84
-
85
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
86
- for all models (such as downloading or saving).
87
-
88
- Parameters:
89
- sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
90
- Height and width of input/output sample.
91
- in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
92
- out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
93
- center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
94
- flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
95
- Whether to flip the sin to cos in the time embedding.
96
- freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
97
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
98
- The tuple of downsample blocks to use.
99
- mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
100
- Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
101
- `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
102
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
103
- The tuple of upsample blocks to use.
104
- only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
105
- Whether to include self-attention in the basic transformer blocks, see
106
- [`~models.attention.BasicTransformerBlock`].
107
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
108
- The tuple of output channels for each block.
109
- layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
110
- downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
111
- mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
112
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
113
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
114
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
115
- If `None`, normalization and activation layers is skipped in post-processing.
116
- norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
117
- cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
118
- The dimension of the cross attention features.
119
- transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
120
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
121
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
122
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
123
- reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
124
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
125
- blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
126
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
127
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
128
- encoder_hid_dim (`int`, *optional*, defaults to None):
129
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
130
- dimension to `cross_attention_dim`.
131
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
132
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
133
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
134
- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
135
- num_attention_heads (`int`, *optional*):
136
- The number of attention heads. If not defined, defaults to `attention_head_dim`
137
- resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
138
- for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
139
- class_embed_type (`str`, *optional*, defaults to `None`):
140
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
141
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
142
- addition_embed_type (`str`, *optional*, defaults to `None`):
143
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
144
- "text". "text" will use the `TextTimeEmbedding` layer.
145
- addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
146
- Dimension for the timestep embeddings.
147
- num_class_embeds (`int`, *optional*, defaults to `None`):
148
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
149
- class conditioning with `class_embed_type` equal to `None`.
150
- time_embedding_type (`str`, *optional*, defaults to `positional`):
151
- The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
152
- time_embedding_dim (`int`, *optional*, defaults to `None`):
153
- An optional override for the dimension of the projected time embedding.
154
- time_embedding_act_fn (`str`, *optional*, defaults to `None`):
155
- Optional activation function to use only once on the time embeddings before they are passed to the rest of
156
- the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
157
- timestep_post_act (`str`, *optional*, defaults to `None`):
158
- The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
159
- time_cond_proj_dim (`int`, *optional*, defaults to `None`):
160
- The dimension of `cond_proj` layer in the timestep embedding.
161
- conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
162
- *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
163
- *optional*): The dimension of the `class_labels` input when
164
- `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
165
- class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
166
- embeddings with the class embeddings.
167
- mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
168
- Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
169
- `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
170
- `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
171
- otherwise.
172
- """
173
-
174
- _supports_gradient_checkpointing = True
175
-
176
- @register_to_config
177
- def __init__(
178
- self,
179
- sample_size: Optional[int] = None,
180
- in_channels: int = 4,
181
- out_channels: int = 4,
182
- center_input_sample: bool = False,
183
- flip_sin_to_cos: bool = True,
184
- freq_shift: int = 0,
185
- down_block_types: Tuple[str] = (
186
- "CrossAttnDownBlock2D",
187
- "CrossAttnDownBlock2D",
188
- "CrossAttnDownBlock2D",
189
- "DownBlock2D",
190
- ),
191
- mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
192
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
193
- only_cross_attention: Union[bool, Tuple[bool]] = False,
194
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
195
- layers_per_block: Union[int, Tuple[int]] = 2,
196
- downsample_padding: int = 1,
197
- mid_block_scale_factor: float = 1,
198
- dropout: float = 0.0,
199
- act_fn: str = "silu",
200
- norm_num_groups: Optional[int] = 32,
201
- norm_eps: float = 1e-5,
202
- cross_attention_dim: Union[int, Tuple[int]] = 1280,
203
- transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
204
- reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
205
- encoder_hid_dim: Optional[int] = None,
206
- encoder_hid_dim_type: Optional[str] = None,
207
- attention_head_dim: Union[int, Tuple[int]] = 8,
208
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
209
- dual_cross_attention: bool = False,
210
- use_linear_projection: bool = False,
211
- class_embed_type: Optional[str] = None,
212
- addition_embed_type: Optional[str] = None,
213
- addition_time_embed_dim: Optional[int] = None,
214
- num_class_embeds: Optional[int] = None,
215
- upcast_attention: bool = False,
216
- resnet_time_scale_shift: str = "default",
217
- resnet_skip_time_act: bool = False,
218
- resnet_out_scale_factor: int = 1.0,
219
- time_embedding_type: str = "positional",
220
- time_embedding_dim: Optional[int] = None,
221
- time_embedding_act_fn: Optional[str] = None,
222
- timestep_post_act: Optional[str] = None,
223
- time_cond_proj_dim: Optional[int] = None,
224
- conv_in_kernel: int = 3,
225
- conv_out_kernel: int = 3,
226
- projection_class_embeddings_input_dim: Optional[int] = None,
227
- attention_type: str = "default",
228
- class_embeddings_concat: bool = False,
229
- mid_block_only_cross_attention: Optional[bool] = None,
230
- cross_attention_norm: Optional[str] = None,
231
- addition_embed_type_num_heads=64,
232
- ):
233
- super().__init__()
234
-
235
- self.sample_size = sample_size
236
-
237
- if num_attention_heads is not None:
238
- raise ValueError(
239
- "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
240
- )
241
-
242
- # If `num_attention_heads` is not defined (which is the case for most models)
243
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
244
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
245
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
246
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
247
- # which is why we correct for the naming here.
248
- num_attention_heads = num_attention_heads or attention_head_dim
249
-
250
- # Check inputs
251
- if len(down_block_types) != len(up_block_types):
252
- raise ValueError(
253
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
254
- )
255
-
256
- if len(block_out_channels) != len(down_block_types):
257
- raise ValueError(
258
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
259
- )
260
-
261
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
262
- raise ValueError(
263
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
264
- )
265
-
266
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
267
- raise ValueError(
268
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
269
- )
270
-
271
- if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
272
- raise ValueError(
273
- f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
274
- )
275
-
276
- if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
277
- raise ValueError(
278
- f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
279
- )
280
-
281
- if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
282
- raise ValueError(
283
- f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
284
- )
285
- if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
286
- for layer_number_per_block in transformer_layers_per_block:
287
- if isinstance(layer_number_per_block, list):
288
- raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
289
-
290
- # input
291
- conv_in_padding = (conv_in_kernel - 1) // 2
292
- self.conv_in = nn.Conv2d(
293
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
294
- )
295
-
296
- # time
297
- if time_embedding_type == "fourier":
298
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
299
- if time_embed_dim % 2 != 0:
300
- raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
301
- self.time_proj = GaussianFourierProjection(
302
- time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
303
- )
304
- timestep_input_dim = time_embed_dim
305
- elif time_embedding_type == "positional":
306
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
307
-
308
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
309
- timestep_input_dim = block_out_channels[0]
310
- else:
311
- raise ValueError(
312
- f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
313
- )
314
-
315
- self.time_embedding = TimestepEmbedding(
316
- timestep_input_dim,
317
- time_embed_dim,
318
- act_fn=act_fn,
319
- post_act_fn=timestep_post_act,
320
- cond_proj_dim=time_cond_proj_dim,
321
- )
322
-
323
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
324
- encoder_hid_dim_type = "text_proj"
325
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
326
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
327
-
328
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
329
- raise ValueError(
330
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
331
- )
332
-
333
- if encoder_hid_dim_type == "text_proj":
334
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
335
- elif encoder_hid_dim_type == "text_image_proj":
336
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
337
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
338
- # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
339
- self.encoder_hid_proj = TextImageProjection(
340
- text_embed_dim=encoder_hid_dim,
341
- image_embed_dim=cross_attention_dim,
342
- cross_attention_dim=cross_attention_dim,
343
- )
344
- elif encoder_hid_dim_type == "image_proj":
345
- # Kandinsky 2.2
346
- self.encoder_hid_proj = ImageProjection(
347
- image_embed_dim=encoder_hid_dim,
348
- cross_attention_dim=cross_attention_dim,
349
- )
350
- elif encoder_hid_dim_type is not None:
351
- raise ValueError(
352
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
353
- )
354
- else:
355
- self.encoder_hid_proj = None
356
-
357
- # class embedding
358
- if class_embed_type is None and num_class_embeds is not None:
359
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
360
- elif class_embed_type == "timestep":
361
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
362
- elif class_embed_type == "identity":
363
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
364
- elif class_embed_type == "projection":
365
- if projection_class_embeddings_input_dim is None:
366
- raise ValueError(
367
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
368
- )
369
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
370
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
371
- # 2. it projects from an arbitrary input dimension.
372
- #
373
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
374
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
375
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
376
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
377
- elif class_embed_type == "simple_projection":
378
- if projection_class_embeddings_input_dim is None:
379
- raise ValueError(
380
- "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
381
- )
382
- self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
383
- else:
384
- self.class_embedding = None
385
-
386
- if addition_embed_type == "text":
387
- if encoder_hid_dim is not None:
388
- text_time_embedding_from_dim = encoder_hid_dim
389
- else:
390
- text_time_embedding_from_dim = cross_attention_dim
391
-
392
- self.add_embedding = TextTimeEmbedding(
393
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
394
- )
395
- elif addition_embed_type == "text_image":
396
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
397
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
398
- # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
399
- self.add_embedding = TextImageTimeEmbedding(
400
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
401
- )
402
- elif addition_embed_type == "text_time":
403
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
404
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
405
- elif addition_embed_type == "image":
406
- # Kandinsky 2.2
407
- self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
408
- elif addition_embed_type == "image_hint":
409
- # Kandinsky 2.2 ControlNet
410
- self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
411
- elif addition_embed_type is not None:
412
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
413
-
414
- if time_embedding_act_fn is None:
415
- self.time_embed_act = None
416
- else:
417
- self.time_embed_act = get_activation(time_embedding_act_fn)
418
-
419
- self.down_blocks = nn.ModuleList([])
420
- self.up_blocks = nn.ModuleList([])
421
-
422
- if isinstance(only_cross_attention, bool):
423
- if mid_block_only_cross_attention is None:
424
- mid_block_only_cross_attention = only_cross_attention
425
-
426
- only_cross_attention = [only_cross_attention] * len(down_block_types)
427
-
428
- if mid_block_only_cross_attention is None:
429
- mid_block_only_cross_attention = False
430
-
431
- if isinstance(num_attention_heads, int):
432
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
433
-
434
- if isinstance(attention_head_dim, int):
435
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
436
-
437
- if isinstance(cross_attention_dim, int):
438
- cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
439
-
440
- if isinstance(layers_per_block, int):
441
- layers_per_block = [layers_per_block] * len(down_block_types)
442
-
443
- if isinstance(transformer_layers_per_block, int):
444
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
445
-
446
- if class_embeddings_concat:
447
- # The time embeddings are concatenated with the class embeddings. The dimension of the
448
- # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
449
- # regular time embeddings
450
- blocks_time_embed_dim = time_embed_dim * 2
451
- else:
452
- blocks_time_embed_dim = time_embed_dim
453
-
454
- # down
455
- output_channel = block_out_channels[0]
456
- for i, down_block_type in enumerate(down_block_types):
457
- input_channel = output_channel
458
- output_channel = block_out_channels[i]
459
- is_final_block = i == len(block_out_channels) - 1
460
-
461
- down_block = get_down_block(
462
- down_block_type,
463
- num_layers=layers_per_block[i],
464
- transformer_layers_per_block=transformer_layers_per_block[i],
465
- in_channels=input_channel,
466
- out_channels=output_channel,
467
- temb_channels=blocks_time_embed_dim,
468
- add_downsample=not is_final_block,
469
- resnet_eps=norm_eps,
470
- resnet_act_fn=act_fn,
471
- resnet_groups=norm_num_groups,
472
- cross_attention_dim=cross_attention_dim[i],
473
- num_attention_heads=num_attention_heads[i],
474
- downsample_padding=downsample_padding,
475
- dual_cross_attention=dual_cross_attention,
476
- use_linear_projection=use_linear_projection,
477
- only_cross_attention=only_cross_attention[i],
478
- upcast_attention=upcast_attention,
479
- resnet_time_scale_shift=resnet_time_scale_shift,
480
- attention_type=attention_type,
481
- resnet_skip_time_act=resnet_skip_time_act,
482
- resnet_out_scale_factor=resnet_out_scale_factor,
483
- cross_attention_norm=cross_attention_norm,
484
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
485
- dropout=dropout,
486
- )
487
- self.down_blocks.append(down_block)
488
-
489
- # mid
490
- if mid_block_type == "UNetMidBlock2DCrossAttn":
491
- self.mid_block = UNetMidBlock2DCrossAttn(
492
- transformer_layers_per_block=transformer_layers_per_block[-1],
493
- in_channels=block_out_channels[-1],
494
- temb_channels=blocks_time_embed_dim,
495
- dropout=dropout,
496
- resnet_eps=norm_eps,
497
- resnet_act_fn=act_fn,
498
- output_scale_factor=mid_block_scale_factor,
499
- resnet_time_scale_shift=resnet_time_scale_shift,
500
- cross_attention_dim=cross_attention_dim[-1],
501
- num_attention_heads=num_attention_heads[-1],
502
- resnet_groups=norm_num_groups,
503
- dual_cross_attention=dual_cross_attention,
504
- use_linear_projection=use_linear_projection,
505
- upcast_attention=upcast_attention,
506
- attention_type=attention_type,
507
- )
508
- elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
509
- self.mid_block = UNetMidBlock2DSimpleCrossAttn(
510
- in_channels=block_out_channels[-1],
511
- temb_channels=blocks_time_embed_dim,
512
- dropout=dropout,
513
- resnet_eps=norm_eps,
514
- resnet_act_fn=act_fn,
515
- output_scale_factor=mid_block_scale_factor,
516
- cross_attention_dim=cross_attention_dim[-1],
517
- attention_head_dim=attention_head_dim[-1],
518
- resnet_groups=norm_num_groups,
519
- resnet_time_scale_shift=resnet_time_scale_shift,
520
- skip_time_act=resnet_skip_time_act,
521
- only_cross_attention=mid_block_only_cross_attention,
522
- cross_attention_norm=cross_attention_norm,
523
- )
524
- elif mid_block_type == "UNetMidBlock2D":
525
- self.mid_block = UNetMidBlock2D(
526
- in_channels=block_out_channels[-1],
527
- temb_channels=blocks_time_embed_dim,
528
- dropout=dropout,
529
- num_layers=0,
530
- resnet_eps=norm_eps,
531
- resnet_act_fn=act_fn,
532
- output_scale_factor=mid_block_scale_factor,
533
- resnet_groups=norm_num_groups,
534
- resnet_time_scale_shift=resnet_time_scale_shift,
535
- add_attention=False,
536
- )
537
- elif mid_block_type is None:
538
- self.mid_block = None
539
- else:
540
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
541
-
542
- # count how many layers upsample the images
543
- self.num_upsamplers = 0
544
-
545
- # up
546
- reversed_block_out_channels = list(reversed(block_out_channels))
547
- reversed_num_attention_heads = list(reversed(num_attention_heads))
548
- reversed_layers_per_block = list(reversed(layers_per_block))
549
- reversed_cross_attention_dim = list(reversed(cross_attention_dim))
550
- reversed_transformer_layers_per_block = (
551
- list(reversed(transformer_layers_per_block))
552
- if reverse_transformer_layers_per_block is None
553
- else reverse_transformer_layers_per_block
554
- )
555
- only_cross_attention = list(reversed(only_cross_attention))
556
-
557
- output_channel = reversed_block_out_channels[0]
558
- for i, up_block_type in enumerate(up_block_types):
559
- is_final_block = i == len(block_out_channels) - 1
560
-
561
- prev_output_channel = output_channel
562
- output_channel = reversed_block_out_channels[i]
563
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
564
-
565
- # add upsample block for all BUT final layer
566
- if not is_final_block:
567
- add_upsample = True
568
- self.num_upsamplers += 1
569
- else:
570
- add_upsample = False
571
-
572
- up_block = get_up_block(
573
- up_block_type,
574
- num_layers=reversed_layers_per_block[i] + 1,
575
- transformer_layers_per_block=reversed_transformer_layers_per_block[i],
576
- in_channels=input_channel,
577
- out_channels=output_channel,
578
- prev_output_channel=prev_output_channel,
579
- temb_channels=blocks_time_embed_dim,
580
- add_upsample=add_upsample,
581
- resnet_eps=norm_eps,
582
- resnet_act_fn=act_fn,
583
- resolution_idx=i,
584
- resnet_groups=norm_num_groups,
585
- cross_attention_dim=reversed_cross_attention_dim[i],
586
- num_attention_heads=reversed_num_attention_heads[i],
587
- dual_cross_attention=dual_cross_attention,
588
- use_linear_projection=use_linear_projection,
589
- only_cross_attention=only_cross_attention[i],
590
- upcast_attention=upcast_attention,
591
- resnet_time_scale_shift=resnet_time_scale_shift,
592
- attention_type=attention_type,
593
- resnet_skip_time_act=resnet_skip_time_act,
594
- resnet_out_scale_factor=resnet_out_scale_factor,
595
- cross_attention_norm=cross_attention_norm,
596
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
597
- dropout=dropout,
598
- )
599
- self.up_blocks.append(up_block)
600
- prev_output_channel = output_channel
601
-
602
- # out
603
- if norm_num_groups is not None:
604
- self.conv_norm_out = nn.GroupNorm(
605
- num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
606
- )
607
-
608
- self.conv_act = get_activation(act_fn)
609
-
610
- else:
611
- self.conv_norm_out = None
612
- self.conv_act = None
613
-
614
- conv_out_padding = (conv_out_kernel - 1) // 2
615
- self.conv_out = nn.Conv2d(
616
- block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
617
- )
618
-
619
- if attention_type in ["gated", "gated-text-image"]:
620
- positive_len = 768
621
- if isinstance(cross_attention_dim, int):
622
- positive_len = cross_attention_dim
623
- elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
624
- positive_len = cross_attention_dim[0]
625
-
626
- feature_type = "text-only" if attention_type == "gated" else "text-image"
627
- self.position_net = PositionNet(
628
- positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
629
- )
630
-
631
- @property
632
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
633
- r"""
634
- Returns:
635
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
636
- indexed by its weight name.
637
- """
638
- # set recursively
639
- processors = {}
640
-
641
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
642
- if hasattr(module, "get_processor"):
643
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
644
-
645
- for sub_name, child in module.named_children():
646
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
647
-
648
- return processors
649
-
650
- for name, module in self.named_children():
651
- fn_recursive_add_processors(name, module, processors)
652
-
653
- return processors
654
-
655
- def set_attn_processor(
656
- self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
657
- ):
658
- r"""
659
- Sets the attention processor to use to compute attention.
660
-
661
- Parameters:
662
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
663
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
664
- for **all** `Attention` layers.
665
-
666
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
667
- processor. This is strongly recommended when setting trainable attention processors.
668
-
669
- """
670
- count = len(self.attn_processors.keys())
671
-
672
- if isinstance(processor, dict) and len(processor) != count:
673
- raise ValueError(
674
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
675
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
676
- )
677
-
678
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
679
- if hasattr(module, "set_processor"):
680
- if not isinstance(processor, dict):
681
- module.set_processor(processor, _remove_lora=_remove_lora)
682
- else:
683
- module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
684
-
685
- for sub_name, child in module.named_children():
686
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
687
-
688
- for name, module in self.named_children():
689
- fn_recursive_attn_processor(name, module, processor)
690
-
691
- def set_default_attn_processor(self):
692
- """
693
- Disables custom attention processors and sets the default attention implementation.
694
- """
695
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
696
- processor = AttnAddedKVProcessor()
697
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
698
- processor = AttnProcessor()
699
- else:
700
- raise ValueError(
701
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
702
- )
703
-
704
- self.set_attn_processor(processor, _remove_lora=True)
705
-
706
- def set_attention_slice(self, slice_size):
707
- r"""
708
- Enable sliced attention computation.
709
-
710
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
711
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
712
-
713
- Args:
714
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
715
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
716
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
717
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
718
- must be a multiple of `slice_size`.
719
- """
720
- sliceable_head_dims = []
721
-
722
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
723
- if hasattr(module, "set_attention_slice"):
724
- sliceable_head_dims.append(module.sliceable_head_dim)
725
-
726
- for child in module.children():
727
- fn_recursive_retrieve_sliceable_dims(child)
728
-
729
- # retrieve number of attention layers
730
- for module in self.children():
731
- fn_recursive_retrieve_sliceable_dims(module)
732
-
733
- num_sliceable_layers = len(sliceable_head_dims)
734
-
735
- if slice_size == "auto":
736
- # half the attention head size is usually a good trade-off between
737
- # speed and memory
738
- slice_size = [dim // 2 for dim in sliceable_head_dims]
739
- elif slice_size == "max":
740
- # make smallest slice possible
741
- slice_size = num_sliceable_layers * [1]
742
-
743
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
744
-
745
- if len(slice_size) != len(sliceable_head_dims):
746
- raise ValueError(
747
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
748
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
749
- )
750
-
751
- for i in range(len(slice_size)):
752
- size = slice_size[i]
753
- dim = sliceable_head_dims[i]
754
- if size is not None and size > dim:
755
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
756
-
757
- # Recursively walk through all the children.
758
- # Any children which exposes the set_attention_slice method
759
- # gets the message
760
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
761
- if hasattr(module, "set_attention_slice"):
762
- module.set_attention_slice(slice_size.pop())
763
-
764
- for child in module.children():
765
- fn_recursive_set_attention_slice(child, slice_size)
766
-
767
- reversed_slice_size = list(reversed(slice_size))
768
- for module in self.children():
769
- fn_recursive_set_attention_slice(module, reversed_slice_size)
770
-
771
- def _set_gradient_checkpointing(self, module, value=False):
772
- if hasattr(module, "gradient_checkpointing"):
773
- module.gradient_checkpointing = value
774
-
775
- def enable_freeu(self, s1, s2, b1, b2):
776
- r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
777
-
778
- The suffixes after the scaling factors represent the stage blocks where they are being applied.
779
-
780
- Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
781
- are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
782
-
783
- Args:
784
- s1 (`float`):
785
- Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
786
- mitigate the "oversmoothing effect" in the enhanced denoising process.
787
- s2 (`float`):
788
- Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
789
- mitigate the "oversmoothing effect" in the enhanced denoising process.
790
- b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
791
- b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
792
- """
793
- for i, upsample_block in enumerate(self.up_blocks):
794
- setattr(upsample_block, "s1", s1)
795
- setattr(upsample_block, "s2", s2)
796
- setattr(upsample_block, "b1", b1)
797
- setattr(upsample_block, "b2", b2)
798
-
799
- def disable_freeu(self):
800
- """Disables the FreeU mechanism."""
801
- freeu_keys = {"s1", "s2", "b1", "b2"}
802
- for i, upsample_block in enumerate(self.up_blocks):
803
- for k in freeu_keys:
804
- if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
805
- setattr(upsample_block, k, None)
806
-
807
- def forward(
808
- self,
809
- sample: torch.FloatTensor,
810
- timestep: Union[torch.Tensor, float, int],
811
- encoder_hidden_states: torch.Tensor,
812
- class_labels: Optional[torch.Tensor] = None,
813
- timestep_cond: Optional[torch.Tensor] = None,
814
- attention_mask: Optional[torch.Tensor] = None,
815
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
816
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
817
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
818
- mid_block_additional_residual: Optional[torch.Tensor] = None,
819
- down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
820
- encoder_attention_mask: Optional[torch.Tensor] = None,
821
- return_dict: bool = True,
822
- ) -> Union[UNet2DConditionOutput, Tuple]:
823
- r"""
824
- The [`UNet2DConditionModel`] forward method.
825
-
826
- Args:
827
- sample (`torch.FloatTensor`):
828
- The noisy input tensor with the following shape `(batch, channel, height, width)`.
829
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
830
- encoder_hidden_states (`torch.FloatTensor`):
831
- The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
832
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
833
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
834
- timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
835
- Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
836
- through the `self.time_embedding` layer to obtain the timestep embeddings.
837
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
838
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
839
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
840
- negative values to the attention scores corresponding to "discard" tokens.
841
- cross_attention_kwargs (`dict`, *optional*):
842
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
843
- `self.processor` in
844
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
845
- added_cond_kwargs: (`dict`, *optional*):
846
- A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
847
- are passed along to the UNet blocks.
848
- down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
849
- A tuple of tensors that if specified are added to the residuals of down unet blocks.
850
- mid_block_additional_residual: (`torch.Tensor`, *optional*):
851
- A tensor that if specified is added to the residual of the middle unet block.
852
- encoder_attention_mask (`torch.Tensor`):
853
- A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
854
- `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
855
- which adds large negative values to the attention scores corresponding to "discard" tokens.
856
- return_dict (`bool`, *optional*, defaults to `True`):
857
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
858
- tuple.
859
- cross_attention_kwargs (`dict`, *optional*):
860
- A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
861
- added_cond_kwargs: (`dict`, *optional*):
862
- A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
863
- are passed along to the UNet blocks.
864
- down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
865
- additional residuals to be added to UNet long skip connections from down blocks to up blocks for
866
- example from ControlNet side model(s)
867
- mid_block_additional_residual (`torch.Tensor`, *optional*):
868
- additional residual to be added to UNet mid block output, for example from ControlNet side model
869
- down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
870
- additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
871
-
872
- Returns:
873
- [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
874
- If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
875
- a `tuple` is returned where the first element is the sample tensor.
876
- """
877
- # By default samples have to be AT least a multiple of the overall upsampling factor.
878
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
879
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
880
- # on the fly if necessary.
881
- default_overall_up_factor = 2**self.num_upsamplers
882
-
883
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
884
- forward_upsample_size = False
885
- upsample_size = None
886
-
887
- for dim in sample.shape[-2:]:
888
- if dim % default_overall_up_factor != 0:
889
- # Forward upsample size to force interpolation output size.
890
- forward_upsample_size = True
891
- break
892
-
893
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
894
- # expects mask of shape:
895
- # [batch, key_tokens]
896
- # adds singleton query_tokens dimension:
897
- # [batch, 1, key_tokens]
898
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
899
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
900
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
901
- if attention_mask is not None:
902
- # assume that mask is expressed as:
903
- # (1 = keep, 0 = discard)
904
- # convert mask into a bias that can be added to attention scores:
905
- # (keep = +0, discard = -10000.0)
906
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
907
- attention_mask = attention_mask.unsqueeze(1)
908
-
909
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
910
- if encoder_attention_mask is not None:
911
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
912
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
913
-
914
- # 0. center input if necessary
915
- if self.config.center_input_sample:
916
- sample = 2 * sample - 1.0
917
-
918
- # 1. time
919
- timesteps = timestep
920
- if not torch.is_tensor(timesteps):
921
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
922
- # This would be a good case for the `match` statement (Python 3.10+)
923
- is_mps = sample.device.type == "mps"
924
- if isinstance(timestep, float):
925
- dtype = torch.float32 if is_mps else torch.float64
926
- else:
927
- dtype = torch.int32 if is_mps else torch.int64
928
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
929
- elif len(timesteps.shape) == 0:
930
- timesteps = timesteps[None].to(sample.device)
931
-
932
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
933
- timesteps = timesteps.expand(sample.shape[0])
934
-
935
- t_emb = self.time_proj(timesteps)
936
-
937
- # `Timesteps` does not contain any weights and will always return f32 tensors
938
- # but time_embedding might actually be running in fp16. so we need to cast here.
939
- # there might be better ways to encapsulate this.
940
- t_emb = t_emb.to(dtype=sample.dtype)
941
-
942
- emb = self.time_embedding(t_emb, timestep_cond)
943
- aug_emb = None
944
-
945
- if self.class_embedding is not None:
946
- if class_labels is None:
947
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
948
-
949
- if self.config.class_embed_type == "timestep":
950
- class_labels = self.time_proj(class_labels)
951
-
952
- # `Timesteps` does not contain any weights and will always return f32 tensors
953
- # there might be better ways to encapsulate this.
954
- class_labels = class_labels.to(dtype=sample.dtype)
955
-
956
- class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
957
-
958
- if self.config.class_embeddings_concat:
959
- emb = torch.cat([emb, class_emb], dim=-1)
960
- else:
961
- emb = emb + class_emb
962
-
963
- if self.config.addition_embed_type == "text":
964
- aug_emb = self.add_embedding(encoder_hidden_states)
965
- elif self.config.addition_embed_type == "text_image":
966
- # Kandinsky 2.1 - style
967
- if "image_embeds" not in added_cond_kwargs:
968
- raise ValueError(
969
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
970
- )
971
-
972
- image_embs = added_cond_kwargs.get("image_embeds")
973
- text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
974
- aug_emb = self.add_embedding(text_embs, image_embs)
975
- elif self.config.addition_embed_type == "text_time":
976
- # SDXL - style
977
- if "text_embeds" not in added_cond_kwargs:
978
- raise ValueError(
979
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
980
- )
981
- text_embeds = added_cond_kwargs.get("text_embeds")
982
- if "time_ids" not in added_cond_kwargs:
983
- raise ValueError(
984
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
985
- )
986
- time_ids = added_cond_kwargs.get("time_ids")
987
- time_embeds = self.add_time_proj(time_ids.flatten())
988
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
989
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
990
- add_embeds = add_embeds.to(emb.dtype)
991
- aug_emb = self.add_embedding(add_embeds)
992
- elif self.config.addition_embed_type == "image":
993
- # Kandinsky 2.2 - style
994
- if "image_embeds" not in added_cond_kwargs:
995
- raise ValueError(
996
- f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
997
- )
998
- image_embs = added_cond_kwargs.get("image_embeds")
999
- aug_emb = self.add_embedding(image_embs)
1000
- elif self.config.addition_embed_type == "image_hint":
1001
- # Kandinsky 2.2 - style
1002
- if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1003
- raise ValueError(
1004
- f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1005
- )
1006
- image_embs = added_cond_kwargs.get("image_embeds")
1007
- hint = added_cond_kwargs.get("hint")
1008
- aug_emb, hint = self.add_embedding(image_embs, hint)
1009
- sample = torch.cat([sample, hint], dim=1)
1010
-
1011
- emb = emb + aug_emb if aug_emb is not None else emb
1012
-
1013
- if self.time_embed_act is not None:
1014
- emb = self.time_embed_act(emb)
1015
-
1016
- if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1017
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1018
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1019
- # Kadinsky 2.1 - style
1020
- if "image_embeds" not in added_cond_kwargs:
1021
- raise ValueError(
1022
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1023
- )
1024
-
1025
- image_embeds = added_cond_kwargs.get("image_embeds")
1026
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1027
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1028
- # Kandinsky 2.2 - style
1029
- if "image_embeds" not in added_cond_kwargs:
1030
- raise ValueError(
1031
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1032
- )
1033
- image_embeds = added_cond_kwargs.get("image_embeds")
1034
- encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1035
- # 2. pre-process
1036
- sample = self.conv_in(sample)
1037
-
1038
- # 2.5 GLIGEN position net
1039
- if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1040
- cross_attention_kwargs = cross_attention_kwargs.copy()
1041
- gligen_args = cross_attention_kwargs.pop("gligen")
1042
- cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1043
-
1044
- # For Vton
1045
- spatial_attn_inputs = []
1046
-
1047
- # 3. down
1048
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1049
- if USE_PEFT_BACKEND:
1050
- # weight the lora layers by setting `lora_scale` for each PEFT layer
1051
- scale_lora_layers(self, lora_scale)
1052
-
1053
- is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1054
- # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1055
- is_adapter = down_intrablock_additional_residuals is not None
1056
- # maintain backward compatibility for legacy usage, where
1057
- # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1058
- # but can only use one or the other
1059
- if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1060
- deprecate(
1061
- "T2I should not use down_block_additional_residuals",
1062
- "1.3.0",
1063
- "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1064
- and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1065
- for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1066
- standard_warn=False,
1067
- )
1068
- down_intrablock_additional_residuals = down_block_additional_residuals
1069
- is_adapter = True
1070
-
1071
- down_block_res_samples = (sample,)
1072
- for downsample_block in self.down_blocks:
1073
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1074
- # For t2i-adapter CrossAttnDownBlock2D
1075
- additional_residuals = {}
1076
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
1077
- additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1078
-
1079
- sample, res_samples, spatial_attn_inputs = downsample_block(
1080
- hidden_states=sample,
1081
- spatial_attn_inputs=spatial_attn_inputs,
1082
- temb=emb,
1083
- encoder_hidden_states=encoder_hidden_states,
1084
- attention_mask=attention_mask,
1085
- cross_attention_kwargs=cross_attention_kwargs,
1086
- encoder_attention_mask=encoder_attention_mask,
1087
- **additional_residuals,
1088
- )
1089
- else:
1090
- sample, res_samples = downsample_block(
1091
- hidden_states=sample,
1092
- temb=emb,
1093
- scale=lora_scale,
1094
- )
1095
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
1096
- sample += down_intrablock_additional_residuals.pop(0)
1097
-
1098
- down_block_res_samples += res_samples
1099
-
1100
- # if is_controlnet:
1101
- # new_down_block_res_samples = ()
1102
-
1103
- # for down_block_res_sample, down_block_additional_residual in zip(
1104
- # down_block_res_samples, down_block_additional_residuals
1105
- # ):
1106
- # down_block_res_sample = down_block_res_sample + down_block_additional_residual
1107
- # new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1108
-
1109
- # down_block_res_samples = new_down_block_res_samples
1110
-
1111
- # 4. mid
1112
- if self.mid_block is not None:
1113
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1114
- sample, spatial_attn_inputs = self.mid_block(
1115
- sample,
1116
- spatial_attn_inputs=spatial_attn_inputs,
1117
- temb=emb,
1118
- encoder_hidden_states=encoder_hidden_states,
1119
- attention_mask=attention_mask,
1120
- cross_attention_kwargs=cross_attention_kwargs,
1121
- encoder_attention_mask=encoder_attention_mask,
1122
- )
1123
- else:
1124
- sample = self.mid_block(sample, emb)
1125
-
1126
- # To support T2I-Adapter-XL
1127
- if (
1128
- is_adapter
1129
- and len(down_intrablock_additional_residuals) > 0
1130
- and sample.shape == down_intrablock_additional_residuals[0].shape
1131
- ):
1132
- sample += down_intrablock_additional_residuals.pop(0)
1133
-
1134
- if is_controlnet:
1135
- sample = sample + mid_block_additional_residual
1136
-
1137
- # 5. up
1138
- for i, upsample_block in enumerate(self.up_blocks):
1139
- is_final_block = i == len(self.up_blocks) - 1
1140
-
1141
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1142
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1143
-
1144
- # if we have not reached the final block and need to forward the
1145
- # upsample size, we do it here
1146
- if not is_final_block and forward_upsample_size:
1147
- upsample_size = down_block_res_samples[-1].shape[2:]
1148
-
1149
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1150
- sample, spatial_attn_inputs = upsample_block(
1151
- hidden_states=sample,
1152
- spatial_attn_inputs=spatial_attn_inputs,
1153
- temb=emb,
1154
- res_hidden_states_tuple=res_samples,
1155
- encoder_hidden_states=encoder_hidden_states,
1156
- cross_attention_kwargs=cross_attention_kwargs,
1157
- upsample_size=upsample_size,
1158
- attention_mask=attention_mask,
1159
- encoder_attention_mask=encoder_attention_mask,
1160
- )
1161
- else:
1162
- sample = upsample_block(
1163
- hidden_states=sample,
1164
- temb=emb,
1165
- res_hidden_states_tuple=res_samples,
1166
- upsample_size=upsample_size,
1167
- scale=lora_scale,
1168
- )
1169
-
1170
- # 6. post-process
1171
- if self.conv_norm_out:
1172
- sample = self.conv_norm_out(sample)
1173
- sample = self.conv_act(sample)
1174
- sample = self.conv_out(sample)
1175
-
1176
- if USE_PEFT_BACKEND:
1177
- # remove `lora_scale` from each PEFT layer
1178
- unscale_lora_layers(self, lora_scale)
1179
-
1180
- if not return_dict:
1181
- return (sample,), spatial_attn_inputs
1182
-
1183
- return UNet2DConditionOutput(sample=sample), spatial_attn_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ootd/pipelines_ootd/unet_vton_2d_blocks.py DELETED
The diff for this file is too large to render. See raw diff
 
ootd/pipelines_ootd/unet_vton_2d_condition.py DELETED
@@ -1,1183 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
- from dataclasses import dataclass
17
- from typing import Any, Dict, List, Optional, Tuple, Union
18
-
19
- import torch
20
- import torch.nn as nn
21
- import torch.utils.checkpoint
22
-
23
- from .unet_vton_2d_blocks import (
24
- UNetMidBlock2D,
25
- UNetMidBlock2DCrossAttn,
26
- UNetMidBlock2DSimpleCrossAttn,
27
- get_down_block,
28
- get_up_block,
29
- )
30
-
31
- from diffusers.configuration_utils import ConfigMixin, register_to_config
32
- from diffusers.loaders import UNet2DConditionLoadersMixin
33
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
34
- from diffusers.models.activations import get_activation
35
- from diffusers.models.attention_processor import (
36
- ADDED_KV_ATTENTION_PROCESSORS,
37
- CROSS_ATTENTION_PROCESSORS,
38
- AttentionProcessor,
39
- AttnAddedKVProcessor,
40
- AttnProcessor,
41
- )
42
- from diffusers.models.embeddings import (
43
- GaussianFourierProjection,
44
- ImageHintTimeEmbedding,
45
- ImageProjection,
46
- ImageTimeEmbedding,
47
- PositionNet,
48
- TextImageProjection,
49
- TextImageTimeEmbedding,
50
- TextTimeEmbedding,
51
- TimestepEmbedding,
52
- Timesteps,
53
- )
54
- from diffusers.models.modeling_utils import ModelMixin
55
- # from ..diffusers.src.diffusers.models.unet_2d_blocks import (
56
- # UNetMidBlock2D,
57
- # UNetMidBlock2DCrossAttn,
58
- # UNetMidBlock2DSimpleCrossAttn,
59
- # get_down_block,
60
- # get_up_block,
61
- # )
62
-
63
-
64
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
65
-
66
-
67
- @dataclass
68
- class UNet2DConditionOutput(BaseOutput):
69
- """
70
- The output of [`UNet2DConditionModel`].
71
-
72
- Args:
73
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
74
- The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
75
- """
76
-
77
- sample: torch.FloatTensor = None
78
-
79
-
80
- class UNetVton2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
81
- r"""
82
- A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
83
- shaped output.
84
-
85
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
86
- for all models (such as downloading or saving).
87
-
88
- Parameters:
89
- sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
90
- Height and width of input/output sample.
91
- in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
92
- out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
93
- center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
94
- flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
95
- Whether to flip the sin to cos in the time embedding.
96
- freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
97
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
98
- The tuple of downsample blocks to use.
99
- mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
100
- Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
101
- `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
102
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
103
- The tuple of upsample blocks to use.
104
- only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
105
- Whether to include self-attention in the basic transformer blocks, see
106
- [`~models.attention.BasicTransformerBlock`].
107
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
108
- The tuple of output channels for each block.
109
- layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
110
- downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
111
- mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
112
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
113
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
114
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
115
- If `None`, normalization and activation layers is skipped in post-processing.
116
- norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
117
- cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
118
- The dimension of the cross attention features.
119
- transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
120
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
121
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
122
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
123
- reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
124
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
125
- blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
126
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
127
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
128
- encoder_hid_dim (`int`, *optional*, defaults to None):
129
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
130
- dimension to `cross_attention_dim`.
131
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
132
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
133
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
134
- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
135
- num_attention_heads (`int`, *optional*):
136
- The number of attention heads. If not defined, defaults to `attention_head_dim`
137
- resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
138
- for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
139
- class_embed_type (`str`, *optional*, defaults to `None`):
140
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
141
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
142
- addition_embed_type (`str`, *optional*, defaults to `None`):
143
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
144
- "text". "text" will use the `TextTimeEmbedding` layer.
145
- addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
146
- Dimension for the timestep embeddings.
147
- num_class_embeds (`int`, *optional*, defaults to `None`):
148
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
149
- class conditioning with `class_embed_type` equal to `None`.
150
- time_embedding_type (`str`, *optional*, defaults to `positional`):
151
- The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
152
- time_embedding_dim (`int`, *optional*, defaults to `None`):
153
- An optional override for the dimension of the projected time embedding.
154
- time_embedding_act_fn (`str`, *optional*, defaults to `None`):
155
- Optional activation function to use only once on the time embeddings before they are passed to the rest of
156
- the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
157
- timestep_post_act (`str`, *optional*, defaults to `None`):
158
- The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
159
- time_cond_proj_dim (`int`, *optional*, defaults to `None`):
160
- The dimension of `cond_proj` layer in the timestep embedding.
161
- conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
162
- *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
163
- *optional*): The dimension of the `class_labels` input when
164
- `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
165
- class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
166
- embeddings with the class embeddings.
167
- mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
168
- Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
169
- `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
170
- `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
171
- otherwise.
172
- """
173
-
174
- _supports_gradient_checkpointing = True
175
-
176
- @register_to_config
177
- def __init__(
178
- self,
179
- sample_size: Optional[int] = None,
180
- in_channels: int = 4,
181
- out_channels: int = 4,
182
- center_input_sample: bool = False,
183
- flip_sin_to_cos: bool = True,
184
- freq_shift: int = 0,
185
- down_block_types: Tuple[str] = (
186
- "CrossAttnDownBlock2D",
187
- "CrossAttnDownBlock2D",
188
- "CrossAttnDownBlock2D",
189
- "DownBlock2D",
190
- ),
191
- mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
192
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
193
- only_cross_attention: Union[bool, Tuple[bool]] = False,
194
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
195
- layers_per_block: Union[int, Tuple[int]] = 2,
196
- downsample_padding: int = 1,
197
- mid_block_scale_factor: float = 1,
198
- dropout: float = 0.0,
199
- act_fn: str = "silu",
200
- norm_num_groups: Optional[int] = 32,
201
- norm_eps: float = 1e-5,
202
- cross_attention_dim: Union[int, Tuple[int]] = 1280,
203
- transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
204
- reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
205
- encoder_hid_dim: Optional[int] = None,
206
- encoder_hid_dim_type: Optional[str] = None,
207
- attention_head_dim: Union[int, Tuple[int]] = 8,
208
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
209
- dual_cross_attention: bool = False,
210
- use_linear_projection: bool = False,
211
- class_embed_type: Optional[str] = None,
212
- addition_embed_type: Optional[str] = None,
213
- addition_time_embed_dim: Optional[int] = None,
214
- num_class_embeds: Optional[int] = None,
215
- upcast_attention: bool = False,
216
- resnet_time_scale_shift: str = "default",
217
- resnet_skip_time_act: bool = False,
218
- resnet_out_scale_factor: int = 1.0,
219
- time_embedding_type: str = "positional",
220
- time_embedding_dim: Optional[int] = None,
221
- time_embedding_act_fn: Optional[str] = None,
222
- timestep_post_act: Optional[str] = None,
223
- time_cond_proj_dim: Optional[int] = None,
224
- conv_in_kernel: int = 3,
225
- conv_out_kernel: int = 3,
226
- projection_class_embeddings_input_dim: Optional[int] = None,
227
- attention_type: str = "default",
228
- class_embeddings_concat: bool = False,
229
- mid_block_only_cross_attention: Optional[bool] = None,
230
- cross_attention_norm: Optional[str] = None,
231
- addition_embed_type_num_heads=64,
232
- ):
233
- super().__init__()
234
-
235
- self.sample_size = sample_size
236
-
237
- if num_attention_heads is not None:
238
- raise ValueError(
239
- "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
240
- )
241
-
242
- # If `num_attention_heads` is not defined (which is the case for most models)
243
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
244
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
245
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
246
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
247
- # which is why we correct for the naming here.
248
- num_attention_heads = num_attention_heads or attention_head_dim
249
-
250
- # Check inputs
251
- if len(down_block_types) != len(up_block_types):
252
- raise ValueError(
253
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
254
- )
255
-
256
- if len(block_out_channels) != len(down_block_types):
257
- raise ValueError(
258
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
259
- )
260
-
261
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
262
- raise ValueError(
263
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
264
- )
265
-
266
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
267
- raise ValueError(
268
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
269
- )
270
-
271
- if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
272
- raise ValueError(
273
- f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
274
- )
275
-
276
- if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
277
- raise ValueError(
278
- f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
279
- )
280
-
281
- if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
282
- raise ValueError(
283
- f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
284
- )
285
- if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
286
- for layer_number_per_block in transformer_layers_per_block:
287
- if isinstance(layer_number_per_block, list):
288
- raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
289
-
290
- # input
291
- conv_in_padding = (conv_in_kernel - 1) // 2
292
- self.conv_in = nn.Conv2d(
293
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
294
- )
295
-
296
- # time
297
- if time_embedding_type == "fourier":
298
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
299
- if time_embed_dim % 2 != 0:
300
- raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
301
- self.time_proj = GaussianFourierProjection(
302
- time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
303
- )
304
- timestep_input_dim = time_embed_dim
305
- elif time_embedding_type == "positional":
306
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
307
-
308
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
309
- timestep_input_dim = block_out_channels[0]
310
- else:
311
- raise ValueError(
312
- f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
313
- )
314
-
315
- self.time_embedding = TimestepEmbedding(
316
- timestep_input_dim,
317
- time_embed_dim,
318
- act_fn=act_fn,
319
- post_act_fn=timestep_post_act,
320
- cond_proj_dim=time_cond_proj_dim,
321
- )
322
-
323
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
324
- encoder_hid_dim_type = "text_proj"
325
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
326
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
327
-
328
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
329
- raise ValueError(
330
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
331
- )
332
-
333
- if encoder_hid_dim_type == "text_proj":
334
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
335
- elif encoder_hid_dim_type == "text_image_proj":
336
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
337
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
338
- # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
339
- self.encoder_hid_proj = TextImageProjection(
340
- text_embed_dim=encoder_hid_dim,
341
- image_embed_dim=cross_attention_dim,
342
- cross_attention_dim=cross_attention_dim,
343
- )
344
- elif encoder_hid_dim_type == "image_proj":
345
- # Kandinsky 2.2
346
- self.encoder_hid_proj = ImageProjection(
347
- image_embed_dim=encoder_hid_dim,
348
- cross_attention_dim=cross_attention_dim,
349
- )
350
- elif encoder_hid_dim_type is not None:
351
- raise ValueError(
352
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
353
- )
354
- else:
355
- self.encoder_hid_proj = None
356
-
357
- # class embedding
358
- if class_embed_type is None and num_class_embeds is not None:
359
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
360
- elif class_embed_type == "timestep":
361
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
362
- elif class_embed_type == "identity":
363
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
364
- elif class_embed_type == "projection":
365
- if projection_class_embeddings_input_dim is None:
366
- raise ValueError(
367
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
368
- )
369
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
370
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
371
- # 2. it projects from an arbitrary input dimension.
372
- #
373
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
374
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
375
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
376
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
377
- elif class_embed_type == "simple_projection":
378
- if projection_class_embeddings_input_dim is None:
379
- raise ValueError(
380
- "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
381
- )
382
- self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
383
- else:
384
- self.class_embedding = None
385
-
386
- if addition_embed_type == "text":
387
- if encoder_hid_dim is not None:
388
- text_time_embedding_from_dim = encoder_hid_dim
389
- else:
390
- text_time_embedding_from_dim = cross_attention_dim
391
-
392
- self.add_embedding = TextTimeEmbedding(
393
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
394
- )
395
- elif addition_embed_type == "text_image":
396
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
397
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
398
- # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
399
- self.add_embedding = TextImageTimeEmbedding(
400
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
401
- )
402
- elif addition_embed_type == "text_time":
403
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
404
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
405
- elif addition_embed_type == "image":
406
- # Kandinsky 2.2
407
- self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
408
- elif addition_embed_type == "image_hint":
409
- # Kandinsky 2.2 ControlNet
410
- self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
411
- elif addition_embed_type is not None:
412
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
413
-
414
- if time_embedding_act_fn is None:
415
- self.time_embed_act = None
416
- else:
417
- self.time_embed_act = get_activation(time_embedding_act_fn)
418
-
419
- self.down_blocks = nn.ModuleList([])
420
- self.up_blocks = nn.ModuleList([])
421
-
422
- if isinstance(only_cross_attention, bool):
423
- if mid_block_only_cross_attention is None:
424
- mid_block_only_cross_attention = only_cross_attention
425
-
426
- only_cross_attention = [only_cross_attention] * len(down_block_types)
427
-
428
- if mid_block_only_cross_attention is None:
429
- mid_block_only_cross_attention = False
430
-
431
- if isinstance(num_attention_heads, int):
432
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
433
-
434
- if isinstance(attention_head_dim, int):
435
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
436
-
437
- if isinstance(cross_attention_dim, int):
438
- cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
439
-
440
- if isinstance(layers_per_block, int):
441
- layers_per_block = [layers_per_block] * len(down_block_types)
442
-
443
- if isinstance(transformer_layers_per_block, int):
444
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
445
-
446
- if class_embeddings_concat:
447
- # The time embeddings are concatenated with the class embeddings. The dimension of the
448
- # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
449
- # regular time embeddings
450
- blocks_time_embed_dim = time_embed_dim * 2
451
- else:
452
- blocks_time_embed_dim = time_embed_dim
453
-
454
- # down
455
- output_channel = block_out_channels[0]
456
- for i, down_block_type in enumerate(down_block_types):
457
- input_channel = output_channel
458
- output_channel = block_out_channels[i]
459
- is_final_block = i == len(block_out_channels) - 1
460
-
461
- down_block = get_down_block(
462
- down_block_type,
463
- num_layers=layers_per_block[i],
464
- transformer_layers_per_block=transformer_layers_per_block[i],
465
- in_channels=input_channel,
466
- out_channels=output_channel,
467
- temb_channels=blocks_time_embed_dim,
468
- add_downsample=not is_final_block,
469
- resnet_eps=norm_eps,
470
- resnet_act_fn=act_fn,
471
- resnet_groups=norm_num_groups,
472
- cross_attention_dim=cross_attention_dim[i],
473
- num_attention_heads=num_attention_heads[i],
474
- downsample_padding=downsample_padding,
475
- dual_cross_attention=dual_cross_attention,
476
- use_linear_projection=use_linear_projection,
477
- only_cross_attention=only_cross_attention[i],
478
- upcast_attention=upcast_attention,
479
- resnet_time_scale_shift=resnet_time_scale_shift,
480
- attention_type=attention_type,
481
- resnet_skip_time_act=resnet_skip_time_act,
482
- resnet_out_scale_factor=resnet_out_scale_factor,
483
- cross_attention_norm=cross_attention_norm,
484
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
485
- dropout=dropout,
486
- )
487
- self.down_blocks.append(down_block)
488
-
489
- # mid
490
- if mid_block_type == "UNetMidBlock2DCrossAttn":
491
- self.mid_block = UNetMidBlock2DCrossAttn(
492
- transformer_layers_per_block=transformer_layers_per_block[-1],
493
- in_channels=block_out_channels[-1],
494
- temb_channels=blocks_time_embed_dim,
495
- dropout=dropout,
496
- resnet_eps=norm_eps,
497
- resnet_act_fn=act_fn,
498
- output_scale_factor=mid_block_scale_factor,
499
- resnet_time_scale_shift=resnet_time_scale_shift,
500
- cross_attention_dim=cross_attention_dim[-1],
501
- num_attention_heads=num_attention_heads[-1],
502
- resnet_groups=norm_num_groups,
503
- dual_cross_attention=dual_cross_attention,
504
- use_linear_projection=use_linear_projection,
505
- upcast_attention=upcast_attention,
506
- attention_type=attention_type,
507
- )
508
- elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
509
- self.mid_block = UNetMidBlock2DSimpleCrossAttn(
510
- in_channels=block_out_channels[-1],
511
- temb_channels=blocks_time_embed_dim,
512
- dropout=dropout,
513
- resnet_eps=norm_eps,
514
- resnet_act_fn=act_fn,
515
- output_scale_factor=mid_block_scale_factor,
516
- cross_attention_dim=cross_attention_dim[-1],
517
- attention_head_dim=attention_head_dim[-1],
518
- resnet_groups=norm_num_groups,
519
- resnet_time_scale_shift=resnet_time_scale_shift,
520
- skip_time_act=resnet_skip_time_act,
521
- only_cross_attention=mid_block_only_cross_attention,
522
- cross_attention_norm=cross_attention_norm,
523
- )
524
- elif mid_block_type == "UNetMidBlock2D":
525
- self.mid_block = UNetMidBlock2D(
526
- in_channels=block_out_channels[-1],
527
- temb_channels=blocks_time_embed_dim,
528
- dropout=dropout,
529
- num_layers=0,
530
- resnet_eps=norm_eps,
531
- resnet_act_fn=act_fn,
532
- output_scale_factor=mid_block_scale_factor,
533
- resnet_groups=norm_num_groups,
534
- resnet_time_scale_shift=resnet_time_scale_shift,
535
- add_attention=False,
536
- )
537
- elif mid_block_type is None:
538
- self.mid_block = None
539
- else:
540
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
541
-
542
- # count how many layers upsample the images
543
- self.num_upsamplers = 0
544
-
545
- # up
546
- reversed_block_out_channels = list(reversed(block_out_channels))
547
- reversed_num_attention_heads = list(reversed(num_attention_heads))
548
- reversed_layers_per_block = list(reversed(layers_per_block))
549
- reversed_cross_attention_dim = list(reversed(cross_attention_dim))
550
- reversed_transformer_layers_per_block = (
551
- list(reversed(transformer_layers_per_block))
552
- if reverse_transformer_layers_per_block is None
553
- else reverse_transformer_layers_per_block
554
- )
555
- only_cross_attention = list(reversed(only_cross_attention))
556
-
557
- output_channel = reversed_block_out_channels[0]
558
- for i, up_block_type in enumerate(up_block_types):
559
- is_final_block = i == len(block_out_channels) - 1
560
-
561
- prev_output_channel = output_channel
562
- output_channel = reversed_block_out_channels[i]
563
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
564
-
565
- # add upsample block for all BUT final layer
566
- if not is_final_block:
567
- add_upsample = True
568
- self.num_upsamplers += 1
569
- else:
570
- add_upsample = False
571
-
572
- up_block = get_up_block(
573
- up_block_type,
574
- num_layers=reversed_layers_per_block[i] + 1,
575
- transformer_layers_per_block=reversed_transformer_layers_per_block[i],
576
- in_channels=input_channel,
577
- out_channels=output_channel,
578
- prev_output_channel=prev_output_channel,
579
- temb_channels=blocks_time_embed_dim,
580
- add_upsample=add_upsample,
581
- resnet_eps=norm_eps,
582
- resnet_act_fn=act_fn,
583
- resolution_idx=i,
584
- resnet_groups=norm_num_groups,
585
- cross_attention_dim=reversed_cross_attention_dim[i],
586
- num_attention_heads=reversed_num_attention_heads[i],
587
- dual_cross_attention=dual_cross_attention,
588
- use_linear_projection=use_linear_projection,
589
- only_cross_attention=only_cross_attention[i],
590
- upcast_attention=upcast_attention,
591
- resnet_time_scale_shift=resnet_time_scale_shift,
592
- attention_type=attention_type,
593
- resnet_skip_time_act=resnet_skip_time_act,
594
- resnet_out_scale_factor=resnet_out_scale_factor,
595
- cross_attention_norm=cross_attention_norm,
596
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
597
- dropout=dropout,
598
- )
599
- self.up_blocks.append(up_block)
600
- prev_output_channel = output_channel
601
-
602
- # out
603
- if norm_num_groups is not None:
604
- self.conv_norm_out = nn.GroupNorm(
605
- num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
606
- )
607
-
608
- self.conv_act = get_activation(act_fn)
609
-
610
- else:
611
- self.conv_norm_out = None
612
- self.conv_act = None
613
-
614
- conv_out_padding = (conv_out_kernel - 1) // 2
615
- self.conv_out = nn.Conv2d(
616
- block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
617
- )
618
-
619
- if attention_type in ["gated", "gated-text-image"]:
620
- positive_len = 768
621
- if isinstance(cross_attention_dim, int):
622
- positive_len = cross_attention_dim
623
- elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
624
- positive_len = cross_attention_dim[0]
625
-
626
- feature_type = "text-only" if attention_type == "gated" else "text-image"
627
- self.position_net = PositionNet(
628
- positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
629
- )
630
-
631
- @property
632
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
633
- r"""
634
- Returns:
635
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
636
- indexed by its weight name.
637
- """
638
- # set recursively
639
- processors = {}
640
-
641
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
642
- if hasattr(module, "get_processor"):
643
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
644
-
645
- for sub_name, child in module.named_children():
646
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
647
-
648
- return processors
649
-
650
- for name, module in self.named_children():
651
- fn_recursive_add_processors(name, module, processors)
652
-
653
- return processors
654
-
655
- def set_attn_processor(
656
- self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
657
- ):
658
- r"""
659
- Sets the attention processor to use to compute attention.
660
-
661
- Parameters:
662
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
663
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
664
- for **all** `Attention` layers.
665
-
666
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
667
- processor. This is strongly recommended when setting trainable attention processors.
668
-
669
- """
670
- count = len(self.attn_processors.keys())
671
-
672
- if isinstance(processor, dict) and len(processor) != count:
673
- raise ValueError(
674
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
675
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
676
- )
677
-
678
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
679
- if hasattr(module, "set_processor"):
680
- if not isinstance(processor, dict):
681
- module.set_processor(processor, _remove_lora=_remove_lora)
682
- else:
683
- module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
684
-
685
- for sub_name, child in module.named_children():
686
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
687
-
688
- for name, module in self.named_children():
689
- fn_recursive_attn_processor(name, module, processor)
690
-
691
- def set_default_attn_processor(self):
692
- """
693
- Disables custom attention processors and sets the default attention implementation.
694
- """
695
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
696
- processor = AttnAddedKVProcessor()
697
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
698
- processor = AttnProcessor()
699
- else:
700
- raise ValueError(
701
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
702
- )
703
-
704
- self.set_attn_processor(processor, _remove_lora=True)
705
-
706
- def set_attention_slice(self, slice_size):
707
- r"""
708
- Enable sliced attention computation.
709
-
710
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
711
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
712
-
713
- Args:
714
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
715
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
716
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
717
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
718
- must be a multiple of `slice_size`.
719
- """
720
- sliceable_head_dims = []
721
-
722
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
723
- if hasattr(module, "set_attention_slice"):
724
- sliceable_head_dims.append(module.sliceable_head_dim)
725
-
726
- for child in module.children():
727
- fn_recursive_retrieve_sliceable_dims(child)
728
-
729
- # retrieve number of attention layers
730
- for module in self.children():
731
- fn_recursive_retrieve_sliceable_dims(module)
732
-
733
- num_sliceable_layers = len(sliceable_head_dims)
734
-
735
- if slice_size == "auto":
736
- # half the attention head size is usually a good trade-off between
737
- # speed and memory
738
- slice_size = [dim // 2 for dim in sliceable_head_dims]
739
- elif slice_size == "max":
740
- # make smallest slice possible
741
- slice_size = num_sliceable_layers * [1]
742
-
743
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
744
-
745
- if len(slice_size) != len(sliceable_head_dims):
746
- raise ValueError(
747
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
748
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
749
- )
750
-
751
- for i in range(len(slice_size)):
752
- size = slice_size[i]
753
- dim = sliceable_head_dims[i]
754
- if size is not None and size > dim:
755
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
756
-
757
- # Recursively walk through all the children.
758
- # Any children which exposes the set_attention_slice method
759
- # gets the message
760
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
761
- if hasattr(module, "set_attention_slice"):
762
- module.set_attention_slice(slice_size.pop())
763
-
764
- for child in module.children():
765
- fn_recursive_set_attention_slice(child, slice_size)
766
-
767
- reversed_slice_size = list(reversed(slice_size))
768
- for module in self.children():
769
- fn_recursive_set_attention_slice(module, reversed_slice_size)
770
-
771
- def _set_gradient_checkpointing(self, module, value=False):
772
- if hasattr(module, "gradient_checkpointing"):
773
- module.gradient_checkpointing = value
774
-
775
- def enable_freeu(self, s1, s2, b1, b2):
776
- r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
777
-
778
- The suffixes after the scaling factors represent the stage blocks where they are being applied.
779
-
780
- Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
781
- are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
782
-
783
- Args:
784
- s1 (`float`):
785
- Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
786
- mitigate the "oversmoothing effect" in the enhanced denoising process.
787
- s2 (`float`):
788
- Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
789
- mitigate the "oversmoothing effect" in the enhanced denoising process.
790
- b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
791
- b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
792
- """
793
- for i, upsample_block in enumerate(self.up_blocks):
794
- setattr(upsample_block, "s1", s1)
795
- setattr(upsample_block, "s2", s2)
796
- setattr(upsample_block, "b1", b1)
797
- setattr(upsample_block, "b2", b2)
798
-
799
- def disable_freeu(self):
800
- """Disables the FreeU mechanism."""
801
- freeu_keys = {"s1", "s2", "b1", "b2"}
802
- for i, upsample_block in enumerate(self.up_blocks):
803
- for k in freeu_keys:
804
- if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
805
- setattr(upsample_block, k, None)
806
-
807
- def forward(
808
- self,
809
- sample: torch.FloatTensor,
810
- spatial_attn_inputs,
811
- timestep: Union[torch.Tensor, float, int],
812
- encoder_hidden_states: torch.Tensor,
813
- class_labels: Optional[torch.Tensor] = None,
814
- timestep_cond: Optional[torch.Tensor] = None,
815
- attention_mask: Optional[torch.Tensor] = None,
816
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
817
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
818
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
819
- mid_block_additional_residual: Optional[torch.Tensor] = None,
820
- down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
821
- encoder_attention_mask: Optional[torch.Tensor] = None,
822
- return_dict: bool = True,
823
- ) -> Union[UNet2DConditionOutput, Tuple]:
824
- r"""
825
- The [`UNet2DConditionModel`] forward method.
826
-
827
- Args:
828
- sample (`torch.FloatTensor`):
829
- The noisy input tensor with the following shape `(batch, channel, height, width)`.
830
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
831
- encoder_hidden_states (`torch.FloatTensor`):
832
- The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
833
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
834
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
835
- timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
836
- Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
837
- through the `self.time_embedding` layer to obtain the timestep embeddings.
838
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
839
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
840
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
841
- negative values to the attention scores corresponding to "discard" tokens.
842
- cross_attention_kwargs (`dict`, *optional*):
843
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
844
- `self.processor` in
845
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
846
- added_cond_kwargs: (`dict`, *optional*):
847
- A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
848
- are passed along to the UNet blocks.
849
- down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
850
- A tuple of tensors that if specified are added to the residuals of down unet blocks.
851
- mid_block_additional_residual: (`torch.Tensor`, *optional*):
852
- A tensor that if specified is added to the residual of the middle unet block.
853
- encoder_attention_mask (`torch.Tensor`):
854
- A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
855
- `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
856
- which adds large negative values to the attention scores corresponding to "discard" tokens.
857
- return_dict (`bool`, *optional*, defaults to `True`):
858
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
859
- tuple.
860
- cross_attention_kwargs (`dict`, *optional*):
861
- A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
862
- added_cond_kwargs: (`dict`, *optional*):
863
- A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
864
- are passed along to the UNet blocks.
865
- down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
866
- additional residuals to be added to UNet long skip connections from down blocks to up blocks for
867
- example from ControlNet side model(s)
868
- mid_block_additional_residual (`torch.Tensor`, *optional*):
869
- additional residual to be added to UNet mid block output, for example from ControlNet side model
870
- down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
871
- additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
872
-
873
- Returns:
874
- [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
875
- If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
876
- a `tuple` is returned where the first element is the sample tensor.
877
- """
878
- # By default samples have to be AT least a multiple of the overall upsampling factor.
879
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
880
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
881
- # on the fly if necessary.
882
- default_overall_up_factor = 2**self.num_upsamplers
883
-
884
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
885
- forward_upsample_size = False
886
- upsample_size = None
887
-
888
- for dim in sample.shape[-2:]:
889
- if dim % default_overall_up_factor != 0:
890
- # Forward upsample size to force interpolation output size.
891
- forward_upsample_size = True
892
- break
893
-
894
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
895
- # expects mask of shape:
896
- # [batch, key_tokens]
897
- # adds singleton query_tokens dimension:
898
- # [batch, 1, key_tokens]
899
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
900
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
901
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
902
- if attention_mask is not None:
903
- # assume that mask is expressed as:
904
- # (1 = keep, 0 = discard)
905
- # convert mask into a bias that can be added to attention scores:
906
- # (keep = +0, discard = -10000.0)
907
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
908
- attention_mask = attention_mask.unsqueeze(1)
909
-
910
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
911
- if encoder_attention_mask is not None:
912
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
913
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
914
-
915
- # 0. center input if necessary
916
- if self.config.center_input_sample:
917
- sample = 2 * sample - 1.0
918
-
919
- # 1. time
920
- timesteps = timestep
921
- if not torch.is_tensor(timesteps):
922
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
923
- # This would be a good case for the `match` statement (Python 3.10+)
924
- is_mps = sample.device.type == "mps"
925
- if isinstance(timestep, float):
926
- dtype = torch.float32 if is_mps else torch.float64
927
- else:
928
- dtype = torch.int32 if is_mps else torch.int64
929
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
930
- elif len(timesteps.shape) == 0:
931
- timesteps = timesteps[None].to(sample.device)
932
-
933
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
934
- timesteps = timesteps.expand(sample.shape[0])
935
-
936
- t_emb = self.time_proj(timesteps)
937
-
938
- # `Timesteps` does not contain any weights and will always return f32 tensors
939
- # but time_embedding might actually be running in fp16. so we need to cast here.
940
- # there might be better ways to encapsulate this.
941
- t_emb = t_emb.to(dtype=sample.dtype)
942
-
943
- emb = self.time_embedding(t_emb, timestep_cond)
944
- aug_emb = None
945
-
946
- if self.class_embedding is not None:
947
- if class_labels is None:
948
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
949
-
950
- if self.config.class_embed_type == "timestep":
951
- class_labels = self.time_proj(class_labels)
952
-
953
- # `Timesteps` does not contain any weights and will always return f32 tensors
954
- # there might be better ways to encapsulate this.
955
- class_labels = class_labels.to(dtype=sample.dtype)
956
-
957
- class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
958
-
959
- if self.config.class_embeddings_concat:
960
- emb = torch.cat([emb, class_emb], dim=-1)
961
- else:
962
- emb = emb + class_emb
963
-
964
- if self.config.addition_embed_type == "text":
965
- aug_emb = self.add_embedding(encoder_hidden_states)
966
- elif self.config.addition_embed_type == "text_image":
967
- # Kandinsky 2.1 - style
968
- if "image_embeds" not in added_cond_kwargs:
969
- raise ValueError(
970
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
971
- )
972
-
973
- image_embs = added_cond_kwargs.get("image_embeds")
974
- text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
975
- aug_emb = self.add_embedding(text_embs, image_embs)
976
- elif self.config.addition_embed_type == "text_time":
977
- # SDXL - style
978
- if "text_embeds" not in added_cond_kwargs:
979
- raise ValueError(
980
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
981
- )
982
- text_embeds = added_cond_kwargs.get("text_embeds")
983
- if "time_ids" not in added_cond_kwargs:
984
- raise ValueError(
985
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
986
- )
987
- time_ids = added_cond_kwargs.get("time_ids")
988
- time_embeds = self.add_time_proj(time_ids.flatten())
989
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
990
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
991
- add_embeds = add_embeds.to(emb.dtype)
992
- aug_emb = self.add_embedding(add_embeds)
993
- elif self.config.addition_embed_type == "image":
994
- # Kandinsky 2.2 - style
995
- if "image_embeds" not in added_cond_kwargs:
996
- raise ValueError(
997
- f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
998
- )
999
- image_embs = added_cond_kwargs.get("image_embeds")
1000
- aug_emb = self.add_embedding(image_embs)
1001
- elif self.config.addition_embed_type == "image_hint":
1002
- # Kandinsky 2.2 - style
1003
- if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1004
- raise ValueError(
1005
- f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1006
- )
1007
- image_embs = added_cond_kwargs.get("image_embeds")
1008
- hint = added_cond_kwargs.get("hint")
1009
- aug_emb, hint = self.add_embedding(image_embs, hint)
1010
- sample = torch.cat([sample, hint], dim=1)
1011
-
1012
- emb = emb + aug_emb if aug_emb is not None else emb
1013
-
1014
- if self.time_embed_act is not None:
1015
- emb = self.time_embed_act(emb)
1016
-
1017
- if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1018
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1019
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1020
- # Kadinsky 2.1 - style
1021
- if "image_embeds" not in added_cond_kwargs:
1022
- raise ValueError(
1023
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1024
- )
1025
-
1026
- image_embeds = added_cond_kwargs.get("image_embeds")
1027
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1028
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1029
- # Kandinsky 2.2 - style
1030
- if "image_embeds" not in added_cond_kwargs:
1031
- raise ValueError(
1032
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1033
- )
1034
- image_embeds = added_cond_kwargs.get("image_embeds")
1035
- encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1036
- # 2. pre-process
1037
- sample = self.conv_in(sample)
1038
-
1039
- # 2.5 GLIGEN position net
1040
- if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1041
- cross_attention_kwargs = cross_attention_kwargs.copy()
1042
- gligen_args = cross_attention_kwargs.pop("gligen")
1043
- cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1044
-
1045
- # for spatial attention
1046
- spatial_attn_idx = 0
1047
-
1048
- # 3. down
1049
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1050
- if USE_PEFT_BACKEND:
1051
- # weight the lora layers by setting `lora_scale` for each PEFT layer
1052
- scale_lora_layers(self, lora_scale)
1053
-
1054
- is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1055
- # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1056
- is_adapter = down_intrablock_additional_residuals is not None
1057
- # maintain backward compatibility for legacy usage, where
1058
- # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1059
- # but can only use one or the other
1060
- if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1061
- deprecate(
1062
- "T2I should not use down_block_additional_residuals",
1063
- "1.3.0",
1064
- "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1065
- and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1066
- for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1067
- standard_warn=False,
1068
- )
1069
- down_intrablock_additional_residuals = down_block_additional_residuals
1070
- is_adapter = True
1071
-
1072
- down_block_res_samples = (sample,)
1073
- for downsample_block in self.down_blocks:
1074
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1075
- # For t2i-adapter CrossAttnDownBlock2D
1076
- additional_residuals = {}
1077
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
1078
- additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1079
-
1080
- sample, res_samples, spatial_attn_inputs, spatial_attn_idx = downsample_block(
1081
- hidden_states=sample,
1082
- spatial_attn_inputs=spatial_attn_inputs,
1083
- spatial_attn_idx=spatial_attn_idx,
1084
- temb=emb,
1085
- encoder_hidden_states=encoder_hidden_states,
1086
- attention_mask=attention_mask,
1087
- cross_attention_kwargs=cross_attention_kwargs,
1088
- encoder_attention_mask=encoder_attention_mask,
1089
- **additional_residuals,
1090
- )
1091
- else:
1092
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
1093
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
1094
- sample += down_intrablock_additional_residuals.pop(0)
1095
-
1096
- down_block_res_samples += res_samples
1097
-
1098
- if is_controlnet:
1099
- new_down_block_res_samples = ()
1100
-
1101
- for down_block_res_sample, down_block_additional_residual in zip(
1102
- down_block_res_samples, down_block_additional_residuals
1103
- ):
1104
- down_block_res_sample = down_block_res_sample + down_block_additional_residual
1105
- new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1106
-
1107
- down_block_res_samples = new_down_block_res_samples
1108
-
1109
- # 4. mid
1110
- if self.mid_block is not None:
1111
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1112
- sample, spatial_attn_inputs, spatial_attn_idx = self.mid_block(
1113
- sample,
1114
- spatial_attn_inputs=spatial_attn_inputs,
1115
- spatial_attn_idx=spatial_attn_idx,
1116
- temb=emb,
1117
- encoder_hidden_states=encoder_hidden_states,
1118
- attention_mask=attention_mask,
1119
- cross_attention_kwargs=cross_attention_kwargs,
1120
- encoder_attention_mask=encoder_attention_mask,
1121
- )
1122
- else:
1123
- sample = self.mid_block(sample, emb)
1124
-
1125
- # To support T2I-Adapter-XL
1126
- if (
1127
- is_adapter
1128
- and len(down_intrablock_additional_residuals) > 0
1129
- and sample.shape == down_intrablock_additional_residuals[0].shape
1130
- ):
1131
- sample += down_intrablock_additional_residuals.pop(0)
1132
-
1133
- if is_controlnet:
1134
- sample = sample + mid_block_additional_residual
1135
-
1136
- # 5. up
1137
- for i, upsample_block in enumerate(self.up_blocks):
1138
- is_final_block = i == len(self.up_blocks) - 1
1139
-
1140
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1141
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1142
-
1143
- # if we have not reached the final block and need to forward the
1144
- # upsample size, we do it here
1145
- if not is_final_block and forward_upsample_size:
1146
- upsample_size = down_block_res_samples[-1].shape[2:]
1147
-
1148
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1149
- sample, spatial_attn_inputs, spatial_attn_idx = upsample_block(
1150
- hidden_states=sample,
1151
- spatial_attn_inputs=spatial_attn_inputs,
1152
- spatial_attn_idx=spatial_attn_idx,
1153
- temb=emb,
1154
- res_hidden_states_tuple=res_samples,
1155
- encoder_hidden_states=encoder_hidden_states,
1156
- cross_attention_kwargs=cross_attention_kwargs,
1157
- upsample_size=upsample_size,
1158
- attention_mask=attention_mask,
1159
- encoder_attention_mask=encoder_attention_mask,
1160
- )
1161
- else:
1162
- sample = upsample_block(
1163
- hidden_states=sample,
1164
- temb=emb,
1165
- res_hidden_states_tuple=res_samples,
1166
- upsample_size=upsample_size,
1167
- scale=lora_scale,
1168
- )
1169
-
1170
- # 6. post-process
1171
- if self.conv_norm_out:
1172
- sample = self.conv_norm_out(sample)
1173
- sample = self.conv_act(sample)
1174
- sample = self.conv_out(sample)
1175
-
1176
- if USE_PEFT_BACKEND:
1177
- # remove `lora_scale` from each PEFT layer
1178
- unscale_lora_layers(self, lora_scale)
1179
-
1180
- if not return_dict:
1181
- return (sample,)
1182
-
1183
- return UNet2DConditionOutput(sample=sample)