Upload 19 files
Browse files- .gitattributes +1 -0
- model/__init__.py +12 -0
- model/attention.py +298 -0
- model/bsrnet.py +104 -0
- model/cldm.py +155 -0
- model/clip.py +65 -0
- model/config.py +62 -0
- model/controlnet.py +277 -0
- model/distributions.py +92 -0
- model/gaussian_diffusion.py +118 -0
- model/open_clip/__init__.py +4 -0
- model/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- model/open_clip/model.py +206 -0
- model/open_clip/tokenizer.py +214 -0
- model/open_clip/transformer.py +736 -0
- model/scunet.py +264 -0
- model/swinir.py +905 -0
- model/unet.py +719 -0
- model/util.py +225 -0
- model/vae.py +567 -0
.gitattributes
CHANGED
@@ -6,3 +6,4 @@ assets/visual_results/bsr6.png filter=lfs diff=lfs merge=lfs -text
|
|
6 |
assets/visual_results/tiled_sampling.png filter=lfs diff=lfs merge=lfs -text
|
7 |
assets/visual_results/whole_image1.png filter=lfs diff=lfs merge=lfs -text
|
8 |
assets/visual_results/whole_image2.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
6 |
assets/visual_results/tiled_sampling.png filter=lfs diff=lfs merge=lfs -text
|
7 |
assets/visual_results/whole_image1.png filter=lfs diff=lfs merge=lfs -text
|
8 |
assets/visual_results/whole_image2.png filter=lfs diff=lfs merge=lfs -text
|
9 |
+
model/open_clip/bpe_simple_vocab_16e6.txt.gz filter=lfs diff=lfs merge=lfs -text
|
model/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import config
|
2 |
+
|
3 |
+
from .controlnet import ControlledUnetModel, ControlNet
|
4 |
+
from .vae import AutoencoderKL
|
5 |
+
from .clip import FrozenOpenCLIPEmbedder
|
6 |
+
|
7 |
+
from .cldm import ControlLDM
|
8 |
+
from .gaussian_diffusion import Diffusion
|
9 |
+
|
10 |
+
from .swinir import SwinIR
|
11 |
+
from .bsrnet import RRDBNet
|
12 |
+
from .scunet import SCUNet
|
model/attention.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from packaging import version
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn, einsum
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
from typing import Optional, Any
|
7 |
+
|
8 |
+
from model.util import (
|
9 |
+
checkpoint, zero_module, exists, default
|
10 |
+
)
|
11 |
+
from model.config import Config, AttnMode
|
12 |
+
|
13 |
+
|
14 |
+
# CrossAttn precision handling
|
15 |
+
import os
|
16 |
+
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
17 |
+
|
18 |
+
|
19 |
+
# feedforward
|
20 |
+
class GEGLU(nn.Module):
|
21 |
+
def __init__(self, dim_in, dim_out):
|
22 |
+
super().__init__()
|
23 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
27 |
+
return x * F.gelu(gate)
|
28 |
+
|
29 |
+
|
30 |
+
class FeedForward(nn.Module):
|
31 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
32 |
+
super().__init__()
|
33 |
+
inner_dim = int(dim * mult)
|
34 |
+
dim_out = default(dim_out, dim)
|
35 |
+
project_in = nn.Sequential(
|
36 |
+
nn.Linear(dim, inner_dim),
|
37 |
+
nn.GELU()
|
38 |
+
) if not glu else GEGLU(dim, inner_dim)
|
39 |
+
|
40 |
+
self.net = nn.Sequential(
|
41 |
+
project_in,
|
42 |
+
nn.Dropout(dropout),
|
43 |
+
nn.Linear(inner_dim, dim_out)
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
return self.net(x)
|
48 |
+
|
49 |
+
|
50 |
+
def Normalize(in_channels):
|
51 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
52 |
+
|
53 |
+
|
54 |
+
class CrossAttention(nn.Module):
|
55 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
56 |
+
super().__init__()
|
57 |
+
print(f"Setting up {self.__class__.__name__} (vanilla). Query dim is {query_dim}, context_dim is {context_dim} and using "
|
58 |
+
f"{heads} heads.")
|
59 |
+
inner_dim = dim_head * heads
|
60 |
+
context_dim = default(context_dim, query_dim)
|
61 |
+
|
62 |
+
self.scale = dim_head ** -0.5
|
63 |
+
self.heads = heads
|
64 |
+
|
65 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
66 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
67 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
68 |
+
|
69 |
+
self.to_out = nn.Sequential(
|
70 |
+
nn.Linear(inner_dim, query_dim),
|
71 |
+
nn.Dropout(dropout)
|
72 |
+
)
|
73 |
+
|
74 |
+
def forward(self, x, context=None, mask=None):
|
75 |
+
h = self.heads
|
76 |
+
|
77 |
+
q = self.to_q(x)
|
78 |
+
context = default(context, x)
|
79 |
+
k = self.to_k(context)
|
80 |
+
v = self.to_v(context)
|
81 |
+
|
82 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
83 |
+
|
84 |
+
# force cast to fp32 to avoid overflowing
|
85 |
+
if _ATTN_PRECISION =="fp32":
|
86 |
+
# with torch.autocast(enabled=False, device_type = 'cuda'):
|
87 |
+
with torch.autocast(enabled=False, device_type="cuda" if str(x.device).startswith("cuda") else "cpu"):
|
88 |
+
q, k = q.float(), k.float()
|
89 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
90 |
+
else:
|
91 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
92 |
+
|
93 |
+
del q, k
|
94 |
+
|
95 |
+
if exists(mask):
|
96 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
97 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
98 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
99 |
+
sim.masked_fill_(~mask, max_neg_value)
|
100 |
+
|
101 |
+
# attention, what we cannot get enough of
|
102 |
+
sim = sim.softmax(dim=-1)
|
103 |
+
|
104 |
+
out = einsum('b i j, b j d -> b i d', sim, v)
|
105 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
106 |
+
return self.to_out(out)
|
107 |
+
|
108 |
+
|
109 |
+
class MemoryEfficientCrossAttention(nn.Module):
|
110 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
111 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
112 |
+
super().__init__()
|
113 |
+
print(f"Setting up {self.__class__.__name__} (xformers). Query dim is {query_dim}, context_dim is {context_dim} and using "
|
114 |
+
f"{heads} heads.")
|
115 |
+
inner_dim = dim_head * heads
|
116 |
+
context_dim = default(context_dim, query_dim)
|
117 |
+
|
118 |
+
self.heads = heads
|
119 |
+
self.dim_head = dim_head
|
120 |
+
|
121 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
122 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
123 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
124 |
+
|
125 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
126 |
+
self.attention_op: Optional[Any] = None
|
127 |
+
|
128 |
+
def forward(self, x, context=None, mask=None):
|
129 |
+
q = self.to_q(x)
|
130 |
+
context = default(context, x)
|
131 |
+
k = self.to_k(context)
|
132 |
+
v = self.to_v(context)
|
133 |
+
|
134 |
+
b, _, _ = q.shape
|
135 |
+
q, k, v = map(
|
136 |
+
lambda t: t.unsqueeze(3)
|
137 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
138 |
+
.permute(0, 2, 1, 3)
|
139 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
140 |
+
.contiguous(),
|
141 |
+
(q, k, v),
|
142 |
+
)
|
143 |
+
|
144 |
+
# actually compute the attention, what we cannot get enough of
|
145 |
+
out = Config.xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
146 |
+
|
147 |
+
if exists(mask):
|
148 |
+
raise NotImplementedError
|
149 |
+
out = (
|
150 |
+
out.unsqueeze(0)
|
151 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
152 |
+
.permute(0, 2, 1, 3)
|
153 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
154 |
+
)
|
155 |
+
return self.to_out(out)
|
156 |
+
|
157 |
+
|
158 |
+
class SDPCrossAttention(nn.Module):
|
159 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
160 |
+
super().__init__()
|
161 |
+
print(f"Setting up {self.__class__.__name__} (sdp). Query dim is {query_dim}, context_dim is {context_dim} and using "
|
162 |
+
f"{heads} heads.")
|
163 |
+
inner_dim = dim_head * heads
|
164 |
+
context_dim = default(context_dim, query_dim)
|
165 |
+
|
166 |
+
self.heads = heads
|
167 |
+
self.dim_head = dim_head
|
168 |
+
|
169 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
170 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
171 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
172 |
+
|
173 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
174 |
+
|
175 |
+
def forward(self, x, context=None, mask=None):
|
176 |
+
q = self.to_q(x)
|
177 |
+
context = default(context, x)
|
178 |
+
k = self.to_k(context)
|
179 |
+
v = self.to_v(context)
|
180 |
+
|
181 |
+
b, _, _ = q.shape
|
182 |
+
q, k, v = map(
|
183 |
+
lambda t: t.unsqueeze(3)
|
184 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
185 |
+
.permute(0, 2, 1, 3)
|
186 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
187 |
+
.contiguous(),
|
188 |
+
(q, k, v),
|
189 |
+
)
|
190 |
+
|
191 |
+
# actually compute the attention, what we cannot get enough of
|
192 |
+
out = F.scaled_dot_product_attention(q, k, v)
|
193 |
+
|
194 |
+
if exists(mask):
|
195 |
+
raise NotImplementedError
|
196 |
+
out = (
|
197 |
+
out.unsqueeze(0)
|
198 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
199 |
+
.permute(0, 2, 1, 3)
|
200 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
201 |
+
)
|
202 |
+
return self.to_out(out)
|
203 |
+
|
204 |
+
|
205 |
+
class BasicTransformerBlock(nn.Module):
|
206 |
+
ATTENTION_MODES = {
|
207 |
+
AttnMode.VANILLA: CrossAttention, # vanilla attention
|
208 |
+
AttnMode.XFORMERS: MemoryEfficientCrossAttention,
|
209 |
+
AttnMode.SDP: SDPCrossAttention
|
210 |
+
}
|
211 |
+
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
212 |
+
disable_self_attn=False):
|
213 |
+
super().__init__()
|
214 |
+
attn_cls = self.ATTENTION_MODES[Config.attn_mode]
|
215 |
+
self.disable_self_attn = disable_self_attn
|
216 |
+
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
217 |
+
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
218 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
219 |
+
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
220 |
+
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
221 |
+
self.norm1 = nn.LayerNorm(dim)
|
222 |
+
self.norm2 = nn.LayerNorm(dim)
|
223 |
+
self.norm3 = nn.LayerNorm(dim)
|
224 |
+
self.checkpoint = checkpoint
|
225 |
+
|
226 |
+
def forward(self, x, context=None):
|
227 |
+
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
228 |
+
|
229 |
+
def _forward(self, x, context=None):
|
230 |
+
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
231 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
232 |
+
x = self.ff(self.norm3(x)) + x
|
233 |
+
return x
|
234 |
+
|
235 |
+
|
236 |
+
class SpatialTransformer(nn.Module):
|
237 |
+
"""
|
238 |
+
Transformer block for image-like data.
|
239 |
+
First, project the input (aka embedding)
|
240 |
+
and reshape to b, t, d.
|
241 |
+
Then apply standard transformer action.
|
242 |
+
Finally, reshape to image
|
243 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
244 |
+
"""
|
245 |
+
def __init__(self, in_channels, n_heads, d_head,
|
246 |
+
depth=1, dropout=0., context_dim=None,
|
247 |
+
disable_self_attn=False, use_linear=False,
|
248 |
+
use_checkpoint=True):
|
249 |
+
super().__init__()
|
250 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
251 |
+
context_dim = [context_dim]
|
252 |
+
self.in_channels = in_channels
|
253 |
+
inner_dim = n_heads * d_head
|
254 |
+
self.norm = Normalize(in_channels)
|
255 |
+
if not use_linear:
|
256 |
+
self.proj_in = nn.Conv2d(in_channels,
|
257 |
+
inner_dim,
|
258 |
+
kernel_size=1,
|
259 |
+
stride=1,
|
260 |
+
padding=0)
|
261 |
+
else:
|
262 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
263 |
+
|
264 |
+
self.transformer_blocks = nn.ModuleList(
|
265 |
+
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
266 |
+
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
|
267 |
+
for d in range(depth)]
|
268 |
+
)
|
269 |
+
if not use_linear:
|
270 |
+
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
271 |
+
in_channels,
|
272 |
+
kernel_size=1,
|
273 |
+
stride=1,
|
274 |
+
padding=0))
|
275 |
+
else:
|
276 |
+
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
277 |
+
self.use_linear = use_linear
|
278 |
+
|
279 |
+
def forward(self, x, context=None):
|
280 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
281 |
+
if not isinstance(context, list):
|
282 |
+
context = [context]
|
283 |
+
b, c, h, w = x.shape
|
284 |
+
x_in = x
|
285 |
+
x = self.norm(x)
|
286 |
+
if not self.use_linear:
|
287 |
+
x = self.proj_in(x)
|
288 |
+
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
289 |
+
if self.use_linear:
|
290 |
+
x = self.proj_in(x)
|
291 |
+
for i, block in enumerate(self.transformer_blocks):
|
292 |
+
x = block(x, context=context[i])
|
293 |
+
if self.use_linear:
|
294 |
+
x = self.proj_out(x)
|
295 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
296 |
+
if not self.use_linear:
|
297 |
+
x = self.proj_out(x)
|
298 |
+
return x + x_in
|
model/bsrnet.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From BSRGAN
|
2 |
+
import functools
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.nn.init as init
|
7 |
+
|
8 |
+
|
9 |
+
def initialize_weights(net_l, scale=1):
|
10 |
+
if not isinstance(net_l, list):
|
11 |
+
net_l = [net_l]
|
12 |
+
for net in net_l:
|
13 |
+
for m in net.modules():
|
14 |
+
if isinstance(m, nn.Conv2d):
|
15 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
16 |
+
m.weight.data *= scale # for residual block
|
17 |
+
if m.bias is not None:
|
18 |
+
m.bias.data.zero_()
|
19 |
+
elif isinstance(m, nn.Linear):
|
20 |
+
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
21 |
+
m.weight.data *= scale
|
22 |
+
if m.bias is not None:
|
23 |
+
m.bias.data.zero_()
|
24 |
+
elif isinstance(m, nn.BatchNorm2d):
|
25 |
+
init.constant_(m.weight, 1)
|
26 |
+
init.constant_(m.bias.data, 0.0)
|
27 |
+
|
28 |
+
|
29 |
+
def make_layer(block, n_layers):
|
30 |
+
layers = []
|
31 |
+
for _ in range(n_layers):
|
32 |
+
layers.append(block())
|
33 |
+
return nn.Sequential(*layers)
|
34 |
+
|
35 |
+
|
36 |
+
class ResidualDenseBlock_5C(nn.Module):
|
37 |
+
def __init__(self, nf=64, gc=32, bias=True):
|
38 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
39 |
+
# gc: growth channel, i.e. intermediate channels
|
40 |
+
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
41 |
+
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
42 |
+
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
43 |
+
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
44 |
+
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
45 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
46 |
+
|
47 |
+
# initialization
|
48 |
+
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
x1 = self.lrelu(self.conv1(x))
|
52 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
53 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
54 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
55 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
56 |
+
return x5 * 0.2 + x
|
57 |
+
|
58 |
+
|
59 |
+
class RRDB(nn.Module):
|
60 |
+
'''Residual in Residual Dense Block'''
|
61 |
+
|
62 |
+
def __init__(self, nf, gc=32):
|
63 |
+
super(RRDB, self).__init__()
|
64 |
+
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
65 |
+
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
66 |
+
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
out = self.RDB1(x)
|
70 |
+
out = self.RDB2(out)
|
71 |
+
out = self.RDB3(out)
|
72 |
+
return out * 0.2 + x
|
73 |
+
|
74 |
+
|
75 |
+
class RRDBNet(nn.Module):
|
76 |
+
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
77 |
+
super(RRDBNet, self).__init__()
|
78 |
+
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
79 |
+
self.sf = sf
|
80 |
+
print([in_nc, out_nc, nf, nb, gc, sf])
|
81 |
+
|
82 |
+
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
83 |
+
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
84 |
+
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
85 |
+
#### upsampling
|
86 |
+
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
87 |
+
if self.sf==4:
|
88 |
+
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
89 |
+
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
90 |
+
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
91 |
+
|
92 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
fea = self.conv_first(x)
|
96 |
+
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
97 |
+
fea = fea + trunk
|
98 |
+
|
99 |
+
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
100 |
+
if self.sf==4:
|
101 |
+
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
102 |
+
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
103 |
+
|
104 |
+
return out
|
model/cldm.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Set, List, Dict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from model import (
|
7 |
+
ControlledUnetModel, ControlNet,
|
8 |
+
AutoencoderKL, FrozenOpenCLIPEmbedder
|
9 |
+
)
|
10 |
+
from utils.common import sliding_windows, count_vram_usage, gaussian_weights
|
11 |
+
|
12 |
+
|
13 |
+
def disabled_train(self: nn.Module) -> nn.Module:
|
14 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
15 |
+
does not change anymore."""
|
16 |
+
return self
|
17 |
+
|
18 |
+
|
19 |
+
class ControlLDM(nn.Module):
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
unet_cfg,
|
24 |
+
vae_cfg,
|
25 |
+
clip_cfg,
|
26 |
+
controlnet_cfg,
|
27 |
+
latent_scale_factor
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.unet = ControlledUnetModel(**unet_cfg)
|
31 |
+
self.vae = AutoencoderKL(**vae_cfg)
|
32 |
+
self.clip = FrozenOpenCLIPEmbedder(**clip_cfg)
|
33 |
+
self.controlnet = ControlNet(**controlnet_cfg)
|
34 |
+
self.scale_factor = latent_scale_factor
|
35 |
+
self.control_scales = [1.0] * 13
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def load_pretrained_sd(self, sd: Dict[str, torch.Tensor]) -> Set[str]:
|
39 |
+
module_map = {
|
40 |
+
"unet": "model.diffusion_model",
|
41 |
+
"vae": "first_stage_model",
|
42 |
+
"clip": "cond_stage_model",
|
43 |
+
}
|
44 |
+
modules = [("unet", self.unet), ("vae", self.vae), ("clip", self.clip)]
|
45 |
+
used = set()
|
46 |
+
for name, module in modules:
|
47 |
+
init_sd = {}
|
48 |
+
scratch_sd = module.state_dict()
|
49 |
+
for key in scratch_sd:
|
50 |
+
target_key = ".".join([module_map[name], key])
|
51 |
+
init_sd[key] = sd[target_key].clone()
|
52 |
+
used.add(target_key)
|
53 |
+
module.load_state_dict(init_sd, strict=True)
|
54 |
+
unused = set(sd.keys()) - used
|
55 |
+
# NOTE: this is slightly different from previous version, which haven't switched
|
56 |
+
# the UNet to eval mode and disabled the requires_grad flag.
|
57 |
+
for module in [self.vae, self.clip, self.unet]:
|
58 |
+
module.eval()
|
59 |
+
module.train = disabled_train
|
60 |
+
for p in module.parameters():
|
61 |
+
p.requires_grad = False
|
62 |
+
return unused
|
63 |
+
|
64 |
+
@torch.no_grad()
|
65 |
+
def load_controlnet_from_ckpt(self, sd: Dict[str, torch.Tensor]) -> None:
|
66 |
+
self.controlnet.load_state_dict(sd, strict=True)
|
67 |
+
|
68 |
+
@torch.no_grad()
|
69 |
+
def load_controlnet_from_unet(self) -> Tuple[Set[str]]:
|
70 |
+
unet_sd = self.unet.state_dict()
|
71 |
+
scratch_sd = self.controlnet.state_dict()
|
72 |
+
init_sd = {}
|
73 |
+
init_with_new_zero = set()
|
74 |
+
init_with_scratch = set()
|
75 |
+
for key in scratch_sd:
|
76 |
+
if key in unet_sd:
|
77 |
+
this, target = scratch_sd[key], unet_sd[key]
|
78 |
+
if this.size() == target.size():
|
79 |
+
init_sd[key] = target.clone()
|
80 |
+
else:
|
81 |
+
d_ic = this.size(1) - target.size(1)
|
82 |
+
oc, _, h, w = this.size()
|
83 |
+
zeros = torch.zeros((oc, d_ic, h, w), dtype=target.dtype)
|
84 |
+
init_sd[key] = torch.cat((target, zeros), dim=1)
|
85 |
+
init_with_new_zero.add(key)
|
86 |
+
else:
|
87 |
+
init_sd[key] = scratch_sd[key].clone()
|
88 |
+
init_with_scratch.add(key)
|
89 |
+
self.controlnet.load_state_dict(init_sd, strict=True)
|
90 |
+
return init_with_new_zero, init_with_scratch
|
91 |
+
|
92 |
+
def vae_encode(self, image: torch.Tensor, sample: bool=True) -> torch.Tensor:
|
93 |
+
if sample:
|
94 |
+
return self.vae.encode(image).sample() * self.scale_factor
|
95 |
+
else:
|
96 |
+
return self.vae.encode(image).mode() * self.scale_factor
|
97 |
+
|
98 |
+
def vae_encode_tiled(self, image: torch.Tensor, tile_size: int, tile_stride: int, sample: bool=True) -> torch.Tensor:
|
99 |
+
bs, _, h, w = image.shape
|
100 |
+
z = torch.zeros((bs, 4, h // 8, w // 8), dtype=torch.float32, device=image.device)
|
101 |
+
count = torch.zeros_like(z, dtype=torch.float32)
|
102 |
+
weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None]
|
103 |
+
weights = torch.tensor(weights, dtype=torch.float32, device=image.device)
|
104 |
+
tiles = sliding_windows(h // 8, w // 8, tile_size // 8, tile_stride // 8)
|
105 |
+
for hi, hi_end, wi, wi_end in tiles:
|
106 |
+
tile_image = image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8]
|
107 |
+
z[:, :, hi:hi_end, wi:wi_end] += self.vae_encode(tile_image, sample=sample) * weights
|
108 |
+
count[:, :, hi:hi_end, wi:wi_end] += weights
|
109 |
+
z.div_(count)
|
110 |
+
return z
|
111 |
+
|
112 |
+
def vae_decode(self, z: torch.Tensor) -> torch.Tensor:
|
113 |
+
return self.vae.decode(z / self.scale_factor)
|
114 |
+
|
115 |
+
@count_vram_usage
|
116 |
+
def vae_decode_tiled(self, z: torch.Tensor, tile_size: int, tile_stride: int) -> torch.Tensor:
|
117 |
+
bs, _, h, w = z.shape
|
118 |
+
image = torch.zeros((bs, 3, h * 8, w * 8), dtype=torch.float32, device=z.device)
|
119 |
+
count = torch.zeros_like(image, dtype=torch.float32)
|
120 |
+
weights = gaussian_weights(tile_size * 8, tile_size * 8)[None, None]
|
121 |
+
weights = torch.tensor(weights, dtype=torch.float32, device=z.device)
|
122 |
+
tiles = sliding_windows(h, w, tile_size, tile_stride)
|
123 |
+
for hi, hi_end, wi, wi_end in tiles:
|
124 |
+
tile_z = z[:, :, hi:hi_end, wi:wi_end]
|
125 |
+
image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += self.vae_decode(tile_z) * weights
|
126 |
+
count[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += weights
|
127 |
+
image.div_(count)
|
128 |
+
return image
|
129 |
+
|
130 |
+
def prepare_condition(self, clean: torch.Tensor, txt: List[str]) -> Dict[str, torch.Tensor]:
|
131 |
+
return dict(
|
132 |
+
c_txt=self.clip.encode(txt),
|
133 |
+
c_img=self.vae_encode(clean * 2 - 1, sample=False)
|
134 |
+
)
|
135 |
+
|
136 |
+
@count_vram_usage
|
137 |
+
def prepare_condition_tiled(self, clean: torch.Tensor, txt: List[str], tile_size: int, tile_stride: int) -> Dict[str, torch.Tensor]:
|
138 |
+
return dict(
|
139 |
+
c_txt=self.clip.encode(txt),
|
140 |
+
c_img=self.vae_encode_tiled(clean * 2 - 1, tile_size, tile_stride, sample=False)
|
141 |
+
)
|
142 |
+
|
143 |
+
def forward(self, x_noisy, t, cond):
|
144 |
+
c_txt = cond["c_txt"]
|
145 |
+
c_img = cond["c_img"]
|
146 |
+
control = self.controlnet(
|
147 |
+
x=x_noisy, hint=c_img,
|
148 |
+
timesteps=t, context=c_txt
|
149 |
+
)
|
150 |
+
control = [c * scale for c, scale in zip(control, self.control_scales)]
|
151 |
+
eps = self.unet(
|
152 |
+
x=x_noisy, timesteps=t,
|
153 |
+
context=c_txt, control=control, only_mid_control=False
|
154 |
+
)
|
155 |
+
return eps
|
model/clip.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.utils.checkpoint import checkpoint
|
5 |
+
from model.open_clip import CLIP, tokenize
|
6 |
+
|
7 |
+
### pretrained model path
|
8 |
+
# _VITH14 = dict(
|
9 |
+
# laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
|
10 |
+
# )
|
11 |
+
|
12 |
+
class FrozenOpenCLIPEmbedder(nn.Module):
|
13 |
+
"""
|
14 |
+
Uses the OpenCLIP transformer encoder for text
|
15 |
+
"""
|
16 |
+
LAYERS = [
|
17 |
+
#"pooled",
|
18 |
+
"last",
|
19 |
+
"penultimate"
|
20 |
+
]
|
21 |
+
def __init__(self, embed_dim, vision_cfg, text_cfg, layer="last"):
|
22 |
+
super().__init__()
|
23 |
+
assert layer in self.LAYERS
|
24 |
+
# model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
|
25 |
+
model = CLIP(embed_dim, dict(vision_cfg), dict(text_cfg))
|
26 |
+
del model.visual
|
27 |
+
self.model = model
|
28 |
+
|
29 |
+
self.layer = layer
|
30 |
+
if self.layer == "last":
|
31 |
+
self.layer_idx = 0
|
32 |
+
elif self.layer == "penultimate":
|
33 |
+
self.layer_idx = 1
|
34 |
+
else:
|
35 |
+
raise NotImplementedError()
|
36 |
+
|
37 |
+
def forward(self, tokens):
|
38 |
+
z = self.encode_with_transformer(tokens)
|
39 |
+
return z
|
40 |
+
|
41 |
+
def encode_with_transformer(self, text):
|
42 |
+
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
43 |
+
x = x + self.model.positional_embedding
|
44 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
45 |
+
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
46 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
47 |
+
x = self.model.ln_final(x)
|
48 |
+
return x
|
49 |
+
|
50 |
+
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
51 |
+
for i, r in enumerate(self.model.transformer.resblocks):
|
52 |
+
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
53 |
+
break
|
54 |
+
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
|
55 |
+
x = checkpoint(r, x, attn_mask)
|
56 |
+
else:
|
57 |
+
x = r(x, attn_mask=attn_mask)
|
58 |
+
return x
|
59 |
+
|
60 |
+
def encode(self, text: List[str]) -> torch.Tensor:
|
61 |
+
# convert a batch of text to tensor
|
62 |
+
tokens = tokenize(text)
|
63 |
+
# move tensor to model device
|
64 |
+
tokens = tokens.to(next(self.model.parameters()).device)
|
65 |
+
return self(tokens)
|
model/config.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional, Literal
|
3 |
+
from types import ModuleType
|
4 |
+
import enum
|
5 |
+
from packaging import version
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
# collect system information
|
10 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
11 |
+
SDP_IS_AVAILABLE = True
|
12 |
+
else:
|
13 |
+
SDP_IS_AVAILABLE = False
|
14 |
+
|
15 |
+
try:
|
16 |
+
import xformers
|
17 |
+
import xformers.ops
|
18 |
+
XFORMERS_IS_AVAILBLE = True
|
19 |
+
except:
|
20 |
+
XFORMERS_IS_AVAILBLE = False
|
21 |
+
|
22 |
+
|
23 |
+
class AttnMode(enum.Enum):
|
24 |
+
SDP = 0
|
25 |
+
XFORMERS = 1
|
26 |
+
VANILLA = 2
|
27 |
+
|
28 |
+
|
29 |
+
class Config:
|
30 |
+
xformers: Optional[ModuleType] = None
|
31 |
+
attn_mode: AttnMode = AttnMode.VANILLA
|
32 |
+
|
33 |
+
|
34 |
+
# initialize attention mode
|
35 |
+
if SDP_IS_AVAILABLE:
|
36 |
+
Config.attn_mode = AttnMode.SDP
|
37 |
+
print(f"use sdp attention as default")
|
38 |
+
elif XFORMERS_IS_AVAILBLE:
|
39 |
+
Config.attn_mode = AttnMode.XFORMERS
|
40 |
+
print(f"use xformers attention as default")
|
41 |
+
else:
|
42 |
+
print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default")
|
43 |
+
|
44 |
+
if XFORMERS_IS_AVAILBLE:
|
45 |
+
Config.xformers = xformers
|
46 |
+
|
47 |
+
|
48 |
+
# user-specified attention mode
|
49 |
+
ATTN_MODE = os.environ.get("ATTN_MODE", None)
|
50 |
+
if ATTN_MODE is not None:
|
51 |
+
assert ATTN_MODE in ["vanilla", "sdp", "xformers"]
|
52 |
+
if ATTN_MODE == "sdp":
|
53 |
+
assert SDP_IS_AVAILABLE
|
54 |
+
Config.attn_mode = AttnMode.SDP
|
55 |
+
elif ATTN_MODE == "xformers":
|
56 |
+
assert XFORMERS_IS_AVAILBLE
|
57 |
+
Config.attn_mode = AttnMode.XFORMERS
|
58 |
+
else:
|
59 |
+
Config.attn_mode = AttnMode.VANILLA
|
60 |
+
print(f"set attention mode to {ATTN_MODE}")
|
61 |
+
else:
|
62 |
+
print("keep default attention mode")
|
model/controlnet.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch as th
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from model.util import (
|
6 |
+
conv_nd,
|
7 |
+
linear,
|
8 |
+
zero_module,
|
9 |
+
timestep_embedding,
|
10 |
+
exists
|
11 |
+
)
|
12 |
+
from model.attention import SpatialTransformer
|
13 |
+
from model.unet import (
|
14 |
+
TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock, UNetModel
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class ControlledUnetModel(UNetModel):
|
19 |
+
|
20 |
+
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
|
21 |
+
hs = []
|
22 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
23 |
+
emb = self.time_embed(t_emb)
|
24 |
+
h = x.type(self.dtype)
|
25 |
+
for module in self.input_blocks:
|
26 |
+
h = module(h, emb, context)
|
27 |
+
hs.append(h)
|
28 |
+
h = self.middle_block(h, emb, context)
|
29 |
+
|
30 |
+
if control is not None:
|
31 |
+
h += control.pop()
|
32 |
+
|
33 |
+
for i, module in enumerate(self.output_blocks):
|
34 |
+
if only_mid_control or control is None:
|
35 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
36 |
+
else:
|
37 |
+
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
|
38 |
+
h = module(h, emb, context)
|
39 |
+
|
40 |
+
h = h.type(x.dtype)
|
41 |
+
return self.out(h)
|
42 |
+
|
43 |
+
|
44 |
+
class ControlNet(nn.Module):
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
image_size,
|
49 |
+
in_channels,
|
50 |
+
model_channels,
|
51 |
+
hint_channels,
|
52 |
+
num_res_blocks,
|
53 |
+
attention_resolutions,
|
54 |
+
dropout=0,
|
55 |
+
channel_mult=(1, 2, 4, 8),
|
56 |
+
conv_resample=True,
|
57 |
+
dims=2,
|
58 |
+
use_checkpoint=False,
|
59 |
+
use_fp16=False,
|
60 |
+
num_heads=-1,
|
61 |
+
num_head_channels=-1,
|
62 |
+
num_heads_upsample=-1,
|
63 |
+
use_scale_shift_norm=False,
|
64 |
+
resblock_updown=False,
|
65 |
+
use_new_attention_order=False,
|
66 |
+
use_spatial_transformer=False, # custom transformer support
|
67 |
+
transformer_depth=1, # custom transformer support
|
68 |
+
context_dim=None, # custom transformer support
|
69 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
70 |
+
legacy=True,
|
71 |
+
disable_self_attentions=None,
|
72 |
+
num_attention_blocks=None,
|
73 |
+
disable_middle_self_attn=False,
|
74 |
+
use_linear_in_transformer=False,
|
75 |
+
):
|
76 |
+
super().__init__()
|
77 |
+
if use_spatial_transformer:
|
78 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
79 |
+
|
80 |
+
if context_dim is not None:
|
81 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
82 |
+
from omegaconf.listconfig import ListConfig
|
83 |
+
if type(context_dim) == ListConfig:
|
84 |
+
context_dim = list(context_dim)
|
85 |
+
|
86 |
+
if num_heads_upsample == -1:
|
87 |
+
num_heads_upsample = num_heads
|
88 |
+
|
89 |
+
if num_heads == -1:
|
90 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
91 |
+
|
92 |
+
if num_head_channels == -1:
|
93 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
94 |
+
|
95 |
+
self.dims = dims
|
96 |
+
self.image_size = image_size
|
97 |
+
self.in_channels = in_channels
|
98 |
+
self.model_channels = model_channels
|
99 |
+
if isinstance(num_res_blocks, int):
|
100 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
101 |
+
else:
|
102 |
+
if len(num_res_blocks) != len(channel_mult):
|
103 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
104 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
105 |
+
self.num_res_blocks = num_res_blocks
|
106 |
+
if disable_self_attentions is not None:
|
107 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
108 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
109 |
+
if num_attention_blocks is not None:
|
110 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
111 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
112 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
113 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
114 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
115 |
+
f"attention will still not be set.")
|
116 |
+
|
117 |
+
self.attention_resolutions = attention_resolutions
|
118 |
+
self.dropout = dropout
|
119 |
+
self.channel_mult = channel_mult
|
120 |
+
self.conv_resample = conv_resample
|
121 |
+
self.use_checkpoint = use_checkpoint
|
122 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
123 |
+
self.num_heads = num_heads
|
124 |
+
self.num_head_channels = num_head_channels
|
125 |
+
self.num_heads_upsample = num_heads_upsample
|
126 |
+
self.predict_codebook_ids = n_embed is not None
|
127 |
+
|
128 |
+
time_embed_dim = model_channels * 4
|
129 |
+
self.time_embed = nn.Sequential(
|
130 |
+
linear(model_channels, time_embed_dim),
|
131 |
+
nn.SiLU(),
|
132 |
+
linear(time_embed_dim, time_embed_dim),
|
133 |
+
)
|
134 |
+
|
135 |
+
self.input_blocks = nn.ModuleList(
|
136 |
+
[
|
137 |
+
TimestepEmbedSequential(
|
138 |
+
conv_nd(dims, in_channels + hint_channels, model_channels, 3, padding=1)
|
139 |
+
)
|
140 |
+
]
|
141 |
+
)
|
142 |
+
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
143 |
+
|
144 |
+
self._feature_size = model_channels
|
145 |
+
input_block_chans = [model_channels]
|
146 |
+
ch = model_channels
|
147 |
+
ds = 1
|
148 |
+
for level, mult in enumerate(channel_mult):
|
149 |
+
for nr in range(self.num_res_blocks[level]):
|
150 |
+
layers = [
|
151 |
+
ResBlock(
|
152 |
+
ch,
|
153 |
+
time_embed_dim,
|
154 |
+
dropout,
|
155 |
+
out_channels=mult * model_channels,
|
156 |
+
dims=dims,
|
157 |
+
use_checkpoint=use_checkpoint,
|
158 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
159 |
+
)
|
160 |
+
]
|
161 |
+
ch = mult * model_channels
|
162 |
+
if ds in attention_resolutions:
|
163 |
+
if num_head_channels == -1:
|
164 |
+
dim_head = ch // num_heads
|
165 |
+
else:
|
166 |
+
num_heads = ch // num_head_channels
|
167 |
+
dim_head = num_head_channels
|
168 |
+
if legacy:
|
169 |
+
# num_heads = 1
|
170 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
171 |
+
if exists(disable_self_attentions):
|
172 |
+
disabled_sa = disable_self_attentions[level]
|
173 |
+
else:
|
174 |
+
disabled_sa = False
|
175 |
+
|
176 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
177 |
+
layers.append(
|
178 |
+
AttentionBlock(
|
179 |
+
ch,
|
180 |
+
use_checkpoint=use_checkpoint,
|
181 |
+
num_heads=num_heads,
|
182 |
+
num_head_channels=dim_head,
|
183 |
+
use_new_attention_order=use_new_attention_order,
|
184 |
+
) if not use_spatial_transformer else SpatialTransformer(
|
185 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
186 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
187 |
+
use_checkpoint=use_checkpoint
|
188 |
+
)
|
189 |
+
)
|
190 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
191 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
192 |
+
self._feature_size += ch
|
193 |
+
input_block_chans.append(ch)
|
194 |
+
if level != len(channel_mult) - 1:
|
195 |
+
out_ch = ch
|
196 |
+
self.input_blocks.append(
|
197 |
+
TimestepEmbedSequential(
|
198 |
+
ResBlock(
|
199 |
+
ch,
|
200 |
+
time_embed_dim,
|
201 |
+
dropout,
|
202 |
+
out_channels=out_ch,
|
203 |
+
dims=dims,
|
204 |
+
use_checkpoint=use_checkpoint,
|
205 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
206 |
+
down=True,
|
207 |
+
)
|
208 |
+
if resblock_updown
|
209 |
+
else Downsample(
|
210 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
211 |
+
)
|
212 |
+
)
|
213 |
+
)
|
214 |
+
ch = out_ch
|
215 |
+
input_block_chans.append(ch)
|
216 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
217 |
+
ds *= 2
|
218 |
+
self._feature_size += ch
|
219 |
+
|
220 |
+
if num_head_channels == -1:
|
221 |
+
dim_head = ch // num_heads
|
222 |
+
else:
|
223 |
+
num_heads = ch // num_head_channels
|
224 |
+
dim_head = num_head_channels
|
225 |
+
if legacy:
|
226 |
+
# num_heads = 1
|
227 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
228 |
+
self.middle_block = TimestepEmbedSequential(
|
229 |
+
ResBlock(
|
230 |
+
ch,
|
231 |
+
time_embed_dim,
|
232 |
+
dropout,
|
233 |
+
dims=dims,
|
234 |
+
use_checkpoint=use_checkpoint,
|
235 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
236 |
+
),
|
237 |
+
AttentionBlock(
|
238 |
+
ch,
|
239 |
+
use_checkpoint=use_checkpoint,
|
240 |
+
num_heads=num_heads,
|
241 |
+
num_head_channels=dim_head,
|
242 |
+
use_new_attention_order=use_new_attention_order,
|
243 |
+
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
244 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
245 |
+
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
246 |
+
use_checkpoint=use_checkpoint
|
247 |
+
),
|
248 |
+
ResBlock(
|
249 |
+
ch,
|
250 |
+
time_embed_dim,
|
251 |
+
dropout,
|
252 |
+
dims=dims,
|
253 |
+
use_checkpoint=use_checkpoint,
|
254 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
255 |
+
),
|
256 |
+
)
|
257 |
+
self.middle_block_out = self.make_zero_conv(ch)
|
258 |
+
self._feature_size += ch
|
259 |
+
|
260 |
+
def make_zero_conv(self, channels):
|
261 |
+
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
262 |
+
|
263 |
+
def forward(self, x, hint, timesteps, context, **kwargs):
|
264 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
265 |
+
emb = self.time_embed(t_emb)
|
266 |
+
x = torch.cat((x, hint), dim=1)
|
267 |
+
outs = []
|
268 |
+
|
269 |
+
h = x.type(self.dtype)
|
270 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
271 |
+
h = module(h, emb, context)
|
272 |
+
outs.append(zero_conv(h, emb, context))
|
273 |
+
|
274 |
+
h = self.middle_block(h, emb, context)
|
275 |
+
outs.append(self.middle_block_out(h, emb, context))
|
276 |
+
|
277 |
+
return outs
|
model/distributions.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
class AbstractDistribution:
|
6 |
+
def sample(self):
|
7 |
+
raise NotImplementedError()
|
8 |
+
|
9 |
+
def mode(self):
|
10 |
+
raise NotImplementedError()
|
11 |
+
|
12 |
+
|
13 |
+
class DiracDistribution(AbstractDistribution):
|
14 |
+
def __init__(self, value):
|
15 |
+
self.value = value
|
16 |
+
|
17 |
+
def sample(self):
|
18 |
+
return self.value
|
19 |
+
|
20 |
+
def mode(self):
|
21 |
+
return self.value
|
22 |
+
|
23 |
+
|
24 |
+
class DiagonalGaussianDistribution(object):
|
25 |
+
def __init__(self, parameters, deterministic=False):
|
26 |
+
self.parameters = parameters
|
27 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
28 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
29 |
+
self.deterministic = deterministic
|
30 |
+
self.std = torch.exp(0.5 * self.logvar)
|
31 |
+
self.var = torch.exp(self.logvar)
|
32 |
+
if self.deterministic:
|
33 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
34 |
+
|
35 |
+
def sample(self):
|
36 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
37 |
+
return x
|
38 |
+
|
39 |
+
def kl(self, other=None):
|
40 |
+
if self.deterministic:
|
41 |
+
return torch.Tensor([0.])
|
42 |
+
else:
|
43 |
+
if other is None:
|
44 |
+
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
45 |
+
+ self.var - 1.0 - self.logvar,
|
46 |
+
dim=[1, 2, 3])
|
47 |
+
else:
|
48 |
+
return 0.5 * torch.sum(
|
49 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
50 |
+
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
51 |
+
dim=[1, 2, 3])
|
52 |
+
|
53 |
+
def nll(self, sample, dims=[1,2,3]):
|
54 |
+
if self.deterministic:
|
55 |
+
return torch.Tensor([0.])
|
56 |
+
logtwopi = np.log(2.0 * np.pi)
|
57 |
+
return 0.5 * torch.sum(
|
58 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
59 |
+
dim=dims)
|
60 |
+
|
61 |
+
def mode(self):
|
62 |
+
return self.mean
|
63 |
+
|
64 |
+
|
65 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
66 |
+
"""
|
67 |
+
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
68 |
+
Compute the KL divergence between two gaussians.
|
69 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
70 |
+
scalars, among other use cases.
|
71 |
+
"""
|
72 |
+
tensor = None
|
73 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
74 |
+
if isinstance(obj, torch.Tensor):
|
75 |
+
tensor = obj
|
76 |
+
break
|
77 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
78 |
+
|
79 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
80 |
+
# Tensors, but it does not work for torch.exp().
|
81 |
+
logvar1, logvar2 = [
|
82 |
+
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
83 |
+
for x in (logvar1, logvar2)
|
84 |
+
]
|
85 |
+
|
86 |
+
return 0.5 * (
|
87 |
+
-1.0
|
88 |
+
+ logvar2
|
89 |
+
- logvar1
|
90 |
+
+ torch.exp(logvar1 - logvar2)
|
91 |
+
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
92 |
+
)
|
model/gaussian_diffusion.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
10 |
+
if schedule == "linear":
|
11 |
+
betas = (
|
12 |
+
np.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=np.float64) ** 2
|
13 |
+
)
|
14 |
+
|
15 |
+
elif schedule == "cosine":
|
16 |
+
timesteps = (
|
17 |
+
np.arange(n_timestep + 1, dtype=np.float64) / n_timestep + cosine_s
|
18 |
+
)
|
19 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
20 |
+
alphas = np.cos(alphas).pow(2)
|
21 |
+
alphas = alphas / alphas[0]
|
22 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
23 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
24 |
+
|
25 |
+
elif schedule == "sqrt_linear":
|
26 |
+
betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64)
|
27 |
+
elif schedule == "sqrt":
|
28 |
+
betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) ** 0.5
|
29 |
+
else:
|
30 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
31 |
+
return betas
|
32 |
+
|
33 |
+
|
34 |
+
def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int]) -> torch.Tensor:
|
35 |
+
b, *_ = t.shape
|
36 |
+
out = a.gather(-1, t)
|
37 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
38 |
+
|
39 |
+
|
40 |
+
class Diffusion(nn.Module):
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
timesteps=1000,
|
45 |
+
beta_schedule="linear",
|
46 |
+
loss_type="l2",
|
47 |
+
linear_start=1e-4,
|
48 |
+
linear_end=2e-2,
|
49 |
+
cosine_s=8e-3,
|
50 |
+
parameterization="eps"
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
self.num_timesteps = timesteps
|
54 |
+
self.beta_schedule = beta_schedule
|
55 |
+
self.linear_start = linear_start
|
56 |
+
self.linear_end = linear_end
|
57 |
+
self.cosine_s = cosine_s
|
58 |
+
assert parameterization in ["eps", "x0", "v"], "currently only supporting 'eps' and 'x0' and 'v'"
|
59 |
+
self.parameterization = parameterization
|
60 |
+
self.loss_type = loss_type
|
61 |
+
|
62 |
+
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
63 |
+
cosine_s=cosine_s)
|
64 |
+
alphas = 1. - betas
|
65 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
66 |
+
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
|
67 |
+
sqrt_one_minus_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
|
68 |
+
|
69 |
+
self.betas = betas
|
70 |
+
self.register("sqrt_alphas_cumprod", sqrt_alphas_cumprod)
|
71 |
+
self.register("sqrt_one_minus_alphas_cumprod", sqrt_one_minus_alphas_cumprod)
|
72 |
+
|
73 |
+
def register(self, name: str, value: np.ndarray) -> None:
|
74 |
+
self.register_buffer(name, torch.tensor(value, dtype=torch.float32))
|
75 |
+
|
76 |
+
def q_sample(self, x_start, t, noise):
|
77 |
+
return (
|
78 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
79 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
80 |
+
)
|
81 |
+
|
82 |
+
def get_v(self, x, noise, t):
|
83 |
+
return (
|
84 |
+
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
|
85 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
86 |
+
)
|
87 |
+
|
88 |
+
def get_loss(self, pred, target, mean=True):
|
89 |
+
if self.loss_type == 'l1':
|
90 |
+
loss = (target - pred).abs()
|
91 |
+
if mean:
|
92 |
+
loss = loss.mean()
|
93 |
+
elif self.loss_type == 'l2':
|
94 |
+
if mean:
|
95 |
+
loss = torch.nn.functional.mse_loss(target, pred)
|
96 |
+
else:
|
97 |
+
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
98 |
+
else:
|
99 |
+
raise NotImplementedError("unknown loss type '{loss_type}'")
|
100 |
+
|
101 |
+
return loss
|
102 |
+
|
103 |
+
def p_losses(self, model, x_start, t, cond):
|
104 |
+
noise = torch.randn_like(x_start)
|
105 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
106 |
+
model_output = model(x_noisy, t, cond)
|
107 |
+
|
108 |
+
if self.parameterization == "x0":
|
109 |
+
target = x_start
|
110 |
+
elif self.parameterization == "eps":
|
111 |
+
target = noise
|
112 |
+
elif self.parameterization == "v":
|
113 |
+
target = self.get_v(x_start, noise, t)
|
114 |
+
else:
|
115 |
+
raise NotImplementedError()
|
116 |
+
|
117 |
+
loss_simple = self.get_loss(model_output, target, mean=False).mean()
|
118 |
+
return loss_simple
|
model/open_clip/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model import CLIP
|
2 |
+
from .tokenizer import tokenize
|
3 |
+
|
4 |
+
__all__ = ["CLIP", "tokenize"]
|
model/open_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
model/open_clip/model.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP Model
|
2 |
+
|
3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Optional, Tuple, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, VisionTransformer, TextTransformer
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class CLIPVisionCfg:
|
18 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
19 |
+
width: int = 768
|
20 |
+
head_width: int = 64
|
21 |
+
mlp_ratio: float = 4.0
|
22 |
+
patch_size: int = 16
|
23 |
+
image_size: Union[Tuple[int, int], int] = 224
|
24 |
+
|
25 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
26 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
27 |
+
input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
|
28 |
+
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
29 |
+
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
|
30 |
+
n_queries: int = 256 # n_queries for attentional pooler
|
31 |
+
attn_pooler_heads: int = 8 # n heads for attentional_pooling
|
32 |
+
output_tokens: bool = False
|
33 |
+
|
34 |
+
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
35 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
36 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
37 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
38 |
+
timm_proj_bias: bool = False # enable bias final projection
|
39 |
+
timm_drop: float = 0. # head dropout
|
40 |
+
timm_drop_path: Optional[float] = None # backbone stochastic depth
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class CLIPTextCfg:
|
45 |
+
context_length: int = 77
|
46 |
+
vocab_size: int = 49408
|
47 |
+
width: int = 512
|
48 |
+
heads: int = 8
|
49 |
+
layers: int = 12
|
50 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
51 |
+
hf_model_name: str = None
|
52 |
+
hf_tokenizer_name: str = None
|
53 |
+
hf_model_pretrained: bool = True
|
54 |
+
proj: str = 'mlp'
|
55 |
+
pooler_type: str = 'mean_pooler'
|
56 |
+
embed_cls: bool = False
|
57 |
+
pad_id: int = 0
|
58 |
+
output_tokens: bool = False
|
59 |
+
|
60 |
+
|
61 |
+
def get_cast_dtype(precision: str):
|
62 |
+
cast_dtype = None
|
63 |
+
if precision == 'bf16':
|
64 |
+
cast_dtype = torch.bfloat16
|
65 |
+
elif precision == 'fp16':
|
66 |
+
cast_dtype = torch.float16
|
67 |
+
return cast_dtype
|
68 |
+
|
69 |
+
|
70 |
+
def _build_vision_tower(
|
71 |
+
embed_dim: int,
|
72 |
+
vision_cfg: CLIPVisionCfg,
|
73 |
+
quick_gelu: bool = False,
|
74 |
+
cast_dtype: Optional[torch.dtype] = None
|
75 |
+
):
|
76 |
+
if isinstance(vision_cfg, dict):
|
77 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
78 |
+
|
79 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
80 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
81 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
82 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
83 |
+
|
84 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
85 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
86 |
+
visual = VisionTransformer(
|
87 |
+
image_size=vision_cfg.image_size,
|
88 |
+
patch_size=vision_cfg.patch_size,
|
89 |
+
width=vision_cfg.width,
|
90 |
+
layers=vision_cfg.layers,
|
91 |
+
heads=vision_heads,
|
92 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
93 |
+
ls_init_value=vision_cfg.ls_init_value,
|
94 |
+
patch_dropout=vision_cfg.patch_dropout,
|
95 |
+
input_patchnorm=vision_cfg.input_patchnorm,
|
96 |
+
global_average_pool=vision_cfg.global_average_pool,
|
97 |
+
attentional_pool=vision_cfg.attentional_pool,
|
98 |
+
n_queries=vision_cfg.n_queries,
|
99 |
+
attn_pooler_heads=vision_cfg.attn_pooler_heads,
|
100 |
+
output_tokens=vision_cfg.output_tokens,
|
101 |
+
output_dim=embed_dim,
|
102 |
+
act_layer=act_layer,
|
103 |
+
norm_layer=norm_layer,
|
104 |
+
)
|
105 |
+
|
106 |
+
return visual
|
107 |
+
|
108 |
+
|
109 |
+
def _build_text_tower(
|
110 |
+
embed_dim: int,
|
111 |
+
text_cfg: CLIPTextCfg,
|
112 |
+
quick_gelu: bool = False,
|
113 |
+
cast_dtype: Optional[torch.dtype] = None,
|
114 |
+
):
|
115 |
+
if isinstance(text_cfg, dict):
|
116 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
117 |
+
|
118 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
119 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
120 |
+
|
121 |
+
text = TextTransformer(
|
122 |
+
context_length=text_cfg.context_length,
|
123 |
+
vocab_size=text_cfg.vocab_size,
|
124 |
+
width=text_cfg.width,
|
125 |
+
heads=text_cfg.heads,
|
126 |
+
layers=text_cfg.layers,
|
127 |
+
ls_init_value=text_cfg.ls_init_value,
|
128 |
+
output_dim=embed_dim,
|
129 |
+
embed_cls=text_cfg.embed_cls,
|
130 |
+
output_tokens=text_cfg.output_tokens,
|
131 |
+
pad_id=text_cfg.pad_id,
|
132 |
+
act_layer=act_layer,
|
133 |
+
norm_layer=norm_layer,
|
134 |
+
)
|
135 |
+
return text
|
136 |
+
|
137 |
+
|
138 |
+
class CLIP(nn.Module):
|
139 |
+
output_dict: torch.jit.Final[bool]
|
140 |
+
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
embed_dim: int,
|
144 |
+
vision_cfg: CLIPVisionCfg,
|
145 |
+
text_cfg: CLIPTextCfg,
|
146 |
+
quick_gelu: bool = False,
|
147 |
+
cast_dtype: Optional[torch.dtype] = None,
|
148 |
+
output_dict: bool = False,
|
149 |
+
):
|
150 |
+
super().__init__()
|
151 |
+
self.output_dict = output_dict
|
152 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
153 |
+
|
154 |
+
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
155 |
+
self.transformer = text.transformer
|
156 |
+
self.context_length = text.context_length
|
157 |
+
self.vocab_size = text.vocab_size
|
158 |
+
self.token_embedding = text.token_embedding
|
159 |
+
self.positional_embedding = text.positional_embedding
|
160 |
+
self.ln_final = text.ln_final
|
161 |
+
self.text_projection = text.text_projection
|
162 |
+
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
163 |
+
|
164 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
165 |
+
|
166 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
167 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
168 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
169 |
+
|
170 |
+
@torch.jit.ignore
|
171 |
+
def set_grad_checkpointing(self, enable=True):
|
172 |
+
self.visual.set_grad_checkpointing(enable)
|
173 |
+
self.transformer.grad_checkpointing = enable
|
174 |
+
|
175 |
+
def encode_image(self, image, normalize: bool = False):
|
176 |
+
features = self.visual(image)
|
177 |
+
return F.normalize(features, dim=-1) if normalize else features
|
178 |
+
|
179 |
+
def encode_text(self, text, normalize: bool = False):
|
180 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
181 |
+
|
182 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
183 |
+
|
184 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
185 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
186 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
187 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
188 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
189 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
190 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
191 |
+
return F.normalize(x, dim=-1) if normalize else x
|
192 |
+
|
193 |
+
def forward(
|
194 |
+
self,
|
195 |
+
image: Optional[torch.Tensor] = None,
|
196 |
+
text: Optional[torch.Tensor] = None,
|
197 |
+
):
|
198 |
+
image_features = self.encode_image(image, normalize=True) if image is not None else None
|
199 |
+
text_features = self.encode_text(text, normalize=True) if text is not None else None
|
200 |
+
if self.output_dict:
|
201 |
+
return {
|
202 |
+
"image_features": image_features,
|
203 |
+
"text_features": text_features,
|
204 |
+
"logit_scale": self.logit_scale.exp()
|
205 |
+
}
|
206 |
+
return image_features, text_features, self.logit_scale.exp()
|
model/open_clip/tokenizer.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP tokenizer
|
2 |
+
|
3 |
+
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
import gzip
|
6 |
+
import html
|
7 |
+
import os
|
8 |
+
from functools import lru_cache
|
9 |
+
from typing import Union, List
|
10 |
+
|
11 |
+
import ftfy
|
12 |
+
import regex as re
|
13 |
+
import torch
|
14 |
+
|
15 |
+
# https://stackoverflow.com/q/62691279
|
16 |
+
import os
|
17 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
18 |
+
|
19 |
+
|
20 |
+
@lru_cache()
|
21 |
+
def default_bpe():
|
22 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
23 |
+
|
24 |
+
|
25 |
+
@lru_cache()
|
26 |
+
def bytes_to_unicode():
|
27 |
+
"""
|
28 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
29 |
+
The reversible bpe codes work on unicode strings.
|
30 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
31 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
32 |
+
This is a significant percentage of your normal, say, 32K bpe vocab.
|
33 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
34 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
35 |
+
"""
|
36 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
37 |
+
cs = bs[:]
|
38 |
+
n = 0
|
39 |
+
for b in range(2**8):
|
40 |
+
if b not in bs:
|
41 |
+
bs.append(b)
|
42 |
+
cs.append(2**8+n)
|
43 |
+
n += 1
|
44 |
+
cs = [chr(n) for n in cs]
|
45 |
+
return dict(zip(bs, cs))
|
46 |
+
|
47 |
+
|
48 |
+
def get_pairs(word):
|
49 |
+
"""Return set of symbol pairs in a word.
|
50 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
51 |
+
"""
|
52 |
+
pairs = set()
|
53 |
+
prev_char = word[0]
|
54 |
+
for char in word[1:]:
|
55 |
+
pairs.add((prev_char, char))
|
56 |
+
prev_char = char
|
57 |
+
return pairs
|
58 |
+
|
59 |
+
|
60 |
+
def basic_clean(text):
|
61 |
+
text = ftfy.fix_text(text)
|
62 |
+
text = html.unescape(html.unescape(text))
|
63 |
+
return text.strip()
|
64 |
+
|
65 |
+
|
66 |
+
def whitespace_clean(text):
|
67 |
+
text = re.sub(r'\s+', ' ', text)
|
68 |
+
text = text.strip()
|
69 |
+
return text
|
70 |
+
|
71 |
+
|
72 |
+
class SimpleTokenizer(object):
|
73 |
+
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
74 |
+
self.byte_encoder = bytes_to_unicode()
|
75 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
76 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
77 |
+
merges = merges[1:49152-256-2+1]
|
78 |
+
merges = [tuple(merge.split()) for merge in merges]
|
79 |
+
vocab = list(bytes_to_unicode().values())
|
80 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
81 |
+
for merge in merges:
|
82 |
+
vocab.append(''.join(merge))
|
83 |
+
if not special_tokens:
|
84 |
+
special_tokens = ['<start_of_text>', '<end_of_text>']
|
85 |
+
else:
|
86 |
+
special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
|
87 |
+
vocab.extend(special_tokens)
|
88 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
89 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
90 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
91 |
+
self.cache = {t:t for t in special_tokens}
|
92 |
+
special = "|".join(special_tokens)
|
93 |
+
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
94 |
+
|
95 |
+
self.vocab_size = len(self.encoder)
|
96 |
+
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
97 |
+
|
98 |
+
def bpe(self, token):
|
99 |
+
if token in self.cache:
|
100 |
+
return self.cache[token]
|
101 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
102 |
+
pairs = get_pairs(word)
|
103 |
+
|
104 |
+
if not pairs:
|
105 |
+
return token+'</w>'
|
106 |
+
|
107 |
+
while True:
|
108 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
109 |
+
if bigram not in self.bpe_ranks:
|
110 |
+
break
|
111 |
+
first, second = bigram
|
112 |
+
new_word = []
|
113 |
+
i = 0
|
114 |
+
while i < len(word):
|
115 |
+
try:
|
116 |
+
j = word.index(first, i)
|
117 |
+
new_word.extend(word[i:j])
|
118 |
+
i = j
|
119 |
+
except:
|
120 |
+
new_word.extend(word[i:])
|
121 |
+
break
|
122 |
+
|
123 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
124 |
+
new_word.append(first+second)
|
125 |
+
i += 2
|
126 |
+
else:
|
127 |
+
new_word.append(word[i])
|
128 |
+
i += 1
|
129 |
+
new_word = tuple(new_word)
|
130 |
+
word = new_word
|
131 |
+
if len(word) == 1:
|
132 |
+
break
|
133 |
+
else:
|
134 |
+
pairs = get_pairs(word)
|
135 |
+
word = ' '.join(word)
|
136 |
+
self.cache[token] = word
|
137 |
+
return word
|
138 |
+
|
139 |
+
def encode(self, text):
|
140 |
+
bpe_tokens = []
|
141 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
142 |
+
for token in re.findall(self.pat, text):
|
143 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
144 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
145 |
+
return bpe_tokens
|
146 |
+
|
147 |
+
def decode(self, tokens):
|
148 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
149 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
150 |
+
return text
|
151 |
+
|
152 |
+
|
153 |
+
_tokenizer = SimpleTokenizer()
|
154 |
+
|
155 |
+
def decode(output_ids: torch.Tensor):
|
156 |
+
output_ids = output_ids.cpu().numpy()
|
157 |
+
return _tokenizer.decode(output_ids)
|
158 |
+
|
159 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
160 |
+
"""
|
161 |
+
Returns the tokenized representation of given input string(s)
|
162 |
+
|
163 |
+
Parameters
|
164 |
+
----------
|
165 |
+
texts : Union[str, List[str]]
|
166 |
+
An input string or a list of input strings to tokenize
|
167 |
+
context_length : int
|
168 |
+
The context length to use; all CLIP models use 77 as the context length
|
169 |
+
|
170 |
+
Returns
|
171 |
+
-------
|
172 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
173 |
+
"""
|
174 |
+
if isinstance(texts, str):
|
175 |
+
texts = [texts]
|
176 |
+
|
177 |
+
sot_token = _tokenizer.encoder["<start_of_text>"]
|
178 |
+
eot_token = _tokenizer.encoder["<end_of_text>"]
|
179 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
180 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
181 |
+
|
182 |
+
for i, tokens in enumerate(all_tokens):
|
183 |
+
if len(tokens) > context_length:
|
184 |
+
tokens = tokens[:context_length] # Truncate
|
185 |
+
tokens[-1] = eot_token
|
186 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
187 |
+
|
188 |
+
return result
|
189 |
+
|
190 |
+
|
191 |
+
class HFTokenizer:
|
192 |
+
"""HuggingFace tokenizer wrapper"""
|
193 |
+
|
194 |
+
def __init__(self, tokenizer_name: str):
|
195 |
+
from transformers import AutoTokenizer
|
196 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
197 |
+
|
198 |
+
def save_pretrained(self, dest):
|
199 |
+
self.tokenizer.save_pretrained(dest)
|
200 |
+
|
201 |
+
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
|
202 |
+
# same cleaning as for default tokenizer, except lowercasing
|
203 |
+
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
|
204 |
+
if isinstance(texts, str):
|
205 |
+
texts = [texts]
|
206 |
+
texts = [whitespace_clean(basic_clean(text)) for text in texts]
|
207 |
+
input_ids = self.tokenizer(
|
208 |
+
texts,
|
209 |
+
return_tensors='pt',
|
210 |
+
max_length=context_length,
|
211 |
+
padding='max_length',
|
212 |
+
truncation=True,
|
213 |
+
).input_ids
|
214 |
+
return input_ids
|
model/open_clip/transformer.py
ADDED
@@ -0,0 +1,736 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
from collections import OrderedDict
|
3 |
+
import math
|
4 |
+
from typing import Callable, Optional, Sequence, Tuple
|
5 |
+
from itertools import repeat
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.utils.checkpoint import checkpoint
|
11 |
+
|
12 |
+
# From PyTorch internals
|
13 |
+
def _ntuple(n):
|
14 |
+
def parse(x):
|
15 |
+
if isinstance(x, collections.abc.Iterable):
|
16 |
+
return x
|
17 |
+
return tuple(repeat(x, n))
|
18 |
+
return parse
|
19 |
+
|
20 |
+
to_2tuple = _ntuple(2)
|
21 |
+
|
22 |
+
|
23 |
+
class LayerNormFp32(nn.LayerNorm):
|
24 |
+
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
25 |
+
|
26 |
+
def forward(self, x: torch.Tensor):
|
27 |
+
orig_type = x.dtype
|
28 |
+
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
|
29 |
+
return x.to(orig_type)
|
30 |
+
|
31 |
+
|
32 |
+
class LayerNorm(nn.LayerNorm):
|
33 |
+
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
34 |
+
|
35 |
+
def forward(self, x: torch.Tensor):
|
36 |
+
orig_type = x.dtype
|
37 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
38 |
+
return x.to(orig_type)
|
39 |
+
|
40 |
+
|
41 |
+
class QuickGELU(nn.Module):
|
42 |
+
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
43 |
+
def forward(self, x: torch.Tensor):
|
44 |
+
return x * torch.sigmoid(1.702 * x)
|
45 |
+
|
46 |
+
|
47 |
+
class LayerScale(nn.Module):
|
48 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
49 |
+
super().__init__()
|
50 |
+
self.inplace = inplace
|
51 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
55 |
+
|
56 |
+
|
57 |
+
class PatchDropout(nn.Module):
|
58 |
+
"""
|
59 |
+
https://arxiv.org/abs/2212.00794
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, prob, exclude_first_token=True):
|
63 |
+
super().__init__()
|
64 |
+
assert 0 <= prob < 1.
|
65 |
+
self.prob = prob
|
66 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
if not self.training or self.prob == 0.:
|
70 |
+
return x
|
71 |
+
|
72 |
+
if self.exclude_first_token:
|
73 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
74 |
+
else:
|
75 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
76 |
+
|
77 |
+
batch = x.size()[0]
|
78 |
+
num_tokens = x.size()[1]
|
79 |
+
|
80 |
+
batch_indices = torch.arange(batch)
|
81 |
+
batch_indices = batch_indices[..., None]
|
82 |
+
|
83 |
+
keep_prob = 1 - self.prob
|
84 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
85 |
+
|
86 |
+
rand = torch.randn(batch, num_tokens)
|
87 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
88 |
+
|
89 |
+
x = x[batch_indices, patch_indices_keep]
|
90 |
+
|
91 |
+
if self.exclude_first_token:
|
92 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
93 |
+
|
94 |
+
return x
|
95 |
+
|
96 |
+
|
97 |
+
class Attention(nn.Module):
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
dim,
|
101 |
+
num_heads=8,
|
102 |
+
qkv_bias=True,
|
103 |
+
scaled_cosine=False,
|
104 |
+
scale_heads=False,
|
105 |
+
logit_scale_max=math.log(1. / 0.01),
|
106 |
+
attn_drop=0.,
|
107 |
+
proj_drop=0.
|
108 |
+
):
|
109 |
+
super().__init__()
|
110 |
+
self.scaled_cosine = scaled_cosine
|
111 |
+
self.scale_heads = scale_heads
|
112 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
113 |
+
self.num_heads = num_heads
|
114 |
+
self.head_dim = dim // num_heads
|
115 |
+
self.scale = self.head_dim ** -0.5
|
116 |
+
self.logit_scale_max = logit_scale_max
|
117 |
+
|
118 |
+
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
119 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
120 |
+
if qkv_bias:
|
121 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
122 |
+
else:
|
123 |
+
self.in_proj_bias = None
|
124 |
+
|
125 |
+
if self.scaled_cosine:
|
126 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
127 |
+
else:
|
128 |
+
self.logit_scale = None
|
129 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
130 |
+
if self.scale_heads:
|
131 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
132 |
+
else:
|
133 |
+
self.head_scale = None
|
134 |
+
self.out_proj = nn.Linear(dim, dim)
|
135 |
+
self.out_drop = nn.Dropout(proj_drop)
|
136 |
+
|
137 |
+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
138 |
+
L, N, C = x.shape
|
139 |
+
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
140 |
+
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
141 |
+
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
142 |
+
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
143 |
+
|
144 |
+
if self.logit_scale is not None:
|
145 |
+
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
146 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
147 |
+
attn = attn.view(N, self.num_heads, L, L) * logit_scale
|
148 |
+
attn = attn.view(-1, L, L)
|
149 |
+
else:
|
150 |
+
q = q * self.scale
|
151 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
152 |
+
|
153 |
+
if attn_mask is not None:
|
154 |
+
if attn_mask.dtype == torch.bool:
|
155 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
156 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
157 |
+
attn_mask = new_attn_mask
|
158 |
+
attn += attn_mask
|
159 |
+
|
160 |
+
attn = attn.softmax(dim=-1)
|
161 |
+
attn = self.attn_drop(attn)
|
162 |
+
|
163 |
+
x = torch.bmm(attn, v)
|
164 |
+
if self.head_scale is not None:
|
165 |
+
x = x.view(N, self.num_heads, L, C) * self.head_scale
|
166 |
+
x = x.view(-1, L, C)
|
167 |
+
x = x.transpose(0, 1).reshape(L, N, C)
|
168 |
+
x = self.out_proj(x)
|
169 |
+
x = self.out_drop(x)
|
170 |
+
return x
|
171 |
+
|
172 |
+
|
173 |
+
class AttentionalPooler(nn.Module):
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
d_model: int,
|
177 |
+
context_dim: int,
|
178 |
+
n_head: int = 8,
|
179 |
+
n_queries: int = 256,
|
180 |
+
norm_layer: Callable = LayerNorm
|
181 |
+
):
|
182 |
+
super().__init__()
|
183 |
+
self.query = nn.Parameter(torch.randn(n_queries, d_model))
|
184 |
+
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
|
185 |
+
self.ln_q = norm_layer(d_model)
|
186 |
+
self.ln_k = norm_layer(context_dim)
|
187 |
+
|
188 |
+
def forward(self, x: torch.Tensor):
|
189 |
+
x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
|
190 |
+
N = x.shape[1]
|
191 |
+
q = self.ln_q(self.query)
|
192 |
+
out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
|
193 |
+
return out.permute(1, 0, 2) # LND -> NLD
|
194 |
+
|
195 |
+
def _repeat(self, query, N: int):
|
196 |
+
return query.unsqueeze(1).repeat(1, N, 1)
|
197 |
+
|
198 |
+
|
199 |
+
class ResidualAttentionBlock(nn.Module):
|
200 |
+
def __init__(
|
201 |
+
self,
|
202 |
+
d_model: int,
|
203 |
+
n_head: int,
|
204 |
+
mlp_ratio: float = 4.0,
|
205 |
+
ls_init_value: float = None,
|
206 |
+
act_layer: Callable = nn.GELU,
|
207 |
+
norm_layer: Callable = LayerNorm,
|
208 |
+
is_cross_attention: bool = False,
|
209 |
+
):
|
210 |
+
super().__init__()
|
211 |
+
|
212 |
+
self.ln_1 = norm_layer(d_model)
|
213 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
214 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
215 |
+
if is_cross_attention:
|
216 |
+
self.ln_1_kv = norm_layer(d_model)
|
217 |
+
|
218 |
+
self.ln_2 = norm_layer(d_model)
|
219 |
+
mlp_width = int(d_model * mlp_ratio)
|
220 |
+
self.mlp = nn.Sequential(OrderedDict([
|
221 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
222 |
+
("gelu", act_layer()),
|
223 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
224 |
+
]))
|
225 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
226 |
+
|
227 |
+
def attention(
|
228 |
+
self,
|
229 |
+
q_x: torch.Tensor,
|
230 |
+
k_x: Optional[torch.Tensor] = None,
|
231 |
+
v_x: Optional[torch.Tensor] = None,
|
232 |
+
attn_mask: Optional[torch.Tensor] = None,
|
233 |
+
):
|
234 |
+
k_x = k_x if k_x is not None else q_x
|
235 |
+
v_x = v_x if v_x is not None else q_x
|
236 |
+
|
237 |
+
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
|
238 |
+
return self.attn(
|
239 |
+
q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
|
240 |
+
)[0]
|
241 |
+
|
242 |
+
def forward(
|
243 |
+
self,
|
244 |
+
q_x: torch.Tensor,
|
245 |
+
k_x: Optional[torch.Tensor] = None,
|
246 |
+
v_x: Optional[torch.Tensor] = None,
|
247 |
+
attn_mask: Optional[torch.Tensor] = None,
|
248 |
+
):
|
249 |
+
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
|
250 |
+
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
|
251 |
+
|
252 |
+
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
|
253 |
+
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
254 |
+
return x
|
255 |
+
|
256 |
+
|
257 |
+
class CustomResidualAttentionBlock(nn.Module):
|
258 |
+
def __init__(
|
259 |
+
self,
|
260 |
+
d_model: int,
|
261 |
+
n_head: int,
|
262 |
+
mlp_ratio: float = 4.0,
|
263 |
+
ls_init_value: float = None,
|
264 |
+
act_layer: Callable = nn.GELU,
|
265 |
+
norm_layer: Callable = LayerNorm,
|
266 |
+
scale_cosine_attn: bool = False,
|
267 |
+
scale_heads: bool = False,
|
268 |
+
scale_attn: bool = False,
|
269 |
+
scale_fc: bool = False,
|
270 |
+
):
|
271 |
+
super().__init__()
|
272 |
+
|
273 |
+
self.ln_1 = norm_layer(d_model)
|
274 |
+
self.attn = Attention(
|
275 |
+
d_model, n_head,
|
276 |
+
scaled_cosine=scale_cosine_attn,
|
277 |
+
scale_heads=scale_heads,
|
278 |
+
)
|
279 |
+
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
|
280 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
281 |
+
|
282 |
+
self.ln_2 = norm_layer(d_model)
|
283 |
+
mlp_width = int(d_model * mlp_ratio)
|
284 |
+
self.mlp = nn.Sequential(OrderedDict([
|
285 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
286 |
+
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
|
287 |
+
("gelu", act_layer()),
|
288 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
289 |
+
]))
|
290 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
291 |
+
|
292 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
293 |
+
x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
|
294 |
+
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
295 |
+
return x
|
296 |
+
|
297 |
+
|
298 |
+
class Transformer(nn.Module):
|
299 |
+
def __init__(
|
300 |
+
self,
|
301 |
+
width: int,
|
302 |
+
layers: int,
|
303 |
+
heads: int,
|
304 |
+
mlp_ratio: float = 4.0,
|
305 |
+
ls_init_value: float = None,
|
306 |
+
act_layer: Callable = nn.GELU,
|
307 |
+
norm_layer: Callable = LayerNorm,
|
308 |
+
):
|
309 |
+
super().__init__()
|
310 |
+
self.width = width
|
311 |
+
self.layers = layers
|
312 |
+
self.grad_checkpointing = False
|
313 |
+
|
314 |
+
self.resblocks = nn.ModuleList([
|
315 |
+
ResidualAttentionBlock(
|
316 |
+
width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer)
|
317 |
+
for _ in range(layers)
|
318 |
+
])
|
319 |
+
|
320 |
+
def get_cast_dtype(self) -> torch.dtype:
|
321 |
+
if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):
|
322 |
+
return self.resblocks[0].mlp.c_fc.int8_original_dtype
|
323 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
324 |
+
|
325 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
326 |
+
for r in self.resblocks:
|
327 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
328 |
+
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
329 |
+
x = checkpoint(r, x, None, None, attn_mask)
|
330 |
+
else:
|
331 |
+
x = r(x, attn_mask=attn_mask)
|
332 |
+
return x
|
333 |
+
|
334 |
+
|
335 |
+
class VisionTransformer(nn.Module):
|
336 |
+
output_tokens: torch.jit.Final[bool]
|
337 |
+
|
338 |
+
def __init__(
|
339 |
+
self,
|
340 |
+
image_size: int,
|
341 |
+
patch_size: int,
|
342 |
+
width: int,
|
343 |
+
layers: int,
|
344 |
+
heads: int,
|
345 |
+
mlp_ratio: float,
|
346 |
+
ls_init_value: float = None,
|
347 |
+
global_average_pool: bool = False,
|
348 |
+
attentional_pool: bool = False,
|
349 |
+
n_queries: int = 256,
|
350 |
+
attn_pooler_heads: int = 8,
|
351 |
+
output_dim: int = 512,
|
352 |
+
patch_dropout: float = 0.,
|
353 |
+
input_patchnorm: bool = False,
|
354 |
+
act_layer: Callable = nn.GELU,
|
355 |
+
norm_layer: Callable = LayerNorm,
|
356 |
+
output_tokens: bool = False
|
357 |
+
):
|
358 |
+
super().__init__()
|
359 |
+
self.output_tokens = output_tokens
|
360 |
+
image_height, image_width = self.image_size = to_2tuple(image_size)
|
361 |
+
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
|
362 |
+
self.grid_size = (image_height // patch_height, image_width // patch_width)
|
363 |
+
self.output_dim = output_dim
|
364 |
+
|
365 |
+
# whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1
|
366 |
+
self.input_patchnorm = input_patchnorm
|
367 |
+
|
368 |
+
if input_patchnorm:
|
369 |
+
patch_input_dim = patch_height * patch_width * 3
|
370 |
+
self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
|
371 |
+
self.conv1 = nn.Linear(patch_input_dim, width)
|
372 |
+
else:
|
373 |
+
self.patchnorm_pre_ln = nn.Identity()
|
374 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
375 |
+
|
376 |
+
# class embeddings and positional embeddings
|
377 |
+
scale = width ** -0.5
|
378 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
379 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
|
380 |
+
|
381 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
382 |
+
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
383 |
+
|
384 |
+
self.ln_pre = norm_layer(width)
|
385 |
+
self.transformer = Transformer(
|
386 |
+
width,
|
387 |
+
layers,
|
388 |
+
heads,
|
389 |
+
mlp_ratio,
|
390 |
+
ls_init_value=ls_init_value,
|
391 |
+
act_layer=act_layer,
|
392 |
+
norm_layer=norm_layer,
|
393 |
+
)
|
394 |
+
|
395 |
+
self.global_average_pool = global_average_pool
|
396 |
+
if attentional_pool:
|
397 |
+
self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
|
398 |
+
self.ln_post = norm_layer(output_dim)
|
399 |
+
self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
|
400 |
+
else:
|
401 |
+
self.attn_pool = None
|
402 |
+
self.ln_post = norm_layer(width)
|
403 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
404 |
+
|
405 |
+
self.init_parameters()
|
406 |
+
|
407 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
408 |
+
for param in self.parameters():
|
409 |
+
param.requires_grad = False
|
410 |
+
|
411 |
+
if unlocked_groups != 0:
|
412 |
+
groups = [
|
413 |
+
[
|
414 |
+
self.conv1,
|
415 |
+
self.class_embedding,
|
416 |
+
self.positional_embedding,
|
417 |
+
self.ln_pre,
|
418 |
+
],
|
419 |
+
*self.transformer.resblocks[:-1],
|
420 |
+
[
|
421 |
+
self.transformer.resblocks[-1],
|
422 |
+
self.ln_post,
|
423 |
+
],
|
424 |
+
self.proj,
|
425 |
+
]
|
426 |
+
|
427 |
+
def _unlock(x):
|
428 |
+
if isinstance(x, Sequence):
|
429 |
+
for g in x:
|
430 |
+
_unlock(g)
|
431 |
+
else:
|
432 |
+
if isinstance(x, torch.nn.Parameter):
|
433 |
+
x.requires_grad = True
|
434 |
+
else:
|
435 |
+
for p in x.parameters():
|
436 |
+
p.requires_grad = True
|
437 |
+
|
438 |
+
_unlock(groups[-unlocked_groups:])
|
439 |
+
|
440 |
+
def init_parameters(self):
|
441 |
+
# FIXME OpenAI CLIP did not define an init for the VisualTransformer
|
442 |
+
# TODO experiment if default PyTorch init, below, or alternate init is best.
|
443 |
+
|
444 |
+
# nn.init.normal_(self.class_embedding, std=self.scale)
|
445 |
+
# nn.init.normal_(self.positional_embedding, std=self.scale)
|
446 |
+
#
|
447 |
+
# proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
448 |
+
# attn_std = self.transformer.width ** -0.5
|
449 |
+
# fc_std = (2 * self.transformer.width) ** -0.5
|
450 |
+
# for block in self.transformer.resblocks:
|
451 |
+
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
452 |
+
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
453 |
+
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
454 |
+
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
455 |
+
#
|
456 |
+
# if self.text_projection is not None:
|
457 |
+
# nn.init.normal_(self.text_projection, std=self.scale)
|
458 |
+
pass
|
459 |
+
|
460 |
+
@torch.jit.ignore
|
461 |
+
def set_grad_checkpointing(self, enable=True):
|
462 |
+
self.transformer.grad_checkpointing = enable
|
463 |
+
|
464 |
+
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
465 |
+
if self.global_average_pool:
|
466 |
+
return x.mean(dim=1), x
|
467 |
+
else:
|
468 |
+
return x[:, 0], x[:, 1:]
|
469 |
+
|
470 |
+
def forward(self, x: torch.Tensor):
|
471 |
+
|
472 |
+
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
|
473 |
+
if self.input_patchnorm:
|
474 |
+
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
|
475 |
+
x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1])
|
476 |
+
x = x.permute(0, 2, 4, 1, 3, 5)
|
477 |
+
x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
|
478 |
+
x = self.patchnorm_pre_ln(x)
|
479 |
+
x = self.conv1(x)
|
480 |
+
else:
|
481 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
482 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
483 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
484 |
+
|
485 |
+
# class embeddings and positional embeddings
|
486 |
+
x = torch.cat(
|
487 |
+
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
488 |
+
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
489 |
+
x = x + self.positional_embedding.to(x.dtype)
|
490 |
+
|
491 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
492 |
+
x = self.patch_dropout(x)
|
493 |
+
x = self.ln_pre(x)
|
494 |
+
|
495 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
496 |
+
x = self.transformer(x)
|
497 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
498 |
+
|
499 |
+
if self.attn_pool is not None:
|
500 |
+
x = self.attn_pool(x)
|
501 |
+
x = self.ln_post(x)
|
502 |
+
pooled, tokens = self._global_pool(x)
|
503 |
+
else:
|
504 |
+
pooled, tokens = self._global_pool(x)
|
505 |
+
pooled = self.ln_post(pooled)
|
506 |
+
|
507 |
+
if self.proj is not None:
|
508 |
+
pooled = pooled @ self.proj
|
509 |
+
|
510 |
+
if self.output_tokens:
|
511 |
+
return pooled, tokens
|
512 |
+
|
513 |
+
return pooled
|
514 |
+
|
515 |
+
|
516 |
+
class TextTransformer(nn.Module):
|
517 |
+
output_tokens: torch.jit.Final[bool]
|
518 |
+
|
519 |
+
def __init__(
|
520 |
+
self,
|
521 |
+
context_length: int = 77,
|
522 |
+
vocab_size: int = 49408,
|
523 |
+
width: int = 512,
|
524 |
+
heads: int = 8,
|
525 |
+
layers: int = 12,
|
526 |
+
ls_init_value: float = None,
|
527 |
+
output_dim: int = 512,
|
528 |
+
act_layer: Callable = nn.GELU,
|
529 |
+
norm_layer: Callable = LayerNorm,
|
530 |
+
embed_cls: bool = False,
|
531 |
+
pad_id: int = 0,
|
532 |
+
output_tokens: bool = False,
|
533 |
+
):
|
534 |
+
super().__init__()
|
535 |
+
self.output_tokens = output_tokens
|
536 |
+
self.num_pos = self.context_length = context_length
|
537 |
+
self.vocab_size = vocab_size
|
538 |
+
self.width = width
|
539 |
+
self.output_dim = output_dim
|
540 |
+
self.heads = heads
|
541 |
+
self.pad_id = pad_id
|
542 |
+
|
543 |
+
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
544 |
+
|
545 |
+
if embed_cls:
|
546 |
+
self.cls_emb = nn.Parameter(torch.empty(width))
|
547 |
+
self.num_pos += 1
|
548 |
+
else:
|
549 |
+
self.cls_emb = None
|
550 |
+
|
551 |
+
self.token_embedding = nn.Embedding(vocab_size, width)
|
552 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
|
553 |
+
self.transformer = Transformer(
|
554 |
+
width=width,
|
555 |
+
layers=layers,
|
556 |
+
heads=heads,
|
557 |
+
ls_init_value=ls_init_value,
|
558 |
+
act_layer=act_layer,
|
559 |
+
norm_layer=norm_layer,
|
560 |
+
)
|
561 |
+
self.ln_final = norm_layer(width)
|
562 |
+
|
563 |
+
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
564 |
+
|
565 |
+
self.init_parameters()
|
566 |
+
|
567 |
+
def init_parameters(self):
|
568 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
569 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
570 |
+
if self.cls_emb is not None:
|
571 |
+
nn.init.normal_(self.cls_emb, std=0.01)
|
572 |
+
|
573 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
574 |
+
attn_std = self.transformer.width ** -0.5
|
575 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
576 |
+
for block in self.transformer.resblocks:
|
577 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
578 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
579 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
580 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
581 |
+
|
582 |
+
if self.text_projection is not None:
|
583 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
584 |
+
|
585 |
+
@torch.jit.ignore
|
586 |
+
def set_grad_checkpointing(self, enable=True):
|
587 |
+
self.transformer.grad_checkpointing = enable
|
588 |
+
|
589 |
+
def build_attention_mask(self):
|
590 |
+
# lazily create causal attention mask, with full attention between the tokens
|
591 |
+
# pytorch uses additive attention mask; fill with -inf
|
592 |
+
mask = torch.empty(self.num_pos, self.num_pos)
|
593 |
+
mask.fill_(float("-inf"))
|
594 |
+
mask.triu_(1) # zero out the lower diagonal
|
595 |
+
return mask
|
596 |
+
|
597 |
+
def build_cls_mask(self, text, cast_dtype: torch.dtype):
|
598 |
+
cls_mask = (text != self.pad_id).unsqueeze(1)
|
599 |
+
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
|
600 |
+
additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
|
601 |
+
additive_mask.fill_(0)
|
602 |
+
additive_mask.masked_fill_(~cls_mask, float("-inf"))
|
603 |
+
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
|
604 |
+
return additive_mask
|
605 |
+
|
606 |
+
def _repeat(self, t, N: int):
|
607 |
+
return t.reshape(1, 1, -1).repeat(N, 1, 1)
|
608 |
+
|
609 |
+
def forward(self, text):
|
610 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
611 |
+
seq_len = text.shape[1]
|
612 |
+
|
613 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
614 |
+
attn_mask = self.attn_mask
|
615 |
+
if self.cls_emb is not None:
|
616 |
+
seq_len += 1
|
617 |
+
x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
|
618 |
+
cls_mask = self.build_cls_mask(text, cast_dtype)
|
619 |
+
attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
|
620 |
+
|
621 |
+
x = x + self.positional_embedding[:seq_len].to(cast_dtype)
|
622 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
623 |
+
x = self.transformer(x, attn_mask=attn_mask)
|
624 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
625 |
+
|
626 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
627 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
628 |
+
if self.cls_emb is not None:
|
629 |
+
pooled, tokens = x[:, -1], x[:, :-1]
|
630 |
+
pooled = self.ln_final(pooled)
|
631 |
+
else:
|
632 |
+
x = self.ln_final(x)
|
633 |
+
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
|
634 |
+
|
635 |
+
if self.text_projection is not None:
|
636 |
+
pooled = pooled @ self.text_projection
|
637 |
+
|
638 |
+
if self.output_tokens:
|
639 |
+
return pooled, tokens
|
640 |
+
|
641 |
+
return pooled
|
642 |
+
|
643 |
+
|
644 |
+
class MultimodalTransformer(Transformer):
|
645 |
+
def __init__(
|
646 |
+
self,
|
647 |
+
width: int,
|
648 |
+
layers: int,
|
649 |
+
heads: int,
|
650 |
+
context_length: int = 77,
|
651 |
+
mlp_ratio: float = 4.0,
|
652 |
+
ls_init_value: float = None,
|
653 |
+
act_layer: Callable = nn.GELU,
|
654 |
+
norm_layer: Callable = LayerNorm,
|
655 |
+
output_dim: int = 512,
|
656 |
+
):
|
657 |
+
|
658 |
+
super().__init__(
|
659 |
+
width=width,
|
660 |
+
layers=layers,
|
661 |
+
heads=heads,
|
662 |
+
mlp_ratio=mlp_ratio,
|
663 |
+
ls_init_value=ls_init_value,
|
664 |
+
act_layer=act_layer,
|
665 |
+
norm_layer=norm_layer,
|
666 |
+
)
|
667 |
+
self.context_length = context_length
|
668 |
+
self.cross_attn = nn.ModuleList([
|
669 |
+
ResidualAttentionBlock(
|
670 |
+
width,
|
671 |
+
heads,
|
672 |
+
mlp_ratio,
|
673 |
+
ls_init_value=ls_init_value,
|
674 |
+
act_layer=act_layer,
|
675 |
+
norm_layer=norm_layer,
|
676 |
+
is_cross_attention=True,
|
677 |
+
)
|
678 |
+
for _ in range(layers)
|
679 |
+
])
|
680 |
+
|
681 |
+
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
682 |
+
|
683 |
+
self.ln_final = norm_layer(width)
|
684 |
+
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
685 |
+
|
686 |
+
def init_parameters(self):
|
687 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
688 |
+
attn_std = self.transformer.width ** -0.5
|
689 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
690 |
+
for block in self.transformer.resblocks:
|
691 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
692 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
693 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
694 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
695 |
+
for block in self.transformer.cross_attn:
|
696 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
697 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
698 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
699 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
700 |
+
|
701 |
+
if self.text_projection is not None:
|
702 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
703 |
+
|
704 |
+
def build_attention_mask(self):
|
705 |
+
# lazily create causal attention mask, with full attention between the tokens
|
706 |
+
# pytorch uses additive attention mask; fill with -inf
|
707 |
+
mask = torch.empty(self.context_length, self.context_length)
|
708 |
+
mask.fill_(float("-inf"))
|
709 |
+
mask.triu_(1) # zero out the lower diagonal
|
710 |
+
return mask
|
711 |
+
|
712 |
+
def forward(self, image_embs, text_embs):
|
713 |
+
text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
|
714 |
+
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
|
715 |
+
seq_len = text_embs.shape[0]
|
716 |
+
|
717 |
+
for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
|
718 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
719 |
+
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
720 |
+
text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
|
721 |
+
text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
|
722 |
+
else:
|
723 |
+
text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
|
724 |
+
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
|
725 |
+
|
726 |
+
x = text_embs.permute(1, 0, 2) # LND -> NLD
|
727 |
+
x = self.ln_final(x)
|
728 |
+
|
729 |
+
if self.text_projection is not None:
|
730 |
+
x = x @ self.text_projection
|
731 |
+
|
732 |
+
return x
|
733 |
+
|
734 |
+
@torch.jit.ignore
|
735 |
+
def set_grad_checkpointing(self, enable=True):
|
736 |
+
self.grad_checkpointing = enable
|
model/scunet.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from einops import rearrange
|
5 |
+
from einops.layers.torch import Rearrange
|
6 |
+
from timm.models.layers import trunc_normal_, DropPath
|
7 |
+
|
8 |
+
|
9 |
+
class WMSA(nn.Module):
|
10 |
+
""" Self-attention module in Swin Transformer
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
|
14 |
+
super(WMSA, self).__init__()
|
15 |
+
self.input_dim = input_dim
|
16 |
+
self.output_dim = output_dim
|
17 |
+
self.head_dim = head_dim
|
18 |
+
self.scale = self.head_dim ** -0.5
|
19 |
+
self.n_heads = input_dim//head_dim
|
20 |
+
self.window_size = window_size
|
21 |
+
self.type=type
|
22 |
+
self.embedding_layer = nn.Linear(self.input_dim, 3*self.input_dim, bias=True)
|
23 |
+
|
24 |
+
# TODO recover
|
25 |
+
# self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1))
|
26 |
+
self.relative_position_params = nn.Parameter(torch.zeros((2 * window_size - 1)*(2 * window_size -1), self.n_heads))
|
27 |
+
|
28 |
+
self.linear = nn.Linear(self.input_dim, self.output_dim)
|
29 |
+
|
30 |
+
trunc_normal_(self.relative_position_params, std=.02)
|
31 |
+
self.relative_position_params = torch.nn.Parameter(self.relative_position_params.view(2*window_size-1, 2*window_size-1, self.n_heads).transpose(1,2).transpose(0,1))
|
32 |
+
|
33 |
+
def generate_mask(self, h, w, p, shift):
|
34 |
+
""" generating the mask of SW-MSA
|
35 |
+
Args:
|
36 |
+
shift: shift parameters in CyclicShift.
|
37 |
+
Returns:
|
38 |
+
attn_mask: should be (1 1 w p p),
|
39 |
+
"""
|
40 |
+
# supporting sqaure.
|
41 |
+
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
|
42 |
+
if self.type == 'W':
|
43 |
+
return attn_mask
|
44 |
+
|
45 |
+
s = p - shift
|
46 |
+
attn_mask[-1, :, :s, :, s:, :] = True
|
47 |
+
attn_mask[-1, :, s:, :, :s, :] = True
|
48 |
+
attn_mask[:, -1, :, :s, :, s:] = True
|
49 |
+
attn_mask[:, -1, :, s:, :, :s] = True
|
50 |
+
attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
|
51 |
+
return attn_mask
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
""" Forward pass of Window Multi-head Self-attention module.
|
55 |
+
Args:
|
56 |
+
x: input tensor with shape of [b h w c];
|
57 |
+
attn_mask: attention mask, fill -inf where the value is True;
|
58 |
+
Returns:
|
59 |
+
output: tensor shape [b h w c]
|
60 |
+
"""
|
61 |
+
if self.type!='W': x = torch.roll(x, shifts=(-(self.window_size//2), -(self.window_size//2)), dims=(1,2))
|
62 |
+
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
63 |
+
h_windows = x.size(1)
|
64 |
+
w_windows = x.size(2)
|
65 |
+
# sqaure validation
|
66 |
+
# assert h_windows == w_windows
|
67 |
+
|
68 |
+
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
|
69 |
+
qkv = self.embedding_layer(x)
|
70 |
+
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
|
71 |
+
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
|
72 |
+
# Adding learnable relative embedding
|
73 |
+
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
|
74 |
+
# Using Attn Mask to distinguish different subwindows.
|
75 |
+
if self.type != 'W':
|
76 |
+
attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size//2)
|
77 |
+
sim = sim.masked_fill_(attn_mask, float("-inf"))
|
78 |
+
|
79 |
+
probs = nn.functional.softmax(sim, dim=-1)
|
80 |
+
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
|
81 |
+
output = rearrange(output, 'h b w p c -> b w p (h c)')
|
82 |
+
output = self.linear(output)
|
83 |
+
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
84 |
+
|
85 |
+
if self.type!='W': output = torch.roll(output, shifts=(self.window_size//2, self.window_size//2), dims=(1,2))
|
86 |
+
return output
|
87 |
+
|
88 |
+
def relative_embedding(self):
|
89 |
+
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
|
90 |
+
relation = cord[:, None, :] - cord[None, :, :] + self.window_size -1
|
91 |
+
# negative is allowed
|
92 |
+
return self.relative_position_params[:, relation[:,:,0].long(), relation[:,:,1].long()]
|
93 |
+
|
94 |
+
|
95 |
+
class Block(nn.Module):
|
96 |
+
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
97 |
+
""" SwinTransformer Block
|
98 |
+
"""
|
99 |
+
super(Block, self).__init__()
|
100 |
+
self.input_dim = input_dim
|
101 |
+
self.output_dim = output_dim
|
102 |
+
assert type in ['W', 'SW']
|
103 |
+
self.type = type
|
104 |
+
if input_resolution <= window_size:
|
105 |
+
self.type = 'W'
|
106 |
+
|
107 |
+
print("Block Initial Type: {}, drop_path_rate:{:.6f}".format(self.type, drop_path))
|
108 |
+
self.ln1 = nn.LayerNorm(input_dim)
|
109 |
+
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
|
110 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
111 |
+
self.ln2 = nn.LayerNorm(input_dim)
|
112 |
+
self.mlp = nn.Sequential(
|
113 |
+
nn.Linear(input_dim, 4 * input_dim),
|
114 |
+
nn.GELU(),
|
115 |
+
nn.Linear(4 * input_dim, output_dim),
|
116 |
+
)
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
x = x + self.drop_path(self.msa(self.ln1(x)))
|
120 |
+
x = x + self.drop_path(self.mlp(self.ln2(x)))
|
121 |
+
return x
|
122 |
+
|
123 |
+
|
124 |
+
class ConvTransBlock(nn.Module):
|
125 |
+
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
126 |
+
""" SwinTransformer and Conv Block
|
127 |
+
"""
|
128 |
+
super(ConvTransBlock, self).__init__()
|
129 |
+
self.conv_dim = conv_dim
|
130 |
+
self.trans_dim = trans_dim
|
131 |
+
self.head_dim = head_dim
|
132 |
+
self.window_size = window_size
|
133 |
+
self.drop_path = drop_path
|
134 |
+
self.type = type
|
135 |
+
self.input_resolution = input_resolution
|
136 |
+
|
137 |
+
assert self.type in ['W', 'SW']
|
138 |
+
if self.input_resolution <= self.window_size:
|
139 |
+
self.type = 'W'
|
140 |
+
|
141 |
+
self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, self.type, self.input_resolution)
|
142 |
+
self.conv1_1 = nn.Conv2d(self.conv_dim+self.trans_dim, self.conv_dim+self.trans_dim, 1, 1, 0, bias=True)
|
143 |
+
self.conv1_2 = nn.Conv2d(self.conv_dim+self.trans_dim, self.conv_dim+self.trans_dim, 1, 1, 0, bias=True)
|
144 |
+
|
145 |
+
self.conv_block = nn.Sequential(
|
146 |
+
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
147 |
+
nn.ReLU(True),
|
148 |
+
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
|
149 |
+
)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
|
153 |
+
conv_x = self.conv_block(conv_x) + conv_x
|
154 |
+
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
|
155 |
+
trans_x = self.trans_block(trans_x)
|
156 |
+
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
|
157 |
+
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
|
158 |
+
x = x + res
|
159 |
+
|
160 |
+
return x
|
161 |
+
|
162 |
+
|
163 |
+
class SCUNet(nn.Module):
|
164 |
+
|
165 |
+
def __init__(self, in_nc=3, config=[2,2,2,2,2,2,2], dim=64, drop_path_rate=0.0, input_resolution=256):
|
166 |
+
super(SCUNet, self).__init__()
|
167 |
+
self.config = config
|
168 |
+
self.dim = dim
|
169 |
+
self.head_dim = 32
|
170 |
+
self.window_size = 8
|
171 |
+
|
172 |
+
# drop path rate for each layer
|
173 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
|
174 |
+
|
175 |
+
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
|
176 |
+
|
177 |
+
begin = 0
|
178 |
+
self.m_down1 = [ConvTransBlock(dim//2, dim//2, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution)
|
179 |
+
for i in range(config[0])] + \
|
180 |
+
[nn.Conv2d(dim, 2*dim, 2, 2, 0, bias=False)]
|
181 |
+
|
182 |
+
begin += config[0]
|
183 |
+
self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//2)
|
184 |
+
for i in range(config[1])] + \
|
185 |
+
[nn.Conv2d(2*dim, 4*dim, 2, 2, 0, bias=False)]
|
186 |
+
|
187 |
+
begin += config[1]
|
188 |
+
self.m_down3 = [ConvTransBlock(2*dim, 2*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW',input_resolution//4)
|
189 |
+
for i in range(config[2])] + \
|
190 |
+
[nn.Conv2d(4*dim, 8*dim, 2, 2, 0, bias=False)]
|
191 |
+
|
192 |
+
begin += config[2]
|
193 |
+
self.m_body = [ConvTransBlock(4*dim, 4*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//8)
|
194 |
+
for i in range(config[3])]
|
195 |
+
|
196 |
+
begin += config[3]
|
197 |
+
self.m_up3 = [nn.ConvTranspose2d(8*dim, 4*dim, 2, 2, 0, bias=False),] + \
|
198 |
+
[ConvTransBlock(2*dim, 2*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW',input_resolution//4)
|
199 |
+
for i in range(config[4])]
|
200 |
+
|
201 |
+
begin += config[4]
|
202 |
+
self.m_up2 = [nn.ConvTranspose2d(4*dim, 2*dim, 2, 2, 0, bias=False),] + \
|
203 |
+
[ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//2)
|
204 |
+
for i in range(config[5])]
|
205 |
+
|
206 |
+
begin += config[5]
|
207 |
+
self.m_up1 = [nn.ConvTranspose2d(2*dim, dim, 2, 2, 0, bias=False),] + \
|
208 |
+
[ConvTransBlock(dim//2, dim//2, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution)
|
209 |
+
for i in range(config[6])]
|
210 |
+
|
211 |
+
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
|
212 |
+
|
213 |
+
self.m_head = nn.Sequential(*self.m_head)
|
214 |
+
self.m_down1 = nn.Sequential(*self.m_down1)
|
215 |
+
self.m_down2 = nn.Sequential(*self.m_down2)
|
216 |
+
self.m_down3 = nn.Sequential(*self.m_down3)
|
217 |
+
self.m_body = nn.Sequential(*self.m_body)
|
218 |
+
self.m_up3 = nn.Sequential(*self.m_up3)
|
219 |
+
self.m_up2 = nn.Sequential(*self.m_up2)
|
220 |
+
self.m_up1 = nn.Sequential(*self.m_up1)
|
221 |
+
self.m_tail = nn.Sequential(*self.m_tail)
|
222 |
+
#self.apply(self._init_weights)
|
223 |
+
|
224 |
+
def forward(self, x0):
|
225 |
+
|
226 |
+
h, w = x0.size()[-2:]
|
227 |
+
paddingBottom = int(np.ceil(h/64)*64-h)
|
228 |
+
paddingRight = int(np.ceil(w/64)*64-w)
|
229 |
+
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
|
230 |
+
|
231 |
+
x1 = self.m_head(x0)
|
232 |
+
x2 = self.m_down1(x1)
|
233 |
+
x3 = self.m_down2(x2)
|
234 |
+
x4 = self.m_down3(x3)
|
235 |
+
x = self.m_body(x4)
|
236 |
+
x = self.m_up3(x+x4)
|
237 |
+
x = self.m_up2(x+x3)
|
238 |
+
x = self.m_up1(x+x2)
|
239 |
+
x = self.m_tail(x+x1)
|
240 |
+
|
241 |
+
x = x[..., :h, :w]
|
242 |
+
|
243 |
+
return x
|
244 |
+
|
245 |
+
|
246 |
+
def _init_weights(self, m):
|
247 |
+
if isinstance(m, nn.Linear):
|
248 |
+
trunc_normal_(m.weight, std=.02)
|
249 |
+
if m.bias is not None:
|
250 |
+
nn.init.constant_(m.bias, 0)
|
251 |
+
elif isinstance(m, nn.LayerNorm):
|
252 |
+
nn.init.constant_(m.bias, 0)
|
253 |
+
nn.init.constant_(m.weight, 1.0)
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
if __name__ == '__main__':
|
258 |
+
|
259 |
+
# torch.cuda.empty_cache()
|
260 |
+
net = SCUNet()
|
261 |
+
|
262 |
+
x = torch.randn((2, 3, 64, 128))
|
263 |
+
x = net(x)
|
264 |
+
print(x.shape)
|
model/swinir.py
ADDED
@@ -0,0 +1,905 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------------
|
2 |
+
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
3 |
+
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
4 |
+
# -----------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
# Originally borrowed from DifFace (https://github.com/zsyOAOA/DifFace/blob/master/models/swinir.py)
|
7 |
+
|
8 |
+
import math
|
9 |
+
from typing import Set
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torch.utils.checkpoint as checkpoint
|
15 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
16 |
+
|
17 |
+
|
18 |
+
class Mlp(nn.Module):
|
19 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
20 |
+
super().__init__()
|
21 |
+
out_features = out_features or in_features
|
22 |
+
hidden_features = hidden_features or in_features
|
23 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
24 |
+
self.act = act_layer()
|
25 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
26 |
+
self.drop = nn.Dropout(drop)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x = self.fc1(x)
|
30 |
+
x = self.act(x)
|
31 |
+
x = self.drop(x)
|
32 |
+
x = self.fc2(x)
|
33 |
+
x = self.drop(x)
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
def window_partition(x, window_size):
|
38 |
+
"""
|
39 |
+
Args:
|
40 |
+
x: (B, H, W, C)
|
41 |
+
window_size (int): window size
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
windows: (num_windows*B, window_size, window_size, C)
|
45 |
+
"""
|
46 |
+
B, H, W, C = x.shape
|
47 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
48 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
49 |
+
return windows
|
50 |
+
|
51 |
+
|
52 |
+
def window_reverse(windows, window_size, H, W):
|
53 |
+
"""
|
54 |
+
Args:
|
55 |
+
windows: (num_windows*B, window_size, window_size, C)
|
56 |
+
window_size (int): Window size
|
57 |
+
H (int): Height of image
|
58 |
+
W (int): Width of image
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
x: (B, H, W, C)
|
62 |
+
"""
|
63 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
64 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
65 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
class WindowAttention(nn.Module):
|
70 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
71 |
+
It supports both of shifted and non-shifted window.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
dim (int): Number of input channels.
|
75 |
+
window_size (tuple[int]): The height and width of the window.
|
76 |
+
num_heads (int): Number of attention heads.
|
77 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
78 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
79 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
80 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
84 |
+
|
85 |
+
super().__init__()
|
86 |
+
self.dim = dim
|
87 |
+
self.window_size = window_size # Wh, Ww
|
88 |
+
self.num_heads = num_heads
|
89 |
+
head_dim = dim // num_heads
|
90 |
+
self.scale = qk_scale or head_dim ** -0.5
|
91 |
+
|
92 |
+
# define a parameter table of relative position bias
|
93 |
+
self.relative_position_bias_table = nn.Parameter(
|
94 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
95 |
+
|
96 |
+
# get pair-wise relative position index for each token inside the window
|
97 |
+
coords_h = torch.arange(self.window_size[0])
|
98 |
+
coords_w = torch.arange(self.window_size[1])
|
99 |
+
# coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
100 |
+
# Fix: Pass indexing="ij" to avoid warning
|
101 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
|
102 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
103 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
104 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
105 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
106 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
107 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
108 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
109 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
110 |
+
|
111 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
112 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
113 |
+
self.proj = nn.Linear(dim, dim)
|
114 |
+
|
115 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
116 |
+
|
117 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
118 |
+
self.softmax = nn.Softmax(dim=-1)
|
119 |
+
|
120 |
+
def forward(self, x, mask=None):
|
121 |
+
"""
|
122 |
+
Args:
|
123 |
+
x: input features with shape of (num_windows*B, N, C)
|
124 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
125 |
+
"""
|
126 |
+
B_, N, C = x.shape
|
127 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
128 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
129 |
+
|
130 |
+
q = q * self.scale
|
131 |
+
attn = (q @ k.transpose(-2, -1))
|
132 |
+
|
133 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
134 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
135 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
136 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
137 |
+
|
138 |
+
if mask is not None:
|
139 |
+
nW = mask.shape[0]
|
140 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
141 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
142 |
+
attn = self.softmax(attn)
|
143 |
+
else:
|
144 |
+
attn = self.softmax(attn)
|
145 |
+
|
146 |
+
attn = self.attn_drop(attn)
|
147 |
+
|
148 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
149 |
+
x = self.proj(x)
|
150 |
+
x = self.proj_drop(x)
|
151 |
+
return x
|
152 |
+
|
153 |
+
def extra_repr(self) -> str:
|
154 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
155 |
+
|
156 |
+
def flops(self, N):
|
157 |
+
# calculate flops for 1 window with token length of N
|
158 |
+
flops = 0
|
159 |
+
# qkv = self.qkv(x)
|
160 |
+
flops += N * self.dim * 3 * self.dim
|
161 |
+
# attn = (q @ k.transpose(-2, -1))
|
162 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
163 |
+
# x = (attn @ v)
|
164 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
165 |
+
# x = self.proj(x)
|
166 |
+
flops += N * self.dim * self.dim
|
167 |
+
return flops
|
168 |
+
|
169 |
+
|
170 |
+
class SwinTransformerBlock(nn.Module):
|
171 |
+
r""" Swin Transformer Block.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
dim (int): Number of input channels.
|
175 |
+
input_resolution (tuple[int]): Input resulotion.
|
176 |
+
num_heads (int): Number of attention heads.
|
177 |
+
window_size (int): Window size.
|
178 |
+
shift_size (int): Shift size for SW-MSA.
|
179 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
180 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
181 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
182 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
183 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
184 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
185 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
186 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
190 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
191 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
192 |
+
super().__init__()
|
193 |
+
self.dim = dim
|
194 |
+
self.input_resolution = input_resolution
|
195 |
+
self.num_heads = num_heads
|
196 |
+
self.window_size = window_size
|
197 |
+
self.shift_size = shift_size
|
198 |
+
self.mlp_ratio = mlp_ratio
|
199 |
+
if min(self.input_resolution) <= self.window_size:
|
200 |
+
# if window size is larger than input resolution, we don't partition windows
|
201 |
+
self.shift_size = 0
|
202 |
+
self.window_size = min(self.input_resolution)
|
203 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
204 |
+
|
205 |
+
self.norm1 = norm_layer(dim)
|
206 |
+
self.attn = WindowAttention(
|
207 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
208 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
209 |
+
|
210 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
211 |
+
self.norm2 = norm_layer(dim)
|
212 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
213 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
214 |
+
|
215 |
+
if self.shift_size > 0:
|
216 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
217 |
+
else:
|
218 |
+
attn_mask = None
|
219 |
+
|
220 |
+
self.register_buffer("attn_mask", attn_mask)
|
221 |
+
|
222 |
+
def calculate_mask(self, x_size):
|
223 |
+
# calculate attention mask for SW-MSA
|
224 |
+
H, W = x_size
|
225 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
226 |
+
h_slices = (slice(0, -self.window_size),
|
227 |
+
slice(-self.window_size, -self.shift_size),
|
228 |
+
slice(-self.shift_size, None))
|
229 |
+
w_slices = (slice(0, -self.window_size),
|
230 |
+
slice(-self.window_size, -self.shift_size),
|
231 |
+
slice(-self.shift_size, None))
|
232 |
+
cnt = 0
|
233 |
+
for h in h_slices:
|
234 |
+
for w in w_slices:
|
235 |
+
img_mask[:, h, w, :] = cnt
|
236 |
+
cnt += 1
|
237 |
+
|
238 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
239 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
240 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
241 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
242 |
+
|
243 |
+
return attn_mask
|
244 |
+
|
245 |
+
def forward(self, x, x_size):
|
246 |
+
H, W = x_size
|
247 |
+
B, L, C = x.shape
|
248 |
+
# assert L == H * W, "input feature has wrong size"
|
249 |
+
|
250 |
+
shortcut = x
|
251 |
+
x = self.norm1(x)
|
252 |
+
x = x.view(B, H, W, C)
|
253 |
+
|
254 |
+
# cyclic shift
|
255 |
+
if self.shift_size > 0:
|
256 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
257 |
+
else:
|
258 |
+
shifted_x = x
|
259 |
+
|
260 |
+
# partition windows
|
261 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
262 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
263 |
+
|
264 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
265 |
+
if self.input_resolution == x_size:
|
266 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
267 |
+
else:
|
268 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
269 |
+
|
270 |
+
# merge windows
|
271 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
272 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
273 |
+
|
274 |
+
# reverse cyclic shift
|
275 |
+
if self.shift_size > 0:
|
276 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
277 |
+
else:
|
278 |
+
x = shifted_x
|
279 |
+
x = x.view(B, H * W, C)
|
280 |
+
|
281 |
+
# FFN
|
282 |
+
x = shortcut + self.drop_path(x)
|
283 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
284 |
+
|
285 |
+
return x
|
286 |
+
|
287 |
+
def extra_repr(self) -> str:
|
288 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
289 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
290 |
+
|
291 |
+
def flops(self):
|
292 |
+
flops = 0
|
293 |
+
H, W = self.input_resolution
|
294 |
+
# norm1
|
295 |
+
flops += self.dim * H * W
|
296 |
+
# W-MSA/SW-MSA
|
297 |
+
nW = H * W / self.window_size / self.window_size
|
298 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
299 |
+
# mlp
|
300 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
301 |
+
# norm2
|
302 |
+
flops += self.dim * H * W
|
303 |
+
return flops
|
304 |
+
|
305 |
+
|
306 |
+
class PatchMerging(nn.Module):
|
307 |
+
r""" Patch Merging Layer.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
311 |
+
dim (int): Number of input channels.
|
312 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
313 |
+
"""
|
314 |
+
|
315 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
316 |
+
super().__init__()
|
317 |
+
self.input_resolution = input_resolution
|
318 |
+
self.dim = dim
|
319 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
320 |
+
self.norm = norm_layer(4 * dim)
|
321 |
+
|
322 |
+
def forward(self, x):
|
323 |
+
"""
|
324 |
+
x: B, H*W, C
|
325 |
+
"""
|
326 |
+
H, W = self.input_resolution
|
327 |
+
B, L, C = x.shape
|
328 |
+
assert L == H * W, "input feature has wrong size"
|
329 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
330 |
+
|
331 |
+
x = x.view(B, H, W, C)
|
332 |
+
|
333 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
334 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
335 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
336 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
337 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
338 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
339 |
+
|
340 |
+
x = self.norm(x)
|
341 |
+
x = self.reduction(x)
|
342 |
+
|
343 |
+
return x
|
344 |
+
|
345 |
+
def extra_repr(self) -> str:
|
346 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
347 |
+
|
348 |
+
def flops(self):
|
349 |
+
H, W = self.input_resolution
|
350 |
+
flops = H * W * self.dim
|
351 |
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
352 |
+
return flops
|
353 |
+
|
354 |
+
|
355 |
+
class BasicLayer(nn.Module):
|
356 |
+
""" A basic Swin Transformer layer for one stage.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
dim (int): Number of input channels.
|
360 |
+
input_resolution (tuple[int]): Input resolution.
|
361 |
+
depth (int): Number of blocks.
|
362 |
+
num_heads (int): Number of attention heads.
|
363 |
+
window_size (int): Local window size.
|
364 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
365 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
366 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
367 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
368 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
369 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
370 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
371 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
372 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
373 |
+
"""
|
374 |
+
|
375 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
376 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
377 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
378 |
+
|
379 |
+
super().__init__()
|
380 |
+
self.dim = dim
|
381 |
+
self.input_resolution = input_resolution
|
382 |
+
self.depth = depth
|
383 |
+
self.use_checkpoint = use_checkpoint
|
384 |
+
|
385 |
+
# build blocks
|
386 |
+
self.blocks = nn.ModuleList([
|
387 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
388 |
+
num_heads=num_heads, window_size=window_size,
|
389 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
390 |
+
mlp_ratio=mlp_ratio,
|
391 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
392 |
+
drop=drop, attn_drop=attn_drop,
|
393 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
394 |
+
norm_layer=norm_layer)
|
395 |
+
for i in range(depth)])
|
396 |
+
|
397 |
+
# patch merging layer
|
398 |
+
if downsample is not None:
|
399 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
400 |
+
else:
|
401 |
+
self.downsample = None
|
402 |
+
|
403 |
+
def forward(self, x, x_size):
|
404 |
+
for blk in self.blocks:
|
405 |
+
if self.use_checkpoint:
|
406 |
+
x = checkpoint.checkpoint(blk, x, x_size)
|
407 |
+
else:
|
408 |
+
x = blk(x, x_size)
|
409 |
+
if self.downsample is not None:
|
410 |
+
x = self.downsample(x)
|
411 |
+
return x
|
412 |
+
|
413 |
+
def extra_repr(self) -> str:
|
414 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
415 |
+
|
416 |
+
def flops(self):
|
417 |
+
flops = 0
|
418 |
+
for blk in self.blocks:
|
419 |
+
flops += blk.flops()
|
420 |
+
if self.downsample is not None:
|
421 |
+
flops += self.downsample.flops()
|
422 |
+
return flops
|
423 |
+
|
424 |
+
|
425 |
+
class RSTB(nn.Module):
|
426 |
+
"""Residual Swin Transformer Block (RSTB).
|
427 |
+
|
428 |
+
Args:
|
429 |
+
dim (int): Number of input channels.
|
430 |
+
input_resolution (tuple[int]): Input resolution.
|
431 |
+
depth (int): Number of blocks.
|
432 |
+
num_heads (int): Number of attention heads.
|
433 |
+
window_size (int): Local window size.
|
434 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
435 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
436 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
437 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
438 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
439 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
440 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
441 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
442 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
443 |
+
img_size: Input image size.
|
444 |
+
patch_size: Patch size.
|
445 |
+
resi_connection: The convolutional block before residual connection.
|
446 |
+
"""
|
447 |
+
|
448 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
449 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
450 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
451 |
+
img_size=224, patch_size=4, resi_connection='1conv'):
|
452 |
+
super(RSTB, self).__init__()
|
453 |
+
|
454 |
+
self.dim = dim
|
455 |
+
self.input_resolution = input_resolution
|
456 |
+
|
457 |
+
self.residual_group = BasicLayer(dim=dim,
|
458 |
+
input_resolution=input_resolution,
|
459 |
+
depth=depth,
|
460 |
+
num_heads=num_heads,
|
461 |
+
window_size=window_size,
|
462 |
+
mlp_ratio=mlp_ratio,
|
463 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
464 |
+
drop=drop, attn_drop=attn_drop,
|
465 |
+
drop_path=drop_path,
|
466 |
+
norm_layer=norm_layer,
|
467 |
+
downsample=downsample,
|
468 |
+
use_checkpoint=use_checkpoint)
|
469 |
+
|
470 |
+
if resi_connection == '1conv':
|
471 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
472 |
+
elif resi_connection == '3conv':
|
473 |
+
# to save parameters and memory
|
474 |
+
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
475 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
476 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
477 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
478 |
+
|
479 |
+
self.patch_embed = PatchEmbed(
|
480 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
481 |
+
norm_layer=None)
|
482 |
+
|
483 |
+
self.patch_unembed = PatchUnEmbed(
|
484 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
485 |
+
norm_layer=None)
|
486 |
+
|
487 |
+
def forward(self, x, x_size):
|
488 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
489 |
+
|
490 |
+
def flops(self):
|
491 |
+
flops = 0
|
492 |
+
flops += self.residual_group.flops()
|
493 |
+
H, W = self.input_resolution
|
494 |
+
flops += H * W * self.dim * self.dim * 9
|
495 |
+
flops += self.patch_embed.flops()
|
496 |
+
flops += self.patch_unembed.flops()
|
497 |
+
|
498 |
+
return flops
|
499 |
+
|
500 |
+
|
501 |
+
class PatchEmbed(nn.Module):
|
502 |
+
r""" Image to Patch Embedding
|
503 |
+
|
504 |
+
Args:
|
505 |
+
img_size (int): Image size. Default: 224.
|
506 |
+
patch_size (int): Patch token size. Default: 4.
|
507 |
+
in_chans (int): Number of input image channels. Default: 3.
|
508 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
509 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
510 |
+
"""
|
511 |
+
|
512 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
513 |
+
super().__init__()
|
514 |
+
img_size = to_2tuple(img_size)
|
515 |
+
patch_size = to_2tuple(patch_size)
|
516 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
517 |
+
self.img_size = img_size
|
518 |
+
self.patch_size = patch_size
|
519 |
+
self.patches_resolution = patches_resolution
|
520 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
521 |
+
|
522 |
+
self.in_chans = in_chans
|
523 |
+
self.embed_dim = embed_dim
|
524 |
+
|
525 |
+
if norm_layer is not None:
|
526 |
+
self.norm = norm_layer(embed_dim)
|
527 |
+
else:
|
528 |
+
self.norm = None
|
529 |
+
|
530 |
+
def forward(self, x):
|
531 |
+
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
532 |
+
if self.norm is not None:
|
533 |
+
x = self.norm(x)
|
534 |
+
return x
|
535 |
+
|
536 |
+
def flops(self):
|
537 |
+
flops = 0
|
538 |
+
H, W = self.img_size
|
539 |
+
if self.norm is not None:
|
540 |
+
flops += H * W * self.embed_dim
|
541 |
+
return flops
|
542 |
+
|
543 |
+
|
544 |
+
class PatchUnEmbed(nn.Module):
|
545 |
+
r""" Image to Patch Unembedding
|
546 |
+
|
547 |
+
Args:
|
548 |
+
img_size (int): Image size. Default: 224.
|
549 |
+
patch_size (int): Patch token size. Default: 4.
|
550 |
+
in_chans (int): Number of input image channels. Default: 3.
|
551 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
552 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
553 |
+
"""
|
554 |
+
|
555 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
556 |
+
super().__init__()
|
557 |
+
img_size = to_2tuple(img_size)
|
558 |
+
patch_size = to_2tuple(patch_size)
|
559 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
560 |
+
self.img_size = img_size
|
561 |
+
self.patch_size = patch_size
|
562 |
+
self.patches_resolution = patches_resolution
|
563 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
564 |
+
|
565 |
+
self.in_chans = in_chans
|
566 |
+
self.embed_dim = embed_dim
|
567 |
+
|
568 |
+
def forward(self, x, x_size):
|
569 |
+
B, HW, C = x.shape
|
570 |
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
571 |
+
return x
|
572 |
+
|
573 |
+
def flops(self):
|
574 |
+
flops = 0
|
575 |
+
return flops
|
576 |
+
|
577 |
+
|
578 |
+
class Upsample(nn.Sequential):
|
579 |
+
"""Upsample module.
|
580 |
+
|
581 |
+
Args:
|
582 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
583 |
+
num_feat (int): Channel number of intermediate features.
|
584 |
+
"""
|
585 |
+
|
586 |
+
def __init__(self, scale, num_feat):
|
587 |
+
m = []
|
588 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
589 |
+
for _ in range(int(math.log(scale, 2))):
|
590 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
591 |
+
m.append(nn.PixelShuffle(2))
|
592 |
+
elif scale == 3:
|
593 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
594 |
+
m.append(nn.PixelShuffle(3))
|
595 |
+
else:
|
596 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
597 |
+
super(Upsample, self).__init__(*m)
|
598 |
+
|
599 |
+
|
600 |
+
class UpsampleOneStep(nn.Sequential):
|
601 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
602 |
+
Used in lightweight SR to save parameters.
|
603 |
+
|
604 |
+
Args:
|
605 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
606 |
+
num_feat (int): Channel number of intermediate features.
|
607 |
+
|
608 |
+
"""
|
609 |
+
|
610 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
611 |
+
self.num_feat = num_feat
|
612 |
+
self.input_resolution = input_resolution
|
613 |
+
m = []
|
614 |
+
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
615 |
+
m.append(nn.PixelShuffle(scale))
|
616 |
+
super(UpsampleOneStep, self).__init__(*m)
|
617 |
+
|
618 |
+
def flops(self):
|
619 |
+
H, W = self.input_resolution
|
620 |
+
flops = H * W * self.num_feat * 3 * 9
|
621 |
+
return flops
|
622 |
+
|
623 |
+
|
624 |
+
class SwinIR(nn.Module):
|
625 |
+
r""" SwinIR
|
626 |
+
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
|
627 |
+
|
628 |
+
Args:
|
629 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
630 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
631 |
+
in_chans (int): Number of input image channels. Default: 3
|
632 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
633 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
634 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
635 |
+
window_size (int): Window size. Default: 7
|
636 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
637 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
638 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
639 |
+
drop_rate (float): Dropout rate. Default: 0
|
640 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
641 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
642 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
643 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
644 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
645 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
646 |
+
sf: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
647 |
+
img_range: Image range. 1. or 255.
|
648 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
649 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
650 |
+
"""
|
651 |
+
|
652 |
+
def __init__(
|
653 |
+
self,
|
654 |
+
img_size=64,
|
655 |
+
patch_size=1,
|
656 |
+
in_chans=3,
|
657 |
+
embed_dim=96,
|
658 |
+
depths=[6, 6, 6, 6],
|
659 |
+
num_heads=[6, 6, 6, 6],
|
660 |
+
window_size=7,
|
661 |
+
mlp_ratio=4.,
|
662 |
+
qkv_bias=True,
|
663 |
+
qk_scale=None,
|
664 |
+
drop_rate=0.,
|
665 |
+
attn_drop_rate=0.,
|
666 |
+
drop_path_rate=0.1,
|
667 |
+
norm_layer=nn.LayerNorm,
|
668 |
+
ape=False,
|
669 |
+
patch_norm=True,
|
670 |
+
use_checkpoint=False,
|
671 |
+
sf=4,
|
672 |
+
img_range=1.,
|
673 |
+
upsampler='',
|
674 |
+
resi_connection='1conv',
|
675 |
+
unshuffle=False,
|
676 |
+
unshuffle_scale=None,
|
677 |
+
hq_key: str="jpg",
|
678 |
+
lq_key: str="hint",
|
679 |
+
learning_rate: float=None,
|
680 |
+
weight_decay: float=None
|
681 |
+
) -> "SwinIR":
|
682 |
+
super(SwinIR, self).__init__()
|
683 |
+
num_in_ch = in_chans * (unshuffle_scale**2) if unshuffle else in_chans
|
684 |
+
num_out_ch = in_chans
|
685 |
+
num_feat = 64
|
686 |
+
self.img_range = img_range
|
687 |
+
if in_chans == 3:
|
688 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
689 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
690 |
+
else:
|
691 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
692 |
+
self.upscale = sf
|
693 |
+
self.upsampler = upsampler
|
694 |
+
self.window_size = window_size
|
695 |
+
self.unshuffle_scale = unshuffle_scale
|
696 |
+
self.unshuffle = unshuffle
|
697 |
+
|
698 |
+
#####################################################################################################
|
699 |
+
################################### 1, shallow feature extraction ###################################
|
700 |
+
if unshuffle:
|
701 |
+
assert unshuffle_scale is not None
|
702 |
+
self.conv_first = nn.Sequential(
|
703 |
+
nn.PixelUnshuffle(sf),
|
704 |
+
nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1),
|
705 |
+
)
|
706 |
+
else:
|
707 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
708 |
+
|
709 |
+
#####################################################################################################
|
710 |
+
################################### 2, deep feature extraction ######################################
|
711 |
+
self.num_layers = len(depths)
|
712 |
+
self.embed_dim = embed_dim
|
713 |
+
self.ape = ape
|
714 |
+
self.patch_norm = patch_norm
|
715 |
+
self.num_features = embed_dim
|
716 |
+
self.mlp_ratio = mlp_ratio
|
717 |
+
|
718 |
+
# split image into non-overlapping patches
|
719 |
+
self.patch_embed = PatchEmbed(
|
720 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
721 |
+
norm_layer=norm_layer if self.patch_norm else None
|
722 |
+
)
|
723 |
+
num_patches = self.patch_embed.num_patches
|
724 |
+
patches_resolution = self.patch_embed.patches_resolution
|
725 |
+
self.patches_resolution = patches_resolution
|
726 |
+
|
727 |
+
# merge non-overlapping patches into image
|
728 |
+
self.patch_unembed = PatchUnEmbed(
|
729 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
730 |
+
norm_layer=norm_layer if self.patch_norm else None
|
731 |
+
)
|
732 |
+
|
733 |
+
# absolute position embedding
|
734 |
+
if self.ape:
|
735 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
736 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
737 |
+
|
738 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
739 |
+
|
740 |
+
# stochastic depth
|
741 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
742 |
+
|
743 |
+
# build Residual Swin Transformer blocks (RSTB)
|
744 |
+
self.layers = nn.ModuleList()
|
745 |
+
for i_layer in range(self.num_layers):
|
746 |
+
layer = RSTB(
|
747 |
+
dim=embed_dim,
|
748 |
+
input_resolution=(patches_resolution[0], patches_resolution[1]),
|
749 |
+
depth=depths[i_layer],
|
750 |
+
num_heads=num_heads[i_layer],
|
751 |
+
window_size=window_size,
|
752 |
+
mlp_ratio=self.mlp_ratio,
|
753 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
754 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
755 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
756 |
+
norm_layer=norm_layer,
|
757 |
+
downsample=None,
|
758 |
+
use_checkpoint=use_checkpoint,
|
759 |
+
img_size=img_size,
|
760 |
+
patch_size=patch_size,
|
761 |
+
resi_connection=resi_connection
|
762 |
+
)
|
763 |
+
self.layers.append(layer)
|
764 |
+
self.norm = norm_layer(self.num_features)
|
765 |
+
|
766 |
+
# build the last conv layer in deep feature extraction
|
767 |
+
if resi_connection == '1conv':
|
768 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
769 |
+
elif resi_connection == '3conv':
|
770 |
+
# to save parameters and memory
|
771 |
+
self.conv_after_body = nn.Sequential(
|
772 |
+
nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
773 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
774 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
775 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
776 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)
|
777 |
+
)
|
778 |
+
|
779 |
+
#####################################################################################################
|
780 |
+
################################ 3, high quality image reconstruction ################################
|
781 |
+
if self.upsampler == 'pixelshuffle':
|
782 |
+
# for classical SR
|
783 |
+
self.conv_before_upsample = nn.Sequential(
|
784 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
785 |
+
nn.LeakyReLU(inplace=True)
|
786 |
+
)
|
787 |
+
self.upsample = Upsample(sf, num_feat)
|
788 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
789 |
+
elif self.upsampler == 'pixelshuffledirect':
|
790 |
+
# for lightweight SR (to save parameters)
|
791 |
+
self.upsample = UpsampleOneStep(
|
792 |
+
sf, embed_dim, num_out_ch,
|
793 |
+
(patches_resolution[0], patches_resolution[1])
|
794 |
+
)
|
795 |
+
elif self.upsampler == 'nearest+conv':
|
796 |
+
# for real-world SR (less artifacts)
|
797 |
+
self.conv_before_upsample = nn.Sequential(
|
798 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
799 |
+
nn.LeakyReLU(inplace=True)
|
800 |
+
)
|
801 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
802 |
+
if self.upscale == 4:
|
803 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
804 |
+
elif self.upscale == 8:
|
805 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
806 |
+
self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
807 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
808 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
809 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
810 |
+
else:
|
811 |
+
# for image denoising and JPEG compression artifact reduction
|
812 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
813 |
+
|
814 |
+
self.apply(self._init_weights)
|
815 |
+
|
816 |
+
def _init_weights(self, m: nn.Module) -> None:
|
817 |
+
if isinstance(m, nn.Linear):
|
818 |
+
trunc_normal_(m.weight, std=.02)
|
819 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
820 |
+
nn.init.constant_(m.bias, 0)
|
821 |
+
elif isinstance(m, nn.LayerNorm):
|
822 |
+
nn.init.constant_(m.bias, 0)
|
823 |
+
nn.init.constant_(m.weight, 1.0)
|
824 |
+
|
825 |
+
# TODO: What's this ?
|
826 |
+
@torch.jit.ignore
|
827 |
+
def no_weight_decay(self) -> Set[str]:
|
828 |
+
return {'absolute_pos_embed'}
|
829 |
+
|
830 |
+
@torch.jit.ignore
|
831 |
+
def no_weight_decay_keywords(self) -> Set[str]:
|
832 |
+
return {'relative_position_bias_table'}
|
833 |
+
|
834 |
+
def check_image_size(self, x: torch.Tensor) -> torch.Tensor:
|
835 |
+
_, _, h, w = x.size()
|
836 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
837 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
838 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
839 |
+
return x
|
840 |
+
|
841 |
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
842 |
+
x_size = (x.shape[2], x.shape[3])
|
843 |
+
x = self.patch_embed(x)
|
844 |
+
if self.ape:
|
845 |
+
x = x + self.absolute_pos_embed
|
846 |
+
x = self.pos_drop(x)
|
847 |
+
|
848 |
+
for layer in self.layers:
|
849 |
+
x = layer(x, x_size)
|
850 |
+
|
851 |
+
x = self.norm(x) # B L C
|
852 |
+
x = self.patch_unembed(x, x_size)
|
853 |
+
|
854 |
+
return x
|
855 |
+
|
856 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
857 |
+
H, W = x.shape[2:]
|
858 |
+
x = self.check_image_size(x)
|
859 |
+
|
860 |
+
self.mean = self.mean.type_as(x)
|
861 |
+
x = (x - self.mean) * self.img_range
|
862 |
+
|
863 |
+
if self.upsampler == 'pixelshuffle':
|
864 |
+
# for classical SR
|
865 |
+
x = self.conv_first(x)
|
866 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
867 |
+
x = self.conv_before_upsample(x)
|
868 |
+
x = self.conv_last(self.upsample(x))
|
869 |
+
elif self.upsampler == 'pixelshuffledirect':
|
870 |
+
# for lightweight SR
|
871 |
+
x = self.conv_first(x)
|
872 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
873 |
+
x = self.upsample(x)
|
874 |
+
elif self.upsampler == 'nearest+conv':
|
875 |
+
# for real-world SR
|
876 |
+
x = self.conv_first(x)
|
877 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
878 |
+
x = self.conv_before_upsample(x)
|
879 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
880 |
+
if self.upscale == 4:
|
881 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
882 |
+
elif self.upscale == 8:
|
883 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
884 |
+
x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
885 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
886 |
+
else:
|
887 |
+
# for image denoising and JPEG compression artifact reduction
|
888 |
+
x_first = self.conv_first(x)
|
889 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
890 |
+
x = x + self.conv_last(res)
|
891 |
+
|
892 |
+
x = x / self.img_range + self.mean
|
893 |
+
|
894 |
+
return x[:, :, :H*self.upscale, :W*self.upscale]
|
895 |
+
|
896 |
+
def flops(self) -> int:
|
897 |
+
flops = 0
|
898 |
+
H, W = self.patches_resolution
|
899 |
+
flops += H * W * 3 * self.embed_dim * 9
|
900 |
+
flops += self.patch_embed.flops()
|
901 |
+
for i, layer in enumerate(self.layers):
|
902 |
+
flops += layer.flops()
|
903 |
+
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
904 |
+
flops += self.upsample.flops()
|
905 |
+
return flops
|
model/unet.py
ADDED
@@ -0,0 +1,719 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
import math
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch as th
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from model.util import (
|
10 |
+
checkpoint,
|
11 |
+
conv_nd,
|
12 |
+
linear,
|
13 |
+
avg_pool_nd,
|
14 |
+
zero_module,
|
15 |
+
normalization,
|
16 |
+
timestep_embedding,
|
17 |
+
exists
|
18 |
+
)
|
19 |
+
from model.attention import SpatialTransformer
|
20 |
+
|
21 |
+
|
22 |
+
class TimestepBlock(nn.Module):
|
23 |
+
"""
|
24 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
25 |
+
"""
|
26 |
+
|
27 |
+
@abstractmethod
|
28 |
+
def forward(self, x, emb):
|
29 |
+
"""
|
30 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
31 |
+
"""
|
32 |
+
|
33 |
+
|
34 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
35 |
+
"""
|
36 |
+
A sequential module that passes timestep embeddings to the children that
|
37 |
+
support it as an extra input.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def forward(self, x, emb, context=None):
|
41 |
+
for layer in self:
|
42 |
+
if isinstance(layer, TimestepBlock):
|
43 |
+
x = layer(x, emb)
|
44 |
+
elif isinstance(layer, SpatialTransformer):
|
45 |
+
x = layer(x, context)
|
46 |
+
else:
|
47 |
+
x = layer(x)
|
48 |
+
return x
|
49 |
+
|
50 |
+
|
51 |
+
class Upsample(nn.Module):
|
52 |
+
"""
|
53 |
+
An upsampling layer with an optional convolution.
|
54 |
+
:param channels: channels in the inputs and outputs.
|
55 |
+
:param use_conv: a bool determining if a convolution is applied.
|
56 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
57 |
+
upsampling occurs in the inner-two dimensions.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
61 |
+
super().__init__()
|
62 |
+
self.channels = channels
|
63 |
+
self.out_channels = out_channels or channels
|
64 |
+
self.use_conv = use_conv
|
65 |
+
self.dims = dims
|
66 |
+
if use_conv:
|
67 |
+
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
assert x.shape[1] == self.channels
|
71 |
+
if self.dims == 3:
|
72 |
+
x = F.interpolate(
|
73 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
77 |
+
if self.use_conv:
|
78 |
+
x = self.conv(x)
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
class Downsample(nn.Module):
|
83 |
+
"""
|
84 |
+
A downsampling layer with an optional convolution.
|
85 |
+
:param channels: channels in the inputs and outputs.
|
86 |
+
:param use_conv: a bool determining if a convolution is applied.
|
87 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
88 |
+
downsampling occurs in the inner-two dimensions.
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
92 |
+
super().__init__()
|
93 |
+
self.channels = channels
|
94 |
+
self.out_channels = out_channels or channels
|
95 |
+
self.use_conv = use_conv
|
96 |
+
self.dims = dims
|
97 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
98 |
+
if use_conv:
|
99 |
+
self.op = conv_nd(
|
100 |
+
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
101 |
+
)
|
102 |
+
else:
|
103 |
+
assert self.channels == self.out_channels
|
104 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
assert x.shape[1] == self.channels
|
108 |
+
return self.op(x)
|
109 |
+
|
110 |
+
|
111 |
+
class ResBlock(TimestepBlock):
|
112 |
+
"""
|
113 |
+
A residual block that can optionally change the number of channels.
|
114 |
+
:param channels: the number of input channels.
|
115 |
+
:param emb_channels: the number of timestep embedding channels.
|
116 |
+
:param dropout: the rate of dropout.
|
117 |
+
:param out_channels: if specified, the number of out channels.
|
118 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
119 |
+
convolution instead of a smaller 1x1 convolution to change the
|
120 |
+
channels in the skip connection.
|
121 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
122 |
+
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
123 |
+
:param up: if True, use this block for upsampling.
|
124 |
+
:param down: if True, use this block for downsampling.
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
channels,
|
130 |
+
emb_channels,
|
131 |
+
dropout,
|
132 |
+
out_channels=None,
|
133 |
+
use_conv=False,
|
134 |
+
use_scale_shift_norm=False,
|
135 |
+
dims=2,
|
136 |
+
use_checkpoint=False,
|
137 |
+
up=False,
|
138 |
+
down=False,
|
139 |
+
):
|
140 |
+
super().__init__()
|
141 |
+
self.channels = channels
|
142 |
+
self.emb_channels = emb_channels
|
143 |
+
self.dropout = dropout
|
144 |
+
self.out_channels = out_channels or channels
|
145 |
+
self.use_conv = use_conv
|
146 |
+
self.use_checkpoint = use_checkpoint
|
147 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
148 |
+
|
149 |
+
self.in_layers = nn.Sequential(
|
150 |
+
normalization(channels),
|
151 |
+
nn.SiLU(),
|
152 |
+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
153 |
+
)
|
154 |
+
|
155 |
+
self.updown = up or down
|
156 |
+
|
157 |
+
if up:
|
158 |
+
self.h_upd = Upsample(channels, False, dims)
|
159 |
+
self.x_upd = Upsample(channels, False, dims)
|
160 |
+
elif down:
|
161 |
+
self.h_upd = Downsample(channels, False, dims)
|
162 |
+
self.x_upd = Downsample(channels, False, dims)
|
163 |
+
else:
|
164 |
+
self.h_upd = self.x_upd = nn.Identity()
|
165 |
+
|
166 |
+
self.emb_layers = nn.Sequential(
|
167 |
+
nn.SiLU(),
|
168 |
+
linear(
|
169 |
+
emb_channels,
|
170 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
171 |
+
),
|
172 |
+
)
|
173 |
+
self.out_layers = nn.Sequential(
|
174 |
+
normalization(self.out_channels),
|
175 |
+
nn.SiLU(),
|
176 |
+
nn.Dropout(p=dropout),
|
177 |
+
zero_module(
|
178 |
+
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
179 |
+
),
|
180 |
+
)
|
181 |
+
|
182 |
+
if self.out_channels == channels:
|
183 |
+
self.skip_connection = nn.Identity()
|
184 |
+
elif use_conv:
|
185 |
+
self.skip_connection = conv_nd(
|
186 |
+
dims, channels, self.out_channels, 3, padding=1
|
187 |
+
)
|
188 |
+
else:
|
189 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
190 |
+
|
191 |
+
def forward(self, x, emb):
|
192 |
+
"""
|
193 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
194 |
+
:param x: an [N x C x ...] Tensor of features.
|
195 |
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
196 |
+
:return: an [N x C x ...] Tensor of outputs.
|
197 |
+
"""
|
198 |
+
return checkpoint(
|
199 |
+
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
200 |
+
)
|
201 |
+
|
202 |
+
|
203 |
+
def _forward(self, x, emb):
|
204 |
+
if self.updown:
|
205 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
206 |
+
h = in_rest(x)
|
207 |
+
h = self.h_upd(h)
|
208 |
+
x = self.x_upd(x)
|
209 |
+
h = in_conv(h)
|
210 |
+
else:
|
211 |
+
h = self.in_layers(x)
|
212 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
213 |
+
while len(emb_out.shape) < len(h.shape):
|
214 |
+
emb_out = emb_out[..., None]
|
215 |
+
if self.use_scale_shift_norm:
|
216 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
217 |
+
scale, shift = th.chunk(emb_out, 2, dim=1)
|
218 |
+
h = out_norm(h) * (1 + scale) + shift
|
219 |
+
h = out_rest(h)
|
220 |
+
else:
|
221 |
+
h = h + emb_out
|
222 |
+
h = self.out_layers(h)
|
223 |
+
return self.skip_connection(x) + h
|
224 |
+
|
225 |
+
|
226 |
+
class AttentionBlock(nn.Module):
|
227 |
+
"""
|
228 |
+
An attention block that allows spatial positions to attend to each other.
|
229 |
+
Originally ported from here, but adapted to the N-d case.
|
230 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
231 |
+
"""
|
232 |
+
|
233 |
+
def __init__(
|
234 |
+
self,
|
235 |
+
channels,
|
236 |
+
num_heads=1,
|
237 |
+
num_head_channels=-1,
|
238 |
+
use_checkpoint=False,
|
239 |
+
use_new_attention_order=False,
|
240 |
+
):
|
241 |
+
super().__init__()
|
242 |
+
self.channels = channels
|
243 |
+
if num_head_channels == -1:
|
244 |
+
self.num_heads = num_heads
|
245 |
+
else:
|
246 |
+
assert (
|
247 |
+
channels % num_head_channels == 0
|
248 |
+
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
249 |
+
self.num_heads = channels // num_head_channels
|
250 |
+
self.use_checkpoint = use_checkpoint
|
251 |
+
self.norm = normalization(channels)
|
252 |
+
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
253 |
+
if use_new_attention_order:
|
254 |
+
# split qkv before split heads
|
255 |
+
self.attention = QKVAttention(self.num_heads)
|
256 |
+
else:
|
257 |
+
# split heads before split qkv
|
258 |
+
self.attention = QKVAttentionLegacy(self.num_heads)
|
259 |
+
|
260 |
+
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
261 |
+
|
262 |
+
def forward(self, x):
|
263 |
+
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
264 |
+
#return pt_checkpoint(self._forward, x) # pytorch
|
265 |
+
|
266 |
+
def _forward(self, x):
|
267 |
+
b, c, *spatial = x.shape
|
268 |
+
x = x.reshape(b, c, -1)
|
269 |
+
qkv = self.qkv(self.norm(x))
|
270 |
+
h = self.attention(qkv)
|
271 |
+
h = self.proj_out(h)
|
272 |
+
return (x + h).reshape(b, c, *spatial)
|
273 |
+
|
274 |
+
|
275 |
+
def count_flops_attn(model, _x, y):
|
276 |
+
"""
|
277 |
+
A counter for the `thop` package to count the operations in an
|
278 |
+
attention operation.
|
279 |
+
Meant to be used like:
|
280 |
+
macs, params = thop.profile(
|
281 |
+
model,
|
282 |
+
inputs=(inputs, timestamps),
|
283 |
+
custom_ops={QKVAttention: QKVAttention.count_flops},
|
284 |
+
)
|
285 |
+
"""
|
286 |
+
b, c, *spatial = y[0].shape
|
287 |
+
num_spatial = int(np.prod(spatial))
|
288 |
+
# We perform two matmuls with the same number of ops.
|
289 |
+
# The first computes the weight matrix, the second computes
|
290 |
+
# the combination of the value vectors.
|
291 |
+
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
292 |
+
model.total_ops += th.DoubleTensor([matmul_ops])
|
293 |
+
|
294 |
+
|
295 |
+
class QKVAttentionLegacy(nn.Module):
|
296 |
+
"""
|
297 |
+
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
298 |
+
"""
|
299 |
+
|
300 |
+
def __init__(self, n_heads):
|
301 |
+
super().__init__()
|
302 |
+
self.n_heads = n_heads
|
303 |
+
|
304 |
+
def forward(self, qkv):
|
305 |
+
"""
|
306 |
+
Apply QKV attention.
|
307 |
+
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
308 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
309 |
+
"""
|
310 |
+
bs, width, length = qkv.shape
|
311 |
+
assert width % (3 * self.n_heads) == 0
|
312 |
+
ch = width // (3 * self.n_heads)
|
313 |
+
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
314 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
315 |
+
weight = th.einsum(
|
316 |
+
"bct,bcs->bts", q * scale, k * scale
|
317 |
+
) # More stable with f16 than dividing afterwards
|
318 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
319 |
+
a = th.einsum("bts,bcs->bct", weight, v)
|
320 |
+
return a.reshape(bs, -1, length)
|
321 |
+
|
322 |
+
@staticmethod
|
323 |
+
def count_flops(model, _x, y):
|
324 |
+
return count_flops_attn(model, _x, y)
|
325 |
+
|
326 |
+
|
327 |
+
class QKVAttention(nn.Module):
|
328 |
+
"""
|
329 |
+
A module which performs QKV attention and splits in a different order.
|
330 |
+
"""
|
331 |
+
|
332 |
+
def __init__(self, n_heads):
|
333 |
+
super().__init__()
|
334 |
+
self.n_heads = n_heads
|
335 |
+
|
336 |
+
def forward(self, qkv):
|
337 |
+
"""
|
338 |
+
Apply QKV attention.
|
339 |
+
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
340 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
341 |
+
"""
|
342 |
+
bs, width, length = qkv.shape
|
343 |
+
assert width % (3 * self.n_heads) == 0
|
344 |
+
ch = width // (3 * self.n_heads)
|
345 |
+
q, k, v = qkv.chunk(3, dim=1)
|
346 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
347 |
+
weight = th.einsum(
|
348 |
+
"bct,bcs->bts",
|
349 |
+
(q * scale).view(bs * self.n_heads, ch, length),
|
350 |
+
(k * scale).view(bs * self.n_heads, ch, length),
|
351 |
+
) # More stable with f16 than dividing afterwards
|
352 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
353 |
+
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
354 |
+
return a.reshape(bs, -1, length)
|
355 |
+
|
356 |
+
@staticmethod
|
357 |
+
def count_flops(model, _x, y):
|
358 |
+
return count_flops_attn(model, _x, y)
|
359 |
+
|
360 |
+
|
361 |
+
class UNetModel(nn.Module):
|
362 |
+
"""
|
363 |
+
The full UNet model with attention and timestep embedding.
|
364 |
+
:param in_channels: channels in the input Tensor.
|
365 |
+
:param model_channels: base channel count for the model.
|
366 |
+
:param out_channels: channels in the output Tensor.
|
367 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
368 |
+
:param attention_resolutions: a collection of downsample rates at which
|
369 |
+
attention will take place. May be a set, list, or tuple.
|
370 |
+
For example, if this contains 4, then at 4x downsampling, attention
|
371 |
+
will be used.
|
372 |
+
:param dropout: the dropout probability.
|
373 |
+
:param channel_mult: channel multiplier for each level of the UNet.
|
374 |
+
:param conv_resample: if True, use learned convolutions for upsampling and
|
375 |
+
downsampling.
|
376 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
377 |
+
:param num_classes: if specified (as an int), then this model will be
|
378 |
+
class-conditional with `num_classes` classes.
|
379 |
+
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
380 |
+
:param num_heads: the number of attention heads in each attention layer.
|
381 |
+
:param num_heads_channels: if specified, ignore num_heads and instead use
|
382 |
+
a fixed channel width per attention head.
|
383 |
+
:param num_heads_upsample: works with num_heads to set a different number
|
384 |
+
of heads for upsampling. Deprecated.
|
385 |
+
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
386 |
+
:param resblock_updown: use residual blocks for up/downsampling.
|
387 |
+
:param use_new_attention_order: use a different attention pattern for potentially
|
388 |
+
increased efficiency.
|
389 |
+
"""
|
390 |
+
|
391 |
+
def __init__(
|
392 |
+
self,
|
393 |
+
image_size,
|
394 |
+
in_channels,
|
395 |
+
model_channels,
|
396 |
+
out_channels,
|
397 |
+
num_res_blocks,
|
398 |
+
attention_resolutions,
|
399 |
+
dropout=0,
|
400 |
+
channel_mult=(1, 2, 4, 8),
|
401 |
+
conv_resample=True,
|
402 |
+
dims=2,
|
403 |
+
num_classes=None,
|
404 |
+
use_checkpoint=False,
|
405 |
+
use_fp16=False,
|
406 |
+
num_heads=-1,
|
407 |
+
num_head_channels=-1,
|
408 |
+
num_heads_upsample=-1,
|
409 |
+
use_scale_shift_norm=False,
|
410 |
+
resblock_updown=False,
|
411 |
+
use_new_attention_order=False,
|
412 |
+
use_spatial_transformer=False, # custom transformer support
|
413 |
+
transformer_depth=1, # custom transformer support
|
414 |
+
context_dim=None, # custom transformer support
|
415 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
416 |
+
legacy=True,
|
417 |
+
disable_self_attentions=None,
|
418 |
+
num_attention_blocks=None,
|
419 |
+
disable_middle_self_attn=False,
|
420 |
+
use_linear_in_transformer=False,
|
421 |
+
):
|
422 |
+
super().__init__()
|
423 |
+
if use_spatial_transformer:
|
424 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
425 |
+
|
426 |
+
if context_dim is not None:
|
427 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
428 |
+
from omegaconf.listconfig import ListConfig
|
429 |
+
if type(context_dim) == ListConfig:
|
430 |
+
context_dim = list(context_dim)
|
431 |
+
|
432 |
+
if num_heads_upsample == -1:
|
433 |
+
num_heads_upsample = num_heads
|
434 |
+
|
435 |
+
if num_heads == -1:
|
436 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
437 |
+
|
438 |
+
if num_head_channels == -1:
|
439 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
440 |
+
|
441 |
+
self.image_size = image_size
|
442 |
+
self.in_channels = in_channels
|
443 |
+
self.model_channels = model_channels
|
444 |
+
self.out_channels = out_channels
|
445 |
+
if isinstance(num_res_blocks, int):
|
446 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
447 |
+
else:
|
448 |
+
if len(num_res_blocks) != len(channel_mult):
|
449 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
450 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
451 |
+
self.num_res_blocks = num_res_blocks
|
452 |
+
if disable_self_attentions is not None:
|
453 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
454 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
455 |
+
if num_attention_blocks is not None:
|
456 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
457 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
458 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
459 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
460 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
461 |
+
f"attention will still not be set.")
|
462 |
+
|
463 |
+
self.attention_resolutions = attention_resolutions
|
464 |
+
self.dropout = dropout
|
465 |
+
self.channel_mult = channel_mult
|
466 |
+
self.conv_resample = conv_resample
|
467 |
+
self.num_classes = num_classes
|
468 |
+
self.use_checkpoint = use_checkpoint
|
469 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
470 |
+
self.num_heads = num_heads
|
471 |
+
self.num_head_channels = num_head_channels
|
472 |
+
self.num_heads_upsample = num_heads_upsample
|
473 |
+
self.predict_codebook_ids = n_embed is not None
|
474 |
+
|
475 |
+
time_embed_dim = model_channels * 4
|
476 |
+
self.time_embed = nn.Sequential(
|
477 |
+
linear(model_channels, time_embed_dim),
|
478 |
+
nn.SiLU(),
|
479 |
+
linear(time_embed_dim, time_embed_dim),
|
480 |
+
)
|
481 |
+
|
482 |
+
if self.num_classes is not None:
|
483 |
+
if isinstance(self.num_classes, int):
|
484 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
485 |
+
elif self.num_classes == "continuous":
|
486 |
+
print("setting up linear c_adm embedding layer")
|
487 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
488 |
+
else:
|
489 |
+
raise ValueError()
|
490 |
+
|
491 |
+
self.input_blocks = nn.ModuleList(
|
492 |
+
[
|
493 |
+
TimestepEmbedSequential(
|
494 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
495 |
+
)
|
496 |
+
]
|
497 |
+
)
|
498 |
+
self._feature_size = model_channels
|
499 |
+
input_block_chans = [model_channels]
|
500 |
+
ch = model_channels
|
501 |
+
ds = 1
|
502 |
+
for level, mult in enumerate(channel_mult):
|
503 |
+
for nr in range(self.num_res_blocks[level]):
|
504 |
+
layers = [
|
505 |
+
ResBlock(
|
506 |
+
ch,
|
507 |
+
time_embed_dim,
|
508 |
+
dropout,
|
509 |
+
out_channels=mult * model_channels,
|
510 |
+
dims=dims,
|
511 |
+
use_checkpoint=use_checkpoint,
|
512 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
513 |
+
)
|
514 |
+
]
|
515 |
+
ch = mult * model_channels
|
516 |
+
if ds in attention_resolutions:
|
517 |
+
if num_head_channels == -1:
|
518 |
+
dim_head = ch // num_heads
|
519 |
+
else:
|
520 |
+
num_heads = ch // num_head_channels
|
521 |
+
dim_head = num_head_channels
|
522 |
+
if legacy:
|
523 |
+
#num_heads = 1
|
524 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
525 |
+
if exists(disable_self_attentions):
|
526 |
+
disabled_sa = disable_self_attentions[level]
|
527 |
+
else:
|
528 |
+
disabled_sa = False
|
529 |
+
|
530 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
531 |
+
layers.append(
|
532 |
+
AttentionBlock(
|
533 |
+
ch,
|
534 |
+
use_checkpoint=use_checkpoint,
|
535 |
+
num_heads=num_heads,
|
536 |
+
num_head_channels=dim_head,
|
537 |
+
use_new_attention_order=use_new_attention_order,
|
538 |
+
) if not use_spatial_transformer else SpatialTransformer(
|
539 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
540 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
541 |
+
use_checkpoint=use_checkpoint
|
542 |
+
)
|
543 |
+
)
|
544 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
545 |
+
self._feature_size += ch
|
546 |
+
input_block_chans.append(ch)
|
547 |
+
if level != len(channel_mult) - 1:
|
548 |
+
out_ch = ch
|
549 |
+
self.input_blocks.append(
|
550 |
+
TimestepEmbedSequential(
|
551 |
+
ResBlock(
|
552 |
+
ch,
|
553 |
+
time_embed_dim,
|
554 |
+
dropout,
|
555 |
+
out_channels=out_ch,
|
556 |
+
dims=dims,
|
557 |
+
use_checkpoint=use_checkpoint,
|
558 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
559 |
+
down=True,
|
560 |
+
)
|
561 |
+
if resblock_updown
|
562 |
+
else Downsample(
|
563 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
564 |
+
)
|
565 |
+
)
|
566 |
+
)
|
567 |
+
ch = out_ch
|
568 |
+
input_block_chans.append(ch)
|
569 |
+
ds *= 2
|
570 |
+
self._feature_size += ch
|
571 |
+
|
572 |
+
if num_head_channels == -1:
|
573 |
+
dim_head = ch // num_heads
|
574 |
+
else:
|
575 |
+
num_heads = ch // num_head_channels
|
576 |
+
dim_head = num_head_channels
|
577 |
+
if legacy:
|
578 |
+
#num_heads = 1
|
579 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
580 |
+
self.middle_block = TimestepEmbedSequential(
|
581 |
+
ResBlock(
|
582 |
+
ch,
|
583 |
+
time_embed_dim,
|
584 |
+
dropout,
|
585 |
+
dims=dims,
|
586 |
+
use_checkpoint=use_checkpoint,
|
587 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
588 |
+
),
|
589 |
+
AttentionBlock(
|
590 |
+
ch,
|
591 |
+
use_checkpoint=use_checkpoint,
|
592 |
+
num_heads=num_heads,
|
593 |
+
num_head_channels=dim_head,
|
594 |
+
use_new_attention_order=use_new_attention_order,
|
595 |
+
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
596 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
597 |
+
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
598 |
+
use_checkpoint=use_checkpoint
|
599 |
+
),
|
600 |
+
ResBlock(
|
601 |
+
ch,
|
602 |
+
time_embed_dim,
|
603 |
+
dropout,
|
604 |
+
dims=dims,
|
605 |
+
use_checkpoint=use_checkpoint,
|
606 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
607 |
+
),
|
608 |
+
)
|
609 |
+
self._feature_size += ch
|
610 |
+
|
611 |
+
self.output_blocks = nn.ModuleList([])
|
612 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
613 |
+
for i in range(self.num_res_blocks[level] + 1):
|
614 |
+
ich = input_block_chans.pop()
|
615 |
+
layers = [
|
616 |
+
ResBlock(
|
617 |
+
ch + ich,
|
618 |
+
time_embed_dim,
|
619 |
+
dropout,
|
620 |
+
out_channels=model_channels * mult,
|
621 |
+
dims=dims,
|
622 |
+
use_checkpoint=use_checkpoint,
|
623 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
624 |
+
)
|
625 |
+
]
|
626 |
+
ch = model_channels * mult
|
627 |
+
if ds in attention_resolutions:
|
628 |
+
if num_head_channels == -1:
|
629 |
+
dim_head = ch // num_heads
|
630 |
+
else:
|
631 |
+
num_heads = ch // num_head_channels
|
632 |
+
dim_head = num_head_channels
|
633 |
+
if legacy:
|
634 |
+
#num_heads = 1
|
635 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
636 |
+
if exists(disable_self_attentions):
|
637 |
+
disabled_sa = disable_self_attentions[level]
|
638 |
+
else:
|
639 |
+
disabled_sa = False
|
640 |
+
|
641 |
+
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
642 |
+
layers.append(
|
643 |
+
AttentionBlock(
|
644 |
+
ch,
|
645 |
+
use_checkpoint=use_checkpoint,
|
646 |
+
num_heads=num_heads_upsample,
|
647 |
+
num_head_channels=dim_head,
|
648 |
+
use_new_attention_order=use_new_attention_order,
|
649 |
+
) if not use_spatial_transformer else SpatialTransformer(
|
650 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
651 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
652 |
+
use_checkpoint=use_checkpoint
|
653 |
+
)
|
654 |
+
)
|
655 |
+
if level and i == self.num_res_blocks[level]:
|
656 |
+
out_ch = ch
|
657 |
+
layers.append(
|
658 |
+
ResBlock(
|
659 |
+
ch,
|
660 |
+
time_embed_dim,
|
661 |
+
dropout,
|
662 |
+
out_channels=out_ch,
|
663 |
+
dims=dims,
|
664 |
+
use_checkpoint=use_checkpoint,
|
665 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
666 |
+
up=True,
|
667 |
+
)
|
668 |
+
if resblock_updown
|
669 |
+
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
670 |
+
)
|
671 |
+
ds //= 2
|
672 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
673 |
+
self._feature_size += ch
|
674 |
+
|
675 |
+
self.out = nn.Sequential(
|
676 |
+
normalization(ch),
|
677 |
+
nn.SiLU(),
|
678 |
+
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
679 |
+
)
|
680 |
+
if self.predict_codebook_ids:
|
681 |
+
self.id_predictor = nn.Sequential(
|
682 |
+
normalization(ch),
|
683 |
+
conv_nd(dims, model_channels, n_embed, 1),
|
684 |
+
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
685 |
+
)
|
686 |
+
|
687 |
+
def forward(self, x, timesteps, context=None, y=None,**kwargs):
|
688 |
+
"""
|
689 |
+
Apply the model to an input batch.
|
690 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
691 |
+
:param timesteps: a 1-D batch of timesteps.
|
692 |
+
:param context: conditioning plugged in via crossattn
|
693 |
+
:param y: an [N] Tensor of labels, if class-conditional.
|
694 |
+
:return: an [N x C x ...] Tensor of outputs.
|
695 |
+
"""
|
696 |
+
assert (y is not None) == (
|
697 |
+
self.num_classes is not None
|
698 |
+
), "must specify y if and only if the model is class-conditional"
|
699 |
+
hs = []
|
700 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
701 |
+
emb = self.time_embed(t_emb)
|
702 |
+
|
703 |
+
if self.num_classes is not None:
|
704 |
+
assert y.shape[0] == x.shape[0]
|
705 |
+
emb = emb + self.label_emb(y)
|
706 |
+
|
707 |
+
h = x.type(self.dtype)
|
708 |
+
for module in self.input_blocks:
|
709 |
+
h = module(h, emb, context)
|
710 |
+
hs.append(h)
|
711 |
+
h = self.middle_block(h, emb, context)
|
712 |
+
for module in self.output_blocks:
|
713 |
+
h = th.cat([h, hs.pop()], dim=1)
|
714 |
+
h = module(h, emb, context)
|
715 |
+
h = h.type(x.dtype)
|
716 |
+
if self.predict_codebook_ids:
|
717 |
+
return self.id_predictor(h)
|
718 |
+
else:
|
719 |
+
return self.out(h)
|
model/util.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adopted from
|
2 |
+
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
3 |
+
# and
|
4 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
5 |
+
# and
|
6 |
+
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
7 |
+
#
|
8 |
+
# thanks!
|
9 |
+
|
10 |
+
|
11 |
+
import os
|
12 |
+
import math
|
13 |
+
from inspect import isfunction
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import numpy as np
|
17 |
+
from einops import repeat
|
18 |
+
|
19 |
+
|
20 |
+
def exists(val):
|
21 |
+
return val is not None
|
22 |
+
|
23 |
+
|
24 |
+
def default(val, d):
|
25 |
+
if exists(val):
|
26 |
+
return val
|
27 |
+
return d() if isfunction(d) else d
|
28 |
+
|
29 |
+
|
30 |
+
def checkpoint(func, inputs, params, flag):
|
31 |
+
"""
|
32 |
+
Evaluate a function without caching intermediate activations, allowing for
|
33 |
+
reduced memory at the expense of extra compute in the backward pass.
|
34 |
+
:param func: the function to evaluate.
|
35 |
+
:param inputs: the argument sequence to pass to `func`.
|
36 |
+
:param params: a sequence of parameters `func` depends on but does not
|
37 |
+
explicitly take as arguments.
|
38 |
+
:param flag: if False, disable gradient checkpointing.
|
39 |
+
"""
|
40 |
+
if flag:
|
41 |
+
args = tuple(inputs) + tuple(params)
|
42 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
43 |
+
else:
|
44 |
+
return func(*inputs)
|
45 |
+
|
46 |
+
|
47 |
+
# class CheckpointFunction(torch.autograd.Function):
|
48 |
+
# @staticmethod
|
49 |
+
# def forward(ctx, run_function, length, *args):
|
50 |
+
# ctx.run_function = run_function
|
51 |
+
# ctx.input_tensors = list(args[:length])
|
52 |
+
# ctx.input_params = list(args[length:])
|
53 |
+
# ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
54 |
+
# "dtype": torch.get_autocast_gpu_dtype(),
|
55 |
+
# "cache_enabled": torch.is_autocast_cache_enabled()}
|
56 |
+
# with torch.no_grad():
|
57 |
+
# output_tensors = ctx.run_function(*ctx.input_tensors)
|
58 |
+
# return output_tensors
|
59 |
+
|
60 |
+
# @staticmethod
|
61 |
+
# def backward(ctx, *output_grads):
|
62 |
+
# ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
63 |
+
# with torch.enable_grad(), \
|
64 |
+
# torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
65 |
+
# # Fixes a bug where the first op in run_function modifies the
|
66 |
+
# # Tensor storage in place, which is not allowed for detach()'d
|
67 |
+
# # Tensors.
|
68 |
+
# shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
69 |
+
# output_tensors = ctx.run_function(*shallow_copies)
|
70 |
+
# input_grads = torch.autograd.grad(
|
71 |
+
# output_tensors,
|
72 |
+
# ctx.input_tensors + ctx.input_params,
|
73 |
+
# output_grads,
|
74 |
+
# allow_unused=True,
|
75 |
+
# )
|
76 |
+
# del ctx.input_tensors
|
77 |
+
# del ctx.input_params
|
78 |
+
# del output_tensors
|
79 |
+
# return (None, None) + input_grads
|
80 |
+
|
81 |
+
|
82 |
+
# Fixes: When we set unet parameters with requires_grad=False, the original CheckpointFunction
|
83 |
+
# still tries to compute gradient for unet parameters.
|
84 |
+
# https://discuss.pytorch.org/t/get-runtimeerror-one-of-the-differentiated-tensors-does-not-require-grad-in-pytorch-lightning/179738/6
|
85 |
+
class CheckpointFunction(torch.autograd.Function):
|
86 |
+
@staticmethod
|
87 |
+
def forward(ctx, run_function, length, *args):
|
88 |
+
ctx.run_function = run_function
|
89 |
+
ctx.input_tensors = list(args[:length])
|
90 |
+
ctx.input_params = list(args[length:])
|
91 |
+
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
92 |
+
"dtype": torch.get_autocast_gpu_dtype(),
|
93 |
+
"cache_enabled": torch.is_autocast_cache_enabled()}
|
94 |
+
with torch.no_grad():
|
95 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
96 |
+
return output_tensors
|
97 |
+
|
98 |
+
@staticmethod
|
99 |
+
def backward(ctx, *output_grads):
|
100 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
101 |
+
with torch.enable_grad(), \
|
102 |
+
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
103 |
+
# Fixes a bug where the first op in run_function modifies the
|
104 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
105 |
+
# Tensors.
|
106 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
107 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
108 |
+
grads = torch.autograd.grad(
|
109 |
+
output_tensors,
|
110 |
+
ctx.input_tensors + [x for x in ctx.input_params if x.requires_grad],
|
111 |
+
output_grads,
|
112 |
+
allow_unused=True,
|
113 |
+
)
|
114 |
+
grads = list(grads)
|
115 |
+
# Assign gradients to the correct positions, matching None for those that do not require gradients
|
116 |
+
input_grads = []
|
117 |
+
for tensor in ctx.input_tensors + ctx.input_params:
|
118 |
+
if tensor.requires_grad:
|
119 |
+
input_grads.append(grads.pop(0)) # Get the next computed gradient
|
120 |
+
else:
|
121 |
+
input_grads.append(None) # No gradient required for this tensor
|
122 |
+
del ctx.input_tensors
|
123 |
+
del ctx.input_params
|
124 |
+
del output_tensors
|
125 |
+
return (None, None) + tuple(input_grads)
|
126 |
+
|
127 |
+
|
128 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
129 |
+
"""
|
130 |
+
Create sinusoidal timestep embeddings.
|
131 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
132 |
+
These may be fractional.
|
133 |
+
:param dim: the dimension of the output.
|
134 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
135 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
136 |
+
"""
|
137 |
+
if not repeat_only:
|
138 |
+
half = dim // 2
|
139 |
+
freqs = torch.exp(
|
140 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
141 |
+
).to(device=timesteps.device)
|
142 |
+
args = timesteps[:, None].float() * freqs[None]
|
143 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
144 |
+
if dim % 2:
|
145 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
146 |
+
else:
|
147 |
+
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
148 |
+
return embedding
|
149 |
+
|
150 |
+
|
151 |
+
def zero_module(module):
|
152 |
+
"""
|
153 |
+
Zero out the parameters of a module and return it.
|
154 |
+
"""
|
155 |
+
for p in module.parameters():
|
156 |
+
p.detach().zero_()
|
157 |
+
return module
|
158 |
+
|
159 |
+
|
160 |
+
def scale_module(module, scale):
|
161 |
+
"""
|
162 |
+
Scale the parameters of a module and return it.
|
163 |
+
"""
|
164 |
+
for p in module.parameters():
|
165 |
+
p.detach().mul_(scale)
|
166 |
+
return module
|
167 |
+
|
168 |
+
|
169 |
+
def mean_flat(tensor):
|
170 |
+
"""
|
171 |
+
Take the mean over all non-batch dimensions.
|
172 |
+
"""
|
173 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
174 |
+
|
175 |
+
|
176 |
+
def normalization(channels):
|
177 |
+
"""
|
178 |
+
Make a standard normalization layer.
|
179 |
+
:param channels: number of input channels.
|
180 |
+
:return: an nn.Module for normalization.
|
181 |
+
"""
|
182 |
+
return GroupNorm32(32, channels)
|
183 |
+
|
184 |
+
|
185 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
186 |
+
class SiLU(nn.Module):
|
187 |
+
def forward(self, x):
|
188 |
+
return x * torch.sigmoid(x)
|
189 |
+
|
190 |
+
|
191 |
+
class GroupNorm32(nn.GroupNorm):
|
192 |
+
def forward(self, x):
|
193 |
+
return super().forward(x.float()).type(x.dtype)
|
194 |
+
|
195 |
+
def conv_nd(dims, *args, **kwargs):
|
196 |
+
"""
|
197 |
+
Create a 1D, 2D, or 3D convolution module.
|
198 |
+
"""
|
199 |
+
if dims == 1:
|
200 |
+
return nn.Conv1d(*args, **kwargs)
|
201 |
+
elif dims == 2:
|
202 |
+
return nn.Conv2d(*args, **kwargs)
|
203 |
+
elif dims == 3:
|
204 |
+
return nn.Conv3d(*args, **kwargs)
|
205 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
206 |
+
|
207 |
+
|
208 |
+
def linear(*args, **kwargs):
|
209 |
+
"""
|
210 |
+
Create a linear module.
|
211 |
+
"""
|
212 |
+
return nn.Linear(*args, **kwargs)
|
213 |
+
|
214 |
+
|
215 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
216 |
+
"""
|
217 |
+
Create a 1D, 2D, or 3D average pooling module.
|
218 |
+
"""
|
219 |
+
if dims == 1:
|
220 |
+
return nn.AvgPool1d(*args, **kwargs)
|
221 |
+
elif dims == 2:
|
222 |
+
return nn.AvgPool2d(*args, **kwargs)
|
223 |
+
elif dims == 3:
|
224 |
+
return nn.AvgPool3d(*args, **kwargs)
|
225 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
model/vae.py
ADDED
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
import numpy as np
|
6 |
+
from einops import rearrange
|
7 |
+
from typing import Optional, Any
|
8 |
+
|
9 |
+
from model.distributions import DiagonalGaussianDistribution
|
10 |
+
from model.config import Config, AttnMode
|
11 |
+
|
12 |
+
|
13 |
+
def nonlinearity(x):
|
14 |
+
# swish
|
15 |
+
return x*torch.sigmoid(x)
|
16 |
+
|
17 |
+
|
18 |
+
def Normalize(in_channels, num_groups=32):
|
19 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
20 |
+
|
21 |
+
|
22 |
+
class Upsample(nn.Module):
|
23 |
+
def __init__(self, in_channels, with_conv):
|
24 |
+
super().__init__()
|
25 |
+
self.with_conv = with_conv
|
26 |
+
if self.with_conv:
|
27 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
28 |
+
in_channels,
|
29 |
+
kernel_size=3,
|
30 |
+
stride=1,
|
31 |
+
padding=1)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
35 |
+
if self.with_conv:
|
36 |
+
x = self.conv(x)
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class Downsample(nn.Module):
|
41 |
+
def __init__(self, in_channels, with_conv):
|
42 |
+
super().__init__()
|
43 |
+
self.with_conv = with_conv
|
44 |
+
if self.with_conv:
|
45 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
46 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
47 |
+
in_channels,
|
48 |
+
kernel_size=3,
|
49 |
+
stride=2,
|
50 |
+
padding=0)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
if self.with_conv:
|
54 |
+
pad = (0,1,0,1)
|
55 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
56 |
+
x = self.conv(x)
|
57 |
+
else:
|
58 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
class ResnetBlock(nn.Module):
|
63 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
64 |
+
dropout, temb_channels=512):
|
65 |
+
super().__init__()
|
66 |
+
self.in_channels = in_channels
|
67 |
+
out_channels = in_channels if out_channels is None else out_channels
|
68 |
+
self.out_channels = out_channels
|
69 |
+
self.use_conv_shortcut = conv_shortcut
|
70 |
+
|
71 |
+
self.norm1 = Normalize(in_channels)
|
72 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
73 |
+
out_channels,
|
74 |
+
kernel_size=3,
|
75 |
+
stride=1,
|
76 |
+
padding=1)
|
77 |
+
if temb_channels > 0:
|
78 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
79 |
+
out_channels)
|
80 |
+
self.norm2 = Normalize(out_channels)
|
81 |
+
self.dropout = torch.nn.Dropout(dropout)
|
82 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
83 |
+
out_channels,
|
84 |
+
kernel_size=3,
|
85 |
+
stride=1,
|
86 |
+
padding=1)
|
87 |
+
if self.in_channels != self.out_channels:
|
88 |
+
if self.use_conv_shortcut:
|
89 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
90 |
+
out_channels,
|
91 |
+
kernel_size=3,
|
92 |
+
stride=1,
|
93 |
+
padding=1)
|
94 |
+
else:
|
95 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
96 |
+
out_channels,
|
97 |
+
kernel_size=1,
|
98 |
+
stride=1,
|
99 |
+
padding=0)
|
100 |
+
|
101 |
+
def forward(self, x, temb):
|
102 |
+
h = x
|
103 |
+
h = self.norm1(h)
|
104 |
+
h = nonlinearity(h)
|
105 |
+
h = self.conv1(h)
|
106 |
+
|
107 |
+
if temb is not None:
|
108 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
109 |
+
|
110 |
+
h = self.norm2(h)
|
111 |
+
h = nonlinearity(h)
|
112 |
+
h = self.dropout(h)
|
113 |
+
h = self.conv2(h)
|
114 |
+
|
115 |
+
if self.in_channels != self.out_channels:
|
116 |
+
if self.use_conv_shortcut:
|
117 |
+
x = self.conv_shortcut(x)
|
118 |
+
else:
|
119 |
+
x = self.nin_shortcut(x)
|
120 |
+
|
121 |
+
return x+h
|
122 |
+
|
123 |
+
|
124 |
+
class AttnBlock(nn.Module):
|
125 |
+
def __init__(self, in_channels):
|
126 |
+
super().__init__()
|
127 |
+
print(f"building AttnBlock (vanilla) with {in_channels} in_channels")
|
128 |
+
|
129 |
+
self.in_channels = in_channels
|
130 |
+
|
131 |
+
self.norm = Normalize(in_channels)
|
132 |
+
self.q = torch.nn.Conv2d(in_channels,
|
133 |
+
in_channels,
|
134 |
+
kernel_size=1,
|
135 |
+
stride=1,
|
136 |
+
padding=0)
|
137 |
+
self.k = torch.nn.Conv2d(in_channels,
|
138 |
+
in_channels,
|
139 |
+
kernel_size=1,
|
140 |
+
stride=1,
|
141 |
+
padding=0)
|
142 |
+
self.v = torch.nn.Conv2d(in_channels,
|
143 |
+
in_channels,
|
144 |
+
kernel_size=1,
|
145 |
+
stride=1,
|
146 |
+
padding=0)
|
147 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
148 |
+
in_channels,
|
149 |
+
kernel_size=1,
|
150 |
+
stride=1,
|
151 |
+
padding=0)
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
h_ = x
|
155 |
+
h_ = self.norm(h_)
|
156 |
+
q = self.q(h_)
|
157 |
+
k = self.k(h_)
|
158 |
+
v = self.v(h_)
|
159 |
+
|
160 |
+
# compute attention
|
161 |
+
b,c,h,w = q.shape
|
162 |
+
q = q.reshape(b,c,h*w)
|
163 |
+
q = q.permute(0,2,1) # b,hw,c
|
164 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
165 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
166 |
+
w_ = w_ * (int(c)**(-0.5))
|
167 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
168 |
+
|
169 |
+
# attend to values
|
170 |
+
v = v.reshape(b,c,h*w)
|
171 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
172 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
173 |
+
h_ = h_.reshape(b,c,h,w)
|
174 |
+
|
175 |
+
h_ = self.proj_out(h_)
|
176 |
+
|
177 |
+
return x+h_
|
178 |
+
|
179 |
+
|
180 |
+
class MemoryEfficientAttnBlock(nn.Module):
|
181 |
+
"""
|
182 |
+
Uses xformers efficient implementation,
|
183 |
+
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
184 |
+
Note: this is a single-head self-attention operation
|
185 |
+
"""
|
186 |
+
#
|
187 |
+
def __init__(self, in_channels):
|
188 |
+
super().__init__()
|
189 |
+
print(f"building MemoryEfficientAttnBlock (xformers) with {in_channels} in_channels")
|
190 |
+
self.in_channels = in_channels
|
191 |
+
|
192 |
+
self.norm = Normalize(in_channels)
|
193 |
+
self.q = torch.nn.Conv2d(in_channels,
|
194 |
+
in_channels,
|
195 |
+
kernel_size=1,
|
196 |
+
stride=1,
|
197 |
+
padding=0)
|
198 |
+
self.k = torch.nn.Conv2d(in_channels,
|
199 |
+
in_channels,
|
200 |
+
kernel_size=1,
|
201 |
+
stride=1,
|
202 |
+
padding=0)
|
203 |
+
self.v = torch.nn.Conv2d(in_channels,
|
204 |
+
in_channels,
|
205 |
+
kernel_size=1,
|
206 |
+
stride=1,
|
207 |
+
padding=0)
|
208 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
209 |
+
in_channels,
|
210 |
+
kernel_size=1,
|
211 |
+
stride=1,
|
212 |
+
padding=0)
|
213 |
+
self.attention_op: Optional[Any] = None
|
214 |
+
|
215 |
+
def forward(self, x):
|
216 |
+
h_ = x
|
217 |
+
h_ = self.norm(h_)
|
218 |
+
q = self.q(h_)
|
219 |
+
k = self.k(h_)
|
220 |
+
v = self.v(h_)
|
221 |
+
|
222 |
+
# compute attention
|
223 |
+
B, C, H, W = q.shape
|
224 |
+
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
225 |
+
|
226 |
+
q, k, v = map(
|
227 |
+
lambda t: t.unsqueeze(3)
|
228 |
+
.reshape(B, t.shape[1], 1, C)
|
229 |
+
.permute(0, 2, 1, 3)
|
230 |
+
.reshape(B * 1, t.shape[1], C)
|
231 |
+
.contiguous(),
|
232 |
+
(q, k, v),
|
233 |
+
)
|
234 |
+
out = Config.xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
235 |
+
|
236 |
+
out = (
|
237 |
+
out.unsqueeze(0)
|
238 |
+
.reshape(B, 1, out.shape[1], C)
|
239 |
+
.permute(0, 2, 1, 3)
|
240 |
+
.reshape(B, out.shape[1], C)
|
241 |
+
)
|
242 |
+
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
243 |
+
out = self.proj_out(out)
|
244 |
+
return x+out
|
245 |
+
|
246 |
+
|
247 |
+
class SDPAttnBlock(nn.Module):
|
248 |
+
|
249 |
+
def __init__(self, in_channels):
|
250 |
+
super().__init__()
|
251 |
+
print(f"building SDPAttnBlock (sdp) with {in_channels} in_channels")
|
252 |
+
self.in_channels = in_channels
|
253 |
+
|
254 |
+
self.norm = Normalize(in_channels)
|
255 |
+
self.q = torch.nn.Conv2d(in_channels,
|
256 |
+
in_channels,
|
257 |
+
kernel_size=1,
|
258 |
+
stride=1,
|
259 |
+
padding=0)
|
260 |
+
self.k = torch.nn.Conv2d(in_channels,
|
261 |
+
in_channels,
|
262 |
+
kernel_size=1,
|
263 |
+
stride=1,
|
264 |
+
padding=0)
|
265 |
+
self.v = torch.nn.Conv2d(in_channels,
|
266 |
+
in_channels,
|
267 |
+
kernel_size=1,
|
268 |
+
stride=1,
|
269 |
+
padding=0)
|
270 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
271 |
+
in_channels,
|
272 |
+
kernel_size=1,
|
273 |
+
stride=1,
|
274 |
+
padding=0)
|
275 |
+
|
276 |
+
def forward(self, x):
|
277 |
+
h_ = x
|
278 |
+
h_ = self.norm(h_)
|
279 |
+
q = self.q(h_)
|
280 |
+
k = self.k(h_)
|
281 |
+
v = self.v(h_)
|
282 |
+
|
283 |
+
# compute attention
|
284 |
+
B, C, H, W = q.shape
|
285 |
+
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
286 |
+
|
287 |
+
q, k, v = map(
|
288 |
+
lambda t: t.unsqueeze(3)
|
289 |
+
.reshape(B, t.shape[1], 1, C)
|
290 |
+
.permute(0, 2, 1, 3)
|
291 |
+
.reshape(B * 1, t.shape[1], C)
|
292 |
+
.contiguous(),
|
293 |
+
(q, k, v),
|
294 |
+
)
|
295 |
+
out = F.scaled_dot_product_attention(q, k, v)
|
296 |
+
|
297 |
+
out = (
|
298 |
+
out.unsqueeze(0)
|
299 |
+
.reshape(B, 1, out.shape[1], C)
|
300 |
+
.permute(0, 2, 1, 3)
|
301 |
+
.reshape(B, out.shape[1], C)
|
302 |
+
)
|
303 |
+
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
304 |
+
out = self.proj_out(out)
|
305 |
+
return x+out
|
306 |
+
|
307 |
+
|
308 |
+
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
309 |
+
assert attn_type in ["vanilla", "sdp", "xformers", "linear", "none"], f'attn_type {attn_type} unknown'
|
310 |
+
if attn_type == "vanilla":
|
311 |
+
assert attn_kwargs is None
|
312 |
+
return AttnBlock(in_channels)
|
313 |
+
elif attn_type == "sdp":
|
314 |
+
return SDPAttnBlock(in_channels)
|
315 |
+
elif attn_type == "xformers":
|
316 |
+
return MemoryEfficientAttnBlock(in_channels)
|
317 |
+
elif attn_type == "none":
|
318 |
+
return nn.Identity(in_channels)
|
319 |
+
else:
|
320 |
+
raise NotImplementedError()
|
321 |
+
|
322 |
+
|
323 |
+
class Encoder(nn.Module):
|
324 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
325 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
326 |
+
resolution, z_channels, double_z=True, use_linear_attn=False,
|
327 |
+
**ignore_kwargs):
|
328 |
+
super().__init__()
|
329 |
+
### setup attention type
|
330 |
+
if Config.attn_mode == AttnMode.SDP:
|
331 |
+
attn_type = "sdp"
|
332 |
+
elif Config.attn_mode == AttnMode.XFORMERS:
|
333 |
+
attn_type = "xformers"
|
334 |
+
else:
|
335 |
+
attn_type = "vanilla"
|
336 |
+
if use_linear_attn: attn_type = "linear"
|
337 |
+
self.ch = ch
|
338 |
+
self.temb_ch = 0
|
339 |
+
self.num_resolutions = len(ch_mult)
|
340 |
+
self.num_res_blocks = num_res_blocks
|
341 |
+
self.resolution = resolution
|
342 |
+
self.in_channels = in_channels
|
343 |
+
|
344 |
+
# downsampling
|
345 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
346 |
+
self.ch,
|
347 |
+
kernel_size=3,
|
348 |
+
stride=1,
|
349 |
+
padding=1)
|
350 |
+
|
351 |
+
curr_res = resolution
|
352 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
353 |
+
self.in_ch_mult = in_ch_mult
|
354 |
+
self.down = nn.ModuleList()
|
355 |
+
for i_level in range(self.num_resolutions):
|
356 |
+
block = nn.ModuleList()
|
357 |
+
attn = nn.ModuleList()
|
358 |
+
block_in = ch*in_ch_mult[i_level]
|
359 |
+
block_out = ch*ch_mult[i_level]
|
360 |
+
for i_block in range(self.num_res_blocks):
|
361 |
+
block.append(ResnetBlock(in_channels=block_in,
|
362 |
+
out_channels=block_out,
|
363 |
+
temb_channels=self.temb_ch,
|
364 |
+
dropout=dropout))
|
365 |
+
block_in = block_out
|
366 |
+
if curr_res in attn_resolutions:
|
367 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
368 |
+
down = nn.Module()
|
369 |
+
down.block = block
|
370 |
+
down.attn = attn
|
371 |
+
if i_level != self.num_resolutions-1:
|
372 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
373 |
+
curr_res = curr_res // 2
|
374 |
+
self.down.append(down)
|
375 |
+
|
376 |
+
# middle
|
377 |
+
self.mid = nn.Module()
|
378 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
379 |
+
out_channels=block_in,
|
380 |
+
temb_channels=self.temb_ch,
|
381 |
+
dropout=dropout)
|
382 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
383 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
384 |
+
out_channels=block_in,
|
385 |
+
temb_channels=self.temb_ch,
|
386 |
+
dropout=dropout)
|
387 |
+
|
388 |
+
# end
|
389 |
+
self.norm_out = Normalize(block_in)
|
390 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
391 |
+
2*z_channels if double_z else z_channels,
|
392 |
+
kernel_size=3,
|
393 |
+
stride=1,
|
394 |
+
padding=1)
|
395 |
+
|
396 |
+
def forward(self, x):
|
397 |
+
# timestep embedding
|
398 |
+
temb = None
|
399 |
+
|
400 |
+
# downsampling
|
401 |
+
hs = [self.conv_in(x)]
|
402 |
+
for i_level in range(self.num_resolutions):
|
403 |
+
for i_block in range(self.num_res_blocks):
|
404 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
405 |
+
if len(self.down[i_level].attn) > 0:
|
406 |
+
h = self.down[i_level].attn[i_block](h)
|
407 |
+
hs.append(h)
|
408 |
+
if i_level != self.num_resolutions-1:
|
409 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
410 |
+
|
411 |
+
# middle
|
412 |
+
h = hs[-1]
|
413 |
+
h = self.mid.block_1(h, temb)
|
414 |
+
h = self.mid.attn_1(h)
|
415 |
+
h = self.mid.block_2(h, temb)
|
416 |
+
|
417 |
+
# end
|
418 |
+
h = self.norm_out(h)
|
419 |
+
h = nonlinearity(h)
|
420 |
+
h = self.conv_out(h)
|
421 |
+
return h
|
422 |
+
|
423 |
+
|
424 |
+
class Decoder(nn.Module):
|
425 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
426 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
427 |
+
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
428 |
+
**ignorekwargs):
|
429 |
+
super().__init__()
|
430 |
+
### setup attention type
|
431 |
+
if Config.attn_mode == AttnMode.SDP:
|
432 |
+
attn_type = "sdp"
|
433 |
+
elif Config.attn_mode == AttnMode.XFORMERS:
|
434 |
+
attn_type = "xformers"
|
435 |
+
else:
|
436 |
+
attn_type = "vanilla"
|
437 |
+
if use_linear_attn: attn_type = "linear"
|
438 |
+
self.ch = ch
|
439 |
+
self.temb_ch = 0
|
440 |
+
self.num_resolutions = len(ch_mult)
|
441 |
+
self.num_res_blocks = num_res_blocks
|
442 |
+
self.resolution = resolution
|
443 |
+
self.in_channels = in_channels
|
444 |
+
self.give_pre_end = give_pre_end
|
445 |
+
self.tanh_out = tanh_out
|
446 |
+
|
447 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
448 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
449 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
450 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
451 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
452 |
+
|
453 |
+
# z to block_in
|
454 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
455 |
+
block_in,
|
456 |
+
kernel_size=3,
|
457 |
+
stride=1,
|
458 |
+
padding=1)
|
459 |
+
|
460 |
+
# middle
|
461 |
+
self.mid = nn.Module()
|
462 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
463 |
+
out_channels=block_in,
|
464 |
+
temb_channels=self.temb_ch,
|
465 |
+
dropout=dropout)
|
466 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
467 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
468 |
+
out_channels=block_in,
|
469 |
+
temb_channels=self.temb_ch,
|
470 |
+
dropout=dropout)
|
471 |
+
|
472 |
+
# upsampling
|
473 |
+
self.up = nn.ModuleList()
|
474 |
+
for i_level in reversed(range(self.num_resolutions)):
|
475 |
+
block = nn.ModuleList()
|
476 |
+
attn = nn.ModuleList()
|
477 |
+
block_out = ch*ch_mult[i_level]
|
478 |
+
for i_block in range(self.num_res_blocks+1):
|
479 |
+
block.append(ResnetBlock(in_channels=block_in,
|
480 |
+
out_channels=block_out,
|
481 |
+
temb_channels=self.temb_ch,
|
482 |
+
dropout=dropout))
|
483 |
+
block_in = block_out
|
484 |
+
if curr_res in attn_resolutions:
|
485 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
486 |
+
up = nn.Module()
|
487 |
+
up.block = block
|
488 |
+
up.attn = attn
|
489 |
+
if i_level != 0:
|
490 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
491 |
+
curr_res = curr_res * 2
|
492 |
+
self.up.insert(0, up) # prepend to get consistent order
|
493 |
+
|
494 |
+
# end
|
495 |
+
self.norm_out = Normalize(block_in)
|
496 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
497 |
+
out_ch,
|
498 |
+
kernel_size=3,
|
499 |
+
stride=1,
|
500 |
+
padding=1)
|
501 |
+
|
502 |
+
def forward(self, z):
|
503 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
504 |
+
self.last_z_shape = z.shape
|
505 |
+
|
506 |
+
# timestep embedding
|
507 |
+
temb = None
|
508 |
+
|
509 |
+
# z to block_in
|
510 |
+
h = self.conv_in(z)
|
511 |
+
|
512 |
+
# middle
|
513 |
+
h = self.mid.block_1(h, temb)
|
514 |
+
h = self.mid.attn_1(h)
|
515 |
+
h = self.mid.block_2(h, temb)
|
516 |
+
|
517 |
+
# upsampling
|
518 |
+
for i_level in reversed(range(self.num_resolutions)):
|
519 |
+
for i_block in range(self.num_res_blocks+1):
|
520 |
+
h = self.up[i_level].block[i_block](h, temb)
|
521 |
+
if len(self.up[i_level].attn) > 0:
|
522 |
+
h = self.up[i_level].attn[i_block](h)
|
523 |
+
if i_level != 0:
|
524 |
+
h = self.up[i_level].upsample(h)
|
525 |
+
|
526 |
+
# end
|
527 |
+
if self.give_pre_end:
|
528 |
+
return h
|
529 |
+
|
530 |
+
h = self.norm_out(h)
|
531 |
+
h = nonlinearity(h)
|
532 |
+
h = self.conv_out(h)
|
533 |
+
if self.tanh_out:
|
534 |
+
h = torch.tanh(h)
|
535 |
+
return h
|
536 |
+
|
537 |
+
|
538 |
+
class AutoencoderKL(nn.Module):
|
539 |
+
|
540 |
+
def __init__(self, ddconfig, embed_dim):
|
541 |
+
super().__init__()
|
542 |
+
self.encoder = Encoder(**ddconfig)
|
543 |
+
self.decoder = Decoder(**ddconfig)
|
544 |
+
assert ddconfig["double_z"]
|
545 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
546 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
547 |
+
self.embed_dim = embed_dim
|
548 |
+
|
549 |
+
def encode(self, x):
|
550 |
+
h = self.encoder(x)
|
551 |
+
moments = self.quant_conv(h)
|
552 |
+
posterior = DiagonalGaussianDistribution(moments)
|
553 |
+
return posterior
|
554 |
+
|
555 |
+
def decode(self, z):
|
556 |
+
z = self.post_quant_conv(z)
|
557 |
+
dec = self.decoder(z)
|
558 |
+
return dec
|
559 |
+
|
560 |
+
def forward(self, input, sample_posterior=True):
|
561 |
+
posterior = self.encode(input)
|
562 |
+
if sample_posterior:
|
563 |
+
z = posterior.sample()
|
564 |
+
else:
|
565 |
+
z = posterior.mode()
|
566 |
+
dec = self.decode(z)
|
567 |
+
return dec, posterior
|