Feature Extraction
Transformers
Safetensors
vision-encoder-decoder
custom_code
anicolson commited on
Commit
8c83227
1 Parent(s): 7377cb5

Upload modelling_uniformer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modelling_uniformer.py +412 -0
modelling_uniformer.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from functools import partial
3
+ from typing import Optional, Tuple, Union
4
+ from math import isqrt
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
9
+ from transformers import ViTConfig
10
+ from transformers.modeling_outputs import ModelOutput
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.utils import logging
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+
17
+ layer_scale = False
18
+ init_value = 1e-6
19
+
20
+
21
+ class Mlp(nn.Module):
22
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.fc1 = nn.Linear(in_features, hidden_features)
27
+ self.act = act_layer()
28
+ self.fc2 = nn.Linear(hidden_features, out_features)
29
+ self.drop = nn.Dropout(drop)
30
+
31
+ def forward(self, x):
32
+ x = self.fc1(x)
33
+ x = self.act(x)
34
+ x = self.drop(x)
35
+ x = self.fc2(x)
36
+ x = self.drop(x)
37
+ return x
38
+
39
+
40
+ class CMlp(nn.Module):
41
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
42
+ super().__init__()
43
+ out_features = out_features or in_features
44
+ hidden_features = hidden_features or in_features
45
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
46
+ self.act = act_layer()
47
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
48
+ self.drop = nn.Dropout(drop)
49
+
50
+ def forward(self, x):
51
+ x = self.fc1(x)
52
+ x = self.act(x)
53
+ x = self.drop(x)
54
+ x = self.fc2(x)
55
+ x = self.drop(x)
56
+ return x
57
+
58
+
59
+ class Attention(nn.Module):
60
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
61
+ super().__init__()
62
+ self.num_heads = num_heads
63
+ head_dim = dim // num_heads
64
+ self.scale = qk_scale or head_dim ** -0.5
65
+
66
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
67
+ self.attn_drop = nn.Dropout(attn_drop)
68
+ self.proj = nn.Linear(dim, dim)
69
+ self.proj_drop = nn.Dropout(proj_drop)
70
+
71
+ def forward(self, x):
72
+ B, N, C = x.shape
73
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
74
+ q, k, v = qkv[0], qkv[1], qkv[2]
75
+
76
+ attn = (q @ k.transpose(-2, -1)) * self.scale
77
+ attn = attn.softmax(dim=-1)
78
+ attn = self.attn_drop(attn)
79
+
80
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
81
+ x = self.proj(x)
82
+ x = self.proj_drop(x)
83
+ return x
84
+
85
+
86
+ class CBlock(nn.Module):
87
+ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU):
88
+ super().__init__()
89
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
90
+ self.norm1 = nn.BatchNorm2d(dim)
91
+ self.conv1 = nn.Conv2d(dim, dim, 1)
92
+ self.conv2 = nn.Conv2d(dim, dim, 1)
93
+ self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
94
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
95
+ self.norm2 = nn.BatchNorm2d(dim)
96
+ mlp_hidden_dim = int(dim * mlp_ratio)
97
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
98
+
99
+ def forward(self, x):
100
+ x = x + self.pos_embed(x)
101
+ x = x + self.module_1(x)
102
+ x = x + self.module_2(x)
103
+ return x
104
+
105
+ def module_1(self, x):
106
+ x = self.norm1(x.to(dtype=self.norm1.weight.dtype)) # Won't autocast to the dtype of the parameters of nn.BatchNorm2d.
107
+ x = self.conv1(x)
108
+ x = self.attn(x)
109
+ x = self.conv2(x)
110
+ x = self.drop_path(x)
111
+ return x
112
+
113
+ def module_2(self, x):
114
+ x = self.norm2(x.to(dtype=self.norm2.weight.dtype)) # Won't autocast to the dtype of the parameters of nn.BatchNorm2d.
115
+ x = self.mlp(x)
116
+ x = self.drop_path(x)
117
+ return x
118
+
119
+ class SABlock(nn.Module):
120
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
121
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
122
+ super().__init__()
123
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
124
+ self.norm1 = norm_layer(dim)
125
+ self.attn = Attention(
126
+ dim,
127
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
128
+ attn_drop=attn_drop, proj_drop=drop)
129
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
130
+ self.norm2 = norm_layer(dim)
131
+ mlp_hidden_dim = int(dim * mlp_ratio)
132
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
133
+ global layer_scale
134
+ self.ls = layer_scale
135
+ if self.ls:
136
+ global init_value
137
+ print(f"Use layer_scale: {layer_scale}, init_values: {init_value}")
138
+ self.gamma_1 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
139
+ self.gamma_2 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
140
+
141
+ def forward(self, x):
142
+ x = x + self.pos_embed(x)
143
+ B, N, H, W = x.shape
144
+ x = x.flatten(2).transpose(1, 2)
145
+ if self.ls:
146
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
147
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
148
+ else:
149
+ x = x + self.drop_path(self.attn(self.norm1(x)))
150
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
151
+ x = x.transpose(1, 2).reshape(B, N, H, W)
152
+ return x
153
+
154
+
155
+ class HeadEmbedding(nn.Module):
156
+ def __init__(self, in_channels, out_channels):
157
+ super(HeadEmbedding, self).__init__()
158
+
159
+ self.proj = nn.Sequential(
160
+ nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
161
+ nn.BatchNorm2d(out_channels // 2),
162
+ nn.GELU(),
163
+ nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
164
+ nn.BatchNorm2d(out_channels),
165
+ )
166
+
167
+ def forward(self, x):
168
+ x = self.proj(x)
169
+ return x
170
+
171
+
172
+ class MiddleEmbedding(nn.Module):
173
+ def __init__(self, in_channels, out_channels):
174
+ super(MiddleEmbedding, self).__init__()
175
+
176
+ self.proj = nn.Sequential(
177
+ nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
178
+ nn.BatchNorm2d(out_channels),
179
+ )
180
+
181
+ def forward(self, x):
182
+ x = self.proj(x)
183
+ return x
184
+
185
+
186
+ class PatchEmbed(nn.Module):
187
+ def __init__(self, image_size=224, patch_size=16, in_chans=3, embed_dim=768):
188
+ super().__init__()
189
+ image_size = to_2tuple(image_size)
190
+ patch_size = to_2tuple(patch_size)
191
+ num_patches_height = image_size[0] // patch_size[0]
192
+ num_patches_width = image_size[1] // patch_size[1]
193
+ num_patches = num_patches_height * num_patches_width
194
+ self.image_size = image_size
195
+ self.patch_size = patch_size
196
+ self.num_patches = num_patches
197
+
198
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
199
+ self.norm = nn.LayerNorm(embed_dim)
200
+
201
+ def forward(self, x):
202
+ _, _, H, W = x.shape
203
+ assert H == self.image_size[0] and W == self.image_size[1], \
204
+ f"Input image size ({H}*{W}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
205
+ x = self.proj(x)
206
+ B, _, H, W = x.shape
207
+ x = x.flatten(2).transpose(1, 2)
208
+ x = self.norm(x)
209
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
210
+ return x
211
+
212
+
213
+ class UniFormer(nn.Module):
214
+ def __init__(self, depth=[3, 4, 8, 3], image_size=224, in_chans=3, num_classes=1000, embed_dim=[64, 128, 320, 512],
215
+ head_dim=64, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, patch_size=[4, 2, 2, 2],
216
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., conv_stem=False, layer_norm_eps=1e-6, **kwargs):
217
+ super().__init__()
218
+ self.num_classes = num_classes
219
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
220
+ norm_layer = partial(nn.LayerNorm, eps=layer_norm_eps)
221
+ if conv_stem:
222
+ self.patch_embed1 = HeadEmbedding(in_channels=in_chans, out_channels=embed_dim[0])
223
+ self.patch_embed2 = MiddleEmbedding(in_channels=embed_dim[0], out_channels=embed_dim[1])
224
+ self.patch_embed3 = MiddleEmbedding(in_channels=embed_dim[1], out_channels=embed_dim[2])
225
+ self.patch_embed4 = MiddleEmbedding(in_channels=embed_dim[2], out_channels=embed_dim[3])
226
+ else:
227
+ self.patch_embed1 = PatchEmbed(
228
+ image_size=image_size, patch_size=patch_size[0], in_chans=in_chans, embed_dim=embed_dim[0])
229
+ self.patch_embed2 = PatchEmbed(
230
+ image_size=image_size // patch_size[0], patch_size=patch_size[1], in_chans=embed_dim[0], embed_dim=embed_dim[1])
231
+ self.patch_embed3 = PatchEmbed(
232
+ image_size=image_size // (patch_size[0]*patch_size[1]), patch_size=patch_size[2], in_chans=embed_dim[1], embed_dim=embed_dim[2])
233
+ self.patch_embed4 = PatchEmbed(
234
+ image_size=image_size // (patch_size[0]*patch_size[1]*patch_size[2]), patch_size=patch_size[3], in_chans=embed_dim[2], embed_dim=embed_dim[3])
235
+
236
+ self.pos_drop = nn.Dropout(p=drop_rate)
237
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] # stochastic depth decay rule
238
+ num_heads = [dim // head_dim for dim in embed_dim]
239
+ self.blocks1 = nn.ModuleList([
240
+ CBlock(dim=embed_dim[0], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i])
241
+ for i in range(depth[0])])
242
+ self.blocks2 = nn.ModuleList([
243
+ CBlock(dim=embed_dim[1], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i+depth[0]])
244
+ for i in range(depth[1])])
245
+ self.blocks3 = nn.ModuleList([
246
+ SABlock(
247
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
248
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer)
249
+ for i in range(depth[2])])
250
+ self.blocks4 = nn.ModuleList([
251
+ SABlock(
252
+ dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
253
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer)
254
+ for i in range(depth[3])])
255
+ self.norm = nn.BatchNorm2d(embed_dim[-1])
256
+
257
+ # Representation layer
258
+ if representation_size:
259
+ self.num_features = representation_size
260
+ self.pre_logits = nn.Sequential(OrderedDict([
261
+ ('fc', nn.Linear(embed_dim, representation_size)),
262
+ ('act', nn.Tanh())
263
+ ]))
264
+ else:
265
+ self.pre_logits = nn.Identity()
266
+
267
+ def forward_features(self, x):
268
+ x = self.patch_embed1(x)
269
+ x = self.pos_drop(x)
270
+ for blk in self.blocks1:
271
+ x = blk(x)
272
+ x = self.patch_embed2(x)
273
+ for blk in self.blocks2:
274
+ x = blk(x)
275
+ x = self.patch_embed3(x)
276
+ for blk in self.blocks3:
277
+ x = blk(x)
278
+ x = self.patch_embed4(x)
279
+ for blk in self.blocks4:
280
+ x = blk(x)
281
+ x = self.norm(x.to(dtype=self.norm.weight.dtype)) # Won't autocast to the dtype of the parameters of nn.BatchNorm2d.
282
+ x = self.pre_logits(x)
283
+ return x
284
+
285
+ def forward(self, x):
286
+ x = self.forward_features(x)
287
+ return x
288
+
289
+
290
+ class UniFormerPreTrainedModel(PreTrainedModel):
291
+ """
292
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
293
+ models.
294
+ """
295
+
296
+ config_class = ViTConfig
297
+ base_model_prefix = "vit"
298
+ main_input_name = "pixel_values"
299
+
300
+ def _init_weights(self, m):
301
+ if isinstance(m, nn.Linear):
302
+ trunc_normal_(m.weight, std=.02)
303
+ if isinstance(m, nn.Linear) and m.bias is not None:
304
+ nn.init.constant_(m.bias, 0)
305
+ elif isinstance(m, nn.LayerNorm):
306
+ nn.init.constant_(m.bias, 0)
307
+ nn.init.constant_(m.weight, 1.0)
308
+
309
+
310
+ class UniFormerProjectionHead(torch.nn.Module):
311
+
312
+ def __init__(self, config) -> None:
313
+ super().__init__()
314
+
315
+ # Layer normalisation before projection:
316
+ self.layer_norm = torch.nn.LayerNorm(config.embed_dim[-1], eps=config.layer_norm_eps)
317
+
318
+ # No bias as following layer normalisation with bias:
319
+ self.projection = torch.nn.Linear(config.embed_dim[-1], config.projection_size, bias=False)
320
+
321
+
322
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
323
+ x = self.layer_norm(x)
324
+ x = self.projection(x)
325
+ return x
326
+
327
+
328
+ class UniFormerModel(UniFormerPreTrainedModel):
329
+ def __init__(self, config):
330
+ super().__init__(config)
331
+
332
+ self.uniformer = UniFormer(**vars(config))
333
+
334
+ # Initialize weights and apply final processing:
335
+ self.post_init()
336
+
337
+ def forward(
338
+ self,
339
+ pixel_values: Optional[torch.Tensor] = None,
340
+ output_hidden_states: Optional[bool] = None,
341
+ return_dict: Optional[bool] = None,
342
+ ) -> Union[Tuple, ModelOutput]:
343
+
344
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
345
+
346
+ last_hidden_state = self.uniformer(pixel_values)
347
+
348
+ # Flatten h x w:
349
+ last_hidden_state = torch.flatten(last_hidden_state, 2)
350
+
351
+ # Permute last hidden state:
352
+ last_hidden_state = torch.permute(last_hidden_state, [0, 2, 1])
353
+
354
+ # return last_hidden_state
355
+ if not return_dict:
356
+ return last_hidden_state
357
+
358
+ return ModelOutput(last_hidden_state=last_hidden_state)
359
+
360
+
361
+ class MultiUniFormerWithProjectionHead(UniFormerPreTrainedModel):
362
+ def __init__(self, config):
363
+ super().__init__(config)
364
+
365
+ self.uniformer = UniFormer(**vars(config))
366
+ self.projection_head = UniFormerProjectionHead(config)
367
+
368
+ # Initialize weights and apply final processing:
369
+ self.post_init()
370
+
371
+ def forward(
372
+ self,
373
+ pixel_values: Optional[torch.Tensor] = None,
374
+ output_hidden_states: Optional[bool] = None,
375
+ return_dict: Optional[bool] = None,
376
+ ) -> Union[Tuple, ModelOutput]:
377
+
378
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
379
+
380
+ # Flatten the batch and study_id dimensions:
381
+ assert len(pixel_values.shape) == 5, 'pixel_values must be B, S, C, H, W, where S is the max number of images for a study in the batch.'
382
+ last_hidden_state = self.uniformer(pixel_values.view(-1, *pixel_values.shape[2:]))
383
+ # last_hidden_state = self.uniformer(pixel_values.flatten(start_dim=0, end_dim=1))
384
+
385
+ # Flatten h x w:
386
+ last_hidden_state = torch.flatten(last_hidden_state, 2)
387
+
388
+ # Project the features for each spatial position to the decoder's hidden size:
389
+ projection = self.projection_head(torch.permute(last_hidden_state, [0, 2, 1]))
390
+
391
+ # Concatenate the features for each chest X-ray:
392
+ projection = projection.view(pixel_values.shape[0], -1, projection.shape[-1])
393
+
394
+ # Derive the attention mask from the pixel values:
395
+ mask = (pixel_values[:, :, 0, 0, 0] != 0.0)[:, :, None]
396
+ attention_mask = torch.ones(
397
+ [projection.shape[0], pixel_values.shape[1], projection.shape[1] // pixel_values.shape[1]],
398
+ dtype=torch.long,
399
+ device=mask.device,
400
+ )
401
+ attention_mask = attention_mask * mask
402
+ attention_mask = attention_mask.view(attention_mask.shape[0], -1)
403
+
404
+ if not return_dict:
405
+ return projection
406
+
407
+ return ModelOutput(last_hidden_state=projection, attention_mask=attention_mask)
408
+
409
+
410
+ if __name__ == '__main__':
411
+ y = PatchEmbed()
412
+ y(torch.randn(2, 3, 224, 224))