ParisNeo commited on
Commit
5110eb7
·
1 Parent(s): be93fc8

first working

Browse files
Files changed (10) hide show
  1. .gitignore +2 -0
  2. README.md +3 -0
  3. app.py +62 -0
  4. blip_vqa.pth +3 -0
  5. blip_vqa.py +246 -0
  6. configs/med_config.json +22 -0
  7. med.py +956 -0
  8. plot.py +68 -0
  9. requirements.txt +6 -0
  10. vit.py +305 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ *.pyc
README.md CHANGED
@@ -9,5 +9,8 @@ app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
 
 
 
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
9
  pinned: false
10
  license: mit
11
  ---
12
+ BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation
13
+ This space shows how easy it is to use the BLIP model for image querrying.
14
+ [https://arxiv.org/abs/2201.12086](https://arxiv.org/abs/2201.12086)
15
 
16
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import urllib.request
6
+ import io
7
+ from pathlib import Path
8
+
9
+ from blip_vqa import blip_vqa
10
+
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ image_size = 384
13
+
14
+ class App():
15
+ def __init__(self):
16
+ self.selected_model=0
17
+
18
+ # Load blip for question answer
19
+ print("Loading Blip for question answering")
20
+ model_url = str(Path(__file__).parent/'blip_vqa.pth')
21
+ self.qa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')
22
+ self.qa_model.eval()
23
+ self.qa_model = self.qa_model.to(device)
24
+
25
+
26
+
27
+ with gr.Blocks() as demo:
28
+ with gr.Row():
29
+ self.image_source = gr.inputs.Image(shape=(224, 224))
30
+ with gr.Tabs():
31
+ with gr.Tab("Question/Answer"):
32
+ self.question = gr.inputs.Textbox(label="Custom question (if applicable)", default="where is the right hand?")
33
+ self.answer = gr.Button("Ask")
34
+ self.lbl_caption = gr.outputs.Label(label="Caption")
35
+ self.answer.click(self.answer_question_image, [self.image_source, self.question], self.lbl_caption)
36
+ # Launch the interface
37
+ demo.launch()
38
+
39
+
40
+
41
+ def answer_question_image(self, img, custom_question="Describe this image"):
42
+ # Load the selected PyTorch model
43
+
44
+ # Preprocess the image
45
+ preprocess = transforms.Compose([
46
+ transforms.Resize((image_size,image_size),interpolation=transforms.InterpolationMode.BICUBIC),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
49
+ ])
50
+ img = preprocess(Image.fromarray(img.astype('uint8'), 'RGB'))
51
+
52
+ # Make a prediction with the model
53
+ with torch.no_grad():
54
+ output = self.qa_model(img.unsqueeze(0).to(device), custom_question, train=False, inference='generate')
55
+ answer = output
56
+
57
+ # Return the predicted label as a string
58
+ return answer[0]
59
+
60
+ app = App()
61
+
62
+
blip_vqa.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a7d546209f1ccfa8b3cd3a0138c53e0d1e95e4a4bc280bef8f67e20fe4925ae
3
+ size 1446244375
blip_vqa.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from med import BertConfig, BertModel, BertLMHeadModel
2
+ from vit import VisionTransformer, interpolate_pos_embed
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ from transformers import BertTokenizer
8
+ import numpy as np
9
+ from pathlib import Path
10
+ from urllib.parse import urlparse
11
+ from timm.models.hub import download_cached_file
12
+ import os
13
+
14
+ # General helpers
15
+
16
+ def is_url(url_or_filename):
17
+ parsed = urlparse(url_or_filename)
18
+ return parsed.scheme in ("http", "https")
19
+
20
+ def load_checkpoint(model,url_or_filename):
21
+ if is_url(url_or_filename):
22
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
23
+ checkpoint = torch.load(cached_file, map_location='cpu')
24
+ elif os.path.isfile(url_or_filename):
25
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
26
+ else:
27
+ raise RuntimeError('checkpoint url or path is invalid')
28
+
29
+ state_dict = checkpoint['model']
30
+
31
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
32
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
33
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
34
+ model.visual_encoder_m)
35
+ for key in model.state_dict().keys():
36
+ if key in state_dict.keys():
37
+ if state_dict[key].shape!=model.state_dict()[key].shape:
38
+ del state_dict[key]
39
+
40
+ msg = model.load_state_dict(state_dict,strict=False)
41
+ print('load checkpoint from %s'%url_or_filename)
42
+ return model,msg
43
+
44
+ def init_tokenizer():
45
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
46
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
47
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
48
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
49
+ return tokenizer
50
+
51
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
52
+
53
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
54
+ if vit=='base':
55
+ vision_width = 768
56
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
57
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
58
+ drop_path_rate=0 or drop_path_rate
59
+ )
60
+ elif vit=='large':
61
+ vision_width = 1024
62
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
63
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
64
+ drop_path_rate=0.1 or drop_path_rate
65
+ )
66
+ return visual_encoder, vision_width
67
+
68
+
69
+
70
+ class BLIP_VQA(nn.Module):
71
+ def __init__(self,
72
+ med_config = str(Path(__file__).parent / 'configs/med_config.json'),
73
+ image_size = 480,
74
+ vit = 'base',
75
+ vit_grad_ckpt = False,
76
+ vit_ckpt_layer = 0,
77
+ ):
78
+ """
79
+ Args:
80
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
81
+ image_size (int): input image size
82
+ vit (str): model size of vision transformer
83
+ """
84
+ super().__init__()
85
+
86
+ self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
87
+ self.tokenizer = init_tokenizer()
88
+
89
+ encoder_config = BertConfig.from_json_file(med_config)
90
+ encoder_config.encoder_width = vision_width
91
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
92
+
93
+ decoder_config = BertConfig.from_json_file(med_config)
94
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
95
+
96
+
97
+ def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
98
+
99
+ image_embeds = self.visual_encoder(image)
100
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
101
+
102
+ question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
103
+ return_tensors="pt").to(image.device)
104
+ question.input_ids[:,0] = self.tokenizer.enc_token_id
105
+
106
+ if train:
107
+ '''
108
+ n: number of answers for each question
109
+ weights: weight for each answer
110
+ '''
111
+ answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
112
+ answer.input_ids[:,0] = self.tokenizer.bos_token_id
113
+ answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
114
+
115
+ question_output = self.text_encoder(question.input_ids,
116
+ attention_mask = question.attention_mask,
117
+ encoder_hidden_states = image_embeds,
118
+ encoder_attention_mask = image_atts,
119
+ return_dict = True)
120
+
121
+ question_states = []
122
+ question_atts = []
123
+ for b, n in enumerate(n):
124
+ question_states += [question_output.last_hidden_state[b]]*n
125
+ question_atts += [question.attention_mask[b]]*n
126
+ question_states = torch.stack(question_states,0)
127
+ question_atts = torch.stack(question_atts,0)
128
+
129
+ answer_output = self.text_decoder(answer.input_ids,
130
+ attention_mask = answer.attention_mask,
131
+ encoder_hidden_states = question_states,
132
+ encoder_attention_mask = question_atts,
133
+ labels = answer_targets,
134
+ return_dict = True,
135
+ reduction = 'none',
136
+ )
137
+
138
+ loss = weights * answer_output.loss
139
+ loss = loss.sum()/image.size(0)
140
+
141
+ return loss
142
+
143
+
144
+ else:
145
+ question_output = self.text_encoder(question.input_ids,
146
+ attention_mask = question.attention_mask,
147
+ encoder_hidden_states = image_embeds,
148
+ encoder_attention_mask = image_atts,
149
+ return_dict = True)
150
+
151
+ if inference=='generate':
152
+ num_beams = 3
153
+ question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
154
+ question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
155
+ model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
156
+
157
+ bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
158
+
159
+ outputs = self.text_decoder.generate(input_ids=bos_ids,
160
+ max_length=10,
161
+ min_length=1,
162
+ num_beams=num_beams,
163
+ eos_token_id=self.tokenizer.sep_token_id,
164
+ pad_token_id=self.tokenizer.pad_token_id,
165
+ **model_kwargs)
166
+
167
+ answers = []
168
+ for output in outputs:
169
+ answer = self.tokenizer.decode(output, skip_special_tokens=True)
170
+ answers.append(answer)
171
+ return answers
172
+
173
+ elif inference=='rank':
174
+ max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
175
+ answer.input_ids, answer.attention_mask, k_test)
176
+ return max_ids
177
+
178
+
179
+
180
+ def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
181
+
182
+ num_ques = question_states.size(0)
183
+ start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
184
+
185
+ start_output = self.text_decoder(start_ids,
186
+ encoder_hidden_states = question_states,
187
+ encoder_attention_mask = question_atts,
188
+ return_dict = True,
189
+ reduction = 'none')
190
+ logits = start_output.logits[:,0,:] # first token's logit
191
+
192
+ # topk_probs: top-k probability
193
+ # topk_ids: [num_question, k]
194
+ answer_first_token = answer_ids[:,1]
195
+ prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
196
+ topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
197
+
198
+ # answer input: [num_question*k, answer_len]
199
+ input_ids = []
200
+ input_atts = []
201
+ for b, topk_id in enumerate(topk_ids):
202
+ input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
203
+ input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
204
+ input_ids = torch.cat(input_ids,dim=0)
205
+ input_atts = torch.cat(input_atts,dim=0)
206
+
207
+ targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
208
+
209
+ # repeat encoder's output for top-k answers
210
+ question_states = tile(question_states, 0, k)
211
+ question_atts = tile(question_atts, 0, k)
212
+
213
+ output = self.text_decoder(input_ids,
214
+ attention_mask = input_atts,
215
+ encoder_hidden_states = question_states,
216
+ encoder_attention_mask = question_atts,
217
+ labels = targets_ids,
218
+ return_dict = True,
219
+ reduction = 'none')
220
+
221
+ log_probs_sum = -output.loss
222
+ log_probs_sum = log_probs_sum.view(num_ques,k)
223
+
224
+ max_topk_ids = log_probs_sum.argmax(dim=1)
225
+ max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
226
+
227
+ return max_ids
228
+
229
+
230
+ def blip_vqa(pretrained='',**kwargs):
231
+ model = BLIP_VQA(**kwargs)
232
+ if pretrained:
233
+ model,msg = load_checkpoint(model,pretrained)
234
+ # assert(len(msg.missing_keys)==0)
235
+ return model
236
+
237
+
238
+ def tile(x, dim, n_tile):
239
+ init_dim = x.size(dim)
240
+ repeat_idx = [1] * x.dim()
241
+ repeat_idx[dim] = n_tile
242
+ x = x.repeat(*(repeat_idx))
243
+ order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
244
+ return torch.index_select(x, dim, order_index.to(x.device))
245
+
246
+
configs/med_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ {
3
+ "architectures": [
4
+ "BertModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 3072,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 512,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 12,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "type_vocab_size": 2,
19
+ "vocab_size": 30524,
20
+ "encoder_width": 768,
21
+ "add_cross_attention": true
22
+ }
med.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ '''
3
+ * Copyright (c) 2022, salesforce.com, inc.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
+ * By Junnan Li
8
+ * Based on huggingface code base
9
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
10
+ '''
11
+
12
+ import math
13
+ import os
14
+ import warnings
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from torch import Tensor, device, dtype, nn
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+ from torch.nn import CrossEntropyLoss
23
+ import torch.nn.functional as F
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.file_utils import (
27
+ ModelOutput,
28
+ )
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ BaseModelOutputWithPoolingAndCrossAttentions,
32
+ CausalLMOutputWithCrossAttentions,
33
+ MaskedLMOutput,
34
+ MultipleChoiceModelOutput,
35
+ NextSentencePredictorOutput,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutput,
38
+ TokenClassifierOutput,
39
+ )
40
+ from transformers.modeling_utils import (
41
+ PreTrainedModel,
42
+ apply_chunking_to_forward,
43
+ find_pruneable_heads_and_indices,
44
+ prune_linear_layer,
45
+ )
46
+ from transformers.utils import logging
47
+ from transformers.models.bert.configuration_bert import BertConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+
53
+ class BertEmbeddings(nn.Module):
54
+ """Construct the embeddings from word and position embeddings."""
55
+
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
59
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
60
+
61
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
62
+ # any TensorFlow checkpoint file
63
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
64
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
65
+
66
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
67
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
68
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
69
+
70
+ self.config = config
71
+
72
+ def forward(
73
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
74
+ ):
75
+ if input_ids is not None:
76
+ input_shape = input_ids.size()
77
+ else:
78
+ input_shape = inputs_embeds.size()[:-1]
79
+
80
+ seq_length = input_shape[1]
81
+
82
+ if position_ids is None:
83
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
84
+
85
+ if inputs_embeds is None:
86
+ inputs_embeds = self.word_embeddings(input_ids)
87
+
88
+ embeddings = inputs_embeds
89
+
90
+ if self.position_embedding_type == "absolute":
91
+ position_embeddings = self.position_embeddings(position_ids)
92
+ embeddings += position_embeddings
93
+ embeddings = self.LayerNorm(embeddings)
94
+ embeddings = self.dropout(embeddings)
95
+ return embeddings
96
+
97
+
98
+ class BertSelfAttention(nn.Module):
99
+ def __init__(self, config, is_cross_attention):
100
+ super().__init__()
101
+ self.config = config
102
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
103
+ raise ValueError(
104
+ "The hidden size (%d) is not a multiple of the number of attention "
105
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
106
+ )
107
+
108
+ self.num_attention_heads = config.num_attention_heads
109
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
110
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
111
+
112
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
113
+ if is_cross_attention:
114
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
115
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
116
+ else:
117
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
118
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
119
+
120
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
121
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
122
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
123
+ self.max_position_embeddings = config.max_position_embeddings
124
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
125
+ self.save_attention = False
126
+
127
+ def save_attn_gradients(self, attn_gradients):
128
+ self.attn_gradients = attn_gradients
129
+
130
+ def get_attn_gradients(self):
131
+ return self.attn_gradients
132
+
133
+ def save_attention_map(self, attention_map):
134
+ self.attention_map = attention_map
135
+
136
+ def get_attention_map(self):
137
+ return self.attention_map
138
+
139
+ def transpose_for_scores(self, x):
140
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
141
+ x = x.view(*new_x_shape)
142
+ return x.permute(0, 2, 1, 3)
143
+
144
+ def forward(
145
+ self,
146
+ hidden_states,
147
+ attention_mask=None,
148
+ head_mask=None,
149
+ encoder_hidden_states=None,
150
+ encoder_attention_mask=None,
151
+ past_key_value=None,
152
+ output_attentions=False,
153
+ ):
154
+ mixed_query_layer = self.query(hidden_states)
155
+
156
+ # If this is instantiated as a cross-attention module, the keys
157
+ # and values come from an encoder; the attention mask needs to be
158
+ # such that the encoder's padding tokens are not attended to.
159
+ is_cross_attention = encoder_hidden_states is not None
160
+
161
+ if is_cross_attention:
162
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
163
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
164
+ attention_mask = encoder_attention_mask
165
+ elif past_key_value is not None:
166
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
167
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
168
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
169
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
170
+ else:
171
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
172
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
173
+
174
+ query_layer = self.transpose_for_scores(mixed_query_layer)
175
+
176
+ past_key_value = (key_layer, value_layer)
177
+
178
+ # Take the dot product between "query" and "key" to get the raw attention scores.
179
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
180
+
181
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
182
+ seq_length = hidden_states.size()[1]
183
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
184
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
185
+ distance = position_ids_l - position_ids_r
186
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
187
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
188
+
189
+ if self.position_embedding_type == "relative_key":
190
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
191
+ attention_scores = attention_scores + relative_position_scores
192
+ elif self.position_embedding_type == "relative_key_query":
193
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
194
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
195
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
196
+
197
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
198
+ if attention_mask is not None:
199
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
200
+ attention_scores = attention_scores + attention_mask
201
+
202
+ # Normalize the attention scores to probabilities.
203
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
204
+
205
+ if is_cross_attention and self.save_attention:
206
+ self.save_attention_map(attention_probs)
207
+ attention_probs.register_hook(self.save_attn_gradients)
208
+
209
+ # This is actually dropping out entire tokens to attend to, which might
210
+ # seem a bit unusual, but is taken from the original Transformer paper.
211
+ attention_probs_dropped = self.dropout(attention_probs)
212
+
213
+ # Mask heads if we want to
214
+ if head_mask is not None:
215
+ attention_probs_dropped = attention_probs_dropped * head_mask
216
+
217
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
218
+
219
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
220
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
221
+ context_layer = context_layer.view(*new_context_layer_shape)
222
+
223
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
224
+
225
+ outputs = outputs + (past_key_value,)
226
+ return outputs
227
+
228
+
229
+ class BertSelfOutput(nn.Module):
230
+ def __init__(self, config):
231
+ super().__init__()
232
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
233
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
234
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
235
+
236
+ def forward(self, hidden_states, input_tensor):
237
+ hidden_states = self.dense(hidden_states)
238
+ hidden_states = self.dropout(hidden_states)
239
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
240
+ return hidden_states
241
+
242
+
243
+ class BertAttention(nn.Module):
244
+ def __init__(self, config, is_cross_attention=False):
245
+ super().__init__()
246
+ self.self = BertSelfAttention(config, is_cross_attention)
247
+ self.output = BertSelfOutput(config)
248
+ self.pruned_heads = set()
249
+
250
+ def prune_heads(self, heads):
251
+ if len(heads) == 0:
252
+ return
253
+ heads, index = find_pruneable_heads_and_indices(
254
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
255
+ )
256
+
257
+ # Prune linear layers
258
+ self.self.query = prune_linear_layer(self.self.query, index)
259
+ self.self.key = prune_linear_layer(self.self.key, index)
260
+ self.self.value = prune_linear_layer(self.self.value, index)
261
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
262
+
263
+ # Update hyper params and store pruned heads
264
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
265
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
266
+ self.pruned_heads = self.pruned_heads.union(heads)
267
+
268
+ def forward(
269
+ self,
270
+ hidden_states,
271
+ attention_mask=None,
272
+ head_mask=None,
273
+ encoder_hidden_states=None,
274
+ encoder_attention_mask=None,
275
+ past_key_value=None,
276
+ output_attentions=False,
277
+ ):
278
+ self_outputs = self.self(
279
+ hidden_states,
280
+ attention_mask,
281
+ head_mask,
282
+ encoder_hidden_states,
283
+ encoder_attention_mask,
284
+ past_key_value,
285
+ output_attentions,
286
+ )
287
+ attention_output = self.output(self_outputs[0], hidden_states)
288
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
289
+ return outputs
290
+
291
+
292
+ class BertIntermediate(nn.Module):
293
+ def __init__(self, config):
294
+ super().__init__()
295
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
296
+ if isinstance(config.hidden_act, str):
297
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
298
+ else:
299
+ self.intermediate_act_fn = config.hidden_act
300
+
301
+ def forward(self, hidden_states):
302
+ hidden_states = self.dense(hidden_states)
303
+ hidden_states = self.intermediate_act_fn(hidden_states)
304
+ return hidden_states
305
+
306
+
307
+ class BertOutput(nn.Module):
308
+ def __init__(self, config):
309
+ super().__init__()
310
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
311
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
312
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
313
+
314
+ def forward(self, hidden_states, input_tensor):
315
+ hidden_states = self.dense(hidden_states)
316
+ hidden_states = self.dropout(hidden_states)
317
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
318
+ return hidden_states
319
+
320
+
321
+ class BertLayer(nn.Module):
322
+ def __init__(self, config, layer_num):
323
+ super().__init__()
324
+ self.config = config
325
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
326
+ self.seq_len_dim = 1
327
+ self.attention = BertAttention(config)
328
+ self.layer_num = layer_num
329
+ if self.config.add_cross_attention:
330
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
331
+ self.intermediate = BertIntermediate(config)
332
+ self.output = BertOutput(config)
333
+
334
+ def forward(
335
+ self,
336
+ hidden_states,
337
+ attention_mask=None,
338
+ head_mask=None,
339
+ encoder_hidden_states=None,
340
+ encoder_attention_mask=None,
341
+ past_key_value=None,
342
+ output_attentions=False,
343
+ mode=None,
344
+ ):
345
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
346
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
347
+ self_attention_outputs = self.attention(
348
+ hidden_states,
349
+ attention_mask,
350
+ head_mask,
351
+ output_attentions=output_attentions,
352
+ past_key_value=self_attn_past_key_value,
353
+ )
354
+ attention_output = self_attention_outputs[0]
355
+
356
+ outputs = self_attention_outputs[1:-1]
357
+ present_key_value = self_attention_outputs[-1]
358
+
359
+ if mode=='multimodal':
360
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
361
+
362
+ cross_attention_outputs = self.crossattention(
363
+ attention_output,
364
+ attention_mask,
365
+ head_mask,
366
+ encoder_hidden_states,
367
+ encoder_attention_mask,
368
+ output_attentions=output_attentions,
369
+ )
370
+ attention_output = cross_attention_outputs[0]
371
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
372
+ layer_output = apply_chunking_to_forward(
373
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
374
+ )
375
+ outputs = (layer_output,) + outputs
376
+
377
+ outputs = outputs + (present_key_value,)
378
+
379
+ return outputs
380
+
381
+ def feed_forward_chunk(self, attention_output):
382
+ intermediate_output = self.intermediate(attention_output)
383
+ layer_output = self.output(intermediate_output, attention_output)
384
+ return layer_output
385
+
386
+
387
+ class BertEncoder(nn.Module):
388
+ def __init__(self, config):
389
+ super().__init__()
390
+ self.config = config
391
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
392
+ self.gradient_checkpointing = False
393
+
394
+ def forward(
395
+ self,
396
+ hidden_states,
397
+ attention_mask=None,
398
+ head_mask=None,
399
+ encoder_hidden_states=None,
400
+ encoder_attention_mask=None,
401
+ past_key_values=None,
402
+ use_cache=None,
403
+ output_attentions=False,
404
+ output_hidden_states=False,
405
+ return_dict=True,
406
+ mode='multimodal',
407
+ ):
408
+ all_hidden_states = () if output_hidden_states else None
409
+ all_self_attentions = () if output_attentions else None
410
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
411
+
412
+ next_decoder_cache = () if use_cache else None
413
+
414
+ for i in range(self.config.num_hidden_layers):
415
+ layer_module = self.layer[i]
416
+ if output_hidden_states:
417
+ all_hidden_states = all_hidden_states + (hidden_states,)
418
+
419
+ layer_head_mask = head_mask[i] if head_mask is not None else None
420
+ past_key_value = past_key_values[i] if past_key_values is not None else None
421
+
422
+ if self.gradient_checkpointing and self.training:
423
+
424
+ if use_cache:
425
+ logger.warn(
426
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
427
+ )
428
+ use_cache = False
429
+
430
+ def create_custom_forward(module):
431
+ def custom_forward(*inputs):
432
+ return module(*inputs, past_key_value, output_attentions)
433
+
434
+ return custom_forward
435
+
436
+ layer_outputs = torch.utils.checkpoint.checkpoint(
437
+ create_custom_forward(layer_module),
438
+ hidden_states,
439
+ attention_mask,
440
+ layer_head_mask,
441
+ encoder_hidden_states,
442
+ encoder_attention_mask,
443
+ mode=mode,
444
+ )
445
+ else:
446
+ layer_outputs = layer_module(
447
+ hidden_states,
448
+ attention_mask,
449
+ layer_head_mask,
450
+ encoder_hidden_states,
451
+ encoder_attention_mask,
452
+ past_key_value,
453
+ output_attentions,
454
+ mode=mode,
455
+ )
456
+
457
+ hidden_states = layer_outputs[0]
458
+ if use_cache:
459
+ next_decoder_cache += (layer_outputs[-1],)
460
+ if output_attentions:
461
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
462
+
463
+ if output_hidden_states:
464
+ all_hidden_states = all_hidden_states + (hidden_states,)
465
+
466
+ if not return_dict:
467
+ return tuple(
468
+ v
469
+ for v in [
470
+ hidden_states,
471
+ next_decoder_cache,
472
+ all_hidden_states,
473
+ all_self_attentions,
474
+ all_cross_attentions,
475
+ ]
476
+ if v is not None
477
+ )
478
+ return BaseModelOutputWithPastAndCrossAttentions(
479
+ last_hidden_state=hidden_states,
480
+ past_key_values=next_decoder_cache,
481
+ hidden_states=all_hidden_states,
482
+ attentions=all_self_attentions,
483
+ cross_attentions=all_cross_attentions,
484
+ )
485
+
486
+
487
+ class BertPooler(nn.Module):
488
+ def __init__(self, config):
489
+ super().__init__()
490
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
491
+ self.activation = nn.Tanh()
492
+
493
+ def forward(self, hidden_states):
494
+ # We "pool" the model by simply taking the hidden state corresponding
495
+ # to the first token.
496
+ first_token_tensor = hidden_states[:, 0]
497
+ pooled_output = self.dense(first_token_tensor)
498
+ pooled_output = self.activation(pooled_output)
499
+ return pooled_output
500
+
501
+
502
+ class BertPredictionHeadTransform(nn.Module):
503
+ def __init__(self, config):
504
+ super().__init__()
505
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
506
+ if isinstance(config.hidden_act, str):
507
+ self.transform_act_fn = ACT2FN[config.hidden_act]
508
+ else:
509
+ self.transform_act_fn = config.hidden_act
510
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
511
+
512
+ def forward(self, hidden_states):
513
+ hidden_states = self.dense(hidden_states)
514
+ hidden_states = self.transform_act_fn(hidden_states)
515
+ hidden_states = self.LayerNorm(hidden_states)
516
+ return hidden_states
517
+
518
+
519
+ class BertLMPredictionHead(nn.Module):
520
+ def __init__(self, config):
521
+ super().__init__()
522
+ self.transform = BertPredictionHeadTransform(config)
523
+
524
+ # The output weights are the same as the input embeddings, but there is
525
+ # an output-only bias for each token.
526
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
527
+
528
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
529
+
530
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
531
+ self.decoder.bias = self.bias
532
+
533
+ def forward(self, hidden_states):
534
+ hidden_states = self.transform(hidden_states)
535
+ hidden_states = self.decoder(hidden_states)
536
+ return hidden_states
537
+
538
+
539
+ class BertOnlyMLMHead(nn.Module):
540
+ def __init__(self, config):
541
+ super().__init__()
542
+ self.predictions = BertLMPredictionHead(config)
543
+
544
+ def forward(self, sequence_output):
545
+ prediction_scores = self.predictions(sequence_output)
546
+ return prediction_scores
547
+
548
+
549
+ class BertPreTrainedModel(PreTrainedModel):
550
+ """
551
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
552
+ models.
553
+ """
554
+
555
+ config_class = BertConfig
556
+ base_model_prefix = "bert"
557
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
558
+
559
+ def _init_weights(self, module):
560
+ """ Initialize the weights """
561
+ if isinstance(module, (nn.Linear, nn.Embedding)):
562
+ # Slightly different from the TF version which uses truncated_normal for initialization
563
+ # cf https://github.com/pytorch/pytorch/pull/5617
564
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
565
+ elif isinstance(module, nn.LayerNorm):
566
+ module.bias.data.zero_()
567
+ module.weight.data.fill_(1.0)
568
+ if isinstance(module, nn.Linear) and module.bias is not None:
569
+ module.bias.data.zero_()
570
+
571
+
572
+ class BertModel(BertPreTrainedModel):
573
+ """
574
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
575
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
576
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
577
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
578
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
579
+ input to the forward pass.
580
+ """
581
+
582
+ def __init__(self, config, add_pooling_layer=True):
583
+ super().__init__(config)
584
+ self.config = config
585
+
586
+ self.embeddings = BertEmbeddings(config)
587
+
588
+ self.encoder = BertEncoder(config)
589
+
590
+ self.pooler = BertPooler(config) if add_pooling_layer else None
591
+
592
+ self.init_weights()
593
+
594
+
595
+ def get_input_embeddings(self):
596
+ return self.embeddings.word_embeddings
597
+
598
+ def set_input_embeddings(self, value):
599
+ self.embeddings.word_embeddings = value
600
+
601
+ def _prune_heads(self, heads_to_prune):
602
+ """
603
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
604
+ class PreTrainedModel
605
+ """
606
+ for layer, heads in heads_to_prune.items():
607
+ self.encoder.layer[layer].attention.prune_heads(heads)
608
+
609
+
610
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
611
+ """
612
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
613
+
614
+ Arguments:
615
+ attention_mask (:obj:`torch.Tensor`):
616
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
617
+ input_shape (:obj:`Tuple[int]`):
618
+ The shape of the input to the model.
619
+ device: (:obj:`torch.device`):
620
+ The device of the input to the model.
621
+
622
+ Returns:
623
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
624
+ """
625
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
626
+ # ourselves in which case we just need to make it broadcastable to all heads.
627
+ if attention_mask.dim() == 3:
628
+ extended_attention_mask = attention_mask[:, None, :, :]
629
+ elif attention_mask.dim() == 2:
630
+ # Provided a padding mask of dimensions [batch_size, seq_length]
631
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
632
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
633
+ if is_decoder:
634
+ batch_size, seq_length = input_shape
635
+
636
+ seq_ids = torch.arange(seq_length, device=device)
637
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
638
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
639
+ # causal and attention masks must have same type with pytorch version < 1.3
640
+ causal_mask = causal_mask.to(attention_mask.dtype)
641
+
642
+ if causal_mask.shape[1] < attention_mask.shape[1]:
643
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
644
+ causal_mask = torch.cat(
645
+ [
646
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
647
+ causal_mask,
648
+ ],
649
+ axis=-1,
650
+ )
651
+
652
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
653
+ else:
654
+ extended_attention_mask = attention_mask[:, None, None, :]
655
+ else:
656
+ raise ValueError(
657
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
658
+ input_shape, attention_mask.shape
659
+ )
660
+ )
661
+
662
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
663
+ # masked positions, this operation will create a tensor which is 0.0 for
664
+ # positions we want to attend and -10000.0 for masked positions.
665
+ # Since we are adding it to the raw scores before the softmax, this is
666
+ # effectively the same as removing these entirely.
667
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
668
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
669
+ return extended_attention_mask
670
+
671
+ def forward(
672
+ self,
673
+ input_ids=None,
674
+ attention_mask=None,
675
+ position_ids=None,
676
+ head_mask=None,
677
+ inputs_embeds=None,
678
+ encoder_embeds=None,
679
+ encoder_hidden_states=None,
680
+ encoder_attention_mask=None,
681
+ past_key_values=None,
682
+ use_cache=None,
683
+ output_attentions=None,
684
+ output_hidden_states=None,
685
+ return_dict=None,
686
+ is_decoder=False,
687
+ mode='multimodal',
688
+ ):
689
+ r"""
690
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
691
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
692
+ the model is configured as a decoder.
693
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
694
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
695
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
696
+ - 1 for tokens that are **not masked**,
697
+ - 0 for tokens that are **masked**.
698
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
699
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
700
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
701
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
702
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
703
+ use_cache (:obj:`bool`, `optional`):
704
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
705
+ decoding (see :obj:`past_key_values`).
706
+ """
707
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
708
+ output_hidden_states = (
709
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
710
+ )
711
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
712
+
713
+ if is_decoder:
714
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
715
+ else:
716
+ use_cache = False
717
+
718
+ if input_ids is not None and inputs_embeds is not None:
719
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
720
+ elif input_ids is not None:
721
+ input_shape = input_ids.size()
722
+ batch_size, seq_length = input_shape
723
+ device = input_ids.device
724
+ elif inputs_embeds is not None:
725
+ input_shape = inputs_embeds.size()[:-1]
726
+ batch_size, seq_length = input_shape
727
+ device = inputs_embeds.device
728
+ elif encoder_embeds is not None:
729
+ input_shape = encoder_embeds.size()[:-1]
730
+ batch_size, seq_length = input_shape
731
+ device = encoder_embeds.device
732
+ else:
733
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
734
+
735
+ # past_key_values_length
736
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
737
+
738
+ if attention_mask is None:
739
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
740
+
741
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
742
+ # ourselves in which case we just need to make it broadcastable to all heads.
743
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
744
+ device, is_decoder)
745
+
746
+ # If a 2D or 3D attention mask is provided for the cross-attention
747
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
748
+ if encoder_hidden_states is not None:
749
+ if type(encoder_hidden_states) == list:
750
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
751
+ else:
752
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
753
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
754
+
755
+ if type(encoder_attention_mask) == list:
756
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
757
+ elif encoder_attention_mask is None:
758
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
759
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
760
+ else:
761
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
762
+ else:
763
+ encoder_extended_attention_mask = None
764
+
765
+ # Prepare head mask if needed
766
+ # 1.0 in head_mask indicate we keep the head
767
+ # attention_probs has shape bsz x n_heads x N x N
768
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
769
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
770
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
771
+
772
+ if encoder_embeds is None:
773
+ embedding_output = self.embeddings(
774
+ input_ids=input_ids,
775
+ position_ids=position_ids,
776
+ inputs_embeds=inputs_embeds,
777
+ past_key_values_length=past_key_values_length,
778
+ )
779
+ else:
780
+ embedding_output = encoder_embeds
781
+
782
+ encoder_outputs = self.encoder(
783
+ embedding_output,
784
+ attention_mask=extended_attention_mask,
785
+ head_mask=head_mask,
786
+ encoder_hidden_states=encoder_hidden_states,
787
+ encoder_attention_mask=encoder_extended_attention_mask,
788
+ past_key_values=past_key_values,
789
+ use_cache=use_cache,
790
+ output_attentions=output_attentions,
791
+ output_hidden_states=output_hidden_states,
792
+ return_dict=return_dict,
793
+ mode=mode,
794
+ )
795
+ sequence_output = encoder_outputs[0]
796
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
797
+
798
+ if not return_dict:
799
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
800
+
801
+ return BaseModelOutputWithPoolingAndCrossAttentions(
802
+ last_hidden_state=sequence_output,
803
+ pooler_output=pooled_output,
804
+ past_key_values=encoder_outputs.past_key_values,
805
+ hidden_states=encoder_outputs.hidden_states,
806
+ attentions=encoder_outputs.attentions,
807
+ cross_attentions=encoder_outputs.cross_attentions,
808
+ )
809
+
810
+
811
+
812
+ class BertLMHeadModel(BertPreTrainedModel):
813
+
814
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
815
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
816
+
817
+ def __init__(self, config):
818
+ super().__init__(config)
819
+
820
+ self.bert = BertModel(config, add_pooling_layer=False)
821
+ self.cls = BertOnlyMLMHead(config)
822
+
823
+ self.init_weights()
824
+
825
+ def get_output_embeddings(self):
826
+ return self.cls.predictions.decoder
827
+
828
+ def set_output_embeddings(self, new_embeddings):
829
+ self.cls.predictions.decoder = new_embeddings
830
+
831
+ def forward(
832
+ self,
833
+ input_ids=None,
834
+ attention_mask=None,
835
+ position_ids=None,
836
+ head_mask=None,
837
+ inputs_embeds=None,
838
+ encoder_hidden_states=None,
839
+ encoder_attention_mask=None,
840
+ labels=None,
841
+ past_key_values=None,
842
+ use_cache=None,
843
+ output_attentions=None,
844
+ output_hidden_states=None,
845
+ return_dict=None,
846
+ return_logits=False,
847
+ is_decoder=True,
848
+ reduction='mean',
849
+ mode='multimodal',
850
+ ):
851
+ r"""
852
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
853
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
854
+ the model is configured as a decoder.
855
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
856
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
857
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
858
+ - 1 for tokens that are **not masked**,
859
+ - 0 for tokens that are **masked**.
860
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
861
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
862
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
863
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
864
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
865
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
866
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
867
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
868
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
869
+ use_cache (:obj:`bool`, `optional`):
870
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
871
+ decoding (see :obj:`past_key_values`).
872
+ Returns:
873
+ Example::
874
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
875
+ >>> import torch
876
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
877
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
878
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
879
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
880
+ >>> outputs = model(**inputs)
881
+ >>> prediction_logits = outputs.logits
882
+ """
883
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
884
+ if labels is not None:
885
+ use_cache = False
886
+
887
+ outputs = self.bert(
888
+ input_ids,
889
+ attention_mask=attention_mask,
890
+ position_ids=position_ids,
891
+ head_mask=head_mask,
892
+ inputs_embeds=inputs_embeds,
893
+ encoder_hidden_states=encoder_hidden_states,
894
+ encoder_attention_mask=encoder_attention_mask,
895
+ past_key_values=past_key_values,
896
+ use_cache=use_cache,
897
+ output_attentions=output_attentions,
898
+ output_hidden_states=output_hidden_states,
899
+ return_dict=return_dict,
900
+ is_decoder=is_decoder,
901
+ mode=mode,
902
+ )
903
+
904
+ sequence_output = outputs[0]
905
+ prediction_scores = self.cls(sequence_output)
906
+
907
+ if return_logits:
908
+ return prediction_scores[:, :-1, :].contiguous()
909
+
910
+ lm_loss = None
911
+ if labels is not None:
912
+ # we are doing next-token prediction; shift prediction scores and input ids by one
913
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
914
+ labels = labels[:, 1:].contiguous()
915
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
916
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
917
+ if reduction=='none':
918
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
919
+
920
+ if not return_dict:
921
+ output = (prediction_scores,) + outputs[2:]
922
+ return ((lm_loss,) + output) if lm_loss is not None else output
923
+
924
+ return CausalLMOutputWithCrossAttentions(
925
+ loss=lm_loss,
926
+ logits=prediction_scores,
927
+ past_key_values=outputs.past_key_values,
928
+ hidden_states=outputs.hidden_states,
929
+ attentions=outputs.attentions,
930
+ cross_attentions=outputs.cross_attentions,
931
+ )
932
+
933
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
934
+ input_shape = input_ids.shape
935
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
936
+ if attention_mask is None:
937
+ attention_mask = input_ids.new_ones(input_shape)
938
+
939
+ # cut decoder_input_ids if past is used
940
+ if past is not None:
941
+ input_ids = input_ids[:, -1:]
942
+
943
+ return {
944
+ "input_ids": input_ids,
945
+ "attention_mask": attention_mask,
946
+ "past_key_values": past,
947
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
948
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
949
+ "is_decoder": True,
950
+ }
951
+
952
+ def _reorder_cache(self, past, beam_idx):
953
+ reordered_past = ()
954
+ for layer_past in past:
955
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
956
+ return reordered_past
plot.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from bokeh.plotting import figure, output_file, show
3
+ from bokeh.models import Title, Div
4
+ from bokeh.palettes import Category10_10
5
+ from bokeh.plotting import figure, output_file, show, curdoc
6
+ from bokeh.models import Label, ColumnDataSource
7
+ from bokeh.palettes import Category10_10
8
+ from bokeh.layouts import column
9
+ from bokeh.models.widgets import CheckboxGroup, RadioGroup
10
+
11
+
12
+ # Create a sample dataframe
13
+ df = pd.DataFrame({
14
+ 'Training database size': [10, 20, 30, 40, 50, 60, 70, 80, 90, 100],
15
+ 'Number of hands on steering wheel': [70,80,76,83,84,88,91,92,93,94],
16
+ 'Number of hands on tablet': [97,97,98,99,99,99,99,99,99,99],
17
+ #'Tablet position': [100, 100, 100, 100, 100, 100, 100, 100, 100, 100]
18
+ })
19
+
20
+ df.to_csv('output/db_comparison.csv', index=False)
21
+
22
+ df.index.name = 'Training database size'
23
+ df_title = 'Accuracy evolution as function of the database size'
24
+
25
+ # Define output file name and create a new Bokeh figure
26
+ output_file('output/db_comparison.html')
27
+ p = figure(title='Accuracy evolution as function of the database size', x_axis_label='X-axis', y_axis_label='Y-axis', width=800, height=400, sizing_mode='scale_width')#, toolbar_location=None)
28
+
29
+
30
+ # Add a title to the x-axis
31
+ xaxis_title = Label(text='<b>Category</b>', x=0.5, y=-0.2, text_align='center', text_baseline='middle')
32
+ p.xaxis.axis_label = "Training database size"
33
+ p.yaxis.axis_label = "Accuracy"
34
+ #p.add_layout(xaxis_title, 'below')
35
+
36
+ # Define a color palette and loop through all columns except the first one (x)
37
+ palette = Category10_10
38
+ data = {}
39
+ for i, col in enumerate(df.columns[1:]):
40
+ # Add a line glyph for each column with different color and thickness
41
+ p.line(df['Training database size'], df[col], legend_label=col, line_width=3, line_color=palette[i])
42
+ data[col] = df[col]
43
+
44
+ # Create a ColumnDataSource with the data
45
+ source = ColumnDataSource(data)
46
+
47
+ # Define a checkbox group to allow users to toggle the visibility of the data series
48
+ checkbox_group = CheckboxGroup(labels=list(data.keys()), active=list(range(len(data))), width=200)
49
+
50
+ # Define a radio group to allow users to switch between different x-axis values
51
+ radio_group = RadioGroup(labels=['Training database size', 'Category'], active=0, width=200)
52
+
53
+ # Define a callback function to update the data source when the checkbox or radio button is changed
54
+ def update():
55
+ selected_cols = [list(data.keys())[i] for i in checkbox_group.active]
56
+ x_axis = radio_group.labels[radio_group.active]
57
+ new_data = {x_axis: df[x_axis]}
58
+ for col in selected_cols:
59
+ new_data[col] = data[col]
60
+ source.data = new_data
61
+ # Add the controls to the layout and define a callback for when they are changed
62
+ controls = column(checkbox_group, radio_group)
63
+ checkbox_group.on_change('active', lambda attr, old, new: update())
64
+ radio_group.on_change('active', lambda attr, old, new: update())
65
+
66
+ # Add the controls and the figure to the layout and display it
67
+ layout = column(p, controls, sizing_mode ="stretch_both")
68
+ show(layout)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ pillow
3
+ torch
4
+ torchvision
5
+ timm
6
+ fairscale
vit.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on timm code base
8
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from functools import partial
15
+
16
+ from timm.models.vision_transformer import _cfg, PatchEmbed
17
+ from timm.models.registry import register_model
18
+ from timm.models.layers import trunc_normal_, DropPath
19
+ from timm.models.helpers import named_apply, adapt_input_conv
20
+
21
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22
+
23
+ class Mlp(nn.Module):
24
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25
+ """
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50
+ self.scale = qk_scale or head_dim ** -0.5
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+ self.attn_gradients = None
56
+ self.attention_map = None
57
+
58
+ def save_attn_gradients(self, attn_gradients):
59
+ self.attn_gradients = attn_gradients
60
+
61
+ def get_attn_gradients(self):
62
+ return self.attn_gradients
63
+
64
+ def save_attention_map(self, attention_map):
65
+ self.attention_map = attention_map
66
+
67
+ def get_attention_map(self):
68
+ return self.attention_map
69
+
70
+ def forward(self, x, register_hook=False):
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ if register_hook:
80
+ self.save_attention_map(attn)
81
+ attn.register_hook(self.save_attn_gradients)
82
+
83
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84
+ x = self.proj(x)
85
+ x = self.proj_drop(x)
86
+ return x
87
+
88
+
89
+ class Block(nn.Module):
90
+
91
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93
+ super().__init__()
94
+ self.norm1 = norm_layer(dim)
95
+ self.attn = Attention(
96
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
+ self.norm2 = norm_layer(dim)
100
+ mlp_hidden_dim = int(dim * mlp_ratio)
101
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
+
103
+ if use_grad_checkpointing:
104
+ self.attn = checkpoint_wrapper(self.attn)
105
+ self.mlp = checkpoint_wrapper(self.mlp)
106
+
107
+ def forward(self, x, register_hook=False):
108
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
110
+ return x
111
+
112
+
113
+ class VisionTransformer(nn.Module):
114
+ """ Vision Transformer
115
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116
+ https://arxiv.org/abs/2010.11929
117
+ """
118
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121
+ use_grad_checkpointing=False, ckpt_layer=0):
122
+ """
123
+ Args:
124
+ img_size (int, tuple): input image size
125
+ patch_size (int, tuple): patch size
126
+ in_chans (int): number of input channels
127
+ num_classes (int): number of classes for classification head
128
+ embed_dim (int): embedding dimension
129
+ depth (int): depth of transformer
130
+ num_heads (int): number of attention heads
131
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132
+ qkv_bias (bool): enable bias for qkv if True
133
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135
+ drop_rate (float): dropout rate
136
+ attn_drop_rate (float): attention dropout rate
137
+ drop_path_rate (float): stochastic depth rate
138
+ norm_layer: (nn.Module): normalization layer
139
+ """
140
+ super().__init__()
141
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143
+
144
+ self.patch_embed = PatchEmbed(
145
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146
+
147
+ num_patches = self.patch_embed.num_patches
148
+
149
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151
+ self.pos_drop = nn.Dropout(p=drop_rate)
152
+
153
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154
+ self.blocks = nn.ModuleList([
155
+ Block(
156
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159
+ )
160
+ for i in range(depth)])
161
+ self.norm = norm_layer(embed_dim)
162
+
163
+ trunc_normal_(self.pos_embed, std=.02)
164
+ trunc_normal_(self.cls_token, std=.02)
165
+ self.apply(self._init_weights)
166
+
167
+ def _init_weights(self, m):
168
+ if isinstance(m, nn.Linear):
169
+ trunc_normal_(m.weight, std=.02)
170
+ if isinstance(m, nn.Linear) and m.bias is not None:
171
+ nn.init.constant_(m.bias, 0)
172
+ elif isinstance(m, nn.LayerNorm):
173
+ nn.init.constant_(m.bias, 0)
174
+ nn.init.constant_(m.weight, 1.0)
175
+
176
+ @torch.jit.ignore
177
+ def no_weight_decay(self):
178
+ return {'pos_embed', 'cls_token'}
179
+
180
+ def forward(self, x, register_blk=-1):
181
+ B = x.shape[0]
182
+ x = self.patch_embed(x)
183
+
184
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185
+ x = torch.cat((cls_tokens, x), dim=1)
186
+
187
+ x = x + self.pos_embed[:,:x.size(1),:]
188
+ x = self.pos_drop(x)
189
+
190
+ for i,blk in enumerate(self.blocks):
191
+ x = blk(x, register_blk==i)
192
+ x = self.norm(x)
193
+
194
+ return x
195
+
196
+ @torch.jit.ignore()
197
+ def load_pretrained(self, checkpoint_path, prefix=''):
198
+ _load_weights(self, checkpoint_path, prefix)
199
+
200
+
201
+ @torch.no_grad()
202
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204
+ """
205
+ import numpy as np
206
+
207
+ def _n2p(w, t=True):
208
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209
+ w = w.flatten()
210
+ if t:
211
+ if w.ndim == 4:
212
+ w = w.transpose([3, 2, 0, 1])
213
+ elif w.ndim == 3:
214
+ w = w.transpose([2, 0, 1])
215
+ elif w.ndim == 2:
216
+ w = w.transpose([1, 0])
217
+ return torch.from_numpy(w)
218
+
219
+ w = np.load(checkpoint_path)
220
+ if not prefix and 'opt/target/embedding/kernel' in w:
221
+ prefix = 'opt/target/'
222
+
223
+ if hasattr(model.patch_embed, 'backbone'):
224
+ # hybrid
225
+ backbone = model.patch_embed.backbone
226
+ stem_only = not hasattr(backbone, 'stem')
227
+ stem = backbone if stem_only else backbone.stem
228
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231
+ if not stem_only:
232
+ for i, stage in enumerate(backbone.stages):
233
+ for j, block in enumerate(stage.blocks):
234
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235
+ for r in range(3):
236
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239
+ if block.downsample is not None:
240
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244
+ else:
245
+ embed_conv_w = adapt_input_conv(
246
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
248
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251
+ if pos_embed_w.shape != model.pos_embed.shape:
252
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254
+ model.pos_embed.copy_(pos_embed_w)
255
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263
+ for i, block in enumerate(model.blocks.children()):
264
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268
+ block.attn.qkv.weight.copy_(torch.cat([
269
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270
+ block.attn.qkv.bias.copy_(torch.cat([
271
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274
+ for r in range(2):
275
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279
+
280
+
281
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282
+ # interpolate position embedding
283
+ embedding_size = pos_embed_checkpoint.shape[-1]
284
+ num_patches = visual_encoder.patch_embed.num_patches
285
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286
+ # height (== width) for the checkpoint position embedding
287
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288
+ # height (== width) for the new position embedding
289
+ new_size = int(num_patches ** 0.5)
290
+
291
+ if orig_size!=new_size:
292
+ # class_token and dist_token are kept unchanged
293
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294
+ # only the position tokens are interpolated
295
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297
+ pos_tokens = torch.nn.functional.interpolate(
298
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302
+
303
+ return new_pos_embed
304
+ else:
305
+ return pos_embed_checkpoint