csaybar commited on
Commit
476803e
·
verified ·
1 Parent(s): 0475d44

Upload 5 files

Browse files
swin2_mose/model.py ADDED
@@ -0,0 +1,1157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Source code: https://github.com/mv-lab/swin2sr
3
+ #
4
+ # -----------------------------------------------------------------------------------
5
+ # Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/2209.11345
6
+ # Written by Conde and Choi et al.
7
+ # -----------------------------------------------------------------------------------
8
+
9
+ import math
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
16
+
17
+ from utils import window_reverse, Mlp, window_partition
18
+ from moe import MoE
19
+
20
+
21
+ class WindowAttention(nn.Module):
22
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
23
+ It supports both of shifted and non-shifted window.
24
+ Args:
25
+ dim (int): Number of input channels.
26
+ window_size (tuple[int]): The height and width of the window.
27
+ num_heads (int): Number of attention heads.
28
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
29
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
30
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
31
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
32
+ """
33
+
34
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
35
+ pretrained_window_size=[0, 0],
36
+ use_lepe=False,
37
+ use_cpb_bias=True,
38
+ use_rpe_bias=False):
39
+
40
+ super().__init__()
41
+ self.dim = dim
42
+ self.window_size = window_size # Wh, Ww
43
+ self.pretrained_window_size = pretrained_window_size
44
+ self.num_heads = num_heads
45
+
46
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
47
+
48
+ self.use_cpb_bias = use_cpb_bias
49
+
50
+ if self.use_cpb_bias:
51
+ print('positional encoder: CPB')
52
+ # mlp to generate continuous relative position bias
53
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
54
+ nn.ReLU(inplace=True),
55
+ nn.Linear(512, num_heads, bias=False))
56
+
57
+ # get relative_coords_table
58
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
59
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
60
+ relative_coords_table = torch.stack(
61
+ torch.meshgrid([relative_coords_h,
62
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
63
+ if pretrained_window_size[0] > 0:
64
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
65
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
66
+ else:
67
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
68
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
69
+ relative_coords_table *= 8 # normalize to -8, 8
70
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
71
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
72
+
73
+ self.register_buffer("relative_coords_table", relative_coords_table)
74
+
75
+ # get pair-wise relative position index for each token inside the window
76
+ coords_h = torch.arange(self.window_size[0])
77
+ coords_w = torch.arange(self.window_size[1])
78
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
79
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
80
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
81
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
82
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
83
+ relative_coords[:, :, 1] += self.window_size[1] - 1
84
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
85
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
86
+ self.register_buffer("relative_position_index", relative_position_index)
87
+
88
+ self.use_rpe_bias = use_rpe_bias
89
+ if self.use_rpe_bias:
90
+ print('positional encoder: RPE')
91
+ # define a parameter table of relative position bias
92
+ self.relative_position_bias_table = nn.Parameter(
93
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
94
+
95
+ # get pair-wise relative position index for each token inside the window
96
+ coords_h = torch.arange(self.window_size[0])
97
+ coords_w = torch.arange(self.window_size[1])
98
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
99
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
100
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
101
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
102
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
103
+ relative_coords[:, :, 1] += self.window_size[1] - 1
104
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
105
+ rpe_relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
106
+ self.register_buffer("rpe_relative_position_index", rpe_relative_position_index)
107
+
108
+ trunc_normal_(self.relative_position_bias_table, std=.02)
109
+
110
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
111
+ if qkv_bias:
112
+ self.q_bias = nn.Parameter(torch.zeros(dim))
113
+ self.v_bias = nn.Parameter(torch.zeros(dim))
114
+ else:
115
+ self.q_bias = None
116
+ self.v_bias = None
117
+ self.attn_drop = nn.Dropout(attn_drop)
118
+ self.proj = nn.Linear(dim, dim)
119
+ self.proj_drop = nn.Dropout(proj_drop)
120
+ self.softmax = nn.Softmax(dim=-1)
121
+
122
+ self.use_lepe = use_lepe
123
+ if self.use_lepe:
124
+ print('positional encoder: LEPE')
125
+ self.get_v = nn.Conv2d(
126
+ dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
127
+
128
+ def forward(self, x, mask=None):
129
+ """
130
+ Args:
131
+ x: input features with shape of (num_windows*B, N, C)
132
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
133
+ """
134
+ B_, N, C = x.shape
135
+ qkv_bias = None
136
+ if self.q_bias is not None:
137
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
138
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
139
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
140
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
141
+
142
+ if self.use_lepe:
143
+ lepe = self.lepe_pos(v)
144
+
145
+ # cosine attention
146
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
147
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
148
+ attn = attn * logit_scale
149
+
150
+ if self.use_cpb_bias:
151
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
152
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
153
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
154
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
155
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
156
+ attn = attn + relative_position_bias.unsqueeze(0)
157
+
158
+ if self.use_rpe_bias:
159
+ relative_position_bias = self.relative_position_bias_table[self.rpe_relative_position_index.view(-1)].view(
160
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
161
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
162
+ attn = attn + relative_position_bias.unsqueeze(0)
163
+
164
+ if mask is not None:
165
+ nW = mask.shape[0]
166
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
167
+ attn = attn.view(-1, self.num_heads, N, N)
168
+ attn = self.softmax(attn)
169
+ else:
170
+ attn = self.softmax(attn)
171
+
172
+ attn = self.attn_drop(attn)
173
+
174
+ x = (attn @ v)
175
+
176
+ if self.use_lepe:
177
+ x = x + lepe
178
+
179
+ x = x.transpose(1, 2).reshape(B_, N, C)
180
+ x = self.proj(x)
181
+ x = self.proj_drop(x)
182
+ return x
183
+
184
+ def lepe_pos(self, v):
185
+ B, NH, HW, NW = v.shape
186
+ C = NH * NW
187
+ H = W = int(math.sqrt(HW))
188
+ v = v.transpose(-2, -1).contiguous().view(B, C, H, W)
189
+ lepe = self.get_v(v)
190
+ lepe = lepe.reshape(-1, self.num_heads, NW, HW)
191
+ lepe = lepe.permute(0, 1, 3, 2).contiguous()
192
+ return lepe
193
+
194
+ def extra_repr(self) -> str:
195
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
196
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
197
+
198
+ def flops(self, N):
199
+ # calculate flops for 1 window with token length of N
200
+ flops = 0
201
+ # qkv = self.qkv(x)
202
+ flops += N * self.dim * 3 * self.dim
203
+ # attn = (q @ k.transpose(-2, -1))
204
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
205
+ # x = (attn @ v)
206
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
207
+ # x = self.proj(x)
208
+ flops += N * self.dim * self.dim
209
+ return flops
210
+
211
+
212
+ class SwinTransformerBlock(nn.Module):
213
+ r""" Swin Transformer Block.
214
+ Args:
215
+ dim (int): Number of input channels.
216
+ input_resolution (tuple[int]): Input resulotion.
217
+ num_heads (int): Number of attention heads.
218
+ window_size (int): Window size.
219
+ shift_size (int): Shift size for SW-MSA.
220
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
221
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
222
+ drop (float, optional): Dropout rate. Default: 0.0
223
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
224
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
225
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
226
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
227
+ pretrained_window_size (int): Window size in pre-training.
228
+ """
229
+
230
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
231
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
232
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0,
233
+ use_lepe=False,
234
+ use_cpb_bias=True,
235
+ MoE_config=None,
236
+ use_rpe_bias=False):
237
+ super().__init__()
238
+ self.dim = dim
239
+ self.input_resolution = input_resolution
240
+ self.num_heads = num_heads
241
+ self.window_size = window_size
242
+ self.shift_size = shift_size
243
+ self.mlp_ratio = mlp_ratio
244
+ if min(self.input_resolution) <= self.window_size:
245
+ # if window size is larger than input resolution, we don't partition windows
246
+ self.shift_size = 0
247
+ self.window_size = min(self.input_resolution)
248
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
249
+
250
+ self.norm1 = norm_layer(dim)
251
+ self.attn = WindowAttention(
252
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
253
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
254
+ pretrained_window_size=to_2tuple(pretrained_window_size),
255
+ use_lepe=use_lepe,
256
+ use_cpb_bias=use_cpb_bias,
257
+ use_rpe_bias=use_rpe_bias)
258
+
259
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
260
+ self.norm2 = norm_layer(dim)
261
+ mlp_hidden_dim = int(dim * mlp_ratio)
262
+
263
+ if MoE_config is None:
264
+ print('-->>> MLP')
265
+ self.mlp = Mlp(
266
+ in_features=dim, hidden_features=mlp_hidden_dim,
267
+ act_layer=act_layer, drop=drop)
268
+ else:
269
+ print('-->>> MOE')
270
+ print(MoE_config)
271
+ self.mlp = MoE(
272
+ input_size=dim, output_size=dim, hidden_size=mlp_hidden_dim,
273
+ **MoE_config)
274
+
275
+ if self.shift_size > 0:
276
+ attn_mask = self.calculate_mask(self.input_resolution)
277
+ else:
278
+ attn_mask = None
279
+
280
+ self.register_buffer("attn_mask", attn_mask)
281
+
282
+ def calculate_mask(self, x_size):
283
+ # calculate attention mask for SW-MSA
284
+ H, W = x_size
285
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
286
+ h_slices = (slice(0, -self.window_size),
287
+ slice(-self.window_size, -self.shift_size),
288
+ slice(-self.shift_size, None))
289
+ w_slices = (slice(0, -self.window_size),
290
+ slice(-self.window_size, -self.shift_size),
291
+ slice(-self.shift_size, None))
292
+ cnt = 0
293
+ for h in h_slices:
294
+ for w in w_slices:
295
+ img_mask[:, h, w, :] = cnt
296
+ cnt += 1
297
+
298
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
299
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
300
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
301
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
302
+
303
+ return attn_mask
304
+
305
+ def forward(self, x, x_size):
306
+ H, W = x_size
307
+ B, L, C = x.shape
308
+
309
+ shortcut = x
310
+ x = x.view(B, H, W, C)
311
+
312
+ # cyclic shift
313
+ if self.shift_size > 0:
314
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
315
+ else:
316
+ shifted_x = x
317
+
318
+ # partition windows
319
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
320
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
321
+
322
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
323
+ if self.input_resolution == x_size:
324
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
325
+ else:
326
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
327
+
328
+ # merge windows
329
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
330
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
331
+
332
+ # reverse cyclic shift
333
+ if self.shift_size > 0:
334
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
335
+ else:
336
+ x = shifted_x
337
+ x = x.view(B, H * W, C)
338
+ x = shortcut + self.drop_path(self.norm1(x))
339
+
340
+ # FFN
341
+
342
+ loss_moe = None
343
+ res = self.mlp(x)
344
+ if not torch.is_tensor(res):
345
+ res, loss_moe = res
346
+
347
+ x = x + self.drop_path(self.norm2(res))
348
+
349
+ return x, loss_moe
350
+
351
+ def extra_repr(self) -> str:
352
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
353
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
354
+
355
+ def flops(self):
356
+ flops = 0
357
+ H, W = self.input_resolution
358
+ # norm1
359
+ flops += self.dim * H * W
360
+ # W-MSA/SW-MSA
361
+ nW = H * W / self.window_size / self.window_size
362
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
363
+ # mlp
364
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
365
+ # norm2
366
+ flops += self.dim * H * W
367
+ return flops
368
+
369
+
370
+ class PatchMerging(nn.Module):
371
+ r""" Patch Merging Layer.
372
+ Args:
373
+ input_resolution (tuple[int]): Resolution of input feature.
374
+ dim (int): Number of input channels.
375
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
376
+ """
377
+
378
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
379
+ super().__init__()
380
+ self.input_resolution = input_resolution
381
+ self.dim = dim
382
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
383
+ self.norm = norm_layer(2 * dim)
384
+
385
+ def forward(self, x):
386
+ """
387
+ x: B, H*W, C
388
+ """
389
+ H, W = self.input_resolution
390
+ B, L, C = x.shape
391
+ assert L == H * W, "input feature has wrong size"
392
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
393
+
394
+ x = x.view(B, H, W, C)
395
+
396
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
397
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
398
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
399
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
400
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
401
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
402
+
403
+ x = self.reduction(x)
404
+ x = self.norm(x)
405
+
406
+ return x
407
+
408
+ def extra_repr(self) -> str:
409
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
410
+
411
+ def flops(self):
412
+ H, W = self.input_resolution
413
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
414
+ flops += H * W * self.dim // 2
415
+ return flops
416
+
417
+
418
+ class BasicLayer(nn.Module):
419
+ """ A basic Swin Transformer layer for one stage.
420
+ Args:
421
+ dim (int): Number of input channels.
422
+ input_resolution (tuple[int]): Input resolution.
423
+ depth (int): Number of blocks.
424
+ num_heads (int): Number of attention heads.
425
+ window_size (int): Local window size.
426
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
427
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
428
+ drop (float, optional): Dropout rate. Default: 0.0
429
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
430
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
431
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
432
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
433
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
434
+ pretrained_window_size (int): Local window size in pre-training.
435
+ """
436
+
437
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
438
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
439
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
440
+ pretrained_window_size=0,
441
+ use_lepe=False,
442
+ use_cpb_bias=True,
443
+ MoE_config=None,
444
+ use_rpe_bias=False):
445
+
446
+ super().__init__()
447
+ self.dim = dim
448
+ self.input_resolution = input_resolution
449
+ self.depth = depth
450
+ self.use_checkpoint = use_checkpoint
451
+
452
+ # build blocks
453
+ self.blocks = nn.ModuleList([
454
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
455
+ num_heads=num_heads, window_size=window_size,
456
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
457
+ mlp_ratio=mlp_ratio,
458
+ qkv_bias=qkv_bias,
459
+ drop=drop, attn_drop=attn_drop,
460
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
461
+ norm_layer=norm_layer,
462
+ pretrained_window_size=pretrained_window_size,
463
+ use_lepe=use_lepe,
464
+ use_cpb_bias=use_cpb_bias,
465
+ MoE_config=MoE_config,
466
+ use_rpe_bias=use_rpe_bias)
467
+ for i in range(depth)])
468
+
469
+ # patch merging layer
470
+ if downsample is not None:
471
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
472
+ else:
473
+ self.downsample = None
474
+
475
+ def forward(self, x, x_size):
476
+ loss_moe_all = 0
477
+ for blk in self.blocks:
478
+ if self.use_checkpoint:
479
+ x = checkpoint.checkpoint(blk, x, x_size)
480
+ else:
481
+ x = blk(x, x_size)
482
+
483
+ if not torch.is_tensor(x):
484
+ x, loss_moe = x
485
+ loss_moe_all += loss_moe or 0
486
+
487
+ if self.downsample is not None:
488
+ x = self.downsample(x)
489
+ return x, loss_moe_all
490
+
491
+ def extra_repr(self) -> str:
492
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
493
+
494
+ def flops(self):
495
+ flops = 0
496
+ for blk in self.blocks:
497
+ flops += blk.flops()
498
+ if self.downsample is not None:
499
+ flops += self.downsample.flops()
500
+ return flops
501
+
502
+ def _init_respostnorm(self):
503
+ for blk in self.blocks:
504
+ nn.init.constant_(blk.norm1.bias, 0)
505
+ nn.init.constant_(blk.norm1.weight, 0)
506
+ nn.init.constant_(blk.norm2.bias, 0)
507
+ nn.init.constant_(blk.norm2.weight, 0)
508
+
509
+ class PatchEmbed(nn.Module):
510
+ r""" Image to Patch Embedding
511
+ Args:
512
+ img_size (int): Image size. Default: 224.
513
+ patch_size (int): Patch token size. Default: 4.
514
+ in_chans (int): Number of input image channels. Default: 3.
515
+ embed_dim (int): Number of linear projection output channels. Default: 96.
516
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
517
+ """
518
+
519
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
520
+ super().__init__()
521
+ img_size = to_2tuple(img_size)
522
+ patch_size = to_2tuple(patch_size)
523
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
524
+ self.img_size = img_size
525
+ self.patch_size = patch_size
526
+ self.patches_resolution = patches_resolution
527
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
528
+
529
+ self.in_chans = in_chans
530
+ self.embed_dim = embed_dim
531
+
532
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
533
+ if norm_layer is not None:
534
+ self.norm = norm_layer(embed_dim)
535
+ else:
536
+ self.norm = None
537
+
538
+ def forward(self, x):
539
+ B, C, H, W = x.shape
540
+ # FIXME look at relaxing size constraints
541
+ # assert H == self.img_size[0] and W == self.img_size[1],
542
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
543
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
544
+ if self.norm is not None:
545
+ x = self.norm(x)
546
+ return x
547
+
548
+ def flops(self):
549
+ Ho, Wo = self.patches_resolution
550
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
551
+ if self.norm is not None:
552
+ flops += Ho * Wo * self.embed_dim
553
+ return flops
554
+
555
+
556
+ class RSTB(nn.Module):
557
+ """Residual Swin Transformer Block (RSTB).
558
+
559
+ Args:
560
+ dim (int): Number of input channels.
561
+ input_resolution (tuple[int]): Input resolution.
562
+ depth (int): Number of blocks.
563
+ num_heads (int): Number of attention heads.
564
+ window_size (int): Local window size.
565
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
566
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
567
+ drop (float, optional): Dropout rate. Default: 0.0
568
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
569
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
570
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
571
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
572
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
573
+ img_size: Input image size.
574
+ patch_size: Patch size.
575
+ resi_connection: The convolutional block before residual connection.
576
+ """
577
+
578
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
579
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
580
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
581
+ img_size=224, patch_size=4, resi_connection='1conv',
582
+ use_lepe=False,
583
+ use_cpb_bias=True,
584
+ MoE_config=None,
585
+ use_rpe_bias=False):
586
+ super(RSTB, self).__init__()
587
+
588
+ self.dim = dim
589
+ self.input_resolution = input_resolution
590
+
591
+ self.residual_group = BasicLayer(dim=dim,
592
+ input_resolution=input_resolution,
593
+ depth=depth,
594
+ num_heads=num_heads,
595
+ window_size=window_size,
596
+ mlp_ratio=mlp_ratio,
597
+ qkv_bias=qkv_bias,
598
+ drop=drop, attn_drop=attn_drop,
599
+ drop_path=drop_path,
600
+ norm_layer=norm_layer,
601
+ downsample=downsample,
602
+ use_checkpoint=use_checkpoint,
603
+ use_lepe=use_lepe,
604
+ use_cpb_bias=use_cpb_bias,
605
+ MoE_config=MoE_config,
606
+ use_rpe_bias=use_rpe_bias
607
+ )
608
+
609
+ if resi_connection == '1conv':
610
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
611
+ elif resi_connection == '3conv':
612
+ # to save parameters and memory
613
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
614
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
615
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
616
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
617
+
618
+ self.patch_embed = PatchEmbed(
619
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
620
+ norm_layer=None)
621
+
622
+ self.patch_unembed = PatchUnEmbed(
623
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
624
+ norm_layer=None)
625
+
626
+ def forward(self, x, x_size):
627
+ loss_moe = None
628
+ res = self.residual_group(x, x_size)
629
+
630
+ if not torch.is_tensor(res):
631
+ res, loss_moe = res
632
+
633
+ res = self.patch_embed(self.conv(self.patch_unembed(res, x_size)))
634
+ return res + x, loss_moe
635
+
636
+ def flops(self):
637
+ flops = 0
638
+ flops += self.residual_group.flops()
639
+ H, W = self.input_resolution
640
+ flops += H * W * self.dim * self.dim * 9
641
+ flops += self.patch_embed.flops()
642
+ flops += self.patch_unembed.flops()
643
+
644
+ return flops
645
+
646
+
647
+ class PatchUnEmbed(nn.Module):
648
+ r""" Image to Patch Unembedding
649
+
650
+ Args:
651
+ img_size (int): Image size. Default: 224.
652
+ patch_size (int): Patch token size. Default: 4.
653
+ in_chans (int): Number of input image channels. Default: 3.
654
+ embed_dim (int): Number of linear projection output channels. Default: 96.
655
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
656
+ """
657
+
658
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
659
+ super().__init__()
660
+ img_size = to_2tuple(img_size)
661
+ patch_size = to_2tuple(patch_size)
662
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
663
+ self.img_size = img_size
664
+ self.patch_size = patch_size
665
+ self.patches_resolution = patches_resolution
666
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
667
+
668
+ self.in_chans = in_chans
669
+ self.embed_dim = embed_dim
670
+
671
+ def forward(self, x, x_size):
672
+ B, HW, C = x.shape
673
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
674
+ return x
675
+
676
+ def flops(self):
677
+ flops = 0
678
+ return flops
679
+
680
+
681
+ class Upsample(nn.Sequential):
682
+ """Upsample module.
683
+
684
+ Args:
685
+ scale (int): Scale factor. Supported scales: 2^n and 3.
686
+ num_feat (int): Channel number of intermediate features.
687
+ """
688
+
689
+ def __init__(self, scale, num_feat):
690
+ m = []
691
+ if (scale & (scale - 1)) == 0: # scale = 2^n
692
+ for _ in range(int(math.log(scale, 2))):
693
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
694
+ m.append(nn.PixelShuffle(2))
695
+ elif scale == 3:
696
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
697
+ m.append(nn.PixelShuffle(3))
698
+ else:
699
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
700
+ super(Upsample, self).__init__(*m)
701
+
702
+ class Upsample_hf(nn.Sequential):
703
+ """Upsample module.
704
+
705
+ Args:
706
+ scale (int): Scale factor. Supported scales: 2^n and 3.
707
+ num_feat (int): Channel number of intermediate features.
708
+ """
709
+
710
+ def __init__(self, scale, num_feat):
711
+ m = []
712
+ if (scale & (scale - 1)) == 0: # scale = 2^n
713
+ for _ in range(int(math.log(scale, 2))):
714
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
715
+ m.append(nn.PixelShuffle(2))
716
+ elif scale == 3:
717
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
718
+ m.append(nn.PixelShuffle(3))
719
+ else:
720
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
721
+ super(Upsample_hf, self).__init__(*m)
722
+
723
+
724
+ class UpsampleOneStep(nn.Sequential):
725
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
726
+ Used in lightweight SR to save parameters.
727
+
728
+ Args:
729
+ scale (int): Scale factor. Supported scales: 2^n and 3.
730
+ num_feat (int): Channel number of intermediate features.
731
+
732
+ """
733
+
734
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
735
+ self.num_feat = num_feat
736
+ self.input_resolution = input_resolution
737
+ m = []
738
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
739
+ m.append(nn.PixelShuffle(scale))
740
+ super(UpsampleOneStep, self).__init__(*m)
741
+
742
+ def flops(self):
743
+ H, W = self.input_resolution
744
+ flops = H * W * self.num_feat * 3 * 9
745
+ return flops
746
+
747
+
748
+
749
+ class Swin2SR(nn.Module):
750
+ r""" Swin2SR
751
+ A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
752
+
753
+ Args:
754
+ img_size (int | tuple(int)): Input image size. Default 64
755
+ patch_size (int | tuple(int)): Patch size. Default: 1
756
+ in_chans (int): Number of input image channels. Default: 3
757
+ embed_dim (int): Patch embedding dimension. Default: 96
758
+ depths (tuple(int)): Depth of each Swin Transformer layer.
759
+ num_heads (tuple(int)): Number of attention heads in different layers.
760
+ window_size (int): Window size. Default: 7
761
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
762
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
763
+ drop_rate (float): Dropout rate. Default: 0
764
+ attn_drop_rate (float): Attention dropout rate. Default: 0
765
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
766
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
767
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
768
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
769
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
770
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
771
+ img_range: Image range. 1. or 255.
772
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
773
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
774
+ """
775
+
776
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
777
+ embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
778
+ window_size=7, mlp_ratio=4., qkv_bias=True,
779
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
780
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
781
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
782
+ use_lepe=False,
783
+ use_cpb_bias=True,
784
+ MoE_config=None,
785
+ use_rpe_bias=False,
786
+ **kwargs):
787
+ super(Swin2SR, self).__init__()
788
+ print('==== SWIN 2SR')
789
+ num_in_ch = in_chans
790
+ num_out_ch = in_chans
791
+ num_feat = 64
792
+ self.img_range = img_range
793
+ if in_chans == 3:
794
+ rgb_mean = (0.4488, 0.4371, 0.4040)
795
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
796
+ else:
797
+ self.mean = torch.zeros(1, 1, 1, 1)
798
+ self.upscale = upscale
799
+ self.upsampler = upsampler
800
+ self.window_size = window_size
801
+
802
+ #####################################################################################################
803
+ ################################### 1, shallow feature extraction ###################################
804
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
805
+
806
+ #####################################################################################################
807
+ ################################### 2, deep feature extraction ######################################
808
+ self.num_layers = len(depths)
809
+ self.embed_dim = embed_dim
810
+ self.ape = ape
811
+ self.patch_norm = patch_norm
812
+ self.num_features = embed_dim
813
+ self.mlp_ratio = mlp_ratio
814
+
815
+ # split image into non-overlapping patches
816
+ self.patch_embed = PatchEmbed(
817
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
818
+ norm_layer=norm_layer if self.patch_norm else None)
819
+ num_patches = self.patch_embed.num_patches
820
+ patches_resolution = self.patch_embed.patches_resolution
821
+ self.patches_resolution = patches_resolution
822
+
823
+ # merge non-overlapping patches into image
824
+ self.patch_unembed = PatchUnEmbed(
825
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
826
+ norm_layer=norm_layer if self.patch_norm else None)
827
+
828
+ # absolute position embedding
829
+ if self.ape:
830
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
831
+ trunc_normal_(self.absolute_pos_embed, std=.02)
832
+
833
+ self.pos_drop = nn.Dropout(p=drop_rate)
834
+
835
+ # stochastic depth
836
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
837
+
838
+ # build Residual Swin Transformer blocks (RSTB)
839
+ self.layers = nn.ModuleList()
840
+ for i_layer in range(self.num_layers):
841
+ layer = RSTB(dim=embed_dim,
842
+ input_resolution=(patches_resolution[0],
843
+ patches_resolution[1]),
844
+ depth=depths[i_layer],
845
+ num_heads=num_heads[i_layer],
846
+ window_size=window_size,
847
+ mlp_ratio=self.mlp_ratio,
848
+ qkv_bias=qkv_bias,
849
+ drop=drop_rate, attn_drop=attn_drop_rate,
850
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
851
+ norm_layer=norm_layer,
852
+ downsample=None,
853
+ use_checkpoint=use_checkpoint,
854
+ img_size=img_size,
855
+ patch_size=patch_size,
856
+ resi_connection=resi_connection,
857
+ use_lepe=use_lepe,
858
+ use_cpb_bias=use_cpb_bias,
859
+ MoE_config=MoE_config,
860
+ use_rpe_bias=use_rpe_bias,
861
+ )
862
+ self.layers.append(layer)
863
+
864
+ if self.upsampler == 'pixelshuffle_hf':
865
+ self.layers_hf = nn.ModuleList()
866
+ for i_layer in range(self.num_layers):
867
+ layer = RSTB(dim=embed_dim,
868
+ input_resolution=(patches_resolution[0],
869
+ patches_resolution[1]),
870
+ depth=depths[i_layer],
871
+ num_heads=num_heads[i_layer],
872
+ window_size=window_size,
873
+ mlp_ratio=self.mlp_ratio,
874
+ qkv_bias=qkv_bias,
875
+ drop=drop_rate, attn_drop=attn_drop_rate,
876
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
877
+ norm_layer=norm_layer,
878
+ downsample=None,
879
+ use_checkpoint=use_checkpoint,
880
+ img_size=img_size,
881
+ patch_size=patch_size,
882
+ resi_connection=resi_connection,
883
+ use_lepe=use_lepe,
884
+ use_cpb_bias=use_cpb_bias,
885
+ MoE_config=MoE_config,
886
+ use_rpe_bias=use_rpe_bias
887
+ )
888
+ self.layers_hf.append(layer)
889
+
890
+ self.norm = norm_layer(self.num_features)
891
+
892
+ # build the last conv layer in deep feature extraction
893
+ if resi_connection == '1conv':
894
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
895
+ elif resi_connection == '3conv':
896
+ # to save parameters and memory
897
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
898
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
899
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
900
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
901
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
902
+
903
+ #####################################################################################################
904
+ ################################ 3, high quality image reconstruction ################################
905
+ if self.upsampler == 'pixelshuffle':
906
+ # for classical SR
907
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
908
+ nn.LeakyReLU(inplace=True))
909
+ self.upsample = Upsample(upscale, num_feat)
910
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
911
+ elif self.upsampler == 'pixelshuffle_aux':
912
+ self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
913
+ self.conv_before_upsample = nn.Sequential(
914
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
915
+ nn.LeakyReLU(inplace=True))
916
+ self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
917
+ self.conv_after_aux = nn.Sequential(
918
+ nn.Conv2d(3, num_feat, 3, 1, 1),
919
+ nn.LeakyReLU(inplace=True))
920
+ self.upsample = Upsample(upscale, num_feat)
921
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
922
+
923
+ elif self.upsampler == 'pixelshuffle_hf':
924
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
925
+ nn.LeakyReLU(inplace=True))
926
+ self.upsample = Upsample(upscale, num_feat)
927
+ self.upsample_hf = Upsample_hf(upscale, num_feat)
928
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
929
+ self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
930
+ nn.LeakyReLU(inplace=True))
931
+ self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
932
+ self.conv_before_upsample_hf = nn.Sequential(
933
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
934
+ nn.LeakyReLU(inplace=True))
935
+ self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
936
+
937
+ elif self.upsampler == 'pixelshuffledirect':
938
+ # for lightweight SR (to save parameters)
939
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
940
+ (patches_resolution[0], patches_resolution[1]))
941
+ elif self.upsampler == 'nearest+conv':
942
+ # for real-world SR (less artifacts)
943
+ assert self.upscale == 4, 'only support x4 now.'
944
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
945
+ nn.LeakyReLU(inplace=True))
946
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
947
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
948
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
949
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
950
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
951
+ else:
952
+ # for image denoising and JPEG compression artifact reduction
953
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
954
+
955
+ self.apply(self._init_weights)
956
+
957
+ def _init_weights(self, m):
958
+ if isinstance(m, nn.Linear):
959
+ trunc_normal_(m.weight, std=.02)
960
+ if isinstance(m, nn.Linear) and m.bias is not None:
961
+ nn.init.constant_(m.bias, 0)
962
+ elif isinstance(m, nn.LayerNorm):
963
+ nn.init.constant_(m.bias, 0)
964
+ nn.init.constant_(m.weight, 1.0)
965
+
966
+ @torch.jit.ignore
967
+ def no_weight_decay(self):
968
+ return {'absolute_pos_embed'}
969
+
970
+ @torch.jit.ignore
971
+ def no_weight_decay_keywords(self):
972
+ return {'relative_position_bias_table'}
973
+
974
+ def check_image_size(self, x):
975
+ _, _, h, w = x.size()
976
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
977
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
978
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
979
+ return x
980
+
981
+ def forward_features(self, x):
982
+ x_size = (x.shape[2], x.shape[3])
983
+ x = self.patch_embed(x)
984
+ if self.ape:
985
+ x = x + self.absolute_pos_embed
986
+ x = self.pos_drop(x)
987
+
988
+ loss_moe_all = 0
989
+ for layer in self.layers:
990
+ x = layer(x, x_size)
991
+
992
+ if not torch.is_tensor(x):
993
+ x, loss_moe = x
994
+ loss_moe_all += loss_moe or 0
995
+
996
+ x = self.norm(x) # B L C
997
+ x = self.patch_unembed(x, x_size)
998
+
999
+ return x, loss_moe_all
1000
+
1001
+ def forward_features_hf(self, x):
1002
+ x_size = (x.shape[2], x.shape[3])
1003
+ x = self.patch_embed(x)
1004
+ if self.ape:
1005
+ x = x + self.absolute_pos_embed
1006
+ x = self.pos_drop(x)
1007
+
1008
+ loss_moe_all = 0
1009
+ for layer in self.layers_hf:
1010
+ x = layer(x, x_size)
1011
+
1012
+ if not torch.is_tensor(x):
1013
+ x, loss_moe = x
1014
+ loss_moe_all += loss_moe or 0
1015
+
1016
+ x = self.norm(x) # B L C
1017
+ x = self.patch_unembed(x, x_size)
1018
+
1019
+ return x, loss_moe_all
1020
+
1021
+ def forward_backbone(self, x):
1022
+ H, W = x.shape[2:]
1023
+ x = self.check_image_size(x)
1024
+
1025
+ self.mean = self.mean.type_as(x)
1026
+ x = (x - self.mean) * self.img_range
1027
+
1028
+ if self.upsampler == 'pixelshuffledirect':
1029
+ # for lightweight SR
1030
+ x = self.conv_first(x)
1031
+
1032
+ res = self.forward_features(x)
1033
+ if not torch.is_tensor(res):
1034
+ res, loss_moe = res
1035
+
1036
+ x = self.conv_after_body(res) + x
1037
+ else:
1038
+ raise Exception('not implemented yet')
1039
+
1040
+ x = x / self.img_range + self.mean
1041
+ return x
1042
+
1043
+ def forward(self, x):
1044
+ H, W = x.shape[2:]
1045
+ x = self.check_image_size(x)
1046
+
1047
+ self.mean = self.mean.type_as(x)
1048
+ x = (x - self.mean) * self.img_range
1049
+
1050
+ loss_moe = 0
1051
+ if self.upsampler == 'pixelshuffle':
1052
+ # for classical SR
1053
+ x = self.conv_first(x)
1054
+
1055
+ res = self.forward_features(x)
1056
+ if not torch.is_tensor(res):
1057
+ res, loss_moe = res
1058
+
1059
+ x = self.conv_after_body(res) + x
1060
+ x = self.conv_before_upsample(x)
1061
+ x = self.conv_last(self.upsample(x))
1062
+ elif self.upsampler == 'pixelshuffle_aux':
1063
+ bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
1064
+ bicubic = self.conv_bicubic(bicubic)
1065
+ x = self.conv_first(x)
1066
+
1067
+ res = self.forward_features(x)
1068
+ if not torch.is_tensor(res):
1069
+ res, loss_moe = res
1070
+
1071
+ x = self.conv_after_body(res) + x
1072
+ x = self.conv_before_upsample(x)
1073
+ aux = self.conv_aux(x) # b, 3, LR_H, LR_W
1074
+ x = self.conv_after_aux(aux)
1075
+ x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
1076
+ x = self.conv_last(x)
1077
+ aux = aux / self.img_range + self.mean
1078
+ elif self.upsampler == 'pixelshuffle_hf':
1079
+ # for classical SR with HF
1080
+ x = self.conv_first(x)
1081
+
1082
+ res = self.forward_features(x)
1083
+ if not torch.is_tensor(res):
1084
+ res, loss_moe = res
1085
+
1086
+ x = self.conv_after_body(res) + x
1087
+ x_before = self.conv_before_upsample(x)
1088
+ x_out = self.conv_last(self.upsample(x_before))
1089
+
1090
+ x_hf = self.conv_first_hf(x_before)
1091
+
1092
+ res_hf = self.forward_features_hf(x_hf)
1093
+ if not torch.is_tensor(res_hf):
1094
+ res_hf, loss_moe_hf = res_hf
1095
+ loss_moe += loss_moe_hf
1096
+
1097
+ x_hf = self.conv_after_body_hf(res_hf) + x_hf
1098
+ x_hf = self.conv_before_upsample_hf(x_hf)
1099
+ x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
1100
+ x = x_out + x_hf
1101
+ x_hf = x_hf / self.img_range + self.mean
1102
+
1103
+ elif self.upsampler == 'pixelshuffledirect':
1104
+ # for lightweight SR
1105
+ x = self.conv_first(x)
1106
+
1107
+ res = self.forward_features(x)
1108
+ if not torch.is_tensor(res):
1109
+ res, loss_moe = res
1110
+
1111
+ x = self.conv_after_body(res) + x
1112
+ x = self.upsample(x)
1113
+ elif self.upsampler == 'nearest+conv':
1114
+ # for real-world SR
1115
+ x = self.conv_first(x)
1116
+
1117
+ res = self.forward_features(x)
1118
+ if not torch.is_tensor(res):
1119
+ res, loss_moe = res
1120
+
1121
+ x = self.conv_after_body(res) + x
1122
+ x = self.conv_before_upsample(x)
1123
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
1124
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
1125
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
1126
+ else:
1127
+ # for image denoising and JPEG compression artifact reduction
1128
+ x_first = self.conv_first(x)
1129
+
1130
+ res = self.forward_features(x_first)
1131
+ if not torch.is_tensor(res):
1132
+ res, loss_moe = res
1133
+
1134
+ res = self.conv_after_body(res) + x_first
1135
+ x = x + self.conv_last(res)
1136
+
1137
+ x = x / self.img_range + self.mean
1138
+ if self.upsampler == "pixelshuffle_aux":
1139
+ return x[:, :, :H*self.upscale, :W*self.upscale], aux, loss_moe
1140
+
1141
+ elif self.upsampler == "pixelshuffle_hf":
1142
+ x_out = x_out / self.img_range + self.mean
1143
+ return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale], loss_moe
1144
+
1145
+ else:
1146
+ return x[:, :, :H*self.upscale, :W*self.upscale], loss_moe
1147
+
1148
+ def flops(self):
1149
+ flops = 0
1150
+ H, W = self.patches_resolution
1151
+ flops += H * W * 3 * self.embed_dim * 9
1152
+ flops += self.patch_embed.flops()
1153
+ for i, layer in enumerate(self.layers):
1154
+ flops += layer.flops()
1155
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
1156
+ flops += self.upsample.flops()
1157
+ return flops
swin2_mose/moe.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Source code: https://github.com/davidmrau/mixture-of-experts
3
+ #
4
+
5
+ # Sparsely-Gated Mixture-of-Experts Layers.
6
+ # See "Outrageously Large Neural Networks"
7
+ # https://arxiv.org/abs/1701.06538
8
+ #
9
+ # Author: David Rau
10
+ #
11
+ # The code is based on the TensorFlow implementation:
12
+ # https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/expert_utils.py
13
+
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.distributions.normal import Normal
18
+ from copy import deepcopy
19
+ import numpy as np
20
+
21
+ from utils import Mlp as MLP
22
+
23
+ class SparseDispatcher(object):
24
+ """Helper for implementing a mixture of experts.
25
+ The purpose of this class is to create input minibatches for the
26
+ experts and to combine the results of the experts to form a unified
27
+ output tensor.
28
+ There are two functions:
29
+ dispatch - take an input Tensor and create input Tensors for each expert.
30
+ combine - take output Tensors from each expert and form a combined output
31
+ Tensor. Outputs from different experts for the same batch element are
32
+ summed together, weighted by the provided "gates".
33
+ The class is initialized with a "gates" Tensor, which specifies which
34
+ batch elements go to which experts, and the weights to use when combining
35
+ the outputs. Batch element b is sent to expert e iff gates[b, e] != 0.
36
+ The inputs and outputs are all two-dimensional [batch, depth].
37
+ Caller is responsible for collapsing additional dimensions prior to
38
+ calling this class and reshaping the output to the original shape.
39
+ See common_layers.reshape_like().
40
+ Example use:
41
+ gates: a float32 `Tensor` with shape `[batch_size, num_experts]`
42
+ inputs: a float32 `Tensor` with shape `[batch_size, input_size]`
43
+ experts: a list of length `num_experts` containing sub-networks.
44
+ dispatcher = SparseDispatcher(num_experts, gates)
45
+ expert_inputs = dispatcher.dispatch(inputs)
46
+ expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)]
47
+ outputs = dispatcher.combine(expert_outputs)
48
+ The preceding code sets the output for a particular example b to:
49
+ output[b] = Sum_i(gates[b, i] * experts[i](inputs[b]))
50
+ This class takes advantage of sparsity in the gate matrix by including in the
51
+ `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`.
52
+ """
53
+
54
+ def __init__(self, num_experts, gates):
55
+ """Create a SparseDispatcher."""
56
+
57
+ self._gates = gates
58
+ self._num_experts = num_experts
59
+ # sort experts
60
+ sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
61
+ # drop indices
62
+ _, self._expert_index = sorted_experts.split(1, dim=1)
63
+ # get according batch index for each expert
64
+ self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
65
+ # calculate num samples that each expert gets
66
+ self._part_sizes = (gates > 0).sum(0).tolist()
67
+ # expand gates to match with self._batch_index
68
+ gates_exp = gates[self._batch_index.flatten()]
69
+ self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)
70
+
71
+ def dispatch(self, inp):
72
+ """Create one input Tensor for each expert.
73
+ The `Tensor` for a expert `i` contains the slices of `inp` corresponding
74
+ to the batch elements `b` where `gates[b, i] > 0`.
75
+ Args:
76
+ inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]`
77
+ Returns:
78
+ a list of `num_experts` `Tensor`s with shapes
79
+ `[expert_batch_size_i, <extra_input_dims>]`.
80
+ """
81
+
82
+ # assigns samples to experts whose gate is nonzero
83
+
84
+ # expand according to batch index so we can just split by _part_sizes
85
+ inp_exp = inp[self._batch_index].squeeze(1)
86
+ return torch.split(inp_exp, self._part_sizes, dim=0)
87
+
88
+ def combine(self, expert_out, multiply_by_gates=True, cnn_combine=None):
89
+ """Sum together the expert output, weighted by the gates.
90
+ The slice corresponding to a particular batch element `b` is computed
91
+ as the sum over all experts `i` of the expert output, weighted by the
92
+ corresponding gate values. If `multiply_by_gates` is set to False, the
93
+ gate values are ignored.
94
+ Args:
95
+ expert_out: a list of `num_experts` `Tensor`s, each with shape
96
+ `[expert_batch_size_i, <extra_output_dims>]`.
97
+ multiply_by_gates: a boolean
98
+ Returns:
99
+ a `Tensor` with shape `[batch_size, <extra_output_dims>]`.
100
+ """
101
+ # apply exp to expert outputs, so we are not longer in log space
102
+ stitched = torch.cat(expert_out, 0)
103
+
104
+ if multiply_by_gates:
105
+ stitched = stitched.mul(self._nonzero_gates.unsqueeze(1))
106
+ zeros = torch.zeros((self._gates.size(0),) + expert_out[-1].shape[1:],
107
+ requires_grad=True, device=stitched.device)
108
+ # combine samples that have been processed by the same k experts
109
+
110
+ if cnn_combine is not None:
111
+ return self.smartly_combine(stitched, cnn_combine)
112
+
113
+ combined = zeros.index_add(0, self._batch_index, stitched.float())
114
+ return combined
115
+
116
+ def smartly_combine(self, stitched, cnn_combine):
117
+ idxes = []
118
+ for i in self._batch_index.unique():
119
+ idx = (self._batch_index == i).nonzero().squeeze(1)
120
+ idxes.append(idx)
121
+ idxes = torch.stack(idxes)
122
+ return cnn_combine(stitched[idxes]).squeeze(1)
123
+
124
+ def expert_to_gates(self):
125
+ """Gate values corresponding to the examples in the per-expert `Tensor`s.
126
+ Returns:
127
+ a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32`
128
+ and shapes `[expert_batch_size_i]`
129
+ """
130
+ # split nonzero gates for each expert
131
+ return torch.split(self._nonzero_gates, self._part_sizes, dim=0)
132
+
133
+
134
+ def build_experts(experts_cfg, default_cfg, num_experts):
135
+ experts_cfg = deepcopy(experts_cfg)
136
+ if experts_cfg is None:
137
+ # old build way
138
+ return nn.ModuleList([
139
+ MLP(*default_cfg)
140
+ for i in range(num_experts)])
141
+ # new build way: mix mlp with leff
142
+ experts = []
143
+ for e_cfg in experts_cfg:
144
+ type_ = e_cfg.pop('type')
145
+ if type_ == 'mlp':
146
+ experts.append(MLP(*default_cfg))
147
+ return nn.ModuleList(experts)
148
+
149
+
150
+ class MoE(nn.Module):
151
+ """Call a Sparsely gated mixture of experts layer with 1-layer
152
+ Feed-Forward networks as experts.
153
+
154
+ Args:
155
+ input_size: integer - size of the input
156
+ output_size: integer - size of the input
157
+ num_experts: an integer - number of experts
158
+ hidden_size: an integer - hidden size of the experts
159
+ noisy_gating: a boolean
160
+ k: an integer - how many experts to use for each batch element
161
+ """
162
+
163
+ def __init__(self, input_size, output_size, num_experts, hidden_size,
164
+ experts=None, noisy_gating=True, k=4,
165
+ x_gating=None, with_noise=True, with_smart_merger=None):
166
+ super(MoE, self).__init__()
167
+ self.noisy_gating = noisy_gating
168
+ self.num_experts = num_experts
169
+ self.output_size = output_size
170
+ self.input_size = input_size
171
+ self.hidden_size = hidden_size
172
+ self.k = k
173
+ self.with_noise = with_noise
174
+ # instantiate experts
175
+ self.experts = build_experts(
176
+ experts,
177
+ (self.input_size, self.hidden_size, self.output_size),
178
+ num_experts)
179
+ self.w_gate = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True)
180
+ self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True)
181
+
182
+ self.x_gating = x_gating
183
+ if self.x_gating == 'conv1d':
184
+ self.x_gate = nn.Conv1d(4096, 1, kernel_size=3, padding=1)
185
+
186
+ self.softplus = nn.Softplus()
187
+ self.softmax = nn.Softmax(1)
188
+ self.register_buffer("mean", torch.tensor([0.0]))
189
+ self.register_buffer("std", torch.tensor([1.0]))
190
+ assert(self.k <= self.num_experts)
191
+
192
+ self.cnn_combine = None
193
+ if with_smart_merger == 'v1':
194
+ print('with SMART MERGER')
195
+ self.cnn_combine = nn.Conv2d(self.k, 1, kernel_size=3, padding=1)
196
+
197
+ def cv_squared(self, x):
198
+ """The squared coefficient of variation of a sample.
199
+ Useful as a loss to encourage a positive distribution to be more uniform.
200
+ Epsilons added for numerical stability.
201
+ Returns 0 for an empty Tensor.
202
+ Args:
203
+ x: a `Tensor`.
204
+ Returns:
205
+ a `Scalar`.
206
+ """
207
+ eps = 1e-10
208
+ # if only num_experts = 1
209
+
210
+ if x.shape[0] == 1:
211
+ return torch.tensor([0], device=x.device, dtype=x.dtype)
212
+ return x.float().var() / (x.float().mean()**2 + eps)
213
+
214
+ def _gates_to_load(self, gates):
215
+ """Compute the true load per expert, given the gates.
216
+ The load is the number of examples for which the corresponding gate is >0.
217
+ Args:
218
+ gates: a `Tensor` of shape [batch_size, n]
219
+ Returns:
220
+ a float32 `Tensor` of shape [n]
221
+ """
222
+ return (gates > 0).sum(0)
223
+
224
+ def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
225
+ """Helper function to NoisyTopKGating.
226
+ Computes the probability that value is in top k, given different random noise.
227
+ This gives us a way of backpropagating from a loss that balances the number
228
+ of times each expert is in the top k experts per example.
229
+ In the case of no noise, pass in None for noise_stddev, and the result will
230
+ not be differentiable.
231
+ Args:
232
+ clean_values: a `Tensor` of shape [batch, n].
233
+ noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
234
+ normally distributed noise with standard deviation noise_stddev.
235
+ noise_stddev: a `Tensor` of shape [batch, n], or None
236
+ noisy_top_values: a `Tensor` of shape [batch, m].
237
+ "values" Output of tf.top_k(noisy_top_values, m). m >= k+1
238
+ Returns:
239
+ a `Tensor` of shape [batch, n].
240
+ """
241
+ batch = clean_values.size(0)
242
+ m = noisy_top_values.size(1)
243
+ top_values_flat = noisy_top_values.flatten()
244
+
245
+ threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.k
246
+ threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
247
+ is_in = torch.gt(noisy_values, threshold_if_in)
248
+ threshold_positions_if_out = threshold_positions_if_in - 1
249
+ threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
250
+ # is each value currently in the top k.
251
+ normal = Normal(self.mean, self.std)
252
+ prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev)
253
+ prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev)
254
+ prob = torch.where(is_in, prob_if_in, prob_if_out)
255
+ return prob
256
+
257
+ def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
258
+ """Noisy top-k gating.
259
+ See paper: https://arxiv.org/abs/1701.06538.
260
+ Args:
261
+ x: input Tensor with shape [batch_size, input_size]
262
+ train: a boolean - we only add noise at training time.
263
+ noise_epsilon: a float
264
+ Returns:
265
+ gates: a Tensor with shape [batch_size, num_experts]
266
+ load: a Tensor with shape [num_experts]
267
+ """
268
+ clean_logits = x @ self.w_gate
269
+ if self.noisy_gating and train:
270
+ raw_noise_stddev = x @ self.w_noise
271
+ noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon))
272
+ noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
273
+ logits = noisy_logits
274
+ else:
275
+ logits = clean_logits
276
+
277
+ # calculate topk + 1 that will be needed for the noisy gates
278
+ top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
279
+ top_k_logits = top_logits[:, :self.k]
280
+ top_k_indices = top_indices[:, :self.k]
281
+ top_k_gates = self.softmax(top_k_logits)
282
+
283
+ zeros = torch.zeros_like(logits, requires_grad=True)
284
+ gates = zeros.scatter(1, top_k_indices, top_k_gates)
285
+
286
+ if self.noisy_gating and self.k < self.num_experts and train:
287
+ load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0)
288
+ else:
289
+ load = self._gates_to_load(gates)
290
+ return gates, load
291
+
292
+ def forward(self, x, loss_coef=1e-2):
293
+ """Args:
294
+ x: tensor shape [batch_size, input_size]
295
+ train: a boolean scalar.
296
+ loss_coef: a scalar - multiplier on load-balancing losses
297
+
298
+ Returns:
299
+ y: a tensor with shape [batch_size, output_size].
300
+ extra_training_loss: a scalar. This should be added into the overall
301
+ training loss of the model. The backpropagation of this loss
302
+ encourages all experts to be approximately equally used across a batch.
303
+ """
304
+ if self.x_gating is not None:
305
+ xg = self.x_gate(x).squeeze(1)
306
+ else:
307
+ xg = x.mean(1)
308
+
309
+ gates, load = self.noisy_top_k_gating(
310
+ xg, self.training and self.with_noise)
311
+ # calculate importance loss
312
+ importance = gates.sum(0)
313
+ #
314
+ loss = self.cv_squared(importance) + self.cv_squared(load)
315
+ loss *= loss_coef
316
+
317
+ dispatcher = SparseDispatcher(self.num_experts, gates)
318
+ expert_inputs = dispatcher.dispatch(x)
319
+ gates = dispatcher.expert_to_gates()
320
+ expert_outputs = [self.experts[i](expert_inputs[i])
321
+ for i in range(self.num_experts)]
322
+ y = dispatcher.combine(expert_outputs, cnn_combine=self.cnn_combine)
323
+ return y, loss
swin2_mose/run.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import Swin2SR
3
+
4
+ model_weights = "model-70.pt"
5
+ model_params = {
6
+ "upscale": 2,
7
+ "in_chans": 4,
8
+ "img_size": 64,
9
+ "window_size": 16,
10
+ "img_range": 1.,
11
+ "depths": [6, 6, 6, 6],
12
+ "embed_dim": 90,
13
+ "num_heads": [6, 6, 6, 6],
14
+ "mlp_ratio": 2,
15
+ "upsampler": "pixelshuffledirect",
16
+ "resi_connection": "1conv"
17
+ }
18
+
19
+ sr_model = Swin2SR(**model_params)
20
+ sr_model.load_state_dict(torch.load(model_weights))
swin2_mose/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ def window_reverse(windows, window_size, H, W):
5
+ """
6
+ Args:
7
+ windows: (num_windows*B, window_size, window_size, C)
8
+ window_size (int): Window size
9
+ H (int): Height of image
10
+ W (int): Width of image
11
+
12
+ Returns:
13
+ x: (B, H, W, C)
14
+ """
15
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
16
+ x = windows.view(B, H // window_size, W // window_size, window_size,
17
+ window_size, -1)
18
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
19
+ return x
20
+
21
+
22
+ class Mlp(nn.Module):
23
+ def __init__(self, in_features, hidden_features=None, out_features=None,
24
+ act_layer=nn.GELU, drop=0.):
25
+ super().__init__()
26
+ out_features = out_features or in_features
27
+ hidden_features = hidden_features or in_features
28
+ self.fc1 = nn.Linear(in_features, hidden_features)
29
+ self.act = act_layer()
30
+ self.fc2 = nn.Linear(hidden_features, out_features)
31
+ self.drop = nn.Dropout(drop)
32
+
33
+ def forward(self, x):
34
+ x = self.fc1(x)
35
+ x = self.act(x)
36
+ x = self.drop(x)
37
+ x = self.fc2(x)
38
+ x = self.drop(x)
39
+ return x
40
+
41
+
42
+ def window_partition(x, window_size):
43
+ """
44
+ Args:
45
+ x: (B, H, W, C)
46
+ window_size (int): window size
47
+
48
+ Returns:
49
+ windows: (num_windows*B, window_size, window_size, C)
50
+ """
51
+ B, H, W, C = x.shape
52
+ x = x.view(B, H // window_size, window_size,
53
+ W // window_size, window_size, C)
54
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(
55
+ -1, window_size, window_size, C)
56
+ return windows
swin2_mose/weights/model-70.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9f1229521879af2c8162f7a32fe278e487d0bc0826dddccc87a4e22294aa067
3
+ size 118890958