zzzzzeee commited on
Commit
ae2a01b
·
verified ·
1 Parent(s): 1c73d88

Create submodel.py

Browse files
Files changed (1) hide show
  1. submodel.py +833 -0
submodel.py ADDED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint as checkpoint
5
+ import numpy as np
6
+ from timm.models.layers import DropPath, trunc_normal_
7
+
8
+ from functools import reduce, lru_cache
9
+ from operator import mul
10
+ from einops import rearrange
11
+
12
+ from model.submodules import ResidualBlock
13
+
14
+
15
+ class residual_feature_generator(nn.Module):
16
+ def __init__(self, dim):
17
+ super(residual_feature_generator, self).__init__()
18
+ self.dim = dim
19
+ self.resblock1 = ResidualBlock(dim, dim, 1, norm='BN')
20
+ self.resblock2 = ResidualBlock(dim, dim, 1, norm='BN')
21
+ self.resblock3 = ResidualBlock(dim, dim, 1, norm='BN')
22
+ self.resblock4 = ResidualBlock(dim, dim, 1, norm='BN')
23
+ def forward(self, x):
24
+ out = self.resblock1(x)
25
+ out = self.resblock2(out)
26
+ out = self.resblock3(out)
27
+ out = self.resblock4(out)
28
+ return out
29
+
30
+
31
+ class feature_generator(nn.Module):
32
+ def __init__(self, dim, kernel_size=3):
33
+ super(feature_generator, self).__init__()
34
+ self.dim = dim
35
+ self.kernel_size = kernel_size
36
+ self.conv1 = nn.Conv2d(in_channels=dim,
37
+ out_channels=dim,
38
+ kernel_size=kernel_size,
39
+ stride=1,
40
+ padding=(kernel_size-1)//2)
41
+ self.conv2 = nn.Conv2d(in_channels=dim,
42
+ out_channels=dim,
43
+ kernel_size=kernel_size,
44
+ stride=1,
45
+ padding=(kernel_size-1)//2)
46
+ self.conv3 = nn.Conv2d(in_channels=dim,
47
+ out_channels=dim,
48
+ kernel_size=kernel_size,
49
+ stride=1,
50
+ padding=(kernel_size-1)//2)
51
+ self.conv4 = nn.Conv2d(in_channels=dim,
52
+ out_channels=dim,
53
+ kernel_size=kernel_size,
54
+ stride=1,
55
+ padding=(kernel_size-1)//2)
56
+ self.bn1 = nn.BatchNorm2d(dim)
57
+ self.bn2 = nn.BatchNorm2d(dim)
58
+ self.bn3 = nn.BatchNorm2d(dim)
59
+ self.bn4 = nn.BatchNorm2d(dim)
60
+ def forward(self, x):
61
+ out = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.01, inplace=False)
62
+ out = F.leaky_relu(self.bn2(self.conv2(out)), negative_slope=0.01, inplace=False)
63
+ out = F.leaky_relu(self.bn3(self.conv3(out)), negative_slope=0.01, inplace=False)
64
+ out = F.leaky_relu(self.bn4(self.conv4(out)), negative_slope=0.01, inplace=False)
65
+ return out
66
+
67
+
68
+ class PatchEmbedLocalGlobal(nn.Module):
69
+ def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None):
70
+ super().__init__()
71
+ self.patch_size = patch_size
72
+
73
+ self.in_chans = in_chans
74
+ self.embed_dim = embed_dim
75
+
76
+ self.num_blocks = self.in_chans // patch_size[0]
77
+
78
+ self.head = nn.Conv2d(in_chans // self.num_blocks, embed_dim // 2, kernel_size=3, stride=1, padding=1)
79
+
80
+ self.global_head = nn.Conv2d(in_chans, embed_dim // 2, kernel_size=3, stride=1, padding=1)
81
+
82
+ self.residual_encoding = residual_feature_generator(embed_dim//2)
83
+ self.global_residual_encoding = residual_feature_generator(embed_dim//2)
84
+
85
+ self.proj = nn.Conv2d(embed_dim//2, embed_dim//2, kernel_size=3, stride=patch_size[1:], padding=1)
86
+ self.global_proj = nn.Conv2d(embed_dim//2, embed_dim//2, kernel_size=3, stride=patch_size[1:], padding=1)
87
+
88
+ if norm_layer is not None:
89
+ self.norm = norm_layer(embed_dim)
90
+ else:
91
+ self.norm = None
92
+
93
+ # patches_resolution = [224 // patch_size[1], 224 // patch_size[2]]
94
+ # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, self.num_blocks, patches_resolution[0], patches_resolution[1]))
95
+ # trunc_normal_(self.absolute_pos_embed, std=.02)
96
+
97
+ def forward(self, x):
98
+ """Forward function."""
99
+ # padding
100
+ B, C, H, W = x.size()
101
+ # if W % self.patch_size[2] != 0:
102
+ # x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
103
+ # if H % self.patch_size[1] != 0:
104
+ # x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
105
+ # if D % self.patch_size[0] != 0:
106
+ # x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
107
+ xs = x.chunk(self.num_blocks, 1)
108
+ outs = []
109
+ outi_global = self.global_head(x)
110
+ outi_global = self.global_residual_encoding(outi_global)
111
+ outi_global = self.global_proj(outi_global)
112
+
113
+ for i in range(self.num_blocks):
114
+ outi_local = self.head(xs[i])
115
+ outi_local = self.residual_encoding(outi_local)
116
+ outi_local = self.proj(outi_local)
117
+ outi = torch.cat([outi_local, outi_global], dim=1)
118
+ outi = outi.unsqueeze(2)
119
+ outs.append(outi)
120
+
121
+ out = torch.cat(outs, dim=2) # B, 96, 4, H, W
122
+
123
+ # x = self.proj(x) # B C D Wh Ww
124
+ if self.norm is not None:
125
+ D, Wh, Ww = out.size(2), out.size(3), out.size(4)
126
+ out = out.flatten(2).transpose(1, 2)
127
+ out = self.norm(out)
128
+ out = out.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
129
+
130
+ return out
131
+
132
+
133
+ class PatchEmbedConv(nn.Module):
134
+ def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None):
135
+ super().__init__()
136
+ self.patch_size = patch_size
137
+
138
+ self.in_chans = in_chans
139
+ self.embed_dim = embed_dim
140
+
141
+ self.num_blocks = self.in_chans // patch_size[0]
142
+
143
+ self.head = nn.Conv2d(in_chans // self.num_blocks, embed_dim, kernel_size=3, stride=1, padding=1)
144
+
145
+ self.residual_encoding = residual_feature_generator(embed_dim)
146
+
147
+ self.proj = nn.Conv2d(embed_dim, embed_dim, kernel_size=3, stride=patch_size[1:], padding=1)
148
+ if norm_layer is not None:
149
+ self.norm = norm_layer(embed_dim)
150
+ else:
151
+ self.norm = None
152
+
153
+ def forward(self, x):
154
+ """Forward function."""
155
+ # padding
156
+ B, C, H, W = x.size()
157
+ # if W % self.patch_size[2] != 0:
158
+ # x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
159
+ # if H % self.patch_size[1] != 0:
160
+ # x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
161
+ # if D % self.patch_size[0] != 0:
162
+ # x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
163
+ xs = x.chunk(self.num_blocks, 1)
164
+ outs = []
165
+
166
+ for i in range(self.num_blocks):
167
+ outi = self.head(xs[i])
168
+ outi = self.residual_encoding(outi)
169
+ outi = self.proj(outi)
170
+ outi = outi.unsqueeze(2)
171
+ outs.append(outi)
172
+
173
+ out = torch.cat(outs, dim=2) # B, 96, 4, H, W
174
+
175
+ # x = self.proj(x) # B C D Wh Ww
176
+ if self.norm is not None:
177
+ D, Wh, Ww = out.size(2), out.size(3), out.size(4)
178
+ out = out.flatten(2).transpose(1, 2)
179
+ out = self.norm(out)
180
+ out = out.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
181
+
182
+ return out
183
+
184
+
185
+ class Mlp(nn.Module):
186
+ """ Multilayer perceptron."""
187
+
188
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
189
+ super().__init__()
190
+ out_features = out_features or in_features
191
+ hidden_features = hidden_features or in_features
192
+ self.fc1 = nn.Linear(in_features, hidden_features)
193
+ self.act = act_layer()
194
+ self.fc2 = nn.Linear(hidden_features, out_features)
195
+ self.drop = nn.Dropout(drop)
196
+
197
+ def forward(self, x):
198
+ x = self.fc1(x)
199
+ x = self.act(x)
200
+ x = self.drop(x)
201
+ x = self.fc2(x)
202
+ x = self.drop(x)
203
+ return x
204
+
205
+
206
+ def window_partition(x, window_size):
207
+ """
208
+ Args:
209
+ x: (B, D, H, W, C)
210
+ window_size (tuple[int]): window size
211
+ Returns:
212
+ windows: (B*num_windows, window_size*window_size, C)
213
+ """
214
+ B, D, H, W, C = x.shape
215
+ x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C)
216
+ windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)
217
+ return windows
218
+
219
+
220
+ def window_reverse(windows, window_size, B, D, H, W):
221
+ """
222
+ Args:
223
+ windows: (B*num_windows, window_size, window_size, C)
224
+ window_size (tuple[int]): Window size
225
+ H (int): Height of image
226
+ W (int): Width of image
227
+ Returns:
228
+ x: (B, D, H, W, C)
229
+ """
230
+ x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], -1)
231
+ x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1)
232
+ return x
233
+
234
+
235
+ def get_window_size(x_size, window_size, shift_size=None):
236
+ use_window_size = list(window_size)
237
+ if shift_size is not None:
238
+ use_shift_size = list(shift_size)
239
+ for i in range(len(x_size)):
240
+ if x_size[i] <= window_size[i]:
241
+ use_window_size[i] = x_size[i]
242
+ if shift_size is not None:
243
+ use_shift_size[i] = 0
244
+
245
+ if shift_size is None:
246
+ return tuple(use_window_size)
247
+ else:
248
+ return tuple(use_window_size), tuple(use_shift_size)
249
+
250
+
251
+ class WindowAttention3D(nn.Module):
252
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
253
+ It supports both of shifted and non-shifted window.
254
+ Args:
255
+ dim (int): Number of input channels.
256
+ window_size (tuple[int]): The temporal length, height and width of the window.
257
+ num_heads (int): Number of attention heads.
258
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
259
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
260
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
261
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
262
+ """
263
+
264
+ def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
265
+
266
+ super().__init__()
267
+ self.dim = dim
268
+ self.window_size = window_size # Wd, Wh, Ww
269
+ self.num_heads = num_heads
270
+ head_dim = dim // num_heads
271
+ self.scale = qk_scale or head_dim ** -0.5
272
+
273
+ # define a parameter table of relative position bias
274
+ self.relative_position_bias_table = nn.Parameter(
275
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
276
+
277
+ # get pair-wise relative position index for each token inside the window
278
+ coords_d = torch.arange(self.window_size[0])
279
+ coords_h = torch.arange(self.window_size[1])
280
+ coords_w = torch.arange(self.window_size[2])
281
+ coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww
282
+ coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
283
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
284
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
285
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
286
+ relative_coords[:, :, 1] += self.window_size[1] - 1
287
+ relative_coords[:, :, 2] += self.window_size[2] - 1
288
+
289
+ relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
290
+ relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
291
+ relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
292
+ self.register_buffer("relative_position_index", relative_position_index)
293
+
294
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
295
+ self.attn_drop = nn.Dropout(attn_drop)
296
+ self.proj = nn.Linear(dim, dim)
297
+ self.proj_drop = nn.Dropout(proj_drop)
298
+
299
+ trunc_normal_(self.relative_position_bias_table, std=.02)
300
+ self.softmax = nn.Softmax(dim=-1)
301
+
302
+ def forward(self, x, mask=None):
303
+ """ Forward function.
304
+ Args:
305
+ x: input features with shape of (num_windows*B, N, C)
306
+ mask: (0/-inf) mask with shape of (num_windows, N, N) or None
307
+ """
308
+ B_, N, C = x.shape
309
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
310
+ q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C
311
+
312
+ q = q * self.scale
313
+ attn = q @ k.transpose(-2, -1)
314
+
315
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape(
316
+ N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH
317
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww
318
+ attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
319
+
320
+ if mask is not None:
321
+ nW = mask.shape[0]
322
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
323
+ attn = attn.view(-1, self.num_heads, N, N)
324
+ attn = self.softmax(attn)
325
+ else:
326
+ attn = self.softmax(attn)
327
+
328
+ attn = self.attn_drop(attn)
329
+ # print('attn: ', attn.shape, ', v: ', v.shape, ', x: ', x.shape)
330
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
331
+ x = self.proj(x)
332
+ x = self.proj_drop(x)
333
+ return x
334
+
335
+
336
+ class SwinTransformerBlock3D(nn.Module):
337
+ """ Swin Transformer Block.
338
+ Args:
339
+ dim (int): Number of input channels.
340
+ num_heads (int): Number of attention heads.
341
+ window_size (tuple[int]): Window size.
342
+ shift_size (tuple[int]): Shift size for SW-MSA.
343
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
344
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
345
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
346
+ drop (float, optional): Dropout rate. Default: 0.0
347
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
348
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
349
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
350
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
351
+ """
352
+
353
+ def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0),
354
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
355
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):
356
+ super().__init__()
357
+ self.dim = dim
358
+ self.num_heads = num_heads
359
+ self.window_size = window_size
360
+ self.shift_size = shift_size
361
+ self.mlp_ratio = mlp_ratio
362
+ self.use_checkpoint=use_checkpoint
363
+
364
+ assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size"
365
+ assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size"
366
+ assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size"
367
+
368
+ self.norm1 = norm_layer(dim)
369
+ self.attn = WindowAttention3D(
370
+ dim, window_size=self.window_size, num_heads=num_heads,
371
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
372
+
373
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
374
+ self.norm2 = norm_layer(dim)
375
+ mlp_hidden_dim = int(dim * mlp_ratio)
376
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
377
+
378
+ def forward_part1(self, x, mask_matrix):
379
+ B, D, H, W, C = x.shape
380
+ window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size)
381
+ # print('window_size: ', window_size, ', shift_size: ', shift_size)
382
+ x = self.norm1(x)
383
+ # pad feature maps to multiples of window size
384
+ pad_l = pad_t = pad_d0 = 0
385
+ pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0]
386
+ pad_b = (window_size[1] - H % window_size[1]) % window_size[1]
387
+ pad_r = (window_size[2] - W % window_size[2]) % window_size[2]
388
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
389
+ _, Dp, Hp, Wp, _ = x.shape
390
+ # cyclic shift
391
+ if any(i > 0 for i in shift_size):
392
+ shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
393
+ attn_mask = mask_matrix
394
+ else:
395
+ shifted_x = x
396
+ attn_mask = None
397
+ # partition windows
398
+ x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C
399
+ # print('shifted_x: ', shifted_x.shape, 'x_windows: ', x_windows.shape)
400
+ # W-MSA/SW-MSA
401
+ attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C
402
+ # merge windows
403
+ attn_windows = attn_windows.view(-1, *(window_size+(C,)))
404
+ # print('attn_windows: ', attn_windows.shape)
405
+ shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C
406
+ # reverse cyclic shift
407
+ if any(i > 0 for i in shift_size):
408
+ x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
409
+ else:
410
+ x = shifted_x
411
+
412
+ if pad_d1 >0 or pad_r > 0 or pad_b > 0:
413
+ x = x[:, :D, :H, :W, :].contiguous()
414
+ return x
415
+
416
+ def forward_part2(self, x):
417
+ return self.drop_path(self.mlp(self.norm2(x)))
418
+
419
+ def forward(self, x, mask_matrix):
420
+ """ Forward function.
421
+ Args:
422
+ x: Input feature, tensor size (B, D, H, W, C).
423
+ mask_matrix: Attention mask for cyclic shift.
424
+ """
425
+
426
+ shortcut = x
427
+ if self.use_checkpoint:
428
+ x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
429
+ else:
430
+ x = self.forward_part1(x, mask_matrix)
431
+ x = shortcut + self.drop_path(x)
432
+
433
+ if self.use_checkpoint:
434
+ x = x + checkpoint.checkpoint(self.forward_part2, x)
435
+ else:
436
+ x = x + self.forward_part2(x)
437
+
438
+ return x
439
+
440
+
441
+ class PatchMerging(nn.Module):
442
+ """ Patch Merging Layer
443
+ Args:
444
+ dim (int): Number of input channels.
445
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
446
+ """
447
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
448
+ super().__init__()
449
+ self.dim = dim
450
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
451
+ self.norm = norm_layer(4 * dim)
452
+
453
+ def forward(self, x):
454
+ """ Forward function.
455
+ Args:
456
+ x: Input feature, tensor size (B, D, H, W, C).
457
+ """
458
+ B, D, H, W, C = x.shape
459
+
460
+ # padding
461
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
462
+ if pad_input:
463
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
464
+
465
+ x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C
466
+ x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C
467
+ x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C
468
+ x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C
469
+ x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C
470
+
471
+ x = self.norm(x)
472
+ x = self.reduction(x)
473
+
474
+ return x
475
+
476
+
477
+ # cache each stage results
478
+ @lru_cache()
479
+ def compute_mask(D, H, W, window_size, shift_size, device):
480
+ img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1
481
+ cnt = 0
482
+ for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None):
483
+ for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None):
484
+ for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None):
485
+ img_mask[:, d, h, w, :] = cnt
486
+ cnt += 1
487
+ mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1
488
+ mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2]
489
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
490
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
491
+ return attn_mask
492
+
493
+
494
+ class BasicLayer(nn.Module):
495
+ """ A basic Swin Transformer layer for one stage.
496
+ Args:
497
+ dim (int): Number of feature channels
498
+ depth (int): Depths of this stage.
499
+ num_heads (int): Number of attention head.
500
+ window_size (tuple[int]): Local window size. Default: (1,7,7).
501
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
502
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
503
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
504
+ drop (float, optional): Dropout rate. Default: 0.0
505
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
506
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
507
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
508
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
509
+ """
510
+
511
+ def __init__(self,
512
+ dim,
513
+ depth,
514
+ num_heads,
515
+ window_size=(1,7,7),
516
+ mlp_ratio=4.,
517
+ qkv_bias=False,
518
+ qk_scale=None,
519
+ drop=0.,
520
+ attn_drop=0.,
521
+ drop_path=0.,
522
+ norm_layer=nn.LayerNorm,
523
+ downsample=None,
524
+ use_checkpoint=False):
525
+ super().__init__()
526
+ self.window_size = window_size
527
+ self.shift_size = tuple(i // 2 for i in window_size)
528
+ self.depth = depth
529
+ self.use_checkpoint = use_checkpoint
530
+
531
+ # build blocks
532
+ self.blocks = nn.ModuleList([
533
+ SwinTransformerBlock3D(
534
+ dim=dim,
535
+ num_heads=num_heads,
536
+ window_size=window_size,
537
+ shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size,
538
+ mlp_ratio=mlp_ratio,
539
+ qkv_bias=qkv_bias,
540
+ qk_scale=qk_scale,
541
+ drop=drop,
542
+ attn_drop=attn_drop,
543
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
544
+ norm_layer=norm_layer,
545
+ use_checkpoint=use_checkpoint,
546
+ )
547
+ for i in range(depth)])
548
+
549
+ self.downsample = downsample
550
+ if self.downsample is not None:
551
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
552
+
553
+ def forward(self, x):
554
+ """ Forward function.
555
+ Args:
556
+ x: Input feature, tensor size (B, C, D, H, W).
557
+ """
558
+ # calculate attention mask for SW-MSA
559
+ B, C, D, H, W = x.shape
560
+ window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size)
561
+ x = rearrange(x, 'b c d h w -> b d h w c')
562
+ Dp = int(np.ceil(D / window_size[0])) * window_size[0]
563
+ Hp = int(np.ceil(H / window_size[1])) * window_size[1]
564
+ Wp = int(np.ceil(W / window_size[2])) * window_size[2]
565
+ attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device)
566
+ for blk in self.blocks:
567
+ x = blk(x, attn_mask)
568
+ # print(x.shape)
569
+ x = x.view(B, D, H, W, -1)
570
+
571
+ if self.downsample is not None:
572
+ x_out = self.downsample(x)
573
+ else:
574
+ x_out = x
575
+ x_out = rearrange(x_out, 'b d h w c -> b c d h w')
576
+ return x_out, x
577
+
578
+
579
+ class PatchEmbed3D(nn.Module):
580
+ """ Video to Patch Embedding.
581
+ Args:
582
+ patch_size (int): Patch token size. Default: (2,4,4).
583
+ in_chans (int): Number of input video channels. Default: 3.
584
+ embed_dim (int): Number of linear projection output channels. Default: 96.
585
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
586
+ """
587
+ def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None):
588
+ super().__init__()
589
+ self.patch_size = patch_size
590
+
591
+ self.in_chans = in_chans
592
+ self.embed_dim = embed_dim
593
+
594
+ # self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
595
+ self.proj = nn.Conv3d(1, embed_dim, kernel_size=patch_size, stride=patch_size)
596
+ if norm_layer is not None:
597
+ self.norm = norm_layer(embed_dim)
598
+ else:
599
+ self.norm = None
600
+
601
+ def forward(self, x):
602
+ """Forward function."""
603
+ # padding
604
+ x = x.unsqueeze(1)
605
+ _, _, D, H, W = x.size()
606
+ if W % self.patch_size[2] != 0:
607
+ x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
608
+ if H % self.patch_size[1] != 0:
609
+ x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
610
+ if D % self.patch_size[0] != 0:
611
+ x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
612
+
613
+ x = self.proj(x) # B C D Wh Ww
614
+ if self.norm is not None:
615
+ D, Wh, Ww = x.size(2), x.size(3), x.size(4)
616
+ x = x.flatten(2).transpose(1, 2)
617
+ x = self.norm(x)
618
+ x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
619
+
620
+ return x
621
+
622
+
623
+ class SwinTransformer3D(nn.Module):
624
+ """ Swin Transformer backbone.
625
+ Args:
626
+ patch_size (int | tuple(int)): Patch size. Default: (4,4,4).
627
+ in_chans (int): Number of input image channels. Default: 3.
628
+ embed_dim (int): Number of linear projection output channels. Default: 96.
629
+ depths (tuple[int]): Depths of each Swin Transformer stage.
630
+ num_heads (tuple[int]): Number of attention head of each stage.
631
+ window_size (int): Window size. Default: 7.
632
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
633
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
634
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
635
+ drop_rate (float): Dropout rate.
636
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
637
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
638
+ norm_layer: Normalization layer. Default: nn.LayerNorm.
639
+ patch_norm (bool): If True, add normalization after patch embedding. Default: False.
640
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
641
+ -1 means not freezing any parameters.
642
+ """
643
+
644
+ def __init__(self,
645
+ pretrained=None,
646
+ pretrained2d=True,
647
+ patch_size=(4,4,4),
648
+ in_chans=3,
649
+ embed_dim=96,
650
+ depths=[2, 2, 6, 2],
651
+ num_heads=[3, 6, 12, 24],
652
+ window_size=(2,7,7),
653
+ mlp_ratio=4.,
654
+ qkv_bias=True,
655
+ qk_scale=None,
656
+ drop_rate=0.,
657
+ attn_drop_rate=0.,
658
+ drop_path_rate=0.2,
659
+ norm_layer=nn.LayerNorm,
660
+ patch_norm=False,
661
+ out_indices=(0,1,2,3),
662
+ frozen_stages=-1,
663
+ use_checkpoint=False,
664
+ new_version=0):
665
+ super().__init__()
666
+
667
+ self.pretrained = pretrained
668
+ self.pretrained2d = pretrained2d
669
+ self.num_layers = len(depths)
670
+ self.embed_dim = embed_dim
671
+ self.patch_norm = patch_norm
672
+ self.frozen_stages = frozen_stages
673
+ self.window_size = window_size
674
+ self.patch_size = patch_size
675
+ self.out_indices = out_indices
676
+
677
+ # split image into non-overlapping patches
678
+ if new_version==3:
679
+ print("---- new version 3 ----")
680
+ self.patch_embed = PatchEmbedConv(
681
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
682
+ norm_layer=norm_layer if self.patch_norm else None)
683
+ elif new_version==4:
684
+ print("---- new version 4 ----")
685
+ self.patch_embed = PatchEmbedLocalGlobal(
686
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
687
+ norm_layer=norm_layer if self.patch_norm else None)
688
+ else:
689
+ print("---- old version ----")
690
+ self.patch_embed = PatchEmbed3D(
691
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
692
+ norm_layer=norm_layer if self.patch_norm else None)
693
+
694
+ self.pos_drop = nn.Dropout(p=drop_rate)
695
+
696
+ # stochastic depth
697
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
698
+
699
+ # build layers
700
+ self.layers = nn.ModuleList()
701
+ for i_layer in range(self.num_layers):
702
+ layer = BasicLayer(
703
+ dim=int(embed_dim * 2**i_layer),
704
+ depth=depths[i_layer],
705
+ num_heads=num_heads[i_layer],
706
+ window_size=window_size,
707
+ mlp_ratio=mlp_ratio,
708
+ qkv_bias=qkv_bias,
709
+ qk_scale=qk_scale,
710
+ drop=drop_rate,
711
+ attn_drop=attn_drop_rate,
712
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
713
+ norm_layer=norm_layer,
714
+ downsample=PatchMerging if i_layer<self.num_layers-1 else None,
715
+ use_checkpoint=use_checkpoint)
716
+ self.layers.append(layer)
717
+
718
+ # self.num_features = int(embed_dim * 2**(self.num_layers-1))
719
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
720
+ self.num_features = num_features
721
+
722
+ # add a norm layer for each output
723
+ # self.norm = norm_layer(self.num_features)
724
+
725
+ # add a norm layer for each output
726
+ for i_layer in self.out_indices:
727
+ layer = norm_layer(self.num_features[i_layer])
728
+ layer_name = f'norm{i_layer}'
729
+ self.add_module(layer_name, layer)
730
+
731
+
732
+ def inflate_weights(self, logger):
733
+ """Inflate the swin2d parameters to swin3d.
734
+ The differences between swin3d and swin2d mainly lie in an extra
735
+ axis. To utilize the pretrained parameters in 2d model,
736
+ the weight of swin2d models should be inflated to fit in the shapes of
737
+ the 3d counterpart.
738
+ Args:
739
+ logger (logging.Logger): The logger used to print
740
+ debugging infomation.
741
+ """
742
+ checkpoint = torch.load(self.pretrained, map_location='cpu')
743
+ state_dict = checkpoint['model']
744
+
745
+ # delete relative_position_index since we always re-init it
746
+ relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
747
+ for k in relative_position_index_keys:
748
+ del state_dict[k]
749
+
750
+ # delete attn_mask since we always re-init it
751
+ attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
752
+ for k in attn_mask_keys:
753
+ del state_dict[k]
754
+
755
+ state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).repeat(1,1,self.patch_size[0],1,1) / self.patch_size[0]
756
+
757
+ # bicubic interpolate relative_position_bias_table if not match
758
+ relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
759
+ for k in relative_position_bias_table_keys:
760
+ relative_position_bias_table_pretrained = state_dict[k]
761
+ relative_position_bias_table_current = self.state_dict()[k]
762
+ L1, nH1 = relative_position_bias_table_pretrained.size()
763
+ L2, nH2 = relative_position_bias_table_current.size()
764
+ L2 = (2*self.window_size[1]-1) * (2*self.window_size[2]-1)
765
+ wd = self.window_size[0]
766
+ if nH1 != nH2:
767
+ logger.warning(f"Error in loading {k}, passing")
768
+ else:
769
+ if L1 != L2:
770
+ S1 = int(L1 ** 0.5)
771
+ relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
772
+ relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(2*self.window_size[1]-1, 2*self.window_size[2]-1),
773
+ mode='bicubic')
774
+ relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
775
+ state_dict[k] = relative_position_bias_table_pretrained.repeat(2*wd-1,1)
776
+
777
+ msg = self.load_state_dict(state_dict, strict=False)
778
+ logger.info(msg)
779
+ logger.info(f"=> loaded successfully '{self.pretrained}'")
780
+ del checkpoint
781
+ torch.cuda.empty_cache()
782
+
783
+ def init_weights(self, pretrained=None):
784
+ """Initialize the weights in backbone.
785
+ Args:
786
+ pretrained (str, optional): Path to pre-trained weights.
787
+ Defaults to None.
788
+ """
789
+ def _init_weights(m):
790
+ if isinstance(m, nn.Linear):
791
+ trunc_normal_(m.weight, std=.02)
792
+ if isinstance(m, nn.Linear) and m.bias is not None:
793
+ nn.init.constant_(m.bias, 0)
794
+ elif isinstance(m, nn.LayerNorm):
795
+ nn.init.constant_(m.bias, 0)
796
+ nn.init.constant_(m.weight, 1.0)
797
+
798
+ if pretrained:
799
+ self.pretrained = pretrained
800
+ if isinstance(self.pretrained, str):
801
+ self.apply(_init_weights)
802
+ logger = get_root_logger()
803
+ logger.info(f'load model from: {self.pretrained}')
804
+
805
+ if self.pretrained2d:
806
+ # Inflate 2D model into 3D model.
807
+ self.inflate_weights(logger)
808
+ else:
809
+ # Directly load 3D model.
810
+ load_checkpoint(self, self.pretrained, strict=False, logger=logger)
811
+ elif self.pretrained is None:
812
+ self.apply(_init_weights)
813
+ else:
814
+ raise TypeError('pretrained must be a str or None')
815
+
816
+ def forward(self, x):
817
+ """Forward function."""
818
+ x = self.patch_embed(x)
819
+ # print(x.shape)
820
+ x = self.pos_drop(x)
821
+
822
+ outs = []
823
+ for i, layer in enumerate(self.layers):
824
+ x, out_x = layer(x.contiguous())
825
+ # print('---- ', out_x.shape)
826
+ if i in self.out_indices:
827
+ norm_layer = getattr(self, f'norm{i}')
828
+ out_x = norm_layer(out_x)
829
+ _, Ti, Hi, Wi, Ci = out_x.shape
830
+ out = rearrange(out_x, 'n d h w c -> n c d h w')
831
+ outs.append(out)
832
+
833
+ return tuple(outs)