Spaces:
Runtime error
Runtime error
Upload 98 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +25 -1
- images/datasets/claysculpture.png +3 -0
- images/datasets/claytoys.png +3 -0
- images/datasets/cook.png +3 -0
- images/datasets/emoji.png +3 -0
- images/datasets/fabrictoys.png +3 -0
- images/datasets/icon.png +3 -0
- images/datasets/illustration.png +3 -0
- images/datasets/inkpainting.png +3 -0
- images/datasets/jadecarving.png +3 -0
- images/datasets/landscape.png +3 -0
- images/datasets/lego.png +3 -0
- images/datasets/linedraw.png +3 -0
- images/datasets/oilpainting.png +3 -0
- images/datasets/painting.png +3 -0
- images/datasets/pencilsketch.png +3 -0
- images/datasets/portrait.png +3 -0
- images/datasets/sandart.png +3 -0
- images/datasets/sketch.png +3 -0
- images/datasets/transformer.png +3 -0
- images/datasets/woodsculpture.png +3 -0
- images/datasets/zbrush.png +3 -0
- images/i2i.png +3 -0
- images/oneshot.png +3 -0
- images/t2i.png +3 -0
- images/teaser.png +3 -0
- library/__init__.py +0 -0
- library/adafactor_fused.py +138 -0
- library/attention_processors.py +227 -0
- library/config_util.py +716 -0
- library/custom_offloading_utils.py +227 -0
- library/custom_train_functions.py +559 -0
- library/deepspeed_utils.py +139 -0
- library/device_utils.py +84 -0
- library/flux_models.py +1237 -0
- library/flux_train_utils.py +582 -0
- library/flux_train_utils_recraft.py +659 -0
- library/flux_utils.py +472 -0
- library/huggingface_util.py +84 -0
- library/hypernetwork.py +223 -0
- library/ipex/__init__.py +180 -0
- library/ipex/attention.py +177 -0
- library/ipex/diffusers.py +312 -0
- library/ipex/gradscaler.py +183 -0
- library/ipex/hijacks.py +313 -0
- library/lpw_stable_diffusion.py +1233 -0
- library/model_util.py +1356 -0
- library/original_unet.py +1919 -0
- library/sai_model_spec.py +334 -0
- library/sd3_models.py +1413 -0
.gitattributes
CHANGED
@@ -43,4 +43,28 @@ asy_results
|
|
43 |
recraft_results
|
44 |
drop
|
45 |
SplitAsy
|
46 |
-
example*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
recraft_results
|
44 |
drop
|
45 |
SplitAsy
|
46 |
+
example*images/datasets/claysculpture.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
images/datasets/claytoys.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
images/datasets/cook.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
images/datasets/emoji.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
images/datasets/fabrictoys.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
images/datasets/icon.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
images/datasets/illustration.png filter=lfs diff=lfs merge=lfs -text
|
53 |
+
images/datasets/inkpainting.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
images/datasets/jadecarving.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
images/datasets/landscape.png filter=lfs diff=lfs merge=lfs -text
|
56 |
+
images/datasets/lego.png filter=lfs diff=lfs merge=lfs -text
|
57 |
+
images/datasets/linedraw.png filter=lfs diff=lfs merge=lfs -text
|
58 |
+
images/datasets/oilpainting.png filter=lfs diff=lfs merge=lfs -text
|
59 |
+
images/datasets/painting.png filter=lfs diff=lfs merge=lfs -text
|
60 |
+
images/datasets/pencilsketch.png filter=lfs diff=lfs merge=lfs -text
|
61 |
+
images/datasets/portrait.png filter=lfs diff=lfs merge=lfs -text
|
62 |
+
images/datasets/sandart.png filter=lfs diff=lfs merge=lfs -text
|
63 |
+
images/datasets/sketch.png filter=lfs diff=lfs merge=lfs -text
|
64 |
+
images/datasets/transformer.png filter=lfs diff=lfs merge=lfs -text
|
65 |
+
images/datasets/woodsculpture.png filter=lfs diff=lfs merge=lfs -text
|
66 |
+
images/datasets/zbrush.png filter=lfs diff=lfs merge=lfs -text
|
67 |
+
images/i2i.png filter=lfs diff=lfs merge=lfs -text
|
68 |
+
images/oneshot.png filter=lfs diff=lfs merge=lfs -text
|
69 |
+
images/t2i.png filter=lfs diff=lfs merge=lfs -text
|
70 |
+
images/teaser.png filter=lfs diff=lfs merge=lfs -text
|
images/datasets/claysculpture.png
ADDED
![]() |
Git LFS Details
|
images/datasets/claytoys.png
ADDED
![]() |
Git LFS Details
|
images/datasets/cook.png
ADDED
![]() |
Git LFS Details
|
images/datasets/emoji.png
ADDED
![]() |
Git LFS Details
|
images/datasets/fabrictoys.png
ADDED
![]() |
Git LFS Details
|
images/datasets/icon.png
ADDED
![]() |
Git LFS Details
|
images/datasets/illustration.png
ADDED
![]() |
Git LFS Details
|
images/datasets/inkpainting.png
ADDED
![]() |
Git LFS Details
|
images/datasets/jadecarving.png
ADDED
![]() |
Git LFS Details
|
images/datasets/landscape.png
ADDED
![]() |
Git LFS Details
|
images/datasets/lego.png
ADDED
![]() |
Git LFS Details
|
images/datasets/linedraw.png
ADDED
![]() |
Git LFS Details
|
images/datasets/oilpainting.png
ADDED
![]() |
Git LFS Details
|
images/datasets/painting.png
ADDED
![]() |
Git LFS Details
|
images/datasets/pencilsketch.png
ADDED
![]() |
Git LFS Details
|
images/datasets/portrait.png
ADDED
![]() |
Git LFS Details
|
images/datasets/sandart.png
ADDED
![]() |
Git LFS Details
|
images/datasets/sketch.png
ADDED
![]() |
Git LFS Details
|
images/datasets/transformer.png
ADDED
![]() |
Git LFS Details
|
images/datasets/woodsculpture.png
ADDED
![]() |
Git LFS Details
|
images/datasets/zbrush.png
ADDED
![]() |
Git LFS Details
|
images/i2i.png
ADDED
![]() |
Git LFS Details
|
images/oneshot.png
ADDED
![]() |
Git LFS Details
|
images/t2i.png
ADDED
![]() |
Git LFS Details
|
images/teaser.png
ADDED
![]() |
Git LFS Details
|
library/__init__.py
ADDED
File without changes
|
library/adafactor_fused.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from transformers import Adafactor
|
4 |
+
|
5 |
+
# stochastic rounding for bfloat16
|
6 |
+
# The implementation was provided by 2kpr. Thank you very much!
|
7 |
+
|
8 |
+
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
9 |
+
"""
|
10 |
+
copies source into target using stochastic rounding
|
11 |
+
|
12 |
+
Args:
|
13 |
+
target: the target tensor with dtype=bfloat16
|
14 |
+
source: the target tensor with dtype=float32
|
15 |
+
"""
|
16 |
+
# create a random 16 bit integer
|
17 |
+
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
18 |
+
|
19 |
+
# add the random number to the lower 16 bit of the mantissa
|
20 |
+
result.add_(source.view(dtype=torch.int32))
|
21 |
+
|
22 |
+
# mask off the lower 16 bit of the mantissa
|
23 |
+
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
24 |
+
|
25 |
+
# copy the higher 16 bit into the target tensor
|
26 |
+
target.copy_(result.view(dtype=torch.float32))
|
27 |
+
|
28 |
+
del result
|
29 |
+
|
30 |
+
|
31 |
+
@torch.no_grad()
|
32 |
+
def adafactor_step_param(self, p, group):
|
33 |
+
if p.grad is None:
|
34 |
+
return
|
35 |
+
grad = p.grad
|
36 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
37 |
+
grad = grad.float()
|
38 |
+
if grad.is_sparse:
|
39 |
+
raise RuntimeError("Adafactor does not support sparse gradients.")
|
40 |
+
|
41 |
+
state = self.state[p]
|
42 |
+
grad_shape = grad.shape
|
43 |
+
|
44 |
+
factored, use_first_moment = Adafactor._get_options(group, grad_shape)
|
45 |
+
# State Initialization
|
46 |
+
if len(state) == 0:
|
47 |
+
state["step"] = 0
|
48 |
+
|
49 |
+
if use_first_moment:
|
50 |
+
# Exponential moving average of gradient values
|
51 |
+
state["exp_avg"] = torch.zeros_like(grad)
|
52 |
+
if factored:
|
53 |
+
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
54 |
+
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
55 |
+
else:
|
56 |
+
state["exp_avg_sq"] = torch.zeros_like(grad)
|
57 |
+
|
58 |
+
state["RMS"] = 0
|
59 |
+
else:
|
60 |
+
if use_first_moment:
|
61 |
+
state["exp_avg"] = state["exp_avg"].to(grad)
|
62 |
+
if factored:
|
63 |
+
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
64 |
+
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
65 |
+
else:
|
66 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
67 |
+
|
68 |
+
p_data_fp32 = p
|
69 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
70 |
+
p_data_fp32 = p_data_fp32.float()
|
71 |
+
|
72 |
+
state["step"] += 1
|
73 |
+
state["RMS"] = Adafactor._rms(p_data_fp32)
|
74 |
+
lr = Adafactor._get_lr(group, state)
|
75 |
+
|
76 |
+
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
77 |
+
update = (grad**2) + group["eps"][0]
|
78 |
+
if factored:
|
79 |
+
exp_avg_sq_row = state["exp_avg_sq_row"]
|
80 |
+
exp_avg_sq_col = state["exp_avg_sq_col"]
|
81 |
+
|
82 |
+
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
83 |
+
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
84 |
+
|
85 |
+
# Approximation of exponential moving average of square of gradient
|
86 |
+
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
87 |
+
update.mul_(grad)
|
88 |
+
else:
|
89 |
+
exp_avg_sq = state["exp_avg_sq"]
|
90 |
+
|
91 |
+
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
92 |
+
update = exp_avg_sq.rsqrt().mul_(grad)
|
93 |
+
|
94 |
+
update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
95 |
+
update.mul_(lr)
|
96 |
+
|
97 |
+
if use_first_moment:
|
98 |
+
exp_avg = state["exp_avg"]
|
99 |
+
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
|
100 |
+
update = exp_avg
|
101 |
+
|
102 |
+
if group["weight_decay"] != 0:
|
103 |
+
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
104 |
+
|
105 |
+
p_data_fp32.add_(-update)
|
106 |
+
|
107 |
+
# if p.dtype in {torch.float16, torch.bfloat16}:
|
108 |
+
# p.copy_(p_data_fp32)
|
109 |
+
|
110 |
+
if p.dtype == torch.bfloat16:
|
111 |
+
copy_stochastic_(p, p_data_fp32)
|
112 |
+
elif p.dtype == torch.float16:
|
113 |
+
p.copy_(p_data_fp32)
|
114 |
+
|
115 |
+
|
116 |
+
@torch.no_grad()
|
117 |
+
def adafactor_step(self, closure=None):
|
118 |
+
"""
|
119 |
+
Performs a single optimization step
|
120 |
+
|
121 |
+
Arguments:
|
122 |
+
closure (callable, optional): A closure that reevaluates the model
|
123 |
+
and returns the loss.
|
124 |
+
"""
|
125 |
+
loss = None
|
126 |
+
if closure is not None:
|
127 |
+
loss = closure()
|
128 |
+
|
129 |
+
for group in self.param_groups:
|
130 |
+
for p in group["params"]:
|
131 |
+
adafactor_step_param(self, p, group)
|
132 |
+
|
133 |
+
return loss
|
134 |
+
|
135 |
+
|
136 |
+
def patch_adafactor_fused(optimizer: Adafactor):
|
137 |
+
optimizer.step_param = adafactor_step_param.__get__(optimizer)
|
138 |
+
optimizer.step = adafactor_step.__get__(optimizer)
|
library/attention_processors.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any
|
3 |
+
from einops import rearrange
|
4 |
+
import torch
|
5 |
+
from diffusers.models.attention_processor import Attention
|
6 |
+
|
7 |
+
|
8 |
+
# flash attention forwards and backwards
|
9 |
+
|
10 |
+
# https://arxiv.org/abs/2205.14135
|
11 |
+
|
12 |
+
EPSILON = 1e-6
|
13 |
+
|
14 |
+
|
15 |
+
class FlashAttentionFunction(torch.autograd.function.Function):
|
16 |
+
@staticmethod
|
17 |
+
@torch.no_grad()
|
18 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
19 |
+
"""Algorithm 2 in the paper"""
|
20 |
+
|
21 |
+
device = q.device
|
22 |
+
dtype = q.dtype
|
23 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
24 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
25 |
+
|
26 |
+
o = torch.zeros_like(q)
|
27 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
28 |
+
all_row_maxes = torch.full(
|
29 |
+
(*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
|
30 |
+
)
|
31 |
+
|
32 |
+
scale = q.shape[-1] ** -0.5
|
33 |
+
|
34 |
+
if mask is None:
|
35 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
36 |
+
else:
|
37 |
+
mask = rearrange(mask, "b n -> b 1 1 n")
|
38 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
39 |
+
|
40 |
+
row_splits = zip(
|
41 |
+
q.split(q_bucket_size, dim=-2),
|
42 |
+
o.split(q_bucket_size, dim=-2),
|
43 |
+
mask,
|
44 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
45 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
46 |
+
)
|
47 |
+
|
48 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
49 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
50 |
+
|
51 |
+
col_splits = zip(
|
52 |
+
k.split(k_bucket_size, dim=-2),
|
53 |
+
v.split(k_bucket_size, dim=-2),
|
54 |
+
)
|
55 |
+
|
56 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
57 |
+
k_start_index = k_ind * k_bucket_size
|
58 |
+
|
59 |
+
attn_weights = (
|
60 |
+
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
61 |
+
)
|
62 |
+
|
63 |
+
if row_mask is not None:
|
64 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
65 |
+
|
66 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
67 |
+
causal_mask = torch.ones(
|
68 |
+
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
69 |
+
).triu(q_start_index - k_start_index + 1)
|
70 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
71 |
+
|
72 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
73 |
+
attn_weights -= block_row_maxes
|
74 |
+
exp_weights = torch.exp(attn_weights)
|
75 |
+
|
76 |
+
if row_mask is not None:
|
77 |
+
exp_weights.masked_fill_(~row_mask, 0.0)
|
78 |
+
|
79 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
|
80 |
+
min=EPSILON
|
81 |
+
)
|
82 |
+
|
83 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
84 |
+
|
85 |
+
exp_values = torch.einsum(
|
86 |
+
"... i j, ... j d -> ... i d", exp_weights, vc
|
87 |
+
)
|
88 |
+
|
89 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
90 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
91 |
+
|
92 |
+
new_row_sums = (
|
93 |
+
exp_row_max_diff * row_sums
|
94 |
+
+ exp_block_row_max_diff * block_row_sums
|
95 |
+
)
|
96 |
+
|
97 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
|
98 |
+
(exp_block_row_max_diff / new_row_sums) * exp_values
|
99 |
+
)
|
100 |
+
|
101 |
+
row_maxes.copy_(new_row_maxes)
|
102 |
+
row_sums.copy_(new_row_sums)
|
103 |
+
|
104 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
105 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
106 |
+
|
107 |
+
return o
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
@torch.no_grad()
|
111 |
+
def backward(ctx, do):
|
112 |
+
"""Algorithm 4 in the paper"""
|
113 |
+
|
114 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
115 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
116 |
+
|
117 |
+
device = q.device
|
118 |
+
|
119 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
120 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
121 |
+
|
122 |
+
dq = torch.zeros_like(q)
|
123 |
+
dk = torch.zeros_like(k)
|
124 |
+
dv = torch.zeros_like(v)
|
125 |
+
|
126 |
+
row_splits = zip(
|
127 |
+
q.split(q_bucket_size, dim=-2),
|
128 |
+
o.split(q_bucket_size, dim=-2),
|
129 |
+
do.split(q_bucket_size, dim=-2),
|
130 |
+
mask,
|
131 |
+
l.split(q_bucket_size, dim=-2),
|
132 |
+
m.split(q_bucket_size, dim=-2),
|
133 |
+
dq.split(q_bucket_size, dim=-2),
|
134 |
+
)
|
135 |
+
|
136 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
137 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
138 |
+
|
139 |
+
col_splits = zip(
|
140 |
+
k.split(k_bucket_size, dim=-2),
|
141 |
+
v.split(k_bucket_size, dim=-2),
|
142 |
+
dk.split(k_bucket_size, dim=-2),
|
143 |
+
dv.split(k_bucket_size, dim=-2),
|
144 |
+
)
|
145 |
+
|
146 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
147 |
+
k_start_index = k_ind * k_bucket_size
|
148 |
+
|
149 |
+
attn_weights = (
|
150 |
+
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
151 |
+
)
|
152 |
+
|
153 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
154 |
+
causal_mask = torch.ones(
|
155 |
+
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
156 |
+
).triu(q_start_index - k_start_index + 1)
|
157 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
158 |
+
|
159 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
160 |
+
|
161 |
+
if row_mask is not None:
|
162 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
163 |
+
|
164 |
+
p = exp_attn_weights / lc
|
165 |
+
|
166 |
+
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
167 |
+
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
168 |
+
|
169 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
170 |
+
ds = p * scale * (dp - D)
|
171 |
+
|
172 |
+
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
173 |
+
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
174 |
+
|
175 |
+
dqc.add_(dq_chunk)
|
176 |
+
dkc.add_(dk_chunk)
|
177 |
+
dvc.add_(dv_chunk)
|
178 |
+
|
179 |
+
return dq, dk, dv, None, None, None, None
|
180 |
+
|
181 |
+
|
182 |
+
class FlashAttnProcessor:
|
183 |
+
def __call__(
|
184 |
+
self,
|
185 |
+
attn: Attention,
|
186 |
+
hidden_states,
|
187 |
+
encoder_hidden_states=None,
|
188 |
+
attention_mask=None,
|
189 |
+
) -> Any:
|
190 |
+
q_bucket_size = 512
|
191 |
+
k_bucket_size = 1024
|
192 |
+
|
193 |
+
h = attn.heads
|
194 |
+
q = attn.to_q(hidden_states)
|
195 |
+
|
196 |
+
encoder_hidden_states = (
|
197 |
+
encoder_hidden_states
|
198 |
+
if encoder_hidden_states is not None
|
199 |
+
else hidden_states
|
200 |
+
)
|
201 |
+
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
|
202 |
+
|
203 |
+
if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
|
204 |
+
context_k, context_v = attn.hypernetwork.forward(
|
205 |
+
hidden_states, encoder_hidden_states
|
206 |
+
)
|
207 |
+
context_k = context_k.to(hidden_states.dtype)
|
208 |
+
context_v = context_v.to(hidden_states.dtype)
|
209 |
+
else:
|
210 |
+
context_k = encoder_hidden_states
|
211 |
+
context_v = encoder_hidden_states
|
212 |
+
|
213 |
+
k = attn.to_k(context_k)
|
214 |
+
v = attn.to_v(context_v)
|
215 |
+
del encoder_hidden_states, hidden_states
|
216 |
+
|
217 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
218 |
+
|
219 |
+
out = FlashAttentionFunction.apply(
|
220 |
+
q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
|
221 |
+
)
|
222 |
+
|
223 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
224 |
+
|
225 |
+
out = attn.to_out[0](out)
|
226 |
+
out = attn.to_out[1](out)
|
227 |
+
return out
|
library/config_util.py
ADDED
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from dataclasses import (
|
3 |
+
asdict,
|
4 |
+
dataclass,
|
5 |
+
)
|
6 |
+
import functools
|
7 |
+
import random
|
8 |
+
from textwrap import dedent, indent
|
9 |
+
import json
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
# from toolz import curry
|
13 |
+
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
14 |
+
|
15 |
+
import toml
|
16 |
+
import voluptuous
|
17 |
+
from voluptuous import (
|
18 |
+
Any,
|
19 |
+
ExactSequence,
|
20 |
+
MultipleInvalid,
|
21 |
+
Object,
|
22 |
+
Required,
|
23 |
+
Schema,
|
24 |
+
)
|
25 |
+
from transformers import CLIPTokenizer
|
26 |
+
|
27 |
+
from . import train_util
|
28 |
+
from .train_util import (
|
29 |
+
DreamBoothSubset,
|
30 |
+
FineTuningSubset,
|
31 |
+
ControlNetSubset,
|
32 |
+
DreamBoothDataset,
|
33 |
+
FineTuningDataset,
|
34 |
+
ControlNetDataset,
|
35 |
+
DatasetGroup,
|
36 |
+
)
|
37 |
+
from .utils import setup_logging
|
38 |
+
|
39 |
+
setup_logging()
|
40 |
+
import logging
|
41 |
+
|
42 |
+
logger = logging.getLogger(__name__)
|
43 |
+
|
44 |
+
|
45 |
+
def add_config_arguments(parser: argparse.ArgumentParser):
|
46 |
+
parser.add_argument(
|
47 |
+
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
# TODO: inherit Params class in Subset, Dataset
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class BaseSubsetParams:
|
56 |
+
image_dir: Optional[str] = None
|
57 |
+
num_repeats: int = 1
|
58 |
+
shuffle_caption: bool = False
|
59 |
+
caption_separator: str = (",",)
|
60 |
+
keep_tokens: int = 0
|
61 |
+
keep_tokens_separator: str = (None,)
|
62 |
+
secondary_separator: Optional[str] = None
|
63 |
+
enable_wildcard: bool = False
|
64 |
+
color_aug: bool = False
|
65 |
+
flip_aug: bool = False
|
66 |
+
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
67 |
+
random_crop: bool = False
|
68 |
+
caption_prefix: Optional[str] = None
|
69 |
+
caption_suffix: Optional[str] = None
|
70 |
+
caption_dropout_rate: float = 0.0
|
71 |
+
caption_dropout_every_n_epochs: int = 0
|
72 |
+
caption_tag_dropout_rate: float = 0.0
|
73 |
+
token_warmup_min: int = 1
|
74 |
+
token_warmup_step: float = 0
|
75 |
+
custom_attributes: Optional[Dict[str, Any]] = None
|
76 |
+
|
77 |
+
|
78 |
+
@dataclass
|
79 |
+
class DreamBoothSubsetParams(BaseSubsetParams):
|
80 |
+
is_reg: bool = False
|
81 |
+
class_tokens: Optional[str] = None
|
82 |
+
caption_extension: str = ".caption"
|
83 |
+
cache_info: bool = False
|
84 |
+
alpha_mask: bool = False
|
85 |
+
|
86 |
+
|
87 |
+
@dataclass
|
88 |
+
class FineTuningSubsetParams(BaseSubsetParams):
|
89 |
+
metadata_file: Optional[str] = None
|
90 |
+
alpha_mask: bool = False
|
91 |
+
|
92 |
+
|
93 |
+
@dataclass
|
94 |
+
class ControlNetSubsetParams(BaseSubsetParams):
|
95 |
+
conditioning_data_dir: str = None
|
96 |
+
caption_extension: str = ".caption"
|
97 |
+
cache_info: bool = False
|
98 |
+
|
99 |
+
|
100 |
+
@dataclass
|
101 |
+
class BaseDatasetParams:
|
102 |
+
resolution: Optional[Tuple[int, int]] = None
|
103 |
+
network_multiplier: float = 1.0
|
104 |
+
debug_dataset: bool = False
|
105 |
+
|
106 |
+
|
107 |
+
@dataclass
|
108 |
+
class DreamBoothDatasetParams(BaseDatasetParams):
|
109 |
+
batch_size: int = 1
|
110 |
+
enable_bucket: bool = False
|
111 |
+
min_bucket_reso: int = 256
|
112 |
+
max_bucket_reso: int = 1024
|
113 |
+
bucket_reso_steps: int = 64
|
114 |
+
bucket_no_upscale: bool = False
|
115 |
+
prior_loss_weight: float = 1.0
|
116 |
+
|
117 |
+
|
118 |
+
@dataclass
|
119 |
+
class FineTuningDatasetParams(BaseDatasetParams):
|
120 |
+
batch_size: int = 1
|
121 |
+
enable_bucket: bool = False
|
122 |
+
min_bucket_reso: int = 256
|
123 |
+
max_bucket_reso: int = 1024
|
124 |
+
bucket_reso_steps: int = 64
|
125 |
+
bucket_no_upscale: bool = False
|
126 |
+
|
127 |
+
|
128 |
+
@dataclass
|
129 |
+
class ControlNetDatasetParams(BaseDatasetParams):
|
130 |
+
batch_size: int = 1
|
131 |
+
enable_bucket: bool = False
|
132 |
+
min_bucket_reso: int = 256
|
133 |
+
max_bucket_reso: int = 1024
|
134 |
+
bucket_reso_steps: int = 64
|
135 |
+
bucket_no_upscale: bool = False
|
136 |
+
|
137 |
+
|
138 |
+
@dataclass
|
139 |
+
class SubsetBlueprint:
|
140 |
+
params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
|
141 |
+
|
142 |
+
|
143 |
+
@dataclass
|
144 |
+
class DatasetBlueprint:
|
145 |
+
is_dreambooth: bool
|
146 |
+
is_controlnet: bool
|
147 |
+
params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
|
148 |
+
subsets: Sequence[SubsetBlueprint]
|
149 |
+
|
150 |
+
|
151 |
+
@dataclass
|
152 |
+
class DatasetGroupBlueprint:
|
153 |
+
datasets: Sequence[DatasetBlueprint]
|
154 |
+
|
155 |
+
|
156 |
+
@dataclass
|
157 |
+
class Blueprint:
|
158 |
+
dataset_group: DatasetGroupBlueprint
|
159 |
+
|
160 |
+
|
161 |
+
class ConfigSanitizer:
|
162 |
+
# @curry
|
163 |
+
@staticmethod
|
164 |
+
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
165 |
+
Schema(ExactSequence([klass, klass]))(value)
|
166 |
+
return tuple(value)
|
167 |
+
|
168 |
+
# @curry
|
169 |
+
@staticmethod
|
170 |
+
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
171 |
+
Schema(Any(klass, ExactSequence([klass, klass])))(value)
|
172 |
+
try:
|
173 |
+
Schema(klass)(value)
|
174 |
+
return (value, value)
|
175 |
+
except:
|
176 |
+
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
177 |
+
|
178 |
+
# subset schema
|
179 |
+
SUBSET_ASCENDABLE_SCHEMA = {
|
180 |
+
"color_aug": bool,
|
181 |
+
"face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
|
182 |
+
"flip_aug": bool,
|
183 |
+
"num_repeats": int,
|
184 |
+
"random_crop": bool,
|
185 |
+
"shuffle_caption": bool,
|
186 |
+
"keep_tokens": int,
|
187 |
+
"keep_tokens_separator": str,
|
188 |
+
"secondary_separator": str,
|
189 |
+
"caption_separator": str,
|
190 |
+
"enable_wildcard": bool,
|
191 |
+
"token_warmup_min": int,
|
192 |
+
"token_warmup_step": Any(float, int),
|
193 |
+
"caption_prefix": str,
|
194 |
+
"caption_suffix": str,
|
195 |
+
"custom_attributes": dict,
|
196 |
+
}
|
197 |
+
# DO means DropOut
|
198 |
+
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
199 |
+
"caption_dropout_every_n_epochs": int,
|
200 |
+
"caption_dropout_rate": Any(float, int),
|
201 |
+
"caption_tag_dropout_rate": Any(float, int),
|
202 |
+
}
|
203 |
+
# DB means DreamBooth
|
204 |
+
DB_SUBSET_ASCENDABLE_SCHEMA = {
|
205 |
+
"caption_extension": str,
|
206 |
+
"class_tokens": str,
|
207 |
+
"cache_info": bool,
|
208 |
+
}
|
209 |
+
DB_SUBSET_DISTINCT_SCHEMA = {
|
210 |
+
Required("image_dir"): str,
|
211 |
+
"is_reg": bool,
|
212 |
+
"alpha_mask": bool,
|
213 |
+
}
|
214 |
+
# FT means FineTuning
|
215 |
+
FT_SUBSET_DISTINCT_SCHEMA = {
|
216 |
+
Required("metadata_file"): str,
|
217 |
+
"image_dir": str,
|
218 |
+
"alpha_mask": bool,
|
219 |
+
}
|
220 |
+
CN_SUBSET_ASCENDABLE_SCHEMA = {
|
221 |
+
"caption_extension": str,
|
222 |
+
"cache_info": bool,
|
223 |
+
}
|
224 |
+
CN_SUBSET_DISTINCT_SCHEMA = {
|
225 |
+
Required("image_dir"): str,
|
226 |
+
Required("conditioning_data_dir"): str,
|
227 |
+
}
|
228 |
+
|
229 |
+
# datasets schema
|
230 |
+
DATASET_ASCENDABLE_SCHEMA = {
|
231 |
+
"batch_size": int,
|
232 |
+
"bucket_no_upscale": bool,
|
233 |
+
"bucket_reso_steps": int,
|
234 |
+
"enable_bucket": bool,
|
235 |
+
"max_bucket_reso": int,
|
236 |
+
"min_bucket_reso": int,
|
237 |
+
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
238 |
+
"network_multiplier": float,
|
239 |
+
}
|
240 |
+
|
241 |
+
# options handled by argparse but not handled by user config
|
242 |
+
ARGPARSE_SPECIFIC_SCHEMA = {
|
243 |
+
"debug_dataset": bool,
|
244 |
+
"max_token_length": Any(None, int),
|
245 |
+
"prior_loss_weight": Any(float, int),
|
246 |
+
}
|
247 |
+
# for handling default None value of argparse
|
248 |
+
ARGPARSE_NULLABLE_OPTNAMES = [
|
249 |
+
"face_crop_aug_range",
|
250 |
+
"resolution",
|
251 |
+
]
|
252 |
+
# prepare map because option name may differ among argparse and user config
|
253 |
+
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
|
254 |
+
"train_batch_size": "batch_size",
|
255 |
+
"dataset_repeats": "num_repeats",
|
256 |
+
}
|
257 |
+
|
258 |
+
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
259 |
+
assert support_dreambooth or support_finetuning or support_controlnet, (
|
260 |
+
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
|
261 |
+
+ " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
|
262 |
+
)
|
263 |
+
|
264 |
+
self.db_subset_schema = self.__merge_dict(
|
265 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
266 |
+
self.DB_SUBSET_DISTINCT_SCHEMA,
|
267 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
268 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
269 |
+
)
|
270 |
+
|
271 |
+
self.ft_subset_schema = self.__merge_dict(
|
272 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
273 |
+
self.FT_SUBSET_DISTINCT_SCHEMA,
|
274 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
275 |
+
)
|
276 |
+
|
277 |
+
self.cn_subset_schema = self.__merge_dict(
|
278 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
279 |
+
self.CN_SUBSET_DISTINCT_SCHEMA,
|
280 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
281 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
282 |
+
)
|
283 |
+
|
284 |
+
self.db_dataset_schema = self.__merge_dict(
|
285 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
286 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
287 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
288 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
289 |
+
{"subsets": [self.db_subset_schema]},
|
290 |
+
)
|
291 |
+
|
292 |
+
self.ft_dataset_schema = self.__merge_dict(
|
293 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
294 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
295 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
296 |
+
{"subsets": [self.ft_subset_schema]},
|
297 |
+
)
|
298 |
+
|
299 |
+
self.cn_dataset_schema = self.__merge_dict(
|
300 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
301 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
302 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
303 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
304 |
+
{"subsets": [self.cn_subset_schema]},
|
305 |
+
)
|
306 |
+
|
307 |
+
if support_dreambooth and support_finetuning:
|
308 |
+
|
309 |
+
def validate_flex_dataset(dataset_config: dict):
|
310 |
+
subsets_config = dataset_config.get("subsets", [])
|
311 |
+
|
312 |
+
if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
|
313 |
+
return Schema(self.cn_dataset_schema)(dataset_config)
|
314 |
+
# check dataset meets FT style
|
315 |
+
# NOTE: all FT subsets should have "metadata_file"
|
316 |
+
elif all(["metadata_file" in subset for subset in subsets_config]):
|
317 |
+
return Schema(self.ft_dataset_schema)(dataset_config)
|
318 |
+
# check dataset meets DB style
|
319 |
+
# NOTE: all DB subsets should have no "metadata_file"
|
320 |
+
elif all(["metadata_file" not in subset for subset in subsets_config]):
|
321 |
+
return Schema(self.db_dataset_schema)(dataset_config)
|
322 |
+
else:
|
323 |
+
raise voluptuous.Invalid(
|
324 |
+
"DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
|
325 |
+
)
|
326 |
+
|
327 |
+
self.dataset_schema = validate_flex_dataset
|
328 |
+
elif support_dreambooth:
|
329 |
+
if support_controlnet:
|
330 |
+
self.dataset_schema = self.cn_dataset_schema
|
331 |
+
else:
|
332 |
+
self.dataset_schema = self.db_dataset_schema
|
333 |
+
elif support_finetuning:
|
334 |
+
self.dataset_schema = self.ft_dataset_schema
|
335 |
+
elif support_controlnet:
|
336 |
+
self.dataset_schema = self.cn_dataset_schema
|
337 |
+
|
338 |
+
self.general_schema = self.__merge_dict(
|
339 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
340 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
341 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
|
342 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
|
343 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
344 |
+
)
|
345 |
+
|
346 |
+
self.user_config_validator = Schema(
|
347 |
+
{
|
348 |
+
"general": self.general_schema,
|
349 |
+
"datasets": [self.dataset_schema],
|
350 |
+
}
|
351 |
+
)
|
352 |
+
|
353 |
+
self.argparse_schema = self.__merge_dict(
|
354 |
+
self.general_schema,
|
355 |
+
self.ARGPARSE_SPECIFIC_SCHEMA,
|
356 |
+
{optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
|
357 |
+
{a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
|
358 |
+
)
|
359 |
+
|
360 |
+
self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
|
361 |
+
|
362 |
+
def sanitize_user_config(self, user_config: dict) -> dict:
|
363 |
+
try:
|
364 |
+
return self.user_config_validator(user_config)
|
365 |
+
except MultipleInvalid:
|
366 |
+
# TODO: エラー発生時のメッセージをわかりやすくする
|
367 |
+
logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
|
368 |
+
raise
|
369 |
+
|
370 |
+
# NOTE: In nature, argument parser result is not needed to be sanitize
|
371 |
+
# However this will help us to detect program bug
|
372 |
+
def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
|
373 |
+
try:
|
374 |
+
return self.argparse_config_validator(argparse_namespace)
|
375 |
+
except MultipleInvalid:
|
376 |
+
# XXX: this should be a bug
|
377 |
+
logger.error(
|
378 |
+
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
|
379 |
+
)
|
380 |
+
raise
|
381 |
+
|
382 |
+
# NOTE: value would be overwritten by latter dict if there is already the same key
|
383 |
+
@staticmethod
|
384 |
+
def __merge_dict(*dict_list: dict) -> dict:
|
385 |
+
merged = {}
|
386 |
+
for schema in dict_list:
|
387 |
+
# merged |= schema
|
388 |
+
for k, v in schema.items():
|
389 |
+
merged[k] = v
|
390 |
+
return merged
|
391 |
+
|
392 |
+
|
393 |
+
class BlueprintGenerator:
|
394 |
+
BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
|
395 |
+
|
396 |
+
def __init__(self, sanitizer: ConfigSanitizer):
|
397 |
+
self.sanitizer = sanitizer
|
398 |
+
|
399 |
+
# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
|
400 |
+
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
|
401 |
+
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
|
402 |
+
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
|
403 |
+
|
404 |
+
# convert argparse namespace to dict like config
|
405 |
+
# NOTE: it is ok to have extra entries in dict
|
406 |
+
optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
|
407 |
+
argparse_config = {
|
408 |
+
optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()
|
409 |
+
}
|
410 |
+
|
411 |
+
general_config = sanitized_user_config.get("general", {})
|
412 |
+
|
413 |
+
dataset_blueprints = []
|
414 |
+
for dataset_config in sanitized_user_config.get("datasets", []):
|
415 |
+
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
|
416 |
+
subsets = dataset_config.get("subsets", [])
|
417 |
+
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
|
418 |
+
is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
|
419 |
+
if is_controlnet:
|
420 |
+
subset_params_klass = ControlNetSubsetParams
|
421 |
+
dataset_params_klass = ControlNetDatasetParams
|
422 |
+
elif is_dreambooth:
|
423 |
+
subset_params_klass = DreamBoothSubsetParams
|
424 |
+
dataset_params_klass = DreamBoothDatasetParams
|
425 |
+
else:
|
426 |
+
subset_params_klass = FineTuningSubsetParams
|
427 |
+
dataset_params_klass = FineTuningDatasetParams
|
428 |
+
|
429 |
+
subset_blueprints = []
|
430 |
+
for subset_config in subsets:
|
431 |
+
params = self.generate_params_by_fallbacks(
|
432 |
+
subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]
|
433 |
+
)
|
434 |
+
subset_blueprints.append(SubsetBlueprint(params))
|
435 |
+
|
436 |
+
params = self.generate_params_by_fallbacks(
|
437 |
+
dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
|
438 |
+
)
|
439 |
+
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
|
440 |
+
|
441 |
+
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
|
442 |
+
|
443 |
+
return Blueprint(dataset_group_blueprint)
|
444 |
+
|
445 |
+
@staticmethod
|
446 |
+
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
|
447 |
+
name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
|
448 |
+
search_value = BlueprintGenerator.search_value
|
449 |
+
default_params = asdict(param_klass())
|
450 |
+
param_names = default_params.keys()
|
451 |
+
|
452 |
+
params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
|
453 |
+
|
454 |
+
return param_klass(**params)
|
455 |
+
|
456 |
+
@staticmethod
|
457 |
+
def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
|
458 |
+
for cand in fallbacks:
|
459 |
+
value = cand.get(key)
|
460 |
+
if value is not None:
|
461 |
+
return value
|
462 |
+
|
463 |
+
return default_value
|
464 |
+
|
465 |
+
|
466 |
+
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
467 |
+
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
468 |
+
|
469 |
+
for dataset_blueprint in dataset_group_blueprint.datasets:
|
470 |
+
if dataset_blueprint.is_controlnet:
|
471 |
+
subset_klass = ControlNetSubset
|
472 |
+
dataset_klass = ControlNetDataset
|
473 |
+
elif dataset_blueprint.is_dreambooth:
|
474 |
+
subset_klass = DreamBoothSubset
|
475 |
+
dataset_klass = DreamBoothDataset
|
476 |
+
else:
|
477 |
+
subset_klass = FineTuningSubset
|
478 |
+
dataset_klass = FineTuningDataset
|
479 |
+
|
480 |
+
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
481 |
+
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
482 |
+
datasets.append(dataset)
|
483 |
+
|
484 |
+
# print info
|
485 |
+
info = ""
|
486 |
+
for i, dataset in enumerate(datasets):
|
487 |
+
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
488 |
+
is_controlnet = isinstance(dataset, ControlNetDataset)
|
489 |
+
info += dedent(
|
490 |
+
f"""\
|
491 |
+
[Dataset {i}]
|
492 |
+
batch_size: {dataset.batch_size}
|
493 |
+
resolution: {(dataset.width, dataset.height)}
|
494 |
+
enable_bucket: {dataset.enable_bucket}
|
495 |
+
network_multiplier: {dataset.network_multiplier}
|
496 |
+
"""
|
497 |
+
)
|
498 |
+
|
499 |
+
if dataset.enable_bucket:
|
500 |
+
info += indent(
|
501 |
+
dedent(
|
502 |
+
f"""\
|
503 |
+
min_bucket_reso: {dataset.min_bucket_reso}
|
504 |
+
max_bucket_reso: {dataset.max_bucket_reso}
|
505 |
+
bucket_reso_steps: {dataset.bucket_reso_steps}
|
506 |
+
bucket_no_upscale: {dataset.bucket_no_upscale}
|
507 |
+
\n"""
|
508 |
+
),
|
509 |
+
" ",
|
510 |
+
)
|
511 |
+
else:
|
512 |
+
info += "\n"
|
513 |
+
|
514 |
+
for j, subset in enumerate(dataset.subsets):
|
515 |
+
info += indent(
|
516 |
+
dedent(
|
517 |
+
f"""\
|
518 |
+
[Subset {j} of Dataset {i}]
|
519 |
+
image_dir: "{subset.image_dir}"
|
520 |
+
image_count: {subset.img_count}
|
521 |
+
num_repeats: {subset.num_repeats}
|
522 |
+
shuffle_caption: {subset.shuffle_caption}
|
523 |
+
keep_tokens: {subset.keep_tokens}
|
524 |
+
keep_tokens_separator: {subset.keep_tokens_separator}
|
525 |
+
caption_separator: {subset.caption_separator}
|
526 |
+
secondary_separator: {subset.secondary_separator}
|
527 |
+
enable_wildcard: {subset.enable_wildcard}
|
528 |
+
caption_dropout_rate: {subset.caption_dropout_rate}
|
529 |
+
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
|
530 |
+
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
531 |
+
caption_prefix: {subset.caption_prefix}
|
532 |
+
caption_suffix: {subset.caption_suffix}
|
533 |
+
color_aug: {subset.color_aug}
|
534 |
+
flip_aug: {subset.flip_aug}
|
535 |
+
face_crop_aug_range: {subset.face_crop_aug_range}
|
536 |
+
random_crop: {subset.random_crop}
|
537 |
+
token_warmup_min: {subset.token_warmup_min}
|
538 |
+
token_warmup_step: {subset.token_warmup_step}
|
539 |
+
alpha_mask: {subset.alpha_mask}
|
540 |
+
custom_attributes: {subset.custom_attributes}
|
541 |
+
"""
|
542 |
+
),
|
543 |
+
" ",
|
544 |
+
)
|
545 |
+
|
546 |
+
if is_dreambooth:
|
547 |
+
info += indent(
|
548 |
+
dedent(
|
549 |
+
f"""\
|
550 |
+
is_reg: {subset.is_reg}
|
551 |
+
class_tokens: {subset.class_tokens}
|
552 |
+
caption_extension: {subset.caption_extension}
|
553 |
+
\n"""
|
554 |
+
),
|
555 |
+
" ",
|
556 |
+
)
|
557 |
+
elif not is_controlnet:
|
558 |
+
info += indent(
|
559 |
+
dedent(
|
560 |
+
f"""\
|
561 |
+
metadata_file: {subset.metadata_file}
|
562 |
+
\n"""
|
563 |
+
),
|
564 |
+
" ",
|
565 |
+
)
|
566 |
+
|
567 |
+
logger.info(f"{info}")
|
568 |
+
|
569 |
+
# make buckets first because it determines the length of dataset
|
570 |
+
# and set the same seed for all datasets
|
571 |
+
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
572 |
+
for i, dataset in enumerate(datasets):
|
573 |
+
logger.info(f"[Dataset {i}]")
|
574 |
+
dataset.make_buckets()
|
575 |
+
dataset.set_seed(seed)
|
576 |
+
|
577 |
+
return DatasetGroup(datasets)
|
578 |
+
|
579 |
+
|
580 |
+
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
581 |
+
def extract_dreambooth_params(name: str) -> Tuple[int, str]:
|
582 |
+
tokens = name.split("_")
|
583 |
+
try:
|
584 |
+
n_repeats = int(tokens[0])
|
585 |
+
except ValueError as e:
|
586 |
+
logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
|
587 |
+
return 0, ""
|
588 |
+
caption_by_folder = "_".join(tokens[1:])
|
589 |
+
return n_repeats, caption_by_folder
|
590 |
+
|
591 |
+
def generate(base_dir: Optional[str], is_reg: bool):
|
592 |
+
if base_dir is None:
|
593 |
+
return []
|
594 |
+
|
595 |
+
base_dir: Path = Path(base_dir)
|
596 |
+
if not base_dir.is_dir():
|
597 |
+
return []
|
598 |
+
|
599 |
+
subsets_config = []
|
600 |
+
for subdir in base_dir.iterdir():
|
601 |
+
if not subdir.is_dir():
|
602 |
+
continue
|
603 |
+
|
604 |
+
num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
|
605 |
+
if num_repeats < 1:
|
606 |
+
continue
|
607 |
+
|
608 |
+
subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
|
609 |
+
subsets_config.append(subset_config)
|
610 |
+
|
611 |
+
return subsets_config
|
612 |
+
|
613 |
+
subsets_config = []
|
614 |
+
subsets_config += generate(train_data_dir, False)
|
615 |
+
subsets_config += generate(reg_data_dir, True)
|
616 |
+
|
617 |
+
return subsets_config
|
618 |
+
|
619 |
+
|
620 |
+
def generate_controlnet_subsets_config_by_subdirs(
|
621 |
+
train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"
|
622 |
+
):
|
623 |
+
def generate(base_dir: Optional[str]):
|
624 |
+
if base_dir is None:
|
625 |
+
return []
|
626 |
+
|
627 |
+
base_dir: Path = Path(base_dir)
|
628 |
+
if not base_dir.is_dir():
|
629 |
+
return []
|
630 |
+
|
631 |
+
subsets_config = []
|
632 |
+
subset_config = {
|
633 |
+
"image_dir": train_data_dir,
|
634 |
+
"conditioning_data_dir": conditioning_data_dir,
|
635 |
+
"caption_extension": caption_extension,
|
636 |
+
"num_repeats": 1,
|
637 |
+
}
|
638 |
+
subsets_config.append(subset_config)
|
639 |
+
|
640 |
+
return subsets_config
|
641 |
+
|
642 |
+
subsets_config = []
|
643 |
+
subsets_config += generate(train_data_dir)
|
644 |
+
|
645 |
+
return subsets_config
|
646 |
+
|
647 |
+
|
648 |
+
def load_user_config(file: str) -> dict:
|
649 |
+
file: Path = Path(file)
|
650 |
+
if not file.is_file():
|
651 |
+
raise ValueError(f"file not found / ファイルが見つかりません: {file}")
|
652 |
+
|
653 |
+
if file.name.lower().endswith(".json"):
|
654 |
+
try:
|
655 |
+
with open(file, "r") as f:
|
656 |
+
config = json.load(f)
|
657 |
+
except Exception:
|
658 |
+
logger.error(
|
659 |
+
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
660 |
+
)
|
661 |
+
raise
|
662 |
+
elif file.name.lower().endswith(".toml"):
|
663 |
+
try:
|
664 |
+
config = toml.load(file)
|
665 |
+
except Exception:
|
666 |
+
logger.error(
|
667 |
+
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
668 |
+
)
|
669 |
+
raise
|
670 |
+
else:
|
671 |
+
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
672 |
+
|
673 |
+
return config
|
674 |
+
|
675 |
+
|
676 |
+
# for config test
|
677 |
+
if __name__ == "__main__":
|
678 |
+
parser = argparse.ArgumentParser()
|
679 |
+
parser.add_argument("--support_dreambooth", action="store_true")
|
680 |
+
parser.add_argument("--support_finetuning", action="store_true")
|
681 |
+
parser.add_argument("--support_controlnet", action="store_true")
|
682 |
+
parser.add_argument("--support_dropout", action="store_true")
|
683 |
+
parser.add_argument("dataset_config")
|
684 |
+
config_args, remain = parser.parse_known_args()
|
685 |
+
|
686 |
+
parser = argparse.ArgumentParser()
|
687 |
+
train_util.add_dataset_arguments(
|
688 |
+
parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout
|
689 |
+
)
|
690 |
+
train_util.add_training_arguments(parser, config_args.support_dreambooth)
|
691 |
+
argparse_namespace = parser.parse_args(remain)
|
692 |
+
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
693 |
+
|
694 |
+
logger.info("[argparse_namespace]")
|
695 |
+
logger.info(f"{vars(argparse_namespace)}")
|
696 |
+
|
697 |
+
user_config = load_user_config(config_args.dataset_config)
|
698 |
+
|
699 |
+
logger.info("")
|
700 |
+
logger.info("[user_config]")
|
701 |
+
logger.info(f"{user_config}")
|
702 |
+
|
703 |
+
sanitizer = ConfigSanitizer(
|
704 |
+
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
705 |
+
)
|
706 |
+
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
707 |
+
|
708 |
+
logger.info("")
|
709 |
+
logger.info("[sanitized_user_config]")
|
710 |
+
logger.info(f"{sanitized_user_config}")
|
711 |
+
|
712 |
+
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
713 |
+
|
714 |
+
logger.info("")
|
715 |
+
logger.info("[blueprint]")
|
716 |
+
logger.info(f"{blueprint}")
|
library/custom_offloading_utils.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from concurrent.futures import ThreadPoolExecutor
|
2 |
+
import time
|
3 |
+
from typing import Optional
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from library.device_utils import clean_memory_on_device
|
8 |
+
|
9 |
+
|
10 |
+
def synchronize_device(device: torch.device):
|
11 |
+
if device.type == "cuda":
|
12 |
+
torch.cuda.synchronize()
|
13 |
+
elif device.type == "xpu":
|
14 |
+
torch.xpu.synchronize()
|
15 |
+
elif device.type == "mps":
|
16 |
+
torch.mps.synchronize()
|
17 |
+
|
18 |
+
|
19 |
+
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
20 |
+
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
21 |
+
|
22 |
+
weight_swap_jobs = []
|
23 |
+
|
24 |
+
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
|
25 |
+
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
26 |
+
# print(module_to_cpu.__class__, module_to_cuda.__class__)
|
27 |
+
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
28 |
+
# weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
29 |
+
|
30 |
+
modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
|
31 |
+
for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
|
32 |
+
if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
|
33 |
+
module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
|
34 |
+
if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
|
35 |
+
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
36 |
+
else:
|
37 |
+
if module_to_cuda.weight.data.device.type != device.type:
|
38 |
+
# print(
|
39 |
+
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
|
40 |
+
# )
|
41 |
+
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
|
42 |
+
|
43 |
+
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
44 |
+
|
45 |
+
stream = torch.cuda.Stream()
|
46 |
+
with torch.cuda.stream(stream):
|
47 |
+
# cuda to cpu
|
48 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
49 |
+
cuda_data_view.record_stream(stream)
|
50 |
+
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
51 |
+
|
52 |
+
stream.synchronize()
|
53 |
+
|
54 |
+
# cpu to cuda
|
55 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
56 |
+
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
57 |
+
module_to_cuda.weight.data = cuda_data_view
|
58 |
+
|
59 |
+
stream.synchronize()
|
60 |
+
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
61 |
+
|
62 |
+
|
63 |
+
def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
64 |
+
"""
|
65 |
+
not tested
|
66 |
+
"""
|
67 |
+
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
68 |
+
|
69 |
+
weight_swap_jobs = []
|
70 |
+
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
71 |
+
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
72 |
+
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
73 |
+
|
74 |
+
# device to cpu
|
75 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
76 |
+
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
77 |
+
|
78 |
+
synchronize_device()
|
79 |
+
|
80 |
+
# cpu to device
|
81 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
82 |
+
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
83 |
+
module_to_cuda.weight.data = cuda_data_view
|
84 |
+
|
85 |
+
synchronize_device()
|
86 |
+
|
87 |
+
|
88 |
+
def weighs_to_device(layer: nn.Module, device: torch.device):
|
89 |
+
for module in layer.modules():
|
90 |
+
if hasattr(module, "weight") and module.weight is not None:
|
91 |
+
module.weight.data = module.weight.data.to(device, non_blocking=True)
|
92 |
+
|
93 |
+
|
94 |
+
class Offloader:
|
95 |
+
"""
|
96 |
+
common offloading class
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
100 |
+
self.num_blocks = num_blocks
|
101 |
+
self.blocks_to_swap = blocks_to_swap
|
102 |
+
self.device = device
|
103 |
+
self.debug = debug
|
104 |
+
|
105 |
+
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
106 |
+
self.futures = {}
|
107 |
+
self.cuda_available = device.type == "cuda"
|
108 |
+
|
109 |
+
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
|
110 |
+
if self.cuda_available:
|
111 |
+
swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
|
112 |
+
else:
|
113 |
+
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
|
114 |
+
|
115 |
+
def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
|
116 |
+
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
117 |
+
if self.debug:
|
118 |
+
start_time = time.perf_counter()
|
119 |
+
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
|
120 |
+
|
121 |
+
self.swap_weight_devices(block_to_cpu, block_to_cuda)
|
122 |
+
|
123 |
+
if self.debug:
|
124 |
+
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
|
125 |
+
return bidx_to_cpu, bidx_to_cuda # , event
|
126 |
+
|
127 |
+
block_to_cpu = blocks[block_idx_to_cpu]
|
128 |
+
block_to_cuda = blocks[block_idx_to_cuda]
|
129 |
+
|
130 |
+
self.futures[block_idx_to_cuda] = self.thread_pool.submit(
|
131 |
+
move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
|
132 |
+
)
|
133 |
+
|
134 |
+
def _wait_blocks_move(self, block_idx):
|
135 |
+
if block_idx not in self.futures:
|
136 |
+
return
|
137 |
+
|
138 |
+
if self.debug:
|
139 |
+
print(f"Wait for block {block_idx}")
|
140 |
+
start_time = time.perf_counter()
|
141 |
+
|
142 |
+
future = self.futures.pop(block_idx)
|
143 |
+
_, bidx_to_cuda = future.result()
|
144 |
+
|
145 |
+
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
|
146 |
+
|
147 |
+
if self.debug:
|
148 |
+
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
149 |
+
|
150 |
+
|
151 |
+
class ModelOffloader(Offloader):
|
152 |
+
"""
|
153 |
+
supports forward offloading
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
157 |
+
super().__init__(num_blocks, blocks_to_swap, device, debug)
|
158 |
+
|
159 |
+
# register backward hooks
|
160 |
+
self.remove_handles = []
|
161 |
+
for i, block in enumerate(blocks):
|
162 |
+
hook = self.create_backward_hook(blocks, i)
|
163 |
+
if hook is not None:
|
164 |
+
handle = block.register_full_backward_hook(hook)
|
165 |
+
self.remove_handles.append(handle)
|
166 |
+
|
167 |
+
def __del__(self):
|
168 |
+
for handle in self.remove_handles:
|
169 |
+
handle.remove()
|
170 |
+
|
171 |
+
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
|
172 |
+
# -1 for 0-based index
|
173 |
+
num_blocks_propagated = self.num_blocks - block_index - 1
|
174 |
+
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
175 |
+
waiting = block_index > 0 and block_index <= self.blocks_to_swap
|
176 |
+
|
177 |
+
if not swapping and not waiting:
|
178 |
+
return None
|
179 |
+
|
180 |
+
# create hook
|
181 |
+
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
|
182 |
+
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
|
183 |
+
block_idx_to_wait = block_index - 1
|
184 |
+
|
185 |
+
def backward_hook(module, grad_input, grad_output):
|
186 |
+
if self.debug:
|
187 |
+
print(f"Backward hook for block {block_index}")
|
188 |
+
|
189 |
+
if swapping:
|
190 |
+
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
191 |
+
if waiting:
|
192 |
+
self._wait_blocks_move(block_idx_to_wait)
|
193 |
+
return None
|
194 |
+
|
195 |
+
return backward_hook
|
196 |
+
|
197 |
+
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
|
198 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
199 |
+
return
|
200 |
+
|
201 |
+
if self.debug:
|
202 |
+
print("Prepare block devices before forward")
|
203 |
+
|
204 |
+
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
205 |
+
b.to(self.device)
|
206 |
+
weighs_to_device(b, self.device) # make sure weights are on device
|
207 |
+
|
208 |
+
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
|
209 |
+
b.to(self.device) # move block to device first
|
210 |
+
weighs_to_device(b, "cpu") # make sure weights are on cpu
|
211 |
+
|
212 |
+
synchronize_device(self.device)
|
213 |
+
clean_memory_on_device(self.device)
|
214 |
+
|
215 |
+
def wait_for_block(self, block_idx: int):
|
216 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
217 |
+
return
|
218 |
+
self._wait_blocks_move(block_idx)
|
219 |
+
|
220 |
+
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
|
221 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
222 |
+
return
|
223 |
+
if block_idx >= self.blocks_to_swap:
|
224 |
+
return
|
225 |
+
block_idx_to_cpu = block_idx
|
226 |
+
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
|
227 |
+
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
library/custom_train_functions.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
import random
|
4 |
+
import re
|
5 |
+
from typing import List, Optional, Union
|
6 |
+
from .utils import setup_logging
|
7 |
+
|
8 |
+
setup_logging()
|
9 |
+
import logging
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
15 |
+
if hasattr(noise_scheduler, "all_snr"):
|
16 |
+
return
|
17 |
+
|
18 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
19 |
+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
20 |
+
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
21 |
+
alpha = sqrt_alphas_cumprod
|
22 |
+
sigma = sqrt_one_minus_alphas_cumprod
|
23 |
+
all_snr = (alpha / sigma) ** 2
|
24 |
+
|
25 |
+
noise_scheduler.all_snr = all_snr.to(device)
|
26 |
+
|
27 |
+
|
28 |
+
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
29 |
+
# fix beta: zero terminal SNR
|
30 |
+
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
31 |
+
|
32 |
+
def enforce_zero_terminal_snr(betas):
|
33 |
+
# Convert betas to alphas_bar_sqrt
|
34 |
+
alphas = 1 - betas
|
35 |
+
alphas_bar = alphas.cumprod(0)
|
36 |
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
37 |
+
|
38 |
+
# Store old values.
|
39 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
40 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
41 |
+
# Shift so last timestep is zero.
|
42 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
43 |
+
# Scale so first timestep is back to old value.
|
44 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
45 |
+
|
46 |
+
# Convert alphas_bar_sqrt to betas
|
47 |
+
alphas_bar = alphas_bar_sqrt**2
|
48 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
49 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
50 |
+
betas = 1 - alphas
|
51 |
+
return betas
|
52 |
+
|
53 |
+
betas = noise_scheduler.betas
|
54 |
+
betas = enforce_zero_terminal_snr(betas)
|
55 |
+
alphas = 1.0 - betas
|
56 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
57 |
+
|
58 |
+
# logger.info(f"original: {noise_scheduler.betas}")
|
59 |
+
# logger.info(f"fixed: {betas}")
|
60 |
+
|
61 |
+
noise_scheduler.betas = betas
|
62 |
+
noise_scheduler.alphas = alphas
|
63 |
+
noise_scheduler.alphas_cumprod = alphas_cumprod
|
64 |
+
|
65 |
+
|
66 |
+
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
67 |
+
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
68 |
+
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
69 |
+
if v_prediction:
|
70 |
+
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
|
71 |
+
else:
|
72 |
+
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
73 |
+
loss = loss * snr_weight
|
74 |
+
return loss
|
75 |
+
|
76 |
+
|
77 |
+
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
78 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
79 |
+
loss = loss * scale
|
80 |
+
return loss
|
81 |
+
|
82 |
+
|
83 |
+
def get_snr_scale(timesteps, noise_scheduler):
|
84 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
85 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
86 |
+
scale = snr_t / (snr_t + 1)
|
87 |
+
# # show debug info
|
88 |
+
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
89 |
+
return scale
|
90 |
+
|
91 |
+
|
92 |
+
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
93 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
94 |
+
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
95 |
+
loss = loss + loss / scale * v_pred_like_loss
|
96 |
+
return loss
|
97 |
+
|
98 |
+
|
99 |
+
def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
|
100 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
101 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
102 |
+
if v_prediction:
|
103 |
+
weight = 1 / (snr_t + 1)
|
104 |
+
else:
|
105 |
+
weight = 1 / torch.sqrt(snr_t)
|
106 |
+
loss = weight * loss
|
107 |
+
return loss
|
108 |
+
|
109 |
+
|
110 |
+
# TODO train_utilと分散しているのでどちらかに寄せる
|
111 |
+
|
112 |
+
|
113 |
+
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
114 |
+
parser.add_argument(
|
115 |
+
"--min_snr_gamma",
|
116 |
+
type=float,
|
117 |
+
default=None,
|
118 |
+
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--scale_v_pred_loss_like_noise_pred",
|
122 |
+
action="store_true",
|
123 |
+
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"--v_pred_like_loss",
|
127 |
+
type=float,
|
128 |
+
default=None,
|
129 |
+
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけ��ものをlossに加算する",
|
130 |
+
)
|
131 |
+
parser.add_argument(
|
132 |
+
"--debiased_estimation_loss",
|
133 |
+
action="store_true",
|
134 |
+
help="debiased estimation loss / debiased estimation loss",
|
135 |
+
)
|
136 |
+
if support_weighted_captions:
|
137 |
+
parser.add_argument(
|
138 |
+
"--weighted_captions",
|
139 |
+
action="store_true",
|
140 |
+
default=False,
|
141 |
+
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
142 |
+
)
|
143 |
+
|
144 |
+
|
145 |
+
re_attention = re.compile(
|
146 |
+
r"""
|
147 |
+
\\\(|
|
148 |
+
\\\)|
|
149 |
+
\\\[|
|
150 |
+
\\]|
|
151 |
+
\\\\|
|
152 |
+
\\|
|
153 |
+
\(|
|
154 |
+
\[|
|
155 |
+
:([+-]?[.\d]+)\)|
|
156 |
+
\)|
|
157 |
+
]|
|
158 |
+
[^\\()\[\]:]+|
|
159 |
+
:
|
160 |
+
""",
|
161 |
+
re.X,
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
def parse_prompt_attention(text):
|
166 |
+
"""
|
167 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
168 |
+
Accepted tokens are:
|
169 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
170 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
171 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
172 |
+
\( - literal character '('
|
173 |
+
\[ - literal character '['
|
174 |
+
\) - literal character ')'
|
175 |
+
\] - literal character ']'
|
176 |
+
\\ - literal character '\'
|
177 |
+
anything else - just text
|
178 |
+
>>> parse_prompt_attention('normal text')
|
179 |
+
[['normal text', 1.0]]
|
180 |
+
>>> parse_prompt_attention('an (important) word')
|
181 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
182 |
+
>>> parse_prompt_attention('(unbalanced')
|
183 |
+
[['unbalanced', 1.1]]
|
184 |
+
>>> parse_prompt_attention('\(literal\]')
|
185 |
+
[['(literal]', 1.0]]
|
186 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
187 |
+
[['unnecessaryparens', 1.1]]
|
188 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
189 |
+
[['a ', 1.0],
|
190 |
+
['house', 1.5730000000000004],
|
191 |
+
[' ', 1.1],
|
192 |
+
['on', 1.0],
|
193 |
+
[' a ', 1.1],
|
194 |
+
['hill', 0.55],
|
195 |
+
[', sun, ', 1.1],
|
196 |
+
['sky', 1.4641000000000006],
|
197 |
+
['.', 1.1]]
|
198 |
+
"""
|
199 |
+
|
200 |
+
res = []
|
201 |
+
round_brackets = []
|
202 |
+
square_brackets = []
|
203 |
+
|
204 |
+
round_bracket_multiplier = 1.1
|
205 |
+
square_bracket_multiplier = 1 / 1.1
|
206 |
+
|
207 |
+
def multiply_range(start_position, multiplier):
|
208 |
+
for p in range(start_position, len(res)):
|
209 |
+
res[p][1] *= multiplier
|
210 |
+
|
211 |
+
for m in re_attention.finditer(text):
|
212 |
+
text = m.group(0)
|
213 |
+
weight = m.group(1)
|
214 |
+
|
215 |
+
if text.startswith("\\"):
|
216 |
+
res.append([text[1:], 1.0])
|
217 |
+
elif text == "(":
|
218 |
+
round_brackets.append(len(res))
|
219 |
+
elif text == "[":
|
220 |
+
square_brackets.append(len(res))
|
221 |
+
elif weight is not None and len(round_brackets) > 0:
|
222 |
+
multiply_range(round_brackets.pop(), float(weight))
|
223 |
+
elif text == ")" and len(round_brackets) > 0:
|
224 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
225 |
+
elif text == "]" and len(square_brackets) > 0:
|
226 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
227 |
+
else:
|
228 |
+
res.append([text, 1.0])
|
229 |
+
|
230 |
+
for pos in round_brackets:
|
231 |
+
multiply_range(pos, round_bracket_multiplier)
|
232 |
+
|
233 |
+
for pos in square_brackets:
|
234 |
+
multiply_range(pos, square_bracket_multiplier)
|
235 |
+
|
236 |
+
if len(res) == 0:
|
237 |
+
res = [["", 1.0]]
|
238 |
+
|
239 |
+
# merge runs of identical weights
|
240 |
+
i = 0
|
241 |
+
while i + 1 < len(res):
|
242 |
+
if res[i][1] == res[i + 1][1]:
|
243 |
+
res[i][0] += res[i + 1][0]
|
244 |
+
res.pop(i + 1)
|
245 |
+
else:
|
246 |
+
i += 1
|
247 |
+
|
248 |
+
return res
|
249 |
+
|
250 |
+
|
251 |
+
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
252 |
+
r"""
|
253 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
254 |
+
|
255 |
+
No padding, starting or ending token is included.
|
256 |
+
"""
|
257 |
+
tokens = []
|
258 |
+
weights = []
|
259 |
+
truncated = False
|
260 |
+
for text in prompt:
|
261 |
+
texts_and_weights = parse_prompt_attention(text)
|
262 |
+
text_token = []
|
263 |
+
text_weight = []
|
264 |
+
for word, weight in texts_and_weights:
|
265 |
+
# tokenize and discard the starting and the ending token
|
266 |
+
token = tokenizer(word).input_ids[1:-1]
|
267 |
+
text_token += token
|
268 |
+
# copy the weight by length of token
|
269 |
+
text_weight += [weight] * len(token)
|
270 |
+
# stop if the text is too long (longer than truncation limit)
|
271 |
+
if len(text_token) > max_length:
|
272 |
+
truncated = True
|
273 |
+
break
|
274 |
+
# truncate
|
275 |
+
if len(text_token) > max_length:
|
276 |
+
truncated = True
|
277 |
+
text_token = text_token[:max_length]
|
278 |
+
text_weight = text_weight[:max_length]
|
279 |
+
tokens.append(text_token)
|
280 |
+
weights.append(text_weight)
|
281 |
+
if truncated:
|
282 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
283 |
+
return tokens, weights
|
284 |
+
|
285 |
+
|
286 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
287 |
+
r"""
|
288 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
289 |
+
"""
|
290 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
291 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
292 |
+
for i in range(len(tokens)):
|
293 |
+
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
294 |
+
if no_boseos_middle:
|
295 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
296 |
+
else:
|
297 |
+
w = []
|
298 |
+
if len(weights[i]) == 0:
|
299 |
+
w = [1.0] * weights_length
|
300 |
+
else:
|
301 |
+
for j in range(max_embeddings_multiples):
|
302 |
+
w.append(1.0) # weight for starting token in this chunk
|
303 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
304 |
+
w.append(1.0) # weight for ending token in this chunk
|
305 |
+
w += [1.0] * (weights_length - len(w))
|
306 |
+
weights[i] = w[:]
|
307 |
+
|
308 |
+
return tokens, weights
|
309 |
+
|
310 |
+
|
311 |
+
def get_unweighted_text_embeddings(
|
312 |
+
tokenizer,
|
313 |
+
text_encoder,
|
314 |
+
text_input: torch.Tensor,
|
315 |
+
chunk_length: int,
|
316 |
+
clip_skip: int,
|
317 |
+
eos: int,
|
318 |
+
pad: int,
|
319 |
+
no_boseos_middle: Optional[bool] = True,
|
320 |
+
):
|
321 |
+
"""
|
322 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
323 |
+
it should be split into chunks and sent to the text encoder individually.
|
324 |
+
"""
|
325 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
326 |
+
if max_embeddings_multiples > 1:
|
327 |
+
text_embeddings = []
|
328 |
+
for i in range(max_embeddings_multiples):
|
329 |
+
# extract the i-th chunk
|
330 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
331 |
+
|
332 |
+
# cover the head and the tail by the starting and the ending tokens
|
333 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
334 |
+
if pad == eos: # v1
|
335 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
336 |
+
else: # v2
|
337 |
+
for j in range(len(text_input_chunk)):
|
338 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
339 |
+
text_input_chunk[j, -1] = eos
|
340 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
341 |
+
text_input_chunk[j, 1] = eos
|
342 |
+
|
343 |
+
if clip_skip is None or clip_skip == 1:
|
344 |
+
text_embedding = text_encoder(text_input_chunk)[0]
|
345 |
+
else:
|
346 |
+
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
347 |
+
text_embedding = enc_out["hidden_states"][-clip_skip]
|
348 |
+
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
349 |
+
|
350 |
+
if no_boseos_middle:
|
351 |
+
if i == 0:
|
352 |
+
# discard the ending token
|
353 |
+
text_embedding = text_embedding[:, :-1]
|
354 |
+
elif i == max_embeddings_multiples - 1:
|
355 |
+
# discard the starting token
|
356 |
+
text_embedding = text_embedding[:, 1:]
|
357 |
+
else:
|
358 |
+
# discard both starting and ending tokens
|
359 |
+
text_embedding = text_embedding[:, 1:-1]
|
360 |
+
|
361 |
+
text_embeddings.append(text_embedding)
|
362 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
363 |
+
else:
|
364 |
+
if clip_skip is None or clip_skip == 1:
|
365 |
+
text_embeddings = text_encoder(text_input)[0]
|
366 |
+
else:
|
367 |
+
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
368 |
+
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
369 |
+
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
370 |
+
return text_embeddings
|
371 |
+
|
372 |
+
|
373 |
+
def get_weighted_text_embeddings(
|
374 |
+
tokenizer,
|
375 |
+
text_encoder,
|
376 |
+
prompt: Union[str, List[str]],
|
377 |
+
device,
|
378 |
+
max_embeddings_multiples: Optional[int] = 3,
|
379 |
+
no_boseos_middle: Optional[bool] = False,
|
380 |
+
clip_skip=None,
|
381 |
+
):
|
382 |
+
r"""
|
383 |
+
Prompts can be assigned with local weights using brackets. For example,
|
384 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
385 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
386 |
+
|
387 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
388 |
+
|
389 |
+
Args:
|
390 |
+
prompt (`str` or `List[str]`):
|
391 |
+
The prompt or prompts to guide the image generation.
|
392 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
393 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
394 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
395 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
396 |
+
ending token in each of the chunk in the middle.
|
397 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
398 |
+
Skip the parsing of brackets.
|
399 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
400 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
401 |
+
"""
|
402 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
403 |
+
if isinstance(prompt, str):
|
404 |
+
prompt = [prompt]
|
405 |
+
|
406 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
|
407 |
+
|
408 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
409 |
+
max_length = max([len(token) for token in prompt_tokens])
|
410 |
+
|
411 |
+
max_embeddings_multiples = min(
|
412 |
+
max_embeddings_multiples,
|
413 |
+
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
|
414 |
+
)
|
415 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
416 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
417 |
+
|
418 |
+
# pad the length of tokens and weights
|
419 |
+
bos = tokenizer.bos_token_id
|
420 |
+
eos = tokenizer.eos_token_id
|
421 |
+
pad = tokenizer.pad_token_id
|
422 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
423 |
+
prompt_tokens,
|
424 |
+
prompt_weights,
|
425 |
+
max_length,
|
426 |
+
bos,
|
427 |
+
eos,
|
428 |
+
no_boseos_middle=no_boseos_middle,
|
429 |
+
chunk_length=tokenizer.model_max_length,
|
430 |
+
)
|
431 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
|
432 |
+
|
433 |
+
# get the embeddings
|
434 |
+
text_embeddings = get_unweighted_text_embeddings(
|
435 |
+
tokenizer,
|
436 |
+
text_encoder,
|
437 |
+
prompt_tokens,
|
438 |
+
tokenizer.model_max_length,
|
439 |
+
clip_skip,
|
440 |
+
eos,
|
441 |
+
pad,
|
442 |
+
no_boseos_middle=no_boseos_middle,
|
443 |
+
)
|
444 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
|
445 |
+
|
446 |
+
# assign weights to the prompts and normalize in the sense of mean
|
447 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
448 |
+
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
|
449 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
450 |
+
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
451 |
+
|
452 |
+
return text_embeddings
|
453 |
+
|
454 |
+
|
455 |
+
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
456 |
+
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
457 |
+
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
458 |
+
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
459 |
+
for i in range(iterations):
|
460 |
+
r = random.random() * 2 + 2 # Rather than always going 2x,
|
461 |
+
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
462 |
+
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
463 |
+
if wn == 1 or hn == 1:
|
464 |
+
break # Lowest resolution is 1x1
|
465 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
466 |
+
|
467 |
+
|
468 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
469 |
+
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
470 |
+
if noise_offset is None:
|
471 |
+
return noise
|
472 |
+
if adaptive_noise_scale is not None:
|
473 |
+
# latent shape: (batch_size, channels, height, width)
|
474 |
+
# abs mean value for each channel
|
475 |
+
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
|
476 |
+
|
477 |
+
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
478 |
+
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
479 |
+
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
480 |
+
|
481 |
+
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
482 |
+
return noise
|
483 |
+
|
484 |
+
|
485 |
+
def apply_masked_loss(loss, batch):
|
486 |
+
if "conditioning_images" in batch:
|
487 |
+
# conditioning image is -1 to 1. we need to convert it to 0 to 1
|
488 |
+
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
489 |
+
mask_image = mask_image / 2 + 0.5
|
490 |
+
# print(f"conditioning_image: {mask_image.shape}")
|
491 |
+
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
|
492 |
+
# alpha mask is 0 to 1
|
493 |
+
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
|
494 |
+
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
|
495 |
+
else:
|
496 |
+
return loss
|
497 |
+
|
498 |
+
# resize to the same size as the loss
|
499 |
+
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
500 |
+
loss = loss * mask_image
|
501 |
+
return loss
|
502 |
+
|
503 |
+
|
504 |
+
"""
|
505 |
+
##########################################
|
506 |
+
# Perlin Noise
|
507 |
+
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
508 |
+
delta = (res[0] / shape[0], res[1] / shape[1])
|
509 |
+
d = (shape[0] // res[0], shape[1] // res[1])
|
510 |
+
|
511 |
+
grid = (
|
512 |
+
torch.stack(
|
513 |
+
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
514 |
+
dim=-1,
|
515 |
+
)
|
516 |
+
% 1
|
517 |
+
)
|
518 |
+
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
519 |
+
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
520 |
+
|
521 |
+
tile_grads = (
|
522 |
+
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
523 |
+
.repeat_interleave(d[0], 0)
|
524 |
+
.repeat_interleave(d[1], 1)
|
525 |
+
)
|
526 |
+
dot = lambda grad, shift: (
|
527 |
+
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
528 |
+
* grad[: shape[0], : shape[1]]
|
529 |
+
).sum(dim=-1)
|
530 |
+
|
531 |
+
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
532 |
+
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
533 |
+
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
534 |
+
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
535 |
+
t = fade(grid[: shape[0], : shape[1]])
|
536 |
+
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
537 |
+
|
538 |
+
|
539 |
+
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
540 |
+
noise = torch.zeros(shape, device=device)
|
541 |
+
frequency = 1
|
542 |
+
amplitude = 1
|
543 |
+
for _ in range(octaves):
|
544 |
+
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
545 |
+
frequency *= 2
|
546 |
+
amplitude *= persistence
|
547 |
+
return noise
|
548 |
+
|
549 |
+
|
550 |
+
def perlin_noise(noise, device, octaves):
|
551 |
+
_, c, w, h = noise.shape
|
552 |
+
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
553 |
+
noise_perlin = []
|
554 |
+
for _ in range(c):
|
555 |
+
noise_perlin.append(perlin())
|
556 |
+
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
557 |
+
noise += noise_perlin # broadcast for each batch
|
558 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
559 |
+
"""
|
library/deepspeed_utils.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
from accelerate import DeepSpeedPlugin, Accelerator
|
5 |
+
|
6 |
+
from .utils import setup_logging
|
7 |
+
|
8 |
+
setup_logging()
|
9 |
+
import logging
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
|
15 |
+
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
|
16 |
+
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
|
17 |
+
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
|
18 |
+
parser.add_argument(
|
19 |
+
"--offload_optimizer_device",
|
20 |
+
type=str,
|
21 |
+
default=None,
|
22 |
+
choices=[None, "cpu", "nvme"],
|
23 |
+
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--offload_optimizer_nvme_path",
|
27 |
+
type=str,
|
28 |
+
default=None,
|
29 |
+
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--offload_param_device",
|
33 |
+
type=str,
|
34 |
+
default=None,
|
35 |
+
choices=[None, "cpu", "nvme"],
|
36 |
+
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--offload_param_nvme_path",
|
40 |
+
type=str,
|
41 |
+
default=None,
|
42 |
+
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--zero3_init_flag",
|
46 |
+
action="store_true",
|
47 |
+
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
|
48 |
+
"Only applicable with ZeRO Stage-3.",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--zero3_save_16bit_model",
|
52 |
+
action="store_true",
|
53 |
+
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--fp16_master_weights_and_gradients",
|
57 |
+
action="store_true",
|
58 |
+
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def prepare_deepspeed_args(args: argparse.Namespace):
|
63 |
+
if not args.deepspeed:
|
64 |
+
return
|
65 |
+
|
66 |
+
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
67 |
+
args.max_data_loader_n_workers = 1
|
68 |
+
|
69 |
+
|
70 |
+
def prepare_deepspeed_plugin(args: argparse.Namespace):
|
71 |
+
if not args.deepspeed:
|
72 |
+
return None
|
73 |
+
|
74 |
+
try:
|
75 |
+
import deepspeed
|
76 |
+
except ImportError as e:
|
77 |
+
logger.error(
|
78 |
+
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
|
79 |
+
)
|
80 |
+
exit(1)
|
81 |
+
|
82 |
+
deepspeed_plugin = DeepSpeedPlugin(
|
83 |
+
zero_stage=args.zero_stage,
|
84 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
85 |
+
gradient_clipping=args.max_grad_norm,
|
86 |
+
offload_optimizer_device=args.offload_optimizer_device,
|
87 |
+
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
|
88 |
+
offload_param_device=args.offload_param_device,
|
89 |
+
offload_param_nvme_path=args.offload_param_nvme_path,
|
90 |
+
zero3_init_flag=args.zero3_init_flag,
|
91 |
+
zero3_save_16bit_model=args.zero3_save_16bit_model,
|
92 |
+
)
|
93 |
+
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
|
94 |
+
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
95 |
+
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
96 |
+
)
|
97 |
+
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
98 |
+
if args.mixed_precision.lower() == "fp16":
|
99 |
+
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
100 |
+
if args.full_fp16 or args.fp16_master_weights_and_gradients:
|
101 |
+
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
|
102 |
+
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
|
103 |
+
logger.info("[DeepSpeed] full fp16 enable.")
|
104 |
+
else:
|
105 |
+
logger.info(
|
106 |
+
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
|
107 |
+
)
|
108 |
+
|
109 |
+
if args.offload_optimizer_device is not None:
|
110 |
+
logger.info("[DeepSpeed] start to manually build cpu_adam.")
|
111 |
+
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
112 |
+
logger.info("[DeepSpeed] building cpu_adam done.")
|
113 |
+
|
114 |
+
return deepspeed_plugin
|
115 |
+
|
116 |
+
|
117 |
+
# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
|
118 |
+
def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
119 |
+
# remove None from models
|
120 |
+
models = {k: v for k, v in models.items() if v is not None}
|
121 |
+
|
122 |
+
class DeepSpeedWrapper(torch.nn.Module):
|
123 |
+
def __init__(self, **kw_models) -> None:
|
124 |
+
super().__init__()
|
125 |
+
self.models = torch.nn.ModuleDict()
|
126 |
+
|
127 |
+
for key, model in kw_models.items():
|
128 |
+
if isinstance(model, list):
|
129 |
+
model = torch.nn.ModuleList(model)
|
130 |
+
assert isinstance(
|
131 |
+
model, torch.nn.Module
|
132 |
+
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
133 |
+
self.models.update(torch.nn.ModuleDict({key: model}))
|
134 |
+
|
135 |
+
def get_models(self):
|
136 |
+
return self.models
|
137 |
+
|
138 |
+
ds_model = DeepSpeedWrapper(**models)
|
139 |
+
return ds_model
|
library/device_utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import gc
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
try:
|
7 |
+
HAS_CUDA = torch.cuda.is_available()
|
8 |
+
except Exception:
|
9 |
+
HAS_CUDA = False
|
10 |
+
|
11 |
+
try:
|
12 |
+
HAS_MPS = torch.backends.mps.is_available()
|
13 |
+
except Exception:
|
14 |
+
HAS_MPS = False
|
15 |
+
|
16 |
+
try:
|
17 |
+
import intel_extension_for_pytorch as ipex # noqa
|
18 |
+
|
19 |
+
HAS_XPU = torch.xpu.is_available()
|
20 |
+
except Exception:
|
21 |
+
HAS_XPU = False
|
22 |
+
|
23 |
+
|
24 |
+
def clean_memory():
|
25 |
+
gc.collect()
|
26 |
+
if HAS_CUDA:
|
27 |
+
torch.cuda.empty_cache()
|
28 |
+
if HAS_XPU:
|
29 |
+
torch.xpu.empty_cache()
|
30 |
+
if HAS_MPS:
|
31 |
+
torch.mps.empty_cache()
|
32 |
+
|
33 |
+
|
34 |
+
def clean_memory_on_device(device: torch.device):
|
35 |
+
r"""
|
36 |
+
Clean memory on the specified device, will be called from training scripts.
|
37 |
+
"""
|
38 |
+
gc.collect()
|
39 |
+
|
40 |
+
# device may "cuda" or "cuda:0", so we need to check the type of device
|
41 |
+
if device.type == "cuda":
|
42 |
+
torch.cuda.empty_cache()
|
43 |
+
if device.type == "xpu":
|
44 |
+
torch.xpu.empty_cache()
|
45 |
+
if device.type == "mps":
|
46 |
+
torch.mps.empty_cache()
|
47 |
+
|
48 |
+
|
49 |
+
@functools.lru_cache(maxsize=None)
|
50 |
+
def get_preferred_device() -> torch.device:
|
51 |
+
r"""
|
52 |
+
Do not call this function from training scripts. Use accelerator.device instead.
|
53 |
+
"""
|
54 |
+
if HAS_CUDA:
|
55 |
+
device = torch.device("cuda")
|
56 |
+
elif HAS_XPU:
|
57 |
+
device = torch.device("xpu")
|
58 |
+
elif HAS_MPS:
|
59 |
+
device = torch.device("mps")
|
60 |
+
else:
|
61 |
+
device = torch.device("cpu")
|
62 |
+
print(f"get_preferred_device() -> {device}")
|
63 |
+
return device
|
64 |
+
|
65 |
+
|
66 |
+
def init_ipex():
|
67 |
+
"""
|
68 |
+
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
|
69 |
+
|
70 |
+
This function should run right after importing torch and before doing anything else.
|
71 |
+
|
72 |
+
If IPEX is not available, this function does nothing.
|
73 |
+
"""
|
74 |
+
try:
|
75 |
+
if HAS_XPU:
|
76 |
+
from library.ipex import ipex_init
|
77 |
+
|
78 |
+
is_initialized, error_message = ipex_init()
|
79 |
+
if not is_initialized:
|
80 |
+
print("failed to initialize ipex:", error_message)
|
81 |
+
else:
|
82 |
+
return
|
83 |
+
except Exception as e:
|
84 |
+
print("failed to initialize ipex:", e)
|
library/flux_models.py
ADDED
@@ -0,0 +1,1237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copy from FLUX repo: https://github.com/black-forest-labs/flux
|
2 |
+
# license: Apache-2.0 License
|
3 |
+
|
4 |
+
|
5 |
+
from concurrent.futures import Future, ThreadPoolExecutor
|
6 |
+
from dataclasses import dataclass
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
from typing import Dict, List, Optional, Union
|
11 |
+
|
12 |
+
from library import utils
|
13 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
14 |
+
|
15 |
+
init_ipex()
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from einops import rearrange
|
19 |
+
from torch import Tensor, nn
|
20 |
+
from torch.utils.checkpoint import checkpoint
|
21 |
+
from library import custom_offloading_utils
|
22 |
+
|
23 |
+
# USE_REENTRANT = True
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class FluxParams:
|
28 |
+
in_channels: int
|
29 |
+
vec_in_dim: int
|
30 |
+
context_in_dim: int
|
31 |
+
hidden_size: int
|
32 |
+
mlp_ratio: float
|
33 |
+
num_heads: int
|
34 |
+
depth: int
|
35 |
+
depth_single_blocks: int
|
36 |
+
axes_dim: list[int]
|
37 |
+
theta: int
|
38 |
+
qkv_bias: bool
|
39 |
+
guidance_embed: bool
|
40 |
+
|
41 |
+
|
42 |
+
# region autoencoder
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class AutoEncoderParams:
|
47 |
+
resolution: int
|
48 |
+
in_channels: int
|
49 |
+
ch: int
|
50 |
+
out_ch: int
|
51 |
+
ch_mult: list[int]
|
52 |
+
num_res_blocks: int
|
53 |
+
z_channels: int
|
54 |
+
scale_factor: float
|
55 |
+
shift_factor: float
|
56 |
+
|
57 |
+
|
58 |
+
def swish(x: Tensor) -> Tensor:
|
59 |
+
return x * torch.sigmoid(x)
|
60 |
+
|
61 |
+
|
62 |
+
class AttnBlock(nn.Module):
|
63 |
+
def __init__(self, in_channels: int):
|
64 |
+
super().__init__()
|
65 |
+
self.in_channels = in_channels
|
66 |
+
|
67 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
68 |
+
|
69 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
70 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
71 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
72 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
73 |
+
|
74 |
+
def attention(self, h_: Tensor) -> Tensor:
|
75 |
+
h_ = self.norm(h_)
|
76 |
+
q = self.q(h_)
|
77 |
+
k = self.k(h_)
|
78 |
+
v = self.v(h_)
|
79 |
+
|
80 |
+
b, c, h, w = q.shape
|
81 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
82 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
83 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
84 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
85 |
+
|
86 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
87 |
+
|
88 |
+
def forward(self, x: Tensor) -> Tensor:
|
89 |
+
return x + self.proj_out(self.attention(x))
|
90 |
+
|
91 |
+
|
92 |
+
class ResnetBlock(nn.Module):
|
93 |
+
def __init__(self, in_channels: int, out_channels: int):
|
94 |
+
super().__init__()
|
95 |
+
self.in_channels = in_channels
|
96 |
+
out_channels = in_channels if out_channels is None else out_channels
|
97 |
+
self.out_channels = out_channels
|
98 |
+
|
99 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
100 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
101 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
102 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
103 |
+
if self.in_channels != self.out_channels:
|
104 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
h = x
|
108 |
+
h = self.norm1(h)
|
109 |
+
h = swish(h)
|
110 |
+
h = self.conv1(h)
|
111 |
+
|
112 |
+
h = self.norm2(h)
|
113 |
+
h = swish(h)
|
114 |
+
h = self.conv2(h)
|
115 |
+
|
116 |
+
if self.in_channels != self.out_channels:
|
117 |
+
x = self.nin_shortcut(x)
|
118 |
+
|
119 |
+
return x + h
|
120 |
+
|
121 |
+
|
122 |
+
class Downsample(nn.Module):
|
123 |
+
def __init__(self, in_channels: int):
|
124 |
+
super().__init__()
|
125 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
126 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
127 |
+
|
128 |
+
def forward(self, x: Tensor):
|
129 |
+
pad = (0, 1, 0, 1)
|
130 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
131 |
+
x = self.conv(x)
|
132 |
+
return x
|
133 |
+
|
134 |
+
|
135 |
+
class Upsample(nn.Module):
|
136 |
+
def __init__(self, in_channels: int):
|
137 |
+
super().__init__()
|
138 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
139 |
+
|
140 |
+
def forward(self, x: Tensor):
|
141 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
142 |
+
x = self.conv(x)
|
143 |
+
return x
|
144 |
+
|
145 |
+
|
146 |
+
class Encoder(nn.Module):
|
147 |
+
def __init__(
|
148 |
+
self,
|
149 |
+
resolution: int,
|
150 |
+
in_channels: int,
|
151 |
+
ch: int,
|
152 |
+
ch_mult: list[int],
|
153 |
+
num_res_blocks: int,
|
154 |
+
z_channels: int,
|
155 |
+
):
|
156 |
+
super().__init__()
|
157 |
+
self.ch = ch
|
158 |
+
self.num_resolutions = len(ch_mult)
|
159 |
+
self.num_res_blocks = num_res_blocks
|
160 |
+
self.resolution = resolution
|
161 |
+
self.in_channels = in_channels
|
162 |
+
# downsampling
|
163 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
164 |
+
|
165 |
+
curr_res = resolution
|
166 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
167 |
+
self.in_ch_mult = in_ch_mult
|
168 |
+
self.down = nn.ModuleList()
|
169 |
+
block_in = self.ch
|
170 |
+
for i_level in range(self.num_resolutions):
|
171 |
+
block = nn.ModuleList()
|
172 |
+
attn = nn.ModuleList()
|
173 |
+
block_in = ch * in_ch_mult[i_level]
|
174 |
+
block_out = ch * ch_mult[i_level]
|
175 |
+
for _ in range(self.num_res_blocks):
|
176 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
177 |
+
block_in = block_out
|
178 |
+
down = nn.Module()
|
179 |
+
down.block = block
|
180 |
+
down.attn = attn
|
181 |
+
if i_level != self.num_resolutions - 1:
|
182 |
+
down.downsample = Downsample(block_in)
|
183 |
+
curr_res = curr_res // 2
|
184 |
+
self.down.append(down)
|
185 |
+
|
186 |
+
# middle
|
187 |
+
self.mid = nn.Module()
|
188 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
189 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
190 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
191 |
+
|
192 |
+
# end
|
193 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
194 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
195 |
+
|
196 |
+
def forward(self, x: Tensor) -> Tensor:
|
197 |
+
# downsampling
|
198 |
+
hs = [self.conv_in(x)]
|
199 |
+
for i_level in range(self.num_resolutions):
|
200 |
+
for i_block in range(self.num_res_blocks):
|
201 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
202 |
+
if len(self.down[i_level].attn) > 0:
|
203 |
+
h = self.down[i_level].attn[i_block](h)
|
204 |
+
hs.append(h)
|
205 |
+
if i_level != self.num_resolutions - 1:
|
206 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
207 |
+
|
208 |
+
# middle
|
209 |
+
h = hs[-1]
|
210 |
+
h = self.mid.block_1(h)
|
211 |
+
h = self.mid.attn_1(h)
|
212 |
+
h = self.mid.block_2(h)
|
213 |
+
# end
|
214 |
+
h = self.norm_out(h)
|
215 |
+
h = swish(h)
|
216 |
+
h = self.conv_out(h)
|
217 |
+
return h
|
218 |
+
|
219 |
+
|
220 |
+
class Decoder(nn.Module):
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
ch: int,
|
224 |
+
out_ch: int,
|
225 |
+
ch_mult: list[int],
|
226 |
+
num_res_blocks: int,
|
227 |
+
in_channels: int,
|
228 |
+
resolution: int,
|
229 |
+
z_channels: int,
|
230 |
+
):
|
231 |
+
super().__init__()
|
232 |
+
self.ch = ch
|
233 |
+
self.num_resolutions = len(ch_mult)
|
234 |
+
self.num_res_blocks = num_res_blocks
|
235 |
+
self.resolution = resolution
|
236 |
+
self.in_channels = in_channels
|
237 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
238 |
+
|
239 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
240 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
241 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
242 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
243 |
+
|
244 |
+
# z to block_in
|
245 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
246 |
+
|
247 |
+
# middle
|
248 |
+
self.mid = nn.Module()
|
249 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
250 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
251 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
252 |
+
|
253 |
+
# upsampling
|
254 |
+
self.up = nn.ModuleList()
|
255 |
+
for i_level in reversed(range(self.num_resolutions)):
|
256 |
+
block = nn.ModuleList()
|
257 |
+
attn = nn.ModuleList()
|
258 |
+
block_out = ch * ch_mult[i_level]
|
259 |
+
for _ in range(self.num_res_blocks + 1):
|
260 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
261 |
+
block_in = block_out
|
262 |
+
up = nn.Module()
|
263 |
+
up.block = block
|
264 |
+
up.attn = attn
|
265 |
+
if i_level != 0:
|
266 |
+
up.upsample = Upsample(block_in)
|
267 |
+
curr_res = curr_res * 2
|
268 |
+
self.up.insert(0, up) # prepend to get consistent order
|
269 |
+
|
270 |
+
# end
|
271 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
272 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
273 |
+
|
274 |
+
def forward(self, z: Tensor) -> Tensor:
|
275 |
+
# z to block_in
|
276 |
+
h = self.conv_in(z)
|
277 |
+
|
278 |
+
# middle
|
279 |
+
h = self.mid.block_1(h)
|
280 |
+
h = self.mid.attn_1(h)
|
281 |
+
h = self.mid.block_2(h)
|
282 |
+
|
283 |
+
# upsampling
|
284 |
+
for i_level in reversed(range(self.num_resolutions)):
|
285 |
+
for i_block in range(self.num_res_blocks + 1):
|
286 |
+
h = self.up[i_level].block[i_block](h)
|
287 |
+
if len(self.up[i_level].attn) > 0:
|
288 |
+
h = self.up[i_level].attn[i_block](h)
|
289 |
+
if i_level != 0:
|
290 |
+
h = self.up[i_level].upsample(h)
|
291 |
+
|
292 |
+
# end
|
293 |
+
h = self.norm_out(h)
|
294 |
+
h = swish(h)
|
295 |
+
h = self.conv_out(h)
|
296 |
+
return h
|
297 |
+
|
298 |
+
|
299 |
+
class DiagonalGaussian(nn.Module):
|
300 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
301 |
+
super().__init__()
|
302 |
+
self.sample = sample
|
303 |
+
self.chunk_dim = chunk_dim
|
304 |
+
|
305 |
+
def forward(self, z: Tensor) -> Tensor:
|
306 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
307 |
+
if self.sample:
|
308 |
+
std = torch.exp(0.5 * logvar)
|
309 |
+
return mean + std * torch.randn_like(mean)
|
310 |
+
else:
|
311 |
+
return mean
|
312 |
+
|
313 |
+
|
314 |
+
class AutoEncoder(nn.Module):
|
315 |
+
def __init__(self, params: AutoEncoderParams):
|
316 |
+
super().__init__()
|
317 |
+
self.encoder = Encoder(
|
318 |
+
resolution=params.resolution,
|
319 |
+
in_channels=params.in_channels,
|
320 |
+
ch=params.ch,
|
321 |
+
ch_mult=params.ch_mult,
|
322 |
+
num_res_blocks=params.num_res_blocks,
|
323 |
+
z_channels=params.z_channels,
|
324 |
+
)
|
325 |
+
self.decoder = Decoder(
|
326 |
+
resolution=params.resolution,
|
327 |
+
in_channels=params.in_channels,
|
328 |
+
ch=params.ch,
|
329 |
+
out_ch=params.out_ch,
|
330 |
+
ch_mult=params.ch_mult,
|
331 |
+
num_res_blocks=params.num_res_blocks,
|
332 |
+
z_channels=params.z_channels,
|
333 |
+
)
|
334 |
+
self.reg = DiagonalGaussian()
|
335 |
+
|
336 |
+
self.scale_factor = params.scale_factor
|
337 |
+
self.shift_factor = params.shift_factor
|
338 |
+
|
339 |
+
@property
|
340 |
+
def device(self) -> torch.device:
|
341 |
+
return next(self.parameters()).device
|
342 |
+
|
343 |
+
@property
|
344 |
+
def dtype(self) -> torch.dtype:
|
345 |
+
return next(self.parameters()).dtype
|
346 |
+
|
347 |
+
def encode(self, x: Tensor) -> Tensor:
|
348 |
+
z = self.reg(self.encoder(x))
|
349 |
+
z = self.scale_factor * (z - self.shift_factor)
|
350 |
+
return z
|
351 |
+
|
352 |
+
def decode(self, z: Tensor) -> Tensor:
|
353 |
+
z = z / self.scale_factor + self.shift_factor
|
354 |
+
return self.decoder(z)
|
355 |
+
|
356 |
+
def forward(self, x: Tensor) -> Tensor:
|
357 |
+
return self.decode(self.encode(x))
|
358 |
+
|
359 |
+
|
360 |
+
# endregion
|
361 |
+
# region config
|
362 |
+
|
363 |
+
|
364 |
+
@dataclass
|
365 |
+
class ModelSpec:
|
366 |
+
params: FluxParams
|
367 |
+
ae_params: AutoEncoderParams
|
368 |
+
ckpt_path: str | None
|
369 |
+
ae_path: str | None
|
370 |
+
# repo_id: str | None
|
371 |
+
# repo_flow: str | None
|
372 |
+
# repo_ae: str | None
|
373 |
+
|
374 |
+
|
375 |
+
configs = {
|
376 |
+
"dev": ModelSpec(
|
377 |
+
# repo_id="black-forest-labs/FLUX.1-dev",
|
378 |
+
# repo_flow="flux1-dev.sft",
|
379 |
+
# repo_ae="ae.sft",
|
380 |
+
ckpt_path=None, # os.getenv("FLUX_DEV"),
|
381 |
+
params=FluxParams(
|
382 |
+
in_channels=64,
|
383 |
+
vec_in_dim=768,
|
384 |
+
context_in_dim=4096,
|
385 |
+
hidden_size=3072,
|
386 |
+
mlp_ratio=4.0,
|
387 |
+
num_heads=24,
|
388 |
+
depth=19,
|
389 |
+
depth_single_blocks=38,
|
390 |
+
axes_dim=[16, 56, 56],
|
391 |
+
theta=10_000,
|
392 |
+
qkv_bias=True,
|
393 |
+
guidance_embed=True,
|
394 |
+
),
|
395 |
+
ae_path=None, # os.getenv("AE"),
|
396 |
+
ae_params=AutoEncoderParams(
|
397 |
+
resolution=256,
|
398 |
+
in_channels=3,
|
399 |
+
ch=128,
|
400 |
+
out_ch=3,
|
401 |
+
ch_mult=[1, 2, 4, 4],
|
402 |
+
num_res_blocks=2,
|
403 |
+
z_channels=16,
|
404 |
+
scale_factor=0.3611,
|
405 |
+
shift_factor=0.1159,
|
406 |
+
),
|
407 |
+
),
|
408 |
+
"schnell": ModelSpec(
|
409 |
+
# repo_id="black-forest-labs/FLUX.1-schnell",
|
410 |
+
# repo_flow="flux1-schnell.sft",
|
411 |
+
# repo_ae="ae.sft",
|
412 |
+
ckpt_path=None, # os.getenv("FLUX_SCHNELL"),
|
413 |
+
params=FluxParams(
|
414 |
+
in_channels=64,
|
415 |
+
vec_in_dim=768,
|
416 |
+
context_in_dim=4096,
|
417 |
+
hidden_size=3072,
|
418 |
+
mlp_ratio=4.0,
|
419 |
+
num_heads=24,
|
420 |
+
depth=19,
|
421 |
+
depth_single_blocks=38,
|
422 |
+
axes_dim=[16, 56, 56],
|
423 |
+
theta=10_000,
|
424 |
+
qkv_bias=True,
|
425 |
+
guidance_embed=False,
|
426 |
+
),
|
427 |
+
ae_path=None, # os.getenv("AE"),
|
428 |
+
ae_params=AutoEncoderParams(
|
429 |
+
resolution=256,
|
430 |
+
in_channels=3,
|
431 |
+
ch=128,
|
432 |
+
out_ch=3,
|
433 |
+
ch_mult=[1, 2, 4, 4],
|
434 |
+
num_res_blocks=2,
|
435 |
+
z_channels=16,
|
436 |
+
scale_factor=0.3611,
|
437 |
+
shift_factor=0.1159,
|
438 |
+
),
|
439 |
+
),
|
440 |
+
}
|
441 |
+
|
442 |
+
|
443 |
+
# endregion
|
444 |
+
|
445 |
+
# region math
|
446 |
+
|
447 |
+
|
448 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
|
449 |
+
q, k = apply_rope(q, k, pe)
|
450 |
+
|
451 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
452 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
453 |
+
|
454 |
+
return x
|
455 |
+
|
456 |
+
|
457 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
458 |
+
assert dim % 2 == 0
|
459 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
460 |
+
omega = 1.0 / (theta**scale)
|
461 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
462 |
+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
463 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
464 |
+
return out.float()
|
465 |
+
|
466 |
+
|
467 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
468 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
469 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
470 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
471 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
472 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
473 |
+
|
474 |
+
|
475 |
+
# endregion
|
476 |
+
|
477 |
+
|
478 |
+
# region layers
|
479 |
+
|
480 |
+
|
481 |
+
# for cpu_offload_checkpointing
|
482 |
+
|
483 |
+
|
484 |
+
def to_cuda(x):
|
485 |
+
if isinstance(x, torch.Tensor):
|
486 |
+
return x.cuda()
|
487 |
+
elif isinstance(x, (list, tuple)):
|
488 |
+
return [to_cuda(elem) for elem in x]
|
489 |
+
elif isinstance(x, dict):
|
490 |
+
return {k: to_cuda(v) for k, v in x.items()}
|
491 |
+
else:
|
492 |
+
return x
|
493 |
+
|
494 |
+
|
495 |
+
def to_cpu(x):
|
496 |
+
if isinstance(x, torch.Tensor):
|
497 |
+
return x.cpu()
|
498 |
+
elif isinstance(x, (list, tuple)):
|
499 |
+
return [to_cpu(elem) for elem in x]
|
500 |
+
elif isinstance(x, dict):
|
501 |
+
return {k: to_cpu(v) for k, v in x.items()}
|
502 |
+
else:
|
503 |
+
return x
|
504 |
+
|
505 |
+
|
506 |
+
class EmbedND(nn.Module):
|
507 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
508 |
+
super().__init__()
|
509 |
+
self.dim = dim
|
510 |
+
self.theta = theta
|
511 |
+
self.axes_dim = axes_dim
|
512 |
+
|
513 |
+
def forward(self, ids: Tensor) -> Tensor:
|
514 |
+
n_axes = ids.shape[-1]
|
515 |
+
emb = torch.cat(
|
516 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
517 |
+
dim=-3,
|
518 |
+
)
|
519 |
+
|
520 |
+
return emb.unsqueeze(1)
|
521 |
+
|
522 |
+
|
523 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
524 |
+
"""
|
525 |
+
Create sinusoidal timestep embeddings.
|
526 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
527 |
+
These may be fractional.
|
528 |
+
:param dim: the dimension of the output.
|
529 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
530 |
+
:return: an (N, D) Tensor of positional embeddings.
|
531 |
+
"""
|
532 |
+
t = time_factor * t
|
533 |
+
half = dim // 2
|
534 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
535 |
+
|
536 |
+
args = t[:, None].float() * freqs[None]
|
537 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
538 |
+
if dim % 2:
|
539 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
540 |
+
if torch.is_floating_point(t):
|
541 |
+
embedding = embedding.to(t)
|
542 |
+
return embedding
|
543 |
+
|
544 |
+
|
545 |
+
class MLPEmbedder(nn.Module):
|
546 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
547 |
+
super().__init__()
|
548 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
549 |
+
self.silu = nn.SiLU()
|
550 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
551 |
+
|
552 |
+
self.gradient_checkpointing = False
|
553 |
+
|
554 |
+
def enable_gradient_checkpointing(self):
|
555 |
+
self.gradient_checkpointing = True
|
556 |
+
|
557 |
+
def disable_gradient_checkpointing(self):
|
558 |
+
self.gradient_checkpointing = False
|
559 |
+
|
560 |
+
def _forward(self, x: Tensor) -> Tensor:
|
561 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
562 |
+
|
563 |
+
def forward(self, *args, **kwargs):
|
564 |
+
if self.training and self.gradient_checkpointing:
|
565 |
+
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
566 |
+
else:
|
567 |
+
return self._forward(*args, **kwargs)
|
568 |
+
|
569 |
+
# def forward(self, x):
|
570 |
+
# if self.training and self.gradient_checkpointing:
|
571 |
+
# def create_custom_forward(func):
|
572 |
+
# def custom_forward(*inputs):
|
573 |
+
# return func(*inputs)
|
574 |
+
# return custom_forward
|
575 |
+
# return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, use_reentrant=USE_REENTRANT)
|
576 |
+
# else:
|
577 |
+
# return self._forward(x)
|
578 |
+
|
579 |
+
|
580 |
+
class RMSNorm(torch.nn.Module):
|
581 |
+
def __init__(self, dim: int):
|
582 |
+
super().__init__()
|
583 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
584 |
+
|
585 |
+
def forward(self, x: Tensor):
|
586 |
+
x_dtype = x.dtype
|
587 |
+
x = x.float()
|
588 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
589 |
+
# return (x * rrms).to(dtype=x_dtype) * self.scale
|
590 |
+
return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
|
591 |
+
|
592 |
+
|
593 |
+
class QKNorm(torch.nn.Module):
|
594 |
+
def __init__(self, dim: int):
|
595 |
+
super().__init__()
|
596 |
+
self.query_norm = RMSNorm(dim)
|
597 |
+
self.key_norm = RMSNorm(dim)
|
598 |
+
|
599 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
600 |
+
q = self.query_norm(q)
|
601 |
+
k = self.key_norm(k)
|
602 |
+
return q.to(v), k.to(v)
|
603 |
+
|
604 |
+
|
605 |
+
class SelfAttention(nn.Module):
|
606 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
607 |
+
super().__init__()
|
608 |
+
self.num_heads = num_heads
|
609 |
+
head_dim = dim // num_heads
|
610 |
+
|
611 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
612 |
+
self.norm = QKNorm(head_dim)
|
613 |
+
self.proj = nn.Linear(dim, dim)
|
614 |
+
|
615 |
+
# this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly
|
616 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
617 |
+
qkv = self.qkv(x)
|
618 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
619 |
+
q, k = self.norm(q, k, v)
|
620 |
+
x = attention(q, k, v, pe=pe)
|
621 |
+
x = self.proj(x)
|
622 |
+
return x
|
623 |
+
|
624 |
+
|
625 |
+
@dataclass
|
626 |
+
class ModulationOut:
|
627 |
+
shift: Tensor
|
628 |
+
scale: Tensor
|
629 |
+
gate: Tensor
|
630 |
+
|
631 |
+
|
632 |
+
class Modulation(nn.Module):
|
633 |
+
def __init__(self, dim: int, double: bool):
|
634 |
+
super().__init__()
|
635 |
+
self.is_double = double
|
636 |
+
self.multiplier = 6 if double else 3
|
637 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
638 |
+
|
639 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
640 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
641 |
+
|
642 |
+
return (
|
643 |
+
ModulationOut(*out[:3]),
|
644 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
645 |
+
)
|
646 |
+
|
647 |
+
|
648 |
+
class DoubleStreamBlock(nn.Module):
|
649 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
650 |
+
super().__init__()
|
651 |
+
|
652 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
653 |
+
self.num_heads = num_heads
|
654 |
+
self.hidden_size = hidden_size
|
655 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
656 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
657 |
+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
658 |
+
|
659 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
660 |
+
self.img_mlp = nn.Sequential(
|
661 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
662 |
+
nn.GELU(approximate="tanh"),
|
663 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
664 |
+
)
|
665 |
+
|
666 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
667 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
668 |
+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
669 |
+
|
670 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
671 |
+
self.txt_mlp = nn.Sequential(
|
672 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
673 |
+
nn.GELU(approximate="tanh"),
|
674 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
675 |
+
)
|
676 |
+
|
677 |
+
self.gradient_checkpointing = False
|
678 |
+
self.cpu_offload_checkpointing = False
|
679 |
+
|
680 |
+
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
681 |
+
self.gradient_checkpointing = True
|
682 |
+
self.cpu_offload_checkpointing = cpu_offload
|
683 |
+
|
684 |
+
def disable_gradient_checkpointing(self):
|
685 |
+
self.gradient_checkpointing = False
|
686 |
+
self.cpu_offload_checkpointing = False
|
687 |
+
|
688 |
+
def _forward(
|
689 |
+
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
|
690 |
+
) -> tuple[Tensor, Tensor]:
|
691 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
692 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
693 |
+
|
694 |
+
# prepare image for attention
|
695 |
+
img_modulated = self.img_norm1(img)
|
696 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
697 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
698 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
699 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
700 |
+
|
701 |
+
# prepare txt for attention
|
702 |
+
txt_modulated = self.txt_norm1(txt)
|
703 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
704 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
705 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
706 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
707 |
+
|
708 |
+
# run actual attention
|
709 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
710 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
711 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
712 |
+
|
713 |
+
# make attention mask if not None
|
714 |
+
attn_mask = None
|
715 |
+
if txt_attention_mask is not None:
|
716 |
+
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
|
717 |
+
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
|
718 |
+
attn_mask = torch.cat(
|
719 |
+
(attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1
|
720 |
+
) # b, seq_len + img_len
|
721 |
+
|
722 |
+
# broadcast attn_mask to all heads
|
723 |
+
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
|
724 |
+
|
725 |
+
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
726 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
727 |
+
|
728 |
+
# calculate the img blocks
|
729 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
730 |
+
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
731 |
+
|
732 |
+
# calculate the txt blocks
|
733 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
734 |
+
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
735 |
+
return img, txt
|
736 |
+
|
737 |
+
def forward(
|
738 |
+
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
|
739 |
+
) -> tuple[Tensor, Tensor]:
|
740 |
+
if self.training and self.gradient_checkpointing:
|
741 |
+
if not self.cpu_offload_checkpointing:
|
742 |
+
return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False)
|
743 |
+
# cpu offload checkpointing
|
744 |
+
|
745 |
+
def create_custom_forward(func):
|
746 |
+
def custom_forward(*inputs):
|
747 |
+
cuda_inputs = to_cuda(inputs)
|
748 |
+
outputs = func(*cuda_inputs)
|
749 |
+
return to_cpu(outputs)
|
750 |
+
|
751 |
+
return custom_forward
|
752 |
+
|
753 |
+
return torch.utils.checkpoint.checkpoint(
|
754 |
+
create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False
|
755 |
+
)
|
756 |
+
|
757 |
+
else:
|
758 |
+
return self._forward(img, txt, vec, pe, txt_attention_mask)
|
759 |
+
|
760 |
+
|
761 |
+
class SingleStreamBlock(nn.Module):
|
762 |
+
"""
|
763 |
+
A DiT block with parallel linear layers as described in
|
764 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
765 |
+
"""
|
766 |
+
|
767 |
+
def __init__(
|
768 |
+
self,
|
769 |
+
hidden_size: int,
|
770 |
+
num_heads: int,
|
771 |
+
mlp_ratio: float = 4.0,
|
772 |
+
qk_scale: float | None = None,
|
773 |
+
):
|
774 |
+
super().__init__()
|
775 |
+
self.hidden_dim = hidden_size
|
776 |
+
self.num_heads = num_heads
|
777 |
+
head_dim = hidden_size // num_heads
|
778 |
+
self.scale = qk_scale or head_dim**-0.5
|
779 |
+
|
780 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
781 |
+
# qkv and mlp_in
|
782 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
783 |
+
# proj and mlp_out
|
784 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
785 |
+
|
786 |
+
self.norm = QKNorm(head_dim)
|
787 |
+
|
788 |
+
self.hidden_size = hidden_size
|
789 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
790 |
+
|
791 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
792 |
+
self.modulation = Modulation(hidden_size, double=False)
|
793 |
+
|
794 |
+
self.gradient_checkpointing = False
|
795 |
+
self.cpu_offload_checkpointing = False
|
796 |
+
|
797 |
+
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
798 |
+
self.gradient_checkpointing = True
|
799 |
+
self.cpu_offload_checkpointing = cpu_offload
|
800 |
+
|
801 |
+
def disable_gradient_checkpointing(self):
|
802 |
+
self.gradient_checkpointing = False
|
803 |
+
self.cpu_offload_checkpointing = False
|
804 |
+
|
805 |
+
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
806 |
+
mod, _ = self.modulation(vec)
|
807 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
808 |
+
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
809 |
+
|
810 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
811 |
+
q, k = self.norm(q, k, v)
|
812 |
+
|
813 |
+
# make attention mask if not None
|
814 |
+
attn_mask = None
|
815 |
+
if txt_attention_mask is not None:
|
816 |
+
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
|
817 |
+
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
|
818 |
+
attn_mask = torch.cat(
|
819 |
+
(
|
820 |
+
attn_mask,
|
821 |
+
torch.ones(
|
822 |
+
attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool
|
823 |
+
),
|
824 |
+
),
|
825 |
+
dim=1,
|
826 |
+
) # b, seq_len + img_len = x_len
|
827 |
+
|
828 |
+
# broadcast attn_mask to all heads
|
829 |
+
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
|
830 |
+
|
831 |
+
# compute attention
|
832 |
+
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
833 |
+
|
834 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
835 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
836 |
+
return x + mod.gate * output
|
837 |
+
|
838 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
839 |
+
if self.training and self.gradient_checkpointing:
|
840 |
+
if not self.cpu_offload_checkpointing:
|
841 |
+
return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False)
|
842 |
+
|
843 |
+
# cpu offload checkpointing
|
844 |
+
|
845 |
+
def create_custom_forward(func):
|
846 |
+
def custom_forward(*inputs):
|
847 |
+
cuda_inputs = to_cuda(inputs)
|
848 |
+
outputs = func(*cuda_inputs)
|
849 |
+
return to_cpu(outputs)
|
850 |
+
|
851 |
+
return custom_forward
|
852 |
+
|
853 |
+
return torch.utils.checkpoint.checkpoint(
|
854 |
+
create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False
|
855 |
+
)
|
856 |
+
else:
|
857 |
+
return self._forward(x, vec, pe, txt_attention_mask)
|
858 |
+
|
859 |
+
|
860 |
+
class LastLayer(nn.Module):
|
861 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
862 |
+
super().__init__()
|
863 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
864 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
865 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
866 |
+
|
867 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
868 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
869 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
870 |
+
x = self.linear(x)
|
871 |
+
return x
|
872 |
+
|
873 |
+
|
874 |
+
# endregion
|
875 |
+
|
876 |
+
|
877 |
+
class Flux(nn.Module):
|
878 |
+
"""
|
879 |
+
Transformer model for flow matching on sequences.
|
880 |
+
"""
|
881 |
+
|
882 |
+
def __init__(self, params: FluxParams):
|
883 |
+
super().__init__()
|
884 |
+
|
885 |
+
self.params = params
|
886 |
+
self.in_channels = params.in_channels
|
887 |
+
self.out_channels = self.in_channels
|
888 |
+
if params.hidden_size % params.num_heads != 0:
|
889 |
+
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
890 |
+
pe_dim = params.hidden_size // params.num_heads
|
891 |
+
if sum(params.axes_dim) != pe_dim:
|
892 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
893 |
+
self.hidden_size = params.hidden_size
|
894 |
+
self.num_heads = params.num_heads
|
895 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
896 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
897 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
898 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
899 |
+
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
900 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
901 |
+
|
902 |
+
self.double_blocks = nn.ModuleList(
|
903 |
+
[
|
904 |
+
DoubleStreamBlock(
|
905 |
+
self.hidden_size,
|
906 |
+
self.num_heads,
|
907 |
+
mlp_ratio=params.mlp_ratio,
|
908 |
+
qkv_bias=params.qkv_bias,
|
909 |
+
)
|
910 |
+
for _ in range(params.depth)
|
911 |
+
]
|
912 |
+
)
|
913 |
+
|
914 |
+
self.single_blocks = nn.ModuleList(
|
915 |
+
[
|
916 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
917 |
+
for _ in range(params.depth_single_blocks)
|
918 |
+
]
|
919 |
+
)
|
920 |
+
|
921 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
922 |
+
|
923 |
+
self.gradient_checkpointing = False
|
924 |
+
self.cpu_offload_checkpointing = False
|
925 |
+
self.blocks_to_swap = None
|
926 |
+
|
927 |
+
self.offloader_double = None
|
928 |
+
self.offloader_single = None
|
929 |
+
self.num_double_blocks = len(self.double_blocks)
|
930 |
+
self.num_single_blocks = len(self.single_blocks)
|
931 |
+
|
932 |
+
@property
|
933 |
+
def device(self):
|
934 |
+
return next(self.parameters()).device
|
935 |
+
|
936 |
+
@property
|
937 |
+
def dtype(self):
|
938 |
+
return next(self.parameters()).dtype
|
939 |
+
|
940 |
+
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
941 |
+
self.gradient_checkpointing = True
|
942 |
+
self.cpu_offload_checkpointing = cpu_offload
|
943 |
+
|
944 |
+
self.time_in.enable_gradient_checkpointing()
|
945 |
+
self.vector_in.enable_gradient_checkpointing()
|
946 |
+
if self.guidance_in.__class__ != nn.Identity:
|
947 |
+
self.guidance_in.enable_gradient_checkpointing()
|
948 |
+
|
949 |
+
for block in self.double_blocks + self.single_blocks:
|
950 |
+
block.enable_gradient_checkpointing(cpu_offload=cpu_offload)
|
951 |
+
|
952 |
+
print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}")
|
953 |
+
|
954 |
+
def disable_gradient_checkpointing(self):
|
955 |
+
self.gradient_checkpointing = False
|
956 |
+
self.cpu_offload_checkpointing = False
|
957 |
+
|
958 |
+
self.time_in.disable_gradient_checkpointing()
|
959 |
+
self.vector_in.disable_gradient_checkpointing()
|
960 |
+
if self.guidance_in.__class__ != nn.Identity:
|
961 |
+
self.guidance_in.disable_gradient_checkpointing()
|
962 |
+
|
963 |
+
for block in self.double_blocks + self.single_blocks:
|
964 |
+
block.disable_gradient_checkpointing()
|
965 |
+
|
966 |
+
print("FLUX: Gradient checkpointing disabled.")
|
967 |
+
|
968 |
+
def enable_block_swap(self, num_blocks: int, device: torch.device):
|
969 |
+
self.blocks_to_swap = num_blocks
|
970 |
+
double_blocks_to_swap = num_blocks // 2
|
971 |
+
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
|
972 |
+
|
973 |
+
assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, (
|
974 |
+
f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "
|
975 |
+
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
|
976 |
+
)
|
977 |
+
|
978 |
+
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
979 |
+
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
|
980 |
+
)
|
981 |
+
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
982 |
+
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
|
983 |
+
)
|
984 |
+
print(
|
985 |
+
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
986 |
+
)
|
987 |
+
|
988 |
+
def move_to_device_except_swap_blocks(self, device: torch.device):
|
989 |
+
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
990 |
+
if self.blocks_to_swap:
|
991 |
+
save_double_blocks = self.double_blocks
|
992 |
+
save_single_blocks = self.single_blocks
|
993 |
+
self.double_blocks = None
|
994 |
+
self.single_blocks = None
|
995 |
+
|
996 |
+
self.to(device)
|
997 |
+
|
998 |
+
if self.blocks_to_swap:
|
999 |
+
self.double_blocks = save_double_blocks
|
1000 |
+
self.single_blocks = save_single_blocks
|
1001 |
+
|
1002 |
+
def prepare_block_swap_before_forward(self):
|
1003 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
1004 |
+
return
|
1005 |
+
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
|
1006 |
+
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
|
1007 |
+
|
1008 |
+
def forward(
|
1009 |
+
self,
|
1010 |
+
img: Tensor,
|
1011 |
+
img_ids: Tensor,
|
1012 |
+
txt: Tensor,
|
1013 |
+
txt_ids: Tensor,
|
1014 |
+
timesteps: Tensor,
|
1015 |
+
y: Tensor,
|
1016 |
+
guidance: Tensor | None = None,
|
1017 |
+
txt_attention_mask: Tensor | None = None,
|
1018 |
+
) -> Tensor:
|
1019 |
+
if img.ndim != 3 or txt.ndim != 3:
|
1020 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
1021 |
+
|
1022 |
+
# running on sequences img
|
1023 |
+
img = self.img_in(img)
|
1024 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
1025 |
+
if self.params.guidance_embed:
|
1026 |
+
if guidance is None:
|
1027 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
1028 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
1029 |
+
vec = vec + self.vector_in(y)
|
1030 |
+
txt = self.txt_in(txt)
|
1031 |
+
|
1032 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
1033 |
+
pe = self.pe_embedder(ids)
|
1034 |
+
|
1035 |
+
if not self.blocks_to_swap:
|
1036 |
+
for block in self.double_blocks:
|
1037 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1038 |
+
img = torch.cat((txt, img), 1)
|
1039 |
+
for block in self.single_blocks:
|
1040 |
+
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1041 |
+
else:
|
1042 |
+
for block_idx, block in enumerate(self.double_blocks):
|
1043 |
+
self.offloader_double.wait_for_block(block_idx)
|
1044 |
+
|
1045 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1046 |
+
|
1047 |
+
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
|
1048 |
+
|
1049 |
+
img = torch.cat((txt, img), 1)
|
1050 |
+
|
1051 |
+
for block_idx, block in enumerate(self.single_blocks):
|
1052 |
+
self.offloader_single.wait_for_block(block_idx)
|
1053 |
+
|
1054 |
+
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1055 |
+
|
1056 |
+
self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
|
1057 |
+
|
1058 |
+
img = img[:, txt.shape[1] :, ...]
|
1059 |
+
|
1060 |
+
if self.training and self.cpu_offload_checkpointing:
|
1061 |
+
img = img.to(self.device)
|
1062 |
+
vec = vec.to(self.device)
|
1063 |
+
|
1064 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
1065 |
+
|
1066 |
+
return img
|
1067 |
+
|
1068 |
+
|
1069 |
+
"""
|
1070 |
+
class FluxUpper(nn.Module):
|
1071 |
+
""
|
1072 |
+
Transformer model for flow matching on sequences.
|
1073 |
+
""
|
1074 |
+
|
1075 |
+
def __init__(self, params: FluxParams):
|
1076 |
+
super().__init__()
|
1077 |
+
|
1078 |
+
self.params = params
|
1079 |
+
self.in_channels = params.in_channels
|
1080 |
+
self.out_channels = self.in_channels
|
1081 |
+
if params.hidden_size % params.num_heads != 0:
|
1082 |
+
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
1083 |
+
pe_dim = params.hidden_size // params.num_heads
|
1084 |
+
if sum(params.axes_dim) != pe_dim:
|
1085 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
1086 |
+
self.hidden_size = params.hidden_size
|
1087 |
+
self.num_heads = params.num_heads
|
1088 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
1089 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
1090 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
1091 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
1092 |
+
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
1093 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
1094 |
+
|
1095 |
+
self.double_blocks = nn.ModuleList(
|
1096 |
+
[
|
1097 |
+
DoubleStreamBlock(
|
1098 |
+
self.hidden_size,
|
1099 |
+
self.num_heads,
|
1100 |
+
mlp_ratio=params.mlp_ratio,
|
1101 |
+
qkv_bias=params.qkv_bias,
|
1102 |
+
)
|
1103 |
+
for _ in range(params.depth)
|
1104 |
+
]
|
1105 |
+
)
|
1106 |
+
|
1107 |
+
self.gradient_checkpointing = False
|
1108 |
+
|
1109 |
+
@property
|
1110 |
+
def device(self):
|
1111 |
+
return next(self.parameters()).device
|
1112 |
+
|
1113 |
+
@property
|
1114 |
+
def dtype(self):
|
1115 |
+
return next(self.parameters()).dtype
|
1116 |
+
|
1117 |
+
def enable_gradient_checkpointing(self):
|
1118 |
+
self.gradient_checkpointing = True
|
1119 |
+
|
1120 |
+
self.time_in.enable_gradient_checkpointing()
|
1121 |
+
self.vector_in.enable_gradient_checkpointing()
|
1122 |
+
if self.guidance_in.__class__ != nn.Identity:
|
1123 |
+
self.guidance_in.enable_gradient_checkpointing()
|
1124 |
+
|
1125 |
+
for block in self.double_blocks:
|
1126 |
+
block.enable_gradient_checkpointing()
|
1127 |
+
|
1128 |
+
print("FLUX: Gradient checkpointing enabled.")
|
1129 |
+
|
1130 |
+
def disable_gradient_checkpointing(self):
|
1131 |
+
self.gradient_checkpointing = False
|
1132 |
+
|
1133 |
+
self.time_in.disable_gradient_checkpointing()
|
1134 |
+
self.vector_in.disable_gradient_checkpointing()
|
1135 |
+
if self.guidance_in.__class__ != nn.Identity:
|
1136 |
+
self.guidance_in.disable_gradient_checkpointing()
|
1137 |
+
|
1138 |
+
for block in self.double_blocks:
|
1139 |
+
block.disable_gradient_checkpointing()
|
1140 |
+
|
1141 |
+
print("FLUX: Gradient checkpointing disabled.")
|
1142 |
+
|
1143 |
+
def forward(
|
1144 |
+
self,
|
1145 |
+
img: Tensor,
|
1146 |
+
img_ids: Tensor,
|
1147 |
+
txt: Tensor,
|
1148 |
+
txt_ids: Tensor,
|
1149 |
+
timesteps: Tensor,
|
1150 |
+
y: Tensor,
|
1151 |
+
guidance: Tensor | None = None,
|
1152 |
+
txt_attention_mask: Tensor | None = None,
|
1153 |
+
) -> Tensor:
|
1154 |
+
if img.ndim != 3 or txt.ndim != 3:
|
1155 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
1156 |
+
|
1157 |
+
# running on sequences img
|
1158 |
+
img = self.img_in(img)
|
1159 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
1160 |
+
if self.params.guidance_embed:
|
1161 |
+
if guidance is None:
|
1162 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
1163 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
1164 |
+
vec = vec + self.vector_in(y)
|
1165 |
+
txt = self.txt_in(txt)
|
1166 |
+
|
1167 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
1168 |
+
pe = self.pe_embedder(ids)
|
1169 |
+
|
1170 |
+
for block in self.double_blocks:
|
1171 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1172 |
+
|
1173 |
+
return img, txt, vec, pe
|
1174 |
+
|
1175 |
+
|
1176 |
+
class FluxLower(nn.Module):
|
1177 |
+
""
|
1178 |
+
Transformer model for flow matching on sequences.
|
1179 |
+
""
|
1180 |
+
|
1181 |
+
def __init__(self, params: FluxParams):
|
1182 |
+
super().__init__()
|
1183 |
+
self.hidden_size = params.hidden_size
|
1184 |
+
self.num_heads = params.num_heads
|
1185 |
+
self.out_channels = params.in_channels
|
1186 |
+
|
1187 |
+
self.single_blocks = nn.ModuleList(
|
1188 |
+
[
|
1189 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
1190 |
+
for _ in range(params.depth_single_blocks)
|
1191 |
+
]
|
1192 |
+
)
|
1193 |
+
|
1194 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
1195 |
+
|
1196 |
+
self.gradient_checkpointing = False
|
1197 |
+
|
1198 |
+
@property
|
1199 |
+
def device(self):
|
1200 |
+
return next(self.parameters()).device
|
1201 |
+
|
1202 |
+
@property
|
1203 |
+
def dtype(self):
|
1204 |
+
return next(self.parameters()).dtype
|
1205 |
+
|
1206 |
+
def enable_gradient_checkpointing(self):
|
1207 |
+
self.gradient_checkpointing = True
|
1208 |
+
|
1209 |
+
for block in self.single_blocks:
|
1210 |
+
block.enable_gradient_checkpointing()
|
1211 |
+
|
1212 |
+
print("FLUX: Gradient checkpointing enabled.")
|
1213 |
+
|
1214 |
+
def disable_gradient_checkpointing(self):
|
1215 |
+
self.gradient_checkpointing = False
|
1216 |
+
|
1217 |
+
for block in self.single_blocks:
|
1218 |
+
block.disable_gradient_checkpointing()
|
1219 |
+
|
1220 |
+
print("FLUX: Gradient checkpointing disabled.")
|
1221 |
+
|
1222 |
+
def forward(
|
1223 |
+
self,
|
1224 |
+
img: Tensor,
|
1225 |
+
txt: Tensor,
|
1226 |
+
vec: Tensor | None = None,
|
1227 |
+
pe: Tensor | None = None,
|
1228 |
+
txt_attention_mask: Tensor | None = None,
|
1229 |
+
) -> Tensor:
|
1230 |
+
img = torch.cat((txt, img), 1)
|
1231 |
+
for block in self.single_blocks:
|
1232 |
+
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1233 |
+
img = img[:, txt.shape[1] :, ...]
|
1234 |
+
|
1235 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
1236 |
+
return img
|
1237 |
+
"""
|
library/flux_train_utils.py
ADDED
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import toml
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from accelerate import Accelerator, PartialState
|
12 |
+
from transformers import CLIPTextModel
|
13 |
+
from tqdm import tqdm
|
14 |
+
from PIL import Image
|
15 |
+
from safetensors.torch import save_file
|
16 |
+
|
17 |
+
from library import flux_models, flux_utils, strategy_base, train_util
|
18 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
19 |
+
|
20 |
+
init_ipex()
|
21 |
+
|
22 |
+
from .utils import setup_logging, mem_eff_save_file
|
23 |
+
|
24 |
+
setup_logging()
|
25 |
+
import logging
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
# region sample images
|
31 |
+
|
32 |
+
|
33 |
+
def sample_images(
|
34 |
+
accelerator: Accelerator,
|
35 |
+
args: argparse.Namespace,
|
36 |
+
epoch,
|
37 |
+
steps,
|
38 |
+
flux,
|
39 |
+
ae,
|
40 |
+
text_encoders,
|
41 |
+
sample_prompts_te_outputs,
|
42 |
+
prompt_replacement=None,
|
43 |
+
):
|
44 |
+
if steps == 0:
|
45 |
+
if not args.sample_at_first:
|
46 |
+
return
|
47 |
+
else:
|
48 |
+
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
49 |
+
return
|
50 |
+
if args.sample_every_n_epochs is not None:
|
51 |
+
# sample_every_n_steps は無視する
|
52 |
+
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
53 |
+
return
|
54 |
+
else:
|
55 |
+
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
56 |
+
return
|
57 |
+
|
58 |
+
logger.info("")
|
59 |
+
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
60 |
+
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
61 |
+
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
62 |
+
return
|
63 |
+
|
64 |
+
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
65 |
+
|
66 |
+
# unwrap unet and text_encoder(s)
|
67 |
+
flux = accelerator.unwrap_model(flux)
|
68 |
+
if text_encoders is not None:
|
69 |
+
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
70 |
+
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
71 |
+
|
72 |
+
prompts = train_util.load_prompts(args.sample_prompts)
|
73 |
+
|
74 |
+
save_dir = args.output_dir + "/sample"
|
75 |
+
os.makedirs(save_dir, exist_ok=True)
|
76 |
+
|
77 |
+
# save random state to restore later
|
78 |
+
rng_state = torch.get_rng_state()
|
79 |
+
cuda_rng_state = None
|
80 |
+
try:
|
81 |
+
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
82 |
+
except Exception:
|
83 |
+
pass
|
84 |
+
|
85 |
+
if distributed_state.num_processes <= 1:
|
86 |
+
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
87 |
+
with torch.no_grad(), accelerator.autocast():
|
88 |
+
for prompt_dict in prompts:
|
89 |
+
sample_image_inference(
|
90 |
+
accelerator,
|
91 |
+
args,
|
92 |
+
flux,
|
93 |
+
text_encoders,
|
94 |
+
ae,
|
95 |
+
save_dir,
|
96 |
+
prompt_dict,
|
97 |
+
epoch,
|
98 |
+
steps,
|
99 |
+
sample_prompts_te_outputs,
|
100 |
+
prompt_replacement,
|
101 |
+
)
|
102 |
+
else:
|
103 |
+
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
104 |
+
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
105 |
+
per_process_prompts = [] # list of lists
|
106 |
+
for i in range(distributed_state.num_processes):
|
107 |
+
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
108 |
+
|
109 |
+
with torch.no_grad():
|
110 |
+
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
111 |
+
for prompt_dict in prompt_dict_lists[0]:
|
112 |
+
sample_image_inference(
|
113 |
+
accelerator,
|
114 |
+
args,
|
115 |
+
flux,
|
116 |
+
text_encoders,
|
117 |
+
ae,
|
118 |
+
save_dir,
|
119 |
+
prompt_dict,
|
120 |
+
epoch,
|
121 |
+
steps,
|
122 |
+
sample_prompts_te_outputs,
|
123 |
+
prompt_replacement,
|
124 |
+
)
|
125 |
+
|
126 |
+
torch.set_rng_state(rng_state)
|
127 |
+
if cuda_rng_state is not None:
|
128 |
+
torch.cuda.set_rng_state(cuda_rng_state)
|
129 |
+
|
130 |
+
clean_memory_on_device(accelerator.device)
|
131 |
+
|
132 |
+
|
133 |
+
def sample_image_inference(
|
134 |
+
accelerator: Accelerator,
|
135 |
+
args: argparse.Namespace,
|
136 |
+
flux: flux_models.Flux,
|
137 |
+
text_encoders: Optional[List[CLIPTextModel]],
|
138 |
+
ae: flux_models.AutoEncoder,
|
139 |
+
save_dir,
|
140 |
+
prompt_dict,
|
141 |
+
epoch,
|
142 |
+
steps,
|
143 |
+
sample_prompts_te_outputs,
|
144 |
+
prompt_replacement,
|
145 |
+
):
|
146 |
+
assert isinstance(prompt_dict, dict)
|
147 |
+
# negative_prompt = prompt_dict.get("negative_prompt")
|
148 |
+
sample_steps = prompt_dict.get("sample_steps", 20)
|
149 |
+
width = prompt_dict.get("width", 512)
|
150 |
+
height = prompt_dict.get("height", 512)
|
151 |
+
scale = prompt_dict.get("scale", 3.5)
|
152 |
+
seed = prompt_dict.get("seed")
|
153 |
+
# controlnet_image = prompt_dict.get("controlnet_image")
|
154 |
+
prompt: str = prompt_dict.get("prompt", "")
|
155 |
+
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
156 |
+
|
157 |
+
if prompt_replacement is not None:
|
158 |
+
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
159 |
+
# if negative_prompt is not None:
|
160 |
+
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
161 |
+
|
162 |
+
if seed is not None:
|
163 |
+
torch.manual_seed(seed)
|
164 |
+
torch.cuda.manual_seed(seed)
|
165 |
+
else:
|
166 |
+
# True random sample image generation
|
167 |
+
torch.seed()
|
168 |
+
torch.cuda.seed()
|
169 |
+
|
170 |
+
# if negative_prompt is None:
|
171 |
+
# negative_prompt = ""
|
172 |
+
|
173 |
+
height = max(64, height - height % 16) # round to divisible by 16
|
174 |
+
width = max(64, width - width % 16) # round to divisible by 16
|
175 |
+
logger.info(f"prompt: {prompt}")
|
176 |
+
# logger.info(f"negative_prompt: {negative_prompt}")
|
177 |
+
logger.info(f"height: {height}")
|
178 |
+
logger.info(f"width: {width}")
|
179 |
+
logger.info(f"sample_steps: {sample_steps}")
|
180 |
+
logger.info(f"scale: {scale}")
|
181 |
+
# logger.info(f"sample_sampler: {sampler_name}")
|
182 |
+
if seed is not None:
|
183 |
+
logger.info(f"seed: {seed}")
|
184 |
+
|
185 |
+
# encode prompts
|
186 |
+
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
187 |
+
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
188 |
+
|
189 |
+
text_encoder_conds = []
|
190 |
+
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
191 |
+
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
192 |
+
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
193 |
+
if text_encoders is not None:
|
194 |
+
print(f"Encoding prompt: {prompt}")
|
195 |
+
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
196 |
+
# strategy has apply_t5_attn_mask option
|
197 |
+
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
198 |
+
|
199 |
+
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
200 |
+
if len(text_encoder_conds) == 0:
|
201 |
+
text_encoder_conds = encoded_text_encoder_conds
|
202 |
+
else:
|
203 |
+
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
204 |
+
for i in range(len(encoded_text_encoder_conds)):
|
205 |
+
if encoded_text_encoder_conds[i] is not None:
|
206 |
+
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
207 |
+
|
208 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
209 |
+
|
210 |
+
# sample image
|
211 |
+
weight_dtype = ae.dtype # TOFO give dtype as argument
|
212 |
+
packed_latent_height = height // 16
|
213 |
+
packed_latent_width = width // 16
|
214 |
+
noise = torch.randn(
|
215 |
+
1,
|
216 |
+
packed_latent_height * packed_latent_width,
|
217 |
+
16 * 2 * 2,
|
218 |
+
device=accelerator.device,
|
219 |
+
dtype=weight_dtype,
|
220 |
+
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
|
221 |
+
)
|
222 |
+
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
|
223 |
+
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
224 |
+
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
225 |
+
|
226 |
+
with accelerator.autocast(), torch.no_grad():
|
227 |
+
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask)
|
228 |
+
|
229 |
+
x = x.float()
|
230 |
+
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
231 |
+
|
232 |
+
# latent to image
|
233 |
+
clean_memory_on_device(accelerator.device)
|
234 |
+
org_vae_device = ae.device # will be on cpu
|
235 |
+
ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
|
236 |
+
with accelerator.autocast(), torch.no_grad():
|
237 |
+
x = ae.decode(x)
|
238 |
+
ae.to(org_vae_device)
|
239 |
+
clean_memory_on_device(accelerator.device)
|
240 |
+
|
241 |
+
x = x.clamp(-1, 1)
|
242 |
+
x = x.permute(0, 2, 3, 1)
|
243 |
+
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
244 |
+
|
245 |
+
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
246 |
+
# but adding 'enum' to the filename should be enough
|
247 |
+
|
248 |
+
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
249 |
+
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
250 |
+
seed_suffix = "" if seed is None else f"_{seed}"
|
251 |
+
i: int = prompt_dict["enum"]
|
252 |
+
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
253 |
+
image.save(os.path.join(save_dir, img_filename))
|
254 |
+
|
255 |
+
# send images to wandb if enabled
|
256 |
+
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
257 |
+
wandb_tracker = accelerator.get_tracker("wandb")
|
258 |
+
|
259 |
+
import wandb
|
260 |
+
|
261 |
+
# not to commit images to avoid inconsistency between training and logging steps
|
262 |
+
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
263 |
+
|
264 |
+
|
265 |
+
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
266 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
267 |
+
|
268 |
+
|
269 |
+
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
270 |
+
m = (y2 - y1) / (x2 - x1)
|
271 |
+
b = y1 - m * x1
|
272 |
+
return lambda x: m * x + b
|
273 |
+
|
274 |
+
|
275 |
+
def get_schedule(
|
276 |
+
num_steps: int,
|
277 |
+
image_seq_len: int,
|
278 |
+
base_shift: float = 0.5,
|
279 |
+
max_shift: float = 1.15,
|
280 |
+
shift: bool = True,
|
281 |
+
) -> list[float]:
|
282 |
+
# extra step for zero
|
283 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
284 |
+
|
285 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
286 |
+
if shift:
|
287 |
+
# eastimate mu based on linear estimation between two points
|
288 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
289 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
290 |
+
|
291 |
+
return timesteps.tolist()
|
292 |
+
|
293 |
+
|
294 |
+
def denoise(
|
295 |
+
model: flux_models.Flux,
|
296 |
+
img: torch.Tensor,
|
297 |
+
img_ids: torch.Tensor,
|
298 |
+
txt: torch.Tensor,
|
299 |
+
txt_ids: torch.Tensor,
|
300 |
+
vec: torch.Tensor,
|
301 |
+
timesteps: list[float],
|
302 |
+
guidance: float = 4.0,
|
303 |
+
t5_attn_mask: Optional[torch.Tensor] = None,
|
304 |
+
):
|
305 |
+
# this is ignored for schnell
|
306 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
307 |
+
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
308 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
309 |
+
model.prepare_block_swap_before_forward()
|
310 |
+
pred = model(
|
311 |
+
img=img,
|
312 |
+
img_ids=img_ids,
|
313 |
+
txt=txt,
|
314 |
+
txt_ids=txt_ids,
|
315 |
+
y=vec,
|
316 |
+
timesteps=t_vec,
|
317 |
+
guidance=guidance_vec,
|
318 |
+
txt_attention_mask=t5_attn_mask,
|
319 |
+
)
|
320 |
+
|
321 |
+
img = img + (t_prev - t_curr) * pred
|
322 |
+
|
323 |
+
model.prepare_block_swap_before_forward()
|
324 |
+
return img
|
325 |
+
|
326 |
+
|
327 |
+
# endregion
|
328 |
+
|
329 |
+
|
330 |
+
# region train
|
331 |
+
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
332 |
+
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
333 |
+
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
334 |
+
timesteps = timesteps.to(device)
|
335 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
336 |
+
|
337 |
+
sigma = sigmas[step_indices].flatten()
|
338 |
+
while len(sigma.shape) < n_dim:
|
339 |
+
sigma = sigma.unsqueeze(-1)
|
340 |
+
return sigma
|
341 |
+
|
342 |
+
|
343 |
+
def compute_density_for_timestep_sampling(
|
344 |
+
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
345 |
+
):
|
346 |
+
"""Compute the density for sampling the timesteps when doing SD3 training.
|
347 |
+
|
348 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
349 |
+
|
350 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
351 |
+
"""
|
352 |
+
if weighting_scheme == "logit_normal":
|
353 |
+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
354 |
+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
355 |
+
u = torch.nn.functional.sigmoid(u)
|
356 |
+
elif weighting_scheme == "mode":
|
357 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
358 |
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
359 |
+
else:
|
360 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
361 |
+
return u
|
362 |
+
|
363 |
+
|
364 |
+
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
365 |
+
"""Computes loss weighting scheme for SD3 training.
|
366 |
+
|
367 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
368 |
+
|
369 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
370 |
+
"""
|
371 |
+
if weighting_scheme == "sigma_sqrt":
|
372 |
+
weighting = (sigmas**-2.0).float()
|
373 |
+
elif weighting_scheme == "cosmap":
|
374 |
+
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
375 |
+
weighting = 2 / (math.pi * bot)
|
376 |
+
else:
|
377 |
+
weighting = torch.ones_like(sigmas)
|
378 |
+
return weighting
|
379 |
+
|
380 |
+
|
381 |
+
def get_noisy_model_input_and_timesteps(
|
382 |
+
args, noise_scheduler, latents, noise, device, dtype
|
383 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
384 |
+
bsz, _, h, w = latents.shape
|
385 |
+
sigmas = None
|
386 |
+
|
387 |
+
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
388 |
+
# Simple random t-based noise sampling
|
389 |
+
if args.timestep_sampling == "sigmoid":
|
390 |
+
# https://github.com/XLabs-AI/x-flux/tree/main
|
391 |
+
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
392 |
+
else:
|
393 |
+
t = torch.rand((bsz,), device=device)
|
394 |
+
|
395 |
+
timesteps = t * 1000.0
|
396 |
+
t = t.view(-1, 1, 1, 1)
|
397 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
398 |
+
elif args.timestep_sampling == "shift":
|
399 |
+
shift = args.discrete_flow_shift
|
400 |
+
logits_norm = torch.randn(bsz, device=device)
|
401 |
+
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
402 |
+
timesteps = logits_norm.sigmoid()
|
403 |
+
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
404 |
+
|
405 |
+
t = timesteps.view(-1, 1, 1, 1)
|
406 |
+
timesteps = timesteps * 1000.0
|
407 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
408 |
+
elif args.timestep_sampling == "flux_shift":
|
409 |
+
logits_norm = torch.randn(bsz, device=device)
|
410 |
+
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
411 |
+
timesteps = logits_norm.sigmoid()
|
412 |
+
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
413 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
414 |
+
|
415 |
+
t = timesteps.view(-1, 1, 1, 1)
|
416 |
+
timesteps = timesteps * 1000.0
|
417 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
418 |
+
else:
|
419 |
+
# Sample a random timestep for each image
|
420 |
+
# for weighting schemes where we sample timesteps non-uniformly
|
421 |
+
u = compute_density_for_timestep_sampling(
|
422 |
+
weighting_scheme=args.weighting_scheme,
|
423 |
+
batch_size=bsz,
|
424 |
+
logit_mean=args.logit_mean,
|
425 |
+
logit_std=args.logit_std,
|
426 |
+
mode_scale=args.mode_scale,
|
427 |
+
)
|
428 |
+
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
429 |
+
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
430 |
+
|
431 |
+
# Add noise according to flow matching.
|
432 |
+
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
433 |
+
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
434 |
+
|
435 |
+
return noisy_model_input, timesteps, sigmas
|
436 |
+
|
437 |
+
|
438 |
+
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
439 |
+
weighting = None
|
440 |
+
if args.model_prediction_type == "raw":
|
441 |
+
pass
|
442 |
+
elif args.model_prediction_type == "additive":
|
443 |
+
# add the model_pred to the noisy_model_input
|
444 |
+
model_pred = model_pred + noisy_model_input
|
445 |
+
elif args.model_prediction_type == "sigma_scaled":
|
446 |
+
# apply sigma scaling
|
447 |
+
model_pred = model_pred * (-sigmas) + noisy_model_input
|
448 |
+
|
449 |
+
# these weighting schemes use a uniform timestep sampling
|
450 |
+
# and instead post-weight the loss
|
451 |
+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
452 |
+
|
453 |
+
return model_pred, weighting
|
454 |
+
|
455 |
+
|
456 |
+
def save_models(
|
457 |
+
ckpt_path: str,
|
458 |
+
flux: flux_models.Flux,
|
459 |
+
sai_metadata: Optional[dict],
|
460 |
+
save_dtype: Optional[torch.dtype] = None,
|
461 |
+
use_mem_eff_save: bool = False,
|
462 |
+
):
|
463 |
+
state_dict = {}
|
464 |
+
|
465 |
+
def update_sd(prefix, sd):
|
466 |
+
for k, v in sd.items():
|
467 |
+
key = prefix + k
|
468 |
+
if save_dtype is not None and v.dtype != save_dtype:
|
469 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
470 |
+
state_dict[key] = v
|
471 |
+
|
472 |
+
update_sd("", flux.state_dict())
|
473 |
+
|
474 |
+
if not use_mem_eff_save:
|
475 |
+
save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
476 |
+
else:
|
477 |
+
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
478 |
+
|
479 |
+
|
480 |
+
def save_flux_model_on_train_end(
|
481 |
+
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
|
482 |
+
):
|
483 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
484 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
485 |
+
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
486 |
+
|
487 |
+
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
488 |
+
|
489 |
+
|
490 |
+
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
491 |
+
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
492 |
+
def save_flux_model_on_epoch_end_or_stepwise(
|
493 |
+
args: argparse.Namespace,
|
494 |
+
on_epoch_end: bool,
|
495 |
+
accelerator,
|
496 |
+
save_dtype: torch.dtype,
|
497 |
+
epoch: int,
|
498 |
+
num_train_epochs: int,
|
499 |
+
global_step: int,
|
500 |
+
flux: flux_models.Flux,
|
501 |
+
):
|
502 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
503 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
504 |
+
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
505 |
+
|
506 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
507 |
+
args,
|
508 |
+
on_epoch_end,
|
509 |
+
accelerator,
|
510 |
+
True,
|
511 |
+
True,
|
512 |
+
epoch,
|
513 |
+
num_train_epochs,
|
514 |
+
global_step,
|
515 |
+
sd_saver,
|
516 |
+
None,
|
517 |
+
)
|
518 |
+
|
519 |
+
|
520 |
+
# endregion
|
521 |
+
|
522 |
+
|
523 |
+
def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
524 |
+
parser.add_argument(
|
525 |
+
"--clip_l",
|
526 |
+
type=str,
|
527 |
+
help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提",
|
528 |
+
)
|
529 |
+
parser.add_argument(
|
530 |
+
"--t5xxl",
|
531 |
+
type=str,
|
532 |
+
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提",
|
533 |
+
)
|
534 |
+
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
|
535 |
+
parser.add_argument(
|
536 |
+
"--t5xxl_max_token_length",
|
537 |
+
type=int,
|
538 |
+
default=None,
|
539 |
+
help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
|
540 |
+
" / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
|
541 |
+
)
|
542 |
+
parser.add_argument(
|
543 |
+
"--apply_t5_attn_mask",
|
544 |
+
action="store_true",
|
545 |
+
help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
|
546 |
+
)
|
547 |
+
|
548 |
+
parser.add_argument(
|
549 |
+
"--guidance_scale",
|
550 |
+
type=float,
|
551 |
+
default=3.5,
|
552 |
+
help="the FLUX.1 dev variant is a guidance distilled model",
|
553 |
+
)
|
554 |
+
|
555 |
+
parser.add_argument(
|
556 |
+
"--timestep_sampling",
|
557 |
+
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
|
558 |
+
default="sigma",
|
559 |
+
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
|
560 |
+
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
|
561 |
+
)
|
562 |
+
parser.add_argument(
|
563 |
+
"--sigmoid_scale",
|
564 |
+
type=float,
|
565 |
+
default=1.0,
|
566 |
+
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
|
567 |
+
)
|
568 |
+
parser.add_argument(
|
569 |
+
"--model_prediction_type",
|
570 |
+
choices=["raw", "additive", "sigma_scaled"],
|
571 |
+
default="sigma_scaled",
|
572 |
+
help="How to interpret and process the model prediction: "
|
573 |
+
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
|
574 |
+
" / モデル予測の解釈と処理方法:"
|
575 |
+
"raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
|
576 |
+
)
|
577 |
+
parser.add_argument(
|
578 |
+
"--discrete_flow_shift",
|
579 |
+
type=float,
|
580 |
+
default=3.0,
|
581 |
+
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
582 |
+
)
|
library/flux_train_utils_recraft.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import toml
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
9 |
+
import pdb
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from accelerate import Accelerator, PartialState
|
13 |
+
from transformers import CLIPTextModel
|
14 |
+
from tqdm import tqdm
|
15 |
+
from PIL import Image
|
16 |
+
from safetensors.torch import save_file
|
17 |
+
|
18 |
+
from library import flux_models, flux_utils, strategy_base, train_util
|
19 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
20 |
+
|
21 |
+
init_ipex()
|
22 |
+
|
23 |
+
from .utils import setup_logging, mem_eff_save_file
|
24 |
+
|
25 |
+
setup_logging()
|
26 |
+
import logging
|
27 |
+
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
# region sample images
|
32 |
+
|
33 |
+
def sample_images(
|
34 |
+
accelerator: Accelerator,
|
35 |
+
args: argparse.Namespace,
|
36 |
+
epoch,
|
37 |
+
steps,
|
38 |
+
flux,
|
39 |
+
ae,
|
40 |
+
text_encoders,
|
41 |
+
sample_prompts_te_outputs,
|
42 |
+
prompt_replacement=None,
|
43 |
+
sample_images_ae_outputs=None
|
44 |
+
):
|
45 |
+
if steps == 0:
|
46 |
+
if not args.sample_at_first:
|
47 |
+
return
|
48 |
+
else:
|
49 |
+
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
50 |
+
return
|
51 |
+
if args.sample_every_n_epochs is not None:
|
52 |
+
# sample_every_n_steps は無視する
|
53 |
+
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
54 |
+
return
|
55 |
+
else:
|
56 |
+
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
57 |
+
return
|
58 |
+
|
59 |
+
logger.info("")
|
60 |
+
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
61 |
+
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
62 |
+
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
63 |
+
return
|
64 |
+
|
65 |
+
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
66 |
+
|
67 |
+
# unwrap unet and text_encoder(s)
|
68 |
+
flux = accelerator.unwrap_model(flux)
|
69 |
+
if text_encoders is not None:
|
70 |
+
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
71 |
+
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
72 |
+
|
73 |
+
prompts = train_util.load_prompts(args.sample_prompts)
|
74 |
+
|
75 |
+
save_dir = args.output_dir + "/sample"
|
76 |
+
os.makedirs(save_dir, exist_ok=True)
|
77 |
+
|
78 |
+
# save random state to restore later
|
79 |
+
rng_state = torch.get_rng_state()
|
80 |
+
cuda_rng_state = None
|
81 |
+
try:
|
82 |
+
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
83 |
+
except Exception:
|
84 |
+
pass
|
85 |
+
|
86 |
+
if distributed_state.num_processes <= 1:
|
87 |
+
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
88 |
+
with torch.no_grad(), accelerator.autocast():
|
89 |
+
for prompt_dict in prompts:
|
90 |
+
sample_image_inference(
|
91 |
+
accelerator,
|
92 |
+
args,
|
93 |
+
flux,
|
94 |
+
text_encoders,
|
95 |
+
ae,
|
96 |
+
save_dir,
|
97 |
+
prompt_dict,
|
98 |
+
epoch,
|
99 |
+
steps,
|
100 |
+
sample_prompts_te_outputs,
|
101 |
+
prompt_replacement,
|
102 |
+
sample_images_ae_outputs
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
106 |
+
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
107 |
+
per_process_prompts = [] # list of lists
|
108 |
+
for i in range(distributed_state.num_processes):
|
109 |
+
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
110 |
+
|
111 |
+
with torch.no_grad():
|
112 |
+
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
113 |
+
for prompt_dict in prompt_dict_lists[0]:
|
114 |
+
sample_image_inference(
|
115 |
+
accelerator,
|
116 |
+
args,
|
117 |
+
flux,
|
118 |
+
text_encoders,
|
119 |
+
ae,
|
120 |
+
save_dir,
|
121 |
+
prompt_dict,
|
122 |
+
epoch,
|
123 |
+
steps,
|
124 |
+
sample_prompts_te_outputs,
|
125 |
+
prompt_replacement,
|
126 |
+
sample_images_ae_outputs
|
127 |
+
)
|
128 |
+
|
129 |
+
torch.set_rng_state(rng_state)
|
130 |
+
if cuda_rng_state is not None:
|
131 |
+
torch.cuda.set_rng_state(cuda_rng_state)
|
132 |
+
|
133 |
+
clean_memory_on_device(accelerator.device)
|
134 |
+
|
135 |
+
|
136 |
+
def sample_image_inference(
|
137 |
+
accelerator: Accelerator,
|
138 |
+
args: argparse.Namespace,
|
139 |
+
flux: flux_models.Flux,
|
140 |
+
text_encoders: Optional[List[CLIPTextModel]],
|
141 |
+
ae: flux_models.AutoEncoder,
|
142 |
+
save_dir,
|
143 |
+
prompt_dict,
|
144 |
+
epoch,
|
145 |
+
steps,
|
146 |
+
sample_prompts_te_outputs,
|
147 |
+
prompt_replacement,
|
148 |
+
sample_images_ae_outputs
|
149 |
+
):
|
150 |
+
assert isinstance(prompt_dict, dict)
|
151 |
+
# negative_prompt = prompt_dict.get("negative_prompt")
|
152 |
+
sample_steps = prompt_dict.get("sample_steps", 20)
|
153 |
+
width = prompt_dict.get("width", 1024) if args.frame_num==4 else prompt_dict.get("width", 1056)
|
154 |
+
height = prompt_dict.get("height", 1024) if args.frame_num==4 else prompt_dict.get("height", 1056)
|
155 |
+
scale = prompt_dict.get("scale", 1.0)
|
156 |
+
seed = prompt_dict.get("seed")
|
157 |
+
# controlnet_image = prompt_dict.get("controlnet_image")
|
158 |
+
prompt: str = prompt_dict.get("prompt", "")
|
159 |
+
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
160 |
+
|
161 |
+
if prompt_replacement is not None:
|
162 |
+
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
163 |
+
# if negative_prompt is not None:
|
164 |
+
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
165 |
+
|
166 |
+
if seed is not None:
|
167 |
+
torch.manual_seed(seed)
|
168 |
+
torch.cuda.manual_seed(seed)
|
169 |
+
else:
|
170 |
+
# True random sample image generation
|
171 |
+
torch.seed()
|
172 |
+
torch.cuda.seed()
|
173 |
+
|
174 |
+
# if negative_prompt is None:
|
175 |
+
# negative_prompt = ""
|
176 |
+
|
177 |
+
height = max(64, height - height % 16) # round to divisible by 16
|
178 |
+
width = max(64, width - width % 16) # round to divisible by 16
|
179 |
+
logger.info(f"prompt: {prompt}")
|
180 |
+
# logger.info(f"negative_prompt: {negative_prompt}")
|
181 |
+
logger.info(f"height: {height}")
|
182 |
+
logger.info(f"width: {width}")
|
183 |
+
logger.info(f"sample_steps: {sample_steps}")
|
184 |
+
logger.info(f"scale: {scale}")
|
185 |
+
# logger.info(f"sample_sampler: {sampler_name}")
|
186 |
+
if seed is not None:
|
187 |
+
logger.info(f"seed: {seed}")
|
188 |
+
|
189 |
+
# encode prompts
|
190 |
+
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
191 |
+
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
192 |
+
|
193 |
+
text_encoder_conds = []
|
194 |
+
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
195 |
+
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
196 |
+
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
197 |
+
if text_encoders is not None:
|
198 |
+
print(f"Encoding prompt: {prompt}")
|
199 |
+
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
200 |
+
# strategy has apply_t5_attn_mask option
|
201 |
+
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
202 |
+
|
203 |
+
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
204 |
+
if len(text_encoder_conds) == 0:
|
205 |
+
text_encoder_conds = encoded_text_encoder_conds
|
206 |
+
else:
|
207 |
+
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
208 |
+
for i in range(len(encoded_text_encoder_conds)):
|
209 |
+
if encoded_text_encoder_conds[i] is not None:
|
210 |
+
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
211 |
+
|
212 |
+
if sample_images_ae_outputs and prompt in sample_images_ae_outputs:
|
213 |
+
ae_outputs = sample_images_ae_outputs[prompt]
|
214 |
+
else:
|
215 |
+
ae_outputs = None
|
216 |
+
|
217 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
218 |
+
|
219 |
+
# sample image
|
220 |
+
weight_dtype = ae.dtype # TOFO give dtype as argument
|
221 |
+
packed_latent_height = height // 16
|
222 |
+
packed_latent_width = width // 16
|
223 |
+
noise = torch.randn(
|
224 |
+
1,
|
225 |
+
packed_latent_height * packed_latent_width,
|
226 |
+
16 * 2 * 2,
|
227 |
+
device=accelerator.device,
|
228 |
+
dtype=weight_dtype,
|
229 |
+
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
|
230 |
+
)
|
231 |
+
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
|
232 |
+
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
233 |
+
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
234 |
+
|
235 |
+
with accelerator.autocast(), torch.no_grad():
|
236 |
+
x = denoise(args, flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs)
|
237 |
+
|
238 |
+
x = x.float()
|
239 |
+
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
240 |
+
|
241 |
+
# latent to image
|
242 |
+
clean_memory_on_device(accelerator.device)
|
243 |
+
org_vae_device = ae.device # will be on cpu
|
244 |
+
ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
|
245 |
+
with accelerator.autocast(), torch.no_grad():
|
246 |
+
x = ae.decode(x)
|
247 |
+
ae.to(org_vae_device)
|
248 |
+
clean_memory_on_device(accelerator.device)
|
249 |
+
|
250 |
+
x = x.clamp(-1, 1)
|
251 |
+
x = x.permute(0, 2, 3, 1)
|
252 |
+
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
253 |
+
|
254 |
+
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
255 |
+
# but adding 'enum' to the filename should be enough
|
256 |
+
|
257 |
+
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
258 |
+
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
259 |
+
seed_suffix = "" if seed is None else f"_{seed}"
|
260 |
+
i: int = prompt_dict["enum"]
|
261 |
+
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
262 |
+
image.save(os.path.join(save_dir, img_filename))
|
263 |
+
|
264 |
+
# send images to wandb if enabled
|
265 |
+
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
266 |
+
wandb_tracker = accelerator.get_tracker("wandb")
|
267 |
+
|
268 |
+
import wandb
|
269 |
+
# not to commit images to avoid inconsistency between training and logging steps
|
270 |
+
wandb_tracker.log(
|
271 |
+
{f"sample_{i}": wandb.Image(
|
272 |
+
image,
|
273 |
+
caption=prompt # positive prompt as a caption
|
274 |
+
)},
|
275 |
+
commit=False
|
276 |
+
)
|
277 |
+
|
278 |
+
|
279 |
+
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
280 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
281 |
+
|
282 |
+
|
283 |
+
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
284 |
+
m = (y2 - y1) / (x2 - x1)
|
285 |
+
b = y1 - m * x1
|
286 |
+
return lambda x: m * x + b
|
287 |
+
|
288 |
+
|
289 |
+
def get_schedule(
|
290 |
+
num_steps: int,
|
291 |
+
image_seq_len: int,
|
292 |
+
base_shift: float = 0.5,
|
293 |
+
max_shift: float = 1.15,
|
294 |
+
shift: bool = True,
|
295 |
+
) -> list[float]:
|
296 |
+
# extra step for zero
|
297 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
298 |
+
|
299 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
300 |
+
if shift:
|
301 |
+
# eastimate mu based on linear estimation between two points
|
302 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
303 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
304 |
+
|
305 |
+
return timesteps.tolist()
|
306 |
+
|
307 |
+
|
308 |
+
def denoise(
|
309 |
+
args: argparse.Namespace,
|
310 |
+
model: flux_models.Flux,
|
311 |
+
img: torch.Tensor,
|
312 |
+
img_ids: torch.Tensor,
|
313 |
+
txt: torch.Tensor,
|
314 |
+
txt_ids: torch.Tensor,
|
315 |
+
vec: torch.Tensor,
|
316 |
+
timesteps: list[float],
|
317 |
+
guidance: float = 4.0,
|
318 |
+
t5_attn_mask: Optional[torch.Tensor] = None,
|
319 |
+
ae_outputs: torch.Tensor = None,
|
320 |
+
):
|
321 |
+
# this is ignored for schnell
|
322 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
323 |
+
img_ids = img_ids.to(img.device)
|
324 |
+
txt_ids = txt_ids.to(img.device)
|
325 |
+
vec = vec.to(img.device)
|
326 |
+
txt = txt.to(img.device)
|
327 |
+
|
328 |
+
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
329 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
330 |
+
model.prepare_block_swap_before_forward()
|
331 |
+
if args.frame_num == 4:
|
332 |
+
packed_latent_height, packed_latent_width = ae_outputs.shape[2]*2 // 2, ae_outputs.shape[3]*2 // 2
|
333 |
+
img = flux_utils.unpack_latents(img, packed_latent_height, packed_latent_width)
|
334 |
+
img[:,:, img.shape[2] // 2: img.shape[2], :img.shape[3] // 2] = ae_outputs
|
335 |
+
else:
|
336 |
+
packed_latent_height, packed_latent_width = ae_outputs.shape[2]*3 // 2, ae_outputs.shape[3]*3 // 2
|
337 |
+
img = flux_utils.unpack_latents(img, packed_latent_height, packed_latent_width)
|
338 |
+
img[:,:, 2*img.shape[2] // 3: img.shape[2], 2*img.shape[3] // 3:img.shape[3]] = ae_outputs
|
339 |
+
|
340 |
+
img = flux_utils.pack_latents(img)
|
341 |
+
pred = model(
|
342 |
+
img=img,
|
343 |
+
img_ids=img_ids,
|
344 |
+
txt=txt,
|
345 |
+
txt_ids=txt_ids,
|
346 |
+
y=vec,
|
347 |
+
timesteps=t_vec,
|
348 |
+
guidance=guidance_vec,
|
349 |
+
txt_attention_mask=t5_attn_mask,
|
350 |
+
)
|
351 |
+
|
352 |
+
img = img + (t_prev - t_curr) * pred
|
353 |
+
|
354 |
+
model.prepare_block_swap_before_forward()
|
355 |
+
return img
|
356 |
+
|
357 |
+
|
358 |
+
# endregion
|
359 |
+
|
360 |
+
|
361 |
+
# region train
|
362 |
+
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
363 |
+
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
364 |
+
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
365 |
+
timesteps = timesteps.to(device)
|
366 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
367 |
+
|
368 |
+
sigma = sigmas[step_indices].flatten()
|
369 |
+
while len(sigma.shape) < n_dim:
|
370 |
+
sigma = sigma.unsqueeze(-1)
|
371 |
+
return sigma
|
372 |
+
|
373 |
+
|
374 |
+
def compute_density_for_timestep_sampling(
|
375 |
+
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
376 |
+
):
|
377 |
+
"""Compute the density for sampling the timesteps when doing SD3 training.
|
378 |
+
|
379 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
380 |
+
|
381 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
382 |
+
"""
|
383 |
+
if weighting_scheme == "logit_normal":
|
384 |
+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
385 |
+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
386 |
+
u = torch.nn.functional.sigmoid(u)
|
387 |
+
elif weighting_scheme == "mode":
|
388 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
389 |
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
390 |
+
else:
|
391 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
392 |
+
return u
|
393 |
+
|
394 |
+
|
395 |
+
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
396 |
+
"""Computes loss weighting scheme for SD3 training.
|
397 |
+
|
398 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
399 |
+
|
400 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
401 |
+
"""
|
402 |
+
if weighting_scheme == "sigma_sqrt":
|
403 |
+
weighting = (sigmas**-2.0).float()
|
404 |
+
elif weighting_scheme == "cosmap":
|
405 |
+
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
406 |
+
weighting = 2 / (math.pi * bot)
|
407 |
+
else:
|
408 |
+
weighting = torch.ones_like(sigmas)
|
409 |
+
return weighting
|
410 |
+
|
411 |
+
|
412 |
+
def get_noisy_model_input_and_timesteps(
|
413 |
+
args, noise_scheduler, latents, noise, device, dtype
|
414 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
415 |
+
bsz, _, h, w = latents.shape
|
416 |
+
sigmas = None
|
417 |
+
|
418 |
+
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
419 |
+
# Simple random t-based noise sampling
|
420 |
+
if args.timestep_sampling == "sigmoid":
|
421 |
+
# https://github.com/XLabs-AI/x-flux/tree/main
|
422 |
+
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
423 |
+
else:
|
424 |
+
t = torch.rand((bsz,), device=device)
|
425 |
+
|
426 |
+
timesteps = t * 1000.0
|
427 |
+
t = t.view(-1, 1, 1, 1)
|
428 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
429 |
+
elif args.timestep_sampling == "shift":
|
430 |
+
shift = args.discrete_flow_shift
|
431 |
+
logits_norm = torch.randn(bsz, device=device)
|
432 |
+
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
433 |
+
timesteps = logits_norm.sigmoid()
|
434 |
+
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
435 |
+
|
436 |
+
t = timesteps.view(-1, 1, 1, 1)
|
437 |
+
timesteps = timesteps * 1000.0
|
438 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
439 |
+
elif args.timestep_sampling == "flux_shift":
|
440 |
+
logits_norm = torch.randn(bsz, device=device)
|
441 |
+
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
442 |
+
timesteps = logits_norm.sigmoid()
|
443 |
+
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
444 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
445 |
+
|
446 |
+
t = timesteps.view(-1, 1, 1, 1)
|
447 |
+
timesteps = timesteps * 1000.0
|
448 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
449 |
+
else:
|
450 |
+
# Sample a random timestep for each image
|
451 |
+
# for weighting schemes where we sample timesteps non-uniformly
|
452 |
+
u = compute_density_for_timestep_sampling(
|
453 |
+
weighting_scheme=args.weighting_scheme,
|
454 |
+
batch_size=bsz,
|
455 |
+
logit_mean=args.logit_mean,
|
456 |
+
logit_std=args.logit_std,
|
457 |
+
mode_scale=args.mode_scale,
|
458 |
+
)
|
459 |
+
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
460 |
+
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
461 |
+
|
462 |
+
# Add noise according to flow matching.
|
463 |
+
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
464 |
+
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
465 |
+
|
466 |
+
# 替换部分区域为原始latents
|
467 |
+
h, w = noisy_model_input.shape[2], noisy_model_input.shape[3]
|
468 |
+
# import pdb; pdb.set_trace()
|
469 |
+
if args.frame_num == 4:
|
470 |
+
noisy_model_input[:, :, h//2 : h, w//2 : w] = latents[:, :, h//2:h, w//2:w]
|
471 |
+
else:
|
472 |
+
noisy_model_input[:, :, 2*h//3 : h, 2*w//3 : w] = latents[:, :, 2*h//3:h, 2*w//3:w]
|
473 |
+
|
474 |
+
|
475 |
+
return noisy_model_input, timesteps, sigmas
|
476 |
+
|
477 |
+
|
478 |
+
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
479 |
+
weighting = None
|
480 |
+
if args.model_prediction_type == "raw":
|
481 |
+
pass
|
482 |
+
elif args.model_prediction_type == "additive":
|
483 |
+
# add the model_pred to the noisy_model_input
|
484 |
+
model_pred = model_pred + noisy_model_input
|
485 |
+
elif args.model_prediction_type == "sigma_scaled":
|
486 |
+
# apply sigma scaling
|
487 |
+
model_pred = model_pred * (-sigmas) + noisy_model_input
|
488 |
+
|
489 |
+
# these weighting schemes use a uniform timestep sampling
|
490 |
+
# and instead post-weight the loss
|
491 |
+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
492 |
+
|
493 |
+
return model_pred, weighting
|
494 |
+
|
495 |
+
|
496 |
+
def save_models(
|
497 |
+
ckpt_path: str,
|
498 |
+
flux: flux_models.Flux,
|
499 |
+
sai_metadata: Optional[dict],
|
500 |
+
save_dtype: Optional[torch.dtype] = None,
|
501 |
+
use_mem_eff_save: bool = False,
|
502 |
+
):
|
503 |
+
state_dict = {}
|
504 |
+
|
505 |
+
def update_sd(prefix, sd):
|
506 |
+
for k, v in sd.items():
|
507 |
+
key = prefix + k
|
508 |
+
if save_dtype is not None and v.dtype != save_dtype:
|
509 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
510 |
+
state_dict[key] = v
|
511 |
+
|
512 |
+
update_sd("", flux.state_dict())
|
513 |
+
|
514 |
+
if not use_mem_eff_save:
|
515 |
+
save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
516 |
+
else:
|
517 |
+
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
518 |
+
|
519 |
+
|
520 |
+
def save_flux_model_on_train_end(
|
521 |
+
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
|
522 |
+
):
|
523 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
524 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
525 |
+
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
526 |
+
|
527 |
+
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
528 |
+
|
529 |
+
|
530 |
+
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
531 |
+
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
532 |
+
def save_flux_model_on_epoch_end_or_stepwise(
|
533 |
+
args: argparse.Namespace,
|
534 |
+
on_epoch_end: bool,
|
535 |
+
accelerator,
|
536 |
+
save_dtype: torch.dtype,
|
537 |
+
epoch: int,
|
538 |
+
num_train_epochs: int,
|
539 |
+
global_step: int,
|
540 |
+
flux: flux_models.Flux,
|
541 |
+
):
|
542 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
543 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
544 |
+
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
545 |
+
|
546 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
547 |
+
args,
|
548 |
+
on_epoch_end,
|
549 |
+
accelerator,
|
550 |
+
True,
|
551 |
+
True,
|
552 |
+
epoch,
|
553 |
+
num_train_epochs,
|
554 |
+
global_step,
|
555 |
+
sd_saver,
|
556 |
+
None,
|
557 |
+
)
|
558 |
+
|
559 |
+
|
560 |
+
# endregion
|
561 |
+
|
562 |
+
|
563 |
+
def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
564 |
+
parser.add_argument(
|
565 |
+
"--clip_l",
|
566 |
+
type=str,
|
567 |
+
help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提",
|
568 |
+
)
|
569 |
+
parser.add_argument(
|
570 |
+
"--t5xxl",
|
571 |
+
type=str,
|
572 |
+
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提",
|
573 |
+
)
|
574 |
+
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
|
575 |
+
parser.add_argument(
|
576 |
+
"--t5xxl_max_token_length",
|
577 |
+
type=int,
|
578 |
+
default=None,
|
579 |
+
help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
|
580 |
+
" / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
|
581 |
+
)
|
582 |
+
parser.add_argument(
|
583 |
+
"--apply_t5_attn_mask",
|
584 |
+
action="store_true",
|
585 |
+
help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
|
586 |
+
)
|
587 |
+
parser.add_argument(
|
588 |
+
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
589 |
+
)
|
590 |
+
parser.add_argument(
|
591 |
+
"--cache_text_encoder_outputs_to_disk",
|
592 |
+
action="store_true",
|
593 |
+
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
594 |
+
)
|
595 |
+
parser.add_argument(
|
596 |
+
"--text_encoder_batch_size",
|
597 |
+
type=int,
|
598 |
+
default=None,
|
599 |
+
help="text encoder batch size (default: None, use dataset's batch size)"
|
600 |
+
+ " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)",
|
601 |
+
)
|
602 |
+
parser.add_argument(
|
603 |
+
"--disable_mmap_load_safetensors",
|
604 |
+
action="store_true",
|
605 |
+
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
|
606 |
+
)
|
607 |
+
|
608 |
+
# copy from Diffusers
|
609 |
+
parser.add_argument(
|
610 |
+
"--weighting_scheme",
|
611 |
+
type=str,
|
612 |
+
default="none",
|
613 |
+
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
|
614 |
+
)
|
615 |
+
parser.add_argument(
|
616 |
+
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
|
617 |
+
)
|
618 |
+
parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.")
|
619 |
+
parser.add_argument(
|
620 |
+
"--mode_scale",
|
621 |
+
type=float,
|
622 |
+
default=1.29,
|
623 |
+
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
624 |
+
)
|
625 |
+
parser.add_argument(
|
626 |
+
"--guidance_scale",
|
627 |
+
type=float,
|
628 |
+
default=3.5,
|
629 |
+
help="the FLUX.1 dev variant is a guidance distilled model",
|
630 |
+
)
|
631 |
+
|
632 |
+
parser.add_argument(
|
633 |
+
"--timestep_sampling",
|
634 |
+
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
|
635 |
+
default="sigma",
|
636 |
+
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
|
637 |
+
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
|
638 |
+
)
|
639 |
+
parser.add_argument(
|
640 |
+
"--sigmoid_scale",
|
641 |
+
type=float,
|
642 |
+
default=1.0,
|
643 |
+
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
|
644 |
+
)
|
645 |
+
parser.add_argument(
|
646 |
+
"--model_prediction_type",
|
647 |
+
choices=["raw", "additive", "sigma_scaled"],
|
648 |
+
default="sigma_scaled",
|
649 |
+
help="How to interpret and process the model prediction: "
|
650 |
+
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
|
651 |
+
" / モデル予測の解釈と処理方法:"
|
652 |
+
"raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
|
653 |
+
)
|
654 |
+
parser.add_argument(
|
655 |
+
"--discrete_flow_shift",
|
656 |
+
type=float,
|
657 |
+
default=3.0,
|
658 |
+
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
659 |
+
)
|
library/flux_utils.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import replace
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
import einops
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from safetensors.torch import load_file
|
9 |
+
from safetensors import safe_open
|
10 |
+
from accelerate import init_empty_weights
|
11 |
+
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
|
12 |
+
|
13 |
+
from library.utils import setup_logging
|
14 |
+
|
15 |
+
setup_logging()
|
16 |
+
import logging
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
from library import flux_models
|
21 |
+
from library.utils import load_safetensors
|
22 |
+
|
23 |
+
MODEL_VERSION_FLUX_V1 = "flux1"
|
24 |
+
MODEL_NAME_DEV = "dev"
|
25 |
+
MODEL_NAME_SCHNELL = "schnell"
|
26 |
+
|
27 |
+
|
28 |
+
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
|
29 |
+
"""
|
30 |
+
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
|
31 |
+
|
32 |
+
Args:
|
33 |
+
ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Tuple[bool, bool, Tuple[int, int], List[str]]:
|
37 |
+
- bool: Diffusersかどうかを示すフラグ。
|
38 |
+
- bool: Schnellかどうかを示すフラグ。
|
39 |
+
- Tuple[int, int]: ダブルブロックとシングルブロックの数。
|
40 |
+
- List[str]: チェックポイントに含まれるキーのリスト。
|
41 |
+
"""
|
42 |
+
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks
|
43 |
+
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
|
44 |
+
|
45 |
+
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
|
46 |
+
ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
|
47 |
+
if "00001-of-00003" in ckpt_path:
|
48 |
+
ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
|
49 |
+
else:
|
50 |
+
ckpt_paths = [ckpt_path]
|
51 |
+
|
52 |
+
keys = []
|
53 |
+
for ckpt_path in ckpt_paths:
|
54 |
+
with safe_open(ckpt_path, framework="pt") as f:
|
55 |
+
keys.extend(f.keys())
|
56 |
+
|
57 |
+
# if the key has annoying prefix, remove it
|
58 |
+
if keys[0].startswith("model.diffusion_model."):
|
59 |
+
keys = [key.replace("model.diffusion_model.", "") for key in keys]
|
60 |
+
|
61 |
+
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
|
62 |
+
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
|
63 |
+
|
64 |
+
# check number of double and single blocks
|
65 |
+
if not is_diffusers:
|
66 |
+
max_double_block_index = max(
|
67 |
+
[int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")]
|
68 |
+
)
|
69 |
+
max_single_block_index = max(
|
70 |
+
[int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")]
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
max_double_block_index = max(
|
74 |
+
[
|
75 |
+
int(key.split(".")[1])
|
76 |
+
for key in keys
|
77 |
+
if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias")
|
78 |
+
]
|
79 |
+
)
|
80 |
+
max_single_block_index = max(
|
81 |
+
[
|
82 |
+
int(key.split(".")[1])
|
83 |
+
for key in keys
|
84 |
+
if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias")
|
85 |
+
]
|
86 |
+
)
|
87 |
+
|
88 |
+
num_double_blocks = max_double_block_index + 1
|
89 |
+
num_single_blocks = max_single_block_index + 1
|
90 |
+
|
91 |
+
return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths
|
92 |
+
|
93 |
+
|
94 |
+
def load_flow_model(
|
95 |
+
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
96 |
+
) -> Tuple[bool, flux_models.Flux]:
|
97 |
+
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
|
98 |
+
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
99 |
+
|
100 |
+
# build model
|
101 |
+
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
|
102 |
+
with torch.device("meta"):
|
103 |
+
params = flux_models.configs[name].params
|
104 |
+
|
105 |
+
# set the number of blocks
|
106 |
+
if params.depth != num_double_blocks:
|
107 |
+
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
|
108 |
+
params = replace(params, depth=num_double_blocks)
|
109 |
+
if params.depth_single_blocks != num_single_blocks:
|
110 |
+
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
|
111 |
+
params = replace(params, depth_single_blocks=num_single_blocks)
|
112 |
+
|
113 |
+
model = flux_models.Flux(params)
|
114 |
+
if dtype is not None:
|
115 |
+
model = model.to(dtype)
|
116 |
+
|
117 |
+
# load_sft doesn't support torch.device
|
118 |
+
logger.info(f"Loading state dict from {ckpt_path}")
|
119 |
+
sd = {}
|
120 |
+
for ckpt_path in ckpt_paths:
|
121 |
+
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
|
122 |
+
|
123 |
+
# convert Diffusers to BFL
|
124 |
+
if is_diffusers:
|
125 |
+
logger.info("Converting Diffusers to BFL")
|
126 |
+
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
|
127 |
+
logger.info("Converted Diffusers to BFL")
|
128 |
+
|
129 |
+
# if the key has annoying prefix, remove it
|
130 |
+
for key in list(sd.keys()):
|
131 |
+
new_key = key.replace("model.diffusion_model.", "")
|
132 |
+
if new_key == key:
|
133 |
+
break # the model doesn't have annoying prefix
|
134 |
+
sd[new_key] = sd.pop(key)
|
135 |
+
|
136 |
+
info = model.load_state_dict(sd, strict=False, assign=True)
|
137 |
+
logger.info(f"Loaded Flux: {info}")
|
138 |
+
return is_schnell, model
|
139 |
+
|
140 |
+
|
141 |
+
def load_ae(
|
142 |
+
ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
|
143 |
+
) -> flux_models.AutoEncoder:
|
144 |
+
logger.info("Building AutoEncoder")
|
145 |
+
with torch.device("meta"):
|
146 |
+
# dev and schnell have the same AE params
|
147 |
+
ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype)
|
148 |
+
|
149 |
+
logger.info(f"Loading state dict from {ckpt_path}")
|
150 |
+
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
151 |
+
info = ae.load_state_dict(sd, strict=False, assign=True)
|
152 |
+
logger.info(f"Loaded AE: {info}")
|
153 |
+
return ae
|
154 |
+
|
155 |
+
|
156 |
+
def load_clip_l(
|
157 |
+
ckpt_path: Optional[str],
|
158 |
+
dtype: torch.dtype,
|
159 |
+
device: Union[str, torch.device],
|
160 |
+
disable_mmap: bool = False,
|
161 |
+
state_dict: Optional[dict] = None,
|
162 |
+
) -> CLIPTextModel:
|
163 |
+
logger.info("Building CLIP-L")
|
164 |
+
CLIPL_CONFIG = {
|
165 |
+
"_name_or_path": "clip-vit-large-patch14/",
|
166 |
+
"architectures": ["CLIPModel"],
|
167 |
+
"initializer_factor": 1.0,
|
168 |
+
"logit_scale_init_value": 2.6592,
|
169 |
+
"model_type": "clip",
|
170 |
+
"projection_dim": 768,
|
171 |
+
# "text_config": {
|
172 |
+
"_name_or_path": "",
|
173 |
+
"add_cross_attention": False,
|
174 |
+
"architectures": None,
|
175 |
+
"attention_dropout": 0.0,
|
176 |
+
"bad_words_ids": None,
|
177 |
+
"bos_token_id": 0,
|
178 |
+
"chunk_size_feed_forward": 0,
|
179 |
+
"cross_attention_hidden_size": None,
|
180 |
+
"decoder_start_token_id": None,
|
181 |
+
"diversity_penalty": 0.0,
|
182 |
+
"do_sample": False,
|
183 |
+
"dropout": 0.0,
|
184 |
+
"early_stopping": False,
|
185 |
+
"encoder_no_repeat_ngram_size": 0,
|
186 |
+
"eos_token_id": 2,
|
187 |
+
"finetuning_task": None,
|
188 |
+
"forced_bos_token_id": None,
|
189 |
+
"forced_eos_token_id": None,
|
190 |
+
"hidden_act": "quick_gelu",
|
191 |
+
"hidden_size": 768,
|
192 |
+
"id2label": {"0": "LABEL_0", "1": "LABEL_1"},
|
193 |
+
"initializer_factor": 1.0,
|
194 |
+
"initializer_range": 0.02,
|
195 |
+
"intermediate_size": 3072,
|
196 |
+
"is_decoder": False,
|
197 |
+
"is_encoder_decoder": False,
|
198 |
+
"label2id": {"LABEL_0": 0, "LABEL_1": 1},
|
199 |
+
"layer_norm_eps": 1e-05,
|
200 |
+
"length_penalty": 1.0,
|
201 |
+
"max_length": 20,
|
202 |
+
"max_position_embeddings": 77,
|
203 |
+
"min_length": 0,
|
204 |
+
"model_type": "clip_text_model",
|
205 |
+
"no_repeat_ngram_size": 0,
|
206 |
+
"num_attention_heads": 12,
|
207 |
+
"num_beam_groups": 1,
|
208 |
+
"num_beams": 1,
|
209 |
+
"num_hidden_layers": 12,
|
210 |
+
"num_return_sequences": 1,
|
211 |
+
"output_attentions": False,
|
212 |
+
"output_hidden_states": False,
|
213 |
+
"output_scores": False,
|
214 |
+
"pad_token_id": 1,
|
215 |
+
"prefix": None,
|
216 |
+
"problem_type": None,
|
217 |
+
"projection_dim": 768,
|
218 |
+
"pruned_heads": {},
|
219 |
+
"remove_invalid_values": False,
|
220 |
+
"repetition_penalty": 1.0,
|
221 |
+
"return_dict": True,
|
222 |
+
"return_dict_in_generate": False,
|
223 |
+
"sep_token_id": None,
|
224 |
+
"task_specific_params": None,
|
225 |
+
"temperature": 1.0,
|
226 |
+
"tie_encoder_decoder": False,
|
227 |
+
"tie_word_embeddings": True,
|
228 |
+
"tokenizer_class": None,
|
229 |
+
"top_k": 50,
|
230 |
+
"top_p": 1.0,
|
231 |
+
"torch_dtype": None,
|
232 |
+
"torchscript": False,
|
233 |
+
"transformers_version": "4.16.0.dev0",
|
234 |
+
"use_bfloat16": False,
|
235 |
+
"vocab_size": 49408,
|
236 |
+
"hidden_act": "gelu",
|
237 |
+
"hidden_size": 1280,
|
238 |
+
"intermediate_size": 5120,
|
239 |
+
"num_attention_heads": 20,
|
240 |
+
"num_hidden_layers": 32,
|
241 |
+
# },
|
242 |
+
# "text_config_dict": {
|
243 |
+
"hidden_size": 768,
|
244 |
+
"intermediate_size": 3072,
|
245 |
+
"num_attention_heads": 12,
|
246 |
+
"num_hidden_layers": 12,
|
247 |
+
"projection_dim": 768,
|
248 |
+
# },
|
249 |
+
# "torch_dtype": "float32",
|
250 |
+
# "transformers_version": None,
|
251 |
+
}
|
252 |
+
config = CLIPConfig(**CLIPL_CONFIG)
|
253 |
+
with init_empty_weights():
|
254 |
+
clip = CLIPTextModel._from_config(config)
|
255 |
+
|
256 |
+
if state_dict is not None:
|
257 |
+
sd = state_dict
|
258 |
+
else:
|
259 |
+
logger.info(f"Loading state dict from {ckpt_path}")
|
260 |
+
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
261 |
+
info = clip.load_state_dict(sd, strict=False, assign=True)
|
262 |
+
logger.info(f"Loaded CLIP-L: {info}")
|
263 |
+
return clip
|
264 |
+
|
265 |
+
|
266 |
+
def load_t5xxl(
|
267 |
+
ckpt_path: str,
|
268 |
+
dtype: Optional[torch.dtype],
|
269 |
+
device: Union[str, torch.device],
|
270 |
+
disable_mmap: bool = False,
|
271 |
+
state_dict: Optional[dict] = None,
|
272 |
+
) -> T5EncoderModel:
|
273 |
+
T5_CONFIG_JSON = """
|
274 |
+
{
|
275 |
+
"architectures": [
|
276 |
+
"T5EncoderModel"
|
277 |
+
],
|
278 |
+
"classifier_dropout": 0.0,
|
279 |
+
"d_ff": 10240,
|
280 |
+
"d_kv": 64,
|
281 |
+
"d_model": 4096,
|
282 |
+
"decoder_start_token_id": 0,
|
283 |
+
"dense_act_fn": "gelu_new",
|
284 |
+
"dropout_rate": 0.1,
|
285 |
+
"eos_token_id": 1,
|
286 |
+
"feed_forward_proj": "gated-gelu",
|
287 |
+
"initializer_factor": 1.0,
|
288 |
+
"is_encoder_decoder": true,
|
289 |
+
"is_gated_act": true,
|
290 |
+
"layer_norm_epsilon": 1e-06,
|
291 |
+
"model_type": "t5",
|
292 |
+
"num_decoder_layers": 24,
|
293 |
+
"num_heads": 64,
|
294 |
+
"num_layers": 24,
|
295 |
+
"output_past": true,
|
296 |
+
"pad_token_id": 0,
|
297 |
+
"relative_attention_max_distance": 128,
|
298 |
+
"relative_attention_num_buckets": 32,
|
299 |
+
"tie_word_embeddings": false,
|
300 |
+
"torch_dtype": "float16",
|
301 |
+
"transformers_version": "4.41.2",
|
302 |
+
"use_cache": true,
|
303 |
+
"vocab_size": 32128
|
304 |
+
}
|
305 |
+
"""
|
306 |
+
config = json.loads(T5_CONFIG_JSON)
|
307 |
+
config = T5Config(**config)
|
308 |
+
with init_empty_weights():
|
309 |
+
t5xxl = T5EncoderModel._from_config(config)
|
310 |
+
|
311 |
+
if state_dict is not None:
|
312 |
+
sd = state_dict
|
313 |
+
else:
|
314 |
+
logger.info(f"Loading state dict from {ckpt_path}")
|
315 |
+
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
316 |
+
info = t5xxl.load_state_dict(sd, strict=False, assign=True)
|
317 |
+
logger.info(f"Loaded T5xxl: {info}")
|
318 |
+
return t5xxl
|
319 |
+
|
320 |
+
|
321 |
+
def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
|
322 |
+
# nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
|
323 |
+
return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
|
324 |
+
|
325 |
+
|
326 |
+
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
|
327 |
+
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
|
328 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
|
329 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :]
|
330 |
+
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
331 |
+
return img_ids
|
332 |
+
|
333 |
+
|
334 |
+
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
|
335 |
+
"""
|
336 |
+
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
|
337 |
+
"""
|
338 |
+
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
|
339 |
+
return x
|
340 |
+
|
341 |
+
|
342 |
+
def pack_latents(x: torch.Tensor) -> torch.Tensor:
|
343 |
+
"""
|
344 |
+
x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
|
345 |
+
"""
|
346 |
+
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
347 |
+
return x
|
348 |
+
|
349 |
+
|
350 |
+
# region Diffusers
|
351 |
+
|
352 |
+
NUM_DOUBLE_BLOCKS = 19
|
353 |
+
NUM_SINGLE_BLOCKS = 38
|
354 |
+
|
355 |
+
BFL_TO_DIFFUSERS_MAP = {
|
356 |
+
"time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
|
357 |
+
"time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
|
358 |
+
"time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
|
359 |
+
"time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
|
360 |
+
"vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
|
361 |
+
"vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
|
362 |
+
"vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
|
363 |
+
"vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
|
364 |
+
"guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
|
365 |
+
"guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
|
366 |
+
"guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
|
367 |
+
"guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
|
368 |
+
"txt_in.weight": ["context_embedder.weight"],
|
369 |
+
"txt_in.bias": ["context_embedder.bias"],
|
370 |
+
"img_in.weight": ["x_embedder.weight"],
|
371 |
+
"img_in.bias": ["x_embedder.bias"],
|
372 |
+
"double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
|
373 |
+
"double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
|
374 |
+
"double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
|
375 |
+
"double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
|
376 |
+
"double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
|
377 |
+
"double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
|
378 |
+
"double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
|
379 |
+
"double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
|
380 |
+
"double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
|
381 |
+
"double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
|
382 |
+
"double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
|
383 |
+
"double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
|
384 |
+
"double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
|
385 |
+
"double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
|
386 |
+
"double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
|
387 |
+
"double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
|
388 |
+
"double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
|
389 |
+
"double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
|
390 |
+
"double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
|
391 |
+
"double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
|
392 |
+
"double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
|
393 |
+
"double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
|
394 |
+
"double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
|
395 |
+
"double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
|
396 |
+
"single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
|
397 |
+
"single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
|
398 |
+
"single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
|
399 |
+
"single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
|
400 |
+
"single_blocks.().linear2.weight": ["proj_out.weight"],
|
401 |
+
"single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
|
402 |
+
"single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
|
403 |
+
"single_blocks.().linear2.weight": ["proj_out.weight"],
|
404 |
+
"single_blocks.().linear2.bias": ["proj_out.bias"],
|
405 |
+
"final_layer.linear.weight": ["proj_out.weight"],
|
406 |
+
"final_layer.linear.bias": ["proj_out.bias"],
|
407 |
+
"final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
|
408 |
+
"final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
|
409 |
+
}
|
410 |
+
|
411 |
+
|
412 |
+
def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]:
|
413 |
+
# make reverse map from diffusers map
|
414 |
+
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
|
415 |
+
for b in range(num_double_blocks):
|
416 |
+
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
417 |
+
if key.startswith("double_blocks."):
|
418 |
+
block_prefix = f"transformer_blocks.{b}."
|
419 |
+
for i, weight in enumerate(weights):
|
420 |
+
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
|
421 |
+
for b in range(num_single_blocks):
|
422 |
+
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
423 |
+
if key.startswith("single_blocks."):
|
424 |
+
block_prefix = f"single_transformer_blocks.{b}."
|
425 |
+
for i, weight in enumerate(weights):
|
426 |
+
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
|
427 |
+
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
428 |
+
if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")):
|
429 |
+
for i, weight in enumerate(weights):
|
430 |
+
diffusers_to_bfl_map[weight] = (i, key)
|
431 |
+
return diffusers_to_bfl_map
|
432 |
+
|
433 |
+
|
434 |
+
def convert_diffusers_sd_to_bfl(
|
435 |
+
diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS
|
436 |
+
) -> dict[str, torch.Tensor]:
|
437 |
+
diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks)
|
438 |
+
|
439 |
+
# iterate over three safetensors files to reduce memory usage
|
440 |
+
flux_sd = {}
|
441 |
+
for diffusers_key, tensor in diffusers_sd.items():
|
442 |
+
if diffusers_key in diffusers_to_bfl_map:
|
443 |
+
index, bfl_key = diffusers_to_bfl_map[diffusers_key]
|
444 |
+
if bfl_key not in flux_sd:
|
445 |
+
flux_sd[bfl_key] = []
|
446 |
+
flux_sd[bfl_key].append((index, tensor))
|
447 |
+
else:
|
448 |
+
logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}")
|
449 |
+
raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}")
|
450 |
+
|
451 |
+
# concat tensors if multiple tensors are mapped to a single key, sort by index
|
452 |
+
for key, values in flux_sd.items():
|
453 |
+
if len(values) == 1:
|
454 |
+
flux_sd[key] = values[0][1]
|
455 |
+
else:
|
456 |
+
flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])])
|
457 |
+
|
458 |
+
# special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias
|
459 |
+
def swap_scale_shift(weight):
|
460 |
+
shift, scale = weight.chunk(2, dim=0)
|
461 |
+
new_weight = torch.cat([scale, shift], dim=0)
|
462 |
+
return new_weight
|
463 |
+
|
464 |
+
if "final_layer.adaLN_modulation.1.weight" in flux_sd:
|
465 |
+
flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"])
|
466 |
+
if "final_layer.adaLN_modulation.1.bias" in flux_sd:
|
467 |
+
flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"])
|
468 |
+
|
469 |
+
return flux_sd
|
470 |
+
|
471 |
+
|
472 |
+
# endregion
|
library/huggingface_util.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, BinaryIO
|
2 |
+
from huggingface_hub import HfApi
|
3 |
+
from pathlib import Path
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
from library.utils import fire_in_thread
|
7 |
+
from library.utils import setup_logging
|
8 |
+
setup_logging()
|
9 |
+
import logging
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
|
13 |
+
api = HfApi(
|
14 |
+
token=token,
|
15 |
+
)
|
16 |
+
try:
|
17 |
+
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
18 |
+
return True
|
19 |
+
except:
|
20 |
+
return False
|
21 |
+
|
22 |
+
|
23 |
+
def upload(
|
24 |
+
args: argparse.Namespace,
|
25 |
+
src: Union[str, Path, bytes, BinaryIO],
|
26 |
+
dest_suffix: str = "",
|
27 |
+
force_sync_upload: bool = False,
|
28 |
+
):
|
29 |
+
repo_id = args.huggingface_repo_id
|
30 |
+
repo_type = args.huggingface_repo_type
|
31 |
+
token = args.huggingface_token
|
32 |
+
path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
|
33 |
+
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
|
34 |
+
api = HfApi(token=token)
|
35 |
+
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
36 |
+
try:
|
37 |
+
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
38 |
+
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
|
39 |
+
logger.error("===========================================")
|
40 |
+
logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
41 |
+
logger.error("===========================================")
|
42 |
+
|
43 |
+
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
|
44 |
+
|
45 |
+
def uploader():
|
46 |
+
try:
|
47 |
+
if is_folder:
|
48 |
+
api.upload_folder(
|
49 |
+
repo_id=repo_id,
|
50 |
+
repo_type=repo_type,
|
51 |
+
folder_path=src,
|
52 |
+
path_in_repo=path_in_repo,
|
53 |
+
)
|
54 |
+
else:
|
55 |
+
api.upload_file(
|
56 |
+
repo_id=repo_id,
|
57 |
+
repo_type=repo_type,
|
58 |
+
path_or_fileobj=src,
|
59 |
+
path_in_repo=path_in_repo,
|
60 |
+
)
|
61 |
+
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
|
62 |
+
logger.error("===========================================")
|
63 |
+
logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
64 |
+
logger.error("===========================================")
|
65 |
+
|
66 |
+
if args.async_upload and not force_sync_upload:
|
67 |
+
fire_in_thread(uploader)
|
68 |
+
else:
|
69 |
+
uploader()
|
70 |
+
|
71 |
+
|
72 |
+
def list_dir(
|
73 |
+
repo_id: str,
|
74 |
+
subfolder: str,
|
75 |
+
repo_type: str,
|
76 |
+
revision: str = "main",
|
77 |
+
token: str = None,
|
78 |
+
):
|
79 |
+
api = HfApi(
|
80 |
+
token=token,
|
81 |
+
)
|
82 |
+
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
83 |
+
file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
|
84 |
+
return file_list
|
library/hypernetwork.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from diffusers.models.attention_processor import (
|
4 |
+
Attention,
|
5 |
+
AttnProcessor2_0,
|
6 |
+
SlicedAttnProcessor,
|
7 |
+
XFormersAttnProcessor
|
8 |
+
)
|
9 |
+
|
10 |
+
try:
|
11 |
+
import xformers.ops
|
12 |
+
except:
|
13 |
+
xformers = None
|
14 |
+
|
15 |
+
|
16 |
+
loaded_networks = []
|
17 |
+
|
18 |
+
|
19 |
+
def apply_single_hypernetwork(
|
20 |
+
hypernetwork, hidden_states, encoder_hidden_states
|
21 |
+
):
|
22 |
+
context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
|
23 |
+
return context_k, context_v
|
24 |
+
|
25 |
+
|
26 |
+
def apply_hypernetworks(context_k, context_v, layer=None):
|
27 |
+
if len(loaded_networks) == 0:
|
28 |
+
return context_v, context_v
|
29 |
+
for hypernetwork in loaded_networks:
|
30 |
+
context_k, context_v = hypernetwork.forward(context_k, context_v)
|
31 |
+
|
32 |
+
context_k = context_k.to(dtype=context_k.dtype)
|
33 |
+
context_v = context_v.to(dtype=context_k.dtype)
|
34 |
+
|
35 |
+
return context_k, context_v
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
def xformers_forward(
|
40 |
+
self: XFormersAttnProcessor,
|
41 |
+
attn: Attention,
|
42 |
+
hidden_states: torch.Tensor,
|
43 |
+
encoder_hidden_states: torch.Tensor = None,
|
44 |
+
attention_mask: torch.Tensor = None,
|
45 |
+
):
|
46 |
+
batch_size, sequence_length, _ = (
|
47 |
+
hidden_states.shape
|
48 |
+
if encoder_hidden_states is None
|
49 |
+
else encoder_hidden_states.shape
|
50 |
+
)
|
51 |
+
|
52 |
+
attention_mask = attn.prepare_attention_mask(
|
53 |
+
attention_mask, sequence_length, batch_size
|
54 |
+
)
|
55 |
+
|
56 |
+
query = attn.to_q(hidden_states)
|
57 |
+
|
58 |
+
if encoder_hidden_states is None:
|
59 |
+
encoder_hidden_states = hidden_states
|
60 |
+
elif attn.norm_cross:
|
61 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
62 |
+
|
63 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
64 |
+
|
65 |
+
key = attn.to_k(context_k)
|
66 |
+
value = attn.to_v(context_v)
|
67 |
+
|
68 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
69 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
70 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
71 |
+
|
72 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
73 |
+
query,
|
74 |
+
key,
|
75 |
+
value,
|
76 |
+
attn_bias=attention_mask,
|
77 |
+
op=self.attention_op,
|
78 |
+
scale=attn.scale,
|
79 |
+
)
|
80 |
+
hidden_states = hidden_states.to(query.dtype)
|
81 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
82 |
+
|
83 |
+
# linear proj
|
84 |
+
hidden_states = attn.to_out[0](hidden_states)
|
85 |
+
# dropout
|
86 |
+
hidden_states = attn.to_out[1](hidden_states)
|
87 |
+
return hidden_states
|
88 |
+
|
89 |
+
|
90 |
+
def sliced_attn_forward(
|
91 |
+
self: SlicedAttnProcessor,
|
92 |
+
attn: Attention,
|
93 |
+
hidden_states: torch.Tensor,
|
94 |
+
encoder_hidden_states: torch.Tensor = None,
|
95 |
+
attention_mask: torch.Tensor = None,
|
96 |
+
):
|
97 |
+
batch_size, sequence_length, _ = (
|
98 |
+
hidden_states.shape
|
99 |
+
if encoder_hidden_states is None
|
100 |
+
else encoder_hidden_states.shape
|
101 |
+
)
|
102 |
+
attention_mask = attn.prepare_attention_mask(
|
103 |
+
attention_mask, sequence_length, batch_size
|
104 |
+
)
|
105 |
+
|
106 |
+
query = attn.to_q(hidden_states)
|
107 |
+
dim = query.shape[-1]
|
108 |
+
query = attn.head_to_batch_dim(query)
|
109 |
+
|
110 |
+
if encoder_hidden_states is None:
|
111 |
+
encoder_hidden_states = hidden_states
|
112 |
+
elif attn.norm_cross:
|
113 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
114 |
+
|
115 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
116 |
+
|
117 |
+
key = attn.to_k(context_k)
|
118 |
+
value = attn.to_v(context_v)
|
119 |
+
key = attn.head_to_batch_dim(key)
|
120 |
+
value = attn.head_to_batch_dim(value)
|
121 |
+
|
122 |
+
batch_size_attention, query_tokens, _ = query.shape
|
123 |
+
hidden_states = torch.zeros(
|
124 |
+
(batch_size_attention, query_tokens, dim // attn.heads),
|
125 |
+
device=query.device,
|
126 |
+
dtype=query.dtype,
|
127 |
+
)
|
128 |
+
|
129 |
+
for i in range(batch_size_attention // self.slice_size):
|
130 |
+
start_idx = i * self.slice_size
|
131 |
+
end_idx = (i + 1) * self.slice_size
|
132 |
+
|
133 |
+
query_slice = query[start_idx:end_idx]
|
134 |
+
key_slice = key[start_idx:end_idx]
|
135 |
+
attn_mask_slice = (
|
136 |
+
attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
137 |
+
)
|
138 |
+
|
139 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
140 |
+
|
141 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
142 |
+
|
143 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
144 |
+
|
145 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
146 |
+
|
147 |
+
# linear proj
|
148 |
+
hidden_states = attn.to_out[0](hidden_states)
|
149 |
+
# dropout
|
150 |
+
hidden_states = attn.to_out[1](hidden_states)
|
151 |
+
|
152 |
+
return hidden_states
|
153 |
+
|
154 |
+
|
155 |
+
def v2_0_forward(
|
156 |
+
self: AttnProcessor2_0,
|
157 |
+
attn: Attention,
|
158 |
+
hidden_states,
|
159 |
+
encoder_hidden_states=None,
|
160 |
+
attention_mask=None,
|
161 |
+
):
|
162 |
+
batch_size, sequence_length, _ = (
|
163 |
+
hidden_states.shape
|
164 |
+
if encoder_hidden_states is None
|
165 |
+
else encoder_hidden_states.shape
|
166 |
+
)
|
167 |
+
inner_dim = hidden_states.shape[-1]
|
168 |
+
|
169 |
+
if attention_mask is not None:
|
170 |
+
attention_mask = attn.prepare_attention_mask(
|
171 |
+
attention_mask, sequence_length, batch_size
|
172 |
+
)
|
173 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
174 |
+
# (batch, heads, source_length, target_length)
|
175 |
+
attention_mask = attention_mask.view(
|
176 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
177 |
+
)
|
178 |
+
|
179 |
+
query = attn.to_q(hidden_states)
|
180 |
+
|
181 |
+
if encoder_hidden_states is None:
|
182 |
+
encoder_hidden_states = hidden_states
|
183 |
+
elif attn.norm_cross:
|
184 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
185 |
+
|
186 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
187 |
+
|
188 |
+
key = attn.to_k(context_k)
|
189 |
+
value = attn.to_v(context_v)
|
190 |
+
|
191 |
+
head_dim = inner_dim // attn.heads
|
192 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
193 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
194 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
195 |
+
|
196 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
197 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
198 |
+
hidden_states = F.scaled_dot_product_attention(
|
199 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
200 |
+
)
|
201 |
+
|
202 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
203 |
+
batch_size, -1, attn.heads * head_dim
|
204 |
+
)
|
205 |
+
hidden_states = hidden_states.to(query.dtype)
|
206 |
+
|
207 |
+
# linear proj
|
208 |
+
hidden_states = attn.to_out[0](hidden_states)
|
209 |
+
# dropout
|
210 |
+
hidden_states = attn.to_out[1](hidden_states)
|
211 |
+
return hidden_states
|
212 |
+
|
213 |
+
|
214 |
+
def replace_attentions_for_hypernetwork():
|
215 |
+
import diffusers.models.attention_processor
|
216 |
+
|
217 |
+
diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
|
218 |
+
xformers_forward
|
219 |
+
)
|
220 |
+
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
|
221 |
+
sliced_attn_forward
|
222 |
+
)
|
223 |
+
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
|
library/ipex/__init__.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import contextlib
|
4 |
+
import torch
|
5 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
6 |
+
from .hijacks import ipex_hijacks
|
7 |
+
|
8 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
9 |
+
|
10 |
+
def ipex_init(): # pylint: disable=too-many-statements
|
11 |
+
try:
|
12 |
+
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
|
13 |
+
return True, "Skipping IPEX hijack"
|
14 |
+
else:
|
15 |
+
# Replace cuda with xpu:
|
16 |
+
torch.cuda.current_device = torch.xpu.current_device
|
17 |
+
torch.cuda.current_stream = torch.xpu.current_stream
|
18 |
+
torch.cuda.device = torch.xpu.device
|
19 |
+
torch.cuda.device_count = torch.xpu.device_count
|
20 |
+
torch.cuda.device_of = torch.xpu.device_of
|
21 |
+
torch.cuda.get_device_name = torch.xpu.get_device_name
|
22 |
+
torch.cuda.get_device_properties = torch.xpu.get_device_properties
|
23 |
+
torch.cuda.init = torch.xpu.init
|
24 |
+
torch.cuda.is_available = torch.xpu.is_available
|
25 |
+
torch.cuda.is_initialized = torch.xpu.is_initialized
|
26 |
+
torch.cuda.is_current_stream_capturing = lambda: False
|
27 |
+
torch.cuda.set_device = torch.xpu.set_device
|
28 |
+
torch.cuda.stream = torch.xpu.stream
|
29 |
+
torch.cuda.synchronize = torch.xpu.synchronize
|
30 |
+
torch.cuda.Event = torch.xpu.Event
|
31 |
+
torch.cuda.Stream = torch.xpu.Stream
|
32 |
+
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
33 |
+
torch.Tensor.cuda = torch.Tensor.xpu
|
34 |
+
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
35 |
+
torch.nn.Module.cuda = torch.nn.Module.xpu
|
36 |
+
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
37 |
+
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
38 |
+
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
39 |
+
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
40 |
+
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
41 |
+
torch.cuda._tls = torch.xpu.lazy_init._tls
|
42 |
+
torch.cuda.threading = torch.xpu.lazy_init.threading
|
43 |
+
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
44 |
+
torch.cuda.Optional = torch.xpu.Optional
|
45 |
+
torch.cuda.__cached__ = torch.xpu.__cached__
|
46 |
+
torch.cuda.__loader__ = torch.xpu.__loader__
|
47 |
+
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
48 |
+
torch.cuda.Tuple = torch.xpu.Tuple
|
49 |
+
torch.cuda.streams = torch.xpu.streams
|
50 |
+
torch.cuda._lazy_new = torch.xpu._lazy_new
|
51 |
+
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
52 |
+
torch.cuda.Any = torch.xpu.Any
|
53 |
+
torch.cuda.__doc__ = torch.xpu.__doc__
|
54 |
+
torch.cuda.default_generators = torch.xpu.default_generators
|
55 |
+
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
56 |
+
torch.cuda._get_device_index = torch.xpu._get_device_index
|
57 |
+
torch.cuda.__path__ = torch.xpu.__path__
|
58 |
+
torch.cuda.Device = torch.xpu.Device
|
59 |
+
torch.cuda.IntTensor = torch.xpu.IntTensor
|
60 |
+
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
61 |
+
torch.cuda.set_stream = torch.xpu.set_stream
|
62 |
+
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
63 |
+
torch.cuda.os = torch.xpu.os
|
64 |
+
torch.cuda.torch = torch.xpu.torch
|
65 |
+
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
66 |
+
torch.cuda.Union = torch.xpu.Union
|
67 |
+
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
68 |
+
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
69 |
+
torch.cuda.LongTensor = torch.xpu.LongTensor
|
70 |
+
torch.cuda.IntStorage = torch.xpu.IntStorage
|
71 |
+
torch.cuda.LongStorage = torch.xpu.LongStorage
|
72 |
+
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
73 |
+
torch.cuda.__package__ = torch.xpu.__package__
|
74 |
+
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
75 |
+
torch.cuda.CharTensor = torch.xpu.CharTensor
|
76 |
+
torch.cuda.List = torch.xpu.List
|
77 |
+
torch.cuda._lazy_init = torch.xpu._lazy_init
|
78 |
+
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
79 |
+
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
80 |
+
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
81 |
+
torch.cuda.StreamContext = torch.xpu.StreamContext
|
82 |
+
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
83 |
+
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
84 |
+
torch.cuda._lazy_call = torch.xpu._lazy_call
|
85 |
+
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
86 |
+
torch.cuda.random = torch.xpu.random
|
87 |
+
torch.cuda._device = torch.xpu._device
|
88 |
+
torch.cuda.classproperty = torch.xpu.classproperty
|
89 |
+
torch.cuda.__name__ = torch.xpu.__name__
|
90 |
+
torch.cuda._device_t = torch.xpu._device_t
|
91 |
+
torch.cuda.warnings = torch.xpu.warnings
|
92 |
+
torch.cuda.__spec__ = torch.xpu.__spec__
|
93 |
+
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
94 |
+
torch.cuda.CharStorage = torch.xpu.CharStorage
|
95 |
+
torch.cuda.__file__ = torch.xpu.__file__
|
96 |
+
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
97 |
+
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
98 |
+
|
99 |
+
# Memory:
|
100 |
+
torch.cuda.memory = torch.xpu.memory
|
101 |
+
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
102 |
+
torch.xpu.empty_cache = lambda: None
|
103 |
+
torch.cuda.empty_cache = torch.xpu.empty_cache
|
104 |
+
torch.cuda.memory_stats = torch.xpu.memory_stats
|
105 |
+
torch.cuda.memory_summary = torch.xpu.memory_summary
|
106 |
+
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
107 |
+
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
108 |
+
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
109 |
+
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
110 |
+
torch.cuda.memory_cached = torch.xpu.memory_reserved
|
111 |
+
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
|
112 |
+
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
|
113 |
+
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
|
114 |
+
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
|
115 |
+
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
|
116 |
+
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
117 |
+
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
118 |
+
|
119 |
+
# RNG:
|
120 |
+
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
121 |
+
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
122 |
+
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
123 |
+
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
|
124 |
+
torch.cuda.manual_seed = torch.xpu.manual_seed
|
125 |
+
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
|
126 |
+
torch.cuda.seed = torch.xpu.seed
|
127 |
+
torch.cuda.seed_all = torch.xpu.seed_all
|
128 |
+
torch.cuda.initial_seed = torch.xpu.initial_seed
|
129 |
+
|
130 |
+
# AMP:
|
131 |
+
torch.cuda.amp = torch.xpu.amp
|
132 |
+
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
133 |
+
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
134 |
+
|
135 |
+
if not hasattr(torch.cuda.amp, "common"):
|
136 |
+
torch.cuda.amp.common = contextlib.nullcontext()
|
137 |
+
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
138 |
+
|
139 |
+
try:
|
140 |
+
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
141 |
+
except Exception: # pylint: disable=broad-exception-caught
|
142 |
+
try:
|
143 |
+
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
144 |
+
gradscaler_init()
|
145 |
+
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
146 |
+
except Exception: # pylint: disable=broad-exception-caught
|
147 |
+
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
148 |
+
|
149 |
+
# C
|
150 |
+
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
151 |
+
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
152 |
+
ipex._C._DeviceProperties.major = 2024
|
153 |
+
ipex._C._DeviceProperties.minor = 0
|
154 |
+
|
155 |
+
# Fix functions with ipex:
|
156 |
+
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
157 |
+
torch._utils._get_available_device_type = lambda: "xpu"
|
158 |
+
torch.has_cuda = True
|
159 |
+
torch.cuda.has_half = True
|
160 |
+
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
161 |
+
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
162 |
+
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
163 |
+
torch.version.cuda = "12.1"
|
164 |
+
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
|
165 |
+
torch.cuda.get_device_properties.major = 12
|
166 |
+
torch.cuda.get_device_properties.minor = 1
|
167 |
+
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
168 |
+
torch.cuda.utilization = lambda *args, **kwargs: 0
|
169 |
+
|
170 |
+
ipex_hijacks()
|
171 |
+
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
|
172 |
+
try:
|
173 |
+
from .diffusers import ipex_diffusers
|
174 |
+
ipex_diffusers()
|
175 |
+
except Exception: # pylint: disable=broad-exception-caught
|
176 |
+
pass
|
177 |
+
torch.cuda.is_xpu_hijacked = True
|
178 |
+
except Exception as e:
|
179 |
+
return False, e
|
180 |
+
return True, None
|
library/ipex/attention.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
4 |
+
from functools import cache
|
5 |
+
|
6 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
7 |
+
|
8 |
+
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
|
9 |
+
|
10 |
+
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
|
11 |
+
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
12 |
+
|
13 |
+
# Find something divisible with the input_tokens
|
14 |
+
@cache
|
15 |
+
def find_slice_size(slice_size, slice_block_size):
|
16 |
+
while (slice_size * slice_block_size) > attention_slice_rate:
|
17 |
+
slice_size = slice_size // 2
|
18 |
+
if slice_size <= 1:
|
19 |
+
slice_size = 1
|
20 |
+
break
|
21 |
+
return slice_size
|
22 |
+
|
23 |
+
# Find slice sizes for SDPA
|
24 |
+
@cache
|
25 |
+
def find_sdpa_slice_sizes(query_shape, query_element_size):
|
26 |
+
if len(query_shape) == 3:
|
27 |
+
batch_size_attention, query_tokens, shape_three = query_shape
|
28 |
+
shape_four = 1
|
29 |
+
else:
|
30 |
+
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
31 |
+
|
32 |
+
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
33 |
+
block_size = batch_size_attention * slice_block_size
|
34 |
+
|
35 |
+
split_slice_size = batch_size_attention
|
36 |
+
split_2_slice_size = query_tokens
|
37 |
+
split_3_slice_size = shape_three
|
38 |
+
|
39 |
+
do_split = False
|
40 |
+
do_split_2 = False
|
41 |
+
do_split_3 = False
|
42 |
+
|
43 |
+
if block_size > sdpa_slice_trigger_rate:
|
44 |
+
do_split = True
|
45 |
+
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
46 |
+
if split_slice_size * slice_block_size > attention_slice_rate:
|
47 |
+
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
48 |
+
do_split_2 = True
|
49 |
+
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
50 |
+
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
51 |
+
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
52 |
+
do_split_3 = True
|
53 |
+
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
54 |
+
|
55 |
+
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
56 |
+
|
57 |
+
# Find slice sizes for BMM
|
58 |
+
@cache
|
59 |
+
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
|
60 |
+
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
|
61 |
+
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
|
62 |
+
block_size = batch_size_attention * slice_block_size
|
63 |
+
|
64 |
+
split_slice_size = batch_size_attention
|
65 |
+
split_2_slice_size = input_tokens
|
66 |
+
split_3_slice_size = mat2_atten_shape
|
67 |
+
|
68 |
+
do_split = False
|
69 |
+
do_split_2 = False
|
70 |
+
do_split_3 = False
|
71 |
+
|
72 |
+
if block_size > attention_slice_rate:
|
73 |
+
do_split = True
|
74 |
+
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
75 |
+
if split_slice_size * slice_block_size > attention_slice_rate:
|
76 |
+
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
|
77 |
+
do_split_2 = True
|
78 |
+
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
79 |
+
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
80 |
+
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
|
81 |
+
do_split_3 = True
|
82 |
+
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
83 |
+
|
84 |
+
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
85 |
+
|
86 |
+
|
87 |
+
original_torch_bmm = torch.bmm
|
88 |
+
def torch_bmm_32_bit(input, mat2, *, out=None):
|
89 |
+
if input.device.type != "xpu":
|
90 |
+
return original_torch_bmm(input, mat2, out=out)
|
91 |
+
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
|
92 |
+
|
93 |
+
# Slice BMM
|
94 |
+
if do_split:
|
95 |
+
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
96 |
+
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
97 |
+
for i in range(batch_size_attention // split_slice_size):
|
98 |
+
start_idx = i * split_slice_size
|
99 |
+
end_idx = (i + 1) * split_slice_size
|
100 |
+
if do_split_2:
|
101 |
+
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
102 |
+
start_idx_2 = i2 * split_2_slice_size
|
103 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
104 |
+
if do_split_3:
|
105 |
+
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
|
106 |
+
start_idx_3 = i3 * split_3_slice_size
|
107 |
+
end_idx_3 = (i3 + 1) * split_3_slice_size
|
108 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
|
109 |
+
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
110 |
+
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
111 |
+
out=out
|
112 |
+
)
|
113 |
+
else:
|
114 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
115 |
+
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
116 |
+
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
117 |
+
out=out
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
121 |
+
input[start_idx:end_idx],
|
122 |
+
mat2[start_idx:end_idx],
|
123 |
+
out=out
|
124 |
+
)
|
125 |
+
torch.xpu.synchronize(input.device)
|
126 |
+
else:
|
127 |
+
return original_torch_bmm(input, mat2, out=out)
|
128 |
+
return hidden_states
|
129 |
+
|
130 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
131 |
+
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
132 |
+
if query.device.type != "xpu":
|
133 |
+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
134 |
+
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
|
135 |
+
|
136 |
+
# Slice SDPA
|
137 |
+
if do_split:
|
138 |
+
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
139 |
+
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
140 |
+
for i in range(batch_size_attention // split_slice_size):
|
141 |
+
start_idx = i * split_slice_size
|
142 |
+
end_idx = (i + 1) * split_slice_size
|
143 |
+
if do_split_2:
|
144 |
+
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
145 |
+
start_idx_2 = i2 * split_2_slice_size
|
146 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
147 |
+
if do_split_3:
|
148 |
+
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
149 |
+
start_idx_3 = i3 * split_3_slice_size
|
150 |
+
end_idx_3 = (i3 + 1) * split_3_slice_size
|
151 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
|
152 |
+
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
153 |
+
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
154 |
+
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
155 |
+
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
156 |
+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
157 |
+
)
|
158 |
+
else:
|
159 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
160 |
+
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
161 |
+
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
162 |
+
value[start_idx:end_idx, start_idx_2:end_idx_2],
|
163 |
+
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
164 |
+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
168 |
+
query[start_idx:end_idx],
|
169 |
+
key[start_idx:end_idx],
|
170 |
+
value[start_idx:end_idx],
|
171 |
+
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
172 |
+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
173 |
+
)
|
174 |
+
torch.xpu.synchronize(query.device)
|
175 |
+
else:
|
176 |
+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
177 |
+
return hidden_states
|
library/ipex/diffusers.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
4 |
+
import diffusers #0.24.0 # pylint: disable=import-error
|
5 |
+
from diffusers.models.attention_processor import Attention
|
6 |
+
from diffusers.utils import USE_PEFT_BACKEND
|
7 |
+
from functools import cache
|
8 |
+
|
9 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
10 |
+
|
11 |
+
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
12 |
+
|
13 |
+
@cache
|
14 |
+
def find_slice_size(slice_size, slice_block_size):
|
15 |
+
while (slice_size * slice_block_size) > attention_slice_rate:
|
16 |
+
slice_size = slice_size // 2
|
17 |
+
if slice_size <= 1:
|
18 |
+
slice_size = 1
|
19 |
+
break
|
20 |
+
return slice_size
|
21 |
+
|
22 |
+
@cache
|
23 |
+
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
|
24 |
+
if len(query_shape) == 3:
|
25 |
+
batch_size_attention, query_tokens, shape_three = query_shape
|
26 |
+
shape_four = 1
|
27 |
+
else:
|
28 |
+
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
29 |
+
if slice_size is not None:
|
30 |
+
batch_size_attention = slice_size
|
31 |
+
|
32 |
+
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
33 |
+
block_size = batch_size_attention * slice_block_size
|
34 |
+
|
35 |
+
split_slice_size = batch_size_attention
|
36 |
+
split_2_slice_size = query_tokens
|
37 |
+
split_3_slice_size = shape_three
|
38 |
+
|
39 |
+
do_split = False
|
40 |
+
do_split_2 = False
|
41 |
+
do_split_3 = False
|
42 |
+
|
43 |
+
if query_device_type != "xpu":
|
44 |
+
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
45 |
+
|
46 |
+
if block_size > attention_slice_rate:
|
47 |
+
do_split = True
|
48 |
+
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
49 |
+
if split_slice_size * slice_block_size > attention_slice_rate:
|
50 |
+
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
51 |
+
do_split_2 = True
|
52 |
+
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
53 |
+
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
54 |
+
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
55 |
+
do_split_3 = True
|
56 |
+
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
57 |
+
|
58 |
+
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
59 |
+
|
60 |
+
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
61 |
+
r"""
|
62 |
+
Processor for implementing sliced attention.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
slice_size (`int`, *optional*):
|
66 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
67 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, slice_size):
|
71 |
+
self.slice_size = slice_size
|
72 |
+
|
73 |
+
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
74 |
+
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
75 |
+
|
76 |
+
residual = hidden_states
|
77 |
+
|
78 |
+
input_ndim = hidden_states.ndim
|
79 |
+
|
80 |
+
if input_ndim == 4:
|
81 |
+
batch_size, channel, height, width = hidden_states.shape
|
82 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
83 |
+
|
84 |
+
batch_size, sequence_length, _ = (
|
85 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
86 |
+
)
|
87 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
88 |
+
|
89 |
+
if attn.group_norm is not None:
|
90 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
91 |
+
|
92 |
+
query = attn.to_q(hidden_states)
|
93 |
+
dim = query.shape[-1]
|
94 |
+
query = attn.head_to_batch_dim(query)
|
95 |
+
|
96 |
+
if encoder_hidden_states is None:
|
97 |
+
encoder_hidden_states = hidden_states
|
98 |
+
elif attn.norm_cross:
|
99 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
100 |
+
|
101 |
+
key = attn.to_k(encoder_hidden_states)
|
102 |
+
value = attn.to_v(encoder_hidden_states)
|
103 |
+
key = attn.head_to_batch_dim(key)
|
104 |
+
value = attn.head_to_batch_dim(value)
|
105 |
+
|
106 |
+
batch_size_attention, query_tokens, shape_three = query.shape
|
107 |
+
hidden_states = torch.zeros(
|
108 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
109 |
+
)
|
110 |
+
|
111 |
+
####################################################################
|
112 |
+
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
113 |
+
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
|
114 |
+
|
115 |
+
for i in range(batch_size_attention // split_slice_size):
|
116 |
+
start_idx = i * split_slice_size
|
117 |
+
end_idx = (i + 1) * split_slice_size
|
118 |
+
if do_split_2:
|
119 |
+
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
120 |
+
start_idx_2 = i2 * split_2_slice_size
|
121 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
122 |
+
if do_split_3:
|
123 |
+
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
124 |
+
start_idx_3 = i3 * split_3_slice_size
|
125 |
+
end_idx_3 = (i3 + 1) * split_3_slice_size
|
126 |
+
|
127 |
+
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
128 |
+
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
129 |
+
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
130 |
+
|
131 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
132 |
+
del query_slice
|
133 |
+
del key_slice
|
134 |
+
del attn_mask_slice
|
135 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
136 |
+
|
137 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
138 |
+
del attn_slice
|
139 |
+
else:
|
140 |
+
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
141 |
+
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
142 |
+
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
143 |
+
|
144 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
145 |
+
del query_slice
|
146 |
+
del key_slice
|
147 |
+
del attn_mask_slice
|
148 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
149 |
+
|
150 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
151 |
+
del attn_slice
|
152 |
+
torch.xpu.synchronize(query.device)
|
153 |
+
else:
|
154 |
+
query_slice = query[start_idx:end_idx]
|
155 |
+
key_slice = key[start_idx:end_idx]
|
156 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
157 |
+
|
158 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
159 |
+
del query_slice
|
160 |
+
del key_slice
|
161 |
+
del attn_mask_slice
|
162 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
163 |
+
|
164 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
165 |
+
del attn_slice
|
166 |
+
####################################################################
|
167 |
+
|
168 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
169 |
+
|
170 |
+
# linear proj
|
171 |
+
hidden_states = attn.to_out[0](hidden_states)
|
172 |
+
# dropout
|
173 |
+
hidden_states = attn.to_out[1](hidden_states)
|
174 |
+
|
175 |
+
if input_ndim == 4:
|
176 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
177 |
+
|
178 |
+
if attn.residual_connection:
|
179 |
+
hidden_states = hidden_states + residual
|
180 |
+
|
181 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
182 |
+
|
183 |
+
return hidden_states
|
184 |
+
|
185 |
+
|
186 |
+
class AttnProcessor:
|
187 |
+
r"""
|
188 |
+
Default processor for performing attention-related computations.
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
192 |
+
encoder_hidden_states=None, attention_mask=None,
|
193 |
+
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
194 |
+
|
195 |
+
residual = hidden_states
|
196 |
+
|
197 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
198 |
+
|
199 |
+
if attn.spatial_norm is not None:
|
200 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
201 |
+
|
202 |
+
input_ndim = hidden_states.ndim
|
203 |
+
|
204 |
+
if input_ndim == 4:
|
205 |
+
batch_size, channel, height, width = hidden_states.shape
|
206 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
207 |
+
|
208 |
+
batch_size, sequence_length, _ = (
|
209 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
210 |
+
)
|
211 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
212 |
+
|
213 |
+
if attn.group_norm is not None:
|
214 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
215 |
+
|
216 |
+
query = attn.to_q(hidden_states, *args)
|
217 |
+
|
218 |
+
if encoder_hidden_states is None:
|
219 |
+
encoder_hidden_states = hidden_states
|
220 |
+
elif attn.norm_cross:
|
221 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
222 |
+
|
223 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
224 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
225 |
+
|
226 |
+
query = attn.head_to_batch_dim(query)
|
227 |
+
key = attn.head_to_batch_dim(key)
|
228 |
+
value = attn.head_to_batch_dim(value)
|
229 |
+
|
230 |
+
####################################################################
|
231 |
+
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
232 |
+
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
233 |
+
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
234 |
+
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
|
235 |
+
|
236 |
+
if do_split:
|
237 |
+
for i in range(batch_size_attention // split_slice_size):
|
238 |
+
start_idx = i * split_slice_size
|
239 |
+
end_idx = (i + 1) * split_slice_size
|
240 |
+
if do_split_2:
|
241 |
+
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
242 |
+
start_idx_2 = i2 * split_2_slice_size
|
243 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
244 |
+
if do_split_3:
|
245 |
+
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
246 |
+
start_idx_3 = i3 * split_3_slice_size
|
247 |
+
end_idx_3 = (i3 + 1) * split_3_slice_size
|
248 |
+
|
249 |
+
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
250 |
+
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
251 |
+
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
252 |
+
|
253 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
254 |
+
del query_slice
|
255 |
+
del key_slice
|
256 |
+
del attn_mask_slice
|
257 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
258 |
+
|
259 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
260 |
+
del attn_slice
|
261 |
+
else:
|
262 |
+
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
263 |
+
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
264 |
+
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
265 |
+
|
266 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
267 |
+
del query_slice
|
268 |
+
del key_slice
|
269 |
+
del attn_mask_slice
|
270 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
271 |
+
|
272 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
273 |
+
del attn_slice
|
274 |
+
else:
|
275 |
+
query_slice = query[start_idx:end_idx]
|
276 |
+
key_slice = key[start_idx:end_idx]
|
277 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
278 |
+
|
279 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
280 |
+
del query_slice
|
281 |
+
del key_slice
|
282 |
+
del attn_mask_slice
|
283 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
284 |
+
|
285 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
286 |
+
del attn_slice
|
287 |
+
torch.xpu.synchronize(query.device)
|
288 |
+
else:
|
289 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
290 |
+
hidden_states = torch.bmm(attention_probs, value)
|
291 |
+
####################################################################
|
292 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
293 |
+
|
294 |
+
# linear proj
|
295 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
296 |
+
# dropout
|
297 |
+
hidden_states = attn.to_out[1](hidden_states)
|
298 |
+
|
299 |
+
if input_ndim == 4:
|
300 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
301 |
+
|
302 |
+
if attn.residual_connection:
|
303 |
+
hidden_states = hidden_states + residual
|
304 |
+
|
305 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
306 |
+
|
307 |
+
return hidden_states
|
308 |
+
|
309 |
+
def ipex_diffusers():
|
310 |
+
#ARC GPUs can't allocate more than 4GB to a single block:
|
311 |
+
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
312 |
+
diffusers.models.attention_processor.AttnProcessor = AttnProcessor
|
library/ipex/gradscaler.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
import torch
|
3 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
4 |
+
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
|
5 |
+
|
6 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
7 |
+
|
8 |
+
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
9 |
+
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
10 |
+
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
11 |
+
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
12 |
+
|
13 |
+
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
|
14 |
+
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
15 |
+
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
16 |
+
|
17 |
+
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
18 |
+
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
19 |
+
# However, we don't know their devices or dtypes in advance.
|
20 |
+
|
21 |
+
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
22 |
+
# Google says mypy struggles with defaultdicts type annotations.
|
23 |
+
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
24 |
+
# sync grad to master weight
|
25 |
+
if hasattr(optimizer, "sync_grad"):
|
26 |
+
optimizer.sync_grad()
|
27 |
+
with torch.no_grad():
|
28 |
+
for group in optimizer.param_groups:
|
29 |
+
for param in group["params"]:
|
30 |
+
if param.grad is None:
|
31 |
+
continue
|
32 |
+
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
33 |
+
raise ValueError("Attempting to unscale FP16 gradients.")
|
34 |
+
if param.grad.is_sparse:
|
35 |
+
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
36 |
+
# coalesce() deduplicates indices and adds all values that have the same index.
|
37 |
+
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
|
38 |
+
# so we should check the coalesced _values().
|
39 |
+
if param.grad.dtype is torch.float16:
|
40 |
+
param.grad = param.grad.coalesce()
|
41 |
+
to_unscale = param.grad._values()
|
42 |
+
else:
|
43 |
+
to_unscale = param.grad
|
44 |
+
|
45 |
+
# -: is there a way to split by device and dtype without appending in the inner loop?
|
46 |
+
to_unscale = to_unscale.to("cpu")
|
47 |
+
per_device_and_dtype_grads[to_unscale.device][
|
48 |
+
to_unscale.dtype
|
49 |
+
].append(to_unscale)
|
50 |
+
|
51 |
+
for _, per_dtype_grads in per_device_and_dtype_grads.items():
|
52 |
+
for grads in per_dtype_grads.values():
|
53 |
+
core._amp_foreach_non_finite_check_and_unscale_(
|
54 |
+
grads,
|
55 |
+
per_device_found_inf.get("cpu"),
|
56 |
+
per_device_inv_scale.get("cpu"),
|
57 |
+
)
|
58 |
+
|
59 |
+
return per_device_found_inf._per_device_tensors
|
60 |
+
|
61 |
+
def unscale_(self, optimizer):
|
62 |
+
"""
|
63 |
+
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
64 |
+
:meth:`unscale_` is optional, serving cases where you need to
|
65 |
+
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
66 |
+
between the backward pass(es) and :meth:`step`.
|
67 |
+
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
68 |
+
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
69 |
+
...
|
70 |
+
scaler.scale(loss).backward()
|
71 |
+
scaler.unscale_(optimizer)
|
72 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
73 |
+
scaler.step(optimizer)
|
74 |
+
scaler.update()
|
75 |
+
Args:
|
76 |
+
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
77 |
+
.. warning::
|
78 |
+
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
79 |
+
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
80 |
+
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
81 |
+
.. warning::
|
82 |
+
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
83 |
+
"""
|
84 |
+
if not self._enabled:
|
85 |
+
return
|
86 |
+
|
87 |
+
self._check_scale_growth_tracker("unscale_")
|
88 |
+
|
89 |
+
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
90 |
+
|
91 |
+
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
|
92 |
+
raise RuntimeError(
|
93 |
+
"unscale_() has already been called on this optimizer since the last update()."
|
94 |
+
)
|
95 |
+
elif optimizer_state["stage"] is OptState.STEPPED:
|
96 |
+
raise RuntimeError("unscale_() is being called after step().")
|
97 |
+
|
98 |
+
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
99 |
+
assert self._scale is not None
|
100 |
+
if device_supports_fp64:
|
101 |
+
inv_scale = self._scale.double().reciprocal().float()
|
102 |
+
else:
|
103 |
+
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
104 |
+
found_inf = torch.full(
|
105 |
+
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
106 |
+
)
|
107 |
+
|
108 |
+
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
109 |
+
optimizer, inv_scale, found_inf, False
|
110 |
+
)
|
111 |
+
optimizer_state["stage"] = OptState.UNSCALED
|
112 |
+
|
113 |
+
def update(self, new_scale=None):
|
114 |
+
"""
|
115 |
+
Updates the scale factor.
|
116 |
+
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
117 |
+
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
118 |
+
the scale is multiplied by ``growth_factor`` to increase it.
|
119 |
+
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
120 |
+
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
121 |
+
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
122 |
+
affect the scale GradScaler uses internally.)
|
123 |
+
Args:
|
124 |
+
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
|
125 |
+
.. warning::
|
126 |
+
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
127 |
+
been invoked for all optimizers used this iteration.
|
128 |
+
"""
|
129 |
+
if not self._enabled:
|
130 |
+
return
|
131 |
+
|
132 |
+
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
133 |
+
|
134 |
+
if new_scale is not None:
|
135 |
+
# Accept a new user-defined scale.
|
136 |
+
if isinstance(new_scale, float):
|
137 |
+
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
138 |
+
else:
|
139 |
+
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
|
140 |
+
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
|
141 |
+
assert new_scale.numel() == 1, reason
|
142 |
+
assert new_scale.requires_grad is False, reason
|
143 |
+
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
144 |
+
else:
|
145 |
+
# Consume shared inf/nan data collected from optimizers to update the scale.
|
146 |
+
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
147 |
+
found_infs = [
|
148 |
+
found_inf.to(device="cpu", non_blocking=True)
|
149 |
+
for state in self._per_optimizer_states.values()
|
150 |
+
for found_inf in state["found_inf_per_device"].values()
|
151 |
+
]
|
152 |
+
|
153 |
+
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
154 |
+
|
155 |
+
found_inf_combined = found_infs[0]
|
156 |
+
if len(found_infs) > 1:
|
157 |
+
for i in range(1, len(found_infs)):
|
158 |
+
found_inf_combined += found_infs[i]
|
159 |
+
|
160 |
+
to_device = _scale.device
|
161 |
+
_scale = _scale.to("cpu")
|
162 |
+
_growth_tracker = _growth_tracker.to("cpu")
|
163 |
+
|
164 |
+
core._amp_update_scale_(
|
165 |
+
_scale,
|
166 |
+
_growth_tracker,
|
167 |
+
found_inf_combined,
|
168 |
+
self._growth_factor,
|
169 |
+
self._backoff_factor,
|
170 |
+
self._growth_interval,
|
171 |
+
)
|
172 |
+
|
173 |
+
_scale = _scale.to(to_device)
|
174 |
+
_growth_tracker = _growth_tracker.to(to_device)
|
175 |
+
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
176 |
+
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
177 |
+
|
178 |
+
def gradscaler_init():
|
179 |
+
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
180 |
+
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
|
181 |
+
torch.xpu.amp.GradScaler.unscale_ = unscale_
|
182 |
+
torch.xpu.amp.GradScaler.update = update
|
183 |
+
return torch.xpu.amp.GradScaler
|
library/ipex/hijacks.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import wraps
|
3 |
+
from contextlib import nullcontext
|
4 |
+
import torch
|
5 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
9 |
+
|
10 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
11 |
+
|
12 |
+
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
13 |
+
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
14 |
+
if isinstance(device_ids, list) and len(device_ids) > 1:
|
15 |
+
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
16 |
+
return module.to("xpu")
|
17 |
+
|
18 |
+
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
19 |
+
return nullcontext()
|
20 |
+
|
21 |
+
@property
|
22 |
+
def is_cuda(self):
|
23 |
+
return self.device.type == 'xpu' or self.device.type == 'cuda'
|
24 |
+
|
25 |
+
def check_device(device):
|
26 |
+
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
27 |
+
|
28 |
+
def return_xpu(device):
|
29 |
+
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
30 |
+
|
31 |
+
|
32 |
+
# Autocast
|
33 |
+
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
|
34 |
+
@wraps(torch.amp.autocast_mode.autocast.__init__)
|
35 |
+
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
|
36 |
+
if device_type == "cuda":
|
37 |
+
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
38 |
+
else:
|
39 |
+
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
40 |
+
|
41 |
+
# Latent Antialias CPU Offload:
|
42 |
+
original_interpolate = torch.nn.functional.interpolate
|
43 |
+
@wraps(torch.nn.functional.interpolate)
|
44 |
+
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
45 |
+
if antialias or align_corners is not None or mode == 'bicubic':
|
46 |
+
return_device = tensor.device
|
47 |
+
return_dtype = tensor.dtype
|
48 |
+
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
49 |
+
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
|
50 |
+
else:
|
51 |
+
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
52 |
+
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
53 |
+
|
54 |
+
|
55 |
+
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
56 |
+
original_from_numpy = torch.from_numpy
|
57 |
+
@wraps(torch.from_numpy)
|
58 |
+
def from_numpy(ndarray):
|
59 |
+
if ndarray.dtype == float:
|
60 |
+
return original_from_numpy(ndarray.astype('float32'))
|
61 |
+
else:
|
62 |
+
return original_from_numpy(ndarray)
|
63 |
+
|
64 |
+
original_as_tensor = torch.as_tensor
|
65 |
+
@wraps(torch.as_tensor)
|
66 |
+
def as_tensor(data, dtype=None, device=None):
|
67 |
+
if check_device(device):
|
68 |
+
device = return_xpu(device)
|
69 |
+
if isinstance(data, np.ndarray) and data.dtype == float and not (
|
70 |
+
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
|
71 |
+
return original_as_tensor(data, dtype=torch.float32, device=device)
|
72 |
+
else:
|
73 |
+
return original_as_tensor(data, dtype=dtype, device=device)
|
74 |
+
|
75 |
+
|
76 |
+
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
|
77 |
+
original_torch_bmm = torch.bmm
|
78 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
79 |
+
else:
|
80 |
+
# 32 bit attention workarounds for Alchemist:
|
81 |
+
try:
|
82 |
+
from .attention import torch_bmm_32_bit as original_torch_bmm
|
83 |
+
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
84 |
+
except Exception: # pylint: disable=broad-exception-caught
|
85 |
+
original_torch_bmm = torch.bmm
|
86 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
87 |
+
|
88 |
+
|
89 |
+
# Data Type Errors:
|
90 |
+
@wraps(torch.bmm)
|
91 |
+
def torch_bmm(input, mat2, *, out=None):
|
92 |
+
if input.dtype != mat2.dtype:
|
93 |
+
mat2 = mat2.to(input.dtype)
|
94 |
+
return original_torch_bmm(input, mat2, out=out)
|
95 |
+
|
96 |
+
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
97 |
+
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
98 |
+
if query.dtype != key.dtype:
|
99 |
+
key = key.to(dtype=query.dtype)
|
100 |
+
if query.dtype != value.dtype:
|
101 |
+
value = value.to(dtype=query.dtype)
|
102 |
+
if attn_mask is not None and query.dtype != attn_mask.dtype:
|
103 |
+
attn_mask = attn_mask.to(dtype=query.dtype)
|
104 |
+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
105 |
+
|
106 |
+
# A1111 FP16
|
107 |
+
original_functional_group_norm = torch.nn.functional.group_norm
|
108 |
+
@wraps(torch.nn.functional.group_norm)
|
109 |
+
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
110 |
+
if weight is not None and input.dtype != weight.data.dtype:
|
111 |
+
input = input.to(dtype=weight.data.dtype)
|
112 |
+
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
113 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
114 |
+
return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
|
115 |
+
|
116 |
+
# A1111 BF16
|
117 |
+
original_functional_layer_norm = torch.nn.functional.layer_norm
|
118 |
+
@wraps(torch.nn.functional.layer_norm)
|
119 |
+
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
120 |
+
if weight is not None and input.dtype != weight.data.dtype:
|
121 |
+
input = input.to(dtype=weight.data.dtype)
|
122 |
+
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
123 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
124 |
+
return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
|
125 |
+
|
126 |
+
# Training
|
127 |
+
original_functional_linear = torch.nn.functional.linear
|
128 |
+
@wraps(torch.nn.functional.linear)
|
129 |
+
def functional_linear(input, weight, bias=None):
|
130 |
+
if input.dtype != weight.data.dtype:
|
131 |
+
input = input.to(dtype=weight.data.dtype)
|
132 |
+
if bias is not None and bias.data.dtype != weight.data.dtype:
|
133 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
134 |
+
return original_functional_linear(input, weight, bias=bias)
|
135 |
+
|
136 |
+
original_functional_conv2d = torch.nn.functional.conv2d
|
137 |
+
@wraps(torch.nn.functional.conv2d)
|
138 |
+
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
139 |
+
if input.dtype != weight.data.dtype:
|
140 |
+
input = input.to(dtype=weight.data.dtype)
|
141 |
+
if bias is not None and bias.data.dtype != weight.data.dtype:
|
142 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
143 |
+
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
144 |
+
|
145 |
+
# A1111 Embedding BF16
|
146 |
+
original_torch_cat = torch.cat
|
147 |
+
@wraps(torch.cat)
|
148 |
+
def torch_cat(tensor, *args, **kwargs):
|
149 |
+
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
150 |
+
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
151 |
+
else:
|
152 |
+
return original_torch_cat(tensor, *args, **kwargs)
|
153 |
+
|
154 |
+
# SwinIR BF16:
|
155 |
+
original_functional_pad = torch.nn.functional.pad
|
156 |
+
@wraps(torch.nn.functional.pad)
|
157 |
+
def functional_pad(input, pad, mode='constant', value=None):
|
158 |
+
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
159 |
+
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
160 |
+
else:
|
161 |
+
return original_functional_pad(input, pad, mode=mode, value=value)
|
162 |
+
|
163 |
+
|
164 |
+
original_torch_tensor = torch.tensor
|
165 |
+
@wraps(torch.tensor)
|
166 |
+
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
167 |
+
if check_device(device):
|
168 |
+
device = return_xpu(device)
|
169 |
+
if not device_supports_fp64:
|
170 |
+
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
|
171 |
+
if dtype == torch.float64:
|
172 |
+
dtype = torch.float32
|
173 |
+
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
|
174 |
+
dtype = torch.float32
|
175 |
+
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
|
176 |
+
|
177 |
+
original_Tensor_to = torch.Tensor.to
|
178 |
+
@wraps(torch.Tensor.to)
|
179 |
+
def Tensor_to(self, device=None, *args, **kwargs):
|
180 |
+
if check_device(device):
|
181 |
+
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
|
182 |
+
else:
|
183 |
+
return original_Tensor_to(self, device, *args, **kwargs)
|
184 |
+
|
185 |
+
original_Tensor_cuda = torch.Tensor.cuda
|
186 |
+
@wraps(torch.Tensor.cuda)
|
187 |
+
def Tensor_cuda(self, device=None, *args, **kwargs):
|
188 |
+
if check_device(device):
|
189 |
+
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
|
190 |
+
else:
|
191 |
+
return original_Tensor_cuda(self, device, *args, **kwargs)
|
192 |
+
|
193 |
+
original_Tensor_pin_memory = torch.Tensor.pin_memory
|
194 |
+
@wraps(torch.Tensor.pin_memory)
|
195 |
+
def Tensor_pin_memory(self, device=None, *args, **kwargs):
|
196 |
+
if device is None:
|
197 |
+
device = "xpu"
|
198 |
+
if check_device(device):
|
199 |
+
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
|
200 |
+
else:
|
201 |
+
return original_Tensor_pin_memory(self, device, *args, **kwargs)
|
202 |
+
|
203 |
+
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
204 |
+
@wraps(torch.UntypedStorage.__init__)
|
205 |
+
def UntypedStorage_init(*args, device=None, **kwargs):
|
206 |
+
if check_device(device):
|
207 |
+
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
|
208 |
+
else:
|
209 |
+
return original_UntypedStorage_init(*args, device=device, **kwargs)
|
210 |
+
|
211 |
+
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
212 |
+
@wraps(torch.UntypedStorage.cuda)
|
213 |
+
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
214 |
+
if check_device(device):
|
215 |
+
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
|
216 |
+
else:
|
217 |
+
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
|
218 |
+
|
219 |
+
original_torch_empty = torch.empty
|
220 |
+
@wraps(torch.empty)
|
221 |
+
def torch_empty(*args, device=None, **kwargs):
|
222 |
+
if check_device(device):
|
223 |
+
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
|
224 |
+
else:
|
225 |
+
return original_torch_empty(*args, device=device, **kwargs)
|
226 |
+
|
227 |
+
original_torch_randn = torch.randn
|
228 |
+
@wraps(torch.randn)
|
229 |
+
def torch_randn(*args, device=None, dtype=None, **kwargs):
|
230 |
+
if dtype == bytes:
|
231 |
+
dtype = None
|
232 |
+
if check_device(device):
|
233 |
+
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
234 |
+
else:
|
235 |
+
return original_torch_randn(*args, device=device, **kwargs)
|
236 |
+
|
237 |
+
original_torch_ones = torch.ones
|
238 |
+
@wraps(torch.ones)
|
239 |
+
def torch_ones(*args, device=None, **kwargs):
|
240 |
+
if check_device(device):
|
241 |
+
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
|
242 |
+
else:
|
243 |
+
return original_torch_ones(*args, device=device, **kwargs)
|
244 |
+
|
245 |
+
original_torch_zeros = torch.zeros
|
246 |
+
@wraps(torch.zeros)
|
247 |
+
def torch_zeros(*args, device=None, **kwargs):
|
248 |
+
if check_device(device):
|
249 |
+
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
|
250 |
+
else:
|
251 |
+
return original_torch_zeros(*args, device=device, **kwargs)
|
252 |
+
|
253 |
+
original_torch_linspace = torch.linspace
|
254 |
+
@wraps(torch.linspace)
|
255 |
+
def torch_linspace(*args, device=None, **kwargs):
|
256 |
+
if check_device(device):
|
257 |
+
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
|
258 |
+
else:
|
259 |
+
return original_torch_linspace(*args, device=device, **kwargs)
|
260 |
+
|
261 |
+
original_torch_Generator = torch.Generator
|
262 |
+
@wraps(torch.Generator)
|
263 |
+
def torch_Generator(device=None):
|
264 |
+
if check_device(device):
|
265 |
+
return original_torch_Generator(return_xpu(device))
|
266 |
+
else:
|
267 |
+
return original_torch_Generator(device)
|
268 |
+
|
269 |
+
original_torch_load = torch.load
|
270 |
+
@wraps(torch.load)
|
271 |
+
def torch_load(f, map_location=None, *args, **kwargs):
|
272 |
+
if map_location is None:
|
273 |
+
map_location = "xpu"
|
274 |
+
if check_device(map_location):
|
275 |
+
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
|
276 |
+
else:
|
277 |
+
return original_torch_load(f, *args, map_location=map_location, **kwargs)
|
278 |
+
|
279 |
+
|
280 |
+
# Hijack Functions:
|
281 |
+
def ipex_hijacks():
|
282 |
+
torch.tensor = torch_tensor
|
283 |
+
torch.Tensor.to = Tensor_to
|
284 |
+
torch.Tensor.cuda = Tensor_cuda
|
285 |
+
torch.Tensor.pin_memory = Tensor_pin_memory
|
286 |
+
torch.UntypedStorage.__init__ = UntypedStorage_init
|
287 |
+
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
288 |
+
torch.empty = torch_empty
|
289 |
+
torch.randn = torch_randn
|
290 |
+
torch.ones = torch_ones
|
291 |
+
torch.zeros = torch_zeros
|
292 |
+
torch.linspace = torch_linspace
|
293 |
+
torch.Generator = torch_Generator
|
294 |
+
torch.load = torch_load
|
295 |
+
|
296 |
+
torch.backends.cuda.sdp_kernel = return_null_context
|
297 |
+
torch.nn.DataParallel = DummyDataParallel
|
298 |
+
torch.UntypedStorage.is_cuda = is_cuda
|
299 |
+
torch.amp.autocast_mode.autocast.__init__ = autocast_init
|
300 |
+
|
301 |
+
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
302 |
+
torch.nn.functional.group_norm = functional_group_norm
|
303 |
+
torch.nn.functional.layer_norm = functional_layer_norm
|
304 |
+
torch.nn.functional.linear = functional_linear
|
305 |
+
torch.nn.functional.conv2d = functional_conv2d
|
306 |
+
torch.nn.functional.interpolate = interpolate
|
307 |
+
torch.nn.functional.pad = functional_pad
|
308 |
+
|
309 |
+
torch.bmm = torch_bmm
|
310 |
+
torch.cat = torch_cat
|
311 |
+
if not device_supports_fp64:
|
312 |
+
torch.from_numpy = from_numpy
|
313 |
+
torch.as_tensor = as_tensor
|
library/lpw_stable_diffusion.py
ADDED
@@ -0,0 +1,1233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
|
2 |
+
# and modify to support SD2.x
|
3 |
+
|
4 |
+
import inspect
|
5 |
+
import re
|
6 |
+
from typing import Callable, List, Optional, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import PIL.Image
|
10 |
+
import torch
|
11 |
+
from packaging import version
|
12 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
13 |
+
|
14 |
+
import diffusers
|
15 |
+
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
16 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
17 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
18 |
+
from diffusers.utils import logging
|
19 |
+
|
20 |
+
try:
|
21 |
+
from diffusers.utils import PIL_INTERPOLATION
|
22 |
+
except ImportError:
|
23 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
24 |
+
PIL_INTERPOLATION = {
|
25 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
26 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
27 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
28 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
29 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
30 |
+
}
|
31 |
+
else:
|
32 |
+
PIL_INTERPOLATION = {
|
33 |
+
"linear": PIL.Image.LINEAR,
|
34 |
+
"bilinear": PIL.Image.BILINEAR,
|
35 |
+
"bicubic": PIL.Image.BICUBIC,
|
36 |
+
"lanczos": PIL.Image.LANCZOS,
|
37 |
+
"nearest": PIL.Image.NEAREST,
|
38 |
+
}
|
39 |
+
# ------------------------------------------------------------------------------
|
40 |
+
|
41 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
42 |
+
|
43 |
+
re_attention = re.compile(
|
44 |
+
r"""
|
45 |
+
\\\(|
|
46 |
+
\\\)|
|
47 |
+
\\\[|
|
48 |
+
\\]|
|
49 |
+
\\\\|
|
50 |
+
\\|
|
51 |
+
\(|
|
52 |
+
\[|
|
53 |
+
:([+-]?[.\d]+)\)|
|
54 |
+
\)|
|
55 |
+
]|
|
56 |
+
[^\\()\[\]:]+|
|
57 |
+
:
|
58 |
+
""",
|
59 |
+
re.X,
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def parse_prompt_attention(text):
|
64 |
+
"""
|
65 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
66 |
+
Accepted tokens are:
|
67 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
68 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
69 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
70 |
+
\( - literal character '('
|
71 |
+
\[ - literal character '['
|
72 |
+
\) - literal character ')'
|
73 |
+
\] - literal character ']'
|
74 |
+
\\ - literal character '\'
|
75 |
+
anything else - just text
|
76 |
+
>>> parse_prompt_attention('normal text')
|
77 |
+
[['normal text', 1.0]]
|
78 |
+
>>> parse_prompt_attention('an (important) word')
|
79 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
80 |
+
>>> parse_prompt_attention('(unbalanced')
|
81 |
+
[['unbalanced', 1.1]]
|
82 |
+
>>> parse_prompt_attention('\(literal\]')
|
83 |
+
[['(literal]', 1.0]]
|
84 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
85 |
+
[['unnecessaryparens', 1.1]]
|
86 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
87 |
+
[['a ', 1.0],
|
88 |
+
['house', 1.5730000000000004],
|
89 |
+
[' ', 1.1],
|
90 |
+
['on', 1.0],
|
91 |
+
[' a ', 1.1],
|
92 |
+
['hill', 0.55],
|
93 |
+
[', sun, ', 1.1],
|
94 |
+
['sky', 1.4641000000000006],
|
95 |
+
['.', 1.1]]
|
96 |
+
"""
|
97 |
+
|
98 |
+
res = []
|
99 |
+
round_brackets = []
|
100 |
+
square_brackets = []
|
101 |
+
|
102 |
+
round_bracket_multiplier = 1.1
|
103 |
+
square_bracket_multiplier = 1 / 1.1
|
104 |
+
|
105 |
+
def multiply_range(start_position, multiplier):
|
106 |
+
for p in range(start_position, len(res)):
|
107 |
+
res[p][1] *= multiplier
|
108 |
+
|
109 |
+
for m in re_attention.finditer(text):
|
110 |
+
text = m.group(0)
|
111 |
+
weight = m.group(1)
|
112 |
+
|
113 |
+
if text.startswith("\\"):
|
114 |
+
res.append([text[1:], 1.0])
|
115 |
+
elif text == "(":
|
116 |
+
round_brackets.append(len(res))
|
117 |
+
elif text == "[":
|
118 |
+
square_brackets.append(len(res))
|
119 |
+
elif weight is not None and len(round_brackets) > 0:
|
120 |
+
multiply_range(round_brackets.pop(), float(weight))
|
121 |
+
elif text == ")" and len(round_brackets) > 0:
|
122 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
123 |
+
elif text == "]" and len(square_brackets) > 0:
|
124 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
125 |
+
else:
|
126 |
+
res.append([text, 1.0])
|
127 |
+
|
128 |
+
for pos in round_brackets:
|
129 |
+
multiply_range(pos, round_bracket_multiplier)
|
130 |
+
|
131 |
+
for pos in square_brackets:
|
132 |
+
multiply_range(pos, square_bracket_multiplier)
|
133 |
+
|
134 |
+
if len(res) == 0:
|
135 |
+
res = [["", 1.0]]
|
136 |
+
|
137 |
+
# merge runs of identical weights
|
138 |
+
i = 0
|
139 |
+
while i + 1 < len(res):
|
140 |
+
if res[i][1] == res[i + 1][1]:
|
141 |
+
res[i][0] += res[i + 1][0]
|
142 |
+
res.pop(i + 1)
|
143 |
+
else:
|
144 |
+
i += 1
|
145 |
+
|
146 |
+
return res
|
147 |
+
|
148 |
+
|
149 |
+
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
|
150 |
+
r"""
|
151 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
152 |
+
|
153 |
+
No padding, starting or ending token is included.
|
154 |
+
"""
|
155 |
+
tokens = []
|
156 |
+
weights = []
|
157 |
+
truncated = False
|
158 |
+
for text in prompt:
|
159 |
+
texts_and_weights = parse_prompt_attention(text)
|
160 |
+
text_token = []
|
161 |
+
text_weight = []
|
162 |
+
for word, weight in texts_and_weights:
|
163 |
+
# tokenize and discard the starting and the ending token
|
164 |
+
token = pipe.tokenizer(word).input_ids[1:-1]
|
165 |
+
text_token += token
|
166 |
+
# copy the weight by length of token
|
167 |
+
text_weight += [weight] * len(token)
|
168 |
+
# stop if the text is too long (longer than truncation limit)
|
169 |
+
if len(text_token) > max_length:
|
170 |
+
truncated = True
|
171 |
+
break
|
172 |
+
# truncate
|
173 |
+
if len(text_token) > max_length:
|
174 |
+
truncated = True
|
175 |
+
text_token = text_token[:max_length]
|
176 |
+
text_weight = text_weight[:max_length]
|
177 |
+
tokens.append(text_token)
|
178 |
+
weights.append(text_weight)
|
179 |
+
if truncated:
|
180 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
181 |
+
return tokens, weights
|
182 |
+
|
183 |
+
|
184 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
185 |
+
r"""
|
186 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
187 |
+
"""
|
188 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
189 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
190 |
+
for i in range(len(tokens)):
|
191 |
+
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
192 |
+
if no_boseos_middle:
|
193 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
194 |
+
else:
|
195 |
+
w = []
|
196 |
+
if len(weights[i]) == 0:
|
197 |
+
w = [1.0] * weights_length
|
198 |
+
else:
|
199 |
+
for j in range(max_embeddings_multiples):
|
200 |
+
w.append(1.0) # weight for starting token in this chunk
|
201 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
202 |
+
w.append(1.0) # weight for ending token in this chunk
|
203 |
+
w += [1.0] * (weights_length - len(w))
|
204 |
+
weights[i] = w[:]
|
205 |
+
|
206 |
+
return tokens, weights
|
207 |
+
|
208 |
+
|
209 |
+
def get_unweighted_text_embeddings(
|
210 |
+
pipe: StableDiffusionPipeline,
|
211 |
+
text_input: torch.Tensor,
|
212 |
+
chunk_length: int,
|
213 |
+
clip_skip: int,
|
214 |
+
eos: int,
|
215 |
+
pad: int,
|
216 |
+
no_boseos_middle: Optional[bool] = True,
|
217 |
+
):
|
218 |
+
"""
|
219 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
220 |
+
it should be split into chunks and sent to the text encoder individually.
|
221 |
+
"""
|
222 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
223 |
+
if max_embeddings_multiples > 1:
|
224 |
+
text_embeddings = []
|
225 |
+
for i in range(max_embeddings_multiples):
|
226 |
+
# extract the i-th chunk
|
227 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
228 |
+
|
229 |
+
# cover the head and the tail by the starting and the ending tokens
|
230 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
231 |
+
if pad == eos: # v1
|
232 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
233 |
+
else: # v2
|
234 |
+
for j in range(len(text_input_chunk)):
|
235 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
236 |
+
text_input_chunk[j, -1] = eos
|
237 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
238 |
+
text_input_chunk[j, 1] = eos
|
239 |
+
|
240 |
+
if clip_skip is None or clip_skip == 1:
|
241 |
+
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
242 |
+
else:
|
243 |
+
enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
244 |
+
text_embedding = enc_out["hidden_states"][-clip_skip]
|
245 |
+
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
|
246 |
+
|
247 |
+
if no_boseos_middle:
|
248 |
+
if i == 0:
|
249 |
+
# discard the ending token
|
250 |
+
text_embedding = text_embedding[:, :-1]
|
251 |
+
elif i == max_embeddings_multiples - 1:
|
252 |
+
# discard the starting token
|
253 |
+
text_embedding = text_embedding[:, 1:]
|
254 |
+
else:
|
255 |
+
# discard both starting and ending tokens
|
256 |
+
text_embedding = text_embedding[:, 1:-1]
|
257 |
+
|
258 |
+
text_embeddings.append(text_embedding)
|
259 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
260 |
+
else:
|
261 |
+
if clip_skip is None or clip_skip == 1:
|
262 |
+
text_embeddings = pipe.text_encoder(text_input)[0]
|
263 |
+
else:
|
264 |
+
enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
265 |
+
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
266 |
+
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
|
267 |
+
return text_embeddings
|
268 |
+
|
269 |
+
|
270 |
+
def get_weighted_text_embeddings(
|
271 |
+
pipe: StableDiffusionPipeline,
|
272 |
+
prompt: Union[str, List[str]],
|
273 |
+
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
274 |
+
max_embeddings_multiples: Optional[int] = 3,
|
275 |
+
no_boseos_middle: Optional[bool] = False,
|
276 |
+
skip_parsing: Optional[bool] = False,
|
277 |
+
skip_weighting: Optional[bool] = False,
|
278 |
+
clip_skip=None,
|
279 |
+
):
|
280 |
+
r"""
|
281 |
+
Prompts can be assigned with local weights using brackets. For example,
|
282 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
283 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
284 |
+
|
285 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
pipe (`StableDiffusionPipeline`):
|
289 |
+
Pipe to provide access to the tokenizer and the text encoder.
|
290 |
+
prompt (`str` or `List[str]`):
|
291 |
+
The prompt or prompts to guide the image generation.
|
292 |
+
uncond_prompt (`str` or `List[str]`):
|
293 |
+
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
294 |
+
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
295 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
296 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
297 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
298 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
299 |
+
ending token in each of the chunk in the middle.
|
300 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
301 |
+
Skip the parsing of brackets.
|
302 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
303 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
304 |
+
"""
|
305 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
306 |
+
if isinstance(prompt, str):
|
307 |
+
prompt = [prompt]
|
308 |
+
|
309 |
+
if not skip_parsing:
|
310 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
|
311 |
+
if uncond_prompt is not None:
|
312 |
+
if isinstance(uncond_prompt, str):
|
313 |
+
uncond_prompt = [uncond_prompt]
|
314 |
+
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
315 |
+
else:
|
316 |
+
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
|
317 |
+
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
318 |
+
if uncond_prompt is not None:
|
319 |
+
if isinstance(uncond_prompt, str):
|
320 |
+
uncond_prompt = [uncond_prompt]
|
321 |
+
uncond_tokens = [
|
322 |
+
token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
323 |
+
]
|
324 |
+
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
325 |
+
|
326 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
327 |
+
max_length = max([len(token) for token in prompt_tokens])
|
328 |
+
if uncond_prompt is not None:
|
329 |
+
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
330 |
+
|
331 |
+
max_embeddings_multiples = min(
|
332 |
+
max_embeddings_multiples,
|
333 |
+
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
|
334 |
+
)
|
335 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
336 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
337 |
+
|
338 |
+
# pad the length of tokens and weights
|
339 |
+
bos = pipe.tokenizer.bos_token_id
|
340 |
+
eos = pipe.tokenizer.eos_token_id
|
341 |
+
pad = pipe.tokenizer.pad_token_id
|
342 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
343 |
+
prompt_tokens,
|
344 |
+
prompt_weights,
|
345 |
+
max_length,
|
346 |
+
bos,
|
347 |
+
eos,
|
348 |
+
no_boseos_middle=no_boseos_middle,
|
349 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
350 |
+
)
|
351 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
352 |
+
if uncond_prompt is not None:
|
353 |
+
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
354 |
+
uncond_tokens,
|
355 |
+
uncond_weights,
|
356 |
+
max_length,
|
357 |
+
bos,
|
358 |
+
eos,
|
359 |
+
no_boseos_middle=no_boseos_middle,
|
360 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
361 |
+
)
|
362 |
+
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
363 |
+
|
364 |
+
# get the embeddings
|
365 |
+
text_embeddings = get_unweighted_text_embeddings(
|
366 |
+
pipe,
|
367 |
+
prompt_tokens,
|
368 |
+
pipe.tokenizer.model_max_length,
|
369 |
+
clip_skip,
|
370 |
+
eos,
|
371 |
+
pad,
|
372 |
+
no_boseos_middle=no_boseos_middle,
|
373 |
+
)
|
374 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
375 |
+
if uncond_prompt is not None:
|
376 |
+
uncond_embeddings = get_unweighted_text_embeddings(
|
377 |
+
pipe,
|
378 |
+
uncond_tokens,
|
379 |
+
pipe.tokenizer.model_max_length,
|
380 |
+
clip_skip,
|
381 |
+
eos,
|
382 |
+
pad,
|
383 |
+
no_boseos_middle=no_boseos_middle,
|
384 |
+
)
|
385 |
+
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
386 |
+
|
387 |
+
# assign weights to the prompts and normalize in the sense of mean
|
388 |
+
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
389 |
+
if (not skip_parsing) and (not skip_weighting):
|
390 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
391 |
+
text_embeddings *= prompt_weights.unsqueeze(-1)
|
392 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
393 |
+
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
394 |
+
if uncond_prompt is not None:
|
395 |
+
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
396 |
+
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
397 |
+
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
398 |
+
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
399 |
+
|
400 |
+
if uncond_prompt is not None:
|
401 |
+
return text_embeddings, uncond_embeddings
|
402 |
+
return text_embeddings, None
|
403 |
+
|
404 |
+
|
405 |
+
def preprocess_image(image):
|
406 |
+
w, h = image.size
|
407 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
408 |
+
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
409 |
+
image = np.array(image).astype(np.float32) / 255.0
|
410 |
+
image = image[None].transpose(0, 3, 1, 2)
|
411 |
+
image = torch.from_numpy(image)
|
412 |
+
return 2.0 * image - 1.0
|
413 |
+
|
414 |
+
|
415 |
+
def preprocess_mask(mask, scale_factor=8):
|
416 |
+
mask = mask.convert("L")
|
417 |
+
w, h = mask.size
|
418 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
419 |
+
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
420 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
421 |
+
mask = np.tile(mask, (4, 1, 1))
|
422 |
+
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
423 |
+
mask = 1 - mask # repaint white, keep black
|
424 |
+
mask = torch.from_numpy(mask)
|
425 |
+
return mask
|
426 |
+
|
427 |
+
|
428 |
+
def prepare_controlnet_image(
|
429 |
+
image: PIL.Image.Image,
|
430 |
+
width: int,
|
431 |
+
height: int,
|
432 |
+
batch_size: int,
|
433 |
+
num_images_per_prompt: int,
|
434 |
+
device: torch.device,
|
435 |
+
dtype: torch.dtype,
|
436 |
+
do_classifier_free_guidance: bool = False,
|
437 |
+
guess_mode: bool = False,
|
438 |
+
):
|
439 |
+
if not isinstance(image, torch.Tensor):
|
440 |
+
if isinstance(image, PIL.Image.Image):
|
441 |
+
image = [image]
|
442 |
+
|
443 |
+
if isinstance(image[0], PIL.Image.Image):
|
444 |
+
images = []
|
445 |
+
|
446 |
+
for image_ in image:
|
447 |
+
image_ = image_.convert("RGB")
|
448 |
+
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
449 |
+
image_ = np.array(image_)
|
450 |
+
image_ = image_[None, :]
|
451 |
+
images.append(image_)
|
452 |
+
|
453 |
+
image = images
|
454 |
+
|
455 |
+
image = np.concatenate(image, axis=0)
|
456 |
+
image = np.array(image).astype(np.float32) / 255.0
|
457 |
+
image = image.transpose(0, 3, 1, 2)
|
458 |
+
image = torch.from_numpy(image)
|
459 |
+
elif isinstance(image[0], torch.Tensor):
|
460 |
+
image = torch.cat(image, dim=0)
|
461 |
+
|
462 |
+
image_batch_size = image.shape[0]
|
463 |
+
|
464 |
+
if image_batch_size == 1:
|
465 |
+
repeat_by = batch_size
|
466 |
+
else:
|
467 |
+
# image batch size is the same as prompt batch size
|
468 |
+
repeat_by = num_images_per_prompt
|
469 |
+
|
470 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
471 |
+
|
472 |
+
image = image.to(device=device, dtype=dtype)
|
473 |
+
|
474 |
+
if do_classifier_free_guidance and not guess_mode:
|
475 |
+
image = torch.cat([image] * 2)
|
476 |
+
|
477 |
+
return image
|
478 |
+
|
479 |
+
|
480 |
+
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
481 |
+
r"""
|
482 |
+
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
483 |
+
weighting in prompt.
|
484 |
+
|
485 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
486 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
487 |
+
|
488 |
+
Args:
|
489 |
+
vae ([`AutoencoderKL`]):
|
490 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
491 |
+
text_encoder ([`CLIPTextModel`]):
|
492 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
493 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
494 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
495 |
+
tokenizer (`CLIPTokenizer`):
|
496 |
+
Tokenizer of class
|
497 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
498 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
499 |
+
scheduler ([`SchedulerMixin`]):
|
500 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
501 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
502 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
503 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
504 |
+
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
505 |
+
feature_extractor ([`CLIPFeatureExtractor`]):
|
506 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
507 |
+
"""
|
508 |
+
|
509 |
+
# if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
510 |
+
|
511 |
+
def __init__(
|
512 |
+
self,
|
513 |
+
vae: AutoencoderKL,
|
514 |
+
text_encoder: CLIPTextModel,
|
515 |
+
tokenizer: CLIPTokenizer,
|
516 |
+
unet: UNet2DConditionModel,
|
517 |
+
scheduler: SchedulerMixin,
|
518 |
+
# clip_skip: int,
|
519 |
+
safety_checker: StableDiffusionSafetyChecker,
|
520 |
+
feature_extractor: CLIPFeatureExtractor,
|
521 |
+
requires_safety_checker: bool = True,
|
522 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
523 |
+
clip_skip: int = 1,
|
524 |
+
):
|
525 |
+
super().__init__(
|
526 |
+
vae=vae,
|
527 |
+
text_encoder=text_encoder,
|
528 |
+
tokenizer=tokenizer,
|
529 |
+
unet=unet,
|
530 |
+
scheduler=scheduler,
|
531 |
+
safety_checker=safety_checker,
|
532 |
+
feature_extractor=feature_extractor,
|
533 |
+
requires_safety_checker=requires_safety_checker,
|
534 |
+
image_encoder=image_encoder,
|
535 |
+
)
|
536 |
+
self.custom_clip_skip = clip_skip
|
537 |
+
self.__init__additional__()
|
538 |
+
|
539 |
+
def __init__additional__(self):
|
540 |
+
if not hasattr(self, "vae_scale_factor"):
|
541 |
+
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
542 |
+
|
543 |
+
@property
|
544 |
+
def _execution_device(self):
|
545 |
+
r"""
|
546 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
547 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
548 |
+
hooks.
|
549 |
+
"""
|
550 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
551 |
+
return self.device
|
552 |
+
for module in self.unet.modules():
|
553 |
+
if (
|
554 |
+
hasattr(module, "_hf_hook")
|
555 |
+
and hasattr(module._hf_hook, "execution_device")
|
556 |
+
and module._hf_hook.execution_device is not None
|
557 |
+
):
|
558 |
+
return torch.device(module._hf_hook.execution_device)
|
559 |
+
return self.device
|
560 |
+
|
561 |
+
def _encode_prompt(
|
562 |
+
self,
|
563 |
+
prompt,
|
564 |
+
device,
|
565 |
+
num_images_per_prompt,
|
566 |
+
do_classifier_free_guidance,
|
567 |
+
negative_prompt,
|
568 |
+
max_embeddings_multiples,
|
569 |
+
):
|
570 |
+
r"""
|
571 |
+
Encodes the prompt into text encoder hidden states.
|
572 |
+
|
573 |
+
Args:
|
574 |
+
prompt (`str` or `list(int)`):
|
575 |
+
prompt to be encoded
|
576 |
+
device: (`torch.device`):
|
577 |
+
torch device
|
578 |
+
num_images_per_prompt (`int`):
|
579 |
+
number of images that should be generated per prompt
|
580 |
+
do_classifier_free_guidance (`bool`):
|
581 |
+
whether to use classifier free guidance or not
|
582 |
+
negative_prompt (`str` or `List[str]`):
|
583 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
584 |
+
if `guidance_scale` is less than `1`).
|
585 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
586 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
587 |
+
"""
|
588 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
589 |
+
|
590 |
+
if negative_prompt is None:
|
591 |
+
negative_prompt = [""] * batch_size
|
592 |
+
elif isinstance(negative_prompt, str):
|
593 |
+
negative_prompt = [negative_prompt] * batch_size
|
594 |
+
if batch_size != len(negative_prompt):
|
595 |
+
raise ValueError(
|
596 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
597 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
598 |
+
" the batch size of `prompt`."
|
599 |
+
)
|
600 |
+
|
601 |
+
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
602 |
+
pipe=self,
|
603 |
+
prompt=prompt,
|
604 |
+
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
605 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
606 |
+
clip_skip=self.custom_clip_skip,
|
607 |
+
)
|
608 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
609 |
+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
610 |
+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
611 |
+
|
612 |
+
if do_classifier_free_guidance:
|
613 |
+
bs_embed, seq_len, _ = uncond_embeddings.shape
|
614 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
615 |
+
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
616 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
617 |
+
|
618 |
+
return text_embeddings
|
619 |
+
|
620 |
+
def check_inputs(self, prompt, height, width, strength, callback_steps):
|
621 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
622 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
623 |
+
|
624 |
+
if strength < 0 or strength > 1:
|
625 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
626 |
+
|
627 |
+
if height % 8 != 0 or width % 8 != 0:
|
628 |
+
logger.info(f'{height} {width}')
|
629 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
630 |
+
|
631 |
+
if (callback_steps is None) or (
|
632 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
633 |
+
):
|
634 |
+
raise ValueError(
|
635 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
|
636 |
+
)
|
637 |
+
|
638 |
+
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
|
639 |
+
if is_text2img:
|
640 |
+
return self.scheduler.timesteps.to(device), num_inference_steps
|
641 |
+
else:
|
642 |
+
# get the original timestep using init_timestep
|
643 |
+
offset = self.scheduler.config.get("steps_offset", 0)
|
644 |
+
init_timestep = int(num_inference_steps * strength) + offset
|
645 |
+
init_timestep = min(init_timestep, num_inference_steps)
|
646 |
+
|
647 |
+
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
648 |
+
timesteps = self.scheduler.timesteps[t_start:].to(device)
|
649 |
+
return timesteps, num_inference_steps - t_start
|
650 |
+
|
651 |
+
def run_safety_checker(self, image, device, dtype):
|
652 |
+
if self.safety_checker is not None:
|
653 |
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
654 |
+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
|
655 |
+
else:
|
656 |
+
has_nsfw_concept = None
|
657 |
+
return image, has_nsfw_concept
|
658 |
+
|
659 |
+
def decode_latents(self, latents):
|
660 |
+
latents = 1 / 0.18215 * latents
|
661 |
+
image = self.vae.decode(latents).sample
|
662 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
663 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
664 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
665 |
+
return image
|
666 |
+
|
667 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
668 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
669 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
670 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
671 |
+
# and should be between [0, 1]
|
672 |
+
|
673 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
674 |
+
extra_step_kwargs = {}
|
675 |
+
if accepts_eta:
|
676 |
+
extra_step_kwargs["eta"] = eta
|
677 |
+
|
678 |
+
# check if the scheduler accepts generator
|
679 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
680 |
+
if accepts_generator:
|
681 |
+
extra_step_kwargs["generator"] = generator
|
682 |
+
return extra_step_kwargs
|
683 |
+
|
684 |
+
def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
|
685 |
+
if image is None:
|
686 |
+
shape = (
|
687 |
+
batch_size,
|
688 |
+
self.unet.in_channels,
|
689 |
+
height // self.vae_scale_factor,
|
690 |
+
width // self.vae_scale_factor,
|
691 |
+
)
|
692 |
+
|
693 |
+
if latents is None:
|
694 |
+
if device.type == "mps":
|
695 |
+
# randn does not work reproducibly on mps
|
696 |
+
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
697 |
+
else:
|
698 |
+
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
699 |
+
else:
|
700 |
+
if latents.shape != shape:
|
701 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
702 |
+
latents = latents.to(device)
|
703 |
+
|
704 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
705 |
+
latents = latents * self.scheduler.init_noise_sigma
|
706 |
+
return latents, None, None
|
707 |
+
else:
|
708 |
+
init_latent_dist = self.vae.encode(image).latent_dist
|
709 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
710 |
+
init_latents = 0.18215 * init_latents
|
711 |
+
init_latents = torch.cat([init_latents] * batch_size, dim=0)
|
712 |
+
init_latents_orig = init_latents
|
713 |
+
shape = init_latents.shape
|
714 |
+
|
715 |
+
# add noise to latents using the timesteps
|
716 |
+
if device.type == "mps":
|
717 |
+
noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
718 |
+
else:
|
719 |
+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
720 |
+
latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
721 |
+
return latents, init_latents_orig, noise
|
722 |
+
|
723 |
+
@torch.no_grad()
|
724 |
+
def __call__(
|
725 |
+
self,
|
726 |
+
prompt: Union[str, List[str]],
|
727 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
728 |
+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
729 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
730 |
+
height: int = 512,
|
731 |
+
width: int = 512,
|
732 |
+
num_inference_steps: int = 50,
|
733 |
+
guidance_scale: float = 7.5,
|
734 |
+
strength: float = 0.8,
|
735 |
+
num_images_per_prompt: Optional[int] = 1,
|
736 |
+
eta: float = 0.0,
|
737 |
+
generator: Optional[torch.Generator] = None,
|
738 |
+
latents: Optional[torch.FloatTensor] = None,
|
739 |
+
max_embeddings_multiples: Optional[int] = 3,
|
740 |
+
output_type: Optional[str] = "pil",
|
741 |
+
return_dict: bool = True,
|
742 |
+
controlnet=None,
|
743 |
+
controlnet_image=None,
|
744 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
745 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
746 |
+
callback_steps: int = 1,
|
747 |
+
):
|
748 |
+
r"""
|
749 |
+
Function invoked when calling the pipeline for generation.
|
750 |
+
|
751 |
+
Args:
|
752 |
+
prompt (`str` or `List[str]`):
|
753 |
+
The prompt or prompts to guide the image generation.
|
754 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
755 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
756 |
+
if `guidance_scale` is less than `1`).
|
757 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
758 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
759 |
+
process.
|
760 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
761 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
762 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
763 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
764 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
765 |
+
height (`int`, *optional*, defaults to 512):
|
766 |
+
The height in pixels of the generated image.
|
767 |
+
width (`int`, *optional*, defaults to 512):
|
768 |
+
The width in pixels of the generated image.
|
769 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
770 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
771 |
+
expense of slower inference.
|
772 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
773 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
774 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
775 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
776 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
777 |
+
usually at the expense of lower image quality.
|
778 |
+
strength (`float`, *optional*, defaults to 0.8):
|
779 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
780 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
781 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
782 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
783 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
784 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
785 |
+
The number of images to generate per prompt.
|
786 |
+
eta (`float`, *optional*, defaults to 0.0):
|
787 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
788 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
789 |
+
generator (`torch.Generator`, *optional*):
|
790 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
791 |
+
deterministic.
|
792 |
+
latents (`torch.FloatTensor`, *optional*):
|
793 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
794 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
795 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
796 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
797 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
798 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
799 |
+
The output format of the generate image. Choose between
|
800 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
801 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
802 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
803 |
+
plain tuple.
|
804 |
+
controlnet (`diffusers.ControlNetModel`, *optional*):
|
805 |
+
A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
|
806 |
+
controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
|
807 |
+
`Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
|
808 |
+
inference.
|
809 |
+
callback (`Callable`, *optional*):
|
810 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
811 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
812 |
+
is_cancelled_callback (`Callable`, *optional*):
|
813 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
814 |
+
`True`, the inference will be cancelled.
|
815 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
816 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
817 |
+
called at every step.
|
818 |
+
|
819 |
+
Returns:
|
820 |
+
`None` if cancelled by `is_cancelled_callback`,
|
821 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
822 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
823 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
824 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
825 |
+
(nsfw) content, according to the `safety_checker`.
|
826 |
+
"""
|
827 |
+
if controlnet is not None and controlnet_image is None:
|
828 |
+
raise ValueError("controlnet_image must be provided if controlnet is not None.")
|
829 |
+
|
830 |
+
# 0. Default height and width to unet
|
831 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
832 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
833 |
+
|
834 |
+
# 1. Check inputs. Raise error if not correct
|
835 |
+
self.check_inputs(prompt, height, width, strength, callback_steps)
|
836 |
+
|
837 |
+
# 2. Define call parameters
|
838 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
839 |
+
device = self._execution_device
|
840 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
841 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
842 |
+
# corresponds to doing no classifier free guidance.
|
843 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
844 |
+
|
845 |
+
# 3. Encode input prompt
|
846 |
+
text_embeddings = self._encode_prompt(
|
847 |
+
prompt,
|
848 |
+
device,
|
849 |
+
num_images_per_prompt,
|
850 |
+
do_classifier_free_guidance,
|
851 |
+
negative_prompt,
|
852 |
+
max_embeddings_multiples,
|
853 |
+
)
|
854 |
+
dtype = text_embeddings.dtype
|
855 |
+
|
856 |
+
# 4. Preprocess image and mask
|
857 |
+
if isinstance(image, PIL.Image.Image):
|
858 |
+
image = preprocess_image(image)
|
859 |
+
if image is not None:
|
860 |
+
image = image.to(device=self.device, dtype=dtype)
|
861 |
+
if isinstance(mask_image, PIL.Image.Image):
|
862 |
+
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
|
863 |
+
if mask_image is not None:
|
864 |
+
mask = mask_image.to(device=self.device, dtype=dtype)
|
865 |
+
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
|
866 |
+
else:
|
867 |
+
mask = None
|
868 |
+
|
869 |
+
if controlnet_image is not None:
|
870 |
+
controlnet_image = prepare_controlnet_image(
|
871 |
+
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
|
872 |
+
)
|
873 |
+
|
874 |
+
# 5. set timesteps
|
875 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
876 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
|
877 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
878 |
+
|
879 |
+
# 6. Prepare latent variables
|
880 |
+
latents, init_latents_orig, noise = self.prepare_latents(
|
881 |
+
image,
|
882 |
+
latent_timestep,
|
883 |
+
batch_size * num_images_per_prompt,
|
884 |
+
height,
|
885 |
+
width,
|
886 |
+
dtype,
|
887 |
+
device,
|
888 |
+
generator,
|
889 |
+
latents,
|
890 |
+
)
|
891 |
+
|
892 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
893 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
894 |
+
|
895 |
+
# 8. Denoising loop
|
896 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
897 |
+
# expand the latents if we are doing classifier free guidance
|
898 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
899 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
900 |
+
|
901 |
+
unet_additional_args = {}
|
902 |
+
if controlnet is not None:
|
903 |
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
904 |
+
latent_model_input,
|
905 |
+
t,
|
906 |
+
encoder_hidden_states=text_embeddings,
|
907 |
+
controlnet_cond=controlnet_image,
|
908 |
+
conditioning_scale=1.0,
|
909 |
+
guess_mode=False,
|
910 |
+
return_dict=False,
|
911 |
+
)
|
912 |
+
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
|
913 |
+
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
|
914 |
+
|
915 |
+
# predict the noise residual
|
916 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
|
917 |
+
|
918 |
+
# perform guidance
|
919 |
+
if do_classifier_free_guidance:
|
920 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
921 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
922 |
+
|
923 |
+
# compute the previous noisy sample x_t -> x_t-1
|
924 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
925 |
+
|
926 |
+
if mask is not None:
|
927 |
+
# masking
|
928 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
929 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
930 |
+
|
931 |
+
# call the callback, if provided
|
932 |
+
if i % callback_steps == 0:
|
933 |
+
if callback is not None:
|
934 |
+
callback(i, t, latents)
|
935 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
936 |
+
return None
|
937 |
+
|
938 |
+
return latents
|
939 |
+
|
940 |
+
def latents_to_image(self, latents):
|
941 |
+
# 9. Post-processing
|
942 |
+
image = self.decode_latents(latents.to(self.vae.dtype))
|
943 |
+
image = self.numpy_to_pil(image)
|
944 |
+
return image
|
945 |
+
|
946 |
+
def text2img(
|
947 |
+
self,
|
948 |
+
prompt: Union[str, List[str]],
|
949 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
950 |
+
height: int = 512,
|
951 |
+
width: int = 512,
|
952 |
+
num_inference_steps: int = 50,
|
953 |
+
guidance_scale: float = 7.5,
|
954 |
+
num_images_per_prompt: Optional[int] = 1,
|
955 |
+
eta: float = 0.0,
|
956 |
+
generator: Optional[torch.Generator] = None,
|
957 |
+
latents: Optional[torch.FloatTensor] = None,
|
958 |
+
max_embeddings_multiples: Optional[int] = 3,
|
959 |
+
output_type: Optional[str] = "pil",
|
960 |
+
return_dict: bool = True,
|
961 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
962 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
963 |
+
callback_steps: int = 1,
|
964 |
+
):
|
965 |
+
r"""
|
966 |
+
Function for text-to-image generation.
|
967 |
+
Args:
|
968 |
+
prompt (`str` or `List[str]`):
|
969 |
+
The prompt or prompts to guide the image generation.
|
970 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
971 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
972 |
+
if `guidance_scale` is less than `1`).
|
973 |
+
height (`int`, *optional*, defaults to 512):
|
974 |
+
The height in pixels of the generated image.
|
975 |
+
width (`int`, *optional*, defaults to 512):
|
976 |
+
The width in pixels of the generated image.
|
977 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
978 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
979 |
+
expense of slower inference.
|
980 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
981 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
982 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
983 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
984 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
985 |
+
usually at the expense of lower image quality.
|
986 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
987 |
+
The number of images to generate per prompt.
|
988 |
+
eta (`float`, *optional*, defaults to 0.0):
|
989 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
990 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
991 |
+
generator (`torch.Generator`, *optional*):
|
992 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
993 |
+
deterministic.
|
994 |
+
latents (`torch.FloatTensor`, *optional*):
|
995 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
996 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
997 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
998 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
999 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1000 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1001 |
+
The output format of the generate image. Choose between
|
1002 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1003 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1004 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1005 |
+
plain tuple.
|
1006 |
+
callback (`Callable`, *optional*):
|
1007 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1008 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1009 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1010 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1011 |
+
`True`, the inference will be cancelled.
|
1012 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1013 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1014 |
+
called at every step.
|
1015 |
+
Returns:
|
1016 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1017 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1018 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1019 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1020 |
+
(nsfw) content, according to the `safety_checker`.
|
1021 |
+
"""
|
1022 |
+
return self.__call__(
|
1023 |
+
prompt=prompt,
|
1024 |
+
negative_prompt=negative_prompt,
|
1025 |
+
height=height,
|
1026 |
+
width=width,
|
1027 |
+
num_inference_steps=num_inference_steps,
|
1028 |
+
guidance_scale=guidance_scale,
|
1029 |
+
num_images_per_prompt=num_images_per_prompt,
|
1030 |
+
eta=eta,
|
1031 |
+
generator=generator,
|
1032 |
+
latents=latents,
|
1033 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1034 |
+
output_type=output_type,
|
1035 |
+
return_dict=return_dict,
|
1036 |
+
callback=callback,
|
1037 |
+
is_cancelled_callback=is_cancelled_callback,
|
1038 |
+
callback_steps=callback_steps,
|
1039 |
+
)
|
1040 |
+
|
1041 |
+
def img2img(
|
1042 |
+
self,
|
1043 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
1044 |
+
prompt: Union[str, List[str]],
|
1045 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1046 |
+
strength: float = 0.8,
|
1047 |
+
num_inference_steps: Optional[int] = 50,
|
1048 |
+
guidance_scale: Optional[float] = 7.5,
|
1049 |
+
num_images_per_prompt: Optional[int] = 1,
|
1050 |
+
eta: Optional[float] = 0.0,
|
1051 |
+
generator: Optional[torch.Generator] = None,
|
1052 |
+
max_embeddings_multiples: Optional[int] = 3,
|
1053 |
+
output_type: Optional[str] = "pil",
|
1054 |
+
return_dict: bool = True,
|
1055 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1056 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1057 |
+
callback_steps: int = 1,
|
1058 |
+
):
|
1059 |
+
r"""
|
1060 |
+
Function for image-to-image generation.
|
1061 |
+
Args:
|
1062 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1063 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
1064 |
+
process.
|
1065 |
+
prompt (`str` or `List[str]`):
|
1066 |
+
The prompt or prompts to guide the image generation.
|
1067 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1068 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1069 |
+
if `guidance_scale` is less than `1`).
|
1070 |
+
strength (`float`, *optional*, defaults to 0.8):
|
1071 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
1072 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
1073 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
1074 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
1075 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
1076 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1077 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1078 |
+
expense of slower inference. This parameter will be modulated by `strength`.
|
1079 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1080 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1081 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1082 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1083 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1084 |
+
usually at the expense of lower image quality.
|
1085 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1086 |
+
The number of images to generate per prompt.
|
1087 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1088 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1089 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1090 |
+
generator (`torch.Generator`, *optional*):
|
1091 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
1092 |
+
deterministic.
|
1093 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1094 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1095 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1096 |
+
The output format of the generate image. Choose between
|
1097 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1098 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1099 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1100 |
+
plain tuple.
|
1101 |
+
callback (`Callable`, *optional*):
|
1102 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1103 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1104 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1105 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1106 |
+
`True`, the inference will be cancelled.
|
1107 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1108 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1109 |
+
called at every step.
|
1110 |
+
Returns:
|
1111 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1112 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1113 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1114 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1115 |
+
(nsfw) content, according to the `safety_checker`.
|
1116 |
+
"""
|
1117 |
+
return self.__call__(
|
1118 |
+
prompt=prompt,
|
1119 |
+
negative_prompt=negative_prompt,
|
1120 |
+
image=image,
|
1121 |
+
num_inference_steps=num_inference_steps,
|
1122 |
+
guidance_scale=guidance_scale,
|
1123 |
+
strength=strength,
|
1124 |
+
num_images_per_prompt=num_images_per_prompt,
|
1125 |
+
eta=eta,
|
1126 |
+
generator=generator,
|
1127 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1128 |
+
output_type=output_type,
|
1129 |
+
return_dict=return_dict,
|
1130 |
+
callback=callback,
|
1131 |
+
is_cancelled_callback=is_cancelled_callback,
|
1132 |
+
callback_steps=callback_steps,
|
1133 |
+
)
|
1134 |
+
|
1135 |
+
def inpaint(
|
1136 |
+
self,
|
1137 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
1138 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
1139 |
+
prompt: Union[str, List[str]],
|
1140 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1141 |
+
strength: float = 0.8,
|
1142 |
+
num_inference_steps: Optional[int] = 50,
|
1143 |
+
guidance_scale: Optional[float] = 7.5,
|
1144 |
+
num_images_per_prompt: Optional[int] = 1,
|
1145 |
+
eta: Optional[float] = 0.0,
|
1146 |
+
generator: Optional[torch.Generator] = None,
|
1147 |
+
max_embeddings_multiples: Optional[int] = 3,
|
1148 |
+
output_type: Optional[str] = "pil",
|
1149 |
+
return_dict: bool = True,
|
1150 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1151 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1152 |
+
callback_steps: int = 1,
|
1153 |
+
):
|
1154 |
+
r"""
|
1155 |
+
Function for inpaint.
|
1156 |
+
Args:
|
1157 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1158 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
1159 |
+
process. This is the image whose masked region will be inpainted.
|
1160 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1161 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
1162 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
1163 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
1164 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
1165 |
+
prompt (`str` or `List[str]`):
|
1166 |
+
The prompt or prompts to guide the image generation.
|
1167 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1168 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1169 |
+
if `guidance_scale` is less than `1`).
|
1170 |
+
strength (`float`, *optional*, defaults to 0.8):
|
1171 |
+
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
1172 |
+
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
1173 |
+
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
|
1174 |
+
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
|
1175 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1176 |
+
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
|
1177 |
+
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
|
1178 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1179 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1180 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1181 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1182 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1183 |
+
usually at the expense of lower image quality.
|
1184 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1185 |
+
The number of images to generate per prompt.
|
1186 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1187 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1188 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1189 |
+
generator (`torch.Generator`, *optional*):
|
1190 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
1191 |
+
deterministic.
|
1192 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1193 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1194 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1195 |
+
The output format of the generate image. Choose between
|
1196 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1197 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1198 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1199 |
+
plain tuple.
|
1200 |
+
callback (`Callable`, *optional*):
|
1201 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1202 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1203 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1204 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1205 |
+
`True`, the inference will be cancelled.
|
1206 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1207 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1208 |
+
called at every step.
|
1209 |
+
Returns:
|
1210 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1211 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1212 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1213 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1214 |
+
(nsfw) content, according to the `safety_checker`.
|
1215 |
+
"""
|
1216 |
+
return self.__call__(
|
1217 |
+
prompt=prompt,
|
1218 |
+
negative_prompt=negative_prompt,
|
1219 |
+
image=image,
|
1220 |
+
mask_image=mask_image,
|
1221 |
+
num_inference_steps=num_inference_steps,
|
1222 |
+
guidance_scale=guidance_scale,
|
1223 |
+
strength=strength,
|
1224 |
+
num_images_per_prompt=num_images_per_prompt,
|
1225 |
+
eta=eta,
|
1226 |
+
generator=generator,
|
1227 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1228 |
+
output_type=output_type,
|
1229 |
+
return_dict=return_dict,
|
1230 |
+
callback=callback,
|
1231 |
+
is_cancelled_callback=is_cancelled_callback,
|
1232 |
+
callback_steps=callback_steps,
|
1233 |
+
)
|
library/model_util.py
ADDED
@@ -0,0 +1,1356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# v1: split from train_db_fixed.py.
|
2 |
+
# v2: support safetensors
|
3 |
+
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from library.device_utils import init_ipex
|
9 |
+
init_ipex()
|
10 |
+
|
11 |
+
import diffusers
|
12 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
13 |
+
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
14 |
+
from safetensors.torch import load_file, save_file
|
15 |
+
from library.original_unet import UNet2DConditionModel
|
16 |
+
from library.utils import setup_logging
|
17 |
+
setup_logging()
|
18 |
+
import logging
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
22 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
23 |
+
BETA_START = 0.00085
|
24 |
+
BETA_END = 0.0120
|
25 |
+
|
26 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
27 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
28 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
29 |
+
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
|
30 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
31 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
32 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
33 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
34 |
+
UNET_PARAMS_NUM_HEADS = 8
|
35 |
+
# UNET_PARAMS_USE_LINEAR_PROJECTION = False
|
36 |
+
|
37 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
38 |
+
VAE_PARAMS_RESOLUTION = 256
|
39 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
40 |
+
VAE_PARAMS_OUT_CH = 3
|
41 |
+
VAE_PARAMS_CH = 128
|
42 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
43 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
44 |
+
|
45 |
+
# V2
|
46 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
47 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
48 |
+
# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
|
49 |
+
|
50 |
+
# Diffusersの設定を読み込むための参照モデル
|
51 |
+
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
52 |
+
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
53 |
+
|
54 |
+
|
55 |
+
# region StableDiffusion->Diffusersの変換コード
|
56 |
+
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
57 |
+
|
58 |
+
|
59 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
60 |
+
"""
|
61 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
62 |
+
"""
|
63 |
+
if n_shave_prefix_segments >= 0:
|
64 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
65 |
+
else:
|
66 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
67 |
+
|
68 |
+
|
69 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
70 |
+
"""
|
71 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
72 |
+
"""
|
73 |
+
mapping = []
|
74 |
+
for old_item in old_list:
|
75 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
76 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
77 |
+
|
78 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
79 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
80 |
+
|
81 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
82 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
83 |
+
|
84 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
85 |
+
|
86 |
+
mapping.append({"old": old_item, "new": new_item})
|
87 |
+
|
88 |
+
return mapping
|
89 |
+
|
90 |
+
|
91 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
92 |
+
"""
|
93 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
94 |
+
"""
|
95 |
+
mapping = []
|
96 |
+
for old_item in old_list:
|
97 |
+
new_item = old_item
|
98 |
+
|
99 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
100 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
101 |
+
|
102 |
+
mapping.append({"old": old_item, "new": new_item})
|
103 |
+
|
104 |
+
return mapping
|
105 |
+
|
106 |
+
|
107 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
108 |
+
"""
|
109 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
110 |
+
"""
|
111 |
+
mapping = []
|
112 |
+
for old_item in old_list:
|
113 |
+
new_item = old_item
|
114 |
+
|
115 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
116 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
117 |
+
|
118 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
119 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
120 |
+
|
121 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
122 |
+
|
123 |
+
mapping.append({"old": old_item, "new": new_item})
|
124 |
+
|
125 |
+
return mapping
|
126 |
+
|
127 |
+
|
128 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
129 |
+
"""
|
130 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
131 |
+
"""
|
132 |
+
mapping = []
|
133 |
+
for old_item in old_list:
|
134 |
+
new_item = old_item
|
135 |
+
|
136 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
137 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
138 |
+
|
139 |
+
if diffusers.__version__ < "0.17.0":
|
140 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
141 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
142 |
+
|
143 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
144 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
145 |
+
|
146 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
147 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
148 |
+
|
149 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
150 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
151 |
+
else:
|
152 |
+
new_item = new_item.replace("q.weight", "to_q.weight")
|
153 |
+
new_item = new_item.replace("q.bias", "to_q.bias")
|
154 |
+
|
155 |
+
new_item = new_item.replace("k.weight", "to_k.weight")
|
156 |
+
new_item = new_item.replace("k.bias", "to_k.bias")
|
157 |
+
|
158 |
+
new_item = new_item.replace("v.weight", "to_v.weight")
|
159 |
+
new_item = new_item.replace("v.bias", "to_v.bias")
|
160 |
+
|
161 |
+
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
162 |
+
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
163 |
+
|
164 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
165 |
+
|
166 |
+
mapping.append({"old": old_item, "new": new_item})
|
167 |
+
|
168 |
+
return mapping
|
169 |
+
|
170 |
+
|
171 |
+
def assign_to_checkpoint(
|
172 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
173 |
+
):
|
174 |
+
"""
|
175 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
176 |
+
to them. It splits attention layers, and takes into account additional replacements
|
177 |
+
that may arise.
|
178 |
+
|
179 |
+
Assigns the weights to the new checkpoint.
|
180 |
+
"""
|
181 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
182 |
+
|
183 |
+
# Splits the attention layers into three variables.
|
184 |
+
if attention_paths_to_split is not None:
|
185 |
+
for path, path_map in attention_paths_to_split.items():
|
186 |
+
old_tensor = old_checkpoint[path]
|
187 |
+
channels = old_tensor.shape[0] // 3
|
188 |
+
|
189 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
190 |
+
|
191 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
192 |
+
|
193 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
194 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
195 |
+
|
196 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
197 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
198 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
199 |
+
|
200 |
+
for path in paths:
|
201 |
+
new_path = path["new"]
|
202 |
+
|
203 |
+
# These have already been assigned
|
204 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
205 |
+
continue
|
206 |
+
|
207 |
+
# Global renaming happens here
|
208 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
209 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
210 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
211 |
+
|
212 |
+
if additional_replacements is not None:
|
213 |
+
for replacement in additional_replacements:
|
214 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
215 |
+
|
216 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
217 |
+
reshaping = False
|
218 |
+
if diffusers.__version__ < "0.17.0":
|
219 |
+
if "proj_attn.weight" in new_path:
|
220 |
+
reshaping = True
|
221 |
+
else:
|
222 |
+
if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
|
223 |
+
reshaping = True
|
224 |
+
|
225 |
+
if reshaping:
|
226 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
227 |
+
else:
|
228 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
229 |
+
|
230 |
+
|
231 |
+
def conv_attn_to_linear(checkpoint):
|
232 |
+
keys = list(checkpoint.keys())
|
233 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
234 |
+
for key in keys:
|
235 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
236 |
+
if checkpoint[key].ndim > 2:
|
237 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
238 |
+
elif "proj_attn.weight" in key:
|
239 |
+
if checkpoint[key].ndim > 2:
|
240 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
241 |
+
|
242 |
+
|
243 |
+
def linear_transformer_to_conv(checkpoint):
|
244 |
+
keys = list(checkpoint.keys())
|
245 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
246 |
+
for key in keys:
|
247 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
248 |
+
if checkpoint[key].ndim == 2:
|
249 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
250 |
+
|
251 |
+
|
252 |
+
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
253 |
+
"""
|
254 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
255 |
+
"""
|
256 |
+
|
257 |
+
# extract state_dict for UNet
|
258 |
+
unet_state_dict = {}
|
259 |
+
unet_key = "model.diffusion_model."
|
260 |
+
keys = list(checkpoint.keys())
|
261 |
+
for key in keys:
|
262 |
+
if key.startswith(unet_key):
|
263 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
264 |
+
|
265 |
+
new_checkpoint = {}
|
266 |
+
|
267 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
268 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
269 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
270 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
271 |
+
|
272 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
273 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
274 |
+
|
275 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
276 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
277 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
278 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
279 |
+
|
280 |
+
# Retrieves the keys for the input blocks only
|
281 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
282 |
+
input_blocks = {
|
283 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
|
284 |
+
}
|
285 |
+
|
286 |
+
# Retrieves the keys for the middle blocks only
|
287 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
288 |
+
middle_blocks = {
|
289 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
|
290 |
+
}
|
291 |
+
|
292 |
+
# Retrieves the keys for the output blocks only
|
293 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
294 |
+
output_blocks = {
|
295 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
|
296 |
+
}
|
297 |
+
|
298 |
+
for i in range(1, num_input_blocks):
|
299 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
300 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
301 |
+
|
302 |
+
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
|
303 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
304 |
+
|
305 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
306 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
307 |
+
f"input_blocks.{i}.0.op.weight"
|
308 |
+
)
|
309 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
|
310 |
+
|
311 |
+
paths = renew_resnet_paths(resnets)
|
312 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
313 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
314 |
+
|
315 |
+
if len(attentions):
|
316 |
+
paths = renew_attention_paths(attentions)
|
317 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
318 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
319 |
+
|
320 |
+
resnet_0 = middle_blocks[0]
|
321 |
+
attentions = middle_blocks[1]
|
322 |
+
resnet_1 = middle_blocks[2]
|
323 |
+
|
324 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
325 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
326 |
+
|
327 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
328 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
329 |
+
|
330 |
+
attentions_paths = renew_attention_paths(attentions)
|
331 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
332 |
+
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
333 |
+
|
334 |
+
for i in range(num_output_blocks):
|
335 |
+
block_id = i // (config["layers_per_block"] + 1)
|
336 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
337 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
338 |
+
output_block_list = {}
|
339 |
+
|
340 |
+
for layer in output_block_layers:
|
341 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
342 |
+
if layer_id in output_block_list:
|
343 |
+
output_block_list[layer_id].append(layer_name)
|
344 |
+
else:
|
345 |
+
output_block_list[layer_id] = [layer_name]
|
346 |
+
|
347 |
+
if len(output_block_list) > 1:
|
348 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
349 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
350 |
+
|
351 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
352 |
+
paths = renew_resnet_paths(resnets)
|
353 |
+
|
354 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
355 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
356 |
+
|
357 |
+
# オリジナル:
|
358 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
359 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
360 |
+
|
361 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
362 |
+
for l in output_block_list.values():
|
363 |
+
l.sort()
|
364 |
+
|
365 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
366 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
367 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
368 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
369 |
+
]
|
370 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
371 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
372 |
+
]
|
373 |
+
|
374 |
+
# Clear attentions as they have been attributed above.
|
375 |
+
if len(attentions) == 2:
|
376 |
+
attentions = []
|
377 |
+
|
378 |
+
if len(attentions):
|
379 |
+
paths = renew_attention_paths(attentions)
|
380 |
+
meta_path = {
|
381 |
+
"old": f"output_blocks.{i}.1",
|
382 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
383 |
+
}
|
384 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
385 |
+
else:
|
386 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
387 |
+
for path in resnet_0_paths:
|
388 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
389 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
390 |
+
|
391 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
392 |
+
|
393 |
+
# SDのv2では1*1のconv2dがlinearに変わっている
|
394 |
+
# 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
|
395 |
+
if v2 and not config.get("use_linear_projection", False):
|
396 |
+
linear_transformer_to_conv(new_checkpoint)
|
397 |
+
|
398 |
+
return new_checkpoint
|
399 |
+
|
400 |
+
|
401 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
402 |
+
# extract state dict for VAE
|
403 |
+
vae_state_dict = {}
|
404 |
+
vae_key = "first_stage_model."
|
405 |
+
keys = list(checkpoint.keys())
|
406 |
+
for key in keys:
|
407 |
+
if key.startswith(vae_key):
|
408 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
409 |
+
# if len(vae_state_dict) == 0:
|
410 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
411 |
+
# vae_state_dict = checkpoint
|
412 |
+
|
413 |
+
new_checkpoint = {}
|
414 |
+
|
415 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
416 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
417 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
418 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
419 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
420 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
421 |
+
|
422 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
423 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
424 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
425 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
426 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
427 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
428 |
+
|
429 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
430 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
431 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
432 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
433 |
+
|
434 |
+
# Retrieves the keys for the encoder down blocks only
|
435 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
436 |
+
down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
|
437 |
+
|
438 |
+
# Retrieves the keys for the decoder up blocks only
|
439 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
440 |
+
up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
|
441 |
+
|
442 |
+
for i in range(num_down_blocks):
|
443 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
444 |
+
|
445 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
446 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
447 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
448 |
+
)
|
449 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
450 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
451 |
+
)
|
452 |
+
|
453 |
+
paths = renew_vae_resnet_paths(resnets)
|
454 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
455 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
456 |
+
|
457 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
458 |
+
num_mid_res_blocks = 2
|
459 |
+
for i in range(1, num_mid_res_blocks + 1):
|
460 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
461 |
+
|
462 |
+
paths = renew_vae_resnet_paths(resnets)
|
463 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
464 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
465 |
+
|
466 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
467 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
468 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
469 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
470 |
+
conv_attn_to_linear(new_checkpoint)
|
471 |
+
|
472 |
+
for i in range(num_up_blocks):
|
473 |
+
block_id = num_up_blocks - 1 - i
|
474 |
+
resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
|
475 |
+
|
476 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
477 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
478 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
479 |
+
]
|
480 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
481 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
482 |
+
]
|
483 |
+
|
484 |
+
paths = renew_vae_resnet_paths(resnets)
|
485 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
486 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
487 |
+
|
488 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
489 |
+
num_mid_res_blocks = 2
|
490 |
+
for i in range(1, num_mid_res_blocks + 1):
|
491 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
492 |
+
|
493 |
+
paths = renew_vae_resnet_paths(resnets)
|
494 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
495 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
496 |
+
|
497 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
498 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
499 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
500 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
501 |
+
conv_attn_to_linear(new_checkpoint)
|
502 |
+
return new_checkpoint
|
503 |
+
|
504 |
+
|
505 |
+
def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
|
506 |
+
"""
|
507 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
508 |
+
"""
|
509 |
+
# unet_params = original_config.model.params.unet_config.params
|
510 |
+
|
511 |
+
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
512 |
+
|
513 |
+
down_block_types = []
|
514 |
+
resolution = 1
|
515 |
+
for i in range(len(block_out_channels)):
|
516 |
+
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
517 |
+
down_block_types.append(block_type)
|
518 |
+
if i != len(block_out_channels) - 1:
|
519 |
+
resolution *= 2
|
520 |
+
|
521 |
+
up_block_types = []
|
522 |
+
for i in range(len(block_out_channels)):
|
523 |
+
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
524 |
+
up_block_types.append(block_type)
|
525 |
+
resolution //= 2
|
526 |
+
|
527 |
+
config = dict(
|
528 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
529 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
530 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
531 |
+
down_block_types=tuple(down_block_types),
|
532 |
+
up_block_types=tuple(up_block_types),
|
533 |
+
block_out_channels=tuple(block_out_channels),
|
534 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
535 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
536 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
537 |
+
# use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
|
538 |
+
)
|
539 |
+
if v2 and use_linear_projection_in_v2:
|
540 |
+
config["use_linear_projection"] = True
|
541 |
+
|
542 |
+
return config
|
543 |
+
|
544 |
+
|
545 |
+
def create_vae_diffusers_config():
|
546 |
+
"""
|
547 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
548 |
+
"""
|
549 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
550 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
551 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
552 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
553 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
554 |
+
|
555 |
+
config = dict(
|
556 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
557 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
558 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
559 |
+
down_block_types=tuple(down_block_types),
|
560 |
+
up_block_types=tuple(up_block_types),
|
561 |
+
block_out_channels=tuple(block_out_channels),
|
562 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
563 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
564 |
+
)
|
565 |
+
return config
|
566 |
+
|
567 |
+
|
568 |
+
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
569 |
+
keys = list(checkpoint.keys())
|
570 |
+
text_model_dict = {}
|
571 |
+
for key in keys:
|
572 |
+
if key.startswith("cond_stage_model.transformer"):
|
573 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
574 |
+
|
575 |
+
# remove position_ids for newer transformer, which causes error :(
|
576 |
+
if "text_model.embeddings.position_ids" in text_model_dict:
|
577 |
+
text_model_dict.pop("text_model.embeddings.position_ids")
|
578 |
+
|
579 |
+
return text_model_dict
|
580 |
+
|
581 |
+
|
582 |
+
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
583 |
+
# 嫌になるくらい違うぞ!
|
584 |
+
def convert_key(key):
|
585 |
+
if not key.startswith("cond_stage_model"):
|
586 |
+
return None
|
587 |
+
|
588 |
+
# common conversion
|
589 |
+
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
590 |
+
key = key.replace("cond_stage_model.model.", "text_model.")
|
591 |
+
|
592 |
+
if "resblocks" in key:
|
593 |
+
# resblocks conversion
|
594 |
+
key = key.replace(".resblocks.", ".layers.")
|
595 |
+
if ".ln_" in key:
|
596 |
+
key = key.replace(".ln_", ".layer_norm")
|
597 |
+
elif ".mlp." in key:
|
598 |
+
key = key.replace(".c_fc.", ".fc1.")
|
599 |
+
key = key.replace(".c_proj.", ".fc2.")
|
600 |
+
elif ".attn.out_proj" in key:
|
601 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
602 |
+
elif ".attn.in_proj" in key:
|
603 |
+
key = None # 特殊なので後で処理する
|
604 |
+
else:
|
605 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
606 |
+
elif ".positional_embedding" in key:
|
607 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
608 |
+
elif ".text_projection" in key:
|
609 |
+
key = None # 使われない???
|
610 |
+
elif ".logit_scale" in key:
|
611 |
+
key = None # 使われない???
|
612 |
+
elif ".token_embedding" in key:
|
613 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
614 |
+
elif ".ln_final" in key:
|
615 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
616 |
+
return key
|
617 |
+
|
618 |
+
keys = list(checkpoint.keys())
|
619 |
+
new_sd = {}
|
620 |
+
for key in keys:
|
621 |
+
# remove resblocks 23
|
622 |
+
if ".resblocks.23." in key:
|
623 |
+
continue
|
624 |
+
new_key = convert_key(key)
|
625 |
+
if new_key is None:
|
626 |
+
continue
|
627 |
+
new_sd[new_key] = checkpoint[key]
|
628 |
+
|
629 |
+
# attnの変換
|
630 |
+
for key in keys:
|
631 |
+
if ".resblocks.23." in key:
|
632 |
+
continue
|
633 |
+
if ".resblocks" in key and ".attn.in_proj_" in key:
|
634 |
+
# 三つに分割
|
635 |
+
values = torch.chunk(checkpoint[key], 3)
|
636 |
+
|
637 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
638 |
+
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
639 |
+
key_pfx = key_pfx.replace("_weight", "")
|
640 |
+
key_pfx = key_pfx.replace("_bias", "")
|
641 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
642 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
643 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
644 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
645 |
+
|
646 |
+
# rename or add position_ids
|
647 |
+
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
648 |
+
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
649 |
+
# waifu diffusion v1.4
|
650 |
+
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
651 |
+
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
652 |
+
else:
|
653 |
+
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
654 |
+
|
655 |
+
new_sd["text_model.embeddings.position_ids"] = position_ids
|
656 |
+
return new_sd
|
657 |
+
|
658 |
+
|
659 |
+
# endregion
|
660 |
+
|
661 |
+
|
662 |
+
# region Diffusers->StableDiffusion の変換コード
|
663 |
+
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
664 |
+
|
665 |
+
|
666 |
+
def conv_transformer_to_linear(checkpoint):
|
667 |
+
keys = list(checkpoint.keys())
|
668 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
669 |
+
for key in keys:
|
670 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
671 |
+
if checkpoint[key].ndim > 2:
|
672 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
673 |
+
|
674 |
+
|
675 |
+
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
676 |
+
unet_conversion_map = [
|
677 |
+
# (stable-diffusion, HF Diffusers)
|
678 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
679 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
680 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
681 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
682 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
683 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
684 |
+
("out.0.weight", "conv_norm_out.weight"),
|
685 |
+
("out.0.bias", "conv_norm_out.bias"),
|
686 |
+
("out.2.weight", "conv_out.weight"),
|
687 |
+
("out.2.bias", "conv_out.bias"),
|
688 |
+
]
|
689 |
+
|
690 |
+
unet_conversion_map_resnet = [
|
691 |
+
# (stable-diffusion, HF Diffusers)
|
692 |
+
("in_layers.0", "norm1"),
|
693 |
+
("in_layers.2", "conv1"),
|
694 |
+
("out_layers.0", "norm2"),
|
695 |
+
("out_layers.3", "conv2"),
|
696 |
+
("emb_layers.1", "time_emb_proj"),
|
697 |
+
("skip_connection", "conv_shortcut"),
|
698 |
+
]
|
699 |
+
|
700 |
+
unet_conversion_map_layer = []
|
701 |
+
for i in range(4):
|
702 |
+
# loop over downblocks/upblocks
|
703 |
+
|
704 |
+
for j in range(2):
|
705 |
+
# loop over resnets/attentions for downblocks
|
706 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
707 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
708 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
709 |
+
|
710 |
+
if i < 3:
|
711 |
+
# no attention layers in down_blocks.3
|
712 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
713 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
714 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
715 |
+
|
716 |
+
for j in range(3):
|
717 |
+
# loop over resnets/attentions for upblocks
|
718 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
719 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
720 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
721 |
+
|
722 |
+
if i > 0:
|
723 |
+
# no attention layers in up_blocks.0
|
724 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
725 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
726 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
727 |
+
|
728 |
+
if i < 3:
|
729 |
+
# no downsample in down_blocks.3
|
730 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
731 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
732 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
733 |
+
|
734 |
+
# no upsample in up_blocks.3
|
735 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
736 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
737 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
738 |
+
|
739 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
740 |
+
sd_mid_atn_prefix = "middle_block.1."
|
741 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
742 |
+
|
743 |
+
for j in range(2):
|
744 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
745 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
746 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
747 |
+
|
748 |
+
# buyer beware: this is a *brittle* function,
|
749 |
+
# and correct output requires that all of these pieces interact in
|
750 |
+
# the exact order in which I have arranged them.
|
751 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
752 |
+
for sd_name, hf_name in unet_conversion_map:
|
753 |
+
mapping[hf_name] = sd_name
|
754 |
+
for k, v in mapping.items():
|
755 |
+
if "resnets" in k:
|
756 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
757 |
+
v = v.replace(hf_part, sd_part)
|
758 |
+
mapping[k] = v
|
759 |
+
for k, v in mapping.items():
|
760 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
761 |
+
v = v.replace(hf_part, sd_part)
|
762 |
+
mapping[k] = v
|
763 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
764 |
+
|
765 |
+
if v2:
|
766 |
+
conv_transformer_to_linear(new_state_dict)
|
767 |
+
|
768 |
+
return new_state_dict
|
769 |
+
|
770 |
+
|
771 |
+
def controlnet_conversion_map():
|
772 |
+
unet_conversion_map = [
|
773 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
774 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
775 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
776 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
777 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
778 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
779 |
+
("middle_block_out.0.weight", "controlnet_mid_block.weight"),
|
780 |
+
("middle_block_out.0.bias", "controlnet_mid_block.bias"),
|
781 |
+
]
|
782 |
+
|
783 |
+
unet_conversion_map_resnet = [
|
784 |
+
("in_layers.0", "norm1"),
|
785 |
+
("in_layers.2", "conv1"),
|
786 |
+
("out_layers.0", "norm2"),
|
787 |
+
("out_layers.3", "conv2"),
|
788 |
+
("emb_layers.1", "time_emb_proj"),
|
789 |
+
("skip_connection", "conv_shortcut"),
|
790 |
+
]
|
791 |
+
|
792 |
+
unet_conversion_map_layer = []
|
793 |
+
for i in range(4):
|
794 |
+
for j in range(2):
|
795 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
796 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
797 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
798 |
+
|
799 |
+
if i < 3:
|
800 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
801 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
802 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
803 |
+
|
804 |
+
if i < 3:
|
805 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
806 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
807 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
808 |
+
|
809 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
810 |
+
sd_mid_atn_prefix = "middle_block.1."
|
811 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
812 |
+
|
813 |
+
for j in range(2):
|
814 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
815 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
816 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
817 |
+
|
818 |
+
controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
|
819 |
+
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
|
820 |
+
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
|
821 |
+
sd_prefix = f"input_hint_block.{i*2}."
|
822 |
+
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
823 |
+
|
824 |
+
for i in range(12):
|
825 |
+
hf_prefix = f"controlnet_down_blocks.{i}."
|
826 |
+
sd_prefix = f"zero_convs.{i}.0."
|
827 |
+
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
828 |
+
|
829 |
+
return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
|
830 |
+
|
831 |
+
|
832 |
+
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
833 |
+
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
834 |
+
|
835 |
+
mapping = {k: k for k in controlnet_state_dict.keys()}
|
836 |
+
for sd_name, diffusers_name in unet_conversion_map:
|
837 |
+
mapping[diffusers_name] = sd_name
|
838 |
+
for k, v in mapping.items():
|
839 |
+
if "resnets" in k:
|
840 |
+
for sd_part, diffusers_part in unet_conversion_map_resnet:
|
841 |
+
v = v.replace(diffusers_part, sd_part)
|
842 |
+
mapping[k] = v
|
843 |
+
for k, v in mapping.items():
|
844 |
+
for sd_part, diffusers_part in unet_conversion_map_layer:
|
845 |
+
v = v.replace(diffusers_part, sd_part)
|
846 |
+
mapping[k] = v
|
847 |
+
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
848 |
+
return new_state_dict
|
849 |
+
|
850 |
+
|
851 |
+
def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
|
852 |
+
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
853 |
+
|
854 |
+
mapping = {k: k for k in controlnet_state_dict.keys()}
|
855 |
+
for sd_name, diffusers_name in unet_conversion_map:
|
856 |
+
mapping[sd_name] = diffusers_name
|
857 |
+
for k, v in mapping.items():
|
858 |
+
for sd_part, diffusers_part in unet_conversion_map_layer:
|
859 |
+
v = v.replace(sd_part, diffusers_part)
|
860 |
+
mapping[k] = v
|
861 |
+
for k, v in mapping.items():
|
862 |
+
if "resnets" in v:
|
863 |
+
for sd_part, diffusers_part in unet_conversion_map_resnet:
|
864 |
+
v = v.replace(sd_part, diffusers_part)
|
865 |
+
mapping[k] = v
|
866 |
+
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
867 |
+
return new_state_dict
|
868 |
+
|
869 |
+
|
870 |
+
# ================#
|
871 |
+
# VAE Conversion #
|
872 |
+
# ================#
|
873 |
+
|
874 |
+
|
875 |
+
def reshape_weight_for_sd(w):
|
876 |
+
# convert HF linear weights to SD conv2d weights
|
877 |
+
return w.reshape(*w.shape, 1, 1)
|
878 |
+
|
879 |
+
|
880 |
+
def convert_vae_state_dict(vae_state_dict):
|
881 |
+
vae_conversion_map = [
|
882 |
+
# (stable-diffusion, HF Diffusers)
|
883 |
+
("nin_shortcut", "conv_shortcut"),
|
884 |
+
("norm_out", "conv_norm_out"),
|
885 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
886 |
+
]
|
887 |
+
|
888 |
+
for i in range(4):
|
889 |
+
# down_blocks have two resnets
|
890 |
+
for j in range(2):
|
891 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
892 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
893 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
894 |
+
|
895 |
+
if i < 3:
|
896 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
897 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
898 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
899 |
+
|
900 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
901 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
902 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
903 |
+
|
904 |
+
# up_blocks have three resnets
|
905 |
+
# also, up blocks in hf are numbered in reverse from sd
|
906 |
+
for j in range(3):
|
907 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
908 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
909 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
910 |
+
|
911 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
912 |
+
for i in range(2):
|
913 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
914 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
915 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
916 |
+
|
917 |
+
if diffusers.__version__ < "0.17.0":
|
918 |
+
vae_conversion_map_attn = [
|
919 |
+
# (stable-diffusion, HF Diffusers)
|
920 |
+
("norm.", "group_norm."),
|
921 |
+
("q.", "query."),
|
922 |
+
("k.", "key."),
|
923 |
+
("v.", "value."),
|
924 |
+
("proj_out.", "proj_attn."),
|
925 |
+
]
|
926 |
+
else:
|
927 |
+
vae_conversion_map_attn = [
|
928 |
+
# (stable-diffusion, HF Diffusers)
|
929 |
+
("norm.", "group_norm."),
|
930 |
+
("q.", "to_q."),
|
931 |
+
("k.", "to_k."),
|
932 |
+
("v.", "to_v."),
|
933 |
+
("proj_out.", "to_out.0."),
|
934 |
+
]
|
935 |
+
|
936 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
937 |
+
for k, v in mapping.items():
|
938 |
+
for sd_part, hf_part in vae_conversion_map:
|
939 |
+
v = v.replace(hf_part, sd_part)
|
940 |
+
mapping[k] = v
|
941 |
+
for k, v in mapping.items():
|
942 |
+
if "attentions" in k:
|
943 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
944 |
+
v = v.replace(hf_part, sd_part)
|
945 |
+
mapping[k] = v
|
946 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
947 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
948 |
+
for k, v in new_state_dict.items():
|
949 |
+
for weight_name in weights_to_convert:
|
950 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
951 |
+
# logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
|
952 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
953 |
+
|
954 |
+
return new_state_dict
|
955 |
+
|
956 |
+
|
957 |
+
# endregion
|
958 |
+
|
959 |
+
# region 自作のモデル読み書きなど
|
960 |
+
|
961 |
+
|
962 |
+
def is_safetensors(path):
|
963 |
+
return os.path.splitext(path)[1].lower() == ".safetensors"
|
964 |
+
|
965 |
+
|
966 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
967 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
968 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
969 |
+
("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
|
970 |
+
("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
|
971 |
+
("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
|
972 |
+
]
|
973 |
+
|
974 |
+
if is_safetensors(ckpt_path):
|
975 |
+
checkpoint = None
|
976 |
+
state_dict = load_file(ckpt_path) # , device) # may causes error
|
977 |
+
else:
|
978 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
979 |
+
if "state_dict" in checkpoint:
|
980 |
+
state_dict = checkpoint["state_dict"]
|
981 |
+
else:
|
982 |
+
state_dict = checkpoint
|
983 |
+
checkpoint = None
|
984 |
+
|
985 |
+
key_reps = []
|
986 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
987 |
+
for key in state_dict.keys():
|
988 |
+
if key.startswith(rep_from):
|
989 |
+
new_key = rep_to + key[len(rep_from) :]
|
990 |
+
key_reps.append((key, new_key))
|
991 |
+
|
992 |
+
for key, new_key in key_reps:
|
993 |
+
state_dict[new_key] = state_dict[key]
|
994 |
+
del state_dict[key]
|
995 |
+
|
996 |
+
return checkpoint, state_dict
|
997 |
+
|
998 |
+
|
999 |
+
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
1000 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
|
1001 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
1002 |
+
|
1003 |
+
# Convert the UNet2DConditionModel model.
|
1004 |
+
unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
|
1005 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
1006 |
+
|
1007 |
+
unet = UNet2DConditionModel(**unet_config).to(device)
|
1008 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
1009 |
+
logger.info(f"loading u-net: {info}")
|
1010 |
+
|
1011 |
+
# Convert the VAE model.
|
1012 |
+
vae_config = create_vae_diffusers_config()
|
1013 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
1014 |
+
|
1015 |
+
vae = AutoencoderKL(**vae_config).to(device)
|
1016 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
1017 |
+
logger.info(f"loading vae: {info}")
|
1018 |
+
|
1019 |
+
# convert text_model
|
1020 |
+
if v2:
|
1021 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
1022 |
+
cfg = CLIPTextConfig(
|
1023 |
+
vocab_size=49408,
|
1024 |
+
hidden_size=1024,
|
1025 |
+
intermediate_size=4096,
|
1026 |
+
num_hidden_layers=23,
|
1027 |
+
num_attention_heads=16,
|
1028 |
+
max_position_embeddings=77,
|
1029 |
+
hidden_act="gelu",
|
1030 |
+
layer_norm_eps=1e-05,
|
1031 |
+
dropout=0.0,
|
1032 |
+
attention_dropout=0.0,
|
1033 |
+
initializer_range=0.02,
|
1034 |
+
initializer_factor=1.0,
|
1035 |
+
pad_token_id=1,
|
1036 |
+
bos_token_id=0,
|
1037 |
+
eos_token_id=2,
|
1038 |
+
model_type="clip_text_model",
|
1039 |
+
projection_dim=512,
|
1040 |
+
torch_dtype="float32",
|
1041 |
+
transformers_version="4.25.0.dev0",
|
1042 |
+
)
|
1043 |
+
text_model = CLIPTextModel._from_config(cfg)
|
1044 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
1045 |
+
else:
|
1046 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
1047 |
+
|
1048 |
+
# logging.set_verbosity_error() # don't show annoying warning
|
1049 |
+
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
1050 |
+
# logging.set_verbosity_warning()
|
1051 |
+
# logger.info(f"config: {text_model.config}")
|
1052 |
+
cfg = CLIPTextConfig(
|
1053 |
+
vocab_size=49408,
|
1054 |
+
hidden_size=768,
|
1055 |
+
intermediate_size=3072,
|
1056 |
+
num_hidden_layers=12,
|
1057 |
+
num_attention_heads=12,
|
1058 |
+
max_position_embeddings=77,
|
1059 |
+
hidden_act="quick_gelu",
|
1060 |
+
layer_norm_eps=1e-05,
|
1061 |
+
dropout=0.0,
|
1062 |
+
attention_dropout=0.0,
|
1063 |
+
initializer_range=0.02,
|
1064 |
+
initializer_factor=1.0,
|
1065 |
+
pad_token_id=1,
|
1066 |
+
bos_token_id=0,
|
1067 |
+
eos_token_id=2,
|
1068 |
+
model_type="clip_text_model",
|
1069 |
+
projection_dim=768,
|
1070 |
+
torch_dtype="float32",
|
1071 |
+
)
|
1072 |
+
text_model = CLIPTextModel._from_config(cfg)
|
1073 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
1074 |
+
logger.info(f"loading text encoder: {info}")
|
1075 |
+
|
1076 |
+
return text_model, vae, unet
|
1077 |
+
|
1078 |
+
|
1079 |
+
def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
|
1080 |
+
# only for reference
|
1081 |
+
version_str = "sd"
|
1082 |
+
if v2:
|
1083 |
+
version_str += "_v2"
|
1084 |
+
else:
|
1085 |
+
version_str += "_v1"
|
1086 |
+
if v_parameterization:
|
1087 |
+
version_str += "_v"
|
1088 |
+
return version_str
|
1089 |
+
|
1090 |
+
|
1091 |
+
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
1092 |
+
def convert_key(key):
|
1093 |
+
# position_idsの除去
|
1094 |
+
if ".position_ids" in key:
|
1095 |
+
return None
|
1096 |
+
|
1097 |
+
# common
|
1098 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
1099 |
+
key = key.replace("text_model.", "")
|
1100 |
+
if "layers" in key:
|
1101 |
+
# resblocks conversion
|
1102 |
+
key = key.replace(".layers.", ".resblocks.")
|
1103 |
+
if ".layer_norm" in key:
|
1104 |
+
key = key.replace(".layer_norm", ".ln_")
|
1105 |
+
elif ".mlp." in key:
|
1106 |
+
key = key.replace(".fc1.", ".c_fc.")
|
1107 |
+
key = key.replace(".fc2.", ".c_proj.")
|
1108 |
+
elif ".self_attn.out_proj" in key:
|
1109 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
1110 |
+
elif ".self_attn." in key:
|
1111 |
+
key = None # 特殊なので後で処理する
|
1112 |
+
else:
|
1113 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
1114 |
+
elif ".position_embedding" in key:
|
1115 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
1116 |
+
elif ".token_embedding" in key:
|
1117 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
1118 |
+
elif "final_layer_norm" in key:
|
1119 |
+
key = key.replace("final_layer_norm", "ln_final")
|
1120 |
+
return key
|
1121 |
+
|
1122 |
+
keys = list(checkpoint.keys())
|
1123 |
+
new_sd = {}
|
1124 |
+
for key in keys:
|
1125 |
+
new_key = convert_key(key)
|
1126 |
+
if new_key is None:
|
1127 |
+
continue
|
1128 |
+
new_sd[new_key] = checkpoint[key]
|
1129 |
+
|
1130 |
+
# attnの変換
|
1131 |
+
for key in keys:
|
1132 |
+
if "layers" in key and "q_proj" in key:
|
1133 |
+
# 三つを結合
|
1134 |
+
key_q = key
|
1135 |
+
key_k = key.replace("q_proj", "k_proj")
|
1136 |
+
key_v = key.replace("q_proj", "v_proj")
|
1137 |
+
|
1138 |
+
value_q = checkpoint[key_q]
|
1139 |
+
value_k = checkpoint[key_k]
|
1140 |
+
value_v = checkpoint[key_v]
|
1141 |
+
value = torch.cat([value_q, value_k, value_v])
|
1142 |
+
|
1143 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
1144 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
1145 |
+
new_sd[new_key] = value
|
1146 |
+
|
1147 |
+
# 最後の層などを捏造するか
|
1148 |
+
if make_dummy_weights:
|
1149 |
+
logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
|
1150 |
+
keys = list(new_sd.keys())
|
1151 |
+
for key in keys:
|
1152 |
+
if key.startswith("transformer.resblocks.22."):
|
1153 |
+
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
1154 |
+
|
1155 |
+
# Diffusersに含まれない重みを作っておく
|
1156 |
+
new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
1157 |
+
new_sd["logit_scale"] = torch.tensor(1)
|
1158 |
+
|
1159 |
+
return new_sd
|
1160 |
+
|
1161 |
+
|
1162 |
+
def save_stable_diffusion_checkpoint(
|
1163 |
+
v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
|
1164 |
+
):
|
1165 |
+
if ckpt_path is not None:
|
1166 |
+
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
1167 |
+
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
1168 |
+
if checkpoint is None: # safetensors または state_dictのckpt
|
1169 |
+
checkpoint = {}
|
1170 |
+
strict = False
|
1171 |
+
else:
|
1172 |
+
strict = True
|
1173 |
+
if "state_dict" in state_dict:
|
1174 |
+
del state_dict["state_dict"]
|
1175 |
+
else:
|
1176 |
+
# 新しく作る
|
1177 |
+
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
1178 |
+
checkpoint = {}
|
1179 |
+
state_dict = {}
|
1180 |
+
strict = False
|
1181 |
+
|
1182 |
+
def update_sd(prefix, sd):
|
1183 |
+
for k, v in sd.items():
|
1184 |
+
key = prefix + k
|
1185 |
+
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
1186 |
+
if save_dtype is not None:
|
1187 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
1188 |
+
state_dict[key] = v
|
1189 |
+
|
1190 |
+
# Convert the UNet model
|
1191 |
+
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
1192 |
+
update_sd("model.diffusion_model.", unet_state_dict)
|
1193 |
+
|
1194 |
+
# Convert the text encoder model
|
1195 |
+
if v2:
|
1196 |
+
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
|
1197 |
+
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
1198 |
+
update_sd("cond_stage_model.model.", text_enc_dict)
|
1199 |
+
else:
|
1200 |
+
text_enc_dict = text_encoder.state_dict()
|
1201 |
+
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
1202 |
+
|
1203 |
+
# Convert the VAE
|
1204 |
+
if vae is not None:
|
1205 |
+
vae_dict = convert_vae_state_dict(vae.state_dict())
|
1206 |
+
update_sd("first_stage_model.", vae_dict)
|
1207 |
+
|
1208 |
+
# Put together new checkpoint
|
1209 |
+
key_count = len(state_dict.keys())
|
1210 |
+
new_ckpt = {"state_dict": state_dict}
|
1211 |
+
|
1212 |
+
# epoch and global_step are sometimes not int
|
1213 |
+
try:
|
1214 |
+
if "epoch" in checkpoint:
|
1215 |
+
epochs += checkpoint["epoch"]
|
1216 |
+
if "global_step" in checkpoint:
|
1217 |
+
steps += checkpoint["global_step"]
|
1218 |
+
except:
|
1219 |
+
pass
|
1220 |
+
|
1221 |
+
new_ckpt["epoch"] = epochs
|
1222 |
+
new_ckpt["global_step"] = steps
|
1223 |
+
|
1224 |
+
if is_safetensors(output_file):
|
1225 |
+
# TODO Tensor以外のdictの値を削除したほうがいいか
|
1226 |
+
save_file(state_dict, output_file, metadata)
|
1227 |
+
else:
|
1228 |
+
torch.save(new_ckpt, output_file)
|
1229 |
+
|
1230 |
+
return key_count
|
1231 |
+
|
1232 |
+
|
1233 |
+
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
1234 |
+
if pretrained_model_name_or_path is None:
|
1235 |
+
# load default settings for v1/v2
|
1236 |
+
if v2:
|
1237 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
1238 |
+
else:
|
1239 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
1240 |
+
|
1241 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
1242 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
1243 |
+
if vae is None:
|
1244 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
1245 |
+
|
1246 |
+
# original U-Net cannot be saved, so we need to convert it to the Diffusers version
|
1247 |
+
# TODO this consumes a lot of memory
|
1248 |
+
diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
|
1249 |
+
diffusers_unet.load_state_dict(unet.state_dict())
|
1250 |
+
|
1251 |
+
pipeline = StableDiffusionPipeline(
|
1252 |
+
unet=diffusers_unet,
|
1253 |
+
text_encoder=text_encoder,
|
1254 |
+
vae=vae,
|
1255 |
+
scheduler=scheduler,
|
1256 |
+
tokenizer=tokenizer,
|
1257 |
+
safety_checker=None,
|
1258 |
+
feature_extractor=None,
|
1259 |
+
requires_safety_checker=None,
|
1260 |
+
)
|
1261 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
1262 |
+
|
1263 |
+
|
1264 |
+
VAE_PREFIX = "first_stage_model."
|
1265 |
+
|
1266 |
+
|
1267 |
+
def load_vae(vae_id, dtype):
|
1268 |
+
logger.info(f"load VAE: {vae_id}")
|
1269 |
+
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
1270 |
+
# Diffusers local/remote
|
1271 |
+
try:
|
1272 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
1273 |
+
except EnvironmentError as e:
|
1274 |
+
logger.error(f"exception occurs in loading vae: {e}")
|
1275 |
+
logger.error("retry with subfolder='vae'")
|
1276 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
1277 |
+
return vae
|
1278 |
+
|
1279 |
+
# local
|
1280 |
+
vae_config = create_vae_diffusers_config()
|
1281 |
+
|
1282 |
+
if vae_id.endswith(".bin"):
|
1283 |
+
# SD 1.5 VAE on Huggingface
|
1284 |
+
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
1285 |
+
else:
|
1286 |
+
# StableDiffusion
|
1287 |
+
vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
|
1288 |
+
vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
|
1289 |
+
|
1290 |
+
# vae only or full model
|
1291 |
+
full_model = False
|
1292 |
+
for vae_key in vae_sd:
|
1293 |
+
if vae_key.startswith(VAE_PREFIX):
|
1294 |
+
full_model = True
|
1295 |
+
break
|
1296 |
+
if not full_model:
|
1297 |
+
sd = {}
|
1298 |
+
for key, value in vae_sd.items():
|
1299 |
+
sd[VAE_PREFIX + key] = value
|
1300 |
+
vae_sd = sd
|
1301 |
+
del sd
|
1302 |
+
|
1303 |
+
# Convert the VAE model.
|
1304 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
1305 |
+
|
1306 |
+
vae = AutoencoderKL(**vae_config)
|
1307 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
1308 |
+
return vae
|
1309 |
+
|
1310 |
+
|
1311 |
+
# endregion
|
1312 |
+
|
1313 |
+
|
1314 |
+
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
1315 |
+
max_width, max_height = max_reso
|
1316 |
+
max_area = max_width * max_height
|
1317 |
+
|
1318 |
+
resos = set()
|
1319 |
+
|
1320 |
+
width = int(math.sqrt(max_area) // divisible) * divisible
|
1321 |
+
resos.add((width, width))
|
1322 |
+
|
1323 |
+
width = min_size
|
1324 |
+
while width <= max_size:
|
1325 |
+
height = min(max_size, int((max_area // width) // divisible) * divisible)
|
1326 |
+
if height >= min_size:
|
1327 |
+
resos.add((width, height))
|
1328 |
+
resos.add((height, width))
|
1329 |
+
|
1330 |
+
# # make additional resos
|
1331 |
+
# if width >= height and width - divisible >= min_size:
|
1332 |
+
# resos.add((width - divisible, height))
|
1333 |
+
# resos.add((height, width - divisible))
|
1334 |
+
# if height >= width and height - divisible >= min_size:
|
1335 |
+
# resos.add((width, height - divisible))
|
1336 |
+
# resos.add((height - divisible, width))
|
1337 |
+
|
1338 |
+
width += divisible
|
1339 |
+
|
1340 |
+
resos = list(resos)
|
1341 |
+
resos.sort()
|
1342 |
+
return resos
|
1343 |
+
|
1344 |
+
|
1345 |
+
if __name__ == "__main__":
|
1346 |
+
resos = make_bucket_resolutions((512, 768))
|
1347 |
+
logger.info(f"{len(resos)}")
|
1348 |
+
logger.info(f"{resos}")
|
1349 |
+
aspect_ratios = [w / h for w, h in resos]
|
1350 |
+
logger.info(f"{aspect_ratios}")
|
1351 |
+
|
1352 |
+
ars = set()
|
1353 |
+
for ar in aspect_ratios:
|
1354 |
+
if ar in ars:
|
1355 |
+
logger.error(f"error! duplicate ar: {ar}")
|
1356 |
+
ars.add(ar)
|
library/original_unet.py
ADDED
@@ -0,0 +1,1919 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる
|
2 |
+
# 条件分岐等で不要な部分は削除している
|
3 |
+
# コードの多くはDiffusersからコピーしている
|
4 |
+
# 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある
|
5 |
+
|
6 |
+
# Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers.
|
7 |
+
# Unnecessary parts are deleted by condition branching.
|
8 |
+
# As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2
|
9 |
+
|
10 |
+
"""
|
11 |
+
v1.5とv2.1の相違点は
|
12 |
+
- attention_head_dimがintかlist[int]か
|
13 |
+
- cross_attention_dimが768か1024か
|
14 |
+
- use_linear_projection: trueがない(=False, 1.5)かあるか
|
15 |
+
- upcast_attentionがFalse(1.5)かTrue(2.1)か
|
16 |
+
- (以下は多分無視していい)
|
17 |
+
- sample_sizeが64か96か
|
18 |
+
- dual_cross_attentionがあるかないか
|
19 |
+
- num_class_embedsがあるかないか
|
20 |
+
- only_cross_attentionがあるかないか
|
21 |
+
|
22 |
+
v1.5
|
23 |
+
{
|
24 |
+
"_class_name": "UNet2DConditionModel",
|
25 |
+
"_diffusers_version": "0.6.0",
|
26 |
+
"act_fn": "silu",
|
27 |
+
"attention_head_dim": 8,
|
28 |
+
"block_out_channels": [
|
29 |
+
320,
|
30 |
+
640,
|
31 |
+
1280,
|
32 |
+
1280
|
33 |
+
],
|
34 |
+
"center_input_sample": false,
|
35 |
+
"cross_attention_dim": 768,
|
36 |
+
"down_block_types": [
|
37 |
+
"CrossAttnDownBlock2D",
|
38 |
+
"CrossAttnDownBlock2D",
|
39 |
+
"CrossAttnDownBlock2D",
|
40 |
+
"DownBlock2D"
|
41 |
+
],
|
42 |
+
"downsample_padding": 1,
|
43 |
+
"flip_sin_to_cos": true,
|
44 |
+
"freq_shift": 0,
|
45 |
+
"in_channels": 4,
|
46 |
+
"layers_per_block": 2,
|
47 |
+
"mid_block_scale_factor": 1,
|
48 |
+
"norm_eps": 1e-05,
|
49 |
+
"norm_num_groups": 32,
|
50 |
+
"out_channels": 4,
|
51 |
+
"sample_size": 64,
|
52 |
+
"up_block_types": [
|
53 |
+
"UpBlock2D",
|
54 |
+
"CrossAttnUpBlock2D",
|
55 |
+
"CrossAttnUpBlock2D",
|
56 |
+
"CrossAttnUpBlock2D"
|
57 |
+
]
|
58 |
+
}
|
59 |
+
|
60 |
+
v2.1
|
61 |
+
{
|
62 |
+
"_class_name": "UNet2DConditionModel",
|
63 |
+
"_diffusers_version": "0.10.0.dev0",
|
64 |
+
"act_fn": "silu",
|
65 |
+
"attention_head_dim": [
|
66 |
+
5,
|
67 |
+
10,
|
68 |
+
20,
|
69 |
+
20
|
70 |
+
],
|
71 |
+
"block_out_channels": [
|
72 |
+
320,
|
73 |
+
640,
|
74 |
+
1280,
|
75 |
+
1280
|
76 |
+
],
|
77 |
+
"center_input_sample": false,
|
78 |
+
"cross_attention_dim": 1024,
|
79 |
+
"down_block_types": [
|
80 |
+
"CrossAttnDownBlock2D",
|
81 |
+
"CrossAttnDownBlock2D",
|
82 |
+
"CrossAttnDownBlock2D",
|
83 |
+
"DownBlock2D"
|
84 |
+
],
|
85 |
+
"downsample_padding": 1,
|
86 |
+
"dual_cross_attention": false,
|
87 |
+
"flip_sin_to_cos": true,
|
88 |
+
"freq_shift": 0,
|
89 |
+
"in_channels": 4,
|
90 |
+
"layers_per_block": 2,
|
91 |
+
"mid_block_scale_factor": 1,
|
92 |
+
"norm_eps": 1e-05,
|
93 |
+
"norm_num_groups": 32,
|
94 |
+
"num_class_embeds": null,
|
95 |
+
"only_cross_attention": false,
|
96 |
+
"out_channels": 4,
|
97 |
+
"sample_size": 96,
|
98 |
+
"up_block_types": [
|
99 |
+
"UpBlock2D",
|
100 |
+
"CrossAttnUpBlock2D",
|
101 |
+
"CrossAttnUpBlock2D",
|
102 |
+
"CrossAttnUpBlock2D"
|
103 |
+
],
|
104 |
+
"use_linear_projection": true,
|
105 |
+
"upcast_attention": true
|
106 |
+
}
|
107 |
+
"""
|
108 |
+
|
109 |
+
import math
|
110 |
+
from types import SimpleNamespace
|
111 |
+
from typing import Dict, Optional, Tuple, Union
|
112 |
+
import torch
|
113 |
+
from torch import nn
|
114 |
+
from torch.nn import functional as F
|
115 |
+
from einops import rearrange
|
116 |
+
from library.utils import setup_logging
|
117 |
+
setup_logging()
|
118 |
+
import logging
|
119 |
+
logger = logging.getLogger(__name__)
|
120 |
+
|
121 |
+
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
|
122 |
+
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
|
123 |
+
TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4
|
124 |
+
IN_CHANNELS: int = 4
|
125 |
+
OUT_CHANNELS: int = 4
|
126 |
+
LAYERS_PER_BLOCK: int = 2
|
127 |
+
LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1
|
128 |
+
TIME_EMBED_FLIP_SIN_TO_COS: bool = True
|
129 |
+
TIME_EMBED_FREQ_SHIFT: int = 0
|
130 |
+
NORM_GROUPS: int = 32
|
131 |
+
NORM_EPS: float = 1e-5
|
132 |
+
TRANSFORMER_NORM_NUM_GROUPS = 32
|
133 |
+
|
134 |
+
DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
|
135 |
+
UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
|
136 |
+
|
137 |
+
|
138 |
+
# region memory efficient attention
|
139 |
+
|
140 |
+
# FlashAttentionを使うCrossAttention
|
141 |
+
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
142 |
+
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
143 |
+
|
144 |
+
# constants
|
145 |
+
|
146 |
+
EPSILON = 1e-6
|
147 |
+
|
148 |
+
# helper functions
|
149 |
+
|
150 |
+
|
151 |
+
def exists(val):
|
152 |
+
return val is not None
|
153 |
+
|
154 |
+
|
155 |
+
def default(val, d):
|
156 |
+
return val if exists(val) else d
|
157 |
+
|
158 |
+
|
159 |
+
# flash attention forwards and backwards
|
160 |
+
|
161 |
+
# https://arxiv.org/abs/2205.14135
|
162 |
+
|
163 |
+
|
164 |
+
class FlashAttentionFunction(torch.autograd.Function):
|
165 |
+
@staticmethod
|
166 |
+
@torch.no_grad()
|
167 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
168 |
+
"""Algorithm 2 in the paper"""
|
169 |
+
|
170 |
+
device = q.device
|
171 |
+
dtype = q.dtype
|
172 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
173 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
174 |
+
|
175 |
+
o = torch.zeros_like(q)
|
176 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
177 |
+
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
178 |
+
|
179 |
+
scale = q.shape[-1] ** -0.5
|
180 |
+
|
181 |
+
if not exists(mask):
|
182 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
183 |
+
else:
|
184 |
+
mask = rearrange(mask, "b n -> b 1 1 n")
|
185 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
186 |
+
|
187 |
+
row_splits = zip(
|
188 |
+
q.split(q_bucket_size, dim=-2),
|
189 |
+
o.split(q_bucket_size, dim=-2),
|
190 |
+
mask,
|
191 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
192 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
193 |
+
)
|
194 |
+
|
195 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
196 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
197 |
+
|
198 |
+
col_splits = zip(
|
199 |
+
k.split(k_bucket_size, dim=-2),
|
200 |
+
v.split(k_bucket_size, dim=-2),
|
201 |
+
)
|
202 |
+
|
203 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
204 |
+
k_start_index = k_ind * k_bucket_size
|
205 |
+
|
206 |
+
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
207 |
+
|
208 |
+
if exists(row_mask):
|
209 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
210 |
+
|
211 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
212 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
213 |
+
q_start_index - k_start_index + 1
|
214 |
+
)
|
215 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
216 |
+
|
217 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
218 |
+
attn_weights -= block_row_maxes
|
219 |
+
exp_weights = torch.exp(attn_weights)
|
220 |
+
|
221 |
+
if exists(row_mask):
|
222 |
+
exp_weights.masked_fill_(~row_mask, 0.0)
|
223 |
+
|
224 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
225 |
+
|
226 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
227 |
+
|
228 |
+
exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
|
229 |
+
|
230 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
231 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
232 |
+
|
233 |
+
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
234 |
+
|
235 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
236 |
+
|
237 |
+
row_maxes.copy_(new_row_maxes)
|
238 |
+
row_sums.copy_(new_row_sums)
|
239 |
+
|
240 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
241 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
242 |
+
|
243 |
+
return o
|
244 |
+
|
245 |
+
@staticmethod
|
246 |
+
@torch.no_grad()
|
247 |
+
def backward(ctx, do):
|
248 |
+
"""Algorithm 4 in the paper"""
|
249 |
+
|
250 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
251 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
252 |
+
|
253 |
+
device = q.device
|
254 |
+
|
255 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
256 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
257 |
+
|
258 |
+
dq = torch.zeros_like(q)
|
259 |
+
dk = torch.zeros_like(k)
|
260 |
+
dv = torch.zeros_like(v)
|
261 |
+
|
262 |
+
row_splits = zip(
|
263 |
+
q.split(q_bucket_size, dim=-2),
|
264 |
+
o.split(q_bucket_size, dim=-2),
|
265 |
+
do.split(q_bucket_size, dim=-2),
|
266 |
+
mask,
|
267 |
+
l.split(q_bucket_size, dim=-2),
|
268 |
+
m.split(q_bucket_size, dim=-2),
|
269 |
+
dq.split(q_bucket_size, dim=-2),
|
270 |
+
)
|
271 |
+
|
272 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
273 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
274 |
+
|
275 |
+
col_splits = zip(
|
276 |
+
k.split(k_bucket_size, dim=-2),
|
277 |
+
v.split(k_bucket_size, dim=-2),
|
278 |
+
dk.split(k_bucket_size, dim=-2),
|
279 |
+
dv.split(k_bucket_size, dim=-2),
|
280 |
+
)
|
281 |
+
|
282 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
283 |
+
k_start_index = k_ind * k_bucket_size
|
284 |
+
|
285 |
+
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
286 |
+
|
287 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
288 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
289 |
+
q_start_index - k_start_index + 1
|
290 |
+
)
|
291 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
292 |
+
|
293 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
294 |
+
|
295 |
+
if exists(row_mask):
|
296 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
297 |
+
|
298 |
+
p = exp_attn_weights / lc
|
299 |
+
|
300 |
+
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
301 |
+
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
302 |
+
|
303 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
304 |
+
ds = p * scale * (dp - D)
|
305 |
+
|
306 |
+
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
307 |
+
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
308 |
+
|
309 |
+
dqc.add_(dq_chunk)
|
310 |
+
dkc.add_(dk_chunk)
|
311 |
+
dvc.add_(dv_chunk)
|
312 |
+
|
313 |
+
return dq, dk, dv, None, None, None, None
|
314 |
+
|
315 |
+
|
316 |
+
# endregion
|
317 |
+
|
318 |
+
|
319 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
320 |
+
return next(parameter.parameters()).dtype
|
321 |
+
|
322 |
+
|
323 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
324 |
+
return next(parameter.parameters()).device
|
325 |
+
|
326 |
+
|
327 |
+
def get_timestep_embedding(
|
328 |
+
timesteps: torch.Tensor,
|
329 |
+
embedding_dim: int,
|
330 |
+
flip_sin_to_cos: bool = False,
|
331 |
+
downscale_freq_shift: float = 1,
|
332 |
+
scale: float = 1,
|
333 |
+
max_period: int = 10000,
|
334 |
+
):
|
335 |
+
"""
|
336 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
337 |
+
|
338 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
339 |
+
These may be fractional.
|
340 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
341 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
342 |
+
"""
|
343 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
344 |
+
|
345 |
+
half_dim = embedding_dim // 2
|
346 |
+
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
347 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
348 |
+
|
349 |
+
emb = torch.exp(exponent)
|
350 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
351 |
+
|
352 |
+
# scale embeddings
|
353 |
+
emb = scale * emb
|
354 |
+
|
355 |
+
# concat sine and cosine embeddings
|
356 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
357 |
+
|
358 |
+
# flip sine and cosine embeddings
|
359 |
+
if flip_sin_to_cos:
|
360 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
361 |
+
|
362 |
+
# zero pad
|
363 |
+
if embedding_dim % 2 == 1:
|
364 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
365 |
+
return emb
|
366 |
+
|
367 |
+
|
368 |
+
# Deep Shrink: We do not common this function, because minimize dependencies.
|
369 |
+
def resize_like(x, target, mode="bicubic", align_corners=False):
|
370 |
+
org_dtype = x.dtype
|
371 |
+
if org_dtype == torch.bfloat16:
|
372 |
+
x = x.to(torch.float32)
|
373 |
+
|
374 |
+
if x.shape[-2:] != target.shape[-2:]:
|
375 |
+
if mode == "nearest":
|
376 |
+
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
|
377 |
+
else:
|
378 |
+
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
|
379 |
+
|
380 |
+
if org_dtype == torch.bfloat16:
|
381 |
+
x = x.to(org_dtype)
|
382 |
+
return x
|
383 |
+
|
384 |
+
|
385 |
+
class SampleOutput:
|
386 |
+
def __init__(self, sample):
|
387 |
+
self.sample = sample
|
388 |
+
|
389 |
+
|
390 |
+
class TimestepEmbedding(nn.Module):
|
391 |
+
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
|
392 |
+
super().__init__()
|
393 |
+
|
394 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
395 |
+
self.act = None
|
396 |
+
if act_fn == "silu":
|
397 |
+
self.act = nn.SiLU()
|
398 |
+
elif act_fn == "mish":
|
399 |
+
self.act = nn.Mish()
|
400 |
+
|
401 |
+
if out_dim is not None:
|
402 |
+
time_embed_dim_out = out_dim
|
403 |
+
else:
|
404 |
+
time_embed_dim_out = time_embed_dim
|
405 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
406 |
+
|
407 |
+
def forward(self, sample):
|
408 |
+
sample = self.linear_1(sample)
|
409 |
+
|
410 |
+
if self.act is not None:
|
411 |
+
sample = self.act(sample)
|
412 |
+
|
413 |
+
sample = self.linear_2(sample)
|
414 |
+
return sample
|
415 |
+
|
416 |
+
|
417 |
+
class Timesteps(nn.Module):
|
418 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
419 |
+
super().__init__()
|
420 |
+
self.num_channels = num_channels
|
421 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
422 |
+
self.downscale_freq_shift = downscale_freq_shift
|
423 |
+
|
424 |
+
def forward(self, timesteps):
|
425 |
+
t_emb = get_timestep_embedding(
|
426 |
+
timesteps,
|
427 |
+
self.num_channels,
|
428 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
429 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
430 |
+
)
|
431 |
+
return t_emb
|
432 |
+
|
433 |
+
|
434 |
+
class ResnetBlock2D(nn.Module):
|
435 |
+
def __init__(
|
436 |
+
self,
|
437 |
+
in_channels,
|
438 |
+
out_channels,
|
439 |
+
):
|
440 |
+
super().__init__()
|
441 |
+
self.in_channels = in_channels
|
442 |
+
self.out_channels = out_channels
|
443 |
+
|
444 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True)
|
445 |
+
|
446 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
447 |
+
|
448 |
+
self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels)
|
449 |
+
|
450 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True)
|
451 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
452 |
+
|
453 |
+
# if non_linearity == "swish":
|
454 |
+
self.nonlinearity = lambda x: F.silu(x)
|
455 |
+
|
456 |
+
self.use_in_shortcut = self.in_channels != self.out_channels
|
457 |
+
|
458 |
+
self.conv_shortcut = None
|
459 |
+
if self.use_in_shortcut:
|
460 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
461 |
+
|
462 |
+
def forward(self, input_tensor, temb):
|
463 |
+
hidden_states = input_tensor
|
464 |
+
|
465 |
+
hidden_states = self.norm1(hidden_states)
|
466 |
+
hidden_states = self.nonlinearity(hidden_states)
|
467 |
+
|
468 |
+
hidden_states = self.conv1(hidden_states)
|
469 |
+
|
470 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
471 |
+
hidden_states = hidden_states + temb
|
472 |
+
|
473 |
+
hidden_states = self.norm2(hidden_states)
|
474 |
+
hidden_states = self.nonlinearity(hidden_states)
|
475 |
+
|
476 |
+
hidden_states = self.conv2(hidden_states)
|
477 |
+
|
478 |
+
if self.conv_shortcut is not None:
|
479 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
480 |
+
|
481 |
+
output_tensor = input_tensor + hidden_states
|
482 |
+
|
483 |
+
return output_tensor
|
484 |
+
|
485 |
+
|
486 |
+
class DownBlock2D(nn.Module):
|
487 |
+
def __init__(
|
488 |
+
self,
|
489 |
+
in_channels: int,
|
490 |
+
out_channels: int,
|
491 |
+
add_downsample=True,
|
492 |
+
):
|
493 |
+
super().__init__()
|
494 |
+
|
495 |
+
self.has_cross_attention = False
|
496 |
+
resnets = []
|
497 |
+
|
498 |
+
for i in range(LAYERS_PER_BLOCK):
|
499 |
+
in_channels = in_channels if i == 0 else out_channels
|
500 |
+
resnets.append(
|
501 |
+
ResnetBlock2D(
|
502 |
+
in_channels=in_channels,
|
503 |
+
out_channels=out_channels,
|
504 |
+
)
|
505 |
+
)
|
506 |
+
self.resnets = nn.ModuleList(resnets)
|
507 |
+
|
508 |
+
if add_downsample:
|
509 |
+
self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)]
|
510 |
+
else:
|
511 |
+
self.downsamplers = None
|
512 |
+
|
513 |
+
self.gradient_checkpointing = False
|
514 |
+
|
515 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
516 |
+
pass
|
517 |
+
|
518 |
+
def set_use_sdpa(self, sdpa):
|
519 |
+
pass
|
520 |
+
|
521 |
+
def forward(self, hidden_states, temb=None):
|
522 |
+
output_states = ()
|
523 |
+
|
524 |
+
for resnet in self.resnets:
|
525 |
+
if self.training and self.gradient_checkpointing:
|
526 |
+
|
527 |
+
def create_custom_forward(module):
|
528 |
+
def custom_forward(*inputs):
|
529 |
+
return module(*inputs)
|
530 |
+
|
531 |
+
return custom_forward
|
532 |
+
|
533 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
534 |
+
else:
|
535 |
+
hidden_states = resnet(hidden_states, temb)
|
536 |
+
|
537 |
+
output_states += (hidden_states,)
|
538 |
+
|
539 |
+
if self.downsamplers is not None:
|
540 |
+
for downsampler in self.downsamplers:
|
541 |
+
hidden_states = downsampler(hidden_states)
|
542 |
+
|
543 |
+
output_states += (hidden_states,)
|
544 |
+
|
545 |
+
return hidden_states, output_states
|
546 |
+
|
547 |
+
|
548 |
+
class Downsample2D(nn.Module):
|
549 |
+
def __init__(self, channels, out_channels):
|
550 |
+
super().__init__()
|
551 |
+
|
552 |
+
self.channels = channels
|
553 |
+
self.out_channels = out_channels
|
554 |
+
|
555 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
|
556 |
+
|
557 |
+
def forward(self, hidden_states):
|
558 |
+
assert hidden_states.shape[1] == self.channels
|
559 |
+
hidden_states = self.conv(hidden_states)
|
560 |
+
|
561 |
+
return hidden_states
|
562 |
+
|
563 |
+
|
564 |
+
class CrossAttention(nn.Module):
|
565 |
+
def __init__(
|
566 |
+
self,
|
567 |
+
query_dim: int,
|
568 |
+
cross_attention_dim: Optional[int] = None,
|
569 |
+
heads: int = 8,
|
570 |
+
dim_head: int = 64,
|
571 |
+
upcast_attention: bool = False,
|
572 |
+
):
|
573 |
+
super().__init__()
|
574 |
+
inner_dim = dim_head * heads
|
575 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
576 |
+
self.upcast_attention = upcast_attention
|
577 |
+
|
578 |
+
self.scale = dim_head**-0.5
|
579 |
+
self.heads = heads
|
580 |
+
|
581 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
582 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
583 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
584 |
+
|
585 |
+
self.to_out = nn.ModuleList([])
|
586 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
587 |
+
# no dropout here
|
588 |
+
|
589 |
+
self.use_memory_efficient_attention_xformers = False
|
590 |
+
self.use_memory_efficient_attention_mem_eff = False
|
591 |
+
self.use_sdpa = False
|
592 |
+
|
593 |
+
# Attention processor
|
594 |
+
self.processor = None
|
595 |
+
|
596 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
597 |
+
self.use_memory_efficient_attention_xformers = xformers
|
598 |
+
self.use_memory_efficient_attention_mem_eff = mem_eff
|
599 |
+
|
600 |
+
def set_use_sdpa(self, sdpa):
|
601 |
+
self.use_sdpa = sdpa
|
602 |
+
|
603 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
604 |
+
batch_size, seq_len, dim = tensor.shape
|
605 |
+
head_size = self.heads
|
606 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
607 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
608 |
+
return tensor
|
609 |
+
|
610 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
611 |
+
batch_size, seq_len, dim = tensor.shape
|
612 |
+
head_size = self.heads
|
613 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
614 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
615 |
+
return tensor
|
616 |
+
|
617 |
+
def set_processor(self):
|
618 |
+
return self.processor
|
619 |
+
|
620 |
+
def get_processor(self):
|
621 |
+
return self.processor
|
622 |
+
|
623 |
+
def forward(self, hidden_states, context=None, mask=None, **kwargs):
|
624 |
+
if self.processor is not None:
|
625 |
+
(
|
626 |
+
hidden_states,
|
627 |
+
encoder_hidden_states,
|
628 |
+
attention_mask,
|
629 |
+
) = translate_attention_names_from_diffusers(
|
630 |
+
hidden_states=hidden_states, context=context, mask=mask, **kwargs
|
631 |
+
)
|
632 |
+
return self.processor(
|
633 |
+
attn=self,
|
634 |
+
hidden_states=hidden_states,
|
635 |
+
encoder_hidden_states=context,
|
636 |
+
attention_mask=mask,
|
637 |
+
**kwargs
|
638 |
+
)
|
639 |
+
if self.use_memory_efficient_attention_xformers:
|
640 |
+
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
641 |
+
if self.use_memory_efficient_attention_mem_eff:
|
642 |
+
return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
|
643 |
+
if self.use_sdpa:
|
644 |
+
return self.forward_sdpa(hidden_states, context, mask)
|
645 |
+
|
646 |
+
query = self.to_q(hidden_states)
|
647 |
+
context = context if context is not None else hidden_states
|
648 |
+
key = self.to_k(context)
|
649 |
+
value = self.to_v(context)
|
650 |
+
|
651 |
+
query = self.reshape_heads_to_batch_dim(query)
|
652 |
+
key = self.reshape_heads_to_batch_dim(key)
|
653 |
+
value = self.reshape_heads_to_batch_dim(value)
|
654 |
+
|
655 |
+
hidden_states = self._attention(query, key, value)
|
656 |
+
|
657 |
+
# linear proj
|
658 |
+
hidden_states = self.to_out[0](hidden_states)
|
659 |
+
# hidden_states = self.to_out[1](hidden_states) # no dropout
|
660 |
+
return hidden_states
|
661 |
+
|
662 |
+
def _attention(self, query, key, value):
|
663 |
+
if self.upcast_attention:
|
664 |
+
query = query.float()
|
665 |
+
key = key.float()
|
666 |
+
|
667 |
+
attention_scores = torch.baddbmm(
|
668 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
669 |
+
query,
|
670 |
+
key.transpose(-1, -2),
|
671 |
+
beta=0,
|
672 |
+
alpha=self.scale,
|
673 |
+
)
|
674 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
675 |
+
|
676 |
+
# cast back to the original dtype
|
677 |
+
attention_probs = attention_probs.to(value.dtype)
|
678 |
+
|
679 |
+
# compute attention output
|
680 |
+
hidden_states = torch.bmm(attention_probs, value)
|
681 |
+
|
682 |
+
# reshape hidden_states
|
683 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
684 |
+
return hidden_states
|
685 |
+
|
686 |
+
# TODO support Hypernetworks
|
687 |
+
def forward_memory_efficient_xformers(self, x, context=None, mask=None):
|
688 |
+
import xformers.ops
|
689 |
+
|
690 |
+
h = self.heads
|
691 |
+
q_in = self.to_q(x)
|
692 |
+
context = context if context is not None else x
|
693 |
+
context = context.to(x.dtype)
|
694 |
+
k_in = self.to_k(context)
|
695 |
+
v_in = self.to_v(context)
|
696 |
+
|
697 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
698 |
+
del q_in, k_in, v_in
|
699 |
+
|
700 |
+
q = q.contiguous()
|
701 |
+
k = k.contiguous()
|
702 |
+
v = v.contiguous()
|
703 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
704 |
+
|
705 |
+
out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
706 |
+
|
707 |
+
out = self.to_out[0](out)
|
708 |
+
return out
|
709 |
+
|
710 |
+
def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
|
711 |
+
flash_func = FlashAttentionFunction
|
712 |
+
|
713 |
+
q_bucket_size = 512
|
714 |
+
k_bucket_size = 1024
|
715 |
+
|
716 |
+
h = self.heads
|
717 |
+
q = self.to_q(x)
|
718 |
+
context = context if context is not None else x
|
719 |
+
context = context.to(x.dtype)
|
720 |
+
k = self.to_k(context)
|
721 |
+
v = self.to_v(context)
|
722 |
+
del context, x
|
723 |
+
|
724 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
725 |
+
|
726 |
+
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
727 |
+
|
728 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
729 |
+
|
730 |
+
out = self.to_out[0](out)
|
731 |
+
return out
|
732 |
+
|
733 |
+
def forward_sdpa(self, x, context=None, mask=None):
|
734 |
+
h = self.heads
|
735 |
+
q_in = self.to_q(x)
|
736 |
+
context = context if context is not None else x
|
737 |
+
context = context.to(x.dtype)
|
738 |
+
k_in = self.to_k(context)
|
739 |
+
v_in = self.to_v(context)
|
740 |
+
|
741 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
|
742 |
+
del q_in, k_in, v_in
|
743 |
+
|
744 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
745 |
+
|
746 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
747 |
+
|
748 |
+
out = self.to_out[0](out)
|
749 |
+
return out
|
750 |
+
|
751 |
+
def translate_attention_names_from_diffusers(
|
752 |
+
hidden_states: torch.FloatTensor,
|
753 |
+
context: Optional[torch.FloatTensor] = None,
|
754 |
+
mask: Optional[torch.FloatTensor] = None,
|
755 |
+
# HF naming
|
756 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
757 |
+
attention_mask: Optional[torch.FloatTensor] = None
|
758 |
+
):
|
759 |
+
# translate from hugging face diffusers
|
760 |
+
context = context if context is not None else encoder_hidden_states
|
761 |
+
|
762 |
+
# translate from hugging face diffusers
|
763 |
+
mask = mask if mask is not None else attention_mask
|
764 |
+
|
765 |
+
return hidden_states, context, mask
|
766 |
+
|
767 |
+
# feedforward
|
768 |
+
class GEGLU(nn.Module):
|
769 |
+
r"""
|
770 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
771 |
+
|
772 |
+
Parameters:
|
773 |
+
dim_in (`int`): The number of channels in the input.
|
774 |
+
dim_out (`int`): The number of channels in the output.
|
775 |
+
"""
|
776 |
+
|
777 |
+
def __init__(self, dim_in: int, dim_out: int):
|
778 |
+
super().__init__()
|
779 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
780 |
+
|
781 |
+
def gelu(self, gate):
|
782 |
+
if gate.device.type != "mps":
|
783 |
+
return F.gelu(gate)
|
784 |
+
# mps: gelu is not implemented for float16
|
785 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
786 |
+
|
787 |
+
def forward(self, hidden_states):
|
788 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
789 |
+
return hidden_states * self.gelu(gate)
|
790 |
+
|
791 |
+
|
792 |
+
class FeedForward(nn.Module):
|
793 |
+
def __init__(
|
794 |
+
self,
|
795 |
+
dim: int,
|
796 |
+
):
|
797 |
+
super().__init__()
|
798 |
+
inner_dim = int(dim * 4) # mult is always 4
|
799 |
+
|
800 |
+
self.net = nn.ModuleList([])
|
801 |
+
# project in
|
802 |
+
self.net.append(GEGLU(dim, inner_dim))
|
803 |
+
# project dropout
|
804 |
+
self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
|
805 |
+
# project out
|
806 |
+
self.net.append(nn.Linear(inner_dim, dim))
|
807 |
+
|
808 |
+
def forward(self, hidden_states):
|
809 |
+
for module in self.net:
|
810 |
+
hidden_states = module(hidden_states)
|
811 |
+
return hidden_states
|
812 |
+
|
813 |
+
|
814 |
+
class BasicTransformerBlock(nn.Module):
|
815 |
+
def __init__(
|
816 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
|
817 |
+
):
|
818 |
+
super().__init__()
|
819 |
+
|
820 |
+
# 1. Self-Attn
|
821 |
+
self.attn1 = CrossAttention(
|
822 |
+
query_dim=dim,
|
823 |
+
cross_attention_dim=None,
|
824 |
+
heads=num_attention_heads,
|
825 |
+
dim_head=attention_head_dim,
|
826 |
+
upcast_attention=upcast_attention,
|
827 |
+
)
|
828 |
+
self.ff = FeedForward(dim)
|
829 |
+
|
830 |
+
# 2. Cross-Attn
|
831 |
+
self.attn2 = CrossAttention(
|
832 |
+
query_dim=dim,
|
833 |
+
cross_attention_dim=cross_attention_dim,
|
834 |
+
heads=num_attention_heads,
|
835 |
+
dim_head=attention_head_dim,
|
836 |
+
upcast_attention=upcast_attention,
|
837 |
+
)
|
838 |
+
|
839 |
+
self.norm1 = nn.LayerNorm(dim)
|
840 |
+
self.norm2 = nn.LayerNorm(dim)
|
841 |
+
|
842 |
+
# 3. Feed-forward
|
843 |
+
self.norm3 = nn.LayerNorm(dim)
|
844 |
+
|
845 |
+
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
|
846 |
+
self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
|
847 |
+
self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
|
848 |
+
|
849 |
+
def set_use_sdpa(self, sdpa: bool):
|
850 |
+
self.attn1.set_use_sdpa(sdpa)
|
851 |
+
self.attn2.set_use_sdpa(sdpa)
|
852 |
+
|
853 |
+
def forward(self, hidden_states, context=None, timestep=None):
|
854 |
+
# 1. Self-Attention
|
855 |
+
norm_hidden_states = self.norm1(hidden_states)
|
856 |
+
|
857 |
+
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
858 |
+
|
859 |
+
# 2. Cross-Attention
|
860 |
+
norm_hidden_states = self.norm2(hidden_states)
|
861 |
+
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
|
862 |
+
|
863 |
+
# 3. Feed-forward
|
864 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
865 |
+
|
866 |
+
return hidden_states
|
867 |
+
|
868 |
+
|
869 |
+
class Transformer2DModel(nn.Module):
|
870 |
+
def __init__(
|
871 |
+
self,
|
872 |
+
num_attention_heads: int = 16,
|
873 |
+
attention_head_dim: int = 88,
|
874 |
+
in_channels: Optional[int] = None,
|
875 |
+
cross_attention_dim: Optional[int] = None,
|
876 |
+
use_linear_projection: bool = False,
|
877 |
+
upcast_attention: bool = False,
|
878 |
+
):
|
879 |
+
super().__init__()
|
880 |
+
self.in_channels = in_channels
|
881 |
+
self.num_attention_heads = num_attention_heads
|
882 |
+
self.attention_head_dim = attention_head_dim
|
883 |
+
inner_dim = num_attention_heads * attention_head_dim
|
884 |
+
self.use_linear_projection = use_linear_projection
|
885 |
+
|
886 |
+
self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True)
|
887 |
+
|
888 |
+
if use_linear_projection:
|
889 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
890 |
+
else:
|
891 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
892 |
+
|
893 |
+
self.transformer_blocks = nn.ModuleList(
|
894 |
+
[
|
895 |
+
BasicTransformerBlock(
|
896 |
+
inner_dim,
|
897 |
+
num_attention_heads,
|
898 |
+
attention_head_dim,
|
899 |
+
cross_attention_dim=cross_attention_dim,
|
900 |
+
upcast_attention=upcast_attention,
|
901 |
+
)
|
902 |
+
]
|
903 |
+
)
|
904 |
+
|
905 |
+
if use_linear_projection:
|
906 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
907 |
+
else:
|
908 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
909 |
+
|
910 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
911 |
+
for transformer in self.transformer_blocks:
|
912 |
+
transformer.set_use_memory_efficient_attention(xformers, mem_eff)
|
913 |
+
|
914 |
+
def set_use_sdpa(self, sdpa):
|
915 |
+
for transformer in self.transformer_blocks:
|
916 |
+
transformer.set_use_sdpa(sdpa)
|
917 |
+
|
918 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
919 |
+
# 1. Input
|
920 |
+
batch, _, height, weight = hidden_states.shape
|
921 |
+
residual = hidden_states
|
922 |
+
|
923 |
+
hidden_states = self.norm(hidden_states)
|
924 |
+
if not self.use_linear_projection:
|
925 |
+
hidden_states = self.proj_in(hidden_states)
|
926 |
+
inner_dim = hidden_states.shape[1]
|
927 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
928 |
+
else:
|
929 |
+
inner_dim = hidden_states.shape[1]
|
930 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
931 |
+
hidden_states = self.proj_in(hidden_states)
|
932 |
+
|
933 |
+
# 2. Blocks
|
934 |
+
for block in self.transformer_blocks:
|
935 |
+
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
|
936 |
+
|
937 |
+
# 3. Output
|
938 |
+
if not self.use_linear_projection:
|
939 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
940 |
+
hidden_states = self.proj_out(hidden_states)
|
941 |
+
else:
|
942 |
+
hidden_states = self.proj_out(hidden_states)
|
943 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
944 |
+
|
945 |
+
output = hidden_states + residual
|
946 |
+
|
947 |
+
if not return_dict:
|
948 |
+
return (output,)
|
949 |
+
|
950 |
+
return SampleOutput(sample=output)
|
951 |
+
|
952 |
+
|
953 |
+
class CrossAttnDownBlock2D(nn.Module):
|
954 |
+
def __init__(
|
955 |
+
self,
|
956 |
+
in_channels: int,
|
957 |
+
out_channels: int,
|
958 |
+
add_downsample=True,
|
959 |
+
cross_attention_dim=1280,
|
960 |
+
attn_num_head_channels=1,
|
961 |
+
use_linear_projection=False,
|
962 |
+
upcast_attention=False,
|
963 |
+
):
|
964 |
+
super().__init__()
|
965 |
+
self.has_cross_attention = True
|
966 |
+
resnets = []
|
967 |
+
attentions = []
|
968 |
+
|
969 |
+
self.attn_num_head_channels = attn_num_head_channels
|
970 |
+
|
971 |
+
for i in range(LAYERS_PER_BLOCK):
|
972 |
+
in_channels = in_channels if i == 0 else out_channels
|
973 |
+
|
974 |
+
resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels))
|
975 |
+
attentions.append(
|
976 |
+
Transformer2DModel(
|
977 |
+
attn_num_head_channels,
|
978 |
+
out_channels // attn_num_head_channels,
|
979 |
+
in_channels=out_channels,
|
980 |
+
cross_attention_dim=cross_attention_dim,
|
981 |
+
use_linear_projection=use_linear_projection,
|
982 |
+
upcast_attention=upcast_attention,
|
983 |
+
)
|
984 |
+
)
|
985 |
+
self.attentions = nn.ModuleList(attentions)
|
986 |
+
self.resnets = nn.ModuleList(resnets)
|
987 |
+
|
988 |
+
if add_downsample:
|
989 |
+
self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
|
990 |
+
else:
|
991 |
+
self.downsamplers = None
|
992 |
+
|
993 |
+
self.gradient_checkpointing = False
|
994 |
+
|
995 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
996 |
+
for attn in self.attentions:
|
997 |
+
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
998 |
+
|
999 |
+
def set_use_sdpa(self, sdpa):
|
1000 |
+
for attn in self.attentions:
|
1001 |
+
attn.set_use_sdpa(sdpa)
|
1002 |
+
|
1003 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
1004 |
+
output_states = ()
|
1005 |
+
|
1006 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1007 |
+
if self.training and self.gradient_checkpointing:
|
1008 |
+
|
1009 |
+
def create_custom_forward(module, return_dict=None):
|
1010 |
+
def custom_forward(*inputs):
|
1011 |
+
if return_dict is not None:
|
1012 |
+
return module(*inputs, return_dict=return_dict)
|
1013 |
+
else:
|
1014 |
+
return module(*inputs)
|
1015 |
+
|
1016 |
+
return custom_forward
|
1017 |
+
|
1018 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1019 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1020 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
1021 |
+
)[0]
|
1022 |
+
else:
|
1023 |
+
hidden_states = resnet(hidden_states, temb)
|
1024 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
1025 |
+
|
1026 |
+
output_states += (hidden_states,)
|
1027 |
+
|
1028 |
+
if self.downsamplers is not None:
|
1029 |
+
for downsampler in self.downsamplers:
|
1030 |
+
hidden_states = downsampler(hidden_states)
|
1031 |
+
|
1032 |
+
output_states += (hidden_states,)
|
1033 |
+
|
1034 |
+
return hidden_states, output_states
|
1035 |
+
|
1036 |
+
|
1037 |
+
class UNetMidBlock2DCrossAttn(nn.Module):
|
1038 |
+
def __init__(
|
1039 |
+
self,
|
1040 |
+
in_channels: int,
|
1041 |
+
attn_num_head_channels=1,
|
1042 |
+
cross_attention_dim=1280,
|
1043 |
+
use_linear_projection=False,
|
1044 |
+
):
|
1045 |
+
super().__init__()
|
1046 |
+
|
1047 |
+
self.has_cross_attention = True
|
1048 |
+
self.attn_num_head_channels = attn_num_head_channels
|
1049 |
+
|
1050 |
+
# Middle block has two resnets and one attention
|
1051 |
+
resnets = [
|
1052 |
+
ResnetBlock2D(
|
1053 |
+
in_channels=in_channels,
|
1054 |
+
out_channels=in_channels,
|
1055 |
+
),
|
1056 |
+
ResnetBlock2D(
|
1057 |
+
in_channels=in_channels,
|
1058 |
+
out_channels=in_channels,
|
1059 |
+
),
|
1060 |
+
]
|
1061 |
+
attentions = [
|
1062 |
+
Transformer2DModel(
|
1063 |
+
attn_num_head_channels,
|
1064 |
+
in_channels // attn_num_head_channels,
|
1065 |
+
in_channels=in_channels,
|
1066 |
+
cross_attention_dim=cross_attention_dim,
|
1067 |
+
use_linear_projection=use_linear_projection,
|
1068 |
+
)
|
1069 |
+
]
|
1070 |
+
|
1071 |
+
self.attentions = nn.ModuleList(attentions)
|
1072 |
+
self.resnets = nn.ModuleList(resnets)
|
1073 |
+
|
1074 |
+
self.gradient_checkpointing = False
|
1075 |
+
|
1076 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
1077 |
+
for attn in self.attentions:
|
1078 |
+
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
1079 |
+
|
1080 |
+
def set_use_sdpa(self, sdpa):
|
1081 |
+
for attn in self.attentions:
|
1082 |
+
attn.set_use_sdpa(sdpa)
|
1083 |
+
|
1084 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
1085 |
+
for i, resnet in enumerate(self.resnets):
|
1086 |
+
attn = None if i == 0 else self.attentions[i - 1]
|
1087 |
+
|
1088 |
+
if self.training and self.gradient_checkpointing:
|
1089 |
+
|
1090 |
+
def create_custom_forward(module, return_dict=None):
|
1091 |
+
def custom_forward(*inputs):
|
1092 |
+
if return_dict is not None:
|
1093 |
+
return module(*inputs, return_dict=return_dict)
|
1094 |
+
else:
|
1095 |
+
return module(*inputs)
|
1096 |
+
|
1097 |
+
return custom_forward
|
1098 |
+
|
1099 |
+
if attn is not None:
|
1100 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1101 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
1102 |
+
)[0]
|
1103 |
+
|
1104 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1105 |
+
else:
|
1106 |
+
if attn is not None:
|
1107 |
+
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
1108 |
+
hidden_states = resnet(hidden_states, temb)
|
1109 |
+
|
1110 |
+
return hidden_states
|
1111 |
+
|
1112 |
+
|
1113 |
+
class Upsample2D(nn.Module):
|
1114 |
+
def __init__(self, channels, out_channels):
|
1115 |
+
super().__init__()
|
1116 |
+
self.channels = channels
|
1117 |
+
self.out_channels = out_channels
|
1118 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
1119 |
+
|
1120 |
+
def forward(self, hidden_states, output_size):
|
1121 |
+
assert hidden_states.shape[1] == self.channels
|
1122 |
+
|
1123 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
1124 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
1125 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
1126 |
+
dtype = hidden_states.dtype
|
1127 |
+
if dtype == torch.bfloat16:
|
1128 |
+
hidden_states = hidden_states.to(torch.float32)
|
1129 |
+
|
1130 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
1131 |
+
if hidden_states.shape[0] >= 64:
|
1132 |
+
hidden_states = hidden_states.contiguous()
|
1133 |
+
|
1134 |
+
# if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
|
1135 |
+
if output_size is None:
|
1136 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
1137 |
+
else:
|
1138 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
1139 |
+
|
1140 |
+
# If the input is bfloat16, we cast back to bfloat16
|
1141 |
+
if dtype == torch.bfloat16:
|
1142 |
+
hidden_states = hidden_states.to(dtype)
|
1143 |
+
|
1144 |
+
hidden_states = self.conv(hidden_states)
|
1145 |
+
|
1146 |
+
return hidden_states
|
1147 |
+
|
1148 |
+
|
1149 |
+
class UpBlock2D(nn.Module):
|
1150 |
+
def __init__(
|
1151 |
+
self,
|
1152 |
+
in_channels: int,
|
1153 |
+
prev_output_channel: int,
|
1154 |
+
out_channels: int,
|
1155 |
+
add_upsample=True,
|
1156 |
+
):
|
1157 |
+
super().__init__()
|
1158 |
+
|
1159 |
+
self.has_cross_attention = False
|
1160 |
+
resnets = []
|
1161 |
+
|
1162 |
+
for i in range(LAYERS_PER_BLOCK_UP):
|
1163 |
+
res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
|
1164 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1165 |
+
|
1166 |
+
resnets.append(
|
1167 |
+
ResnetBlock2D(
|
1168 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1169 |
+
out_channels=out_channels,
|
1170 |
+
)
|
1171 |
+
)
|
1172 |
+
|
1173 |
+
self.resnets = nn.ModuleList(resnets)
|
1174 |
+
|
1175 |
+
if add_upsample:
|
1176 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
|
1177 |
+
else:
|
1178 |
+
self.upsamplers = None
|
1179 |
+
|
1180 |
+
self.gradient_checkpointing = False
|
1181 |
+
|
1182 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
1183 |
+
pass
|
1184 |
+
|
1185 |
+
def set_use_sdpa(self, sdpa):
|
1186 |
+
pass
|
1187 |
+
|
1188 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
1189 |
+
for resnet in self.resnets:
|
1190 |
+
# pop res hidden states
|
1191 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1192 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1193 |
+
|
1194 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1195 |
+
|
1196 |
+
if self.training and self.gradient_checkpointing:
|
1197 |
+
|
1198 |
+
def create_custom_forward(module):
|
1199 |
+
def custom_forward(*inputs):
|
1200 |
+
return module(*inputs)
|
1201 |
+
|
1202 |
+
return custom_forward
|
1203 |
+
|
1204 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1205 |
+
else:
|
1206 |
+
hidden_states = resnet(hidden_states, temb)
|
1207 |
+
|
1208 |
+
if self.upsamplers is not None:
|
1209 |
+
for upsampler in self.upsamplers:
|
1210 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1211 |
+
|
1212 |
+
return hidden_states
|
1213 |
+
|
1214 |
+
|
1215 |
+
class CrossAttnUpBlock2D(nn.Module):
|
1216 |
+
def __init__(
|
1217 |
+
self,
|
1218 |
+
in_channels: int,
|
1219 |
+
out_channels: int,
|
1220 |
+
prev_output_channel: int,
|
1221 |
+
attn_num_head_channels=1,
|
1222 |
+
cross_attention_dim=1280,
|
1223 |
+
add_upsample=True,
|
1224 |
+
use_linear_projection=False,
|
1225 |
+
upcast_attention=False,
|
1226 |
+
):
|
1227 |
+
super().__init__()
|
1228 |
+
resnets = []
|
1229 |
+
attentions = []
|
1230 |
+
|
1231 |
+
self.has_cross_attention = True
|
1232 |
+
self.attn_num_head_channels = attn_num_head_channels
|
1233 |
+
|
1234 |
+
for i in range(LAYERS_PER_BLOCK_UP):
|
1235 |
+
res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
|
1236 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1237 |
+
|
1238 |
+
resnets.append(
|
1239 |
+
ResnetBlock2D(
|
1240 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1241 |
+
out_channels=out_channels,
|
1242 |
+
)
|
1243 |
+
)
|
1244 |
+
attentions.append(
|
1245 |
+
Transformer2DModel(
|
1246 |
+
attn_num_head_channels,
|
1247 |
+
out_channels // attn_num_head_channels,
|
1248 |
+
in_channels=out_channels,
|
1249 |
+
cross_attention_dim=cross_attention_dim,
|
1250 |
+
use_linear_projection=use_linear_projection,
|
1251 |
+
upcast_attention=upcast_attention,
|
1252 |
+
)
|
1253 |
+
)
|
1254 |
+
|
1255 |
+
self.attentions = nn.ModuleList(attentions)
|
1256 |
+
self.resnets = nn.ModuleList(resnets)
|
1257 |
+
|
1258 |
+
if add_upsample:
|
1259 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
|
1260 |
+
else:
|
1261 |
+
self.upsamplers = None
|
1262 |
+
|
1263 |
+
self.gradient_checkpointing = False
|
1264 |
+
|
1265 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
1266 |
+
for attn in self.attentions:
|
1267 |
+
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
1268 |
+
|
1269 |
+
def set_use_sdpa(self, sdpa):
|
1270 |
+
for attn in self.attentions:
|
1271 |
+
attn.set_use_sdpa(sdpa)
|
1272 |
+
|
1273 |
+
def forward(
|
1274 |
+
self,
|
1275 |
+
hidden_states,
|
1276 |
+
res_hidden_states_tuple,
|
1277 |
+
temb=None,
|
1278 |
+
encoder_hidden_states=None,
|
1279 |
+
upsample_size=None,
|
1280 |
+
):
|
1281 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1282 |
+
# pop res hidden states
|
1283 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1284 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1285 |
+
|
1286 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1287 |
+
|
1288 |
+
if self.training and self.gradient_checkpointing:
|
1289 |
+
|
1290 |
+
def create_custom_forward(module, return_dict=None):
|
1291 |
+
def custom_forward(*inputs):
|
1292 |
+
if return_dict is not None:
|
1293 |
+
return module(*inputs, return_dict=return_dict)
|
1294 |
+
else:
|
1295 |
+
return module(*inputs)
|
1296 |
+
|
1297 |
+
return custom_forward
|
1298 |
+
|
1299 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1300 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1301 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
1302 |
+
)[0]
|
1303 |
+
else:
|
1304 |
+
hidden_states = resnet(hidden_states, temb)
|
1305 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
1306 |
+
|
1307 |
+
if self.upsamplers is not None:
|
1308 |
+
for upsampler in self.upsamplers:
|
1309 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1310 |
+
|
1311 |
+
return hidden_states
|
1312 |
+
|
1313 |
+
|
1314 |
+
def get_down_block(
|
1315 |
+
down_block_type,
|
1316 |
+
in_channels,
|
1317 |
+
out_channels,
|
1318 |
+
add_downsample,
|
1319 |
+
attn_num_head_channels,
|
1320 |
+
cross_attention_dim,
|
1321 |
+
use_linear_projection,
|
1322 |
+
upcast_attention,
|
1323 |
+
):
|
1324 |
+
if down_block_type == "DownBlock2D":
|
1325 |
+
return DownBlock2D(
|
1326 |
+
in_channels=in_channels,
|
1327 |
+
out_channels=out_channels,
|
1328 |
+
add_downsample=add_downsample,
|
1329 |
+
)
|
1330 |
+
elif down_block_type == "CrossAttnDownBlock2D":
|
1331 |
+
return CrossAttnDownBlock2D(
|
1332 |
+
in_channels=in_channels,
|
1333 |
+
out_channels=out_channels,
|
1334 |
+
add_downsample=add_downsample,
|
1335 |
+
cross_attention_dim=cross_attention_dim,
|
1336 |
+
attn_num_head_channels=attn_num_head_channels,
|
1337 |
+
use_linear_projection=use_linear_projection,
|
1338 |
+
upcast_attention=upcast_attention,
|
1339 |
+
)
|
1340 |
+
|
1341 |
+
|
1342 |
+
def get_up_block(
|
1343 |
+
up_block_type,
|
1344 |
+
in_channels,
|
1345 |
+
out_channels,
|
1346 |
+
prev_output_channel,
|
1347 |
+
add_upsample,
|
1348 |
+
attn_num_head_channels,
|
1349 |
+
cross_attention_dim=None,
|
1350 |
+
use_linear_projection=False,
|
1351 |
+
upcast_attention=False,
|
1352 |
+
):
|
1353 |
+
if up_block_type == "UpBlock2D":
|
1354 |
+
return UpBlock2D(
|
1355 |
+
in_channels=in_channels,
|
1356 |
+
prev_output_channel=prev_output_channel,
|
1357 |
+
out_channels=out_channels,
|
1358 |
+
add_upsample=add_upsample,
|
1359 |
+
)
|
1360 |
+
elif up_block_type == "CrossAttnUpBlock2D":
|
1361 |
+
return CrossAttnUpBlock2D(
|
1362 |
+
in_channels=in_channels,
|
1363 |
+
out_channels=out_channels,
|
1364 |
+
prev_output_channel=prev_output_channel,
|
1365 |
+
attn_num_head_channels=attn_num_head_channels,
|
1366 |
+
cross_attention_dim=cross_attention_dim,
|
1367 |
+
add_upsample=add_upsample,
|
1368 |
+
use_linear_projection=use_linear_projection,
|
1369 |
+
upcast_attention=upcast_attention,
|
1370 |
+
)
|
1371 |
+
|
1372 |
+
|
1373 |
+
class UNet2DConditionModel(nn.Module):
|
1374 |
+
_supports_gradient_checkpointing = True
|
1375 |
+
|
1376 |
+
def __init__(
|
1377 |
+
self,
|
1378 |
+
sample_size: Optional[int] = None,
|
1379 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
1380 |
+
cross_attention_dim: int = 1280,
|
1381 |
+
use_linear_projection: bool = False,
|
1382 |
+
upcast_attention: bool = False,
|
1383 |
+
**kwargs,
|
1384 |
+
):
|
1385 |
+
super().__init__()
|
1386 |
+
assert sample_size is not None, "sample_size must be specified"
|
1387 |
+
logger.info(
|
1388 |
+
f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
|
1389 |
+
)
|
1390 |
+
|
1391 |
+
# 外部からの参照用に定義しておく
|
1392 |
+
self.in_channels = IN_CHANNELS
|
1393 |
+
self.out_channels = OUT_CHANNELS
|
1394 |
+
|
1395 |
+
self.sample_size = sample_size
|
1396 |
+
self.prepare_config(sample_size=sample_size)
|
1397 |
+
|
1398 |
+
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
|
1399 |
+
|
1400 |
+
# input
|
1401 |
+
self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1))
|
1402 |
+
|
1403 |
+
# time
|
1404 |
+
self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT)
|
1405 |
+
|
1406 |
+
self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM)
|
1407 |
+
|
1408 |
+
self.down_blocks = nn.ModuleList([])
|
1409 |
+
self.mid_block = None
|
1410 |
+
self.up_blocks = nn.ModuleList([])
|
1411 |
+
|
1412 |
+
if isinstance(attention_head_dim, int):
|
1413 |
+
attention_head_dim = (attention_head_dim,) * 4
|
1414 |
+
|
1415 |
+
# down
|
1416 |
+
output_channel = BLOCK_OUT_CHANNELS[0]
|
1417 |
+
for i, down_block_type in enumerate(DOWN_BLOCK_TYPES):
|
1418 |
+
input_channel = output_channel
|
1419 |
+
output_channel = BLOCK_OUT_CHANNELS[i]
|
1420 |
+
is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
|
1421 |
+
|
1422 |
+
down_block = get_down_block(
|
1423 |
+
down_block_type,
|
1424 |
+
in_channels=input_channel,
|
1425 |
+
out_channels=output_channel,
|
1426 |
+
add_downsample=not is_final_block,
|
1427 |
+
attn_num_head_channels=attention_head_dim[i],
|
1428 |
+
cross_attention_dim=cross_attention_dim,
|
1429 |
+
use_linear_projection=use_linear_projection,
|
1430 |
+
upcast_attention=upcast_attention,
|
1431 |
+
)
|
1432 |
+
self.down_blocks.append(down_block)
|
1433 |
+
|
1434 |
+
# mid
|
1435 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
1436 |
+
in_channels=BLOCK_OUT_CHANNELS[-1],
|
1437 |
+
attn_num_head_channels=attention_head_dim[-1],
|
1438 |
+
cross_attention_dim=cross_attention_dim,
|
1439 |
+
use_linear_projection=use_linear_projection,
|
1440 |
+
)
|
1441 |
+
|
1442 |
+
# count how many layers upsample the images
|
1443 |
+
self.num_upsamplers = 0
|
1444 |
+
|
1445 |
+
# up
|
1446 |
+
reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS))
|
1447 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
1448 |
+
output_channel = reversed_block_out_channels[0]
|
1449 |
+
for i, up_block_type in enumerate(UP_BLOCK_TYPES):
|
1450 |
+
is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
|
1451 |
+
|
1452 |
+
prev_output_channel = output_channel
|
1453 |
+
output_channel = reversed_block_out_channels[i]
|
1454 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)]
|
1455 |
+
|
1456 |
+
# add upsample block for all BUT final layer
|
1457 |
+
if not is_final_block:
|
1458 |
+
add_upsample = True
|
1459 |
+
self.num_upsamplers += 1
|
1460 |
+
else:
|
1461 |
+
add_upsample = False
|
1462 |
+
|
1463 |
+
up_block = get_up_block(
|
1464 |
+
up_block_type,
|
1465 |
+
in_channels=input_channel,
|
1466 |
+
out_channels=output_channel,
|
1467 |
+
prev_output_channel=prev_output_channel,
|
1468 |
+
add_upsample=add_upsample,
|
1469 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
1470 |
+
cross_attention_dim=cross_attention_dim,
|
1471 |
+
use_linear_projection=use_linear_projection,
|
1472 |
+
upcast_attention=upcast_attention,
|
1473 |
+
)
|
1474 |
+
self.up_blocks.append(up_block)
|
1475 |
+
prev_output_channel = output_channel
|
1476 |
+
|
1477 |
+
# out
|
1478 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS)
|
1479 |
+
self.conv_act = nn.SiLU()
|
1480 |
+
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
1481 |
+
|
1482 |
+
# region diffusers compatibility
|
1483 |
+
def prepare_config(self, *args, **kwargs):
|
1484 |
+
self.config = SimpleNamespace(**kwargs)
|
1485 |
+
|
1486 |
+
@property
|
1487 |
+
def dtype(self) -> torch.dtype:
|
1488 |
+
# `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
1489 |
+
return get_parameter_dtype(self)
|
1490 |
+
|
1491 |
+
@property
|
1492 |
+
def device(self) -> torch.device:
|
1493 |
+
# `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
|
1494 |
+
return get_parameter_device(self)
|
1495 |
+
|
1496 |
+
def set_attention_slice(self, slice_size):
|
1497 |
+
raise NotImplementedError("Attention slicing is not supported for this model.")
|
1498 |
+
|
1499 |
+
def is_gradient_checkpointing(self) -> bool:
|
1500 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
1501 |
+
|
1502 |
+
def enable_gradient_checkpointing(self):
|
1503 |
+
self.set_gradient_checkpointing(value=True)
|
1504 |
+
|
1505 |
+
def disable_gradient_checkpointing(self):
|
1506 |
+
self.set_gradient_checkpointing(value=False)
|
1507 |
+
|
1508 |
+
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
|
1509 |
+
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
1510 |
+
for module in modules:
|
1511 |
+
module.set_use_memory_efficient_attention(xformers, mem_eff)
|
1512 |
+
|
1513 |
+
def set_use_sdpa(self, sdpa: bool) -> None:
|
1514 |
+
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
1515 |
+
for module in modules:
|
1516 |
+
module.set_use_sdpa(sdpa)
|
1517 |
+
|
1518 |
+
def set_gradient_checkpointing(self, value=False):
|
1519 |
+
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
1520 |
+
for module in modules:
|
1521 |
+
logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
|
1522 |
+
module.gradient_checkpointing = value
|
1523 |
+
|
1524 |
+
# endregion
|
1525 |
+
|
1526 |
+
def forward(
|
1527 |
+
self,
|
1528 |
+
sample: torch.FloatTensor,
|
1529 |
+
timestep: Union[torch.Tensor, float, int],
|
1530 |
+
encoder_hidden_states: torch.Tensor,
|
1531 |
+
class_labels: Optional[torch.Tensor] = None,
|
1532 |
+
return_dict: bool = True,
|
1533 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1534 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
1535 |
+
) -> Union[Dict, Tuple]:
|
1536 |
+
r"""
|
1537 |
+
Args:
|
1538 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
1539 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
1540 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
1541 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1542 |
+
Whether or not to return a dict instead of a plain tuple.
|
1543 |
+
|
1544 |
+
Returns:
|
1545 |
+
`SampleOutput` or `tuple`:
|
1546 |
+
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
1547 |
+
"""
|
1548 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
1549 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
1550 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
1551 |
+
# on the fly if necessary.
|
1552 |
+
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
1553 |
+
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
1554 |
+
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
1555 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
1556 |
+
|
1557 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
1558 |
+
# 64で割り切れないときはupsamplerにサイズを伝える
|
1559 |
+
forward_upsample_size = False
|
1560 |
+
upsample_size = None
|
1561 |
+
|
1562 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
1563 |
+
# logger.info("Forward upsample size to force interpolation output size.")
|
1564 |
+
forward_upsample_size = True
|
1565 |
+
|
1566 |
+
# 1. time
|
1567 |
+
timesteps = timestep
|
1568 |
+
timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
1569 |
+
|
1570 |
+
t_emb = self.time_proj(timesteps)
|
1571 |
+
|
1572 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
1573 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
1574 |
+
# there might be better ways to encapsulate this.
|
1575 |
+
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
1576 |
+
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
1577 |
+
# time_projでキャストしておけばいいんじゃね?
|
1578 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
1579 |
+
emb = self.time_embedding(t_emb)
|
1580 |
+
|
1581 |
+
# 2. pre-process
|
1582 |
+
sample = self.conv_in(sample)
|
1583 |
+
|
1584 |
+
down_block_res_samples = (sample,)
|
1585 |
+
for downsample_block in self.down_blocks:
|
1586 |
+
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
1587 |
+
# まあこちらのほうがわかりやすいかもしれない
|
1588 |
+
if downsample_block.has_cross_attention:
|
1589 |
+
sample, res_samples = downsample_block(
|
1590 |
+
hidden_states=sample,
|
1591 |
+
temb=emb,
|
1592 |
+
encoder_hidden_states=encoder_hidden_states,
|
1593 |
+
)
|
1594 |
+
else:
|
1595 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
1596 |
+
|
1597 |
+
down_block_res_samples += res_samples
|
1598 |
+
|
1599 |
+
# skip connectionにControlNetの出力を追加する
|
1600 |
+
if down_block_additional_residuals is not None:
|
1601 |
+
down_block_res_samples = list(down_block_res_samples)
|
1602 |
+
for i in range(len(down_block_res_samples)):
|
1603 |
+
down_block_res_samples[i] += down_block_additional_residuals[i]
|
1604 |
+
down_block_res_samples = tuple(down_block_res_samples)
|
1605 |
+
|
1606 |
+
# 4. mid
|
1607 |
+
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
1608 |
+
|
1609 |
+
# ControlNetの出力を追加する
|
1610 |
+
if mid_block_additional_residual is not None:
|
1611 |
+
sample += mid_block_additional_residual
|
1612 |
+
|
1613 |
+
# 5. up
|
1614 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
1615 |
+
is_final_block = i == len(self.up_blocks) - 1
|
1616 |
+
|
1617 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1618 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
1619 |
+
|
1620 |
+
# if we have not reached the final block and need to forward the upsample size, we do it here
|
1621 |
+
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
1622 |
+
if not is_final_block and forward_upsample_size:
|
1623 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
1624 |
+
|
1625 |
+
if upsample_block.has_cross_attention:
|
1626 |
+
sample = upsample_block(
|
1627 |
+
hidden_states=sample,
|
1628 |
+
temb=emb,
|
1629 |
+
res_hidden_states_tuple=res_samples,
|
1630 |
+
encoder_hidden_states=encoder_hidden_states,
|
1631 |
+
upsample_size=upsample_size,
|
1632 |
+
)
|
1633 |
+
else:
|
1634 |
+
sample = upsample_block(
|
1635 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
1636 |
+
)
|
1637 |
+
|
1638 |
+
# 6. post-process
|
1639 |
+
sample = self.conv_norm_out(sample)
|
1640 |
+
sample = self.conv_act(sample)
|
1641 |
+
sample = self.conv_out(sample)
|
1642 |
+
|
1643 |
+
if not return_dict:
|
1644 |
+
return (sample,)
|
1645 |
+
|
1646 |
+
return SampleOutput(sample=sample)
|
1647 |
+
|
1648 |
+
def handle_unusual_timesteps(self, sample, timesteps):
|
1649 |
+
r"""
|
1650 |
+
timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。
|
1651 |
+
"""
|
1652 |
+
if not torch.is_tensor(timesteps):
|
1653 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
1654 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
1655 |
+
is_mps = sample.device.type == "mps"
|
1656 |
+
if isinstance(timesteps, float):
|
1657 |
+
dtype = torch.float32 if is_mps else torch.float64
|
1658 |
+
else:
|
1659 |
+
dtype = torch.int32 if is_mps else torch.int64
|
1660 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
1661 |
+
elif len(timesteps.shape) == 0:
|
1662 |
+
timesteps = timesteps[None].to(sample.device)
|
1663 |
+
|
1664 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
1665 |
+
timesteps = timesteps.expand(sample.shape[0])
|
1666 |
+
|
1667 |
+
return timesteps
|
1668 |
+
|
1669 |
+
|
1670 |
+
class InferUNet2DConditionModel:
|
1671 |
+
def __init__(self, original_unet: UNet2DConditionModel):
|
1672 |
+
self.delegate = original_unet
|
1673 |
+
|
1674 |
+
# override original model's forward method: because forward is not called by `__call__`
|
1675 |
+
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
|
1676 |
+
self.delegate.forward = self.forward
|
1677 |
+
|
1678 |
+
# override original model's up blocks' forward method
|
1679 |
+
for up_block in self.delegate.up_blocks:
|
1680 |
+
if up_block.__class__.__name__ == "UpBlock2D":
|
1681 |
+
|
1682 |
+
def resnet_wrapper(func, block):
|
1683 |
+
def forward(*args, **kwargs):
|
1684 |
+
return func(block, *args, **kwargs)
|
1685 |
+
|
1686 |
+
return forward
|
1687 |
+
|
1688 |
+
up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
|
1689 |
+
|
1690 |
+
elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
1691 |
+
|
1692 |
+
def cross_attn_up_wrapper(func, block):
|
1693 |
+
def forward(*args, **kwargs):
|
1694 |
+
return func(block, *args, **kwargs)
|
1695 |
+
|
1696 |
+
return forward
|
1697 |
+
|
1698 |
+
up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
|
1699 |
+
|
1700 |
+
# Deep Shrink
|
1701 |
+
self.ds_depth_1 = None
|
1702 |
+
self.ds_depth_2 = None
|
1703 |
+
self.ds_timesteps_1 = None
|
1704 |
+
self.ds_timesteps_2 = None
|
1705 |
+
self.ds_ratio = None
|
1706 |
+
|
1707 |
+
# call original model's methods
|
1708 |
+
def __getattr__(self, name):
|
1709 |
+
return getattr(self.delegate, name)
|
1710 |
+
|
1711 |
+
def __call__(self, *args, **kwargs):
|
1712 |
+
return self.delegate(*args, **kwargs)
|
1713 |
+
|
1714 |
+
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
1715 |
+
if ds_depth_1 is None:
|
1716 |
+
logger.info("Deep Shrink is disabled.")
|
1717 |
+
self.ds_depth_1 = None
|
1718 |
+
self.ds_timesteps_1 = None
|
1719 |
+
self.ds_depth_2 = None
|
1720 |
+
self.ds_timesteps_2 = None
|
1721 |
+
self.ds_ratio = None
|
1722 |
+
else:
|
1723 |
+
logger.info(
|
1724 |
+
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
1725 |
+
)
|
1726 |
+
self.ds_depth_1 = ds_depth_1
|
1727 |
+
self.ds_timesteps_1 = ds_timesteps_1
|
1728 |
+
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
1729 |
+
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
1730 |
+
self.ds_ratio = ds_ratio
|
1731 |
+
|
1732 |
+
def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
1733 |
+
for resnet in _self.resnets:
|
1734 |
+
# pop res hidden states
|
1735 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1736 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1737 |
+
|
1738 |
+
# Deep Shrink
|
1739 |
+
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
1740 |
+
hidden_states = resize_like(hidden_states, res_hidden_states)
|
1741 |
+
|
1742 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1743 |
+
hidden_states = resnet(hidden_states, temb)
|
1744 |
+
|
1745 |
+
if _self.upsamplers is not None:
|
1746 |
+
for upsampler in _self.upsamplers:
|
1747 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1748 |
+
|
1749 |
+
return hidden_states
|
1750 |
+
|
1751 |
+
def cross_attn_up_block_forward(
|
1752 |
+
self,
|
1753 |
+
_self,
|
1754 |
+
hidden_states,
|
1755 |
+
res_hidden_states_tuple,
|
1756 |
+
temb=None,
|
1757 |
+
encoder_hidden_states=None,
|
1758 |
+
upsample_size=None,
|
1759 |
+
):
|
1760 |
+
for resnet, attn in zip(_self.resnets, _self.attentions):
|
1761 |
+
# pop res hidden states
|
1762 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1763 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1764 |
+
|
1765 |
+
# Deep Shrink
|
1766 |
+
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
1767 |
+
hidden_states = resize_like(hidden_states, res_hidden_states)
|
1768 |
+
|
1769 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1770 |
+
hidden_states = resnet(hidden_states, temb)
|
1771 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
1772 |
+
|
1773 |
+
if _self.upsamplers is not None:
|
1774 |
+
for upsampler in _self.upsamplers:
|
1775 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1776 |
+
|
1777 |
+
return hidden_states
|
1778 |
+
|
1779 |
+
def forward(
|
1780 |
+
self,
|
1781 |
+
sample: torch.FloatTensor,
|
1782 |
+
timestep: Union[torch.Tensor, float, int],
|
1783 |
+
encoder_hidden_states: torch.Tensor,
|
1784 |
+
class_labels: Optional[torch.Tensor] = None,
|
1785 |
+
return_dict: bool = True,
|
1786 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1787 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
1788 |
+
) -> Union[Dict, Tuple]:
|
1789 |
+
r"""
|
1790 |
+
current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
|
1791 |
+
"""
|
1792 |
+
|
1793 |
+
r"""
|
1794 |
+
Args:
|
1795 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
1796 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
1797 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
1798 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1799 |
+
Whether or not to return a dict instead of a plain tuple.
|
1800 |
+
|
1801 |
+
Returns:
|
1802 |
+
`SampleOutput` or `tuple`:
|
1803 |
+
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
1804 |
+
"""
|
1805 |
+
|
1806 |
+
_self = self.delegate
|
1807 |
+
|
1808 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
1809 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
1810 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
1811 |
+
# on the fly if necessary.
|
1812 |
+
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
1813 |
+
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
1814 |
+
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
1815 |
+
default_overall_up_factor = 2**_self.num_upsamplers
|
1816 |
+
|
1817 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
1818 |
+
# 64で割り切れないときはupsamplerにサイズを伝える
|
1819 |
+
forward_upsample_size = False
|
1820 |
+
upsample_size = None
|
1821 |
+
|
1822 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
1823 |
+
# logger.info("Forward upsample size to force interpolation output size.")
|
1824 |
+
forward_upsample_size = True
|
1825 |
+
|
1826 |
+
# 1. time
|
1827 |
+
timesteps = timestep
|
1828 |
+
timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
1829 |
+
|
1830 |
+
t_emb = _self.time_proj(timesteps)
|
1831 |
+
|
1832 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
1833 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
1834 |
+
# there might be better ways to encapsulate this.
|
1835 |
+
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
1836 |
+
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
1837 |
+
# time_projでキャストしておけばいいんじゃね?
|
1838 |
+
t_emb = t_emb.to(dtype=_self.dtype)
|
1839 |
+
emb = _self.time_embedding(t_emb)
|
1840 |
+
|
1841 |
+
# 2. pre-process
|
1842 |
+
sample = _self.conv_in(sample)
|
1843 |
+
|
1844 |
+
down_block_res_samples = (sample,)
|
1845 |
+
for depth, downsample_block in enumerate(_self.down_blocks):
|
1846 |
+
# Deep Shrink
|
1847 |
+
if self.ds_depth_1 is not None:
|
1848 |
+
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
1849 |
+
self.ds_depth_2 is not None
|
1850 |
+
and depth == self.ds_depth_2
|
1851 |
+
and timesteps[0] < self.ds_timesteps_1
|
1852 |
+
and timesteps[0] >= self.ds_timesteps_2
|
1853 |
+
):
|
1854 |
+
org_dtype = sample.dtype
|
1855 |
+
if org_dtype == torch.bfloat16:
|
1856 |
+
sample = sample.to(torch.float32)
|
1857 |
+
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
1858 |
+
|
1859 |
+
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
1860 |
+
# まあこちらのほうがわかりやすいかもしれない
|
1861 |
+
if downsample_block.has_cross_attention:
|
1862 |
+
sample, res_samples = downsample_block(
|
1863 |
+
hidden_states=sample,
|
1864 |
+
temb=emb,
|
1865 |
+
encoder_hidden_states=encoder_hidden_states,
|
1866 |
+
)
|
1867 |
+
else:
|
1868 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
1869 |
+
|
1870 |
+
down_block_res_samples += res_samples
|
1871 |
+
|
1872 |
+
# skip connectionにControlNetの出力を追加する
|
1873 |
+
if down_block_additional_residuals is not None:
|
1874 |
+
down_block_res_samples = list(down_block_res_samples)
|
1875 |
+
for i in range(len(down_block_res_samples)):
|
1876 |
+
down_block_res_samples[i] += down_block_additional_residuals[i]
|
1877 |
+
down_block_res_samples = tuple(down_block_res_samples)
|
1878 |
+
|
1879 |
+
# 4. mid
|
1880 |
+
sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
1881 |
+
|
1882 |
+
# ControlNetの出力を追加する
|
1883 |
+
if mid_block_additional_residual is not None:
|
1884 |
+
sample += mid_block_additional_residual
|
1885 |
+
|
1886 |
+
# 5. up
|
1887 |
+
for i, upsample_block in enumerate(_self.up_blocks):
|
1888 |
+
is_final_block = i == len(_self.up_blocks) - 1
|
1889 |
+
|
1890 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1891 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
1892 |
+
|
1893 |
+
# if we have not reached the final block and need to forward the upsample size, we do it here
|
1894 |
+
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
1895 |
+
if not is_final_block and forward_upsample_size:
|
1896 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
1897 |
+
|
1898 |
+
if upsample_block.has_cross_attention:
|
1899 |
+
sample = upsample_block(
|
1900 |
+
hidden_states=sample,
|
1901 |
+
temb=emb,
|
1902 |
+
res_hidden_states_tuple=res_samples,
|
1903 |
+
encoder_hidden_states=encoder_hidden_states,
|
1904 |
+
upsample_size=upsample_size,
|
1905 |
+
)
|
1906 |
+
else:
|
1907 |
+
sample = upsample_block(
|
1908 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
1909 |
+
)
|
1910 |
+
|
1911 |
+
# 6. post-process
|
1912 |
+
sample = _self.conv_norm_out(sample)
|
1913 |
+
sample = _self.conv_act(sample)
|
1914 |
+
sample = _self.conv_out(sample)
|
1915 |
+
|
1916 |
+
if not return_dict:
|
1917 |
+
return (sample,)
|
1918 |
+
|
1919 |
+
return SampleOutput(sample=sample)
|
library/sai_model_spec.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://github.com/Stability-AI/ModelSpec
|
2 |
+
import datetime
|
3 |
+
import hashlib
|
4 |
+
from io import BytesIO
|
5 |
+
import os
|
6 |
+
from typing import List, Optional, Tuple, Union
|
7 |
+
import safetensors
|
8 |
+
from library.utils import setup_logging
|
9 |
+
|
10 |
+
setup_logging()
|
11 |
+
import logging
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
r"""
|
16 |
+
# Metadata Example
|
17 |
+
metadata = {
|
18 |
+
# === Must ===
|
19 |
+
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
20 |
+
"modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
|
21 |
+
"modelspec.implementation": "sgm",
|
22 |
+
"modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
|
23 |
+
# === Should ===
|
24 |
+
"modelspec.author": "Example Corp", # Your name or company name
|
25 |
+
"modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
|
26 |
+
"modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
|
27 |
+
# === Can ===
|
28 |
+
"modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
|
29 |
+
"modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
|
30 |
+
}
|
31 |
+
"""
|
32 |
+
|
33 |
+
BASE_METADATA = {
|
34 |
+
# === Must ===
|
35 |
+
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
36 |
+
"modelspec.architecture": None,
|
37 |
+
"modelspec.implementation": None,
|
38 |
+
"modelspec.title": None,
|
39 |
+
"modelspec.resolution": None,
|
40 |
+
# === Should ===
|
41 |
+
"modelspec.description": None,
|
42 |
+
"modelspec.author": None,
|
43 |
+
"modelspec.date": None,
|
44 |
+
# === Can ===
|
45 |
+
"modelspec.license": None,
|
46 |
+
"modelspec.tags": None,
|
47 |
+
"modelspec.merged_from": None,
|
48 |
+
"modelspec.prediction_type": None,
|
49 |
+
"modelspec.timestep_range": None,
|
50 |
+
"modelspec.encoder_layer": None,
|
51 |
+
}
|
52 |
+
|
53 |
+
# 別に使うやつだけ定義
|
54 |
+
MODELSPEC_TITLE = "modelspec.title"
|
55 |
+
|
56 |
+
ARCH_SD_V1 = "stable-diffusion-v1"
|
57 |
+
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
|
58 |
+
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
59 |
+
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
60 |
+
ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
|
61 |
+
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
|
62 |
+
ARCH_FLUX_1_DEV = "flux-1-dev"
|
63 |
+
ARCH_FLUX_1_UNKNOWN = "flux-1"
|
64 |
+
|
65 |
+
ADAPTER_LORA = "lora"
|
66 |
+
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
67 |
+
|
68 |
+
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
69 |
+
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
|
70 |
+
IMPL_DIFFUSERS = "diffusers"
|
71 |
+
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
|
72 |
+
|
73 |
+
PRED_TYPE_EPSILON = "epsilon"
|
74 |
+
PRED_TYPE_V = "v"
|
75 |
+
|
76 |
+
|
77 |
+
def load_bytes_in_safetensors(tensors):
|
78 |
+
bytes = safetensors.torch.save(tensors)
|
79 |
+
b = BytesIO(bytes)
|
80 |
+
|
81 |
+
b.seek(0)
|
82 |
+
header = b.read(8)
|
83 |
+
n = int.from_bytes(header, "little")
|
84 |
+
|
85 |
+
offset = n + 8
|
86 |
+
b.seek(offset)
|
87 |
+
|
88 |
+
return b.read()
|
89 |
+
|
90 |
+
|
91 |
+
def precalculate_safetensors_hashes(state_dict):
|
92 |
+
# calculate each tensor one by one to reduce memory usage
|
93 |
+
hash_sha256 = hashlib.sha256()
|
94 |
+
for tensor in state_dict.values():
|
95 |
+
single_tensor_sd = {"tensor": tensor}
|
96 |
+
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
|
97 |
+
hash_sha256.update(bytes_for_tensor)
|
98 |
+
|
99 |
+
return f"0x{hash_sha256.hexdigest()}"
|
100 |
+
|
101 |
+
|
102 |
+
def update_hash_sha256(metadata: dict, state_dict: dict):
|
103 |
+
raise NotImplementedError
|
104 |
+
|
105 |
+
|
106 |
+
def build_metadata(
|
107 |
+
state_dict: Optional[dict],
|
108 |
+
v2: bool,
|
109 |
+
v_parameterization: bool,
|
110 |
+
sdxl: bool,
|
111 |
+
lora: bool,
|
112 |
+
textual_inversion: bool,
|
113 |
+
timestamp: float,
|
114 |
+
title: Optional[str] = None,
|
115 |
+
reso: Optional[Union[int, Tuple[int, int]]] = None,
|
116 |
+
is_stable_diffusion_ckpt: Optional[bool] = None,
|
117 |
+
author: Optional[str] = None,
|
118 |
+
description: Optional[str] = None,
|
119 |
+
license: Optional[str] = None,
|
120 |
+
tags: Optional[str] = None,
|
121 |
+
merged_from: Optional[str] = None,
|
122 |
+
timesteps: Optional[Tuple[int, int]] = None,
|
123 |
+
clip_skip: Optional[int] = None,
|
124 |
+
sd3: Optional[str] = None,
|
125 |
+
flux: Optional[str] = None,
|
126 |
+
):
|
127 |
+
"""
|
128 |
+
sd3: only supports "m", flux: only supports "dev"
|
129 |
+
"""
|
130 |
+
# if state_dict is None, hash is not calculated
|
131 |
+
|
132 |
+
metadata = {}
|
133 |
+
metadata.update(BASE_METADATA)
|
134 |
+
|
135 |
+
# TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
|
136 |
+
# if state_dict is not None:
|
137 |
+
# hash = precalculate_safetensors_hashes(state_dict)
|
138 |
+
# metadata["modelspec.hash_sha256"] = hash
|
139 |
+
|
140 |
+
if sdxl:
|
141 |
+
arch = ARCH_SD_XL_V1_BASE
|
142 |
+
elif sd3 is not None:
|
143 |
+
arch = ARCH_SD3_M + "-" + sd3
|
144 |
+
elif flux is not None:
|
145 |
+
if flux == "dev":
|
146 |
+
arch = ARCH_FLUX_1_DEV
|
147 |
+
else:
|
148 |
+
arch = ARCH_FLUX_1_UNKNOWN
|
149 |
+
elif v2:
|
150 |
+
if v_parameterization:
|
151 |
+
arch = ARCH_SD_V2_768_V
|
152 |
+
else:
|
153 |
+
arch = ARCH_SD_V2_512
|
154 |
+
else:
|
155 |
+
arch = ARCH_SD_V1
|
156 |
+
|
157 |
+
if lora:
|
158 |
+
arch += f"/{ADAPTER_LORA}"
|
159 |
+
elif textual_inversion:
|
160 |
+
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
|
161 |
+
|
162 |
+
metadata["modelspec.architecture"] = arch
|
163 |
+
|
164 |
+
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
165 |
+
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
166 |
+
|
167 |
+
if flux is not None:
|
168 |
+
# Flux
|
169 |
+
impl = IMPL_FLUX
|
170 |
+
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
171 |
+
# Stable Diffusion ckpt, TI, SDXL LoRA
|
172 |
+
impl = IMPL_STABILITY_AI
|
173 |
+
else:
|
174 |
+
# v1/v2 LoRA or Diffusers
|
175 |
+
impl = IMPL_DIFFUSERS
|
176 |
+
metadata["modelspec.implementation"] = impl
|
177 |
+
|
178 |
+
if title is None:
|
179 |
+
if lora:
|
180 |
+
title = "LoRA"
|
181 |
+
elif textual_inversion:
|
182 |
+
title = "TextualInversion"
|
183 |
+
else:
|
184 |
+
title = "Checkpoint"
|
185 |
+
title += f"@{timestamp}"
|
186 |
+
metadata[MODELSPEC_TITLE] = title
|
187 |
+
|
188 |
+
if author is not None:
|
189 |
+
metadata["modelspec.author"] = author
|
190 |
+
else:
|
191 |
+
del metadata["modelspec.author"]
|
192 |
+
|
193 |
+
if description is not None:
|
194 |
+
metadata["modelspec.description"] = description
|
195 |
+
else:
|
196 |
+
del metadata["modelspec.description"]
|
197 |
+
|
198 |
+
if merged_from is not None:
|
199 |
+
metadata["modelspec.merged_from"] = merged_from
|
200 |
+
else:
|
201 |
+
del metadata["modelspec.merged_from"]
|
202 |
+
|
203 |
+
if license is not None:
|
204 |
+
metadata["modelspec.license"] = license
|
205 |
+
else:
|
206 |
+
del metadata["modelspec.license"]
|
207 |
+
|
208 |
+
if tags is not None:
|
209 |
+
metadata["modelspec.tags"] = tags
|
210 |
+
else:
|
211 |
+
del metadata["modelspec.tags"]
|
212 |
+
|
213 |
+
# remove microsecond from time
|
214 |
+
int_ts = int(timestamp)
|
215 |
+
|
216 |
+
# time to iso-8601 compliant date
|
217 |
+
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
|
218 |
+
metadata["modelspec.date"] = date
|
219 |
+
|
220 |
+
if reso is not None:
|
221 |
+
# comma separated to tuple
|
222 |
+
if isinstance(reso, str):
|
223 |
+
reso = tuple(map(int, reso.split(",")))
|
224 |
+
if len(reso) == 1:
|
225 |
+
reso = (reso[0], reso[0])
|
226 |
+
else:
|
227 |
+
# resolution is defined in dataset, so use default
|
228 |
+
if sdxl or sd3 is not None or flux is not None:
|
229 |
+
reso = 1024
|
230 |
+
elif v2 and v_parameterization:
|
231 |
+
reso = 768
|
232 |
+
else:
|
233 |
+
reso = 512
|
234 |
+
if isinstance(reso, int):
|
235 |
+
reso = (reso, reso)
|
236 |
+
|
237 |
+
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
|
238 |
+
|
239 |
+
if flux is not None:
|
240 |
+
del metadata["modelspec.prediction_type"]
|
241 |
+
elif v_parameterization:
|
242 |
+
metadata["modelspec.prediction_type"] = PRED_TYPE_V
|
243 |
+
else:
|
244 |
+
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
|
245 |
+
|
246 |
+
if timesteps is not None:
|
247 |
+
if isinstance(timesteps, str) or isinstance(timesteps, int):
|
248 |
+
timesteps = (timesteps, timesteps)
|
249 |
+
if len(timesteps) == 1:
|
250 |
+
timesteps = (timesteps[0], timesteps[0])
|
251 |
+
metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
|
252 |
+
else:
|
253 |
+
del metadata["modelspec.timestep_range"]
|
254 |
+
|
255 |
+
if clip_skip is not None:
|
256 |
+
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
|
257 |
+
else:
|
258 |
+
del metadata["modelspec.encoder_layer"]
|
259 |
+
|
260 |
+
# # assert all values are filled
|
261 |
+
# assert all([v is not None for v in metadata.values()]), metadata
|
262 |
+
if not all([v is not None for v in metadata.values()]):
|
263 |
+
logger.error(f"Internal error: some metadata values are None: {metadata}")
|
264 |
+
|
265 |
+
return metadata
|
266 |
+
|
267 |
+
|
268 |
+
# region utils
|
269 |
+
|
270 |
+
|
271 |
+
def get_title(metadata: dict) -> Optional[str]:
|
272 |
+
return metadata.get(MODELSPEC_TITLE, None)
|
273 |
+
|
274 |
+
|
275 |
+
def load_metadata_from_safetensors(model: str) -> dict:
|
276 |
+
if not model.endswith(".safetensors"):
|
277 |
+
return {}
|
278 |
+
|
279 |
+
with safetensors.safe_open(model, framework="pt") as f:
|
280 |
+
metadata = f.metadata()
|
281 |
+
if metadata is None:
|
282 |
+
metadata = {}
|
283 |
+
return metadata
|
284 |
+
|
285 |
+
|
286 |
+
def build_merged_from(models: List[str]) -> str:
|
287 |
+
def get_title(model: str):
|
288 |
+
metadata = load_metadata_from_safetensors(model)
|
289 |
+
title = metadata.get(MODELSPEC_TITLE, None)
|
290 |
+
if title is None:
|
291 |
+
title = os.path.splitext(os.path.basename(model))[0] # use filename
|
292 |
+
return title
|
293 |
+
|
294 |
+
titles = [get_title(model) for model in models]
|
295 |
+
return ", ".join(titles)
|
296 |
+
|
297 |
+
|
298 |
+
# endregion
|
299 |
+
|
300 |
+
|
301 |
+
r"""
|
302 |
+
if __name__ == "__main__":
|
303 |
+
import argparse
|
304 |
+
import torch
|
305 |
+
from safetensors.torch import load_file
|
306 |
+
from library import train_util
|
307 |
+
|
308 |
+
parser = argparse.ArgumentParser()
|
309 |
+
parser.add_argument("--ckpt", type=str, required=True)
|
310 |
+
args = parser.parse_args()
|
311 |
+
|
312 |
+
print(f"Loading {args.ckpt}")
|
313 |
+
state_dict = load_file(args.ckpt)
|
314 |
+
|
315 |
+
print(f"Calculating metadata")
|
316 |
+
metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
|
317 |
+
print(metadata)
|
318 |
+
del state_dict
|
319 |
+
|
320 |
+
# by reference implementation
|
321 |
+
with open(args.ckpt, mode="rb") as file_data:
|
322 |
+
file_hash = hashlib.sha256()
|
323 |
+
head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
|
324 |
+
header = json.loads(file_data.read(head_len[0])) # header itself, json string
|
325 |
+
content = (
|
326 |
+
file_data.read()
|
327 |
+
) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
|
328 |
+
file_hash.update(content)
|
329 |
+
# ===== Update the hash for modelspec =====
|
330 |
+
by_ref = f"0x{file_hash.hexdigest()}"
|
331 |
+
print(by_ref)
|
332 |
+
print("is same?", by_ref == metadata["modelspec.hash_sha256"])
|
333 |
+
|
334 |
+
"""
|
library/sd3_models.py
ADDED
@@ -0,0 +1,1413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref
|
2 |
+
# the original code is licensed under the MIT License
|
3 |
+
|
4 |
+
# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution!
|
5 |
+
|
6 |
+
from ast import Tuple
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from functools import partial
|
10 |
+
import math
|
11 |
+
from types import SimpleNamespace
|
12 |
+
from typing import Dict, List, Optional, Union
|
13 |
+
import einops
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch.utils.checkpoint import checkpoint
|
19 |
+
from transformers import CLIPTokenizer, T5TokenizerFast
|
20 |
+
|
21 |
+
from library import custom_offloading_utils
|
22 |
+
from library.device_utils import clean_memory_on_device
|
23 |
+
|
24 |
+
from .utils import setup_logging
|
25 |
+
|
26 |
+
setup_logging()
|
27 |
+
import logging
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
memory_efficient_attention = None
|
33 |
+
try:
|
34 |
+
import xformers
|
35 |
+
except:
|
36 |
+
pass
|
37 |
+
|
38 |
+
try:
|
39 |
+
from xformers.ops import memory_efficient_attention
|
40 |
+
except:
|
41 |
+
memory_efficient_attention = None
|
42 |
+
|
43 |
+
|
44 |
+
# region mmdit
|
45 |
+
|
46 |
+
|
47 |
+
@dataclass
|
48 |
+
class SD3Params:
|
49 |
+
patch_size: int
|
50 |
+
depth: int
|
51 |
+
num_patches: int
|
52 |
+
pos_embed_max_size: int
|
53 |
+
adm_in_channels: int
|
54 |
+
qk_norm: Optional[str]
|
55 |
+
x_block_self_attn_layers: list[int]
|
56 |
+
context_embedder_in_features: int
|
57 |
+
context_embedder_out_features: int
|
58 |
+
model_type: str
|
59 |
+
|
60 |
+
|
61 |
+
def get_2d_sincos_pos_embed(
|
62 |
+
embed_dim,
|
63 |
+
grid_size,
|
64 |
+
scaling_factor=None,
|
65 |
+
offset=None,
|
66 |
+
):
|
67 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
68 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
69 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
70 |
+
grid = np.stack(grid, axis=0)
|
71 |
+
if scaling_factor is not None:
|
72 |
+
grid = grid / scaling_factor
|
73 |
+
if offset is not None:
|
74 |
+
grid = grid - offset
|
75 |
+
|
76 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
77 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
78 |
+
return pos_embed
|
79 |
+
|
80 |
+
|
81 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
82 |
+
assert embed_dim % 2 == 0
|
83 |
+
|
84 |
+
# use half of dimensions to encode grid_h
|
85 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
86 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
87 |
+
|
88 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
89 |
+
return emb
|
90 |
+
|
91 |
+
|
92 |
+
def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16):
|
93 |
+
"""
|
94 |
+
This function is contributed by KohakuBlueleaf. Thanks for the contribution!
|
95 |
+
|
96 |
+
Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions
|
97 |
+
when the resolution differs from the training resolution.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
embed_dim (int): Dimension of the positional embedding.
|
101 |
+
grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid.
|
102 |
+
cls_token (bool): Whether to include class token. Defaults to False.
|
103 |
+
extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0.
|
104 |
+
sample_size (int): Reference resolution (typically training resolution). Defaults to 64.
|
105 |
+
base_size (int): Base grid size used during training. Defaults to 16.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or
|
109 |
+
(H*W + extra_tokens, embed_dim) if cls_token is True.
|
110 |
+
"""
|
111 |
+
# Convert grid_size to tuple if it's an integer
|
112 |
+
if isinstance(grid_size, int):
|
113 |
+
grid_size = (grid_size, grid_size)
|
114 |
+
|
115 |
+
# Create normalized grid coordinates (0 to 1)
|
116 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0]
|
117 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1]
|
118 |
+
|
119 |
+
# Calculate scaling factors for height and width
|
120 |
+
# This ensures that the central region matches the original resolution's embeddings
|
121 |
+
scale_h = base_size * grid_size[0] / (sample_size)
|
122 |
+
scale_w = base_size * grid_size[1] / (sample_size)
|
123 |
+
|
124 |
+
# Calculate shift values to center the original resolution's embedding region
|
125 |
+
# This ensures that the central sample_size x sample_size region has similar
|
126 |
+
# positional embeddings to the original resolution
|
127 |
+
shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0])
|
128 |
+
shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1])
|
129 |
+
|
130 |
+
# Apply scaling and shifting to create the final grid coordinates
|
131 |
+
grid_h = grid_h * scale_h - shift_h
|
132 |
+
grid_w = grid_w * scale_w - shift_w
|
133 |
+
|
134 |
+
# Create 2D grid using meshgrid (note: w goes first)
|
135 |
+
grid = np.meshgrid(grid_w, grid_h)
|
136 |
+
grid = np.stack(grid, axis=0)
|
137 |
+
|
138 |
+
# # Calculate the starting indices for the central region
|
139 |
+
# # This is used for debugging/visualization of the central region
|
140 |
+
# st_h = (grid_size[0] - sample_size) // 2
|
141 |
+
# st_w = (grid_size[1] - sample_size) // 2
|
142 |
+
# print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size])
|
143 |
+
|
144 |
+
# Reshape grid for positional embedding calculation
|
145 |
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
146 |
+
|
147 |
+
# Generate the sinusoidal positional embeddings
|
148 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
149 |
+
|
150 |
+
# Add zeros for extra tokens (e.g., [CLS] token) if required
|
151 |
+
if cls_token and extra_tokens > 0:
|
152 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
153 |
+
|
154 |
+
return pos_embed
|
155 |
+
|
156 |
+
|
157 |
+
# if __name__ == "__main__":
|
158 |
+
# # This is what you get when you load SD3.5 state dict
|
159 |
+
# pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed(
|
160 |
+
# 1536, [384, 384], sample_size=64, base_size=16
|
161 |
+
# )).float().unsqueeze(0)
|
162 |
+
|
163 |
+
|
164 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
165 |
+
"""
|
166 |
+
embed_dim: output dimension for each position
|
167 |
+
pos: a list of positions to be encoded: size (M,)
|
168 |
+
out: (M, D)
|
169 |
+
"""
|
170 |
+
assert embed_dim % 2 == 0
|
171 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
172 |
+
omega /= embed_dim / 2.0
|
173 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
174 |
+
|
175 |
+
pos = pos.reshape(-1) # (M,)
|
176 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
177 |
+
|
178 |
+
emb_sin = np.sin(out) # (M, D/2)
|
179 |
+
emb_cos = np.cos(out) # (M, D/2)
|
180 |
+
|
181 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
182 |
+
return emb
|
183 |
+
|
184 |
+
|
185 |
+
def get_1d_sincos_pos_embed_from_grid_torch(
|
186 |
+
embed_dim,
|
187 |
+
pos,
|
188 |
+
device=None,
|
189 |
+
dtype=torch.float32,
|
190 |
+
):
|
191 |
+
omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
|
192 |
+
omega *= 2.0 / embed_dim
|
193 |
+
omega = 1.0 / 10000**omega
|
194 |
+
out = torch.outer(pos.reshape(-1), omega)
|
195 |
+
emb = torch.cat([out.sin(), out.cos()], dim=1)
|
196 |
+
return emb
|
197 |
+
|
198 |
+
|
199 |
+
def get_2d_sincos_pos_embed_torch(
|
200 |
+
embed_dim,
|
201 |
+
w,
|
202 |
+
h,
|
203 |
+
val_center=7.5,
|
204 |
+
val_magnitude=7.5,
|
205 |
+
device=None,
|
206 |
+
dtype=torch.float32,
|
207 |
+
):
|
208 |
+
small = min(h, w)
|
209 |
+
val_h = (h / small) * val_magnitude
|
210 |
+
val_w = (w / small) * val_magnitude
|
211 |
+
grid_h, grid_w = torch.meshgrid(
|
212 |
+
torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype),
|
213 |
+
torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype),
|
214 |
+
indexing="ij",
|
215 |
+
)
|
216 |
+
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
217 |
+
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
218 |
+
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
219 |
+
return emb
|
220 |
+
|
221 |
+
|
222 |
+
def modulate(x, shift, scale):
|
223 |
+
if shift is None:
|
224 |
+
shift = torch.zeros_like(scale)
|
225 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
226 |
+
|
227 |
+
|
228 |
+
def default(x, default_value):
|
229 |
+
if x is None:
|
230 |
+
return default_value
|
231 |
+
return x
|
232 |
+
|
233 |
+
|
234 |
+
def timestep_embedding(t, dim, max_period=10000):
|
235 |
+
half = dim // 2
|
236 |
+
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
237 |
+
# device=t.device, dtype=t.dtype
|
238 |
+
# )
|
239 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
240 |
+
args = t[:, None].float() * freqs[None]
|
241 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
242 |
+
if dim % 2:
|
243 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
244 |
+
if torch.is_floating_point(t):
|
245 |
+
embedding = embedding.to(dtype=t.dtype)
|
246 |
+
return embedding
|
247 |
+
|
248 |
+
|
249 |
+
class PatchEmbed(nn.Module):
|
250 |
+
def __init__(
|
251 |
+
self,
|
252 |
+
img_size=256,
|
253 |
+
patch_size=4,
|
254 |
+
in_channels=3,
|
255 |
+
embed_dim=512,
|
256 |
+
norm_layer=None,
|
257 |
+
flatten=True,
|
258 |
+
bias=True,
|
259 |
+
strict_img_size=True,
|
260 |
+
dynamic_img_pad=False,
|
261 |
+
):
|
262 |
+
# dynamic_img_pad and norm is omitted in SD3.5
|
263 |
+
super().__init__()
|
264 |
+
self.patch_size = patch_size
|
265 |
+
self.flatten = flatten
|
266 |
+
self.strict_img_size = strict_img_size
|
267 |
+
self.dynamic_img_pad = dynamic_img_pad
|
268 |
+
if img_size is not None:
|
269 |
+
self.img_size = img_size
|
270 |
+
self.grid_size = img_size // patch_size
|
271 |
+
self.num_patches = self.grid_size**2
|
272 |
+
else:
|
273 |
+
self.img_size = None
|
274 |
+
self.grid_size = None
|
275 |
+
self.num_patches = None
|
276 |
+
|
277 |
+
self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias)
|
278 |
+
self.norm = nn.Identity() if norm_layer is None else norm_layer(embed_dim)
|
279 |
+
|
280 |
+
def forward(self, x):
|
281 |
+
B, C, H, W = x.shape
|
282 |
+
|
283 |
+
if self.dynamic_img_pad:
|
284 |
+
# Pad input so we won't have partial patch
|
285 |
+
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
|
286 |
+
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
|
287 |
+
x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
|
288 |
+
x = self.proj(x)
|
289 |
+
if self.flatten:
|
290 |
+
x = x.flatten(2).transpose(1, 2)
|
291 |
+
x = self.norm(x)
|
292 |
+
return x
|
293 |
+
|
294 |
+
|
295 |
+
# FinalLayer in mmdit.py
|
296 |
+
class UnPatch(nn.Module):
|
297 |
+
def __init__(self, hidden_size=512, patch_size=4, out_channels=3):
|
298 |
+
super().__init__()
|
299 |
+
self.patch_size = patch_size
|
300 |
+
self.c = out_channels
|
301 |
+
|
302 |
+
# eps is default in mmdit.py
|
303 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
304 |
+
self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels)
|
305 |
+
self.adaLN_modulation = nn.Sequential(
|
306 |
+
nn.SiLU(),
|
307 |
+
nn.Linear(hidden_size, 2 * hidden_size),
|
308 |
+
)
|
309 |
+
|
310 |
+
def forward(self, x: torch.Tensor, cmod, H=None, W=None):
|
311 |
+
b, n, _ = x.shape
|
312 |
+
p = self.patch_size
|
313 |
+
c = self.c
|
314 |
+
if H is None and W is None:
|
315 |
+
w = h = int(n**0.5)
|
316 |
+
assert h * w == n
|
317 |
+
else:
|
318 |
+
h = H // p if H else n // (W // p)
|
319 |
+
w = W // p if W else n // h
|
320 |
+
assert h * w == n
|
321 |
+
|
322 |
+
shift, scale = self.adaLN_modulation(cmod).chunk(2, dim=-1)
|
323 |
+
x = modulate(self.norm_final(x), shift, scale)
|
324 |
+
x = self.linear(x)
|
325 |
+
|
326 |
+
x = x.view(b, h, w, p, p, c)
|
327 |
+
x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
|
328 |
+
x = x.view(b, c, h * p, w * p)
|
329 |
+
return x
|
330 |
+
|
331 |
+
|
332 |
+
class MLP(nn.Module):
|
333 |
+
def __init__(
|
334 |
+
self,
|
335 |
+
in_features,
|
336 |
+
hidden_features=None,
|
337 |
+
out_features=None,
|
338 |
+
act_layer=lambda: nn.GELU(),
|
339 |
+
norm_layer=None,
|
340 |
+
bias=True,
|
341 |
+
use_conv=False,
|
342 |
+
):
|
343 |
+
super().__init__()
|
344 |
+
out_features = out_features or in_features
|
345 |
+
hidden_features = hidden_features or in_features
|
346 |
+
self.use_conv = use_conv
|
347 |
+
|
348 |
+
layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear
|
349 |
+
|
350 |
+
self.fc1 = layer(in_features, hidden_features, bias=bias)
|
351 |
+
self.fc2 = layer(hidden_features, out_features, bias=bias)
|
352 |
+
self.act = act_layer()
|
353 |
+
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
|
354 |
+
|
355 |
+
def forward(self, x):
|
356 |
+
x = self.fc1(x)
|
357 |
+
x = self.act(x)
|
358 |
+
x = self.norm(x)
|
359 |
+
x = self.fc2(x)
|
360 |
+
return x
|
361 |
+
|
362 |
+
|
363 |
+
class TimestepEmbedding(nn.Module):
|
364 |
+
def __init__(self, hidden_size, freq_embed_size=256):
|
365 |
+
super().__init__()
|
366 |
+
self.mlp = nn.Sequential(
|
367 |
+
nn.Linear(freq_embed_size, hidden_size),
|
368 |
+
nn.SiLU(),
|
369 |
+
nn.Linear(hidden_size, hidden_size),
|
370 |
+
)
|
371 |
+
self.freq_embed_size = freq_embed_size
|
372 |
+
|
373 |
+
def forward(self, t, dtype=None, **kwargs):
|
374 |
+
t_freq = timestep_embedding(t, self.freq_embed_size).to(dtype)
|
375 |
+
t_emb = self.mlp(t_freq)
|
376 |
+
return t_emb
|
377 |
+
|
378 |
+
|
379 |
+
class Embedder(nn.Module):
|
380 |
+
def __init__(self, input_dim, hidden_size):
|
381 |
+
super().__init__()
|
382 |
+
self.mlp = nn.Sequential(
|
383 |
+
nn.Linear(input_dim, hidden_size),
|
384 |
+
nn.SiLU(),
|
385 |
+
nn.Linear(hidden_size, hidden_size),
|
386 |
+
)
|
387 |
+
|
388 |
+
def forward(self, x):
|
389 |
+
return self.mlp(x)
|
390 |
+
|
391 |
+
|
392 |
+
def rmsnorm(x, eps=1e-6):
|
393 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
394 |
+
|
395 |
+
|
396 |
+
class RMSNorm(torch.nn.Module):
|
397 |
+
def __init__(
|
398 |
+
self,
|
399 |
+
dim: int,
|
400 |
+
elementwise_affine: bool = False,
|
401 |
+
eps: float = 1e-6,
|
402 |
+
device=None,
|
403 |
+
dtype=None,
|
404 |
+
):
|
405 |
+
"""
|
406 |
+
Initialize the RMSNorm normalization layer.
|
407 |
+
Args:
|
408 |
+
dim (int): The dimension of the input tensor.
|
409 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
410 |
+
Attributes:
|
411 |
+
eps (float): A small value added to the denominator for numerical stability.
|
412 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
413 |
+
"""
|
414 |
+
super().__init__()
|
415 |
+
self.eps = eps
|
416 |
+
self.learnable_scale = elementwise_affine
|
417 |
+
if self.learnable_scale:
|
418 |
+
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
419 |
+
else:
|
420 |
+
self.register_parameter("weight", None)
|
421 |
+
|
422 |
+
def forward(self, x):
|
423 |
+
"""
|
424 |
+
Forward pass through the RMSNorm layer.
|
425 |
+
Args:
|
426 |
+
x (torch.Tensor): The input tensor.
|
427 |
+
Returns:
|
428 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
429 |
+
"""
|
430 |
+
x = rmsnorm(x, eps=self.eps)
|
431 |
+
if self.learnable_scale:
|
432 |
+
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
433 |
+
else:
|
434 |
+
return x
|
435 |
+
|
436 |
+
|
437 |
+
class SwiGLUFeedForward(nn.Module):
|
438 |
+
def __init__(
|
439 |
+
self,
|
440 |
+
dim: int,
|
441 |
+
hidden_dim: int,
|
442 |
+
multiple_of: int,
|
443 |
+
ffn_dim_multiplier: float = None,
|
444 |
+
):
|
445 |
+
super().__init__()
|
446 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
447 |
+
# custom dim factor multiplier
|
448 |
+
if ffn_dim_multiplier is not None:
|
449 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
450 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
451 |
+
|
452 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
453 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
454 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
455 |
+
|
456 |
+
def forward(self, x):
|
457 |
+
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
458 |
+
|
459 |
+
|
460 |
+
# Linears for SelfAttention in mmdit.py
|
461 |
+
class AttentionLinears(nn.Module):
|
462 |
+
def __init__(
|
463 |
+
self,
|
464 |
+
dim: int,
|
465 |
+
num_heads: int = 8,
|
466 |
+
qkv_bias: bool = False,
|
467 |
+
pre_only: bool = False,
|
468 |
+
qk_norm: Optional[str] = None,
|
469 |
+
):
|
470 |
+
super().__init__()
|
471 |
+
self.num_heads = num_heads
|
472 |
+
self.head_dim = dim // num_heads
|
473 |
+
|
474 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
475 |
+
if not pre_only:
|
476 |
+
self.proj = nn.Linear(dim, dim)
|
477 |
+
self.pre_only = pre_only
|
478 |
+
|
479 |
+
if qk_norm == "rms":
|
480 |
+
self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
|
481 |
+
self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
|
482 |
+
elif qk_norm == "ln":
|
483 |
+
self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
|
484 |
+
self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
|
485 |
+
elif qk_norm is None:
|
486 |
+
self.ln_q = nn.Identity()
|
487 |
+
self.ln_k = nn.Identity()
|
488 |
+
else:
|
489 |
+
raise ValueError(qk_norm)
|
490 |
+
|
491 |
+
def pre_attention(self, x: torch.Tensor) -> torch.Tensor:
|
492 |
+
"""
|
493 |
+
output:
|
494 |
+
q, k, v: [B, L, D]
|
495 |
+
"""
|
496 |
+
B, L, C = x.shape
|
497 |
+
qkv: torch.Tensor = self.qkv(x)
|
498 |
+
q, k, v = qkv.reshape(B, L, -1, self.head_dim).chunk(3, dim=2)
|
499 |
+
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
|
500 |
+
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
|
501 |
+
return (q, k, v)
|
502 |
+
|
503 |
+
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
|
504 |
+
assert not self.pre_only
|
505 |
+
x = self.proj(x)
|
506 |
+
return x
|
507 |
+
|
508 |
+
|
509 |
+
MEMORY_LAYOUTS = {
|
510 |
+
"torch": (
|
511 |
+
lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2),
|
512 |
+
lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1),
|
513 |
+
lambda x: (1, x, 1, 1),
|
514 |
+
),
|
515 |
+
"xformers": (
|
516 |
+
lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim),
|
517 |
+
lambda x: x.reshape(x.shape[0], x.shape[1], -1),
|
518 |
+
lambda x: (1, 1, x, 1),
|
519 |
+
),
|
520 |
+
"math": (
|
521 |
+
lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2),
|
522 |
+
lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1),
|
523 |
+
lambda x: (1, x, 1, 1),
|
524 |
+
),
|
525 |
+
}
|
526 |
+
# ATTN_FUNCTION = {
|
527 |
+
# "torch": F.scaled_dot_product_attention,
|
528 |
+
# "xformers": memory_efficient_attention,
|
529 |
+
# }
|
530 |
+
|
531 |
+
|
532 |
+
def vanilla_attention(q, k, v, mask, scale=None):
|
533 |
+
if scale is None:
|
534 |
+
scale = math.sqrt(q.size(-1))
|
535 |
+
scores = torch.bmm(q, k.transpose(-1, -2)) / scale
|
536 |
+
if mask is not None:
|
537 |
+
mask = einops.rearrange(mask, "b ... -> b (...)")
|
538 |
+
max_neg_value = -torch.finfo(scores.dtype).max
|
539 |
+
mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3))
|
540 |
+
scores = scores.masked_fill(~mask, max_neg_value)
|
541 |
+
p_attn = F.softmax(scores, dim=-1)
|
542 |
+
return torch.bmm(p_attn, v)
|
543 |
+
|
544 |
+
|
545 |
+
def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"):
|
546 |
+
"""
|
547 |
+
q, k, v: [B, L, D]
|
548 |
+
"""
|
549 |
+
pre_attn_layout = MEMORY_LAYOUTS[mode][0]
|
550 |
+
post_attn_layout = MEMORY_LAYOUTS[mode][1]
|
551 |
+
q = pre_attn_layout(q, head_dim)
|
552 |
+
k = pre_attn_layout(k, head_dim)
|
553 |
+
v = pre_attn_layout(v, head_dim)
|
554 |
+
|
555 |
+
# scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale)
|
556 |
+
if mode == "torch":
|
557 |
+
assert scale is None
|
558 |
+
scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask) # , scale=scale)
|
559 |
+
elif mode == "xformers":
|
560 |
+
scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale)
|
561 |
+
else:
|
562 |
+
scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale)
|
563 |
+
|
564 |
+
scores = post_attn_layout(scores)
|
565 |
+
return scores
|
566 |
+
|
567 |
+
|
568 |
+
# DismantledBlock in mmdit.py
|
569 |
+
class SingleDiTBlock(nn.Module):
|
570 |
+
"""
|
571 |
+
A DiT block with gated adaptive layer norm (adaLN) conditioning.
|
572 |
+
"""
|
573 |
+
|
574 |
+
def __init__(
|
575 |
+
self,
|
576 |
+
hidden_size: int,
|
577 |
+
num_heads: int,
|
578 |
+
mlp_ratio: float = 4.0,
|
579 |
+
attn_mode: str = "xformers",
|
580 |
+
qkv_bias: bool = False,
|
581 |
+
pre_only: bool = False,
|
582 |
+
rmsnorm: bool = False,
|
583 |
+
scale_mod_only: bool = False,
|
584 |
+
swiglu: bool = False,
|
585 |
+
qk_norm: Optional[str] = None,
|
586 |
+
x_block_self_attn: bool = False,
|
587 |
+
**block_kwargs,
|
588 |
+
):
|
589 |
+
super().__init__()
|
590 |
+
assert attn_mode in MEMORY_LAYOUTS
|
591 |
+
self.attn_mode = attn_mode
|
592 |
+
if not rmsnorm:
|
593 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
594 |
+
else:
|
595 |
+
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
596 |
+
self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm)
|
597 |
+
|
598 |
+
self.x_block_self_attn = x_block_self_attn
|
599 |
+
if self.x_block_self_attn:
|
600 |
+
assert not pre_only
|
601 |
+
assert not scale_mod_only
|
602 |
+
self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm)
|
603 |
+
|
604 |
+
if not pre_only:
|
605 |
+
if not rmsnorm:
|
606 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
607 |
+
else:
|
608 |
+
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
609 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
610 |
+
if not pre_only:
|
611 |
+
if not swiglu:
|
612 |
+
self.mlp = MLP(
|
613 |
+
in_features=hidden_size,
|
614 |
+
hidden_features=mlp_hidden_dim,
|
615 |
+
act_layer=lambda: nn.GELU(approximate="tanh"),
|
616 |
+
)
|
617 |
+
else:
|
618 |
+
self.mlp = SwiGLUFeedForward(
|
619 |
+
dim=hidden_size,
|
620 |
+
hidden_dim=mlp_hidden_dim,
|
621 |
+
multiple_of=256,
|
622 |
+
)
|
623 |
+
self.scale_mod_only = scale_mod_only
|
624 |
+
if self.x_block_self_attn:
|
625 |
+
n_mods = 9
|
626 |
+
elif not scale_mod_only:
|
627 |
+
n_mods = 6 if not pre_only else 2
|
628 |
+
else:
|
629 |
+
n_mods = 4 if not pre_only else 1
|
630 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size))
|
631 |
+
self.pre_only = pre_only
|
632 |
+
|
633 |
+
def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
634 |
+
if not self.pre_only:
|
635 |
+
if not self.scale_mod_only:
|
636 |
+
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1)
|
637 |
+
else:
|
638 |
+
shift_msa = None
|
639 |
+
shift_mlp = None
|
640 |
+
(scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1)
|
641 |
+
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
642 |
+
return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
643 |
+
else:
|
644 |
+
if not self.scale_mod_only:
|
645 |
+
(shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1)
|
646 |
+
else:
|
647 |
+
shift_msa = None
|
648 |
+
scale_msa = self.adaLN_modulation(c)
|
649 |
+
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
650 |
+
return qkv, None
|
651 |
+
|
652 |
+
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
653 |
+
assert self.x_block_self_attn
|
654 |
+
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation(
|
655 |
+
c
|
656 |
+
).chunk(9, dim=1)
|
657 |
+
x_norm = self.norm1(x)
|
658 |
+
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
659 |
+
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
660 |
+
return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2)
|
661 |
+
|
662 |
+
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
|
663 |
+
assert not self.pre_only
|
664 |
+
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
665 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
666 |
+
return x
|
667 |
+
|
668 |
+
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0):
|
669 |
+
assert not self.pre_only
|
670 |
+
if attn1_dropout > 0.0:
|
671 |
+
# Use torch.bernoulli to implement dropout, only dropout the batch dimension
|
672 |
+
attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device))
|
673 |
+
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout
|
674 |
+
else:
|
675 |
+
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
676 |
+
x = x + attn_
|
677 |
+
attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2)
|
678 |
+
x = x + attn2_
|
679 |
+
mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
680 |
+
x = x + mlp_
|
681 |
+
return x
|
682 |
+
|
683 |
+
|
684 |
+
# JointBlock + block_mixing in mmdit.py
|
685 |
+
class MMDiTBlock(nn.Module):
|
686 |
+
def __init__(self, *args, **kwargs):
|
687 |
+
super().__init__()
|
688 |
+
pre_only = kwargs.pop("pre_only")
|
689 |
+
x_block_self_attn = kwargs.pop("x_block_self_attn")
|
690 |
+
|
691 |
+
self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs)
|
692 |
+
self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs)
|
693 |
+
|
694 |
+
self.head_dim = self.x_block.attn.head_dim
|
695 |
+
self.mode = self.x_block.attn_mode
|
696 |
+
self.gradient_checkpointing = False
|
697 |
+
|
698 |
+
def enable_gradient_checkpointing(self):
|
699 |
+
self.gradient_checkpointing = True
|
700 |
+
|
701 |
+
def _forward(self, context, x, c):
|
702 |
+
ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c)
|
703 |
+
|
704 |
+
if self.x_block.x_block_self_attn:
|
705 |
+
x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c)
|
706 |
+
else:
|
707 |
+
x_qkv, x_intermediates = self.x_block.pre_attention(x, c)
|
708 |
+
|
709 |
+
ctx_len = ctx_qkv[0].size(1)
|
710 |
+
|
711 |
+
q = torch.concat((ctx_qkv[0], x_qkv[0]), dim=1)
|
712 |
+
k = torch.concat((ctx_qkv[1], x_qkv[1]), dim=1)
|
713 |
+
v = torch.concat((ctx_qkv[2], x_qkv[2]), dim=1)
|
714 |
+
|
715 |
+
attn = attention(q, k, v, head_dim=self.head_dim, mode=self.mode)
|
716 |
+
ctx_attn_out = attn[:, :ctx_len]
|
717 |
+
x_attn_out = attn[:, ctx_len:]
|
718 |
+
|
719 |
+
if self.x_block.x_block_self_attn:
|
720 |
+
x_q2, x_k2, x_v2 = x_qkv2
|
721 |
+
attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads, mode=self.mode)
|
722 |
+
x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates)
|
723 |
+
else:
|
724 |
+
x = self.x_block.post_attention(x_attn_out, *x_intermediates)
|
725 |
+
|
726 |
+
if not self.context_block.pre_only:
|
727 |
+
context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate)
|
728 |
+
else:
|
729 |
+
context = None
|
730 |
+
|
731 |
+
return context, x
|
732 |
+
|
733 |
+
def forward(self, *args, **kwargs):
|
734 |
+
if self.training and self.gradient_checkpointing:
|
735 |
+
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
736 |
+
else:
|
737 |
+
return self._forward(*args, **kwargs)
|
738 |
+
|
739 |
+
|
740 |
+
class MMDiT(nn.Module):
|
741 |
+
"""
|
742 |
+
Diffusion model with a Transformer backbone.
|
743 |
+
"""
|
744 |
+
|
745 |
+
# prepare pos_embed for latent size * 2
|
746 |
+
POS_EMBED_MAX_RATIO = 1.5
|
747 |
+
|
748 |
+
def __init__(
|
749 |
+
self,
|
750 |
+
input_size: int = 32,
|
751 |
+
patch_size: int = 2,
|
752 |
+
in_channels: int = 4,
|
753 |
+
depth: int = 28,
|
754 |
+
# hidden_size: Optional[int] = None,
|
755 |
+
# num_heads: Optional[int] = None,
|
756 |
+
mlp_ratio: float = 4.0,
|
757 |
+
learn_sigma: bool = False,
|
758 |
+
adm_in_channels: Optional[int] = None,
|
759 |
+
context_embedder_in_features: Optional[int] = None,
|
760 |
+
context_embedder_out_features: Optional[int] = None,
|
761 |
+
use_checkpoint: bool = False,
|
762 |
+
register_length: int = 0,
|
763 |
+
attn_mode: str = "torch",
|
764 |
+
rmsnorm: bool = False,
|
765 |
+
scale_mod_only: bool = False,
|
766 |
+
swiglu: bool = False,
|
767 |
+
out_channels: Optional[int] = None,
|
768 |
+
pos_embed_scaling_factor: Optional[float] = None,
|
769 |
+
pos_embed_offset: Optional[float] = None,
|
770 |
+
pos_embed_max_size: Optional[int] = None,
|
771 |
+
num_patches=None,
|
772 |
+
qk_norm: Optional[str] = None,
|
773 |
+
x_block_self_attn_layers: Optional[list[int]] = [],
|
774 |
+
qkv_bias: bool = True,
|
775 |
+
pos_emb_random_crop_rate: float = 0.0,
|
776 |
+
use_scaled_pos_embed: bool = False,
|
777 |
+
pos_embed_latent_sizes: Optional[list[int]] = None,
|
778 |
+
model_type: str = "sd3m",
|
779 |
+
):
|
780 |
+
super().__init__()
|
781 |
+
self._model_type = model_type
|
782 |
+
self.learn_sigma = learn_sigma
|
783 |
+
self.in_channels = in_channels
|
784 |
+
default_out_channels = in_channels * 2 if learn_sigma else in_channels
|
785 |
+
self.out_channels = default(out_channels, default_out_channels)
|
786 |
+
self.patch_size = patch_size
|
787 |
+
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
788 |
+
self.pos_embed_offset = pos_embed_offset
|
789 |
+
self.pos_embed_max_size = pos_embed_max_size
|
790 |
+
self.x_block_self_attn_layers = x_block_self_attn_layers
|
791 |
+
self.pos_emb_random_crop_rate = pos_emb_random_crop_rate
|
792 |
+
self.gradient_checkpointing = use_checkpoint
|
793 |
+
|
794 |
+
# hidden_size = default(hidden_size, 64 * depth)
|
795 |
+
# num_heads = default(num_heads, hidden_size // 64)
|
796 |
+
|
797 |
+
# apply magic --> this defines a head_size of 64
|
798 |
+
self.hidden_size = 64 * depth
|
799 |
+
num_heads = depth
|
800 |
+
|
801 |
+
self.num_heads = num_heads
|
802 |
+
|
803 |
+
self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes)
|
804 |
+
|
805 |
+
self.x_embedder = PatchEmbed(
|
806 |
+
input_size,
|
807 |
+
patch_size,
|
808 |
+
in_channels,
|
809 |
+
self.hidden_size,
|
810 |
+
bias=True,
|
811 |
+
strict_img_size=self.pos_embed_max_size is None,
|
812 |
+
)
|
813 |
+
self.t_embedder = TimestepEmbedding(self.hidden_size)
|
814 |
+
|
815 |
+
self.y_embedder = None
|
816 |
+
if adm_in_channels is not None:
|
817 |
+
assert isinstance(adm_in_channels, int)
|
818 |
+
self.y_embedder = Embedder(adm_in_channels, self.hidden_size)
|
819 |
+
|
820 |
+
if context_embedder_in_features is not None:
|
821 |
+
self.context_embedder = nn.Linear(context_embedder_in_features, context_embedder_out_features)
|
822 |
+
else:
|
823 |
+
self.context_embedder = nn.Identity()
|
824 |
+
|
825 |
+
self.register_length = register_length
|
826 |
+
if self.register_length > 0:
|
827 |
+
self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size))
|
828 |
+
|
829 |
+
# num_patches = self.x_embedder.num_patches
|
830 |
+
# Will use fixed sin-cos embedding:
|
831 |
+
# just use a buffer already
|
832 |
+
if num_patches is not None:
|
833 |
+
self.register_buffer(
|
834 |
+
"pos_embed",
|
835 |
+
torch.empty(1, num_patches, self.hidden_size),
|
836 |
+
)
|
837 |
+
else:
|
838 |
+
self.pos_embed = None
|
839 |
+
|
840 |
+
self.use_checkpoint = use_checkpoint
|
841 |
+
self.joint_blocks = nn.ModuleList(
|
842 |
+
[
|
843 |
+
MMDiTBlock(
|
844 |
+
self.hidden_size,
|
845 |
+
num_heads,
|
846 |
+
mlp_ratio=mlp_ratio,
|
847 |
+
attn_mode=attn_mode,
|
848 |
+
qkv_bias=qkv_bias,
|
849 |
+
pre_only=i == depth - 1,
|
850 |
+
rmsnorm=rmsnorm,
|
851 |
+
scale_mod_only=scale_mod_only,
|
852 |
+
swiglu=swiglu,
|
853 |
+
qk_norm=qk_norm,
|
854 |
+
x_block_self_attn=(i in self.x_block_self_attn_layers),
|
855 |
+
)
|
856 |
+
for i in range(depth)
|
857 |
+
]
|
858 |
+
)
|
859 |
+
for block in self.joint_blocks:
|
860 |
+
block.gradient_checkpointing = use_checkpoint
|
861 |
+
|
862 |
+
self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels)
|
863 |
+
# self.initialize_weights()
|
864 |
+
|
865 |
+
self.blocks_to_swap = None
|
866 |
+
self.offloader = None
|
867 |
+
self.num_blocks = len(self.joint_blocks)
|
868 |
+
|
869 |
+
def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]):
|
870 |
+
self.use_scaled_pos_embed = use_scaled_pos_embed
|
871 |
+
|
872 |
+
if self.use_scaled_pos_embed:
|
873 |
+
# remove pos_embed to free up memory up to 0.4 GB
|
874 |
+
self.pos_embed = None
|
875 |
+
|
876 |
+
# remove duplicates and sort latent sizes in ascending order
|
877 |
+
latent_sizes = list(set(latent_sizes))
|
878 |
+
latent_sizes = sorted(latent_sizes)
|
879 |
+
|
880 |
+
patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes]
|
881 |
+
|
882 |
+
# calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape
|
883 |
+
max_areas = []
|
884 |
+
for i in range(1, len(patched_sizes)):
|
885 |
+
prev_area = patched_sizes[i - 1] ** 2
|
886 |
+
area = patched_sizes[i] ** 2
|
887 |
+
max_areas.append((prev_area + area) // 2)
|
888 |
+
|
889 |
+
# area of the last latent size, if the latent size exceeds this, error will be raised
|
890 |
+
max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2))
|
891 |
+
# print("max_areas", max_areas)
|
892 |
+
|
893 |
+
self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)]
|
894 |
+
|
895 |
+
self.resolution_pos_embeds = {}
|
896 |
+
for patched_size in patched_sizes:
|
897 |
+
grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
|
898 |
+
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size)
|
899 |
+
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
|
900 |
+
self.resolution_pos_embeds[patched_size] = pos_embed
|
901 |
+
# print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}")
|
902 |
+
|
903 |
+
else:
|
904 |
+
self.resolution_area_to_latent_size = None
|
905 |
+
self.resolution_pos_embeds = None
|
906 |
+
|
907 |
+
@property
|
908 |
+
def model_type(self):
|
909 |
+
return self._model_type
|
910 |
+
|
911 |
+
@property
|
912 |
+
def device(self):
|
913 |
+
return next(self.parameters()).device
|
914 |
+
|
915 |
+
@property
|
916 |
+
def dtype(self):
|
917 |
+
return next(self.parameters()).dtype
|
918 |
+
|
919 |
+
def enable_gradient_checkpointing(self):
|
920 |
+
self.gradient_checkpointing = True
|
921 |
+
for block in self.joint_blocks:
|
922 |
+
block.enable_gradient_checkpointing()
|
923 |
+
|
924 |
+
def disable_gradient_checkpointing(self):
|
925 |
+
self.gradient_checkpointing = False
|
926 |
+
for block in self.joint_blocks:
|
927 |
+
block.disable_gradient_checkpointing()
|
928 |
+
|
929 |
+
def initialize_weights(self):
|
930 |
+
# TODO: Init context_embedder?
|
931 |
+
# Initialize transformer layers:
|
932 |
+
def _basic_init(module):
|
933 |
+
if isinstance(module, nn.Linear):
|
934 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
935 |
+
if module.bias is not None:
|
936 |
+
nn.init.constant_(module.bias, 0)
|
937 |
+
|
938 |
+
self.apply(_basic_init)
|
939 |
+
|
940 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding
|
941 |
+
if self.pos_embed is not None:
|
942 |
+
pos_embed = get_2d_sincos_pos_embed(
|
943 |
+
self.pos_embed.shape[-1],
|
944 |
+
int(self.pos_embed.shape[-2] ** 0.5),
|
945 |
+
scaling_factor=self.pos_embed_scaling_factor,
|
946 |
+
)
|
947 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
948 |
+
|
949 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
950 |
+
w = self.x_embedder.proj.weight.data
|
951 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
952 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
953 |
+
|
954 |
+
if getattr(self, "y_embedder", None) is not None:
|
955 |
+
nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02)
|
956 |
+
nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02)
|
957 |
+
|
958 |
+
# Initialize timestep embedding MLP:
|
959 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
960 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
961 |
+
|
962 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
963 |
+
for block in self.joint_blocks:
|
964 |
+
nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0)
|
965 |
+
nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0)
|
966 |
+
nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0)
|
967 |
+
nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0)
|
968 |
+
|
969 |
+
# Zero-out output layers:
|
970 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
971 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
972 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
973 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
974 |
+
|
975 |
+
def set_pos_emb_random_crop_rate(self, rate: float):
|
976 |
+
self.pos_emb_random_crop_rate = rate
|
977 |
+
|
978 |
+
def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False):
|
979 |
+
p = self.x_embedder.patch_size
|
980 |
+
# patched size
|
981 |
+
h = (h + 1) // p
|
982 |
+
w = (w + 1) // p
|
983 |
+
if self.pos_embed is None: # should not happen
|
984 |
+
return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
|
985 |
+
assert self.pos_embed_max_size is not None
|
986 |
+
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
987 |
+
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
|
988 |
+
|
989 |
+
if not random_crop:
|
990 |
+
top = (self.pos_embed_max_size - h) // 2
|
991 |
+
left = (self.pos_embed_max_size - w) // 2
|
992 |
+
else:
|
993 |
+
top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item()
|
994 |
+
left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item()
|
995 |
+
|
996 |
+
spatial_pos_embed = self.pos_embed.reshape(
|
997 |
+
1,
|
998 |
+
self.pos_embed_max_size,
|
999 |
+
self.pos_embed_max_size,
|
1000 |
+
self.pos_embed.shape[-1],
|
1001 |
+
)
|
1002 |
+
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
1003 |
+
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
1004 |
+
return spatial_pos_embed
|
1005 |
+
|
1006 |
+
def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False):
|
1007 |
+
p = self.x_embedder.patch_size
|
1008 |
+
# patched size
|
1009 |
+
h = (h + 1) // p
|
1010 |
+
w = (w + 1) // p
|
1011 |
+
|
1012 |
+
# select pos_embed size based on area
|
1013 |
+
area = h * w
|
1014 |
+
patched_size = None
|
1015 |
+
for area_, patched_size_ in self.resolution_area_to_latent_size:
|
1016 |
+
if area <= area_:
|
1017 |
+
patched_size = patched_size_
|
1018 |
+
break
|
1019 |
+
if patched_size is None:
|
1020 |
+
raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
|
1021 |
+
|
1022 |
+
pos_embed = self.resolution_pos_embeds[patched_size]
|
1023 |
+
pos_embed_size = round(math.sqrt(pos_embed.shape[1]))
|
1024 |
+
if h > pos_embed_size or w > pos_embed_size:
|
1025 |
+
# # fallback to normal pos_embed
|
1026 |
+
# return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop)
|
1027 |
+
# extend pos_embed size
|
1028 |
+
logger.warning(
|
1029 |
+
f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
|
1030 |
+
)
|
1031 |
+
pos_embed_size = max(h, w)
|
1032 |
+
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size)
|
1033 |
+
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
|
1034 |
+
self.resolution_pos_embeds[patched_size] = pos_embed
|
1035 |
+
logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}")
|
1036 |
+
|
1037 |
+
if not random_crop:
|
1038 |
+
top = (pos_embed_size - h) // 2
|
1039 |
+
left = (pos_embed_size - w) // 2
|
1040 |
+
else:
|
1041 |
+
top = torch.randint(0, pos_embed_size - h + 1, (1,)).item()
|
1042 |
+
left = torch.randint(0, pos_embed_size - w + 1, (1,)).item()
|
1043 |
+
|
1044 |
+
if pos_embed.device != device:
|
1045 |
+
pos_embed = pos_embed.to(device)
|
1046 |
+
# which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device.
|
1047 |
+
self.resolution_pos_embeds[patched_size] = pos_embed # update device
|
1048 |
+
if pos_embed.dtype != dtype:
|
1049 |
+
pos_embed = pos_embed.to(dtype)
|
1050 |
+
self.resolution_pos_embeds[patched_size] = pos_embed # update dtype
|
1051 |
+
|
1052 |
+
spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1])
|
1053 |
+
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
1054 |
+
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
1055 |
+
# print(
|
1056 |
+
# f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}"
|
1057 |
+
# )
|
1058 |
+
return spatial_pos_embed
|
1059 |
+
|
1060 |
+
def enable_block_swap(self, num_blocks: int, device: torch.device):
|
1061 |
+
self.blocks_to_swap = num_blocks
|
1062 |
+
|
1063 |
+
assert (
|
1064 |
+
self.blocks_to_swap <= self.num_blocks - 2
|
1065 |
+
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
|
1066 |
+
|
1067 |
+
self.offloader = custom_offloading_utils.ModelOffloader(
|
1068 |
+
self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
|
1069 |
+
)
|
1070 |
+
print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
|
1071 |
+
|
1072 |
+
def move_to_device_except_swap_blocks(self, device: torch.device):
|
1073 |
+
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
1074 |
+
if self.blocks_to_swap:
|
1075 |
+
save_blocks = self.joint_blocks
|
1076 |
+
self.joint_blocks = None
|
1077 |
+
|
1078 |
+
self.to(device)
|
1079 |
+
|
1080 |
+
if self.blocks_to_swap:
|
1081 |
+
self.joint_blocks = save_blocks
|
1082 |
+
|
1083 |
+
def prepare_block_swap_before_forward(self):
|
1084 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
1085 |
+
return
|
1086 |
+
self.offloader.prepare_block_devices_before_forward(self.joint_blocks)
|
1087 |
+
|
1088 |
+
def forward(
|
1089 |
+
self,
|
1090 |
+
x: torch.Tensor,
|
1091 |
+
t: torch.Tensor,
|
1092 |
+
y: Optional[torch.Tensor] = None,
|
1093 |
+
context: Optional[torch.Tensor] = None,
|
1094 |
+
) -> torch.Tensor:
|
1095 |
+
"""
|
1096 |
+
Forward pass of DiT.
|
1097 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
1098 |
+
t: (N,) tensor of diffusion timesteps
|
1099 |
+
y: (N, D) tensor of class labels
|
1100 |
+
"""
|
1101 |
+
pos_emb_random_crop = (
|
1102 |
+
False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate
|
1103 |
+
)
|
1104 |
+
|
1105 |
+
B, C, H, W = x.shape
|
1106 |
+
|
1107 |
+
# x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
|
1108 |
+
if not self.use_scaled_pos_embed:
|
1109 |
+
pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
|
1110 |
+
else:
|
1111 |
+
# print(f"Using scaled pos_embed for size {H}x{W}")
|
1112 |
+
pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop)
|
1113 |
+
x = self.x_embedder(x) + pos_embed
|
1114 |
+
del pos_embed
|
1115 |
+
|
1116 |
+
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
1117 |
+
if y is not None and self.y_embedder is not None:
|
1118 |
+
y = self.y_embedder(y) # (N, D)
|
1119 |
+
c = c + y # (N, D)
|
1120 |
+
|
1121 |
+
if context is not None:
|
1122 |
+
context = self.context_embedder(context)
|
1123 |
+
|
1124 |
+
if self.register_length > 0:
|
1125 |
+
context = torch.cat(
|
1126 |
+
(einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), default(context, torch.Tensor([]).type_as(x))), 1
|
1127 |
+
)
|
1128 |
+
|
1129 |
+
if not self.blocks_to_swap:
|
1130 |
+
for block in self.joint_blocks:
|
1131 |
+
context, x = block(context, x, c)
|
1132 |
+
else:
|
1133 |
+
for block_idx, block in enumerate(self.joint_blocks):
|
1134 |
+
self.offloader.wait_for_block(block_idx)
|
1135 |
+
|
1136 |
+
context, x = block(context, x, c)
|
1137 |
+
|
1138 |
+
self.offloader.submit_move_blocks(self.joint_blocks, block_idx)
|
1139 |
+
|
1140 |
+
x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify
|
1141 |
+
return x[:, :, :H, :W]
|
1142 |
+
|
1143 |
+
|
1144 |
+
def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT:
|
1145 |
+
mmdit = MMDiT(
|
1146 |
+
input_size=None,
|
1147 |
+
pos_embed_max_size=params.pos_embed_max_size,
|
1148 |
+
patch_size=params.patch_size,
|
1149 |
+
in_channels=16,
|
1150 |
+
adm_in_channels=params.adm_in_channels,
|
1151 |
+
context_embedder_in_features=params.context_embedder_in_features,
|
1152 |
+
context_embedder_out_features=params.context_embedder_out_features,
|
1153 |
+
depth=params.depth,
|
1154 |
+
mlp_ratio=4,
|
1155 |
+
qk_norm=params.qk_norm,
|
1156 |
+
x_block_self_attn_layers=params.x_block_self_attn_layers,
|
1157 |
+
num_patches=params.num_patches,
|
1158 |
+
attn_mode=attn_mode,
|
1159 |
+
model_type=params.model_type,
|
1160 |
+
)
|
1161 |
+
return mmdit
|
1162 |
+
|
1163 |
+
|
1164 |
+
# endregion
|
1165 |
+
|
1166 |
+
# region VAE
|
1167 |
+
|
1168 |
+
VAE_SCALE_FACTOR = 1.5305
|
1169 |
+
VAE_SHIFT_FACTOR = 0.0609
|
1170 |
+
|
1171 |
+
|
1172 |
+
def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
|
1173 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
1174 |
+
|
1175 |
+
|
1176 |
+
class ResnetBlock(torch.nn.Module):
|
1177 |
+
def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):
|
1178 |
+
super().__init__()
|
1179 |
+
self.in_channels = in_channels
|
1180 |
+
out_channels = in_channels if out_channels is None else out_channels
|
1181 |
+
self.out_channels = out_channels
|
1182 |
+
|
1183 |
+
self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
|
1184 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
1185 |
+
self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
|
1186 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
1187 |
+
if self.in_channels != self.out_channels:
|
1188 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
1189 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device
|
1190 |
+
)
|
1191 |
+
else:
|
1192 |
+
self.nin_shortcut = None
|
1193 |
+
self.swish = torch.nn.SiLU(inplace=True)
|
1194 |
+
|
1195 |
+
def forward(self, x):
|
1196 |
+
hidden = x
|
1197 |
+
hidden = self.norm1(hidden)
|
1198 |
+
hidden = self.swish(hidden)
|
1199 |
+
hidden = self.conv1(hidden)
|
1200 |
+
hidden = self.norm2(hidden)
|
1201 |
+
hidden = self.swish(hidden)
|
1202 |
+
hidden = self.conv2(hidden)
|
1203 |
+
if self.in_channels != self.out_channels:
|
1204 |
+
x = self.nin_shortcut(x)
|
1205 |
+
return x + hidden
|
1206 |
+
|
1207 |
+
|
1208 |
+
class AttnBlock(torch.nn.Module):
|
1209 |
+
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
1210 |
+
super().__init__()
|
1211 |
+
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
1212 |
+
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
|
1213 |
+
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
|
1214 |
+
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
|
1215 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
|
1216 |
+
|
1217 |
+
def forward(self, x):
|
1218 |
+
hidden = self.norm(x)
|
1219 |
+
q = self.q(hidden)
|
1220 |
+
k = self.k(hidden)
|
1221 |
+
v = self.v(hidden)
|
1222 |
+
b, c, h, w = q.shape
|
1223 |
+
q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
|
1224 |
+
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
|
1225 |
+
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
1226 |
+
hidden = self.proj_out(hidden)
|
1227 |
+
return x + hidden
|
1228 |
+
|
1229 |
+
|
1230 |
+
class Downsample(torch.nn.Module):
|
1231 |
+
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
1232 |
+
super().__init__()
|
1233 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device)
|
1234 |
+
|
1235 |
+
def forward(self, x):
|
1236 |
+
pad = (0, 1, 0, 1)
|
1237 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
1238 |
+
x = self.conv(x)
|
1239 |
+
return x
|
1240 |
+
|
1241 |
+
|
1242 |
+
class Upsample(torch.nn.Module):
|
1243 |
+
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
1244 |
+
super().__init__()
|
1245 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
1246 |
+
|
1247 |
+
def forward(self, x):
|
1248 |
+
org_dtype = x.dtype
|
1249 |
+
if x.dtype == torch.bfloat16:
|
1250 |
+
x = x.to(torch.float32)
|
1251 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
1252 |
+
if x.dtype != org_dtype:
|
1253 |
+
x = x.to(org_dtype)
|
1254 |
+
x = self.conv(x)
|
1255 |
+
return x
|
1256 |
+
|
1257 |
+
|
1258 |
+
class VAEEncoder(torch.nn.Module):
|
1259 |
+
def __init__(
|
1260 |
+
self, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None
|
1261 |
+
):
|
1262 |
+
super().__init__()
|
1263 |
+
self.num_resolutions = len(ch_mult)
|
1264 |
+
self.num_res_blocks = num_res_blocks
|
1265 |
+
# downsampling
|
1266 |
+
self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
1267 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
1268 |
+
self.in_ch_mult = in_ch_mult
|
1269 |
+
self.down = torch.nn.ModuleList()
|
1270 |
+
for i_level in range(self.num_resolutions):
|
1271 |
+
block = torch.nn.ModuleList()
|
1272 |
+
attn = torch.nn.ModuleList()
|
1273 |
+
block_in = ch * in_ch_mult[i_level]
|
1274 |
+
block_out = ch * ch_mult[i_level]
|
1275 |
+
for i_block in range(num_res_blocks):
|
1276 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
|
1277 |
+
block_in = block_out
|
1278 |
+
down = torch.nn.Module()
|
1279 |
+
down.block = block
|
1280 |
+
down.attn = attn
|
1281 |
+
if i_level != self.num_resolutions - 1:
|
1282 |
+
down.downsample = Downsample(block_in, dtype=dtype, device=device)
|
1283 |
+
self.down.append(down)
|
1284 |
+
# middle
|
1285 |
+
self.mid = torch.nn.Module()
|
1286 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
1287 |
+
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
1288 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
1289 |
+
# end
|
1290 |
+
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
1291 |
+
self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
1292 |
+
self.swish = torch.nn.SiLU(inplace=True)
|
1293 |
+
|
1294 |
+
def forward(self, x):
|
1295 |
+
# downsampling
|
1296 |
+
hs = [self.conv_in(x)]
|
1297 |
+
for i_level in range(self.num_resolutions):
|
1298 |
+
for i_block in range(self.num_res_blocks):
|
1299 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
1300 |
+
hs.append(h)
|
1301 |
+
if i_level != self.num_resolutions - 1:
|
1302 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
1303 |
+
# middle
|
1304 |
+
h = hs[-1]
|
1305 |
+
h = self.mid.block_1(h)
|
1306 |
+
h = self.mid.attn_1(h)
|
1307 |
+
h = self.mid.block_2(h)
|
1308 |
+
# end
|
1309 |
+
h = self.norm_out(h)
|
1310 |
+
h = self.swish(h)
|
1311 |
+
h = self.conv_out(h)
|
1312 |
+
return h
|
1313 |
+
|
1314 |
+
|
1315 |
+
class VAEDecoder(torch.nn.Module):
|
1316 |
+
def __init__(
|
1317 |
+
self,
|
1318 |
+
ch=128,
|
1319 |
+
out_ch=3,
|
1320 |
+
ch_mult=(1, 2, 4, 4),
|
1321 |
+
num_res_blocks=2,
|
1322 |
+
resolution=256,
|
1323 |
+
z_channels=16,
|
1324 |
+
dtype=torch.float32,
|
1325 |
+
device=None,
|
1326 |
+
):
|
1327 |
+
super().__init__()
|
1328 |
+
self.num_resolutions = len(ch_mult)
|
1329 |
+
self.num_res_blocks = num_res_blocks
|
1330 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
1331 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
1332 |
+
# z to block_in
|
1333 |
+
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
1334 |
+
# middle
|
1335 |
+
self.mid = torch.nn.Module()
|
1336 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
1337 |
+
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
1338 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
1339 |
+
# upsampling
|
1340 |
+
self.up = torch.nn.ModuleList()
|
1341 |
+
for i_level in reversed(range(self.num_resolutions)):
|
1342 |
+
block = torch.nn.ModuleList()
|
1343 |
+
block_out = ch * ch_mult[i_level]
|
1344 |
+
for i_block in range(self.num_res_blocks + 1):
|
1345 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
|
1346 |
+
block_in = block_out
|
1347 |
+
up = torch.nn.Module()
|
1348 |
+
up.block = block
|
1349 |
+
if i_level != 0:
|
1350 |
+
up.upsample = Upsample(block_in, dtype=dtype, device=device)
|
1351 |
+
curr_res = curr_res * 2
|
1352 |
+
self.up.insert(0, up) # prepend to get consistent order
|
1353 |
+
# end
|
1354 |
+
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
1355 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
1356 |
+
self.swish = torch.nn.SiLU(inplace=True)
|
1357 |
+
|
1358 |
+
def forward(self, z):
|
1359 |
+
# z to block_in
|
1360 |
+
hidden = self.conv_in(z)
|
1361 |
+
# middle
|
1362 |
+
hidden = self.mid.block_1(hidden)
|
1363 |
+
hidden = self.mid.attn_1(hidden)
|
1364 |
+
hidden = self.mid.block_2(hidden)
|
1365 |
+
# upsampling
|
1366 |
+
for i_level in reversed(range(self.num_resolutions)):
|
1367 |
+
for i_block in range(self.num_res_blocks + 1):
|
1368 |
+
hidden = self.up[i_level].block[i_block](hidden)
|
1369 |
+
if i_level != 0:
|
1370 |
+
hidden = self.up[i_level].upsample(hidden)
|
1371 |
+
# end
|
1372 |
+
hidden = self.norm_out(hidden)
|
1373 |
+
hidden = self.swish(hidden)
|
1374 |
+
hidden = self.conv_out(hidden)
|
1375 |
+
return hidden
|
1376 |
+
|
1377 |
+
|
1378 |
+
class SDVAE(torch.nn.Module):
|
1379 |
+
def __init__(self, dtype=torch.float32, device=None):
|
1380 |
+
super().__init__()
|
1381 |
+
self.encoder = VAEEncoder(dtype=dtype, device=device)
|
1382 |
+
self.decoder = VAEDecoder(dtype=dtype, device=device)
|
1383 |
+
|
1384 |
+
@property
|
1385 |
+
def device(self):
|
1386 |
+
return next(self.parameters()).device
|
1387 |
+
|
1388 |
+
@property
|
1389 |
+
def dtype(self):
|
1390 |
+
return next(self.parameters()).dtype
|
1391 |
+
|
1392 |
+
# @torch.autocast("cuda", dtype=torch.float16)
|
1393 |
+
def decode(self, latent):
|
1394 |
+
return self.decoder(latent)
|
1395 |
+
|
1396 |
+
# @torch.autocast("cuda", dtype=torch.float16)
|
1397 |
+
def encode(self, image):
|
1398 |
+
hidden = self.encoder(image)
|
1399 |
+
mean, logvar = torch.chunk(hidden, 2, dim=1)
|
1400 |
+
logvar = torch.clamp(logvar, -30.0, 20.0)
|
1401 |
+
std = torch.exp(0.5 * logvar)
|
1402 |
+
return mean + std * torch.randn_like(mean)
|
1403 |
+
|
1404 |
+
@staticmethod
|
1405 |
+
def process_in(latent):
|
1406 |
+
return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR
|
1407 |
+
|
1408 |
+
@staticmethod
|
1409 |
+
def process_out(latent):
|
1410 |
+
return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR
|
1411 |
+
|
1412 |
+
|
1413 |
+
# endregion
|