zzzzzeee commited on
Commit
c87565a
·
verified ·
1 Parent(s): aef0448

Update submodules.py

Browse files
Files changed (1) hide show
  1. submodules.py +416 -770
submodules.py CHANGED
@@ -1,833 +1,479 @@
1
  import torch
2
  import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torch.utils.checkpoint as checkpoint
5
- import numpy as np
6
- from timm.models.layers import DropPath, trunc_normal_
7
-
8
- from functools import reduce, lru_cache
9
- from operator import mul
10
- from einops import rearrange
11
-
12
- from model.submodules import ResidualBlock
13
-
14
-
15
- class residual_feature_generator(nn.Module):
16
- def __init__(self, dim):
17
- super(residual_feature_generator, self).__init__()
18
- self.dim = dim
19
- self.resblock1 = ResidualBlock(dim, dim, 1, norm='BN')
20
- self.resblock2 = ResidualBlock(dim, dim, 1, norm='BN')
21
- self.resblock3 = ResidualBlock(dim, dim, 1, norm='BN')
22
- self.resblock4 = ResidualBlock(dim, dim, 1, norm='BN')
23
- def forward(self, x):
24
- out = self.resblock1(x)
25
- out = self.resblock2(out)
26
- out = self.resblock3(out)
27
- out = self.resblock4(out)
28
- return out
29
 
 
 
 
 
 
30
 
31
- class feature_generator(nn.Module):
32
- def __init__(self, dim, kernel_size=3):
33
- super(feature_generator, self).__init__()
34
- self.dim = dim
35
- self.kernel_size = kernel_size
36
- self.conv1 = nn.Conv2d(in_channels=dim,
37
- out_channels=dim,
38
- kernel_size=kernel_size,
39
- stride=1,
40
- padding=(kernel_size-1)//2)
41
- self.conv2 = nn.Conv2d(in_channels=dim,
42
- out_channels=dim,
43
- kernel_size=kernel_size,
44
- stride=1,
45
- padding=(kernel_size-1)//2)
46
- self.conv3 = nn.Conv2d(in_channels=dim,
47
- out_channels=dim,
48
- kernel_size=kernel_size,
49
- stride=1,
50
- padding=(kernel_size-1)//2)
51
- self.conv4 = nn.Conv2d(in_channels=dim,
52
- out_channels=dim,
53
- kernel_size=kernel_size,
54
- stride=1,
55
- padding=(kernel_size-1)//2)
56
- self.bn1 = nn.BatchNorm2d(dim)
57
- self.bn2 = nn.BatchNorm2d(dim)
58
- self.bn3 = nn.BatchNorm2d(dim)
59
- self.bn4 = nn.BatchNorm2d(dim)
60
  def forward(self, x):
61
- out = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.01, inplace=False)
62
- out = F.leaky_relu(self.bn2(self.conv2(out)), negative_slope=0.01, inplace=False)
63
- out = F.leaky_relu(self.bn3(self.conv3(out)), negative_slope=0.01, inplace=False)
64
- out = F.leaky_relu(self.bn4(self.conv4(out)), negative_slope=0.01, inplace=False)
 
 
 
 
65
  return out
66
 
67
 
68
- class PatchEmbedLocalGlobal(nn.Module):
69
- def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None):
70
- super().__init__()
71
- self.patch_size = patch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- self.in_chans = in_chans
74
- self.embed_dim = embed_dim
75
 
76
- self.num_blocks = self.in_chans // patch_size[0]
 
77
 
78
- self.head = nn.Conv2d(in_chans // self.num_blocks, embed_dim // 2, kernel_size=3, stride=1, padding=1)
79
 
80
- self.global_head = nn.Conv2d(in_chans, embed_dim // 2, kernel_size=3, stride=1, padding=1)
81
 
82
- self.residual_encoding = residual_feature_generator(embed_dim//2)
83
- self.global_residual_encoding = residual_feature_generator(embed_dim//2)
 
84
 
85
- self.proj = nn.Conv2d(embed_dim//2, embed_dim//2, kernel_size=3, stride=patch_size[1:], padding=1)
86
- self.global_proj = nn.Conv2d(embed_dim//2, embed_dim//2, kernel_size=3, stride=patch_size[1:], padding=1)
87
 
88
- if norm_layer is not None:
89
- self.norm = norm_layer(embed_dim)
90
  else:
91
- self.norm = None
92
 
93
- # patches_resolution = [224 // patch_size[1], 224 // patch_size[2]]
94
- # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, self.num_blocks, patches_resolution[0], patches_resolution[1]))
95
- # trunc_normal_(self.absolute_pos_embed, std=.02)
 
 
96
 
97
  def forward(self, x):
98
- """Forward function."""
99
- # padding
100
- B, C, H, W = x.size()
101
- # if W % self.patch_size[2] != 0:
102
- # x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
103
- # if H % self.patch_size[1] != 0:
104
- # x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
105
- # if D % self.patch_size[0] != 0:
106
- # x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
107
- xs = x.chunk(self.num_blocks, 1)
108
- outs = []
109
- outi_global = self.global_head(x)
110
- outi_global = self.global_residual_encoding(outi_global)
111
- outi_global = self.global_proj(outi_global)
112
-
113
- for i in range(self.num_blocks):
114
- outi_local = self.head(xs[i])
115
- outi_local = self.residual_encoding(outi_local)
116
- outi_local = self.proj(outi_local)
117
- outi = torch.cat([outi_local, outi_global], dim=1)
118
- outi = outi.unsqueeze(2)
119
- outs.append(outi)
120
-
121
- out = torch.cat(outs, dim=2) # B, 96, 4, H, W
122
-
123
- # x = self.proj(x) # B C D Wh Ww
124
- if self.norm is not None:
125
- D, Wh, Ww = out.size(2), out.size(3), out.size(4)
126
- out = out.flatten(2).transpose(1, 2)
127
- out = self.norm(out)
128
- out = out.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
129
 
130
  return out
131
 
132
 
133
- class PatchEmbedConv(nn.Module):
134
- def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None):
135
- super().__init__()
136
- self.patch_size = patch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- self.in_chans = in_chans
139
- self.embed_dim = embed_dim
 
 
 
140
 
141
- self.num_blocks = self.in_chans // patch_size[0]
142
 
143
- self.head = nn.Conv2d(in_chans // self.num_blocks, embed_dim, kernel_size=3, stride=1, padding=1)
 
 
 
144
 
145
- self.residual_encoding = residual_feature_generator(embed_dim)
 
 
146
 
147
- self.proj = nn.Conv2d(embed_dim, embed_dim, kernel_size=3, stride=patch_size[1:], padding=1)
148
- if norm_layer is not None:
149
- self.norm = norm_layer(embed_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  else:
151
- self.norm = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  def forward(self, x):
154
- """Forward function."""
155
- # padding
156
- B, C, H, W = x.size()
157
- # if W % self.patch_size[2] != 0:
158
- # x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
159
- # if H % self.patch_size[1] != 0:
160
- # x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
161
- # if D % self.patch_size[0] != 0:
162
- # x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
163
- xs = x.chunk(self.num_blocks, 1)
164
- outs = []
165
-
166
- for i in range(self.num_blocks):
167
- outi = self.head(xs[i])
168
- outi = self.residual_encoding(outi)
169
- outi = self.proj(outi)
170
- outi = outi.unsqueeze(2)
171
- outs.append(outi)
172
-
173
- out = torch.cat(outs, dim=2) # B, 96, 4, H, W
174
-
175
- # x = self.proj(x) # B C D Wh Ww
176
- if self.norm is not None:
177
- D, Wh, Ww = out.size(2), out.size(3), out.size(4)
178
- out = out.flatten(2).transpose(1, 2)
179
- out = self.norm(out)
180
- out = out.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
181
-
182
  return out
183
 
184
 
185
- class Mlp(nn.Module):
186
- """ Multilayer perceptron."""
 
187
 
188
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  super().__init__()
190
- out_features = out_features or in_features
191
- hidden_features = hidden_features or in_features
192
- self.fc1 = nn.Linear(in_features, hidden_features)
193
- self.act = act_layer()
194
- self.fc2 = nn.Linear(hidden_features, out_features)
195
- self.drop = nn.Dropout(drop)
196
 
197
- def forward(self, x):
198
- x = self.fc1(x)
199
- x = self.act(x)
200
- x = self.drop(x)
201
- x = self.fc2(x)
202
- x = self.drop(x)
203
- return x
204
 
 
 
 
 
 
 
 
 
205
 
206
- def window_partition(x, window_size):
207
- """
208
- Args:
209
- x: (B, D, H, W, C)
210
- window_size (tuple[int]): window size
211
- Returns:
212
- windows: (B*num_windows, window_size*window_size, C)
213
- """
214
- B, D, H, W, C = x.shape
215
- x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C)
216
- windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C)
217
- return windows
218
 
 
 
 
 
 
 
 
 
 
219
 
220
- def window_reverse(windows, window_size, B, D, H, W):
221
- """
222
- Args:
223
- windows: (B*num_windows, window_size, window_size, C)
224
- window_size (tuple[int]): Window size
225
- H (int): Height of image
226
- W (int): Width of image
227
- Returns:
228
- x: (B, D, H, W, C)
229
- """
230
- x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], -1)
231
- x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1)
232
- return x
233
-
234
-
235
- def get_window_size(x_size, window_size, shift_size=None):
236
- use_window_size = list(window_size)
237
- if shift_size is not None:
238
- use_shift_size = list(shift_size)
239
- for i in range(len(x_size)):
240
- if x_size[i] <= window_size[i]:
241
- use_window_size[i] = x_size[i]
242
- if shift_size is not None:
243
- use_shift_size[i] = 0
244
-
245
- if shift_size is None:
246
- return tuple(use_window_size)
247
- else:
248
- return tuple(use_window_size), tuple(use_shift_size)
249
-
250
-
251
- class WindowAttention3D(nn.Module):
252
- """ Window based multi-head self attention (W-MSA) module with relative position bias.
253
- It supports both of shifted and non-shifted window.
254
- Args:
255
- dim (int): Number of input channels.
256
- window_size (tuple[int]): The temporal length, height and width of the window.
257
- num_heads (int): Number of attention heads.
258
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
259
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
260
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
261
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
262
- """
263
 
264
- def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
 
 
265
 
266
- super().__init__()
267
- self.dim = dim
268
- self.window_size = window_size # Wd, Wh, Ww
269
- self.num_heads = num_heads
270
- head_dim = dim // num_heads
271
- self.scale = qk_scale or head_dim ** -0.5
272
-
273
- # define a parameter table of relative position bias
274
- self.relative_position_bias_table = nn.Parameter(
275
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
276
-
277
- # get pair-wise relative position index for each token inside the window
278
- coords_d = torch.arange(self.window_size[0])
279
- coords_h = torch.arange(self.window_size[1])
280
- coords_w = torch.arange(self.window_size[2])
281
- coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww
282
- coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
283
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
284
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
285
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
286
- relative_coords[:, :, 1] += self.window_size[1] - 1
287
- relative_coords[:, :, 2] += self.window_size[2] - 1
288
-
289
- relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
290
- relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
291
- relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
292
- self.register_buffer("relative_position_index", relative_position_index)
293
-
294
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
295
- self.attn_drop = nn.Dropout(attn_drop)
296
- self.proj = nn.Linear(dim, dim)
297
- self.proj_drop = nn.Dropout(proj_drop)
298
-
299
- trunc_normal_(self.relative_position_bias_table, std=.02)
300
- self.softmax = nn.Softmax(dim=-1)
301
-
302
- def forward(self, x, mask=None):
303
- """ Forward function.
304
- Args:
305
- x: input features with shape of (num_windows*B, N, C)
306
- mask: (0/-inf) mask with shape of (num_windows, N, N) or None
307
- """
308
- B_, N, C = x.shape
309
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
310
- q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C
311
-
312
- q = q * self.scale
313
- attn = q @ k.transpose(-2, -1)
314
-
315
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape(
316
- N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH
317
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww
318
- attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
319
-
320
- if mask is not None:
321
- nW = mask.shape[0]
322
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
323
- attn = attn.view(-1, self.num_heads, N, N)
324
- attn = self.softmax(attn)
325
- else:
326
- attn = self.softmax(attn)
327
-
328
- attn = self.attn_drop(attn)
329
- # print('attn: ', attn.shape, ', v: ', v.shape, ', x: ', x.shape)
330
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
331
- x = self.proj(x)
332
- x = self.proj_drop(x)
333
- return x
334
-
335
-
336
- class SwinTransformerBlock3D(nn.Module):
337
- """ Swin Transformer Block.
338
- Args:
339
- dim (int): Number of input channels.
340
- num_heads (int): Number of attention heads.
341
- window_size (tuple[int]): Window size.
342
- shift_size (tuple[int]): Shift size for SW-MSA.
343
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
344
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
345
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
346
- drop (float, optional): Dropout rate. Default: 0.0
347
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
348
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
349
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
350
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
351
- """
352
 
353
- def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0),
354
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
355
- act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):
356
- super().__init__()
357
- self.dim = dim
358
- self.num_heads = num_heads
359
- self.window_size = window_size
360
- self.shift_size = shift_size
361
- self.mlp_ratio = mlp_ratio
362
- self.use_checkpoint=use_checkpoint
363
-
364
- assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size"
365
- assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size"
366
- assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size"
367
-
368
- self.norm1 = norm_layer(dim)
369
- self.attn = WindowAttention3D(
370
- dim, window_size=self.window_size, num_heads=num_heads,
371
- qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
372
-
373
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
374
- self.norm2 = norm_layer(dim)
375
- mlp_hidden_dim = int(dim * mlp_ratio)
376
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
377
-
378
- def forward_part1(self, x, mask_matrix):
379
- B, D, H, W, C = x.shape
380
- window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size)
381
- # print('window_size: ', window_size, ', shift_size: ', shift_size)
382
- x = self.norm1(x)
383
- # pad feature maps to multiples of window size
384
- pad_l = pad_t = pad_d0 = 0
385
- pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0]
386
- pad_b = (window_size[1] - H % window_size[1]) % window_size[1]
387
- pad_r = (window_size[2] - W % window_size[2]) % window_size[2]
388
- x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
389
- _, Dp, Hp, Wp, _ = x.shape
390
- # cyclic shift
391
- if any(i > 0 for i in shift_size):
392
- shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
393
- attn_mask = mask_matrix
394
- else:
395
- shifted_x = x
396
- attn_mask = None
397
- # partition windows
398
- x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C
399
- # print('shifted_x: ', shifted_x.shape, 'x_windows: ', x_windows.shape)
400
- # W-MSA/SW-MSA
401
- attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C
402
- # merge windows
403
- attn_windows = attn_windows.view(-1, *(window_size+(C,)))
404
- # print('attn_windows: ', attn_windows.shape)
405
- shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C
406
- # reverse cyclic shift
407
- if any(i > 0 for i in shift_size):
408
- x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
409
- else:
410
- x = shifted_x
411
 
412
- if pad_d1 >0 or pad_r > 0 or pad_b > 0:
413
- x = x[:, :D, :H, :W, :].contiguous()
414
- return x
 
 
415
 
416
- def forward_part2(self, x):
417
- return self.drop_path(self.mlp(self.norm2(x)))
418
 
419
- def forward(self, x, mask_matrix):
420
- """ Forward function.
421
- Args:
422
- x: Input feature, tensor size (B, D, H, W, C).
423
- mask_matrix: Attention mask for cyclic shift.
424
- """
425
 
426
- shortcut = x
427
- if self.use_checkpoint:
428
- x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
429
- else:
430
- x = self.forward_part1(x, mask_matrix)
431
- x = shortcut + self.drop_path(x)
432
 
433
- if self.use_checkpoint:
434
- x = x + checkpoint.checkpoint(self.forward_part2, x)
435
- else:
436
- x = x + self.forward_part2(x)
437
 
438
- return x
 
 
439
 
 
 
440
 
441
- class PatchMerging(nn.Module):
442
- """ Patch Merging Layer
443
- Args:
444
- dim (int): Number of input channels.
445
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
446
- """
447
- def __init__(self, dim, norm_layer=nn.LayerNorm):
448
- super().__init__()
449
- self.dim = dim
450
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
451
- self.norm = norm_layer(4 * dim)
452
 
453
- def forward(self, x):
454
- """ Forward function.
455
- Args:
456
- x: Input feature, tensor size (B, D, H, W, C).
457
- """
458
- B, D, H, W, C = x.shape
459
-
460
- # padding
461
- pad_input = (H % 2 == 1) or (W % 2 == 1)
462
- if pad_input:
463
- x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
464
-
465
- x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C
466
- x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C
467
- x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C
468
- x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C
469
- x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C
470
-
471
- x = self.norm(x)
472
- x = self.reduction(x)
473
-
474
- return x
475
-
476
-
477
- # cache each stage results
478
- @lru_cache()
479
- def compute_mask(D, H, W, window_size, shift_size, device):
480
- img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1
481
- cnt = 0
482
- for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None):
483
- for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None):
484
- for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None):
485
- img_mask[:, d, h, w, :] = cnt
486
- cnt += 1
487
- mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1
488
- mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2]
489
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
490
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
491
- return attn_mask
492
-
493
-
494
- class BasicLayer(nn.Module):
495
- """ A basic Swin Transformer layer for one stage.
496
- Args:
497
- dim (int): Number of feature channels
498
- depth (int): Depths of this stage.
499
- num_heads (int): Number of attention head.
500
- window_size (tuple[int]): Local window size. Default: (1,7,7).
501
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
502
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
503
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
504
- drop (float, optional): Dropout rate. Default: 0.0
505
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
506
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
507
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
508
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
509
- """
510
 
511
- def __init__(self,
512
- dim,
513
- depth,
514
- num_heads,
515
- window_size=(1,7,7),
516
- mlp_ratio=4.,
517
- qkv_bias=False,
518
- qk_scale=None,
519
- drop=0.,
520
- attn_drop=0.,
521
- drop_path=0.,
522
- norm_layer=nn.LayerNorm,
523
- downsample=None,
524
- use_checkpoint=False):
525
- super().__init__()
526
- self.window_size = window_size
527
- self.shift_size = tuple(i // 2 for i in window_size)
528
- self.depth = depth
529
- self.use_checkpoint = use_checkpoint
530
-
531
- # build blocks
532
- self.blocks = nn.ModuleList([
533
- SwinTransformerBlock3D(
534
- dim=dim,
535
- num_heads=num_heads,
536
- window_size=window_size,
537
- shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size,
538
- mlp_ratio=mlp_ratio,
539
- qkv_bias=qkv_bias,
540
- qk_scale=qk_scale,
541
- drop=drop,
542
- attn_drop=attn_drop,
543
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
544
- norm_layer=norm_layer,
545
- use_checkpoint=use_checkpoint,
546
- )
547
- for i in range(depth)])
548
-
549
- self.downsample = downsample
550
- if self.downsample is not None:
551
- self.downsample = downsample(dim=dim, norm_layer=norm_layer)
552
 
553
- def forward(self, x):
554
- """ Forward function.
555
- Args:
556
- x: Input feature, tensor size (B, C, D, H, W).
557
- """
558
- # calculate attention mask for SW-MSA
559
- B, C, D, H, W = x.shape
560
- window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size)
561
- x = rearrange(x, 'b c d h w -> b d h w c')
562
- Dp = int(np.ceil(D / window_size[0])) * window_size[0]
563
- Hp = int(np.ceil(H / window_size[1])) * window_size[1]
564
- Wp = int(np.ceil(W / window_size[2])) * window_size[2]
565
- attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device)
566
- for blk in self.blocks:
567
- x = blk(x, attn_mask)
568
- # print(x.shape)
569
- x = x.view(B, D, H, W, -1)
570
-
571
- if self.downsample is not None:
572
- x_out = self.downsample(x)
573
- else:
574
- x_out = x
575
- x_out = rearrange(x_out, 'b d h w c -> b c d h w')
576
- return x_out, x
577
-
578
-
579
- class PatchEmbed3D(nn.Module):
580
- """ Video to Patch Embedding.
581
- Args:
582
- patch_size (int): Patch token size. Default: (2,4,4).
583
- in_chans (int): Number of input video channels. Default: 3.
584
- embed_dim (int): Number of linear projection output channels. Default: 96.
585
- norm_layer (nn.Module, optional): Normalization layer. Default: None
586
- """
587
- def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None):
588
- super().__init__()
589
- self.patch_size = patch_size
590
 
591
- self.in_chans = in_chans
592
- self.embed_dim = embed_dim
 
593
 
594
- # self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
595
- self.proj = nn.Conv3d(1, embed_dim, kernel_size=patch_size, stride=patch_size)
596
- if norm_layer is not None:
597
- self.norm = norm_layer(embed_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
  else:
599
- self.norm = None
600
 
601
- def forward(self, x):
602
- """Forward function."""
603
- # padding
604
- x = x.unsqueeze(1)
605
- _, _, D, H, W = x.size()
606
- if W % self.patch_size[2] != 0:
607
- x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
608
- if H % self.patch_size[1] != 0:
609
- x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
610
- if D % self.patch_size[0] != 0:
611
- x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
612
-
613
- x = self.proj(x) # B C D Wh Ww
614
- if self.norm is not None:
615
- D, Wh, Ww = x.size(2), x.size(3), x.size(4)
616
- x = x.flatten(2).transpose(1, 2)
617
- x = self.norm(x)
618
- x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
619
-
620
- return x
621
-
622
-
623
- class SwinTransformer3D(nn.Module):
624
- """ Swin Transformer backbone.
625
- Args:
626
- patch_size (int | tuple(int)): Patch size. Default: (4,4,4).
627
- in_chans (int): Number of input image channels. Default: 3.
628
- embed_dim (int): Number of linear projection output channels. Default: 96.
629
- depths (tuple[int]): Depths of each Swin Transformer stage.
630
- num_heads (tuple[int]): Number of attention head of each stage.
631
- window_size (int): Window size. Default: 7.
632
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
633
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
634
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
635
- drop_rate (float): Dropout rate.
636
- attn_drop_rate (float): Attention dropout rate. Default: 0.
637
- drop_path_rate (float): Stochastic depth rate. Default: 0.2.
638
- norm_layer: Normalization layer. Default: nn.LayerNorm.
639
- patch_norm (bool): If True, add normalization after patch embedding. Default: False.
640
- frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
641
- -1 means not freezing any parameters.
642
  """
643
 
644
- def __init__(self,
645
- pretrained=None,
646
- pretrained2d=True,
647
- patch_size=(4,4,4),
648
- in_chans=3,
649
- embed_dim=96,
650
- depths=[2, 2, 6, 2],
651
- num_heads=[3, 6, 12, 24],
652
- window_size=(2,7,7),
653
- mlp_ratio=4.,
654
- qkv_bias=True,
655
- qk_scale=None,
656
- drop_rate=0.,
657
- attn_drop_rate=0.,
658
- drop_path_rate=0.2,
659
- norm_layer=nn.LayerNorm,
660
- patch_norm=False,
661
- out_indices=(0,1,2,3),
662
- frozen_stages=-1,
663
- use_checkpoint=False,
664
- new_version=0):
665
  super().__init__()
666
-
667
- self.pretrained = pretrained
668
- self.pretrained2d = pretrained2d
669
- self.num_layers = len(depths)
670
- self.embed_dim = embed_dim
671
- self.patch_norm = patch_norm
672
- self.frozen_stages = frozen_stages
673
- self.window_size = window_size
674
- self.patch_size = patch_size
675
- self.out_indices = out_indices
676
-
677
- # split image into non-overlapping patches
678
- if new_version==3:
679
- print("---- new version 3 ----")
680
- self.patch_embed = PatchEmbedConv(
681
- patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
682
- norm_layer=norm_layer if self.patch_norm else None)
683
- elif new_version==4:
684
- print("---- new version 4 ----")
685
- self.patch_embed = PatchEmbedLocalGlobal(
686
- patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
687
- norm_layer=norm_layer if self.patch_norm else None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688
  else:
689
- print("---- old version ----")
690
- self.patch_embed = PatchEmbed3D(
691
- patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
692
- norm_layer=norm_layer if self.patch_norm else None)
693
-
694
- self.pos_drop = nn.Dropout(p=drop_rate)
695
-
696
- # stochastic depth
697
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
698
-
699
- # build layers
700
- self.layers = nn.ModuleList()
701
- for i_layer in range(self.num_layers):
702
- layer = BasicLayer(
703
- dim=int(embed_dim * 2**i_layer),
704
- depth=depths[i_layer],
705
- num_heads=num_heads[i_layer],
706
- window_size=window_size,
707
- mlp_ratio=mlp_ratio,
708
- qkv_bias=qkv_bias,
709
- qk_scale=qk_scale,
710
- drop=drop_rate,
711
- attn_drop=attn_drop_rate,
712
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
713
- norm_layer=norm_layer,
714
- downsample=PatchMerging if i_layer<self.num_layers-1 else None,
715
- use_checkpoint=use_checkpoint)
716
- self.layers.append(layer)
717
-
718
- # self.num_features = int(embed_dim * 2**(self.num_layers-1))
719
- num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
720
- self.num_features = num_features
721
-
722
- # add a norm layer for each output
723
- # self.norm = norm_layer(self.num_features)
724
-
725
- # add a norm layer for each output
726
- for i_layer in self.out_indices:
727
- layer = norm_layer(self.num_features[i_layer])
728
- layer_name = f'norm{i_layer}'
729
- self.add_module(layer_name, layer)
730
-
731
-
732
- def inflate_weights(self, logger):
733
- """Inflate the swin2d parameters to swin3d.
734
- The differences between swin3d and swin2d mainly lie in an extra
735
- axis. To utilize the pretrained parameters in 2d model,
736
- the weight of swin2d models should be inflated to fit in the shapes of
737
- the 3d counterpart.
738
- Args:
739
- logger (logging.Logger): The logger used to print
740
- debugging infomation.
741
- """
742
- checkpoint = torch.load(self.pretrained, map_location='cpu')
743
- state_dict = checkpoint['model']
744
-
745
- # delete relative_position_index since we always re-init it
746
- relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
747
- for k in relative_position_index_keys:
748
- del state_dict[k]
749
-
750
- # delete attn_mask since we always re-init it
751
- attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
752
- for k in attn_mask_keys:
753
- del state_dict[k]
754
-
755
- state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).repeat(1,1,self.patch_size[0],1,1) / self.patch_size[0]
756
-
757
- # bicubic interpolate relative_position_bias_table if not match
758
- relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
759
- for k in relative_position_bias_table_keys:
760
- relative_position_bias_table_pretrained = state_dict[k]
761
- relative_position_bias_table_current = self.state_dict()[k]
762
- L1, nH1 = relative_position_bias_table_pretrained.size()
763
- L2, nH2 = relative_position_bias_table_current.size()
764
- L2 = (2*self.window_size[1]-1) * (2*self.window_size[2]-1)
765
- wd = self.window_size[0]
766
- if nH1 != nH2:
767
- logger.warning(f"Error in loading {k}, passing")
768
- else:
769
- if L1 != L2:
770
- S1 = int(L1 ** 0.5)
771
- relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
772
- relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(2*self.window_size[1]-1, 2*self.window_size[2]-1),
773
- mode='bicubic')
774
- relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
775
- state_dict[k] = relative_position_bias_table_pretrained.repeat(2*wd-1,1)
776
-
777
- msg = self.load_state_dict(state_dict, strict=False)
778
- logger.info(msg)
779
- logger.info(f"=> loaded successfully '{self.pretrained}'")
780
- del checkpoint
781
- torch.cuda.empty_cache()
782
-
783
- def init_weights(self, pretrained=None):
784
- """Initialize the weights in backbone.
785
- Args:
786
- pretrained (str, optional): Path to pre-trained weights.
787
- Defaults to None.
788
- """
789
- def _init_weights(m):
790
- if isinstance(m, nn.Linear):
791
- trunc_normal_(m.weight, std=.02)
792
- if isinstance(m, nn.Linear) and m.bias is not None:
793
- nn.init.constant_(m.bias, 0)
794
- elif isinstance(m, nn.LayerNorm):
795
- nn.init.constant_(m.bias, 0)
796
- nn.init.constant_(m.weight, 1.0)
797
-
798
- if pretrained:
799
- self.pretrained = pretrained
800
- if isinstance(self.pretrained, str):
801
- self.apply(_init_weights)
802
- logger = get_root_logger()
803
- logger.info(f'load model from: {self.pretrained}')
804
-
805
- if self.pretrained2d:
806
- # Inflate 2D model into 3D model.
807
- self.inflate_weights(logger)
808
- else:
809
- # Directly load 3D model.
810
- load_checkpoint(self, self.pretrained, strict=False, logger=logger)
811
- elif self.pretrained is None:
812
- self.apply(_init_weights)
813
- else:
814
- raise TypeError('pretrained must be a str or None')
815
-
816
- def forward(self, x):
817
- """Forward function."""
818
- x = self.patch_embed(x)
819
- # print(x.shape)
820
- x = self.pos_drop(x)
821
-
822
- outs = []
823
- for i, layer in enumerate(self.layers):
824
- x, out_x = layer(x.contiguous())
825
- # print('---- ', out_x.shape)
826
- if i in self.out_indices:
827
- norm_layer = getattr(self, f'norm{i}')
828
- out_x = norm_layer(out_x)
829
- _, Ti, Hi, Wi, Ci = out_x.shape
830
- out = rearrange(out_x, 'n d h w c -> n c d h w')
831
- outs.append(out)
832
-
833
- return tuple(outs)
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as f
4
+ from torch.nn import init
5
+ import math
6
+
7
+
8
+ class ConvLayer(nn.Module):
9
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None,
10
+ BN_momentum=0.1):
11
+ super(ConvLayer, self).__init__()
12
+
13
+ bias = False if norm == 'BN' else True
14
+ self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
15
+ if activation is not None:
16
+ self.activation = getattr(torch, activation)
17
+ else:
18
+ self.activation = None
 
 
 
 
 
 
 
 
 
 
19
 
20
+ self.norm = norm
21
+ if norm == 'BN':
22
+ self.norm_layer = nn.BatchNorm2d(out_channels, momentum=BN_momentum)
23
+ elif norm == 'IN':
24
+ self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def forward(self, x):
27
+ out = self.conv2d(x)
28
+
29
+ if self.norm in ['BN', 'IN']:
30
+ out = self.norm_layer(out)
31
+
32
+ if self.activation is not None:
33
+ out = self.activation(out)
34
+
35
  return out
36
 
37
 
38
+ class TransposedConvLayer(nn.Module):
39
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None):
40
+ super(TransposedConvLayer, self).__init__()
41
+
42
+ bias = False if norm == 'BN' else True
43
+ self.transposed_conv2d = nn.ConvTranspose2d(
44
+ in_channels, out_channels, kernel_size, stride=2, padding=padding, output_padding=1, bias=bias)
45
+
46
+ if activation is not None:
47
+ self.activation = getattr(torch, activation)
48
+ else:
49
+ self.activation = None
50
+
51
+ self.norm = norm
52
+ if norm == 'BN':
53
+ self.norm_layer = nn.BatchNorm2d(out_channels)
54
+ elif norm == 'IN':
55
+ self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True)
56
+
57
+ def forward(self, x):
58
+ out = self.transposed_conv2d(x)
59
 
60
+ if self.norm in ['BN', 'IN']:
61
+ out = self.norm_layer(out)
62
 
63
+ if self.activation is not None:
64
+ out = self.activation(out)
65
 
66
+ return out
67
 
 
68
 
69
+ class UpsampleConvLayer(nn.Module):
70
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None):
71
+ super(UpsampleConvLayer, self).__init__()
72
 
73
+ bias = False if norm == 'BN' else True
74
+ self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
75
 
76
+ if activation is not None:
77
+ self.activation = getattr(torch, activation)
78
  else:
79
+ self.activation = None
80
 
81
+ self.norm = norm
82
+ if norm == 'BN':
83
+ self.norm_layer = nn.BatchNorm2d(out_channels)
84
+ elif norm == 'IN':
85
+ self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True)
86
 
87
  def forward(self, x):
88
+ x_upsampled = f.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
89
+ out = self.conv2d(x_upsampled)
90
+
91
+ if self.norm in ['BN', 'IN']:
92
+ out = self.norm_layer(out)
93
+
94
+ if self.activation is not None:
95
+ out = self.activation(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  return out
98
 
99
 
100
+ class RecurrentConvLayer(nn.Module):
101
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
102
+ recurrent_block_type='convlstm', activation='relu', norm=None, BN_momentum=0.1):
103
+ super(RecurrentConvLayer, self).__init__()
104
+
105
+ assert(recurrent_block_type in ['convlstm', 'convgru'])
106
+ self.recurrent_block_type = recurrent_block_type
107
+ if self.recurrent_block_type == 'convlstm':
108
+ RecurrentBlock = ConvLSTM
109
+ else:
110
+ RecurrentBlock = ConvGRU
111
+
112
+ # self.conv = ConvLayer(in_channels, out_channels, kernel_size, stride, padding, activation, norm,
113
+ # BN_momentum=BN_momentum)
114
+ self.recurrent_block = RecurrentBlock(input_size=out_channels, hidden_size=out_channels, kernel_size=3)
115
+
116
+ def forward(self, x, prev_state):
117
+ # x = self.conv(x)
118
+ state = self.recurrent_block(x, prev_state)
119
+ x = state[0] if self.recurrent_block_type == 'convlstm' else state
120
+ return x, state
121
+
122
+ class Recurrent2ConvLayer(nn.Module):
123
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
124
+ recurrent_block_type='convlstm', activation='relu', norm=None, BN_momentum=0.1):
125
+ super(Recurrent2ConvLayer, self).__init__()
126
+
127
+ assert(recurrent_block_type in ['convlstm', 'convgru'])
128
+ self.recurrent_block_type = recurrent_block_type
129
+ if self.recurrent_block_type == 'convlstm':
130
+ RecurrentBlock = ConvLSTM
131
+ else:
132
+ RecurrentBlock = ConvGRU
133
+
134
+ self.conv = ConvLayer(in_channels, out_channels, kernel_size, stride, padding, activation, norm,
135
+ BN_momentum=BN_momentum)
136
+ self.recurrent_block = RecurrentBlock(input_size=out_channels, hidden_size=out_channels, kernel_size=3)
137
 
138
+ def forward(self, x, prev_state):
139
+ x = self.conv(x)
140
+ state = self.recurrent_block(x, prev_state)
141
+ x = state[0] if self.recurrent_block_type == 'convlstm' else state
142
+ return x, state
143
 
 
144
 
145
+ class RecurrentPhasedConvLayer(nn.Module):
146
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
147
+ activation='relu', norm=None, BN_momentum=0.1):
148
+ super(RecurrentPhasedConvLayer, self).__init__()
149
 
150
+ self.conv = ConvLayer(in_channels, out_channels, kernel_size, stride, padding, activation, norm,
151
+ BN_momentum=BN_momentum)
152
+ self.recurrent_block = PhasedConvLSTMCell(input_channels=out_channels, hidden_channels=out_channels, kernel_size=3)
153
 
154
+ def forward(self, x, times, prev_state):
155
+ x = self.conv(x)
156
+ x, state = self.recurrent_block(x, times, prev_state)
157
+ return x, state
158
+
159
+
160
+ class DownsampleRecurrentConvLayer(nn.Module):
161
+ def __init__(self, in_channels, out_channels, kernel_size=3, recurrent_block_type='convlstm', padding=0, activation='relu'):
162
+ super(DownsampleRecurrentConvLayer, self).__init__()
163
+
164
+ self.activation = getattr(torch, activation)
165
+
166
+ assert(recurrent_block_type in ['convlstm', 'convgru'])
167
+ self.recurrent_block_type = recurrent_block_type
168
+ if self.recurrent_block_type == 'convlstm':
169
+ RecurrentBlock = ConvLSTM
170
  else:
171
+ RecurrentBlock = ConvGRU
172
+ self.recurrent_block = RecurrentBlock(input_size=in_channels, hidden_size=out_channels, kernel_size=kernel_size)
173
+
174
+ def forward(self, x, prev_state):
175
+ state = self.recurrent_block(x, prev_state)
176
+ x = state[0] if self.recurrent_block_type == 'convlstm' else state
177
+ x = f.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
178
+ return self.activation(x), state
179
+
180
+
181
+ # Residual block
182
+ class ResidualBlock(nn.Module):
183
+ def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm=None,
184
+ BN_momentum=0.1):
185
+ super(ResidualBlock, self).__init__()
186
+ bias = False if norm == 'BN' else True
187
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=bias)
188
+ self.norm = norm
189
+ if norm == 'BN':
190
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=BN_momentum)
191
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=BN_momentum)
192
+ elif norm == 'IN':
193
+ self.bn1 = nn.InstanceNorm2d(out_channels)
194
+ self.bn2 = nn.InstanceNorm2d(out_channels)
195
+
196
+ self.relu = nn.ReLU(inplace=False)
197
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
198
+ self.downsample = downsample
199
 
200
  def forward(self, x):
201
+ residual = x
202
+ out = self.conv1(x)
203
+ if self.norm in ['BN', 'IN']:
204
+ out = self.bn1(out)
205
+ out = self.relu(out)
206
+ out = self.conv2(out)
207
+ if self.norm in ['BN', 'IN']:
208
+ out = self.bn2(out)
209
+
210
+ if self.downsample:
211
+ residual = self.downsample(x)
212
+
213
+ out += residual
214
+ out = self.relu(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  return out
216
 
217
 
218
+ class PhasedLSTMCell(nn.Module):
219
+ """Phased LSTM recurrent network cell.
220
+ """
221
 
222
+ def __init__(
223
+ self,
224
+ hidden_size,
225
+ leak=0.001,
226
+ ratio_on=0.1,
227
+ period_init_min=0.02,
228
+ period_init_max=50.0
229
+ ):
230
+ """
231
+ Args:
232
+ hidden_size: int, The number of units in the Phased LSTM cell.
233
+ leak: float or scalar float Tensor with value in [0, 1]. Leak applied
234
+ during training.
235
+ ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
236
+ period during which the gates are open.
237
+ period_init_min: float or scalar float Tensor. With value > 0.
238
+ Minimum value of the initialized period.
239
+ The period values are initialized by drawing from the distribution:
240
+ e^U(log(period_init_min), log(period_init_max))
241
+ Where U(.,.) is the uniform distribution.
242
+ period_init_max: float or scalar float Tensor.
243
+ With value > period_init_min. Maximum value of the initialized period.
244
+ """
245
  super().__init__()
 
 
 
 
 
 
246
 
247
+ self.hidden_size = hidden_size
248
+ self.ratio_on = ratio_on
249
+ self.leak = leak
 
 
 
 
250
 
251
+ # initialize time-gating parameters
252
+ period = torch.exp(
253
+ torch.Tensor(hidden_size).uniform_(
254
+ math.log(period_init_min), math.log(period_init_max)
255
+ )
256
+ )
257
+ #self.tau = nn.Parameter(period)
258
+ self.register_parameter("tau", nn.Parameter(period))
259
 
260
+ phase = torch.Tensor(hidden_size).uniform_() * period
261
+ self.register_parameter("phase", nn.Parameter(phase))
262
+ self.phase.requires_grad = True
263
+ self.tau.requires_grad = True
264
+ #self.phase = nn.Parameter(phase)
 
 
 
 
 
 
 
265
 
266
+ def _compute_phi(self, t):
267
+ t_ = t.view(-1, 1).repeat(1, self.hidden_size)
268
+ phase_ = self.phase.view(1, -1).repeat(t.shape[0], 1)
269
+ tau_ = self.tau.view(1, -1).repeat(t.shape[0], 1)
270
+ tau_.to(t_.device)
271
+ phase_.to(t_.device)
272
+ phi = self._mod((t_ - phase_), tau_)
273
+ phi = torch.abs(phi) / tau_
274
+ return phi
275
 
276
+ def _mod(self, x, y):
277
+ """Modulo function that propagates x gradients."""
278
+ return x + (torch.fmod(x, y) - x).detach()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
+ def set_state(self, c, h):
281
+ self.h0 = h
282
+ self.c0 = c
283
 
284
+ def forward(self, c_s, h_s, t):
285
+ # print(c_s.size(), h_s.size(), t.size())
286
+ phi = self._compute_phi(t)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
+ # Phase-related augmentations
289
+ k_up = 2 * phi / self.ratio_on
290
+ k_down = 2 - k_up
291
+ k_closed = self.leak * phi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ k = torch.where(phi < self.ratio_on, k_down, k_closed)
294
+ k = torch.where(phi < 0.5 * self.ratio_on, k_up, k)
295
+ k = k.view(c_s.shape[0], -1)
296
+ c_s_new = k * c_s + (1 - k) * self.c0
297
+ h_s_new = k * h_s + (1 - k) * self.h0
298
 
299
+ return h_s_new, c_s_new
 
300
 
 
 
 
 
 
 
301
 
302
+ class ConvLSTM(nn.Module):
303
+ """Adapted from: https://github.com/Atcold/pytorch-CortexNet/blob/master/model/ConvLSTMCell.py """
 
 
 
 
304
 
305
+ def __init__(self, input_size, hidden_size, kernel_size):
306
+ super(ConvLSTM, self).__init__()
 
 
307
 
308
+ self.input_size = input_size
309
+ self.hidden_size = hidden_size
310
+ pad = kernel_size // 2
311
 
312
+ # cache a tensor filled with zeros to avoid reallocating memory at each inference step if --no-recurrent is enabled
313
+ self.zero_tensors = {}
314
 
315
+ self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size, padding=pad)
 
 
 
 
 
 
 
 
 
 
316
 
317
+ def forward(self, input_, prev_state=None):
318
+ # get batch and spatial sizes
319
+ batch_size = input_.data.size()[0]
320
+ spatial_size = input_.data.size()[2:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
+ # generate empty prev_state, if None is provided
323
+ if prev_state is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
+ # create the zero tensor if it has not been created already
326
+ state_size = tuple([batch_size, self.hidden_size] + list(spatial_size))
327
+ if state_size not in self.zero_tensors:
328
+ # allocate a tensor with size `spatial_size`, filled with zero (if it has not been allocated already)
329
+ self.zero_tensors[state_size] = (
330
+ torch.zeros(state_size, dtype=input_.dtype).to(input_.device),
331
+ torch.zeros(state_size, dtype=input_.dtype).to(input_.device)
332
+ )
333
+
334
+ prev_state = self.zero_tensors[tuple(state_size)]
335
+
336
+ prev_hidden, prev_cell = prev_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
+ # data size is [batch, channel, height, width]
339
+ stacked_inputs = torch.cat((input_, prev_hidden), 1)
340
+ gates = self.Gates(stacked_inputs)
341
 
342
+ # chunk across channel dimension
343
+ in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)
344
+
345
+ # apply sigmoid non linearity
346
+ in_gate = torch.sigmoid(in_gate)
347
+ remember_gate = torch.sigmoid(remember_gate)
348
+ out_gate = torch.sigmoid(out_gate)
349
+
350
+ # apply tanh non linearity
351
+ cell_gate = torch.tanh(cell_gate)
352
+
353
+ # compute current cell and hidden state
354
+ cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
355
+ hidden = out_gate * torch.tanh(cell)
356
+
357
+ return hidden, cell
358
+
359
+
360
+ class PhasedConvLSTMCell(nn.Module):
361
+ def __init__(
362
+ self,
363
+ input_channels,
364
+ hidden_channels,
365
+ kernel_size=3
366
+ ):
367
+ super().__init__()
368
+ self.hidden_channels = hidden_channels
369
+
370
+ self.lstm = ConvLSTM(
371
+ input_size=input_channels,
372
+ hidden_size=hidden_channels,
373
+ kernel_size=kernel_size
374
+ )
375
+
376
+ # as soon as spatial dimension is known, phased lstm cell is instantiated
377
+ self.phased_cell = None
378
+ self.hidden_size = None
379
+
380
+ def forward(self, input, times, prev_state=None):
381
+ # input: B x C x H x W
382
+ # times: B
383
+ # returns: output: B x C_out x H x W, prev_state: (B x C_out x H x W, B x C_out x H x W)
384
+
385
+ B, C, H, W = input.shape
386
+
387
+ if self.hidden_size is None:
388
+ self.hidden_size = self.hidden_channels * W * H
389
+ self.phased_cell = PhasedLSTMCell(hidden_size=self.hidden_size)
390
+ self.phased_cell = self.phased_cell.to(input.device)
391
+ self.phased_cell.requires_grad = True
392
+
393
+ if prev_state is None:
394
+ h0 = input.new_zeros((B, self.hidden_channels, H, W))
395
+ c0 = input.new_zeros((B, self.hidden_channels, H, W))
396
  else:
397
+ c0, h0 = prev_state
398
 
399
+ self.phased_cell.set_state(c0.view(B, -1), h0.view(B, -1))
400
+
401
+ c_t, h_t = self.lstm(input, (c0, h0))
402
+
403
+ # reshape activation maps such that phased lstm can use them
404
+ (c_s, h_s) = self.phased_cell(c_t.view(B, -1), h_t.view(B, -1), times)
405
+
406
+ # reshape to feed to conv lstm
407
+ c_s = c_s.view(B, -1, H, W)
408
+ h_s = h_s.view(B, -1, H, W)
409
+
410
+ return h_t, (c_s, h_s)
411
+
412
+
413
+ class ConvGRU(nn.Module):
414
+ """
415
+ Generate a convolutional GRU cell
416
+ Adapted from: https://github.com/jacobkimmel/pytorch_convgru/blob/master/convgru.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  """
418
 
419
+ def __init__(self, input_size, hidden_size, kernel_size):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  super().__init__()
421
+ padding = kernel_size // 2
422
+ self.input_size = input_size
423
+ self.hidden_size = hidden_size
424
+ self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
425
+ self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
426
+ self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
427
+
428
+ init.orthogonal_(self.reset_gate.weight)
429
+ init.orthogonal_(self.update_gate.weight)
430
+ init.orthogonal_(self.out_gate.weight)
431
+ init.constant_(self.reset_gate.bias, 0.)
432
+ init.constant_(self.update_gate.bias, 0.)
433
+ init.constant_(self.out_gate.bias, 0.)
434
+
435
+ def forward(self, input_, prev_state):
436
+
437
+ # get batch and spatial sizes
438
+ batch_size = input_.data.size()[0]
439
+ spatial_size = input_.data.size()[2:]
440
+
441
+ # generate empty prev_state, if None is provided
442
+ if prev_state is None:
443
+ state_size = [batch_size, self.hidden_size] + list(spatial_size)
444
+ prev_state = torch.zeros(state_size, dtype=input_.dtype).to(input_.device)
445
+
446
+ # data size is [batch, channel, height, width]
447
+ stacked_inputs = torch.cat([input_, prev_state], dim=1)
448
+ update = torch.sigmoid(self.update_gate(stacked_inputs))
449
+ reset = torch.sigmoid(self.reset_gate(stacked_inputs))
450
+ out_inputs = torch.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1)))
451
+ new_state = prev_state * (1 - update) + out_inputs * update
452
+
453
+ return new_state
454
+
455
+
456
+ class RecurrentResidualLayer(nn.Module):
457
+ def __init__(self, in_channels, out_channels,
458
+ recurrent_block_type='convlstm', norm=None, BN_momentum=0.1):
459
+ super(RecurrentResidualLayer, self).__init__()
460
+
461
+ assert(recurrent_block_type in ['convlstm', 'convgru'])
462
+ self.recurrent_block_type = recurrent_block_type
463
+ if self.recurrent_block_type == 'convlstm':
464
+ RecurrentBlock = ConvLSTM
465
  else:
466
+ RecurrentBlock = ConvGRU
467
+ self.conv = ResidualBlock(in_channels=in_channels,
468
+ out_channels=out_channels,
469
+ norm=norm,
470
+ BN_momentum=BN_momentum)
471
+ self.recurrent_block = RecurrentBlock(input_size=out_channels,
472
+ hidden_size=out_channels,
473
+ kernel_size=3)
474
+
475
+ def forward(self, x, prev_state):
476
+ x = self.conv(x)
477
+ state = self.recurrent_block(x, prev_state)
478
+ x = state[0] if self.recurrent_block_type == 'convlstm' else state
479
+ return x, state