Spaces:
Sleeping
Sleeping
Upload 61 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- lxmert/.gitignore +3 -0
- lxmert/.gitmodules +3 -0
- lxmert/.ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
- lxmert/LICENSE +21 -0
- lxmert/__init__.py +0 -0
- lxmert/__pycache__/__init__.cpython-38.pyc +0 -0
- lxmert/experiments/paper/COCO_val2014_000000127510/COCO_val2014_000000127510.jpg +0 -0
- lxmert/experiments/paper/COCO_val2014_000000185590/COCO_val2014_000000185590.jpg +0 -0
- lxmert/experiments/paper/COCO_val2014_000000200717/COCO_val2014_000000200717.jpg +0 -0
- lxmert/experiments/paper/COCO_val2014_000000324266/COCO_val2014_000000324266.jpg +0 -0
- lxmert/experiments/paper/new.jpg +0 -0
- lxmert/perturbation.py +254 -0
- lxmert/requirements.txt +107 -0
- lxmert/run/README.md +49 -0
- lxmert/run/gqa_finetune.bash +17 -0
- lxmert/run/gqa_test.bash +15 -0
- lxmert/run/lxmert_pretrain.bash +21 -0
- lxmert/run/nlvr2_finetune.bash +18 -0
- lxmert/run/nlvr2_test.bash +14 -0
- lxmert/run/vqa_finetune.bash +17 -0
- lxmert/run/vqa_test.bash +16 -0
- lxmert/src/.ipynb_checkpoints/Untitled-checkpoint.ipynb +81 -0
- lxmert/src/ExplanationGenerator.py +665 -0
- lxmert/src/__init__.py +0 -0
- lxmert/src/__pycache__/ExplanationGenerator.cpython-38.pyc +0 -0
- lxmert/src/__pycache__/__init__.cpython-38.pyc +0 -0
- lxmert/src/__pycache__/huggingface_lxmert.cpython-38.pyc +0 -0
- lxmert/src/__pycache__/layers.cpython-38.pyc +0 -0
- lxmert/src/__pycache__/lxmert_lrp.cpython-38.pyc +0 -0
- lxmert/src/__pycache__/modeling_frcnn.cpython-38.pyc +0 -0
- lxmert/src/__pycache__/processing_image.cpython-38.pyc +0 -0
- lxmert/src/__pycache__/vqa_utils.cpython-38.pyc +0 -0
- lxmert/src/huggingface_lxmert.py +1472 -0
- lxmert/src/layers.py +292 -0
- lxmert/src/lxmert_lrp.py +1693 -0
- lxmert/src/lxrt/__init__.py +0 -0
- lxmert/src/lxrt/entry.py +156 -0
- lxmert/src/lxrt/file_utils.py +247 -0
- lxmert/src/lxrt/modeling.py +1018 -0
- lxmert/src/lxrt/optimization.py +180 -0
- lxmert/src/lxrt/tokenization.py +388 -0
- lxmert/src/modeling_frcnn.py +1922 -0
- lxmert/src/param.py +126 -0
- lxmert/src/pretrain/__init__.py +0 -0
- lxmert/src/pretrain/lxmert_data.py +255 -0
- lxmert/src/pretrain/lxmert_pretrain.py +435 -0
- lxmert/src/pretrain/qa_answer_table.py +158 -0
- lxmert/src/processing_image.py +147 -0
- lxmert/src/tasks/__init__.py +0 -0
- lxmert/src/tasks/gqa.py +210 -0
lxmert/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
*.caffemodel
|
2 |
+
*.tsv
|
3 |
+
/snap
|
lxmert/.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "data/nlvr2/nlvr"]
|
2 |
+
path = data/nlvr2/nlvr
|
3 |
+
url = https://github.com/lil-lab/nlvr.git
|
lxmert/.ipynb_checkpoints/Untitled-checkpoint.ipynb
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [],
|
3 |
+
"metadata": {},
|
4 |
+
"nbformat": 4,
|
5 |
+
"nbformat_minor": 5
|
6 |
+
}
|
lxmert/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 Hao Tan
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
lxmert/__init__.py
ADDED
File without changes
|
lxmert/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (144 Bytes). View file
|
|
lxmert/experiments/paper/COCO_val2014_000000127510/COCO_val2014_000000127510.jpg
ADDED
![]() |
lxmert/experiments/paper/COCO_val2014_000000185590/COCO_val2014_000000185590.jpg
ADDED
![]() |
lxmert/experiments/paper/COCO_val2014_000000200717/COCO_val2014_000000200717.jpg
ADDED
![]() |
lxmert/experiments/paper/COCO_val2014_000000324266/COCO_val2014_000000324266.jpg
ADDED
![]() |
lxmert/experiments/paper/new.jpg
ADDED
![]() |
lxmert/perturbation.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lxmert.lxmert.src.tasks import vqa_data
|
2 |
+
from lxmert.lxmert.src.modeling_frcnn import GeneralizedRCNN
|
3 |
+
import lxmert.lxmert.src.vqa_utils as utils
|
4 |
+
from lxmert.lxmert.src.processing_image import Preprocess
|
5 |
+
from transformers import LxmertTokenizer
|
6 |
+
from lxmert.lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering
|
7 |
+
from lxmert.lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP
|
8 |
+
from tqdm import tqdm
|
9 |
+
from lxmert.lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation
|
10 |
+
import random
|
11 |
+
from lxmert.lxmert.src.param import args
|
12 |
+
|
13 |
+
OBJ_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/objects_vocab.txt"
|
14 |
+
ATTR_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/attributes_vocab.txt"
|
15 |
+
VQA_URL = "https://raw.githubusercontent.com/airsplay/lxmert/master/data/vqa/trainval_label2ans.json"
|
16 |
+
|
17 |
+
class ModelPert:
|
18 |
+
def __init__(self, COCO_val_path, use_lrp=False):
|
19 |
+
self.COCO_VAL_PATH = COCO_val_path
|
20 |
+
self.vqa_answers = utils.get_data(VQA_URL)
|
21 |
+
|
22 |
+
# load models and model components
|
23 |
+
self.frcnn_cfg = utils.Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
|
24 |
+
self.frcnn_cfg.MODEL.DEVICE = "cuda"
|
25 |
+
|
26 |
+
self.frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg)
|
27 |
+
|
28 |
+
self.image_preprocess = Preprocess(self.frcnn_cfg)
|
29 |
+
|
30 |
+
self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("unc-nlp/lxmert-base-uncased")
|
31 |
+
|
32 |
+
if use_lrp:
|
33 |
+
self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("unc-nlp/lxmert-vqa-uncased").to("cuda")
|
34 |
+
else:
|
35 |
+
self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased").to("cuda")
|
36 |
+
|
37 |
+
self.lxmert_vqa.eval()
|
38 |
+
self.model = self.lxmert_vqa
|
39 |
+
|
40 |
+
self.vqa_dataset = vqa_data.VQADataset(splits="valid")
|
41 |
+
|
42 |
+
self.pert_steps = [0, 0.25, 0.5, 0.75, 0.8, 0.85, 0.9, 0.95, 1]
|
43 |
+
self.pert_acc = [0] * len(self.pert_steps)
|
44 |
+
|
45 |
+
def forward(self, item):
|
46 |
+
image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
|
47 |
+
self.image_file_path = image_file_path
|
48 |
+
self.image_id = item['img_id']
|
49 |
+
# run frcnn
|
50 |
+
images, sizes, scales_yx = self.image_preprocess(image_file_path)
|
51 |
+
output_dict = self.frcnn(
|
52 |
+
images,
|
53 |
+
sizes,
|
54 |
+
scales_yx=scales_yx,
|
55 |
+
padding="max_detections",
|
56 |
+
max_detections= self.frcnn_cfg.max_detections,
|
57 |
+
return_tensors="pt"
|
58 |
+
)
|
59 |
+
inputs = self.lxmert_tokenizer(
|
60 |
+
item['sent'],
|
61 |
+
truncation=True,
|
62 |
+
return_token_type_ids=True,
|
63 |
+
return_attention_mask=True,
|
64 |
+
add_special_tokens=True,
|
65 |
+
return_tensors="pt"
|
66 |
+
)
|
67 |
+
self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten())
|
68 |
+
self.text_len = len(self.question_tokens)
|
69 |
+
# Very important that the boxes are normalized
|
70 |
+
normalized_boxes = output_dict.get("normalized_boxes")
|
71 |
+
features = output_dict.get("roi_features")
|
72 |
+
self.image_boxes_len = features.shape[1]
|
73 |
+
self.bboxes = output_dict.get("boxes")
|
74 |
+
self.output = self.lxmert_vqa(
|
75 |
+
input_ids=inputs.input_ids.to("cuda"),
|
76 |
+
attention_mask=inputs.attention_mask.to("cuda"),
|
77 |
+
visual_feats=features.to("cuda"),
|
78 |
+
visual_pos=normalized_boxes.to("cuda"),
|
79 |
+
token_type_ids=inputs.token_type_ids.to("cuda"),
|
80 |
+
return_dict=True,
|
81 |
+
output_attentions=False,
|
82 |
+
)
|
83 |
+
return self.output
|
84 |
+
|
85 |
+
def perturbation_image(self, item, cam_image, cam_text, is_positive_pert=False):
|
86 |
+
if is_positive_pert:
|
87 |
+
cam_image = cam_image * (-1)
|
88 |
+
image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
|
89 |
+
# run frcnn
|
90 |
+
images, sizes, scales_yx = self.image_preprocess(image_file_path)
|
91 |
+
output_dict = self.frcnn(
|
92 |
+
images,
|
93 |
+
sizes,
|
94 |
+
scales_yx=scales_yx,
|
95 |
+
padding="max_detections",
|
96 |
+
max_detections=self.frcnn_cfg.max_detections,
|
97 |
+
return_tensors="pt"
|
98 |
+
)
|
99 |
+
inputs = self.lxmert_tokenizer(
|
100 |
+
item['sent'],
|
101 |
+
truncation=True,
|
102 |
+
return_token_type_ids=True,
|
103 |
+
return_attention_mask=True,
|
104 |
+
add_special_tokens=True,
|
105 |
+
return_tensors="pt"
|
106 |
+
)
|
107 |
+
# Very important that the boxes are normalized
|
108 |
+
normalized_boxes = output_dict.get("normalized_boxes")
|
109 |
+
features = output_dict.get("roi_features")
|
110 |
+
for step_idx, step in enumerate(self.pert_steps):
|
111 |
+
# find top step boxes
|
112 |
+
curr_num_boxes = int((1 - step) * self.image_boxes_len)
|
113 |
+
_, top_bboxes_indices = cam_image.topk(k=curr_num_boxes, dim=-1)
|
114 |
+
top_bboxes_indices = top_bboxes_indices.cpu().data.numpy()
|
115 |
+
|
116 |
+
curr_features = features[:, top_bboxes_indices, :]
|
117 |
+
curr_pos = normalized_boxes[:, top_bboxes_indices, :]
|
118 |
+
|
119 |
+
output = self.lxmert_vqa(
|
120 |
+
input_ids=inputs.input_ids.to("cuda"),
|
121 |
+
attention_mask=inputs.attention_mask.to("cuda"),
|
122 |
+
visual_feats=curr_features.to("cuda"),
|
123 |
+
visual_pos=curr_pos.to("cuda"),
|
124 |
+
token_type_ids=inputs.token_type_ids.to("cuda"),
|
125 |
+
return_dict=True,
|
126 |
+
output_attentions=False,
|
127 |
+
)
|
128 |
+
|
129 |
+
answer = self.vqa_answers[output.question_answering_score.argmax()]
|
130 |
+
accuracy = item["label"].get(answer, 0)
|
131 |
+
self.pert_acc[step_idx] += accuracy
|
132 |
+
|
133 |
+
return self.pert_acc
|
134 |
+
|
135 |
+
def perturbation_text(self, item, cam_image, cam_text, is_positive_pert=False):
|
136 |
+
if is_positive_pert:
|
137 |
+
cam_text = cam_text * (-1)
|
138 |
+
image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
|
139 |
+
# run frcnn
|
140 |
+
images, sizes, scales_yx = self.image_preprocess(image_file_path)
|
141 |
+
output_dict = self.frcnn(
|
142 |
+
images,
|
143 |
+
sizes,
|
144 |
+
scales_yx=scales_yx,
|
145 |
+
padding="max_detections",
|
146 |
+
max_detections=self.frcnn_cfg.max_detections,
|
147 |
+
return_tensors="pt"
|
148 |
+
)
|
149 |
+
inputs = self.lxmert_tokenizer(
|
150 |
+
item['sent'],
|
151 |
+
truncation=True,
|
152 |
+
return_token_type_ids=True,
|
153 |
+
return_attention_mask=True,
|
154 |
+
add_special_tokens=True,
|
155 |
+
return_tensors="pt"
|
156 |
+
)
|
157 |
+
# Very important that the boxes are normalized
|
158 |
+
normalized_boxes = output_dict.get("normalized_boxes")
|
159 |
+
features = output_dict.get("roi_features")
|
160 |
+
for step_idx, step in enumerate(self.pert_steps):
|
161 |
+
# we must keep the [CLS] token in order to have the classification
|
162 |
+
# we also keep the [SEP] token
|
163 |
+
cam_pure_text = cam_text[1:-1]
|
164 |
+
text_len = cam_pure_text.shape[0]
|
165 |
+
# find top step tokens, without the [CLS] token and the [SEP] token
|
166 |
+
curr_num_tokens = int((1 - step) * text_len)
|
167 |
+
_, top_bboxes_indices = cam_pure_text.topk(k=curr_num_tokens, dim=-1)
|
168 |
+
top_bboxes_indices = top_bboxes_indices.cpu().data.numpy()
|
169 |
+
|
170 |
+
# add back [CLS], [SEP] tokens
|
171 |
+
top_bboxes_indices = [0, cam_text.shape[0] - 1] +\
|
172 |
+
[top_bboxes_indices[i] + 1 for i in range(len(top_bboxes_indices))]
|
173 |
+
# text tokens must be sorted for positional embedding to work
|
174 |
+
top_bboxes_indices = sorted(top_bboxes_indices)
|
175 |
+
|
176 |
+
curr_input_ids = inputs.input_ids[:, top_bboxes_indices]
|
177 |
+
curr_attention_mask = inputs.attention_mask[:, top_bboxes_indices]
|
178 |
+
curr_token_ids = inputs.token_type_ids[:, top_bboxes_indices]
|
179 |
+
|
180 |
+
output = self.lxmert_vqa(
|
181 |
+
input_ids=curr_input_ids.to("cuda"),
|
182 |
+
attention_mask=curr_attention_mask.to("cuda"),
|
183 |
+
visual_feats=features.to("cuda"),
|
184 |
+
visual_pos=normalized_boxes.to("cuda"),
|
185 |
+
token_type_ids=curr_token_ids.to("cuda"),
|
186 |
+
return_dict=True,
|
187 |
+
output_attentions=False,
|
188 |
+
)
|
189 |
+
|
190 |
+
answer = self.vqa_answers[output.question_answering_score.argmax()]
|
191 |
+
accuracy = item["label"].get(answer, 0)
|
192 |
+
self.pert_acc[step_idx] += accuracy
|
193 |
+
|
194 |
+
return self.pert_acc
|
195 |
+
|
196 |
+
def main(args):
|
197 |
+
model_pert = ModelPert(args.COCO_path, use_lrp=True)
|
198 |
+
ours = GeneratorOurs(model_pert)
|
199 |
+
baselines = GeneratorBaselines(model_pert)
|
200 |
+
oursNoAggAblation = GeneratorOursAblationNoAggregation(model_pert)
|
201 |
+
vqa_dataset = vqa_data.VQADataset(splits="valid")
|
202 |
+
vqa_answers = utils.get_data(VQA_URL)
|
203 |
+
method_name = args.method
|
204 |
+
|
205 |
+
items = vqa_dataset.data
|
206 |
+
random.seed(1234)
|
207 |
+
r = list(range(len(items)))
|
208 |
+
random.shuffle(r)
|
209 |
+
pert_samples_indices = r[:args.num_samples]
|
210 |
+
iterator = tqdm([vqa_dataset.data[i] for i in pert_samples_indices])
|
211 |
+
|
212 |
+
test_type = "positive" if args.is_positive_pert else "negative"
|
213 |
+
modality = "text" if args.is_text_pert else "image"
|
214 |
+
print("runnig {0} pert test for {1} modality with method {2}".format(test_type, modality, args.method))
|
215 |
+
|
216 |
+
for index, item in enumerate(iterator):
|
217 |
+
if method_name == 'transformer_att':
|
218 |
+
R_t_t, R_t_i = baselines.generate_transformer_attr(item)
|
219 |
+
elif method_name == 'attn_gradcam':
|
220 |
+
R_t_t, R_t_i = baselines.generate_attn_gradcam(item)
|
221 |
+
elif method_name == 'partial_lrp':
|
222 |
+
R_t_t, R_t_i = baselines.generate_partial_lrp(item)
|
223 |
+
elif method_name == 'raw_attn':
|
224 |
+
R_t_t, R_t_i = baselines.generate_raw_attn(item)
|
225 |
+
elif method_name == 'rollout':
|
226 |
+
R_t_t, R_t_i = baselines.generate_rollout(item)
|
227 |
+
elif method_name == "ours_with_lrp_no_normalization":
|
228 |
+
R_t_t, R_t_i = ours.generate_ours(item, normalize_self_attention=False)
|
229 |
+
elif method_name == "ours_no_lrp":
|
230 |
+
R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False)
|
231 |
+
elif method_name == "ours_no_lrp_no_norm":
|
232 |
+
R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False, normalize_self_attention=False)
|
233 |
+
elif method_name == "ours_with_lrp":
|
234 |
+
R_t_t, R_t_i = ours.generate_ours(item, use_lrp=True)
|
235 |
+
elif method_name == "ablation_no_self_in_10":
|
236 |
+
R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False, apply_self_in_rule_10=False)
|
237 |
+
elif method_name == "ablation_no_aggregation":
|
238 |
+
R_t_t, R_t_i = oursNoAggAblation.generate_ours_no_agg(item, use_lrp=False, normalize_self_attention=False)
|
239 |
+
else:
|
240 |
+
print("Please enter a valid method name")
|
241 |
+
return
|
242 |
+
cam_image = R_t_i[0]
|
243 |
+
cam_text = R_t_t[0]
|
244 |
+
cam_image = (cam_image - cam_image.min()) / (cam_image.max() - cam_image.min())
|
245 |
+
cam_text = (cam_text - cam_text.min()) / (cam_text.max() - cam_text.min())
|
246 |
+
if args.is_text_pert:
|
247 |
+
curr_pert_result = model_pert.perturbation_text(item, cam_image, cam_text, args.is_positive_pert)
|
248 |
+
else:
|
249 |
+
curr_pert_result = model_pert.perturbation_image(item, cam_image, cam_text, args.is_positive_pert)
|
250 |
+
curr_pert_result = [round(res / (index+1) * 100, 2) for res in curr_pert_result]
|
251 |
+
iterator.set_description("Acc: {}".format(curr_pert_result))
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
main(args)
|
lxmert/requirements.txt
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
argon2-cffi==20.1.0
|
2 |
+
async-generator==1.10
|
3 |
+
attrs==20.3.0
|
4 |
+
backcall==0.2.0
|
5 |
+
bleach==3.3.0
|
6 |
+
certifi==2020.12.5
|
7 |
+
cffi==1.14.5
|
8 |
+
chardet==3.0.4
|
9 |
+
click==7.1.2
|
10 |
+
cycler==0.10.0
|
11 |
+
Cython==0.29.22
|
12 |
+
dataclasses==0.6
|
13 |
+
decorator==4.4.2
|
14 |
+
defusedxml==0.6.0
|
15 |
+
demjson==2.2.4
|
16 |
+
editdistance==0.5.3
|
17 |
+
einops==0.3.0
|
18 |
+
entrypoints==0.3
|
19 |
+
fasttext==0.9.1
|
20 |
+
filelock==3.0.12
|
21 |
+
future==0.18.2
|
22 |
+
gitdb==4.0.5
|
23 |
+
GitPython==3.1.0
|
24 |
+
idna==2.10
|
25 |
+
imageio==2.9.0
|
26 |
+
importlib-metadata==3.4.0
|
27 |
+
ipykernel==5.4.3
|
28 |
+
ipython==7.20.0
|
29 |
+
ipython-genutils==0.2.0
|
30 |
+
ipywidgets==7.6.3
|
31 |
+
jedi==0.18.0
|
32 |
+
Jinja2==2.11.3
|
33 |
+
joblib==0.17.0
|
34 |
+
jsonschema==3.2.0
|
35 |
+
jupyter-client==6.1.11
|
36 |
+
jupyter-console==6.2.0
|
37 |
+
jupyter-core==4.7.1
|
38 |
+
jupyterlab-pygments==0.1.2
|
39 |
+
jupyterlab-widgets==1.0.0
|
40 |
+
kiwisolver==1.3.1
|
41 |
+
lmdb==0.98
|
42 |
+
MarkupSafe==1.1.1
|
43 |
+
matplotlib==3.3.4
|
44 |
+
mistune==0.8.4
|
45 |
+
nbclient==0.5.2
|
46 |
+
nbconvert==6.0.7
|
47 |
+
nbformat==5.1.2
|
48 |
+
nest-asyncio==1.5.1
|
49 |
+
networkx==2.4
|
50 |
+
nltk==3.4.5
|
51 |
+
notebook==6.2.0
|
52 |
+
numpy==1.19.2
|
53 |
+
omegaconf==2.0.1rc4
|
54 |
+
opencv-python==4.5.1.48
|
55 |
+
packaging==20.9
|
56 |
+
pandocfilters==1.4.3
|
57 |
+
parso==0.8.1
|
58 |
+
pexpect==4.8.0
|
59 |
+
pickleshare==0.7.5
|
60 |
+
Pillow==8.1.2
|
61 |
+
prometheus-client==0.9.0
|
62 |
+
prompt-toolkit==3.0.16
|
63 |
+
protobuf==3.15.6
|
64 |
+
ptyprocess==0.7.0
|
65 |
+
pybind11==2.6.2
|
66 |
+
pycocotools==2.0.2
|
67 |
+
pycparser==2.20
|
68 |
+
pyparsing==2.4.7
|
69 |
+
pyrsistent==0.17.3
|
70 |
+
python-dateutil==2.8.1
|
71 |
+
PyWavelets==1.1.1
|
72 |
+
PyYAML==5.4.1
|
73 |
+
pyzmq==22.0.3
|
74 |
+
qtconsole==5.0.2
|
75 |
+
QtPy==1.9.0
|
76 |
+
regex==2020.11.13
|
77 |
+
requests==2.23.0
|
78 |
+
sacremoses==0.0.43
|
79 |
+
scikit-image==0.17.2
|
80 |
+
scikit-learn==0.23.2
|
81 |
+
scipy==1.6.1
|
82 |
+
Send2Trash==1.5.0
|
83 |
+
sentencepiece==0.1.91
|
84 |
+
six==1.15.0
|
85 |
+
sklearn==0.0
|
86 |
+
smmap==3.0.5
|
87 |
+
termcolor==1.1.0
|
88 |
+
terminado==0.9.2
|
89 |
+
testpath==0.4.4
|
90 |
+
threadpoolctl==2.1.0
|
91 |
+
tifffile==2021.2.1
|
92 |
+
tokenizers==0.9.3
|
93 |
+
torch==1.7.1
|
94 |
+
torchtext==0.5.0
|
95 |
+
torchvision==0.8.2
|
96 |
+
tornado==6.1
|
97 |
+
tqdm==4.51.0
|
98 |
+
traitlets==5.0.5
|
99 |
+
transformers==3.5.1
|
100 |
+
typing-extensions==3.7.4.3
|
101 |
+
urllib3==1.25.11
|
102 |
+
utils==1.0.1
|
103 |
+
wcwidth==0.2.5
|
104 |
+
webencodings==0.5.1
|
105 |
+
wget==3.2
|
106 |
+
widgetsnbextension==3.5.1
|
107 |
+
zipp==3.4.0
|
lxmert/run/README.md
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Running Script Arguments
|
2 |
+
|
3 |
+
```
|
4 |
+
Data Splits:
|
5 |
+
--train [str,str,...]: use the splits (separated by comma) in training.
|
6 |
+
--valid [str,str,...]: use the splits (separated by comma) in validation.
|
7 |
+
--test [str,str,...]: use the splits (separated by comma) in testing.
|
8 |
+
Model Architecture:
|
9 |
+
--llayers [int]: number of layers in language encoder.
|
10 |
+
--xlayers [int]: number of layers in cross-modality encoder.
|
11 |
+
--rlayers [int]: number of layers in object relationship encoder.
|
12 |
+
Load Weights:
|
13 |
+
--load [str='path/to/saved_model']: load fine-tuned model path/to/saved_model.pth.
|
14 |
+
--loadLXMERT [str='path/to/saved_model']: load pre-trained model without answer heads from path/to/saved_model_LXRT.pth.
|
15 |
+
--loadLXMERTQA [str='path/to/saved_model']: load pre-trained model with answer head path/to/saved_model_LXRT.pth.
|
16 |
+
--fromScratch: If none of the above loading parameters are set, the default mode would
|
17 |
+
load the pre-trained BERT weights.
|
18 |
+
As we promised to EMNLP reviewers, the language encoder would be re-initialized with this one-line argument to test the performance without BERT weights.
|
19 |
+
Training Hyper Parameters:
|
20 |
+
--batchSize [int]: batch size.
|
21 |
+
--optim [str]: optimizers.
|
22 |
+
--lr [float]: peak learning rate.
|
23 |
+
--epochs [int]: training epochs.
|
24 |
+
Debugging:
|
25 |
+
--tiny: Load 512 images for each data split. (Note: number of images might be changed due to dataset specification)
|
26 |
+
--fast: Load 5000 images for each data split. (Note: number of images might be changed due to dataset specification)
|
27 |
+
```
|
28 |
+
|
29 |
+
# Pre-training-Specific Arguments
|
30 |
+
```
|
31 |
+
Pre-training Tasks:
|
32 |
+
--taskMaskLM: use the masked language model task.
|
33 |
+
--taskObjPredict: use the masked object prediction task.
|
34 |
+
--taskMatched: use the cross-modality matched task.
|
35 |
+
--taskQA: use the image QA task.
|
36 |
+
Visual Pre-training Losses (Tasks):
|
37 |
+
--visualLosses [str,str,...]: The sub-tasks in pre-training visual modality. Each one is from 'obj,attr,feat'.
|
38 |
+
obj: detected-object-label classification.
|
39 |
+
attr: detected-object-attribute classification.
|
40 |
+
feat: RoI-feature regression.
|
41 |
+
Mask Rate in Pre-training:
|
42 |
+
--wordMaskRate [float]: The prob of masking a word.
|
43 |
+
--objMaskRate [float]: The prob of masking an object.
|
44 |
+
Initialization:
|
45 |
+
--fromScratch: The default mode would load the pre-trained BERT weights into the model.
|
46 |
+
As we promised to EMNLP reviewers, this option would re-initialize the language encoder.
|
47 |
+
```
|
48 |
+
|
49 |
+
|
lxmert/run/gqa_finetune.bash
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The name of this experiment.
|
2 |
+
name=$2
|
3 |
+
|
4 |
+
# Save logs and models under snap/gqa; make backup.
|
5 |
+
output=snap/gqa/$name
|
6 |
+
mkdir -p $output/src
|
7 |
+
cp -r src/* $output/src/
|
8 |
+
cp $0 $output/run.bash
|
9 |
+
|
10 |
+
# See Readme.md for option details.
|
11 |
+
CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
|
12 |
+
python src/tasks/gqa.py \
|
13 |
+
--train train,valid --valid testdev \
|
14 |
+
--llayers 9 --xlayers 5 --rlayers 5 \
|
15 |
+
--loadLXMERTQA snap/pretrained/model \
|
16 |
+
--batchSize 32 --optim bert --lr 1e-5 --epochs 4 \
|
17 |
+
--tqdm --output $output ${@:3}
|
lxmert/run/gqa_test.bash
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The name of this experiment.
|
2 |
+
name=$2
|
3 |
+
|
4 |
+
# Save logs and models under snap/gqa; make backup.
|
5 |
+
output=snap/gqa/$name
|
6 |
+
mkdir -p $output/src
|
7 |
+
cp -r src/* $output/src/
|
8 |
+
cp $0 $output/run.bash
|
9 |
+
|
10 |
+
# See Readme.md for option details.
|
11 |
+
CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
|
12 |
+
python src/tasks/gqa.py \
|
13 |
+
--tiny --train train --valid "" \
|
14 |
+
--llayers 9 --xlayers 5 --rlayers 5 \
|
15 |
+
--tqdm --output $output ${@:3}
|
lxmert/run/lxmert_pretrain.bash
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The name of experiment
|
2 |
+
name=lxmert
|
3 |
+
|
4 |
+
# Create dirs and make backup
|
5 |
+
output=snap/pretrain/$name
|
6 |
+
mkdir -p $output/src
|
7 |
+
cp -r src/* $output/src/
|
8 |
+
cp $0 $output/run.bash
|
9 |
+
|
10 |
+
# Pre-training
|
11 |
+
CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
|
12 |
+
python src/pretrain/lxmert_pretrain.py \
|
13 |
+
--taskMaskLM --taskObjPredict --taskMatched --taskQA \
|
14 |
+
--visualLosses obj,attr,feat \
|
15 |
+
--wordMaskRate 0.15 --objMaskRate 0.15 \
|
16 |
+
--train mscoco_train,mscoco_nominival,vgnococo --valid mscoco_minival \
|
17 |
+
--llayers 9 --xlayers 5 --rlayers 5 \
|
18 |
+
--fromScratch \
|
19 |
+
--batchSize 256 --optim bert --lr 1e-4 --epochs 20 \
|
20 |
+
--tqdm --output $output ${@:2}
|
21 |
+
|
lxmert/run/nlvr2_finetune.bash
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The name of this experiment.
|
2 |
+
name=$2
|
3 |
+
|
4 |
+
# Save logs and models under snap/nlvr2; Make backup.
|
5 |
+
output=snap/nlvr2/$name
|
6 |
+
mkdir -p $output/src
|
7 |
+
cp -r src/* $output/src/
|
8 |
+
cp $0 $output/run.bash
|
9 |
+
|
10 |
+
# See run/Readme.md for option details.
|
11 |
+
CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
|
12 |
+
python src/tasks/nlvr2.py \
|
13 |
+
--train train --valid valid \
|
14 |
+
--llayers 9 --xlayers 5 --rlayers 5 \
|
15 |
+
--loadLXMERT snap/pretrained/model \
|
16 |
+
--batchSize 32 --optim bert --lr 5e-5 --epochs 4 \
|
17 |
+
--tqdm --output $output ${@:3}
|
18 |
+
|
lxmert/run/nlvr2_test.bash
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The name of this experiment.
|
2 |
+
name=$2
|
3 |
+
|
4 |
+
# Save logs and models under snap/nlvr2; make backup.
|
5 |
+
output=snap/nlvr2/$name
|
6 |
+
mkdir -p $output/src
|
7 |
+
cp -r src/* $output/src/
|
8 |
+
cp $0 $output/run.bash
|
9 |
+
|
10 |
+
# See Readme.md for option details.
|
11 |
+
CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
|
12 |
+
python src/tasks/nlvr2.py \
|
13 |
+
--tiny --llayers 9 --xlayers 5 --rlayers 5 \
|
14 |
+
--tqdm --output $output ${@:3}
|
lxmert/run/vqa_finetune.bash
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The name of this experiment.
|
2 |
+
name=$2
|
3 |
+
|
4 |
+
# Save logs and models under snap/vqa; make backup.
|
5 |
+
output=snap/vqa/$name
|
6 |
+
mkdir -p $output/src
|
7 |
+
cp -r src/* $output/src/
|
8 |
+
cp $0 $output/run.bash
|
9 |
+
|
10 |
+
# See Readme.md for option details.
|
11 |
+
CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
|
12 |
+
python src/tasks/vqa.py \
|
13 |
+
--train train,nominival --valid minival \
|
14 |
+
--llayers 9 --xlayers 5 --rlayers 5 \
|
15 |
+
--loadLXMERTQA snap/pretrained/model \
|
16 |
+
--batchSize 32 --optim bert --lr 5e-5 --epochs 4 \
|
17 |
+
--tqdm --output $output ${@:3}
|
lxmert/run/vqa_test.bash
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The name of this experiment.
|
2 |
+
name=$2
|
3 |
+
|
4 |
+
# Save logs and models under snap/vqa; make backup.
|
5 |
+
output=snap/vqa/$name
|
6 |
+
mkdir -p $output/src
|
7 |
+
cp -r src/* $output/src/
|
8 |
+
cp $0 $output/run.bash
|
9 |
+
|
10 |
+
# See Readme.md for option details.
|
11 |
+
CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
|
12 |
+
python src/tasks/vqa.py \
|
13 |
+
--tiny --train train --valid "" \
|
14 |
+
--llayers 9 --xlayers 5 --rlayers 5 \
|
15 |
+
--batchSize 32 --optim bert --lr 5e-5 --epochs 4 \
|
16 |
+
--tqdm --output $output ${@:3}
|
lxmert/src/.ipynb_checkpoints/Untitled-checkpoint.ipynb
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 7,
|
6 |
+
"id": "loose-wrong",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"ename": "ModuleNotFoundError",
|
11 |
+
"evalue": "No module named 'src'",
|
12 |
+
"output_type": "error",
|
13 |
+
"traceback": [
|
14 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
15 |
+
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
16 |
+
"\u001b[0;32m<ipython-input-7-b03239bcd702>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mlxmert_lrp\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLxmertForQuestionAnswering\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mLxmertForQuestionAnsweringLRP\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtasks\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mvqa_data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodeling_frcnn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mGeneralizedRCNN\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvqa_utils\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocessing_image\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPreprocess\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
17 |
+
"\u001b[0;32m/media/data2/hila_chefer/lxmert/lxmert/src/lxmert_lrp.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCrossEntropyLoss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSmoothL1Loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 28\u001b[0m from transformers.file_utils import (\n\u001b[1;32m 29\u001b[0m \u001b[0mModelOutput\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
18 |
+
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'src'"
|
19 |
+
]
|
20 |
+
}
|
21 |
+
],
|
22 |
+
"source": [
|
23 |
+
"from lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP\n",
|
24 |
+
"from src.tasks import vqa_data\n",
|
25 |
+
"from src.modeling_frcnn import GeneralizedRCNN\n",
|
26 |
+
"import src.vqa_utils as utils\n",
|
27 |
+
"from src.processing_image import Preprocess\n",
|
28 |
+
"from transformers import LxmertTokenizer\n",
|
29 |
+
"from src.huggingface_lxmert import LxmertForQuestionAnswering\n",
|
30 |
+
"\n",
|
31 |
+
"from tqdm import tqdm\n",
|
32 |
+
"from src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines\n",
|
33 |
+
"import random\n",
|
34 |
+
"import cv2\n",
|
35 |
+
"\n",
|
36 |
+
"COCO_VAL_PATH = '/media/data2/hila_chefer/env_MMF/datasets/coco/subset_val/images/val2014/'\n",
|
37 |
+
"\n",
|
38 |
+
"OBJ_URL = \"https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/objects_vocab.txt\"\n",
|
39 |
+
"ATTR_URL = \"https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/attributes_vocab.txt\"\n",
|
40 |
+
"VQA_URL = \"https://raw.githubusercontent.com/airsplay/lxmert/master/data/vqa/trainval_label2ans.json\""
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": null,
|
46 |
+
"id": "emerging-trace",
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [],
|
49 |
+
"source": []
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"execution_count": null,
|
54 |
+
"id": "royal-small",
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [],
|
57 |
+
"source": []
|
58 |
+
}
|
59 |
+
],
|
60 |
+
"metadata": {
|
61 |
+
"kernelspec": {
|
62 |
+
"display_name": "Python 3",
|
63 |
+
"language": "python",
|
64 |
+
"name": "python3"
|
65 |
+
},
|
66 |
+
"language_info": {
|
67 |
+
"codemirror_mode": {
|
68 |
+
"name": "ipython",
|
69 |
+
"version": 3
|
70 |
+
},
|
71 |
+
"file_extension": ".py",
|
72 |
+
"mimetype": "text/x-python",
|
73 |
+
"name": "python",
|
74 |
+
"nbconvert_exporter": "python",
|
75 |
+
"pygments_lexer": "ipython3",
|
76 |
+
"version": "3.7.9"
|
77 |
+
}
|
78 |
+
},
|
79 |
+
"nbformat": 4,
|
80 |
+
"nbformat_minor": 5
|
81 |
+
}
|
lxmert/src/ExplanationGenerator.py
ADDED
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import copy
|
4 |
+
|
5 |
+
|
6 |
+
def compute_rollout_attention(all_layer_matrices, start_layer=0):
|
7 |
+
# adding residual consideration
|
8 |
+
num_tokens = all_layer_matrices[0].shape[1]
|
9 |
+
eye = torch.eye(num_tokens).to(all_layer_matrices[0].device)
|
10 |
+
all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
|
11 |
+
matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
|
12 |
+
for i in range(len(all_layer_matrices))]
|
13 |
+
joint_attention = matrices_aug[start_layer]
|
14 |
+
for i in range(start_layer + 1, len(matrices_aug)):
|
15 |
+
joint_attention = matrices_aug[i].matmul(joint_attention)
|
16 |
+
return joint_attention
|
17 |
+
|
18 |
+
|
19 |
+
# rule 5 from paper
|
20 |
+
def avg_heads(cam, grad):
|
21 |
+
cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])
|
22 |
+
grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
|
23 |
+
cam = grad * cam
|
24 |
+
cam = cam.clamp(min=0).mean(dim=0)
|
25 |
+
return cam
|
26 |
+
|
27 |
+
|
28 |
+
# rules 6 + 7 from paper
|
29 |
+
def apply_self_attention_rules(R_ss, R_sq, cam_ss):
|
30 |
+
R_sq_addition = torch.matmul(cam_ss, R_sq)
|
31 |
+
R_ss_addition = torch.matmul(cam_ss, R_ss)
|
32 |
+
return R_ss_addition, R_sq_addition
|
33 |
+
|
34 |
+
|
35 |
+
# rules 10 + 11 from paper
|
36 |
+
def apply_mm_attention_rules(R_ss, R_qq, R_qs, cam_sq, apply_normalization=True, apply_self_in_rule_10=True):
|
37 |
+
R_ss_normalized = R_ss
|
38 |
+
R_qq_normalized = R_qq
|
39 |
+
if apply_normalization:
|
40 |
+
R_ss_normalized = handle_residual(R_ss)
|
41 |
+
R_qq_normalized = handle_residual(R_qq)
|
42 |
+
R_sq_addition = torch.matmul(R_ss_normalized.t(), torch.matmul(cam_sq, R_qq_normalized))
|
43 |
+
if not apply_self_in_rule_10:
|
44 |
+
R_sq_addition = cam_sq
|
45 |
+
R_ss_addition = torch.matmul(cam_sq, R_qs)
|
46 |
+
return R_sq_addition, R_ss_addition
|
47 |
+
|
48 |
+
|
49 |
+
# normalization- eq. 8+9
|
50 |
+
def handle_residual(orig_self_attention):
|
51 |
+
self_attention = orig_self_attention.clone()
|
52 |
+
diag_idx = range(self_attention.shape[-1])
|
53 |
+
# computing R hat
|
54 |
+
self_attention -= torch.eye(self_attention.shape[-1]).to(self_attention.device)
|
55 |
+
assert self_attention[diag_idx, diag_idx].min() >= 0
|
56 |
+
# normalizing R hat
|
57 |
+
self_attention = self_attention / self_attention.sum(dim=-1, keepdim=True)
|
58 |
+
self_attention += torch.eye(self_attention.shape[-1]).to(self_attention.device)
|
59 |
+
return self_attention
|
60 |
+
|
61 |
+
|
62 |
+
class GeneratorOurs:
|
63 |
+
def __init__(self, model_usage, save_visualization=False):
|
64 |
+
self.model_usage = model_usage
|
65 |
+
self.save_visualization = save_visualization
|
66 |
+
|
67 |
+
def handle_self_attention_lang(self, blocks):
|
68 |
+
for blk in blocks:
|
69 |
+
grad = blk.attention.self.get_attn_gradients().detach()
|
70 |
+
if self.use_lrp:
|
71 |
+
cam = blk.attention.self.get_attn_cam().detach()
|
72 |
+
else:
|
73 |
+
cam = blk.attention.self.get_attn().detach()
|
74 |
+
cam = avg_heads(cam, grad)
|
75 |
+
R_t_t_add, R_t_i_add = apply_self_attention_rules(self.R_t_t, self.R_t_i, cam)
|
76 |
+
self.R_t_t += R_t_t_add
|
77 |
+
self.R_t_i += R_t_i_add
|
78 |
+
|
79 |
+
def handle_self_attention_image(self, blocks):
|
80 |
+
for blk in blocks:
|
81 |
+
grad = blk.attention.self.get_attn_gradients().detach()
|
82 |
+
if self.use_lrp:
|
83 |
+
cam = blk.attention.self.get_attn_cam().detach()
|
84 |
+
else:
|
85 |
+
cam = blk.attention.self.get_attn().detach()
|
86 |
+
cam = avg_heads(cam, grad)
|
87 |
+
R_i_i_add, R_i_t_add = apply_self_attention_rules(self.R_i_i, self.R_i_t, cam)
|
88 |
+
self.R_i_i += R_i_i_add
|
89 |
+
self.R_i_t += R_i_t_add
|
90 |
+
|
91 |
+
def handle_co_attn_self_lang(self, block):
|
92 |
+
grad = block.lang_self_att.self.get_attn_gradients().detach()
|
93 |
+
if self.use_lrp:
|
94 |
+
cam = block.lang_self_att.self.get_attn_cam().detach()
|
95 |
+
else:
|
96 |
+
cam = block.lang_self_att.self.get_attn().detach()
|
97 |
+
cam = avg_heads(cam, grad)
|
98 |
+
R_t_t_add, R_t_i_add = apply_self_attention_rules(self.R_t_t, self.R_t_i, cam)
|
99 |
+
self.R_t_t += R_t_t_add
|
100 |
+
self.R_t_i += R_t_i_add
|
101 |
+
|
102 |
+
def handle_co_attn_self_image(self, block):
|
103 |
+
grad = block.visn_self_att.self.get_attn_gradients().detach()
|
104 |
+
if self.use_lrp:
|
105 |
+
cam = block.visn_self_att.self.get_attn_cam().detach()
|
106 |
+
else:
|
107 |
+
cam = block.visn_self_att.self.get_attn().detach()
|
108 |
+
cam = avg_heads(cam, grad)
|
109 |
+
R_i_i_add, R_i_t_add = apply_self_attention_rules(self.R_i_i, self.R_i_t, cam)
|
110 |
+
self.R_i_i += R_i_i_add
|
111 |
+
self.R_i_t += R_i_t_add
|
112 |
+
|
113 |
+
def handle_co_attn_lang(self, block):
|
114 |
+
if self.use_lrp:
|
115 |
+
cam_t_i = block.visual_attention.att.get_attn_cam().detach()
|
116 |
+
else:
|
117 |
+
cam_t_i = block.visual_attention.att.get_attn().detach()
|
118 |
+
grad_t_i = block.visual_attention.att.get_attn_gradients().detach()
|
119 |
+
cam_t_i = avg_heads(cam_t_i, grad_t_i)
|
120 |
+
R_t_i_addition, R_t_t_addition = apply_mm_attention_rules(self.R_t_t, self.R_i_i, self.R_i_t, cam_t_i,
|
121 |
+
apply_normalization=self.normalize_self_attention,
|
122 |
+
apply_self_in_rule_10=self.apply_self_in_rule_10)
|
123 |
+
return R_t_i_addition, R_t_t_addition
|
124 |
+
|
125 |
+
def handle_co_attn_image(self, block):
|
126 |
+
if self.use_lrp:
|
127 |
+
cam_i_t = block.visual_attention_copy.att.get_attn_cam().detach()
|
128 |
+
else:
|
129 |
+
cam_i_t = block.visual_attention_copy.att.get_attn().detach()
|
130 |
+
grad_i_t = block.visual_attention_copy.att.get_attn_gradients().detach()
|
131 |
+
cam_i_t = avg_heads(cam_i_t, grad_i_t)
|
132 |
+
R_i_t_addition, R_i_i_addition = apply_mm_attention_rules(self.R_i_i, self.R_t_t, self.R_t_i, cam_i_t,
|
133 |
+
apply_normalization=self.normalize_self_attention,
|
134 |
+
apply_self_in_rule_10=self.apply_self_in_rule_10)
|
135 |
+
return R_i_t_addition, R_i_i_addition
|
136 |
+
|
137 |
+
def generate_ours(self, input, index=None, use_lrp=True, normalize_self_attention=True, apply_self_in_rule_10=True,
|
138 |
+
method_name="ours"):
|
139 |
+
self.use_lrp = use_lrp
|
140 |
+
self.normalize_self_attention = normalize_self_attention
|
141 |
+
self.apply_self_in_rule_10 = apply_self_in_rule_10
|
142 |
+
kwargs = {"alpha": 1}
|
143 |
+
output = self.model_usage.forward(input).question_answering_score
|
144 |
+
model = self.model_usage.model
|
145 |
+
|
146 |
+
# initialize relevancy matrices
|
147 |
+
text_tokens = self.model_usage.text_len
|
148 |
+
image_bboxes = self.model_usage.image_boxes_len
|
149 |
+
|
150 |
+
# text self attention matrix
|
151 |
+
self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device)
|
152 |
+
# image self attention matrix
|
153 |
+
self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device)
|
154 |
+
# impact of images on text
|
155 |
+
self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
|
156 |
+
# impact of text on images
|
157 |
+
self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)
|
158 |
+
|
159 |
+
if index is None:
|
160 |
+
index = np.argmax(output.cpu().data.numpy(), axis=-1)
|
161 |
+
|
162 |
+
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
163 |
+
one_hot[0, index] = 1
|
164 |
+
one_hot_vector = one_hot
|
165 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
166 |
+
one_hot = torch.sum(one_hot.cuda() * output)
|
167 |
+
|
168 |
+
model.zero_grad()
|
169 |
+
one_hot.backward(retain_graph=True)
|
170 |
+
if self.use_lrp:
|
171 |
+
model.relprop(torch.tensor(one_hot_vector).to(output.device), **kwargs)
|
172 |
+
|
173 |
+
# language self attention
|
174 |
+
blocks = model.lxmert.encoder.layer
|
175 |
+
self.handle_self_attention_lang(blocks)
|
176 |
+
|
177 |
+
# image self attention
|
178 |
+
blocks = model.lxmert.encoder.r_layers
|
179 |
+
self.handle_self_attention_image(blocks)
|
180 |
+
|
181 |
+
# cross attn layers
|
182 |
+
blocks = model.lxmert.encoder.x_layers
|
183 |
+
for i, blk in enumerate(blocks):
|
184 |
+
# in the last cross attention module, only the text cross modal
|
185 |
+
# attention has an impact on the CLS token, since it's the first
|
186 |
+
# token in the language tokens
|
187 |
+
if i == len(blocks) - 1:
|
188 |
+
break
|
189 |
+
# cross attn- first for language then for image
|
190 |
+
R_t_i_addition, R_t_t_addition = self.handle_co_attn_lang(blk)
|
191 |
+
R_i_t_addition, R_i_i_addition = self.handle_co_attn_image(blk)
|
192 |
+
|
193 |
+
self.R_t_i += R_t_i_addition
|
194 |
+
self.R_t_t += R_t_t_addition
|
195 |
+
self.R_i_t += R_i_t_addition
|
196 |
+
self.R_i_i += R_i_i_addition
|
197 |
+
|
198 |
+
# language self attention
|
199 |
+
self.handle_co_attn_self_lang(blk)
|
200 |
+
|
201 |
+
# image self attention
|
202 |
+
self.handle_co_attn_self_image(blk)
|
203 |
+
|
204 |
+
# take care of last cross attention layer- only text
|
205 |
+
blk = model.lxmert.encoder.x_layers[-1]
|
206 |
+
# cross attn- first for language then for image
|
207 |
+
R_t_i_addition, R_t_t_addition = self.handle_co_attn_lang(blk)
|
208 |
+
self.R_t_i += R_t_i_addition
|
209 |
+
self.R_t_t += R_t_t_addition
|
210 |
+
|
211 |
+
# language self attention
|
212 |
+
self.handle_co_attn_self_lang(blk)
|
213 |
+
|
214 |
+
# disregard the [CLS] token itself
|
215 |
+
self.R_t_t[0, 0] = 0
|
216 |
+
return self.R_t_t, self.R_t_i
|
217 |
+
|
218 |
+
|
219 |
+
class GeneratorOursAblationNoAggregation:
|
220 |
+
def __init__(self, model_usage, save_visualization=False):
|
221 |
+
self.model_usage = model_usage
|
222 |
+
self.save_visualization = save_visualization
|
223 |
+
|
224 |
+
def handle_self_attention_lang(self, blocks):
|
225 |
+
for blk in blocks:
|
226 |
+
grad = blk.attention.self.get_attn_gradients().detach()
|
227 |
+
if self.use_lrp:
|
228 |
+
cam = blk.attention.self.get_attn_cam().detach()
|
229 |
+
else:
|
230 |
+
cam = blk.attention.self.get_attn().detach()
|
231 |
+
cam = avg_heads(cam, grad)
|
232 |
+
R_t_t_add, R_t_i_add = apply_self_attention_rules(self.R_t_t, self.R_t_i, cam)
|
233 |
+
self.R_t_t = R_t_t_add
|
234 |
+
self.R_t_i = R_t_i_add
|
235 |
+
|
236 |
+
def handle_self_attention_image(self, blocks):
|
237 |
+
for blk in blocks:
|
238 |
+
grad = blk.attention.self.get_attn_gradients().detach()
|
239 |
+
if self.use_lrp:
|
240 |
+
cam = blk.attention.self.get_attn_cam().detach()
|
241 |
+
else:
|
242 |
+
cam = blk.attention.self.get_attn().detach()
|
243 |
+
cam = avg_heads(cam, grad)
|
244 |
+
R_i_i_add, R_i_t_add = apply_self_attention_rules(self.R_i_i, self.R_i_t, cam)
|
245 |
+
self.R_i_i = R_i_i_add
|
246 |
+
self.R_i_t = R_i_t_add
|
247 |
+
|
248 |
+
def handle_co_attn_self_lang(self, block):
|
249 |
+
grad = block.lang_self_att.self.get_attn_gradients().detach()
|
250 |
+
if self.use_lrp:
|
251 |
+
cam = block.lang_self_att.self.get_attn_cam().detach()
|
252 |
+
else:
|
253 |
+
cam = block.lang_self_att.self.get_attn().detach()
|
254 |
+
cam = avg_heads(cam, grad)
|
255 |
+
R_t_t_add, R_t_i_add = apply_self_attention_rules(self.R_t_t, self.R_t_i, cam)
|
256 |
+
self.R_t_t = R_t_t_add
|
257 |
+
self.R_t_i = R_t_i_add
|
258 |
+
|
259 |
+
def handle_co_attn_self_image(self, block):
|
260 |
+
grad = block.visn_self_att.self.get_attn_gradients().detach()
|
261 |
+
if self.use_lrp:
|
262 |
+
cam = block.visn_self_att.self.get_attn_cam().detach()
|
263 |
+
else:
|
264 |
+
cam = block.visn_self_att.self.get_attn().detach()
|
265 |
+
cam = avg_heads(cam, grad)
|
266 |
+
R_i_i_add, R_i_t_add = apply_self_attention_rules(self.R_i_i, self.R_i_t, cam)
|
267 |
+
self.R_i_i = R_i_i_add
|
268 |
+
self.R_i_t = R_i_t_add
|
269 |
+
|
270 |
+
def handle_co_attn_lang(self, block):
|
271 |
+
if self.use_lrp:
|
272 |
+
cam_t_i = block.visual_attention.att.get_attn_cam().detach()
|
273 |
+
else:
|
274 |
+
cam_t_i = block.visual_attention.att.get_attn().detach()
|
275 |
+
grad_t_i = block.visual_attention.att.get_attn_gradients().detach()
|
276 |
+
cam_t_i = avg_heads(cam_t_i, grad_t_i)
|
277 |
+
R_t_i_addition, R_t_t_addition = apply_mm_attention_rules(self.R_t_t, self.R_i_i, self.R_i_t, cam_t_i,
|
278 |
+
apply_normalization=self.normalize_self_attention)
|
279 |
+
return R_t_i_addition, R_t_t_addition
|
280 |
+
|
281 |
+
def handle_co_attn_image(self, block):
|
282 |
+
if self.use_lrp:
|
283 |
+
cam_i_t = block.visual_attention_copy.att.get_attn_cam().detach()
|
284 |
+
else:
|
285 |
+
cam_i_t = block.visual_attention_copy.att.get_attn().detach()
|
286 |
+
grad_i_t = block.visual_attention_copy.att.get_attn_gradients().detach()
|
287 |
+
cam_i_t = avg_heads(cam_i_t, grad_i_t)
|
288 |
+
R_i_t_addition, R_i_i_addition = apply_mm_attention_rules(self.R_i_i, self.R_t_t, self.R_t_i, cam_i_t,
|
289 |
+
apply_normalization=self.normalize_self_attention)
|
290 |
+
return R_i_t_addition, R_i_i_addition
|
291 |
+
|
292 |
+
def generate_ours_no_agg(self, input, index=None, use_lrp=False, normalize_self_attention=True,
|
293 |
+
method_name="ours_no_agg"):
|
294 |
+
self.use_lrp = use_lrp
|
295 |
+
self.normalize_self_attention = normalize_self_attention
|
296 |
+
kwargs = {"alpha": 1}
|
297 |
+
output = self.model_usage.forward(input).question_answering_score
|
298 |
+
model = self.model_usage.model
|
299 |
+
|
300 |
+
# initialize relevancy matrices
|
301 |
+
text_tokens = self.model_usage.text_len
|
302 |
+
image_bboxes = self.model_usage.image_boxes_len
|
303 |
+
|
304 |
+
# text self attention matrix
|
305 |
+
self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device)
|
306 |
+
# image self attention matrix
|
307 |
+
self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device)
|
308 |
+
# impact of images on text
|
309 |
+
self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
|
310 |
+
# impact of text on images
|
311 |
+
self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)
|
312 |
+
|
313 |
+
if index is None:
|
314 |
+
index = np.argmax(output.cpu().data.numpy(), axis=-1)
|
315 |
+
|
316 |
+
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
317 |
+
one_hot[0, index] = 1
|
318 |
+
one_hot_vector = one_hot
|
319 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
320 |
+
one_hot = torch.sum(one_hot.cuda() * output)
|
321 |
+
|
322 |
+
model.zero_grad()
|
323 |
+
one_hot.backward(retain_graph=True)
|
324 |
+
if self.use_lrp:
|
325 |
+
model.relprop(torch.tensor(one_hot_vector).to(output.device), **kwargs)
|
326 |
+
|
327 |
+
# language self attention
|
328 |
+
blocks = model.lxmert.encoder.layer
|
329 |
+
self.handle_self_attention_lang(blocks)
|
330 |
+
|
331 |
+
# image self attention
|
332 |
+
blocks = model.lxmert.encoder.r_layers
|
333 |
+
self.handle_self_attention_image(blocks)
|
334 |
+
|
335 |
+
# cross attn layers
|
336 |
+
blocks = model.lxmert.encoder.x_layers
|
337 |
+
for i, blk in enumerate(blocks):
|
338 |
+
# in the last cross attention module, only the text cross modal
|
339 |
+
# attention has an impact on the CLS token, since it's the first
|
340 |
+
# token in the language tokens
|
341 |
+
if i == len(blocks) - 1:
|
342 |
+
break
|
343 |
+
# cross attn- first for language then for image
|
344 |
+
R_t_i_addition, R_t_t_addition = self.handle_co_attn_lang(blk)
|
345 |
+
R_i_t_addition, R_i_i_addition = self.handle_co_attn_image(blk)
|
346 |
+
|
347 |
+
self.R_t_i = R_t_i_addition
|
348 |
+
self.R_t_t = R_t_t_addition
|
349 |
+
self.R_i_t = R_i_t_addition
|
350 |
+
self.R_i_i = R_i_i_addition
|
351 |
+
|
352 |
+
# language self attention
|
353 |
+
self.handle_co_attn_self_lang(blk)
|
354 |
+
|
355 |
+
# image self attention
|
356 |
+
self.handle_co_attn_self_image(blk)
|
357 |
+
|
358 |
+
# take care of last cross attention layer- only text
|
359 |
+
blk = model.lxmert.encoder.x_layers[-1]
|
360 |
+
# cross attn- first for language then for image
|
361 |
+
R_t_i_addition, R_t_t_addition = self.handle_co_attn_lang(blk)
|
362 |
+
self.R_t_i = R_t_i_addition
|
363 |
+
self.R_t_t = R_t_t_addition
|
364 |
+
|
365 |
+
# language self attention
|
366 |
+
self.handle_co_attn_self_lang(blk)
|
367 |
+
|
368 |
+
# disregard the [CLS] token itself
|
369 |
+
self.R_t_t[0, 0] = 0
|
370 |
+
return self.R_t_t, self.R_t_i
|
371 |
+
|
372 |
+
|
373 |
+
class GeneratorBaselines:
|
374 |
+
def __init__(self, model_usage, save_visualization=False):
|
375 |
+
self.model_usage = model_usage
|
376 |
+
self.save_visualization = save_visualization
|
377 |
+
|
378 |
+
def generate_transformer_attr(self, input, index=None, method_name="transformer_attr"):
|
379 |
+
kwargs = {"alpha": 1}
|
380 |
+
output = self.model_usage.forward(input).question_answering_score
|
381 |
+
model = self.model_usage.model
|
382 |
+
|
383 |
+
# initialize relevancy matrices
|
384 |
+
text_tokens = self.model_usage.text_len
|
385 |
+
image_bboxes = self.model_usage.image_boxes_len
|
386 |
+
|
387 |
+
# text self attention matrix
|
388 |
+
self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device)
|
389 |
+
# image self attention matrix
|
390 |
+
self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device)
|
391 |
+
# impact of images on text
|
392 |
+
self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
|
393 |
+
# impact of text on images
|
394 |
+
self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)
|
395 |
+
|
396 |
+
if index == None:
|
397 |
+
index = np.argmax(output.cpu().data.numpy(), axis=-1)
|
398 |
+
|
399 |
+
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
400 |
+
one_hot[0, index] = 1
|
401 |
+
one_hot_vector = one_hot
|
402 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
403 |
+
one_hot = torch.sum(one_hot.cuda() * output)
|
404 |
+
|
405 |
+
model.zero_grad()
|
406 |
+
one_hot.backward(retain_graph=True)
|
407 |
+
model.relprop(torch.tensor(one_hot_vector).to(output.device), **kwargs)
|
408 |
+
|
409 |
+
# language self attention
|
410 |
+
blocks = model.lxmert.encoder.layer
|
411 |
+
for blk in blocks:
|
412 |
+
grad = blk.attention.self.get_attn_gradients().detach()
|
413 |
+
cam = blk.attention.self.get_attn_cam().detach()
|
414 |
+
cam = avg_heads(cam, grad)
|
415 |
+
self.R_t_t += torch.matmul(cam, self.R_t_t)
|
416 |
+
|
417 |
+
# image self attention
|
418 |
+
blocks = model.lxmert.encoder.r_layers
|
419 |
+
for blk in blocks:
|
420 |
+
grad = blk.attention.self.get_attn_gradients().detach()
|
421 |
+
cam = blk.attention.self.get_attn_cam().detach()
|
422 |
+
cam = avg_heads(cam, grad)
|
423 |
+
self.R_i_i += torch.matmul(cam, self.R_i_i)
|
424 |
+
|
425 |
+
# cross attn layers
|
426 |
+
blocks = model.lxmert.encoder.x_layers
|
427 |
+
for i, blk in enumerate(blocks):
|
428 |
+
# in the last cross attention module, only the text cross modal
|
429 |
+
# attention has an impact on the CLS token, since it's the first
|
430 |
+
# token in the language tokens
|
431 |
+
if i == len(blocks) - 1:
|
432 |
+
break
|
433 |
+
|
434 |
+
# language self attention
|
435 |
+
grad = blk.lang_self_att.self.get_attn_gradients().detach()
|
436 |
+
cam = blk.lang_self_att.self.get_attn_cam().detach()
|
437 |
+
cam = avg_heads(cam, grad)
|
438 |
+
self.R_t_t += torch.matmul(cam, self.R_t_t)
|
439 |
+
|
440 |
+
# image self attention
|
441 |
+
grad = blk.visn_self_att.self.get_attn_gradients().detach()
|
442 |
+
cam = blk.visn_self_att.self.get_attn_cam().detach()
|
443 |
+
cam = avg_heads(cam, grad)
|
444 |
+
self.R_i_i += torch.matmul(cam, self.R_i_i)
|
445 |
+
|
446 |
+
# take care of last cross attention layer- only text
|
447 |
+
blk = model.lxmert.encoder.x_layers[-1]
|
448 |
+
# cross attn cam will be the one used for the R_t_i matrix
|
449 |
+
cam_t_i = blk.visual_attention.att.get_attn_cam().detach()
|
450 |
+
grad_t_i = blk.visual_attention.att.get_attn_gradients().detach()
|
451 |
+
cam_t_i = avg_heads(cam_t_i, grad_t_i)
|
452 |
+
# self.R_t_i = torch.matmul(self.R_t_t.t(), torch.matmul(cam_t_i, self.R_i_i))
|
453 |
+
self.R_t_i = cam_t_i
|
454 |
+
|
455 |
+
# language self attention
|
456 |
+
grad = blk.lang_self_att.self.get_attn_gradients().detach()
|
457 |
+
cam = blk.lang_self_att.self.get_attn_cam().detach()
|
458 |
+
cam = avg_heads(cam, grad)
|
459 |
+
self.R_t_t += torch.matmul(cam, self.R_t_t)
|
460 |
+
|
461 |
+
self.R_t_t[0, 0] = 0
|
462 |
+
return self.R_t_t, self.R_t_i
|
463 |
+
|
464 |
+
def generate_partial_lrp(self, input, index=None, method_name="partial_lrp"):
|
465 |
+
kwargs = {"alpha": 1}
|
466 |
+
output = self.model_usage.forward(input).question_answering_score
|
467 |
+
model = self.model_usage.model
|
468 |
+
|
469 |
+
# initialize relevancy matrices
|
470 |
+
text_tokens = self.model_usage.text_len
|
471 |
+
image_bboxes = self.model_usage.image_boxes_len
|
472 |
+
|
473 |
+
# text self attention matrix
|
474 |
+
self.R_t_t = torch.zeros(text_tokens, text_tokens).to(model.device)
|
475 |
+
# image self attention matrix
|
476 |
+
self.R_i_i = torch.zeros(image_bboxes, image_bboxes).to(model.device)
|
477 |
+
# impact of images on text
|
478 |
+
self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
|
479 |
+
# impact of text on images
|
480 |
+
self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)
|
481 |
+
|
482 |
+
if index == None:
|
483 |
+
index = np.argmax(output.cpu().data.numpy(), axis=-1)
|
484 |
+
|
485 |
+
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
486 |
+
one_hot[0, index] = 1
|
487 |
+
one_hot_vector = one_hot
|
488 |
+
model.relprop(torch.tensor(one_hot_vector).to(output.device), **kwargs)
|
489 |
+
|
490 |
+
# last cross attention + self- attention layer
|
491 |
+
blk = model.lxmert.encoder.x_layers[-1]
|
492 |
+
# cross attn cam will be the one used for the R_t_i matrix
|
493 |
+
cam_t_i = blk.visual_attention.att.get_attn_cam().detach()
|
494 |
+
cam_t_i = cam_t_i.reshape(-1, cam_t_i.shape[-2], cam_t_i.shape[-1]).mean(dim=0)
|
495 |
+
self.R_t_i = cam_t_i
|
496 |
+
|
497 |
+
# language self attention
|
498 |
+
cam = blk.lang_self_att.self.get_attn_cam().detach()
|
499 |
+
cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
|
500 |
+
self.R_t_t = cam
|
501 |
+
|
502 |
+
# normalize to get non-negative cams
|
503 |
+
self.R_t_t = (self.R_t_t - self.R_t_t.min()) / (self.R_t_t.max() - self.R_t_t.min())
|
504 |
+
self.R_t_i = (self.R_t_i - self.R_t_i.min()) / (self.R_t_i.max() - self.R_t_i.min())
|
505 |
+
# disregard the [CLS] token itself
|
506 |
+
self.R_t_t[0, 0] = 0
|
507 |
+
return self.R_t_t, self.R_t_i
|
508 |
+
|
509 |
+
def generate_raw_attn(self, input, method_name="raw_attention"):
|
510 |
+
output = self.model_usage.forward(input).question_answering_score
|
511 |
+
model = self.model_usage.model
|
512 |
+
|
513 |
+
# initialize relevancy matrices
|
514 |
+
text_tokens = self.model_usage.text_len
|
515 |
+
image_bboxes = self.model_usage.image_boxes_len
|
516 |
+
|
517 |
+
# text self attention matrix
|
518 |
+
self.R_t_t = torch.zeros(text_tokens, text_tokens).to(model.device)
|
519 |
+
# image self attention matrix
|
520 |
+
self.R_i_i = torch.zeros(image_bboxes, image_bboxes).to(model.device)
|
521 |
+
# impact of images on text
|
522 |
+
self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
|
523 |
+
# impact of text on images
|
524 |
+
self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)
|
525 |
+
|
526 |
+
# last cross attention + self- attention layer
|
527 |
+
blk = model.lxmert.encoder.x_layers[-1]
|
528 |
+
# cross attn cam will be the one used for the R_t_i matrix
|
529 |
+
cam_t_i = blk.visual_attention.att.get_attn().detach()
|
530 |
+
cam_t_i = cam_t_i.reshape(-1, cam_t_i.shape[-2], cam_t_i.shape[-1]).mean(dim=0)
|
531 |
+
# self.R_t_i = torch.matmul(self.R_t_t.t(), torch.matmul(cam_t_i, self.R_i_i))
|
532 |
+
self.R_t_i = cam_t_i
|
533 |
+
|
534 |
+
# language self attention
|
535 |
+
cam = blk.lang_self_att.self.get_attn().detach()
|
536 |
+
cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
|
537 |
+
self.R_t_t = cam
|
538 |
+
|
539 |
+
# disregard the [CLS] token itself
|
540 |
+
self.R_t_t[0, 0] = 0
|
541 |
+
return self.R_t_t, self.R_t_i
|
542 |
+
|
543 |
+
def gradcam(self, cam, grad):
|
544 |
+
cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])
|
545 |
+
grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
|
546 |
+
grad = grad.mean(dim=[1, 2], keepdim=True)
|
547 |
+
cam = (cam * grad).mean(0).clamp(min=0)
|
548 |
+
return cam
|
549 |
+
|
550 |
+
def generate_attn_gradcam(self, input, index=None, method_name="gradcam"):
|
551 |
+
output = self.model_usage.forward(input).question_answering_score
|
552 |
+
model = self.model_usage.model
|
553 |
+
|
554 |
+
# initialize relevancy matrices
|
555 |
+
text_tokens = self.model_usage.text_len
|
556 |
+
image_bboxes = self.model_usage.image_boxes_len
|
557 |
+
|
558 |
+
# text self attention matrix
|
559 |
+
self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device)
|
560 |
+
# image self attention matrix
|
561 |
+
self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device)
|
562 |
+
# impact of images on text
|
563 |
+
self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
|
564 |
+
# impact of text on images
|
565 |
+
self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)
|
566 |
+
|
567 |
+
if index == None:
|
568 |
+
index = np.argmax(output.cpu().data.numpy(), axis=-1)
|
569 |
+
|
570 |
+
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
571 |
+
one_hot[0, index] = 1
|
572 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
573 |
+
one_hot = torch.sum(one_hot.cuda() * output)
|
574 |
+
|
575 |
+
model.zero_grad()
|
576 |
+
one_hot.backward(retain_graph=True)
|
577 |
+
|
578 |
+
# last cross attention + self- attention layer
|
579 |
+
blk = model.lxmert.encoder.x_layers[-1]
|
580 |
+
# cross attn cam will be the one used for the R_t_i matrix
|
581 |
+
grad_t_i = blk.visual_attention.att.get_attn_gradients().detach()
|
582 |
+
cam_t_i = blk.visual_attention.att.get_attn().detach()
|
583 |
+
cam_t_i = self.gradcam(cam_t_i, grad_t_i)
|
584 |
+
# self.R_t_i = torch.matmul(self.R_t_t.t(), torch.matmul(cam_t_i, self.R_i_i))
|
585 |
+
self.R_t_i = cam_t_i
|
586 |
+
|
587 |
+
# language self attention
|
588 |
+
grad = blk.lang_self_att.self.get_attn_gradients().detach()
|
589 |
+
cam = blk.lang_self_att.self.get_attn().detach()
|
590 |
+
self.R_t_t = self.gradcam(cam, grad)
|
591 |
+
|
592 |
+
# disregard the [CLS] token itself
|
593 |
+
self.R_t_t[0, 0] = 0
|
594 |
+
return self.R_t_t, self.R_t_i
|
595 |
+
|
596 |
+
def generate_rollout(self, input, method_name="rollout"):
|
597 |
+
output = self.model_usage.forward(input).question_answering_score
|
598 |
+
model = self.model_usage.model
|
599 |
+
|
600 |
+
# initialize relevancy matrices
|
601 |
+
text_tokens = self.model_usage.text_len
|
602 |
+
image_bboxes = self.model_usage.image_boxes_len
|
603 |
+
|
604 |
+
# text self attention matrix
|
605 |
+
self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device)
|
606 |
+
# image self attention matrix
|
607 |
+
self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device)
|
608 |
+
# impact of images on text
|
609 |
+
self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
|
610 |
+
# impact of text on images
|
611 |
+
self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)
|
612 |
+
|
613 |
+
cams_text = []
|
614 |
+
cams_image = []
|
615 |
+
# language self attention
|
616 |
+
blocks = model.lxmert.encoder.layer
|
617 |
+
for blk in blocks:
|
618 |
+
cam = blk.attention.self.get_attn().detach()
|
619 |
+
cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
|
620 |
+
cams_text.append(cam)
|
621 |
+
|
622 |
+
# image self attention
|
623 |
+
blocks = model.lxmert.encoder.r_layers
|
624 |
+
for blk in blocks:
|
625 |
+
cam = blk.attention.self.get_attn().detach()
|
626 |
+
cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
|
627 |
+
cams_image.append(cam)
|
628 |
+
|
629 |
+
# cross attn layers
|
630 |
+
blocks = model.lxmert.encoder.x_layers
|
631 |
+
for i, blk in enumerate(blocks):
|
632 |
+
# in the last cross attention module, only the text cross modal
|
633 |
+
# attention has an impact on the CLS token, since it's the first
|
634 |
+
# token in the language tokens
|
635 |
+
if i == len(blocks) - 1:
|
636 |
+
break
|
637 |
+
|
638 |
+
# language self attention
|
639 |
+
cam = blk.lang_self_att.self.get_attn().detach()
|
640 |
+
cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
|
641 |
+
cams_text.append(cam)
|
642 |
+
|
643 |
+
# image self attention
|
644 |
+
cam = blk.visn_self_att.self.get_attn().detach()
|
645 |
+
cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
|
646 |
+
cams_image.append(cam)
|
647 |
+
|
648 |
+
# take care of last cross attention layer- only text
|
649 |
+
blk = model.lxmert.encoder.x_layers[-1]
|
650 |
+
# cross attn cam will be the one used for the R_t_i matrix
|
651 |
+
cam_t_i = blk.visual_attention.att.get_attn().detach()
|
652 |
+
cam_t_i = cam_t_i.reshape(-1, cam_t_i.shape[-2], cam_t_i.shape[-1]).mean(dim=0)
|
653 |
+
self.R_t_t = compute_rollout_attention(copy.deepcopy(cams_text))
|
654 |
+
self.R_i_i = compute_rollout_attention(cams_image)
|
655 |
+
self.R_t_i = torch.matmul(self.R_t_t.t(), torch.matmul(cam_t_i, self.R_i_i))
|
656 |
+
# language self attention
|
657 |
+
cam = blk.lang_self_att.self.get_attn().detach()
|
658 |
+
cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
|
659 |
+
cams_text.append(cam)
|
660 |
+
|
661 |
+
self.R_t_t = compute_rollout_attention(cams_text)
|
662 |
+
|
663 |
+
# disregard the [CLS] token itself
|
664 |
+
self.R_t_t[0, 0] = 0
|
665 |
+
return self.R_t_t, self.R_t_i
|
lxmert/src/__init__.py
ADDED
File without changes
|
lxmert/src/__pycache__/ExplanationGenerator.cpython-38.pyc
ADDED
Binary file (15.3 kB). View file
|
|
lxmert/src/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (148 Bytes). View file
|
|
lxmert/src/__pycache__/huggingface_lxmert.cpython-38.pyc
ADDED
Binary file (47.6 kB). View file
|
|
lxmert/src/__pycache__/layers.cpython-38.pyc
ADDED
Binary file (10.5 kB). View file
|
|
lxmert/src/__pycache__/lxmert_lrp.cpython-38.pyc
ADDED
Binary file (53.6 kB). View file
|
|
lxmert/src/__pycache__/modeling_frcnn.cpython-38.pyc
ADDED
Binary file (56.8 kB). View file
|
|
lxmert/src/__pycache__/processing_image.cpython-38.pyc
ADDED
Binary file (5.73 kB). View file
|
|
lxmert/src/__pycache__/vqa_utils.cpython-38.pyc
ADDED
Binary file (14.4 kB). View file
|
|
lxmert/src/huggingface_lxmert.py
ADDED
@@ -0,0 +1,1472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 Hao Tan, Mohit Bansal, and the HuggingFace team
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch lxmert model. """
|
16 |
+
|
17 |
+
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import warnings
|
21 |
+
from dataclasses import dataclass
|
22 |
+
from typing import Optional, Tuple
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import CrossEntropyLoss, SmoothL1Loss
|
27 |
+
|
28 |
+
from transformers.activations import ACT2FN, gelu
|
29 |
+
from transformers.file_utils import (
|
30 |
+
ModelOutput,
|
31 |
+
add_code_sample_docstrings,
|
32 |
+
add_start_docstrings,
|
33 |
+
add_start_docstrings_to_model_forward,
|
34 |
+
replace_return_docstrings,
|
35 |
+
)
|
36 |
+
from transformers.modeling_utils import PreTrainedModel
|
37 |
+
from transformers.utils import logging
|
38 |
+
from transformers.configuration_lxmert import LxmertConfig
|
39 |
+
|
40 |
+
|
41 |
+
logger = logging.get_logger(__name__)
|
42 |
+
|
43 |
+
_CONFIG_FOR_DOC = "LxmertConfig"
|
44 |
+
_TOKENIZER_FOR_DOC = "LxmertTokenizer"
|
45 |
+
|
46 |
+
LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
47 |
+
"unc-nlp/lxmert-base-uncased",
|
48 |
+
]
|
49 |
+
|
50 |
+
|
51 |
+
class GeLU(nn.Module):
|
52 |
+
def __init__(self):
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
return gelu(x)
|
57 |
+
|
58 |
+
|
59 |
+
@dataclass
|
60 |
+
class LxmertModelOutput(ModelOutput):
|
61 |
+
"""
|
62 |
+
Lxmert's outputs that contain the last hidden states, pooled outputs, and attention probabilities for the language,
|
63 |
+
visual, and, cross-modality encoders. (note: the visual encoder in Lxmert is referred to as the "relation-ship"
|
64 |
+
encoder")
|
65 |
+
|
66 |
+
|
67 |
+
Args:
|
68 |
+
language_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
69 |
+
Sequence of hidden-states at the output of the last layer of the language encoder.
|
70 |
+
vision_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
71 |
+
Sequence of hidden-states at the output of the last layer of the visual encoder.
|
72 |
+
pooled_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
|
73 |
+
Last layer hidden-state of the first token of the sequence (classification, CLS, token) further processed
|
74 |
+
by a Linear layer and a Tanh activation function. The Linear
|
75 |
+
language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
76 |
+
Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
|
77 |
+
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
78 |
+
vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
79 |
+
Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
|
80 |
+
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
81 |
+
language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
82 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
83 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
84 |
+
weighted average in the self-attention heads.
|
85 |
+
vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
86 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
87 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
88 |
+
weighted average in the self-attention heads.
|
89 |
+
cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
90 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
91 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
92 |
+
weighted average in the self-attention heads.
|
93 |
+
"""
|
94 |
+
|
95 |
+
language_output: Optional[torch.FloatTensor] = None
|
96 |
+
vision_output: Optional[torch.FloatTensor] = None
|
97 |
+
pooled_output: Optional[torch.FloatTensor] = None
|
98 |
+
language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
99 |
+
vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
100 |
+
language_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
101 |
+
vision_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
102 |
+
cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
103 |
+
|
104 |
+
|
105 |
+
@dataclass
|
106 |
+
class LxmertForQuestionAnsweringOutput(ModelOutput):
|
107 |
+
"""
|
108 |
+
Output type of :class:`~transformers.LxmertForQuestionAnswering`.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
112 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction
|
113 |
+
(classification) loss.k.
|
114 |
+
question_answering_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_qa_answers)`, `optional`):
|
115 |
+
Prediction scores of question answering objective (classification).
|
116 |
+
language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
117 |
+
Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
|
118 |
+
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
119 |
+
vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
120 |
+
Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
|
121 |
+
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
122 |
+
language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
123 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
124 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
125 |
+
weighted average in the self-attention heads.
|
126 |
+
vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
127 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
128 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
129 |
+
weighted average in the self-attention heads.
|
130 |
+
cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
131 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
132 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
133 |
+
weighted average in the self-attention heads.
|
134 |
+
"""
|
135 |
+
|
136 |
+
loss: Optional[torch.FloatTensor] = None
|
137 |
+
question_answering_score: Optional[torch.FloatTensor] = None
|
138 |
+
language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
139 |
+
vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
140 |
+
language_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
141 |
+
vision_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
142 |
+
cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
143 |
+
|
144 |
+
|
145 |
+
@dataclass
|
146 |
+
class LxmertForPreTrainingOutput(ModelOutput):
|
147 |
+
"""
|
148 |
+
Output type of :class:`~transformers.LxmertForPreTraining`.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
152 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction
|
153 |
+
(classification) loss.
|
154 |
+
prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
155 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
156 |
+
cross_relationship_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
|
157 |
+
Prediction scores of the textual matching objective (classification) head (scores of True/False
|
158 |
+
continuation before SoftMax).
|
159 |
+
question_answering_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_qa_answers)`):
|
160 |
+
Prediction scores of question answering objective (classification).
|
161 |
+
language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
162 |
+
Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
|
163 |
+
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
164 |
+
vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
165 |
+
Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
|
166 |
+
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
167 |
+
language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
168 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
169 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
170 |
+
weighted average in the self-attention heads.
|
171 |
+
vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
172 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
173 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
174 |
+
weighted average in the self-attention heads.
|
175 |
+
cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
176 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
177 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
178 |
+
weighted average in the self-attention heads.
|
179 |
+
|
180 |
+
"""
|
181 |
+
|
182 |
+
loss: [torch.FloatTensor] = None
|
183 |
+
prediction_logits: Optional[torch.FloatTensor] = None
|
184 |
+
cross_relationship_score: Optional[torch.FloatTensor] = None
|
185 |
+
question_answering_score: Optional[torch.FloatTensor] = None
|
186 |
+
language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
187 |
+
vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
188 |
+
language_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
189 |
+
vision_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
190 |
+
cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
191 |
+
|
192 |
+
|
193 |
+
def load_tf_weights_in_lxmert(model, config, tf_checkpoint_path):
|
194 |
+
"""Load tf checkpoints in a pytorch model."""
|
195 |
+
try:
|
196 |
+
import re
|
197 |
+
|
198 |
+
import numpy as np
|
199 |
+
import tensorflow as tf
|
200 |
+
except ImportError:
|
201 |
+
logger.error(
|
202 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
203 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
204 |
+
)
|
205 |
+
raise
|
206 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
207 |
+
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
208 |
+
# Load weights from TF model
|
209 |
+
init_vars = tf.train.list_variables(tf_path)
|
210 |
+
names = []
|
211 |
+
arrays = []
|
212 |
+
for name, shape in init_vars:
|
213 |
+
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
214 |
+
array = tf.train.load_variable(tf_path, name)
|
215 |
+
names.append(name)
|
216 |
+
arrays.append(array)
|
217 |
+
|
218 |
+
for name, array in zip(names, arrays):
|
219 |
+
name = name.split("/")
|
220 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
221 |
+
# which are not required for using pretrained model
|
222 |
+
if any(
|
223 |
+
n
|
224 |
+
in [
|
225 |
+
"adam_v",
|
226 |
+
"adam_m",
|
227 |
+
"AdamWeightDecayOptimizer",
|
228 |
+
"AdamWeightDecayOptimizer_1",
|
229 |
+
"global_step",
|
230 |
+
]
|
231 |
+
for n in name
|
232 |
+
):
|
233 |
+
logger.info("Skipping {}".format("/".join(name)))
|
234 |
+
continue
|
235 |
+
pointer = model
|
236 |
+
for m_name in name:
|
237 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
238 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
239 |
+
else:
|
240 |
+
scope_names = [m_name]
|
241 |
+
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
242 |
+
pointer = getattr(pointer, "weight")
|
243 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
244 |
+
pointer = getattr(pointer, "bias")
|
245 |
+
elif scope_names[0] == "output_weights":
|
246 |
+
pointer = getattr(pointer, "weight")
|
247 |
+
elif scope_names[0] == "squad":
|
248 |
+
pointer = getattr(pointer, "classifier")
|
249 |
+
else:
|
250 |
+
try:
|
251 |
+
pointer = getattr(pointer, scope_names[0])
|
252 |
+
except AttributeError:
|
253 |
+
logger.info("Skipping {}".format("/".join(name)))
|
254 |
+
continue
|
255 |
+
if len(scope_names) >= 2:
|
256 |
+
num = int(scope_names[1])
|
257 |
+
pointer = pointer[num]
|
258 |
+
if m_name[-11:] == "_embeddings":
|
259 |
+
pointer = getattr(pointer, "weight")
|
260 |
+
elif m_name == "kernel":
|
261 |
+
array = np.transpose(array)
|
262 |
+
try:
|
263 |
+
assert pointer.shape == array.shape
|
264 |
+
except AssertionError as e:
|
265 |
+
e.args += (pointer.shape, array.shape)
|
266 |
+
raise
|
267 |
+
logger.info("Initialize PyTorch weight {}".format(name))
|
268 |
+
pointer.data = torch.from_numpy(array)
|
269 |
+
return model
|
270 |
+
|
271 |
+
|
272 |
+
class LxmertEmbeddings(nn.Module):
|
273 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
274 |
+
|
275 |
+
def __init__(self, config):
|
276 |
+
super().__init__()
|
277 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
278 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0)
|
279 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0)
|
280 |
+
|
281 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
282 |
+
# any TensorFlow checkpoint file
|
283 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
|
284 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
285 |
+
|
286 |
+
def forward(self, input_ids, token_type_ids=None, inputs_embeds=None):
|
287 |
+
if input_ids is not None:
|
288 |
+
input_shape = input_ids.size()
|
289 |
+
device = input_ids.device
|
290 |
+
else:
|
291 |
+
input_shape = inputs_embeds.size()[:-1]
|
292 |
+
device = inputs_embeds.device
|
293 |
+
seq_length = input_shape[1]
|
294 |
+
|
295 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
296 |
+
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
297 |
+
|
298 |
+
if token_type_ids is None:
|
299 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
300 |
+
|
301 |
+
if inputs_embeds is None:
|
302 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
303 |
+
position_embeddings = self.position_embeddings(position_ids)
|
304 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
305 |
+
|
306 |
+
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
307 |
+
embeddings = self.LayerNorm(embeddings)
|
308 |
+
embeddings = self.dropout(embeddings)
|
309 |
+
return embeddings
|
310 |
+
|
311 |
+
|
312 |
+
class LxmertAttention(nn.Module):
|
313 |
+
def __init__(self, config, ctx_dim=None, save_cams=False):
|
314 |
+
super().__init__()
|
315 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
316 |
+
raise ValueError(
|
317 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
318 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
319 |
+
)
|
320 |
+
self.num_attention_heads = config.num_attention_heads
|
321 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
322 |
+
self.head_size = self.num_attention_heads * self.attention_head_size
|
323 |
+
|
324 |
+
# visual_dim = 2048
|
325 |
+
if ctx_dim is None:
|
326 |
+
ctx_dim = config.hidden_size
|
327 |
+
self.query = nn.Linear(config.hidden_size, self.head_size)
|
328 |
+
self.key = nn.Linear(ctx_dim, self.head_size)
|
329 |
+
self.value = nn.Linear(ctx_dim, self.head_size)
|
330 |
+
|
331 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
332 |
+
|
333 |
+
self.save_cams = save_cams
|
334 |
+
self.attn = None
|
335 |
+
self.attn_gradients = None
|
336 |
+
|
337 |
+
def get_attn(self):
|
338 |
+
ret = self.attn
|
339 |
+
self.attn = None
|
340 |
+
return ret
|
341 |
+
|
342 |
+
def save_attn(self, attn):
|
343 |
+
if self.attn is not None:
|
344 |
+
self.attn = [self.attn, attn]
|
345 |
+
else:
|
346 |
+
self.attn = attn
|
347 |
+
|
348 |
+
def save_attn_gradients(self, attn_gradients):
|
349 |
+
if self.attn_gradients is not None:
|
350 |
+
self.attn_gradients = [self.attn_gradients, attn_gradients]
|
351 |
+
else:
|
352 |
+
self.attn_gradients = attn_gradients
|
353 |
+
|
354 |
+
def get_attn_gradients(self):
|
355 |
+
ret = self.attn_gradients
|
356 |
+
self.attn_gradients = None
|
357 |
+
return ret
|
358 |
+
|
359 |
+
def reset(self):
|
360 |
+
self.attn = None
|
361 |
+
self.attn_gradients = None
|
362 |
+
|
363 |
+
def transpose_for_scores(self, x):
|
364 |
+
new_x_shape = x.size()[:-1] + (
|
365 |
+
self.num_attention_heads,
|
366 |
+
self.attention_head_size,
|
367 |
+
)
|
368 |
+
x = x.view(*new_x_shape)
|
369 |
+
return x.permute(0, 2, 1, 3)
|
370 |
+
|
371 |
+
def forward(self, hidden_states, context, attention_mask=None, output_attentions=False):
|
372 |
+
mixed_query_layer = self.query(hidden_states)
|
373 |
+
mixed_key_layer = self.key(context)
|
374 |
+
mixed_value_layer = self.value(context)
|
375 |
+
|
376 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
377 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
378 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
379 |
+
|
380 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
381 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
382 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
383 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
384 |
+
if attention_mask is not None:
|
385 |
+
attention_scores = attention_scores + attention_mask
|
386 |
+
|
387 |
+
# Normalize the attention scores to probabilities.
|
388 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
389 |
+
|
390 |
+
# if self.save_cams:
|
391 |
+
self.save_attn(attention_probs)
|
392 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
393 |
+
|
394 |
+
# This is actually dropping out entire tokens to attend to, which might
|
395 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
396 |
+
attention_probs = self.dropout(attention_probs)
|
397 |
+
|
398 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
399 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
400 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,)
|
401 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
402 |
+
|
403 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
404 |
+
return outputs
|
405 |
+
|
406 |
+
|
407 |
+
class LxmertAttentionOutput(nn.Module):
|
408 |
+
def __init__(self, config):
|
409 |
+
super().__init__()
|
410 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
411 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
|
412 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
413 |
+
|
414 |
+
def forward(self, hidden_states, input_tensor):
|
415 |
+
hidden_states = self.dense(hidden_states)
|
416 |
+
hidden_states = self.dropout(hidden_states)
|
417 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
418 |
+
return hidden_states
|
419 |
+
|
420 |
+
|
421 |
+
class LxmertCrossAttentionLayer(nn.Module):
|
422 |
+
def __init__(self, config, save_cams=False):
|
423 |
+
super().__init__()
|
424 |
+
self.att = LxmertAttention(config, save_cams=save_cams)
|
425 |
+
self.output = LxmertAttentionOutput(config)
|
426 |
+
|
427 |
+
def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None, output_attentions=False):
|
428 |
+
output = self.att(input_tensor, ctx_tensor, ctx_att_mask, output_attentions=output_attentions)
|
429 |
+
if output_attentions:
|
430 |
+
attention_probs = output[1]
|
431 |
+
attention_output = self.output(output[0], input_tensor)
|
432 |
+
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
433 |
+
return outputs
|
434 |
+
|
435 |
+
|
436 |
+
class LxmertSelfAttentionLayer(nn.Module):
|
437 |
+
def __init__(self, config, save_cams=False):
|
438 |
+
super().__init__()
|
439 |
+
self.self = LxmertAttention(config, save_cams=save_cams)
|
440 |
+
self.output = LxmertAttentionOutput(config)
|
441 |
+
|
442 |
+
def forward(self, input_tensor, attention_mask, output_attentions=False):
|
443 |
+
# Self attention attends to itself, thus keys and queries are the same (input_tensor).
|
444 |
+
output = self.self(
|
445 |
+
input_tensor,
|
446 |
+
input_tensor,
|
447 |
+
attention_mask,
|
448 |
+
output_attentions=output_attentions,
|
449 |
+
)
|
450 |
+
if output_attentions:
|
451 |
+
attention_probs = output[1]
|
452 |
+
attention_output = self.output(output[0], input_tensor)
|
453 |
+
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
454 |
+
return outputs
|
455 |
+
|
456 |
+
|
457 |
+
class LxmertIntermediate(nn.Module):
|
458 |
+
def __init__(self, config):
|
459 |
+
super().__init__()
|
460 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
461 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
462 |
+
|
463 |
+
def forward(self, hidden_states):
|
464 |
+
hidden_states = self.dense(hidden_states)
|
465 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
466 |
+
return hidden_states
|
467 |
+
|
468 |
+
|
469 |
+
class LxmertOutput(nn.Module):
|
470 |
+
def __init__(self, config):
|
471 |
+
super().__init__()
|
472 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
473 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
|
474 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
475 |
+
|
476 |
+
def forward(self, hidden_states, input_tensor):
|
477 |
+
hidden_states = self.dense(hidden_states)
|
478 |
+
hidden_states = self.dropout(hidden_states)
|
479 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
480 |
+
return hidden_states
|
481 |
+
|
482 |
+
|
483 |
+
class LxmertLayer(nn.Module):
|
484 |
+
def __init__(self, config, save_cams=False):
|
485 |
+
super().__init__()
|
486 |
+
self.attention = LxmertSelfAttentionLayer(config, save_cams=save_cams)
|
487 |
+
self.intermediate = LxmertIntermediate(config)
|
488 |
+
self.output = LxmertOutput(config)
|
489 |
+
|
490 |
+
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
491 |
+
outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions)
|
492 |
+
attention_output = outputs[0]
|
493 |
+
intermediate_output = self.intermediate(attention_output)
|
494 |
+
layer_output = self.output(intermediate_output, attention_output)
|
495 |
+
outputs = (layer_output,) + outputs[1:] # add attentions if we output them
|
496 |
+
return outputs
|
497 |
+
|
498 |
+
|
499 |
+
class LxmertXLayer(nn.Module):
|
500 |
+
def __init__(self, config, save_cams=False):
|
501 |
+
super().__init__()
|
502 |
+
# The cross-attention Layer
|
503 |
+
self.visual_attention = LxmertCrossAttentionLayer(config, save_cams=save_cams)
|
504 |
+
|
505 |
+
# Self-attention Layers
|
506 |
+
self.lang_self_att = LxmertSelfAttentionLayer(config)
|
507 |
+
self.visn_self_att = LxmertSelfAttentionLayer(config)
|
508 |
+
|
509 |
+
# Intermediate and Output Layers (FFNs)
|
510 |
+
self.lang_inter = LxmertIntermediate(config)
|
511 |
+
self.lang_output = LxmertOutput(config)
|
512 |
+
self.visn_inter = LxmertIntermediate(config)
|
513 |
+
self.visn_output = LxmertOutput(config)
|
514 |
+
|
515 |
+
def cross_att(
|
516 |
+
self,
|
517 |
+
lang_input,
|
518 |
+
lang_attention_mask,
|
519 |
+
visual_input,
|
520 |
+
visual_attention_mask,
|
521 |
+
output_x_attentions=False,
|
522 |
+
):
|
523 |
+
# Cross Attention
|
524 |
+
lang_att_output = self.visual_attention(
|
525 |
+
lang_input,
|
526 |
+
visual_input,
|
527 |
+
ctx_att_mask=visual_attention_mask,
|
528 |
+
output_attentions=output_x_attentions,
|
529 |
+
)
|
530 |
+
visual_att_output = self.visual_attention(
|
531 |
+
visual_input,
|
532 |
+
lang_input,
|
533 |
+
ctx_att_mask=lang_attention_mask,
|
534 |
+
output_attentions=False,
|
535 |
+
)
|
536 |
+
return lang_att_output, visual_att_output
|
537 |
+
|
538 |
+
def self_att(self, lang_input, lang_attention_mask, visual_input, visual_attention_mask):
|
539 |
+
# Self Attention
|
540 |
+
lang_att_output = self.lang_self_att(lang_input, lang_attention_mask, output_attentions=False)
|
541 |
+
visual_att_output = self.visn_self_att(visual_input, visual_attention_mask, output_attentions=False)
|
542 |
+
return lang_att_output[0], visual_att_output[0]
|
543 |
+
|
544 |
+
def output_fc(self, lang_input, visual_input):
|
545 |
+
# FC layers
|
546 |
+
lang_inter_output = self.lang_inter(lang_input)
|
547 |
+
visual_inter_output = self.visn_inter(visual_input)
|
548 |
+
|
549 |
+
# Layer output
|
550 |
+
lang_output = self.lang_output(lang_inter_output, lang_input)
|
551 |
+
visual_output = self.visn_output(visual_inter_output, visual_input)
|
552 |
+
|
553 |
+
return lang_output, visual_output
|
554 |
+
|
555 |
+
def forward(
|
556 |
+
self,
|
557 |
+
lang_feats,
|
558 |
+
lang_attention_mask,
|
559 |
+
visual_feats,
|
560 |
+
visual_attention_mask,
|
561 |
+
output_attentions=False,
|
562 |
+
):
|
563 |
+
|
564 |
+
lang_att_output, visual_att_output = self.cross_att(
|
565 |
+
lang_input=lang_feats,
|
566 |
+
lang_attention_mask=lang_attention_mask,
|
567 |
+
visual_input=visual_feats,
|
568 |
+
visual_attention_mask=visual_attention_mask,
|
569 |
+
output_x_attentions=output_attentions,
|
570 |
+
)
|
571 |
+
attention_probs = lang_att_output[1:]
|
572 |
+
lang_att_output, visual_att_output = self.self_att(
|
573 |
+
lang_att_output[0],
|
574 |
+
lang_attention_mask,
|
575 |
+
visual_att_output[0],
|
576 |
+
visual_attention_mask,
|
577 |
+
)
|
578 |
+
|
579 |
+
lang_output, visual_output = self.output_fc(lang_att_output, visual_att_output)
|
580 |
+
return (
|
581 |
+
(
|
582 |
+
lang_output,
|
583 |
+
visual_output,
|
584 |
+
attention_probs[0],
|
585 |
+
)
|
586 |
+
if output_attentions
|
587 |
+
else (lang_output, visual_output)
|
588 |
+
)
|
589 |
+
|
590 |
+
|
591 |
+
class LxmertVisualFeatureEncoder(nn.Module):
|
592 |
+
def __init__(self, config):
|
593 |
+
super().__init__()
|
594 |
+
feat_dim = config.visual_feat_dim
|
595 |
+
pos_dim = config.visual_pos_dim
|
596 |
+
|
597 |
+
# Object feature encoding
|
598 |
+
self.visn_fc = nn.Linear(feat_dim, config.hidden_size)
|
599 |
+
self.visn_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
|
600 |
+
|
601 |
+
# Box position encoding
|
602 |
+
self.box_fc = nn.Linear(pos_dim, config.hidden_size)
|
603 |
+
self.box_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
|
604 |
+
|
605 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
606 |
+
|
607 |
+
def forward(self, visual_feats, visual_pos):
|
608 |
+
x = self.visn_fc(visual_feats)
|
609 |
+
x = self.visn_layer_norm(x)
|
610 |
+
y = self.box_fc(visual_pos)
|
611 |
+
y = self.box_layer_norm(y)
|
612 |
+
output = (x + y) / 2
|
613 |
+
|
614 |
+
output = self.dropout(output)
|
615 |
+
return output
|
616 |
+
|
617 |
+
|
618 |
+
class LxmertEncoder(nn.Module):
|
619 |
+
def __init__(self, config, save_cams=False):
|
620 |
+
super().__init__()
|
621 |
+
|
622 |
+
# Obj-level image embedding layer
|
623 |
+
self.visn_fc = LxmertVisualFeatureEncoder(config)
|
624 |
+
self.config = config
|
625 |
+
|
626 |
+
# Number of layers
|
627 |
+
self.num_l_layers = config.l_layers
|
628 |
+
self.num_x_layers = config.x_layers
|
629 |
+
self.num_r_layers = config.r_layers
|
630 |
+
|
631 |
+
# Layers
|
632 |
+
# Using self.layer instead of self.l_layer to support loading BERT weights.
|
633 |
+
self.layer = nn.ModuleList([LxmertLayer(config, save_cams=save_cams) for _ in range(self.num_l_layers)])
|
634 |
+
self.x_layers = nn.ModuleList([LxmertXLayer(config) for _ in range(self.num_x_layers)])
|
635 |
+
self.r_layers = nn.ModuleList([LxmertLayer(config, save_cams=save_cams) for _ in range(self.num_r_layers)])
|
636 |
+
|
637 |
+
def forward(
|
638 |
+
self,
|
639 |
+
lang_feats,
|
640 |
+
lang_attention_mask,
|
641 |
+
visual_feats,
|
642 |
+
visual_pos,
|
643 |
+
visual_attention_mask=None,
|
644 |
+
output_attentions=None,
|
645 |
+
):
|
646 |
+
|
647 |
+
vision_hidden_states = ()
|
648 |
+
language_hidden_states = ()
|
649 |
+
vision_attentions = () if output_attentions or self.config.output_attentions else None
|
650 |
+
language_attentions = () if output_attentions or self.config.output_attentions else None
|
651 |
+
cross_encoder_attentions = () if output_attentions or self.config.output_attentions else None
|
652 |
+
|
653 |
+
visual_feats = self.visn_fc(visual_feats, visual_pos)
|
654 |
+
|
655 |
+
# Run language layers
|
656 |
+
for layer_module in self.layer:
|
657 |
+
l_outputs = layer_module(lang_feats, lang_attention_mask, output_attentions=output_attentions)
|
658 |
+
lang_feats = l_outputs[0]
|
659 |
+
language_hidden_states = language_hidden_states + (lang_feats,)
|
660 |
+
if language_attentions is not None:
|
661 |
+
language_attentions = language_attentions + (l_outputs[1],)
|
662 |
+
|
663 |
+
# Run relational layers
|
664 |
+
for layer_module in self.r_layers:
|
665 |
+
v_outputs = layer_module(visual_feats, visual_attention_mask, output_attentions=output_attentions)
|
666 |
+
visual_feats = v_outputs[0]
|
667 |
+
vision_hidden_states = vision_hidden_states + (visual_feats,)
|
668 |
+
if vision_attentions is not None:
|
669 |
+
vision_attentions = vision_attentions + (v_outputs[1],)
|
670 |
+
|
671 |
+
# Run cross-modality layers
|
672 |
+
for layer_module in self.x_layers:
|
673 |
+
x_outputs = layer_module(
|
674 |
+
lang_feats,
|
675 |
+
lang_attention_mask,
|
676 |
+
visual_feats,
|
677 |
+
visual_attention_mask,
|
678 |
+
output_attentions=output_attentions,
|
679 |
+
)
|
680 |
+
lang_feats, visual_feats = x_outputs[:2]
|
681 |
+
vision_hidden_states = vision_hidden_states + (visual_feats,)
|
682 |
+
language_hidden_states = language_hidden_states + (lang_feats,)
|
683 |
+
if cross_encoder_attentions is not None:
|
684 |
+
cross_encoder_attentions = cross_encoder_attentions + (x_outputs[2],)
|
685 |
+
visual_encoder_outputs = (
|
686 |
+
vision_hidden_states,
|
687 |
+
vision_attentions if output_attentions else None,
|
688 |
+
)
|
689 |
+
lang_encoder_outputs = (
|
690 |
+
language_hidden_states,
|
691 |
+
language_attentions if output_attentions else None,
|
692 |
+
)
|
693 |
+
return (
|
694 |
+
visual_encoder_outputs,
|
695 |
+
lang_encoder_outputs,
|
696 |
+
cross_encoder_attentions if output_attentions else None,
|
697 |
+
)
|
698 |
+
|
699 |
+
|
700 |
+
class LxmertPooler(nn.Module):
|
701 |
+
def __init__(self, config):
|
702 |
+
super(LxmertPooler, self).__init__()
|
703 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
704 |
+
self.activation = nn.Tanh()
|
705 |
+
|
706 |
+
def forward(self, hidden_states):
|
707 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
708 |
+
# to the first token.
|
709 |
+
first_token_tensor = hidden_states[:, 0]
|
710 |
+
pooled_output = self.dense(first_token_tensor)
|
711 |
+
pooled_output = self.activation(pooled_output)
|
712 |
+
return pooled_output
|
713 |
+
|
714 |
+
|
715 |
+
class LxmertPredictionHeadTransform(nn.Module):
|
716 |
+
def __init__(self, config):
|
717 |
+
super(LxmertPredictionHeadTransform, self).__init__()
|
718 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
719 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
720 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
|
721 |
+
|
722 |
+
def forward(self, hidden_states):
|
723 |
+
hidden_states = self.dense(hidden_states)
|
724 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
725 |
+
hidden_states = self.LayerNorm(hidden_states)
|
726 |
+
return hidden_states
|
727 |
+
|
728 |
+
|
729 |
+
class LxmertLMPredictionHead(nn.Module):
|
730 |
+
def __init__(self, config, lxmert_model_embedding_weights):
|
731 |
+
super(LxmertLMPredictionHead, self).__init__()
|
732 |
+
self.transform = LxmertPredictionHeadTransform(config)
|
733 |
+
|
734 |
+
# The output weights are the same as the input embeddings, but there is
|
735 |
+
# an output-only bias for each token.
|
736 |
+
self.decoder = nn.Linear(
|
737 |
+
lxmert_model_embedding_weights.size(1),
|
738 |
+
lxmert_model_embedding_weights.size(0),
|
739 |
+
bias=False,
|
740 |
+
)
|
741 |
+
self.decoder.weight = lxmert_model_embedding_weights
|
742 |
+
self.bias = nn.Parameter(torch.zeros(lxmert_model_embedding_weights.size(0)))
|
743 |
+
|
744 |
+
def forward(self, hidden_states):
|
745 |
+
hidden_states = self.transform(hidden_states)
|
746 |
+
hidden_states = self.decoder(hidden_states) + self.bias
|
747 |
+
return hidden_states
|
748 |
+
|
749 |
+
|
750 |
+
class LxmertVisualAnswerHead(nn.Module):
|
751 |
+
def __init__(self, config, num_labels):
|
752 |
+
super().__init__()
|
753 |
+
hid_dim = config.hidden_size
|
754 |
+
self.logit_fc = nn.Sequential(
|
755 |
+
nn.Linear(hid_dim, hid_dim * 2),
|
756 |
+
GeLU(),
|
757 |
+
nn.LayerNorm(hid_dim * 2, eps=1e-12),
|
758 |
+
nn.Linear(hid_dim * 2, num_labels),
|
759 |
+
)
|
760 |
+
|
761 |
+
def forward(self, hidden_states):
|
762 |
+
return self.logit_fc(hidden_states)
|
763 |
+
|
764 |
+
|
765 |
+
class LxmertVisualObjHead(nn.Module):
|
766 |
+
def __init__(self, config):
|
767 |
+
super().__init__()
|
768 |
+
self.transform = LxmertPredictionHeadTransform(config)
|
769 |
+
# Decide the use of visual losses
|
770 |
+
visual_losses = {}
|
771 |
+
if config.visual_obj_loss:
|
772 |
+
visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels}
|
773 |
+
if config.visual_attr_loss:
|
774 |
+
visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels}
|
775 |
+
if config.visual_obj_loss:
|
776 |
+
visual_losses["feat"] = {
|
777 |
+
"shape": (-1, config.visual_feat_dim),
|
778 |
+
"num": config.visual_feat_dim,
|
779 |
+
}
|
780 |
+
self.visual_losses = visual_losses
|
781 |
+
|
782 |
+
# The output weights are the same as the input embeddings, but there is
|
783 |
+
# an output-only bias for each token.
|
784 |
+
self.decoder_dict = nn.ModuleDict(
|
785 |
+
{key: nn.Linear(config.hidden_size, self.visual_losses[key]["num"]) for key in self.visual_losses}
|
786 |
+
)
|
787 |
+
|
788 |
+
def forward(self, hidden_states):
|
789 |
+
hidden_states = self.transform(hidden_states)
|
790 |
+
output = {}
|
791 |
+
for key in self.visual_losses:
|
792 |
+
output[key] = self.decoder_dict[key](hidden_states)
|
793 |
+
return output
|
794 |
+
|
795 |
+
|
796 |
+
class LxmertPreTrainingHeads(nn.Module):
|
797 |
+
def __init__(self, config, lxmert_model_embedding_weights):
|
798 |
+
super(LxmertPreTrainingHeads, self).__init__()
|
799 |
+
self.predictions = LxmertLMPredictionHead(config, lxmert_model_embedding_weights)
|
800 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
801 |
+
|
802 |
+
def forward(self, sequence_output, pooled_output):
|
803 |
+
prediction_scores = self.predictions(sequence_output)
|
804 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
805 |
+
return prediction_scores, seq_relationship_score
|
806 |
+
|
807 |
+
|
808 |
+
class LxmertPreTrainedModel(PreTrainedModel):
|
809 |
+
"""
|
810 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
811 |
+
models.
|
812 |
+
"""
|
813 |
+
|
814 |
+
config_class = LxmertConfig
|
815 |
+
load_tf_weights = load_tf_weights_in_lxmert
|
816 |
+
base_model_prefix = "lxmert"
|
817 |
+
|
818 |
+
def _init_weights(self, module):
|
819 |
+
""" Initialize the weights """
|
820 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
821 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
822 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
823 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
824 |
+
elif isinstance(module, nn.LayerNorm):
|
825 |
+
module.bias.data.zero_()
|
826 |
+
module.weight.data.fill_(1.0)
|
827 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
828 |
+
module.bias.data.zero_()
|
829 |
+
|
830 |
+
|
831 |
+
LXMERT_START_DOCSTRING = r"""
|
832 |
+
|
833 |
+
The lxmert model was proposed in `lxmert: Learning Cross-Modality Encoder Representations from Transformers
|
834 |
+
<https://arxiv.org/abs/1908.07490>`__ by Hao Tan and Mohit Bansal. It's a vision and language transformer model,
|
835 |
+
pretrained on a variety of multi-modal datasets comprising of GQA, VQAv2.0, MCSCOCO captions, and Visual genome,
|
836 |
+
using a combination of masked language modeling, region of interest feature regression, cross entropy loss for
|
837 |
+
question answering attribute prediction, and object tag prediction.
|
838 |
+
|
839 |
+
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
840 |
+
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
841 |
+
pruning heads etc.)
|
842 |
+
|
843 |
+
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
844 |
+
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
845 |
+
general usage and behavior.
|
846 |
+
|
847 |
+
Parameters:
|
848 |
+
config (:class:`~transformers.LxmertConfig`): Model configuration class with all the parameters of the model.
|
849 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
850 |
+
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
851 |
+
weights.
|
852 |
+
"""
|
853 |
+
|
854 |
+
LXMERT_INPUTS_DOCSTRING = r"""
|
855 |
+
|
856 |
+
Args:
|
857 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
|
858 |
+
Indices of input sequence tokens in the vocabulary.
|
859 |
+
|
860 |
+
Indices can be obtained using :class:`~transformers.LxmertTokenizer`. See
|
861 |
+
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
862 |
+
details.
|
863 |
+
|
864 |
+
`What are input IDs? <../glossary.html#input-ids>`__
|
865 |
+
visual_feats: (:obj:`torch.FloatTensor` of shape :obj:՝(batch_size, num_visual_features, visual_feat_dim)՝):
|
866 |
+
This input represents visual features. They ROI pooled object features from bounding boxes using a
|
867 |
+
faster-RCNN model)
|
868 |
+
|
869 |
+
These are currently not provided by the transformers library.
|
870 |
+
visual_pos: (:obj:`torch.FloatTensor` of shape :obj:՝(batch_size, num_visual_features, visual_pos_dim)՝):
|
871 |
+
This input represents spacial features corresponding to their relative (via index) visual features. The
|
872 |
+
pre-trained lxmert model expects these spacial features to be normalized bounding boxes on a scale of 0 to
|
873 |
+
1.
|
874 |
+
|
875 |
+
These are currently not provided by the transformers library.
|
876 |
+
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
877 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
878 |
+
|
879 |
+
- 1 for tokens that are **not masked**,
|
880 |
+
- 0 for tokens that are **masked**.
|
881 |
+
|
882 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
883 |
+
visual_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
884 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
885 |
+
|
886 |
+
- 1 for tokens that are **not masked**,
|
887 |
+
- 0 for tokens that are **masked**.
|
888 |
+
|
889 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
890 |
+
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
891 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
892 |
+
1]``:
|
893 |
+
|
894 |
+
- 0 corresponds to a `sentence A` token,
|
895 |
+
- 1 corresponds to a `sentence B` token.
|
896 |
+
|
897 |
+
`What are token type IDs? <../glossary.html#token-type-ids>`__
|
898 |
+
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
|
899 |
+
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
900 |
+
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
901 |
+
vectors than the model's internal embedding lookup matrix.
|
902 |
+
output_attentions (:obj:`bool`, `optional`):
|
903 |
+
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
904 |
+
tensors for more detail.
|
905 |
+
output_hidden_states (:obj:`bool`, `optional`):
|
906 |
+
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
907 |
+
more detail.
|
908 |
+
return_dict (:obj:`bool`, `optional`):
|
909 |
+
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
910 |
+
"""
|
911 |
+
|
912 |
+
|
913 |
+
@add_start_docstrings(
|
914 |
+
"The bare Lxmert Model transformer outputting raw hidden-states without any specific head on top.",
|
915 |
+
LXMERT_START_DOCSTRING,
|
916 |
+
)
|
917 |
+
class LxmertModel(LxmertPreTrainedModel):
|
918 |
+
def __init__(self, config, save_cams=False):
|
919 |
+
super().__init__(config)
|
920 |
+
self.embeddings = LxmertEmbeddings(config)
|
921 |
+
self.encoder = LxmertEncoder(config, save_cams=save_cams)
|
922 |
+
self.pooler = LxmertPooler(config)
|
923 |
+
self.init_weights()
|
924 |
+
|
925 |
+
def get_input_embeddings(self):
|
926 |
+
return self.embeddings.word_embeddings
|
927 |
+
|
928 |
+
def set_input_embeddings(self, new_embeddings):
|
929 |
+
self.embeddings.word_embeddings = new_embeddings
|
930 |
+
|
931 |
+
@add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
932 |
+
@add_code_sample_docstrings(
|
933 |
+
tokenizer_class=_TOKENIZER_FOR_DOC,
|
934 |
+
checkpoint="unc-nlp/lxmert-base-uncased",
|
935 |
+
output_type=LxmertModelOutput,
|
936 |
+
config_class=_CONFIG_FOR_DOC,
|
937 |
+
)
|
938 |
+
def forward(
|
939 |
+
self,
|
940 |
+
input_ids=None,
|
941 |
+
visual_feats=None,
|
942 |
+
visual_pos=None,
|
943 |
+
attention_mask=None,
|
944 |
+
visual_attention_mask=None,
|
945 |
+
token_type_ids=None,
|
946 |
+
inputs_embeds=None,
|
947 |
+
output_attentions=None,
|
948 |
+
output_hidden_states=None,
|
949 |
+
return_dict=None,
|
950 |
+
):
|
951 |
+
|
952 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
953 |
+
output_hidden_states = (
|
954 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
955 |
+
)
|
956 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
957 |
+
|
958 |
+
if input_ids is not None and inputs_embeds is not None:
|
959 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
960 |
+
elif input_ids is not None:
|
961 |
+
input_shape = input_ids.size()
|
962 |
+
elif inputs_embeds is not None:
|
963 |
+
input_shape = inputs_embeds.size()[:-1]
|
964 |
+
else:
|
965 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
966 |
+
|
967 |
+
assert visual_feats is not None, "`visual_feats` cannot be `None`"
|
968 |
+
assert visual_pos is not None, "`visual_pos` cannot be `None`"
|
969 |
+
|
970 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
971 |
+
|
972 |
+
if attention_mask is None:
|
973 |
+
attention_mask = torch.ones(input_shape, device=device)
|
974 |
+
if token_type_ids is None:
|
975 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
976 |
+
|
977 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
978 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
979 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
980 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
981 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
982 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
983 |
+
|
984 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
985 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
986 |
+
# positions we want to attend and -10000.0 for masked positions.
|
987 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
988 |
+
# effectively the same as removing these entirely.
|
989 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
|
990 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
991 |
+
|
992 |
+
# Process the visual attention mask
|
993 |
+
if visual_attention_mask is not None:
|
994 |
+
extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2)
|
995 |
+
extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=self.dtype)
|
996 |
+
extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
|
997 |
+
else:
|
998 |
+
extended_visual_attention_mask = None
|
999 |
+
|
1000 |
+
# Positional Word Embeddings
|
1001 |
+
embedding_output = self.embeddings(input_ids, token_type_ids, inputs_embeds)
|
1002 |
+
|
1003 |
+
# Run Lxmert encoder
|
1004 |
+
encoder_outputs = self.encoder(
|
1005 |
+
embedding_output,
|
1006 |
+
extended_attention_mask,
|
1007 |
+
visual_feats=visual_feats,
|
1008 |
+
visual_pos=visual_pos,
|
1009 |
+
visual_attention_mask=extended_visual_attention_mask,
|
1010 |
+
output_attentions=output_attentions,
|
1011 |
+
)
|
1012 |
+
|
1013 |
+
visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2]
|
1014 |
+
vision_hidden_states = visual_encoder_outputs[0]
|
1015 |
+
language_hidden_states = lang_encoder_outputs[0]
|
1016 |
+
|
1017 |
+
all_attentions = ()
|
1018 |
+
if output_attentions:
|
1019 |
+
language_attentions = lang_encoder_outputs[1]
|
1020 |
+
vision_attentions = visual_encoder_outputs[1]
|
1021 |
+
cross_encoder_attentions = encoder_outputs[2]
|
1022 |
+
all_attentions = (
|
1023 |
+
language_attentions,
|
1024 |
+
vision_attentions,
|
1025 |
+
cross_encoder_attentions,
|
1026 |
+
)
|
1027 |
+
|
1028 |
+
hidden_states = (language_hidden_states, vision_hidden_states) if output_hidden_states else ()
|
1029 |
+
|
1030 |
+
visual_output = vision_hidden_states[-1]
|
1031 |
+
lang_output = language_hidden_states[-1]
|
1032 |
+
pooled_output = self.pooler(lang_output)
|
1033 |
+
|
1034 |
+
if not return_dict:
|
1035 |
+
return (lang_output, visual_output, pooled_output) + hidden_states + all_attentions
|
1036 |
+
|
1037 |
+
return LxmertModelOutput(
|
1038 |
+
pooled_output=pooled_output,
|
1039 |
+
language_output=lang_output,
|
1040 |
+
vision_output=visual_output,
|
1041 |
+
language_hidden_states=language_hidden_states if output_hidden_states else None,
|
1042 |
+
vision_hidden_states=vision_hidden_states if output_hidden_states else None,
|
1043 |
+
language_attentions=language_attentions if output_attentions else None,
|
1044 |
+
vision_attentions=vision_attentions if output_attentions else None,
|
1045 |
+
cross_encoder_attentions=cross_encoder_attentions if output_attentions else None,
|
1046 |
+
)
|
1047 |
+
|
1048 |
+
|
1049 |
+
@add_start_docstrings(
|
1050 |
+
"""Lxmert Model with a specified pretraining head on top. """,
|
1051 |
+
LXMERT_START_DOCSTRING,
|
1052 |
+
)
|
1053 |
+
class LxmertForPreTraining(LxmertPreTrainedModel):
|
1054 |
+
def __init__(self, config, save_cams=False):
|
1055 |
+
super().__init__(config)
|
1056 |
+
# Configuration
|
1057 |
+
self.config = config
|
1058 |
+
self.num_qa_labels = config.num_qa_labels
|
1059 |
+
self.visual_loss_normalizer = config.visual_loss_normalizer
|
1060 |
+
|
1061 |
+
# Use of pretraining tasks
|
1062 |
+
self.task_mask_lm = config.task_mask_lm
|
1063 |
+
self.task_obj_predict = config.task_obj_predict
|
1064 |
+
self.task_matched = config.task_matched
|
1065 |
+
self.task_qa = config.task_qa
|
1066 |
+
|
1067 |
+
# Lxmert backbone
|
1068 |
+
self.lxmert = LxmertModel(config, save_cams=save_cams)
|
1069 |
+
|
1070 |
+
# Pre-training heads
|
1071 |
+
self.cls = LxmertPreTrainingHeads(config, self.lxmert.embeddings.word_embeddings.weight)
|
1072 |
+
if self.task_obj_predict:
|
1073 |
+
self.obj_predict_head = LxmertVisualObjHead(config)
|
1074 |
+
if self.task_qa:
|
1075 |
+
self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels)
|
1076 |
+
|
1077 |
+
# Weight initialization
|
1078 |
+
self.init_weights()
|
1079 |
+
|
1080 |
+
# Loss functions
|
1081 |
+
self.loss_fcts = {
|
1082 |
+
"l2": SmoothL1Loss(reduction="none"),
|
1083 |
+
"visual_ce": CrossEntropyLoss(reduction="none"),
|
1084 |
+
"ce": CrossEntropyLoss(),
|
1085 |
+
}
|
1086 |
+
|
1087 |
+
visual_losses = {}
|
1088 |
+
if config.visual_obj_loss:
|
1089 |
+
visual_losses["obj"] = {
|
1090 |
+
"shape": (-1,),
|
1091 |
+
"num": config.num_object_labels,
|
1092 |
+
"loss": "visual_ce",
|
1093 |
+
}
|
1094 |
+
if config.visual_attr_loss:
|
1095 |
+
visual_losses["attr"] = {
|
1096 |
+
"shape": (-1,),
|
1097 |
+
"num": config.num_attr_labels,
|
1098 |
+
"loss": "visual_ce",
|
1099 |
+
}
|
1100 |
+
if config.visual_obj_loss:
|
1101 |
+
visual_losses["feat"] = {
|
1102 |
+
"shape": (-1, config.visual_feat_dim),
|
1103 |
+
"num": config.visual_feat_dim,
|
1104 |
+
"loss": "l2",
|
1105 |
+
}
|
1106 |
+
self.visual_losses = visual_losses
|
1107 |
+
|
1108 |
+
def resize_num_qa_labels(self, num_labels):
|
1109 |
+
"""
|
1110 |
+
Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size
|
1111 |
+
will add newly initialized weights. Reducing the size will remove weights from the end
|
1112 |
+
|
1113 |
+
Args:
|
1114 |
+
num_labels (:obj:`int`, `optional`):
|
1115 |
+
New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized
|
1116 |
+
weights at the end. Reducing the size will remove weights from the end. If not provided or :obj:`None`,
|
1117 |
+
just returns a pointer to the qa labels :obj:`torch.nn.Linear`` module of the model without doing
|
1118 |
+
anything.
|
1119 |
+
|
1120 |
+
Return:
|
1121 |
+
:obj:`torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer
|
1122 |
+
"""
|
1123 |
+
|
1124 |
+
cur_qa_logit_layer = self.get_qa_logit_layer()
|
1125 |
+
if num_labels is None or cur_qa_logit_layer is None:
|
1126 |
+
return
|
1127 |
+
new_qa_logit_layer = self._resize_qa_labels(num_labels)
|
1128 |
+
self.config.num_qa_labels = num_labels
|
1129 |
+
self.num_qa_labels = num_labels
|
1130 |
+
|
1131 |
+
return new_qa_logit_layer
|
1132 |
+
|
1133 |
+
def _resize_qa_labels(self, num_labels):
|
1134 |
+
cur_qa_logit_layer = self.get_qa_logit_layer()
|
1135 |
+
new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels)
|
1136 |
+
self._set_qa_logit_layer(new_qa_logit_layer)
|
1137 |
+
return self.get_qa_logit_layer()
|
1138 |
+
|
1139 |
+
def get_qa_logit_layer(self) -> nn.Module:
|
1140 |
+
"""
|
1141 |
+
Returns the the linear layer that produces question answering logits.
|
1142 |
+
|
1143 |
+
Returns:
|
1144 |
+
:obj:`nn.Module`: A torch module mapping the question answering prediction hidden states or :obj:`None` if
|
1145 |
+
lxmert does not have a visual answering head.
|
1146 |
+
"""
|
1147 |
+
if hasattr(self, "answer_head"):
|
1148 |
+
return self.answer_head.logit_fc[-1]
|
1149 |
+
|
1150 |
+
def _set_qa_logit_layer(self, qa_logit_layer):
|
1151 |
+
self.answer_head.logit_fc[-1] = qa_logit_layer
|
1152 |
+
|
1153 |
+
def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels):
|
1154 |
+
|
1155 |
+
if num_labels is None:
|
1156 |
+
return cur_qa_logit_layer
|
1157 |
+
|
1158 |
+
cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size()
|
1159 |
+
if cur_qa_labels == num_labels:
|
1160 |
+
return cur_qa_logit_layer
|
1161 |
+
|
1162 |
+
# Build new linear output
|
1163 |
+
if getattr(cur_qa_logit_layer, "bias", None) is not None:
|
1164 |
+
new_qa_logit_layer = nn.Linear(hidden_dim, num_labels)
|
1165 |
+
else:
|
1166 |
+
new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False)
|
1167 |
+
|
1168 |
+
new_qa_logit_layer.to(cur_qa_logit_layer.weight.device)
|
1169 |
+
|
1170 |
+
# initialize all new labels
|
1171 |
+
self._init_weights(new_qa_logit_layer)
|
1172 |
+
|
1173 |
+
# Copy labels from the previous weights
|
1174 |
+
num_labels_to_copy = min(cur_qa_labels, num_labels)
|
1175 |
+
new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :]
|
1176 |
+
if getattr(cur_qa_logit_layer, "bias", None) is not None:
|
1177 |
+
new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy]
|
1178 |
+
|
1179 |
+
return new_qa_logit_layer
|
1180 |
+
|
1181 |
+
@add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1182 |
+
@replace_return_docstrings(output_type=LxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
1183 |
+
def forward(
|
1184 |
+
self,
|
1185 |
+
input_ids=None,
|
1186 |
+
visual_feats=None,
|
1187 |
+
visual_pos=None,
|
1188 |
+
attention_mask=None,
|
1189 |
+
visual_attention_mask=None,
|
1190 |
+
token_type_ids=None,
|
1191 |
+
inputs_embeds=None,
|
1192 |
+
labels=None,
|
1193 |
+
obj_labels=None,
|
1194 |
+
matched_label=None,
|
1195 |
+
ans=None,
|
1196 |
+
output_attentions=None,
|
1197 |
+
output_hidden_states=None,
|
1198 |
+
return_dict=None,
|
1199 |
+
**kwargs,
|
1200 |
+
):
|
1201 |
+
r"""
|
1202 |
+
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
|
1203 |
+
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
1204 |
+
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
1205 |
+
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
1206 |
+
obj_labels: (``Dict[Str: Tuple[Torch.FloatTensor, Torch.FloatTensor]]``, `optional`):
|
1207 |
+
each key is named after each one of the visual losses and each element of the tuple is of the shape
|
1208 |
+
``(batch_size, num_features)`` and ``(batch_size, num_features, visual_feature_dim)`` for each the label id
|
1209 |
+
and the label score respectively
|
1210 |
+
matched_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
|
1211 |
+
Labels for computing the whether or not the text input matches the image (classification) loss. Input
|
1212 |
+
should be a sequence pair (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
|
1213 |
+
|
1214 |
+
- 0 indicates that the sentence does not match the image,
|
1215 |
+
- 1 indicates that the sentence does match the image.
|
1216 |
+
ans: (``Torch.Tensor`` of shape ``(batch_size)``, `optional`):
|
1217 |
+
a one hot representation hof the correct answer `optional`
|
1218 |
+
|
1219 |
+
Returns:
|
1220 |
+
"""
|
1221 |
+
|
1222 |
+
if "masked_lm_labels" in kwargs:
|
1223 |
+
warnings.warn(
|
1224 |
+
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
1225 |
+
FutureWarning,
|
1226 |
+
)
|
1227 |
+
labels = kwargs.pop("masked_lm_labels")
|
1228 |
+
|
1229 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1230 |
+
|
1231 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
1232 |
+
lxmert_output = self.lxmert(
|
1233 |
+
input_ids=input_ids,
|
1234 |
+
visual_feats=visual_feats,
|
1235 |
+
visual_pos=visual_pos,
|
1236 |
+
token_type_ids=token_type_ids,
|
1237 |
+
attention_mask=attention_mask,
|
1238 |
+
visual_attention_mask=visual_attention_mask,
|
1239 |
+
inputs_embeds=inputs_embeds,
|
1240 |
+
output_hidden_states=output_hidden_states,
|
1241 |
+
output_attentions=output_attentions,
|
1242 |
+
return_dict=return_dict,
|
1243 |
+
)
|
1244 |
+
|
1245 |
+
lang_output, visual_output, pooled_output = (
|
1246 |
+
lxmert_output[0],
|
1247 |
+
lxmert_output[1],
|
1248 |
+
lxmert_output[2],
|
1249 |
+
)
|
1250 |
+
lang_prediction_scores, cross_relationship_score = self.cls(lang_output, pooled_output)
|
1251 |
+
if self.task_qa:
|
1252 |
+
answer_score = self.answer_head(pooled_output)
|
1253 |
+
else:
|
1254 |
+
answer_score = pooled_output[0][0]
|
1255 |
+
|
1256 |
+
total_loss = (
|
1257 |
+
None
|
1258 |
+
if (labels is None and matched_label is None and obj_labels is None and ans is None)
|
1259 |
+
else torch.tensor(0.0, device=device)
|
1260 |
+
)
|
1261 |
+
if labels is not None and self.task_mask_lm:
|
1262 |
+
masked_lm_loss = self.loss_fcts["ce"](
|
1263 |
+
lang_prediction_scores.view(-1, self.config.vocab_size),
|
1264 |
+
labels.view(-1),
|
1265 |
+
)
|
1266 |
+
total_loss += masked_lm_loss
|
1267 |
+
if matched_label is not None and self.task_matched:
|
1268 |
+
matched_loss = self.loss_fcts["ce"](cross_relationship_score.view(-1, 2), matched_label.view(-1))
|
1269 |
+
total_loss += matched_loss
|
1270 |
+
if obj_labels is not None and self.task_obj_predict:
|
1271 |
+
total_visual_loss = torch.tensor(0.0, device=input_ids.device)
|
1272 |
+
visual_prediction_scores_dict = self.obj_predict_head(visual_output)
|
1273 |
+
for key, key_info in self.visual_losses.items():
|
1274 |
+
label, mask_conf = obj_labels[key]
|
1275 |
+
output_dim = key_info["num"]
|
1276 |
+
loss_fct_name = key_info["loss"]
|
1277 |
+
label_shape = key_info["shape"]
|
1278 |
+
weight = self.visual_loss_normalizer
|
1279 |
+
visual_loss_fct = self.loss_fcts[loss_fct_name]
|
1280 |
+
visual_prediction_scores = visual_prediction_scores_dict[key]
|
1281 |
+
visual_loss = visual_loss_fct(
|
1282 |
+
visual_prediction_scores.view(-1, output_dim),
|
1283 |
+
label.view(*label_shape),
|
1284 |
+
)
|
1285 |
+
if visual_loss.dim() > 1: # Regression Losses
|
1286 |
+
visual_loss = visual_loss.mean(1)
|
1287 |
+
visual_loss = (visual_loss * mask_conf.view(-1)).mean() * weight
|
1288 |
+
total_visual_loss += visual_loss
|
1289 |
+
total_loss += total_visual_loss
|
1290 |
+
if ans is not None and self.task_qa:
|
1291 |
+
answer_loss = self.loss_fcts["ce"](answer_score.view(-1, self.num_qa_labels), ans.view(-1))
|
1292 |
+
total_loss += answer_loss
|
1293 |
+
|
1294 |
+
if not return_dict:
|
1295 |
+
output = (
|
1296 |
+
lang_prediction_scores,
|
1297 |
+
cross_relationship_score,
|
1298 |
+
answer_score,
|
1299 |
+
) + lxmert_output[3:]
|
1300 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
1301 |
+
|
1302 |
+
return LxmertForPreTrainingOutput(
|
1303 |
+
loss=total_loss,
|
1304 |
+
prediction_logits=lang_prediction_scores,
|
1305 |
+
cross_relationship_score=cross_relationship_score,
|
1306 |
+
question_answering_score=answer_score,
|
1307 |
+
language_hidden_states=lxmert_output.language_hidden_states,
|
1308 |
+
vision_hidden_states=lxmert_output.vision_hidden_states,
|
1309 |
+
language_attentions=lxmert_output.language_attentions,
|
1310 |
+
vision_attentions=lxmert_output.vision_attentions,
|
1311 |
+
cross_encoder_attentions=lxmert_output.cross_encoder_attentions,
|
1312 |
+
)
|
1313 |
+
|
1314 |
+
|
1315 |
+
@add_start_docstrings(
|
1316 |
+
"""Lxmert Model with a visual-answering head on top for downstream QA tasks""",
|
1317 |
+
LXMERT_START_DOCSTRING,
|
1318 |
+
)
|
1319 |
+
class LxmertForQuestionAnswering(LxmertPreTrainedModel):
|
1320 |
+
def __init__(self, config):
|
1321 |
+
super().__init__(config)
|
1322 |
+
# Configuration
|
1323 |
+
self.config = config
|
1324 |
+
self.num_qa_labels = config.num_qa_labels
|
1325 |
+
self.visual_loss_normalizer = config.visual_loss_normalizer
|
1326 |
+
|
1327 |
+
# Lxmert backbone
|
1328 |
+
self.lxmert = LxmertModel(config)
|
1329 |
+
|
1330 |
+
self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels)
|
1331 |
+
|
1332 |
+
# Weight initialization
|
1333 |
+
self.init_weights()
|
1334 |
+
|
1335 |
+
# Loss function
|
1336 |
+
self.loss = CrossEntropyLoss()
|
1337 |
+
|
1338 |
+
def resize_num_qa_labels(self, num_labels):
|
1339 |
+
"""
|
1340 |
+
Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size
|
1341 |
+
will add newly initialized weights. Reducing the size will remove weights from the end
|
1342 |
+
|
1343 |
+
Args:
|
1344 |
+
num_labels (:obj:`int`, `optional`):
|
1345 |
+
New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized
|
1346 |
+
weights at the end. Reducing the size will remove weights from the end. If not provided or :obj:`None`,
|
1347 |
+
just returns a pointer to the qa labels :obj:`torch.nn.Linear`` module of the model without doing
|
1348 |
+
anything.
|
1349 |
+
|
1350 |
+
Return:
|
1351 |
+
:obj:`torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer
|
1352 |
+
"""
|
1353 |
+
|
1354 |
+
cur_qa_logit_layer = self.get_qa_logit_layer()
|
1355 |
+
if num_labels is None or cur_qa_logit_layer is None:
|
1356 |
+
return
|
1357 |
+
new_qa_logit_layer = self._resize_qa_labels(num_labels)
|
1358 |
+
self.config.num_qa_labels = num_labels
|
1359 |
+
self.num_qa_labels = num_labels
|
1360 |
+
|
1361 |
+
return new_qa_logit_layer
|
1362 |
+
|
1363 |
+
def _resize_qa_labels(self, num_labels):
|
1364 |
+
cur_qa_logit_layer = self.get_qa_logit_layer()
|
1365 |
+
new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels)
|
1366 |
+
self._set_qa_logit_layer(new_qa_logit_layer)
|
1367 |
+
return self.get_qa_logit_layer()
|
1368 |
+
|
1369 |
+
def get_qa_logit_layer(self) -> nn.Module:
|
1370 |
+
"""
|
1371 |
+
Returns the the linear layer that produces question answering logits
|
1372 |
+
|
1373 |
+
Returns:
|
1374 |
+
:obj:`nn.Module`: A torch module mapping the question answering prediction hidden states. :obj:`None`: A
|
1375 |
+
NoneType object if Lxmert does not have the visual answering head.
|
1376 |
+
"""
|
1377 |
+
|
1378 |
+
if hasattr(self, "answer_head"):
|
1379 |
+
return self.answer_head.logit_fc[-1]
|
1380 |
+
|
1381 |
+
def _set_qa_logit_layer(self, qa_logit_layer):
|
1382 |
+
self.answer_head.logit_fc[-1] = qa_logit_layer
|
1383 |
+
|
1384 |
+
def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels):
|
1385 |
+
|
1386 |
+
if num_labels is None:
|
1387 |
+
return cur_qa_logit_layer
|
1388 |
+
|
1389 |
+
cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size()
|
1390 |
+
if cur_qa_labels == num_labels:
|
1391 |
+
return cur_qa_logit_layer
|
1392 |
+
|
1393 |
+
# Build new linear output
|
1394 |
+
if getattr(cur_qa_logit_layer, "bias", None) is not None:
|
1395 |
+
new_qa_logit_layer = nn.Linear(hidden_dim, num_labels)
|
1396 |
+
else:
|
1397 |
+
new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False)
|
1398 |
+
|
1399 |
+
new_qa_logit_layer.to(cur_qa_logit_layer.weight.device)
|
1400 |
+
|
1401 |
+
# initialize all new labels
|
1402 |
+
self._init_weights(new_qa_logit_layer)
|
1403 |
+
|
1404 |
+
# Copy labels from the previous weights
|
1405 |
+
num_labels_to_copy = min(cur_qa_labels, num_labels)
|
1406 |
+
new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :]
|
1407 |
+
if getattr(cur_qa_logit_layer, "bias", None) is not None:
|
1408 |
+
new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy]
|
1409 |
+
|
1410 |
+
return new_qa_logit_layer
|
1411 |
+
|
1412 |
+
@add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1413 |
+
@add_code_sample_docstrings(
|
1414 |
+
tokenizer_class=_TOKENIZER_FOR_DOC,
|
1415 |
+
checkpoint="unc-nlp/lxmert-base-uncased",
|
1416 |
+
output_type=LxmertForQuestionAnsweringOutput,
|
1417 |
+
config_class=_CONFIG_FOR_DOC,
|
1418 |
+
)
|
1419 |
+
def forward(
|
1420 |
+
self,
|
1421 |
+
input_ids=None,
|
1422 |
+
visual_feats=None,
|
1423 |
+
visual_pos=None,
|
1424 |
+
attention_mask=None,
|
1425 |
+
visual_attention_mask=None,
|
1426 |
+
token_type_ids=None,
|
1427 |
+
inputs_embeds=None,
|
1428 |
+
labels=None,
|
1429 |
+
output_attentions=None,
|
1430 |
+
output_hidden_states=None,
|
1431 |
+
return_dict=None,
|
1432 |
+
):
|
1433 |
+
r"""
|
1434 |
+
labels: (``Torch.Tensor`` of shape ``(batch_size)``, `optional`):
|
1435 |
+
A one-hot representation of the correct answer
|
1436 |
+
|
1437 |
+
Returns:
|
1438 |
+
"""
|
1439 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1440 |
+
|
1441 |
+
lxmert_output = self.lxmert(
|
1442 |
+
input_ids=input_ids,
|
1443 |
+
visual_feats=visual_feats,
|
1444 |
+
visual_pos=visual_pos,
|
1445 |
+
token_type_ids=token_type_ids,
|
1446 |
+
attention_mask=attention_mask,
|
1447 |
+
visual_attention_mask=visual_attention_mask,
|
1448 |
+
inputs_embeds=inputs_embeds,
|
1449 |
+
output_hidden_states=output_hidden_states,
|
1450 |
+
output_attentions=output_attentions,
|
1451 |
+
return_dict=return_dict,
|
1452 |
+
)
|
1453 |
+
|
1454 |
+
pooled_output = lxmert_output[2]
|
1455 |
+
answer_score = self.answer_head(pooled_output)
|
1456 |
+
loss = None
|
1457 |
+
if labels is not None:
|
1458 |
+
loss = self.loss(answer_score.view(-1, self.num_qa_labels), labels.view(-1))
|
1459 |
+
|
1460 |
+
if not return_dict:
|
1461 |
+
output = (answer_score,) + lxmert_output[3:]
|
1462 |
+
return (loss,) + output if loss is not None else output
|
1463 |
+
|
1464 |
+
return LxmertForQuestionAnsweringOutput(
|
1465 |
+
loss=loss,
|
1466 |
+
question_answering_score=answer_score,
|
1467 |
+
language_hidden_states=lxmert_output.language_hidden_states,
|
1468 |
+
vision_hidden_states=lxmert_output.vision_hidden_states,
|
1469 |
+
language_attentions=lxmert_output.language_attentions,
|
1470 |
+
vision_attentions=lxmert_output.vision_attentions,
|
1471 |
+
cross_encoder_attentions=lxmert_output.cross_encoder_attentions,
|
1472 |
+
)
|
lxmert/src/layers.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
__all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
|
6 |
+
'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
|
7 |
+
'LayerNorm', 'AddEye', 'Tanh', 'MatMul', 'Mul']
|
8 |
+
|
9 |
+
|
10 |
+
def safe_divide(a, b):
|
11 |
+
den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
|
12 |
+
den = den + den.eq(0).type(den.type()) * 1e-9
|
13 |
+
return a / den * b.ne(0).type(b.type())
|
14 |
+
|
15 |
+
|
16 |
+
def forward_hook(self, input, output):
|
17 |
+
if type(input[0]) in (list, tuple):
|
18 |
+
self.X = []
|
19 |
+
for i in input[0]:
|
20 |
+
x = i.detach()
|
21 |
+
x.requires_grad = True
|
22 |
+
self.X.append(x)
|
23 |
+
else:
|
24 |
+
self.X = input[0].detach()
|
25 |
+
self.X.requires_grad = True
|
26 |
+
|
27 |
+
self.Y = output
|
28 |
+
|
29 |
+
|
30 |
+
def backward_hook(self, grad_input, grad_output):
|
31 |
+
self.grad_input = grad_input
|
32 |
+
self.grad_output = grad_output
|
33 |
+
|
34 |
+
|
35 |
+
class RelProp(nn.Module):
|
36 |
+
def __init__(self):
|
37 |
+
super(RelProp, self).__init__()
|
38 |
+
# if not self.training:
|
39 |
+
self.register_forward_hook(forward_hook)
|
40 |
+
|
41 |
+
def gradprop(self, Z, X, S):
|
42 |
+
C = torch.autograd.grad(Z, X, S, retain_graph=True)
|
43 |
+
return C
|
44 |
+
|
45 |
+
def relprop(self, R, alpha):
|
46 |
+
return R
|
47 |
+
|
48 |
+
|
49 |
+
class RelPropSimple(RelProp):
|
50 |
+
def relprop(self, R, alpha):
|
51 |
+
Z = self.forward(self.X)
|
52 |
+
S = safe_divide(R, Z)
|
53 |
+
C = self.gradprop(Z, self.X, S)
|
54 |
+
|
55 |
+
if torch.is_tensor(self.X) == False:
|
56 |
+
outputs = []
|
57 |
+
outputs.append(self.X[0] * C[0])
|
58 |
+
outputs.append(self.X[1] * C[1])
|
59 |
+
else:
|
60 |
+
outputs = self.X * (C[0])
|
61 |
+
return outputs
|
62 |
+
|
63 |
+
class AddEye(RelPropSimple):
|
64 |
+
# input of shape B, C, seq_len, seq_len
|
65 |
+
def forward(self, input):
|
66 |
+
return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
|
67 |
+
|
68 |
+
class ReLU(nn.ReLU, RelProp):
|
69 |
+
pass
|
70 |
+
|
71 |
+
class GELU(nn.GELU, RelProp):
|
72 |
+
pass
|
73 |
+
|
74 |
+
class Softmax(nn.Softmax, RelProp):
|
75 |
+
pass
|
76 |
+
|
77 |
+
class Mul(RelPropSimple):
|
78 |
+
def forward(self, inputs):
|
79 |
+
return torch.mul(*inputs)
|
80 |
+
|
81 |
+
class Tanh(nn.Tanh, RelProp):
|
82 |
+
pass
|
83 |
+
class LayerNorm(nn.LayerNorm, RelProp):
|
84 |
+
pass
|
85 |
+
|
86 |
+
class Dropout(nn.Dropout, RelProp):
|
87 |
+
pass
|
88 |
+
|
89 |
+
class MatMul(RelPropSimple):
|
90 |
+
def forward(self, inputs):
|
91 |
+
return torch.matmul(*inputs)
|
92 |
+
|
93 |
+
class MaxPool2d(nn.MaxPool2d, RelPropSimple):
|
94 |
+
pass
|
95 |
+
|
96 |
+
class LayerNorm(nn.LayerNorm, RelProp):
|
97 |
+
pass
|
98 |
+
|
99 |
+
class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
|
100 |
+
pass
|
101 |
+
|
102 |
+
|
103 |
+
class AvgPool2d(nn.AvgPool2d, RelPropSimple):
|
104 |
+
pass
|
105 |
+
|
106 |
+
|
107 |
+
class Add(RelPropSimple):
|
108 |
+
def forward(self, inputs):
|
109 |
+
return torch.add(*inputs)
|
110 |
+
|
111 |
+
def relprop(self, R, alpha):
|
112 |
+
Z = self.forward(self.X)
|
113 |
+
S = safe_divide(R, Z)
|
114 |
+
C = self.gradprop(Z, self.X, S)
|
115 |
+
|
116 |
+
a = self.X[0] * C[0]
|
117 |
+
b = self.X[1] * C[1]
|
118 |
+
|
119 |
+
a_sum = a.sum()
|
120 |
+
b_sum = b.sum()
|
121 |
+
|
122 |
+
a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
|
123 |
+
b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
|
124 |
+
|
125 |
+
a = a * safe_divide(a_fact, a.sum())
|
126 |
+
b = b * safe_divide(b_fact, b.sum())
|
127 |
+
|
128 |
+
outputs = [a, b]
|
129 |
+
|
130 |
+
return outputs
|
131 |
+
|
132 |
+
class einsum(RelPropSimple):
|
133 |
+
def __init__(self, equation):
|
134 |
+
super().__init__()
|
135 |
+
self.equation = equation
|
136 |
+
def forward(self, *operands):
|
137 |
+
return torch.einsum(self.equation, *operands)
|
138 |
+
|
139 |
+
class IndexSelect(RelProp):
|
140 |
+
def forward(self, inputs, dim, indices):
|
141 |
+
self.__setattr__('dim', dim)
|
142 |
+
self.__setattr__('indices', indices)
|
143 |
+
|
144 |
+
return torch.index_select(inputs, dim, indices)
|
145 |
+
|
146 |
+
def relprop(self, R, alpha):
|
147 |
+
Z = self.forward(self.X, self.dim, self.indices)
|
148 |
+
S = safe_divide(R, Z)
|
149 |
+
C = self.gradprop(Z, self.X, S)
|
150 |
+
|
151 |
+
if torch.is_tensor(self.X) == False:
|
152 |
+
outputs = []
|
153 |
+
outputs.append(self.X[0] * C[0])
|
154 |
+
outputs.append(self.X[1] * C[1])
|
155 |
+
else:
|
156 |
+
outputs = self.X * (C[0])
|
157 |
+
return outputs
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
class Clone(RelProp):
|
162 |
+
def forward(self, input, num):
|
163 |
+
self.__setattr__('num', num)
|
164 |
+
outputs = []
|
165 |
+
for _ in range(num):
|
166 |
+
outputs.append(input)
|
167 |
+
|
168 |
+
return outputs
|
169 |
+
|
170 |
+
def relprop(self, R, alpha):
|
171 |
+
Z = []
|
172 |
+
for _ in range(self.num):
|
173 |
+
Z.append(self.X)
|
174 |
+
S = [safe_divide(r, z) for r, z in zip(R, Z)]
|
175 |
+
C = self.gradprop(Z, self.X, S)[0]
|
176 |
+
|
177 |
+
R = self.X * C
|
178 |
+
|
179 |
+
return R
|
180 |
+
|
181 |
+
|
182 |
+
class Cat(RelProp):
|
183 |
+
def forward(self, inputs, dim):
|
184 |
+
self.__setattr__('dim', dim)
|
185 |
+
return torch.cat(inputs, dim)
|
186 |
+
|
187 |
+
def relprop(self, R, alpha):
|
188 |
+
Z = self.forward(self.X, self.dim)
|
189 |
+
S = safe_divide(R, Z)
|
190 |
+
C = self.gradprop(Z, self.X, S)
|
191 |
+
|
192 |
+
outputs = []
|
193 |
+
for x, c in zip(self.X, C):
|
194 |
+
outputs.append(x * c)
|
195 |
+
|
196 |
+
return outputs
|
197 |
+
|
198 |
+
|
199 |
+
class Sequential(nn.Sequential):
|
200 |
+
def relprop(self, R, alpha):
|
201 |
+
for m in reversed(self._modules.values()):
|
202 |
+
R = m.relprop(R, alpha)
|
203 |
+
return R
|
204 |
+
|
205 |
+
|
206 |
+
class BatchNorm2d(nn.BatchNorm2d, RelProp):
|
207 |
+
def relprop(self, R, alpha):
|
208 |
+
X = self.X
|
209 |
+
beta = 1 - alpha
|
210 |
+
weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
|
211 |
+
(self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
|
212 |
+
Z = X * weight + 1e-9
|
213 |
+
S = R / Z
|
214 |
+
Ca = S * weight
|
215 |
+
R = self.X * (Ca)
|
216 |
+
return R
|
217 |
+
|
218 |
+
|
219 |
+
class Linear(nn.Linear, RelProp):
|
220 |
+
def relprop(self, R, alpha):
|
221 |
+
beta = alpha - 1
|
222 |
+
pw = torch.clamp(self.weight, min=0)
|
223 |
+
nw = torch.clamp(self.weight, max=0)
|
224 |
+
px = torch.clamp(self.X, min=0)
|
225 |
+
nx = torch.clamp(self.X, max=0)
|
226 |
+
|
227 |
+
def f(w1, w2, x1, x2):
|
228 |
+
Z1 = F.linear(x1, w1)
|
229 |
+
Z2 = F.linear(x2, w2)
|
230 |
+
S1 = safe_divide(R, Z1 + Z2)
|
231 |
+
S2 = safe_divide(R, Z1 + Z2)
|
232 |
+
C1 = x1 * self.gradprop(Z1, x1, S1)[0]
|
233 |
+
C2 = x2 * self.gradprop(Z2, x2, S2)[0]
|
234 |
+
|
235 |
+
return C1 + C2
|
236 |
+
|
237 |
+
activator_relevances = f(pw, nw, px, nx)
|
238 |
+
inhibitor_relevances = f(nw, pw, px, nx)
|
239 |
+
|
240 |
+
R = alpha * activator_relevances - beta * inhibitor_relevances
|
241 |
+
|
242 |
+
return R
|
243 |
+
|
244 |
+
|
245 |
+
class Conv2d(nn.Conv2d, RelProp):
|
246 |
+
def gradprop2(self, DY, weight):
|
247 |
+
Z = self.forward(self.X)
|
248 |
+
|
249 |
+
output_padding = self.X.size()[2] - (
|
250 |
+
(Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
|
251 |
+
|
252 |
+
return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
|
253 |
+
|
254 |
+
def relprop(self, R, alpha):
|
255 |
+
if self.X.shape[1] == 3:
|
256 |
+
pw = torch.clamp(self.weight, min=0)
|
257 |
+
nw = torch.clamp(self.weight, max=0)
|
258 |
+
X = self.X
|
259 |
+
L = self.X * 0 + \
|
260 |
+
torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
|
261 |
+
keepdim=True)[0]
|
262 |
+
H = self.X * 0 + \
|
263 |
+
torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
|
264 |
+
keepdim=True)[0]
|
265 |
+
Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
|
266 |
+
torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
|
267 |
+
torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
|
268 |
+
|
269 |
+
S = R / Za
|
270 |
+
C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
|
271 |
+
R = C
|
272 |
+
else:
|
273 |
+
beta = alpha - 1
|
274 |
+
pw = torch.clamp(self.weight, min=0)
|
275 |
+
nw = torch.clamp(self.weight, max=0)
|
276 |
+
px = torch.clamp(self.X, min=0)
|
277 |
+
nx = torch.clamp(self.X, max=0)
|
278 |
+
|
279 |
+
def f(w1, w2, x1, x2):
|
280 |
+
Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
|
281 |
+
Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
|
282 |
+
S1 = safe_divide(R, Z1)
|
283 |
+
S2 = safe_divide(R, Z2)
|
284 |
+
C1 = x1 * self.gradprop(Z1, x1, S1)[0]
|
285 |
+
C2 = x2 * self.gradprop(Z2, x2, S2)[0]
|
286 |
+
return C1 + C2
|
287 |
+
|
288 |
+
activator_relevances = f(pw, nw, px, nx)
|
289 |
+
inhibitor_relevances = f(nw, pw, px, nx)
|
290 |
+
|
291 |
+
R = alpha * activator_relevances - beta * inhibitor_relevances
|
292 |
+
return R
|
lxmert/src/lxmert_lrp.py
ADDED
@@ -0,0 +1,1693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 Hao Tan, Mohit Bansal, and the HuggingFace team
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch lxmert model. """
|
16 |
+
|
17 |
+
import math
|
18 |
+
import os
|
19 |
+
import warnings
|
20 |
+
import copy
|
21 |
+
from dataclasses import dataclass
|
22 |
+
from typing import Optional, Tuple
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import CrossEntropyLoss, SmoothL1Loss
|
27 |
+
from lxmert.lxmert.src.layers import *
|
28 |
+
from transformers.file_utils import (
|
29 |
+
ModelOutput,
|
30 |
+
add_code_sample_docstrings,
|
31 |
+
add_start_docstrings,
|
32 |
+
add_start_docstrings_to_model_forward,
|
33 |
+
replace_return_docstrings,
|
34 |
+
)
|
35 |
+
from transformers.modeling_utils import PreTrainedModel
|
36 |
+
from transformers.utils import logging
|
37 |
+
from transformers.configuration_lxmert import LxmertConfig
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__)
|
40 |
+
|
41 |
+
_CONFIG_FOR_DOC = "LxmertConfig"
|
42 |
+
_TOKENIZER_FOR_DOC = "LxmertTokenizer"
|
43 |
+
|
44 |
+
LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
45 |
+
"unc-nlp/lxmert-base-uncased",
|
46 |
+
]
|
47 |
+
|
48 |
+
ACT2FN = {
|
49 |
+
"relu": ReLU,
|
50 |
+
"tanh": Tanh,
|
51 |
+
"gelu": GELU,
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
@dataclass
|
56 |
+
class LxmertModelOutput(ModelOutput):
|
57 |
+
"""
|
58 |
+
Lxmert's outputs that contain the last hidden states, pooled outputs, and attention probabilities for the language,
|
59 |
+
visual, and, cross-modality encoders. (note: the visual encoder in Lxmert is referred to as the "relation-ship"
|
60 |
+
encoder")
|
61 |
+
|
62 |
+
|
63 |
+
Args:
|
64 |
+
language_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
65 |
+
Sequence of hidden-states at the output of the last layer of the language encoder.
|
66 |
+
vision_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
67 |
+
Sequence of hidden-states at the output of the last layer of the visual encoder.
|
68 |
+
pooled_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
|
69 |
+
Last layer hidden-state of the first token of the sequence (classification, CLS, token) further processed
|
70 |
+
by a Linear layer and a Tanh activation function. The Linear
|
71 |
+
language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
72 |
+
Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
|
73 |
+
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
74 |
+
vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
75 |
+
Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
|
76 |
+
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
77 |
+
language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
78 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
79 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
80 |
+
weighted average in the self-attention heads.
|
81 |
+
vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
82 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
83 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
84 |
+
weighted average in the self-attention heads.
|
85 |
+
cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
86 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
87 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
88 |
+
weighted average in the self-attention heads.
|
89 |
+
"""
|
90 |
+
|
91 |
+
language_output: Optional[torch.FloatTensor] = None
|
92 |
+
vision_output: Optional[torch.FloatTensor] = None
|
93 |
+
pooled_output: Optional[torch.FloatTensor] = None
|
94 |
+
language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
95 |
+
vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
96 |
+
language_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
97 |
+
vision_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
98 |
+
cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
99 |
+
|
100 |
+
|
101 |
+
@dataclass
|
102 |
+
class LxmertForQuestionAnsweringOutput(ModelOutput):
|
103 |
+
"""
|
104 |
+
Output type of :class:`~transformers.LxmertForQuestionAnswering`.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
108 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction
|
109 |
+
(classification) loss.k.
|
110 |
+
question_answering_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_qa_answers)`, `optional`):
|
111 |
+
Prediction scores of question answering objective (classification).
|
112 |
+
language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
113 |
+
Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
|
114 |
+
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
115 |
+
vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
116 |
+
Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
|
117 |
+
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
118 |
+
language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
119 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
120 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
121 |
+
weighted average in the self-attention heads.
|
122 |
+
vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
123 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
124 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
125 |
+
weighted average in the self-attention heads.
|
126 |
+
cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
127 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
128 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
129 |
+
weighted average in the self-attention heads.
|
130 |
+
"""
|
131 |
+
|
132 |
+
loss: Optional[torch.FloatTensor] = None
|
133 |
+
question_answering_score: Optional[torch.FloatTensor] = None
|
134 |
+
language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
135 |
+
vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
136 |
+
language_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
137 |
+
vision_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
138 |
+
cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
139 |
+
|
140 |
+
|
141 |
+
@dataclass
|
142 |
+
class LxmertForPreTrainingOutput(ModelOutput):
|
143 |
+
"""
|
144 |
+
Output type of :class:`~transformers.LxmertForPreTraining`.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
148 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction
|
149 |
+
(classification) loss.
|
150 |
+
prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
151 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
152 |
+
cross_relationship_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
|
153 |
+
Prediction scores of the textual matching objective (classification) head (scores of True/False
|
154 |
+
continuation before SoftMax).
|
155 |
+
question_answering_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_qa_answers)`):
|
156 |
+
Prediction scores of question answering objective (classification).
|
157 |
+
language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
158 |
+
Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
|
159 |
+
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
160 |
+
vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
161 |
+
Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
|
162 |
+
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
163 |
+
language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
164 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
165 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
166 |
+
weighted average in the self-attention heads.
|
167 |
+
vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
168 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
169 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
170 |
+
weighted average in the self-attention heads.
|
171 |
+
cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
172 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
173 |
+
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
174 |
+
weighted average in the self-attention heads.
|
175 |
+
|
176 |
+
"""
|
177 |
+
|
178 |
+
loss: [torch.FloatTensor] = None
|
179 |
+
prediction_logits: Optional[torch.FloatTensor] = None
|
180 |
+
cross_relationship_score: Optional[torch.FloatTensor] = None
|
181 |
+
question_answering_score: Optional[torch.FloatTensor] = None
|
182 |
+
language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
183 |
+
vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
184 |
+
language_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
185 |
+
vision_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
186 |
+
cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
187 |
+
|
188 |
+
|
189 |
+
def load_tf_weights_in_lxmert(model, config, tf_checkpoint_path):
|
190 |
+
"""Load tf checkpoints in a pytorch model."""
|
191 |
+
try:
|
192 |
+
import re
|
193 |
+
|
194 |
+
import numpy as np
|
195 |
+
import tensorflow as tf
|
196 |
+
except ImportError:
|
197 |
+
logger.error(
|
198 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
199 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
200 |
+
)
|
201 |
+
raise
|
202 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
203 |
+
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
204 |
+
# Load weights from TF model
|
205 |
+
init_vars = tf.train.list_variables(tf_path)
|
206 |
+
names = []
|
207 |
+
arrays = []
|
208 |
+
for name, shape in init_vars:
|
209 |
+
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
210 |
+
array = tf.train.load_variable(tf_path, name)
|
211 |
+
names.append(name)
|
212 |
+
arrays.append(array)
|
213 |
+
|
214 |
+
for name, array in zip(names, arrays):
|
215 |
+
name = name.split("/")
|
216 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
217 |
+
# which are not required for using pretrained model
|
218 |
+
if any(
|
219 |
+
n
|
220 |
+
in [
|
221 |
+
"adam_v",
|
222 |
+
"adam_m",
|
223 |
+
"AdamWeightDecayOptimizer",
|
224 |
+
"AdamWeightDecayOptimizer_1",
|
225 |
+
"global_step",
|
226 |
+
]
|
227 |
+
for n in name
|
228 |
+
):
|
229 |
+
logger.info("Skipping {}".format("/".join(name)))
|
230 |
+
continue
|
231 |
+
pointer = model
|
232 |
+
for m_name in name:
|
233 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
234 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
235 |
+
else:
|
236 |
+
scope_names = [m_name]
|
237 |
+
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
238 |
+
pointer = getattr(pointer, "weight")
|
239 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
240 |
+
pointer = getattr(pointer, "bias")
|
241 |
+
elif scope_names[0] == "output_weights":
|
242 |
+
pointer = getattr(pointer, "weight")
|
243 |
+
elif scope_names[0] == "squad":
|
244 |
+
pointer = getattr(pointer, "classifier")
|
245 |
+
else:
|
246 |
+
try:
|
247 |
+
pointer = getattr(pointer, scope_names[0])
|
248 |
+
except AttributeError:
|
249 |
+
logger.info("Skipping {}".format("/".join(name)))
|
250 |
+
continue
|
251 |
+
if len(scope_names) >= 2:
|
252 |
+
num = int(scope_names[1])
|
253 |
+
pointer = pointer[num]
|
254 |
+
if m_name[-11:] == "_embeddings":
|
255 |
+
pointer = getattr(pointer, "weight")
|
256 |
+
elif m_name == "kernel":
|
257 |
+
array = np.transpose(array)
|
258 |
+
try:
|
259 |
+
assert pointer.shape == array.shape
|
260 |
+
except AssertionError as e:
|
261 |
+
e.args += (pointer.shape, array.shape)
|
262 |
+
raise
|
263 |
+
logger.info("Initialize PyTorch weight {}".format(name))
|
264 |
+
pointer.data = torch.from_numpy(array)
|
265 |
+
return model
|
266 |
+
|
267 |
+
|
268 |
+
class LxmertEmbeddings(nn.Module):
|
269 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
270 |
+
|
271 |
+
def __init__(self, config):
|
272 |
+
super().__init__()
|
273 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
274 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0)
|
275 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0)
|
276 |
+
|
277 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
278 |
+
# any TensorFlow checkpoint file
|
279 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
280 |
+
self.dropout = Dropout(config.hidden_dropout_prob)
|
281 |
+
|
282 |
+
self.add1 = Add()
|
283 |
+
self.add2 = Add()
|
284 |
+
|
285 |
+
def forward(self, input_ids, token_type_ids=None, inputs_embeds=None):
|
286 |
+
if input_ids is not None:
|
287 |
+
input_shape = input_ids.size()
|
288 |
+
device = input_ids.device
|
289 |
+
else:
|
290 |
+
input_shape = inputs_embeds.size()[:-1]
|
291 |
+
device = inputs_embeds.device
|
292 |
+
seq_length = input_shape[1]
|
293 |
+
|
294 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
295 |
+
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
296 |
+
|
297 |
+
if token_type_ids is None:
|
298 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
299 |
+
|
300 |
+
if inputs_embeds is None:
|
301 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
302 |
+
position_embeddings = self.position_embeddings(position_ids)
|
303 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
304 |
+
|
305 |
+
# embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
306 |
+
embeddings = self.add1([token_type_embeddings, position_embeddings])
|
307 |
+
embeddings = self.add2([embeddings, inputs_embeds])
|
308 |
+
embeddings = self.LayerNorm(embeddings)
|
309 |
+
embeddings = self.dropout(embeddings)
|
310 |
+
return embeddings
|
311 |
+
|
312 |
+
def relprop(self, cam, **kwargs):
|
313 |
+
cam = self.dropout.relprop(cam, **kwargs)
|
314 |
+
cam = self.LayerNorm.relprop(cam, **kwargs)
|
315 |
+
|
316 |
+
# [inputs_embeds, position_embeddings, token_type_embeddings]
|
317 |
+
(cam) = self.add2.relprop(cam, **kwargs)
|
318 |
+
|
319 |
+
return cam
|
320 |
+
|
321 |
+
|
322 |
+
class LxmertAttention(nn.Module):
|
323 |
+
def __init__(self, config, ctx_dim=None):
|
324 |
+
super().__init__()
|
325 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
326 |
+
raise ValueError(
|
327 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
328 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
329 |
+
)
|
330 |
+
self.num_attention_heads = config.num_attention_heads
|
331 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
332 |
+
self.head_size = self.num_attention_heads * self.attention_head_size
|
333 |
+
|
334 |
+
# visual_dim = 2048
|
335 |
+
if ctx_dim is None:
|
336 |
+
ctx_dim = config.hidden_size
|
337 |
+
self.query = Linear(config.hidden_size, self.head_size)
|
338 |
+
self.key = Linear(ctx_dim, self.head_size)
|
339 |
+
self.value = Linear(ctx_dim, self.head_size)
|
340 |
+
|
341 |
+
self.dropout = Dropout(config.attention_probs_dropout_prob)
|
342 |
+
|
343 |
+
self.matmul1 = MatMul()
|
344 |
+
self.matmul2 = MatMul()
|
345 |
+
self.softmax = Softmax(dim=-1)
|
346 |
+
self.add = Add()
|
347 |
+
self.mul = Mul()
|
348 |
+
self.head_mask = None
|
349 |
+
self.attention_mask = None
|
350 |
+
self.clone = Clone()
|
351 |
+
|
352 |
+
self.attn = None
|
353 |
+
self.attn_gradients = None
|
354 |
+
self.attn_cam = None
|
355 |
+
|
356 |
+
def get_attn(self):
|
357 |
+
return self.attn
|
358 |
+
|
359 |
+
def save_attn(self, attn):
|
360 |
+
self.attn = attn
|
361 |
+
|
362 |
+
def get_attn_cam(self):
|
363 |
+
return self.attn_cam
|
364 |
+
|
365 |
+
def save_attn_cam(self, attn_cam):
|
366 |
+
self.attn_cam = attn_cam
|
367 |
+
|
368 |
+
def save_attn_gradients(self, attn_gradients):
|
369 |
+
self.attn_gradients = attn_gradients
|
370 |
+
|
371 |
+
def get_attn_gradients(self):
|
372 |
+
return self.attn_gradients
|
373 |
+
|
374 |
+
def transpose_for_scores(self, x):
|
375 |
+
new_x_shape = x.size()[:-1] + (
|
376 |
+
self.num_attention_heads,
|
377 |
+
self.attention_head_size,
|
378 |
+
)
|
379 |
+
x = x.view(*new_x_shape)
|
380 |
+
return x.permute(0, 2, 1, 3)
|
381 |
+
|
382 |
+
def transpose_for_scores_relprop(self, x):
|
383 |
+
return x.permute(0, 2, 1, 3).flatten(2)
|
384 |
+
|
385 |
+
def forward(self, hidden_states, context, attention_mask=None, output_attentions=False):
|
386 |
+
key, value = self.clone(context, 2)
|
387 |
+
mixed_query_layer = self.query(hidden_states)
|
388 |
+
# mixed_key_layer = self.key(context)
|
389 |
+
# mixed_value_layer = self.value(context)
|
390 |
+
mixed_key_layer = self.key(key)
|
391 |
+
mixed_value_layer = self.value(value)
|
392 |
+
|
393 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
394 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
395 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
396 |
+
|
397 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
398 |
+
attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)])
|
399 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
400 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
401 |
+
if attention_mask is not None:
|
402 |
+
attention_scores = self.add([attention_scores, attention_mask])
|
403 |
+
|
404 |
+
# Normalize the attention scores to probabilities.
|
405 |
+
attention_probs = self.softmax(attention_scores)
|
406 |
+
|
407 |
+
self.save_attn(attention_probs)
|
408 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
409 |
+
|
410 |
+
# This is actually dropping out entire tokens to attend to, which might
|
411 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
412 |
+
attention_probs = self.dropout(attention_probs)
|
413 |
+
|
414 |
+
context_layer = self.matmul2([attention_probs, value_layer])
|
415 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
416 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,)
|
417 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
418 |
+
|
419 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
420 |
+
return outputs
|
421 |
+
|
422 |
+
def relprop(self, cam, **kwargs):
|
423 |
+
# Assume output_attentions == False
|
424 |
+
cam = self.transpose_for_scores(cam)
|
425 |
+
|
426 |
+
# [attention_probs, value_layer]
|
427 |
+
(cam1, cam2) = self.matmul2.relprop(cam, **kwargs)
|
428 |
+
cam1 /= 2
|
429 |
+
cam2 /= 2
|
430 |
+
|
431 |
+
self.save_attn_cam(cam1)
|
432 |
+
|
433 |
+
cam1 = self.dropout.relprop(cam1, **kwargs)
|
434 |
+
|
435 |
+
cam1 = self.softmax.relprop(cam1, **kwargs)
|
436 |
+
|
437 |
+
if self.attention_mask is not None:
|
438 |
+
# [attention_scores, attention_mask]
|
439 |
+
(cam1, _) = self.add.relprop(cam1, **kwargs)
|
440 |
+
|
441 |
+
# [query_layer, key_layer.transpose(-1, -2)]
|
442 |
+
(cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs)
|
443 |
+
cam1_1 /= 2
|
444 |
+
cam1_2 /= 2
|
445 |
+
|
446 |
+
# query
|
447 |
+
cam1_1 = self.transpose_for_scores_relprop(cam1_1)
|
448 |
+
cam1_1 = self.query.relprop(cam1_1, **kwargs)
|
449 |
+
|
450 |
+
# key
|
451 |
+
cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2))
|
452 |
+
cam1_2 = self.key.relprop(cam1_2, **kwargs)
|
453 |
+
|
454 |
+
# value
|
455 |
+
cam2 = self.transpose_for_scores_relprop(cam2)
|
456 |
+
cam2 = self.value.relprop(cam2, **kwargs)
|
457 |
+
|
458 |
+
cam = self.clone.relprop((cam1_2, cam2), **kwargs)
|
459 |
+
|
460 |
+
# returning two cams- one for the hidden state and one for the context
|
461 |
+
return (cam1_1, cam)
|
462 |
+
|
463 |
+
|
464 |
+
class LxmertAttentionOutput(nn.Module):
|
465 |
+
def __init__(self, config):
|
466 |
+
super().__init__()
|
467 |
+
self.dense = Linear(config.hidden_size, config.hidden_size)
|
468 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
469 |
+
self.dropout = Dropout(config.hidden_dropout_prob)
|
470 |
+
self.add = Add()
|
471 |
+
|
472 |
+
def forward(self, hidden_states, input_tensor):
|
473 |
+
hidden_states = self.dense(hidden_states)
|
474 |
+
hidden_states = self.dropout(hidden_states)
|
475 |
+
add = self.add([hidden_states, input_tensor])
|
476 |
+
hidden_states = self.LayerNorm(add)
|
477 |
+
return hidden_states
|
478 |
+
|
479 |
+
def relprop(self, cam, **kwargs):
|
480 |
+
cam = self.LayerNorm.relprop(cam, **kwargs)
|
481 |
+
# [hidden_states, input_tensor]
|
482 |
+
(cam1, cam2) = self.add.relprop(cam, **kwargs)
|
483 |
+
cam1 = self.dropout.relprop(cam1, **kwargs)
|
484 |
+
cam1 = self.dense.relprop(cam1, **kwargs)
|
485 |
+
|
486 |
+
return (cam1, cam2)
|
487 |
+
|
488 |
+
|
489 |
+
class LxmertCrossAttentionLayer(nn.Module):
|
490 |
+
def __init__(self, config):
|
491 |
+
super().__init__()
|
492 |
+
self.att = LxmertAttention(config)
|
493 |
+
self.output = LxmertAttentionOutput(config)
|
494 |
+
self.clone = Clone()
|
495 |
+
|
496 |
+
def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None, output_attentions=False):
|
497 |
+
inp1, inp2 = self.clone(input_tensor, 2)
|
498 |
+
output = self.att(inp1, ctx_tensor, ctx_att_mask, output_attentions=output_attentions)
|
499 |
+
if output_attentions:
|
500 |
+
attention_probs = output[1]
|
501 |
+
attention_output = self.output(output[0], inp2)
|
502 |
+
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
503 |
+
return outputs
|
504 |
+
|
505 |
+
def relprop(self, cam, **kwargs):
|
506 |
+
cam_output, cam_inp2 = self.output.relprop(cam, **kwargs)
|
507 |
+
cam_inp1, cam_ctx = self.att.relprop(cam_output, **kwargs)
|
508 |
+
cam_inp = self.clone.relprop((cam_inp1, cam_inp2), **kwargs)
|
509 |
+
|
510 |
+
return (cam_inp, cam_ctx)
|
511 |
+
|
512 |
+
|
513 |
+
class LxmertSelfAttentionLayer(nn.Module):
|
514 |
+
def __init__(self, config):
|
515 |
+
super().__init__()
|
516 |
+
self.self = LxmertAttention(config)
|
517 |
+
self.output = LxmertAttentionOutput(config)
|
518 |
+
self.clone = Clone()
|
519 |
+
|
520 |
+
def forward(self, input_tensor, attention_mask, output_attentions=False):
|
521 |
+
inp1, inp2, inp3 = self.clone(input_tensor, 3)
|
522 |
+
# Self attention attends to itself, thus keys and queries are the same (input_tensor).
|
523 |
+
output = self.self(
|
524 |
+
inp1,
|
525 |
+
inp2,
|
526 |
+
attention_mask,
|
527 |
+
output_attentions=output_attentions,
|
528 |
+
)
|
529 |
+
if output_attentions:
|
530 |
+
attention_probs = output[1]
|
531 |
+
attention_output = self.output(output[0], inp3)
|
532 |
+
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
533 |
+
return outputs
|
534 |
+
|
535 |
+
def relprop(self, cam, **kwargs):
|
536 |
+
cam_output, cam_inp3 = self.output.relprop(cam, **kwargs)
|
537 |
+
cam_inp1, cam_inp2 = self.self.relprop(cam_output, **kwargs)
|
538 |
+
cam_inp = self.clone.relprop((cam_inp1, cam_inp2, cam_inp3), **kwargs)
|
539 |
+
|
540 |
+
return cam_inp
|
541 |
+
|
542 |
+
|
543 |
+
class LxmertIntermediate(nn.Module):
|
544 |
+
def __init__(self, config):
|
545 |
+
super().__init__()
|
546 |
+
self.dense = Linear(config.hidden_size, config.intermediate_size)
|
547 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]()
|
548 |
+
|
549 |
+
def forward(self, hidden_states):
|
550 |
+
hidden_states = self.dense(hidden_states)
|
551 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
552 |
+
return hidden_states
|
553 |
+
|
554 |
+
def relprop(self, cam, **kwargs):
|
555 |
+
cam = self.intermediate_act_fn.relprop(cam, **kwargs)
|
556 |
+
cam = self.dense.relprop(cam, **kwargs)
|
557 |
+
return cam
|
558 |
+
|
559 |
+
|
560 |
+
class LxmertOutput(nn.Module):
|
561 |
+
def __init__(self, config):
|
562 |
+
super().__init__()
|
563 |
+
self.dense = Linear(config.intermediate_size, config.hidden_size)
|
564 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
565 |
+
self.dropout = Dropout(config.hidden_dropout_prob)
|
566 |
+
self.add = Add()
|
567 |
+
|
568 |
+
def forward(self, hidden_states, input_tensor):
|
569 |
+
hidden_states = self.dense(hidden_states)
|
570 |
+
hidden_states = self.dropout(hidden_states)
|
571 |
+
add = self.add([hidden_states, input_tensor])
|
572 |
+
hidden_states = self.LayerNorm(add)
|
573 |
+
return hidden_states
|
574 |
+
|
575 |
+
def relprop(self, cam, **kwargs):
|
576 |
+
cam = self.LayerNorm.relprop(cam, **kwargs)
|
577 |
+
# [hidden_states, input_tensor]
|
578 |
+
(cam1, cam2)= self.add.relprop(cam, **kwargs)
|
579 |
+
cam1 = self.dropout.relprop(cam1, **kwargs)
|
580 |
+
cam1 = self.dense.relprop(cam1, **kwargs)
|
581 |
+
return (cam1, cam2)
|
582 |
+
|
583 |
+
|
584 |
+
class LxmertLayer(nn.Module):
|
585 |
+
def __init__(self, config):
|
586 |
+
super().__init__()
|
587 |
+
self.attention = LxmertSelfAttentionLayer(config)
|
588 |
+
self.intermediate = LxmertIntermediate(config)
|
589 |
+
self.output = LxmertOutput(config)
|
590 |
+
self.clone = Clone()
|
591 |
+
|
592 |
+
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
593 |
+
outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions)
|
594 |
+
attention_output = outputs[0]
|
595 |
+
ao1, ao2 = self.clone(attention_output, 2)
|
596 |
+
intermediate_output = self.intermediate(ao1)
|
597 |
+
layer_output = self.output(intermediate_output, ao2)
|
598 |
+
outputs = (layer_output,) + outputs[1:] # add attentions if we output them
|
599 |
+
return outputs
|
600 |
+
|
601 |
+
def relprop(self, cam, **kwargs):
|
602 |
+
(cam1, cam2) = self.output.relprop(cam, **kwargs)
|
603 |
+
cam1 = self.intermediate.relprop(cam1, **kwargs)
|
604 |
+
cam = self.clone.relprop((cam1, cam2), **kwargs)
|
605 |
+
cam = self.attention.relprop(cam, **kwargs)
|
606 |
+
return cam
|
607 |
+
|
608 |
+
|
609 |
+
class LxmertXLayer(nn.Module):
|
610 |
+
def __init__(self, config):
|
611 |
+
super().__init__()
|
612 |
+
# The cross-attention Layer
|
613 |
+
self.visual_attention = LxmertCrossAttentionLayer(config)
|
614 |
+
|
615 |
+
# Self-attention Layers
|
616 |
+
self.lang_self_att = LxmertSelfAttentionLayer(config)
|
617 |
+
self.visn_self_att = LxmertSelfAttentionLayer(config)
|
618 |
+
|
619 |
+
# Intermediate and Output Layers (FFNs)
|
620 |
+
self.lang_inter = LxmertIntermediate(config)
|
621 |
+
self.lang_output = LxmertOutput(config)
|
622 |
+
self.visn_inter = LxmertIntermediate(config)
|
623 |
+
self.visn_output = LxmertOutput(config)
|
624 |
+
|
625 |
+
self.clone1 = Clone()
|
626 |
+
self.clone2 = Clone()
|
627 |
+
self.clone3 = Clone()
|
628 |
+
self.clone4 = Clone()
|
629 |
+
|
630 |
+
def cross_att(
|
631 |
+
self,
|
632 |
+
lang_input,
|
633 |
+
lang_attention_mask,
|
634 |
+
visual_input,
|
635 |
+
visual_attention_mask,
|
636 |
+
output_x_attentions=False,
|
637 |
+
):
|
638 |
+
lang_input1, lang_input2 = self.clone1(lang_input, 2)
|
639 |
+
visual_input1, visual_input2 = self.clone2(visual_input, 2)
|
640 |
+
if not hasattr(self, 'visual_attention_copy'):
|
641 |
+
self.visual_attention_copy = copy.deepcopy(self.visual_attention)
|
642 |
+
# Cross Attention
|
643 |
+
lang_att_output = self.visual_attention(
|
644 |
+
lang_input1,
|
645 |
+
visual_input1,
|
646 |
+
ctx_att_mask=visual_attention_mask,
|
647 |
+
output_attentions=output_x_attentions,
|
648 |
+
)
|
649 |
+
visual_att_output = self.visual_attention_copy(
|
650 |
+
visual_input2,
|
651 |
+
lang_input2,
|
652 |
+
ctx_att_mask=lang_attention_mask,
|
653 |
+
output_attentions=False,
|
654 |
+
)
|
655 |
+
return lang_att_output, visual_att_output
|
656 |
+
|
657 |
+
def relprop_cross(self, cam, **kwargs):
|
658 |
+
cam_lang, cam_vis = cam
|
659 |
+
cam_vis2, cam_lang2 = self.visual_attention_copy.relprop(cam_vis, **kwargs)
|
660 |
+
cam_lang1, cam_vis1 = self.visual_attention.relprop(cam_lang, **kwargs)
|
661 |
+
cam_vis = self.clone2.relprop((cam_vis1, cam_vis2), **kwargs)
|
662 |
+
cam_lang = self.clone1.relprop((cam_lang1, cam_lang2), **kwargs)
|
663 |
+
return cam_lang, cam_vis
|
664 |
+
|
665 |
+
|
666 |
+
def self_att(self, lang_input, lang_attention_mask, visual_input, visual_attention_mask):
|
667 |
+
# Self Attention
|
668 |
+
lang_att_output = self.lang_self_att(lang_input, lang_attention_mask, output_attentions=False)
|
669 |
+
visual_att_output = self.visn_self_att(visual_input, visual_attention_mask, output_attentions=False)
|
670 |
+
return lang_att_output[0], visual_att_output[0]
|
671 |
+
|
672 |
+
def relprop_self(self, cam, **kwargs):
|
673 |
+
cam_lang, cam_vis = cam
|
674 |
+
cam_vis = self.visn_self_att.relprop(cam_vis, **kwargs)
|
675 |
+
cam_lang = self.lang_self_att.relprop(cam_lang, **kwargs)
|
676 |
+
return cam_lang, cam_vis
|
677 |
+
|
678 |
+
def output_fc(self, lang_input, visual_input):
|
679 |
+
lang_input1, lang_input2 = self.clone3(lang_input, 2)
|
680 |
+
visual_input1, visual_input2 = self.clone4(visual_input, 2)
|
681 |
+
# FC layers
|
682 |
+
lang_inter_output = self.lang_inter(lang_input1)
|
683 |
+
visual_inter_output = self.visn_inter(visual_input1)
|
684 |
+
|
685 |
+
# Layer output
|
686 |
+
lang_output = self.lang_output(lang_inter_output, lang_input2)
|
687 |
+
visual_output = self.visn_output(visual_inter_output, visual_input2)
|
688 |
+
|
689 |
+
return lang_output, visual_output
|
690 |
+
|
691 |
+
def relprop_output(self, cam, **kwargs):
|
692 |
+
cam_lang, cam_vis = cam
|
693 |
+
cam_vis_inter, cam_vis2 = self.visn_output.relprop(cam_vis, **kwargs)
|
694 |
+
cam_lang_inter, cam_lang2 = self.lang_output.relprop(cam_lang, **kwargs)
|
695 |
+
cam_vis1 = self.visn_inter.relprop(cam_vis_inter, **kwargs)
|
696 |
+
cam_lang1 = self.lang_inter.relprop(cam_lang_inter, **kwargs)
|
697 |
+
cam_vis = self.clone4.relprop((cam_vis1, cam_vis2), **kwargs)
|
698 |
+
cam_lang = self.clone3.relprop((cam_lang1, cam_lang2), **kwargs)
|
699 |
+
return cam_lang, cam_vis
|
700 |
+
|
701 |
+
def forward(
|
702 |
+
self,
|
703 |
+
lang_feats,
|
704 |
+
lang_attention_mask,
|
705 |
+
visual_feats,
|
706 |
+
visual_attention_mask,
|
707 |
+
output_attentions=False,
|
708 |
+
):
|
709 |
+
lang_att_output, visual_att_output = self.cross_att(
|
710 |
+
lang_input=lang_feats,
|
711 |
+
lang_attention_mask=lang_attention_mask,
|
712 |
+
visual_input=visual_feats,
|
713 |
+
visual_attention_mask=visual_attention_mask,
|
714 |
+
output_x_attentions=output_attentions,
|
715 |
+
)
|
716 |
+
attention_probs = lang_att_output[1:]
|
717 |
+
lang_att_output, visual_att_output = self.self_att(
|
718 |
+
lang_att_output[0],
|
719 |
+
lang_attention_mask,
|
720 |
+
visual_att_output[0],
|
721 |
+
visual_attention_mask,
|
722 |
+
)
|
723 |
+
|
724 |
+
lang_output, visual_output = self.output_fc(lang_att_output, visual_att_output)
|
725 |
+
return (
|
726 |
+
(
|
727 |
+
lang_output,
|
728 |
+
visual_output,
|
729 |
+
attention_probs[0],
|
730 |
+
)
|
731 |
+
if output_attentions
|
732 |
+
else (lang_output, visual_output)
|
733 |
+
)
|
734 |
+
|
735 |
+
def relprop(self, cam, **kwargs):
|
736 |
+
cam_lang, cam_vis = cam
|
737 |
+
cam_lang, cam_vis = self.relprop_output((cam_lang, cam_vis), **kwargs)
|
738 |
+
cam_lang, cam_vis = self.relprop_self((cam_lang, cam_vis), **kwargs)
|
739 |
+
cam_lang, cam_vis = self.relprop_cross((cam_lang, cam_vis), **kwargs)
|
740 |
+
return cam_lang, cam_vis
|
741 |
+
|
742 |
+
class LxmertVisualFeatureEncoder(nn.Module):
|
743 |
+
def __init__(self, config):
|
744 |
+
super().__init__()
|
745 |
+
feat_dim = config.visual_feat_dim
|
746 |
+
pos_dim = config.visual_pos_dim
|
747 |
+
|
748 |
+
# Object feature encoding
|
749 |
+
self.visn_fc = Linear(feat_dim, config.hidden_size)
|
750 |
+
self.visn_layer_norm = LayerNorm(config.hidden_size, eps=1e-12)
|
751 |
+
|
752 |
+
# Box position encoding
|
753 |
+
self.box_fc = Linear(pos_dim, config.hidden_size)
|
754 |
+
self.box_layer_norm = LayerNorm(config.hidden_size, eps=1e-12)
|
755 |
+
|
756 |
+
self.dropout = Dropout(config.hidden_dropout_prob)
|
757 |
+
|
758 |
+
def forward(self, visual_feats, visual_pos):
|
759 |
+
x = self.visn_fc(visual_feats)
|
760 |
+
x = self.visn_layer_norm(x)
|
761 |
+
y = self.box_fc(visual_pos)
|
762 |
+
y = self.box_layer_norm(y)
|
763 |
+
output = (x + y) / 2
|
764 |
+
|
765 |
+
output = self.dropout(output)
|
766 |
+
return output
|
767 |
+
|
768 |
+
def relprop(self, cam, **kwargs):
|
769 |
+
cam = self.dropout.relprop(cam, **kwargs)
|
770 |
+
cam = self.visn_layer_norm.relprop(cam, **kwargs)
|
771 |
+
cam = self.visn_fc.relprop(cam, **kwargs)
|
772 |
+
return cam
|
773 |
+
|
774 |
+
class LxmertEncoder(nn.Module):
|
775 |
+
def __init__(self, config):
|
776 |
+
super().__init__()
|
777 |
+
|
778 |
+
# Obj-level image embedding layer
|
779 |
+
self.visn_fc = LxmertVisualFeatureEncoder(config)
|
780 |
+
self.config = config
|
781 |
+
|
782 |
+
# Number of layers
|
783 |
+
self.num_l_layers = config.l_layers
|
784 |
+
self.num_x_layers = config.x_layers
|
785 |
+
self.num_r_layers = config.r_layers
|
786 |
+
|
787 |
+
# Layers
|
788 |
+
# Using self.layer instead of self.l_layer to support loading BERT weights.
|
789 |
+
self.layer = nn.ModuleList([LxmertLayer(config) for _ in range(self.num_l_layers)])
|
790 |
+
self.x_layers = nn.ModuleList([LxmertXLayer(config) for _ in range(self.num_x_layers)])
|
791 |
+
self.r_layers = nn.ModuleList([LxmertLayer(config) for _ in range(self.num_r_layers)])
|
792 |
+
|
793 |
+
def forward(
|
794 |
+
self,
|
795 |
+
lang_feats,
|
796 |
+
lang_attention_mask,
|
797 |
+
visual_feats,
|
798 |
+
visual_pos,
|
799 |
+
visual_attention_mask=None,
|
800 |
+
output_attentions=None,
|
801 |
+
):
|
802 |
+
|
803 |
+
vision_hidden_states = ()
|
804 |
+
language_hidden_states = ()
|
805 |
+
vision_attentions = () if output_attentions or self.config.output_attentions else None
|
806 |
+
language_attentions = () if output_attentions or self.config.output_attentions else None
|
807 |
+
cross_encoder_attentions = () if output_attentions or self.config.output_attentions else None
|
808 |
+
|
809 |
+
visual_feats = self.visn_fc(visual_feats, visual_pos)
|
810 |
+
|
811 |
+
# Run language layers
|
812 |
+
for layer_module in self.layer:
|
813 |
+
l_outputs = layer_module(lang_feats, lang_attention_mask, output_attentions=output_attentions)
|
814 |
+
lang_feats = l_outputs[0]
|
815 |
+
language_hidden_states = language_hidden_states + (lang_feats,)
|
816 |
+
if language_attentions is not None:
|
817 |
+
language_attentions = language_attentions + (l_outputs[1],)
|
818 |
+
|
819 |
+
# Run relational layers
|
820 |
+
for layer_module in self.r_layers:
|
821 |
+
v_outputs = layer_module(visual_feats, visual_attention_mask, output_attentions=output_attentions)
|
822 |
+
visual_feats = v_outputs[0]
|
823 |
+
vision_hidden_states = vision_hidden_states + (visual_feats,)
|
824 |
+
if vision_attentions is not None:
|
825 |
+
vision_attentions = vision_attentions + (v_outputs[1],)
|
826 |
+
|
827 |
+
# Run cross-modality layers
|
828 |
+
for layer_module in self.x_layers:
|
829 |
+
x_outputs = layer_module(
|
830 |
+
lang_feats,
|
831 |
+
lang_attention_mask,
|
832 |
+
visual_feats,
|
833 |
+
visual_attention_mask,
|
834 |
+
output_attentions=output_attentions,
|
835 |
+
)
|
836 |
+
lang_feats, visual_feats = x_outputs[:2]
|
837 |
+
vision_hidden_states = vision_hidden_states + (visual_feats,)
|
838 |
+
language_hidden_states = language_hidden_states + (lang_feats,)
|
839 |
+
if cross_encoder_attentions is not None:
|
840 |
+
cross_encoder_attentions = cross_encoder_attentions + (x_outputs[2],)
|
841 |
+
visual_encoder_outputs = (
|
842 |
+
vision_hidden_states,
|
843 |
+
vision_attentions if output_attentions else None,
|
844 |
+
)
|
845 |
+
lang_encoder_outputs = (
|
846 |
+
language_hidden_states,
|
847 |
+
language_attentions if output_attentions else None,
|
848 |
+
)
|
849 |
+
return (
|
850 |
+
visual_encoder_outputs,
|
851 |
+
lang_encoder_outputs,
|
852 |
+
cross_encoder_attentions if output_attentions else None,
|
853 |
+
)
|
854 |
+
|
855 |
+
def relprop(self, cam, **kwargs):
|
856 |
+
cam_lang, cam_vis = cam
|
857 |
+
for layer_module in reversed(self.x_layers):
|
858 |
+
cam_lang, cam_vis = layer_module.relprop((cam_lang, cam_vis), **kwargs)
|
859 |
+
|
860 |
+
for layer_module in reversed(self.r_layers):
|
861 |
+
cam_vis = layer_module.relprop(cam_vis, **kwargs)
|
862 |
+
|
863 |
+
for layer_module in reversed(self.layer):
|
864 |
+
cam_lang = layer_module.relprop(cam_lang, **kwargs)
|
865 |
+
return cam_lang, cam_vis
|
866 |
+
|
867 |
+
|
868 |
+
class LxmertPooler(nn.Module):
|
869 |
+
def __init__(self, config):
|
870 |
+
super(LxmertPooler, self).__init__()
|
871 |
+
self.dense = Linear(config.hidden_size, config.hidden_size)
|
872 |
+
self.activation = Tanh()
|
873 |
+
|
874 |
+
self.pool = IndexSelect()
|
875 |
+
|
876 |
+
def forward(self, hidden_states):
|
877 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
878 |
+
# to the first token.
|
879 |
+
# first_token_tensor = hidden_states[:, 0]
|
880 |
+
first_token_tensor = self.pool(hidden_states, 1, torch.tensor(0, device=hidden_states.device))
|
881 |
+
first_token_tensor = first_token_tensor.squeeze(1)
|
882 |
+
pooled_output = self.dense(first_token_tensor)
|
883 |
+
pooled_output = self.activation(pooled_output)
|
884 |
+
return pooled_output
|
885 |
+
|
886 |
+
def relprop(self, cam, **kwargs):
|
887 |
+
cam = self.activation.relprop(cam, **kwargs)
|
888 |
+
cam = self.dense.relprop(cam, **kwargs)
|
889 |
+
cam = cam.unsqueeze(1)
|
890 |
+
cam = self.pool.relprop(cam, **kwargs)
|
891 |
+
|
892 |
+
return cam
|
893 |
+
|
894 |
+
|
895 |
+
class LxmertPredictionHeadTransform(nn.Module):
|
896 |
+
def __init__(self, config):
|
897 |
+
super(LxmertPredictionHeadTransform, self).__init__()
|
898 |
+
self.dense = Linear(config.hidden_size, config.hidden_size)
|
899 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
900 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
901 |
+
|
902 |
+
def forward(self, hidden_states):
|
903 |
+
hidden_states = self.dense(hidden_states)
|
904 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
905 |
+
hidden_states = self.LayerNorm(hidden_states)
|
906 |
+
return hidden_states
|
907 |
+
|
908 |
+
def relprop(self, cam, **kwargs):
|
909 |
+
cam = self.LayerNorm.relprop(cam, **kwargs)
|
910 |
+
cam = self.transform_act_fn.relprop(cam, **kwargs)
|
911 |
+
cam = self.dense.relprop(cam, **kwargs)
|
912 |
+
return cam
|
913 |
+
|
914 |
+
|
915 |
+
class LxmertLMPredictionHead(nn.Module):
|
916 |
+
def __init__(self, config, lxmert_model_embedding_weights):
|
917 |
+
super(LxmertLMPredictionHead, self).__init__()
|
918 |
+
self.transform = LxmertPredictionHeadTransform(config)
|
919 |
+
|
920 |
+
# The output weights are the same as the input embeddings, but there is
|
921 |
+
# an output-only bias for each token.
|
922 |
+
self.decoder = Linear(
|
923 |
+
lxmert_model_embedding_weights.size(1),
|
924 |
+
lxmert_model_embedding_weights.size(0),
|
925 |
+
bias=False,
|
926 |
+
)
|
927 |
+
self.decoder.weight = lxmert_model_embedding_weights
|
928 |
+
self.bias = nn.Parameter(torch.zeros(lxmert_model_embedding_weights.size(0)))
|
929 |
+
|
930 |
+
def forward(self, hidden_states):
|
931 |
+
hidden_states = self.transform(hidden_states)
|
932 |
+
hidden_states = self.decoder(hidden_states) + self.bias
|
933 |
+
return hidden_states
|
934 |
+
|
935 |
+
def relprop(self, cam, **kwargs):
|
936 |
+
cam = self.decoder.relprop(cam, **kwargs)
|
937 |
+
cam = self.transform.relprop(cam, **kwargs)
|
938 |
+
return cam
|
939 |
+
|
940 |
+
|
941 |
+
class LxmertVisualAnswerHead(nn.Module):
|
942 |
+
def __init__(self, config, num_labels):
|
943 |
+
super().__init__()
|
944 |
+
hid_dim = config.hidden_size
|
945 |
+
self.logit_fc = nn.Sequential(
|
946 |
+
Linear(hid_dim, hid_dim * 2),
|
947 |
+
GELU(),
|
948 |
+
LayerNorm(hid_dim * 2, eps=1e-12),
|
949 |
+
Linear(hid_dim * 2, num_labels),
|
950 |
+
)
|
951 |
+
|
952 |
+
def forward(self, hidden_states):
|
953 |
+
return self.logit_fc(hidden_states)
|
954 |
+
|
955 |
+
def relprop(self, cam, **kwargs):
|
956 |
+
for m in reversed(self.logit_fc._modules.values()):
|
957 |
+
cam = m.relprop(cam, **kwargs)
|
958 |
+
return cam
|
959 |
+
|
960 |
+
|
961 |
+
class LxmertVisualObjHead(nn.Module):
|
962 |
+
def __init__(self, config):
|
963 |
+
super().__init__()
|
964 |
+
self.transform = LxmertPredictionHeadTransform(config)
|
965 |
+
# Decide the use of visual losses
|
966 |
+
visual_losses = {}
|
967 |
+
if config.visual_obj_loss:
|
968 |
+
visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels}
|
969 |
+
if config.visual_attr_loss:
|
970 |
+
visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels}
|
971 |
+
if config.visual_obj_loss:
|
972 |
+
visual_losses["feat"] = {
|
973 |
+
"shape": (-1, config.visual_feat_dim),
|
974 |
+
"num": config.visual_feat_dim,
|
975 |
+
}
|
976 |
+
self.visual_losses = visual_losses
|
977 |
+
|
978 |
+
# The output weights are the same as the input embeddings, but there is
|
979 |
+
# an output-only bias for each token.
|
980 |
+
self.decoder_dict = nn.ModuleDict(
|
981 |
+
{key: nn.Linear(config.hidden_size, self.visual_losses[key]["num"]) for key in self.visual_losses}
|
982 |
+
)
|
983 |
+
|
984 |
+
def forward(self, hidden_states):
|
985 |
+
hidden_states = self.transform(hidden_states)
|
986 |
+
output = {}
|
987 |
+
for key in self.visual_losses:
|
988 |
+
output[key] = self.decoder_dict[key](hidden_states)
|
989 |
+
return output
|
990 |
+
|
991 |
+
def relprop(self, cam, **kwargs):
|
992 |
+
return self.transform.relprop(cam, **kwargs)
|
993 |
+
|
994 |
+
|
995 |
+
class LxmertPreTrainingHeads(nn.Module):
|
996 |
+
def __init__(self, config, lxmert_model_embedding_weights):
|
997 |
+
super(LxmertPreTrainingHeads, self).__init__()
|
998 |
+
self.predictions = LxmertLMPredictionHead(config, lxmert_model_embedding_weights)
|
999 |
+
self.seq_relationship = Linear(config.hidden_size, 2)
|
1000 |
+
|
1001 |
+
def forward(self, sequence_output, pooled_output):
|
1002 |
+
prediction_scores = self.predictions(sequence_output)
|
1003 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
1004 |
+
return prediction_scores, seq_relationship_score
|
1005 |
+
|
1006 |
+
def relprop(self, cam, **kwargs):
|
1007 |
+
cam_seq, cam_pooled = cam
|
1008 |
+
cam_pooled = self.seq_relationship.relprop(cam_pooled, **kwargs)
|
1009 |
+
cam_seq = self.predictions.relprop(cam_seq, **kwargs)
|
1010 |
+
return cam_seq, cam_pooled
|
1011 |
+
|
1012 |
+
|
1013 |
+
class LxmertPreTrainedModel(PreTrainedModel):
|
1014 |
+
"""
|
1015 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
1016 |
+
models.
|
1017 |
+
"""
|
1018 |
+
|
1019 |
+
config_class = LxmertConfig
|
1020 |
+
load_tf_weights = load_tf_weights_in_lxmert
|
1021 |
+
base_model_prefix = "lxmert"
|
1022 |
+
|
1023 |
+
def _init_weights(self, module):
|
1024 |
+
""" Initialize the weights """
|
1025 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
1026 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
1027 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
1028 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
1029 |
+
elif isinstance(module, nn.LayerNorm):
|
1030 |
+
module.bias.data.zero_()
|
1031 |
+
module.weight.data.fill_(1.0)
|
1032 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
1033 |
+
module.bias.data.zero_()
|
1034 |
+
|
1035 |
+
|
1036 |
+
LXMERT_START_DOCSTRING = r"""
|
1037 |
+
|
1038 |
+
The lxmert model was proposed in `lxmert: Learning Cross-Modality Encoder Representations from Transformers
|
1039 |
+
<https://arxiv.org/abs/1908.07490>`__ by Hao Tan and Mohit Bansal. It's a vision and language transformer model,
|
1040 |
+
pretrained on a variety of multi-modal datasets comprising of GQA, VQAv2.0, MCSCOCO captions, and Visual genome,
|
1041 |
+
using a combination of masked language modeling, region of interest feature regression, cross entropy loss for
|
1042 |
+
question answering attribute prediction, and object tag prediction.
|
1043 |
+
|
1044 |
+
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
1045 |
+
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
1046 |
+
pruning heads etc.)
|
1047 |
+
|
1048 |
+
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
1049 |
+
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
1050 |
+
general usage and behavior.
|
1051 |
+
|
1052 |
+
Parameters:
|
1053 |
+
config (:class:`~transformers.LxmertConfig`): Model configuration class with all the parameters of the model.
|
1054 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
1055 |
+
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
1056 |
+
weights.
|
1057 |
+
"""
|
1058 |
+
|
1059 |
+
LXMERT_INPUTS_DOCSTRING = r"""
|
1060 |
+
|
1061 |
+
Args:
|
1062 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
|
1063 |
+
Indices of input sequence tokens in the vocabulary.
|
1064 |
+
|
1065 |
+
Indices can be obtained using :class:`~transformers.LxmertTokenizer`. See
|
1066 |
+
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
1067 |
+
details.
|
1068 |
+
|
1069 |
+
`What are input IDs? <../glossary.html#input-ids>`__
|
1070 |
+
visual_feats: (:obj:`torch.FloatTensor` of shape :obj:՝(batch_size, num_visual_features, visual_feat_dim)՝):
|
1071 |
+
This input represents visual features. They ROI pooled object features from bounding boxes using a
|
1072 |
+
faster-RCNN model)
|
1073 |
+
|
1074 |
+
These are currently not provided by the transformers library.
|
1075 |
+
visual_pos: (:obj:`torch.FloatTensor` of shape :obj:՝(batch_size, num_visual_features, visual_pos_dim)՝):
|
1076 |
+
This input represents spacial features corresponding to their relative (via index) visual features. The
|
1077 |
+
pre-trained lxmert model expects these spacial features to be normalized bounding boxes on a scale of 0 to
|
1078 |
+
1.
|
1079 |
+
|
1080 |
+
These are currently not provided by the transformers library.
|
1081 |
+
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
1082 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
1083 |
+
|
1084 |
+
- 1 for tokens that are **not masked**,
|
1085 |
+
- 0 for tokens that are **masked**.
|
1086 |
+
|
1087 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
1088 |
+
visual_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
1089 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
1090 |
+
|
1091 |
+
- 1 for tokens that are **not masked**,
|
1092 |
+
- 0 for tokens that are **masked**.
|
1093 |
+
|
1094 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
1095 |
+
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
1096 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
1097 |
+
1]``:
|
1098 |
+
|
1099 |
+
- 0 corresponds to a `sentence A` token,
|
1100 |
+
- 1 corresponds to a `sentence B` token.
|
1101 |
+
|
1102 |
+
`What are token type IDs? <../glossary.html#token-type-ids>`__
|
1103 |
+
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
|
1104 |
+
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
1105 |
+
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
1106 |
+
vectors than the model's internal embedding lookup matrix.
|
1107 |
+
output_attentions (:obj:`bool`, `optional`):
|
1108 |
+
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
1109 |
+
tensors for more detail.
|
1110 |
+
output_hidden_states (:obj:`bool`, `optional`):
|
1111 |
+
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
1112 |
+
more detail.
|
1113 |
+
return_dict (:obj:`bool`, `optional`):
|
1114 |
+
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
1115 |
+
"""
|
1116 |
+
|
1117 |
+
|
1118 |
+
@add_start_docstrings(
|
1119 |
+
"The bare Lxmert Model transformer outputting raw hidden-states without any specific head on top.",
|
1120 |
+
LXMERT_START_DOCSTRING,
|
1121 |
+
)
|
1122 |
+
class LxmertModel(LxmertPreTrainedModel):
|
1123 |
+
def __init__(self, config):
|
1124 |
+
super().__init__(config)
|
1125 |
+
self.embeddings = LxmertEmbeddings(config)
|
1126 |
+
self.encoder = LxmertEncoder(config)
|
1127 |
+
self.pooler = LxmertPooler(config)
|
1128 |
+
self.init_weights()
|
1129 |
+
|
1130 |
+
def get_input_embeddings(self):
|
1131 |
+
return self.embeddings.word_embeddings
|
1132 |
+
|
1133 |
+
def set_input_embeddings(self, new_embeddings):
|
1134 |
+
self.embeddings.word_embeddings = new_embeddings
|
1135 |
+
|
1136 |
+
@add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1137 |
+
@add_code_sample_docstrings(
|
1138 |
+
tokenizer_class=_TOKENIZER_FOR_DOC,
|
1139 |
+
checkpoint="unc-nlp/lxmert-base-uncased",
|
1140 |
+
output_type=LxmertModelOutput,
|
1141 |
+
config_class=_CONFIG_FOR_DOC,
|
1142 |
+
)
|
1143 |
+
def forward(
|
1144 |
+
self,
|
1145 |
+
input_ids=None,
|
1146 |
+
visual_feats=None,
|
1147 |
+
visual_pos=None,
|
1148 |
+
attention_mask=None,
|
1149 |
+
visual_attention_mask=None,
|
1150 |
+
token_type_ids=None,
|
1151 |
+
inputs_embeds=None,
|
1152 |
+
output_attentions=None,
|
1153 |
+
output_hidden_states=None,
|
1154 |
+
return_dict=None,
|
1155 |
+
):
|
1156 |
+
|
1157 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1158 |
+
output_hidden_states = (
|
1159 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1160 |
+
)
|
1161 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1162 |
+
|
1163 |
+
if input_ids is not None and inputs_embeds is not None:
|
1164 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
1165 |
+
elif input_ids is not None:
|
1166 |
+
input_shape = input_ids.size()
|
1167 |
+
elif inputs_embeds is not None:
|
1168 |
+
input_shape = inputs_embeds.size()[:-1]
|
1169 |
+
else:
|
1170 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
1171 |
+
|
1172 |
+
assert visual_feats is not None, "`visual_feats` cannot be `None`"
|
1173 |
+
assert visual_pos is not None, "`visual_pos` cannot be `None`"
|
1174 |
+
|
1175 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
1176 |
+
|
1177 |
+
if attention_mask is None:
|
1178 |
+
attention_mask = torch.ones(input_shape, device=device)
|
1179 |
+
if token_type_ids is None:
|
1180 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
1181 |
+
|
1182 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
1183 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
1184 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
1185 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
1186 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
1187 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
1188 |
+
|
1189 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
1190 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
1191 |
+
# positions we want to attend and -10000.0 for masked positions.
|
1192 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
1193 |
+
# effectively the same as removing these entirely.
|
1194 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
|
1195 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
1196 |
+
|
1197 |
+
# Process the visual attention mask
|
1198 |
+
if visual_attention_mask is not None:
|
1199 |
+
extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2)
|
1200 |
+
extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=self.dtype)
|
1201 |
+
extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
|
1202 |
+
else:
|
1203 |
+
extended_visual_attention_mask = None
|
1204 |
+
|
1205 |
+
# Positional Word Embeddings
|
1206 |
+
embedding_output = self.embeddings(input_ids, token_type_ids, inputs_embeds)
|
1207 |
+
|
1208 |
+
# Run Lxmert encoder
|
1209 |
+
encoder_outputs = self.encoder(
|
1210 |
+
embedding_output,
|
1211 |
+
extended_attention_mask,
|
1212 |
+
visual_feats=visual_feats,
|
1213 |
+
visual_pos=visual_pos,
|
1214 |
+
visual_attention_mask=extended_visual_attention_mask,
|
1215 |
+
output_attentions=output_attentions,
|
1216 |
+
)
|
1217 |
+
|
1218 |
+
visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2]
|
1219 |
+
vision_hidden_states = visual_encoder_outputs[0]
|
1220 |
+
language_hidden_states = lang_encoder_outputs[0]
|
1221 |
+
|
1222 |
+
all_attentions = ()
|
1223 |
+
if output_attentions:
|
1224 |
+
language_attentions = lang_encoder_outputs[1]
|
1225 |
+
vision_attentions = visual_encoder_outputs[1]
|
1226 |
+
cross_encoder_attentions = encoder_outputs[2]
|
1227 |
+
all_attentions = (
|
1228 |
+
language_attentions,
|
1229 |
+
vision_attentions,
|
1230 |
+
cross_encoder_attentions,
|
1231 |
+
)
|
1232 |
+
|
1233 |
+
hidden_states = (language_hidden_states, vision_hidden_states) if output_hidden_states else ()
|
1234 |
+
|
1235 |
+
visual_output = vision_hidden_states[-1]
|
1236 |
+
lang_output = language_hidden_states[-1]
|
1237 |
+
pooled_output = self.pooler(lang_output)
|
1238 |
+
|
1239 |
+
if not return_dict:
|
1240 |
+
return (lang_output, visual_output, pooled_output) + hidden_states + all_attentions
|
1241 |
+
|
1242 |
+
return LxmertModelOutput(
|
1243 |
+
pooled_output=pooled_output,
|
1244 |
+
language_output=lang_output,
|
1245 |
+
vision_output=visual_output,
|
1246 |
+
language_hidden_states=language_hidden_states if output_hidden_states else None,
|
1247 |
+
vision_hidden_states=vision_hidden_states if output_hidden_states else None,
|
1248 |
+
language_attentions=language_attentions if output_attentions else None,
|
1249 |
+
vision_attentions=vision_attentions if output_attentions else None,
|
1250 |
+
cross_encoder_attentions=cross_encoder_attentions if output_attentions else None,
|
1251 |
+
)
|
1252 |
+
|
1253 |
+
def relprop(self, cam, **kwargs):
|
1254 |
+
cam_lang, cam_vis = cam
|
1255 |
+
cam_lang = self.pooler.relprop(cam_lang, **kwargs)
|
1256 |
+
cam_lang, cam_vis = self.encoder.relprop((cam_lang, cam_vis), **kwargs)
|
1257 |
+
return cam_lang, cam_vis
|
1258 |
+
|
1259 |
+
|
1260 |
+
|
1261 |
+
@add_start_docstrings(
|
1262 |
+
"""Lxmert Model with a specified pretraining head on top. """,
|
1263 |
+
LXMERT_START_DOCSTRING,
|
1264 |
+
)
|
1265 |
+
class LxmertForPreTraining(LxmertPreTrainedModel):
|
1266 |
+
def __init__(self, config):
|
1267 |
+
super().__init__(config)
|
1268 |
+
# Configuration
|
1269 |
+
self.config = config
|
1270 |
+
self.num_qa_labels = config.num_qa_labels
|
1271 |
+
self.visual_loss_normalizer = config.visual_loss_normalizer
|
1272 |
+
|
1273 |
+
# Use of pretraining tasks
|
1274 |
+
self.task_mask_lm = config.task_mask_lm
|
1275 |
+
self.task_obj_predict = config.task_obj_predict
|
1276 |
+
self.task_matched = config.task_matched
|
1277 |
+
self.task_qa = config.task_qa
|
1278 |
+
|
1279 |
+
# Lxmert backbone
|
1280 |
+
self.lxmert = LxmertModel(config)
|
1281 |
+
|
1282 |
+
# Pre-training heads
|
1283 |
+
self.cls = LxmertPreTrainingHeads(config, self.lxmert.embeddings.word_embeddings.weight)
|
1284 |
+
if self.task_obj_predict:
|
1285 |
+
self.obj_predict_head = LxmertVisualObjHead(config)
|
1286 |
+
if self.task_qa:
|
1287 |
+
self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels)
|
1288 |
+
|
1289 |
+
# Weight initialization
|
1290 |
+
self.init_weights()
|
1291 |
+
|
1292 |
+
# Loss functions
|
1293 |
+
self.loss_fcts = {
|
1294 |
+
"l2": SmoothL1Loss(reduction="none"),
|
1295 |
+
"visual_ce": CrossEntropyLoss(reduction="none"),
|
1296 |
+
"ce": CrossEntropyLoss(),
|
1297 |
+
}
|
1298 |
+
|
1299 |
+
visual_losses = {}
|
1300 |
+
if config.visual_obj_loss:
|
1301 |
+
visual_losses["obj"] = {
|
1302 |
+
"shape": (-1,),
|
1303 |
+
"num": config.num_object_labels,
|
1304 |
+
"loss": "visual_ce",
|
1305 |
+
}
|
1306 |
+
if config.visual_attr_loss:
|
1307 |
+
visual_losses["attr"] = {
|
1308 |
+
"shape": (-1,),
|
1309 |
+
"num": config.num_attr_labels,
|
1310 |
+
"loss": "visual_ce",
|
1311 |
+
}
|
1312 |
+
if config.visual_obj_loss:
|
1313 |
+
visual_losses["feat"] = {
|
1314 |
+
"shape": (-1, config.visual_feat_dim),
|
1315 |
+
"num": config.visual_feat_dim,
|
1316 |
+
"loss": "l2",
|
1317 |
+
}
|
1318 |
+
self.visual_losses = visual_losses
|
1319 |
+
|
1320 |
+
def resize_num_qa_labels(self, num_labels):
|
1321 |
+
"""
|
1322 |
+
Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size
|
1323 |
+
will add newly initialized weights. Reducing the size will remove weights from the end
|
1324 |
+
|
1325 |
+
Args:
|
1326 |
+
num_labels (:obj:`int`, `optional`):
|
1327 |
+
New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized
|
1328 |
+
weights at the end. Reducing the size will remove weights from the end. If not provided or :obj:`None`,
|
1329 |
+
just returns a pointer to the qa labels :obj:`torch.nn.Linear`` module of the model without doing
|
1330 |
+
anything.
|
1331 |
+
|
1332 |
+
Return:
|
1333 |
+
:obj:`torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer
|
1334 |
+
"""
|
1335 |
+
|
1336 |
+
cur_qa_logit_layer = self.get_qa_logit_layer()
|
1337 |
+
if num_labels is None or cur_qa_logit_layer is None:
|
1338 |
+
return
|
1339 |
+
new_qa_logit_layer = self._resize_qa_labels(num_labels)
|
1340 |
+
self.config.num_qa_labels = num_labels
|
1341 |
+
self.num_qa_labels = num_labels
|
1342 |
+
|
1343 |
+
return new_qa_logit_layer
|
1344 |
+
|
1345 |
+
def _resize_qa_labels(self, num_labels):
|
1346 |
+
cur_qa_logit_layer = self.get_qa_logit_layer()
|
1347 |
+
new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels)
|
1348 |
+
self._set_qa_logit_layer(new_qa_logit_layer)
|
1349 |
+
return self.get_qa_logit_layer()
|
1350 |
+
|
1351 |
+
def get_qa_logit_layer(self) -> nn.Module:
|
1352 |
+
"""
|
1353 |
+
Returns the the linear layer that produces question answering logits.
|
1354 |
+
|
1355 |
+
Returns:
|
1356 |
+
:obj:`nn.Module`: A torch module mapping the question answering prediction hidden states or :obj:`None` if
|
1357 |
+
lxmert does not have a visual answering head.
|
1358 |
+
"""
|
1359 |
+
if hasattr(self, "answer_head"):
|
1360 |
+
return self.answer_head.logit_fc[-1]
|
1361 |
+
|
1362 |
+
def _set_qa_logit_layer(self, qa_logit_layer):
|
1363 |
+
self.answer_head.logit_fc[-1] = qa_logit_layer
|
1364 |
+
|
1365 |
+
def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels):
|
1366 |
+
|
1367 |
+
if num_labels is None:
|
1368 |
+
return cur_qa_logit_layer
|
1369 |
+
|
1370 |
+
cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size()
|
1371 |
+
if cur_qa_labels == num_labels:
|
1372 |
+
return cur_qa_logit_layer
|
1373 |
+
|
1374 |
+
# Build new linear output
|
1375 |
+
if getattr(cur_qa_logit_layer, "bias", None) is not None:
|
1376 |
+
new_qa_logit_layer = nn.Linear(hidden_dim, num_labels)
|
1377 |
+
else:
|
1378 |
+
new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False)
|
1379 |
+
|
1380 |
+
new_qa_logit_layer.to(cur_qa_logit_layer.weight.device)
|
1381 |
+
|
1382 |
+
# initialize all new labels
|
1383 |
+
self._init_weights(new_qa_logit_layer)
|
1384 |
+
|
1385 |
+
# Copy labels from the previous weights
|
1386 |
+
num_labels_to_copy = min(cur_qa_labels, num_labels)
|
1387 |
+
new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :]
|
1388 |
+
if getattr(cur_qa_logit_layer, "bias", None) is not None:
|
1389 |
+
new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy]
|
1390 |
+
|
1391 |
+
return new_qa_logit_layer
|
1392 |
+
|
1393 |
+
@add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1394 |
+
@replace_return_docstrings(output_type=LxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
1395 |
+
def forward(
|
1396 |
+
self,
|
1397 |
+
input_ids=None,
|
1398 |
+
visual_feats=None,
|
1399 |
+
visual_pos=None,
|
1400 |
+
attention_mask=None,
|
1401 |
+
visual_attention_mask=None,
|
1402 |
+
token_type_ids=None,
|
1403 |
+
inputs_embeds=None,
|
1404 |
+
labels=None,
|
1405 |
+
obj_labels=None,
|
1406 |
+
matched_label=None,
|
1407 |
+
ans=None,
|
1408 |
+
output_attentions=None,
|
1409 |
+
output_hidden_states=None,
|
1410 |
+
return_dict=None,
|
1411 |
+
**kwargs,
|
1412 |
+
):
|
1413 |
+
r"""
|
1414 |
+
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
|
1415 |
+
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
1416 |
+
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
1417 |
+
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
1418 |
+
obj_labels: (``Dict[Str: Tuple[Torch.FloatTensor, Torch.FloatTensor]]``, `optional`):
|
1419 |
+
each key is named after each one of the visual losses and each element of the tuple is of the shape
|
1420 |
+
``(batch_size, num_features)`` and ``(batch_size, num_features, visual_feature_dim)`` for each the label id
|
1421 |
+
and the label score respectively
|
1422 |
+
matched_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
|
1423 |
+
Labels for computing the whether or not the text input matches the image (classification) loss. Input
|
1424 |
+
should be a sequence pair (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
|
1425 |
+
|
1426 |
+
- 0 indicates that the sentence does not match the image,
|
1427 |
+
- 1 indicates that the sentence does match the image.
|
1428 |
+
ans: (``Torch.Tensor`` of shape ``(batch_size)``, `optional`):
|
1429 |
+
a one hot representation hof the correct answer `optional`
|
1430 |
+
|
1431 |
+
Returns:
|
1432 |
+
"""
|
1433 |
+
|
1434 |
+
if "masked_lm_labels" in kwargs:
|
1435 |
+
warnings.warn(
|
1436 |
+
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
1437 |
+
FutureWarning,
|
1438 |
+
)
|
1439 |
+
labels = kwargs.pop("masked_lm_labels")
|
1440 |
+
|
1441 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1442 |
+
|
1443 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
1444 |
+
lxmert_output = self.lxmert(
|
1445 |
+
input_ids=input_ids,
|
1446 |
+
visual_feats=visual_feats,
|
1447 |
+
visual_pos=visual_pos,
|
1448 |
+
token_type_ids=token_type_ids,
|
1449 |
+
attention_mask=attention_mask,
|
1450 |
+
visual_attention_mask=visual_attention_mask,
|
1451 |
+
inputs_embeds=inputs_embeds,
|
1452 |
+
output_hidden_states=output_hidden_states,
|
1453 |
+
output_attentions=output_attentions,
|
1454 |
+
return_dict=return_dict,
|
1455 |
+
)
|
1456 |
+
|
1457 |
+
lang_output, visual_output, pooled_output = (
|
1458 |
+
lxmert_output[0],
|
1459 |
+
lxmert_output[1],
|
1460 |
+
lxmert_output[2],
|
1461 |
+
)
|
1462 |
+
lang_prediction_scores, cross_relationship_score = self.cls(lang_output, pooled_output)
|
1463 |
+
if self.task_qa:
|
1464 |
+
answer_score = self.answer_head(pooled_output)
|
1465 |
+
else:
|
1466 |
+
answer_score = pooled_output[0][0]
|
1467 |
+
|
1468 |
+
total_loss = (
|
1469 |
+
None
|
1470 |
+
if (labels is None and matched_label is None and obj_labels is None and ans is None)
|
1471 |
+
else torch.tensor(0.0, device=device)
|
1472 |
+
)
|
1473 |
+
if labels is not None and self.task_mask_lm:
|
1474 |
+
masked_lm_loss = self.loss_fcts["ce"](
|
1475 |
+
lang_prediction_scores.view(-1, self.config.vocab_size),
|
1476 |
+
labels.view(-1),
|
1477 |
+
)
|
1478 |
+
total_loss += masked_lm_loss
|
1479 |
+
if matched_label is not None and self.task_matched:
|
1480 |
+
matched_loss = self.loss_fcts["ce"](cross_relationship_score.view(-1, 2), matched_label.view(-1))
|
1481 |
+
total_loss += matched_loss
|
1482 |
+
if obj_labels is not None and self.task_obj_predict:
|
1483 |
+
total_visual_loss = torch.tensor(0.0, device=input_ids.device)
|
1484 |
+
visual_prediction_scores_dict = self.obj_predict_head(visual_output)
|
1485 |
+
for key, key_info in self.visual_losses.items():
|
1486 |
+
label, mask_conf = obj_labels[key]
|
1487 |
+
output_dim = key_info["num"]
|
1488 |
+
loss_fct_name = key_info["loss"]
|
1489 |
+
label_shape = key_info["shape"]
|
1490 |
+
weight = self.visual_loss_normalizer
|
1491 |
+
visual_loss_fct = self.loss_fcts[loss_fct_name]
|
1492 |
+
visual_prediction_scores = visual_prediction_scores_dict[key]
|
1493 |
+
visual_loss = visual_loss_fct(
|
1494 |
+
visual_prediction_scores.view(-1, output_dim),
|
1495 |
+
label.view(*label_shape),
|
1496 |
+
)
|
1497 |
+
if visual_loss.dim() > 1: # Regression Losses
|
1498 |
+
visual_loss = visual_loss.mean(1)
|
1499 |
+
visual_loss = (visual_loss * mask_conf.view(-1)).mean() * weight
|
1500 |
+
total_visual_loss += visual_loss
|
1501 |
+
total_loss += total_visual_loss
|
1502 |
+
if ans is not None and self.task_qa:
|
1503 |
+
answer_loss = self.loss_fcts["ce"](answer_score.view(-1, self.num_qa_labels), ans.view(-1))
|
1504 |
+
total_loss += answer_loss
|
1505 |
+
|
1506 |
+
if not return_dict:
|
1507 |
+
output = (
|
1508 |
+
lang_prediction_scores,
|
1509 |
+
cross_relationship_score,
|
1510 |
+
answer_score,
|
1511 |
+
) + lxmert_output[3:]
|
1512 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
1513 |
+
|
1514 |
+
return LxmertForPreTrainingOutput(
|
1515 |
+
loss=total_loss,
|
1516 |
+
prediction_logits=lang_prediction_scores,
|
1517 |
+
cross_relationship_score=cross_relationship_score,
|
1518 |
+
question_answering_score=answer_score,
|
1519 |
+
language_hidden_states=lxmert_output.language_hidden_states,
|
1520 |
+
vision_hidden_states=lxmert_output.vision_hidden_states,
|
1521 |
+
language_attentions=lxmert_output.language_attentions,
|
1522 |
+
vision_attentions=lxmert_output.vision_attentions,
|
1523 |
+
cross_encoder_attentions=lxmert_output.cross_encoder_attentions,
|
1524 |
+
)
|
1525 |
+
|
1526 |
+
|
1527 |
+
|
1528 |
+
@add_start_docstrings(
|
1529 |
+
"""Lxmert Model with a visual-answering head on top for downstream QA tasks""",
|
1530 |
+
LXMERT_START_DOCSTRING,
|
1531 |
+
)
|
1532 |
+
class LxmertForQuestionAnswering(LxmertPreTrainedModel):
|
1533 |
+
def __init__(self, config):
|
1534 |
+
super().__init__(config)
|
1535 |
+
# Configuration
|
1536 |
+
self.config = config
|
1537 |
+
self.num_qa_labels = config.num_qa_labels
|
1538 |
+
self.visual_loss_normalizer = config.visual_loss_normalizer
|
1539 |
+
|
1540 |
+
# Lxmert backbone
|
1541 |
+
self.lxmert = LxmertModel(config)
|
1542 |
+
|
1543 |
+
self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels)
|
1544 |
+
|
1545 |
+
# Weight initialization
|
1546 |
+
self.init_weights()
|
1547 |
+
|
1548 |
+
# Loss function
|
1549 |
+
self.loss = CrossEntropyLoss()
|
1550 |
+
|
1551 |
+
def resize_num_qa_labels(self, num_labels):
|
1552 |
+
"""
|
1553 |
+
Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size
|
1554 |
+
will add newly initialized weights. Reducing the size will remove weights from the end
|
1555 |
+
|
1556 |
+
Args:
|
1557 |
+
num_labels (:obj:`int`, `optional`):
|
1558 |
+
New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized
|
1559 |
+
weights at the end. Reducing the size will remove weights from the end. If not provided or :obj:`None`,
|
1560 |
+
just returns a pointer to the qa labels :obj:`torch.nn.Linear`` module of the model without doing
|
1561 |
+
anything.
|
1562 |
+
|
1563 |
+
Return:
|
1564 |
+
:obj:`torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer
|
1565 |
+
"""
|
1566 |
+
|
1567 |
+
cur_qa_logit_layer = self.get_qa_logit_layer()
|
1568 |
+
if num_labels is None or cur_qa_logit_layer is None:
|
1569 |
+
return
|
1570 |
+
new_qa_logit_layer = self._resize_qa_labels(num_labels)
|
1571 |
+
self.config.num_qa_labels = num_labels
|
1572 |
+
self.num_qa_labels = num_labels
|
1573 |
+
|
1574 |
+
return new_qa_logit_layer
|
1575 |
+
|
1576 |
+
def _resize_qa_labels(self, num_labels):
|
1577 |
+
cur_qa_logit_layer = self.get_qa_logit_layer()
|
1578 |
+
new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels)
|
1579 |
+
self._set_qa_logit_layer(new_qa_logit_layer)
|
1580 |
+
return self.get_qa_logit_layer()
|
1581 |
+
|
1582 |
+
def get_qa_logit_layer(self) -> nn.Module:
|
1583 |
+
"""
|
1584 |
+
Returns the the linear layer that produces question answering logits
|
1585 |
+
|
1586 |
+
Returns:
|
1587 |
+
:obj:`nn.Module`: A torch module mapping the question answering prediction hidden states. :obj:`None`: A
|
1588 |
+
NoneType object if Lxmert does not have the visual answering head.
|
1589 |
+
"""
|
1590 |
+
|
1591 |
+
if hasattr(self, "answer_head"):
|
1592 |
+
return self.answer_head.logit_fc[-1]
|
1593 |
+
|
1594 |
+
def _set_qa_logit_layer(self, qa_logit_layer):
|
1595 |
+
self.answer_head.logit_fc[-1] = qa_logit_layer
|
1596 |
+
|
1597 |
+
def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels):
|
1598 |
+
|
1599 |
+
if num_labels is None:
|
1600 |
+
return cur_qa_logit_layer
|
1601 |
+
|
1602 |
+
cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size()
|
1603 |
+
if cur_qa_labels == num_labels:
|
1604 |
+
return cur_qa_logit_layer
|
1605 |
+
|
1606 |
+
# Build new linear output
|
1607 |
+
if getattr(cur_qa_logit_layer, "bias", None) is not None:
|
1608 |
+
new_qa_logit_layer = nn.Linear(hidden_dim, num_labels)
|
1609 |
+
else:
|
1610 |
+
new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False)
|
1611 |
+
|
1612 |
+
new_qa_logit_layer.to(cur_qa_logit_layer.weight.device)
|
1613 |
+
|
1614 |
+
# initialize all new labels
|
1615 |
+
self._init_weights(new_qa_logit_layer)
|
1616 |
+
|
1617 |
+
# Copy labels from the previous weights
|
1618 |
+
num_labels_to_copy = min(cur_qa_labels, num_labels)
|
1619 |
+
new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :]
|
1620 |
+
if getattr(cur_qa_logit_layer, "bias", None) is not None:
|
1621 |
+
new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy]
|
1622 |
+
|
1623 |
+
return new_qa_logit_layer
|
1624 |
+
|
1625 |
+
@add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1626 |
+
@add_code_sample_docstrings(
|
1627 |
+
tokenizer_class=_TOKENIZER_FOR_DOC,
|
1628 |
+
checkpoint="unc-nlp/lxmert-base-uncased",
|
1629 |
+
output_type=LxmertForQuestionAnsweringOutput,
|
1630 |
+
config_class=_CONFIG_FOR_DOC,
|
1631 |
+
)
|
1632 |
+
def forward(
|
1633 |
+
self,
|
1634 |
+
input_ids=None,
|
1635 |
+
visual_feats=None,
|
1636 |
+
visual_pos=None,
|
1637 |
+
attention_mask=None,
|
1638 |
+
visual_attention_mask=None,
|
1639 |
+
token_type_ids=None,
|
1640 |
+
inputs_embeds=None,
|
1641 |
+
labels=None,
|
1642 |
+
output_attentions=None,
|
1643 |
+
output_hidden_states=None,
|
1644 |
+
return_dict=None,
|
1645 |
+
):
|
1646 |
+
r"""
|
1647 |
+
labels: (``Torch.Tensor`` of shape ``(batch_size)``, `optional`):
|
1648 |
+
A one-hot representation of the correct answer
|
1649 |
+
|
1650 |
+
Returns:
|
1651 |
+
"""
|
1652 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1653 |
+
|
1654 |
+
lxmert_output = self.lxmert(
|
1655 |
+
input_ids=input_ids,
|
1656 |
+
visual_feats=visual_feats,
|
1657 |
+
visual_pos=visual_pos,
|
1658 |
+
token_type_ids=token_type_ids,
|
1659 |
+
attention_mask=attention_mask,
|
1660 |
+
visual_attention_mask=visual_attention_mask,
|
1661 |
+
inputs_embeds=inputs_embeds,
|
1662 |
+
output_hidden_states=output_hidden_states,
|
1663 |
+
output_attentions=output_attentions,
|
1664 |
+
return_dict=return_dict,
|
1665 |
+
)
|
1666 |
+
|
1667 |
+
pooled_output = lxmert_output[2]
|
1668 |
+
answer_score = self.answer_head(pooled_output)
|
1669 |
+
loss = None
|
1670 |
+
if labels is not None:
|
1671 |
+
loss = self.loss(answer_score.view(-1, self.num_qa_labels), labels.view(-1))
|
1672 |
+
|
1673 |
+
if not return_dict:
|
1674 |
+
output = (answer_score,) + lxmert_output[3:]
|
1675 |
+
return (loss,) + output if loss is not None else output
|
1676 |
+
|
1677 |
+
self.vis_shape = lxmert_output.vision_output.shape
|
1678 |
+
|
1679 |
+
return LxmertForQuestionAnsweringOutput(
|
1680 |
+
loss=loss,
|
1681 |
+
question_answering_score=answer_score,
|
1682 |
+
language_hidden_states=lxmert_output.language_hidden_states,
|
1683 |
+
vision_hidden_states=lxmert_output.vision_hidden_states,
|
1684 |
+
language_attentions=lxmert_output.language_attentions,
|
1685 |
+
vision_attentions=lxmert_output.vision_attentions,
|
1686 |
+
cross_encoder_attentions=lxmert_output.cross_encoder_attentions,
|
1687 |
+
)
|
1688 |
+
|
1689 |
+
def relprop(self, cam, **kwargs):
|
1690 |
+
cam_lang = self.answer_head.relprop(cam, **kwargs)
|
1691 |
+
cam_vis = torch.zeros(self.vis_shape).to(cam_lang.device)
|
1692 |
+
cam_lang, cam_vis = self.lxmert.relprop((cam_lang, cam_vis), **kwargs)
|
1693 |
+
return cam_lang, cam_vis
|
lxmert/src/lxrt/__init__.py
ADDED
File without changes
|
lxmert/src/lxrt/entry.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2019 project LXRT.
|
3 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
4 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import os
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
|
23 |
+
from ..lxrt.tokenization import BertTokenizer
|
24 |
+
from ..lxrt.modeling import LXRTFeatureExtraction as VisualBertForLXRFeature, VISUAL_CONFIG
|
25 |
+
|
26 |
+
|
27 |
+
class InputFeatures(object):
|
28 |
+
"""A single set of features of data."""
|
29 |
+
|
30 |
+
def __init__(self, input_ids, input_mask, segment_ids):
|
31 |
+
self.input_ids = input_ids
|
32 |
+
self.input_mask = input_mask
|
33 |
+
self.segment_ids = segment_ids
|
34 |
+
|
35 |
+
|
36 |
+
def convert_sents_to_features(sents, max_seq_length, tokenizer):
|
37 |
+
"""Loads a data file into a list of `InputBatch`s."""
|
38 |
+
|
39 |
+
features = []
|
40 |
+
for (i, sent) in enumerate(sents):
|
41 |
+
tokens_a = tokenizer.tokenize(sent.strip())
|
42 |
+
|
43 |
+
# Account for [CLS] and [SEP] with "- 2"
|
44 |
+
if len(tokens_a) > max_seq_length - 2:
|
45 |
+
tokens_a = tokens_a[:(max_seq_length - 2)]
|
46 |
+
|
47 |
+
# Keep segment id which allows loading BERT-weights.
|
48 |
+
tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
|
49 |
+
segment_ids = [0] * len(tokens)
|
50 |
+
|
51 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
52 |
+
|
53 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
54 |
+
# tokens are attended to.
|
55 |
+
input_mask = [1] * len(input_ids)
|
56 |
+
|
57 |
+
# Zero-pad up to the sequence length.
|
58 |
+
padding = [0] * (max_seq_length - len(input_ids))
|
59 |
+
input_ids += padding
|
60 |
+
input_mask += padding
|
61 |
+
segment_ids += padding
|
62 |
+
|
63 |
+
assert len(input_ids) == max_seq_length
|
64 |
+
assert len(input_mask) == max_seq_length
|
65 |
+
assert len(segment_ids) == max_seq_length
|
66 |
+
|
67 |
+
features.append(
|
68 |
+
InputFeatures(input_ids=input_ids,
|
69 |
+
input_mask=input_mask,
|
70 |
+
segment_ids=segment_ids))
|
71 |
+
return features
|
72 |
+
|
73 |
+
|
74 |
+
def set_visual_config(args):
|
75 |
+
VISUAL_CONFIG.l_layers = args.llayers
|
76 |
+
VISUAL_CONFIG.x_layers = args.xlayers
|
77 |
+
VISUAL_CONFIG.r_layers = args.rlayers
|
78 |
+
|
79 |
+
|
80 |
+
class LXRTEncoder(nn.Module):
|
81 |
+
def __init__(self, args, max_seq_length, mode='x'):
|
82 |
+
super().__init__()
|
83 |
+
self.max_seq_length = max_seq_length
|
84 |
+
set_visual_config(args)
|
85 |
+
|
86 |
+
# Using the bert tokenizer
|
87 |
+
self.tokenizer = BertTokenizer.from_pretrained(
|
88 |
+
"bert-base-uncased",
|
89 |
+
do_lower_case=True
|
90 |
+
)
|
91 |
+
|
92 |
+
# Build LXRT Model
|
93 |
+
self.model = VisualBertForLXRFeature.from_pretrained(
|
94 |
+
"bert-base-uncased",
|
95 |
+
mode=mode
|
96 |
+
)
|
97 |
+
|
98 |
+
if args.from_scratch:
|
99 |
+
print("initializing all the weights")
|
100 |
+
self.model.apply(self.model.init_bert_weights)
|
101 |
+
|
102 |
+
def multi_gpu(self):
|
103 |
+
self.model = nn.DataParallel(self.model)
|
104 |
+
|
105 |
+
@property
|
106 |
+
def dim(self):
|
107 |
+
return 768
|
108 |
+
|
109 |
+
def forward(self, sents, feats, visual_attention_mask=None):
|
110 |
+
train_features = convert_sents_to_features(
|
111 |
+
sents, self.max_seq_length, self.tokenizer)
|
112 |
+
|
113 |
+
input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda()
|
114 |
+
input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda()
|
115 |
+
segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda()
|
116 |
+
|
117 |
+
output = self.model(input_ids, segment_ids, input_mask,
|
118 |
+
visual_feats=feats,
|
119 |
+
visual_attention_mask=visual_attention_mask)
|
120 |
+
return output
|
121 |
+
|
122 |
+
def save(self, path):
|
123 |
+
torch.save(self.model.state_dict(),
|
124 |
+
os.path.join("%s_LXRT.pth" % path))
|
125 |
+
|
126 |
+
def load(self, path):
|
127 |
+
# Load state_dict from snapshot file
|
128 |
+
print("Load lxmert pre-trained model from %s" % path)
|
129 |
+
state_dict = torch.load("%s_LXRT.pth" % path)
|
130 |
+
new_state_dict = {}
|
131 |
+
for key, value in state_dict.items():
|
132 |
+
if key.startswith("module."):
|
133 |
+
new_state_dict[key[len("module."):]] = value
|
134 |
+
else:
|
135 |
+
new_state_dict[key] = value
|
136 |
+
state_dict = new_state_dict
|
137 |
+
|
138 |
+
# Print out the differences of pre-trained and model weights.
|
139 |
+
load_keys = set(state_dict.keys())
|
140 |
+
model_keys = set(self.model.state_dict().keys())
|
141 |
+
print()
|
142 |
+
print("Weights in loaded but not in model:")
|
143 |
+
for key in sorted(load_keys.difference(model_keys)):
|
144 |
+
print(key)
|
145 |
+
print()
|
146 |
+
print("Weights in model but not in loaded:")
|
147 |
+
for key in sorted(model_keys.difference(load_keys)):
|
148 |
+
print(key)
|
149 |
+
print()
|
150 |
+
|
151 |
+
# Load weights to model
|
152 |
+
self.model.load_state_dict(state_dict, strict=False)
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
|
lxmert/src/lxrt/file_utils.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for working with the local dataset cache.
|
3 |
+
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
4 |
+
Copyright by the AllenNLP authors.
|
5 |
+
"""
|
6 |
+
import json
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import shutil
|
10 |
+
import tempfile
|
11 |
+
from functools import wraps
|
12 |
+
from hashlib import sha256
|
13 |
+
import sys
|
14 |
+
from io import open
|
15 |
+
|
16 |
+
import boto3
|
17 |
+
import requests
|
18 |
+
from botocore.exceptions import ClientError
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
try:
|
22 |
+
from urllib.parse import urlparse
|
23 |
+
except ImportError:
|
24 |
+
from urlparse import urlparse
|
25 |
+
|
26 |
+
try:
|
27 |
+
from pathlib import Path
|
28 |
+
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
29 |
+
Path.home() / '.pytorch_pretrained_bert'))
|
30 |
+
except (AttributeError, ImportError):
|
31 |
+
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
32 |
+
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
|
33 |
+
|
34 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
35 |
+
|
36 |
+
|
37 |
+
def url_to_filename(url, etag=None):
|
38 |
+
"""
|
39 |
+
Convert `url` into a hashed filename in a repeatable way.
|
40 |
+
If `etag` is specified, append its hash to the url's, delimited
|
41 |
+
by a period.
|
42 |
+
"""
|
43 |
+
url_bytes = url.encode('utf-8')
|
44 |
+
url_hash = sha256(url_bytes)
|
45 |
+
filename = url_hash.hexdigest()
|
46 |
+
|
47 |
+
if etag:
|
48 |
+
etag_bytes = etag.encode('utf-8')
|
49 |
+
etag_hash = sha256(etag_bytes)
|
50 |
+
filename += '.' + etag_hash.hexdigest()
|
51 |
+
|
52 |
+
return filename
|
53 |
+
|
54 |
+
|
55 |
+
def filename_to_url(filename, cache_dir=None):
|
56 |
+
"""
|
57 |
+
Return the url and etag (which may be ``None``) stored for `filename`.
|
58 |
+
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
59 |
+
"""
|
60 |
+
if cache_dir is None:
|
61 |
+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
62 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
63 |
+
cache_dir = str(cache_dir)
|
64 |
+
|
65 |
+
cache_path = os.path.join(cache_dir, filename)
|
66 |
+
if not os.path.exists(cache_path):
|
67 |
+
raise EnvironmentError("file {} not found".format(cache_path))
|
68 |
+
|
69 |
+
meta_path = cache_path + '.json'
|
70 |
+
if not os.path.exists(meta_path):
|
71 |
+
raise EnvironmentError("file {} not found".format(meta_path))
|
72 |
+
|
73 |
+
with open(meta_path, encoding="utf-8") as meta_file:
|
74 |
+
metadata = json.load(meta_file)
|
75 |
+
url = metadata['url']
|
76 |
+
etag = metadata['etag']
|
77 |
+
|
78 |
+
return url, etag
|
79 |
+
|
80 |
+
|
81 |
+
def cached_path(url_or_filename, cache_dir=None):
|
82 |
+
"""
|
83 |
+
Given something that might be a URL (or might be a local path),
|
84 |
+
determine which. If it's a URL, download the file and cache it, and
|
85 |
+
return the path to the cached file. If it's already a local path,
|
86 |
+
make sure the file exists and then return the path.
|
87 |
+
"""
|
88 |
+
if cache_dir is None:
|
89 |
+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
90 |
+
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
|
91 |
+
url_or_filename = str(url_or_filename)
|
92 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
93 |
+
cache_dir = str(cache_dir)
|
94 |
+
|
95 |
+
parsed = urlparse(url_or_filename)
|
96 |
+
|
97 |
+
if parsed.scheme in ('http', 'https', 's3'):
|
98 |
+
# URL, so get it from the cache (downloading if necessary)
|
99 |
+
return get_from_cache(url_or_filename, cache_dir)
|
100 |
+
elif os.path.exists(url_or_filename):
|
101 |
+
# File, and it exists.
|
102 |
+
return url_or_filename
|
103 |
+
elif parsed.scheme == '':
|
104 |
+
# File, but it doesn't exist.
|
105 |
+
raise EnvironmentError("file {} not found".format(url_or_filename))
|
106 |
+
else:
|
107 |
+
# Something unknown
|
108 |
+
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
109 |
+
|
110 |
+
|
111 |
+
def split_s3_path(url):
|
112 |
+
"""Split a full s3 path into the bucket name and path."""
|
113 |
+
parsed = urlparse(url)
|
114 |
+
if not parsed.netloc or not parsed.path:
|
115 |
+
raise ValueError("bad s3 path {}".format(url))
|
116 |
+
bucket_name = parsed.netloc
|
117 |
+
s3_path = parsed.path
|
118 |
+
# Remove '/' at beginning of path.
|
119 |
+
if s3_path.startswith("/"):
|
120 |
+
s3_path = s3_path[1:]
|
121 |
+
return bucket_name, s3_path
|
122 |
+
|
123 |
+
|
124 |
+
def s3_request(func):
|
125 |
+
"""
|
126 |
+
Wrapper function for s3 requests in order to create more helpful error
|
127 |
+
messages.
|
128 |
+
"""
|
129 |
+
|
130 |
+
@wraps(func)
|
131 |
+
def wrapper(url, *args, **kwargs):
|
132 |
+
try:
|
133 |
+
return func(url, *args, **kwargs)
|
134 |
+
except ClientError as exc:
|
135 |
+
if int(exc.response["Error"]["Code"]) == 404:
|
136 |
+
raise EnvironmentError("file {} not found".format(url))
|
137 |
+
else:
|
138 |
+
raise
|
139 |
+
|
140 |
+
return wrapper
|
141 |
+
|
142 |
+
|
143 |
+
@s3_request
|
144 |
+
def s3_etag(url):
|
145 |
+
"""Check ETag on S3 object."""
|
146 |
+
s3_resource = boto3.resource("s3")
|
147 |
+
bucket_name, s3_path = split_s3_path(url)
|
148 |
+
s3_object = s3_resource.Object(bucket_name, s3_path)
|
149 |
+
return s3_object.e_tag
|
150 |
+
|
151 |
+
|
152 |
+
@s3_request
|
153 |
+
def s3_get(url, temp_file):
|
154 |
+
"""Pull a file directly from S3."""
|
155 |
+
s3_resource = boto3.resource("s3")
|
156 |
+
bucket_name, s3_path = split_s3_path(url)
|
157 |
+
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
158 |
+
|
159 |
+
|
160 |
+
def http_get(url, temp_file):
|
161 |
+
req = requests.get(url, stream=True)
|
162 |
+
content_length = req.headers.get('Content-Length')
|
163 |
+
total = int(content_length) if content_length is not None else None
|
164 |
+
progress = tqdm(unit="B", total=total)
|
165 |
+
for chunk in req.iter_content(chunk_size=1024):
|
166 |
+
if chunk: # filter out keep-alive new chunks
|
167 |
+
progress.update(len(chunk))
|
168 |
+
temp_file.write(chunk)
|
169 |
+
progress.close()
|
170 |
+
|
171 |
+
|
172 |
+
def get_from_cache(url, cache_dir=None):
|
173 |
+
"""
|
174 |
+
Given a URL, look for the corresponding dataset in the local cache.
|
175 |
+
If it's not there, download it. Then return the path to the cached file.
|
176 |
+
"""
|
177 |
+
if cache_dir is None:
|
178 |
+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
179 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
180 |
+
cache_dir = str(cache_dir)
|
181 |
+
|
182 |
+
if not os.path.exists(cache_dir):
|
183 |
+
os.makedirs(cache_dir)
|
184 |
+
|
185 |
+
# Get eTag to add to filename, if it exists.
|
186 |
+
if url.startswith("s3://"):
|
187 |
+
etag = s3_etag(url)
|
188 |
+
else:
|
189 |
+
response = requests.head(url, allow_redirects=True)
|
190 |
+
if response.status_code != 200:
|
191 |
+
raise IOError("HEAD request failed for url {} with status code {}"
|
192 |
+
.format(url, response.status_code))
|
193 |
+
etag = response.headers.get("ETag")
|
194 |
+
|
195 |
+
filename = url_to_filename(url, etag)
|
196 |
+
|
197 |
+
# get cache path to put the file
|
198 |
+
cache_path = os.path.join(cache_dir, filename)
|
199 |
+
|
200 |
+
if not os.path.exists(cache_path):
|
201 |
+
# Download to temporary file, then copy to cache dir once finished.
|
202 |
+
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
203 |
+
with tempfile.NamedTemporaryFile() as temp_file:
|
204 |
+
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
|
205 |
+
|
206 |
+
# GET file object
|
207 |
+
if url.startswith("s3://"):
|
208 |
+
s3_get(url, temp_file)
|
209 |
+
else:
|
210 |
+
http_get(url, temp_file)
|
211 |
+
|
212 |
+
# we are copying the file before closing it, so flush to avoid truncation
|
213 |
+
temp_file.flush()
|
214 |
+
# shutil.copyfileobj() starts at the current position, so go to the start
|
215 |
+
temp_file.seek(0)
|
216 |
+
|
217 |
+
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
|
218 |
+
with open(cache_path, 'wb') as cache_file:
|
219 |
+
shutil.copyfileobj(temp_file, cache_file)
|
220 |
+
|
221 |
+
logger.info("creating metadata file for %s", cache_path)
|
222 |
+
meta = {'url': url, 'etag': etag}
|
223 |
+
meta_path = cache_path + '.json'
|
224 |
+
with open(meta_path, 'w', encoding="utf-8") as meta_file:
|
225 |
+
json.dump(meta, meta_file)
|
226 |
+
|
227 |
+
logger.info("removing temp file %s", temp_file.name)
|
228 |
+
|
229 |
+
return cache_path
|
230 |
+
|
231 |
+
|
232 |
+
def read_set_from_file(filename):
|
233 |
+
'''
|
234 |
+
Extract a de-duped collection (set) of text from a file.
|
235 |
+
Expected file format is one item per line.
|
236 |
+
'''
|
237 |
+
collection = set()
|
238 |
+
with open(filename, 'r', encoding='utf-8') as file_:
|
239 |
+
for line in file_:
|
240 |
+
collection.add(line.rstrip())
|
241 |
+
return collection
|
242 |
+
|
243 |
+
|
244 |
+
def get_file_extension(path, dot=True, lower=True):
|
245 |
+
ext = os.path.splitext(path)[1]
|
246 |
+
ext = ext if dot else ext[1:]
|
247 |
+
return ext.lower() if lower else ext
|
lxmert/src/lxrt/modeling.py
ADDED
@@ -0,0 +1,1018 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2019 project LXRT.
|
3 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
4 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
"""PyTorch LXRT model."""
|
18 |
+
|
19 |
+
import copy
|
20 |
+
import json
|
21 |
+
import logging
|
22 |
+
import math
|
23 |
+
import os
|
24 |
+
import shutil
|
25 |
+
import tarfile
|
26 |
+
import tempfile
|
27 |
+
import sys
|
28 |
+
from io import open
|
29 |
+
|
30 |
+
import torch
|
31 |
+
from torch import nn
|
32 |
+
from torch.nn import CrossEntropyLoss, SmoothL1Loss
|
33 |
+
|
34 |
+
from .file_utils import cached_path
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
|
38 |
+
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
39 |
+
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
|
40 |
+
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
|
41 |
+
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
|
42 |
+
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
|
43 |
+
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
|
44 |
+
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
45 |
+
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
46 |
+
}
|
47 |
+
CONFIG_NAME = 'bert_config.json'
|
48 |
+
WEIGHTS_NAME = 'pytorch_model.bin'
|
49 |
+
TF_WEIGHTS_NAME = 'model.ckpt'
|
50 |
+
|
51 |
+
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
52 |
+
""" Load tf checkpoints in a pytorch model
|
53 |
+
"""
|
54 |
+
try:
|
55 |
+
import re
|
56 |
+
import numpy as np
|
57 |
+
import tensorflow as tf
|
58 |
+
except Importtokenization:
|
59 |
+
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
60 |
+
"https://www.tensorflow.org/install/ for installation instructions.")
|
61 |
+
raise
|
62 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
63 |
+
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
64 |
+
# Load weights from TF model
|
65 |
+
init_vars = tf.train.list_variables(tf_path)
|
66 |
+
names = []
|
67 |
+
arrays = []
|
68 |
+
for name, shape in init_vars:
|
69 |
+
print("Loading TF weight {} with shape {}".format(name, shape))
|
70 |
+
array = tf.train.load_variable(tf_path, name)
|
71 |
+
names.append(name)
|
72 |
+
arrays.append(array)
|
73 |
+
|
74 |
+
for name, array in zip(names, arrays):
|
75 |
+
name = name.split('/')
|
76 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
77 |
+
# which are not required for using pretrained model
|
78 |
+
if any(n in ["adam_v", "adam_m"] for n in name):
|
79 |
+
print("Skipping {}".format("/".join(name)))
|
80 |
+
continue
|
81 |
+
pointer = model
|
82 |
+
for m_name in name:
|
83 |
+
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
84 |
+
l = re.split(r'_(\d+)', m_name)
|
85 |
+
else:
|
86 |
+
l = [m_name]
|
87 |
+
if l[0] == 'kernel' or l[0] == 'gamma':
|
88 |
+
pointer = getattr(pointer, 'weight')
|
89 |
+
elif l[0] == 'output_bias' or l[0] == 'beta':
|
90 |
+
pointer = getattr(pointer, 'bias')
|
91 |
+
elif l[0] == 'output_weights':
|
92 |
+
pointer = getattr(pointer, 'weight')
|
93 |
+
else:
|
94 |
+
pointer = getattr(pointer, l[0])
|
95 |
+
if len(l) >= 2:
|
96 |
+
num = int(l[1])
|
97 |
+
pointer = pointer[num]
|
98 |
+
if m_name[-11:] == '_embeddings':
|
99 |
+
pointer = getattr(pointer, 'weight')
|
100 |
+
elif m_name == 'kernel':
|
101 |
+
array = np.transpose(array)
|
102 |
+
try:
|
103 |
+
assert pointer.shape == array.shape
|
104 |
+
except AssertionError as e:
|
105 |
+
e.args += (pointer.shape, array.shape)
|
106 |
+
raise
|
107 |
+
print("Initialize PyTorch weight {}".format(name))
|
108 |
+
pointer.data = torch.from_numpy(array)
|
109 |
+
return model
|
110 |
+
|
111 |
+
|
112 |
+
def gelu(x):
|
113 |
+
"""Implementation of the gelu activation function.
|
114 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
115 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
116 |
+
Also see https://arxiv.org/abs/1606.08415
|
117 |
+
"""
|
118 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
119 |
+
|
120 |
+
|
121 |
+
class GeLU(nn.Module):
|
122 |
+
"""Implementation of the gelu activation function.
|
123 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
124 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
125 |
+
Also see https://arxiv.org/abs/1606.08415
|
126 |
+
"""
|
127 |
+
def __init__(self):
|
128 |
+
super().__init__()
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
return gelu(x)
|
132 |
+
|
133 |
+
|
134 |
+
def swish(x):
|
135 |
+
return x * torch.sigmoid(x)
|
136 |
+
|
137 |
+
|
138 |
+
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
139 |
+
|
140 |
+
|
141 |
+
class VisualConfig(object):
|
142 |
+
VISUAL_LOSSES = ['obj', 'attr', 'feat']
|
143 |
+
def __init__(self,
|
144 |
+
l_layers=12,
|
145 |
+
x_layers=5,
|
146 |
+
r_layers=0):
|
147 |
+
self.l_layers = l_layers
|
148 |
+
self.x_layers = x_layers
|
149 |
+
self.r_layers = r_layers
|
150 |
+
|
151 |
+
self.visual_feat_dim = 2048
|
152 |
+
self.visual_pos_dim = 4
|
153 |
+
|
154 |
+
self.obj_id_num = 1600
|
155 |
+
self.attr_id_num = 400
|
156 |
+
|
157 |
+
self.visual_losses = self.VISUAL_LOSSES
|
158 |
+
self.visual_loss_config = {
|
159 |
+
'obj': (self.obj_id_num, 'ce', (-1,), 1/0.15),
|
160 |
+
'attr': (self.attr_id_num, 'ce', (-1,), 1/0.15),
|
161 |
+
'feat': (2048, 'l2', (-1, 2048), 1/0.15),
|
162 |
+
}
|
163 |
+
|
164 |
+
def set_visual_dims(self, feat_dim, pos_dim):
|
165 |
+
self.visual_feat_dim = feat_dim
|
166 |
+
self.visual_pos_dim = pos_dim
|
167 |
+
|
168 |
+
|
169 |
+
VISUAL_CONFIG = VisualConfig()
|
170 |
+
|
171 |
+
|
172 |
+
class BertConfig(object):
|
173 |
+
"""Configuration class to store the configuration of a `BertModel`.
|
174 |
+
"""
|
175 |
+
def __init__(self,
|
176 |
+
vocab_size_or_config_json_file,
|
177 |
+
hidden_size=768,
|
178 |
+
num_hidden_layers=12,
|
179 |
+
num_attention_heads=12,
|
180 |
+
intermediate_size=3072,
|
181 |
+
hidden_act="gelu",
|
182 |
+
hidden_dropout_prob=0.1,
|
183 |
+
attention_probs_dropout_prob=0.1,
|
184 |
+
max_position_embeddings=512,
|
185 |
+
type_vocab_size=2,
|
186 |
+
initializer_range=0.02):
|
187 |
+
"""Constructs BertConfig.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
|
191 |
+
hidden_size: Size of the encoder layers and the pooler layer.
|
192 |
+
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
193 |
+
num_attention_heads: Number of attention heads for each attention layer in
|
194 |
+
the Transformer encoder.
|
195 |
+
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
196 |
+
layer in the Transformer encoder.
|
197 |
+
hidden_act: The non-linear activation function (function or string) in the
|
198 |
+
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
199 |
+
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
200 |
+
layers in the embeddings, encoder, and pooler.
|
201 |
+
attention_probs_dropout_prob: The dropout ratio for the attention
|
202 |
+
probabilities.
|
203 |
+
max_position_embeddings: The maximum sequence length that this model might
|
204 |
+
ever be used with. Typically set this to something large just in case
|
205 |
+
(e.g., 512 or 1024 or 2048).
|
206 |
+
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
207 |
+
`BertModel`.
|
208 |
+
initializer_range: The sttdev of the truncated_normal_initializer for
|
209 |
+
initializing all weight matrices.
|
210 |
+
"""
|
211 |
+
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
212 |
+
and isinstance(vocab_size_or_config_json_file, unicode)):
|
213 |
+
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
214 |
+
json_config = json.loads(reader.read())
|
215 |
+
for key, value in json_config.items():
|
216 |
+
self.__dict__[key] = value
|
217 |
+
elif isinstance(vocab_size_or_config_json_file, int):
|
218 |
+
self.vocab_size = vocab_size_or_config_json_file
|
219 |
+
self.hidden_size = hidden_size
|
220 |
+
self.num_hidden_layers = num_hidden_layers
|
221 |
+
self.num_attention_heads = num_attention_heads
|
222 |
+
self.hidden_act = hidden_act
|
223 |
+
self.intermediate_size = intermediate_size
|
224 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
225 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
226 |
+
self.max_position_embeddings = max_position_embeddings
|
227 |
+
self.type_vocab_size = type_vocab_size
|
228 |
+
self.initializer_range = initializer_range
|
229 |
+
else:
|
230 |
+
raise ValueError("First argument must be either a vocabulary size (int)"
|
231 |
+
"or the path to a pretrained model config file (str)")
|
232 |
+
|
233 |
+
@classmethod
|
234 |
+
def from_dict(cls, json_object):
|
235 |
+
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
236 |
+
config = BertConfig(vocab_size_or_config_json_file=-1)
|
237 |
+
for key, value in json_object.items():
|
238 |
+
config.__dict__[key] = value
|
239 |
+
return config
|
240 |
+
|
241 |
+
@classmethod
|
242 |
+
def from_json_file(cls, json_file):
|
243 |
+
"""Constructs a `BertConfig` from a json file of parameters."""
|
244 |
+
with open(json_file, "r", encoding='utf-8') as reader:
|
245 |
+
text = reader.read()
|
246 |
+
return cls.from_dict(json.loads(text))
|
247 |
+
|
248 |
+
def __repr__(self):
|
249 |
+
return str(self.to_json_string())
|
250 |
+
|
251 |
+
def to_dict(self):
|
252 |
+
"""Serializes this instance to a Python dictionary."""
|
253 |
+
output = copy.deepcopy(self.__dict__)
|
254 |
+
return output
|
255 |
+
|
256 |
+
def to_json_string(self):
|
257 |
+
"""Serializes this instance to a JSON string."""
|
258 |
+
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
259 |
+
|
260 |
+
|
261 |
+
BertLayerNorm = torch.nn.LayerNorm
|
262 |
+
|
263 |
+
|
264 |
+
class BertEmbeddings(nn.Module):
|
265 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
266 |
+
"""
|
267 |
+
def __init__(self, config):
|
268 |
+
super(BertEmbeddings, self).__init__()
|
269 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
270 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0)
|
271 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0)
|
272 |
+
|
273 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
274 |
+
# any TensorFlow checkpoint file
|
275 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
276 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
277 |
+
|
278 |
+
def forward(self, input_ids, token_type_ids=None):
|
279 |
+
seq_length = input_ids.size(1)
|
280 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
281 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
282 |
+
if token_type_ids is None:
|
283 |
+
token_type_ids = torch.zeros_like(input_ids)
|
284 |
+
|
285 |
+
words_embeddings = self.word_embeddings(input_ids)
|
286 |
+
position_embeddings = self.position_embeddings(position_ids)
|
287 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
288 |
+
|
289 |
+
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
290 |
+
embeddings = self.LayerNorm(embeddings)
|
291 |
+
embeddings = self.dropout(embeddings)
|
292 |
+
return embeddings
|
293 |
+
|
294 |
+
|
295 |
+
class BertAttention(nn.Module):
|
296 |
+
def __init__(self, config, ctx_dim=None):
|
297 |
+
super().__init__()
|
298 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
299 |
+
raise ValueError(
|
300 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
301 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
302 |
+
self.num_attention_heads = config.num_attention_heads
|
303 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
304 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
305 |
+
|
306 |
+
# visual_dim = 2048
|
307 |
+
if ctx_dim is None:
|
308 |
+
ctx_dim =config.hidden_size
|
309 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
310 |
+
self.key = nn.Linear(ctx_dim, self.all_head_size)
|
311 |
+
self.value = nn.Linear(ctx_dim, self.all_head_size)
|
312 |
+
|
313 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
314 |
+
|
315 |
+
def transpose_for_scores(self, x):
|
316 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
317 |
+
x = x.view(*new_x_shape)
|
318 |
+
return x.permute(0, 2, 1, 3)
|
319 |
+
|
320 |
+
def forward(self, hidden_states, context, attention_mask=None):
|
321 |
+
mixed_query_layer = self.query(hidden_states)
|
322 |
+
mixed_key_layer = self.key(context)
|
323 |
+
mixed_value_layer = self.value(context)
|
324 |
+
|
325 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
326 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
327 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
328 |
+
|
329 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
330 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
331 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
332 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
333 |
+
if attention_mask is not None:
|
334 |
+
attention_scores = attention_scores + attention_mask
|
335 |
+
|
336 |
+
# Normalize the attention scores to probabilities.
|
337 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
338 |
+
|
339 |
+
# This is actually dropping out entire tokens to attend to, which might
|
340 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
341 |
+
attention_probs = self.dropout(attention_probs)
|
342 |
+
|
343 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
344 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
345 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
346 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
347 |
+
return context_layer
|
348 |
+
|
349 |
+
|
350 |
+
class BertAttOutput(nn.Module):
|
351 |
+
def __init__(self, config):
|
352 |
+
super(BertAttOutput, self).__init__()
|
353 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
354 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
355 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
356 |
+
|
357 |
+
def forward(self, hidden_states, input_tensor):
|
358 |
+
hidden_states = self.dense(hidden_states)
|
359 |
+
hidden_states = self.dropout(hidden_states)
|
360 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
361 |
+
return hidden_states
|
362 |
+
|
363 |
+
|
364 |
+
class BertCrossattLayer(nn.Module):
|
365 |
+
def __init__(self, config):
|
366 |
+
super().__init__()
|
367 |
+
self.att = BertAttention(config)
|
368 |
+
self.output = BertAttOutput(config)
|
369 |
+
|
370 |
+
def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None):
|
371 |
+
output = self.att(input_tensor, ctx_tensor, ctx_att_mask)
|
372 |
+
attention_output = self.output(output, input_tensor)
|
373 |
+
return attention_output
|
374 |
+
|
375 |
+
|
376 |
+
class BertSelfattLayer(nn.Module):
|
377 |
+
def __init__(self, config):
|
378 |
+
super(BertSelfattLayer, self).__init__()
|
379 |
+
self.self = BertAttention(config)
|
380 |
+
self.output = BertAttOutput(config)
|
381 |
+
|
382 |
+
def forward(self, input_tensor, attention_mask):
|
383 |
+
# Self attention attends to itself, thus keys and querys are the same (input_tensor).
|
384 |
+
self_output = self.self(input_tensor, input_tensor, attention_mask)
|
385 |
+
attention_output = self.output(self_output, input_tensor)
|
386 |
+
return attention_output
|
387 |
+
|
388 |
+
|
389 |
+
class BertIntermediate(nn.Module):
|
390 |
+
def __init__(self, config):
|
391 |
+
super(BertIntermediate, self).__init__()
|
392 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
393 |
+
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
394 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
395 |
+
else:
|
396 |
+
self.intermediate_act_fn = config.hidden_act
|
397 |
+
|
398 |
+
def forward(self, hidden_states):
|
399 |
+
hidden_states = self.dense(hidden_states)
|
400 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
401 |
+
return hidden_states
|
402 |
+
|
403 |
+
|
404 |
+
class BertOutput(nn.Module):
|
405 |
+
def __init__(self, config):
|
406 |
+
super(BertOutput, self).__init__()
|
407 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
408 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
409 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
410 |
+
|
411 |
+
def forward(self, hidden_states, input_tensor):
|
412 |
+
hidden_states = self.dense(hidden_states)
|
413 |
+
hidden_states = self.dropout(hidden_states)
|
414 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
415 |
+
return hidden_states
|
416 |
+
|
417 |
+
|
418 |
+
class BertLayer(nn.Module):
|
419 |
+
def __init__(self, config):
|
420 |
+
super(BertLayer, self).__init__()
|
421 |
+
self.attention = BertSelfattLayer(config)
|
422 |
+
self.intermediate = BertIntermediate(config)
|
423 |
+
self.output = BertOutput(config)
|
424 |
+
|
425 |
+
def forward(self, hidden_states, attention_mask):
|
426 |
+
attention_output = self.attention(hidden_states, attention_mask)
|
427 |
+
intermediate_output = self.intermediate(attention_output)
|
428 |
+
layer_output = self.output(intermediate_output, attention_output)
|
429 |
+
return layer_output
|
430 |
+
|
431 |
+
|
432 |
+
"""
|
433 |
+
---------------------------------------------------------------------------------------
|
434 |
+
Above modules are copied from BERT (pytorch-transformer) with modifications.
|
435 |
+
---------------------------------------------------------------------------------------
|
436 |
+
"""
|
437 |
+
|
438 |
+
|
439 |
+
class LXRTXLayer(nn.Module):
|
440 |
+
def __init__(self, config):
|
441 |
+
super().__init__()
|
442 |
+
# The cross-attention Layer
|
443 |
+
self.visual_attention = BertCrossattLayer(config)
|
444 |
+
|
445 |
+
# Self-attention Layers
|
446 |
+
self.lang_self_att = BertSelfattLayer(config)
|
447 |
+
self.visn_self_att = BertSelfattLayer(config)
|
448 |
+
|
449 |
+
# Intermediate and Output Layers (FFNs)
|
450 |
+
self.lang_inter = BertIntermediate(config)
|
451 |
+
self.lang_output = BertOutput(config)
|
452 |
+
self.visn_inter = BertIntermediate(config)
|
453 |
+
self.visn_output = BertOutput(config)
|
454 |
+
|
455 |
+
def cross_att(self, lang_input, lang_attention_mask, visn_input, visn_attention_mask):
|
456 |
+
# Cross Attention
|
457 |
+
lang_att_output = self.visual_attention(lang_input, visn_input, ctx_att_mask=visn_attention_mask)
|
458 |
+
visn_att_output = self.visual_attention(visn_input, lang_input, ctx_att_mask=lang_attention_mask)
|
459 |
+
return lang_att_output, visn_att_output
|
460 |
+
|
461 |
+
def self_att(self, lang_input, lang_attention_mask, visn_input, visn_attention_mask):
|
462 |
+
# Self Attention
|
463 |
+
lang_att_output = self.lang_self_att(lang_input, lang_attention_mask)
|
464 |
+
visn_att_output = self.visn_self_att(visn_input, visn_attention_mask)
|
465 |
+
return lang_att_output, visn_att_output
|
466 |
+
|
467 |
+
def output_fc(self, lang_input, visn_input):
|
468 |
+
# FC layers
|
469 |
+
lang_inter_output = self.lang_inter(lang_input)
|
470 |
+
visn_inter_output = self.visn_inter(visn_input)
|
471 |
+
|
472 |
+
# Layer output
|
473 |
+
lang_output = self.lang_output(lang_inter_output, lang_input)
|
474 |
+
visn_output = self.visn_output(visn_inter_output, visn_input)
|
475 |
+
return lang_output, visn_output
|
476 |
+
|
477 |
+
def forward(self, lang_feats, lang_attention_mask,
|
478 |
+
visn_feats, visn_attention_mask):
|
479 |
+
lang_att_output = lang_feats
|
480 |
+
visn_att_output = visn_feats
|
481 |
+
|
482 |
+
lang_att_output, visn_att_output = self.cross_att(lang_att_output, lang_attention_mask,
|
483 |
+
visn_att_output, visn_attention_mask)
|
484 |
+
lang_att_output, visn_att_output = self.self_att(lang_att_output, lang_attention_mask,
|
485 |
+
visn_att_output, visn_attention_mask)
|
486 |
+
lang_output, visn_output = self.output_fc(lang_att_output, visn_att_output)
|
487 |
+
|
488 |
+
return lang_output, visn_output
|
489 |
+
|
490 |
+
|
491 |
+
class VisualFeatEncoder(nn.Module):
|
492 |
+
def __init__(self, config):
|
493 |
+
super().__init__()
|
494 |
+
feat_dim = VISUAL_CONFIG.visual_feat_dim
|
495 |
+
pos_dim = VISUAL_CONFIG.visual_pos_dim
|
496 |
+
|
497 |
+
# Object feature encoding
|
498 |
+
self.visn_fc = nn.Linear(feat_dim, config.hidden_size)
|
499 |
+
self.visn_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
500 |
+
|
501 |
+
# Box position encoding
|
502 |
+
self.box_fc = nn.Linear(pos_dim, config.hidden_size)
|
503 |
+
self.box_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
504 |
+
|
505 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
506 |
+
|
507 |
+
def forward(self, visn_input):
|
508 |
+
feats, boxes = visn_input
|
509 |
+
|
510 |
+
x = self.visn_fc(feats)
|
511 |
+
x = self.visn_layer_norm(x)
|
512 |
+
y = self.box_fc(boxes)
|
513 |
+
y = self.box_layer_norm(y)
|
514 |
+
output = (x + y) / 2
|
515 |
+
|
516 |
+
output = self.dropout(output)
|
517 |
+
return output
|
518 |
+
|
519 |
+
|
520 |
+
class LXRTEncoder(nn.Module):
|
521 |
+
def __init__(self, config):
|
522 |
+
super().__init__()
|
523 |
+
|
524 |
+
# Obj-level image embedding layer
|
525 |
+
self.visn_fc = VisualFeatEncoder(config)
|
526 |
+
|
527 |
+
# Number of layers
|
528 |
+
self.num_l_layers = VISUAL_CONFIG.l_layers
|
529 |
+
self.num_x_layers = VISUAL_CONFIG.x_layers
|
530 |
+
self.num_r_layers = VISUAL_CONFIG.r_layers
|
531 |
+
print("LXRT encoder with %d l_layers, %d x_layers, and %d r_layers." %
|
532 |
+
(self.num_l_layers, self.num_x_layers, self.num_r_layers))
|
533 |
+
|
534 |
+
# Layers
|
535 |
+
# Using self.layer instead of self.l_layer to support loading BERT weights.
|
536 |
+
self.layer = nn.ModuleList(
|
537 |
+
[BertLayer(config) for _ in range(self.num_l_layers)]
|
538 |
+
)
|
539 |
+
self.x_layers = nn.ModuleList(
|
540 |
+
[LXRTXLayer(config) for _ in range(self.num_x_layers)]
|
541 |
+
)
|
542 |
+
self.r_layers = nn.ModuleList(
|
543 |
+
[BertLayer(config) for _ in range(self.num_r_layers)]
|
544 |
+
)
|
545 |
+
|
546 |
+
def forward(self, lang_feats, lang_attention_mask,
|
547 |
+
visn_feats, visn_attention_mask=None):
|
548 |
+
# Run visual embedding layer
|
549 |
+
# Note: Word embedding layer was executed outside this module.
|
550 |
+
# Keep this design to allow loading BERT weights.
|
551 |
+
visn_feats = self.visn_fc(visn_feats)
|
552 |
+
|
553 |
+
# Run language layers
|
554 |
+
for layer_module in self.layer:
|
555 |
+
lang_feats = layer_module(lang_feats, lang_attention_mask)
|
556 |
+
|
557 |
+
# Run relational layers
|
558 |
+
for layer_module in self.r_layers:
|
559 |
+
visn_feats = layer_module(visn_feats, visn_attention_mask)
|
560 |
+
|
561 |
+
# Run cross-modality layers
|
562 |
+
for layer_module in self.x_layers:
|
563 |
+
lang_feats, visn_feats = layer_module(lang_feats, lang_attention_mask,
|
564 |
+
visn_feats, visn_attention_mask)
|
565 |
+
|
566 |
+
return lang_feats, visn_feats
|
567 |
+
|
568 |
+
|
569 |
+
class BertPooler(nn.Module):
|
570 |
+
def __init__(self, config):
|
571 |
+
super(BertPooler, self).__init__()
|
572 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
573 |
+
self.activation = nn.Tanh()
|
574 |
+
|
575 |
+
def forward(self, hidden_states):
|
576 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
577 |
+
# to the first token.
|
578 |
+
first_token_tensor = hidden_states[:, 0]
|
579 |
+
pooled_output = self.dense(first_token_tensor)
|
580 |
+
pooled_output = self.activation(pooled_output)
|
581 |
+
return pooled_output
|
582 |
+
|
583 |
+
|
584 |
+
class BertPredictionHeadTransform(nn.Module):
|
585 |
+
def __init__(self, config):
|
586 |
+
super(BertPredictionHeadTransform, self).__init__()
|
587 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
588 |
+
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
589 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
590 |
+
else:
|
591 |
+
self.transform_act_fn = config.hidden_act
|
592 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
593 |
+
|
594 |
+
def forward(self, hidden_states):
|
595 |
+
hidden_states = self.dense(hidden_states)
|
596 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
597 |
+
hidden_states = self.LayerNorm(hidden_states)
|
598 |
+
return hidden_states
|
599 |
+
|
600 |
+
|
601 |
+
class BertLMPredictionHead(nn.Module):
|
602 |
+
def __init__(self, config, bert_model_embedding_weights):
|
603 |
+
super(BertLMPredictionHead, self).__init__()
|
604 |
+
self.transform = BertPredictionHeadTransform(config)
|
605 |
+
|
606 |
+
# The output weights are the same as the input embeddings, but there is
|
607 |
+
# an output-only bias for each token.
|
608 |
+
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
609 |
+
bert_model_embedding_weights.size(0),
|
610 |
+
bias=False)
|
611 |
+
self.decoder.weight = bert_model_embedding_weights
|
612 |
+
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
|
613 |
+
|
614 |
+
def forward(self, hidden_states):
|
615 |
+
hidden_states = self.transform(hidden_states)
|
616 |
+
hidden_states = self.decoder(hidden_states) + self.bias
|
617 |
+
return hidden_states
|
618 |
+
|
619 |
+
|
620 |
+
class BertVisualAnswerHead(nn.Module):
|
621 |
+
def __init__(self, config, num_answers):
|
622 |
+
super().__init__()
|
623 |
+
hid_dim = config.hidden_size
|
624 |
+
self.logit_fc = nn.Sequential(
|
625 |
+
nn.Linear(hid_dim, hid_dim * 2),
|
626 |
+
GeLU(),
|
627 |
+
BertLayerNorm(hid_dim * 2, eps=1e-12),
|
628 |
+
nn.Linear(hid_dim * 2, num_answers)
|
629 |
+
)
|
630 |
+
|
631 |
+
def forward(self, hidden_states):
|
632 |
+
return self.logit_fc(hidden_states)
|
633 |
+
|
634 |
+
|
635 |
+
class BertVisualObjHead(nn.Module):
|
636 |
+
def __init__(self, config, visual_losses):
|
637 |
+
super().__init__()
|
638 |
+
self.transform = BertPredictionHeadTransform(config)
|
639 |
+
|
640 |
+
# Decide the use of visual losses
|
641 |
+
visual_losses = visual_losses.split(",")
|
642 |
+
for loss in visual_losses:
|
643 |
+
assert loss in VISUAL_CONFIG.VISUAL_LOSSES
|
644 |
+
self.visual_losses = visual_losses
|
645 |
+
|
646 |
+
# The output weights are the same as the input embeddings, but there is
|
647 |
+
# an output-only bias for each token.
|
648 |
+
self.decoder_dict = nn.ModuleDict({
|
649 |
+
key: nn.Linear(config.hidden_size, VISUAL_CONFIG.visual_loss_config[key][0])
|
650 |
+
for key in self.visual_losses
|
651 |
+
})
|
652 |
+
|
653 |
+
def forward(self, hidden_states):
|
654 |
+
hidden_states = self.transform(hidden_states)
|
655 |
+
output = {}
|
656 |
+
for key in self.visual_losses:
|
657 |
+
output[key] = self.decoder_dict[key](hidden_states)
|
658 |
+
return output
|
659 |
+
|
660 |
+
|
661 |
+
class BertPreTrainingHeads(nn.Module):
|
662 |
+
def __init__(self, config, bert_model_embedding_weights):
|
663 |
+
super(BertPreTrainingHeads, self).__init__()
|
664 |
+
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
|
665 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
666 |
+
|
667 |
+
def forward(self, sequence_output, pooled_output):
|
668 |
+
prediction_scores = self.predictions(sequence_output)
|
669 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
670 |
+
return prediction_scores, seq_relationship_score
|
671 |
+
|
672 |
+
|
673 |
+
class BertPreTrainedModel(nn.Module):
|
674 |
+
""" An abstract class to handle weights initialization and
|
675 |
+
a simple interface for dowloading and loading pretrained models.
|
676 |
+
"""
|
677 |
+
def __init__(self, config, *inputs, **kwargs):
|
678 |
+
super(BertPreTrainedModel, self).__init__()
|
679 |
+
if not isinstance(config, BertConfig):
|
680 |
+
raise ValueError(
|
681 |
+
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
|
682 |
+
"To create a model from a Google pretrained model use "
|
683 |
+
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
684 |
+
self.__class__.__name__, self.__class__.__name__
|
685 |
+
))
|
686 |
+
self.config = config
|
687 |
+
|
688 |
+
def init_bert_weights(self, module):
|
689 |
+
""" Initialize the weights.
|
690 |
+
"""
|
691 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
692 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
693 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
694 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
695 |
+
elif isinstance(module, BertLayerNorm):
|
696 |
+
module.bias.data.zero_()
|
697 |
+
module.weight.data.fill_(1.0)
|
698 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
699 |
+
module.bias.data.zero_()
|
700 |
+
|
701 |
+
@classmethod
|
702 |
+
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
|
703 |
+
from_tf=False, *inputs, **kwargs):
|
704 |
+
"""
|
705 |
+
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
706 |
+
Download and cache the pre-trained model file if needed.
|
707 |
+
|
708 |
+
Params:
|
709 |
+
pretrained_model_name_or_path: either:
|
710 |
+
- a str with the name of a pre-trained model to load selected in the list of:
|
711 |
+
. `bert-base-uncased`
|
712 |
+
. `bert-large-uncased`
|
713 |
+
. `bert-base-cased`
|
714 |
+
. `bert-large-cased`
|
715 |
+
. `bert-base-multilingual-uncased`
|
716 |
+
. `bert-base-multilingual-cased`
|
717 |
+
. `bert-base-chinese`
|
718 |
+
- a path or url to a pretrained model archive containing:
|
719 |
+
. `bert_config.json` a configuration file for the model
|
720 |
+
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
721 |
+
- a path or url to a pretrained model archive containing:
|
722 |
+
. `bert_config.json` a configuration file for the model
|
723 |
+
. `model.chkpt` a TensorFlow checkpoint
|
724 |
+
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
725 |
+
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
726 |
+
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
727 |
+
*inputs, **kwargs: additional input for the specific Bert class
|
728 |
+
(ex: num_labels for BertForSequenceClassification)
|
729 |
+
"""
|
730 |
+
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
731 |
+
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
732 |
+
else:
|
733 |
+
archive_file = pretrained_model_name_or_path
|
734 |
+
# redirect to the cache, if necessary
|
735 |
+
try:
|
736 |
+
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
737 |
+
except EnvironmentError:
|
738 |
+
if pretrained_model_name_or_path == 'bert-base-uncased':
|
739 |
+
try:
|
740 |
+
print("The BERT-weight-downloading query to AWS was time-out;"
|
741 |
+
"trying to download from UNC servers")
|
742 |
+
archive_file = "https://nlp.cs.unc.edu/data/bert/bert-base-uncased.tar.gz"
|
743 |
+
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
744 |
+
except EnvironmentError:
|
745 |
+
print("The weight-downloading still crashed with link: %s, "
|
746 |
+
"please check your network connection" % archive_file)
|
747 |
+
return None
|
748 |
+
else:
|
749 |
+
logger.error(
|
750 |
+
"Model name '{}' was not found in model name list ({}). "
|
751 |
+
"We assumed '{}' was a path or url but couldn't find any file "
|
752 |
+
"associated to this path or url.".format(
|
753 |
+
pretrained_model_name_or_path,
|
754 |
+
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
755 |
+
archive_file))
|
756 |
+
if resolved_archive_file == archive_file:
|
757 |
+
logger.info("loading archive file {}".format(archive_file))
|
758 |
+
else:
|
759 |
+
logger.info("loading archive file {} from cache at {}".format(
|
760 |
+
archive_file, resolved_archive_file))
|
761 |
+
tempdir = None
|
762 |
+
if os.path.isdir(resolved_archive_file) or from_tf:
|
763 |
+
serialization_dir = resolved_archive_file
|
764 |
+
else:
|
765 |
+
# Extract archive to temp dir
|
766 |
+
tempdir = tempfile.mkdtemp()
|
767 |
+
logger.info("extracting archive file {} to temp dir {}".format(
|
768 |
+
resolved_archive_file, tempdir))
|
769 |
+
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
770 |
+
archive.extractall(tempdir)
|
771 |
+
serialization_dir = tempdir
|
772 |
+
# Load config
|
773 |
+
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
774 |
+
config = BertConfig.from_json_file(config_file)
|
775 |
+
logger.info("Model config {}".format(config))
|
776 |
+
# Instantiate model.
|
777 |
+
model = cls(config, *inputs, **kwargs)
|
778 |
+
if state_dict is None and not from_tf:
|
779 |
+
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
780 |
+
state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
|
781 |
+
if tempdir:
|
782 |
+
# Clean up temp dir
|
783 |
+
shutil.rmtree(tempdir)
|
784 |
+
if from_tf:
|
785 |
+
# Directly load from a TensorFlow checkpoint
|
786 |
+
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
787 |
+
return load_tf_weights_in_bert(model, weights_path)
|
788 |
+
# Load from a PyTorch state_dict
|
789 |
+
old_keys = []
|
790 |
+
new_keys = []
|
791 |
+
for key in state_dict.keys():
|
792 |
+
new_key = None
|
793 |
+
if 'gamma' in key:
|
794 |
+
new_key = key.replace('gamma', 'weight')
|
795 |
+
if 'beta' in key:
|
796 |
+
new_key = key.replace('beta', 'bias')
|
797 |
+
if new_key:
|
798 |
+
old_keys.append(key)
|
799 |
+
new_keys.append(new_key)
|
800 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
801 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
802 |
+
|
803 |
+
missing_keys = []
|
804 |
+
unexpected_keys = []
|
805 |
+
error_msgs = []
|
806 |
+
# copy state_dict so _load_from_state_dict can modify it
|
807 |
+
metadata = getattr(state_dict, '_metadata', None)
|
808 |
+
state_dict = state_dict.copy()
|
809 |
+
if metadata is not None:
|
810 |
+
state_dict._metadata = metadata
|
811 |
+
|
812 |
+
def load(module, prefix=''):
|
813 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
814 |
+
module._load_from_state_dict(
|
815 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
816 |
+
for name, child in module._modules.items():
|
817 |
+
if child is not None:
|
818 |
+
load(child, prefix + name + '.')
|
819 |
+
start_prefix = ''
|
820 |
+
if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
|
821 |
+
start_prefix = 'bert.'
|
822 |
+
load(model, prefix=start_prefix)
|
823 |
+
# if len(missing_keys) > 0:
|
824 |
+
# logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
825 |
+
# model.__class__.__name__, missing_keys))
|
826 |
+
# if len(unexpected_keys) > 0:
|
827 |
+
# logger.info("Weights from pretrained model not used in {}: {}".format(
|
828 |
+
# model.__class__.__name__, unexpected_keys))
|
829 |
+
if len(error_msgs) > 0:
|
830 |
+
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
831 |
+
model.__class__.__name__, "\n\t".join(error_msgs)))
|
832 |
+
return model
|
833 |
+
|
834 |
+
|
835 |
+
class LXRTModel(BertPreTrainedModel):
|
836 |
+
"""LXRT Model."""
|
837 |
+
|
838 |
+
def __init__(self, config):
|
839 |
+
super().__init__(config)
|
840 |
+
self.embeddings = BertEmbeddings(config)
|
841 |
+
self.encoder = LXRTEncoder(config)
|
842 |
+
self.pooler = BertPooler(config)
|
843 |
+
self.apply(self.init_bert_weights)
|
844 |
+
|
845 |
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None,
|
846 |
+
visual_feats=None, visual_attention_mask=None):
|
847 |
+
if attention_mask is None:
|
848 |
+
attention_mask = torch.ones_like(input_ids)
|
849 |
+
if token_type_ids is None:
|
850 |
+
token_type_ids = torch.zeros_like(input_ids)
|
851 |
+
|
852 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
853 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
854 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
855 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
856 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
857 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
858 |
+
|
859 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
860 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
861 |
+
# positions we want to attend and -10000.0 for masked positions.
|
862 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
863 |
+
# effectively the same as removing these entirely.
|
864 |
+
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
865 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
866 |
+
|
867 |
+
# Process the visual attention mask
|
868 |
+
if visual_attention_mask is not None:
|
869 |
+
extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2)
|
870 |
+
extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
871 |
+
extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
|
872 |
+
else:
|
873 |
+
extended_visual_attention_mask = None
|
874 |
+
|
875 |
+
# Positional Word Embeddings
|
876 |
+
embedding_output = self.embeddings(input_ids, token_type_ids)
|
877 |
+
|
878 |
+
# Run LXRT backbone
|
879 |
+
lang_feats, visn_feats = self.encoder(
|
880 |
+
embedding_output,
|
881 |
+
extended_attention_mask,
|
882 |
+
visn_feats=visual_feats,
|
883 |
+
visn_attention_mask=extended_visual_attention_mask)
|
884 |
+
pooled_output = self.pooler(lang_feats)
|
885 |
+
|
886 |
+
return (lang_feats, visn_feats), pooled_output
|
887 |
+
|
888 |
+
|
889 |
+
class LXRTPretraining(BertPreTrainedModel):
|
890 |
+
def __init__(self,
|
891 |
+
config,
|
892 |
+
task_mask_lm=True,
|
893 |
+
task_matched=True,
|
894 |
+
task_obj_predict=True,
|
895 |
+
visual_losses='',
|
896 |
+
task_qa=True,
|
897 |
+
num_answers=2):
|
898 |
+
super().__init__(config)
|
899 |
+
# Configuration
|
900 |
+
self.config = config
|
901 |
+
self.num_answers = num_answers
|
902 |
+
|
903 |
+
# Use of pre-training tasks
|
904 |
+
self.task_mask_lm = task_mask_lm
|
905 |
+
self.task_obj_predict = task_obj_predict
|
906 |
+
self.task_matched = task_matched
|
907 |
+
self.task_qa = task_qa
|
908 |
+
|
909 |
+
# LXRT backbone
|
910 |
+
self.bert = LXRTModel(config)
|
911 |
+
|
912 |
+
# Pre-training heads
|
913 |
+
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
|
914 |
+
if self.task_obj_predict:
|
915 |
+
self.obj_predict_head = BertVisualObjHead(config, visual_losses)
|
916 |
+
if self.task_qa:
|
917 |
+
self.answer_head = BertVisualAnswerHead(config, self.num_answers)
|
918 |
+
|
919 |
+
# Weight initialization
|
920 |
+
self.apply(self.init_bert_weights)
|
921 |
+
|
922 |
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
923 |
+
visual_feats=None, pos=None, obj_labels=None, matched_label=None, ans=None):
|
924 |
+
(lang_output, visn_output), pooled_output = self.bert(
|
925 |
+
input_ids, token_type_ids, attention_mask,
|
926 |
+
visual_feats=(visual_feats, pos),
|
927 |
+
)
|
928 |
+
|
929 |
+
lang_prediction_scores, cross_relationship_score = self.cls(lang_output, pooled_output)
|
930 |
+
if self.task_qa:
|
931 |
+
answer_score = self.answer_head(pooled_output)
|
932 |
+
else:
|
933 |
+
# This answer_score would not be used anywhere,
|
934 |
+
# just to keep a constant return function signature.
|
935 |
+
answer_score = pooled_output[0][0]
|
936 |
+
|
937 |
+
total_loss = 0.
|
938 |
+
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
939 |
+
losses = ()
|
940 |
+
if masked_lm_labels is not None and self.task_mask_lm:
|
941 |
+
masked_lm_loss = loss_fct(
|
942 |
+
lang_prediction_scores.view(-1, self.config.vocab_size),
|
943 |
+
masked_lm_labels.view(-1)
|
944 |
+
)
|
945 |
+
total_loss += masked_lm_loss
|
946 |
+
losses += (masked_lm_loss.detach(),)
|
947 |
+
if matched_label is not None and self.task_matched:
|
948 |
+
matched_loss = loss_fct(
|
949 |
+
cross_relationship_score.view(-1, 2),
|
950 |
+
matched_label.view(-1)
|
951 |
+
)
|
952 |
+
total_loss += matched_loss
|
953 |
+
losses += (matched_loss.detach(),)
|
954 |
+
if obj_labels is not None and self.task_obj_predict:
|
955 |
+
loss_fcts = {
|
956 |
+
'l2': SmoothL1Loss(reduction='none'),
|
957 |
+
'ce': CrossEntropyLoss(ignore_index=-1, reduction='none')
|
958 |
+
}
|
959 |
+
total_visn_loss = 0.
|
960 |
+
visn_prediction_scores_dict = self.obj_predict_head(visn_output)
|
961 |
+
for key in VISUAL_CONFIG.visual_losses:
|
962 |
+
label, mask_conf = obj_labels[key]
|
963 |
+
output_dim, loss_fct_name, label_shape, weight = VISUAL_CONFIG.visual_loss_config[key]
|
964 |
+
visn_loss_fct = loss_fcts[loss_fct_name]
|
965 |
+
visn_prediction_scores = visn_prediction_scores_dict[key]
|
966 |
+
visn_loss = visn_loss_fct(
|
967 |
+
visn_prediction_scores.view(-1, output_dim),
|
968 |
+
label.view(*label_shape),
|
969 |
+
)
|
970 |
+
if visn_loss.dim() > 1: # Regression Losses
|
971 |
+
visn_loss = visn_loss.mean(1)
|
972 |
+
visn_loss = (visn_loss * mask_conf.view(-1)).mean() * weight
|
973 |
+
total_visn_loss += visn_loss
|
974 |
+
losses += (visn_loss.detach(),)
|
975 |
+
total_loss += total_visn_loss
|
976 |
+
if ans is not None and self.task_qa:
|
977 |
+
answer_loss = loss_fct(
|
978 |
+
answer_score.view(-1, self.num_answers),
|
979 |
+
ans.view(-1)
|
980 |
+
)
|
981 |
+
# Since this Github version pre-trains with QA loss from the beginning,
|
982 |
+
# I exclude "*2" here to match the effect of QA losses.
|
983 |
+
# Previous: (loss *0) for 6 epochs, (loss *2) for 6 epochs. (Used 10 instead of 6 in EMNLP paper)
|
984 |
+
# Now : (loss *1) for 12 epochs
|
985 |
+
#
|
986 |
+
# * 2 # Multiply by 2 because > half of the data will not have label
|
987 |
+
total_loss += answer_loss
|
988 |
+
losses += (answer_loss.detach(),)
|
989 |
+
return total_loss, torch.stack(losses).unsqueeze(0), answer_score.detach()
|
990 |
+
|
991 |
+
|
992 |
+
class LXRTFeatureExtraction(BertPreTrainedModel):
|
993 |
+
"""
|
994 |
+
BERT model for classification.
|
995 |
+
"""
|
996 |
+
def __init__(self, config, mode='lxr'):
|
997 |
+
"""
|
998 |
+
|
999 |
+
:param config:
|
1000 |
+
:param mode: Number of visual layers
|
1001 |
+
"""
|
1002 |
+
super().__init__(config)
|
1003 |
+
self.bert = LXRTModel(config)
|
1004 |
+
self.mode = mode
|
1005 |
+
self.apply(self.init_bert_weights)
|
1006 |
+
|
1007 |
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, visual_feats=None,
|
1008 |
+
visual_attention_mask=None):
|
1009 |
+
feat_seq, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
|
1010 |
+
visual_feats=visual_feats,
|
1011 |
+
visual_attention_mask=visual_attention_mask)
|
1012 |
+
if 'x' == self.mode:
|
1013 |
+
return pooled_output
|
1014 |
+
elif 'x' in self.mode and ('l' in self.mode or 'r' in self.mode):
|
1015 |
+
return feat_seq, pooled_output
|
1016 |
+
elif 'l' in self.mode or 'r' in self.mode:
|
1017 |
+
return feat_seq
|
1018 |
+
|
lxmert/src/lxrt/optimization.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2019 project LXRT
|
3 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch optimization for BERT model."""
|
17 |
+
|
18 |
+
import math
|
19 |
+
import torch
|
20 |
+
from torch.optim import Optimizer
|
21 |
+
from torch.optim.optimizer import required
|
22 |
+
import logging
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
def warmup_cosine(x, warmup=0.002):
|
27 |
+
if x < warmup:
|
28 |
+
return x/warmup
|
29 |
+
return 0.5 * (1.0 + torch.cos(math.pi * x))
|
30 |
+
|
31 |
+
def warmup_constant(x, warmup=0.002):
|
32 |
+
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
|
33 |
+
Learning rate is 1. afterwards. """
|
34 |
+
if x < warmup:
|
35 |
+
return x/warmup
|
36 |
+
return 1.0
|
37 |
+
|
38 |
+
def warmup_linear(x, warmup=0.002):
|
39 |
+
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
|
40 |
+
After `t_total`-th training step, learning rate is zero. """
|
41 |
+
if x < warmup:
|
42 |
+
return x/warmup
|
43 |
+
return max((x-1.)/(warmup-1.), 0)
|
44 |
+
|
45 |
+
SCHEDULES = {
|
46 |
+
'warmup_cosine': warmup_cosine,
|
47 |
+
'warmup_constant': warmup_constant,
|
48 |
+
'warmup_linear': warmup_linear,
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
class BertAdam(Optimizer):
|
53 |
+
"""Implements BERT version of Adam algorithm with weight decay fix.
|
54 |
+
Params:
|
55 |
+
lr: learning rate
|
56 |
+
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
|
57 |
+
t_total: total number of training steps for the learning
|
58 |
+
rate schedule, -1 means constant learning rate. Default: -1
|
59 |
+
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
|
60 |
+
b1: Adams b1. Default: 0.9
|
61 |
+
b2: Adams b2. Default: 0.999
|
62 |
+
e: Adams epsilon. Default: 1e-6
|
63 |
+
weight_decay: Weight decay. Default: 0.01
|
64 |
+
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
65 |
+
"""
|
66 |
+
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
|
67 |
+
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
|
68 |
+
max_grad_norm=1.0):
|
69 |
+
if lr is not required and lr < 0.0:
|
70 |
+
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
71 |
+
if schedule not in SCHEDULES:
|
72 |
+
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
73 |
+
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
74 |
+
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
75 |
+
if not 0.0 <= b1 < 1.0:
|
76 |
+
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
77 |
+
if not 0.0 <= b2 < 1.0:
|
78 |
+
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
79 |
+
if not e >= 0.0:
|
80 |
+
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
81 |
+
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
|
82 |
+
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
83 |
+
max_grad_norm=max_grad_norm)
|
84 |
+
super(BertAdam, self).__init__(params, defaults)
|
85 |
+
|
86 |
+
def get_lr(self):
|
87 |
+
lr = []
|
88 |
+
for group in self.param_groups:
|
89 |
+
for p in group['params']:
|
90 |
+
state = self.state[p]
|
91 |
+
if len(state) == 0:
|
92 |
+
return [0]
|
93 |
+
if group['t_total'] != -1:
|
94 |
+
schedule_fct = SCHEDULES[group['schedule']]
|
95 |
+
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
96 |
+
else:
|
97 |
+
lr_scheduled = group['lr']
|
98 |
+
lr.append(lr_scheduled)
|
99 |
+
return lr
|
100 |
+
|
101 |
+
def step(self, closure=None):
|
102 |
+
"""Performs a single optimization step.
|
103 |
+
|
104 |
+
Arguments:
|
105 |
+
closure (callable, optional): A closure that reevaluates the model
|
106 |
+
and returns the loss.
|
107 |
+
"""
|
108 |
+
loss = None
|
109 |
+
if closure is not None:
|
110 |
+
loss = closure()
|
111 |
+
|
112 |
+
warned_for_t_total = False
|
113 |
+
|
114 |
+
for group in self.param_groups:
|
115 |
+
for p in group['params']:
|
116 |
+
if p.grad is None:
|
117 |
+
continue
|
118 |
+
grad = p.grad.data
|
119 |
+
if grad.is_sparse:
|
120 |
+
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
121 |
+
|
122 |
+
state = self.state[p]
|
123 |
+
|
124 |
+
# State initialization
|
125 |
+
if len(state) == 0:
|
126 |
+
state['step'] = 0
|
127 |
+
# Exponential moving average of gradient values
|
128 |
+
state['next_m'] = torch.zeros_like(p.data)
|
129 |
+
# Exponential moving average of squared gradient values
|
130 |
+
state['next_v'] = torch.zeros_like(p.data)
|
131 |
+
|
132 |
+
next_m, next_v = state['next_m'], state['next_v']
|
133 |
+
beta1, beta2 = group['b1'], group['b2']
|
134 |
+
|
135 |
+
# LXRT: grad is clipped outside.
|
136 |
+
# Add grad clipping
|
137 |
+
# if group['max_grad_norm'] > 0:
|
138 |
+
# clip_grad_norm_(p, group['max_grad_norm'])
|
139 |
+
|
140 |
+
# Decay the first and second moment running average coefficient
|
141 |
+
# In-place operations to update the averages at the same time
|
142 |
+
next_m.mul_(beta1).add_(1 - beta1, grad)
|
143 |
+
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
144 |
+
update = next_m / (next_v.sqrt() + group['e'])
|
145 |
+
|
146 |
+
# Just adding the square of the weights to the loss function is *not*
|
147 |
+
# the correct way of using L2 regularization/weight decay with Adam,
|
148 |
+
# since that will interact with the m and v parameters in strange ways.
|
149 |
+
#
|
150 |
+
# Instead we want to decay the weights in a manner that doesn't interact
|
151 |
+
# with the m/v parameters. This is equivalent to adding the square
|
152 |
+
# of the weights to the loss with plain (non-momentum) SGD.
|
153 |
+
if group['weight_decay'] > 0.0:
|
154 |
+
update += group['weight_decay'] * p.data
|
155 |
+
|
156 |
+
if group['t_total'] != -1:
|
157 |
+
schedule_fct = SCHEDULES[group['schedule']]
|
158 |
+
progress = state['step']/group['t_total']
|
159 |
+
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
160 |
+
# warning for exceeding t_total (only active with warmup_linear
|
161 |
+
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
|
162 |
+
logger.warning(
|
163 |
+
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
|
164 |
+
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
|
165 |
+
warned_for_t_total = True
|
166 |
+
# end warning
|
167 |
+
else:
|
168 |
+
lr_scheduled = group['lr']
|
169 |
+
|
170 |
+
update_with_lr = lr_scheduled * update
|
171 |
+
p.data.add_(-update_with_lr)
|
172 |
+
|
173 |
+
state['step'] += 1
|
174 |
+
|
175 |
+
# step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
|
176 |
+
# No bias correction
|
177 |
+
# bias_correction1 = 1 - beta1 ** state['step']
|
178 |
+
# bias_correction2 = 1 - beta2 ** state['step']
|
179 |
+
|
180 |
+
return loss
|
lxmert/src/lxrt/tokenization.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Tokenization classes."""
|
16 |
+
|
17 |
+
import collections
|
18 |
+
import logging
|
19 |
+
import os
|
20 |
+
import unicodedata
|
21 |
+
from io import open
|
22 |
+
|
23 |
+
from .file_utils import cached_path
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
28 |
+
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
29 |
+
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
|
30 |
+
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
|
31 |
+
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
|
32 |
+
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
|
33 |
+
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
34 |
+
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
35 |
+
}
|
36 |
+
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
37 |
+
'bert-base-uncased': 512,
|
38 |
+
'bert-large-uncased': 512,
|
39 |
+
'bert-base-cased': 512,
|
40 |
+
'bert-large-cased': 512,
|
41 |
+
'bert-base-multilingual-uncased': 512,
|
42 |
+
'bert-base-multilingual-cased': 512,
|
43 |
+
'bert-base-chinese': 512,
|
44 |
+
}
|
45 |
+
VOCAB_NAME = 'vocab.txt'
|
46 |
+
|
47 |
+
|
48 |
+
def load_vocab(vocab_file):
|
49 |
+
"""Loads a vocabulary file into a dictionary."""
|
50 |
+
vocab = collections.OrderedDict()
|
51 |
+
index = 0
|
52 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
53 |
+
while True:
|
54 |
+
token = reader.readline()
|
55 |
+
if not token:
|
56 |
+
break
|
57 |
+
token = token.strip()
|
58 |
+
vocab[token] = index
|
59 |
+
index += 1
|
60 |
+
return vocab
|
61 |
+
|
62 |
+
|
63 |
+
def whitespace_tokenize(text):
|
64 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
65 |
+
text = text.strip()
|
66 |
+
if not text:
|
67 |
+
return []
|
68 |
+
tokens = text.split()
|
69 |
+
return tokens
|
70 |
+
|
71 |
+
|
72 |
+
class BertTokenizer(object):
|
73 |
+
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
|
74 |
+
|
75 |
+
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
|
76 |
+
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
77 |
+
"""Constructs a BertTokenizer.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
vocab_file: Path to a one-wordpiece-per-line vocabulary file
|
81 |
+
do_lower_case: Whether to lower case the input
|
82 |
+
Only has an effect when do_wordpiece_only=False
|
83 |
+
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
|
84 |
+
max_len: An artificial maximum length to truncate tokenized sequences to;
|
85 |
+
Effective maximum length is always the minimum of this
|
86 |
+
value (if specified) and the underlying BERT model's
|
87 |
+
sequence length.
|
88 |
+
never_split: List of tokens which will never be split during tokenization.
|
89 |
+
Only has an effect when do_wordpiece_only=False
|
90 |
+
"""
|
91 |
+
if not os.path.isfile(vocab_file):
|
92 |
+
raise ValueError(
|
93 |
+
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
94 |
+
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
|
95 |
+
self.vocab = load_vocab(vocab_file)
|
96 |
+
self.ids_to_tokens = collections.OrderedDict(
|
97 |
+
[(ids, tok) for tok, ids in self.vocab.items()])
|
98 |
+
self.do_basic_tokenize = do_basic_tokenize
|
99 |
+
if do_basic_tokenize:
|
100 |
+
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
|
101 |
+
never_split=never_split)
|
102 |
+
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
103 |
+
self.max_len = max_len if max_len is not None else int(1e12)
|
104 |
+
|
105 |
+
def tokenize(self, text):
|
106 |
+
if self.do_basic_tokenize:
|
107 |
+
split_tokens = []
|
108 |
+
for token in self.basic_tokenizer.tokenize(text):
|
109 |
+
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
110 |
+
split_tokens.append(sub_token)
|
111 |
+
else:
|
112 |
+
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
113 |
+
return split_tokens
|
114 |
+
|
115 |
+
def convert_tokens_to_ids(self, tokens):
|
116 |
+
"""Converts a sequence of tokens into ids using the vocab."""
|
117 |
+
ids = []
|
118 |
+
for token in tokens:
|
119 |
+
ids.append(self.vocab[token])
|
120 |
+
if len(ids) > self.max_len:
|
121 |
+
logger.warning(
|
122 |
+
"Token indices sequence length is longer than the specified maximum "
|
123 |
+
" sequence length for this BERT model ({} > {}). Running this"
|
124 |
+
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
|
125 |
+
)
|
126 |
+
return ids
|
127 |
+
|
128 |
+
def convert_ids_to_tokens(self, ids):
|
129 |
+
"""Converts a sequence of ids in wordpiece tokens using the vocab."""
|
130 |
+
tokens = []
|
131 |
+
for i in ids:
|
132 |
+
tokens.append(self.ids_to_tokens[i])
|
133 |
+
return tokens
|
134 |
+
|
135 |
+
@classmethod
|
136 |
+
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
137 |
+
"""
|
138 |
+
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
139 |
+
Download and cache the pre-trained model file if needed.
|
140 |
+
"""
|
141 |
+
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
142 |
+
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
143 |
+
else:
|
144 |
+
vocab_file = pretrained_model_name_or_path
|
145 |
+
if os.path.isdir(vocab_file):
|
146 |
+
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
|
147 |
+
# redirect to the cache, if necessary
|
148 |
+
try:
|
149 |
+
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
150 |
+
except EnvironmentError:
|
151 |
+
logger.error(
|
152 |
+
"Model name '{}' was not found in model name list ({}). "
|
153 |
+
"We assumed '{}' was a path or url but couldn't find any file "
|
154 |
+
"associated to this path or url.".format(
|
155 |
+
pretrained_model_name_or_path,
|
156 |
+
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
157 |
+
vocab_file))
|
158 |
+
return None
|
159 |
+
if resolved_vocab_file == vocab_file:
|
160 |
+
logger.info("loading vocabulary file {}".format(vocab_file))
|
161 |
+
else:
|
162 |
+
logger.info("loading vocabulary file {} from cache at {}".format(
|
163 |
+
vocab_file, resolved_vocab_file))
|
164 |
+
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
165 |
+
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
166 |
+
# than the number of positional embeddings
|
167 |
+
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
168 |
+
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
169 |
+
# Instantiate tokenizer.
|
170 |
+
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
|
171 |
+
return tokenizer
|
172 |
+
|
173 |
+
|
174 |
+
class BasicTokenizer(object):
|
175 |
+
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
176 |
+
|
177 |
+
def __init__(self,
|
178 |
+
do_lower_case=True,
|
179 |
+
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
180 |
+
"""Constructs a BasicTokenizer.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
do_lower_case: Whether to lower case the input.
|
184 |
+
"""
|
185 |
+
self.do_lower_case = do_lower_case
|
186 |
+
self.never_split = never_split
|
187 |
+
|
188 |
+
def tokenize(self, text):
|
189 |
+
"""Tokenizes a piece of text."""
|
190 |
+
text = self._clean_text(text)
|
191 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
192 |
+
# models. This is also applied to the English models now, but it doesn't
|
193 |
+
# matter since the English models were not trained on any Chinese data
|
194 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
195 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
196 |
+
# words in the English Wikipedia.).
|
197 |
+
text = self._tokenize_chinese_chars(text)
|
198 |
+
orig_tokens = whitespace_tokenize(text)
|
199 |
+
split_tokens = []
|
200 |
+
for token in orig_tokens:
|
201 |
+
if self.do_lower_case and token not in self.never_split:
|
202 |
+
token = token.lower()
|
203 |
+
token = self._run_strip_accents(token)
|
204 |
+
split_tokens.extend(self._run_split_on_punc(token))
|
205 |
+
|
206 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
207 |
+
return output_tokens
|
208 |
+
|
209 |
+
def _run_strip_accents(self, text):
|
210 |
+
"""Strips accents from a piece of text."""
|
211 |
+
text = unicodedata.normalize("NFD", text)
|
212 |
+
output = []
|
213 |
+
for char in text:
|
214 |
+
cat = unicodedata.category(char)
|
215 |
+
if cat == "Mn":
|
216 |
+
continue
|
217 |
+
output.append(char)
|
218 |
+
return "".join(output)
|
219 |
+
|
220 |
+
def _run_split_on_punc(self, text):
|
221 |
+
"""Splits punctuation on a piece of text."""
|
222 |
+
if text in self.never_split:
|
223 |
+
return [text]
|
224 |
+
chars = list(text)
|
225 |
+
i = 0
|
226 |
+
start_new_word = True
|
227 |
+
output = []
|
228 |
+
while i < len(chars):
|
229 |
+
char = chars[i]
|
230 |
+
if _is_punctuation(char):
|
231 |
+
output.append([char])
|
232 |
+
start_new_word = True
|
233 |
+
else:
|
234 |
+
if start_new_word:
|
235 |
+
output.append([])
|
236 |
+
start_new_word = False
|
237 |
+
output[-1].append(char)
|
238 |
+
i += 1
|
239 |
+
|
240 |
+
return ["".join(x) for x in output]
|
241 |
+
|
242 |
+
def _tokenize_chinese_chars(self, text):
|
243 |
+
"""Adds whitespace around any CJK character."""
|
244 |
+
output = []
|
245 |
+
for char in text:
|
246 |
+
cp = ord(char)
|
247 |
+
if self._is_chinese_char(cp):
|
248 |
+
output.append(" ")
|
249 |
+
output.append(char)
|
250 |
+
output.append(" ")
|
251 |
+
else:
|
252 |
+
output.append(char)
|
253 |
+
return "".join(output)
|
254 |
+
|
255 |
+
def _is_chinese_char(self, cp):
|
256 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
257 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
258 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
259 |
+
#
|
260 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
261 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
262 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
263 |
+
# space-separated words, so they are not treated specially and handled
|
264 |
+
# like the all of the other languages.
|
265 |
+
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
266 |
+
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
267 |
+
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
268 |
+
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
269 |
+
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
270 |
+
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
271 |
+
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
272 |
+
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
273 |
+
return True
|
274 |
+
|
275 |
+
return False
|
276 |
+
|
277 |
+
def _clean_text(self, text):
|
278 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
279 |
+
output = []
|
280 |
+
for char in text:
|
281 |
+
cp = ord(char)
|
282 |
+
if cp == 0 or cp == 0xfffd or _is_control(char):
|
283 |
+
continue
|
284 |
+
if _is_whitespace(char):
|
285 |
+
output.append(" ")
|
286 |
+
else:
|
287 |
+
output.append(char)
|
288 |
+
return "".join(output)
|
289 |
+
|
290 |
+
|
291 |
+
class WordpieceTokenizer(object):
|
292 |
+
"""Runs WordPiece tokenization."""
|
293 |
+
|
294 |
+
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
|
295 |
+
self.vocab = vocab
|
296 |
+
self.unk_token = unk_token
|
297 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
298 |
+
|
299 |
+
def tokenize(self, text):
|
300 |
+
"""Tokenizes a piece of text into its word pieces.
|
301 |
+
|
302 |
+
This uses a greedy longest-match-first algorithm to perform tokenization
|
303 |
+
using the given vocabulary.
|
304 |
+
|
305 |
+
For example:
|
306 |
+
input = "unaffable"
|
307 |
+
output = ["un", "##aff", "##able"]
|
308 |
+
|
309 |
+
Args:
|
310 |
+
text: A single token or whitespace separated tokens. This should have
|
311 |
+
already been passed through `BasicTokenizer`.
|
312 |
+
|
313 |
+
Returns:
|
314 |
+
A list of wordpiece tokens.
|
315 |
+
"""
|
316 |
+
|
317 |
+
output_tokens = []
|
318 |
+
for token in whitespace_tokenize(text):
|
319 |
+
chars = list(token)
|
320 |
+
if len(chars) > self.max_input_chars_per_word:
|
321 |
+
output_tokens.append(self.unk_token)
|
322 |
+
continue
|
323 |
+
|
324 |
+
is_bad = False
|
325 |
+
start = 0
|
326 |
+
sub_tokens = []
|
327 |
+
while start < len(chars):
|
328 |
+
end = len(chars)
|
329 |
+
cur_substr = None
|
330 |
+
while start < end:
|
331 |
+
substr = "".join(chars[start:end])
|
332 |
+
if start > 0:
|
333 |
+
substr = "##" + substr
|
334 |
+
if substr in self.vocab:
|
335 |
+
cur_substr = substr
|
336 |
+
break
|
337 |
+
end -= 1
|
338 |
+
if cur_substr is None:
|
339 |
+
is_bad = True
|
340 |
+
break
|
341 |
+
sub_tokens.append(cur_substr)
|
342 |
+
start = end
|
343 |
+
|
344 |
+
if is_bad:
|
345 |
+
output_tokens.append(self.unk_token)
|
346 |
+
else:
|
347 |
+
output_tokens.extend(sub_tokens)
|
348 |
+
return output_tokens
|
349 |
+
|
350 |
+
|
351 |
+
def _is_whitespace(char):
|
352 |
+
"""Checks whether `chars` is a whitespace character."""
|
353 |
+
# \t, \n, and \r are technically contorl characters but we treat them
|
354 |
+
# as whitespace since they are generally considered as such.
|
355 |
+
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
356 |
+
return True
|
357 |
+
cat = unicodedata.category(char)
|
358 |
+
if cat == "Zs":
|
359 |
+
return True
|
360 |
+
return False
|
361 |
+
|
362 |
+
|
363 |
+
def _is_control(char):
|
364 |
+
"""Checks whether `chars` is a control character."""
|
365 |
+
# These are technically control characters but we count them as whitespace
|
366 |
+
# characters.
|
367 |
+
if char == "\t" or char == "\n" or char == "\r":
|
368 |
+
return False
|
369 |
+
cat = unicodedata.category(char)
|
370 |
+
if cat.startswith("C"):
|
371 |
+
return True
|
372 |
+
return False
|
373 |
+
|
374 |
+
|
375 |
+
def _is_punctuation(char):
|
376 |
+
"""Checks whether `chars` is a punctuation character."""
|
377 |
+
cp = ord(char)
|
378 |
+
# We treat all non-letter/number ASCII as punctuation.
|
379 |
+
# Characters such as "^", "$", and "`" are not in the Unicode
|
380 |
+
# Punctuation class but we treat them as punctuation anyways, for
|
381 |
+
# consistency.
|
382 |
+
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
383 |
+
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
384 |
+
return True
|
385 |
+
cat = unicodedata.category(char)
|
386 |
+
if cat.startswith("P"):
|
387 |
+
return True
|
388 |
+
return False
|
lxmert/src/modeling_frcnn.py
ADDED
@@ -0,0 +1,1922 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
coding=utf-8
|
3 |
+
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal
|
4 |
+
Adapted From Facebook Inc, Detectron2 && Huggingface Co.
|
5 |
+
|
6 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
you may not use this file except in compliance with the License.
|
8 |
+
You may obtain a copy of the License at
|
9 |
+
|
10 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
|
12 |
+
Unless required by applicable law or agreed to in writing, software
|
13 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
See the License for the specific language governing permissions and
|
16 |
+
limitations under the License.import copy
|
17 |
+
"""
|
18 |
+
import itertools
|
19 |
+
import math
|
20 |
+
import os
|
21 |
+
from abc import ABCMeta, abstractmethod
|
22 |
+
from collections import OrderedDict, namedtuple
|
23 |
+
from typing import Dict, List, Tuple
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
from torch import nn
|
28 |
+
from torch.nn import functional as F
|
29 |
+
from torch.nn.modules.batchnorm import BatchNorm2d
|
30 |
+
from torchvision.ops import RoIPool
|
31 |
+
from torchvision.ops.boxes import batched_nms, nms
|
32 |
+
|
33 |
+
from lxmert.lxmert.src.vqa_utils import WEIGHTS_NAME, Config, cached_path, hf_bucket_url, is_remote_url, load_checkpoint
|
34 |
+
|
35 |
+
|
36 |
+
# other:
|
37 |
+
def norm_box(boxes, raw_sizes):
|
38 |
+
if not isinstance(boxes, torch.Tensor):
|
39 |
+
normalized_boxes = boxes.copy()
|
40 |
+
else:
|
41 |
+
normalized_boxes = boxes.clone()
|
42 |
+
normalized_boxes[:, :, (0, 2)] /= raw_sizes[:, 1]
|
43 |
+
normalized_boxes[:, :, (1, 3)] /= raw_sizes[:, 0]
|
44 |
+
return normalized_boxes
|
45 |
+
|
46 |
+
|
47 |
+
def pad_list_tensors(
|
48 |
+
list_tensors,
|
49 |
+
preds_per_image,
|
50 |
+
max_detections=None,
|
51 |
+
return_tensors=None,
|
52 |
+
padding=None,
|
53 |
+
pad_value=0,
|
54 |
+
location=None,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
location will always be cpu for np tensors
|
58 |
+
"""
|
59 |
+
if location is None:
|
60 |
+
location = "cpu"
|
61 |
+
assert return_tensors in {"pt", "np", None}
|
62 |
+
assert padding in {"max_detections", "max_batch", None}
|
63 |
+
new = []
|
64 |
+
if padding is None:
|
65 |
+
if return_tensors is None:
|
66 |
+
return list_tensors
|
67 |
+
elif return_tensors == "pt":
|
68 |
+
if not isinstance(list_tensors, torch.Tensor):
|
69 |
+
return torch.stack(list_tensors).to(location)
|
70 |
+
else:
|
71 |
+
return list_tensors.to(location)
|
72 |
+
else:
|
73 |
+
if not isinstance(list_tensors, list):
|
74 |
+
return np.array(list_tensors.to(location))
|
75 |
+
else:
|
76 |
+
return list_tensors.to(location)
|
77 |
+
if padding == "max_detections":
|
78 |
+
assert max_detections is not None, "specify max number of detections per batch"
|
79 |
+
elif padding == "max_batch":
|
80 |
+
max_detections = max(preds_per_image)
|
81 |
+
for i in range(len(list_tensors)):
|
82 |
+
too_small = False
|
83 |
+
tensor_i = list_tensors.pop(0)
|
84 |
+
if tensor_i.ndim < 2:
|
85 |
+
too_small = True
|
86 |
+
tensor_i = tensor_i.unsqueeze(-1)
|
87 |
+
assert isinstance(tensor_i, torch.Tensor)
|
88 |
+
tensor_i = F.pad(
|
89 |
+
input=tensor_i,
|
90 |
+
pad=(0, 0, 0, max_detections - preds_per_image[i]),
|
91 |
+
mode="constant",
|
92 |
+
value=pad_value,
|
93 |
+
)
|
94 |
+
if too_small:
|
95 |
+
tensor_i = tensor_i.squeeze(-1)
|
96 |
+
if return_tensors is None:
|
97 |
+
if location == "cpu":
|
98 |
+
tensor_i = tensor_i.cpu()
|
99 |
+
tensor_i = tensor_i.tolist()
|
100 |
+
if return_tensors == "np":
|
101 |
+
if location == "cpu":
|
102 |
+
tensor_i = tensor_i.cpu()
|
103 |
+
tensor_i = tensor_i.numpy()
|
104 |
+
else:
|
105 |
+
if location == "cpu":
|
106 |
+
tensor_i = tensor_i.cpu()
|
107 |
+
new.append(tensor_i)
|
108 |
+
if return_tensors == "np":
|
109 |
+
return np.stack(new, axis=0)
|
110 |
+
elif return_tensors == "pt" and not isinstance(new, torch.Tensor):
|
111 |
+
return torch.stack(new, dim=0)
|
112 |
+
else:
|
113 |
+
return list_tensors
|
114 |
+
|
115 |
+
|
116 |
+
def do_nms(boxes, scores, image_shape, score_thresh, nms_thresh, mind, maxd):
|
117 |
+
scores = scores[:, :-1]
|
118 |
+
num_bbox_reg_classes = boxes.shape[1] // 4
|
119 |
+
# Convert to Boxes to use the `clip` function ...
|
120 |
+
boxes = boxes.reshape(-1, 4)
|
121 |
+
_clip_box(boxes, image_shape)
|
122 |
+
boxes = boxes.view(-1, num_bbox_reg_classes, 4) # R x C x 4
|
123 |
+
|
124 |
+
# Select max scores
|
125 |
+
max_scores, max_classes = scores.max(1) # R x C --> R
|
126 |
+
num_objs = boxes.size(0)
|
127 |
+
boxes = boxes.view(-1, 4)
|
128 |
+
idxs = torch.arange(num_objs).to(boxes.device) * num_bbox_reg_classes + max_classes
|
129 |
+
max_boxes = boxes[idxs] # Select max boxes according to the max scores.
|
130 |
+
|
131 |
+
# Apply NMS
|
132 |
+
keep = nms(max_boxes, max_scores, nms_thresh)
|
133 |
+
keep = keep[:maxd]
|
134 |
+
if keep.shape[-1] >= mind and keep.shape[-1] <= maxd:
|
135 |
+
max_boxes, max_scores = max_boxes[keep], max_scores[keep]
|
136 |
+
classes = max_classes[keep]
|
137 |
+
return max_boxes, max_scores, classes, keep
|
138 |
+
else:
|
139 |
+
return None
|
140 |
+
|
141 |
+
|
142 |
+
# Helper Functions
|
143 |
+
def _clip_box(tensor, box_size: Tuple[int, int]):
|
144 |
+
assert torch.isfinite(tensor).all(), "Box tensor contains infinite or NaN!"
|
145 |
+
h, w = box_size
|
146 |
+
tensor[:, 0].clamp_(min=0, max=w)
|
147 |
+
tensor[:, 1].clamp_(min=0, max=h)
|
148 |
+
tensor[:, 2].clamp_(min=0, max=w)
|
149 |
+
tensor[:, 3].clamp_(min=0, max=h)
|
150 |
+
|
151 |
+
|
152 |
+
def _nonempty_boxes(box, threshold: float = 0.0) -> torch.Tensor:
|
153 |
+
widths = box[:, 2] - box[:, 0]
|
154 |
+
heights = box[:, 3] - box[:, 1]
|
155 |
+
keep = (widths > threshold) & (heights > threshold)
|
156 |
+
return keep
|
157 |
+
|
158 |
+
|
159 |
+
def get_norm(norm, out_channels):
|
160 |
+
if isinstance(norm, str):
|
161 |
+
if len(norm) == 0:
|
162 |
+
return None
|
163 |
+
norm = {
|
164 |
+
"BN": BatchNorm2d,
|
165 |
+
"GN": lambda channels: nn.GroupNorm(32, channels),
|
166 |
+
"nnSyncBN": nn.SyncBatchNorm, # keep for debugging
|
167 |
+
"": lambda x: x,
|
168 |
+
}[norm]
|
169 |
+
return norm(out_channels)
|
170 |
+
|
171 |
+
|
172 |
+
def _create_grid_offsets(size: List[int], stride: int, offset: float, device):
|
173 |
+
|
174 |
+
grid_height, grid_width = size
|
175 |
+
shifts_x = torch.arange(
|
176 |
+
offset * stride,
|
177 |
+
grid_width * stride,
|
178 |
+
step=stride,
|
179 |
+
dtype=torch.float32,
|
180 |
+
device=device,
|
181 |
+
)
|
182 |
+
shifts_y = torch.arange(
|
183 |
+
offset * stride,
|
184 |
+
grid_height * stride,
|
185 |
+
step=stride,
|
186 |
+
dtype=torch.float32,
|
187 |
+
device=device,
|
188 |
+
)
|
189 |
+
|
190 |
+
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
|
191 |
+
shift_x = shift_x.reshape(-1)
|
192 |
+
shift_y = shift_y.reshape(-1)
|
193 |
+
return shift_x, shift_y
|
194 |
+
|
195 |
+
|
196 |
+
def build_backbone(cfg):
|
197 |
+
input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN))
|
198 |
+
norm = cfg.RESNETS.NORM
|
199 |
+
stem = BasicStem(
|
200 |
+
in_channels=input_shape.channels,
|
201 |
+
out_channels=cfg.RESNETS.STEM_OUT_CHANNELS,
|
202 |
+
norm=norm,
|
203 |
+
caffe_maxpool=cfg.MODEL.MAX_POOL,
|
204 |
+
)
|
205 |
+
freeze_at = cfg.BACKBONE.FREEZE_AT
|
206 |
+
|
207 |
+
if freeze_at >= 1:
|
208 |
+
for p in stem.parameters():
|
209 |
+
p.requires_grad = False
|
210 |
+
|
211 |
+
out_features = cfg.RESNETS.OUT_FEATURES
|
212 |
+
depth = cfg.RESNETS.DEPTH
|
213 |
+
num_groups = cfg.RESNETS.NUM_GROUPS
|
214 |
+
width_per_group = cfg.RESNETS.WIDTH_PER_GROUP
|
215 |
+
bottleneck_channels = num_groups * width_per_group
|
216 |
+
in_channels = cfg.RESNETS.STEM_OUT_CHANNELS
|
217 |
+
out_channels = cfg.RESNETS.RES2_OUT_CHANNELS
|
218 |
+
stride_in_1x1 = cfg.RESNETS.STRIDE_IN_1X1
|
219 |
+
res5_dilation = cfg.RESNETS.RES5_DILATION
|
220 |
+
assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
|
221 |
+
|
222 |
+
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
|
223 |
+
|
224 |
+
stages = []
|
225 |
+
out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features]
|
226 |
+
max_stage_idx = max(out_stage_idx)
|
227 |
+
for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
|
228 |
+
dilation = res5_dilation if stage_idx == 5 else 1
|
229 |
+
first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
|
230 |
+
stage_kargs = {
|
231 |
+
"num_blocks": num_blocks_per_stage[idx],
|
232 |
+
"first_stride": first_stride,
|
233 |
+
"in_channels": in_channels,
|
234 |
+
"bottleneck_channels": bottleneck_channels,
|
235 |
+
"out_channels": out_channels,
|
236 |
+
"num_groups": num_groups,
|
237 |
+
"norm": norm,
|
238 |
+
"stride_in_1x1": stride_in_1x1,
|
239 |
+
"dilation": dilation,
|
240 |
+
}
|
241 |
+
|
242 |
+
stage_kargs["block_class"] = BottleneckBlock
|
243 |
+
blocks = ResNet.make_stage(**stage_kargs)
|
244 |
+
in_channels = out_channels
|
245 |
+
out_channels *= 2
|
246 |
+
bottleneck_channels *= 2
|
247 |
+
|
248 |
+
if freeze_at >= stage_idx:
|
249 |
+
for block in blocks:
|
250 |
+
block.freeze()
|
251 |
+
stages.append(blocks)
|
252 |
+
|
253 |
+
return ResNet(stem, stages, out_features=out_features)
|
254 |
+
|
255 |
+
|
256 |
+
def find_top_rpn_proposals(
|
257 |
+
proposals,
|
258 |
+
pred_objectness_logits,
|
259 |
+
images,
|
260 |
+
image_sizes,
|
261 |
+
nms_thresh,
|
262 |
+
pre_nms_topk,
|
263 |
+
post_nms_topk,
|
264 |
+
min_box_side_len,
|
265 |
+
training,
|
266 |
+
):
|
267 |
+
"""Args:
|
268 |
+
proposals (list[Tensor]): (L, N, Hi*Wi*A, 4).
|
269 |
+
pred_objectness_logits: tensors of length L.
|
270 |
+
nms_thresh (float): IoU threshold to use for NMS
|
271 |
+
pre_nms_topk (int): before nms
|
272 |
+
post_nms_topk (int): after nms
|
273 |
+
min_box_side_len (float): minimum proposal box side
|
274 |
+
training (bool): True if proposals are to be used in training,
|
275 |
+
Returns:
|
276 |
+
results (List[Dict]): stores post_nms_topk object proposals for image i.
|
277 |
+
"""
|
278 |
+
num_images = len(images)
|
279 |
+
device = proposals[0].device
|
280 |
+
|
281 |
+
# 1. Select top-k anchor for every level and every image
|
282 |
+
topk_scores = [] # #lvl Tensor, each of shape N x topk
|
283 |
+
topk_proposals = []
|
284 |
+
level_ids = [] # #lvl Tensor, each of shape (topk,)
|
285 |
+
batch_idx = torch.arange(num_images, device=device)
|
286 |
+
for level_id, proposals_i, logits_i in zip(itertools.count(), proposals, pred_objectness_logits):
|
287 |
+
Hi_Wi_A = logits_i.shape[1]
|
288 |
+
num_proposals_i = min(pre_nms_topk, Hi_Wi_A)
|
289 |
+
|
290 |
+
# sort is faster than topk (https://github.com/pytorch/pytorch/issues/22812)
|
291 |
+
# topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
|
292 |
+
logits_i, idx = logits_i.sort(descending=True, dim=1)
|
293 |
+
topk_scores_i = logits_i[batch_idx, :num_proposals_i]
|
294 |
+
topk_idx = idx[batch_idx, :num_proposals_i]
|
295 |
+
|
296 |
+
# each is N x topk
|
297 |
+
topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx] # N x topk x 4
|
298 |
+
|
299 |
+
topk_proposals.append(topk_proposals_i)
|
300 |
+
topk_scores.append(topk_scores_i)
|
301 |
+
level_ids.append(torch.full((num_proposals_i,), level_id, dtype=torch.int64, device=device))
|
302 |
+
|
303 |
+
# 2. Concat all levels together
|
304 |
+
topk_scores = torch.cat(topk_scores, dim=1)
|
305 |
+
topk_proposals = torch.cat(topk_proposals, dim=1)
|
306 |
+
level_ids = torch.cat(level_ids, dim=0)
|
307 |
+
|
308 |
+
# if I change to batched_nms, I wonder if this will make a difference
|
309 |
+
# 3. For each image, run a per-level NMS, and choose topk results.
|
310 |
+
results = []
|
311 |
+
for n, image_size in enumerate(image_sizes):
|
312 |
+
boxes = topk_proposals[n]
|
313 |
+
scores_per_img = topk_scores[n]
|
314 |
+
# I will have to take a look at the boxes clip method
|
315 |
+
_clip_box(boxes, image_size)
|
316 |
+
# filter empty boxes
|
317 |
+
keep = _nonempty_boxes(boxes, threshold=min_box_side_len)
|
318 |
+
lvl = level_ids
|
319 |
+
if keep.sum().item() != len(boxes):
|
320 |
+
boxes, scores_per_img, lvl = (
|
321 |
+
boxes[keep],
|
322 |
+
scores_per_img[keep],
|
323 |
+
level_ids[keep],
|
324 |
+
)
|
325 |
+
|
326 |
+
keep = batched_nms(boxes, scores_per_img, lvl, nms_thresh)
|
327 |
+
keep = keep[:post_nms_topk]
|
328 |
+
|
329 |
+
res = (boxes[keep], scores_per_img[keep])
|
330 |
+
results.append(res)
|
331 |
+
|
332 |
+
# I wonder if it would be possible for me to pad all these things.
|
333 |
+
return results
|
334 |
+
|
335 |
+
|
336 |
+
def subsample_labels(labels, num_samples, positive_fraction, bg_label):
|
337 |
+
"""
|
338 |
+
Returns:
|
339 |
+
pos_idx, neg_idx (Tensor):
|
340 |
+
1D vector of indices. The total length of both is `num_samples` or fewer.
|
341 |
+
"""
|
342 |
+
positive = torch.nonzero((labels != -1) & (labels != bg_label)).squeeze(1)
|
343 |
+
negative = torch.nonzero(labels == bg_label).squeeze(1)
|
344 |
+
|
345 |
+
num_pos = int(num_samples * positive_fraction)
|
346 |
+
# protect against not enough positive examples
|
347 |
+
num_pos = min(positive.numel(), num_pos)
|
348 |
+
num_neg = num_samples - num_pos
|
349 |
+
# protect against not enough negative examples
|
350 |
+
num_neg = min(negative.numel(), num_neg)
|
351 |
+
|
352 |
+
# randomly select positive and negative examples
|
353 |
+
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
|
354 |
+
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
|
355 |
+
|
356 |
+
pos_idx = positive[perm1]
|
357 |
+
neg_idx = negative[perm2]
|
358 |
+
return pos_idx, neg_idx
|
359 |
+
|
360 |
+
|
361 |
+
def add_ground_truth_to_proposals(gt_boxes, proposals):
|
362 |
+
raise NotImplementedError()
|
363 |
+
|
364 |
+
|
365 |
+
def add_ground_truth_to_proposals_single_image(gt_boxes, proposals):
|
366 |
+
raise NotImplementedError()
|
367 |
+
|
368 |
+
|
369 |
+
def _fmt_box_list(box_tensor, batch_index: int):
|
370 |
+
repeated_index = torch.full(
|
371 |
+
(len(box_tensor), 1),
|
372 |
+
batch_index,
|
373 |
+
dtype=box_tensor.dtype,
|
374 |
+
device=box_tensor.device,
|
375 |
+
)
|
376 |
+
return torch.cat((repeated_index, box_tensor), dim=1)
|
377 |
+
|
378 |
+
|
379 |
+
def convert_boxes_to_pooler_format(box_lists: List[torch.Tensor]):
|
380 |
+
pooler_fmt_boxes = torch.cat(
|
381 |
+
[_fmt_box_list(box_list, i) for i, box_list in enumerate(box_lists)],
|
382 |
+
dim=0,
|
383 |
+
)
|
384 |
+
return pooler_fmt_boxes
|
385 |
+
|
386 |
+
|
387 |
+
def assign_boxes_to_levels(
|
388 |
+
box_lists: List[torch.Tensor],
|
389 |
+
min_level: int,
|
390 |
+
max_level: int,
|
391 |
+
canonical_box_size: int,
|
392 |
+
canonical_level: int,
|
393 |
+
):
|
394 |
+
|
395 |
+
box_sizes = torch.sqrt(torch.cat([boxes.area() for boxes in box_lists]))
|
396 |
+
# Eqn.(1) in FPN paper
|
397 |
+
level_assignments = torch.floor(canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8))
|
398 |
+
# clamp level to (min, max), in case the box size is too large or too small
|
399 |
+
# for the available feature maps
|
400 |
+
level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level)
|
401 |
+
return level_assignments.to(torch.int64) - min_level
|
402 |
+
|
403 |
+
|
404 |
+
# Helper Classes
|
405 |
+
class _NewEmptyTensorOp(torch.autograd.Function):
|
406 |
+
@staticmethod
|
407 |
+
def forward(ctx, x, new_shape):
|
408 |
+
ctx.shape = x.shape
|
409 |
+
return x.new_empty(new_shape)
|
410 |
+
|
411 |
+
@staticmethod
|
412 |
+
def backward(ctx, grad):
|
413 |
+
shape = ctx.shape
|
414 |
+
return _NewEmptyTensorOp.apply(grad, shape), None
|
415 |
+
|
416 |
+
|
417 |
+
class ShapeSpec(namedtuple("_ShapeSpec", ["channels", "height", "width", "stride"])):
|
418 |
+
def __new__(cls, *, channels=None, height=None, width=None, stride=None):
|
419 |
+
return super().__new__(cls, channels, height, width, stride)
|
420 |
+
|
421 |
+
|
422 |
+
class Box2BoxTransform(object):
|
423 |
+
"""
|
424 |
+
This R-CNN transformation scales the box's width and height
|
425 |
+
by exp(dw), exp(dh) and shifts a box's center by the offset
|
426 |
+
(dx * width, dy * height).
|
427 |
+
"""
|
428 |
+
|
429 |
+
def __init__(self, weights: Tuple[float, float, float, float], scale_clamp: float = None):
|
430 |
+
"""
|
431 |
+
Args:
|
432 |
+
weights (4-element tuple): Scaling factors that are applied to the
|
433 |
+
(dx, dy, dw, dh) deltas. In Fast R-CNN, these were originally set
|
434 |
+
such that the deltas have unit variance; now they are treated as
|
435 |
+
hyperparameters of the system.
|
436 |
+
scale_clamp (float): When predicting deltas, the predicted box scaling
|
437 |
+
factors (dw and dh) are clamped such that they are <= scale_clamp.
|
438 |
+
"""
|
439 |
+
self.weights = weights
|
440 |
+
if scale_clamp is not None:
|
441 |
+
self.scale_clamp = scale_clamp
|
442 |
+
else:
|
443 |
+
"""
|
444 |
+
Value for clamping large dw and dh predictions.
|
445 |
+
The heuristic is that we clamp such that dw and dh are no larger
|
446 |
+
than what would transform a 16px box into a 1000px box
|
447 |
+
(based on a small anchor, 16px, and a typical image size, 1000px).
|
448 |
+
"""
|
449 |
+
self.scale_clamp = math.log(1000.0 / 16)
|
450 |
+
|
451 |
+
def get_deltas(self, src_boxes, target_boxes):
|
452 |
+
"""
|
453 |
+
Get box regression transformation deltas (dx, dy, dw, dh) that can be used
|
454 |
+
to transform the `src_boxes` into the `target_boxes`. That is, the relation
|
455 |
+
``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true (unless
|
456 |
+
any delta is too large and is clamped).
|
457 |
+
Args:
|
458 |
+
src_boxes (Tensor): source boxes, e.g., object proposals
|
459 |
+
target_boxes (Tensor): target of the transformation, e.g., ground-truth
|
460 |
+
boxes.
|
461 |
+
"""
|
462 |
+
assert isinstance(src_boxes, torch.Tensor), type(src_boxes)
|
463 |
+
assert isinstance(target_boxes, torch.Tensor), type(target_boxes)
|
464 |
+
|
465 |
+
src_widths = src_boxes[:, 2] - src_boxes[:, 0]
|
466 |
+
src_heights = src_boxes[:, 3] - src_boxes[:, 1]
|
467 |
+
src_ctr_x = src_boxes[:, 0] + 0.5 * src_widths
|
468 |
+
src_ctr_y = src_boxes[:, 1] + 0.5 * src_heights
|
469 |
+
|
470 |
+
target_widths = target_boxes[:, 2] - target_boxes[:, 0]
|
471 |
+
target_heights = target_boxes[:, 3] - target_boxes[:, 1]
|
472 |
+
target_ctr_x = target_boxes[:, 0] + 0.5 * target_widths
|
473 |
+
target_ctr_y = target_boxes[:, 1] + 0.5 * target_heights
|
474 |
+
|
475 |
+
wx, wy, ww, wh = self.weights
|
476 |
+
dx = wx * (target_ctr_x - src_ctr_x) / src_widths
|
477 |
+
dy = wy * (target_ctr_y - src_ctr_y) / src_heights
|
478 |
+
dw = ww * torch.log(target_widths / src_widths)
|
479 |
+
dh = wh * torch.log(target_heights / src_heights)
|
480 |
+
|
481 |
+
deltas = torch.stack((dx, dy, dw, dh), dim=1)
|
482 |
+
assert (src_widths > 0).all().item(), "Input boxes to Box2BoxTransform are not valid!"
|
483 |
+
return deltas
|
484 |
+
|
485 |
+
def apply_deltas(self, deltas, boxes):
|
486 |
+
"""
|
487 |
+
Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.
|
488 |
+
Args:
|
489 |
+
deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1.
|
490 |
+
deltas[i] represents k potentially different class-specific
|
491 |
+
box transformations for the single box boxes[i].
|
492 |
+
boxes (Tensor): boxes to transform, of shape (N, 4)
|
493 |
+
"""
|
494 |
+
boxes = boxes.to(deltas.dtype)
|
495 |
+
|
496 |
+
widths = boxes[:, 2] - boxes[:, 0]
|
497 |
+
heights = boxes[:, 3] - boxes[:, 1]
|
498 |
+
ctr_x = boxes[:, 0] + 0.5 * widths
|
499 |
+
ctr_y = boxes[:, 1] + 0.5 * heights
|
500 |
+
|
501 |
+
wx, wy, ww, wh = self.weights
|
502 |
+
dx = deltas[:, 0::4] / wx
|
503 |
+
dy = deltas[:, 1::4] / wy
|
504 |
+
dw = deltas[:, 2::4] / ww
|
505 |
+
dh = deltas[:, 3::4] / wh
|
506 |
+
|
507 |
+
# Prevent sending too large values into torch.exp()
|
508 |
+
dw = torch.clamp(dw, max=self.scale_clamp)
|
509 |
+
dh = torch.clamp(dh, max=self.scale_clamp)
|
510 |
+
|
511 |
+
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
|
512 |
+
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
|
513 |
+
pred_w = torch.exp(dw) * widths[:, None]
|
514 |
+
pred_h = torch.exp(dh) * heights[:, None]
|
515 |
+
|
516 |
+
pred_boxes = torch.zeros_like(deltas)
|
517 |
+
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1
|
518 |
+
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1
|
519 |
+
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2
|
520 |
+
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2
|
521 |
+
return pred_boxes
|
522 |
+
|
523 |
+
|
524 |
+
class Matcher(object):
|
525 |
+
"""
|
526 |
+
This class assigns to each predicted "element" (e.g., a box) a ground-truth
|
527 |
+
element. Each predicted element will have exactly zero or one matches; each
|
528 |
+
ground-truth element may be matched to zero or more predicted elements.
|
529 |
+
The matching is determined by the MxN match_quality_matrix, that characterizes
|
530 |
+
how well each (ground-truth, prediction)-pair match each other. For example,
|
531 |
+
if the elements are boxes, this matrix may contain box intersection-over-union
|
532 |
+
overlap values.
|
533 |
+
The matcher returns (a) a vector of length N containing the index of the
|
534 |
+
ground-truth element m in [0, M) that matches to prediction n in [0, N).
|
535 |
+
(b) a vector of length N containing the labels for each prediction.
|
536 |
+
"""
|
537 |
+
|
538 |
+
def __init__(
|
539 |
+
self,
|
540 |
+
thresholds: List[float],
|
541 |
+
labels: List[int],
|
542 |
+
allow_low_quality_matches: bool = False,
|
543 |
+
):
|
544 |
+
"""
|
545 |
+
Args:
|
546 |
+
thresholds (list): a list of thresholds used to stratify predictions
|
547 |
+
into levels.
|
548 |
+
labels (list): a list of values to label predictions belonging at
|
549 |
+
each level. A label can be one of {-1, 0, 1} signifying
|
550 |
+
{ignore, negative class, positive class}, respectively.
|
551 |
+
allow_low_quality_matches (bool): if True, produce additional matches or predictions with maximum match quality lower than high_threshold.
|
552 |
+
For example, thresholds = [0.3, 0.5] labels = [0, -1, 1] All predictions with iou < 0.3 will be marked with 0 and
|
553 |
+
thus will be considered as false positives while training. All predictions with 0.3 <= iou < 0.5 will be marked with -1 and
|
554 |
+
thus will be ignored. All predictions with 0.5 <= iou will be marked with 1 and thus will be considered as true positives.
|
555 |
+
"""
|
556 |
+
thresholds = thresholds[:]
|
557 |
+
assert thresholds[0] > 0
|
558 |
+
thresholds.insert(0, -float("inf"))
|
559 |
+
thresholds.append(float("inf"))
|
560 |
+
assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])])
|
561 |
+
assert all([label_i in [-1, 0, 1] for label_i in labels])
|
562 |
+
assert len(labels) == len(thresholds) - 1
|
563 |
+
self.thresholds = thresholds
|
564 |
+
self.labels = labels
|
565 |
+
self.allow_low_quality_matches = allow_low_quality_matches
|
566 |
+
|
567 |
+
def __call__(self, match_quality_matrix):
|
568 |
+
"""
|
569 |
+
Args:
|
570 |
+
match_quality_matrix (Tensor[float]): an MxN tensor, containing the pairwise quality between M ground-truth elements and N predicted
|
571 |
+
elements. All elements must be >= 0 (due to the us of `torch.nonzero` for selecting indices in :meth:`set_low_quality_matches_`).
|
572 |
+
Returns:
|
573 |
+
matches (Tensor[int64]): a vector of length N, where matches[i] is a matched ground-truth index in [0, M)
|
574 |
+
match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates true or false positive or ignored
|
575 |
+
"""
|
576 |
+
assert match_quality_matrix.dim() == 2
|
577 |
+
if match_quality_matrix.numel() == 0:
|
578 |
+
default_matches = match_quality_matrix.new_full((match_quality_matrix.size(1),), 0, dtype=torch.int64)
|
579 |
+
# When no gt boxes exist, we define IOU = 0 and therefore set labels
|
580 |
+
# to `self.labels[0]`, which usually defaults to background class 0
|
581 |
+
# To choose to ignore instead,
|
582 |
+
# can make labels=[-1,0,-1,1] + set appropriate thresholds
|
583 |
+
default_match_labels = match_quality_matrix.new_full(
|
584 |
+
(match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8
|
585 |
+
)
|
586 |
+
return default_matches, default_match_labels
|
587 |
+
|
588 |
+
assert torch.all(match_quality_matrix >= 0)
|
589 |
+
|
590 |
+
# match_quality_matrix is M (gt) x N (predicted)
|
591 |
+
# Max over gt elements (dim 0) to find best gt candidate for each prediction
|
592 |
+
matched_vals, matches = match_quality_matrix.max(dim=0)
|
593 |
+
|
594 |
+
match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
|
595 |
+
|
596 |
+
for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
|
597 |
+
low_high = (matched_vals >= low) & (matched_vals < high)
|
598 |
+
match_labels[low_high] = l
|
599 |
+
|
600 |
+
if self.allow_low_quality_matches:
|
601 |
+
self.set_low_quality_matches_(match_labels, match_quality_matrix)
|
602 |
+
|
603 |
+
return matches, match_labels
|
604 |
+
|
605 |
+
def set_low_quality_matches_(self, match_labels, match_quality_matrix):
|
606 |
+
"""
|
607 |
+
Produce additional matches for predictions that have only low-quality matches.
|
608 |
+
Specifically, for each ground-truth G find the set of predictions that have
|
609 |
+
maximum overlap with it (including ties); for each prediction in that set, if
|
610 |
+
it is unmatched, then match it to the ground-truth G.
|
611 |
+
This function implements the RPN assignment case (i)
|
612 |
+
in Sec. 3.1.2 of Faster R-CNN.
|
613 |
+
"""
|
614 |
+
# For each gt, find the prediction with which it has highest quality
|
615 |
+
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
|
616 |
+
# Find the highest quality match available, even if it is low, including ties.
|
617 |
+
# Note that the matches qualities must be positive due to the use of
|
618 |
+
# `torch.nonzero`.
|
619 |
+
of_quality_inds = match_quality_matrix == highest_quality_foreach_gt[:, None]
|
620 |
+
if of_quality_inds.dim() == 0:
|
621 |
+
(_, pred_inds_with_highest_quality) = of_quality_inds.unsqueeze(0).nonzero().unbind(1)
|
622 |
+
else:
|
623 |
+
(_, pred_inds_with_highest_quality) = of_quality_inds.nonzero().unbind(1)
|
624 |
+
match_labels[pred_inds_with_highest_quality] = 1
|
625 |
+
|
626 |
+
|
627 |
+
class RPNOutputs(object):
|
628 |
+
def __init__(
|
629 |
+
self,
|
630 |
+
box2box_transform,
|
631 |
+
anchor_matcher,
|
632 |
+
batch_size_per_image,
|
633 |
+
positive_fraction,
|
634 |
+
images,
|
635 |
+
pred_objectness_logits,
|
636 |
+
pred_anchor_deltas,
|
637 |
+
anchors,
|
638 |
+
boundary_threshold=0,
|
639 |
+
gt_boxes=None,
|
640 |
+
smooth_l1_beta=0.0,
|
641 |
+
):
|
642 |
+
"""
|
643 |
+
Args:
|
644 |
+
box2box_transform (Box2BoxTransform): :class:`Box2BoxTransform` instance for anchor-proposal transformations.
|
645 |
+
anchor_matcher (Matcher): :class:`Matcher` instance for matching anchors to ground-truth boxes; used to determine training labels.
|
646 |
+
batch_size_per_image (int): number of proposals to sample when training
|
647 |
+
positive_fraction (float): target fraction of sampled proposals that should be positive
|
648 |
+
images (ImageList): :class:`ImageList` instance representing N input images
|
649 |
+
pred_objectness_logits (list[Tensor]): A list of L elements. Element i is a tensor of shape (N, A, Hi, W)
|
650 |
+
pred_anchor_deltas (list[Tensor]): A list of L elements. Element i is a tensor of shape (N, A*4, Hi, Wi)
|
651 |
+
anchors (list[torch.Tensor]): nested list of boxes. anchors[i][j] at (n, l) stores anchor array for feature map l
|
652 |
+
boundary_threshold (int): if >= 0, then anchors that extend beyond the image boundary by more than boundary_thresh are not used in training.
|
653 |
+
gt_boxes (list[Boxes], optional): A list of N elements.
|
654 |
+
smooth_l1_beta (float): The transition point between L1 and L2 lossn. When set to 0, the loss becomes L1. When +inf, it is ignored
|
655 |
+
"""
|
656 |
+
self.box2box_transform = box2box_transform
|
657 |
+
self.anchor_matcher = anchor_matcher
|
658 |
+
self.batch_size_per_image = batch_size_per_image
|
659 |
+
self.positive_fraction = positive_fraction
|
660 |
+
self.pred_objectness_logits = pred_objectness_logits
|
661 |
+
self.pred_anchor_deltas = pred_anchor_deltas
|
662 |
+
|
663 |
+
self.anchors = anchors
|
664 |
+
self.gt_boxes = gt_boxes
|
665 |
+
self.num_feature_maps = len(pred_objectness_logits)
|
666 |
+
self.num_images = len(images)
|
667 |
+
self.boundary_threshold = boundary_threshold
|
668 |
+
self.smooth_l1_beta = smooth_l1_beta
|
669 |
+
|
670 |
+
def _get_ground_truth(self):
|
671 |
+
raise NotImplementedError()
|
672 |
+
|
673 |
+
def predict_proposals(self):
|
674 |
+
# pred_anchor_deltas: (L, N, ? Hi, Wi)
|
675 |
+
# anchors:(N, L, -1, B)
|
676 |
+
# here we loop over specific feature map, NOT images
|
677 |
+
proposals = []
|
678 |
+
anchors = self.anchors.transpose(0, 1)
|
679 |
+
for anchors_i, pred_anchor_deltas_i in zip(anchors, self.pred_anchor_deltas):
|
680 |
+
B = anchors_i.size(-1)
|
681 |
+
N, _, Hi, Wi = pred_anchor_deltas_i.shape
|
682 |
+
anchors_i = anchors_i.flatten(start_dim=0, end_dim=1)
|
683 |
+
pred_anchor_deltas_i = pred_anchor_deltas_i.view(N, -1, B, Hi, Wi).permute(0, 3, 4, 1, 2).reshape(-1, B)
|
684 |
+
proposals_i = self.box2box_transform.apply_deltas(pred_anchor_deltas_i, anchors_i)
|
685 |
+
# Append feature map proposals with shape (N, Hi*Wi*A, B)
|
686 |
+
proposals.append(proposals_i.view(N, -1, B))
|
687 |
+
proposals = torch.stack(proposals)
|
688 |
+
return proposals
|
689 |
+
|
690 |
+
def predict_objectness_logits(self):
|
691 |
+
"""
|
692 |
+
Returns:
|
693 |
+
pred_objectness_logits (list[Tensor]) -> (N, Hi*Wi*A).
|
694 |
+
"""
|
695 |
+
pred_objectness_logits = [
|
696 |
+
# Reshape: (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N, Hi*Wi*A)
|
697 |
+
score.permute(0, 2, 3, 1).reshape(self.num_images, -1)
|
698 |
+
for score in self.pred_objectness_logits
|
699 |
+
]
|
700 |
+
return pred_objectness_logits
|
701 |
+
|
702 |
+
|
703 |
+
# Main Classes
|
704 |
+
class Conv2d(torch.nn.Conv2d):
|
705 |
+
def __init__(self, *args, **kwargs):
|
706 |
+
norm = kwargs.pop("norm", None)
|
707 |
+
activation = kwargs.pop("activation", None)
|
708 |
+
super().__init__(*args, **kwargs)
|
709 |
+
|
710 |
+
self.norm = norm
|
711 |
+
self.activation = activation
|
712 |
+
|
713 |
+
def forward(self, x):
|
714 |
+
if x.numel() == 0 and self.training:
|
715 |
+
assert not isinstance(self.norm, torch.nn.SyncBatchNorm)
|
716 |
+
if x.numel() == 0:
|
717 |
+
assert not isinstance(self.norm, torch.nn.GroupNorm)
|
718 |
+
output_shape = [
|
719 |
+
(i + 2 * p - (di * (k - 1) + 1)) // s + 1
|
720 |
+
for i, p, di, k, s in zip(
|
721 |
+
x.shape[-2:],
|
722 |
+
self.padding,
|
723 |
+
self.dilation,
|
724 |
+
self.kernel_size,
|
725 |
+
self.stride,
|
726 |
+
)
|
727 |
+
]
|
728 |
+
output_shape = [x.shape[0], self.weight.shape[0]] + output_shape
|
729 |
+
empty = _NewEmptyTensorOp.apply(x, output_shape)
|
730 |
+
if self.training:
|
731 |
+
_dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
|
732 |
+
return empty + _dummy
|
733 |
+
else:
|
734 |
+
return empty
|
735 |
+
|
736 |
+
x = super().forward(x)
|
737 |
+
if self.norm is not None:
|
738 |
+
x = self.norm(x)
|
739 |
+
if self.activation is not None:
|
740 |
+
x = self.activation(x)
|
741 |
+
return x
|
742 |
+
|
743 |
+
|
744 |
+
class LastLevelMaxPool(nn.Module):
|
745 |
+
"""
|
746 |
+
This module is used in the original FPN to generate a downsampled P6 feature from P5.
|
747 |
+
"""
|
748 |
+
|
749 |
+
def __init__(self):
|
750 |
+
super().__init__()
|
751 |
+
self.num_levels = 1
|
752 |
+
self.in_feature = "p5"
|
753 |
+
|
754 |
+
def forward(self, x):
|
755 |
+
return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
|
756 |
+
|
757 |
+
|
758 |
+
class LastLevelP6P7(nn.Module):
|
759 |
+
"""
|
760 |
+
This module is used in RetinaNet to generate extra layers, P6 and P7 from C5 feature.
|
761 |
+
"""
|
762 |
+
|
763 |
+
def __init__(self, in_channels, out_channels):
|
764 |
+
super().__init__()
|
765 |
+
self.num_levels = 2
|
766 |
+
self.in_feature = "res5"
|
767 |
+
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
|
768 |
+
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
|
769 |
+
|
770 |
+
def forward(self, c5):
|
771 |
+
p6 = self.p6(c5)
|
772 |
+
p7 = self.p7(F.relu(p6))
|
773 |
+
return [p6, p7]
|
774 |
+
|
775 |
+
|
776 |
+
class BasicStem(nn.Module):
|
777 |
+
def __init__(self, in_channels=3, out_channels=64, norm="BN", caffe_maxpool=False):
|
778 |
+
super().__init__()
|
779 |
+
self.conv1 = Conv2d(
|
780 |
+
in_channels,
|
781 |
+
out_channels,
|
782 |
+
kernel_size=7,
|
783 |
+
stride=2,
|
784 |
+
padding=3,
|
785 |
+
bias=False,
|
786 |
+
norm=get_norm(norm, out_channels),
|
787 |
+
)
|
788 |
+
self.caffe_maxpool = caffe_maxpool
|
789 |
+
# use pad 1 instead of pad zero
|
790 |
+
|
791 |
+
def forward(self, x):
|
792 |
+
x = self.conv1(x)
|
793 |
+
x = F.relu_(x)
|
794 |
+
if self.caffe_maxpool:
|
795 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True)
|
796 |
+
else:
|
797 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
|
798 |
+
return x
|
799 |
+
|
800 |
+
@property
|
801 |
+
def out_channels(self):
|
802 |
+
return self.conv1.out_channels
|
803 |
+
|
804 |
+
@property
|
805 |
+
def stride(self):
|
806 |
+
return 4 # = stride 2 conv -> stride 2 max pool
|
807 |
+
|
808 |
+
|
809 |
+
class ResNetBlockBase(nn.Module):
|
810 |
+
def __init__(self, in_channels, out_channels, stride):
|
811 |
+
super().__init__()
|
812 |
+
self.in_channels = in_channels
|
813 |
+
self.out_channels = out_channels
|
814 |
+
self.stride = stride
|
815 |
+
|
816 |
+
def freeze(self):
|
817 |
+
for p in self.parameters():
|
818 |
+
p.requires_grad = False
|
819 |
+
return self
|
820 |
+
|
821 |
+
|
822 |
+
class BottleneckBlock(ResNetBlockBase):
|
823 |
+
def __init__(
|
824 |
+
self,
|
825 |
+
in_channels,
|
826 |
+
out_channels,
|
827 |
+
bottleneck_channels,
|
828 |
+
stride=1,
|
829 |
+
num_groups=1,
|
830 |
+
norm="BN",
|
831 |
+
stride_in_1x1=False,
|
832 |
+
dilation=1,
|
833 |
+
):
|
834 |
+
super().__init__(in_channels, out_channels, stride)
|
835 |
+
|
836 |
+
if in_channels != out_channels:
|
837 |
+
self.shortcut = Conv2d(
|
838 |
+
in_channels,
|
839 |
+
out_channels,
|
840 |
+
kernel_size=1,
|
841 |
+
stride=stride,
|
842 |
+
bias=False,
|
843 |
+
norm=get_norm(norm, out_channels),
|
844 |
+
)
|
845 |
+
else:
|
846 |
+
self.shortcut = None
|
847 |
+
|
848 |
+
# The original MSRA ResNet models have stride in the first 1x1 conv
|
849 |
+
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
|
850 |
+
# stride in the 3x3 conv
|
851 |
+
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
|
852 |
+
|
853 |
+
self.conv1 = Conv2d(
|
854 |
+
in_channels,
|
855 |
+
bottleneck_channels,
|
856 |
+
kernel_size=1,
|
857 |
+
stride=stride_1x1,
|
858 |
+
bias=False,
|
859 |
+
norm=get_norm(norm, bottleneck_channels),
|
860 |
+
)
|
861 |
+
|
862 |
+
self.conv2 = Conv2d(
|
863 |
+
bottleneck_channels,
|
864 |
+
bottleneck_channels,
|
865 |
+
kernel_size=3,
|
866 |
+
stride=stride_3x3,
|
867 |
+
padding=1 * dilation,
|
868 |
+
bias=False,
|
869 |
+
groups=num_groups,
|
870 |
+
dilation=dilation,
|
871 |
+
norm=get_norm(norm, bottleneck_channels),
|
872 |
+
)
|
873 |
+
|
874 |
+
self.conv3 = Conv2d(
|
875 |
+
bottleneck_channels,
|
876 |
+
out_channels,
|
877 |
+
kernel_size=1,
|
878 |
+
bias=False,
|
879 |
+
norm=get_norm(norm, out_channels),
|
880 |
+
)
|
881 |
+
|
882 |
+
def forward(self, x):
|
883 |
+
out = self.conv1(x)
|
884 |
+
out = F.relu_(out)
|
885 |
+
|
886 |
+
out = self.conv2(out)
|
887 |
+
out = F.relu_(out)
|
888 |
+
|
889 |
+
out = self.conv3(out)
|
890 |
+
|
891 |
+
if self.shortcut is not None:
|
892 |
+
shortcut = self.shortcut(x)
|
893 |
+
else:
|
894 |
+
shortcut = x
|
895 |
+
|
896 |
+
out += shortcut
|
897 |
+
out = F.relu_(out)
|
898 |
+
return out
|
899 |
+
|
900 |
+
|
901 |
+
class Backbone(nn.Module, metaclass=ABCMeta):
|
902 |
+
def __init__(self):
|
903 |
+
super().__init__()
|
904 |
+
|
905 |
+
@abstractmethod
|
906 |
+
def forward(self):
|
907 |
+
pass
|
908 |
+
|
909 |
+
@property
|
910 |
+
def size_divisibility(self):
|
911 |
+
"""
|
912 |
+
Some backbones require the input height and width to be divisible by a specific integer. This is
|
913 |
+
typically true for encoder / decoder type networks with lateral connection (e.g., FPN) for which feature maps need to match
|
914 |
+
dimension in the "bottom up" and "top down" paths. Set to 0 if no specific input size divisibility is required.
|
915 |
+
"""
|
916 |
+
return 0
|
917 |
+
|
918 |
+
def output_shape(self):
|
919 |
+
return {
|
920 |
+
name: ShapeSpec(
|
921 |
+
channels=self._out_feature_channels[name],
|
922 |
+
stride=self._out_feature_strides[name],
|
923 |
+
)
|
924 |
+
for name in self._out_features
|
925 |
+
}
|
926 |
+
|
927 |
+
@property
|
928 |
+
def out_features(self):
|
929 |
+
"""deprecated"""
|
930 |
+
return self._out_features
|
931 |
+
|
932 |
+
@property
|
933 |
+
def out_feature_strides(self):
|
934 |
+
"""deprecated"""
|
935 |
+
return {f: self._out_feature_strides[f] for f in self._out_features}
|
936 |
+
|
937 |
+
@property
|
938 |
+
def out_feature_channels(self):
|
939 |
+
"""deprecated"""
|
940 |
+
return {f: self._out_feature_channels[f] for f in self._out_features}
|
941 |
+
|
942 |
+
|
943 |
+
class ResNet(Backbone):
|
944 |
+
def __init__(self, stem, stages, num_classes=None, out_features=None):
|
945 |
+
"""
|
946 |
+
Args:
|
947 |
+
stem (nn.Module): a stem module
|
948 |
+
stages (list[list[ResNetBlock]]): several (typically 4) stages, each contains multiple :class:`ResNetBlockBase`.
|
949 |
+
num_classes (None or int): if None, will not perform classification.
|
950 |
+
out_features (list[str]): name of the layers whose outputs should be returned in forward. Can be anything in:
|
951 |
+
"stem", "linear", or "res2" ... If None, will return the output of the last layer.
|
952 |
+
"""
|
953 |
+
super(ResNet, self).__init__()
|
954 |
+
self.stem = stem
|
955 |
+
self.num_classes = num_classes
|
956 |
+
|
957 |
+
current_stride = self.stem.stride
|
958 |
+
self._out_feature_strides = {"stem": current_stride}
|
959 |
+
self._out_feature_channels = {"stem": self.stem.out_channels}
|
960 |
+
|
961 |
+
self.stages_and_names = []
|
962 |
+
for i, blocks in enumerate(stages):
|
963 |
+
for block in blocks:
|
964 |
+
assert isinstance(block, ResNetBlockBase), block
|
965 |
+
curr_channels = block.out_channels
|
966 |
+
stage = nn.Sequential(*blocks)
|
967 |
+
name = "res" + str(i + 2)
|
968 |
+
self.add_module(name, stage)
|
969 |
+
self.stages_and_names.append((stage, name))
|
970 |
+
self._out_feature_strides[name] = current_stride = int(
|
971 |
+
current_stride * np.prod([k.stride for k in blocks])
|
972 |
+
)
|
973 |
+
self._out_feature_channels[name] = blocks[-1].out_channels
|
974 |
+
|
975 |
+
if num_classes is not None:
|
976 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
977 |
+
self.linear = nn.Linear(curr_channels, num_classes)
|
978 |
+
|
979 |
+
# Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
|
980 |
+
# "The 1000-way fully-connected layer is initialized by
|
981 |
+
# drawing weights from a zero-mean Gaussian with std of 0.01."
|
982 |
+
nn.init.normal_(self.linear.weight, stddev=0.01)
|
983 |
+
name = "linear"
|
984 |
+
|
985 |
+
if out_features is None:
|
986 |
+
out_features = [name]
|
987 |
+
self._out_features = out_features
|
988 |
+
assert len(self._out_features)
|
989 |
+
children = [x[0] for x in self.named_children()]
|
990 |
+
for out_feature in self._out_features:
|
991 |
+
assert out_feature in children, "Available children: {}".format(", ".join(children))
|
992 |
+
|
993 |
+
def forward(self, x):
|
994 |
+
outputs = {}
|
995 |
+
x = self.stem(x)
|
996 |
+
if "stem" in self._out_features:
|
997 |
+
outputs["stem"] = x
|
998 |
+
for stage, name in self.stages_and_names:
|
999 |
+
x = stage(x)
|
1000 |
+
if name in self._out_features:
|
1001 |
+
outputs[name] = x
|
1002 |
+
if self.num_classes is not None:
|
1003 |
+
x = self.avgpool(x)
|
1004 |
+
x = self.linear(x)
|
1005 |
+
if "linear" in self._out_features:
|
1006 |
+
outputs["linear"] = x
|
1007 |
+
return outputs
|
1008 |
+
|
1009 |
+
def output_shape(self):
|
1010 |
+
return {
|
1011 |
+
name: ShapeSpec(
|
1012 |
+
channels=self._out_feature_channels[name],
|
1013 |
+
stride=self._out_feature_strides[name],
|
1014 |
+
)
|
1015 |
+
for name in self._out_features
|
1016 |
+
}
|
1017 |
+
|
1018 |
+
@staticmethod
|
1019 |
+
def make_stage(
|
1020 |
+
block_class,
|
1021 |
+
num_blocks,
|
1022 |
+
first_stride=None,
|
1023 |
+
*,
|
1024 |
+
in_channels,
|
1025 |
+
out_channels,
|
1026 |
+
**kwargs,
|
1027 |
+
):
|
1028 |
+
"""
|
1029 |
+
Usually, layers that produce the same feature map spatial size
|
1030 |
+
are defined as one "stage".
|
1031 |
+
Under such definition, stride_per_block[1:] should all be 1.
|
1032 |
+
"""
|
1033 |
+
if first_stride is not None:
|
1034 |
+
assert "stride" not in kwargs and "stride_per_block" not in kwargs
|
1035 |
+
kwargs["stride_per_block"] = [first_stride] + [1] * (num_blocks - 1)
|
1036 |
+
blocks = []
|
1037 |
+
for i in range(num_blocks):
|
1038 |
+
curr_kwargs = {}
|
1039 |
+
for k, v in kwargs.items():
|
1040 |
+
if k.endswith("_per_block"):
|
1041 |
+
assert len(v) == num_blocks, (
|
1042 |
+
f"Argument '{k}' of make_stage should have the " f"same length as num_blocks={num_blocks}."
|
1043 |
+
)
|
1044 |
+
newk = k[: -len("_per_block")]
|
1045 |
+
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
|
1046 |
+
curr_kwargs[newk] = v[i]
|
1047 |
+
else:
|
1048 |
+
curr_kwargs[k] = v
|
1049 |
+
|
1050 |
+
blocks.append(block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs))
|
1051 |
+
in_channels = out_channels
|
1052 |
+
|
1053 |
+
return blocks
|
1054 |
+
|
1055 |
+
|
1056 |
+
class ROIPooler(nn.Module):
|
1057 |
+
"""
|
1058 |
+
Region of interest feature map pooler that supports pooling from one or more
|
1059 |
+
feature maps.
|
1060 |
+
"""
|
1061 |
+
|
1062 |
+
def __init__(
|
1063 |
+
self,
|
1064 |
+
output_size,
|
1065 |
+
scales,
|
1066 |
+
sampling_ratio,
|
1067 |
+
canonical_box_size=224,
|
1068 |
+
canonical_level=4,
|
1069 |
+
):
|
1070 |
+
super().__init__()
|
1071 |
+
# assumption that stride is a power of 2.
|
1072 |
+
min_level = -math.log2(scales[0])
|
1073 |
+
max_level = -math.log2(scales[-1])
|
1074 |
+
|
1075 |
+
# a bunch of testing
|
1076 |
+
assert math.isclose(min_level, int(min_level)) and math.isclose(max_level, int(max_level))
|
1077 |
+
assert len(scales) == max_level - min_level + 1, "not pyramid"
|
1078 |
+
assert 0 < min_level and min_level <= max_level
|
1079 |
+
if isinstance(output_size, int):
|
1080 |
+
output_size = (output_size, output_size)
|
1081 |
+
assert len(output_size) == 2 and isinstance(output_size[0], int) and isinstance(output_size[1], int)
|
1082 |
+
if len(scales) > 1:
|
1083 |
+
assert min_level <= canonical_level and canonical_level <= max_level
|
1084 |
+
assert canonical_box_size > 0
|
1085 |
+
|
1086 |
+
self.output_size = output_size
|
1087 |
+
self.min_level = int(min_level)
|
1088 |
+
self.max_level = int(max_level)
|
1089 |
+
self.level_poolers = nn.ModuleList(RoIPool(output_size, spatial_scale=scale) for scale in scales)
|
1090 |
+
self.canonical_level = canonical_level
|
1091 |
+
self.canonical_box_size = canonical_box_size
|
1092 |
+
|
1093 |
+
def forward(self, feature_maps, boxes):
|
1094 |
+
"""
|
1095 |
+
Args:
|
1096 |
+
feature_maps: List[torch.Tensor(N,C,W,H)]
|
1097 |
+
box_lists: list[torch.Tensor])
|
1098 |
+
Returns:
|
1099 |
+
A tensor of shape(N*B, Channels, output_size, output_size)
|
1100 |
+
"""
|
1101 |
+
x = [v for v in feature_maps.values()]
|
1102 |
+
num_level_assignments = len(self.level_poolers)
|
1103 |
+
assert len(x) == num_level_assignments and len(boxes) == x[0].size(0)
|
1104 |
+
|
1105 |
+
pooler_fmt_boxes = convert_boxes_to_pooler_format(boxes)
|
1106 |
+
|
1107 |
+
if num_level_assignments == 1:
|
1108 |
+
return self.level_poolers[0](x[0], pooler_fmt_boxes)
|
1109 |
+
|
1110 |
+
level_assignments = assign_boxes_to_levels(
|
1111 |
+
boxes,
|
1112 |
+
self.min_level,
|
1113 |
+
self.max_level,
|
1114 |
+
self.canonical_box_size,
|
1115 |
+
self.canonical_level,
|
1116 |
+
)
|
1117 |
+
|
1118 |
+
num_boxes = len(pooler_fmt_boxes)
|
1119 |
+
num_channels = x[0].shape[1]
|
1120 |
+
output_size = self.output_size[0]
|
1121 |
+
|
1122 |
+
dtype, device = x[0].dtype, x[0].device
|
1123 |
+
output = torch.zeros(
|
1124 |
+
(num_boxes, num_channels, output_size, output_size),
|
1125 |
+
dtype=dtype,
|
1126 |
+
device=device,
|
1127 |
+
)
|
1128 |
+
|
1129 |
+
for level, (x_level, pooler) in enumerate(zip(x, self.level_poolers)):
|
1130 |
+
inds = torch.nonzero(level_assignments == level).squeeze(1)
|
1131 |
+
pooler_fmt_boxes_level = pooler_fmt_boxes[inds]
|
1132 |
+
output[inds] = pooler(x_level, pooler_fmt_boxes_level)
|
1133 |
+
|
1134 |
+
return output
|
1135 |
+
|
1136 |
+
|
1137 |
+
class ROIOutputs(object):
|
1138 |
+
def __init__(self, cfg, training=False):
|
1139 |
+
self.smooth_l1_beta = cfg.ROI_BOX_HEAD.SMOOTH_L1_BETA
|
1140 |
+
self.box2box_transform = Box2BoxTransform(weights=cfg.ROI_BOX_HEAD.BBOX_REG_WEIGHTS)
|
1141 |
+
self.training = training
|
1142 |
+
self.score_thresh = cfg.ROI_HEADS.SCORE_THRESH_TEST
|
1143 |
+
self.min_detections = cfg.MIN_DETECTIONS
|
1144 |
+
self.max_detections = cfg.MAX_DETECTIONS
|
1145 |
+
|
1146 |
+
nms_thresh = cfg.ROI_HEADS.NMS_THRESH_TEST
|
1147 |
+
if not isinstance(nms_thresh, list):
|
1148 |
+
nms_thresh = [nms_thresh]
|
1149 |
+
self.nms_thresh = nms_thresh
|
1150 |
+
|
1151 |
+
def _predict_boxes(self, proposals, box_deltas, preds_per_image):
|
1152 |
+
num_pred = box_deltas.size(0)
|
1153 |
+
B = proposals[0].size(-1)
|
1154 |
+
K = box_deltas.size(-1) // B
|
1155 |
+
box_deltas = box_deltas.view(num_pred * K, B)
|
1156 |
+
proposals = torch.cat(proposals, dim=0).unsqueeze(-2).expand(num_pred, K, B)
|
1157 |
+
proposals = proposals.reshape(-1, B)
|
1158 |
+
boxes = self.box2box_transform.apply_deltas(box_deltas, proposals)
|
1159 |
+
return boxes.view(num_pred, K * B).split(preds_per_image, dim=0)
|
1160 |
+
|
1161 |
+
def _predict_objs(self, obj_logits, preds_per_image):
|
1162 |
+
probs = F.softmax(obj_logits, dim=-1)
|
1163 |
+
probs = probs.split(preds_per_image, dim=0)
|
1164 |
+
return probs
|
1165 |
+
|
1166 |
+
def _predict_attrs(self, attr_logits, preds_per_image):
|
1167 |
+
attr_logits = attr_logits[..., :-1].softmax(-1)
|
1168 |
+
attr_probs, attrs = attr_logits.max(-1)
|
1169 |
+
return attr_probs.split(preds_per_image, dim=0), attrs.split(preds_per_image, dim=0)
|
1170 |
+
|
1171 |
+
@torch.no_grad()
|
1172 |
+
def inference(
|
1173 |
+
self,
|
1174 |
+
obj_logits,
|
1175 |
+
attr_logits,
|
1176 |
+
box_deltas,
|
1177 |
+
pred_boxes,
|
1178 |
+
features,
|
1179 |
+
sizes,
|
1180 |
+
scales=None,
|
1181 |
+
):
|
1182 |
+
# only the pred boxes is the
|
1183 |
+
preds_per_image = [p.size(0) for p in pred_boxes]
|
1184 |
+
boxes_all = self._predict_boxes(pred_boxes, box_deltas, preds_per_image)
|
1185 |
+
obj_scores_all = self._predict_objs(obj_logits, preds_per_image) # list of length N
|
1186 |
+
attr_probs_all, attrs_all = self._predict_attrs(attr_logits, preds_per_image)
|
1187 |
+
features = features.split(preds_per_image, dim=0)
|
1188 |
+
|
1189 |
+
# fun for each image too, also I can experiment and do multiple images
|
1190 |
+
final_results = []
|
1191 |
+
zipped = zip(boxes_all, obj_scores_all, attr_probs_all, attrs_all, sizes)
|
1192 |
+
for i, (boxes, obj_scores, attr_probs, attrs, size) in enumerate(zipped):
|
1193 |
+
for nms_t in self.nms_thresh:
|
1194 |
+
outputs = do_nms(
|
1195 |
+
boxes,
|
1196 |
+
obj_scores,
|
1197 |
+
size,
|
1198 |
+
self.score_thresh,
|
1199 |
+
nms_t,
|
1200 |
+
self.min_detections,
|
1201 |
+
self.max_detections,
|
1202 |
+
)
|
1203 |
+
if outputs is not None:
|
1204 |
+
max_boxes, max_scores, classes, ids = outputs
|
1205 |
+
break
|
1206 |
+
|
1207 |
+
if scales is not None:
|
1208 |
+
scale_yx = scales[i]
|
1209 |
+
max_boxes[:, 0::2] *= scale_yx[1]
|
1210 |
+
max_boxes[:, 1::2] *= scale_yx[0]
|
1211 |
+
|
1212 |
+
final_results.append(
|
1213 |
+
(
|
1214 |
+
max_boxes,
|
1215 |
+
classes,
|
1216 |
+
max_scores,
|
1217 |
+
attrs[ids],
|
1218 |
+
attr_probs[ids],
|
1219 |
+
features[i][ids],
|
1220 |
+
)
|
1221 |
+
)
|
1222 |
+
boxes, classes, class_probs, attrs, attr_probs, roi_features = map(list, zip(*final_results))
|
1223 |
+
return boxes, classes, class_probs, attrs, attr_probs, roi_features
|
1224 |
+
|
1225 |
+
def training(self, obj_logits, attr_logits, box_deltas, pred_boxes, features, sizes):
|
1226 |
+
pass
|
1227 |
+
|
1228 |
+
def __call__(
|
1229 |
+
self,
|
1230 |
+
obj_logits,
|
1231 |
+
attr_logits,
|
1232 |
+
box_deltas,
|
1233 |
+
pred_boxes,
|
1234 |
+
features,
|
1235 |
+
sizes,
|
1236 |
+
scales=None,
|
1237 |
+
):
|
1238 |
+
if self.training:
|
1239 |
+
raise NotImplementedError()
|
1240 |
+
return self.inference(
|
1241 |
+
obj_logits,
|
1242 |
+
attr_logits,
|
1243 |
+
box_deltas,
|
1244 |
+
pred_boxes,
|
1245 |
+
features,
|
1246 |
+
sizes,
|
1247 |
+
scales=scales,
|
1248 |
+
)
|
1249 |
+
|
1250 |
+
|
1251 |
+
class Res5ROIHeads(nn.Module):
|
1252 |
+
"""
|
1253 |
+
ROIHeads perform all per-region computation in an R-CNN.
|
1254 |
+
It contains logic of cropping the regions, extract per-region features
|
1255 |
+
(by the res-5 block in this case), and make per-region predictions.
|
1256 |
+
"""
|
1257 |
+
|
1258 |
+
def __init__(self, cfg, input_shape):
|
1259 |
+
super().__init__()
|
1260 |
+
self.batch_size_per_image = cfg.RPN.BATCH_SIZE_PER_IMAGE
|
1261 |
+
self.positive_sample_fraction = cfg.ROI_HEADS.POSITIVE_FRACTION
|
1262 |
+
self.in_features = cfg.ROI_HEADS.IN_FEATURES
|
1263 |
+
self.num_classes = cfg.ROI_HEADS.NUM_CLASSES
|
1264 |
+
self.proposal_append_gt = cfg.ROI_HEADS.PROPOSAL_APPEND_GT
|
1265 |
+
self.feature_strides = {k: v.stride for k, v in input_shape.items()}
|
1266 |
+
self.feature_channels = {k: v.channels for k, v in input_shape.items()}
|
1267 |
+
self.cls_agnostic_bbox_reg = cfg.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG
|
1268 |
+
self.stage_channel_factor = 2 ** 3 # res5 is 8x res2
|
1269 |
+
self.out_channels = cfg.RESNETS.RES2_OUT_CHANNELS * self.stage_channel_factor
|
1270 |
+
|
1271 |
+
# self.proposal_matcher = Matcher(
|
1272 |
+
# cfg.ROI_HEADS.IOU_THRESHOLDS,
|
1273 |
+
# cfg.ROI_HEADS.IOU_LABELS,
|
1274 |
+
# allow_low_quality_matches=False,
|
1275 |
+
# )
|
1276 |
+
|
1277 |
+
pooler_resolution = cfg.ROI_BOX_HEAD.POOLER_RESOLUTION
|
1278 |
+
pooler_scales = (1.0 / self.feature_strides[self.in_features[0]],)
|
1279 |
+
sampling_ratio = cfg.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
|
1280 |
+
res5_halve = cfg.ROI_BOX_HEAD.RES5HALVE
|
1281 |
+
use_attr = cfg.ROI_BOX_HEAD.ATTR
|
1282 |
+
num_attrs = cfg.ROI_BOX_HEAD.NUM_ATTRS
|
1283 |
+
|
1284 |
+
self.pooler = ROIPooler(
|
1285 |
+
output_size=pooler_resolution,
|
1286 |
+
scales=pooler_scales,
|
1287 |
+
sampling_ratio=sampling_ratio,
|
1288 |
+
)
|
1289 |
+
|
1290 |
+
self.res5 = self._build_res5_block(cfg)
|
1291 |
+
if not res5_halve:
|
1292 |
+
"""
|
1293 |
+
Modifications for VG in RoI heads:
|
1294 |
+
1. Change the stride of conv1 and shortcut in Res5.Block1 from 2 to 1
|
1295 |
+
2. Modifying all conv2 with (padding: 1 --> 2) and (dilation: 1 --> 2)
|
1296 |
+
"""
|
1297 |
+
self.res5[0].conv1.stride = (1, 1)
|
1298 |
+
self.res5[0].shortcut.stride = (1, 1)
|
1299 |
+
for i in range(3):
|
1300 |
+
self.res5[i].conv2.padding = (2, 2)
|
1301 |
+
self.res5[i].conv2.dilation = (2, 2)
|
1302 |
+
|
1303 |
+
self.box_predictor = FastRCNNOutputLayers(
|
1304 |
+
self.out_channels,
|
1305 |
+
self.num_classes,
|
1306 |
+
self.cls_agnostic_bbox_reg,
|
1307 |
+
use_attr=use_attr,
|
1308 |
+
num_attrs=num_attrs,
|
1309 |
+
)
|
1310 |
+
|
1311 |
+
def _build_res5_block(self, cfg):
|
1312 |
+
stage_channel_factor = self.stage_channel_factor # res5 is 8x res2
|
1313 |
+
num_groups = cfg.RESNETS.NUM_GROUPS
|
1314 |
+
width_per_group = cfg.RESNETS.WIDTH_PER_GROUP
|
1315 |
+
bottleneck_channels = num_groups * width_per_group * stage_channel_factor
|
1316 |
+
out_channels = self.out_channels
|
1317 |
+
stride_in_1x1 = cfg.RESNETS.STRIDE_IN_1X1
|
1318 |
+
norm = cfg.RESNETS.NORM
|
1319 |
+
|
1320 |
+
blocks = ResNet.make_stage(
|
1321 |
+
BottleneckBlock,
|
1322 |
+
3,
|
1323 |
+
first_stride=2,
|
1324 |
+
in_channels=out_channels // 2,
|
1325 |
+
bottleneck_channels=bottleneck_channels,
|
1326 |
+
out_channels=out_channels,
|
1327 |
+
num_groups=num_groups,
|
1328 |
+
norm=norm,
|
1329 |
+
stride_in_1x1=stride_in_1x1,
|
1330 |
+
)
|
1331 |
+
return nn.Sequential(*blocks)
|
1332 |
+
|
1333 |
+
def _shared_roi_transform(self, features, boxes):
|
1334 |
+
x = self.pooler(features, boxes)
|
1335 |
+
return self.res5(x)
|
1336 |
+
|
1337 |
+
def forward(self, features, proposal_boxes, gt_boxes=None):
|
1338 |
+
if self.training:
|
1339 |
+
"""
|
1340 |
+
see https://github.com/airsplay/py-bottom-up-attention/\
|
1341 |
+
blob/master/detectron2/modeling/roi_heads/roi_heads.py
|
1342 |
+
"""
|
1343 |
+
raise NotImplementedError()
|
1344 |
+
|
1345 |
+
assert not proposal_boxes[0].requires_grad
|
1346 |
+
box_features = self._shared_roi_transform(features, proposal_boxes)
|
1347 |
+
feature_pooled = box_features.mean(dim=[2, 3]) # pooled to 1x1
|
1348 |
+
obj_logits, attr_logits, pred_proposal_deltas = self.box_predictor(feature_pooled)
|
1349 |
+
return obj_logits, attr_logits, pred_proposal_deltas, feature_pooled
|
1350 |
+
|
1351 |
+
|
1352 |
+
class AnchorGenerator(nn.Module):
|
1353 |
+
"""
|
1354 |
+
For a set of image sizes and feature maps, computes a set of anchors.
|
1355 |
+
"""
|
1356 |
+
|
1357 |
+
def __init__(self, cfg, input_shape: List[ShapeSpec]):
|
1358 |
+
super().__init__()
|
1359 |
+
sizes = cfg.ANCHOR_GENERATOR.SIZES
|
1360 |
+
aspect_ratios = cfg.ANCHOR_GENERATOR.ASPECT_RATIOS
|
1361 |
+
self.strides = [x.stride for x in input_shape]
|
1362 |
+
self.offset = cfg.ANCHOR_GENERATOR.OFFSET
|
1363 |
+
assert 0.0 <= self.offset < 1.0, self.offset
|
1364 |
+
|
1365 |
+
"""
|
1366 |
+
sizes (list[list[int]]): sizes[i] is the list of anchor sizes for feat map i
|
1367 |
+
1. given in absolute lengths in units of the input image;
|
1368 |
+
2. they do not dynamically scale if the input image size changes.
|
1369 |
+
aspect_ratios (list[list[float]])
|
1370 |
+
strides (list[int]): stride of each input feature.
|
1371 |
+
"""
|
1372 |
+
|
1373 |
+
self.num_features = len(self.strides)
|
1374 |
+
self.cell_anchors = nn.ParameterList(self._calculate_anchors(sizes, aspect_ratios))
|
1375 |
+
self._spacial_feat_dim = 4
|
1376 |
+
|
1377 |
+
def _calculate_anchors(self, sizes, aspect_ratios):
|
1378 |
+
# If one size (or aspect ratio) is specified and there are multiple feature
|
1379 |
+
# maps, then we "broadcast" anchors of that single size (or aspect ratio)
|
1380 |
+
if len(sizes) == 1:
|
1381 |
+
sizes *= self.num_features
|
1382 |
+
if len(aspect_ratios) == 1:
|
1383 |
+
aspect_ratios *= self.num_features
|
1384 |
+
assert self.num_features == len(sizes)
|
1385 |
+
assert self.num_features == len(aspect_ratios)
|
1386 |
+
|
1387 |
+
cell_anchors = [self.generate_cell_anchors(s, a).float() for s, a in zip(sizes, aspect_ratios)]
|
1388 |
+
|
1389 |
+
return cell_anchors
|
1390 |
+
|
1391 |
+
@property
|
1392 |
+
def box_dim(self):
|
1393 |
+
return self._spacial_feat_dim
|
1394 |
+
|
1395 |
+
@property
|
1396 |
+
def num_cell_anchors(self):
|
1397 |
+
"""
|
1398 |
+
Returns:
|
1399 |
+
list[int]: Each int is the number of anchors at every pixel location, on that feature map.
|
1400 |
+
"""
|
1401 |
+
return [len(cell_anchors) for cell_anchors in self.cell_anchors]
|
1402 |
+
|
1403 |
+
def grid_anchors(self, grid_sizes):
|
1404 |
+
anchors = []
|
1405 |
+
for (size, stride, base_anchors) in zip(grid_sizes, self.strides, self.cell_anchors):
|
1406 |
+
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors.device)
|
1407 |
+
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
|
1408 |
+
|
1409 |
+
anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
|
1410 |
+
|
1411 |
+
return anchors
|
1412 |
+
|
1413 |
+
def generate_cell_anchors(self, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)):
|
1414 |
+
"""
|
1415 |
+
anchors are continuous geometric rectangles
|
1416 |
+
centered on one feature map point sample.
|
1417 |
+
We can later build the set of anchors
|
1418 |
+
for the entire feature map by tiling these tensors
|
1419 |
+
"""
|
1420 |
+
|
1421 |
+
anchors = []
|
1422 |
+
for size in sizes:
|
1423 |
+
area = size ** 2.0
|
1424 |
+
for aspect_ratio in aspect_ratios:
|
1425 |
+
w = math.sqrt(area / aspect_ratio)
|
1426 |
+
h = aspect_ratio * w
|
1427 |
+
x0, y0, x1, y1 = -w / 2.0, -h / 2.0, w / 2.0, h / 2.0
|
1428 |
+
anchors.append([x0, y0, x1, y1])
|
1429 |
+
return nn.Parameter(torch.Tensor(anchors))
|
1430 |
+
|
1431 |
+
def forward(self, features):
|
1432 |
+
"""
|
1433 |
+
Args:
|
1434 |
+
features List[torch.Tensor]: list of feature maps on which to generate anchors.
|
1435 |
+
Returns:
|
1436 |
+
torch.Tensor: a list of #image elements.
|
1437 |
+
"""
|
1438 |
+
num_images = features[0].size(0)
|
1439 |
+
grid_sizes = [feature_map.shape[-2:] for feature_map in features]
|
1440 |
+
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
|
1441 |
+
anchors_over_all_feature_maps = torch.stack(anchors_over_all_feature_maps)
|
1442 |
+
return anchors_over_all_feature_maps.unsqueeze(0).repeat_interleave(num_images, dim=0)
|
1443 |
+
|
1444 |
+
|
1445 |
+
class RPNHead(nn.Module):
|
1446 |
+
"""
|
1447 |
+
RPN classification and regression heads. Uses a 3x3 conv to produce a shared
|
1448 |
+
hidden state from which one 1x1 conv predicts objectness logits for each anchor
|
1449 |
+
and a second 1x1 conv predicts bounding-box deltas specifying how to deform
|
1450 |
+
each anchor into an object proposal.
|
1451 |
+
"""
|
1452 |
+
|
1453 |
+
def __init__(self, cfg, input_shape: List[ShapeSpec]):
|
1454 |
+
super().__init__()
|
1455 |
+
|
1456 |
+
# Standard RPN is shared across levels:
|
1457 |
+
in_channels = [s.channels for s in input_shape]
|
1458 |
+
assert len(set(in_channels)) == 1, "Each level must have the same channel!"
|
1459 |
+
in_channels = in_channels[0]
|
1460 |
+
|
1461 |
+
anchor_generator = AnchorGenerator(cfg, input_shape)
|
1462 |
+
num_cell_anchors = anchor_generator.num_cell_anchors
|
1463 |
+
box_dim = anchor_generator.box_dim
|
1464 |
+
assert len(set(num_cell_anchors)) == 1, "Each level must have the same number of cell anchors"
|
1465 |
+
num_cell_anchors = num_cell_anchors[0]
|
1466 |
+
|
1467 |
+
if cfg.PROPOSAL_GENERATOR.HIDDEN_CHANNELS == -1:
|
1468 |
+
hid_channels = in_channels
|
1469 |
+
else:
|
1470 |
+
hid_channels = cfg.PROPOSAL_GENERATOR.HIDDEN_CHANNELS
|
1471 |
+
# Modifications for VG in RPN (modeling/proposal_generator/rpn.py)
|
1472 |
+
# Use hidden dim instead fo the same dim as Res4 (in_channels)
|
1473 |
+
|
1474 |
+
# 3x3 conv for the hidden representation
|
1475 |
+
self.conv = nn.Conv2d(in_channels, hid_channels, kernel_size=3, stride=1, padding=1)
|
1476 |
+
# 1x1 conv for predicting objectness logits
|
1477 |
+
self.objectness_logits = nn.Conv2d(hid_channels, num_cell_anchors, kernel_size=1, stride=1)
|
1478 |
+
# 1x1 conv for predicting box2box transform deltas
|
1479 |
+
self.anchor_deltas = nn.Conv2d(hid_channels, num_cell_anchors * box_dim, kernel_size=1, stride=1)
|
1480 |
+
|
1481 |
+
for layer in [self.conv, self.objectness_logits, self.anchor_deltas]:
|
1482 |
+
nn.init.normal_(layer.weight, std=0.01)
|
1483 |
+
nn.init.constant_(layer.bias, 0)
|
1484 |
+
|
1485 |
+
def forward(self, features):
|
1486 |
+
"""
|
1487 |
+
Args:
|
1488 |
+
features (list[Tensor]): list of feature maps
|
1489 |
+
"""
|
1490 |
+
pred_objectness_logits = []
|
1491 |
+
pred_anchor_deltas = []
|
1492 |
+
for x in features:
|
1493 |
+
t = F.relu(self.conv(x))
|
1494 |
+
pred_objectness_logits.append(self.objectness_logits(t))
|
1495 |
+
pred_anchor_deltas.append(self.anchor_deltas(t))
|
1496 |
+
return pred_objectness_logits, pred_anchor_deltas
|
1497 |
+
|
1498 |
+
|
1499 |
+
class RPN(nn.Module):
|
1500 |
+
"""
|
1501 |
+
Region Proposal Network, introduced by the Faster R-CNN paper.
|
1502 |
+
"""
|
1503 |
+
|
1504 |
+
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
|
1505 |
+
super().__init__()
|
1506 |
+
|
1507 |
+
self.min_box_side_len = cfg.PROPOSAL_GENERATOR.MIN_SIZE
|
1508 |
+
self.in_features = cfg.RPN.IN_FEATURES
|
1509 |
+
self.nms_thresh = cfg.RPN.NMS_THRESH
|
1510 |
+
self.batch_size_per_image = cfg.RPN.BATCH_SIZE_PER_IMAGE
|
1511 |
+
self.positive_fraction = cfg.RPN.POSITIVE_FRACTION
|
1512 |
+
self.smooth_l1_beta = cfg.RPN.SMOOTH_L1_BETA
|
1513 |
+
self.loss_weight = cfg.RPN.LOSS_WEIGHT
|
1514 |
+
|
1515 |
+
self.pre_nms_topk = {
|
1516 |
+
True: cfg.RPN.PRE_NMS_TOPK_TRAIN,
|
1517 |
+
False: cfg.RPN.PRE_NMS_TOPK_TEST,
|
1518 |
+
}
|
1519 |
+
self.post_nms_topk = {
|
1520 |
+
True: cfg.RPN.POST_NMS_TOPK_TRAIN,
|
1521 |
+
False: cfg.RPN.POST_NMS_TOPK_TEST,
|
1522 |
+
}
|
1523 |
+
self.boundary_threshold = cfg.RPN.BOUNDARY_THRESH
|
1524 |
+
|
1525 |
+
self.anchor_generator = AnchorGenerator(cfg, [input_shape[f] for f in self.in_features])
|
1526 |
+
self.box2box_transform = Box2BoxTransform(weights=cfg.RPN.BBOX_REG_WEIGHTS)
|
1527 |
+
self.anchor_matcher = Matcher(
|
1528 |
+
cfg.RPN.IOU_THRESHOLDS,
|
1529 |
+
cfg.RPN.IOU_LABELS,
|
1530 |
+
allow_low_quality_matches=True,
|
1531 |
+
)
|
1532 |
+
self.rpn_head = RPNHead(cfg, [input_shape[f] for f in self.in_features])
|
1533 |
+
|
1534 |
+
def training(self, images, image_shapes, features, gt_boxes):
|
1535 |
+
pass
|
1536 |
+
|
1537 |
+
def inference(self, outputs, images, image_shapes, features, gt_boxes=None):
|
1538 |
+
outputs = find_top_rpn_proposals(
|
1539 |
+
outputs.predict_proposals(),
|
1540 |
+
outputs.predict_objectness_logits(),
|
1541 |
+
images,
|
1542 |
+
image_shapes,
|
1543 |
+
self.nms_thresh,
|
1544 |
+
self.pre_nms_topk[self.training],
|
1545 |
+
self.post_nms_topk[self.training],
|
1546 |
+
self.min_box_side_len,
|
1547 |
+
self.training,
|
1548 |
+
)
|
1549 |
+
|
1550 |
+
results = []
|
1551 |
+
for img in outputs:
|
1552 |
+
im_boxes, img_box_logits = img
|
1553 |
+
img_box_logits, inds = img_box_logits.sort(descending=True)
|
1554 |
+
im_boxes = im_boxes[inds]
|
1555 |
+
results.append((im_boxes, img_box_logits))
|
1556 |
+
|
1557 |
+
(proposal_boxes, logits) = tuple(map(list, zip(*results)))
|
1558 |
+
return proposal_boxes, logits
|
1559 |
+
|
1560 |
+
def forward(self, images, image_shapes, features, gt_boxes=None):
|
1561 |
+
"""
|
1562 |
+
Args:
|
1563 |
+
images (torch.Tensor): input images of length `N`
|
1564 |
+
features (dict[str: Tensor])
|
1565 |
+
gt_instances
|
1566 |
+
"""
|
1567 |
+
# features is dict, key = block level, v = feature_map
|
1568 |
+
features = [features[f] for f in self.in_features]
|
1569 |
+
pred_objectness_logits, pred_anchor_deltas = self.rpn_head(features)
|
1570 |
+
anchors = self.anchor_generator(features)
|
1571 |
+
outputs = RPNOutputs(
|
1572 |
+
self.box2box_transform,
|
1573 |
+
self.anchor_matcher,
|
1574 |
+
self.batch_size_per_image,
|
1575 |
+
self.positive_fraction,
|
1576 |
+
images,
|
1577 |
+
pred_objectness_logits,
|
1578 |
+
pred_anchor_deltas,
|
1579 |
+
anchors,
|
1580 |
+
self.boundary_threshold,
|
1581 |
+
gt_boxes,
|
1582 |
+
self.smooth_l1_beta,
|
1583 |
+
)
|
1584 |
+
# For RPN-only models, the proposals are the final output
|
1585 |
+
|
1586 |
+
if self.training:
|
1587 |
+
raise NotImplementedError()
|
1588 |
+
return self.training(outputs, images, image_shapes, features, gt_boxes)
|
1589 |
+
else:
|
1590 |
+
return self.inference(outputs, images, image_shapes, features, gt_boxes)
|
1591 |
+
|
1592 |
+
|
1593 |
+
class FastRCNNOutputLayers(nn.Module):
|
1594 |
+
"""
|
1595 |
+
Two linear layers for predicting Fast R-CNN outputs:
|
1596 |
+
(1) proposal-to-detection box regression deltas
|
1597 |
+
(2) classification scores
|
1598 |
+
"""
|
1599 |
+
|
1600 |
+
def __init__(
|
1601 |
+
self,
|
1602 |
+
input_size,
|
1603 |
+
num_classes,
|
1604 |
+
cls_agnostic_bbox_reg,
|
1605 |
+
box_dim=4,
|
1606 |
+
use_attr=False,
|
1607 |
+
num_attrs=-1,
|
1608 |
+
):
|
1609 |
+
"""
|
1610 |
+
Args:
|
1611 |
+
input_size (int): channels, or (channels, height, width)
|
1612 |
+
num_classes (int)
|
1613 |
+
cls_agnostic_bbox_reg (bool)
|
1614 |
+
box_dim (int)
|
1615 |
+
"""
|
1616 |
+
super().__init__()
|
1617 |
+
|
1618 |
+
if not isinstance(input_size, int):
|
1619 |
+
input_size = np.prod(input_size)
|
1620 |
+
|
1621 |
+
# (do + 1 for background class)
|
1622 |
+
self.cls_score = nn.Linear(input_size, num_classes + 1)
|
1623 |
+
num_bbox_reg_classes = 1 if cls_agnostic_bbox_reg else num_classes
|
1624 |
+
self.bbox_pred = nn.Linear(input_size, num_bbox_reg_classes * box_dim)
|
1625 |
+
|
1626 |
+
self.use_attr = use_attr
|
1627 |
+
if use_attr:
|
1628 |
+
"""
|
1629 |
+
Modifications for VG in RoI heads
|
1630 |
+
Embedding: {num_classes + 1} --> {input_size // 8}
|
1631 |
+
Linear: {input_size + input_size // 8} --> {input_size // 4}
|
1632 |
+
Linear: {input_size // 4} --> {num_attrs + 1}
|
1633 |
+
"""
|
1634 |
+
self.cls_embedding = nn.Embedding(num_classes + 1, input_size // 8)
|
1635 |
+
self.fc_attr = nn.Linear(input_size + input_size // 8, input_size // 4)
|
1636 |
+
self.attr_score = nn.Linear(input_size // 4, num_attrs + 1)
|
1637 |
+
|
1638 |
+
nn.init.normal_(self.cls_score.weight, std=0.01)
|
1639 |
+
nn.init.normal_(self.bbox_pred.weight, std=0.001)
|
1640 |
+
for item in [self.cls_score, self.bbox_pred]:
|
1641 |
+
nn.init.constant_(item.bias, 0)
|
1642 |
+
|
1643 |
+
def forward(self, roi_features):
|
1644 |
+
if roi_features.dim() > 2:
|
1645 |
+
roi_features = torch.flatten(roi_features, start_dim=1)
|
1646 |
+
scores = self.cls_score(roi_features)
|
1647 |
+
proposal_deltas = self.bbox_pred(roi_features)
|
1648 |
+
if self.use_attr:
|
1649 |
+
_, max_class = scores.max(-1) # [b, c] --> [b]
|
1650 |
+
cls_emb = self.cls_embedding(max_class) # [b] --> [b, 256]
|
1651 |
+
roi_features = torch.cat([roi_features, cls_emb], -1) # [b, 2048] + [b, 256] --> [b, 2304]
|
1652 |
+
roi_features = self.fc_attr(roi_features)
|
1653 |
+
roi_features = F.relu(roi_features)
|
1654 |
+
attr_scores = self.attr_score(roi_features)
|
1655 |
+
return scores, attr_scores, proposal_deltas
|
1656 |
+
else:
|
1657 |
+
return scores, proposal_deltas
|
1658 |
+
|
1659 |
+
|
1660 |
+
class GeneralizedRCNN(nn.Module):
|
1661 |
+
def __init__(self, cfg):
|
1662 |
+
super().__init__()
|
1663 |
+
|
1664 |
+
self.device = torch.device(cfg.MODEL.DEVICE)
|
1665 |
+
self.backbone = build_backbone(cfg)
|
1666 |
+
self.proposal_generator = RPN(cfg, self.backbone.output_shape())
|
1667 |
+
self.roi_heads = Res5ROIHeads(cfg, self.backbone.output_shape())
|
1668 |
+
self.roi_outputs = ROIOutputs(cfg)
|
1669 |
+
self.to(self.device)
|
1670 |
+
|
1671 |
+
@classmethod
|
1672 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
1673 |
+
config = kwargs.pop("config", None)
|
1674 |
+
state_dict = kwargs.pop("state_dict", None)
|
1675 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
1676 |
+
from_tf = kwargs.pop("from_tf", False)
|
1677 |
+
force_download = kwargs.pop("force_download", False)
|
1678 |
+
resume_download = kwargs.pop("resume_download", False)
|
1679 |
+
proxies = kwargs.pop("proxies", None)
|
1680 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
1681 |
+
use_cdn = kwargs.pop("use_cdn", True)
|
1682 |
+
|
1683 |
+
# Load config if we don't provide a configuration
|
1684 |
+
if not isinstance(config, Config):
|
1685 |
+
config_path = config if config is not None else pretrained_model_name_or_path
|
1686 |
+
# try:
|
1687 |
+
config = Config.from_pretrained(
|
1688 |
+
config_path,
|
1689 |
+
cache_dir=cache_dir,
|
1690 |
+
force_download=force_download,
|
1691 |
+
resume_download=resume_download,
|
1692 |
+
proxies=proxies,
|
1693 |
+
local_files_only=local_files_only,
|
1694 |
+
)
|
1695 |
+
|
1696 |
+
# Load model
|
1697 |
+
if pretrained_model_name_or_path is not None:
|
1698 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
1699 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
1700 |
+
# Load from a PyTorch checkpoint
|
1701 |
+
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
1702 |
+
else:
|
1703 |
+
raise EnvironmentError(
|
1704 |
+
"Error no file named {} found in directory {} ".format(
|
1705 |
+
WEIGHTS_NAME,
|
1706 |
+
pretrained_model_name_or_path,
|
1707 |
+
)
|
1708 |
+
)
|
1709 |
+
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
1710 |
+
archive_file = pretrained_model_name_or_path
|
1711 |
+
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
1712 |
+
assert (
|
1713 |
+
from_tf
|
1714 |
+
), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
|
1715 |
+
pretrained_model_name_or_path + ".index"
|
1716 |
+
)
|
1717 |
+
archive_file = pretrained_model_name_or_path + ".index"
|
1718 |
+
else:
|
1719 |
+
archive_file = hf_bucket_url(
|
1720 |
+
pretrained_model_name_or_path,
|
1721 |
+
filename=WEIGHTS_NAME,
|
1722 |
+
use_cdn=use_cdn,
|
1723 |
+
)
|
1724 |
+
|
1725 |
+
try:
|
1726 |
+
# Load from URL or cache if already cached
|
1727 |
+
resolved_archive_file = cached_path(
|
1728 |
+
archive_file,
|
1729 |
+
cache_dir=cache_dir,
|
1730 |
+
force_download=force_download,
|
1731 |
+
proxies=proxies,
|
1732 |
+
resume_download=resume_download,
|
1733 |
+
local_files_only=local_files_only,
|
1734 |
+
)
|
1735 |
+
if resolved_archive_file is None:
|
1736 |
+
raise EnvironmentError
|
1737 |
+
except EnvironmentError:
|
1738 |
+
msg = f"Can't load weights for '{pretrained_model_name_or_path}'."
|
1739 |
+
raise EnvironmentError(msg)
|
1740 |
+
|
1741 |
+
if resolved_archive_file == archive_file:
|
1742 |
+
print("loading weights file {}".format(archive_file))
|
1743 |
+
else:
|
1744 |
+
print("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
|
1745 |
+
else:
|
1746 |
+
resolved_archive_file = None
|
1747 |
+
|
1748 |
+
# Instantiate model.
|
1749 |
+
model = cls(config)
|
1750 |
+
|
1751 |
+
if state_dict is None:
|
1752 |
+
try:
|
1753 |
+
try:
|
1754 |
+
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
1755 |
+
except Exception:
|
1756 |
+
state_dict = load_checkpoint(resolved_archive_file)
|
1757 |
+
|
1758 |
+
except Exception:
|
1759 |
+
raise OSError(
|
1760 |
+
"Unable to load weights from pytorch checkpoint file. "
|
1761 |
+
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
|
1762 |
+
)
|
1763 |
+
|
1764 |
+
missing_keys = []
|
1765 |
+
unexpected_keys = []
|
1766 |
+
error_msgs = []
|
1767 |
+
|
1768 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
1769 |
+
old_keys = []
|
1770 |
+
new_keys = []
|
1771 |
+
for key in state_dict.keys():
|
1772 |
+
new_key = None
|
1773 |
+
if "gamma" in key:
|
1774 |
+
new_key = key.replace("gamma", "weight")
|
1775 |
+
if "beta" in key:
|
1776 |
+
new_key = key.replace("beta", "bias")
|
1777 |
+
if new_key:
|
1778 |
+
old_keys.append(key)
|
1779 |
+
new_keys.append(new_key)
|
1780 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
1781 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
1782 |
+
|
1783 |
+
# copy state_dict so _load_from_state_dict can modify it
|
1784 |
+
metadata = getattr(state_dict, "_metadata", None)
|
1785 |
+
state_dict = state_dict.copy()
|
1786 |
+
if metadata is not None:
|
1787 |
+
state_dict._metadata = metadata
|
1788 |
+
|
1789 |
+
model_to_load = model
|
1790 |
+
model_to_load.load_state_dict(state_dict)
|
1791 |
+
|
1792 |
+
if model.__class__.__name__ != model_to_load.__class__.__name__:
|
1793 |
+
base_model_state_dict = model_to_load.state_dict().keys()
|
1794 |
+
head_model_state_dict_without_base_prefix = [
|
1795 |
+
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
|
1796 |
+
]
|
1797 |
+
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
|
1798 |
+
|
1799 |
+
if len(unexpected_keys) > 0:
|
1800 |
+
print(
|
1801 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
1802 |
+
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
1803 |
+
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
1804 |
+
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
|
1805 |
+
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
1806 |
+
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
1807 |
+
)
|
1808 |
+
else:
|
1809 |
+
print(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
1810 |
+
if len(missing_keys) > 0:
|
1811 |
+
print(
|
1812 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
1813 |
+
f"and are newly initialized: {missing_keys}\n"
|
1814 |
+
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
1815 |
+
)
|
1816 |
+
else:
|
1817 |
+
print(
|
1818 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
1819 |
+
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
1820 |
+
f"you can already use {model.__class__.__name__} for predictions without further training."
|
1821 |
+
)
|
1822 |
+
if len(error_msgs) > 0:
|
1823 |
+
raise RuntimeError(
|
1824 |
+
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
1825 |
+
model.__class__.__name__, "\n\t".join(error_msgs)
|
1826 |
+
)
|
1827 |
+
)
|
1828 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
1829 |
+
model.eval()
|
1830 |
+
|
1831 |
+
return model
|
1832 |
+
|
1833 |
+
def forward(
|
1834 |
+
self,
|
1835 |
+
images,
|
1836 |
+
image_shapes,
|
1837 |
+
gt_boxes=None,
|
1838 |
+
proposals=None,
|
1839 |
+
scales_yx=None,
|
1840 |
+
**kwargs,
|
1841 |
+
):
|
1842 |
+
"""
|
1843 |
+
kwargs:
|
1844 |
+
max_detections (int), return_tensors {"np", "pt", None}, padding {None,
|
1845 |
+
"max_detections"}, pad_value (int), location = {"cuda", "cpu"}
|
1846 |
+
"""
|
1847 |
+
if self.training:
|
1848 |
+
raise NotImplementedError()
|
1849 |
+
return self.inference(
|
1850 |
+
images=images,
|
1851 |
+
image_shapes=image_shapes,
|
1852 |
+
gt_boxes=gt_boxes,
|
1853 |
+
proposals=proposals,
|
1854 |
+
scales_yx=scales_yx,
|
1855 |
+
**kwargs,
|
1856 |
+
)
|
1857 |
+
|
1858 |
+
@torch.no_grad()
|
1859 |
+
def inference(
|
1860 |
+
self,
|
1861 |
+
images,
|
1862 |
+
image_shapes,
|
1863 |
+
gt_boxes=None,
|
1864 |
+
proposals=None,
|
1865 |
+
scales_yx=None,
|
1866 |
+
**kwargs,
|
1867 |
+
):
|
1868 |
+
# run images through backbone
|
1869 |
+
original_sizes = image_shapes * scales_yx
|
1870 |
+
features = self.backbone(images)
|
1871 |
+
|
1872 |
+
# generate proposals if none are available
|
1873 |
+
if proposals is None:
|
1874 |
+
proposal_boxes, _ = self.proposal_generator(images, image_shapes, features, gt_boxes)
|
1875 |
+
else:
|
1876 |
+
assert proposals is not None
|
1877 |
+
|
1878 |
+
# pool object features from either gt_boxes, or from proposals
|
1879 |
+
obj_logits, attr_logits, box_deltas, feature_pooled = self.roi_heads(features, proposal_boxes, gt_boxes)
|
1880 |
+
|
1881 |
+
# prepare FRCNN Outputs and select top proposals
|
1882 |
+
boxes, classes, class_probs, attrs, attr_probs, roi_features = self.roi_outputs(
|
1883 |
+
obj_logits=obj_logits,
|
1884 |
+
attr_logits=attr_logits,
|
1885 |
+
box_deltas=box_deltas,
|
1886 |
+
pred_boxes=proposal_boxes,
|
1887 |
+
features=feature_pooled,
|
1888 |
+
sizes=image_shapes,
|
1889 |
+
scales=scales_yx,
|
1890 |
+
)
|
1891 |
+
|
1892 |
+
# will we pad???
|
1893 |
+
subset_kwargs = {
|
1894 |
+
"max_detections": kwargs.get("max_detections", None),
|
1895 |
+
"return_tensors": kwargs.get("return_tensors", None),
|
1896 |
+
"pad_value": kwargs.get("pad_value", 0),
|
1897 |
+
"padding": kwargs.get("padding", None),
|
1898 |
+
}
|
1899 |
+
preds_per_image = torch.tensor([p.size(0) for p in boxes])
|
1900 |
+
boxes = pad_list_tensors(boxes, preds_per_image, **subset_kwargs)
|
1901 |
+
classes = pad_list_tensors(classes, preds_per_image, **subset_kwargs)
|
1902 |
+
class_probs = pad_list_tensors(class_probs, preds_per_image, **subset_kwargs)
|
1903 |
+
attrs = pad_list_tensors(attrs, preds_per_image, **subset_kwargs)
|
1904 |
+
attr_probs = pad_list_tensors(attr_probs, preds_per_image, **subset_kwargs)
|
1905 |
+
roi_features = pad_list_tensors(roi_features, preds_per_image, **subset_kwargs)
|
1906 |
+
subset_kwargs["padding"] = None
|
1907 |
+
preds_per_image = pad_list_tensors(preds_per_image, None, **subset_kwargs)
|
1908 |
+
sizes = pad_list_tensors(image_shapes, None, **subset_kwargs)
|
1909 |
+
normalized_boxes = norm_box(boxes, original_sizes)
|
1910 |
+
return OrderedDict(
|
1911 |
+
{
|
1912 |
+
"obj_ids": classes,
|
1913 |
+
"obj_probs": class_probs,
|
1914 |
+
"attr_ids": attrs,
|
1915 |
+
"attr_probs": attr_probs,
|
1916 |
+
"boxes": boxes,
|
1917 |
+
"sizes": sizes,
|
1918 |
+
"preds_per_image": preds_per_image,
|
1919 |
+
"roi_features": roi_features,
|
1920 |
+
"normalized_boxes": normalized_boxes,
|
1921 |
+
}
|
1922 |
+
)
|
lxmert/src/param.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyleft 2019 project LXRT.
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import random
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
def get_optimizer(optim):
|
12 |
+
# Bind the optimizer
|
13 |
+
if optim == 'rms':
|
14 |
+
print("Optimizer: Using RMSProp")
|
15 |
+
optimizer = torch.optim.RMSprop
|
16 |
+
elif optim == 'adam':
|
17 |
+
print("Optimizer: Using Adam")
|
18 |
+
optimizer = torch.optim.Adam
|
19 |
+
elif optim == 'adamax':
|
20 |
+
print("Optimizer: Using Adamax")
|
21 |
+
optimizer = torch.optim.Adamax
|
22 |
+
elif optim == 'sgd':
|
23 |
+
print("Optimizer: sgd")
|
24 |
+
optimizer = torch.optim.SGD
|
25 |
+
elif 'bert' in optim:
|
26 |
+
optimizer = 'bert' # The bert optimizer will be bind later.
|
27 |
+
else:
|
28 |
+
assert False, "Please add your optimizer %s in the list." % optim
|
29 |
+
|
30 |
+
return optimizer
|
31 |
+
|
32 |
+
|
33 |
+
def parse_args():
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
|
36 |
+
# Data Splits
|
37 |
+
parser.add_argument("--train", default='train')
|
38 |
+
parser.add_argument("--valid", default='valid')
|
39 |
+
parser.add_argument("--test", default=None)
|
40 |
+
|
41 |
+
# Training Hyper-parameters
|
42 |
+
parser.add_argument('--batchSize', dest='batch_size', type=int, default=256)
|
43 |
+
parser.add_argument('--optim', default='bert')
|
44 |
+
parser.add_argument('--lr', type=float, default=1e-4)
|
45 |
+
parser.add_argument('--epochs', type=int, default=10)
|
46 |
+
parser.add_argument('--dropout', type=float, default=0.1)
|
47 |
+
parser.add_argument('--seed', type=int, default=9595, help='random seed')
|
48 |
+
|
49 |
+
# Debugging
|
50 |
+
parser.add_argument('--output', type=str, default='snap/test')
|
51 |
+
parser.add_argument("--fast", action='store_const', default=False, const=True)
|
52 |
+
parser.add_argument("--tiny", action='store_const', default=False, const=True)
|
53 |
+
parser.add_argument("--tqdm", action='store_const', default=False, const=True)
|
54 |
+
|
55 |
+
# Model Loading
|
56 |
+
parser.add_argument('--load', type=str, default=None,
|
57 |
+
help='Load the model (usually the fine-tuned model).')
|
58 |
+
parser.add_argument('--loadLXMERT', dest='load_lxmert', type=str, default=None,
|
59 |
+
help='Load the pre-trained lxmert model.')
|
60 |
+
parser.add_argument('--loadLXMERTQA', dest='load_lxmert_qa', type=str, default=None,
|
61 |
+
help='Load the pre-trained lxmert model with QA answer head.')
|
62 |
+
parser.add_argument("--fromScratch", dest='from_scratch', action='store_const', default=False, const=True,
|
63 |
+
help='If none of the --load, --loadLXMERT, --loadLXMERTQA is set, '
|
64 |
+
'the model would be trained from scratch. If --fromScratch is'
|
65 |
+
' not specified, the model would load BERT-pre-trained weights by'
|
66 |
+
' default. ')
|
67 |
+
|
68 |
+
# Optimization
|
69 |
+
parser.add_argument("--mceLoss", dest='mce_loss', action='store_const', default=False, const=True)
|
70 |
+
|
71 |
+
# LXRT Model Config
|
72 |
+
# Note: LXRT = L, X, R (three encoders), Transformer
|
73 |
+
parser.add_argument("--llayers", default=9, type=int, help='Number of Language layers')
|
74 |
+
parser.add_argument("--xlayers", default=5, type=int, help='Number of CROSS-modality layers.')
|
75 |
+
parser.add_argument("--rlayers", default=5, type=int, help='Number of object Relationship layers.')
|
76 |
+
|
77 |
+
# lxmert Pre-training Config
|
78 |
+
parser.add_argument("--taskMatched", dest='task_matched', action='store_const', default=False, const=True)
|
79 |
+
parser.add_argument("--taskMaskLM", dest='task_mask_lm', action='store_const', default=False, const=True)
|
80 |
+
parser.add_argument("--taskObjPredict", dest='task_obj_predict', action='store_const', default=False, const=True)
|
81 |
+
parser.add_argument("--taskQA", dest='task_qa', action='store_const', default=False, const=True)
|
82 |
+
parser.add_argument("--visualLosses", dest='visual_losses', default='obj,attr,feat', type=str)
|
83 |
+
parser.add_argument("--qaSets", dest='qa_sets', default=None, type=str)
|
84 |
+
parser.add_argument("--wordMaskRate", dest='word_mask_rate', default=0.15, type=float)
|
85 |
+
parser.add_argument("--objMaskRate", dest='obj_mask_rate', default=0.15, type=float)
|
86 |
+
|
87 |
+
# Training configuration
|
88 |
+
parser.add_argument("--multiGPU", action='store_const', default=False, const=True)
|
89 |
+
parser.add_argument("--numWorkers", dest='num_workers', default=0)
|
90 |
+
|
91 |
+
|
92 |
+
# perturbation configuration
|
93 |
+
parser.add_argument('--method', type=str,
|
94 |
+
default='ours_no_lrp',
|
95 |
+
choices=['ours_with_lrp', 'rollout', 'partial_lrp', 'transformer_att',
|
96 |
+
'raw_attn', 'attn_gradcam', 'ours_with_lrp_no_normalization', 'ours_no_lrp',
|
97 |
+
'ours_no_lrp_no_norm', 'ablation_no_aggregation', 'ablation_no_self_in_10'],
|
98 |
+
help='')
|
99 |
+
parser.add_argument('--num-samples', type=int,
|
100 |
+
default=10000,
|
101 |
+
help='')
|
102 |
+
parser.add_argument('--is-positive-pert', type=bool,
|
103 |
+
default=False,
|
104 |
+
help='')
|
105 |
+
parser.add_argument('--is-text-pert', type=bool,
|
106 |
+
default=False,
|
107 |
+
help='')
|
108 |
+
parser.add_argument('--COCO_path', type=str,
|
109 |
+
default='',
|
110 |
+
help='path to COCO 2014 validation set')
|
111 |
+
|
112 |
+
# Parse the arguments.
|
113 |
+
args = parser.parse_args()
|
114 |
+
|
115 |
+
# Bind optimizer class.
|
116 |
+
args.optimizer = get_optimizer(args.optim)
|
117 |
+
|
118 |
+
# Set seeds
|
119 |
+
torch.manual_seed(args.seed)
|
120 |
+
random.seed(args.seed)
|
121 |
+
np.random.seed(args.seed)
|
122 |
+
|
123 |
+
return args
|
124 |
+
|
125 |
+
|
126 |
+
args = parse_args()
|
lxmert/src/pretrain/__init__.py
ADDED
File without changes
|
lxmert/src/pretrain/lxmert_data.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyleft 2019 project LXRT.
|
3 |
+
|
4 |
+
from collections import defaultdict
|
5 |
+
import json
|
6 |
+
import random
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
|
11 |
+
from param import args
|
12 |
+
from pretrain.qa_answer_table import AnswerTable
|
13 |
+
from utils import load_obj_tsv
|
14 |
+
|
15 |
+
TINY_IMG_NUM = 500
|
16 |
+
FAST_IMG_NUM = 5000
|
17 |
+
|
18 |
+
Split2ImgFeatPath = {
|
19 |
+
'mscoco_train': 'data/mscoco_imgfeat/train2014_obj36.tsv',
|
20 |
+
'mscoco_minival': 'data/mscoco_imgfeat/val2014_obj36.tsv',
|
21 |
+
'mscoco_nominival': 'data/mscoco_imgfeat/val2014_obj36.tsv',
|
22 |
+
'vgnococo': 'data/vg_gqa_imgfeat/vg_gqa_obj36.tsv',
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
class InputExample(object):
|
27 |
+
"""A single training/test example for the language model."""
|
28 |
+
def __init__(self, uid, sent, visual_feats=None,
|
29 |
+
obj_labels=None, attr_labels=None,
|
30 |
+
is_matched=None, label=None):
|
31 |
+
self.uid = uid
|
32 |
+
self.sent = sent
|
33 |
+
self.visual_feats = visual_feats
|
34 |
+
self.obj_labels = obj_labels
|
35 |
+
self.attr_labels = attr_labels
|
36 |
+
self.is_matched = is_matched # whether the visual and obj matched
|
37 |
+
self.label = label
|
38 |
+
|
39 |
+
|
40 |
+
class LXMERTDataset:
|
41 |
+
def __init__(self, splits: str, qa_sets=None):
|
42 |
+
"""
|
43 |
+
:param splits: The data sources to be loaded
|
44 |
+
:param qa_sets: if None, no action
|
45 |
+
o.w., only takes the answers appearing in these dsets
|
46 |
+
and remove all unlabeled data (MSCOCO captions)
|
47 |
+
"""
|
48 |
+
self.name = splits
|
49 |
+
self.sources = splits.split(',')
|
50 |
+
|
51 |
+
# Loading datasets to data
|
52 |
+
self.data = []
|
53 |
+
for source in self.sources:
|
54 |
+
self.data.extend(json.load(open("data/lxmert/%s.json" % source)))
|
55 |
+
print("Load %d data from %s" % (len(self.data), self.name))
|
56 |
+
|
57 |
+
# Create answer table according to the qa_sets
|
58 |
+
self.answer_table = AnswerTable(qa_sets)
|
59 |
+
print("Load an answer table of size %d." % (len(self.answer_table.ans2id_map())))
|
60 |
+
|
61 |
+
# Modify the answers
|
62 |
+
for datum in self.data:
|
63 |
+
labelf = datum['labelf']
|
64 |
+
for cat, labels in labelf.items():
|
65 |
+
for label in labels:
|
66 |
+
for ans in list(label.keys()):
|
67 |
+
new_ans = self.answer_table.convert_ans(ans)
|
68 |
+
if self.answer_table.used(new_ans):
|
69 |
+
if ans != new_ans:
|
70 |
+
label[new_ans] = label.pop(ans)
|
71 |
+
else:
|
72 |
+
label.pop(ans)
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
return len(self.data)
|
76 |
+
|
77 |
+
|
78 |
+
def make_uid(img_id, dset, sent_idx):
|
79 |
+
return "%s_%s_%03d" % (img_id, dset, sent_idx),
|
80 |
+
|
81 |
+
|
82 |
+
"""
|
83 |
+
Example in obj tsv:
|
84 |
+
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
|
85 |
+
"attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
|
86 |
+
"""
|
87 |
+
class LXMERTTorchDataset(Dataset):
|
88 |
+
def __init__(self, dataset: LXMERTDataset, topk=-1):
|
89 |
+
super().__init__()
|
90 |
+
self.raw_dataset = dataset
|
91 |
+
self.task_matched = args.task_matched
|
92 |
+
|
93 |
+
if args.tiny:
|
94 |
+
topk = TINY_IMG_NUM
|
95 |
+
elif args.fast:
|
96 |
+
topk = FAST_IMG_NUM
|
97 |
+
|
98 |
+
# Load the dataset
|
99 |
+
img_data = []
|
100 |
+
for source in self.raw_dataset.sources:
|
101 |
+
img_data.extend(load_obj_tsv(Split2ImgFeatPath[source], topk))
|
102 |
+
|
103 |
+
self.imgid2img = {}
|
104 |
+
for img_datum in img_data:
|
105 |
+
self.imgid2img[img_datum['img_id']] = img_datum
|
106 |
+
|
107 |
+
# Filter out the dataset
|
108 |
+
used_data = []
|
109 |
+
for datum in self.raw_dataset.data:
|
110 |
+
if datum['img_id'] in self.imgid2img:
|
111 |
+
used_data.append(datum)
|
112 |
+
|
113 |
+
# Flatten the dataset (into one sent + one image entries)
|
114 |
+
self.data = []
|
115 |
+
for datum in used_data:
|
116 |
+
sentf = datum['sentf']
|
117 |
+
for sents_cat, sents in sentf.items():
|
118 |
+
if sents_cat in datum['labelf']:
|
119 |
+
labels = datum['labelf'][sents_cat]
|
120 |
+
else:
|
121 |
+
labels = None
|
122 |
+
for sent_idx, sent in enumerate(sents):
|
123 |
+
new_datum = {
|
124 |
+
'uid': make_uid(datum['img_id'], sents_cat, sent_idx),
|
125 |
+
'img_id': datum['img_id'],
|
126 |
+
'sent': sent
|
127 |
+
}
|
128 |
+
if labels is not None:
|
129 |
+
new_datum['label'] = labels[sent_idx]
|
130 |
+
self.data.append(new_datum)
|
131 |
+
print("Use %d data in torch dataset" % (len(self.data)))
|
132 |
+
|
133 |
+
def __len__(self):
|
134 |
+
return len(self.data)
|
135 |
+
|
136 |
+
def random_feat(self):
|
137 |
+
"""Get a random obj feat from the dataset."""
|
138 |
+
datum = self.data[random.randint(0, len(self.data)-1)]
|
139 |
+
img_id = datum['img_id']
|
140 |
+
img_info = self.imgid2img[img_id]
|
141 |
+
feat = img_info['features'][random.randint(0, 35)]
|
142 |
+
return feat
|
143 |
+
|
144 |
+
def __getitem__(self, item: int):
|
145 |
+
datum = self.data[item]
|
146 |
+
|
147 |
+
uid = datum['uid']
|
148 |
+
img_id = datum['img_id']
|
149 |
+
|
150 |
+
# Get image info
|
151 |
+
img_info = self.imgid2img[img_id]
|
152 |
+
obj_num = img_info['num_boxes']
|
153 |
+
feats = img_info['features'].copy()
|
154 |
+
boxes = img_info['boxes'].copy()
|
155 |
+
obj_labels = img_info['objects_id'].copy()
|
156 |
+
obj_confs = img_info['objects_conf'].copy()
|
157 |
+
attr_labels = img_info['attrs_id'].copy()
|
158 |
+
attr_confs = img_info['attrs_conf'].copy()
|
159 |
+
assert obj_num == len(boxes) == len(feats)
|
160 |
+
|
161 |
+
# Normalize the boxes (to 0 ~ 1)
|
162 |
+
img_h, img_w = img_info['img_h'], img_info['img_w']
|
163 |
+
boxes = boxes.copy()
|
164 |
+
boxes[:, (0, 2)] /= img_w
|
165 |
+
boxes[:, (1, 3)] /= img_h
|
166 |
+
np.testing.assert_array_less(boxes, 1+1e-5)
|
167 |
+
np.testing.assert_array_less(-boxes, 0+1e-5)
|
168 |
+
|
169 |
+
# If calculating the matched loss, replace the sentence with an sentence
|
170 |
+
# corresponding to other image.
|
171 |
+
is_matched = 1
|
172 |
+
sent = datum['sent']
|
173 |
+
if self.task_matched:
|
174 |
+
if random.random() < 0.5:
|
175 |
+
is_matched = 0
|
176 |
+
other_datum = self.data[random.randint(0, len(self.data)-1)]
|
177 |
+
while other_datum['img_id'] == img_id:
|
178 |
+
other_datum = self.data[random.randint(0, len(self.data)-1)]
|
179 |
+
sent = other_datum['sent']
|
180 |
+
|
181 |
+
# Label, convert answer to id
|
182 |
+
if 'label' in datum:
|
183 |
+
label = datum['label'].copy()
|
184 |
+
for ans in list(label.keys()):
|
185 |
+
label[self.raw_dataset.answer_table.ans2id(ans)] = label.pop(ans)
|
186 |
+
else:
|
187 |
+
label = None
|
188 |
+
|
189 |
+
# Create target
|
190 |
+
example = InputExample(
|
191 |
+
uid, sent, (feats, boxes),
|
192 |
+
(obj_labels, obj_confs), (attr_labels, attr_confs),
|
193 |
+
is_matched, label
|
194 |
+
)
|
195 |
+
return example
|
196 |
+
|
197 |
+
|
198 |
+
class LXMERTEvaluator:
|
199 |
+
def __init__(self, dataset: LXMERTDataset):
|
200 |
+
self.raw_dataset = dataset
|
201 |
+
|
202 |
+
# Create QA Eval Data
|
203 |
+
self.data = []
|
204 |
+
for datum in self.raw_dataset.data:
|
205 |
+
sentf = datum['sentf']
|
206 |
+
for sents_cat, sents in sentf.items():
|
207 |
+
if sents_cat in datum['labelf']: # A labeled dataset
|
208 |
+
labels = datum['labelf'][sents_cat]
|
209 |
+
for sent_idx, sent in enumerate(sents):
|
210 |
+
new_datum = {
|
211 |
+
'uid': make_uid(datum['img_id'], sents_cat, sent_idx),
|
212 |
+
'img_id': datum['img_id'],
|
213 |
+
'sent': sent,
|
214 |
+
'dset': sents_cat,
|
215 |
+
'label': labels[sent_idx]
|
216 |
+
}
|
217 |
+
self.data.append(new_datum)
|
218 |
+
|
219 |
+
# uid2datum
|
220 |
+
self.uid2datum = {}
|
221 |
+
for datum in self.data:
|
222 |
+
self.uid2datum[datum['uid']] = datum
|
223 |
+
|
224 |
+
def evaluate(self, uid2ans: dict, pprint=False):
|
225 |
+
score = 0.
|
226 |
+
cnt = 0
|
227 |
+
dset2score = defaultdict(lambda: 0.)
|
228 |
+
dset2cnt = defaultdict(lambda: 0)
|
229 |
+
for uid, ans in uid2ans.items():
|
230 |
+
if uid not in self.uid2datum: # Not a labeled data
|
231 |
+
continue
|
232 |
+
datum = self.uid2datum[uid]
|
233 |
+
label = datum['label']
|
234 |
+
dset = datum['dset']
|
235 |
+
if ans in label:
|
236 |
+
score += label[ans]
|
237 |
+
dset2score[dset] += label[ans]
|
238 |
+
cnt += 1
|
239 |
+
dset2cnt[dset] += 1
|
240 |
+
accu = score / cnt
|
241 |
+
dset2accu = {}
|
242 |
+
for dset in dset2cnt:
|
243 |
+
dset2accu[dset] = dset2score[dset] / dset2cnt[dset]
|
244 |
+
|
245 |
+
if pprint:
|
246 |
+
accu_str = "Overall Accu %0.4f, " % (accu)
|
247 |
+
sorted_keys = sorted(dset2accu.keys())
|
248 |
+
for key in sorted_keys:
|
249 |
+
accu_str += "%s Accu %0.4f, " % (key, dset2accu[key])
|
250 |
+
print(accu_str)
|
251 |
+
|
252 |
+
return accu, dset2accu
|
253 |
+
|
254 |
+
def dump_result(self, uid2ans: dict, path):
|
255 |
+
raise NotImplemented
|
lxmert/src/pretrain/lxmert_pretrain.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyleft 2019 project LXRT.
|
3 |
+
|
4 |
+
import collections
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
from param import args
|
15 |
+
from pretrain.lxmert_data import InputExample, LXMERTDataset, LXMERTTorchDataset, LXMERTEvaluator
|
16 |
+
from lxrt.entry import set_visual_config
|
17 |
+
from lxrt.tokenization import BertTokenizer
|
18 |
+
from lxrt.modeling import LXRTPretraining
|
19 |
+
|
20 |
+
DataTuple = collections.namedtuple("DataTuple", 'dataset torchdset loader evaluator')
|
21 |
+
|
22 |
+
|
23 |
+
def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1) -> DataTuple:
|
24 |
+
# Decide which QA datasets would be used in pre-training.
|
25 |
+
# Options: vqa, gqa, visual7w
|
26 |
+
# Note: visual7w is a part of vgqa, we take the name here.
|
27 |
+
qa_sets = args.qa_sets
|
28 |
+
if qa_sets is not None:
|
29 |
+
qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(","))
|
30 |
+
|
31 |
+
# Build dataset, data loader, and evaluator.
|
32 |
+
dset = LXMERTDataset(splits, qa_sets=qa_sets)
|
33 |
+
tset = LXMERTTorchDataset(dset, topk)
|
34 |
+
data_loader = DataLoader(
|
35 |
+
tset, batch_size=bs,
|
36 |
+
shuffle=shuffle, num_workers=args.num_workers,
|
37 |
+
collate_fn=lambda x: x,
|
38 |
+
drop_last=drop_last, pin_memory=True
|
39 |
+
)
|
40 |
+
evaluator = LXMERTEvaluator(dset)
|
41 |
+
print()
|
42 |
+
|
43 |
+
return DataTuple(dataset=dset, torchdset=tset, loader=data_loader, evaluator=evaluator)
|
44 |
+
|
45 |
+
|
46 |
+
train_tuple = get_tuple(args.train, args.batch_size, shuffle=True, drop_last=True)
|
47 |
+
valid_batch_size = 2048 if args.multiGPU else 512
|
48 |
+
valid_tuple = get_tuple(args.valid, valid_batch_size, shuffle=False, drop_last=False, topk=5000)
|
49 |
+
|
50 |
+
|
51 |
+
class InputFeatures(object):
|
52 |
+
"""A single set of features of data."""
|
53 |
+
|
54 |
+
def __init__(self,
|
55 |
+
input_ids, input_mask, segment_ids, lm_label_ids,
|
56 |
+
visual_feats, obj_labels,
|
57 |
+
is_matched, ans):
|
58 |
+
self.input_ids = input_ids
|
59 |
+
self.input_mask = input_mask
|
60 |
+
self.segment_ids = segment_ids
|
61 |
+
self.lm_label_ids = lm_label_ids
|
62 |
+
|
63 |
+
self.visual_feats = visual_feats
|
64 |
+
self.obj_labels = obj_labels
|
65 |
+
|
66 |
+
self.is_matched = is_matched
|
67 |
+
|
68 |
+
self.ans = ans
|
69 |
+
|
70 |
+
|
71 |
+
def random_word(tokens, tokenizer):
|
72 |
+
"""
|
73 |
+
Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
|
74 |
+
:param tokens: list of str, tokenized sentence.
|
75 |
+
:param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
|
76 |
+
:return: (list of str, list of int), masked tokens and related labels for LM prediction
|
77 |
+
"""
|
78 |
+
output_label = []
|
79 |
+
|
80 |
+
for i, token in enumerate(tokens):
|
81 |
+
prob = random.random()
|
82 |
+
# mask token with probability
|
83 |
+
ratio = args.word_mask_rate
|
84 |
+
if prob < ratio:
|
85 |
+
prob /= ratio
|
86 |
+
|
87 |
+
# 80% randomly change token to mask token
|
88 |
+
if prob < 0.8:
|
89 |
+
tokens[i] = "[MASK]"
|
90 |
+
|
91 |
+
# 10% randomly change token to random token
|
92 |
+
elif prob < 0.9:
|
93 |
+
tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
|
94 |
+
|
95 |
+
# -> rest 10% randomly keep current token
|
96 |
+
|
97 |
+
# append current token to output (we will predict these later)
|
98 |
+
try:
|
99 |
+
output_label.append(tokenizer.vocab[token])
|
100 |
+
except KeyError:
|
101 |
+
# For unknown words (should not occur with BPE vocab)
|
102 |
+
output_label.append(tokenizer.vocab["[UNK]"])
|
103 |
+
else:
|
104 |
+
# no masking token (will be ignored by loss function later)
|
105 |
+
output_label.append(-1)
|
106 |
+
|
107 |
+
return tokens, output_label
|
108 |
+
|
109 |
+
|
110 |
+
def random_feat(feats):
|
111 |
+
mask_feats = feats.copy()
|
112 |
+
feat_mask = np.zeros(len(feats), dtype=np.float32)
|
113 |
+
for i in range(len(feats)):
|
114 |
+
prob = random.random()
|
115 |
+
# mask token with probability
|
116 |
+
if prob < args.obj_mask_rate:
|
117 |
+
prob /= args.obj_mask_rate
|
118 |
+
|
119 |
+
# 80% randomly change token to zero feat
|
120 |
+
if prob < 0.8:
|
121 |
+
mask_feats[i, :] = 0.
|
122 |
+
|
123 |
+
# 10% randomly change token to random feat
|
124 |
+
elif prob < 0.9:
|
125 |
+
mask_feats[i, :] = train_tuple.torchdset.random_feat()
|
126 |
+
# -> rest 10% randomly keep current feat
|
127 |
+
|
128 |
+
# Need to predict this feat
|
129 |
+
feat_mask[i] = 1.
|
130 |
+
|
131 |
+
return mask_feats, feat_mask
|
132 |
+
|
133 |
+
|
134 |
+
def convert_example_to_features(example: InputExample, max_seq_length, tokenizer)->InputFeatures:
|
135 |
+
"""
|
136 |
+
Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
|
137 |
+
IDs, LM labels, input_mask, CLS and SEP tokens etc.
|
138 |
+
:param example: InputExample, containing sentence input as strings and is_next label
|
139 |
+
:param max_seq_length: int, maximum length of sequence.
|
140 |
+
:param tokenizer: Tokenizer
|
141 |
+
:return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
|
142 |
+
"""
|
143 |
+
tokens = tokenizer.tokenize(example.sent.strip())
|
144 |
+
|
145 |
+
# Account for [CLS] and [SEP] with "- 2"
|
146 |
+
if len(tokens) > max_seq_length - 2:
|
147 |
+
tokens = tokens[:(max_seq_length - 2)]
|
148 |
+
|
149 |
+
# Ge random words
|
150 |
+
masked_tokens, masked_label = random_word(tokens, tokenizer)
|
151 |
+
|
152 |
+
# concatenate lm labels and account for CLS, SEP, SEP
|
153 |
+
masked_tokens = ['[CLS]'] + masked_tokens + ['[SEP]']
|
154 |
+
input_ids = tokenizer.convert_tokens_to_ids(masked_tokens)
|
155 |
+
|
156 |
+
# Mask & Segment Word
|
157 |
+
lm_label_ids = ([-1] + masked_label + [-1])
|
158 |
+
input_mask = [1] * len(input_ids)
|
159 |
+
segment_ids = [0] * len(input_ids)
|
160 |
+
|
161 |
+
# Zero-pad up to the sequence length.
|
162 |
+
while len(input_ids) < max_seq_length:
|
163 |
+
input_ids.append(0)
|
164 |
+
input_mask.append(0)
|
165 |
+
segment_ids.append(0)
|
166 |
+
lm_label_ids.append(-1)
|
167 |
+
|
168 |
+
assert len(input_ids) == max_seq_length
|
169 |
+
assert len(input_mask) == max_seq_length
|
170 |
+
assert len(segment_ids) == max_seq_length
|
171 |
+
assert len(lm_label_ids) == max_seq_length
|
172 |
+
|
173 |
+
feat, boxes = example.visual_feats
|
174 |
+
obj_labels, obj_confs = example.obj_labels
|
175 |
+
attr_labels, attr_confs = example.attr_labels
|
176 |
+
|
177 |
+
# Mask Image Features:
|
178 |
+
masked_feat, feat_mask = random_feat(feat)
|
179 |
+
|
180 |
+
# QA answer label
|
181 |
+
if example.label is None or len(example.label) == 0 or example.is_matched != 1:
|
182 |
+
# 1. No label 2. Label is pruned 3. unmatched visual + language pair
|
183 |
+
ans = -1
|
184 |
+
else:
|
185 |
+
keys, values = zip(*example.label.items())
|
186 |
+
if len(keys) == 1:
|
187 |
+
ans = keys[0]
|
188 |
+
else:
|
189 |
+
value_sum = sum(values)
|
190 |
+
prob = [value / value_sum for value in values]
|
191 |
+
choice = np.random.multinomial(1, prob).argmax()
|
192 |
+
ans = keys[choice]
|
193 |
+
|
194 |
+
features = InputFeatures(
|
195 |
+
input_ids=input_ids,
|
196 |
+
input_mask=input_mask,
|
197 |
+
segment_ids=segment_ids,
|
198 |
+
lm_label_ids=lm_label_ids,
|
199 |
+
visual_feats=(masked_feat, boxes),
|
200 |
+
obj_labels={
|
201 |
+
'obj': (obj_labels, obj_confs),
|
202 |
+
'attr': (attr_labels, attr_confs),
|
203 |
+
'feat': (feat, feat_mask),
|
204 |
+
},
|
205 |
+
is_matched=example.is_matched,
|
206 |
+
ans=ans,
|
207 |
+
)
|
208 |
+
return features
|
209 |
+
|
210 |
+
|
211 |
+
LOSSES_NAME = ('Mask_LM', 'Matched', 'Obj', 'Attr', 'Feat', 'QA')
|
212 |
+
|
213 |
+
|
214 |
+
class LXMERT:
|
215 |
+
def __init__(self, max_seq_length):
|
216 |
+
super().__init__()
|
217 |
+
self.max_seq_length = max_seq_length
|
218 |
+
|
219 |
+
self.tokenizer = BertTokenizer.from_pretrained(
|
220 |
+
"bert-base-uncased",
|
221 |
+
do_lower_case=True
|
222 |
+
)
|
223 |
+
|
224 |
+
# Build model
|
225 |
+
set_visual_config(args)
|
226 |
+
self.model = LXRTPretraining.from_pretrained(
|
227 |
+
"bert-base-uncased",
|
228 |
+
task_mask_lm=args.task_mask_lm,
|
229 |
+
task_obj_predict=args.task_obj_predict,
|
230 |
+
task_matched=args.task_matched,
|
231 |
+
task_qa=args.task_qa,
|
232 |
+
visual_losses=args.visual_losses,
|
233 |
+
num_answers=train_tuple.dataset.answer_table.num_answers
|
234 |
+
)
|
235 |
+
|
236 |
+
# Weight initialization and loading
|
237 |
+
if args.from_scratch:
|
238 |
+
print("Train from Scratch: re-initialize all BERT weights.")
|
239 |
+
self.model.apply(self.model.init_bert_weights)
|
240 |
+
if args.load is not None:
|
241 |
+
self.load(args.load)
|
242 |
+
if args.load_lxmert is not None:
|
243 |
+
# Load lxmert would not load the answer head.
|
244 |
+
self.load_lxmert(args.load_lxmert)
|
245 |
+
|
246 |
+
# GPU Options
|
247 |
+
self.model = self.model.cuda()
|
248 |
+
if args.multiGPU:
|
249 |
+
self.model = nn.DataParallel(self.model)
|
250 |
+
|
251 |
+
def forward(self, examples):
|
252 |
+
train_features = [convert_example_to_features(example, self.max_seq_length, self.tokenizer)
|
253 |
+
for example in examples]
|
254 |
+
|
255 |
+
# language Inputs
|
256 |
+
input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda()
|
257 |
+
input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda()
|
258 |
+
segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda()
|
259 |
+
|
260 |
+
# Visual Inputs
|
261 |
+
feats = torch.from_numpy(np.stack([f.visual_feats[0] for f in train_features])).cuda()
|
262 |
+
pos = torch.from_numpy(np.stack([f.visual_feats[1] for f in train_features])).cuda()
|
263 |
+
|
264 |
+
# Language Prediction
|
265 |
+
lm_labels = torch.tensor([f.lm_label_ids for f in train_features], dtype=torch.long).cuda()
|
266 |
+
|
267 |
+
# Visual Prediction
|
268 |
+
obj_labels = {}
|
269 |
+
for key in ('obj', 'attr', 'feat'):
|
270 |
+
visn_labels = torch.from_numpy(np.stack([f.obj_labels[key][0] for f in train_features])).cuda()
|
271 |
+
visn_mask = torch.from_numpy(np.stack([f.obj_labels[key][1] for f in train_features])).cuda()
|
272 |
+
assert visn_labels.size(0) == visn_mask.size(0) and visn_labels.size(1) == visn_mask.size(1)
|
273 |
+
obj_labels[key] = (visn_labels, visn_mask)
|
274 |
+
|
275 |
+
# Joint Prediction
|
276 |
+
matched_labels = torch.tensor([f.is_matched for f in train_features], dtype=torch.long).cuda()
|
277 |
+
ans = torch.from_numpy(np.stack([f.ans for f in train_features])).cuda()
|
278 |
+
|
279 |
+
"""
|
280 |
+
forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
281 |
+
visual_feats=None, pos=None, obj_labels=None, matched_label=None, ans=None):
|
282 |
+
"""
|
283 |
+
loss, losses, ans_logit = self.model(
|
284 |
+
input_ids, segment_ids, input_mask, lm_labels,
|
285 |
+
feats, pos, obj_labels, matched_labels, ans
|
286 |
+
)
|
287 |
+
return loss, losses.detach().cpu(), ans_logit
|
288 |
+
|
289 |
+
def train_batch(self, optim, batch):
|
290 |
+
optim.zero_grad()
|
291 |
+
loss, losses, ans_logit = self.forward(batch)
|
292 |
+
if args.multiGPU:
|
293 |
+
loss = loss.mean()
|
294 |
+
losses = losses.mean(0)
|
295 |
+
loss.backward()
|
296 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), 1.)
|
297 |
+
optim.step()
|
298 |
+
|
299 |
+
return loss.item(), losses.cpu().numpy(), ans_logit
|
300 |
+
|
301 |
+
def valid_batch(self, batch):
|
302 |
+
with torch.no_grad():
|
303 |
+
loss, losses, ans_logit = self.forward(batch)
|
304 |
+
if args.multiGPU:
|
305 |
+
loss = loss.mean()
|
306 |
+
losses = losses.mean(0)
|
307 |
+
return loss.item(), losses.cpu().numpy(), ans_logit
|
308 |
+
|
309 |
+
def train(self, train_tuple: DataTuple, eval_tuple: DataTuple):
|
310 |
+
train_ld = train_tuple.loader
|
311 |
+
|
312 |
+
# Optimizer
|
313 |
+
from lxrt.optimization import BertAdam
|
314 |
+
batch_per_epoch = len(train_ld)
|
315 |
+
t_total = int(batch_per_epoch * args.epochs)
|
316 |
+
warmup_ratio = 0.05
|
317 |
+
warmup_iters = int(t_total * warmup_ratio)
|
318 |
+
print("Batch per epoch: %d" % batch_per_epoch)
|
319 |
+
print("Total Iters: %d" % t_total)
|
320 |
+
print("Warm up Iters: %d" % warmup_iters)
|
321 |
+
optim = BertAdam(self.model.parameters(), lr=args.lr, warmup=warmup_ratio, t_total=t_total)
|
322 |
+
|
323 |
+
# Train
|
324 |
+
best_eval_loss = 9595.
|
325 |
+
for epoch in range(args.epochs):
|
326 |
+
# Train
|
327 |
+
self.model.train()
|
328 |
+
total_loss = 0.
|
329 |
+
total_losses = 0.
|
330 |
+
uid2ans = {}
|
331 |
+
for batch in tqdm(train_ld, total=len(train_ld)):
|
332 |
+
loss, losses, logit = self.train_batch(optim, batch)
|
333 |
+
total_loss += loss
|
334 |
+
total_losses += losses
|
335 |
+
|
336 |
+
if args.task_qa:
|
337 |
+
score, label = logit.max(1)
|
338 |
+
for datum, l in zip(batch, label.cpu().numpy()):
|
339 |
+
uid = datum.uid
|
340 |
+
ans = train_tuple.dataset.answer_table.id2ans(l)
|
341 |
+
uid2ans[uid] = ans
|
342 |
+
|
343 |
+
print("The training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch))
|
344 |
+
losses_str = "The losses are "
|
345 |
+
for name, loss in zip(LOSSES_NAME, total_losses):
|
346 |
+
losses_str += "%s: %0.4f " % (name, loss / batch_per_epoch)
|
347 |
+
print(losses_str)
|
348 |
+
if args.task_qa:
|
349 |
+
train_tuple.evaluator.evaluate(uid2ans, pprint=True)
|
350 |
+
|
351 |
+
# Eval
|
352 |
+
avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1)
|
353 |
+
|
354 |
+
# Save
|
355 |
+
if avg_eval_loss < best_eval_loss:
|
356 |
+
best_eval_loss = avg_eval_loss
|
357 |
+
self.save("BEST_EVAL_LOSS")
|
358 |
+
self.save("Epoch%02d" % (epoch+1))
|
359 |
+
|
360 |
+
def evaluate_epoch(self, eval_tuple: DataTuple, iters: int=-1):
|
361 |
+
self.model.eval()
|
362 |
+
eval_ld = eval_tuple.loader
|
363 |
+
total_loss = 0.
|
364 |
+
total_losses = 0.
|
365 |
+
uid2ans = {}
|
366 |
+
for i, batch in enumerate(eval_ld):
|
367 |
+
loss, losses, logit = self.valid_batch(batch)
|
368 |
+
total_loss += loss
|
369 |
+
total_losses += losses
|
370 |
+
if args.task_qa:
|
371 |
+
score, label = logit.max(1)
|
372 |
+
for datum, l in zip(batch, label.cpu().numpy()):
|
373 |
+
uid = datum.uid
|
374 |
+
ans = train_tuple.dataset.answer_table.id2ans(l)
|
375 |
+
uid2ans[uid] = ans
|
376 |
+
if i == iters:
|
377 |
+
break
|
378 |
+
|
379 |
+
print("The valid loss is %0.4f" % (total_loss / len(eval_ld)))
|
380 |
+
losses_str = "The losses are "
|
381 |
+
for name, loss in zip(LOSSES_NAME, total_losses / len(eval_ld)):
|
382 |
+
losses_str += "%s: %0.4f " % (name, loss)
|
383 |
+
print(losses_str)
|
384 |
+
|
385 |
+
if args.task_qa:
|
386 |
+
eval_tuple.evaluator.evaluate(uid2ans, pprint=True)
|
387 |
+
|
388 |
+
return total_loss / len(eval_ld)
|
389 |
+
|
390 |
+
def save(self, name):
|
391 |
+
torch.save(self.model.state_dict(),
|
392 |
+
os.path.join(args.output, "%s_LXRT.pth" % name))
|
393 |
+
|
394 |
+
def load(self, path):
|
395 |
+
print("Load BERT extractor from %s" % path)
|
396 |
+
state_dict = torch.load("%s_LXRT.pth" % path)
|
397 |
+
self.model.load_state_dict(state_dict)
|
398 |
+
|
399 |
+
def load_lxmert(self, path):
|
400 |
+
print("Load lxmert model from %s" % path)
|
401 |
+
state_dict = torch.load("%s_LXRT.pth" % path)
|
402 |
+
|
403 |
+
# Do not load any answer head
|
404 |
+
for key in list(state_dict.keys()):
|
405 |
+
if 'answer' in key:
|
406 |
+
state_dict.pop(key)
|
407 |
+
|
408 |
+
# Change Multi GPU to single GPU
|
409 |
+
new_state_dict = {}
|
410 |
+
for key, value in state_dict.items():
|
411 |
+
if key.startswith("module."):
|
412 |
+
new_state_dict[key[len("module."):]] = value
|
413 |
+
state_dict = new_state_dict
|
414 |
+
|
415 |
+
load_keys = set(state_dict.keys())
|
416 |
+
model_keys = set(self.model.state_dict().keys())
|
417 |
+
print()
|
418 |
+
print("Keys in loaded but not in model:")
|
419 |
+
for key in sorted(load_keys.difference(model_keys)):
|
420 |
+
print(key)
|
421 |
+
print()
|
422 |
+
print("Keys in model but not in loaded:")
|
423 |
+
for key in sorted(model_keys.difference(load_keys)):
|
424 |
+
print(key)
|
425 |
+
print()
|
426 |
+
|
427 |
+
self.model.load_state_dict(state_dict, strict=False)
|
428 |
+
|
429 |
+
|
430 |
+
if __name__ == "__main__":
|
431 |
+
|
432 |
+
lxmert = LXMERT(max_seq_length=20)
|
433 |
+
|
434 |
+
|
435 |
+
lxmert.train(train_tuple, valid_tuple)
|
lxmert/src/pretrain/qa_answer_table.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyleft 2019 project LXRT.
|
3 |
+
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class AnswerTable:
|
9 |
+
ANS_CONVERT = {
|
10 |
+
"a man": "man",
|
11 |
+
"the man": "man",
|
12 |
+
"a woman": "woman",
|
13 |
+
"the woman": "woman",
|
14 |
+
'one': '1',
|
15 |
+
'two': '2',
|
16 |
+
'three': '3',
|
17 |
+
'four': '4',
|
18 |
+
'five': '5',
|
19 |
+
'six': '6',
|
20 |
+
'seven': '7',
|
21 |
+
'eight': '8',
|
22 |
+
'nine': '9',
|
23 |
+
'ten': '10',
|
24 |
+
'grey': 'gray',
|
25 |
+
}
|
26 |
+
|
27 |
+
def __init__(self, dsets=None):
|
28 |
+
self.all_ans = json.load(open("data/lxmert/all_ans.json"))
|
29 |
+
if dsets is not None:
|
30 |
+
dsets = set(dsets)
|
31 |
+
# If the answer is used in the dsets
|
32 |
+
self.anss = [ans['ans'] for ans in self.all_ans if
|
33 |
+
len(set(ans['dsets']) & dsets) > 0]
|
34 |
+
else:
|
35 |
+
self.anss = [ans['ans'] for ans in self.all_ans]
|
36 |
+
self.ans_set = set(self.anss)
|
37 |
+
|
38 |
+
self._id2ans_map = self.anss
|
39 |
+
self._ans2id_map = {ans: ans_id for ans_id, ans in enumerate(self.anss)}
|
40 |
+
|
41 |
+
assert len(self._id2ans_map) == len(self._ans2id_map)
|
42 |
+
for ans_id, ans in enumerate(self._id2ans_map):
|
43 |
+
assert self._ans2id_map[ans] == ans_id
|
44 |
+
|
45 |
+
def convert_ans(self, ans):
|
46 |
+
if len(ans) == 0:
|
47 |
+
return ""
|
48 |
+
ans = ans.lower()
|
49 |
+
if ans[-1] == '.':
|
50 |
+
ans = ans[:-1].strip()
|
51 |
+
if ans.startswith("a "):
|
52 |
+
ans = ans[2:].strip()
|
53 |
+
if ans.startswith("an "):
|
54 |
+
ans = ans[3:].strip()
|
55 |
+
if ans.startswith("the "):
|
56 |
+
ans = ans[4:].strip()
|
57 |
+
if ans in self.ANS_CONVERT:
|
58 |
+
ans = self.ANS_CONVERT[ans]
|
59 |
+
return ans
|
60 |
+
|
61 |
+
def ans2id(self, ans):
|
62 |
+
return self._ans2id_map[ans]
|
63 |
+
|
64 |
+
def id2ans(self, ans_id):
|
65 |
+
return self._id2ans_map[ans_id]
|
66 |
+
|
67 |
+
def ans2id_map(self):
|
68 |
+
return self._ans2id_map.copy()
|
69 |
+
|
70 |
+
def id2ans_map(self):
|
71 |
+
return self._id2ans_map.copy()
|
72 |
+
|
73 |
+
def used(self, ans):
|
74 |
+
return ans in self.ans_set
|
75 |
+
|
76 |
+
def all_answers(self):
|
77 |
+
return self.anss.copy()
|
78 |
+
|
79 |
+
@property
|
80 |
+
def num_answers(self):
|
81 |
+
return len(self.anss)
|
82 |
+
|
83 |
+
|
84 |
+
def load_lxmert_qa(path, model, label2ans):
|
85 |
+
"""
|
86 |
+
Load model weights from lxmert pre-training.
|
87 |
+
The answers in the fine-tuned QA task (indicated by label2ans)
|
88 |
+
would also be properly initialized with lxmert pre-trained
|
89 |
+
QA heads.
|
90 |
+
|
91 |
+
:param path: Path to lxmert snapshot.
|
92 |
+
:param model: LXRT model instance.
|
93 |
+
:param label2ans: The label2ans dict of fine-tuned QA datasets, like
|
94 |
+
{0: 'cat', 1: 'dog', ...}
|
95 |
+
:return:
|
96 |
+
"""
|
97 |
+
print("Load QA pre-trained lxmert from %s " % path)
|
98 |
+
loaded_state_dict = torch.load("%s_LXRT.pth" % path)
|
99 |
+
model_state_dict = model.state_dict()
|
100 |
+
|
101 |
+
# Handle Multi-GPU pre-training --> Single GPU fine-tuning
|
102 |
+
for key in list(loaded_state_dict.keys()):
|
103 |
+
loaded_state_dict[key.replace("module.", '')] = loaded_state_dict.pop(key)
|
104 |
+
|
105 |
+
# Isolate bert model
|
106 |
+
bert_state_dict = {}
|
107 |
+
for key, value in loaded_state_dict.items():
|
108 |
+
if key.startswith('bert.'):
|
109 |
+
bert_state_dict[key] = value
|
110 |
+
|
111 |
+
# Isolate answer head
|
112 |
+
answer_state_dict = {}
|
113 |
+
for key, value in loaded_state_dict.items():
|
114 |
+
if key.startswith("answer_head."):
|
115 |
+
answer_state_dict[key.replace('answer_head.', '')] = value
|
116 |
+
|
117 |
+
# Do surgery on answer state dict
|
118 |
+
ans_weight = answer_state_dict['logit_fc.3.weight']
|
119 |
+
ans_bias = answer_state_dict['logit_fc.3.bias']
|
120 |
+
import copy
|
121 |
+
new_answer_weight = copy.deepcopy(model_state_dict['logit_fc.3.weight'])
|
122 |
+
new_answer_bias = copy.deepcopy(model_state_dict['logit_fc.3.bias'])
|
123 |
+
answer_table = AnswerTable()
|
124 |
+
loaded = 0
|
125 |
+
unload = 0
|
126 |
+
if type(label2ans) is list:
|
127 |
+
label2ans = {label: ans for label, ans in enumerate(label2ans)}
|
128 |
+
for label, ans in label2ans.items():
|
129 |
+
new_ans = answer_table.convert_ans(ans)
|
130 |
+
if answer_table.used(new_ans):
|
131 |
+
ans_id_9500 = answer_table.ans2id(new_ans)
|
132 |
+
new_answer_weight[label] = ans_weight[ans_id_9500]
|
133 |
+
new_answer_bias[label] = ans_bias[ans_id_9500]
|
134 |
+
loaded += 1
|
135 |
+
else:
|
136 |
+
new_answer_weight[label] = 0.
|
137 |
+
new_answer_bias[label] = 0.
|
138 |
+
unload += 1
|
139 |
+
print("Loaded %d answers from LXRTQA pre-training and %d not" % (loaded, unload))
|
140 |
+
print()
|
141 |
+
answer_state_dict['logit_fc.3.weight'] = new_answer_weight
|
142 |
+
answer_state_dict['logit_fc.3.bias'] = new_answer_bias
|
143 |
+
|
144 |
+
# Load Bert Weights
|
145 |
+
bert_model_keys = set(model.lxrt_encoder.model.state_dict().keys())
|
146 |
+
bert_loaded_keys = set(bert_state_dict.keys())
|
147 |
+
assert len(bert_model_keys - bert_loaded_keys) == 0
|
148 |
+
model.lxrt_encoder.model.load_state_dict(bert_state_dict, strict=False)
|
149 |
+
|
150 |
+
# Load Answer Logic FC Weights
|
151 |
+
model_keys = set(model.state_dict().keys())
|
152 |
+
ans_loaded_keys = set(answer_state_dict.keys())
|
153 |
+
assert len(ans_loaded_keys - model_keys) == 0
|
154 |
+
|
155 |
+
model.load_state_dict(answer_state_dict, strict=False)
|
156 |
+
|
157 |
+
|
158 |
+
|
lxmert/src/processing_image.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
coding=utf-8
|
3 |
+
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal
|
4 |
+
Adapted From Facebook Inc, Detectron2
|
5 |
+
|
6 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
you may not use this file except in compliance with the License.
|
8 |
+
You may obtain a copy of the License at
|
9 |
+
|
10 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
|
12 |
+
Unless required by applicable law or agreed to in writing, software
|
13 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
See the License for the specific language governing permissions and
|
16 |
+
limitations under the License.import copy
|
17 |
+
"""
|
18 |
+
import sys
|
19 |
+
from typing import Tuple
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from PIL import Image
|
25 |
+
|
26 |
+
from lxmert.lxmert.src.vqa_utils import img_tensorize
|
27 |
+
|
28 |
+
|
29 |
+
class ResizeShortestEdge:
|
30 |
+
def __init__(self, short_edge_length, max_size=sys.maxsize):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
short_edge_length (list[min, max])
|
34 |
+
max_size (int): maximum allowed longest edge length.
|
35 |
+
"""
|
36 |
+
self.interp_method = "bilinear"
|
37 |
+
self.max_size = max_size
|
38 |
+
self.short_edge_length = short_edge_length
|
39 |
+
|
40 |
+
def __call__(self, imgs):
|
41 |
+
img_augs = []
|
42 |
+
for img in imgs:
|
43 |
+
h, w = img.shape[:2]
|
44 |
+
# later: provide list and randomly choose index for resize
|
45 |
+
size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
|
46 |
+
if size == 0:
|
47 |
+
return img
|
48 |
+
scale = size * 1.0 / min(h, w)
|
49 |
+
if h < w:
|
50 |
+
newh, neww = size, scale * w
|
51 |
+
else:
|
52 |
+
newh, neww = scale * h, size
|
53 |
+
if max(newh, neww) > self.max_size:
|
54 |
+
scale = self.max_size * 1.0 / max(newh, neww)
|
55 |
+
newh = newh * scale
|
56 |
+
neww = neww * scale
|
57 |
+
neww = int(neww + 0.5)
|
58 |
+
newh = int(newh + 0.5)
|
59 |
+
|
60 |
+
if img.dtype == np.uint8:
|
61 |
+
pil_image = Image.fromarray(img)
|
62 |
+
pil_image = pil_image.resize((neww, newh), Image.BILINEAR)
|
63 |
+
img = np.asarray(pil_image)
|
64 |
+
else:
|
65 |
+
img = img.permute(2, 0, 1).unsqueeze(0) # 3, 0, 1) # hw(c) -> nchw
|
66 |
+
img = F.interpolate(img, (newh, neww), mode=self.interp_method, align_corners=False).squeeze(0)
|
67 |
+
img_augs.append(img)
|
68 |
+
|
69 |
+
return img_augs
|
70 |
+
|
71 |
+
|
72 |
+
class Preprocess:
|
73 |
+
def __init__(self, cfg):
|
74 |
+
self.aug = ResizeShortestEdge([cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST)
|
75 |
+
self.input_format = cfg.INPUT.FORMAT
|
76 |
+
self.size_divisibility = cfg.SIZE_DIVISIBILITY
|
77 |
+
self.pad_value = cfg.PAD_VALUE
|
78 |
+
self.max_image_size = cfg.INPUT.MAX_SIZE_TEST
|
79 |
+
self.device = cfg.MODEL.DEVICE
|
80 |
+
self.pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1)
|
81 |
+
self.pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1)
|
82 |
+
self.normalizer = lambda x: (x - self.pixel_mean) / self.pixel_std
|
83 |
+
|
84 |
+
def pad(self, images):
|
85 |
+
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
86 |
+
image_sizes = [im.shape[-2:] for im in images]
|
87 |
+
images = [
|
88 |
+
F.pad(
|
89 |
+
im,
|
90 |
+
[0, max_size[-1] - size[1], 0, max_size[-2] - size[0]],
|
91 |
+
value=self.pad_value,
|
92 |
+
)
|
93 |
+
for size, im in zip(image_sizes, images)
|
94 |
+
]
|
95 |
+
|
96 |
+
return torch.stack(images), torch.tensor(image_sizes)
|
97 |
+
|
98 |
+
def __call__(self, images, single_image=False):
|
99 |
+
with torch.no_grad():
|
100 |
+
if not isinstance(images, list):
|
101 |
+
images = [images]
|
102 |
+
if single_image:
|
103 |
+
assert len(images) == 1
|
104 |
+
for i in range(len(images)):
|
105 |
+
if isinstance(images[i], torch.Tensor):
|
106 |
+
images.insert(i, images.pop(i).to(self.device).float())
|
107 |
+
elif not isinstance(images[i], torch.Tensor):
|
108 |
+
images.insert(
|
109 |
+
i,
|
110 |
+
torch.as_tensor(img_tensorize(images.pop(i), input_format=self.input_format))
|
111 |
+
.to(self.device)
|
112 |
+
.float(),
|
113 |
+
)
|
114 |
+
# resize smallest edge
|
115 |
+
raw_sizes = torch.tensor([im.shape[:2] for im in images])
|
116 |
+
images = self.aug(images)
|
117 |
+
# transpose images and convert to torch tensors
|
118 |
+
# images = [torch.as_tensor(i.astype("float32")).permute(2, 0, 1).to(self.device) for i in images]
|
119 |
+
# now normalize before pad to avoid useless arithmetic
|
120 |
+
images = [self.normalizer(x) for x in images]
|
121 |
+
# now pad them to do the following operations
|
122 |
+
images, sizes = self.pad(images)
|
123 |
+
# Normalize
|
124 |
+
|
125 |
+
if self.size_divisibility > 0:
|
126 |
+
raise NotImplementedError()
|
127 |
+
# pad
|
128 |
+
scales_yx = torch.true_divide(raw_sizes, sizes)
|
129 |
+
if single_image:
|
130 |
+
return images[0], sizes[0], scales_yx[0]
|
131 |
+
else:
|
132 |
+
return images, sizes, scales_yx
|
133 |
+
|
134 |
+
|
135 |
+
def _scale_box(boxes, scale_yx):
|
136 |
+
boxes[:, 0::2] *= scale_yx[:, 1]
|
137 |
+
boxes[:, 1::2] *= scale_yx[:, 0]
|
138 |
+
return boxes
|
139 |
+
|
140 |
+
|
141 |
+
def _clip_box(tensor, box_size: Tuple[int, int]):
|
142 |
+
assert torch.isfinite(tensor).all(), "Box tensor contains infinite or NaN!"
|
143 |
+
h, w = box_size
|
144 |
+
tensor[:, 0].clamp_(min=0, max=w)
|
145 |
+
tensor[:, 1].clamp_(min=0, max=h)
|
146 |
+
tensor[:, 2].clamp_(min=0, max=w)
|
147 |
+
tensor[:, 3].clamp_(min=0, max=h)
|
lxmert/src/tasks/__init__.py
ADDED
File without changes
|
lxmert/src/tasks/gqa.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyleft 2019 project LXRT.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import collections
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.utils.data.dataloader import DataLoader
|
11 |
+
|
12 |
+
from param import args
|
13 |
+
from pretrain.qa_answer_table import load_lxmert_qa
|
14 |
+
from tasks.gqa_model import GQAModel
|
15 |
+
from tasks.gqa_data import GQADataset, GQATorchDataset, GQAEvaluator
|
16 |
+
|
17 |
+
|
18 |
+
DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
|
19 |
+
|
20 |
+
|
21 |
+
def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
|
22 |
+
dset = GQADataset(splits)
|
23 |
+
tset = GQATorchDataset(dset)
|
24 |
+
evaluator = GQAEvaluator(dset)
|
25 |
+
data_loader = DataLoader(
|
26 |
+
tset, batch_size=bs,
|
27 |
+
shuffle=shuffle, num_workers=args.num_workers,
|
28 |
+
drop_last=drop_last, pin_memory=True
|
29 |
+
)
|
30 |
+
|
31 |
+
return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
|
32 |
+
|
33 |
+
|
34 |
+
class GQA:
|
35 |
+
def __init__(self):
|
36 |
+
self.train_tuple = get_tuple(
|
37 |
+
args.train, bs=args.batch_size, shuffle=True, drop_last=True
|
38 |
+
)
|
39 |
+
if args.valid != "":
|
40 |
+
valid_bsize = 2048 if args.multiGPU else 512
|
41 |
+
self.valid_tuple = get_tuple(
|
42 |
+
args.valid, bs=valid_bsize,
|
43 |
+
shuffle=False, drop_last=False
|
44 |
+
)
|
45 |
+
else:
|
46 |
+
self.valid_tuple = None
|
47 |
+
|
48 |
+
self.model = GQAModel(self.train_tuple.dataset.num_answers)
|
49 |
+
|
50 |
+
# Load pre-trained weights
|
51 |
+
if args.load_lxmert is not None:
|
52 |
+
self.model.lxrt_encoder.load(args.load_lxmert)
|
53 |
+
if args.load_lxmert_qa is not None:
|
54 |
+
load_lxmert_qa(args.load_lxmert_qa, self.model,
|
55 |
+
label2ans=self.train_tuple.dataset.label2ans)
|
56 |
+
|
57 |
+
# GPU options
|
58 |
+
self.model = self.model.cuda()
|
59 |
+
if args.multiGPU:
|
60 |
+
self.model.lxrt_encoder.multi_gpu()
|
61 |
+
|
62 |
+
# Losses and optimizer
|
63 |
+
self.bce_loss = nn.BCEWithLogitsLoss()
|
64 |
+
self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
|
65 |
+
if 'bert' in args.optim:
|
66 |
+
batch_per_epoch = len(self.train_tuple.loader)
|
67 |
+
t_total = int(batch_per_epoch * args.epochs)
|
68 |
+
print("Total Iters: %d" % t_total)
|
69 |
+
from lxrt.optimization import BertAdam
|
70 |
+
self.optim = BertAdam(list(self.model.parameters()),
|
71 |
+
lr=args.lr,
|
72 |
+
warmup=0.1,
|
73 |
+
t_total=t_total)
|
74 |
+
else:
|
75 |
+
self.optim = args.optimizer(list(self.model.parameters()), args.lr)
|
76 |
+
|
77 |
+
self.output = args.output
|
78 |
+
os.makedirs(self.output, exist_ok=True)
|
79 |
+
|
80 |
+
def train(self, train_tuple, eval_tuple):
|
81 |
+
dset, loader, evaluator = train_tuple
|
82 |
+
iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
|
83 |
+
|
84 |
+
best_valid = 0.
|
85 |
+
for epoch in range(args.epochs):
|
86 |
+
quesid2ans = {}
|
87 |
+
for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)):
|
88 |
+
|
89 |
+
self.model.train()
|
90 |
+
self.optim.zero_grad()
|
91 |
+
|
92 |
+
feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
|
93 |
+
logit = self.model(feats, boxes, sent)
|
94 |
+
assert logit.dim() == target.dim() == 2
|
95 |
+
if args.mce_loss:
|
96 |
+
max_value, target = target.max(1)
|
97 |
+
loss = self.mce_loss(logit, target) * logit.size(1)
|
98 |
+
else:
|
99 |
+
loss = self.bce_loss(logit, target)
|
100 |
+
loss = loss * logit.size(1)
|
101 |
+
|
102 |
+
loss.backward()
|
103 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
|
104 |
+
self.optim.step()
|
105 |
+
|
106 |
+
score, label = logit.max(1)
|
107 |
+
for qid, l in zip(ques_id, label.cpu().numpy()):
|
108 |
+
ans = dset.label2ans[l]
|
109 |
+
quesid2ans[qid] = ans
|
110 |
+
|
111 |
+
log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
|
112 |
+
|
113 |
+
if self.valid_tuple is not None: # Do Validation
|
114 |
+
valid_score = self.evaluate(eval_tuple)
|
115 |
+
if valid_score > best_valid:
|
116 |
+
best_valid = valid_score
|
117 |
+
self.save("BEST")
|
118 |
+
|
119 |
+
log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
|
120 |
+
"Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
|
121 |
+
|
122 |
+
print(log_str, end='')
|
123 |
+
|
124 |
+
with open(self.output + "/log.log", 'a') as f:
|
125 |
+
f.write(log_str)
|
126 |
+
f.flush()
|
127 |
+
|
128 |
+
self.save("LAST")
|
129 |
+
|
130 |
+
def predict(self, eval_tuple: DataTuple, dump=None):
|
131 |
+
self.model.eval()
|
132 |
+
dset, loader, evaluator = eval_tuple
|
133 |
+
quesid2ans = {}
|
134 |
+
for i, datum_tuple in enumerate(loader):
|
135 |
+
ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target
|
136 |
+
with torch.no_grad():
|
137 |
+
feats, boxes = feats.cuda(), boxes.cuda()
|
138 |
+
logit = self.model(feats, boxes, sent)
|
139 |
+
score, label = logit.max(1)
|
140 |
+
for qid, l in zip(ques_id, label.cpu().numpy()):
|
141 |
+
ans = dset.label2ans[l]
|
142 |
+
quesid2ans[qid] = ans
|
143 |
+
if dump is not None:
|
144 |
+
evaluator.dump_result(quesid2ans, dump)
|
145 |
+
return quesid2ans
|
146 |
+
|
147 |
+
def evaluate(self, eval_tuple: DataTuple, dump=None):
|
148 |
+
dset, loader, evaluator = eval_tuple
|
149 |
+
quesid2ans = self.predict(eval_tuple, dump)
|
150 |
+
return evaluator.evaluate(quesid2ans)
|
151 |
+
|
152 |
+
@staticmethod
|
153 |
+
def oracle_score(data_tuple):
|
154 |
+
dset, loader, evaluator = data_tuple
|
155 |
+
quesid2ans = {}
|
156 |
+
for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
|
157 |
+
_, label = target.max(1)
|
158 |
+
for qid, l in zip(ques_id, label.cpu().numpy()):
|
159 |
+
ans = dset.label2ans[l]
|
160 |
+
quesid2ans[qid] = ans
|
161 |
+
return evaluator.evaluate(quesid2ans)
|
162 |
+
|
163 |
+
def save(self, name):
|
164 |
+
torch.save(self.model.state_dict(),
|
165 |
+
os.path.join(self.output, "%s.pth" % name))
|
166 |
+
|
167 |
+
def load(self, path):
|
168 |
+
print("Load model from %s" % path)
|
169 |
+
state_dict = torch.load("%s.pth" % path)
|
170 |
+
for key in list(state_dict.keys()):
|
171 |
+
if '.module' in key:
|
172 |
+
state_dict[key.replace('.module', '')] = state_dict.pop(key)
|
173 |
+
self.model.load_state_dict(state_dict, strict=False)
|
174 |
+
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
# Build Class
|
178 |
+
gqa = GQA()
|
179 |
+
|
180 |
+
# Load Model
|
181 |
+
if args.load is not None:
|
182 |
+
gqa.load(args.load)
|
183 |
+
|
184 |
+
# Test or Train
|
185 |
+
if args.test is not None:
|
186 |
+
args.fast = args.tiny = False # Always loading all data in test
|
187 |
+
if 'submit' in args.test:
|
188 |
+
gqa.predict(
|
189 |
+
get_tuple(args.test, bs=args.batch_size,
|
190 |
+
shuffle=False, drop_last=False),
|
191 |
+
dump=os.path.join(args.output, 'submit_predict.json')
|
192 |
+
)
|
193 |
+
if 'testdev' in args.test:
|
194 |
+
result = gqa.evaluate(
|
195 |
+
get_tuple('testdev', bs=args.batch_size,
|
196 |
+
shuffle=False, drop_last=False),
|
197 |
+
dump=os.path.join(args.output, 'testdev_predict.json')
|
198 |
+
)
|
199 |
+
print(result)
|
200 |
+
else:
|
201 |
+
# print("Train Oracle: %0.2f" % (gqa.oracle_score(gqa.train_tuple) * 100))
|
202 |
+
print('Splits in Train data:', gqa.train_tuple.dataset.splits)
|
203 |
+
if gqa.valid_tuple is not None:
|
204 |
+
print('Splits in Valid data:', gqa.valid_tuple.dataset.splits)
|
205 |
+
print("Valid Oracle: %0.2f" % (gqa.oracle_score(gqa.valid_tuple) * 100))
|
206 |
+
else:
|
207 |
+
print("DO NOT USE VALIDATION")
|
208 |
+
gqa.train(gqa.train_tuple, gqa.valid_tuple)
|
209 |
+
|
210 |
+
|