ofirab commited on
Commit
89a069d
·
verified ·
1 Parent(s): 9cf3625

Upload model

Browse files
Files changed (4) hide show
  1. config.json +2 -4
  2. model.safetensors +1 -1
  3. modeling_vilmaswin.py +797 -0
  4. modeling_visfocus.py +810 -0
config.json CHANGED
@@ -1,12 +1,10 @@
1
  {
2
  "architectures": [
3
- "VisFocusModel",
4
- "VisFocusForLocalizedMaskedLanguageModeling",
5
- "VisFocusForImageTextToText"
6
  ],
7
  "auto_map": {
8
  "AutoConfig": "configuration_visfocus.VisFocusConfig",
9
- "AutoModel": "configuration_visfocus.VisFocusPreTrainedModel",
10
  "AutoModelForConditionalGeneration": "configuration_visfocus.VisFocusForImageTextToText",
11
  "AutoModelForImageTextToText": "configuration_visfocus.VisFocusForImageTextToText"
12
  },
 
1
  {
2
  "architectures": [
3
+ "VisFocusModelForImageTextToText"
 
 
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_visfocus.VisFocusConfig",
7
+ "AutoModel": "modeling_visfocus.VisFocusModelForImageTextToText",
8
  "AutoModelForConditionalGeneration": "configuration_visfocus.VisFocusForImageTextToText",
9
  "AutoModelForImageTextToText": "configuration_visfocus.VisFocusForImageTextToText"
10
  },
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f9a3839fb77abc8c559e4ecf3c972f0592b76efc24632687bf949bde4ea5d3e9
3
  size 1047109288
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:142b3fbf1d72be9681a77e47453f047bdac3f5c9649c354d84bd3621f479427d
3
  size 1047109288
modeling_vilmaswin.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer V2
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+
7
+ # Modifications Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
8
+ # --------------------------------------------------------
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint as checkpoint
14
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
15
+ import numpy as np
16
+
17
+
18
+ class Mlp(nn.Module):
19
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
20
+ super().__init__()
21
+ out_features = out_features or in_features
22
+ hidden_features = hidden_features or in_features
23
+ self.fc1 = nn.Linear(in_features, hidden_features)
24
+ self.act = act_layer()
25
+ self.fc2 = nn.Linear(hidden_features, out_features)
26
+ self.drop = nn.Dropout(drop)
27
+
28
+ def forward(self, x):
29
+ x = self.fc1(x)
30
+ x = self.act(x)
31
+ x = self.drop(x)
32
+ x = self.fc2(x)
33
+ x = self.drop(x)
34
+ return x
35
+
36
+ class PositionalEncoding(nn.Module):
37
+
38
+ def __init__(self, d_hid, n_position=200):
39
+ super(PositionalEncoding, self).__init__()
40
+
41
+ # Not a parameter
42
+ self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
43
+
44
+ def _get_sinusoid_encoding_table(self, n_position, d_hid):
45
+ ''' Sinusoid position encoding table '''
46
+
47
+ def get_position_angle_vec(position):
48
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
49
+
50
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
51
+ sinusoid_table[0::2] = np.sin(sinusoid_table[0::2]) # dim 2i
52
+ sinusoid_table[1::2] = np.cos(sinusoid_table[1::2]) # dim 2i+1
53
+
54
+ return torch.FloatTensor(sinusoid_table).unsqueeze(1) # -> [L,B,dim]
55
+
56
+ def forward(self, x):
57
+ return x + self.pos_table[:, :x.size(1)].clone().detach()
58
+
59
+ class CrossAttention(nn.Module):
60
+ """
61
+ borrowed from https://github.com/openai/CLIP/blob/main/clip/model.py (AttentionPool2d)
62
+ """
63
+ def __init__(self,
64
+ dim: int,
65
+ kv_dim: int,
66
+ output_dim: int = None,
67
+ num_heads: int = None,
68
+ context_length: int = None,
69
+ norm_layer=nn.LayerNorm,
70
+ learned_ape=True,
71
+ **kwargs):
72
+ super().__init__()
73
+ embed_dim = dim
74
+ output_dim = output_dim
75
+ self.learned_ape = learned_ape
76
+ if learned_ape:
77
+ self.positional_embedding = nn.Parameter(torch.randn(context_length, embed_dim) / embed_dim ** 0.5)
78
+ else:
79
+ self.positional_embedding = PositionalEncoding(embed_dim, context_length)
80
+ self.context_length = context_length
81
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
82
+ self.k_proj = nn.Linear(kv_dim, embed_dim)
83
+ self.v_proj = nn.Linear(kv_dim, embed_dim)
84
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
85
+ self.num_heads = num_heads
86
+ self.norm = norm_layer(dim)
87
+
88
+ def forward(self, x_q, x_kv, print_maps=False):
89
+ x_q = x_q.permute(1, 0, 2) # NLW -> LNC
90
+ x_kv = x_kv.permute(1, 0, 2) # NCS -> SNC
91
+ # x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
92
+ if self.learned_ape:
93
+ x_q = x_q + self.positional_embedding[:x_q.shape[0], None, :].to(x_q.dtype) # (HW+1)NC
94
+ else:
95
+ x_q = self.positional_embedding(x_q)
96
+ x, _ = F.multi_head_attention_forward(
97
+ query=x_q, key=x_kv, value=x_kv,
98
+ embed_dim_to_check=x_q.shape[-1],
99
+ num_heads=self.num_heads,
100
+ q_proj_weight=self.q_proj.weight,
101
+ k_proj_weight=self.k_proj.weight,
102
+ v_proj_weight=self.v_proj.weight,
103
+ in_proj_weight=None,
104
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
105
+ bias_k=None,
106
+ bias_v=None,
107
+ add_zero_attn=False,
108
+ dropout_p=0,
109
+ out_proj_weight=self.c_proj.weight,
110
+ out_proj_bias=self.c_proj.bias,
111
+ use_separate_proj_weight=True,
112
+ training=self.training,
113
+ need_weights=False,
114
+ # print_maps=print_maps
115
+ )
116
+ if self.norm:
117
+ x = self.norm(x)
118
+ x = x.permute(1, 0, 2) # LNC -> NLW
119
+ return x
120
+
121
+
122
+ def window_partition(x, window_size):
123
+ """
124
+ Args:
125
+ x: (B, H, W, C)
126
+ window_size (int): window size
127
+ Returns:
128
+ windows: (num_windows*B, window_size, window_size, C)
129
+ """
130
+ B, H, W, C = x.shape
131
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
132
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
133
+ return windows
134
+
135
+
136
+ def window_reverse(windows, window_size, H, W):
137
+ """
138
+ Args:
139
+ windows: (num_windows*B, window_size, window_size, C)
140
+ window_size (int): Window size
141
+ H (int): Height of image
142
+ W (int): Width of image
143
+ Returns:
144
+ x: (B, H, W, C)
145
+ """
146
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
147
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
148
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
149
+ return x
150
+
151
+
152
+ class WindowAttention(nn.Module):
153
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
154
+ It supports both of shifted and non-shifted window.
155
+ Args:
156
+ dim (int): Number of input channels.
157
+ window_size (tuple[int]): The height and width of the window.
158
+ num_heads (int): Number of attention heads.
159
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
160
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
161
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
162
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
163
+ """
164
+
165
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
166
+ pretrained_window_size=[0, 0]):
167
+
168
+ super().__init__()
169
+ self.dim = dim
170
+ self.window_size = window_size # Wh, Ww
171
+ self.pretrained_window_size = pretrained_window_size
172
+ self.num_heads = num_heads
173
+
174
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
175
+
176
+ # mlp to generate continuous relative position bias
177
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
178
+ nn.ReLU(inplace=True),
179
+ nn.Linear(512, num_heads, bias=False))
180
+
181
+ # get relative_coords_table
182
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
183
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
184
+ relative_coords_table = torch.stack(
185
+ torch.meshgrid([relative_coords_h,
186
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
187
+ if pretrained_window_size[0] > 0:
188
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
189
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
190
+ else:
191
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
192
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
193
+ relative_coords_table *= 8 # normalize to -8, 8
194
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
195
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
196
+
197
+ self.register_buffer("relative_coords_table", relative_coords_table)
198
+
199
+ # get pair-wise relative position index for each token inside the window
200
+ coords_h = torch.arange(self.window_size[0])
201
+ coords_w = torch.arange(self.window_size[1])
202
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
203
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
204
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
205
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
206
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
207
+ relative_coords[:, :, 1] += self.window_size[1] - 1
208
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
209
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
210
+ self.register_buffer("relative_position_index", relative_position_index)
211
+
212
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
213
+ if qkv_bias:
214
+ self.q_bias = nn.Parameter(torch.zeros(dim))
215
+ self.v_bias = nn.Parameter(torch.zeros(dim))
216
+ else:
217
+ self.q_bias = None
218
+ self.v_bias = None
219
+ self.attn_drop = nn.Dropout(attn_drop)
220
+ self.proj = nn.Linear(dim, dim)
221
+ self.proj_drop = nn.Dropout(proj_drop)
222
+ self.softmax = nn.Softmax(dim=-1)
223
+
224
+ def forward(self, x, mask=None, v_length=None):
225
+ """
226
+ Args:
227
+ x: input features with shape of (num_windows*B, N, C)
228
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
229
+ """
230
+ B_, N, C = x.shape
231
+ qkv_bias = None
232
+ if self.q_bias is not None:
233
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
234
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
235
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
236
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
237
+
238
+ # cosine attention
239
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
240
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
241
+ attn = attn * logit_scale
242
+
243
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
244
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
245
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
246
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
247
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
248
+ attn[..., :v_length, :v_length] = attn[..., :v_length, :v_length] + relative_position_bias.unsqueeze(0)
249
+
250
+ if mask is not None:
251
+ nW = mask.shape[0]
252
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
253
+ attn = attn.view(-1, self.num_heads, N, N)
254
+ attn = self.softmax(attn)
255
+ else:
256
+ attn = self.softmax(attn)
257
+
258
+ attn = self.attn_drop(attn)
259
+
260
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
261
+ x = self.proj(x)
262
+ x = self.proj_drop(x)
263
+ return x
264
+
265
+ def extra_repr(self) -> str:
266
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
267
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
268
+
269
+ def flops(self, N):
270
+ # calculate flops for 1 window with token length of N
271
+ flops = 0
272
+ # qkv = self.qkv(x)
273
+ flops += N * self.dim * 3 * self.dim
274
+ # attn = (q @ k.transpose(-2, -1))
275
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
276
+ # x = (attn @ v)
277
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
278
+ # x = self.proj(x)
279
+ flops += N * self.dim * self.dim
280
+ return flops
281
+
282
+
283
+ class SwinTransformerBlock(nn.Module):
284
+ r""" Swin Transformer Block.
285
+ Args:
286
+ dim (int): Number of input channels.
287
+ input_resolution (tuple[int]): Input resulotion.
288
+ num_heads (int): Number of attention heads.
289
+ window_size (int): Window size.
290
+ shift_size (int): Shift size for SW-MSA.
291
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
292
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
293
+ drop (float, optional): Dropout rate. Default: 0.0
294
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
295
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
296
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
297
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
298
+ pretrained_window_size (int): Window size in pre-training.
299
+ """
300
+
301
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
302
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
303
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0, lm_d_model=None):
304
+ super().__init__()
305
+ self.dim = dim
306
+ self.input_resolution = input_resolution
307
+ self.num_heads = num_heads
308
+ self.window_size = window_size
309
+ self.shift_size = shift_size
310
+ self.mlp_ratio = mlp_ratio
311
+ if min(self.input_resolution) <= self.window_size:
312
+ # if window size is larger than input resolution, we don't partition windows
313
+ self.shift_size = 0
314
+ self.window_size = min(self.input_resolution)
315
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
316
+
317
+ self.norm1 = norm_layer(dim)
318
+ self.attn = WindowAttention(
319
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
320
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
321
+ pretrained_window_size=to_2tuple(pretrained_window_size))
322
+
323
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
324
+ self.norm2 = norm_layer(dim)
325
+ mlp_hidden_dim = int(dim * mlp_ratio)
326
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
327
+
328
+ if self.shift_size > 0:
329
+ # calculate attention mask for SW-MSA
330
+ H, W = self.input_resolution
331
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
332
+ h_slices = (slice(0, -self.window_size),
333
+ slice(-self.window_size, -self.shift_size),
334
+ slice(-self.shift_size, None))
335
+ w_slices = (slice(0, -self.window_size),
336
+ slice(-self.window_size, -self.shift_size),
337
+ slice(-self.shift_size, None))
338
+ cnt = 0
339
+ for h in h_slices:
340
+ for w in w_slices:
341
+ img_mask[:, h, w, :] = cnt
342
+ cnt += 1
343
+
344
+ # mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
345
+ # mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
346
+ # attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
347
+ # attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
348
+ # else:
349
+ # attn_mask = None
350
+
351
+ # self.register_buffer("attn_mask", attn_mask)
352
+
353
+ def forward(self, x, context_prompts=None):
354
+ # H, W = self.input_resolution
355
+ # B, L, C = x.shape
356
+ # assert L == H * W, "input feature has wrong size"
357
+
358
+ # shortcut = x
359
+ # x = x.view(B, H, W, C)
360
+
361
+ # # cyclic shift
362
+ # if self.shift_size > 0:
363
+ # shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
364
+ # else:
365
+ # shifted_x = x
366
+
367
+ B, L, C = x.shape
368
+ H, W = self.input_resolution
369
+ assert L == H * W, "input feature has wrong size"
370
+
371
+ shortcut = x
372
+ # x = self.norm1(x)
373
+ x = x.view(B, H, W, C)
374
+
375
+ # pad feature maps to multiples of window size
376
+ pad_l = pad_t = 0
377
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
378
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
379
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
380
+ _, Hp, Wp, _ = x.shape
381
+
382
+ # cyclic shift
383
+ if self.shift_size > 0:
384
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
385
+ # attn_mask = mask_matrix
386
+ else:
387
+ shifted_x = x
388
+ # attn_mask = None
389
+
390
+
391
+ # partition windows
392
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
393
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
394
+
395
+ # W-MSA/SW-MSA
396
+ attn_windows = self.attn(x_windows, v_length=self.window_size * self.window_size) # , mask=self.attn_mask) # nW*B, window_size*window_size, C
397
+
398
+ # merge windows
399
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
400
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
401
+
402
+ # reverse cyclic shift
403
+ if self.shift_size > 0:
404
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
405
+ else:
406
+ x = shifted_x
407
+
408
+ if pad_r > 0 or pad_b > 0:
409
+ x = x[:, :H, :W, :].contiguous()
410
+
411
+ x = x.view(B, H * W, C)
412
+
413
+ x = shortcut + self.drop_path(self.norm1(x))
414
+
415
+ # FFN
416
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
417
+
418
+ return x
419
+
420
+ def extra_repr(self) -> str:
421
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
422
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
423
+
424
+ def flops(self):
425
+ flops = 0
426
+ H, W = self.input_resolution
427
+ # norm1
428
+ flops += self.dim * H * W
429
+ # W-MSA/SW-MSA
430
+ nW = H * W / self.window_size / self.window_size
431
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
432
+ # mlp
433
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
434
+ # norm2
435
+ flops += self.dim * H * W
436
+ return flops
437
+
438
+
439
+ class Vilma(nn.Module):
440
+ r""" Vision-Language Marge Attention layer.
441
+ """
442
+
443
+ def __init__(self,
444
+ input_resolution,
445
+ dim,
446
+ num_heads,
447
+ lm_d_model,
448
+ vl_learned_ape=True,
449
+ norm_layer=nn.LayerNorm,
450
+ reduce=True,
451
+ **kwargs):
452
+ super().__init__()
453
+ self.input_resolution = input_resolution
454
+ self.dim = dim
455
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) if reduce else nn.Linear(4 * dim, 4 * dim, bias=False)
456
+ self.norm = norm_layer(2 * dim) if reduce else norm_layer(4 * dim)
457
+ self.cross_attn = CrossAttention(dim=dim * 4,
458
+ kv_dim=lm_d_model,
459
+ context_length=self.input_resolution[0] // 2 * self.input_resolution[1] // 2,
460
+ output_dim=dim * 4,
461
+ num_heads=num_heads,
462
+ learned_ape=vl_learned_ape
463
+ )
464
+ nn.init.eye_(self.cross_attn.q_proj.weight)
465
+ nn.init.constant_(self.cross_attn.q_proj.bias, 0)
466
+ self.cross_attn.q_proj.requires_grad_(False)
467
+ self.vl_alpha = 0.5
468
+
469
+ def forward(self, x, context_prompts, **kwargs):
470
+ """
471
+ x: B, H*W, C
472
+ """
473
+ H, W = self.input_resolution
474
+ B, L, C = x.shape
475
+ assert L == H * W, "input feature has wrong size"
476
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
477
+
478
+ x = x.view(B, H, W, C)
479
+
480
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
481
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
482
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
483
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
484
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
485
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
486
+
487
+ x_vl = self.cross_attn(x, context_prompts)
488
+ x = self.vl_alpha * x_vl + (1 - self.vl_alpha) * x
489
+
490
+ x = self.reduction(x)
491
+ x = self.norm(x)
492
+
493
+ return x
494
+
495
+ def extra_repr(self) -> str:
496
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
497
+
498
+ def flops(self):
499
+ H, W = self.input_resolution
500
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
501
+ flops += H * W * self.dim // 2
502
+ return flops
503
+
504
+
505
+ class BasicLayer(nn.Module):
506
+ """ A basic Swin Transformer layer for one stage.
507
+ Args:
508
+ dim (int): Number of input channels.
509
+ input_resolution (tuple[int]): Input resolution.
510
+ depth (int): Number of blocks.
511
+ num_heads (int): Number of attention heads.
512
+ window_size (int): Local window size.
513
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
514
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
515
+ drop (float, optional): Dropout rate. Default: 0.0
516
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
517
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
518
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
519
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
520
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
521
+ pretrained_window_size (int): Local window size in pre-training.
522
+ """
523
+
524
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
525
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
526
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
527
+ pretrained_window_size=0, do_shift=True, lm_d_model=None):
528
+
529
+ super().__init__()
530
+ self.dim = dim
531
+ self.input_resolution = input_resolution
532
+ self.depth = depth if do_shift else 1 # do not add SWA layers
533
+ self.use_checkpoint = use_checkpoint
534
+ # build blocks
535
+ self.blocks = nn.ModuleList([
536
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
537
+ num_heads=num_heads, window_size=window_size,
538
+ shift_size=0 if ((i % 2 == 0) or (not do_shift)) else window_size // 2,
539
+ mlp_ratio=mlp_ratio,
540
+ qkv_bias=qkv_bias,
541
+ drop=drop, attn_drop=attn_drop,
542
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
543
+ norm_layer=norm_layer,
544
+ pretrained_window_size=pretrained_window_size,
545
+ lm_d_model=lm_d_model)
546
+ for i in range(self.depth)])
547
+
548
+ # patch merging layer
549
+ if downsample is not None:
550
+ self.downsample = downsample(input_resolution=input_resolution,
551
+ dim=dim,
552
+ norm_layer=norm_layer,
553
+ num_heads=num_heads,
554
+ lm_d_model=lm_d_model
555
+ )
556
+ else:
557
+ self.downsample = None
558
+
559
+ def forward(self, x, context_prompts=None):
560
+ for blk in self.blocks:
561
+ if self.use_checkpoint:
562
+ x = checkpoint.checkpoint(blk, x)
563
+ else:
564
+ x = blk(x, context_prompts=context_prompts)
565
+ if self.downsample is not None:
566
+ x = self.downsample(x, context_prompts=context_prompts)
567
+ return x
568
+
569
+ def extra_repr(self) -> str:
570
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
571
+
572
+ def flops(self):
573
+ flops = 0
574
+ for blk in self.blocks:
575
+ flops += blk.flops()
576
+ if self.downsample is not None:
577
+ flops += self.downsample.flops()
578
+ return flops
579
+
580
+ def _init_respostnorm(self):
581
+ for blk in self.blocks:
582
+ nn.init.constant_(blk.norm1.bias, 0)
583
+ nn.init.constant_(blk.norm1.weight, 0)
584
+ nn.init.constant_(blk.norm2.bias, 0)
585
+ nn.init.constant_(blk.norm2.weight, 0)
586
+
587
+
588
+ class PatchEmbed(nn.Module):
589
+ r""" Image to Patch Embedding
590
+ Args:
591
+ img_size (int or tuple): Image size. Default: 224.
592
+ patch_size (int): Patch token size. Default: 4.
593
+ in_chans (int): Number of input image channels. Default: 3.
594
+ embed_dim (int): Number of linear projection output channels. Default: 96.
595
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
596
+ """
597
+
598
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
599
+ super().__init__()
600
+ img_size = to_2tuple(img_size)
601
+ patch_size = to_2tuple(patch_size)
602
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
603
+ self.img_size = img_size
604
+ self.patch_size = patch_size
605
+ self.patches_resolution = patches_resolution
606
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
607
+
608
+ self.in_chans = in_chans
609
+ self.embed_dim = embed_dim
610
+
611
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
612
+ if norm_layer is not None:
613
+ self.norm = norm_layer(embed_dim)
614
+ else:
615
+ self.norm = None
616
+
617
+ def forward(self, x):
618
+ B, C, H, W = x.shape
619
+ # FIXME look at relaxing size constraints
620
+ assert H == self.img_size[0] and W == self.img_size[1], \
621
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
622
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
623
+ if self.norm is not None:
624
+ x = self.norm(x)
625
+ return x
626
+
627
+ def flops(self):
628
+ Ho, Wo = self.patches_resolution
629
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
630
+ if self.norm is not None:
631
+ flops += Ho * Wo * self.embed_dim
632
+ return flops
633
+
634
+
635
+ class PatchEmbed1D(nn.Module):
636
+ r""" 1D Image to Patch Embedding (if for example patches are prextracted)
637
+ Args:
638
+ img_size (int or tuple): Image size. Default: 224.
639
+ patch_size (int): Patch token size. Default: 4.
640
+ in_chans (int): Number of input image channels. Default: 3.
641
+ embed_dim (int): Number of linear projection output channels. Default: 96.
642
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
643
+ """
644
+
645
+ def __init__(self, in_chans=3, embed_dim=96, norm_layer=None, img_size=-1, patch_size=-1, **kwargs):
646
+ super().__init__()
647
+ patch_size = to_2tuple(patch_size)
648
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
649
+ self.img_size = img_size
650
+ self.patch_size = patch_size
651
+ self.patches_resolution = patches_resolution
652
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
653
+
654
+ self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=1, stride=1)
655
+ if norm_layer is not None:
656
+ self.norm = norm_layer(embed_dim)
657
+ else:
658
+ self.norm = None
659
+
660
+ def forward(self, x):
661
+ B, L, C = x.shape # [batch, num_patches, numof_patch_pixels]
662
+ x = x.permute(0, 2, 1)
663
+ x = self.proj(x).flatten(2).permute(0, 2, 1) # B Ph*Pw C
664
+ if self.norm is not None:
665
+ x = self.norm(x)
666
+ return x
667
+
668
+
669
+ class VilmaSwinTransformerV2(nn.Module):
670
+ r""" Swin Transformer with Vilma downsampling and cross attention layers
671
+ borrow from https://github.com/microsoft/Swin-Transformer-V2/blob/main/models/swin_transformer_v2.py
672
+ """
673
+
674
+ def __init__(self, img_size=224, patch_size=4, in_chans=3,
675
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
676
+ window_size=7, mlp_ratio=4., qkv_bias=True,
677
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
678
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
679
+ use_checkpoint=False, pretrained_window_sizes=[0, 0, 0, 0],
680
+ embedd_matcher_dim=512, do_shift=True,
681
+ vl_cross_attn_layers=[], vl_alpha=0.5, lm_d_model=512,
682
+ input_type='rgb', vl_learned_ape=True):
683
+ super().__init__()
684
+ self.model_name = 'swin_v2'
685
+
686
+ self.num_layers = len(depths)
687
+ self.embed_dim = embed_dim
688
+ self.ape = ape
689
+ self.patch_norm = patch_norm
690
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
691
+ self.mlp_ratio = mlp_ratio
692
+ self.input_type = input_type
693
+
694
+ # split image into non-overlapping patches
695
+ self.patch_embed = PatchEmbed(
696
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
697
+ norm_layer=norm_layer if self.patch_norm else None)
698
+
699
+ num_patches = self.patch_embed.num_patches
700
+ patches_resolution = self.patch_embed.patches_resolution
701
+ self.patches_resolution = patches_resolution
702
+
703
+ # absolute position embedding
704
+ if self.ape:
705
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
706
+ trunc_normal_(self.absolute_pos_embed, std=.02)
707
+
708
+ self.pos_drop = nn.Dropout(p=drop_rate)
709
+
710
+ # stochastic depth
711
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
712
+
713
+ self.vl_cross_attn_layers = nn.ModuleDict({str(i): None for i in vl_cross_attn_layers})
714
+ self.vl_alpha = vl_alpha
715
+
716
+ # build layers
717
+ self.layers = nn.ModuleList()
718
+ for i_layer in range(self.num_layers):
719
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
720
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
721
+ patches_resolution[1] // (2 ** i_layer)),
722
+ depth=depths[i_layer],
723
+ num_heads=num_heads[i_layer],
724
+ window_size=window_size,
725
+ mlp_ratio=self.mlp_ratio,
726
+ qkv_bias=qkv_bias,
727
+ drop=drop_rate, attn_drop=attn_drop_rate,
728
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
729
+ norm_layer=norm_layer,
730
+ downsample=Vilma if (i_layer < self.num_layers - 1) else None,
731
+ use_checkpoint=use_checkpoint,
732
+ pretrained_window_size=pretrained_window_sizes[i_layer],
733
+ do_shift=do_shift,
734
+ lm_d_model=lm_d_model)
735
+ self.layers.append(layer)
736
+ if str(i_layer) in self.vl_cross_attn_layers:
737
+ layer_factor = i_layer + int(i_layer < self.num_layers - 1)
738
+ self.vl_cross_attn_layers.update({
739
+ str(i_layer): CrossAttention(
740
+ dim=int(embed_dim * 2 ** layer_factor),
741
+ kv_dim=lm_d_model,
742
+ context_length=patches_resolution[0] // (2 ** layer_factor) * patches_resolution[1] // (2 ** layer_factor),
743
+ num_heads=num_heads[i_layer],
744
+ vl_learned_ape=vl_learned_ape)
745
+ })
746
+
747
+ self.norm = norm_layer(self.num_features)
748
+
749
+ self.embedd_matcher_dim = embedd_matcher_dim
750
+
751
+ self.apply(self._init_weights)
752
+ for bly in self.layers:
753
+ bly._init_respostnorm()
754
+
755
+
756
+ def _init_weights(self, m):
757
+ if isinstance(m, nn.Linear):
758
+ trunc_normal_(m.weight, std=.02)
759
+ if isinstance(m, nn.Linear) and m.bias is not None:
760
+ nn.init.constant_(m.bias, 0)
761
+ elif isinstance(m, nn.LayerNorm):
762
+ nn.init.constant_(m.bias, 0)
763
+ nn.init.constant_(m.weight, 1.0)
764
+
765
+ @torch.jit.ignore
766
+ def no_weight_decay(self):
767
+ return {'absolute_pos_embed'}
768
+
769
+ @torch.jit.ignore
770
+ def no_weight_decay_keywords(self):
771
+ return {"cpb_mlp", "logit_scale", 'relative_position_bias_table'}
772
+
773
+ def forward_features(self, x, context_prompts=None):
774
+ x = self.patch_embed(x)
775
+ if self.ape:
776
+ x = x + self.absolute_pos_embed
777
+ x = self.pos_drop(x)
778
+
779
+ for i, layer in enumerate(self.layers):
780
+ assert context_prompts is not None, 'Context prompt is None'
781
+ x = layer(x, context_prompts)
782
+ x_vl = self.vl_cross_attn_layers[str(i)](x, context_prompts)
783
+ x = self.vl_alpha * x_vl + (1 - self.vl_alpha) * x
784
+ x = self.norm(x) # B L C
785
+ return x
786
+
787
+ def forward(self, x, **kwargs):
788
+ x = self.forward_features(x, **kwargs)
789
+ return x
790
+
791
+ def flops(self):
792
+ flops = 0
793
+ flops += self.patch_embed.flops()
794
+ for i, layer in enumerate(self.layers):
795
+ flops += layer.flops()
796
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
797
+ return flops
modeling_visfocus.py ADDED
@@ -0,0 +1,810 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import LayerNorm, CrossEntropyLoss, L1Loss
5
+ from torch.nn import functional as F
6
+
7
+ from transformers import PreTrainedModel, T5Tokenizer, T5Model, logging
8
+ from transformers.models.t5.modeling_t5 import T5Stack
9
+ from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
10
+ from transformers.file_utils import ModelOutput
11
+
12
+ from timm.models.layers import trunc_normal_
13
+ from typing import Any, Dict, Optional, Tuple
14
+ import warnings
15
+ import random
16
+ import yaml
17
+ import copy
18
+ from easydict import EasyDict
19
+
20
+ from .configuration_visfocus import VisFocusConfig
21
+ from .modeling_vilmaswin import VilmaSwinTransformerV2
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ def get_vision_model(config):
27
+ vision_model = VilmaSwinTransformerV2(
28
+ img_size=config.image_size,
29
+ patch_size=config.patch_size,
30
+ in_chans=config.in_chans,
31
+ embed_dim=config.embed_dim,
32
+ depths=config.depths,
33
+ num_heads=config.num_heads,
34
+ window_size=config.window_size,
35
+ mlp_ratio=config.mlp_ratio,
36
+ qkv_bias=config.qkv_bias,
37
+ drop_rate=config.drop_rate,
38
+ drop_path_rate=config.drop_path_rate,
39
+ ape=config.ape,
40
+ patch_norm=config.patch_norm,
41
+ use_checkpoint=config.use_checkpoint,
42
+ pretrained_window_sizes=config.pretrained_window_sizes,
43
+ do_shift=config.do_shift,
44
+ vl_cross_attn_layers=config.vl_cross_attn_layers,
45
+ vl_alpha=config.vl_alpha,
46
+ lm_d_model=config.lm_d_model,
47
+ input_type=config.input_type,
48
+ vl_learned_ape=config.vl_learned_ape)
49
+ return vision_model
50
+
51
+
52
+ def load_vision_pretrained(configs, model):
53
+ logger.info("Loading vision model from %s", configs.model.vision_resume_from)
54
+ if configs.model.vision_resume_from.startswith("https"):
55
+ checkpoint = torch.hub.load_state_dict_from_url(
56
+ configs.model.vision_resume_from, map_location="cpu", check_hash=True
57
+ )
58
+ else:
59
+ checkpoint = torch.load(configs.model.vision_resume_from, map_location="cpu")
60
+
61
+ state_dict = checkpoint["model"]
62
+
63
+ if "swin" in configs.model.type:
64
+ # delete relative_position_index since we always re-init it
65
+ relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
66
+ for k in relative_position_index_keys:
67
+ del state_dict[k]
68
+
69
+ # delete relative_coords_table since we always re-init it
70
+ relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
71
+ for k in relative_position_index_keys:
72
+ del state_dict[k]
73
+
74
+ # delete attn_mask since we always re-init it
75
+ attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
76
+ for k in attn_mask_keys:
77
+ del state_dict[k]
78
+
79
+ # bicubic interpolate relative_position_bias_table if not match
80
+ relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
81
+ for k in relative_position_bias_table_keys:
82
+ relative_position_bias_table_pretrained = state_dict[k]
83
+ relative_position_bias_table_current = model.vision_model.state_dict()[k]
84
+ L1, nH1 = relative_position_bias_table_pretrained.size()
85
+ L2, nH2 = relative_position_bias_table_current.size()
86
+ if nH1 != nH2:
87
+ logger.warning(f"Error in loading {k}, passing......")
88
+ else:
89
+ if L1 != L2:
90
+ # bicubic interpolate relative_position_bias_table if not match
91
+ S1 = int(L1 ** 0.5)
92
+ S2 = int(L2 ** 0.5)
93
+ relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
94
+ relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
95
+ mode='bicubic')
96
+ state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
97
+
98
+ # bicubic interpolate absolute_pos_embed if not match
99
+ absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k]
100
+ for k in absolute_pos_embed_keys:
101
+ # dpe
102
+ absolute_pos_embed_pretrained = state_dict[k]
103
+ absolute_pos_embed_current = model.vision_model.state_dict()[k]
104
+ _, L1, C1 = absolute_pos_embed_pretrained.size()
105
+ _, L2, C2 = absolute_pos_embed_current.size()
106
+ if C1 != C1:
107
+ logger.warning(f"Error in loading {k}, passing......")
108
+ else:
109
+ if L1 != L2:
110
+ S1 = int(L1 ** 0.5)
111
+ S2 = int(L2 ** 0.5)
112
+ absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
113
+ absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
114
+ absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
115
+ absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
116
+ absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
117
+ absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
118
+ state_dict[k] = absolute_pos_embed_pretrained_resized
119
+
120
+ if model.vision_model.patch_embed.proj.weight.shape != state_dict['patch_embed.proj.weight'].shape:
121
+ model.vision_model.input_type == 'flattened_patches'
122
+ logger.warning(f"PatchEmbed (patch_embed) was not loaded, because input_type is falttened_patches.")
123
+ del state_dict['patch_embed.proj.weight']
124
+
125
+
126
+ # import pdb;pdb.set_trace()
127
+ msg = model.vision_model.load_state_dict(state_dict, strict=False)
128
+
129
+ # do not print unnecessary (vl attn is not loaded now)
130
+ filtered_missing_keys = {k for k in msg.missing_keys
131
+ if 'vl_cross_attn_layers' not in k
132
+ or 'relative_position' not in k}
133
+ filtered_missing_keys.union({'relative_position' for k in msg.missing_keys
134
+ if 'relative_position' not in k})
135
+ # if len({k for k in msg.missing_keys if 'relative_' in k}) > 0:
136
+ # logger.warning(f'Relative position were not loaded')
137
+ # filtered_missing_keys.union()
138
+ logger.warning(f'Missing keys: {set(msg.missing_keys) - filtered_missing_keys}')
139
+ logger.warning(f'Unexpected keys: {msg.unexpected_keys}')
140
+
141
+ # logger.warning(msg)
142
+
143
+ logger.info("Loaded model successfully from %s", configs.model.vision_resume_from)
144
+
145
+ del checkpoint
146
+ torch.cuda.empty_cache()
147
+
148
+
149
+ class T5_Encoder(nn.Module):
150
+ def __init__(self, t5_variant='base', freeze=True):
151
+ super().__init__()
152
+ self.tokenizer = T5Tokenizer.from_pretrained(f'{t5_variant}')
153
+ model = T5Model.from_pretrained(f'{t5_variant}')
154
+ del model.decoder
155
+ self.encoder = model.encoder
156
+ if freeze:
157
+ for p in self.encoder.parameters():
158
+ p.requires_grad = False
159
+
160
+ def forward(self, input_ids):
161
+ encoder_outputs = self.encoder(
162
+ input_ids=input_ids,
163
+ return_dict=True,
164
+ )
165
+ return encoder_outputs[0]
166
+
167
+
168
+ class SpatialEmbeddings(nn.Module):
169
+ def __init__(self, config):
170
+ super().__init__()
171
+
172
+ self.x_position_embeddings = nn.Embedding(
173
+ config.max_2d_position_embeddings, config.hidden_size
174
+ )
175
+ self.y_position_embeddings = nn.Embedding(
176
+ config.max_2d_position_embeddings, config.hidden_size
177
+ )
178
+ self.h_position_embeddings = nn.Embedding(
179
+ config.max_2d_position_embeddings, config.hidden_size
180
+ )
181
+ self.w_position_embeddings = nn.Embedding(
182
+ config.max_2d_position_embeddings, config.hidden_size
183
+ )
184
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
185
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
186
+
187
+ self.config = config
188
+
189
+ def forward(
190
+ self,
191
+ bbox,
192
+ ):
193
+ seq_length = bbox.size(1)
194
+
195
+ left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
196
+ upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
197
+ right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
198
+ lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
199
+ h_position_embeddings = self.h_position_embeddings(
200
+ bbox[:, :, 3] - bbox[:, :, 1]
201
+ )
202
+ w_position_embeddings = self.w_position_embeddings(
203
+ bbox[:, :, 2] - bbox[:, :, 0]
204
+ )
205
+ embeddings = (
206
+ left_position_embeddings
207
+ + upper_position_embeddings
208
+ + right_position_embeddings
209
+ + lower_position_embeddings
210
+ + h_position_embeddings
211
+ + w_position_embeddings
212
+ )
213
+
214
+ embeddings = self.LayerNorm(embeddings)
215
+ embeddings = self.dropout(embeddings)
216
+ return embeddings
217
+
218
+
219
+ class EmbedMatcher(nn.Module):
220
+ def __init__(self, input_dim, inner_dim, output_dim, dropout_rate=0.1):
221
+ super().__init__()
222
+ self.embedd_matcher = nn.Sequential(
223
+ nn.Linear(input_dim, inner_dim, bias=True),
224
+ nn.ReLU(inplace=True),
225
+ nn.Dropout(dropout_rate),
226
+ nn.Linear(inner_dim, output_dim, bias=False),
227
+ nn.Dropout(dropout_rate)
228
+ )
229
+
230
+ self.apply(self._init_weights)
231
+
232
+ def _init_weights(self, m):
233
+ if isinstance(m, nn.Linear):
234
+ trunc_normal_(m.weight, std=.02)
235
+ if isinstance(m, nn.Linear) and m.bias is not None:
236
+ nn.init.constant_(m.bias, 0)
237
+
238
+ def forward(self, x):
239
+ x = self.embedd_matcher(x)
240
+ return x
241
+
242
+
243
+ class MLP(nn.Module):
244
+ """ Very simple multi-layer perceptron (also called FFN)"""
245
+
246
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
247
+ super().__init__()
248
+ self.num_layers = num_layers
249
+ h = [hidden_dim] * (num_layers - 1)
250
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
251
+
252
+ def forward(self, x):
253
+ for i, layer in enumerate(self.layers):
254
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
255
+ return x
256
+
257
+
258
+ class VisFocusModel(PreTrainedModel):
259
+ config_class = VisFocusConfig
260
+
261
+ def __init__(self, config):
262
+ super().__init__(config.lm_config)
263
+ self.set_task_name('ocr')
264
+ self.model_arch = 'visfocus'
265
+ self.config = config
266
+ self.lm_config = config.lm_config
267
+ self.vision_config = config.vision_config
268
+
269
+ self.vision_model = get_vision_model(self.vision_config)
270
+
271
+ input_dim = self.vision_model.num_features
272
+ matcher = MATCHER_MAP[self.config.matcher_type]
273
+
274
+ # load T5 encoder and decoder
275
+ encoder_config = copy.deepcopy(self.lm_config)
276
+ encoder_config.is_decoder = False
277
+ encoder_config.use_cache = False
278
+ encoder_config.is_encoder_decoder = False
279
+ self.encoder = T5Stack(encoder_config)
280
+
281
+ decoder_config = copy.deepcopy(self.lm_config)
282
+ decoder_config.is_decoder = True
283
+ decoder_config.is_encoder_decoder = False
284
+ decoder_config.num_layers = self.lm_config.num_decoder_layers
285
+ self.decoder = T5Stack(decoder_config)
286
+ self.lm_head = nn.Linear(self.lm_config.d_model, self.lm_config.vocab_size, bias=False)
287
+
288
+ if hasattr(self.vision_model, 'last_ds'):
289
+ input_dim = self.vision_model.last_ds.norm.normalized_shape[0]
290
+
291
+ self.vision_embed_matcher = matcher(
292
+ input_dim,
293
+ config.lm_config.hidden_size,
294
+ config.lm_config.hidden_size,
295
+ config.hidden_dropout_prob
296
+ )
297
+
298
+ # losses
299
+ self.loss_fct = CrossEntropyLoss(ignore_index=-100)
300
+
301
+ self.init_weights()
302
+
303
+ if self.config.lora is not None:
304
+ self.apply_lora()
305
+
306
+ if self.config.vl_l1_loss:
307
+ self.vl_l1_loss_fct = L1Loss()
308
+
309
+ def encoder_decoder_forward(
310
+ self,
311
+ input_ids=None,
312
+ attention_mask=None,
313
+ decoder_input_ids=None,
314
+ decoder_attention_mask=None,
315
+ head_mask=None,
316
+ decoder_head_mask=None,
317
+ encoder_outputs=None,
318
+ past_key_values=None,
319
+ inputs_embeds=None,
320
+ decoder_inputs_embeds=None,
321
+ labels=None,
322
+ use_cache=None,
323
+ output_attentions=None,
324
+ output_hidden_states=None,
325
+ return_dict=None,
326
+ **kwargs,
327
+ ):
328
+ r"""
329
+ https://huggingface.co/transformers/v4.5.1/_modules/transformers/modeling_t5.html#T5ForConditionalGeneration.forward
330
+ or https://huggingface.co/transformers/_modules/transformers/modeling_t5.html#T5ForConditionalGeneration.forward
331
+ """
332
+
333
+ if "lm_labels" in kwargs:
334
+ warnings.warn(
335
+ "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
336
+ FutureWarning,
337
+ )
338
+ labels = kwargs.pop("lm_labels")
339
+ if "decoder_past_key_value_states" in kwargs:
340
+ warnings.warn(
341
+ "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
342
+ FutureWarning,
343
+ )
344
+ past_key_values = kwargs.pop("decoder_past_key_value_states")
345
+ if "decoder_past_key_values" in kwargs:
346
+ warnings.warn(
347
+ "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
348
+ FutureWarning,
349
+ )
350
+ past_key_values = kwargs.pop("decoder_past_key_values")
351
+ assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
352
+
353
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
354
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
355
+
356
+ # Encode if needed (training, first prediction pass)
357
+ if encoder_outputs is None:
358
+ # Convert encoder inputs in embeddings if needed
359
+ encoder_outputs = self.encoder(
360
+ input_ids=input_ids,
361
+ attention_mask=attention_mask,
362
+ inputs_embeds=inputs_embeds,
363
+ head_mask=head_mask,
364
+ output_attentions=output_attentions,
365
+ output_hidden_states=output_hidden_states,
366
+ return_dict=return_dict,
367
+ )
368
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
369
+ encoder_outputs = BaseModelOutput(
370
+ last_hidden_state=encoder_outputs[0],
371
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
372
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
373
+ )
374
+
375
+ hidden_states = encoder_outputs[0]
376
+
377
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
378
+ # get decoder inputs from shifting lm labels to the right
379
+ decoder_input_ids = self._shift_right(labels)
380
+
381
+ # If decoding with past key value states, only the last tokens
382
+ # should be given as an input
383
+ if past_key_values is not None:
384
+ assert labels is None, "Decoder should not use cached key value states when training."
385
+ if decoder_input_ids is not None:
386
+ decoder_input_ids = decoder_input_ids[:, -1:]
387
+ if decoder_inputs_embeds is not None:
388
+ decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
389
+
390
+ # Decode
391
+ decoder_outputs = self.decoder(
392
+ input_ids=decoder_input_ids,
393
+ attention_mask=decoder_attention_mask,
394
+ inputs_embeds=decoder_inputs_embeds,
395
+ past_key_values=past_key_values,
396
+ encoder_hidden_states=hidden_states,
397
+ encoder_attention_mask=attention_mask,
398
+ head_mask=head_mask,
399
+ use_cache=use_cache,
400
+ output_attentions=output_attentions,
401
+ output_hidden_states=output_hidden_states,
402
+ return_dict=return_dict,
403
+ )
404
+ sequence_output = decoder_outputs[0]
405
+ # Rescale output before projecting on vocab
406
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
407
+ sequence_output = sequence_output * (self.model_dim ** -0.5)
408
+ lm_logits = self.lm_head(sequence_output)
409
+
410
+ loss = None
411
+ if labels is not None:
412
+ loss = self.loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
413
+
414
+ if self.config.vl_l1_loss:
415
+ labels_ = labels.clone()
416
+ labels_[labels_ == -100] = self.input_tokenizer.pad_token_id # -> replace the ignore_index with the pad_token id to calculate the text target for the vl loss
417
+ with torch.no_grad():
418
+ target = self.encoder(input_ids=labels_).last_hidden_state
419
+ if target.shape[1] != hidden_states.shape[1]:
420
+ v_encoder_intrp = F.interpolate(hidden_states.permute(0,2,1), size=target.shape[1], mode='linear').permute(0,2,1)
421
+ vl_loss = (50 * self.vl_l1_loss_fct(v_encoder_intrp, target))
422
+ loss += vl_loss
423
+
424
+ if not return_dict:
425
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
426
+ if loss is not None:
427
+ output = ((loss,) + output)
428
+
429
+ return output
430
+
431
+ seq2seq_output = Seq2SeqLMOutput(
432
+ loss=loss,
433
+ logits=lm_logits,
434
+ past_key_values=decoder_outputs.past_key_values,
435
+ decoder_hidden_states=decoder_outputs.hidden_states,
436
+ decoder_attentions=decoder_outputs.attentions,
437
+ cross_attentions=decoder_outputs.cross_attentions,
438
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
439
+ encoder_hidden_states=encoder_outputs.hidden_states,
440
+ encoder_attentions=encoder_outputs.attentions,
441
+ )
442
+
443
+ return seq2seq_output
444
+
445
+ def forward(self,
446
+ input_ids=None,
447
+ bbox=None,
448
+ image=None,
449
+ attention_mask=None,
450
+ head_mask=None,
451
+ inputs_embeds=None,
452
+ encoder_hidden_states=None,
453
+ encoder_attention_mask=None,
454
+ labels=None,
455
+ **kwargs):
456
+ # see https://huggingface.co/transformers/v2.10.0/_modules/transformers/modeling_t5.html#T5Model.forward
457
+
458
+ if not kwargs.get('encoder_outputs'):
459
+ _, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=None, image=image)
460
+ else:
461
+ # for generation mode
462
+ assert kwargs.get('decoder_input_ids') is not None
463
+ _ = vision_embeds = attention_mask = None
464
+
465
+ return self.encoder_decoder_forward(input_ids=None,
466
+ attention_mask=attention_mask,
467
+ encoder_outputs=kwargs.get('encoder_outputs'),
468
+ decoder_input_ids=kwargs.get('decoder_input_ids'),
469
+ decoder_attention_mask=None,
470
+ head_mask=head_mask,
471
+ decoder_head_mask=None,
472
+ past_key_values=kwargs.get('past_key_values'),
473
+ inputs_embeds=vision_embeds,
474
+ decoder_inputs_embeds=kwargs.get('decoder_inputs_embeds'),
475
+ labels=labels,
476
+ use_cache=True,
477
+ output_attentions=kwargs.get('output_attentions'),
478
+ output_hidden_states=kwargs.get('output_hidden_states'),
479
+ return_dict=kwargs.get('return_dict')
480
+ )
481
+
482
+
483
+ def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
484
+ if kwargs.get('encoder_outputs') is not None:
485
+ return {'attention_mask': kwargs.get('attention_mask'),
486
+ 'encoder_outputs': kwargs.get('encoder_outputs'),
487
+ 'decoder_input_ids': input_ids,
488
+ 'past_key_values': kwargs.get('past'),
489
+ }
490
+ else:
491
+ raise ValueError(
492
+ "Make sure that encoder_outputs is already computed when preapring inputs for generation. --y.x.")
493
+
494
+ def _prepare_encoder_inputs(self, image, input_ids=None, bbox=None, attention_mask=None):
495
+ # text embedding
496
+ batch_size = image.shape[0]
497
+
498
+ if input_ids is not None:
499
+ text_embeds = self.shared(input_ids)
500
+ text_seq_length = text_embeds.shape[1]
501
+ else:
502
+ text_embeds = None
503
+ text_seq_length = 0
504
+
505
+ assert self.config.vision is not None
506
+ # vision embedding
507
+ vision_embeds = self.vision_model(image)
508
+ vision_embeds = self.vision_embed_matcher(vision_embeds)
509
+ vision_seq_length = vision_embeds.shape[1]
510
+ # add task token (e.g <OCR> for ocr)
511
+ vision_embeds, text_seq_length = self.concat_task_token(vision_embeds, text_seq_length)
512
+ attention_mask = torch.ones((batch_size, vision_seq_length + text_seq_length), dtype=torch.int32).to(self.device)
513
+ return text_embeds, vision_embeds, attention_mask
514
+
515
+ def concat_task_token(self, embeds, text_seq_length=0):
516
+ # add task token (e.g <OCR> for ocr)
517
+ if self.task_name in self.task_token_ids.keys():
518
+ B = embeds.shape[0]
519
+ task_embeds = self.shared(self.task_token_ids[self.task_name])
520
+ text_seq_length += task_embeds.shape[0]
521
+ return torch.cat((embeds, task_embeds.repeat((B, 1, 1))), dim=1), text_seq_length
522
+ else:
523
+ # no such task token exists
524
+ return embeds, text_seq_length
525
+
526
+ def _prepare_model_inputs(
527
+ self,
528
+ inputs: Optional[torch.Tensor] = None,
529
+ bos_token_id: Optional[int] = None,
530
+ model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
531
+ ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
532
+ """
533
+ This function extracts the model-specific `inputs` for generation.
534
+ """
535
+ input_name = 'inputs_embeds'
536
+ _, vision_embeds, attention_mask = self._prepare_encoder_inputs(image=model_kwargs['image'])
537
+ model_kwargs['attention_mask'] = attention_mask
538
+
539
+ inputs = vision_embeds
540
+
541
+ # 4. if `inputs` is still None, try to create `input_ids` from BOS token
542
+ inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
543
+ return inputs, input_name, model_kwargs
544
+
545
+ def _prepare_encoder_decoder_kwargs_for_generation(
546
+ self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
547
+ ) -> Dict[str, Any]:
548
+ assert "encoder_outputs" not in model_kwargs
549
+
550
+ # 1. get encoder
551
+ encoder = self.get_encoder()
552
+
553
+ # 2. prepare encoder args and encoder kwargs from model kwargs
554
+ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
555
+ irrelevent_fields = ['input_ids', 'attention_mask', 'inputs_embeds', 'image', 'bbox', 'line_coordinates',
556
+ 'adj', 'lm_labels', 'banned_token_ids', 'questions', 'answers', 'labels', 'task_name']
557
+ encoder_kwargs = {
558
+ argument: value
559
+ for argument, value in model_kwargs.items()
560
+ if not any(argument.startswith(p) for p in irrelevant_prefix) and argument not in irrelevent_fields
561
+ }
562
+
563
+ # 3. make sure that encoder returns `ModelOutput`
564
+ encoder_kwargs["return_dict"] = True
565
+ model_kwargs["encoder_outputs"]: ModelOutput = encoder(
566
+ input_ids=None, attention_mask=model_kwargs['attention_mask'],
567
+ inputs_embeds=inputs_tensor, **encoder_kwargs)
568
+
569
+ return model_kwargs
570
+
571
+ def add_task_tokens(self):
572
+ self.input_tokenizer.add_tokens('<OCR>', special_tokens=True)
573
+ self.task_token_ids = torch.nn.ParameterDict([['ocr', self.register_token('<OCR>')]])
574
+
575
+ def register_token(self, token: str):
576
+ self.input_tokenizer.add_tokens(token, special_tokens=True)
577
+ token_ids = self.input_tokenizer.encode(token)
578
+ return torch.nn.Parameter(torch.tensor(token_ids), requires_grad=False)
579
+
580
+ def set_task_name(self, task_name):
581
+ if task_name:
582
+ self.task_name = task_name
583
+
584
+ def get_trivial_mask(self, inp):
585
+ return torch.ones((inp.shape[:2]), dtype=torch.int32).to(self.device)
586
+
587
+
588
+ class VisFocusModelForLocalizedMaskedLanguageModeling(VisFocusModel):
589
+ def __init__(self, config):
590
+ super().__init__(config)
591
+ self.set_task_name('mpm')
592
+ self.text_embedder = T5_Encoder(self.vision_config.text_embedder, freeze=True)
593
+
594
+ def forward(self,
595
+ input_ids=None,
596
+ bbox=None,
597
+ image=None,
598
+ attention_mask=None,
599
+ head_mask=None,
600
+ inputs_embeds=None,
601
+ encoder_hidden_states=None,
602
+ encoder_attention_mask=None,
603
+ labels=None,
604
+ **kwargs):
605
+ if not kwargs.get('encoder_outputs'):
606
+ if self.task_name == 'ocr':
607
+ input_ids = None
608
+ if not hasattr(self, 'prompt_embeds'):
609
+ prompt = 'what is written in this document?'
610
+ prompt_ids = self.input_tokenizer.encode(prompt)
611
+ B = image.shape[0]
612
+ prompt_ids = torch.tensor(prompt_ids).expand(B, len(prompt_ids)).to(self.device)
613
+ setattr(self, 'prompt_embeds', self.text_embedder(prompt_ids).detach())
614
+ _, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=input_ids, image=image)
615
+ else:
616
+ # for generation mode
617
+ assert kwargs.get('decoder_input_ids') is not None
618
+ _ = vision_embeds = attention_mask = None
619
+
620
+ return self.encoder_decoder_forward(input_ids=None,
621
+ attention_mask=attention_mask,
622
+ encoder_outputs=kwargs.get('encoder_outputs'),
623
+ decoder_input_ids=kwargs.get('decoder_input_ids'),
624
+ decoder_attention_mask=None,
625
+ head_mask=head_mask,
626
+ decoder_head_mask=None,
627
+ past_key_values=kwargs.get('past_key_values'),
628
+ inputs_embeds=vision_embeds,
629
+ decoder_inputs_embeds=kwargs.get('decoder_inputs_embeds'),
630
+ labels=labels,
631
+ use_cache=True,
632
+ output_attentions=kwargs.get('output_attentions'),
633
+ output_hidden_states=kwargs.get('output_hidden_states'),
634
+ return_dict=kwargs.get('return_dict')
635
+ )
636
+
637
+ def _prepare_encoder_inputs(self, image, input_ids=None, bbox=None, attention_mask=None):
638
+ batch_size = image.shape[0]
639
+
640
+ # if prompt is contant
641
+ if self.task_name == 'ocr':
642
+ assert input_ids is None
643
+ text_embeds = self.prompt_embeds
644
+ else:
645
+ assert input_ids is not None
646
+ if self.text_embedder == self.encoder:
647
+ with torch.no_grad():
648
+ text_embeds = self.encoder(input_ids).last_hidden_state
649
+ else:
650
+ text_embeds = self.text_embedder(input_ids)
651
+
652
+ text_embeds = text_embeds.detach()
653
+
654
+ text_seq_length = text_embeds.shape[1] if self.task_name == 'pm_vqa_concat' else 0
655
+ assert self.config.vision is not None
656
+ # vision embedding
657
+ vision_embeds = self.vision_model(image, context_prompts=text_embeds)
658
+ if self.vision_model.model_name in ["swin_v2"]:
659
+ vision_embeds = self.vision_embed_matcher(vision_embeds)
660
+ vision_seq_length = vision_embeds.shape[1]
661
+ # add task token (e.g <OCR> for ocr)
662
+ vision_embeds, text_seq_length = self.concat_task_token(vision_embeds, text_seq_length=text_seq_length)
663
+ attention_mask = torch.ones((batch_size, vision_seq_length + text_seq_length), dtype=torch.int32).to(self.device)
664
+ return text_embeds, vision_embeds, attention_mask
665
+
666
+ def _prepare_model_inputs(
667
+ self,
668
+ inputs: Optional[torch.Tensor] = None,
669
+ bos_token_id: Optional[int] = None,
670
+ model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
671
+ ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
672
+ """
673
+ This function extracts the model-specific `inputs` for generation.
674
+ """
675
+
676
+ input_name = 'inputs_embeds'
677
+ _, vision_embeds, attention_mask = self._prepare_encoder_inputs(image=model_kwargs['image'], input_ids=model_kwargs['input_ids'])
678
+ model_kwargs['attention_mask'] = attention_mask
679
+ inputs = vision_embeds
680
+ # 4. if `inputs` is still None, try to create `input_ids` from BOS token
681
+ inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
682
+ return inputs, input_name, model_kwargs
683
+
684
+ def add_task_tokens(self):
685
+ super().add_task_tokens()
686
+ self.input_tokenizer.add_tokens('<MPM>', special_tokens=True)
687
+ self.task_token_ids.update({'mpm': self.register_token('<MPM>')})
688
+
689
+
690
+ class VisFocusModelForImageTextToText(VisFocusModelForLocalizedMaskedLanguageModeling):
691
+ def __init__(self, config):
692
+ super().__init__(config)
693
+ self.set_task_name('pm_vqa_concat')
694
+
695
+ def forward(self, questions=None, answers=None, image=None, labels=None, **kwargs):
696
+ if kwargs.get('encoder_outputs') is None:
697
+ text_embeds, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=questions['input_ids'], image=image)
698
+ inputs_embeds = torch.concat((text_embeds, vision_embeds), dim=1)
699
+ attention_mask = self.get_trivial_mask(inputs_embeds) # -> when different tokenizer is used for ViLMA/concat, need to re-calculate attn. mask
700
+ else:
701
+ # for generation mode (image encoding happens before)
702
+ assert kwargs.get('decoder_input_ids') is not None
703
+ assert kwargs.get('encoder_outputs') is not None
704
+ inputs_embeds = kwargs.get('encoder_outputs')
705
+ text_embeds = vision_embeds = attention_mask = None
706
+
707
+ return self.encoder_decoder_forward(input_ids=None,
708
+ attention_mask=attention_mask,
709
+ encoder_outputs=kwargs.get('encoder_outputs'),
710
+ decoder_input_ids=kwargs.get('decoder_input_ids'),
711
+ decoder_attention_mask=None,
712
+ head_mask=None,
713
+ decoder_head_mask=None,
714
+ past_key_values=kwargs.get('past_key_values'),
715
+ inputs_embeds=inputs_embeds,
716
+ decoder_inputs_embeds=kwargs.get('decoder_inputs_embeds'),
717
+ labels=labels,
718
+ use_cache=True,
719
+ output_attentions=kwargs.get('output_attentions'),
720
+ output_hidden_states=kwargs.get('output_hidden_states'),
721
+ return_dict=kwargs.get('return_dict')
722
+ )
723
+
724
+ def _prepare_model_inputs(self, inputs=None, bos_token_id=None, model_kwargs=None ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
725
+ """
726
+ This function extracts the model-specific `inputs` for generation.
727
+ """
728
+ input_name = 'inputs_embeds'
729
+ text_embeds, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=model_kwargs['questions']['input_ids'], image=model_kwargs['image'])
730
+ model_kwargs['attention_mask'] = attention_mask
731
+ inputs_embeds = torch.concat((text_embeds, vision_embeds), dim=1)
732
+ inputs = inputs_embeds
733
+ # 4. if `inputs` is still None, try to create `input_ids` from BOS token
734
+ inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
735
+ model_kwargs['attention_mask'] = self.get_trivial_mask(inputs)
736
+ return inputs, input_name, model_kwargs
737
+
738
+ def _prepare_encoder_inputs(self, image, input_ids=None, bbox=None, attention_mask=None):
739
+ batch_size = image.shape[0]
740
+ assert input_ids is not None
741
+ if self.text_embedder == self.encoder:
742
+ with torch.no_grad():
743
+ text_embeds = self.encoder(input_ids).last_hidden_state
744
+ else:
745
+ text_embeds = self.text_embedder(input_ids)
746
+
747
+ text_embeds = text_embeds.detach()
748
+
749
+ text_seq_length = text_embeds.shape[1] if self.task_name == 'pm_vqa_concat' else 0
750
+ assert self.config.vision is not None
751
+ # vision embedding
752
+ vision_embeds = self.vision_model(image, context_prompts=text_embeds)
753
+ if self.vision_model.model_name in ["swin_v2"]:
754
+ vision_embeds = self.vision_embed_matcher(vision_embeds)
755
+ vision_seq_length = vision_embeds.shape[1]
756
+ # add task token (e.g <OCR> for ocr)
757
+ vision_embeds, text_seq_length = self.concat_task_token(vision_embeds, text_seq_length=text_seq_length)
758
+ attention_mask = torch.ones((batch_size, vision_seq_length + text_seq_length), dtype=torch.int32).to(self.device)
759
+ text_embeds = self.shared(input_ids) # for concat, use direct the T5 nn.embeddings
760
+ return text_embeds, vision_embeds, attention_mask
761
+
762
+ def add_task_tokens(self):
763
+ super().add_task_tokens()
764
+ self.input_tokenizer.add_tokens('<LMPM_VQA_CONCAT>', special_tokens=True)
765
+ self.task_token_ids.update({'pm_vqa_concat': self.register_token('<LMPM_VQA_CONCAT>')})
766
+
767
+
768
+ def _to_cuda(sample, device=torch.device('cuda')):
769
+ if isinstance(sample, torch.Tensor):
770
+ return sample.to(device)
771
+ elif isinstance(sample, list):
772
+ return sample
773
+ else:
774
+ for k in sample.keys():
775
+ sample[k] = _to_cuda(sample[k], device)
776
+ return sample
777
+
778
+
779
+ def fetch_sample(ds, ds_for_vis):
780
+ idx = random.randint(50, 100)
781
+ for i in range(idx):
782
+ inputs = next(ds)
783
+ inputs_to_vis = next(ds_for_vis)
784
+ return inputs, inputs_to_vis
785
+
786
+
787
+ MATCHER_MAP = {
788
+ 'default': EmbedMatcher,
789
+ }
790
+
791
+
792
+ # vqa
793
+ if __name__ == '__main__':
794
+ # load yaml
795
+ with open('configs/test_expts/vf_base_finetune_docvqa__v2_accum4_f32_V5__mpm_altConcat__vilma_concat_V1/vqa_model_args.yaml', 'r') as f:
796
+ model_args = EasyDict(yaml.safe_load(f))
797
+
798
+ DEVICE = 'cpu' # 'cpu'
799
+
800
+ ## load pretrained if needed
801
+ last_ckpt = None # get_last_checkpoint(dirname(model_args.model_config_path))
802
+ ##
803
+
804
+ # model = get_model_class(model_args, last_ckpt=last_ckpt)
805
+
806
+ cfg = VisFocusConfig.from_pretrained('configs/config.json')
807
+ cfg.push_to_hub('ofirab/visfocus-base-docvqa')
808
+ model = VisFocusModelForImageTextToText(cfg)
809
+ model.push_to_hub('ofirab/visfocus-base-docvqa')
810
+ model.to(DEVICE)