Max Meyer commited on
Commit
143915f
·
verified ·
1 Parent(s): acd779e

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +950 -0
model.py CHANGED
@@ -0,0 +1,950 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint as checkpoint
7
+ from einops import rearrange
8
+ from PIL import Image, ImageFilter, ImageOps
9
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
10
+ from torchvision import transforms
11
+
12
+ class Mlp(nn.Module):
13
+ """ Multilayer perceptron."""
14
+
15
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
16
+ super().__init__()
17
+ out_features = out_features or in_features
18
+ hidden_features = hidden_features or in_features
19
+ self.fc1 = nn.Linear(in_features, hidden_features)
20
+ self.act = act_layer()
21
+ self.fc2 = nn.Linear(hidden_features, out_features)
22
+ self.drop = nn.Dropout(drop)
23
+
24
+ def forward(self, x):
25
+ x = self.fc1(x)
26
+ x = self.act(x)
27
+ x = self.drop(x)
28
+ x = self.fc2(x)
29
+ x = self.drop(x)
30
+ return x
31
+
32
+
33
+ def window_partition(x, window_size):
34
+ """
35
+ Args:
36
+ x: (B, H, W, C)
37
+ window_size (int): window size
38
+ Returns:
39
+ windows: (num_windows*B, window_size, window_size, C)
40
+ """
41
+ B, H, W, C = x.shape
42
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
43
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
44
+ return windows
45
+
46
+
47
+ def window_reverse(windows, window_size, H, W):
48
+ """
49
+ Args:
50
+ windows: (num_windows*B, window_size, window_size, C)
51
+ window_size (int): Window size
52
+ H (int): Height of image
53
+ W (int): Width of image
54
+ Returns:
55
+ x: (B, H, W, C)
56
+ """
57
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
58
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
59
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
60
+ return x
61
+
62
+
63
+ class WindowAttention(nn.Module):
64
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
65
+ It supports both of shifted and non-shifted window.
66
+ Args:
67
+ dim (int): Number of input channels.
68
+ window_size (tuple[int]): The height and width of the window.
69
+ num_heads (int): Number of attention heads.
70
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
71
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
72
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
73
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
74
+ """
75
+
76
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
77
+
78
+ super().__init__()
79
+ self.dim = dim
80
+ self.window_size = window_size # Wh, Ww
81
+ self.num_heads = num_heads
82
+ head_dim = dim // num_heads
83
+ self.scale = qk_scale or head_dim ** -0.5
84
+
85
+ # define a parameter table of relative position bias
86
+ self.relative_position_bias_table = nn.Parameter(
87
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
88
+
89
+ # get pair-wise relative position index for each token inside the window
90
+ coords_h = torch.arange(self.window_size[0])
91
+ coords_w = torch.arange(self.window_size[1])
92
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
93
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
94
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
95
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
96
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
97
+ relative_coords[:, :, 1] += self.window_size[1] - 1
98
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
99
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
100
+ self.register_buffer("relative_position_index", relative_position_index)
101
+
102
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
103
+ self.attn_drop = nn.Dropout(attn_drop)
104
+ self.proj = nn.Linear(dim, dim)
105
+ self.proj_drop = nn.Dropout(proj_drop)
106
+
107
+ trunc_normal_(self.relative_position_bias_table, std=.02)
108
+ self.softmax = nn.Softmax(dim=-1)
109
+
110
+ def forward(self, x, mask=None):
111
+ """ Forward function.
112
+ Args:
113
+ x: input features with shape of (num_windows*B, N, C)
114
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
115
+ """
116
+ B_, N, C = x.shape
117
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
118
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
119
+
120
+ q = q * self.scale
121
+ attn = (q @ k.transpose(-2, -1))
122
+
123
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
124
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
125
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
126
+ attn = attn + relative_position_bias.unsqueeze(0)
127
+
128
+ if mask is not None:
129
+ nW = mask.shape[0]
130
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
131
+ attn = attn.view(-1, self.num_heads, N, N)
132
+ attn = self.softmax(attn)
133
+ else:
134
+ attn = self.softmax(attn)
135
+
136
+ attn = self.attn_drop(attn)
137
+
138
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
139
+ x = self.proj(x)
140
+ x = self.proj_drop(x)
141
+ return x
142
+
143
+
144
+ class SwinTransformerBlock(nn.Module):
145
+ """ Swin Transformer Block.
146
+ Args:
147
+ dim (int): Number of input channels.
148
+ num_heads (int): Number of attention heads.
149
+ window_size (int): Window size.
150
+ shift_size (int): Shift size for SW-MSA.
151
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
152
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
153
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
154
+ drop (float, optional): Dropout rate. Default: 0.0
155
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
156
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
157
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
158
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
159
+ """
160
+
161
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
162
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
163
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
164
+ super().__init__()
165
+ self.dim = dim
166
+ self.num_heads = num_heads
167
+ self.window_size = window_size
168
+ self.shift_size = shift_size
169
+ self.mlp_ratio = mlp_ratio
170
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
171
+
172
+ self.norm1 = norm_layer(dim)
173
+ self.attn = WindowAttention(
174
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
175
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
176
+
177
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
178
+ self.norm2 = norm_layer(dim)
179
+ mlp_hidden_dim = int(dim * mlp_ratio)
180
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
181
+
182
+ self.H = None
183
+ self.W = None
184
+
185
+ def forward(self, x, mask_matrix):
186
+ """ Forward function.
187
+ Args:
188
+ x: Input feature, tensor size (B, H*W, C).
189
+ H, W: Spatial resolution of the input feature.
190
+ mask_matrix: Attention mask for cyclic shift.
191
+ """
192
+ B, L, C = x.shape
193
+ H, W = self.H, self.W
194
+ assert L == H * W, "input feature has wrong size"
195
+
196
+ shortcut = x
197
+ x = self.norm1(x)
198
+ x = x.view(B, H, W, C)
199
+
200
+ # pad feature maps to multiples of window size
201
+ pad_l = pad_t = 0
202
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
203
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
204
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
205
+ _, Hp, Wp, _ = x.shape
206
+
207
+ # cyclic shift
208
+ if self.shift_size > 0:
209
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
210
+ attn_mask = mask_matrix
211
+ else:
212
+ shifted_x = x
213
+ attn_mask = None
214
+
215
+ # partition windows
216
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
217
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
218
+
219
+ # W-MSA/SW-MSA
220
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
221
+
222
+ # merge windows
223
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
224
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
225
+
226
+ # reverse cyclic shift
227
+ if self.shift_size > 0:
228
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
229
+ else:
230
+ x = shifted_x
231
+
232
+ if pad_r > 0 or pad_b > 0:
233
+ x = x[:, :H, :W, :].contiguous()
234
+
235
+ x = x.view(B, H * W, C)
236
+
237
+ # FFN
238
+ x = shortcut + self.drop_path(x)
239
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
240
+
241
+ return x
242
+
243
+
244
+ class PatchMerging(nn.Module):
245
+ """ Patch Merging Layer
246
+ Args:
247
+ dim (int): Number of input channels.
248
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
249
+ """
250
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
251
+ super().__init__()
252
+ self.dim = dim
253
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
254
+ self.norm = norm_layer(4 * dim)
255
+
256
+ def forward(self, x, H, W):
257
+ """ Forward function.
258
+ Args:
259
+ x: Input feature, tensor size (B, H*W, C).
260
+ H, W: Spatial resolution of the input feature.
261
+ """
262
+ B, L, C = x.shape
263
+ assert L == H * W, "input feature has wrong size"
264
+
265
+ x = x.view(B, H, W, C)
266
+
267
+ # padding
268
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
269
+ if pad_input:
270
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
271
+
272
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
273
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
274
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
275
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
276
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
277
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
278
+
279
+ x = self.norm(x)
280
+ x = self.reduction(x)
281
+
282
+ return x
283
+
284
+
285
+ class BasicLayer(nn.Module):
286
+ """ A basic Swin Transformer layer for one stage.
287
+ Args:
288
+ dim (int): Number of feature channels
289
+ depth (int): Depths of this stage.
290
+ num_heads (int): Number of attention head.
291
+ window_size (int): Local window size. Default: 7.
292
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
293
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
294
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
295
+ drop (float, optional): Dropout rate. Default: 0.0
296
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
297
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
298
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
299
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
300
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
301
+ """
302
+
303
+ def __init__(self,
304
+ dim,
305
+ depth,
306
+ num_heads,
307
+ window_size=7,
308
+ mlp_ratio=4.,
309
+ qkv_bias=True,
310
+ qk_scale=None,
311
+ drop=0.,
312
+ attn_drop=0.,
313
+ drop_path=0.,
314
+ norm_layer=nn.LayerNorm,
315
+ downsample=None,
316
+ use_checkpoint=False):
317
+ super().__init__()
318
+ self.window_size = window_size
319
+ self.shift_size = window_size // 2
320
+ self.depth = depth
321
+ self.use_checkpoint = use_checkpoint
322
+
323
+ # build blocks
324
+ self.blocks = nn.ModuleList([
325
+ SwinTransformerBlock(
326
+ dim=dim,
327
+ num_heads=num_heads,
328
+ window_size=window_size,
329
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
330
+ mlp_ratio=mlp_ratio,
331
+ qkv_bias=qkv_bias,
332
+ qk_scale=qk_scale,
333
+ drop=drop,
334
+ attn_drop=attn_drop,
335
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
336
+ norm_layer=norm_layer)
337
+ for i in range(depth)])
338
+
339
+ # patch merging layer
340
+ if downsample is not None:
341
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
342
+ else:
343
+ self.downsample = None
344
+
345
+ def forward(self, x, H, W):
346
+ """ Forward function.
347
+ Args:
348
+ x: Input feature, tensor size (B, H*W, C).
349
+ H, W: Spatial resolution of the input feature.
350
+ """
351
+
352
+ # calculate attention mask for SW-MSA
353
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
354
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
355
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
356
+ h_slices = (slice(0, -self.window_size),
357
+ slice(-self.window_size, -self.shift_size),
358
+ slice(-self.shift_size, None))
359
+ w_slices = (slice(0, -self.window_size),
360
+ slice(-self.window_size, -self.shift_size),
361
+ slice(-self.shift_size, None))
362
+ cnt = 0
363
+ for h in h_slices:
364
+ for w in w_slices:
365
+ img_mask[:, h, w, :] = cnt
366
+ cnt += 1
367
+
368
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
369
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
370
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
371
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
372
+
373
+ for blk in self.blocks:
374
+ blk.H, blk.W = H, W
375
+ if self.use_checkpoint:
376
+ x = checkpoint.checkpoint(blk, x, attn_mask)
377
+ else:
378
+ x = blk(x, attn_mask)
379
+ if self.downsample is not None:
380
+ x_down = self.downsample(x, H, W)
381
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
382
+ return x, H, W, x_down, Wh, Ww
383
+ else:
384
+ return x, H, W, x, H, W
385
+
386
+
387
+ class PatchEmbed(nn.Module):
388
+ """ Image to Patch Embedding
389
+ Args:
390
+ patch_size (int): Patch token size. Default: 4.
391
+ in_chans (int): Number of input image channels. Default: 3.
392
+ embed_dim (int): Number of linear projection output channels. Default: 96.
393
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
394
+ """
395
+
396
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
397
+ super().__init__()
398
+ patch_size = to_2tuple(patch_size)
399
+ self.patch_size = patch_size
400
+
401
+ self.in_chans = in_chans
402
+ self.embed_dim = embed_dim
403
+
404
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
405
+ if norm_layer is not None:
406
+ self.norm = norm_layer(embed_dim)
407
+ else:
408
+ self.norm = None
409
+
410
+ def forward(self, x):
411
+ """Forward function."""
412
+ # padding
413
+ _, _, H, W = x.size()
414
+ if W % self.patch_size[1] != 0:
415
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
416
+ if H % self.patch_size[0] != 0:
417
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
418
+
419
+ x = self.proj(x) # B C Wh Ww
420
+ if self.norm is not None:
421
+ Wh, Ww = x.size(2), x.size(3)
422
+ x = x.flatten(2).transpose(1, 2)
423
+ x = self.norm(x)
424
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
425
+
426
+ return x
427
+
428
+
429
+ class SwinTransformer(nn.Module):
430
+ """ Swin Transformer backbone.
431
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
432
+ https://arxiv.org/pdf/2103.14030
433
+ Args:
434
+ pretrain_img_size (int): Input image size for training the pretrained model,
435
+ used in absolute postion embedding. Default 224.
436
+ patch_size (int | tuple(int)): Patch size. Default: 4.
437
+ in_chans (int): Number of input image channels. Default: 3.
438
+ embed_dim (int): Number of linear projection output channels. Default: 96.
439
+ depths (tuple[int]): Depths of each Swin Transformer stage.
440
+ num_heads (tuple[int]): Number of attention head of each stage.
441
+ window_size (int): Window size. Default: 7.
442
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
443
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
444
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
445
+ drop_rate (float): Dropout rate.
446
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
447
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
448
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
449
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
450
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
451
+ out_indices (Sequence[int]): Output from which stages.
452
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
453
+ -1 means not freezing any parameters.
454
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
455
+ """
456
+
457
+ def __init__(self,
458
+ pretrain_img_size=224,
459
+ patch_size=4,
460
+ in_chans=3,
461
+ embed_dim=96,
462
+ depths=[2, 2, 6, 2],
463
+ num_heads=[3, 6, 12, 24],
464
+ window_size=7,
465
+ mlp_ratio=4.,
466
+ qkv_bias=True,
467
+ qk_scale=None,
468
+ drop_rate=0.,
469
+ attn_drop_rate=0.,
470
+ drop_path_rate=0.2,
471
+ norm_layer=nn.LayerNorm,
472
+ ape=False,
473
+ patch_norm=True,
474
+ out_indices=(0, 1, 2, 3),
475
+ frozen_stages=-1,
476
+ use_checkpoint=False):
477
+ super().__init__()
478
+
479
+ self.pretrain_img_size = pretrain_img_size
480
+ self.num_layers = len(depths)
481
+ self.embed_dim = embed_dim
482
+ self.ape = ape
483
+ self.patch_norm = patch_norm
484
+ self.out_indices = out_indices
485
+ self.frozen_stages = frozen_stages
486
+
487
+ # split image into non-overlapping patches
488
+ self.patch_embed = PatchEmbed(
489
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
490
+ norm_layer=norm_layer if self.patch_norm else None)
491
+
492
+ # absolute position embedding
493
+ if self.ape:
494
+ pretrain_img_size = to_2tuple(pretrain_img_size)
495
+ patch_size = to_2tuple(patch_size)
496
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
497
+
498
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
499
+ trunc_normal_(self.absolute_pos_embed, std=.02)
500
+
501
+ self.pos_drop = nn.Dropout(p=drop_rate)
502
+
503
+ # stochastic depth
504
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
505
+
506
+ # build layers
507
+ self.layers = nn.ModuleList()
508
+ for i_layer in range(self.num_layers):
509
+ layer = BasicLayer(
510
+ dim=int(embed_dim * 2 ** i_layer),
511
+ depth=depths[i_layer],
512
+ num_heads=num_heads[i_layer],
513
+ window_size=window_size,
514
+ mlp_ratio=mlp_ratio,
515
+ qkv_bias=qkv_bias,
516
+ qk_scale=qk_scale,
517
+ drop=drop_rate,
518
+ attn_drop=attn_drop_rate,
519
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
520
+ norm_layer=norm_layer,
521
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
522
+ use_checkpoint=use_checkpoint)
523
+ self.layers.append(layer)
524
+
525
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
526
+ self.num_features = num_features
527
+
528
+ # add a norm layer for each output
529
+ for i_layer in out_indices:
530
+ layer = norm_layer(num_features[i_layer])
531
+ layer_name = f'norm{i_layer}'
532
+ self.add_module(layer_name, layer)
533
+
534
+ self._freeze_stages()
535
+
536
+ def _freeze_stages(self):
537
+ if self.frozen_stages >= 0:
538
+ self.patch_embed.eval()
539
+ for param in self.patch_embed.parameters():
540
+ param.requires_grad = False
541
+
542
+ if self.frozen_stages >= 1 and self.ape:
543
+ self.absolute_pos_embed.requires_grad = False
544
+
545
+ if self.frozen_stages >= 2:
546
+ self.pos_drop.eval()
547
+ for i in range(0, self.frozen_stages - 1):
548
+ m = self.layers[i]
549
+ m.eval()
550
+ for param in m.parameters():
551
+ param.requires_grad = False
552
+
553
+
554
+ def forward(self, x):
555
+
556
+ x = self.patch_embed(x)
557
+
558
+ Wh, Ww = x.size(2), x.size(3)
559
+ if self.ape:
560
+ # interpolate the position embedding to the corresponding size
561
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
562
+ x = (x + absolute_pos_embed) # B Wh*Ww C
563
+
564
+ outs = [x.contiguous()]
565
+ x = x.flatten(2).transpose(1, 2)
566
+ x = self.pos_drop(x)
567
+
568
+
569
+ for i in range(self.num_layers):
570
+ layer = self.layers[i]
571
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
572
+
573
+
574
+ if i in self.out_indices:
575
+ norm_layer = getattr(self, f'norm{i}')
576
+ x_out = norm_layer(x_out)
577
+
578
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
579
+ outs.append(out)
580
+
581
+
582
+
583
+ return tuple(outs)
584
+
585
+
586
+
587
+
588
+
589
+
590
+
591
+ def get_activation_fn(activation):
592
+ """Return an activation function given a string"""
593
+ if activation == "gelu":
594
+ return F.gelu
595
+
596
+ raise RuntimeError(F"activation should be gelu, not {activation}.")
597
+
598
+
599
+ def make_cbr(in_dim, out_dim):
600
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
601
+
602
+
603
+ def make_cbg(in_dim, out_dim):
604
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
605
+
606
+
607
+ def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
608
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
609
+
610
+
611
+ def resize_as(x, y, interpolation='bilinear'):
612
+ return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
613
+
614
+
615
+ def image2patches(x):
616
+ """b c (hg h) (wg w) -> (hg wg b) c h w"""
617
+ x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
618
+ return x
619
+
620
+
621
+ def patches2image(x):
622
+ """(hg wg b) c h w -> b c (hg h) (wg w)"""
623
+ x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
624
+ return x
625
+ class PositionEmbeddingSine:
626
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
627
+ super().__init__()
628
+ self.num_pos_feats = num_pos_feats
629
+ self.temperature = temperature
630
+ self.normalize = normalize
631
+ if scale is not None and normalize is False:
632
+ raise ValueError("normalize should be True if scale is passed")
633
+ if scale is None:
634
+ scale = 2 * math.pi
635
+ self.scale = scale
636
+ self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
637
+
638
+ def __call__(self, b, h, w):
639
+ device = self.dim_t.device
640
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
641
+ assert mask is not None
642
+ not_mask = ~mask
643
+ y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
644
+ x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
645
+ if self.normalize:
646
+ eps = 1e-6
647
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
648
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
649
+
650
+ dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
651
+ pos_x = x_embed[:, :, :, None] / dim_t
652
+ pos_y = y_embed[:, :, :, None] / dim_t
653
+
654
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
655
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
656
+
657
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
658
+
659
+
660
+ class MCLM(nn.Module):
661
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
662
+ super(MCLM, self).__init__()
663
+ self.attention = nn.ModuleList([
664
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
665
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
666
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
667
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
668
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
669
+ ])
670
+
671
+ self.linear1 = nn.Linear(d_model, d_model * 2)
672
+ self.linear2 = nn.Linear(d_model * 2, d_model)
673
+ self.linear3 = nn.Linear(d_model, d_model * 2)
674
+ self.linear4 = nn.Linear(d_model * 2, d_model)
675
+ self.norm1 = nn.LayerNorm(d_model)
676
+ self.norm2 = nn.LayerNorm(d_model)
677
+ self.dropout = nn.Dropout(0.1)
678
+ self.dropout1 = nn.Dropout(0.1)
679
+ self.dropout2 = nn.Dropout(0.1)
680
+ self.activation = get_activation_fn('gelu')
681
+ self.pool_ratios = pool_ratios
682
+ self.p_poses = []
683
+ self.g_pos = None
684
+ self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
685
+
686
+ def forward(self, l, g):
687
+ """
688
+ l: 4,c,h,w
689
+ g: 1,c,h,w
690
+ """
691
+ b, c, h, w = l.size()
692
+ # 4,c,h,w -> 1,c,2h,2w
693
+ concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
694
+
695
+ pools = []
696
+ for pool_ratio in self.pool_ratios:
697
+ # b,c,h,w
698
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
699
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
700
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
701
+ if self.g_pos is None:
702
+ pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3])
703
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
704
+ self.p_poses.append(pos_emb)
705
+ pools = torch.cat(pools, 0)
706
+ if self.g_pos is None:
707
+ self.p_poses = torch.cat(self.p_poses, dim=0)
708
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
709
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
710
+
711
+ device = pools.device
712
+ self.p_poses = self.p_poses.to(device)
713
+ self.g_pos = self.g_pos.to(device)
714
+
715
+
716
+ # attention between glb (q) & multisensory concated-locs (k,v)
717
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
718
+
719
+
720
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
721
+ g_hw_b_c = self.norm1(g_hw_b_c)
722
+ g_hw_b_c = g_hw_b_c + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
723
+ g_hw_b_c = self.norm2(g_hw_b_c)
724
+
725
+ # attention between origin locs (q) & freashed glb (k,v)
726
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
727
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
728
+ _g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
729
+ outputs_re = []
730
+ for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
731
+ outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
732
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
733
+
734
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
735
+ l_hw_b_c = self.norm1(l_hw_b_c)
736
+ l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
737
+ l_hw_b_c = self.norm2(l_hw_b_c)
738
+
739
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
740
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
741
+
742
+
743
+
744
+
745
+
746
+
747
+
748
+
749
+
750
+ class MCRM(nn.Module):
751
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
752
+ super(MCRM, self).__init__()
753
+ self.attention = nn.ModuleList([
754
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
755
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
756
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
757
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
758
+ ])
759
+ self.linear3 = nn.Linear(d_model, d_model * 2)
760
+ self.linear4 = nn.Linear(d_model * 2, d_model)
761
+ self.norm1 = nn.LayerNorm(d_model)
762
+ self.norm2 = nn.LayerNorm(d_model)
763
+ self.dropout = nn.Dropout(0.1)
764
+ self.dropout1 = nn.Dropout(0.1)
765
+ self.dropout2 = nn.Dropout(0.1)
766
+ self.sigmoid = nn.Sigmoid()
767
+ self.activation = get_activation_fn('gelu')
768
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
769
+ self.pool_ratios = pool_ratios
770
+
771
+ def forward(self, x):
772
+ device = x.device
773
+ b, c, h, w = x.size()
774
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
775
+
776
+ patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
777
+
778
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
779
+ token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
780
+ loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
781
+
782
+ pools = []
783
+ for pool_ratio in self.pool_ratios:
784
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
785
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
786
+ pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
787
+
788
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
789
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
790
+
791
+ outputs = []
792
+ for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
793
+ v = pools[i]
794
+ k = v
795
+ outputs.append(self.attention[i](q, k, v)[0])
796
+
797
+ outputs = torch.cat(outputs, 1)
798
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
799
+ src = self.norm1(src)
800
+ src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
801
+ src = self.norm2(src)
802
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
803
+ glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
804
+
805
+ return torch.cat((src, glb), 0), token_attention_map
806
+
807
+
808
+ class BEN_Base(nn.Module):
809
+ def __init__(self):
810
+ super().__init__()
811
+
812
+ self.backbone = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
813
+ emb_dim = 128
814
+ self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
815
+ self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
816
+ self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
817
+ self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
818
+ self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
819
+
820
+ self.output5 = make_cbr(1024, emb_dim)
821
+ self.output4 = make_cbr(512, emb_dim)
822
+ self.output3 = make_cbr(256, emb_dim)
823
+ self.output2 = make_cbr(128, emb_dim)
824
+ self.output1 = make_cbr(128, emb_dim)
825
+
826
+ self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
827
+ self.conv1 = make_cbr(emb_dim, emb_dim)
828
+ self.conv2 = make_cbr(emb_dim, emb_dim)
829
+ self.conv3 = make_cbr(emb_dim, emb_dim)
830
+ self.conv4 = make_cbr(emb_dim, emb_dim)
831
+ self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
832
+ self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
833
+ self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
834
+ self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
835
+
836
+ self.insmask_head = nn.Sequential(
837
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
838
+ nn.InstanceNorm2d(384),
839
+ nn.GELU(),
840
+ nn.Conv2d(384, 384, kernel_size=3, padding=1),
841
+ nn.InstanceNorm2d(384),
842
+ nn.GELU(),
843
+ nn.Conv2d(384, emb_dim, kernel_size=3, padding=1)
844
+ )
845
+
846
+ self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
847
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
848
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
849
+ self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
850
+
851
+ for m in self.modules():
852
+ if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
853
+ m.inplace = True
854
+
855
+ def forward(self, x):
856
+ device = x.device
857
+ shallow = self.shallow(x)
858
+ glb = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
859
+ loc = image2patches(x)
860
+ input = torch.cat((loc, glb), dim=0)
861
+ feature = self.backbone(input)
862
+ e5 = self.output5(feature[4]) # (5,128,16,16)
863
+ e4 = self.output4(feature[3]) # (5,128,32,32)
864
+ e3 = self.output3(feature[2]) # (5,128,64,64)
865
+ e2 = self.output2(feature[1]) # (5,128,128,128)
866
+ e1 = self.output1(feature[0]) # (5,128,128,128)
867
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
868
+ e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
869
+
870
+ e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
871
+ e4 = self.conv4(e4)
872
+ e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
873
+ e3 = self.conv3(e3)
874
+ e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
875
+ e2 = self.conv2(e2)
876
+ e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
877
+ e1 = self.conv1(e1)
878
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
879
+ output1_cat = patches2image(loc_e1) # (1,128,256,256)
880
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
881
+ final_output = self.insmask_head(output1_cat) # (1,128,256,256)
882
+ final_output = final_output + resize_as(shallow, final_output)
883
+ final_output = self.upsample1(rescale_to(final_output))
884
+ final_output = rescale_to(final_output + resize_as(shallow, final_output))
885
+ final_output = self.upsample2(final_output)
886
+ final_output = self.output(final_output)
887
+
888
+ return final_output.sigmoid()
889
+
890
+ def inference(self,image):
891
+ image, h, w,original_image = rgb_loader_refiner(image)
892
+
893
+ img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
894
+
895
+ res = self.forward(img_tensor)
896
+
897
+ pred_array = postprocess_image(res, im_size=[w, h])
898
+
899
+ mask_image = Image.fromarray(pred_array, mode='L')
900
+
901
+ blurred_mask = mask_image.filter(ImageFilter.GaussianBlur(radius=1))
902
+
903
+ original_image_rgba = original_image.convert("RGBA")
904
+
905
+ foreground = original_image_rgba.copy()
906
+
907
+ foreground.putalpha(blurred_mask)
908
+
909
+ return blurred_mask, foreground
910
+
911
+ def loadcheckpoints(self,model_path):
912
+ model_dict = torch.load(model_path,map_location="cpu")
913
+ self.load_state_dict(model_dict['model_state_dict'], strict=True)
914
+ del model_path
915
+
916
+
917
+
918
+
919
+ def rgb_loader_refiner( original_image):
920
+ h, w = original_image.size
921
+ # # Apply EXIF orientation
922
+ image = ImageOps.exif_transpose(original_image)
923
+ # Convert to RGB if necessary
924
+ if image.mode != 'RGB':
925
+ image = image.convert('RGB')
926
+
927
+ # Resize the image
928
+ image = image.resize((1024, 1024), resample=Image.LANCZOS)
929
+
930
+ return image.convert('RGB'), h, w,original_image
931
+
932
+ # Define the image transformation
933
+ img_transform = transforms.Compose([
934
+ transforms.ToTensor(),
935
+ transforms.ConvertImageDtype(torch.float32),
936
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
937
+ ])
938
+
939
+ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
940
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
941
+ ma = torch.max(result)
942
+ mi = torch.min(result)
943
+ result = (result - mi) / (ma - mi)
944
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
945
+ im_array = np.squeeze(im_array)
946
+ return im_array
947
+
948
+
949
+
950
+