MultiMatrix commited on
Commit
5d60839
·
verified ·
1 Parent(s): 5318034

Upload 19 files

Browse files
.gitattributes CHANGED
@@ -6,3 +6,4 @@ assets/visual_results/bsr6.png filter=lfs diff=lfs merge=lfs -text
6
  assets/visual_results/tiled_sampling.png filter=lfs diff=lfs merge=lfs -text
7
  assets/visual_results/whole_image1.png filter=lfs diff=lfs merge=lfs -text
8
  assets/visual_results/whole_image2.png filter=lfs diff=lfs merge=lfs -text
 
 
6
  assets/visual_results/tiled_sampling.png filter=lfs diff=lfs merge=lfs -text
7
  assets/visual_results/whole_image1.png filter=lfs diff=lfs merge=lfs -text
8
  assets/visual_results/whole_image2.png filter=lfs diff=lfs merge=lfs -text
9
+ model/open_clip/bpe_simple_vocab_16e6.txt.gz filter=lfs diff=lfs merge=lfs -text
model/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import config
2
+
3
+ from .controlnet import ControlledUnetModel, ControlNet
4
+ from .vae import AutoencoderKL
5
+ from .clip import FrozenOpenCLIPEmbedder
6
+
7
+ from .cldm import ControlLDM
8
+ from .gaussian_diffusion import Diffusion
9
+
10
+ from .swinir import SwinIR
11
+ from .bsrnet import RRDBNet
12
+ from .scunet import SCUNet
model/attention.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from packaging import version
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn, einsum
5
+ from einops import rearrange, repeat
6
+ from typing import Optional, Any
7
+
8
+ from model.util import (
9
+ checkpoint, zero_module, exists, default
10
+ )
11
+ from model.config import Config, AttnMode
12
+
13
+
14
+ # CrossAttn precision handling
15
+ import os
16
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
17
+
18
+
19
+ # feedforward
20
+ class GEGLU(nn.Module):
21
+ def __init__(self, dim_in, dim_out):
22
+ super().__init__()
23
+ self.proj = nn.Linear(dim_in, dim_out * 2)
24
+
25
+ def forward(self, x):
26
+ x, gate = self.proj(x).chunk(2, dim=-1)
27
+ return x * F.gelu(gate)
28
+
29
+
30
+ class FeedForward(nn.Module):
31
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
32
+ super().__init__()
33
+ inner_dim = int(dim * mult)
34
+ dim_out = default(dim_out, dim)
35
+ project_in = nn.Sequential(
36
+ nn.Linear(dim, inner_dim),
37
+ nn.GELU()
38
+ ) if not glu else GEGLU(dim, inner_dim)
39
+
40
+ self.net = nn.Sequential(
41
+ project_in,
42
+ nn.Dropout(dropout),
43
+ nn.Linear(inner_dim, dim_out)
44
+ )
45
+
46
+ def forward(self, x):
47
+ return self.net(x)
48
+
49
+
50
+ def Normalize(in_channels):
51
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
52
+
53
+
54
+ class CrossAttention(nn.Module):
55
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
56
+ super().__init__()
57
+ print(f"Setting up {self.__class__.__name__} (vanilla). Query dim is {query_dim}, context_dim is {context_dim} and using "
58
+ f"{heads} heads.")
59
+ inner_dim = dim_head * heads
60
+ context_dim = default(context_dim, query_dim)
61
+
62
+ self.scale = dim_head ** -0.5
63
+ self.heads = heads
64
+
65
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
66
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
67
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
68
+
69
+ self.to_out = nn.Sequential(
70
+ nn.Linear(inner_dim, query_dim),
71
+ nn.Dropout(dropout)
72
+ )
73
+
74
+ def forward(self, x, context=None, mask=None):
75
+ h = self.heads
76
+
77
+ q = self.to_q(x)
78
+ context = default(context, x)
79
+ k = self.to_k(context)
80
+ v = self.to_v(context)
81
+
82
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
83
+
84
+ # force cast to fp32 to avoid overflowing
85
+ if _ATTN_PRECISION =="fp32":
86
+ # with torch.autocast(enabled=False, device_type = 'cuda'):
87
+ with torch.autocast(enabled=False, device_type="cuda" if str(x.device).startswith("cuda") else "cpu"):
88
+ q, k = q.float(), k.float()
89
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
90
+ else:
91
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
92
+
93
+ del q, k
94
+
95
+ if exists(mask):
96
+ mask = rearrange(mask, 'b ... -> b (...)')
97
+ max_neg_value = -torch.finfo(sim.dtype).max
98
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
99
+ sim.masked_fill_(~mask, max_neg_value)
100
+
101
+ # attention, what we cannot get enough of
102
+ sim = sim.softmax(dim=-1)
103
+
104
+ out = einsum('b i j, b j d -> b i d', sim, v)
105
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
106
+ return self.to_out(out)
107
+
108
+
109
+ class MemoryEfficientCrossAttention(nn.Module):
110
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
111
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
112
+ super().__init__()
113
+ print(f"Setting up {self.__class__.__name__} (xformers). Query dim is {query_dim}, context_dim is {context_dim} and using "
114
+ f"{heads} heads.")
115
+ inner_dim = dim_head * heads
116
+ context_dim = default(context_dim, query_dim)
117
+
118
+ self.heads = heads
119
+ self.dim_head = dim_head
120
+
121
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
122
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
123
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
124
+
125
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
126
+ self.attention_op: Optional[Any] = None
127
+
128
+ def forward(self, x, context=None, mask=None):
129
+ q = self.to_q(x)
130
+ context = default(context, x)
131
+ k = self.to_k(context)
132
+ v = self.to_v(context)
133
+
134
+ b, _, _ = q.shape
135
+ q, k, v = map(
136
+ lambda t: t.unsqueeze(3)
137
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
138
+ .permute(0, 2, 1, 3)
139
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
140
+ .contiguous(),
141
+ (q, k, v),
142
+ )
143
+
144
+ # actually compute the attention, what we cannot get enough of
145
+ out = Config.xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
146
+
147
+ if exists(mask):
148
+ raise NotImplementedError
149
+ out = (
150
+ out.unsqueeze(0)
151
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
152
+ .permute(0, 2, 1, 3)
153
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
154
+ )
155
+ return self.to_out(out)
156
+
157
+
158
+ class SDPCrossAttention(nn.Module):
159
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
160
+ super().__init__()
161
+ print(f"Setting up {self.__class__.__name__} (sdp). Query dim is {query_dim}, context_dim is {context_dim} and using "
162
+ f"{heads} heads.")
163
+ inner_dim = dim_head * heads
164
+ context_dim = default(context_dim, query_dim)
165
+
166
+ self.heads = heads
167
+ self.dim_head = dim_head
168
+
169
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
170
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
171
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
172
+
173
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
174
+
175
+ def forward(self, x, context=None, mask=None):
176
+ q = self.to_q(x)
177
+ context = default(context, x)
178
+ k = self.to_k(context)
179
+ v = self.to_v(context)
180
+
181
+ b, _, _ = q.shape
182
+ q, k, v = map(
183
+ lambda t: t.unsqueeze(3)
184
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
185
+ .permute(0, 2, 1, 3)
186
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
187
+ .contiguous(),
188
+ (q, k, v),
189
+ )
190
+
191
+ # actually compute the attention, what we cannot get enough of
192
+ out = F.scaled_dot_product_attention(q, k, v)
193
+
194
+ if exists(mask):
195
+ raise NotImplementedError
196
+ out = (
197
+ out.unsqueeze(0)
198
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
199
+ .permute(0, 2, 1, 3)
200
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
201
+ )
202
+ return self.to_out(out)
203
+
204
+
205
+ class BasicTransformerBlock(nn.Module):
206
+ ATTENTION_MODES = {
207
+ AttnMode.VANILLA: CrossAttention, # vanilla attention
208
+ AttnMode.XFORMERS: MemoryEfficientCrossAttention,
209
+ AttnMode.SDP: SDPCrossAttention
210
+ }
211
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
212
+ disable_self_attn=False):
213
+ super().__init__()
214
+ attn_cls = self.ATTENTION_MODES[Config.attn_mode]
215
+ self.disable_self_attn = disable_self_attn
216
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
217
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
218
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
219
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
220
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
221
+ self.norm1 = nn.LayerNorm(dim)
222
+ self.norm2 = nn.LayerNorm(dim)
223
+ self.norm3 = nn.LayerNorm(dim)
224
+ self.checkpoint = checkpoint
225
+
226
+ def forward(self, x, context=None):
227
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
228
+
229
+ def _forward(self, x, context=None):
230
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
231
+ x = self.attn2(self.norm2(x), context=context) + x
232
+ x = self.ff(self.norm3(x)) + x
233
+ return x
234
+
235
+
236
+ class SpatialTransformer(nn.Module):
237
+ """
238
+ Transformer block for image-like data.
239
+ First, project the input (aka embedding)
240
+ and reshape to b, t, d.
241
+ Then apply standard transformer action.
242
+ Finally, reshape to image
243
+ NEW: use_linear for more efficiency instead of the 1x1 convs
244
+ """
245
+ def __init__(self, in_channels, n_heads, d_head,
246
+ depth=1, dropout=0., context_dim=None,
247
+ disable_self_attn=False, use_linear=False,
248
+ use_checkpoint=True):
249
+ super().__init__()
250
+ if exists(context_dim) and not isinstance(context_dim, list):
251
+ context_dim = [context_dim]
252
+ self.in_channels = in_channels
253
+ inner_dim = n_heads * d_head
254
+ self.norm = Normalize(in_channels)
255
+ if not use_linear:
256
+ self.proj_in = nn.Conv2d(in_channels,
257
+ inner_dim,
258
+ kernel_size=1,
259
+ stride=1,
260
+ padding=0)
261
+ else:
262
+ self.proj_in = nn.Linear(in_channels, inner_dim)
263
+
264
+ self.transformer_blocks = nn.ModuleList(
265
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
266
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
267
+ for d in range(depth)]
268
+ )
269
+ if not use_linear:
270
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
271
+ in_channels,
272
+ kernel_size=1,
273
+ stride=1,
274
+ padding=0))
275
+ else:
276
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
277
+ self.use_linear = use_linear
278
+
279
+ def forward(self, x, context=None):
280
+ # note: if no context is given, cross-attention defaults to self-attention
281
+ if not isinstance(context, list):
282
+ context = [context]
283
+ b, c, h, w = x.shape
284
+ x_in = x
285
+ x = self.norm(x)
286
+ if not self.use_linear:
287
+ x = self.proj_in(x)
288
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
289
+ if self.use_linear:
290
+ x = self.proj_in(x)
291
+ for i, block in enumerate(self.transformer_blocks):
292
+ x = block(x, context=context[i])
293
+ if self.use_linear:
294
+ x = self.proj_out(x)
295
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
296
+ if not self.use_linear:
297
+ x = self.proj_out(x)
298
+ return x + x_in
model/bsrnet.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From BSRGAN
2
+ import functools
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.nn.init as init
7
+
8
+
9
+ def initialize_weights(net_l, scale=1):
10
+ if not isinstance(net_l, list):
11
+ net_l = [net_l]
12
+ for net in net_l:
13
+ for m in net.modules():
14
+ if isinstance(m, nn.Conv2d):
15
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
16
+ m.weight.data *= scale # for residual block
17
+ if m.bias is not None:
18
+ m.bias.data.zero_()
19
+ elif isinstance(m, nn.Linear):
20
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
21
+ m.weight.data *= scale
22
+ if m.bias is not None:
23
+ m.bias.data.zero_()
24
+ elif isinstance(m, nn.BatchNorm2d):
25
+ init.constant_(m.weight, 1)
26
+ init.constant_(m.bias.data, 0.0)
27
+
28
+
29
+ def make_layer(block, n_layers):
30
+ layers = []
31
+ for _ in range(n_layers):
32
+ layers.append(block())
33
+ return nn.Sequential(*layers)
34
+
35
+
36
+ class ResidualDenseBlock_5C(nn.Module):
37
+ def __init__(self, nf=64, gc=32, bias=True):
38
+ super(ResidualDenseBlock_5C, self).__init__()
39
+ # gc: growth channel, i.e. intermediate channels
40
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
41
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
42
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
43
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
44
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
45
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
46
+
47
+ # initialization
48
+ initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
49
+
50
+ def forward(self, x):
51
+ x1 = self.lrelu(self.conv1(x))
52
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
53
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
54
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
55
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
56
+ return x5 * 0.2 + x
57
+
58
+
59
+ class RRDB(nn.Module):
60
+ '''Residual in Residual Dense Block'''
61
+
62
+ def __init__(self, nf, gc=32):
63
+ super(RRDB, self).__init__()
64
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
65
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
66
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
67
+
68
+ def forward(self, x):
69
+ out = self.RDB1(x)
70
+ out = self.RDB2(out)
71
+ out = self.RDB3(out)
72
+ return out * 0.2 + x
73
+
74
+
75
+ class RRDBNet(nn.Module):
76
+ def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
77
+ super(RRDBNet, self).__init__()
78
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
79
+ self.sf = sf
80
+ print([in_nc, out_nc, nf, nb, gc, sf])
81
+
82
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
83
+ self.RRDB_trunk = make_layer(RRDB_block_f, nb)
84
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
85
+ #### upsampling
86
+ self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
87
+ if self.sf==4:
88
+ self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
89
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
90
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
91
+
92
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
93
+
94
+ def forward(self, x):
95
+ fea = self.conv_first(x)
96
+ trunk = self.trunk_conv(self.RRDB_trunk(fea))
97
+ fea = fea + trunk
98
+
99
+ fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
100
+ if self.sf==4:
101
+ fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
102
+ out = self.conv_last(self.lrelu(self.HRconv(fea)))
103
+
104
+ return out
model/cldm.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Set, List, Dict
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from model import (
7
+ ControlledUnetModel, ControlNet,
8
+ AutoencoderKL, FrozenOpenCLIPEmbedder
9
+ )
10
+ from utils.common import sliding_windows, count_vram_usage, gaussian_weights
11
+
12
+
13
+ def disabled_train(self: nn.Module) -> nn.Module:
14
+ """Overwrite model.train with this function to make sure train/eval mode
15
+ does not change anymore."""
16
+ return self
17
+
18
+
19
+ class ControlLDM(nn.Module):
20
+
21
+ def __init__(
22
+ self,
23
+ unet_cfg,
24
+ vae_cfg,
25
+ clip_cfg,
26
+ controlnet_cfg,
27
+ latent_scale_factor
28
+ ):
29
+ super().__init__()
30
+ self.unet = ControlledUnetModel(**unet_cfg)
31
+ self.vae = AutoencoderKL(**vae_cfg)
32
+ self.clip = FrozenOpenCLIPEmbedder(**clip_cfg)
33
+ self.controlnet = ControlNet(**controlnet_cfg)
34
+ self.scale_factor = latent_scale_factor
35
+ self.control_scales = [1.0] * 13
36
+
37
+ @torch.no_grad()
38
+ def load_pretrained_sd(self, sd: Dict[str, torch.Tensor]) -> Set[str]:
39
+ module_map = {
40
+ "unet": "model.diffusion_model",
41
+ "vae": "first_stage_model",
42
+ "clip": "cond_stage_model",
43
+ }
44
+ modules = [("unet", self.unet), ("vae", self.vae), ("clip", self.clip)]
45
+ used = set()
46
+ for name, module in modules:
47
+ init_sd = {}
48
+ scratch_sd = module.state_dict()
49
+ for key in scratch_sd:
50
+ target_key = ".".join([module_map[name], key])
51
+ init_sd[key] = sd[target_key].clone()
52
+ used.add(target_key)
53
+ module.load_state_dict(init_sd, strict=True)
54
+ unused = set(sd.keys()) - used
55
+ # NOTE: this is slightly different from previous version, which haven't switched
56
+ # the UNet to eval mode and disabled the requires_grad flag.
57
+ for module in [self.vae, self.clip, self.unet]:
58
+ module.eval()
59
+ module.train = disabled_train
60
+ for p in module.parameters():
61
+ p.requires_grad = False
62
+ return unused
63
+
64
+ @torch.no_grad()
65
+ def load_controlnet_from_ckpt(self, sd: Dict[str, torch.Tensor]) -> None:
66
+ self.controlnet.load_state_dict(sd, strict=True)
67
+
68
+ @torch.no_grad()
69
+ def load_controlnet_from_unet(self) -> Tuple[Set[str]]:
70
+ unet_sd = self.unet.state_dict()
71
+ scratch_sd = self.controlnet.state_dict()
72
+ init_sd = {}
73
+ init_with_new_zero = set()
74
+ init_with_scratch = set()
75
+ for key in scratch_sd:
76
+ if key in unet_sd:
77
+ this, target = scratch_sd[key], unet_sd[key]
78
+ if this.size() == target.size():
79
+ init_sd[key] = target.clone()
80
+ else:
81
+ d_ic = this.size(1) - target.size(1)
82
+ oc, _, h, w = this.size()
83
+ zeros = torch.zeros((oc, d_ic, h, w), dtype=target.dtype)
84
+ init_sd[key] = torch.cat((target, zeros), dim=1)
85
+ init_with_new_zero.add(key)
86
+ else:
87
+ init_sd[key] = scratch_sd[key].clone()
88
+ init_with_scratch.add(key)
89
+ self.controlnet.load_state_dict(init_sd, strict=True)
90
+ return init_with_new_zero, init_with_scratch
91
+
92
+ def vae_encode(self, image: torch.Tensor, sample: bool=True) -> torch.Tensor:
93
+ if sample:
94
+ return self.vae.encode(image).sample() * self.scale_factor
95
+ else:
96
+ return self.vae.encode(image).mode() * self.scale_factor
97
+
98
+ def vae_encode_tiled(self, image: torch.Tensor, tile_size: int, tile_stride: int, sample: bool=True) -> torch.Tensor:
99
+ bs, _, h, w = image.shape
100
+ z = torch.zeros((bs, 4, h // 8, w // 8), dtype=torch.float32, device=image.device)
101
+ count = torch.zeros_like(z, dtype=torch.float32)
102
+ weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None]
103
+ weights = torch.tensor(weights, dtype=torch.float32, device=image.device)
104
+ tiles = sliding_windows(h // 8, w // 8, tile_size // 8, tile_stride // 8)
105
+ for hi, hi_end, wi, wi_end in tiles:
106
+ tile_image = image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8]
107
+ z[:, :, hi:hi_end, wi:wi_end] += self.vae_encode(tile_image, sample=sample) * weights
108
+ count[:, :, hi:hi_end, wi:wi_end] += weights
109
+ z.div_(count)
110
+ return z
111
+
112
+ def vae_decode(self, z: torch.Tensor) -> torch.Tensor:
113
+ return self.vae.decode(z / self.scale_factor)
114
+
115
+ @count_vram_usage
116
+ def vae_decode_tiled(self, z: torch.Tensor, tile_size: int, tile_stride: int) -> torch.Tensor:
117
+ bs, _, h, w = z.shape
118
+ image = torch.zeros((bs, 3, h * 8, w * 8), dtype=torch.float32, device=z.device)
119
+ count = torch.zeros_like(image, dtype=torch.float32)
120
+ weights = gaussian_weights(tile_size * 8, tile_size * 8)[None, None]
121
+ weights = torch.tensor(weights, dtype=torch.float32, device=z.device)
122
+ tiles = sliding_windows(h, w, tile_size, tile_stride)
123
+ for hi, hi_end, wi, wi_end in tiles:
124
+ tile_z = z[:, :, hi:hi_end, wi:wi_end]
125
+ image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += self.vae_decode(tile_z) * weights
126
+ count[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += weights
127
+ image.div_(count)
128
+ return image
129
+
130
+ def prepare_condition(self, clean: torch.Tensor, txt: List[str]) -> Dict[str, torch.Tensor]:
131
+ return dict(
132
+ c_txt=self.clip.encode(txt),
133
+ c_img=self.vae_encode(clean * 2 - 1, sample=False)
134
+ )
135
+
136
+ @count_vram_usage
137
+ def prepare_condition_tiled(self, clean: torch.Tensor, txt: List[str], tile_size: int, tile_stride: int) -> Dict[str, torch.Tensor]:
138
+ return dict(
139
+ c_txt=self.clip.encode(txt),
140
+ c_img=self.vae_encode_tiled(clean * 2 - 1, tile_size, tile_stride, sample=False)
141
+ )
142
+
143
+ def forward(self, x_noisy, t, cond):
144
+ c_txt = cond["c_txt"]
145
+ c_img = cond["c_img"]
146
+ control = self.controlnet(
147
+ x=x_noisy, hint=c_img,
148
+ timesteps=t, context=c_txt
149
+ )
150
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
151
+ eps = self.unet(
152
+ x=x_noisy, timesteps=t,
153
+ context=c_txt, control=control, only_mid_control=False
154
+ )
155
+ return eps
model/clip.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.checkpoint import checkpoint
5
+ from model.open_clip import CLIP, tokenize
6
+
7
+ ### pretrained model path
8
+ # _VITH14 = dict(
9
+ # laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
10
+ # )
11
+
12
+ class FrozenOpenCLIPEmbedder(nn.Module):
13
+ """
14
+ Uses the OpenCLIP transformer encoder for text
15
+ """
16
+ LAYERS = [
17
+ #"pooled",
18
+ "last",
19
+ "penultimate"
20
+ ]
21
+ def __init__(self, embed_dim, vision_cfg, text_cfg, layer="last"):
22
+ super().__init__()
23
+ assert layer in self.LAYERS
24
+ # model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
25
+ model = CLIP(embed_dim, dict(vision_cfg), dict(text_cfg))
26
+ del model.visual
27
+ self.model = model
28
+
29
+ self.layer = layer
30
+ if self.layer == "last":
31
+ self.layer_idx = 0
32
+ elif self.layer == "penultimate":
33
+ self.layer_idx = 1
34
+ else:
35
+ raise NotImplementedError()
36
+
37
+ def forward(self, tokens):
38
+ z = self.encode_with_transformer(tokens)
39
+ return z
40
+
41
+ def encode_with_transformer(self, text):
42
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
43
+ x = x + self.model.positional_embedding
44
+ x = x.permute(1, 0, 2) # NLD -> LND
45
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
46
+ x = x.permute(1, 0, 2) # LND -> NLD
47
+ x = self.model.ln_final(x)
48
+ return x
49
+
50
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
51
+ for i, r in enumerate(self.model.transformer.resblocks):
52
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
53
+ break
54
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
55
+ x = checkpoint(r, x, attn_mask)
56
+ else:
57
+ x = r(x, attn_mask=attn_mask)
58
+ return x
59
+
60
+ def encode(self, text: List[str]) -> torch.Tensor:
61
+ # convert a batch of text to tensor
62
+ tokens = tokenize(text)
63
+ # move tensor to model device
64
+ tokens = tokens.to(next(self.model.parameters()).device)
65
+ return self(tokens)
model/config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Literal
3
+ from types import ModuleType
4
+ import enum
5
+ from packaging import version
6
+
7
+ import torch
8
+
9
+ # collect system information
10
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
11
+ SDP_IS_AVAILABLE = True
12
+ else:
13
+ SDP_IS_AVAILABLE = False
14
+
15
+ try:
16
+ import xformers
17
+ import xformers.ops
18
+ XFORMERS_IS_AVAILBLE = True
19
+ except:
20
+ XFORMERS_IS_AVAILBLE = False
21
+
22
+
23
+ class AttnMode(enum.Enum):
24
+ SDP = 0
25
+ XFORMERS = 1
26
+ VANILLA = 2
27
+
28
+
29
+ class Config:
30
+ xformers: Optional[ModuleType] = None
31
+ attn_mode: AttnMode = AttnMode.VANILLA
32
+
33
+
34
+ # initialize attention mode
35
+ if SDP_IS_AVAILABLE:
36
+ Config.attn_mode = AttnMode.SDP
37
+ print(f"use sdp attention as default")
38
+ elif XFORMERS_IS_AVAILBLE:
39
+ Config.attn_mode = AttnMode.XFORMERS
40
+ print(f"use xformers attention as default")
41
+ else:
42
+ print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default")
43
+
44
+ if XFORMERS_IS_AVAILBLE:
45
+ Config.xformers = xformers
46
+
47
+
48
+ # user-specified attention mode
49
+ ATTN_MODE = os.environ.get("ATTN_MODE", None)
50
+ if ATTN_MODE is not None:
51
+ assert ATTN_MODE in ["vanilla", "sdp", "xformers"]
52
+ if ATTN_MODE == "sdp":
53
+ assert SDP_IS_AVAILABLE
54
+ Config.attn_mode = AttnMode.SDP
55
+ elif ATTN_MODE == "xformers":
56
+ assert XFORMERS_IS_AVAILBLE
57
+ Config.attn_mode = AttnMode.XFORMERS
58
+ else:
59
+ Config.attn_mode = AttnMode.VANILLA
60
+ print(f"set attention mode to {ATTN_MODE}")
61
+ else:
62
+ print("keep default attention mode")
model/controlnet.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch as th
3
+ import torch.nn as nn
4
+
5
+ from model.util import (
6
+ conv_nd,
7
+ linear,
8
+ zero_module,
9
+ timestep_embedding,
10
+ exists
11
+ )
12
+ from model.attention import SpatialTransformer
13
+ from model.unet import (
14
+ TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock, UNetModel
15
+ )
16
+
17
+
18
+ class ControlledUnetModel(UNetModel):
19
+
20
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
21
+ hs = []
22
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
23
+ emb = self.time_embed(t_emb)
24
+ h = x.type(self.dtype)
25
+ for module in self.input_blocks:
26
+ h = module(h, emb, context)
27
+ hs.append(h)
28
+ h = self.middle_block(h, emb, context)
29
+
30
+ if control is not None:
31
+ h += control.pop()
32
+
33
+ for i, module in enumerate(self.output_blocks):
34
+ if only_mid_control or control is None:
35
+ h = torch.cat([h, hs.pop()], dim=1)
36
+ else:
37
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
38
+ h = module(h, emb, context)
39
+
40
+ h = h.type(x.dtype)
41
+ return self.out(h)
42
+
43
+
44
+ class ControlNet(nn.Module):
45
+
46
+ def __init__(
47
+ self,
48
+ image_size,
49
+ in_channels,
50
+ model_channels,
51
+ hint_channels,
52
+ num_res_blocks,
53
+ attention_resolutions,
54
+ dropout=0,
55
+ channel_mult=(1, 2, 4, 8),
56
+ conv_resample=True,
57
+ dims=2,
58
+ use_checkpoint=False,
59
+ use_fp16=False,
60
+ num_heads=-1,
61
+ num_head_channels=-1,
62
+ num_heads_upsample=-1,
63
+ use_scale_shift_norm=False,
64
+ resblock_updown=False,
65
+ use_new_attention_order=False,
66
+ use_spatial_transformer=False, # custom transformer support
67
+ transformer_depth=1, # custom transformer support
68
+ context_dim=None, # custom transformer support
69
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
70
+ legacy=True,
71
+ disable_self_attentions=None,
72
+ num_attention_blocks=None,
73
+ disable_middle_self_attn=False,
74
+ use_linear_in_transformer=False,
75
+ ):
76
+ super().__init__()
77
+ if use_spatial_transformer:
78
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
79
+
80
+ if context_dim is not None:
81
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
82
+ from omegaconf.listconfig import ListConfig
83
+ if type(context_dim) == ListConfig:
84
+ context_dim = list(context_dim)
85
+
86
+ if num_heads_upsample == -1:
87
+ num_heads_upsample = num_heads
88
+
89
+ if num_heads == -1:
90
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
91
+
92
+ if num_head_channels == -1:
93
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
94
+
95
+ self.dims = dims
96
+ self.image_size = image_size
97
+ self.in_channels = in_channels
98
+ self.model_channels = model_channels
99
+ if isinstance(num_res_blocks, int):
100
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
101
+ else:
102
+ if len(num_res_blocks) != len(channel_mult):
103
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
104
+ "as a list/tuple (per-level) with the same length as channel_mult")
105
+ self.num_res_blocks = num_res_blocks
106
+ if disable_self_attentions is not None:
107
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
108
+ assert len(disable_self_attentions) == len(channel_mult)
109
+ if num_attention_blocks is not None:
110
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
111
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
112
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
113
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
114
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
115
+ f"attention will still not be set.")
116
+
117
+ self.attention_resolutions = attention_resolutions
118
+ self.dropout = dropout
119
+ self.channel_mult = channel_mult
120
+ self.conv_resample = conv_resample
121
+ self.use_checkpoint = use_checkpoint
122
+ self.dtype = th.float16 if use_fp16 else th.float32
123
+ self.num_heads = num_heads
124
+ self.num_head_channels = num_head_channels
125
+ self.num_heads_upsample = num_heads_upsample
126
+ self.predict_codebook_ids = n_embed is not None
127
+
128
+ time_embed_dim = model_channels * 4
129
+ self.time_embed = nn.Sequential(
130
+ linear(model_channels, time_embed_dim),
131
+ nn.SiLU(),
132
+ linear(time_embed_dim, time_embed_dim),
133
+ )
134
+
135
+ self.input_blocks = nn.ModuleList(
136
+ [
137
+ TimestepEmbedSequential(
138
+ conv_nd(dims, in_channels + hint_channels, model_channels, 3, padding=1)
139
+ )
140
+ ]
141
+ )
142
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
143
+
144
+ self._feature_size = model_channels
145
+ input_block_chans = [model_channels]
146
+ ch = model_channels
147
+ ds = 1
148
+ for level, mult in enumerate(channel_mult):
149
+ for nr in range(self.num_res_blocks[level]):
150
+ layers = [
151
+ ResBlock(
152
+ ch,
153
+ time_embed_dim,
154
+ dropout,
155
+ out_channels=mult * model_channels,
156
+ dims=dims,
157
+ use_checkpoint=use_checkpoint,
158
+ use_scale_shift_norm=use_scale_shift_norm,
159
+ )
160
+ ]
161
+ ch = mult * model_channels
162
+ if ds in attention_resolutions:
163
+ if num_head_channels == -1:
164
+ dim_head = ch // num_heads
165
+ else:
166
+ num_heads = ch // num_head_channels
167
+ dim_head = num_head_channels
168
+ if legacy:
169
+ # num_heads = 1
170
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
171
+ if exists(disable_self_attentions):
172
+ disabled_sa = disable_self_attentions[level]
173
+ else:
174
+ disabled_sa = False
175
+
176
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
177
+ layers.append(
178
+ AttentionBlock(
179
+ ch,
180
+ use_checkpoint=use_checkpoint,
181
+ num_heads=num_heads,
182
+ num_head_channels=dim_head,
183
+ use_new_attention_order=use_new_attention_order,
184
+ ) if not use_spatial_transformer else SpatialTransformer(
185
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
186
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
187
+ use_checkpoint=use_checkpoint
188
+ )
189
+ )
190
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
191
+ self.zero_convs.append(self.make_zero_conv(ch))
192
+ self._feature_size += ch
193
+ input_block_chans.append(ch)
194
+ if level != len(channel_mult) - 1:
195
+ out_ch = ch
196
+ self.input_blocks.append(
197
+ TimestepEmbedSequential(
198
+ ResBlock(
199
+ ch,
200
+ time_embed_dim,
201
+ dropout,
202
+ out_channels=out_ch,
203
+ dims=dims,
204
+ use_checkpoint=use_checkpoint,
205
+ use_scale_shift_norm=use_scale_shift_norm,
206
+ down=True,
207
+ )
208
+ if resblock_updown
209
+ else Downsample(
210
+ ch, conv_resample, dims=dims, out_channels=out_ch
211
+ )
212
+ )
213
+ )
214
+ ch = out_ch
215
+ input_block_chans.append(ch)
216
+ self.zero_convs.append(self.make_zero_conv(ch))
217
+ ds *= 2
218
+ self._feature_size += ch
219
+
220
+ if num_head_channels == -1:
221
+ dim_head = ch // num_heads
222
+ else:
223
+ num_heads = ch // num_head_channels
224
+ dim_head = num_head_channels
225
+ if legacy:
226
+ # num_heads = 1
227
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
228
+ self.middle_block = TimestepEmbedSequential(
229
+ ResBlock(
230
+ ch,
231
+ time_embed_dim,
232
+ dropout,
233
+ dims=dims,
234
+ use_checkpoint=use_checkpoint,
235
+ use_scale_shift_norm=use_scale_shift_norm,
236
+ ),
237
+ AttentionBlock(
238
+ ch,
239
+ use_checkpoint=use_checkpoint,
240
+ num_heads=num_heads,
241
+ num_head_channels=dim_head,
242
+ use_new_attention_order=use_new_attention_order,
243
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
244
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
245
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
246
+ use_checkpoint=use_checkpoint
247
+ ),
248
+ ResBlock(
249
+ ch,
250
+ time_embed_dim,
251
+ dropout,
252
+ dims=dims,
253
+ use_checkpoint=use_checkpoint,
254
+ use_scale_shift_norm=use_scale_shift_norm,
255
+ ),
256
+ )
257
+ self.middle_block_out = self.make_zero_conv(ch)
258
+ self._feature_size += ch
259
+
260
+ def make_zero_conv(self, channels):
261
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
262
+
263
+ def forward(self, x, hint, timesteps, context, **kwargs):
264
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
265
+ emb = self.time_embed(t_emb)
266
+ x = torch.cat((x, hint), dim=1)
267
+ outs = []
268
+
269
+ h = x.type(self.dtype)
270
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
271
+ h = module(h, emb, context)
272
+ outs.append(zero_conv(h, emb, context))
273
+
274
+ h = self.middle_block(h, emb, context)
275
+ outs.append(self.middle_block_out(h, emb, context))
276
+
277
+ return outs
model/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )
model/gaussian_diffusion.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+ import numpy as np
7
+
8
+
9
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
10
+ if schedule == "linear":
11
+ betas = (
12
+ np.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=np.float64) ** 2
13
+ )
14
+
15
+ elif schedule == "cosine":
16
+ timesteps = (
17
+ np.arange(n_timestep + 1, dtype=np.float64) / n_timestep + cosine_s
18
+ )
19
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
20
+ alphas = np.cos(alphas).pow(2)
21
+ alphas = alphas / alphas[0]
22
+ betas = 1 - alphas[1:] / alphas[:-1]
23
+ betas = np.clip(betas, a_min=0, a_max=0.999)
24
+
25
+ elif schedule == "sqrt_linear":
26
+ betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64)
27
+ elif schedule == "sqrt":
28
+ betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) ** 0.5
29
+ else:
30
+ raise ValueError(f"schedule '{schedule}' unknown.")
31
+ return betas
32
+
33
+
34
+ def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int]) -> torch.Tensor:
35
+ b, *_ = t.shape
36
+ out = a.gather(-1, t)
37
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
38
+
39
+
40
+ class Diffusion(nn.Module):
41
+
42
+ def __init__(
43
+ self,
44
+ timesteps=1000,
45
+ beta_schedule="linear",
46
+ loss_type="l2",
47
+ linear_start=1e-4,
48
+ linear_end=2e-2,
49
+ cosine_s=8e-3,
50
+ parameterization="eps"
51
+ ):
52
+ super().__init__()
53
+ self.num_timesteps = timesteps
54
+ self.beta_schedule = beta_schedule
55
+ self.linear_start = linear_start
56
+ self.linear_end = linear_end
57
+ self.cosine_s = cosine_s
58
+ assert parameterization in ["eps", "x0", "v"], "currently only supporting 'eps' and 'x0' and 'v'"
59
+ self.parameterization = parameterization
60
+ self.loss_type = loss_type
61
+
62
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
63
+ cosine_s=cosine_s)
64
+ alphas = 1. - betas
65
+ alphas_cumprod = np.cumprod(alphas, axis=0)
66
+ sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
67
+ sqrt_one_minus_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
68
+
69
+ self.betas = betas
70
+ self.register("sqrt_alphas_cumprod", sqrt_alphas_cumprod)
71
+ self.register("sqrt_one_minus_alphas_cumprod", sqrt_one_minus_alphas_cumprod)
72
+
73
+ def register(self, name: str, value: np.ndarray) -> None:
74
+ self.register_buffer(name, torch.tensor(value, dtype=torch.float32))
75
+
76
+ def q_sample(self, x_start, t, noise):
77
+ return (
78
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
79
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
80
+ )
81
+
82
+ def get_v(self, x, noise, t):
83
+ return (
84
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
85
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
86
+ )
87
+
88
+ def get_loss(self, pred, target, mean=True):
89
+ if self.loss_type == 'l1':
90
+ loss = (target - pred).abs()
91
+ if mean:
92
+ loss = loss.mean()
93
+ elif self.loss_type == 'l2':
94
+ if mean:
95
+ loss = torch.nn.functional.mse_loss(target, pred)
96
+ else:
97
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
98
+ else:
99
+ raise NotImplementedError("unknown loss type '{loss_type}'")
100
+
101
+ return loss
102
+
103
+ def p_losses(self, model, x_start, t, cond):
104
+ noise = torch.randn_like(x_start)
105
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
106
+ model_output = model(x_noisy, t, cond)
107
+
108
+ if self.parameterization == "x0":
109
+ target = x_start
110
+ elif self.parameterization == "eps":
111
+ target = noise
112
+ elif self.parameterization == "v":
113
+ target = self.get_v(x_start, noise, t)
114
+ else:
115
+ raise NotImplementedError()
116
+
117
+ loss_simple = self.get_loss(model_output, target, mean=False).mean()
118
+ return loss_simple
model/open_clip/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .model import CLIP
2
+ from .tokenizer import tokenize
3
+
4
+ __all__ = ["CLIP", "tokenize"]
model/open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
model/open_clip/model.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+
13
+ from .transformer import LayerNormFp32, LayerNorm, QuickGELU, VisionTransformer, TextTransformer
14
+
15
+
16
+ @dataclass
17
+ class CLIPVisionCfg:
18
+ layers: Union[Tuple[int, int, int, int], int] = 12
19
+ width: int = 768
20
+ head_width: int = 64
21
+ mlp_ratio: float = 4.0
22
+ patch_size: int = 16
23
+ image_size: Union[Tuple[int, int], int] = 224
24
+
25
+ ls_init_value: Optional[float] = None # layer scale initial value
26
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
27
+ input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
28
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
29
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
30
+ n_queries: int = 256 # n_queries for attentional pooler
31
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
32
+ output_tokens: bool = False
33
+
34
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
35
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
36
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
37
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
38
+ timm_proj_bias: bool = False # enable bias final projection
39
+ timm_drop: float = 0. # head dropout
40
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
41
+
42
+
43
+ @dataclass
44
+ class CLIPTextCfg:
45
+ context_length: int = 77
46
+ vocab_size: int = 49408
47
+ width: int = 512
48
+ heads: int = 8
49
+ layers: int = 12
50
+ ls_init_value: Optional[float] = None # layer scale initial value
51
+ hf_model_name: str = None
52
+ hf_tokenizer_name: str = None
53
+ hf_model_pretrained: bool = True
54
+ proj: str = 'mlp'
55
+ pooler_type: str = 'mean_pooler'
56
+ embed_cls: bool = False
57
+ pad_id: int = 0
58
+ output_tokens: bool = False
59
+
60
+
61
+ def get_cast_dtype(precision: str):
62
+ cast_dtype = None
63
+ if precision == 'bf16':
64
+ cast_dtype = torch.bfloat16
65
+ elif precision == 'fp16':
66
+ cast_dtype = torch.float16
67
+ return cast_dtype
68
+
69
+
70
+ def _build_vision_tower(
71
+ embed_dim: int,
72
+ vision_cfg: CLIPVisionCfg,
73
+ quick_gelu: bool = False,
74
+ cast_dtype: Optional[torch.dtype] = None
75
+ ):
76
+ if isinstance(vision_cfg, dict):
77
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
78
+
79
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
80
+ # memory efficient in recent PyTorch releases (>= 1.10).
81
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
82
+ act_layer = QuickGELU if quick_gelu else nn.GELU
83
+
84
+ vision_heads = vision_cfg.width // vision_cfg.head_width
85
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
86
+ visual = VisionTransformer(
87
+ image_size=vision_cfg.image_size,
88
+ patch_size=vision_cfg.patch_size,
89
+ width=vision_cfg.width,
90
+ layers=vision_cfg.layers,
91
+ heads=vision_heads,
92
+ mlp_ratio=vision_cfg.mlp_ratio,
93
+ ls_init_value=vision_cfg.ls_init_value,
94
+ patch_dropout=vision_cfg.patch_dropout,
95
+ input_patchnorm=vision_cfg.input_patchnorm,
96
+ global_average_pool=vision_cfg.global_average_pool,
97
+ attentional_pool=vision_cfg.attentional_pool,
98
+ n_queries=vision_cfg.n_queries,
99
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
100
+ output_tokens=vision_cfg.output_tokens,
101
+ output_dim=embed_dim,
102
+ act_layer=act_layer,
103
+ norm_layer=norm_layer,
104
+ )
105
+
106
+ return visual
107
+
108
+
109
+ def _build_text_tower(
110
+ embed_dim: int,
111
+ text_cfg: CLIPTextCfg,
112
+ quick_gelu: bool = False,
113
+ cast_dtype: Optional[torch.dtype] = None,
114
+ ):
115
+ if isinstance(text_cfg, dict):
116
+ text_cfg = CLIPTextCfg(**text_cfg)
117
+
118
+ act_layer = QuickGELU if quick_gelu else nn.GELU
119
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
120
+
121
+ text = TextTransformer(
122
+ context_length=text_cfg.context_length,
123
+ vocab_size=text_cfg.vocab_size,
124
+ width=text_cfg.width,
125
+ heads=text_cfg.heads,
126
+ layers=text_cfg.layers,
127
+ ls_init_value=text_cfg.ls_init_value,
128
+ output_dim=embed_dim,
129
+ embed_cls=text_cfg.embed_cls,
130
+ output_tokens=text_cfg.output_tokens,
131
+ pad_id=text_cfg.pad_id,
132
+ act_layer=act_layer,
133
+ norm_layer=norm_layer,
134
+ )
135
+ return text
136
+
137
+
138
+ class CLIP(nn.Module):
139
+ output_dict: torch.jit.Final[bool]
140
+
141
+ def __init__(
142
+ self,
143
+ embed_dim: int,
144
+ vision_cfg: CLIPVisionCfg,
145
+ text_cfg: CLIPTextCfg,
146
+ quick_gelu: bool = False,
147
+ cast_dtype: Optional[torch.dtype] = None,
148
+ output_dict: bool = False,
149
+ ):
150
+ super().__init__()
151
+ self.output_dict = output_dict
152
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
153
+
154
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
155
+ self.transformer = text.transformer
156
+ self.context_length = text.context_length
157
+ self.vocab_size = text.vocab_size
158
+ self.token_embedding = text.token_embedding
159
+ self.positional_embedding = text.positional_embedding
160
+ self.ln_final = text.ln_final
161
+ self.text_projection = text.text_projection
162
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
163
+
164
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
165
+
166
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
167
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
168
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
169
+
170
+ @torch.jit.ignore
171
+ def set_grad_checkpointing(self, enable=True):
172
+ self.visual.set_grad_checkpointing(enable)
173
+ self.transformer.grad_checkpointing = enable
174
+
175
+ def encode_image(self, image, normalize: bool = False):
176
+ features = self.visual(image)
177
+ return F.normalize(features, dim=-1) if normalize else features
178
+
179
+ def encode_text(self, text, normalize: bool = False):
180
+ cast_dtype = self.transformer.get_cast_dtype()
181
+
182
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
183
+
184
+ x = x + self.positional_embedding.to(cast_dtype)
185
+ x = x.permute(1, 0, 2) # NLD -> LND
186
+ x = self.transformer(x, attn_mask=self.attn_mask)
187
+ x = x.permute(1, 0, 2) # LND -> NLD
188
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
189
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
190
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
191
+ return F.normalize(x, dim=-1) if normalize else x
192
+
193
+ def forward(
194
+ self,
195
+ image: Optional[torch.Tensor] = None,
196
+ text: Optional[torch.Tensor] = None,
197
+ ):
198
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
199
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
200
+ if self.output_dict:
201
+ return {
202
+ "image_features": image_features,
203
+ "text_features": text_features,
204
+ "logit_scale": self.logit_scale.exp()
205
+ }
206
+ return image_features, text_features, self.logit_scale.exp()
model/open_clip/tokenizer.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import gzip
6
+ import html
7
+ import os
8
+ from functools import lru_cache
9
+ from typing import Union, List
10
+
11
+ import ftfy
12
+ import regex as re
13
+ import torch
14
+
15
+ # https://stackoverflow.com/q/62691279
16
+ import os
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+
19
+
20
+ @lru_cache()
21
+ def default_bpe():
22
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
23
+
24
+
25
+ @lru_cache()
26
+ def bytes_to_unicode():
27
+ """
28
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
29
+ The reversible bpe codes work on unicode strings.
30
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32
+ This is a significant percentage of your normal, say, 32K bpe vocab.
33
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
35
+ """
36
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
37
+ cs = bs[:]
38
+ n = 0
39
+ for b in range(2**8):
40
+ if b not in bs:
41
+ bs.append(b)
42
+ cs.append(2**8+n)
43
+ n += 1
44
+ cs = [chr(n) for n in cs]
45
+ return dict(zip(bs, cs))
46
+
47
+
48
+ def get_pairs(word):
49
+ """Return set of symbol pairs in a word.
50
+ Word is represented as tuple of symbols (symbols being variable-length strings).
51
+ """
52
+ pairs = set()
53
+ prev_char = word[0]
54
+ for char in word[1:]:
55
+ pairs.add((prev_char, char))
56
+ prev_char = char
57
+ return pairs
58
+
59
+
60
+ def basic_clean(text):
61
+ text = ftfy.fix_text(text)
62
+ text = html.unescape(html.unescape(text))
63
+ return text.strip()
64
+
65
+
66
+ def whitespace_clean(text):
67
+ text = re.sub(r'\s+', ' ', text)
68
+ text = text.strip()
69
+ return text
70
+
71
+
72
+ class SimpleTokenizer(object):
73
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
74
+ self.byte_encoder = bytes_to_unicode()
75
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
76
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
77
+ merges = merges[1:49152-256-2+1]
78
+ merges = [tuple(merge.split()) for merge in merges]
79
+ vocab = list(bytes_to_unicode().values())
80
+ vocab = vocab + [v+'</w>' for v in vocab]
81
+ for merge in merges:
82
+ vocab.append(''.join(merge))
83
+ if not special_tokens:
84
+ special_tokens = ['<start_of_text>', '<end_of_text>']
85
+ else:
86
+ special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
87
+ vocab.extend(special_tokens)
88
+ self.encoder = dict(zip(vocab, range(len(vocab))))
89
+ self.decoder = {v: k for k, v in self.encoder.items()}
90
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
91
+ self.cache = {t:t for t in special_tokens}
92
+ special = "|".join(special_tokens)
93
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
94
+
95
+ self.vocab_size = len(self.encoder)
96
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
97
+
98
+ def bpe(self, token):
99
+ if token in self.cache:
100
+ return self.cache[token]
101
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
102
+ pairs = get_pairs(word)
103
+
104
+ if not pairs:
105
+ return token+'</w>'
106
+
107
+ while True:
108
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
109
+ if bigram not in self.bpe_ranks:
110
+ break
111
+ first, second = bigram
112
+ new_word = []
113
+ i = 0
114
+ while i < len(word):
115
+ try:
116
+ j = word.index(first, i)
117
+ new_word.extend(word[i:j])
118
+ i = j
119
+ except:
120
+ new_word.extend(word[i:])
121
+ break
122
+
123
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
124
+ new_word.append(first+second)
125
+ i += 2
126
+ else:
127
+ new_word.append(word[i])
128
+ i += 1
129
+ new_word = tuple(new_word)
130
+ word = new_word
131
+ if len(word) == 1:
132
+ break
133
+ else:
134
+ pairs = get_pairs(word)
135
+ word = ' '.join(word)
136
+ self.cache[token] = word
137
+ return word
138
+
139
+ def encode(self, text):
140
+ bpe_tokens = []
141
+ text = whitespace_clean(basic_clean(text)).lower()
142
+ for token in re.findall(self.pat, text):
143
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
144
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
145
+ return bpe_tokens
146
+
147
+ def decode(self, tokens):
148
+ text = ''.join([self.decoder[token] for token in tokens])
149
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
150
+ return text
151
+
152
+
153
+ _tokenizer = SimpleTokenizer()
154
+
155
+ def decode(output_ids: torch.Tensor):
156
+ output_ids = output_ids.cpu().numpy()
157
+ return _tokenizer.decode(output_ids)
158
+
159
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
160
+ """
161
+ Returns the tokenized representation of given input string(s)
162
+
163
+ Parameters
164
+ ----------
165
+ texts : Union[str, List[str]]
166
+ An input string or a list of input strings to tokenize
167
+ context_length : int
168
+ The context length to use; all CLIP models use 77 as the context length
169
+
170
+ Returns
171
+ -------
172
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
173
+ """
174
+ if isinstance(texts, str):
175
+ texts = [texts]
176
+
177
+ sot_token = _tokenizer.encoder["<start_of_text>"]
178
+ eot_token = _tokenizer.encoder["<end_of_text>"]
179
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
180
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
181
+
182
+ for i, tokens in enumerate(all_tokens):
183
+ if len(tokens) > context_length:
184
+ tokens = tokens[:context_length] # Truncate
185
+ tokens[-1] = eot_token
186
+ result[i, :len(tokens)] = torch.tensor(tokens)
187
+
188
+ return result
189
+
190
+
191
+ class HFTokenizer:
192
+ """HuggingFace tokenizer wrapper"""
193
+
194
+ def __init__(self, tokenizer_name: str):
195
+ from transformers import AutoTokenizer
196
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
197
+
198
+ def save_pretrained(self, dest):
199
+ self.tokenizer.save_pretrained(dest)
200
+
201
+ def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
202
+ # same cleaning as for default tokenizer, except lowercasing
203
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
204
+ if isinstance(texts, str):
205
+ texts = [texts]
206
+ texts = [whitespace_clean(basic_clean(text)) for text in texts]
207
+ input_ids = self.tokenizer(
208
+ texts,
209
+ return_tensors='pt',
210
+ max_length=context_length,
211
+ padding='max_length',
212
+ truncation=True,
213
+ ).input_ids
214
+ return input_ids
model/open_clip/transformer.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ from collections import OrderedDict
3
+ import math
4
+ from typing import Callable, Optional, Sequence, Tuple
5
+ from itertools import repeat
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+ # From PyTorch internals
13
+ def _ntuple(n):
14
+ def parse(x):
15
+ if isinstance(x, collections.abc.Iterable):
16
+ return x
17
+ return tuple(repeat(x, n))
18
+ return parse
19
+
20
+ to_2tuple = _ntuple(2)
21
+
22
+
23
+ class LayerNormFp32(nn.LayerNorm):
24
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
25
+
26
+ def forward(self, x: torch.Tensor):
27
+ orig_type = x.dtype
28
+ x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
29
+ return x.to(orig_type)
30
+
31
+
32
+ class LayerNorm(nn.LayerNorm):
33
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
34
+
35
+ def forward(self, x: torch.Tensor):
36
+ orig_type = x.dtype
37
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
38
+ return x.to(orig_type)
39
+
40
+
41
+ class QuickGELU(nn.Module):
42
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
43
+ def forward(self, x: torch.Tensor):
44
+ return x * torch.sigmoid(1.702 * x)
45
+
46
+
47
+ class LayerScale(nn.Module):
48
+ def __init__(self, dim, init_values=1e-5, inplace=False):
49
+ super().__init__()
50
+ self.inplace = inplace
51
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
52
+
53
+ def forward(self, x):
54
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
55
+
56
+
57
+ class PatchDropout(nn.Module):
58
+ """
59
+ https://arxiv.org/abs/2212.00794
60
+ """
61
+
62
+ def __init__(self, prob, exclude_first_token=True):
63
+ super().__init__()
64
+ assert 0 <= prob < 1.
65
+ self.prob = prob
66
+ self.exclude_first_token = exclude_first_token # exclude CLS token
67
+
68
+ def forward(self, x):
69
+ if not self.training or self.prob == 0.:
70
+ return x
71
+
72
+ if self.exclude_first_token:
73
+ cls_tokens, x = x[:, :1], x[:, 1:]
74
+ else:
75
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
76
+
77
+ batch = x.size()[0]
78
+ num_tokens = x.size()[1]
79
+
80
+ batch_indices = torch.arange(batch)
81
+ batch_indices = batch_indices[..., None]
82
+
83
+ keep_prob = 1 - self.prob
84
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
85
+
86
+ rand = torch.randn(batch, num_tokens)
87
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
88
+
89
+ x = x[batch_indices, patch_indices_keep]
90
+
91
+ if self.exclude_first_token:
92
+ x = torch.cat((cls_tokens, x), dim=1)
93
+
94
+ return x
95
+
96
+
97
+ class Attention(nn.Module):
98
+ def __init__(
99
+ self,
100
+ dim,
101
+ num_heads=8,
102
+ qkv_bias=True,
103
+ scaled_cosine=False,
104
+ scale_heads=False,
105
+ logit_scale_max=math.log(1. / 0.01),
106
+ attn_drop=0.,
107
+ proj_drop=0.
108
+ ):
109
+ super().__init__()
110
+ self.scaled_cosine = scaled_cosine
111
+ self.scale_heads = scale_heads
112
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
113
+ self.num_heads = num_heads
114
+ self.head_dim = dim // num_heads
115
+ self.scale = self.head_dim ** -0.5
116
+ self.logit_scale_max = logit_scale_max
117
+
118
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
119
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
120
+ if qkv_bias:
121
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
122
+ else:
123
+ self.in_proj_bias = None
124
+
125
+ if self.scaled_cosine:
126
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
127
+ else:
128
+ self.logit_scale = None
129
+ self.attn_drop = nn.Dropout(attn_drop)
130
+ if self.scale_heads:
131
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
132
+ else:
133
+ self.head_scale = None
134
+ self.out_proj = nn.Linear(dim, dim)
135
+ self.out_drop = nn.Dropout(proj_drop)
136
+
137
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
138
+ L, N, C = x.shape
139
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
140
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
141
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
142
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
143
+
144
+ if self.logit_scale is not None:
145
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
146
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
147
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
148
+ attn = attn.view(-1, L, L)
149
+ else:
150
+ q = q * self.scale
151
+ attn = torch.bmm(q, k.transpose(-1, -2))
152
+
153
+ if attn_mask is not None:
154
+ if attn_mask.dtype == torch.bool:
155
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
156
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
157
+ attn_mask = new_attn_mask
158
+ attn += attn_mask
159
+
160
+ attn = attn.softmax(dim=-1)
161
+ attn = self.attn_drop(attn)
162
+
163
+ x = torch.bmm(attn, v)
164
+ if self.head_scale is not None:
165
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
166
+ x = x.view(-1, L, C)
167
+ x = x.transpose(0, 1).reshape(L, N, C)
168
+ x = self.out_proj(x)
169
+ x = self.out_drop(x)
170
+ return x
171
+
172
+
173
+ class AttentionalPooler(nn.Module):
174
+ def __init__(
175
+ self,
176
+ d_model: int,
177
+ context_dim: int,
178
+ n_head: int = 8,
179
+ n_queries: int = 256,
180
+ norm_layer: Callable = LayerNorm
181
+ ):
182
+ super().__init__()
183
+ self.query = nn.Parameter(torch.randn(n_queries, d_model))
184
+ self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
185
+ self.ln_q = norm_layer(d_model)
186
+ self.ln_k = norm_layer(context_dim)
187
+
188
+ def forward(self, x: torch.Tensor):
189
+ x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
190
+ N = x.shape[1]
191
+ q = self.ln_q(self.query)
192
+ out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
193
+ return out.permute(1, 0, 2) # LND -> NLD
194
+
195
+ def _repeat(self, query, N: int):
196
+ return query.unsqueeze(1).repeat(1, N, 1)
197
+
198
+
199
+ class ResidualAttentionBlock(nn.Module):
200
+ def __init__(
201
+ self,
202
+ d_model: int,
203
+ n_head: int,
204
+ mlp_ratio: float = 4.0,
205
+ ls_init_value: float = None,
206
+ act_layer: Callable = nn.GELU,
207
+ norm_layer: Callable = LayerNorm,
208
+ is_cross_attention: bool = False,
209
+ ):
210
+ super().__init__()
211
+
212
+ self.ln_1 = norm_layer(d_model)
213
+ self.attn = nn.MultiheadAttention(d_model, n_head)
214
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
215
+ if is_cross_attention:
216
+ self.ln_1_kv = norm_layer(d_model)
217
+
218
+ self.ln_2 = norm_layer(d_model)
219
+ mlp_width = int(d_model * mlp_ratio)
220
+ self.mlp = nn.Sequential(OrderedDict([
221
+ ("c_fc", nn.Linear(d_model, mlp_width)),
222
+ ("gelu", act_layer()),
223
+ ("c_proj", nn.Linear(mlp_width, d_model))
224
+ ]))
225
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
226
+
227
+ def attention(
228
+ self,
229
+ q_x: torch.Tensor,
230
+ k_x: Optional[torch.Tensor] = None,
231
+ v_x: Optional[torch.Tensor] = None,
232
+ attn_mask: Optional[torch.Tensor] = None,
233
+ ):
234
+ k_x = k_x if k_x is not None else q_x
235
+ v_x = v_x if v_x is not None else q_x
236
+
237
+ attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
238
+ return self.attn(
239
+ q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
240
+ )[0]
241
+
242
+ def forward(
243
+ self,
244
+ q_x: torch.Tensor,
245
+ k_x: Optional[torch.Tensor] = None,
246
+ v_x: Optional[torch.Tensor] = None,
247
+ attn_mask: Optional[torch.Tensor] = None,
248
+ ):
249
+ k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
250
+ v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
251
+
252
+ x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
253
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
254
+ return x
255
+
256
+
257
+ class CustomResidualAttentionBlock(nn.Module):
258
+ def __init__(
259
+ self,
260
+ d_model: int,
261
+ n_head: int,
262
+ mlp_ratio: float = 4.0,
263
+ ls_init_value: float = None,
264
+ act_layer: Callable = nn.GELU,
265
+ norm_layer: Callable = LayerNorm,
266
+ scale_cosine_attn: bool = False,
267
+ scale_heads: bool = False,
268
+ scale_attn: bool = False,
269
+ scale_fc: bool = False,
270
+ ):
271
+ super().__init__()
272
+
273
+ self.ln_1 = norm_layer(d_model)
274
+ self.attn = Attention(
275
+ d_model, n_head,
276
+ scaled_cosine=scale_cosine_attn,
277
+ scale_heads=scale_heads,
278
+ )
279
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
280
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
281
+
282
+ self.ln_2 = norm_layer(d_model)
283
+ mlp_width = int(d_model * mlp_ratio)
284
+ self.mlp = nn.Sequential(OrderedDict([
285
+ ("c_fc", nn.Linear(d_model, mlp_width)),
286
+ ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
287
+ ("gelu", act_layer()),
288
+ ("c_proj", nn.Linear(mlp_width, d_model))
289
+ ]))
290
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
291
+
292
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
293
+ x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
294
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
295
+ return x
296
+
297
+
298
+ class Transformer(nn.Module):
299
+ def __init__(
300
+ self,
301
+ width: int,
302
+ layers: int,
303
+ heads: int,
304
+ mlp_ratio: float = 4.0,
305
+ ls_init_value: float = None,
306
+ act_layer: Callable = nn.GELU,
307
+ norm_layer: Callable = LayerNorm,
308
+ ):
309
+ super().__init__()
310
+ self.width = width
311
+ self.layers = layers
312
+ self.grad_checkpointing = False
313
+
314
+ self.resblocks = nn.ModuleList([
315
+ ResidualAttentionBlock(
316
+ width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer)
317
+ for _ in range(layers)
318
+ ])
319
+
320
+ def get_cast_dtype(self) -> torch.dtype:
321
+ if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):
322
+ return self.resblocks[0].mlp.c_fc.int8_original_dtype
323
+ return self.resblocks[0].mlp.c_fc.weight.dtype
324
+
325
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
326
+ for r in self.resblocks:
327
+ if self.grad_checkpointing and not torch.jit.is_scripting():
328
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
329
+ x = checkpoint(r, x, None, None, attn_mask)
330
+ else:
331
+ x = r(x, attn_mask=attn_mask)
332
+ return x
333
+
334
+
335
+ class VisionTransformer(nn.Module):
336
+ output_tokens: torch.jit.Final[bool]
337
+
338
+ def __init__(
339
+ self,
340
+ image_size: int,
341
+ patch_size: int,
342
+ width: int,
343
+ layers: int,
344
+ heads: int,
345
+ mlp_ratio: float,
346
+ ls_init_value: float = None,
347
+ global_average_pool: bool = False,
348
+ attentional_pool: bool = False,
349
+ n_queries: int = 256,
350
+ attn_pooler_heads: int = 8,
351
+ output_dim: int = 512,
352
+ patch_dropout: float = 0.,
353
+ input_patchnorm: bool = False,
354
+ act_layer: Callable = nn.GELU,
355
+ norm_layer: Callable = LayerNorm,
356
+ output_tokens: bool = False
357
+ ):
358
+ super().__init__()
359
+ self.output_tokens = output_tokens
360
+ image_height, image_width = self.image_size = to_2tuple(image_size)
361
+ patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
362
+ self.grid_size = (image_height // patch_height, image_width // patch_width)
363
+ self.output_dim = output_dim
364
+
365
+ # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1
366
+ self.input_patchnorm = input_patchnorm
367
+
368
+ if input_patchnorm:
369
+ patch_input_dim = patch_height * patch_width * 3
370
+ self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
371
+ self.conv1 = nn.Linear(patch_input_dim, width)
372
+ else:
373
+ self.patchnorm_pre_ln = nn.Identity()
374
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
375
+
376
+ # class embeddings and positional embeddings
377
+ scale = width ** -0.5
378
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
379
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
380
+
381
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
382
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
383
+
384
+ self.ln_pre = norm_layer(width)
385
+ self.transformer = Transformer(
386
+ width,
387
+ layers,
388
+ heads,
389
+ mlp_ratio,
390
+ ls_init_value=ls_init_value,
391
+ act_layer=act_layer,
392
+ norm_layer=norm_layer,
393
+ )
394
+
395
+ self.global_average_pool = global_average_pool
396
+ if attentional_pool:
397
+ self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
398
+ self.ln_post = norm_layer(output_dim)
399
+ self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
400
+ else:
401
+ self.attn_pool = None
402
+ self.ln_post = norm_layer(width)
403
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
404
+
405
+ self.init_parameters()
406
+
407
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
408
+ for param in self.parameters():
409
+ param.requires_grad = False
410
+
411
+ if unlocked_groups != 0:
412
+ groups = [
413
+ [
414
+ self.conv1,
415
+ self.class_embedding,
416
+ self.positional_embedding,
417
+ self.ln_pre,
418
+ ],
419
+ *self.transformer.resblocks[:-1],
420
+ [
421
+ self.transformer.resblocks[-1],
422
+ self.ln_post,
423
+ ],
424
+ self.proj,
425
+ ]
426
+
427
+ def _unlock(x):
428
+ if isinstance(x, Sequence):
429
+ for g in x:
430
+ _unlock(g)
431
+ else:
432
+ if isinstance(x, torch.nn.Parameter):
433
+ x.requires_grad = True
434
+ else:
435
+ for p in x.parameters():
436
+ p.requires_grad = True
437
+
438
+ _unlock(groups[-unlocked_groups:])
439
+
440
+ def init_parameters(self):
441
+ # FIXME OpenAI CLIP did not define an init for the VisualTransformer
442
+ # TODO experiment if default PyTorch init, below, or alternate init is best.
443
+
444
+ # nn.init.normal_(self.class_embedding, std=self.scale)
445
+ # nn.init.normal_(self.positional_embedding, std=self.scale)
446
+ #
447
+ # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
448
+ # attn_std = self.transformer.width ** -0.5
449
+ # fc_std = (2 * self.transformer.width) ** -0.5
450
+ # for block in self.transformer.resblocks:
451
+ # nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
452
+ # nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
453
+ # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
454
+ # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
455
+ #
456
+ # if self.text_projection is not None:
457
+ # nn.init.normal_(self.text_projection, std=self.scale)
458
+ pass
459
+
460
+ @torch.jit.ignore
461
+ def set_grad_checkpointing(self, enable=True):
462
+ self.transformer.grad_checkpointing = enable
463
+
464
+ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
465
+ if self.global_average_pool:
466
+ return x.mean(dim=1), x
467
+ else:
468
+ return x[:, 0], x[:, 1:]
469
+
470
+ def forward(self, x: torch.Tensor):
471
+
472
+ # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
473
+ if self.input_patchnorm:
474
+ # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
475
+ x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1])
476
+ x = x.permute(0, 2, 4, 1, 3, 5)
477
+ x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
478
+ x = self.patchnorm_pre_ln(x)
479
+ x = self.conv1(x)
480
+ else:
481
+ x = self.conv1(x) # shape = [*, width, grid, grid]
482
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
483
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
484
+
485
+ # class embeddings and positional embeddings
486
+ x = torch.cat(
487
+ [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
488
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
489
+ x = x + self.positional_embedding.to(x.dtype)
490
+
491
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
492
+ x = self.patch_dropout(x)
493
+ x = self.ln_pre(x)
494
+
495
+ x = x.permute(1, 0, 2) # NLD -> LND
496
+ x = self.transformer(x)
497
+ x = x.permute(1, 0, 2) # LND -> NLD
498
+
499
+ if self.attn_pool is not None:
500
+ x = self.attn_pool(x)
501
+ x = self.ln_post(x)
502
+ pooled, tokens = self._global_pool(x)
503
+ else:
504
+ pooled, tokens = self._global_pool(x)
505
+ pooled = self.ln_post(pooled)
506
+
507
+ if self.proj is not None:
508
+ pooled = pooled @ self.proj
509
+
510
+ if self.output_tokens:
511
+ return pooled, tokens
512
+
513
+ return pooled
514
+
515
+
516
+ class TextTransformer(nn.Module):
517
+ output_tokens: torch.jit.Final[bool]
518
+
519
+ def __init__(
520
+ self,
521
+ context_length: int = 77,
522
+ vocab_size: int = 49408,
523
+ width: int = 512,
524
+ heads: int = 8,
525
+ layers: int = 12,
526
+ ls_init_value: float = None,
527
+ output_dim: int = 512,
528
+ act_layer: Callable = nn.GELU,
529
+ norm_layer: Callable = LayerNorm,
530
+ embed_cls: bool = False,
531
+ pad_id: int = 0,
532
+ output_tokens: bool = False,
533
+ ):
534
+ super().__init__()
535
+ self.output_tokens = output_tokens
536
+ self.num_pos = self.context_length = context_length
537
+ self.vocab_size = vocab_size
538
+ self.width = width
539
+ self.output_dim = output_dim
540
+ self.heads = heads
541
+ self.pad_id = pad_id
542
+
543
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
544
+
545
+ if embed_cls:
546
+ self.cls_emb = nn.Parameter(torch.empty(width))
547
+ self.num_pos += 1
548
+ else:
549
+ self.cls_emb = None
550
+
551
+ self.token_embedding = nn.Embedding(vocab_size, width)
552
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
553
+ self.transformer = Transformer(
554
+ width=width,
555
+ layers=layers,
556
+ heads=heads,
557
+ ls_init_value=ls_init_value,
558
+ act_layer=act_layer,
559
+ norm_layer=norm_layer,
560
+ )
561
+ self.ln_final = norm_layer(width)
562
+
563
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
564
+
565
+ self.init_parameters()
566
+
567
+ def init_parameters(self):
568
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
569
+ nn.init.normal_(self.positional_embedding, std=0.01)
570
+ if self.cls_emb is not None:
571
+ nn.init.normal_(self.cls_emb, std=0.01)
572
+
573
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
574
+ attn_std = self.transformer.width ** -0.5
575
+ fc_std = (2 * self.transformer.width) ** -0.5
576
+ for block in self.transformer.resblocks:
577
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
578
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
579
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
580
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
581
+
582
+ if self.text_projection is not None:
583
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
584
+
585
+ @torch.jit.ignore
586
+ def set_grad_checkpointing(self, enable=True):
587
+ self.transformer.grad_checkpointing = enable
588
+
589
+ def build_attention_mask(self):
590
+ # lazily create causal attention mask, with full attention between the tokens
591
+ # pytorch uses additive attention mask; fill with -inf
592
+ mask = torch.empty(self.num_pos, self.num_pos)
593
+ mask.fill_(float("-inf"))
594
+ mask.triu_(1) # zero out the lower diagonal
595
+ return mask
596
+
597
+ def build_cls_mask(self, text, cast_dtype: torch.dtype):
598
+ cls_mask = (text != self.pad_id).unsqueeze(1)
599
+ cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
600
+ additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
601
+ additive_mask.fill_(0)
602
+ additive_mask.masked_fill_(~cls_mask, float("-inf"))
603
+ additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
604
+ return additive_mask
605
+
606
+ def _repeat(self, t, N: int):
607
+ return t.reshape(1, 1, -1).repeat(N, 1, 1)
608
+
609
+ def forward(self, text):
610
+ cast_dtype = self.transformer.get_cast_dtype()
611
+ seq_len = text.shape[1]
612
+
613
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
614
+ attn_mask = self.attn_mask
615
+ if self.cls_emb is not None:
616
+ seq_len += 1
617
+ x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
618
+ cls_mask = self.build_cls_mask(text, cast_dtype)
619
+ attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
620
+
621
+ x = x + self.positional_embedding[:seq_len].to(cast_dtype)
622
+ x = x.permute(1, 0, 2) # NLD -> LND
623
+ x = self.transformer(x, attn_mask=attn_mask)
624
+ x = x.permute(1, 0, 2) # LND -> NLD
625
+
626
+ # x.shape = [batch_size, n_ctx, transformer.width]
627
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
628
+ if self.cls_emb is not None:
629
+ pooled, tokens = x[:, -1], x[:, :-1]
630
+ pooled = self.ln_final(pooled)
631
+ else:
632
+ x = self.ln_final(x)
633
+ pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
634
+
635
+ if self.text_projection is not None:
636
+ pooled = pooled @ self.text_projection
637
+
638
+ if self.output_tokens:
639
+ return pooled, tokens
640
+
641
+ return pooled
642
+
643
+
644
+ class MultimodalTransformer(Transformer):
645
+ def __init__(
646
+ self,
647
+ width: int,
648
+ layers: int,
649
+ heads: int,
650
+ context_length: int = 77,
651
+ mlp_ratio: float = 4.0,
652
+ ls_init_value: float = None,
653
+ act_layer: Callable = nn.GELU,
654
+ norm_layer: Callable = LayerNorm,
655
+ output_dim: int = 512,
656
+ ):
657
+
658
+ super().__init__(
659
+ width=width,
660
+ layers=layers,
661
+ heads=heads,
662
+ mlp_ratio=mlp_ratio,
663
+ ls_init_value=ls_init_value,
664
+ act_layer=act_layer,
665
+ norm_layer=norm_layer,
666
+ )
667
+ self.context_length = context_length
668
+ self.cross_attn = nn.ModuleList([
669
+ ResidualAttentionBlock(
670
+ width,
671
+ heads,
672
+ mlp_ratio,
673
+ ls_init_value=ls_init_value,
674
+ act_layer=act_layer,
675
+ norm_layer=norm_layer,
676
+ is_cross_attention=True,
677
+ )
678
+ for _ in range(layers)
679
+ ])
680
+
681
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
682
+
683
+ self.ln_final = norm_layer(width)
684
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
685
+
686
+ def init_parameters(self):
687
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
688
+ attn_std = self.transformer.width ** -0.5
689
+ fc_std = (2 * self.transformer.width) ** -0.5
690
+ for block in self.transformer.resblocks:
691
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
692
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
693
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
694
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
695
+ for block in self.transformer.cross_attn:
696
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
697
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
698
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
699
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
700
+
701
+ if self.text_projection is not None:
702
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
703
+
704
+ def build_attention_mask(self):
705
+ # lazily create causal attention mask, with full attention between the tokens
706
+ # pytorch uses additive attention mask; fill with -inf
707
+ mask = torch.empty(self.context_length, self.context_length)
708
+ mask.fill_(float("-inf"))
709
+ mask.triu_(1) # zero out the lower diagonal
710
+ return mask
711
+
712
+ def forward(self, image_embs, text_embs):
713
+ text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
714
+ image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
715
+ seq_len = text_embs.shape[0]
716
+
717
+ for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
718
+ if self.grad_checkpointing and not torch.jit.is_scripting():
719
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
720
+ text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
721
+ text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
722
+ else:
723
+ text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
724
+ text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
725
+
726
+ x = text_embs.permute(1, 0, 2) # LND -> NLD
727
+ x = self.ln_final(x)
728
+
729
+ if self.text_projection is not None:
730
+ x = x @ self.text_projection
731
+
732
+ return x
733
+
734
+ @torch.jit.ignore
735
+ def set_grad_checkpointing(self, enable=True):
736
+ self.grad_checkpointing = enable
model/scunet.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from einops import rearrange
5
+ from einops.layers.torch import Rearrange
6
+ from timm.models.layers import trunc_normal_, DropPath
7
+
8
+
9
+ class WMSA(nn.Module):
10
+ """ Self-attention module in Swin Transformer
11
+ """
12
+
13
+ def __init__(self, input_dim, output_dim, head_dim, window_size, type):
14
+ super(WMSA, self).__init__()
15
+ self.input_dim = input_dim
16
+ self.output_dim = output_dim
17
+ self.head_dim = head_dim
18
+ self.scale = self.head_dim ** -0.5
19
+ self.n_heads = input_dim//head_dim
20
+ self.window_size = window_size
21
+ self.type=type
22
+ self.embedding_layer = nn.Linear(self.input_dim, 3*self.input_dim, bias=True)
23
+
24
+ # TODO recover
25
+ # self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1))
26
+ self.relative_position_params = nn.Parameter(torch.zeros((2 * window_size - 1)*(2 * window_size -1), self.n_heads))
27
+
28
+ self.linear = nn.Linear(self.input_dim, self.output_dim)
29
+
30
+ trunc_normal_(self.relative_position_params, std=.02)
31
+ self.relative_position_params = torch.nn.Parameter(self.relative_position_params.view(2*window_size-1, 2*window_size-1, self.n_heads).transpose(1,2).transpose(0,1))
32
+
33
+ def generate_mask(self, h, w, p, shift):
34
+ """ generating the mask of SW-MSA
35
+ Args:
36
+ shift: shift parameters in CyclicShift.
37
+ Returns:
38
+ attn_mask: should be (1 1 w p p),
39
+ """
40
+ # supporting sqaure.
41
+ attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
42
+ if self.type == 'W':
43
+ return attn_mask
44
+
45
+ s = p - shift
46
+ attn_mask[-1, :, :s, :, s:, :] = True
47
+ attn_mask[-1, :, s:, :, :s, :] = True
48
+ attn_mask[:, -1, :, :s, :, s:] = True
49
+ attn_mask[:, -1, :, s:, :, :s] = True
50
+ attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
51
+ return attn_mask
52
+
53
+ def forward(self, x):
54
+ """ Forward pass of Window Multi-head Self-attention module.
55
+ Args:
56
+ x: input tensor with shape of [b h w c];
57
+ attn_mask: attention mask, fill -inf where the value is True;
58
+ Returns:
59
+ output: tensor shape [b h w c]
60
+ """
61
+ if self.type!='W': x = torch.roll(x, shifts=(-(self.window_size//2), -(self.window_size//2)), dims=(1,2))
62
+ x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
63
+ h_windows = x.size(1)
64
+ w_windows = x.size(2)
65
+ # sqaure validation
66
+ # assert h_windows == w_windows
67
+
68
+ x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
69
+ qkv = self.embedding_layer(x)
70
+ q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
71
+ sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
72
+ # Adding learnable relative embedding
73
+ sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
74
+ # Using Attn Mask to distinguish different subwindows.
75
+ if self.type != 'W':
76
+ attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size//2)
77
+ sim = sim.masked_fill_(attn_mask, float("-inf"))
78
+
79
+ probs = nn.functional.softmax(sim, dim=-1)
80
+ output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
81
+ output = rearrange(output, 'h b w p c -> b w p (h c)')
82
+ output = self.linear(output)
83
+ output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
84
+
85
+ if self.type!='W': output = torch.roll(output, shifts=(self.window_size//2, self.window_size//2), dims=(1,2))
86
+ return output
87
+
88
+ def relative_embedding(self):
89
+ cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
90
+ relation = cord[:, None, :] - cord[None, :, :] + self.window_size -1
91
+ # negative is allowed
92
+ return self.relative_position_params[:, relation[:,:,0].long(), relation[:,:,1].long()]
93
+
94
+
95
+ class Block(nn.Module):
96
+ def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
97
+ """ SwinTransformer Block
98
+ """
99
+ super(Block, self).__init__()
100
+ self.input_dim = input_dim
101
+ self.output_dim = output_dim
102
+ assert type in ['W', 'SW']
103
+ self.type = type
104
+ if input_resolution <= window_size:
105
+ self.type = 'W'
106
+
107
+ print("Block Initial Type: {}, drop_path_rate:{:.6f}".format(self.type, drop_path))
108
+ self.ln1 = nn.LayerNorm(input_dim)
109
+ self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
110
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
111
+ self.ln2 = nn.LayerNorm(input_dim)
112
+ self.mlp = nn.Sequential(
113
+ nn.Linear(input_dim, 4 * input_dim),
114
+ nn.GELU(),
115
+ nn.Linear(4 * input_dim, output_dim),
116
+ )
117
+
118
+ def forward(self, x):
119
+ x = x + self.drop_path(self.msa(self.ln1(x)))
120
+ x = x + self.drop_path(self.mlp(self.ln2(x)))
121
+ return x
122
+
123
+
124
+ class ConvTransBlock(nn.Module):
125
+ def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
126
+ """ SwinTransformer and Conv Block
127
+ """
128
+ super(ConvTransBlock, self).__init__()
129
+ self.conv_dim = conv_dim
130
+ self.trans_dim = trans_dim
131
+ self.head_dim = head_dim
132
+ self.window_size = window_size
133
+ self.drop_path = drop_path
134
+ self.type = type
135
+ self.input_resolution = input_resolution
136
+
137
+ assert self.type in ['W', 'SW']
138
+ if self.input_resolution <= self.window_size:
139
+ self.type = 'W'
140
+
141
+ self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, self.type, self.input_resolution)
142
+ self.conv1_1 = nn.Conv2d(self.conv_dim+self.trans_dim, self.conv_dim+self.trans_dim, 1, 1, 0, bias=True)
143
+ self.conv1_2 = nn.Conv2d(self.conv_dim+self.trans_dim, self.conv_dim+self.trans_dim, 1, 1, 0, bias=True)
144
+
145
+ self.conv_block = nn.Sequential(
146
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
147
+ nn.ReLU(True),
148
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
149
+ )
150
+
151
+ def forward(self, x):
152
+ conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
153
+ conv_x = self.conv_block(conv_x) + conv_x
154
+ trans_x = Rearrange('b c h w -> b h w c')(trans_x)
155
+ trans_x = self.trans_block(trans_x)
156
+ trans_x = Rearrange('b h w c -> b c h w')(trans_x)
157
+ res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
158
+ x = x + res
159
+
160
+ return x
161
+
162
+
163
+ class SCUNet(nn.Module):
164
+
165
+ def __init__(self, in_nc=3, config=[2,2,2,2,2,2,2], dim=64, drop_path_rate=0.0, input_resolution=256):
166
+ super(SCUNet, self).__init__()
167
+ self.config = config
168
+ self.dim = dim
169
+ self.head_dim = 32
170
+ self.window_size = 8
171
+
172
+ # drop path rate for each layer
173
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
174
+
175
+ self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
176
+
177
+ begin = 0
178
+ self.m_down1 = [ConvTransBlock(dim//2, dim//2, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution)
179
+ for i in range(config[0])] + \
180
+ [nn.Conv2d(dim, 2*dim, 2, 2, 0, bias=False)]
181
+
182
+ begin += config[0]
183
+ self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//2)
184
+ for i in range(config[1])] + \
185
+ [nn.Conv2d(2*dim, 4*dim, 2, 2, 0, bias=False)]
186
+
187
+ begin += config[1]
188
+ self.m_down3 = [ConvTransBlock(2*dim, 2*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW',input_resolution//4)
189
+ for i in range(config[2])] + \
190
+ [nn.Conv2d(4*dim, 8*dim, 2, 2, 0, bias=False)]
191
+
192
+ begin += config[2]
193
+ self.m_body = [ConvTransBlock(4*dim, 4*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//8)
194
+ for i in range(config[3])]
195
+
196
+ begin += config[3]
197
+ self.m_up3 = [nn.ConvTranspose2d(8*dim, 4*dim, 2, 2, 0, bias=False),] + \
198
+ [ConvTransBlock(2*dim, 2*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW',input_resolution//4)
199
+ for i in range(config[4])]
200
+
201
+ begin += config[4]
202
+ self.m_up2 = [nn.ConvTranspose2d(4*dim, 2*dim, 2, 2, 0, bias=False),] + \
203
+ [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//2)
204
+ for i in range(config[5])]
205
+
206
+ begin += config[5]
207
+ self.m_up1 = [nn.ConvTranspose2d(2*dim, dim, 2, 2, 0, bias=False),] + \
208
+ [ConvTransBlock(dim//2, dim//2, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution)
209
+ for i in range(config[6])]
210
+
211
+ self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
212
+
213
+ self.m_head = nn.Sequential(*self.m_head)
214
+ self.m_down1 = nn.Sequential(*self.m_down1)
215
+ self.m_down2 = nn.Sequential(*self.m_down2)
216
+ self.m_down3 = nn.Sequential(*self.m_down3)
217
+ self.m_body = nn.Sequential(*self.m_body)
218
+ self.m_up3 = nn.Sequential(*self.m_up3)
219
+ self.m_up2 = nn.Sequential(*self.m_up2)
220
+ self.m_up1 = nn.Sequential(*self.m_up1)
221
+ self.m_tail = nn.Sequential(*self.m_tail)
222
+ #self.apply(self._init_weights)
223
+
224
+ def forward(self, x0):
225
+
226
+ h, w = x0.size()[-2:]
227
+ paddingBottom = int(np.ceil(h/64)*64-h)
228
+ paddingRight = int(np.ceil(w/64)*64-w)
229
+ x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
230
+
231
+ x1 = self.m_head(x0)
232
+ x2 = self.m_down1(x1)
233
+ x3 = self.m_down2(x2)
234
+ x4 = self.m_down3(x3)
235
+ x = self.m_body(x4)
236
+ x = self.m_up3(x+x4)
237
+ x = self.m_up2(x+x3)
238
+ x = self.m_up1(x+x2)
239
+ x = self.m_tail(x+x1)
240
+
241
+ x = x[..., :h, :w]
242
+
243
+ return x
244
+
245
+
246
+ def _init_weights(self, m):
247
+ if isinstance(m, nn.Linear):
248
+ trunc_normal_(m.weight, std=.02)
249
+ if m.bias is not None:
250
+ nn.init.constant_(m.bias, 0)
251
+ elif isinstance(m, nn.LayerNorm):
252
+ nn.init.constant_(m.bias, 0)
253
+ nn.init.constant_(m.weight, 1.0)
254
+
255
+
256
+
257
+ if __name__ == '__main__':
258
+
259
+ # torch.cuda.empty_cache()
260
+ net = SCUNet()
261
+
262
+ x = torch.randn((2, 3, 64, 128))
263
+ x = net(x)
264
+ print(x.shape)
model/swinir.py ADDED
@@ -0,0 +1,905 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------------
2
+ # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
3
+ # Originally Written by Ze Liu, Modified by Jingyun Liang.
4
+ # -----------------------------------------------------------------------------------
5
+
6
+ # Originally borrowed from DifFace (https://github.com/zsyOAOA/DifFace/blob/master/models/swinir.py)
7
+
8
+ import math
9
+ from typing import Set
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
16
+
17
+
18
+ class Mlp(nn.Module):
19
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
20
+ super().__init__()
21
+ out_features = out_features or in_features
22
+ hidden_features = hidden_features or in_features
23
+ self.fc1 = nn.Linear(in_features, hidden_features)
24
+ self.act = act_layer()
25
+ self.fc2 = nn.Linear(hidden_features, out_features)
26
+ self.drop = nn.Dropout(drop)
27
+
28
+ def forward(self, x):
29
+ x = self.fc1(x)
30
+ x = self.act(x)
31
+ x = self.drop(x)
32
+ x = self.fc2(x)
33
+ x = self.drop(x)
34
+ return x
35
+
36
+
37
+ def window_partition(x, window_size):
38
+ """
39
+ Args:
40
+ x: (B, H, W, C)
41
+ window_size (int): window size
42
+
43
+ Returns:
44
+ windows: (num_windows*B, window_size, window_size, C)
45
+ """
46
+ B, H, W, C = x.shape
47
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
48
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
49
+ return windows
50
+
51
+
52
+ def window_reverse(windows, window_size, H, W):
53
+ """
54
+ Args:
55
+ windows: (num_windows*B, window_size, window_size, C)
56
+ window_size (int): Window size
57
+ H (int): Height of image
58
+ W (int): Width of image
59
+
60
+ Returns:
61
+ x: (B, H, W, C)
62
+ """
63
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
64
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
65
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
66
+ return x
67
+
68
+
69
+ class WindowAttention(nn.Module):
70
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
71
+ It supports both of shifted and non-shifted window.
72
+
73
+ Args:
74
+ dim (int): Number of input channels.
75
+ window_size (tuple[int]): The height and width of the window.
76
+ num_heads (int): Number of attention heads.
77
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
78
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
79
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
80
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
81
+ """
82
+
83
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
84
+
85
+ super().__init__()
86
+ self.dim = dim
87
+ self.window_size = window_size # Wh, Ww
88
+ self.num_heads = num_heads
89
+ head_dim = dim // num_heads
90
+ self.scale = qk_scale or head_dim ** -0.5
91
+
92
+ # define a parameter table of relative position bias
93
+ self.relative_position_bias_table = nn.Parameter(
94
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
95
+
96
+ # get pair-wise relative position index for each token inside the window
97
+ coords_h = torch.arange(self.window_size[0])
98
+ coords_w = torch.arange(self.window_size[1])
99
+ # coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
100
+ # Fix: Pass indexing="ij" to avoid warning
101
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
102
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
103
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
104
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
105
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
106
+ relative_coords[:, :, 1] += self.window_size[1] - 1
107
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
108
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
109
+ self.register_buffer("relative_position_index", relative_position_index)
110
+
111
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
112
+ self.attn_drop = nn.Dropout(attn_drop)
113
+ self.proj = nn.Linear(dim, dim)
114
+
115
+ self.proj_drop = nn.Dropout(proj_drop)
116
+
117
+ trunc_normal_(self.relative_position_bias_table, std=.02)
118
+ self.softmax = nn.Softmax(dim=-1)
119
+
120
+ def forward(self, x, mask=None):
121
+ """
122
+ Args:
123
+ x: input features with shape of (num_windows*B, N, C)
124
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
125
+ """
126
+ B_, N, C = x.shape
127
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
128
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
129
+
130
+ q = q * self.scale
131
+ attn = (q @ k.transpose(-2, -1))
132
+
133
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
134
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
135
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
136
+ attn = attn + relative_position_bias.unsqueeze(0)
137
+
138
+ if mask is not None:
139
+ nW = mask.shape[0]
140
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
141
+ attn = attn.view(-1, self.num_heads, N, N)
142
+ attn = self.softmax(attn)
143
+ else:
144
+ attn = self.softmax(attn)
145
+
146
+ attn = self.attn_drop(attn)
147
+
148
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
149
+ x = self.proj(x)
150
+ x = self.proj_drop(x)
151
+ return x
152
+
153
+ def extra_repr(self) -> str:
154
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
155
+
156
+ def flops(self, N):
157
+ # calculate flops for 1 window with token length of N
158
+ flops = 0
159
+ # qkv = self.qkv(x)
160
+ flops += N * self.dim * 3 * self.dim
161
+ # attn = (q @ k.transpose(-2, -1))
162
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
163
+ # x = (attn @ v)
164
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
165
+ # x = self.proj(x)
166
+ flops += N * self.dim * self.dim
167
+ return flops
168
+
169
+
170
+ class SwinTransformerBlock(nn.Module):
171
+ r""" Swin Transformer Block.
172
+
173
+ Args:
174
+ dim (int): Number of input channels.
175
+ input_resolution (tuple[int]): Input resulotion.
176
+ num_heads (int): Number of attention heads.
177
+ window_size (int): Window size.
178
+ shift_size (int): Shift size for SW-MSA.
179
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
180
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
181
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
182
+ drop (float, optional): Dropout rate. Default: 0.0
183
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
184
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
185
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
186
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
187
+ """
188
+
189
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
190
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
191
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
192
+ super().__init__()
193
+ self.dim = dim
194
+ self.input_resolution = input_resolution
195
+ self.num_heads = num_heads
196
+ self.window_size = window_size
197
+ self.shift_size = shift_size
198
+ self.mlp_ratio = mlp_ratio
199
+ if min(self.input_resolution) <= self.window_size:
200
+ # if window size is larger than input resolution, we don't partition windows
201
+ self.shift_size = 0
202
+ self.window_size = min(self.input_resolution)
203
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
204
+
205
+ self.norm1 = norm_layer(dim)
206
+ self.attn = WindowAttention(
207
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
208
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
209
+
210
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
211
+ self.norm2 = norm_layer(dim)
212
+ mlp_hidden_dim = int(dim * mlp_ratio)
213
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
214
+
215
+ if self.shift_size > 0:
216
+ attn_mask = self.calculate_mask(self.input_resolution)
217
+ else:
218
+ attn_mask = None
219
+
220
+ self.register_buffer("attn_mask", attn_mask)
221
+
222
+ def calculate_mask(self, x_size):
223
+ # calculate attention mask for SW-MSA
224
+ H, W = x_size
225
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
226
+ h_slices = (slice(0, -self.window_size),
227
+ slice(-self.window_size, -self.shift_size),
228
+ slice(-self.shift_size, None))
229
+ w_slices = (slice(0, -self.window_size),
230
+ slice(-self.window_size, -self.shift_size),
231
+ slice(-self.shift_size, None))
232
+ cnt = 0
233
+ for h in h_slices:
234
+ for w in w_slices:
235
+ img_mask[:, h, w, :] = cnt
236
+ cnt += 1
237
+
238
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
239
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
240
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
241
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
242
+
243
+ return attn_mask
244
+
245
+ def forward(self, x, x_size):
246
+ H, W = x_size
247
+ B, L, C = x.shape
248
+ # assert L == H * W, "input feature has wrong size"
249
+
250
+ shortcut = x
251
+ x = self.norm1(x)
252
+ x = x.view(B, H, W, C)
253
+
254
+ # cyclic shift
255
+ if self.shift_size > 0:
256
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
257
+ else:
258
+ shifted_x = x
259
+
260
+ # partition windows
261
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
262
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
263
+
264
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
265
+ if self.input_resolution == x_size:
266
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
267
+ else:
268
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
269
+
270
+ # merge windows
271
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
272
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
273
+
274
+ # reverse cyclic shift
275
+ if self.shift_size > 0:
276
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
277
+ else:
278
+ x = shifted_x
279
+ x = x.view(B, H * W, C)
280
+
281
+ # FFN
282
+ x = shortcut + self.drop_path(x)
283
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
284
+
285
+ return x
286
+
287
+ def extra_repr(self) -> str:
288
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
289
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
290
+
291
+ def flops(self):
292
+ flops = 0
293
+ H, W = self.input_resolution
294
+ # norm1
295
+ flops += self.dim * H * W
296
+ # W-MSA/SW-MSA
297
+ nW = H * W / self.window_size / self.window_size
298
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
299
+ # mlp
300
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
301
+ # norm2
302
+ flops += self.dim * H * W
303
+ return flops
304
+
305
+
306
+ class PatchMerging(nn.Module):
307
+ r""" Patch Merging Layer.
308
+
309
+ Args:
310
+ input_resolution (tuple[int]): Resolution of input feature.
311
+ dim (int): Number of input channels.
312
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
313
+ """
314
+
315
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
316
+ super().__init__()
317
+ self.input_resolution = input_resolution
318
+ self.dim = dim
319
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
320
+ self.norm = norm_layer(4 * dim)
321
+
322
+ def forward(self, x):
323
+ """
324
+ x: B, H*W, C
325
+ """
326
+ H, W = self.input_resolution
327
+ B, L, C = x.shape
328
+ assert L == H * W, "input feature has wrong size"
329
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
330
+
331
+ x = x.view(B, H, W, C)
332
+
333
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
334
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
335
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
336
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
337
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
338
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
339
+
340
+ x = self.norm(x)
341
+ x = self.reduction(x)
342
+
343
+ return x
344
+
345
+ def extra_repr(self) -> str:
346
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
347
+
348
+ def flops(self):
349
+ H, W = self.input_resolution
350
+ flops = H * W * self.dim
351
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
352
+ return flops
353
+
354
+
355
+ class BasicLayer(nn.Module):
356
+ """ A basic Swin Transformer layer for one stage.
357
+
358
+ Args:
359
+ dim (int): Number of input channels.
360
+ input_resolution (tuple[int]): Input resolution.
361
+ depth (int): Number of blocks.
362
+ num_heads (int): Number of attention heads.
363
+ window_size (int): Local window size.
364
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
365
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
366
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
367
+ drop (float, optional): Dropout rate. Default: 0.0
368
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
369
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
370
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
371
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
372
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
373
+ """
374
+
375
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
376
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
377
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
378
+
379
+ super().__init__()
380
+ self.dim = dim
381
+ self.input_resolution = input_resolution
382
+ self.depth = depth
383
+ self.use_checkpoint = use_checkpoint
384
+
385
+ # build blocks
386
+ self.blocks = nn.ModuleList([
387
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
388
+ num_heads=num_heads, window_size=window_size,
389
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
390
+ mlp_ratio=mlp_ratio,
391
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
392
+ drop=drop, attn_drop=attn_drop,
393
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
394
+ norm_layer=norm_layer)
395
+ for i in range(depth)])
396
+
397
+ # patch merging layer
398
+ if downsample is not None:
399
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
400
+ else:
401
+ self.downsample = None
402
+
403
+ def forward(self, x, x_size):
404
+ for blk in self.blocks:
405
+ if self.use_checkpoint:
406
+ x = checkpoint.checkpoint(blk, x, x_size)
407
+ else:
408
+ x = blk(x, x_size)
409
+ if self.downsample is not None:
410
+ x = self.downsample(x)
411
+ return x
412
+
413
+ def extra_repr(self) -> str:
414
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
415
+
416
+ def flops(self):
417
+ flops = 0
418
+ for blk in self.blocks:
419
+ flops += blk.flops()
420
+ if self.downsample is not None:
421
+ flops += self.downsample.flops()
422
+ return flops
423
+
424
+
425
+ class RSTB(nn.Module):
426
+ """Residual Swin Transformer Block (RSTB).
427
+
428
+ Args:
429
+ dim (int): Number of input channels.
430
+ input_resolution (tuple[int]): Input resolution.
431
+ depth (int): Number of blocks.
432
+ num_heads (int): Number of attention heads.
433
+ window_size (int): Local window size.
434
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
435
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
436
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
437
+ drop (float, optional): Dropout rate. Default: 0.0
438
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
439
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
440
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
441
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
442
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
443
+ img_size: Input image size.
444
+ patch_size: Patch size.
445
+ resi_connection: The convolutional block before residual connection.
446
+ """
447
+
448
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
449
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
450
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
451
+ img_size=224, patch_size=4, resi_connection='1conv'):
452
+ super(RSTB, self).__init__()
453
+
454
+ self.dim = dim
455
+ self.input_resolution = input_resolution
456
+
457
+ self.residual_group = BasicLayer(dim=dim,
458
+ input_resolution=input_resolution,
459
+ depth=depth,
460
+ num_heads=num_heads,
461
+ window_size=window_size,
462
+ mlp_ratio=mlp_ratio,
463
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
464
+ drop=drop, attn_drop=attn_drop,
465
+ drop_path=drop_path,
466
+ norm_layer=norm_layer,
467
+ downsample=downsample,
468
+ use_checkpoint=use_checkpoint)
469
+
470
+ if resi_connection == '1conv':
471
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
472
+ elif resi_connection == '3conv':
473
+ # to save parameters and memory
474
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
475
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
476
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
477
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
478
+
479
+ self.patch_embed = PatchEmbed(
480
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
481
+ norm_layer=None)
482
+
483
+ self.patch_unembed = PatchUnEmbed(
484
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
485
+ norm_layer=None)
486
+
487
+ def forward(self, x, x_size):
488
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
489
+
490
+ def flops(self):
491
+ flops = 0
492
+ flops += self.residual_group.flops()
493
+ H, W = self.input_resolution
494
+ flops += H * W * self.dim * self.dim * 9
495
+ flops += self.patch_embed.flops()
496
+ flops += self.patch_unembed.flops()
497
+
498
+ return flops
499
+
500
+
501
+ class PatchEmbed(nn.Module):
502
+ r""" Image to Patch Embedding
503
+
504
+ Args:
505
+ img_size (int): Image size. Default: 224.
506
+ patch_size (int): Patch token size. Default: 4.
507
+ in_chans (int): Number of input image channels. Default: 3.
508
+ embed_dim (int): Number of linear projection output channels. Default: 96.
509
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
510
+ """
511
+
512
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
513
+ super().__init__()
514
+ img_size = to_2tuple(img_size)
515
+ patch_size = to_2tuple(patch_size)
516
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
517
+ self.img_size = img_size
518
+ self.patch_size = patch_size
519
+ self.patches_resolution = patches_resolution
520
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
521
+
522
+ self.in_chans = in_chans
523
+ self.embed_dim = embed_dim
524
+
525
+ if norm_layer is not None:
526
+ self.norm = norm_layer(embed_dim)
527
+ else:
528
+ self.norm = None
529
+
530
+ def forward(self, x):
531
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
532
+ if self.norm is not None:
533
+ x = self.norm(x)
534
+ return x
535
+
536
+ def flops(self):
537
+ flops = 0
538
+ H, W = self.img_size
539
+ if self.norm is not None:
540
+ flops += H * W * self.embed_dim
541
+ return flops
542
+
543
+
544
+ class PatchUnEmbed(nn.Module):
545
+ r""" Image to Patch Unembedding
546
+
547
+ Args:
548
+ img_size (int): Image size. Default: 224.
549
+ patch_size (int): Patch token size. Default: 4.
550
+ in_chans (int): Number of input image channels. Default: 3.
551
+ embed_dim (int): Number of linear projection output channels. Default: 96.
552
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
553
+ """
554
+
555
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
556
+ super().__init__()
557
+ img_size = to_2tuple(img_size)
558
+ patch_size = to_2tuple(patch_size)
559
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
560
+ self.img_size = img_size
561
+ self.patch_size = patch_size
562
+ self.patches_resolution = patches_resolution
563
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
564
+
565
+ self.in_chans = in_chans
566
+ self.embed_dim = embed_dim
567
+
568
+ def forward(self, x, x_size):
569
+ B, HW, C = x.shape
570
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
571
+ return x
572
+
573
+ def flops(self):
574
+ flops = 0
575
+ return flops
576
+
577
+
578
+ class Upsample(nn.Sequential):
579
+ """Upsample module.
580
+
581
+ Args:
582
+ scale (int): Scale factor. Supported scales: 2^n and 3.
583
+ num_feat (int): Channel number of intermediate features.
584
+ """
585
+
586
+ def __init__(self, scale, num_feat):
587
+ m = []
588
+ if (scale & (scale - 1)) == 0: # scale = 2^n
589
+ for _ in range(int(math.log(scale, 2))):
590
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
591
+ m.append(nn.PixelShuffle(2))
592
+ elif scale == 3:
593
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
594
+ m.append(nn.PixelShuffle(3))
595
+ else:
596
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
597
+ super(Upsample, self).__init__(*m)
598
+
599
+
600
+ class UpsampleOneStep(nn.Sequential):
601
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
602
+ Used in lightweight SR to save parameters.
603
+
604
+ Args:
605
+ scale (int): Scale factor. Supported scales: 2^n and 3.
606
+ num_feat (int): Channel number of intermediate features.
607
+
608
+ """
609
+
610
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
611
+ self.num_feat = num_feat
612
+ self.input_resolution = input_resolution
613
+ m = []
614
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
615
+ m.append(nn.PixelShuffle(scale))
616
+ super(UpsampleOneStep, self).__init__(*m)
617
+
618
+ def flops(self):
619
+ H, W = self.input_resolution
620
+ flops = H * W * self.num_feat * 3 * 9
621
+ return flops
622
+
623
+
624
+ class SwinIR(nn.Module):
625
+ r""" SwinIR
626
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
627
+
628
+ Args:
629
+ img_size (int | tuple(int)): Input image size. Default 64
630
+ patch_size (int | tuple(int)): Patch size. Default: 1
631
+ in_chans (int): Number of input image channels. Default: 3
632
+ embed_dim (int): Patch embedding dimension. Default: 96
633
+ depths (tuple(int)): Depth of each Swin Transformer layer.
634
+ num_heads (tuple(int)): Number of attention heads in different layers.
635
+ window_size (int): Window size. Default: 7
636
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
637
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
638
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
639
+ drop_rate (float): Dropout rate. Default: 0
640
+ attn_drop_rate (float): Attention dropout rate. Default: 0
641
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
642
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
643
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
644
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
645
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
646
+ sf: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
647
+ img_range: Image range. 1. or 255.
648
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
649
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
650
+ """
651
+
652
+ def __init__(
653
+ self,
654
+ img_size=64,
655
+ patch_size=1,
656
+ in_chans=3,
657
+ embed_dim=96,
658
+ depths=[6, 6, 6, 6],
659
+ num_heads=[6, 6, 6, 6],
660
+ window_size=7,
661
+ mlp_ratio=4.,
662
+ qkv_bias=True,
663
+ qk_scale=None,
664
+ drop_rate=0.,
665
+ attn_drop_rate=0.,
666
+ drop_path_rate=0.1,
667
+ norm_layer=nn.LayerNorm,
668
+ ape=False,
669
+ patch_norm=True,
670
+ use_checkpoint=False,
671
+ sf=4,
672
+ img_range=1.,
673
+ upsampler='',
674
+ resi_connection='1conv',
675
+ unshuffle=False,
676
+ unshuffle_scale=None,
677
+ hq_key: str="jpg",
678
+ lq_key: str="hint",
679
+ learning_rate: float=None,
680
+ weight_decay: float=None
681
+ ) -> "SwinIR":
682
+ super(SwinIR, self).__init__()
683
+ num_in_ch = in_chans * (unshuffle_scale**2) if unshuffle else in_chans
684
+ num_out_ch = in_chans
685
+ num_feat = 64
686
+ self.img_range = img_range
687
+ if in_chans == 3:
688
+ rgb_mean = (0.4488, 0.4371, 0.4040)
689
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
690
+ else:
691
+ self.mean = torch.zeros(1, 1, 1, 1)
692
+ self.upscale = sf
693
+ self.upsampler = upsampler
694
+ self.window_size = window_size
695
+ self.unshuffle_scale = unshuffle_scale
696
+ self.unshuffle = unshuffle
697
+
698
+ #####################################################################################################
699
+ ################################### 1, shallow feature extraction ###################################
700
+ if unshuffle:
701
+ assert unshuffle_scale is not None
702
+ self.conv_first = nn.Sequential(
703
+ nn.PixelUnshuffle(sf),
704
+ nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1),
705
+ )
706
+ else:
707
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
708
+
709
+ #####################################################################################################
710
+ ################################### 2, deep feature extraction ######################################
711
+ self.num_layers = len(depths)
712
+ self.embed_dim = embed_dim
713
+ self.ape = ape
714
+ self.patch_norm = patch_norm
715
+ self.num_features = embed_dim
716
+ self.mlp_ratio = mlp_ratio
717
+
718
+ # split image into non-overlapping patches
719
+ self.patch_embed = PatchEmbed(
720
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
721
+ norm_layer=norm_layer if self.patch_norm else None
722
+ )
723
+ num_patches = self.patch_embed.num_patches
724
+ patches_resolution = self.patch_embed.patches_resolution
725
+ self.patches_resolution = patches_resolution
726
+
727
+ # merge non-overlapping patches into image
728
+ self.patch_unembed = PatchUnEmbed(
729
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
730
+ norm_layer=norm_layer if self.patch_norm else None
731
+ )
732
+
733
+ # absolute position embedding
734
+ if self.ape:
735
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
736
+ trunc_normal_(self.absolute_pos_embed, std=.02)
737
+
738
+ self.pos_drop = nn.Dropout(p=drop_rate)
739
+
740
+ # stochastic depth
741
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
742
+
743
+ # build Residual Swin Transformer blocks (RSTB)
744
+ self.layers = nn.ModuleList()
745
+ for i_layer in range(self.num_layers):
746
+ layer = RSTB(
747
+ dim=embed_dim,
748
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
749
+ depth=depths[i_layer],
750
+ num_heads=num_heads[i_layer],
751
+ window_size=window_size,
752
+ mlp_ratio=self.mlp_ratio,
753
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
754
+ drop=drop_rate, attn_drop=attn_drop_rate,
755
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
756
+ norm_layer=norm_layer,
757
+ downsample=None,
758
+ use_checkpoint=use_checkpoint,
759
+ img_size=img_size,
760
+ patch_size=patch_size,
761
+ resi_connection=resi_connection
762
+ )
763
+ self.layers.append(layer)
764
+ self.norm = norm_layer(self.num_features)
765
+
766
+ # build the last conv layer in deep feature extraction
767
+ if resi_connection == '1conv':
768
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
769
+ elif resi_connection == '3conv':
770
+ # to save parameters and memory
771
+ self.conv_after_body = nn.Sequential(
772
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
773
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
774
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
775
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
776
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)
777
+ )
778
+
779
+ #####################################################################################################
780
+ ################################ 3, high quality image reconstruction ################################
781
+ if self.upsampler == 'pixelshuffle':
782
+ # for classical SR
783
+ self.conv_before_upsample = nn.Sequential(
784
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
785
+ nn.LeakyReLU(inplace=True)
786
+ )
787
+ self.upsample = Upsample(sf, num_feat)
788
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
789
+ elif self.upsampler == 'pixelshuffledirect':
790
+ # for lightweight SR (to save parameters)
791
+ self.upsample = UpsampleOneStep(
792
+ sf, embed_dim, num_out_ch,
793
+ (patches_resolution[0], patches_resolution[1])
794
+ )
795
+ elif self.upsampler == 'nearest+conv':
796
+ # for real-world SR (less artifacts)
797
+ self.conv_before_upsample = nn.Sequential(
798
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
799
+ nn.LeakyReLU(inplace=True)
800
+ )
801
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
802
+ if self.upscale == 4:
803
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
804
+ elif self.upscale == 8:
805
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
806
+ self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
807
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
808
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
809
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
810
+ else:
811
+ # for image denoising and JPEG compression artifact reduction
812
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
813
+
814
+ self.apply(self._init_weights)
815
+
816
+ def _init_weights(self, m: nn.Module) -> None:
817
+ if isinstance(m, nn.Linear):
818
+ trunc_normal_(m.weight, std=.02)
819
+ if isinstance(m, nn.Linear) and m.bias is not None:
820
+ nn.init.constant_(m.bias, 0)
821
+ elif isinstance(m, nn.LayerNorm):
822
+ nn.init.constant_(m.bias, 0)
823
+ nn.init.constant_(m.weight, 1.0)
824
+
825
+ # TODO: What's this ?
826
+ @torch.jit.ignore
827
+ def no_weight_decay(self) -> Set[str]:
828
+ return {'absolute_pos_embed'}
829
+
830
+ @torch.jit.ignore
831
+ def no_weight_decay_keywords(self) -> Set[str]:
832
+ return {'relative_position_bias_table'}
833
+
834
+ def check_image_size(self, x: torch.Tensor) -> torch.Tensor:
835
+ _, _, h, w = x.size()
836
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
837
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
838
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
839
+ return x
840
+
841
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
842
+ x_size = (x.shape[2], x.shape[3])
843
+ x = self.patch_embed(x)
844
+ if self.ape:
845
+ x = x + self.absolute_pos_embed
846
+ x = self.pos_drop(x)
847
+
848
+ for layer in self.layers:
849
+ x = layer(x, x_size)
850
+
851
+ x = self.norm(x) # B L C
852
+ x = self.patch_unembed(x, x_size)
853
+
854
+ return x
855
+
856
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
857
+ H, W = x.shape[2:]
858
+ x = self.check_image_size(x)
859
+
860
+ self.mean = self.mean.type_as(x)
861
+ x = (x - self.mean) * self.img_range
862
+
863
+ if self.upsampler == 'pixelshuffle':
864
+ # for classical SR
865
+ x = self.conv_first(x)
866
+ x = self.conv_after_body(self.forward_features(x)) + x
867
+ x = self.conv_before_upsample(x)
868
+ x = self.conv_last(self.upsample(x))
869
+ elif self.upsampler == 'pixelshuffledirect':
870
+ # for lightweight SR
871
+ x = self.conv_first(x)
872
+ x = self.conv_after_body(self.forward_features(x)) + x
873
+ x = self.upsample(x)
874
+ elif self.upsampler == 'nearest+conv':
875
+ # for real-world SR
876
+ x = self.conv_first(x)
877
+ x = self.conv_after_body(self.forward_features(x)) + x
878
+ x = self.conv_before_upsample(x)
879
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
880
+ if self.upscale == 4:
881
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
882
+ elif self.upscale == 8:
883
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
884
+ x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
885
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
886
+ else:
887
+ # for image denoising and JPEG compression artifact reduction
888
+ x_first = self.conv_first(x)
889
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
890
+ x = x + self.conv_last(res)
891
+
892
+ x = x / self.img_range + self.mean
893
+
894
+ return x[:, :, :H*self.upscale, :W*self.upscale]
895
+
896
+ def flops(self) -> int:
897
+ flops = 0
898
+ H, W = self.patches_resolution
899
+ flops += H * W * 3 * self.embed_dim * 9
900
+ flops += self.patch_embed.flops()
901
+ for i, layer in enumerate(self.layers):
902
+ flops += layer.flops()
903
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
904
+ flops += self.upsample.flops()
905
+ return flops
model/unet.py ADDED
@@ -0,0 +1,719 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch as th
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from model.util import (
10
+ checkpoint,
11
+ conv_nd,
12
+ linear,
13
+ avg_pool_nd,
14
+ zero_module,
15
+ normalization,
16
+ timestep_embedding,
17
+ exists
18
+ )
19
+ from model.attention import SpatialTransformer
20
+
21
+
22
+ class TimestepBlock(nn.Module):
23
+ """
24
+ Any module where forward() takes timestep embeddings as a second argument.
25
+ """
26
+
27
+ @abstractmethod
28
+ def forward(self, x, emb):
29
+ """
30
+ Apply the module to `x` given `emb` timestep embeddings.
31
+ """
32
+
33
+
34
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
35
+ """
36
+ A sequential module that passes timestep embeddings to the children that
37
+ support it as an extra input.
38
+ """
39
+
40
+ def forward(self, x, emb, context=None):
41
+ for layer in self:
42
+ if isinstance(layer, TimestepBlock):
43
+ x = layer(x, emb)
44
+ elif isinstance(layer, SpatialTransformer):
45
+ x = layer(x, context)
46
+ else:
47
+ x = layer(x)
48
+ return x
49
+
50
+
51
+ class Upsample(nn.Module):
52
+ """
53
+ An upsampling layer with an optional convolution.
54
+ :param channels: channels in the inputs and outputs.
55
+ :param use_conv: a bool determining if a convolution is applied.
56
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
57
+ upsampling occurs in the inner-two dimensions.
58
+ """
59
+
60
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
61
+ super().__init__()
62
+ self.channels = channels
63
+ self.out_channels = out_channels or channels
64
+ self.use_conv = use_conv
65
+ self.dims = dims
66
+ if use_conv:
67
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
68
+
69
+ def forward(self, x):
70
+ assert x.shape[1] == self.channels
71
+ if self.dims == 3:
72
+ x = F.interpolate(
73
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
74
+ )
75
+ else:
76
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
77
+ if self.use_conv:
78
+ x = self.conv(x)
79
+ return x
80
+
81
+
82
+ class Downsample(nn.Module):
83
+ """
84
+ A downsampling layer with an optional convolution.
85
+ :param channels: channels in the inputs and outputs.
86
+ :param use_conv: a bool determining if a convolution is applied.
87
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
88
+ downsampling occurs in the inner-two dimensions.
89
+ """
90
+
91
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
92
+ super().__init__()
93
+ self.channels = channels
94
+ self.out_channels = out_channels or channels
95
+ self.use_conv = use_conv
96
+ self.dims = dims
97
+ stride = 2 if dims != 3 else (1, 2, 2)
98
+ if use_conv:
99
+ self.op = conv_nd(
100
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
101
+ )
102
+ else:
103
+ assert self.channels == self.out_channels
104
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
105
+
106
+ def forward(self, x):
107
+ assert x.shape[1] == self.channels
108
+ return self.op(x)
109
+
110
+
111
+ class ResBlock(TimestepBlock):
112
+ """
113
+ A residual block that can optionally change the number of channels.
114
+ :param channels: the number of input channels.
115
+ :param emb_channels: the number of timestep embedding channels.
116
+ :param dropout: the rate of dropout.
117
+ :param out_channels: if specified, the number of out channels.
118
+ :param use_conv: if True and out_channels is specified, use a spatial
119
+ convolution instead of a smaller 1x1 convolution to change the
120
+ channels in the skip connection.
121
+ :param dims: determines if the signal is 1D, 2D, or 3D.
122
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
123
+ :param up: if True, use this block for upsampling.
124
+ :param down: if True, use this block for downsampling.
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ channels,
130
+ emb_channels,
131
+ dropout,
132
+ out_channels=None,
133
+ use_conv=False,
134
+ use_scale_shift_norm=False,
135
+ dims=2,
136
+ use_checkpoint=False,
137
+ up=False,
138
+ down=False,
139
+ ):
140
+ super().__init__()
141
+ self.channels = channels
142
+ self.emb_channels = emb_channels
143
+ self.dropout = dropout
144
+ self.out_channels = out_channels or channels
145
+ self.use_conv = use_conv
146
+ self.use_checkpoint = use_checkpoint
147
+ self.use_scale_shift_norm = use_scale_shift_norm
148
+
149
+ self.in_layers = nn.Sequential(
150
+ normalization(channels),
151
+ nn.SiLU(),
152
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
153
+ )
154
+
155
+ self.updown = up or down
156
+
157
+ if up:
158
+ self.h_upd = Upsample(channels, False, dims)
159
+ self.x_upd = Upsample(channels, False, dims)
160
+ elif down:
161
+ self.h_upd = Downsample(channels, False, dims)
162
+ self.x_upd = Downsample(channels, False, dims)
163
+ else:
164
+ self.h_upd = self.x_upd = nn.Identity()
165
+
166
+ self.emb_layers = nn.Sequential(
167
+ nn.SiLU(),
168
+ linear(
169
+ emb_channels,
170
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
171
+ ),
172
+ )
173
+ self.out_layers = nn.Sequential(
174
+ normalization(self.out_channels),
175
+ nn.SiLU(),
176
+ nn.Dropout(p=dropout),
177
+ zero_module(
178
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
179
+ ),
180
+ )
181
+
182
+ if self.out_channels == channels:
183
+ self.skip_connection = nn.Identity()
184
+ elif use_conv:
185
+ self.skip_connection = conv_nd(
186
+ dims, channels, self.out_channels, 3, padding=1
187
+ )
188
+ else:
189
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
190
+
191
+ def forward(self, x, emb):
192
+ """
193
+ Apply the block to a Tensor, conditioned on a timestep embedding.
194
+ :param x: an [N x C x ...] Tensor of features.
195
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
196
+ :return: an [N x C x ...] Tensor of outputs.
197
+ """
198
+ return checkpoint(
199
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
200
+ )
201
+
202
+
203
+ def _forward(self, x, emb):
204
+ if self.updown:
205
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
206
+ h = in_rest(x)
207
+ h = self.h_upd(h)
208
+ x = self.x_upd(x)
209
+ h = in_conv(h)
210
+ else:
211
+ h = self.in_layers(x)
212
+ emb_out = self.emb_layers(emb).type(h.dtype)
213
+ while len(emb_out.shape) < len(h.shape):
214
+ emb_out = emb_out[..., None]
215
+ if self.use_scale_shift_norm:
216
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
217
+ scale, shift = th.chunk(emb_out, 2, dim=1)
218
+ h = out_norm(h) * (1 + scale) + shift
219
+ h = out_rest(h)
220
+ else:
221
+ h = h + emb_out
222
+ h = self.out_layers(h)
223
+ return self.skip_connection(x) + h
224
+
225
+
226
+ class AttentionBlock(nn.Module):
227
+ """
228
+ An attention block that allows spatial positions to attend to each other.
229
+ Originally ported from here, but adapted to the N-d case.
230
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ channels,
236
+ num_heads=1,
237
+ num_head_channels=-1,
238
+ use_checkpoint=False,
239
+ use_new_attention_order=False,
240
+ ):
241
+ super().__init__()
242
+ self.channels = channels
243
+ if num_head_channels == -1:
244
+ self.num_heads = num_heads
245
+ else:
246
+ assert (
247
+ channels % num_head_channels == 0
248
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
249
+ self.num_heads = channels // num_head_channels
250
+ self.use_checkpoint = use_checkpoint
251
+ self.norm = normalization(channels)
252
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
253
+ if use_new_attention_order:
254
+ # split qkv before split heads
255
+ self.attention = QKVAttention(self.num_heads)
256
+ else:
257
+ # split heads before split qkv
258
+ self.attention = QKVAttentionLegacy(self.num_heads)
259
+
260
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
261
+
262
+ def forward(self, x):
263
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
264
+ #return pt_checkpoint(self._forward, x) # pytorch
265
+
266
+ def _forward(self, x):
267
+ b, c, *spatial = x.shape
268
+ x = x.reshape(b, c, -1)
269
+ qkv = self.qkv(self.norm(x))
270
+ h = self.attention(qkv)
271
+ h = self.proj_out(h)
272
+ return (x + h).reshape(b, c, *spatial)
273
+
274
+
275
+ def count_flops_attn(model, _x, y):
276
+ """
277
+ A counter for the `thop` package to count the operations in an
278
+ attention operation.
279
+ Meant to be used like:
280
+ macs, params = thop.profile(
281
+ model,
282
+ inputs=(inputs, timestamps),
283
+ custom_ops={QKVAttention: QKVAttention.count_flops},
284
+ )
285
+ """
286
+ b, c, *spatial = y[0].shape
287
+ num_spatial = int(np.prod(spatial))
288
+ # We perform two matmuls with the same number of ops.
289
+ # The first computes the weight matrix, the second computes
290
+ # the combination of the value vectors.
291
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
292
+ model.total_ops += th.DoubleTensor([matmul_ops])
293
+
294
+
295
+ class QKVAttentionLegacy(nn.Module):
296
+ """
297
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
298
+ """
299
+
300
+ def __init__(self, n_heads):
301
+ super().__init__()
302
+ self.n_heads = n_heads
303
+
304
+ def forward(self, qkv):
305
+ """
306
+ Apply QKV attention.
307
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
308
+ :return: an [N x (H * C) x T] tensor after attention.
309
+ """
310
+ bs, width, length = qkv.shape
311
+ assert width % (3 * self.n_heads) == 0
312
+ ch = width // (3 * self.n_heads)
313
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
314
+ scale = 1 / math.sqrt(math.sqrt(ch))
315
+ weight = th.einsum(
316
+ "bct,bcs->bts", q * scale, k * scale
317
+ ) # More stable with f16 than dividing afterwards
318
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
319
+ a = th.einsum("bts,bcs->bct", weight, v)
320
+ return a.reshape(bs, -1, length)
321
+
322
+ @staticmethod
323
+ def count_flops(model, _x, y):
324
+ return count_flops_attn(model, _x, y)
325
+
326
+
327
+ class QKVAttention(nn.Module):
328
+ """
329
+ A module which performs QKV attention and splits in a different order.
330
+ """
331
+
332
+ def __init__(self, n_heads):
333
+ super().__init__()
334
+ self.n_heads = n_heads
335
+
336
+ def forward(self, qkv):
337
+ """
338
+ Apply QKV attention.
339
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
340
+ :return: an [N x (H * C) x T] tensor after attention.
341
+ """
342
+ bs, width, length = qkv.shape
343
+ assert width % (3 * self.n_heads) == 0
344
+ ch = width // (3 * self.n_heads)
345
+ q, k, v = qkv.chunk(3, dim=1)
346
+ scale = 1 / math.sqrt(math.sqrt(ch))
347
+ weight = th.einsum(
348
+ "bct,bcs->bts",
349
+ (q * scale).view(bs * self.n_heads, ch, length),
350
+ (k * scale).view(bs * self.n_heads, ch, length),
351
+ ) # More stable with f16 than dividing afterwards
352
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
353
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
354
+ return a.reshape(bs, -1, length)
355
+
356
+ @staticmethod
357
+ def count_flops(model, _x, y):
358
+ return count_flops_attn(model, _x, y)
359
+
360
+
361
+ class UNetModel(nn.Module):
362
+ """
363
+ The full UNet model with attention and timestep embedding.
364
+ :param in_channels: channels in the input Tensor.
365
+ :param model_channels: base channel count for the model.
366
+ :param out_channels: channels in the output Tensor.
367
+ :param num_res_blocks: number of residual blocks per downsample.
368
+ :param attention_resolutions: a collection of downsample rates at which
369
+ attention will take place. May be a set, list, or tuple.
370
+ For example, if this contains 4, then at 4x downsampling, attention
371
+ will be used.
372
+ :param dropout: the dropout probability.
373
+ :param channel_mult: channel multiplier for each level of the UNet.
374
+ :param conv_resample: if True, use learned convolutions for upsampling and
375
+ downsampling.
376
+ :param dims: determines if the signal is 1D, 2D, or 3D.
377
+ :param num_classes: if specified (as an int), then this model will be
378
+ class-conditional with `num_classes` classes.
379
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
380
+ :param num_heads: the number of attention heads in each attention layer.
381
+ :param num_heads_channels: if specified, ignore num_heads and instead use
382
+ a fixed channel width per attention head.
383
+ :param num_heads_upsample: works with num_heads to set a different number
384
+ of heads for upsampling. Deprecated.
385
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
386
+ :param resblock_updown: use residual blocks for up/downsampling.
387
+ :param use_new_attention_order: use a different attention pattern for potentially
388
+ increased efficiency.
389
+ """
390
+
391
+ def __init__(
392
+ self,
393
+ image_size,
394
+ in_channels,
395
+ model_channels,
396
+ out_channels,
397
+ num_res_blocks,
398
+ attention_resolutions,
399
+ dropout=0,
400
+ channel_mult=(1, 2, 4, 8),
401
+ conv_resample=True,
402
+ dims=2,
403
+ num_classes=None,
404
+ use_checkpoint=False,
405
+ use_fp16=False,
406
+ num_heads=-1,
407
+ num_head_channels=-1,
408
+ num_heads_upsample=-1,
409
+ use_scale_shift_norm=False,
410
+ resblock_updown=False,
411
+ use_new_attention_order=False,
412
+ use_spatial_transformer=False, # custom transformer support
413
+ transformer_depth=1, # custom transformer support
414
+ context_dim=None, # custom transformer support
415
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
416
+ legacy=True,
417
+ disable_self_attentions=None,
418
+ num_attention_blocks=None,
419
+ disable_middle_self_attn=False,
420
+ use_linear_in_transformer=False,
421
+ ):
422
+ super().__init__()
423
+ if use_spatial_transformer:
424
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
425
+
426
+ if context_dim is not None:
427
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
428
+ from omegaconf.listconfig import ListConfig
429
+ if type(context_dim) == ListConfig:
430
+ context_dim = list(context_dim)
431
+
432
+ if num_heads_upsample == -1:
433
+ num_heads_upsample = num_heads
434
+
435
+ if num_heads == -1:
436
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
437
+
438
+ if num_head_channels == -1:
439
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
440
+
441
+ self.image_size = image_size
442
+ self.in_channels = in_channels
443
+ self.model_channels = model_channels
444
+ self.out_channels = out_channels
445
+ if isinstance(num_res_blocks, int):
446
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
447
+ else:
448
+ if len(num_res_blocks) != len(channel_mult):
449
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
450
+ "as a list/tuple (per-level) with the same length as channel_mult")
451
+ self.num_res_blocks = num_res_blocks
452
+ if disable_self_attentions is not None:
453
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
454
+ assert len(disable_self_attentions) == len(channel_mult)
455
+ if num_attention_blocks is not None:
456
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
457
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
458
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
459
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
460
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
461
+ f"attention will still not be set.")
462
+
463
+ self.attention_resolutions = attention_resolutions
464
+ self.dropout = dropout
465
+ self.channel_mult = channel_mult
466
+ self.conv_resample = conv_resample
467
+ self.num_classes = num_classes
468
+ self.use_checkpoint = use_checkpoint
469
+ self.dtype = th.float16 if use_fp16 else th.float32
470
+ self.num_heads = num_heads
471
+ self.num_head_channels = num_head_channels
472
+ self.num_heads_upsample = num_heads_upsample
473
+ self.predict_codebook_ids = n_embed is not None
474
+
475
+ time_embed_dim = model_channels * 4
476
+ self.time_embed = nn.Sequential(
477
+ linear(model_channels, time_embed_dim),
478
+ nn.SiLU(),
479
+ linear(time_embed_dim, time_embed_dim),
480
+ )
481
+
482
+ if self.num_classes is not None:
483
+ if isinstance(self.num_classes, int):
484
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
485
+ elif self.num_classes == "continuous":
486
+ print("setting up linear c_adm embedding layer")
487
+ self.label_emb = nn.Linear(1, time_embed_dim)
488
+ else:
489
+ raise ValueError()
490
+
491
+ self.input_blocks = nn.ModuleList(
492
+ [
493
+ TimestepEmbedSequential(
494
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
495
+ )
496
+ ]
497
+ )
498
+ self._feature_size = model_channels
499
+ input_block_chans = [model_channels]
500
+ ch = model_channels
501
+ ds = 1
502
+ for level, mult in enumerate(channel_mult):
503
+ for nr in range(self.num_res_blocks[level]):
504
+ layers = [
505
+ ResBlock(
506
+ ch,
507
+ time_embed_dim,
508
+ dropout,
509
+ out_channels=mult * model_channels,
510
+ dims=dims,
511
+ use_checkpoint=use_checkpoint,
512
+ use_scale_shift_norm=use_scale_shift_norm,
513
+ )
514
+ ]
515
+ ch = mult * model_channels
516
+ if ds in attention_resolutions:
517
+ if num_head_channels == -1:
518
+ dim_head = ch // num_heads
519
+ else:
520
+ num_heads = ch // num_head_channels
521
+ dim_head = num_head_channels
522
+ if legacy:
523
+ #num_heads = 1
524
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
525
+ if exists(disable_self_attentions):
526
+ disabled_sa = disable_self_attentions[level]
527
+ else:
528
+ disabled_sa = False
529
+
530
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
531
+ layers.append(
532
+ AttentionBlock(
533
+ ch,
534
+ use_checkpoint=use_checkpoint,
535
+ num_heads=num_heads,
536
+ num_head_channels=dim_head,
537
+ use_new_attention_order=use_new_attention_order,
538
+ ) if not use_spatial_transformer else SpatialTransformer(
539
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
540
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
541
+ use_checkpoint=use_checkpoint
542
+ )
543
+ )
544
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
545
+ self._feature_size += ch
546
+ input_block_chans.append(ch)
547
+ if level != len(channel_mult) - 1:
548
+ out_ch = ch
549
+ self.input_blocks.append(
550
+ TimestepEmbedSequential(
551
+ ResBlock(
552
+ ch,
553
+ time_embed_dim,
554
+ dropout,
555
+ out_channels=out_ch,
556
+ dims=dims,
557
+ use_checkpoint=use_checkpoint,
558
+ use_scale_shift_norm=use_scale_shift_norm,
559
+ down=True,
560
+ )
561
+ if resblock_updown
562
+ else Downsample(
563
+ ch, conv_resample, dims=dims, out_channels=out_ch
564
+ )
565
+ )
566
+ )
567
+ ch = out_ch
568
+ input_block_chans.append(ch)
569
+ ds *= 2
570
+ self._feature_size += ch
571
+
572
+ if num_head_channels == -1:
573
+ dim_head = ch // num_heads
574
+ else:
575
+ num_heads = ch // num_head_channels
576
+ dim_head = num_head_channels
577
+ if legacy:
578
+ #num_heads = 1
579
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
580
+ self.middle_block = TimestepEmbedSequential(
581
+ ResBlock(
582
+ ch,
583
+ time_embed_dim,
584
+ dropout,
585
+ dims=dims,
586
+ use_checkpoint=use_checkpoint,
587
+ use_scale_shift_norm=use_scale_shift_norm,
588
+ ),
589
+ AttentionBlock(
590
+ ch,
591
+ use_checkpoint=use_checkpoint,
592
+ num_heads=num_heads,
593
+ num_head_channels=dim_head,
594
+ use_new_attention_order=use_new_attention_order,
595
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
596
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
597
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
598
+ use_checkpoint=use_checkpoint
599
+ ),
600
+ ResBlock(
601
+ ch,
602
+ time_embed_dim,
603
+ dropout,
604
+ dims=dims,
605
+ use_checkpoint=use_checkpoint,
606
+ use_scale_shift_norm=use_scale_shift_norm,
607
+ ),
608
+ )
609
+ self._feature_size += ch
610
+
611
+ self.output_blocks = nn.ModuleList([])
612
+ for level, mult in list(enumerate(channel_mult))[::-1]:
613
+ for i in range(self.num_res_blocks[level] + 1):
614
+ ich = input_block_chans.pop()
615
+ layers = [
616
+ ResBlock(
617
+ ch + ich,
618
+ time_embed_dim,
619
+ dropout,
620
+ out_channels=model_channels * mult,
621
+ dims=dims,
622
+ use_checkpoint=use_checkpoint,
623
+ use_scale_shift_norm=use_scale_shift_norm,
624
+ )
625
+ ]
626
+ ch = model_channels * mult
627
+ if ds in attention_resolutions:
628
+ if num_head_channels == -1:
629
+ dim_head = ch // num_heads
630
+ else:
631
+ num_heads = ch // num_head_channels
632
+ dim_head = num_head_channels
633
+ if legacy:
634
+ #num_heads = 1
635
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
636
+ if exists(disable_self_attentions):
637
+ disabled_sa = disable_self_attentions[level]
638
+ else:
639
+ disabled_sa = False
640
+
641
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
642
+ layers.append(
643
+ AttentionBlock(
644
+ ch,
645
+ use_checkpoint=use_checkpoint,
646
+ num_heads=num_heads_upsample,
647
+ num_head_channels=dim_head,
648
+ use_new_attention_order=use_new_attention_order,
649
+ ) if not use_spatial_transformer else SpatialTransformer(
650
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
651
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
652
+ use_checkpoint=use_checkpoint
653
+ )
654
+ )
655
+ if level and i == self.num_res_blocks[level]:
656
+ out_ch = ch
657
+ layers.append(
658
+ ResBlock(
659
+ ch,
660
+ time_embed_dim,
661
+ dropout,
662
+ out_channels=out_ch,
663
+ dims=dims,
664
+ use_checkpoint=use_checkpoint,
665
+ use_scale_shift_norm=use_scale_shift_norm,
666
+ up=True,
667
+ )
668
+ if resblock_updown
669
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
670
+ )
671
+ ds //= 2
672
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
673
+ self._feature_size += ch
674
+
675
+ self.out = nn.Sequential(
676
+ normalization(ch),
677
+ nn.SiLU(),
678
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
679
+ )
680
+ if self.predict_codebook_ids:
681
+ self.id_predictor = nn.Sequential(
682
+ normalization(ch),
683
+ conv_nd(dims, model_channels, n_embed, 1),
684
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
685
+ )
686
+
687
+ def forward(self, x, timesteps, context=None, y=None,**kwargs):
688
+ """
689
+ Apply the model to an input batch.
690
+ :param x: an [N x C x ...] Tensor of inputs.
691
+ :param timesteps: a 1-D batch of timesteps.
692
+ :param context: conditioning plugged in via crossattn
693
+ :param y: an [N] Tensor of labels, if class-conditional.
694
+ :return: an [N x C x ...] Tensor of outputs.
695
+ """
696
+ assert (y is not None) == (
697
+ self.num_classes is not None
698
+ ), "must specify y if and only if the model is class-conditional"
699
+ hs = []
700
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
701
+ emb = self.time_embed(t_emb)
702
+
703
+ if self.num_classes is not None:
704
+ assert y.shape[0] == x.shape[0]
705
+ emb = emb + self.label_emb(y)
706
+
707
+ h = x.type(self.dtype)
708
+ for module in self.input_blocks:
709
+ h = module(h, emb, context)
710
+ hs.append(h)
711
+ h = self.middle_block(h, emb, context)
712
+ for module in self.output_blocks:
713
+ h = th.cat([h, hs.pop()], dim=1)
714
+ h = module(h, emb, context)
715
+ h = h.type(x.dtype)
716
+ if self.predict_codebook_ids:
717
+ return self.id_predictor(h)
718
+ else:
719
+ return self.out(h)
model/util.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ from inspect import isfunction
14
+ import torch
15
+ import torch.nn as nn
16
+ import numpy as np
17
+ from einops import repeat
18
+
19
+
20
+ def exists(val):
21
+ return val is not None
22
+
23
+
24
+ def default(val, d):
25
+ if exists(val):
26
+ return val
27
+ return d() if isfunction(d) else d
28
+
29
+
30
+ def checkpoint(func, inputs, params, flag):
31
+ """
32
+ Evaluate a function without caching intermediate activations, allowing for
33
+ reduced memory at the expense of extra compute in the backward pass.
34
+ :param func: the function to evaluate.
35
+ :param inputs: the argument sequence to pass to `func`.
36
+ :param params: a sequence of parameters `func` depends on but does not
37
+ explicitly take as arguments.
38
+ :param flag: if False, disable gradient checkpointing.
39
+ """
40
+ if flag:
41
+ args = tuple(inputs) + tuple(params)
42
+ return CheckpointFunction.apply(func, len(inputs), *args)
43
+ else:
44
+ return func(*inputs)
45
+
46
+
47
+ # class CheckpointFunction(torch.autograd.Function):
48
+ # @staticmethod
49
+ # def forward(ctx, run_function, length, *args):
50
+ # ctx.run_function = run_function
51
+ # ctx.input_tensors = list(args[:length])
52
+ # ctx.input_params = list(args[length:])
53
+ # ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
54
+ # "dtype": torch.get_autocast_gpu_dtype(),
55
+ # "cache_enabled": torch.is_autocast_cache_enabled()}
56
+ # with torch.no_grad():
57
+ # output_tensors = ctx.run_function(*ctx.input_tensors)
58
+ # return output_tensors
59
+
60
+ # @staticmethod
61
+ # def backward(ctx, *output_grads):
62
+ # ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
63
+ # with torch.enable_grad(), \
64
+ # torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
65
+ # # Fixes a bug where the first op in run_function modifies the
66
+ # # Tensor storage in place, which is not allowed for detach()'d
67
+ # # Tensors.
68
+ # shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
69
+ # output_tensors = ctx.run_function(*shallow_copies)
70
+ # input_grads = torch.autograd.grad(
71
+ # output_tensors,
72
+ # ctx.input_tensors + ctx.input_params,
73
+ # output_grads,
74
+ # allow_unused=True,
75
+ # )
76
+ # del ctx.input_tensors
77
+ # del ctx.input_params
78
+ # del output_tensors
79
+ # return (None, None) + input_grads
80
+
81
+
82
+ # Fixes: When we set unet parameters with requires_grad=False, the original CheckpointFunction
83
+ # still tries to compute gradient for unet parameters.
84
+ # https://discuss.pytorch.org/t/get-runtimeerror-one-of-the-differentiated-tensors-does-not-require-grad-in-pytorch-lightning/179738/6
85
+ class CheckpointFunction(torch.autograd.Function):
86
+ @staticmethod
87
+ def forward(ctx, run_function, length, *args):
88
+ ctx.run_function = run_function
89
+ ctx.input_tensors = list(args[:length])
90
+ ctx.input_params = list(args[length:])
91
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
92
+ "dtype": torch.get_autocast_gpu_dtype(),
93
+ "cache_enabled": torch.is_autocast_cache_enabled()}
94
+ with torch.no_grad():
95
+ output_tensors = ctx.run_function(*ctx.input_tensors)
96
+ return output_tensors
97
+
98
+ @staticmethod
99
+ def backward(ctx, *output_grads):
100
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
101
+ with torch.enable_grad(), \
102
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
103
+ # Fixes a bug where the first op in run_function modifies the
104
+ # Tensor storage in place, which is not allowed for detach()'d
105
+ # Tensors.
106
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
107
+ output_tensors = ctx.run_function(*shallow_copies)
108
+ grads = torch.autograd.grad(
109
+ output_tensors,
110
+ ctx.input_tensors + [x for x in ctx.input_params if x.requires_grad],
111
+ output_grads,
112
+ allow_unused=True,
113
+ )
114
+ grads = list(grads)
115
+ # Assign gradients to the correct positions, matching None for those that do not require gradients
116
+ input_grads = []
117
+ for tensor in ctx.input_tensors + ctx.input_params:
118
+ if tensor.requires_grad:
119
+ input_grads.append(grads.pop(0)) # Get the next computed gradient
120
+ else:
121
+ input_grads.append(None) # No gradient required for this tensor
122
+ del ctx.input_tensors
123
+ del ctx.input_params
124
+ del output_tensors
125
+ return (None, None) + tuple(input_grads)
126
+
127
+
128
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
129
+ """
130
+ Create sinusoidal timestep embeddings.
131
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
132
+ These may be fractional.
133
+ :param dim: the dimension of the output.
134
+ :param max_period: controls the minimum frequency of the embeddings.
135
+ :return: an [N x dim] Tensor of positional embeddings.
136
+ """
137
+ if not repeat_only:
138
+ half = dim // 2
139
+ freqs = torch.exp(
140
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
141
+ ).to(device=timesteps.device)
142
+ args = timesteps[:, None].float() * freqs[None]
143
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
144
+ if dim % 2:
145
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
146
+ else:
147
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
148
+ return embedding
149
+
150
+
151
+ def zero_module(module):
152
+ """
153
+ Zero out the parameters of a module and return it.
154
+ """
155
+ for p in module.parameters():
156
+ p.detach().zero_()
157
+ return module
158
+
159
+
160
+ def scale_module(module, scale):
161
+ """
162
+ Scale the parameters of a module and return it.
163
+ """
164
+ for p in module.parameters():
165
+ p.detach().mul_(scale)
166
+ return module
167
+
168
+
169
+ def mean_flat(tensor):
170
+ """
171
+ Take the mean over all non-batch dimensions.
172
+ """
173
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
174
+
175
+
176
+ def normalization(channels):
177
+ """
178
+ Make a standard normalization layer.
179
+ :param channels: number of input channels.
180
+ :return: an nn.Module for normalization.
181
+ """
182
+ return GroupNorm32(32, channels)
183
+
184
+
185
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
186
+ class SiLU(nn.Module):
187
+ def forward(self, x):
188
+ return x * torch.sigmoid(x)
189
+
190
+
191
+ class GroupNorm32(nn.GroupNorm):
192
+ def forward(self, x):
193
+ return super().forward(x.float()).type(x.dtype)
194
+
195
+ def conv_nd(dims, *args, **kwargs):
196
+ """
197
+ Create a 1D, 2D, or 3D convolution module.
198
+ """
199
+ if dims == 1:
200
+ return nn.Conv1d(*args, **kwargs)
201
+ elif dims == 2:
202
+ return nn.Conv2d(*args, **kwargs)
203
+ elif dims == 3:
204
+ return nn.Conv3d(*args, **kwargs)
205
+ raise ValueError(f"unsupported dimensions: {dims}")
206
+
207
+
208
+ def linear(*args, **kwargs):
209
+ """
210
+ Create a linear module.
211
+ """
212
+ return nn.Linear(*args, **kwargs)
213
+
214
+
215
+ def avg_pool_nd(dims, *args, **kwargs):
216
+ """
217
+ Create a 1D, 2D, or 3D average pooling module.
218
+ """
219
+ if dims == 1:
220
+ return nn.AvgPool1d(*args, **kwargs)
221
+ elif dims == 2:
222
+ return nn.AvgPool2d(*args, **kwargs)
223
+ elif dims == 3:
224
+ return nn.AvgPool3d(*args, **kwargs)
225
+ raise ValueError(f"unsupported dimensions: {dims}")
model/vae.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ import numpy as np
6
+ from einops import rearrange
7
+ from typing import Optional, Any
8
+
9
+ from model.distributions import DiagonalGaussianDistribution
10
+ from model.config import Config, AttnMode
11
+
12
+
13
+ def nonlinearity(x):
14
+ # swish
15
+ return x*torch.sigmoid(x)
16
+
17
+
18
+ def Normalize(in_channels, num_groups=32):
19
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
20
+
21
+
22
+ class Upsample(nn.Module):
23
+ def __init__(self, in_channels, with_conv):
24
+ super().__init__()
25
+ self.with_conv = with_conv
26
+ if self.with_conv:
27
+ self.conv = torch.nn.Conv2d(in_channels,
28
+ in_channels,
29
+ kernel_size=3,
30
+ stride=1,
31
+ padding=1)
32
+
33
+ def forward(self, x):
34
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
35
+ if self.with_conv:
36
+ x = self.conv(x)
37
+ return x
38
+
39
+
40
+ class Downsample(nn.Module):
41
+ def __init__(self, in_channels, with_conv):
42
+ super().__init__()
43
+ self.with_conv = with_conv
44
+ if self.with_conv:
45
+ # no asymmetric padding in torch conv, must do it ourselves
46
+ self.conv = torch.nn.Conv2d(in_channels,
47
+ in_channels,
48
+ kernel_size=3,
49
+ stride=2,
50
+ padding=0)
51
+
52
+ def forward(self, x):
53
+ if self.with_conv:
54
+ pad = (0,1,0,1)
55
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
56
+ x = self.conv(x)
57
+ else:
58
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
59
+ return x
60
+
61
+
62
+ class ResnetBlock(nn.Module):
63
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
64
+ dropout, temb_channels=512):
65
+ super().__init__()
66
+ self.in_channels = in_channels
67
+ out_channels = in_channels if out_channels is None else out_channels
68
+ self.out_channels = out_channels
69
+ self.use_conv_shortcut = conv_shortcut
70
+
71
+ self.norm1 = Normalize(in_channels)
72
+ self.conv1 = torch.nn.Conv2d(in_channels,
73
+ out_channels,
74
+ kernel_size=3,
75
+ stride=1,
76
+ padding=1)
77
+ if temb_channels > 0:
78
+ self.temb_proj = torch.nn.Linear(temb_channels,
79
+ out_channels)
80
+ self.norm2 = Normalize(out_channels)
81
+ self.dropout = torch.nn.Dropout(dropout)
82
+ self.conv2 = torch.nn.Conv2d(out_channels,
83
+ out_channels,
84
+ kernel_size=3,
85
+ stride=1,
86
+ padding=1)
87
+ if self.in_channels != self.out_channels:
88
+ if self.use_conv_shortcut:
89
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
90
+ out_channels,
91
+ kernel_size=3,
92
+ stride=1,
93
+ padding=1)
94
+ else:
95
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
96
+ out_channels,
97
+ kernel_size=1,
98
+ stride=1,
99
+ padding=0)
100
+
101
+ def forward(self, x, temb):
102
+ h = x
103
+ h = self.norm1(h)
104
+ h = nonlinearity(h)
105
+ h = self.conv1(h)
106
+
107
+ if temb is not None:
108
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
109
+
110
+ h = self.norm2(h)
111
+ h = nonlinearity(h)
112
+ h = self.dropout(h)
113
+ h = self.conv2(h)
114
+
115
+ if self.in_channels != self.out_channels:
116
+ if self.use_conv_shortcut:
117
+ x = self.conv_shortcut(x)
118
+ else:
119
+ x = self.nin_shortcut(x)
120
+
121
+ return x+h
122
+
123
+
124
+ class AttnBlock(nn.Module):
125
+ def __init__(self, in_channels):
126
+ super().__init__()
127
+ print(f"building AttnBlock (vanilla) with {in_channels} in_channels")
128
+
129
+ self.in_channels = in_channels
130
+
131
+ self.norm = Normalize(in_channels)
132
+ self.q = torch.nn.Conv2d(in_channels,
133
+ in_channels,
134
+ kernel_size=1,
135
+ stride=1,
136
+ padding=0)
137
+ self.k = torch.nn.Conv2d(in_channels,
138
+ in_channels,
139
+ kernel_size=1,
140
+ stride=1,
141
+ padding=0)
142
+ self.v = torch.nn.Conv2d(in_channels,
143
+ in_channels,
144
+ kernel_size=1,
145
+ stride=1,
146
+ padding=0)
147
+ self.proj_out = torch.nn.Conv2d(in_channels,
148
+ in_channels,
149
+ kernel_size=1,
150
+ stride=1,
151
+ padding=0)
152
+
153
+ def forward(self, x):
154
+ h_ = x
155
+ h_ = self.norm(h_)
156
+ q = self.q(h_)
157
+ k = self.k(h_)
158
+ v = self.v(h_)
159
+
160
+ # compute attention
161
+ b,c,h,w = q.shape
162
+ q = q.reshape(b,c,h*w)
163
+ q = q.permute(0,2,1) # b,hw,c
164
+ k = k.reshape(b,c,h*w) # b,c,hw
165
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
166
+ w_ = w_ * (int(c)**(-0.5))
167
+ w_ = torch.nn.functional.softmax(w_, dim=2)
168
+
169
+ # attend to values
170
+ v = v.reshape(b,c,h*w)
171
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
172
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
173
+ h_ = h_.reshape(b,c,h,w)
174
+
175
+ h_ = self.proj_out(h_)
176
+
177
+ return x+h_
178
+
179
+
180
+ class MemoryEfficientAttnBlock(nn.Module):
181
+ """
182
+ Uses xformers efficient implementation,
183
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
184
+ Note: this is a single-head self-attention operation
185
+ """
186
+ #
187
+ def __init__(self, in_channels):
188
+ super().__init__()
189
+ print(f"building MemoryEfficientAttnBlock (xformers) with {in_channels} in_channels")
190
+ self.in_channels = in_channels
191
+
192
+ self.norm = Normalize(in_channels)
193
+ self.q = torch.nn.Conv2d(in_channels,
194
+ in_channels,
195
+ kernel_size=1,
196
+ stride=1,
197
+ padding=0)
198
+ self.k = torch.nn.Conv2d(in_channels,
199
+ in_channels,
200
+ kernel_size=1,
201
+ stride=1,
202
+ padding=0)
203
+ self.v = torch.nn.Conv2d(in_channels,
204
+ in_channels,
205
+ kernel_size=1,
206
+ stride=1,
207
+ padding=0)
208
+ self.proj_out = torch.nn.Conv2d(in_channels,
209
+ in_channels,
210
+ kernel_size=1,
211
+ stride=1,
212
+ padding=0)
213
+ self.attention_op: Optional[Any] = None
214
+
215
+ def forward(self, x):
216
+ h_ = x
217
+ h_ = self.norm(h_)
218
+ q = self.q(h_)
219
+ k = self.k(h_)
220
+ v = self.v(h_)
221
+
222
+ # compute attention
223
+ B, C, H, W = q.shape
224
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
225
+
226
+ q, k, v = map(
227
+ lambda t: t.unsqueeze(3)
228
+ .reshape(B, t.shape[1], 1, C)
229
+ .permute(0, 2, 1, 3)
230
+ .reshape(B * 1, t.shape[1], C)
231
+ .contiguous(),
232
+ (q, k, v),
233
+ )
234
+ out = Config.xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
235
+
236
+ out = (
237
+ out.unsqueeze(0)
238
+ .reshape(B, 1, out.shape[1], C)
239
+ .permute(0, 2, 1, 3)
240
+ .reshape(B, out.shape[1], C)
241
+ )
242
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
243
+ out = self.proj_out(out)
244
+ return x+out
245
+
246
+
247
+ class SDPAttnBlock(nn.Module):
248
+
249
+ def __init__(self, in_channels):
250
+ super().__init__()
251
+ print(f"building SDPAttnBlock (sdp) with {in_channels} in_channels")
252
+ self.in_channels = in_channels
253
+
254
+ self.norm = Normalize(in_channels)
255
+ self.q = torch.nn.Conv2d(in_channels,
256
+ in_channels,
257
+ kernel_size=1,
258
+ stride=1,
259
+ padding=0)
260
+ self.k = torch.nn.Conv2d(in_channels,
261
+ in_channels,
262
+ kernel_size=1,
263
+ stride=1,
264
+ padding=0)
265
+ self.v = torch.nn.Conv2d(in_channels,
266
+ in_channels,
267
+ kernel_size=1,
268
+ stride=1,
269
+ padding=0)
270
+ self.proj_out = torch.nn.Conv2d(in_channels,
271
+ in_channels,
272
+ kernel_size=1,
273
+ stride=1,
274
+ padding=0)
275
+
276
+ def forward(self, x):
277
+ h_ = x
278
+ h_ = self.norm(h_)
279
+ q = self.q(h_)
280
+ k = self.k(h_)
281
+ v = self.v(h_)
282
+
283
+ # compute attention
284
+ B, C, H, W = q.shape
285
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
286
+
287
+ q, k, v = map(
288
+ lambda t: t.unsqueeze(3)
289
+ .reshape(B, t.shape[1], 1, C)
290
+ .permute(0, 2, 1, 3)
291
+ .reshape(B * 1, t.shape[1], C)
292
+ .contiguous(),
293
+ (q, k, v),
294
+ )
295
+ out = F.scaled_dot_product_attention(q, k, v)
296
+
297
+ out = (
298
+ out.unsqueeze(0)
299
+ .reshape(B, 1, out.shape[1], C)
300
+ .permute(0, 2, 1, 3)
301
+ .reshape(B, out.shape[1], C)
302
+ )
303
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
304
+ out = self.proj_out(out)
305
+ return x+out
306
+
307
+
308
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
309
+ assert attn_type in ["vanilla", "sdp", "xformers", "linear", "none"], f'attn_type {attn_type} unknown'
310
+ if attn_type == "vanilla":
311
+ assert attn_kwargs is None
312
+ return AttnBlock(in_channels)
313
+ elif attn_type == "sdp":
314
+ return SDPAttnBlock(in_channels)
315
+ elif attn_type == "xformers":
316
+ return MemoryEfficientAttnBlock(in_channels)
317
+ elif attn_type == "none":
318
+ return nn.Identity(in_channels)
319
+ else:
320
+ raise NotImplementedError()
321
+
322
+
323
+ class Encoder(nn.Module):
324
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
325
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
326
+ resolution, z_channels, double_z=True, use_linear_attn=False,
327
+ **ignore_kwargs):
328
+ super().__init__()
329
+ ### setup attention type
330
+ if Config.attn_mode == AttnMode.SDP:
331
+ attn_type = "sdp"
332
+ elif Config.attn_mode == AttnMode.XFORMERS:
333
+ attn_type = "xformers"
334
+ else:
335
+ attn_type = "vanilla"
336
+ if use_linear_attn: attn_type = "linear"
337
+ self.ch = ch
338
+ self.temb_ch = 0
339
+ self.num_resolutions = len(ch_mult)
340
+ self.num_res_blocks = num_res_blocks
341
+ self.resolution = resolution
342
+ self.in_channels = in_channels
343
+
344
+ # downsampling
345
+ self.conv_in = torch.nn.Conv2d(in_channels,
346
+ self.ch,
347
+ kernel_size=3,
348
+ stride=1,
349
+ padding=1)
350
+
351
+ curr_res = resolution
352
+ in_ch_mult = (1,)+tuple(ch_mult)
353
+ self.in_ch_mult = in_ch_mult
354
+ self.down = nn.ModuleList()
355
+ for i_level in range(self.num_resolutions):
356
+ block = nn.ModuleList()
357
+ attn = nn.ModuleList()
358
+ block_in = ch*in_ch_mult[i_level]
359
+ block_out = ch*ch_mult[i_level]
360
+ for i_block in range(self.num_res_blocks):
361
+ block.append(ResnetBlock(in_channels=block_in,
362
+ out_channels=block_out,
363
+ temb_channels=self.temb_ch,
364
+ dropout=dropout))
365
+ block_in = block_out
366
+ if curr_res in attn_resolutions:
367
+ attn.append(make_attn(block_in, attn_type=attn_type))
368
+ down = nn.Module()
369
+ down.block = block
370
+ down.attn = attn
371
+ if i_level != self.num_resolutions-1:
372
+ down.downsample = Downsample(block_in, resamp_with_conv)
373
+ curr_res = curr_res // 2
374
+ self.down.append(down)
375
+
376
+ # middle
377
+ self.mid = nn.Module()
378
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
379
+ out_channels=block_in,
380
+ temb_channels=self.temb_ch,
381
+ dropout=dropout)
382
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
383
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
384
+ out_channels=block_in,
385
+ temb_channels=self.temb_ch,
386
+ dropout=dropout)
387
+
388
+ # end
389
+ self.norm_out = Normalize(block_in)
390
+ self.conv_out = torch.nn.Conv2d(block_in,
391
+ 2*z_channels if double_z else z_channels,
392
+ kernel_size=3,
393
+ stride=1,
394
+ padding=1)
395
+
396
+ def forward(self, x):
397
+ # timestep embedding
398
+ temb = None
399
+
400
+ # downsampling
401
+ hs = [self.conv_in(x)]
402
+ for i_level in range(self.num_resolutions):
403
+ for i_block in range(self.num_res_blocks):
404
+ h = self.down[i_level].block[i_block](hs[-1], temb)
405
+ if len(self.down[i_level].attn) > 0:
406
+ h = self.down[i_level].attn[i_block](h)
407
+ hs.append(h)
408
+ if i_level != self.num_resolutions-1:
409
+ hs.append(self.down[i_level].downsample(hs[-1]))
410
+
411
+ # middle
412
+ h = hs[-1]
413
+ h = self.mid.block_1(h, temb)
414
+ h = self.mid.attn_1(h)
415
+ h = self.mid.block_2(h, temb)
416
+
417
+ # end
418
+ h = self.norm_out(h)
419
+ h = nonlinearity(h)
420
+ h = self.conv_out(h)
421
+ return h
422
+
423
+
424
+ class Decoder(nn.Module):
425
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
426
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
427
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
428
+ **ignorekwargs):
429
+ super().__init__()
430
+ ### setup attention type
431
+ if Config.attn_mode == AttnMode.SDP:
432
+ attn_type = "sdp"
433
+ elif Config.attn_mode == AttnMode.XFORMERS:
434
+ attn_type = "xformers"
435
+ else:
436
+ attn_type = "vanilla"
437
+ if use_linear_attn: attn_type = "linear"
438
+ self.ch = ch
439
+ self.temb_ch = 0
440
+ self.num_resolutions = len(ch_mult)
441
+ self.num_res_blocks = num_res_blocks
442
+ self.resolution = resolution
443
+ self.in_channels = in_channels
444
+ self.give_pre_end = give_pre_end
445
+ self.tanh_out = tanh_out
446
+
447
+ # compute in_ch_mult, block_in and curr_res at lowest res
448
+ in_ch_mult = (1,)+tuple(ch_mult)
449
+ block_in = ch*ch_mult[self.num_resolutions-1]
450
+ curr_res = resolution // 2**(self.num_resolutions-1)
451
+ self.z_shape = (1,z_channels,curr_res,curr_res)
452
+
453
+ # z to block_in
454
+ self.conv_in = torch.nn.Conv2d(z_channels,
455
+ block_in,
456
+ kernel_size=3,
457
+ stride=1,
458
+ padding=1)
459
+
460
+ # middle
461
+ self.mid = nn.Module()
462
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
463
+ out_channels=block_in,
464
+ temb_channels=self.temb_ch,
465
+ dropout=dropout)
466
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
467
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
468
+ out_channels=block_in,
469
+ temb_channels=self.temb_ch,
470
+ dropout=dropout)
471
+
472
+ # upsampling
473
+ self.up = nn.ModuleList()
474
+ for i_level in reversed(range(self.num_resolutions)):
475
+ block = nn.ModuleList()
476
+ attn = nn.ModuleList()
477
+ block_out = ch*ch_mult[i_level]
478
+ for i_block in range(self.num_res_blocks+1):
479
+ block.append(ResnetBlock(in_channels=block_in,
480
+ out_channels=block_out,
481
+ temb_channels=self.temb_ch,
482
+ dropout=dropout))
483
+ block_in = block_out
484
+ if curr_res in attn_resolutions:
485
+ attn.append(make_attn(block_in, attn_type=attn_type))
486
+ up = nn.Module()
487
+ up.block = block
488
+ up.attn = attn
489
+ if i_level != 0:
490
+ up.upsample = Upsample(block_in, resamp_with_conv)
491
+ curr_res = curr_res * 2
492
+ self.up.insert(0, up) # prepend to get consistent order
493
+
494
+ # end
495
+ self.norm_out = Normalize(block_in)
496
+ self.conv_out = torch.nn.Conv2d(block_in,
497
+ out_ch,
498
+ kernel_size=3,
499
+ stride=1,
500
+ padding=1)
501
+
502
+ def forward(self, z):
503
+ #assert z.shape[1:] == self.z_shape[1:]
504
+ self.last_z_shape = z.shape
505
+
506
+ # timestep embedding
507
+ temb = None
508
+
509
+ # z to block_in
510
+ h = self.conv_in(z)
511
+
512
+ # middle
513
+ h = self.mid.block_1(h, temb)
514
+ h = self.mid.attn_1(h)
515
+ h = self.mid.block_2(h, temb)
516
+
517
+ # upsampling
518
+ for i_level in reversed(range(self.num_resolutions)):
519
+ for i_block in range(self.num_res_blocks+1):
520
+ h = self.up[i_level].block[i_block](h, temb)
521
+ if len(self.up[i_level].attn) > 0:
522
+ h = self.up[i_level].attn[i_block](h)
523
+ if i_level != 0:
524
+ h = self.up[i_level].upsample(h)
525
+
526
+ # end
527
+ if self.give_pre_end:
528
+ return h
529
+
530
+ h = self.norm_out(h)
531
+ h = nonlinearity(h)
532
+ h = self.conv_out(h)
533
+ if self.tanh_out:
534
+ h = torch.tanh(h)
535
+ return h
536
+
537
+
538
+ class AutoencoderKL(nn.Module):
539
+
540
+ def __init__(self, ddconfig, embed_dim):
541
+ super().__init__()
542
+ self.encoder = Encoder(**ddconfig)
543
+ self.decoder = Decoder(**ddconfig)
544
+ assert ddconfig["double_z"]
545
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
546
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
547
+ self.embed_dim = embed_dim
548
+
549
+ def encode(self, x):
550
+ h = self.encoder(x)
551
+ moments = self.quant_conv(h)
552
+ posterior = DiagonalGaussianDistribution(moments)
553
+ return posterior
554
+
555
+ def decode(self, z):
556
+ z = self.post_quant_conv(z)
557
+ dec = self.decoder(z)
558
+ return dec
559
+
560
+ def forward(self, input, sample_posterior=True):
561
+ posterior = self.encode(input)
562
+ if sample_posterior:
563
+ z = posterior.sample()
564
+ else:
565
+ z = posterior.mode()
566
+ dec = self.decode(z)
567
+ return dec, posterior