diff --git a/lxmert/.gitignore b/lxmert/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..69a114a5bd5c5d0fdc922f1b1feadf7f9edcbfa0 --- /dev/null +++ b/lxmert/.gitignore @@ -0,0 +1,3 @@ +*.caffemodel +*.tsv +/snap diff --git a/lxmert/.gitmodules b/lxmert/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..ebb287a88740c4f6878cfbbe5106bfdb10b6e678 --- /dev/null +++ b/lxmert/.gitmodules @@ -0,0 +1,3 @@ +[submodule "data/nlvr2/nlvr"] + path = data/nlvr2/nlvr + url = https://github.com/lil-lab/nlvr.git diff --git a/lxmert/.ipynb_checkpoints/Untitled-checkpoint.ipynb b/lxmert/.ipynb_checkpoints/Untitled-checkpoint.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..363fcab7ed6e9634e198cf5555ceb88932c9a245 --- /dev/null +++ b/lxmert/.ipynb_checkpoints/Untitled-checkpoint.ipynb @@ -0,0 +1,6 @@ +{ + "cells": [], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/lxmert/LICENSE b/lxmert/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..52df82d3566108ace6f716c15338075da4dc1482 --- /dev/null +++ b/lxmert/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Hao Tan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lxmert/__init__.py b/lxmert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lxmert/__pycache__/__init__.cpython-38.pyc b/lxmert/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a3e0f861f8b3628633530bf47ef86b9bd7df6fc Binary files /dev/null and b/lxmert/__pycache__/__init__.cpython-38.pyc differ diff --git a/lxmert/experiments/paper/COCO_val2014_000000127510/COCO_val2014_000000127510.jpg b/lxmert/experiments/paper/COCO_val2014_000000127510/COCO_val2014_000000127510.jpg new file mode 100644 index 0000000000000000000000000000000000000000..364e37b55a51287d4fc696d3bf0b26ec26ff35c4 Binary files /dev/null and b/lxmert/experiments/paper/COCO_val2014_000000127510/COCO_val2014_000000127510.jpg differ diff --git a/lxmert/experiments/paper/COCO_val2014_000000185590/COCO_val2014_000000185590.jpg b/lxmert/experiments/paper/COCO_val2014_000000185590/COCO_val2014_000000185590.jpg new file mode 100644 index 0000000000000000000000000000000000000000..72478535097ed8a48d03dc46d30fc9750c3b4df9 Binary files /dev/null and b/lxmert/experiments/paper/COCO_val2014_000000185590/COCO_val2014_000000185590.jpg differ diff --git a/lxmert/experiments/paper/COCO_val2014_000000200717/COCO_val2014_000000200717.jpg b/lxmert/experiments/paper/COCO_val2014_000000200717/COCO_val2014_000000200717.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ae9884bbb367a0ad3ec655c414d55edae69553c2 Binary files /dev/null and b/lxmert/experiments/paper/COCO_val2014_000000200717/COCO_val2014_000000200717.jpg differ diff --git a/lxmert/experiments/paper/COCO_val2014_000000324266/COCO_val2014_000000324266.jpg b/lxmert/experiments/paper/COCO_val2014_000000324266/COCO_val2014_000000324266.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e4444850db8e27eade69ce847504bf025c08d119 Binary files /dev/null and b/lxmert/experiments/paper/COCO_val2014_000000324266/COCO_val2014_000000324266.jpg differ diff --git a/lxmert/experiments/paper/new.jpg b/lxmert/experiments/paper/new.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6e727b94fcbe1a1379ab4fa60dba00290ec83090 Binary files /dev/null and b/lxmert/experiments/paper/new.jpg differ diff --git a/lxmert/perturbation.py b/lxmert/perturbation.py new file mode 100644 index 0000000000000000000000000000000000000000..62e0dc4e62f88656beb45f5cafa431fcbebc8d2e --- /dev/null +++ b/lxmert/perturbation.py @@ -0,0 +1,254 @@ +from lxmert.lxmert.src.tasks import vqa_data +from lxmert.lxmert.src.modeling_frcnn import GeneralizedRCNN +import lxmert.lxmert.src.vqa_utils as utils +from lxmert.lxmert.src.processing_image import Preprocess +from transformers import LxmertTokenizer +from lxmert.lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering +from lxmert.lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP +from tqdm import tqdm +from lxmert.lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation +import random +from lxmert.lxmert.src.param import args + +OBJ_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/objects_vocab.txt" +ATTR_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/attributes_vocab.txt" +VQA_URL = "https://raw.githubusercontent.com/airsplay/lxmert/master/data/vqa/trainval_label2ans.json" + +class ModelPert: + def __init__(self, COCO_val_path, use_lrp=False): + self.COCO_VAL_PATH = COCO_val_path + self.vqa_answers = utils.get_data(VQA_URL) + + # load models and model components + self.frcnn_cfg = utils.Config.from_pretrained("unc-nlp/frcnn-vg-finetuned") + self.frcnn_cfg.MODEL.DEVICE = "cuda" + + self.frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg) + + self.image_preprocess = Preprocess(self.frcnn_cfg) + + self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("unc-nlp/lxmert-base-uncased") + + if use_lrp: + self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("unc-nlp/lxmert-vqa-uncased").to("cuda") + else: + self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased").to("cuda") + + self.lxmert_vqa.eval() + self.model = self.lxmert_vqa + + self.vqa_dataset = vqa_data.VQADataset(splits="valid") + + self.pert_steps = [0, 0.25, 0.5, 0.75, 0.8, 0.85, 0.9, 0.95, 1] + self.pert_acc = [0] * len(self.pert_steps) + + def forward(self, item): + image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg' + self.image_file_path = image_file_path + self.image_id = item['img_id'] + # run frcnn + images, sizes, scales_yx = self.image_preprocess(image_file_path) + output_dict = self.frcnn( + images, + sizes, + scales_yx=scales_yx, + padding="max_detections", + max_detections= self.frcnn_cfg.max_detections, + return_tensors="pt" + ) + inputs = self.lxmert_tokenizer( + item['sent'], + truncation=True, + return_token_type_ids=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt" + ) + self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten()) + self.text_len = len(self.question_tokens) + # Very important that the boxes are normalized + normalized_boxes = output_dict.get("normalized_boxes") + features = output_dict.get("roi_features") + self.image_boxes_len = features.shape[1] + self.bboxes = output_dict.get("boxes") + self.output = self.lxmert_vqa( + input_ids=inputs.input_ids.to("cuda"), + attention_mask=inputs.attention_mask.to("cuda"), + visual_feats=features.to("cuda"), + visual_pos=normalized_boxes.to("cuda"), + token_type_ids=inputs.token_type_ids.to("cuda"), + return_dict=True, + output_attentions=False, + ) + return self.output + + def perturbation_image(self, item, cam_image, cam_text, is_positive_pert=False): + if is_positive_pert: + cam_image = cam_image * (-1) + image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg' + # run frcnn + images, sizes, scales_yx = self.image_preprocess(image_file_path) + output_dict = self.frcnn( + images, + sizes, + scales_yx=scales_yx, + padding="max_detections", + max_detections=self.frcnn_cfg.max_detections, + return_tensors="pt" + ) + inputs = self.lxmert_tokenizer( + item['sent'], + truncation=True, + return_token_type_ids=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt" + ) + # Very important that the boxes are normalized + normalized_boxes = output_dict.get("normalized_boxes") + features = output_dict.get("roi_features") + for step_idx, step in enumerate(self.pert_steps): + # find top step boxes + curr_num_boxes = int((1 - step) * self.image_boxes_len) + _, top_bboxes_indices = cam_image.topk(k=curr_num_boxes, dim=-1) + top_bboxes_indices = top_bboxes_indices.cpu().data.numpy() + + curr_features = features[:, top_bboxes_indices, :] + curr_pos = normalized_boxes[:, top_bboxes_indices, :] + + output = self.lxmert_vqa( + input_ids=inputs.input_ids.to("cuda"), + attention_mask=inputs.attention_mask.to("cuda"), + visual_feats=curr_features.to("cuda"), + visual_pos=curr_pos.to("cuda"), + token_type_ids=inputs.token_type_ids.to("cuda"), + return_dict=True, + output_attentions=False, + ) + + answer = self.vqa_answers[output.question_answering_score.argmax()] + accuracy = item["label"].get(answer, 0) + self.pert_acc[step_idx] += accuracy + + return self.pert_acc + + def perturbation_text(self, item, cam_image, cam_text, is_positive_pert=False): + if is_positive_pert: + cam_text = cam_text * (-1) + image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg' + # run frcnn + images, sizes, scales_yx = self.image_preprocess(image_file_path) + output_dict = self.frcnn( + images, + sizes, + scales_yx=scales_yx, + padding="max_detections", + max_detections=self.frcnn_cfg.max_detections, + return_tensors="pt" + ) + inputs = self.lxmert_tokenizer( + item['sent'], + truncation=True, + return_token_type_ids=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt" + ) + # Very important that the boxes are normalized + normalized_boxes = output_dict.get("normalized_boxes") + features = output_dict.get("roi_features") + for step_idx, step in enumerate(self.pert_steps): + # we must keep the [CLS] token in order to have the classification + # we also keep the [SEP] token + cam_pure_text = cam_text[1:-1] + text_len = cam_pure_text.shape[0] + # find top step tokens, without the [CLS] token and the [SEP] token + curr_num_tokens = int((1 - step) * text_len) + _, top_bboxes_indices = cam_pure_text.topk(k=curr_num_tokens, dim=-1) + top_bboxes_indices = top_bboxes_indices.cpu().data.numpy() + + # add back [CLS], [SEP] tokens + top_bboxes_indices = [0, cam_text.shape[0] - 1] +\ + [top_bboxes_indices[i] + 1 for i in range(len(top_bboxes_indices))] + # text tokens must be sorted for positional embedding to work + top_bboxes_indices = sorted(top_bboxes_indices) + + curr_input_ids = inputs.input_ids[:, top_bboxes_indices] + curr_attention_mask = inputs.attention_mask[:, top_bboxes_indices] + curr_token_ids = inputs.token_type_ids[:, top_bboxes_indices] + + output = self.lxmert_vqa( + input_ids=curr_input_ids.to("cuda"), + attention_mask=curr_attention_mask.to("cuda"), + visual_feats=features.to("cuda"), + visual_pos=normalized_boxes.to("cuda"), + token_type_ids=curr_token_ids.to("cuda"), + return_dict=True, + output_attentions=False, + ) + + answer = self.vqa_answers[output.question_answering_score.argmax()] + accuracy = item["label"].get(answer, 0) + self.pert_acc[step_idx] += accuracy + + return self.pert_acc + +def main(args): + model_pert = ModelPert(args.COCO_path, use_lrp=True) + ours = GeneratorOurs(model_pert) + baselines = GeneratorBaselines(model_pert) + oursNoAggAblation = GeneratorOursAblationNoAggregation(model_pert) + vqa_dataset = vqa_data.VQADataset(splits="valid") + vqa_answers = utils.get_data(VQA_URL) + method_name = args.method + + items = vqa_dataset.data + random.seed(1234) + r = list(range(len(items))) + random.shuffle(r) + pert_samples_indices = r[:args.num_samples] + iterator = tqdm([vqa_dataset.data[i] for i in pert_samples_indices]) + + test_type = "positive" if args.is_positive_pert else "negative" + modality = "text" if args.is_text_pert else "image" + print("runnig {0} pert test for {1} modality with method {2}".format(test_type, modality, args.method)) + + for index, item in enumerate(iterator): + if method_name == 'transformer_att': + R_t_t, R_t_i = baselines.generate_transformer_attr(item) + elif method_name == 'attn_gradcam': + R_t_t, R_t_i = baselines.generate_attn_gradcam(item) + elif method_name == 'partial_lrp': + R_t_t, R_t_i = baselines.generate_partial_lrp(item) + elif method_name == 'raw_attn': + R_t_t, R_t_i = baselines.generate_raw_attn(item) + elif method_name == 'rollout': + R_t_t, R_t_i = baselines.generate_rollout(item) + elif method_name == "ours_with_lrp_no_normalization": + R_t_t, R_t_i = ours.generate_ours(item, normalize_self_attention=False) + elif method_name == "ours_no_lrp": + R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False) + elif method_name == "ours_no_lrp_no_norm": + R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False, normalize_self_attention=False) + elif method_name == "ours_with_lrp": + R_t_t, R_t_i = ours.generate_ours(item, use_lrp=True) + elif method_name == "ablation_no_self_in_10": + R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False, apply_self_in_rule_10=False) + elif method_name == "ablation_no_aggregation": + R_t_t, R_t_i = oursNoAggAblation.generate_ours_no_agg(item, use_lrp=False, normalize_self_attention=False) + else: + print("Please enter a valid method name") + return + cam_image = R_t_i[0] + cam_text = R_t_t[0] + cam_image = (cam_image - cam_image.min()) / (cam_image.max() - cam_image.min()) + cam_text = (cam_text - cam_text.min()) / (cam_text.max() - cam_text.min()) + if args.is_text_pert: + curr_pert_result = model_pert.perturbation_text(item, cam_image, cam_text, args.is_positive_pert) + else: + curr_pert_result = model_pert.perturbation_image(item, cam_image, cam_text, args.is_positive_pert) + curr_pert_result = [round(res / (index+1) * 100, 2) for res in curr_pert_result] + iterator.set_description("Acc: {}".format(curr_pert_result)) + +if __name__ == "__main__": + main(args) \ No newline at end of file diff --git a/lxmert/requirements.txt b/lxmert/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..355a1d1212ff506b45fd8c9e5f1cd6bcdf93591a --- /dev/null +++ b/lxmert/requirements.txt @@ -0,0 +1,107 @@ +argon2-cffi==20.1.0 +async-generator==1.10 +attrs==20.3.0 +backcall==0.2.0 +bleach==3.3.0 +certifi==2020.12.5 +cffi==1.14.5 +chardet==3.0.4 +click==7.1.2 +cycler==0.10.0 +Cython==0.29.22 +dataclasses==0.6 +decorator==4.4.2 +defusedxml==0.6.0 +demjson==2.2.4 +editdistance==0.5.3 +einops==0.3.0 +entrypoints==0.3 +fasttext==0.9.1 +filelock==3.0.12 +future==0.18.2 +gitdb==4.0.5 +GitPython==3.1.0 +idna==2.10 +imageio==2.9.0 +importlib-metadata==3.4.0 +ipykernel==5.4.3 +ipython==7.20.0 +ipython-genutils==0.2.0 +ipywidgets==7.6.3 +jedi==0.18.0 +Jinja2==2.11.3 +joblib==0.17.0 +jsonschema==3.2.0 +jupyter-client==6.1.11 +jupyter-console==6.2.0 +jupyter-core==4.7.1 +jupyterlab-pygments==0.1.2 +jupyterlab-widgets==1.0.0 +kiwisolver==1.3.1 +lmdb==0.98 +MarkupSafe==1.1.1 +matplotlib==3.3.4 +mistune==0.8.4 +nbclient==0.5.2 +nbconvert==6.0.7 +nbformat==5.1.2 +nest-asyncio==1.5.1 +networkx==2.4 +nltk==3.4.5 +notebook==6.2.0 +numpy==1.19.2 +omegaconf==2.0.1rc4 +opencv-python==4.5.1.48 +packaging==20.9 +pandocfilters==1.4.3 +parso==0.8.1 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==8.1.2 +prometheus-client==0.9.0 +prompt-toolkit==3.0.16 +protobuf==3.15.6 +ptyprocess==0.7.0 +pybind11==2.6.2 +pycocotools==2.0.2 +pycparser==2.20 +pyparsing==2.4.7 +pyrsistent==0.17.3 +python-dateutil==2.8.1 +PyWavelets==1.1.1 +PyYAML==5.4.1 +pyzmq==22.0.3 +qtconsole==5.0.2 +QtPy==1.9.0 +regex==2020.11.13 +requests==2.23.0 +sacremoses==0.0.43 +scikit-image==0.17.2 +scikit-learn==0.23.2 +scipy==1.6.1 +Send2Trash==1.5.0 +sentencepiece==0.1.91 +six==1.15.0 +sklearn==0.0 +smmap==3.0.5 +termcolor==1.1.0 +terminado==0.9.2 +testpath==0.4.4 +threadpoolctl==2.1.0 +tifffile==2021.2.1 +tokenizers==0.9.3 +torch==1.7.1 +torchtext==0.5.0 +torchvision==0.8.2 +tornado==6.1 +tqdm==4.51.0 +traitlets==5.0.5 +transformers==3.5.1 +typing-extensions==3.7.4.3 +urllib3==1.25.11 +utils==1.0.1 +wcwidth==0.2.5 +webencodings==0.5.1 +wget==3.2 +widgetsnbextension==3.5.1 +zipp==3.4.0 diff --git a/lxmert/run/README.md b/lxmert/run/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fdbbebae148e14274ac988fbaf6d60c67d8acfed --- /dev/null +++ b/lxmert/run/README.md @@ -0,0 +1,49 @@ +# Running Script Arguments + +``` +Data Splits: + --train [str,str,...]: use the splits (separated by comma) in training. + --valid [str,str,...]: use the splits (separated by comma) in validation. + --test [str,str,...]: use the splits (separated by comma) in testing. +Model Architecture: + --llayers [int]: number of layers in language encoder. + --xlayers [int]: number of layers in cross-modality encoder. + --rlayers [int]: number of layers in object relationship encoder. +Load Weights: + --load [str='path/to/saved_model']: load fine-tuned model path/to/saved_model.pth. + --loadLXMERT [str='path/to/saved_model']: load pre-trained model without answer heads from path/to/saved_model_LXRT.pth. + --loadLXMERTQA [str='path/to/saved_model']: load pre-trained model with answer head path/to/saved_model_LXRT.pth. + --fromScratch: If none of the above loading parameters are set, the default mode would + load the pre-trained BERT weights. + As we promised to EMNLP reviewers, the language encoder would be re-initialized with this one-line argument to test the performance without BERT weights. +Training Hyper Parameters: + --batchSize [int]: batch size. + --optim [str]: optimizers. + --lr [float]: peak learning rate. + --epochs [int]: training epochs. +Debugging: + --tiny: Load 512 images for each data split. (Note: number of images might be changed due to dataset specification) + --fast: Load 5000 images for each data split. (Note: number of images might be changed due to dataset specification) +``` + +# Pre-training-Specific Arguments +``` +Pre-training Tasks: + --taskMaskLM: use the masked language model task. + --taskObjPredict: use the masked object prediction task. + --taskMatched: use the cross-modality matched task. + --taskQA: use the image QA task. +Visual Pre-training Losses (Tasks): + --visualLosses [str,str,...]: The sub-tasks in pre-training visual modality. Each one is from 'obj,attr,feat'. + obj: detected-object-label classification. + attr: detected-object-attribute classification. + feat: RoI-feature regression. +Mask Rate in Pre-training: + --wordMaskRate [float]: The prob of masking a word. + --objMaskRate [float]: The prob of masking an object. +Initialization: + --fromScratch: The default mode would load the pre-trained BERT weights into the model. + As we promised to EMNLP reviewers, this option would re-initialize the language encoder. +``` + + diff --git a/lxmert/run/gqa_finetune.bash b/lxmert/run/gqa_finetune.bash new file mode 100644 index 0000000000000000000000000000000000000000..bdcb2d3f5c3eabdebccdf88fc8a0acdef80f6db8 --- /dev/null +++ b/lxmert/run/gqa_finetune.bash @@ -0,0 +1,17 @@ +# The name of this experiment. +name=$2 + +# Save logs and models under snap/gqa; make backup. +output=snap/gqa/$name +mkdir -p $output/src +cp -r src/* $output/src/ +cp $0 $output/run.bash + +# See Readme.md for option details. +CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \ + python src/tasks/gqa.py \ + --train train,valid --valid testdev \ + --llayers 9 --xlayers 5 --rlayers 5 \ + --loadLXMERTQA snap/pretrained/model \ + --batchSize 32 --optim bert --lr 1e-5 --epochs 4 \ + --tqdm --output $output ${@:3} diff --git a/lxmert/run/gqa_test.bash b/lxmert/run/gqa_test.bash new file mode 100644 index 0000000000000000000000000000000000000000..b42dd8fcd196475a92a38074f770a3dfba08f82c --- /dev/null +++ b/lxmert/run/gqa_test.bash @@ -0,0 +1,15 @@ +# The name of this experiment. +name=$2 + +# Save logs and models under snap/gqa; make backup. +output=snap/gqa/$name +mkdir -p $output/src +cp -r src/* $output/src/ +cp $0 $output/run.bash + +# See Readme.md for option details. +CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \ + python src/tasks/gqa.py \ + --tiny --train train --valid "" \ + --llayers 9 --xlayers 5 --rlayers 5 \ + --tqdm --output $output ${@:3} diff --git a/lxmert/run/lxmert_pretrain.bash b/lxmert/run/lxmert_pretrain.bash new file mode 100644 index 0000000000000000000000000000000000000000..3147ff0e84cb3e8dc96f04df5e1cde81e7386e58 --- /dev/null +++ b/lxmert/run/lxmert_pretrain.bash @@ -0,0 +1,21 @@ +# The name of experiment +name=lxmert + +# Create dirs and make backup +output=snap/pretrain/$name +mkdir -p $output/src +cp -r src/* $output/src/ +cp $0 $output/run.bash + +# Pre-training +CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \ + python src/pretrain/lxmert_pretrain.py \ + --taskMaskLM --taskObjPredict --taskMatched --taskQA \ + --visualLosses obj,attr,feat \ + --wordMaskRate 0.15 --objMaskRate 0.15 \ + --train mscoco_train,mscoco_nominival,vgnococo --valid mscoco_minival \ + --llayers 9 --xlayers 5 --rlayers 5 \ + --fromScratch \ + --batchSize 256 --optim bert --lr 1e-4 --epochs 20 \ + --tqdm --output $output ${@:2} + diff --git a/lxmert/run/nlvr2_finetune.bash b/lxmert/run/nlvr2_finetune.bash new file mode 100644 index 0000000000000000000000000000000000000000..6372564aa55599554608b8438a99fe1024679a20 --- /dev/null +++ b/lxmert/run/nlvr2_finetune.bash @@ -0,0 +1,18 @@ +# The name of this experiment. +name=$2 + +# Save logs and models under snap/nlvr2; Make backup. +output=snap/nlvr2/$name +mkdir -p $output/src +cp -r src/* $output/src/ +cp $0 $output/run.bash + +# See run/Readme.md for option details. +CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \ + python src/tasks/nlvr2.py \ + --train train --valid valid \ + --llayers 9 --xlayers 5 --rlayers 5 \ + --loadLXMERT snap/pretrained/model \ + --batchSize 32 --optim bert --lr 5e-5 --epochs 4 \ + --tqdm --output $output ${@:3} + diff --git a/lxmert/run/nlvr2_test.bash b/lxmert/run/nlvr2_test.bash new file mode 100644 index 0000000000000000000000000000000000000000..7aa951d25278d7709937926e235a793ab6a9b525 --- /dev/null +++ b/lxmert/run/nlvr2_test.bash @@ -0,0 +1,14 @@ +# The name of this experiment. +name=$2 + +# Save logs and models under snap/nlvr2; make backup. +output=snap/nlvr2/$name +mkdir -p $output/src +cp -r src/* $output/src/ +cp $0 $output/run.bash + +# See Readme.md for option details. +CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \ + python src/tasks/nlvr2.py \ + --tiny --llayers 9 --xlayers 5 --rlayers 5 \ + --tqdm --output $output ${@:3} diff --git a/lxmert/run/vqa_finetune.bash b/lxmert/run/vqa_finetune.bash new file mode 100644 index 0000000000000000000000000000000000000000..09c47e800002cc4d2c81170c52d061dc40418e3f --- /dev/null +++ b/lxmert/run/vqa_finetune.bash @@ -0,0 +1,17 @@ +# The name of this experiment. +name=$2 + +# Save logs and models under snap/vqa; make backup. +output=snap/vqa/$name +mkdir -p $output/src +cp -r src/* $output/src/ +cp $0 $output/run.bash + +# See Readme.md for option details. +CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \ + python src/tasks/vqa.py \ + --train train,nominival --valid minival \ + --llayers 9 --xlayers 5 --rlayers 5 \ + --loadLXMERTQA snap/pretrained/model \ + --batchSize 32 --optim bert --lr 5e-5 --epochs 4 \ + --tqdm --output $output ${@:3} diff --git a/lxmert/run/vqa_test.bash b/lxmert/run/vqa_test.bash new file mode 100644 index 0000000000000000000000000000000000000000..f73642443d2f35d20a3a6422fa2db6169e872ac0 --- /dev/null +++ b/lxmert/run/vqa_test.bash @@ -0,0 +1,16 @@ +# The name of this experiment. +name=$2 + +# Save logs and models under snap/vqa; make backup. +output=snap/vqa/$name +mkdir -p $output/src +cp -r src/* $output/src/ +cp $0 $output/run.bash + +# See Readme.md for option details. +CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \ + python src/tasks/vqa.py \ + --tiny --train train --valid "" \ + --llayers 9 --xlayers 5 --rlayers 5 \ + --batchSize 32 --optim bert --lr 5e-5 --epochs 4 \ + --tqdm --output $output ${@:3} diff --git a/lxmert/src/.ipynb_checkpoints/Untitled-checkpoint.ipynb b/lxmert/src/.ipynb_checkpoints/Untitled-checkpoint.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..74b1c4a3db35dad85a3a705c35dd76b7652fe829 --- /dev/null +++ b/lxmert/src/.ipynb_checkpoints/Untitled-checkpoint.ipynb @@ -0,0 +1,81 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "id": "loose-wrong", + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'src'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mlxmert_lrp\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLxmertForQuestionAnswering\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mLxmertForQuestionAnsweringLRP\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtasks\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mvqa_data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodeling_frcnn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mGeneralizedRCNN\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvqa_utils\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocessing_image\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPreprocess\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/media/data2/hila_chefer/lxmert/lxmert/src/lxmert_lrp.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCrossEntropyLoss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSmoothL1Loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 28\u001b[0m from transformers.file_utils import (\n\u001b[1;32m 29\u001b[0m \u001b[0mModelOutput\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'src'" + ] + } + ], + "source": [ + "from lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP\n", + "from src.tasks import vqa_data\n", + "from src.modeling_frcnn import GeneralizedRCNN\n", + "import src.vqa_utils as utils\n", + "from src.processing_image import Preprocess\n", + "from transformers import LxmertTokenizer\n", + "from src.huggingface_lxmert import LxmertForQuestionAnswering\n", + "\n", + "from tqdm import tqdm\n", + "from src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines\n", + "import random\n", + "import cv2\n", + "\n", + "COCO_VAL_PATH = '/media/data2/hila_chefer/env_MMF/datasets/coco/subset_val/images/val2014/'\n", + "\n", + "OBJ_URL = \"https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/objects_vocab.txt\"\n", + "ATTR_URL = \"https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/attributes_vocab.txt\"\n", + "VQA_URL = \"https://raw.githubusercontent.com/airsplay/lxmert/master/data/vqa/trainval_label2ans.json\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "emerging-trace", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "royal-small", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/lxmert/src/ExplanationGenerator.py b/lxmert/src/ExplanationGenerator.py new file mode 100644 index 0000000000000000000000000000000000000000..9c7351a5e095a64abe39e7a22defe81fee770273 --- /dev/null +++ b/lxmert/src/ExplanationGenerator.py @@ -0,0 +1,665 @@ +import numpy as np +import torch +import copy + + +def compute_rollout_attention(all_layer_matrices, start_layer=0): + # adding residual consideration + num_tokens = all_layer_matrices[0].shape[1] + eye = torch.eye(num_tokens).to(all_layer_matrices[0].device) + all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))] + matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True) + for i in range(len(all_layer_matrices))] + joint_attention = matrices_aug[start_layer] + for i in range(start_layer + 1, len(matrices_aug)): + joint_attention = matrices_aug[i].matmul(joint_attention) + return joint_attention + + +# rule 5 from paper +def avg_heads(cam, grad): + cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]) + grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1]) + cam = grad * cam + cam = cam.clamp(min=0).mean(dim=0) + return cam + + +# rules 6 + 7 from paper +def apply_self_attention_rules(R_ss, R_sq, cam_ss): + R_sq_addition = torch.matmul(cam_ss, R_sq) + R_ss_addition = torch.matmul(cam_ss, R_ss) + return R_ss_addition, R_sq_addition + + +# rules 10 + 11 from paper +def apply_mm_attention_rules(R_ss, R_qq, R_qs, cam_sq, apply_normalization=True, apply_self_in_rule_10=True): + R_ss_normalized = R_ss + R_qq_normalized = R_qq + if apply_normalization: + R_ss_normalized = handle_residual(R_ss) + R_qq_normalized = handle_residual(R_qq) + R_sq_addition = torch.matmul(R_ss_normalized.t(), torch.matmul(cam_sq, R_qq_normalized)) + if not apply_self_in_rule_10: + R_sq_addition = cam_sq + R_ss_addition = torch.matmul(cam_sq, R_qs) + return R_sq_addition, R_ss_addition + + +# normalization- eq. 8+9 +def handle_residual(orig_self_attention): + self_attention = orig_self_attention.clone() + diag_idx = range(self_attention.shape[-1]) + # computing R hat + self_attention -= torch.eye(self_attention.shape[-1]).to(self_attention.device) + assert self_attention[diag_idx, diag_idx].min() >= 0 + # normalizing R hat + self_attention = self_attention / self_attention.sum(dim=-1, keepdim=True) + self_attention += torch.eye(self_attention.shape[-1]).to(self_attention.device) + return self_attention + + +class GeneratorOurs: + def __init__(self, model_usage, save_visualization=False): + self.model_usage = model_usage + self.save_visualization = save_visualization + + def handle_self_attention_lang(self, blocks): + for blk in blocks: + grad = blk.attention.self.get_attn_gradients().detach() + if self.use_lrp: + cam = blk.attention.self.get_attn_cam().detach() + else: + cam = blk.attention.self.get_attn().detach() + cam = avg_heads(cam, grad) + R_t_t_add, R_t_i_add = apply_self_attention_rules(self.R_t_t, self.R_t_i, cam) + self.R_t_t += R_t_t_add + self.R_t_i += R_t_i_add + + def handle_self_attention_image(self, blocks): + for blk in blocks: + grad = blk.attention.self.get_attn_gradients().detach() + if self.use_lrp: + cam = blk.attention.self.get_attn_cam().detach() + else: + cam = blk.attention.self.get_attn().detach() + cam = avg_heads(cam, grad) + R_i_i_add, R_i_t_add = apply_self_attention_rules(self.R_i_i, self.R_i_t, cam) + self.R_i_i += R_i_i_add + self.R_i_t += R_i_t_add + + def handle_co_attn_self_lang(self, block): + grad = block.lang_self_att.self.get_attn_gradients().detach() + if self.use_lrp: + cam = block.lang_self_att.self.get_attn_cam().detach() + else: + cam = block.lang_self_att.self.get_attn().detach() + cam = avg_heads(cam, grad) + R_t_t_add, R_t_i_add = apply_self_attention_rules(self.R_t_t, self.R_t_i, cam) + self.R_t_t += R_t_t_add + self.R_t_i += R_t_i_add + + def handle_co_attn_self_image(self, block): + grad = block.visn_self_att.self.get_attn_gradients().detach() + if self.use_lrp: + cam = block.visn_self_att.self.get_attn_cam().detach() + else: + cam = block.visn_self_att.self.get_attn().detach() + cam = avg_heads(cam, grad) + R_i_i_add, R_i_t_add = apply_self_attention_rules(self.R_i_i, self.R_i_t, cam) + self.R_i_i += R_i_i_add + self.R_i_t += R_i_t_add + + def handle_co_attn_lang(self, block): + if self.use_lrp: + cam_t_i = block.visual_attention.att.get_attn_cam().detach() + else: + cam_t_i = block.visual_attention.att.get_attn().detach() + grad_t_i = block.visual_attention.att.get_attn_gradients().detach() + cam_t_i = avg_heads(cam_t_i, grad_t_i) + R_t_i_addition, R_t_t_addition = apply_mm_attention_rules(self.R_t_t, self.R_i_i, self.R_i_t, cam_t_i, + apply_normalization=self.normalize_self_attention, + apply_self_in_rule_10=self.apply_self_in_rule_10) + return R_t_i_addition, R_t_t_addition + + def handle_co_attn_image(self, block): + if self.use_lrp: + cam_i_t = block.visual_attention_copy.att.get_attn_cam().detach() + else: + cam_i_t = block.visual_attention_copy.att.get_attn().detach() + grad_i_t = block.visual_attention_copy.att.get_attn_gradients().detach() + cam_i_t = avg_heads(cam_i_t, grad_i_t) + R_i_t_addition, R_i_i_addition = apply_mm_attention_rules(self.R_i_i, self.R_t_t, self.R_t_i, cam_i_t, + apply_normalization=self.normalize_self_attention, + apply_self_in_rule_10=self.apply_self_in_rule_10) + return R_i_t_addition, R_i_i_addition + + def generate_ours(self, input, index=None, use_lrp=True, normalize_self_attention=True, apply_self_in_rule_10=True, + method_name="ours"): + self.use_lrp = use_lrp + self.normalize_self_attention = normalize_self_attention + self.apply_self_in_rule_10 = apply_self_in_rule_10 + kwargs = {"alpha": 1} + output = self.model_usage.forward(input).question_answering_score + model = self.model_usage.model + + # initialize relevancy matrices + text_tokens = self.model_usage.text_len + image_bboxes = self.model_usage.image_boxes_len + + # text self attention matrix + self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device) + # image self attention matrix + self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device) + # impact of images on text + self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device) + # impact of text on images + self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device) + + if index is None: + index = np.argmax(output.cpu().data.numpy(), axis=-1) + + one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) + one_hot[0, index] = 1 + one_hot_vector = one_hot + one_hot = torch.from_numpy(one_hot).requires_grad_(True) + one_hot = torch.sum(one_hot.cuda() * output) + + model.zero_grad() + one_hot.backward(retain_graph=True) + if self.use_lrp: + model.relprop(torch.tensor(one_hot_vector).to(output.device), **kwargs) + + # language self attention + blocks = model.lxmert.encoder.layer + self.handle_self_attention_lang(blocks) + + # image self attention + blocks = model.lxmert.encoder.r_layers + self.handle_self_attention_image(blocks) + + # cross attn layers + blocks = model.lxmert.encoder.x_layers + for i, blk in enumerate(blocks): + # in the last cross attention module, only the text cross modal + # attention has an impact on the CLS token, since it's the first + # token in the language tokens + if i == len(blocks) - 1: + break + # cross attn- first for language then for image + R_t_i_addition, R_t_t_addition = self.handle_co_attn_lang(blk) + R_i_t_addition, R_i_i_addition = self.handle_co_attn_image(blk) + + self.R_t_i += R_t_i_addition + self.R_t_t += R_t_t_addition + self.R_i_t += R_i_t_addition + self.R_i_i += R_i_i_addition + + # language self attention + self.handle_co_attn_self_lang(blk) + + # image self attention + self.handle_co_attn_self_image(blk) + + # take care of last cross attention layer- only text + blk = model.lxmert.encoder.x_layers[-1] + # cross attn- first for language then for image + R_t_i_addition, R_t_t_addition = self.handle_co_attn_lang(blk) + self.R_t_i += R_t_i_addition + self.R_t_t += R_t_t_addition + + # language self attention + self.handle_co_attn_self_lang(blk) + + # disregard the [CLS] token itself + self.R_t_t[0, 0] = 0 + return self.R_t_t, self.R_t_i + + +class GeneratorOursAblationNoAggregation: + def __init__(self, model_usage, save_visualization=False): + self.model_usage = model_usage + self.save_visualization = save_visualization + + def handle_self_attention_lang(self, blocks): + for blk in blocks: + grad = blk.attention.self.get_attn_gradients().detach() + if self.use_lrp: + cam = blk.attention.self.get_attn_cam().detach() + else: + cam = blk.attention.self.get_attn().detach() + cam = avg_heads(cam, grad) + R_t_t_add, R_t_i_add = apply_self_attention_rules(self.R_t_t, self.R_t_i, cam) + self.R_t_t = R_t_t_add + self.R_t_i = R_t_i_add + + def handle_self_attention_image(self, blocks): + for blk in blocks: + grad = blk.attention.self.get_attn_gradients().detach() + if self.use_lrp: + cam = blk.attention.self.get_attn_cam().detach() + else: + cam = blk.attention.self.get_attn().detach() + cam = avg_heads(cam, grad) + R_i_i_add, R_i_t_add = apply_self_attention_rules(self.R_i_i, self.R_i_t, cam) + self.R_i_i = R_i_i_add + self.R_i_t = R_i_t_add + + def handle_co_attn_self_lang(self, block): + grad = block.lang_self_att.self.get_attn_gradients().detach() + if self.use_lrp: + cam = block.lang_self_att.self.get_attn_cam().detach() + else: + cam = block.lang_self_att.self.get_attn().detach() + cam = avg_heads(cam, grad) + R_t_t_add, R_t_i_add = apply_self_attention_rules(self.R_t_t, self.R_t_i, cam) + self.R_t_t = R_t_t_add + self.R_t_i = R_t_i_add + + def handle_co_attn_self_image(self, block): + grad = block.visn_self_att.self.get_attn_gradients().detach() + if self.use_lrp: + cam = block.visn_self_att.self.get_attn_cam().detach() + else: + cam = block.visn_self_att.self.get_attn().detach() + cam = avg_heads(cam, grad) + R_i_i_add, R_i_t_add = apply_self_attention_rules(self.R_i_i, self.R_i_t, cam) + self.R_i_i = R_i_i_add + self.R_i_t = R_i_t_add + + def handle_co_attn_lang(self, block): + if self.use_lrp: + cam_t_i = block.visual_attention.att.get_attn_cam().detach() + else: + cam_t_i = block.visual_attention.att.get_attn().detach() + grad_t_i = block.visual_attention.att.get_attn_gradients().detach() + cam_t_i = avg_heads(cam_t_i, grad_t_i) + R_t_i_addition, R_t_t_addition = apply_mm_attention_rules(self.R_t_t, self.R_i_i, self.R_i_t, cam_t_i, + apply_normalization=self.normalize_self_attention) + return R_t_i_addition, R_t_t_addition + + def handle_co_attn_image(self, block): + if self.use_lrp: + cam_i_t = block.visual_attention_copy.att.get_attn_cam().detach() + else: + cam_i_t = block.visual_attention_copy.att.get_attn().detach() + grad_i_t = block.visual_attention_copy.att.get_attn_gradients().detach() + cam_i_t = avg_heads(cam_i_t, grad_i_t) + R_i_t_addition, R_i_i_addition = apply_mm_attention_rules(self.R_i_i, self.R_t_t, self.R_t_i, cam_i_t, + apply_normalization=self.normalize_self_attention) + return R_i_t_addition, R_i_i_addition + + def generate_ours_no_agg(self, input, index=None, use_lrp=False, normalize_self_attention=True, + method_name="ours_no_agg"): + self.use_lrp = use_lrp + self.normalize_self_attention = normalize_self_attention + kwargs = {"alpha": 1} + output = self.model_usage.forward(input).question_answering_score + model = self.model_usage.model + + # initialize relevancy matrices + text_tokens = self.model_usage.text_len + image_bboxes = self.model_usage.image_boxes_len + + # text self attention matrix + self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device) + # image self attention matrix + self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device) + # impact of images on text + self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device) + # impact of text on images + self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device) + + if index is None: + index = np.argmax(output.cpu().data.numpy(), axis=-1) + + one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) + one_hot[0, index] = 1 + one_hot_vector = one_hot + one_hot = torch.from_numpy(one_hot).requires_grad_(True) + one_hot = torch.sum(one_hot.cuda() * output) + + model.zero_grad() + one_hot.backward(retain_graph=True) + if self.use_lrp: + model.relprop(torch.tensor(one_hot_vector).to(output.device), **kwargs) + + # language self attention + blocks = model.lxmert.encoder.layer + self.handle_self_attention_lang(blocks) + + # image self attention + blocks = model.lxmert.encoder.r_layers + self.handle_self_attention_image(blocks) + + # cross attn layers + blocks = model.lxmert.encoder.x_layers + for i, blk in enumerate(blocks): + # in the last cross attention module, only the text cross modal + # attention has an impact on the CLS token, since it's the first + # token in the language tokens + if i == len(blocks) - 1: + break + # cross attn- first for language then for image + R_t_i_addition, R_t_t_addition = self.handle_co_attn_lang(blk) + R_i_t_addition, R_i_i_addition = self.handle_co_attn_image(blk) + + self.R_t_i = R_t_i_addition + self.R_t_t = R_t_t_addition + self.R_i_t = R_i_t_addition + self.R_i_i = R_i_i_addition + + # language self attention + self.handle_co_attn_self_lang(blk) + + # image self attention + self.handle_co_attn_self_image(blk) + + # take care of last cross attention layer- only text + blk = model.lxmert.encoder.x_layers[-1] + # cross attn- first for language then for image + R_t_i_addition, R_t_t_addition = self.handle_co_attn_lang(blk) + self.R_t_i = R_t_i_addition + self.R_t_t = R_t_t_addition + + # language self attention + self.handle_co_attn_self_lang(blk) + + # disregard the [CLS] token itself + self.R_t_t[0, 0] = 0 + return self.R_t_t, self.R_t_i + + +class GeneratorBaselines: + def __init__(self, model_usage, save_visualization=False): + self.model_usage = model_usage + self.save_visualization = save_visualization + + def generate_transformer_attr(self, input, index=None, method_name="transformer_attr"): + kwargs = {"alpha": 1} + output = self.model_usage.forward(input).question_answering_score + model = self.model_usage.model + + # initialize relevancy matrices + text_tokens = self.model_usage.text_len + image_bboxes = self.model_usage.image_boxes_len + + # text self attention matrix + self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device) + # image self attention matrix + self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device) + # impact of images on text + self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device) + # impact of text on images + self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device) + + if index == None: + index = np.argmax(output.cpu().data.numpy(), axis=-1) + + one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) + one_hot[0, index] = 1 + one_hot_vector = one_hot + one_hot = torch.from_numpy(one_hot).requires_grad_(True) + one_hot = torch.sum(one_hot.cuda() * output) + + model.zero_grad() + one_hot.backward(retain_graph=True) + model.relprop(torch.tensor(one_hot_vector).to(output.device), **kwargs) + + # language self attention + blocks = model.lxmert.encoder.layer + for blk in blocks: + grad = blk.attention.self.get_attn_gradients().detach() + cam = blk.attention.self.get_attn_cam().detach() + cam = avg_heads(cam, grad) + self.R_t_t += torch.matmul(cam, self.R_t_t) + + # image self attention + blocks = model.lxmert.encoder.r_layers + for blk in blocks: + grad = blk.attention.self.get_attn_gradients().detach() + cam = blk.attention.self.get_attn_cam().detach() + cam = avg_heads(cam, grad) + self.R_i_i += torch.matmul(cam, self.R_i_i) + + # cross attn layers + blocks = model.lxmert.encoder.x_layers + for i, blk in enumerate(blocks): + # in the last cross attention module, only the text cross modal + # attention has an impact on the CLS token, since it's the first + # token in the language tokens + if i == len(blocks) - 1: + break + + # language self attention + grad = blk.lang_self_att.self.get_attn_gradients().detach() + cam = blk.lang_self_att.self.get_attn_cam().detach() + cam = avg_heads(cam, grad) + self.R_t_t += torch.matmul(cam, self.R_t_t) + + # image self attention + grad = blk.visn_self_att.self.get_attn_gradients().detach() + cam = blk.visn_self_att.self.get_attn_cam().detach() + cam = avg_heads(cam, grad) + self.R_i_i += torch.matmul(cam, self.R_i_i) + + # take care of last cross attention layer- only text + blk = model.lxmert.encoder.x_layers[-1] + # cross attn cam will be the one used for the R_t_i matrix + cam_t_i = blk.visual_attention.att.get_attn_cam().detach() + grad_t_i = blk.visual_attention.att.get_attn_gradients().detach() + cam_t_i = avg_heads(cam_t_i, grad_t_i) + # self.R_t_i = torch.matmul(self.R_t_t.t(), torch.matmul(cam_t_i, self.R_i_i)) + self.R_t_i = cam_t_i + + # language self attention + grad = blk.lang_self_att.self.get_attn_gradients().detach() + cam = blk.lang_self_att.self.get_attn_cam().detach() + cam = avg_heads(cam, grad) + self.R_t_t += torch.matmul(cam, self.R_t_t) + + self.R_t_t[0, 0] = 0 + return self.R_t_t, self.R_t_i + + def generate_partial_lrp(self, input, index=None, method_name="partial_lrp"): + kwargs = {"alpha": 1} + output = self.model_usage.forward(input).question_answering_score + model = self.model_usage.model + + # initialize relevancy matrices + text_tokens = self.model_usage.text_len + image_bboxes = self.model_usage.image_boxes_len + + # text self attention matrix + self.R_t_t = torch.zeros(text_tokens, text_tokens).to(model.device) + # image self attention matrix + self.R_i_i = torch.zeros(image_bboxes, image_bboxes).to(model.device) + # impact of images on text + self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device) + # impact of text on images + self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device) + + if index == None: + index = np.argmax(output.cpu().data.numpy(), axis=-1) + + one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) + one_hot[0, index] = 1 + one_hot_vector = one_hot + model.relprop(torch.tensor(one_hot_vector).to(output.device), **kwargs) + + # last cross attention + self- attention layer + blk = model.lxmert.encoder.x_layers[-1] + # cross attn cam will be the one used for the R_t_i matrix + cam_t_i = blk.visual_attention.att.get_attn_cam().detach() + cam_t_i = cam_t_i.reshape(-1, cam_t_i.shape[-2], cam_t_i.shape[-1]).mean(dim=0) + self.R_t_i = cam_t_i + + # language self attention + cam = blk.lang_self_att.self.get_attn_cam().detach() + cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0) + self.R_t_t = cam + + # normalize to get non-negative cams + self.R_t_t = (self.R_t_t - self.R_t_t.min()) / (self.R_t_t.max() - self.R_t_t.min()) + self.R_t_i = (self.R_t_i - self.R_t_i.min()) / (self.R_t_i.max() - self.R_t_i.min()) + # disregard the [CLS] token itself + self.R_t_t[0, 0] = 0 + return self.R_t_t, self.R_t_i + + def generate_raw_attn(self, input, method_name="raw_attention"): + output = self.model_usage.forward(input).question_answering_score + model = self.model_usage.model + + # initialize relevancy matrices + text_tokens = self.model_usage.text_len + image_bboxes = self.model_usage.image_boxes_len + + # text self attention matrix + self.R_t_t = torch.zeros(text_tokens, text_tokens).to(model.device) + # image self attention matrix + self.R_i_i = torch.zeros(image_bboxes, image_bboxes).to(model.device) + # impact of images on text + self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device) + # impact of text on images + self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device) + + # last cross attention + self- attention layer + blk = model.lxmert.encoder.x_layers[-1] + # cross attn cam will be the one used for the R_t_i matrix + cam_t_i = blk.visual_attention.att.get_attn().detach() + cam_t_i = cam_t_i.reshape(-1, cam_t_i.shape[-2], cam_t_i.shape[-1]).mean(dim=0) + # self.R_t_i = torch.matmul(self.R_t_t.t(), torch.matmul(cam_t_i, self.R_i_i)) + self.R_t_i = cam_t_i + + # language self attention + cam = blk.lang_self_att.self.get_attn().detach() + cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0) + self.R_t_t = cam + + # disregard the [CLS] token itself + self.R_t_t[0, 0] = 0 + return self.R_t_t, self.R_t_i + + def gradcam(self, cam, grad): + cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]) + grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1]) + grad = grad.mean(dim=[1, 2], keepdim=True) + cam = (cam * grad).mean(0).clamp(min=0) + return cam + + def generate_attn_gradcam(self, input, index=None, method_name="gradcam"): + output = self.model_usage.forward(input).question_answering_score + model = self.model_usage.model + + # initialize relevancy matrices + text_tokens = self.model_usage.text_len + image_bboxes = self.model_usage.image_boxes_len + + # text self attention matrix + self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device) + # image self attention matrix + self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device) + # impact of images on text + self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device) + # impact of text on images + self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device) + + if index == None: + index = np.argmax(output.cpu().data.numpy(), axis=-1) + + one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) + one_hot[0, index] = 1 + one_hot = torch.from_numpy(one_hot).requires_grad_(True) + one_hot = torch.sum(one_hot.cuda() * output) + + model.zero_grad() + one_hot.backward(retain_graph=True) + + # last cross attention + self- attention layer + blk = model.lxmert.encoder.x_layers[-1] + # cross attn cam will be the one used for the R_t_i matrix + grad_t_i = blk.visual_attention.att.get_attn_gradients().detach() + cam_t_i = blk.visual_attention.att.get_attn().detach() + cam_t_i = self.gradcam(cam_t_i, grad_t_i) + # self.R_t_i = torch.matmul(self.R_t_t.t(), torch.matmul(cam_t_i, self.R_i_i)) + self.R_t_i = cam_t_i + + # language self attention + grad = blk.lang_self_att.self.get_attn_gradients().detach() + cam = blk.lang_self_att.self.get_attn().detach() + self.R_t_t = self.gradcam(cam, grad) + + # disregard the [CLS] token itself + self.R_t_t[0, 0] = 0 + return self.R_t_t, self.R_t_i + + def generate_rollout(self, input, method_name="rollout"): + output = self.model_usage.forward(input).question_answering_score + model = self.model_usage.model + + # initialize relevancy matrices + text_tokens = self.model_usage.text_len + image_bboxes = self.model_usage.image_boxes_len + + # text self attention matrix + self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device) + # image self attention matrix + self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device) + # impact of images on text + self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device) + # impact of text on images + self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device) + + cams_text = [] + cams_image = [] + # language self attention + blocks = model.lxmert.encoder.layer + for blk in blocks: + cam = blk.attention.self.get_attn().detach() + cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0) + cams_text.append(cam) + + # image self attention + blocks = model.lxmert.encoder.r_layers + for blk in blocks: + cam = blk.attention.self.get_attn().detach() + cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0) + cams_image.append(cam) + + # cross attn layers + blocks = model.lxmert.encoder.x_layers + for i, blk in enumerate(blocks): + # in the last cross attention module, only the text cross modal + # attention has an impact on the CLS token, since it's the first + # token in the language tokens + if i == len(blocks) - 1: + break + + # language self attention + cam = blk.lang_self_att.self.get_attn().detach() + cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0) + cams_text.append(cam) + + # image self attention + cam = blk.visn_self_att.self.get_attn().detach() + cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0) + cams_image.append(cam) + + # take care of last cross attention layer- only text + blk = model.lxmert.encoder.x_layers[-1] + # cross attn cam will be the one used for the R_t_i matrix + cam_t_i = blk.visual_attention.att.get_attn().detach() + cam_t_i = cam_t_i.reshape(-1, cam_t_i.shape[-2], cam_t_i.shape[-1]).mean(dim=0) + self.R_t_t = compute_rollout_attention(copy.deepcopy(cams_text)) + self.R_i_i = compute_rollout_attention(cams_image) + self.R_t_i = torch.matmul(self.R_t_t.t(), torch.matmul(cam_t_i, self.R_i_i)) + # language self attention + cam = blk.lang_self_att.self.get_attn().detach() + cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0) + cams_text.append(cam) + + self.R_t_t = compute_rollout_attention(cams_text) + + # disregard the [CLS] token itself + self.R_t_t[0, 0] = 0 + return self.R_t_t, self.R_t_i diff --git a/lxmert/src/__init__.py b/lxmert/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lxmert/src/__pycache__/ExplanationGenerator.cpython-38.pyc b/lxmert/src/__pycache__/ExplanationGenerator.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7a4589542faac769e5db59a56ce345470970840 Binary files /dev/null and b/lxmert/src/__pycache__/ExplanationGenerator.cpython-38.pyc differ diff --git a/lxmert/src/__pycache__/__init__.cpython-38.pyc b/lxmert/src/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a23129fc68a3332f3c46216f086ddee37c4d2b0 Binary files /dev/null and b/lxmert/src/__pycache__/__init__.cpython-38.pyc differ diff --git a/lxmert/src/__pycache__/huggingface_lxmert.cpython-38.pyc b/lxmert/src/__pycache__/huggingface_lxmert.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c19b5c5e87d39f3f03fd4dcf753d1af914baa251 Binary files /dev/null and b/lxmert/src/__pycache__/huggingface_lxmert.cpython-38.pyc differ diff --git a/lxmert/src/__pycache__/layers.cpython-38.pyc b/lxmert/src/__pycache__/layers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..194c74fa0d76acd64fb0d0de443fa2aa298cdb8e Binary files /dev/null and b/lxmert/src/__pycache__/layers.cpython-38.pyc differ diff --git a/lxmert/src/__pycache__/lxmert_lrp.cpython-38.pyc b/lxmert/src/__pycache__/lxmert_lrp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31b71fde0adc7fdf2c32eb0cc2635fd46a9083d7 Binary files /dev/null and b/lxmert/src/__pycache__/lxmert_lrp.cpython-38.pyc differ diff --git a/lxmert/src/__pycache__/modeling_frcnn.cpython-38.pyc b/lxmert/src/__pycache__/modeling_frcnn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a25adcc8e62984e8f8d4266c59b529a1b1bd8ca Binary files /dev/null and b/lxmert/src/__pycache__/modeling_frcnn.cpython-38.pyc differ diff --git a/lxmert/src/__pycache__/processing_image.cpython-38.pyc b/lxmert/src/__pycache__/processing_image.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c358228a4e673ec4e418edd33518c9641b6fd681 Binary files /dev/null and b/lxmert/src/__pycache__/processing_image.cpython-38.pyc differ diff --git a/lxmert/src/__pycache__/vqa_utils.cpython-38.pyc b/lxmert/src/__pycache__/vqa_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7813cf2b9f498ccb52e741a6812baca6352c3ee Binary files /dev/null and b/lxmert/src/__pycache__/vqa_utils.cpython-38.pyc differ diff --git a/lxmert/src/huggingface_lxmert.py b/lxmert/src/huggingface_lxmert.py new file mode 100644 index 0000000000000000000000000000000000000000..0461c7b86735ae2a2e9a79b8472617d29a9befda --- /dev/null +++ b/lxmert/src/huggingface_lxmert.py @@ -0,0 +1,1472 @@ +# coding=utf-8 +# Copyright 2018 Hao Tan, Mohit Bansal, and the HuggingFace team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch lxmert model. """ + + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss, SmoothL1Loss + +from transformers.activations import ACT2FN, gelu +from transformers.file_utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.configuration_lxmert import LxmertConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LxmertConfig" +_TOKENIZER_FOR_DOC = "LxmertTokenizer" + +LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "unc-nlp/lxmert-base-uncased", +] + + +class GeLU(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return gelu(x) + + +@dataclass +class LxmertModelOutput(ModelOutput): + """ + Lxmert's outputs that contain the last hidden states, pooled outputs, and attention probabilities for the language, + visual, and, cross-modality encoders. (note: the visual encoder in Lxmert is referred to as the "relation-ship" + encoder") + + + Args: + language_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the language encoder. + vision_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the visual encoder. + pooled_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification, CLS, token) further processed + by a Linear layer and a Tanh activation function. The Linear + language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + language_output: Optional[torch.FloatTensor] = None + vision_output: Optional[torch.FloatTensor] = None + pooled_output: Optional[torch.FloatTensor] = None + language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + language_attentions: Optional[Tuple[torch.FloatTensor]] = None + vision_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class LxmertForQuestionAnsweringOutput(ModelOutput): + """ + Output type of :class:`~transformers.LxmertForQuestionAnswering`. + + Args: + loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss.k. + question_answering_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_qa_answers)`, `optional`): + Prediction scores of question answering objective (classification). + language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + question_answering_score: Optional[torch.FloatTensor] = None + language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + language_attentions: Optional[Tuple[torch.FloatTensor]] = None + vision_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class LxmertForPreTrainingOutput(ModelOutput): + """ + Output type of :class:`~transformers.LxmertForPreTraining`. + + Args: + loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cross_relationship_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): + Prediction scores of the textual matching objective (classification) head (scores of True/False + continuation before SoftMax). + question_answering_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_qa_answers)`): + Prediction scores of question answering objective (classification). + language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + + """ + + loss: [torch.FloatTensor] = None + prediction_logits: Optional[torch.FloatTensor] = None + cross_relationship_score: Optional[torch.FloatTensor] = None + question_answering_score: Optional[torch.FloatTensor] = None + language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + language_attentions: Optional[Tuple[torch.FloatTensor]] = None + vision_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +def load_tf_weights_in_lxmert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + ] + for n in name + ): + logger.info("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info("Skipping {}".format("/".join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +class LxmertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + device = input_ids.device + else: + input_shape = inputs_embeds.size()[:-1] + device = inputs_embeds.device + seq_length = input_shape[1] + + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class LxmertAttention(nn.Module): + def __init__(self, config, ctx_dim=None, save_cams=False): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.head_size = self.num_attention_heads * self.attention_head_size + + # visual_dim = 2048 + if ctx_dim is None: + ctx_dim = config.hidden_size + self.query = nn.Linear(config.hidden_size, self.head_size) + self.key = nn.Linear(ctx_dim, self.head_size) + self.value = nn.Linear(ctx_dim, self.head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.save_cams = save_cams + self.attn = None + self.attn_gradients = None + + def get_attn(self): + ret = self.attn + self.attn = None + return ret + + def save_attn(self, attn): + if self.attn is not None: + self.attn = [self.attn, attn] + else: + self.attn = attn + + def save_attn_gradients(self, attn_gradients): + if self.attn_gradients is not None: + self.attn_gradients = [self.attn_gradients, attn_gradients] + else: + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + ret = self.attn_gradients + self.attn_gradients = None + return ret + + def reset(self): + self.attn = None + self.attn_gradients = None + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, context, attention_mask=None, output_attentions=False): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(context) + mixed_value_layer = self.value(context) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # if self.save_cams: + self.save_attn(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + +class LxmertAttentionOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class LxmertCrossAttentionLayer(nn.Module): + def __init__(self, config, save_cams=False): + super().__init__() + self.att = LxmertAttention(config, save_cams=save_cams) + self.output = LxmertAttentionOutput(config) + + def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None, output_attentions=False): + output = self.att(input_tensor, ctx_tensor, ctx_att_mask, output_attentions=output_attentions) + if output_attentions: + attention_probs = output[1] + attention_output = self.output(output[0], input_tensor) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + return outputs + + +class LxmertSelfAttentionLayer(nn.Module): + def __init__(self, config, save_cams=False): + super().__init__() + self.self = LxmertAttention(config, save_cams=save_cams) + self.output = LxmertAttentionOutput(config) + + def forward(self, input_tensor, attention_mask, output_attentions=False): + # Self attention attends to itself, thus keys and queries are the same (input_tensor). + output = self.self( + input_tensor, + input_tensor, + attention_mask, + output_attentions=output_attentions, + ) + if output_attentions: + attention_probs = output[1] + attention_output = self.output(output[0], input_tensor) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + return outputs + + +class LxmertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class LxmertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class LxmertLayer(nn.Module): + def __init__(self, config, save_cams=False): + super().__init__() + self.attention = LxmertSelfAttentionLayer(config, save_cams=save_cams) + self.intermediate = LxmertIntermediate(config) + self.output = LxmertOutput(config) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions) + attention_output = outputs[0] + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output,) + outputs[1:] # add attentions if we output them + return outputs + + +class LxmertXLayer(nn.Module): + def __init__(self, config, save_cams=False): + super().__init__() + # The cross-attention Layer + self.visual_attention = LxmertCrossAttentionLayer(config, save_cams=save_cams) + + # Self-attention Layers + self.lang_self_att = LxmertSelfAttentionLayer(config) + self.visn_self_att = LxmertSelfAttentionLayer(config) + + # Intermediate and Output Layers (FFNs) + self.lang_inter = LxmertIntermediate(config) + self.lang_output = LxmertOutput(config) + self.visn_inter = LxmertIntermediate(config) + self.visn_output = LxmertOutput(config) + + def cross_att( + self, + lang_input, + lang_attention_mask, + visual_input, + visual_attention_mask, + output_x_attentions=False, + ): + # Cross Attention + lang_att_output = self.visual_attention( + lang_input, + visual_input, + ctx_att_mask=visual_attention_mask, + output_attentions=output_x_attentions, + ) + visual_att_output = self.visual_attention( + visual_input, + lang_input, + ctx_att_mask=lang_attention_mask, + output_attentions=False, + ) + return lang_att_output, visual_att_output + + def self_att(self, lang_input, lang_attention_mask, visual_input, visual_attention_mask): + # Self Attention + lang_att_output = self.lang_self_att(lang_input, lang_attention_mask, output_attentions=False) + visual_att_output = self.visn_self_att(visual_input, visual_attention_mask, output_attentions=False) + return lang_att_output[0], visual_att_output[0] + + def output_fc(self, lang_input, visual_input): + # FC layers + lang_inter_output = self.lang_inter(lang_input) + visual_inter_output = self.visn_inter(visual_input) + + # Layer output + lang_output = self.lang_output(lang_inter_output, lang_input) + visual_output = self.visn_output(visual_inter_output, visual_input) + + return lang_output, visual_output + + def forward( + self, + lang_feats, + lang_attention_mask, + visual_feats, + visual_attention_mask, + output_attentions=False, + ): + + lang_att_output, visual_att_output = self.cross_att( + lang_input=lang_feats, + lang_attention_mask=lang_attention_mask, + visual_input=visual_feats, + visual_attention_mask=visual_attention_mask, + output_x_attentions=output_attentions, + ) + attention_probs = lang_att_output[1:] + lang_att_output, visual_att_output = self.self_att( + lang_att_output[0], + lang_attention_mask, + visual_att_output[0], + visual_attention_mask, + ) + + lang_output, visual_output = self.output_fc(lang_att_output, visual_att_output) + return ( + ( + lang_output, + visual_output, + attention_probs[0], + ) + if output_attentions + else (lang_output, visual_output) + ) + + +class LxmertVisualFeatureEncoder(nn.Module): + def __init__(self, config): + super().__init__() + feat_dim = config.visual_feat_dim + pos_dim = config.visual_pos_dim + + # Object feature encoding + self.visn_fc = nn.Linear(feat_dim, config.hidden_size) + self.visn_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12) + + # Box position encoding + self.box_fc = nn.Linear(pos_dim, config.hidden_size) + self.box_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12) + + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, visual_feats, visual_pos): + x = self.visn_fc(visual_feats) + x = self.visn_layer_norm(x) + y = self.box_fc(visual_pos) + y = self.box_layer_norm(y) + output = (x + y) / 2 + + output = self.dropout(output) + return output + + +class LxmertEncoder(nn.Module): + def __init__(self, config, save_cams=False): + super().__init__() + + # Obj-level image embedding layer + self.visn_fc = LxmertVisualFeatureEncoder(config) + self.config = config + + # Number of layers + self.num_l_layers = config.l_layers + self.num_x_layers = config.x_layers + self.num_r_layers = config.r_layers + + # Layers + # Using self.layer instead of self.l_layer to support loading BERT weights. + self.layer = nn.ModuleList([LxmertLayer(config, save_cams=save_cams) for _ in range(self.num_l_layers)]) + self.x_layers = nn.ModuleList([LxmertXLayer(config) for _ in range(self.num_x_layers)]) + self.r_layers = nn.ModuleList([LxmertLayer(config, save_cams=save_cams) for _ in range(self.num_r_layers)]) + + def forward( + self, + lang_feats, + lang_attention_mask, + visual_feats, + visual_pos, + visual_attention_mask=None, + output_attentions=None, + ): + + vision_hidden_states = () + language_hidden_states = () + vision_attentions = () if output_attentions or self.config.output_attentions else None + language_attentions = () if output_attentions or self.config.output_attentions else None + cross_encoder_attentions = () if output_attentions or self.config.output_attentions else None + + visual_feats = self.visn_fc(visual_feats, visual_pos) + + # Run language layers + for layer_module in self.layer: + l_outputs = layer_module(lang_feats, lang_attention_mask, output_attentions=output_attentions) + lang_feats = l_outputs[0] + language_hidden_states = language_hidden_states + (lang_feats,) + if language_attentions is not None: + language_attentions = language_attentions + (l_outputs[1],) + + # Run relational layers + for layer_module in self.r_layers: + v_outputs = layer_module(visual_feats, visual_attention_mask, output_attentions=output_attentions) + visual_feats = v_outputs[0] + vision_hidden_states = vision_hidden_states + (visual_feats,) + if vision_attentions is not None: + vision_attentions = vision_attentions + (v_outputs[1],) + + # Run cross-modality layers + for layer_module in self.x_layers: + x_outputs = layer_module( + lang_feats, + lang_attention_mask, + visual_feats, + visual_attention_mask, + output_attentions=output_attentions, + ) + lang_feats, visual_feats = x_outputs[:2] + vision_hidden_states = vision_hidden_states + (visual_feats,) + language_hidden_states = language_hidden_states + (lang_feats,) + if cross_encoder_attentions is not None: + cross_encoder_attentions = cross_encoder_attentions + (x_outputs[2],) + visual_encoder_outputs = ( + vision_hidden_states, + vision_attentions if output_attentions else None, + ) + lang_encoder_outputs = ( + language_hidden_states, + language_attentions if output_attentions else None, + ) + return ( + visual_encoder_outputs, + lang_encoder_outputs, + cross_encoder_attentions if output_attentions else None, + ) + + +class LxmertPooler(nn.Module): + def __init__(self, config): + super(LxmertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class LxmertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super(LxmertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.transform_act_fn = ACT2FN[config.hidden_act] + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class LxmertLMPredictionHead(nn.Module): + def __init__(self, config, lxmert_model_embedding_weights): + super(LxmertLMPredictionHead, self).__init__() + self.transform = LxmertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + lxmert_model_embedding_weights.size(1), + lxmert_model_embedding_weights.size(0), + bias=False, + ) + self.decoder.weight = lxmert_model_embedding_weights + self.bias = nn.Parameter(torch.zeros(lxmert_model_embedding_weights.size(0))) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + return hidden_states + + +class LxmertVisualAnswerHead(nn.Module): + def __init__(self, config, num_labels): + super().__init__() + hid_dim = config.hidden_size + self.logit_fc = nn.Sequential( + nn.Linear(hid_dim, hid_dim * 2), + GeLU(), + nn.LayerNorm(hid_dim * 2, eps=1e-12), + nn.Linear(hid_dim * 2, num_labels), + ) + + def forward(self, hidden_states): + return self.logit_fc(hidden_states) + + +class LxmertVisualObjHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = LxmertPredictionHeadTransform(config) + # Decide the use of visual losses + visual_losses = {} + if config.visual_obj_loss: + visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels} + if config.visual_attr_loss: + visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels} + if config.visual_obj_loss: + visual_losses["feat"] = { + "shape": (-1, config.visual_feat_dim), + "num": config.visual_feat_dim, + } + self.visual_losses = visual_losses + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder_dict = nn.ModuleDict( + {key: nn.Linear(config.hidden_size, self.visual_losses[key]["num"]) for key in self.visual_losses} + ) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + output = {} + for key in self.visual_losses: + output[key] = self.decoder_dict[key](hidden_states) + return output + + +class LxmertPreTrainingHeads(nn.Module): + def __init__(self, config, lxmert_model_embedding_weights): + super(LxmertPreTrainingHeads, self).__init__() + self.predictions = LxmertLMPredictionHead(config, lxmert_model_embedding_weights) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class LxmertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LxmertConfig + load_tf_weights = load_tf_weights_in_lxmert + base_model_prefix = "lxmert" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +LXMERT_START_DOCSTRING = r""" + + The lxmert model was proposed in `lxmert: Learning Cross-Modality Encoder Representations from Transformers + `__ by Hao Tan and Mohit Bansal. It's a vision and language transformer model, + pretrained on a variety of multi-modal datasets comprising of GQA, VQAv2.0, MCSCOCO captions, and Visual genome, + using a combination of masked language modeling, region of interest feature regression, cross entropy loss for + question answering attribute prediction, and object tag prediction. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.LxmertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +LXMERT_INPUTS_DOCSTRING = r""" + + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.LxmertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + visual_feats: (:obj:`torch.FloatTensor` of shape :obj:՝(batch_size, num_visual_features, visual_feat_dim)՝): + This input represents visual features. They ROI pooled object features from bounding boxes using a + faster-RCNN model) + + These are currently not provided by the transformers library. + visual_pos: (:obj:`torch.FloatTensor` of shape :obj:՝(batch_size, num_visual_features, visual_pos_dim)՝): + This input represents spacial features corresponding to their relative (via index) visual features. The + pre-trained lxmert model expects these spacial features to be normalized bounding boxes on a scale of 0 to + 1. + + These are currently not provided by the transformers library. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + visual_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`__ + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Lxmert Model transformer outputting raw hidden-states without any specific head on top.", + LXMERT_START_DOCSTRING, +) +class LxmertModel(LxmertPreTrainedModel): + def __init__(self, config, save_cams=False): + super().__init__(config) + self.embeddings = LxmertEmbeddings(config) + self.encoder = LxmertEncoder(config, save_cams=save_cams) + self.pooler = LxmertPooler(config) + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="unc-nlp/lxmert-base-uncased", + output_type=LxmertModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + visual_feats=None, + visual_pos=None, + attention_mask=None, + visual_attention_mask=None, + token_type_ids=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + assert visual_feats is not None, "`visual_feats` cannot be `None`" + assert visual_pos is not None, "`visual_pos` cannot be `None`" + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # Process the visual attention mask + if visual_attention_mask is not None: + extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2) + extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=self.dtype) + extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0 + else: + extended_visual_attention_mask = None + + # Positional Word Embeddings + embedding_output = self.embeddings(input_ids, token_type_ids, inputs_embeds) + + # Run Lxmert encoder + encoder_outputs = self.encoder( + embedding_output, + extended_attention_mask, + visual_feats=visual_feats, + visual_pos=visual_pos, + visual_attention_mask=extended_visual_attention_mask, + output_attentions=output_attentions, + ) + + visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2] + vision_hidden_states = visual_encoder_outputs[0] + language_hidden_states = lang_encoder_outputs[0] + + all_attentions = () + if output_attentions: + language_attentions = lang_encoder_outputs[1] + vision_attentions = visual_encoder_outputs[1] + cross_encoder_attentions = encoder_outputs[2] + all_attentions = ( + language_attentions, + vision_attentions, + cross_encoder_attentions, + ) + + hidden_states = (language_hidden_states, vision_hidden_states) if output_hidden_states else () + + visual_output = vision_hidden_states[-1] + lang_output = language_hidden_states[-1] + pooled_output = self.pooler(lang_output) + + if not return_dict: + return (lang_output, visual_output, pooled_output) + hidden_states + all_attentions + + return LxmertModelOutput( + pooled_output=pooled_output, + language_output=lang_output, + vision_output=visual_output, + language_hidden_states=language_hidden_states if output_hidden_states else None, + vision_hidden_states=vision_hidden_states if output_hidden_states else None, + language_attentions=language_attentions if output_attentions else None, + vision_attentions=vision_attentions if output_attentions else None, + cross_encoder_attentions=cross_encoder_attentions if output_attentions else None, + ) + + +@add_start_docstrings( + """Lxmert Model with a specified pretraining head on top. """, + LXMERT_START_DOCSTRING, +) +class LxmertForPreTraining(LxmertPreTrainedModel): + def __init__(self, config, save_cams=False): + super().__init__(config) + # Configuration + self.config = config + self.num_qa_labels = config.num_qa_labels + self.visual_loss_normalizer = config.visual_loss_normalizer + + # Use of pretraining tasks + self.task_mask_lm = config.task_mask_lm + self.task_obj_predict = config.task_obj_predict + self.task_matched = config.task_matched + self.task_qa = config.task_qa + + # Lxmert backbone + self.lxmert = LxmertModel(config, save_cams=save_cams) + + # Pre-training heads + self.cls = LxmertPreTrainingHeads(config, self.lxmert.embeddings.word_embeddings.weight) + if self.task_obj_predict: + self.obj_predict_head = LxmertVisualObjHead(config) + if self.task_qa: + self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels) + + # Weight initialization + self.init_weights() + + # Loss functions + self.loss_fcts = { + "l2": SmoothL1Loss(reduction="none"), + "visual_ce": CrossEntropyLoss(reduction="none"), + "ce": CrossEntropyLoss(), + } + + visual_losses = {} + if config.visual_obj_loss: + visual_losses["obj"] = { + "shape": (-1,), + "num": config.num_object_labels, + "loss": "visual_ce", + } + if config.visual_attr_loss: + visual_losses["attr"] = { + "shape": (-1,), + "num": config.num_attr_labels, + "loss": "visual_ce", + } + if config.visual_obj_loss: + visual_losses["feat"] = { + "shape": (-1, config.visual_feat_dim), + "num": config.visual_feat_dim, + "loss": "l2", + } + self.visual_losses = visual_losses + + def resize_num_qa_labels(self, num_labels): + """ + Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size + will add newly initialized weights. Reducing the size will remove weights from the end + + Args: + num_labels (:obj:`int`, `optional`): + New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized + weights at the end. Reducing the size will remove weights from the end. If not provided or :obj:`None`, + just returns a pointer to the qa labels :obj:`torch.nn.Linear`` module of the model without doing + anything. + + Return: + :obj:`torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer + """ + + cur_qa_logit_layer = self.get_qa_logit_layer() + if num_labels is None or cur_qa_logit_layer is None: + return + new_qa_logit_layer = self._resize_qa_labels(num_labels) + self.config.num_qa_labels = num_labels + self.num_qa_labels = num_labels + + return new_qa_logit_layer + + def _resize_qa_labels(self, num_labels): + cur_qa_logit_layer = self.get_qa_logit_layer() + new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels) + self._set_qa_logit_layer(new_qa_logit_layer) + return self.get_qa_logit_layer() + + def get_qa_logit_layer(self) -> nn.Module: + """ + Returns the the linear layer that produces question answering logits. + + Returns: + :obj:`nn.Module`: A torch module mapping the question answering prediction hidden states or :obj:`None` if + lxmert does not have a visual answering head. + """ + if hasattr(self, "answer_head"): + return self.answer_head.logit_fc[-1] + + def _set_qa_logit_layer(self, qa_logit_layer): + self.answer_head.logit_fc[-1] = qa_logit_layer + + def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels): + + if num_labels is None: + return cur_qa_logit_layer + + cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size() + if cur_qa_labels == num_labels: + return cur_qa_logit_layer + + # Build new linear output + if getattr(cur_qa_logit_layer, "bias", None) is not None: + new_qa_logit_layer = nn.Linear(hidden_dim, num_labels) + else: + new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False) + + new_qa_logit_layer.to(cur_qa_logit_layer.weight.device) + + # initialize all new labels + self._init_weights(new_qa_logit_layer) + + # Copy labels from the previous weights + num_labels_to_copy = min(cur_qa_labels, num_labels) + new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :] + if getattr(cur_qa_logit_layer, "bias", None) is not None: + new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy] + + return new_qa_logit_layer + + @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=LxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + visual_feats=None, + visual_pos=None, + attention_mask=None, + visual_attention_mask=None, + token_type_ids=None, + inputs_embeds=None, + labels=None, + obj_labels=None, + matched_label=None, + ans=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + r""" + labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + obj_labels: (``Dict[Str: Tuple[Torch.FloatTensor, Torch.FloatTensor]]``, `optional`): + each key is named after each one of the visual losses and each element of the tuple is of the shape + ``(batch_size, num_features)`` and ``(batch_size, num_features, visual_feature_dim)`` for each the label id + and the label score respectively + matched_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`): + Labels for computing the whether or not the text input matches the image (classification) loss. Input + should be a sequence pair (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``: + + - 0 indicates that the sentence does not match the image, + - 1 indicates that the sentence does match the image. + ans: (``Torch.Tensor`` of shape ``(batch_size)``, `optional`): + a one hot representation hof the correct answer `optional` + + Returns: + """ + + if "masked_lm_labels" in kwargs: + warnings.warn( + "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("masked_lm_labels") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + device = input_ids.device if input_ids is not None else inputs_embeds.device + lxmert_output = self.lxmert( + input_ids=input_ids, + visual_feats=visual_feats, + visual_pos=visual_pos, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + visual_attention_mask=visual_attention_mask, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + lang_output, visual_output, pooled_output = ( + lxmert_output[0], + lxmert_output[1], + lxmert_output[2], + ) + lang_prediction_scores, cross_relationship_score = self.cls(lang_output, pooled_output) + if self.task_qa: + answer_score = self.answer_head(pooled_output) + else: + answer_score = pooled_output[0][0] + + total_loss = ( + None + if (labels is None and matched_label is None and obj_labels is None and ans is None) + else torch.tensor(0.0, device=device) + ) + if labels is not None and self.task_mask_lm: + masked_lm_loss = self.loss_fcts["ce"]( + lang_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + total_loss += masked_lm_loss + if matched_label is not None and self.task_matched: + matched_loss = self.loss_fcts["ce"](cross_relationship_score.view(-1, 2), matched_label.view(-1)) + total_loss += matched_loss + if obj_labels is not None and self.task_obj_predict: + total_visual_loss = torch.tensor(0.0, device=input_ids.device) + visual_prediction_scores_dict = self.obj_predict_head(visual_output) + for key, key_info in self.visual_losses.items(): + label, mask_conf = obj_labels[key] + output_dim = key_info["num"] + loss_fct_name = key_info["loss"] + label_shape = key_info["shape"] + weight = self.visual_loss_normalizer + visual_loss_fct = self.loss_fcts[loss_fct_name] + visual_prediction_scores = visual_prediction_scores_dict[key] + visual_loss = visual_loss_fct( + visual_prediction_scores.view(-1, output_dim), + label.view(*label_shape), + ) + if visual_loss.dim() > 1: # Regression Losses + visual_loss = visual_loss.mean(1) + visual_loss = (visual_loss * mask_conf.view(-1)).mean() * weight + total_visual_loss += visual_loss + total_loss += total_visual_loss + if ans is not None and self.task_qa: + answer_loss = self.loss_fcts["ce"](answer_score.view(-1, self.num_qa_labels), ans.view(-1)) + total_loss += answer_loss + + if not return_dict: + output = ( + lang_prediction_scores, + cross_relationship_score, + answer_score, + ) + lxmert_output[3:] + return ((total_loss,) + output) if total_loss is not None else output + + return LxmertForPreTrainingOutput( + loss=total_loss, + prediction_logits=lang_prediction_scores, + cross_relationship_score=cross_relationship_score, + question_answering_score=answer_score, + language_hidden_states=lxmert_output.language_hidden_states, + vision_hidden_states=lxmert_output.vision_hidden_states, + language_attentions=lxmert_output.language_attentions, + vision_attentions=lxmert_output.vision_attentions, + cross_encoder_attentions=lxmert_output.cross_encoder_attentions, + ) + + +@add_start_docstrings( + """Lxmert Model with a visual-answering head on top for downstream QA tasks""", + LXMERT_START_DOCSTRING, +) +class LxmertForQuestionAnswering(LxmertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + # Configuration + self.config = config + self.num_qa_labels = config.num_qa_labels + self.visual_loss_normalizer = config.visual_loss_normalizer + + # Lxmert backbone + self.lxmert = LxmertModel(config) + + self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels) + + # Weight initialization + self.init_weights() + + # Loss function + self.loss = CrossEntropyLoss() + + def resize_num_qa_labels(self, num_labels): + """ + Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size + will add newly initialized weights. Reducing the size will remove weights from the end + + Args: + num_labels (:obj:`int`, `optional`): + New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized + weights at the end. Reducing the size will remove weights from the end. If not provided or :obj:`None`, + just returns a pointer to the qa labels :obj:`torch.nn.Linear`` module of the model without doing + anything. + + Return: + :obj:`torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer + """ + + cur_qa_logit_layer = self.get_qa_logit_layer() + if num_labels is None or cur_qa_logit_layer is None: + return + new_qa_logit_layer = self._resize_qa_labels(num_labels) + self.config.num_qa_labels = num_labels + self.num_qa_labels = num_labels + + return new_qa_logit_layer + + def _resize_qa_labels(self, num_labels): + cur_qa_logit_layer = self.get_qa_logit_layer() + new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels) + self._set_qa_logit_layer(new_qa_logit_layer) + return self.get_qa_logit_layer() + + def get_qa_logit_layer(self) -> nn.Module: + """ + Returns the the linear layer that produces question answering logits + + Returns: + :obj:`nn.Module`: A torch module mapping the question answering prediction hidden states. :obj:`None`: A + NoneType object if Lxmert does not have the visual answering head. + """ + + if hasattr(self, "answer_head"): + return self.answer_head.logit_fc[-1] + + def _set_qa_logit_layer(self, qa_logit_layer): + self.answer_head.logit_fc[-1] = qa_logit_layer + + def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels): + + if num_labels is None: + return cur_qa_logit_layer + + cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size() + if cur_qa_labels == num_labels: + return cur_qa_logit_layer + + # Build new linear output + if getattr(cur_qa_logit_layer, "bias", None) is not None: + new_qa_logit_layer = nn.Linear(hidden_dim, num_labels) + else: + new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False) + + new_qa_logit_layer.to(cur_qa_logit_layer.weight.device) + + # initialize all new labels + self._init_weights(new_qa_logit_layer) + + # Copy labels from the previous weights + num_labels_to_copy = min(cur_qa_labels, num_labels) + new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :] + if getattr(cur_qa_logit_layer, "bias", None) is not None: + new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy] + + return new_qa_logit_layer + + @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="unc-nlp/lxmert-base-uncased", + output_type=LxmertForQuestionAnsweringOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + visual_feats=None, + visual_pos=None, + attention_mask=None, + visual_attention_mask=None, + token_type_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels: (``Torch.Tensor`` of shape ``(batch_size)``, `optional`): + A one-hot representation of the correct answer + + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + lxmert_output = self.lxmert( + input_ids=input_ids, + visual_feats=visual_feats, + visual_pos=visual_pos, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + visual_attention_mask=visual_attention_mask, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + pooled_output = lxmert_output[2] + answer_score = self.answer_head(pooled_output) + loss = None + if labels is not None: + loss = self.loss(answer_score.view(-1, self.num_qa_labels), labels.view(-1)) + + if not return_dict: + output = (answer_score,) + lxmert_output[3:] + return (loss,) + output if loss is not None else output + + return LxmertForQuestionAnsweringOutput( + loss=loss, + question_answering_score=answer_score, + language_hidden_states=lxmert_output.language_hidden_states, + vision_hidden_states=lxmert_output.vision_hidden_states, + language_attentions=lxmert_output.language_attentions, + vision_attentions=lxmert_output.vision_attentions, + cross_encoder_attentions=lxmert_output.cross_encoder_attentions, + ) \ No newline at end of file diff --git a/lxmert/src/layers.py b/lxmert/src/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..32ec05ee5b31dc4cc849c81ff63ebc2338f5054e --- /dev/null +++ b/lxmert/src/layers.py @@ -0,0 +1,292 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d', + 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect', + 'LayerNorm', 'AddEye', 'Tanh', 'MatMul', 'Mul'] + + +def safe_divide(a, b): + den = b.clamp(min=1e-9) + b.clamp(max=1e-9) + den = den + den.eq(0).type(den.type()) * 1e-9 + return a / den * b.ne(0).type(b.type()) + + +def forward_hook(self, input, output): + if type(input[0]) in (list, tuple): + self.X = [] + for i in input[0]: + x = i.detach() + x.requires_grad = True + self.X.append(x) + else: + self.X = input[0].detach() + self.X.requires_grad = True + + self.Y = output + + +def backward_hook(self, grad_input, grad_output): + self.grad_input = grad_input + self.grad_output = grad_output + + +class RelProp(nn.Module): + def __init__(self): + super(RelProp, self).__init__() + # if not self.training: + self.register_forward_hook(forward_hook) + + def gradprop(self, Z, X, S): + C = torch.autograd.grad(Z, X, S, retain_graph=True) + return C + + def relprop(self, R, alpha): + return R + + +class RelPropSimple(RelProp): + def relprop(self, R, alpha): + Z = self.forward(self.X) + S = safe_divide(R, Z) + C = self.gradprop(Z, self.X, S) + + if torch.is_tensor(self.X) == False: + outputs = [] + outputs.append(self.X[0] * C[0]) + outputs.append(self.X[1] * C[1]) + else: + outputs = self.X * (C[0]) + return outputs + +class AddEye(RelPropSimple): + # input of shape B, C, seq_len, seq_len + def forward(self, input): + return input + torch.eye(input.shape[2]).expand_as(input).to(input.device) + +class ReLU(nn.ReLU, RelProp): + pass + +class GELU(nn.GELU, RelProp): + pass + +class Softmax(nn.Softmax, RelProp): + pass + +class Mul(RelPropSimple): + def forward(self, inputs): + return torch.mul(*inputs) + +class Tanh(nn.Tanh, RelProp): + pass +class LayerNorm(nn.LayerNorm, RelProp): + pass + +class Dropout(nn.Dropout, RelProp): + pass + +class MatMul(RelPropSimple): + def forward(self, inputs): + return torch.matmul(*inputs) + +class MaxPool2d(nn.MaxPool2d, RelPropSimple): + pass + +class LayerNorm(nn.LayerNorm, RelProp): + pass + +class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple): + pass + + +class AvgPool2d(nn.AvgPool2d, RelPropSimple): + pass + + +class Add(RelPropSimple): + def forward(self, inputs): + return torch.add(*inputs) + + def relprop(self, R, alpha): + Z = self.forward(self.X) + S = safe_divide(R, Z) + C = self.gradprop(Z, self.X, S) + + a = self.X[0] * C[0] + b = self.X[1] * C[1] + + a_sum = a.sum() + b_sum = b.sum() + + a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum() + b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum() + + a = a * safe_divide(a_fact, a.sum()) + b = b * safe_divide(b_fact, b.sum()) + + outputs = [a, b] + + return outputs + +class einsum(RelPropSimple): + def __init__(self, equation): + super().__init__() + self.equation = equation + def forward(self, *operands): + return torch.einsum(self.equation, *operands) + +class IndexSelect(RelProp): + def forward(self, inputs, dim, indices): + self.__setattr__('dim', dim) + self.__setattr__('indices', indices) + + return torch.index_select(inputs, dim, indices) + + def relprop(self, R, alpha): + Z = self.forward(self.X, self.dim, self.indices) + S = safe_divide(R, Z) + C = self.gradprop(Z, self.X, S) + + if torch.is_tensor(self.X) == False: + outputs = [] + outputs.append(self.X[0] * C[0]) + outputs.append(self.X[1] * C[1]) + else: + outputs = self.X * (C[0]) + return outputs + + + +class Clone(RelProp): + def forward(self, input, num): + self.__setattr__('num', num) + outputs = [] + for _ in range(num): + outputs.append(input) + + return outputs + + def relprop(self, R, alpha): + Z = [] + for _ in range(self.num): + Z.append(self.X) + S = [safe_divide(r, z) for r, z in zip(R, Z)] + C = self.gradprop(Z, self.X, S)[0] + + R = self.X * C + + return R + + +class Cat(RelProp): + def forward(self, inputs, dim): + self.__setattr__('dim', dim) + return torch.cat(inputs, dim) + + def relprop(self, R, alpha): + Z = self.forward(self.X, self.dim) + S = safe_divide(R, Z) + C = self.gradprop(Z, self.X, S) + + outputs = [] + for x, c in zip(self.X, C): + outputs.append(x * c) + + return outputs + + +class Sequential(nn.Sequential): + def relprop(self, R, alpha): + for m in reversed(self._modules.values()): + R = m.relprop(R, alpha) + return R + + +class BatchNorm2d(nn.BatchNorm2d, RelProp): + def relprop(self, R, alpha): + X = self.X + beta = 1 - alpha + weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / ( + (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5)) + Z = X * weight + 1e-9 + S = R / Z + Ca = S * weight + R = self.X * (Ca) + return R + + +class Linear(nn.Linear, RelProp): + def relprop(self, R, alpha): + beta = alpha - 1 + pw = torch.clamp(self.weight, min=0) + nw = torch.clamp(self.weight, max=0) + px = torch.clamp(self.X, min=0) + nx = torch.clamp(self.X, max=0) + + def f(w1, w2, x1, x2): + Z1 = F.linear(x1, w1) + Z2 = F.linear(x2, w2) + S1 = safe_divide(R, Z1 + Z2) + S2 = safe_divide(R, Z1 + Z2) + C1 = x1 * self.gradprop(Z1, x1, S1)[0] + C2 = x2 * self.gradprop(Z2, x2, S2)[0] + + return C1 + C2 + + activator_relevances = f(pw, nw, px, nx) + inhibitor_relevances = f(nw, pw, px, nx) + + R = alpha * activator_relevances - beta * inhibitor_relevances + + return R + + +class Conv2d(nn.Conv2d, RelProp): + def gradprop2(self, DY, weight): + Z = self.forward(self.X) + + output_padding = self.X.size()[2] - ( + (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0]) + + return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding) + + def relprop(self, R, alpha): + if self.X.shape[1] == 3: + pw = torch.clamp(self.weight, min=0) + nw = torch.clamp(self.weight, max=0) + X = self.X + L = self.X * 0 + \ + torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, + keepdim=True)[0] + H = self.X * 0 + \ + torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, + keepdim=True)[0] + Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \ + torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \ + torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9 + + S = R / Za + C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw) + R = C + else: + beta = alpha - 1 + pw = torch.clamp(self.weight, min=0) + nw = torch.clamp(self.weight, max=0) + px = torch.clamp(self.X, min=0) + nx = torch.clamp(self.X, max=0) + + def f(w1, w2, x1, x2): + Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding) + Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding) + S1 = safe_divide(R, Z1) + S2 = safe_divide(R, Z2) + C1 = x1 * self.gradprop(Z1, x1, S1)[0] + C2 = x2 * self.gradprop(Z2, x2, S2)[0] + return C1 + C2 + + activator_relevances = f(pw, nw, px, nx) + inhibitor_relevances = f(nw, pw, px, nx) + + R = alpha * activator_relevances - beta * inhibitor_relevances + return R diff --git a/lxmert/src/lxmert_lrp.py b/lxmert/src/lxmert_lrp.py new file mode 100644 index 0000000000000000000000000000000000000000..8f38987f593d5344fe4cb2d5c615a990df6fa8d2 --- /dev/null +++ b/lxmert/src/lxmert_lrp.py @@ -0,0 +1,1693 @@ +# coding=utf-8 +# Copyright 2018 Hao Tan, Mohit Bansal, and the HuggingFace team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch lxmert model. """ + +import math +import os +import warnings +import copy +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss, SmoothL1Loss +from lxmert.lxmert.src.layers import * +from transformers.file_utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.configuration_lxmert import LxmertConfig + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LxmertConfig" +_TOKENIZER_FOR_DOC = "LxmertTokenizer" + +LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "unc-nlp/lxmert-base-uncased", +] + +ACT2FN = { + "relu": ReLU, + "tanh": Tanh, + "gelu": GELU, +} + + +@dataclass +class LxmertModelOutput(ModelOutput): + """ + Lxmert's outputs that contain the last hidden states, pooled outputs, and attention probabilities for the language, + visual, and, cross-modality encoders. (note: the visual encoder in Lxmert is referred to as the "relation-ship" + encoder") + + + Args: + language_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the language encoder. + vision_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the visual encoder. + pooled_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification, CLS, token) further processed + by a Linear layer and a Tanh activation function. The Linear + language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + language_output: Optional[torch.FloatTensor] = None + vision_output: Optional[torch.FloatTensor] = None + pooled_output: Optional[torch.FloatTensor] = None + language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + language_attentions: Optional[Tuple[torch.FloatTensor]] = None + vision_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class LxmertForQuestionAnsweringOutput(ModelOutput): + """ + Output type of :class:`~transformers.LxmertForQuestionAnswering`. + + Args: + loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss.k. + question_answering_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_qa_answers)`, `optional`): + Prediction scores of question answering objective (classification). + language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + question_answering_score: Optional[torch.FloatTensor] = None + language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + language_attentions: Optional[Tuple[torch.FloatTensor]] = None + vision_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class LxmertForPreTrainingOutput(ModelOutput): + """ + Output type of :class:`~transformers.LxmertForPreTraining`. + + Args: + loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cross_relationship_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`): + Prediction scores of the textual matching objective (classification) head (scores of True/False + continuation before SoftMax). + question_answering_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_qa_answers)`): + Prediction scores of question answering objective (classification). + language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + + """ + + loss: [torch.FloatTensor] = None + prediction_logits: Optional[torch.FloatTensor] = None + cross_relationship_score: Optional[torch.FloatTensor] = None + question_answering_score: Optional[torch.FloatTensor] = None + language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + language_attentions: Optional[Tuple[torch.FloatTensor]] = None + vision_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +def load_tf_weights_in_lxmert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + ] + for n in name + ): + logger.info("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info("Skipping {}".format("/".join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +class LxmertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = Dropout(config.hidden_dropout_prob) + + self.add1 = Add() + self.add2 = Add() + + def forward(self, input_ids, token_type_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + device = input_ids.device + else: + input_shape = inputs_embeds.size()[:-1] + device = inputs_embeds.device + seq_length = input_shape[1] + + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand(input_shape) + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + # embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.add1([token_type_embeddings, position_embeddings]) + embeddings = self.add2([embeddings, inputs_embeds]) + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def relprop(self, cam, **kwargs): + cam = self.dropout.relprop(cam, **kwargs) + cam = self.LayerNorm.relprop(cam, **kwargs) + + # [inputs_embeds, position_embeddings, token_type_embeddings] + (cam) = self.add2.relprop(cam, **kwargs) + + return cam + + +class LxmertAttention(nn.Module): + def __init__(self, config, ctx_dim=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.head_size = self.num_attention_heads * self.attention_head_size + + # visual_dim = 2048 + if ctx_dim is None: + ctx_dim = config.hidden_size + self.query = Linear(config.hidden_size, self.head_size) + self.key = Linear(ctx_dim, self.head_size) + self.value = Linear(ctx_dim, self.head_size) + + self.dropout = Dropout(config.attention_probs_dropout_prob) + + self.matmul1 = MatMul() + self.matmul2 = MatMul() + self.softmax = Softmax(dim=-1) + self.add = Add() + self.mul = Mul() + self.head_mask = None + self.attention_mask = None + self.clone = Clone() + + self.attn = None + self.attn_gradients = None + self.attn_cam = None + + def get_attn(self): + return self.attn + + def save_attn(self, attn): + self.attn = attn + + def get_attn_cam(self): + return self.attn_cam + + def save_attn_cam(self, attn_cam): + self.attn_cam = attn_cam + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def transpose_for_scores_relprop(self, x): + return x.permute(0, 2, 1, 3).flatten(2) + + def forward(self, hidden_states, context, attention_mask=None, output_attentions=False): + key, value = self.clone(context, 2) + mixed_query_layer = self.query(hidden_states) + # mixed_key_layer = self.key(context) + # mixed_value_layer = self.value(context) + mixed_key_layer = self.key(key) + mixed_value_layer = self.value(value) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)]) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + if attention_mask is not None: + attention_scores = self.add([attention_scores, attention_mask]) + + # Normalize the attention scores to probabilities. + attention_probs = self.softmax(attention_scores) + + self.save_attn(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = self.matmul2([attention_probs, value_layer]) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + def relprop(self, cam, **kwargs): + # Assume output_attentions == False + cam = self.transpose_for_scores(cam) + + # [attention_probs, value_layer] + (cam1, cam2) = self.matmul2.relprop(cam, **kwargs) + cam1 /= 2 + cam2 /= 2 + + self.save_attn_cam(cam1) + + cam1 = self.dropout.relprop(cam1, **kwargs) + + cam1 = self.softmax.relprop(cam1, **kwargs) + + if self.attention_mask is not None: + # [attention_scores, attention_mask] + (cam1, _) = self.add.relprop(cam1, **kwargs) + + # [query_layer, key_layer.transpose(-1, -2)] + (cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs) + cam1_1 /= 2 + cam1_2 /= 2 + + # query + cam1_1 = self.transpose_for_scores_relprop(cam1_1) + cam1_1 = self.query.relprop(cam1_1, **kwargs) + + # key + cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2)) + cam1_2 = self.key.relprop(cam1_2, **kwargs) + + # value + cam2 = self.transpose_for_scores_relprop(cam2) + cam2 = self.value.relprop(cam2, **kwargs) + + cam = self.clone.relprop((cam1_2, cam2), **kwargs) + + # returning two cams- one for the hidden state and one for the context + return (cam1_1, cam) + + +class LxmertAttentionOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = Dropout(config.hidden_dropout_prob) + self.add = Add() + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + add = self.add([hidden_states, input_tensor]) + hidden_states = self.LayerNorm(add) + return hidden_states + + def relprop(self, cam, **kwargs): + cam = self.LayerNorm.relprop(cam, **kwargs) + # [hidden_states, input_tensor] + (cam1, cam2) = self.add.relprop(cam, **kwargs) + cam1 = self.dropout.relprop(cam1, **kwargs) + cam1 = self.dense.relprop(cam1, **kwargs) + + return (cam1, cam2) + + +class LxmertCrossAttentionLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.att = LxmertAttention(config) + self.output = LxmertAttentionOutput(config) + self.clone = Clone() + + def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None, output_attentions=False): + inp1, inp2 = self.clone(input_tensor, 2) + output = self.att(inp1, ctx_tensor, ctx_att_mask, output_attentions=output_attentions) + if output_attentions: + attention_probs = output[1] + attention_output = self.output(output[0], inp2) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + return outputs + + def relprop(self, cam, **kwargs): + cam_output, cam_inp2 = self.output.relprop(cam, **kwargs) + cam_inp1, cam_ctx = self.att.relprop(cam_output, **kwargs) + cam_inp = self.clone.relprop((cam_inp1, cam_inp2), **kwargs) + + return (cam_inp, cam_ctx) + + +class LxmertSelfAttentionLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.self = LxmertAttention(config) + self.output = LxmertAttentionOutput(config) + self.clone = Clone() + + def forward(self, input_tensor, attention_mask, output_attentions=False): + inp1, inp2, inp3 = self.clone(input_tensor, 3) + # Self attention attends to itself, thus keys and queries are the same (input_tensor). + output = self.self( + inp1, + inp2, + attention_mask, + output_attentions=output_attentions, + ) + if output_attentions: + attention_probs = output[1] + attention_output = self.output(output[0], inp3) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + return outputs + + def relprop(self, cam, **kwargs): + cam_output, cam_inp3 = self.output.relprop(cam, **kwargs) + cam_inp1, cam_inp2 = self.self.relprop(cam_output, **kwargs) + cam_inp = self.clone.relprop((cam_inp1, cam_inp2, cam_inp3), **kwargs) + + return cam_inp + + +class LxmertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = ACT2FN[config.hidden_act]() + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + def relprop(self, cam, **kwargs): + cam = self.intermediate_act_fn.relprop(cam, **kwargs) + cam = self.dense.relprop(cam, **kwargs) + return cam + + +class LxmertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = Dropout(config.hidden_dropout_prob) + self.add = Add() + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + add = self.add([hidden_states, input_tensor]) + hidden_states = self.LayerNorm(add) + return hidden_states + + def relprop(self, cam, **kwargs): + cam = self.LayerNorm.relprop(cam, **kwargs) + # [hidden_states, input_tensor] + (cam1, cam2)= self.add.relprop(cam, **kwargs) + cam1 = self.dropout.relprop(cam1, **kwargs) + cam1 = self.dense.relprop(cam1, **kwargs) + return (cam1, cam2) + + +class LxmertLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = LxmertSelfAttentionLayer(config) + self.intermediate = LxmertIntermediate(config) + self.output = LxmertOutput(config) + self.clone = Clone() + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions) + attention_output = outputs[0] + ao1, ao2 = self.clone(attention_output, 2) + intermediate_output = self.intermediate(ao1) + layer_output = self.output(intermediate_output, ao2) + outputs = (layer_output,) + outputs[1:] # add attentions if we output them + return outputs + + def relprop(self, cam, **kwargs): + (cam1, cam2) = self.output.relprop(cam, **kwargs) + cam1 = self.intermediate.relprop(cam1, **kwargs) + cam = self.clone.relprop((cam1, cam2), **kwargs) + cam = self.attention.relprop(cam, **kwargs) + return cam + + +class LxmertXLayer(nn.Module): + def __init__(self, config): + super().__init__() + # The cross-attention Layer + self.visual_attention = LxmertCrossAttentionLayer(config) + + # Self-attention Layers + self.lang_self_att = LxmertSelfAttentionLayer(config) + self.visn_self_att = LxmertSelfAttentionLayer(config) + + # Intermediate and Output Layers (FFNs) + self.lang_inter = LxmertIntermediate(config) + self.lang_output = LxmertOutput(config) + self.visn_inter = LxmertIntermediate(config) + self.visn_output = LxmertOutput(config) + + self.clone1 = Clone() + self.clone2 = Clone() + self.clone3 = Clone() + self.clone4 = Clone() + + def cross_att( + self, + lang_input, + lang_attention_mask, + visual_input, + visual_attention_mask, + output_x_attentions=False, + ): + lang_input1, lang_input2 = self.clone1(lang_input, 2) + visual_input1, visual_input2 = self.clone2(visual_input, 2) + if not hasattr(self, 'visual_attention_copy'): + self.visual_attention_copy = copy.deepcopy(self.visual_attention) + # Cross Attention + lang_att_output = self.visual_attention( + lang_input1, + visual_input1, + ctx_att_mask=visual_attention_mask, + output_attentions=output_x_attentions, + ) + visual_att_output = self.visual_attention_copy( + visual_input2, + lang_input2, + ctx_att_mask=lang_attention_mask, + output_attentions=False, + ) + return lang_att_output, visual_att_output + + def relprop_cross(self, cam, **kwargs): + cam_lang, cam_vis = cam + cam_vis2, cam_lang2 = self.visual_attention_copy.relprop(cam_vis, **kwargs) + cam_lang1, cam_vis1 = self.visual_attention.relprop(cam_lang, **kwargs) + cam_vis = self.clone2.relprop((cam_vis1, cam_vis2), **kwargs) + cam_lang = self.clone1.relprop((cam_lang1, cam_lang2), **kwargs) + return cam_lang, cam_vis + + + def self_att(self, lang_input, lang_attention_mask, visual_input, visual_attention_mask): + # Self Attention + lang_att_output = self.lang_self_att(lang_input, lang_attention_mask, output_attentions=False) + visual_att_output = self.visn_self_att(visual_input, visual_attention_mask, output_attentions=False) + return lang_att_output[0], visual_att_output[0] + + def relprop_self(self, cam, **kwargs): + cam_lang, cam_vis = cam + cam_vis = self.visn_self_att.relprop(cam_vis, **kwargs) + cam_lang = self.lang_self_att.relprop(cam_lang, **kwargs) + return cam_lang, cam_vis + + def output_fc(self, lang_input, visual_input): + lang_input1, lang_input2 = self.clone3(lang_input, 2) + visual_input1, visual_input2 = self.clone4(visual_input, 2) + # FC layers + lang_inter_output = self.lang_inter(lang_input1) + visual_inter_output = self.visn_inter(visual_input1) + + # Layer output + lang_output = self.lang_output(lang_inter_output, lang_input2) + visual_output = self.visn_output(visual_inter_output, visual_input2) + + return lang_output, visual_output + + def relprop_output(self, cam, **kwargs): + cam_lang, cam_vis = cam + cam_vis_inter, cam_vis2 = self.visn_output.relprop(cam_vis, **kwargs) + cam_lang_inter, cam_lang2 = self.lang_output.relprop(cam_lang, **kwargs) + cam_vis1 = self.visn_inter.relprop(cam_vis_inter, **kwargs) + cam_lang1 = self.lang_inter.relprop(cam_lang_inter, **kwargs) + cam_vis = self.clone4.relprop((cam_vis1, cam_vis2), **kwargs) + cam_lang = self.clone3.relprop((cam_lang1, cam_lang2), **kwargs) + return cam_lang, cam_vis + + def forward( + self, + lang_feats, + lang_attention_mask, + visual_feats, + visual_attention_mask, + output_attentions=False, + ): + lang_att_output, visual_att_output = self.cross_att( + lang_input=lang_feats, + lang_attention_mask=lang_attention_mask, + visual_input=visual_feats, + visual_attention_mask=visual_attention_mask, + output_x_attentions=output_attentions, + ) + attention_probs = lang_att_output[1:] + lang_att_output, visual_att_output = self.self_att( + lang_att_output[0], + lang_attention_mask, + visual_att_output[0], + visual_attention_mask, + ) + + lang_output, visual_output = self.output_fc(lang_att_output, visual_att_output) + return ( + ( + lang_output, + visual_output, + attention_probs[0], + ) + if output_attentions + else (lang_output, visual_output) + ) + + def relprop(self, cam, **kwargs): + cam_lang, cam_vis = cam + cam_lang, cam_vis = self.relprop_output((cam_lang, cam_vis), **kwargs) + cam_lang, cam_vis = self.relprop_self((cam_lang, cam_vis), **kwargs) + cam_lang, cam_vis = self.relprop_cross((cam_lang, cam_vis), **kwargs) + return cam_lang, cam_vis + +class LxmertVisualFeatureEncoder(nn.Module): + def __init__(self, config): + super().__init__() + feat_dim = config.visual_feat_dim + pos_dim = config.visual_pos_dim + + # Object feature encoding + self.visn_fc = Linear(feat_dim, config.hidden_size) + self.visn_layer_norm = LayerNorm(config.hidden_size, eps=1e-12) + + # Box position encoding + self.box_fc = Linear(pos_dim, config.hidden_size) + self.box_layer_norm = LayerNorm(config.hidden_size, eps=1e-12) + + self.dropout = Dropout(config.hidden_dropout_prob) + + def forward(self, visual_feats, visual_pos): + x = self.visn_fc(visual_feats) + x = self.visn_layer_norm(x) + y = self.box_fc(visual_pos) + y = self.box_layer_norm(y) + output = (x + y) / 2 + + output = self.dropout(output) + return output + + def relprop(self, cam, **kwargs): + cam = self.dropout.relprop(cam, **kwargs) + cam = self.visn_layer_norm.relprop(cam, **kwargs) + cam = self.visn_fc.relprop(cam, **kwargs) + return cam + +class LxmertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + # Obj-level image embedding layer + self.visn_fc = LxmertVisualFeatureEncoder(config) + self.config = config + + # Number of layers + self.num_l_layers = config.l_layers + self.num_x_layers = config.x_layers + self.num_r_layers = config.r_layers + + # Layers + # Using self.layer instead of self.l_layer to support loading BERT weights. + self.layer = nn.ModuleList([LxmertLayer(config) for _ in range(self.num_l_layers)]) + self.x_layers = nn.ModuleList([LxmertXLayer(config) for _ in range(self.num_x_layers)]) + self.r_layers = nn.ModuleList([LxmertLayer(config) for _ in range(self.num_r_layers)]) + + def forward( + self, + lang_feats, + lang_attention_mask, + visual_feats, + visual_pos, + visual_attention_mask=None, + output_attentions=None, + ): + + vision_hidden_states = () + language_hidden_states = () + vision_attentions = () if output_attentions or self.config.output_attentions else None + language_attentions = () if output_attentions or self.config.output_attentions else None + cross_encoder_attentions = () if output_attentions or self.config.output_attentions else None + + visual_feats = self.visn_fc(visual_feats, visual_pos) + + # Run language layers + for layer_module in self.layer: + l_outputs = layer_module(lang_feats, lang_attention_mask, output_attentions=output_attentions) + lang_feats = l_outputs[0] + language_hidden_states = language_hidden_states + (lang_feats,) + if language_attentions is not None: + language_attentions = language_attentions + (l_outputs[1],) + + # Run relational layers + for layer_module in self.r_layers: + v_outputs = layer_module(visual_feats, visual_attention_mask, output_attentions=output_attentions) + visual_feats = v_outputs[0] + vision_hidden_states = vision_hidden_states + (visual_feats,) + if vision_attentions is not None: + vision_attentions = vision_attentions + (v_outputs[1],) + + # Run cross-modality layers + for layer_module in self.x_layers: + x_outputs = layer_module( + lang_feats, + lang_attention_mask, + visual_feats, + visual_attention_mask, + output_attentions=output_attentions, + ) + lang_feats, visual_feats = x_outputs[:2] + vision_hidden_states = vision_hidden_states + (visual_feats,) + language_hidden_states = language_hidden_states + (lang_feats,) + if cross_encoder_attentions is not None: + cross_encoder_attentions = cross_encoder_attentions + (x_outputs[2],) + visual_encoder_outputs = ( + vision_hidden_states, + vision_attentions if output_attentions else None, + ) + lang_encoder_outputs = ( + language_hidden_states, + language_attentions if output_attentions else None, + ) + return ( + visual_encoder_outputs, + lang_encoder_outputs, + cross_encoder_attentions if output_attentions else None, + ) + + def relprop(self, cam, **kwargs): + cam_lang, cam_vis = cam + for layer_module in reversed(self.x_layers): + cam_lang, cam_vis = layer_module.relprop((cam_lang, cam_vis), **kwargs) + + for layer_module in reversed(self.r_layers): + cam_vis = layer_module.relprop(cam_vis, **kwargs) + + for layer_module in reversed(self.layer): + cam_lang = layer_module.relprop(cam_lang, **kwargs) + return cam_lang, cam_vis + + +class LxmertPooler(nn.Module): + def __init__(self, config): + super(LxmertPooler, self).__init__() + self.dense = Linear(config.hidden_size, config.hidden_size) + self.activation = Tanh() + + self.pool = IndexSelect() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + # first_token_tensor = hidden_states[:, 0] + first_token_tensor = self.pool(hidden_states, 1, torch.tensor(0, device=hidden_states.device)) + first_token_tensor = first_token_tensor.squeeze(1) + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + def relprop(self, cam, **kwargs): + cam = self.activation.relprop(cam, **kwargs) + cam = self.dense.relprop(cam, **kwargs) + cam = cam.unsqueeze(1) + cam = self.pool.relprop(cam, **kwargs) + + return cam + + +class LxmertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super(LxmertPredictionHeadTransform, self).__init__() + self.dense = Linear(config.hidden_size, config.hidden_size) + self.transform_act_fn = ACT2FN[config.hidden_act] + self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + def relprop(self, cam, **kwargs): + cam = self.LayerNorm.relprop(cam, **kwargs) + cam = self.transform_act_fn.relprop(cam, **kwargs) + cam = self.dense.relprop(cam, **kwargs) + return cam + + +class LxmertLMPredictionHead(nn.Module): + def __init__(self, config, lxmert_model_embedding_weights): + super(LxmertLMPredictionHead, self).__init__() + self.transform = LxmertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = Linear( + lxmert_model_embedding_weights.size(1), + lxmert_model_embedding_weights.size(0), + bias=False, + ) + self.decoder.weight = lxmert_model_embedding_weights + self.bias = nn.Parameter(torch.zeros(lxmert_model_embedding_weights.size(0))) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + return hidden_states + + def relprop(self, cam, **kwargs): + cam = self.decoder.relprop(cam, **kwargs) + cam = self.transform.relprop(cam, **kwargs) + return cam + + +class LxmertVisualAnswerHead(nn.Module): + def __init__(self, config, num_labels): + super().__init__() + hid_dim = config.hidden_size + self.logit_fc = nn.Sequential( + Linear(hid_dim, hid_dim * 2), + GELU(), + LayerNorm(hid_dim * 2, eps=1e-12), + Linear(hid_dim * 2, num_labels), + ) + + def forward(self, hidden_states): + return self.logit_fc(hidden_states) + + def relprop(self, cam, **kwargs): + for m in reversed(self.logit_fc._modules.values()): + cam = m.relprop(cam, **kwargs) + return cam + + +class LxmertVisualObjHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = LxmertPredictionHeadTransform(config) + # Decide the use of visual losses + visual_losses = {} + if config.visual_obj_loss: + visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels} + if config.visual_attr_loss: + visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels} + if config.visual_obj_loss: + visual_losses["feat"] = { + "shape": (-1, config.visual_feat_dim), + "num": config.visual_feat_dim, + } + self.visual_losses = visual_losses + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder_dict = nn.ModuleDict( + {key: nn.Linear(config.hidden_size, self.visual_losses[key]["num"]) for key in self.visual_losses} + ) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + output = {} + for key in self.visual_losses: + output[key] = self.decoder_dict[key](hidden_states) + return output + + def relprop(self, cam, **kwargs): + return self.transform.relprop(cam, **kwargs) + + +class LxmertPreTrainingHeads(nn.Module): + def __init__(self, config, lxmert_model_embedding_weights): + super(LxmertPreTrainingHeads, self).__init__() + self.predictions = LxmertLMPredictionHead(config, lxmert_model_embedding_weights) + self.seq_relationship = Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + def relprop(self, cam, **kwargs): + cam_seq, cam_pooled = cam + cam_pooled = self.seq_relationship.relprop(cam_pooled, **kwargs) + cam_seq = self.predictions.relprop(cam_seq, **kwargs) + return cam_seq, cam_pooled + + +class LxmertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LxmertConfig + load_tf_weights = load_tf_weights_in_lxmert + base_model_prefix = "lxmert" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +LXMERT_START_DOCSTRING = r""" + + The lxmert model was proposed in `lxmert: Learning Cross-Modality Encoder Representations from Transformers + `__ by Hao Tan and Mohit Bansal. It's a vision and language transformer model, + pretrained on a variety of multi-modal datasets comprising of GQA, VQAv2.0, MCSCOCO captions, and Visual genome, + using a combination of masked language modeling, region of interest feature regression, cross entropy loss for + question answering attribute prediction, and object tag prediction. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.LxmertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +LXMERT_INPUTS_DOCSTRING = r""" + + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.LxmertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + visual_feats: (:obj:`torch.FloatTensor` of shape :obj:՝(batch_size, num_visual_features, visual_feat_dim)՝): + This input represents visual features. They ROI pooled object features from bounding boxes using a + faster-RCNN model) + + These are currently not provided by the transformers library. + visual_pos: (:obj:`torch.FloatTensor` of shape :obj:՝(batch_size, num_visual_features, visual_pos_dim)՝): + This input represents spacial features corresponding to their relative (via index) visual features. The + pre-trained lxmert model expects these spacial features to be normalized bounding boxes on a scale of 0 to + 1. + + These are currently not provided by the transformers library. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + visual_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`__ + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Lxmert Model transformer outputting raw hidden-states without any specific head on top.", + LXMERT_START_DOCSTRING, +) +class LxmertModel(LxmertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.embeddings = LxmertEmbeddings(config) + self.encoder = LxmertEncoder(config) + self.pooler = LxmertPooler(config) + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="unc-nlp/lxmert-base-uncased", + output_type=LxmertModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + visual_feats=None, + visual_pos=None, + attention_mask=None, + visual_attention_mask=None, + token_type_ids=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + assert visual_feats is not None, "`visual_feats` cannot be `None`" + assert visual_pos is not None, "`visual_pos` cannot be `None`" + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # Process the visual attention mask + if visual_attention_mask is not None: + extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2) + extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=self.dtype) + extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0 + else: + extended_visual_attention_mask = None + + # Positional Word Embeddings + embedding_output = self.embeddings(input_ids, token_type_ids, inputs_embeds) + + # Run Lxmert encoder + encoder_outputs = self.encoder( + embedding_output, + extended_attention_mask, + visual_feats=visual_feats, + visual_pos=visual_pos, + visual_attention_mask=extended_visual_attention_mask, + output_attentions=output_attentions, + ) + + visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2] + vision_hidden_states = visual_encoder_outputs[0] + language_hidden_states = lang_encoder_outputs[0] + + all_attentions = () + if output_attentions: + language_attentions = lang_encoder_outputs[1] + vision_attentions = visual_encoder_outputs[1] + cross_encoder_attentions = encoder_outputs[2] + all_attentions = ( + language_attentions, + vision_attentions, + cross_encoder_attentions, + ) + + hidden_states = (language_hidden_states, vision_hidden_states) if output_hidden_states else () + + visual_output = vision_hidden_states[-1] + lang_output = language_hidden_states[-1] + pooled_output = self.pooler(lang_output) + + if not return_dict: + return (lang_output, visual_output, pooled_output) + hidden_states + all_attentions + + return LxmertModelOutput( + pooled_output=pooled_output, + language_output=lang_output, + vision_output=visual_output, + language_hidden_states=language_hidden_states if output_hidden_states else None, + vision_hidden_states=vision_hidden_states if output_hidden_states else None, + language_attentions=language_attentions if output_attentions else None, + vision_attentions=vision_attentions if output_attentions else None, + cross_encoder_attentions=cross_encoder_attentions if output_attentions else None, + ) + + def relprop(self, cam, **kwargs): + cam_lang, cam_vis = cam + cam_lang = self.pooler.relprop(cam_lang, **kwargs) + cam_lang, cam_vis = self.encoder.relprop((cam_lang, cam_vis), **kwargs) + return cam_lang, cam_vis + + + +@add_start_docstrings( + """Lxmert Model with a specified pretraining head on top. """, + LXMERT_START_DOCSTRING, +) +class LxmertForPreTraining(LxmertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + # Configuration + self.config = config + self.num_qa_labels = config.num_qa_labels + self.visual_loss_normalizer = config.visual_loss_normalizer + + # Use of pretraining tasks + self.task_mask_lm = config.task_mask_lm + self.task_obj_predict = config.task_obj_predict + self.task_matched = config.task_matched + self.task_qa = config.task_qa + + # Lxmert backbone + self.lxmert = LxmertModel(config) + + # Pre-training heads + self.cls = LxmertPreTrainingHeads(config, self.lxmert.embeddings.word_embeddings.weight) + if self.task_obj_predict: + self.obj_predict_head = LxmertVisualObjHead(config) + if self.task_qa: + self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels) + + # Weight initialization + self.init_weights() + + # Loss functions + self.loss_fcts = { + "l2": SmoothL1Loss(reduction="none"), + "visual_ce": CrossEntropyLoss(reduction="none"), + "ce": CrossEntropyLoss(), + } + + visual_losses = {} + if config.visual_obj_loss: + visual_losses["obj"] = { + "shape": (-1,), + "num": config.num_object_labels, + "loss": "visual_ce", + } + if config.visual_attr_loss: + visual_losses["attr"] = { + "shape": (-1,), + "num": config.num_attr_labels, + "loss": "visual_ce", + } + if config.visual_obj_loss: + visual_losses["feat"] = { + "shape": (-1, config.visual_feat_dim), + "num": config.visual_feat_dim, + "loss": "l2", + } + self.visual_losses = visual_losses + + def resize_num_qa_labels(self, num_labels): + """ + Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size + will add newly initialized weights. Reducing the size will remove weights from the end + + Args: + num_labels (:obj:`int`, `optional`): + New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized + weights at the end. Reducing the size will remove weights from the end. If not provided or :obj:`None`, + just returns a pointer to the qa labels :obj:`torch.nn.Linear`` module of the model without doing + anything. + + Return: + :obj:`torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer + """ + + cur_qa_logit_layer = self.get_qa_logit_layer() + if num_labels is None or cur_qa_logit_layer is None: + return + new_qa_logit_layer = self._resize_qa_labels(num_labels) + self.config.num_qa_labels = num_labels + self.num_qa_labels = num_labels + + return new_qa_logit_layer + + def _resize_qa_labels(self, num_labels): + cur_qa_logit_layer = self.get_qa_logit_layer() + new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels) + self._set_qa_logit_layer(new_qa_logit_layer) + return self.get_qa_logit_layer() + + def get_qa_logit_layer(self) -> nn.Module: + """ + Returns the the linear layer that produces question answering logits. + + Returns: + :obj:`nn.Module`: A torch module mapping the question answering prediction hidden states or :obj:`None` if + lxmert does not have a visual answering head. + """ + if hasattr(self, "answer_head"): + return self.answer_head.logit_fc[-1] + + def _set_qa_logit_layer(self, qa_logit_layer): + self.answer_head.logit_fc[-1] = qa_logit_layer + + def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels): + + if num_labels is None: + return cur_qa_logit_layer + + cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size() + if cur_qa_labels == num_labels: + return cur_qa_logit_layer + + # Build new linear output + if getattr(cur_qa_logit_layer, "bias", None) is not None: + new_qa_logit_layer = nn.Linear(hidden_dim, num_labels) + else: + new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False) + + new_qa_logit_layer.to(cur_qa_logit_layer.weight.device) + + # initialize all new labels + self._init_weights(new_qa_logit_layer) + + # Copy labels from the previous weights + num_labels_to_copy = min(cur_qa_labels, num_labels) + new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :] + if getattr(cur_qa_logit_layer, "bias", None) is not None: + new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy] + + return new_qa_logit_layer + + @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=LxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + visual_feats=None, + visual_pos=None, + attention_mask=None, + visual_attention_mask=None, + token_type_ids=None, + inputs_embeds=None, + labels=None, + obj_labels=None, + matched_label=None, + ans=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + r""" + labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + obj_labels: (``Dict[Str: Tuple[Torch.FloatTensor, Torch.FloatTensor]]``, `optional`): + each key is named after each one of the visual losses and each element of the tuple is of the shape + ``(batch_size, num_features)`` and ``(batch_size, num_features, visual_feature_dim)`` for each the label id + and the label score respectively + matched_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`): + Labels for computing the whether or not the text input matches the image (classification) loss. Input + should be a sequence pair (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``: + + - 0 indicates that the sentence does not match the image, + - 1 indicates that the sentence does match the image. + ans: (``Torch.Tensor`` of shape ``(batch_size)``, `optional`): + a one hot representation hof the correct answer `optional` + + Returns: + """ + + if "masked_lm_labels" in kwargs: + warnings.warn( + "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("masked_lm_labels") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + device = input_ids.device if input_ids is not None else inputs_embeds.device + lxmert_output = self.lxmert( + input_ids=input_ids, + visual_feats=visual_feats, + visual_pos=visual_pos, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + visual_attention_mask=visual_attention_mask, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + lang_output, visual_output, pooled_output = ( + lxmert_output[0], + lxmert_output[1], + lxmert_output[2], + ) + lang_prediction_scores, cross_relationship_score = self.cls(lang_output, pooled_output) + if self.task_qa: + answer_score = self.answer_head(pooled_output) + else: + answer_score = pooled_output[0][0] + + total_loss = ( + None + if (labels is None and matched_label is None and obj_labels is None and ans is None) + else torch.tensor(0.0, device=device) + ) + if labels is not None and self.task_mask_lm: + masked_lm_loss = self.loss_fcts["ce"]( + lang_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + total_loss += masked_lm_loss + if matched_label is not None and self.task_matched: + matched_loss = self.loss_fcts["ce"](cross_relationship_score.view(-1, 2), matched_label.view(-1)) + total_loss += matched_loss + if obj_labels is not None and self.task_obj_predict: + total_visual_loss = torch.tensor(0.0, device=input_ids.device) + visual_prediction_scores_dict = self.obj_predict_head(visual_output) + for key, key_info in self.visual_losses.items(): + label, mask_conf = obj_labels[key] + output_dim = key_info["num"] + loss_fct_name = key_info["loss"] + label_shape = key_info["shape"] + weight = self.visual_loss_normalizer + visual_loss_fct = self.loss_fcts[loss_fct_name] + visual_prediction_scores = visual_prediction_scores_dict[key] + visual_loss = visual_loss_fct( + visual_prediction_scores.view(-1, output_dim), + label.view(*label_shape), + ) + if visual_loss.dim() > 1: # Regression Losses + visual_loss = visual_loss.mean(1) + visual_loss = (visual_loss * mask_conf.view(-1)).mean() * weight + total_visual_loss += visual_loss + total_loss += total_visual_loss + if ans is not None and self.task_qa: + answer_loss = self.loss_fcts["ce"](answer_score.view(-1, self.num_qa_labels), ans.view(-1)) + total_loss += answer_loss + + if not return_dict: + output = ( + lang_prediction_scores, + cross_relationship_score, + answer_score, + ) + lxmert_output[3:] + return ((total_loss,) + output) if total_loss is not None else output + + return LxmertForPreTrainingOutput( + loss=total_loss, + prediction_logits=lang_prediction_scores, + cross_relationship_score=cross_relationship_score, + question_answering_score=answer_score, + language_hidden_states=lxmert_output.language_hidden_states, + vision_hidden_states=lxmert_output.vision_hidden_states, + language_attentions=lxmert_output.language_attentions, + vision_attentions=lxmert_output.vision_attentions, + cross_encoder_attentions=lxmert_output.cross_encoder_attentions, + ) + + + +@add_start_docstrings( + """Lxmert Model with a visual-answering head on top for downstream QA tasks""", + LXMERT_START_DOCSTRING, +) +class LxmertForQuestionAnswering(LxmertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + # Configuration + self.config = config + self.num_qa_labels = config.num_qa_labels + self.visual_loss_normalizer = config.visual_loss_normalizer + + # Lxmert backbone + self.lxmert = LxmertModel(config) + + self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels) + + # Weight initialization + self.init_weights() + + # Loss function + self.loss = CrossEntropyLoss() + + def resize_num_qa_labels(self, num_labels): + """ + Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size + will add newly initialized weights. Reducing the size will remove weights from the end + + Args: + num_labels (:obj:`int`, `optional`): + New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized + weights at the end. Reducing the size will remove weights from the end. If not provided or :obj:`None`, + just returns a pointer to the qa labels :obj:`torch.nn.Linear`` module of the model without doing + anything. + + Return: + :obj:`torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer + """ + + cur_qa_logit_layer = self.get_qa_logit_layer() + if num_labels is None or cur_qa_logit_layer is None: + return + new_qa_logit_layer = self._resize_qa_labels(num_labels) + self.config.num_qa_labels = num_labels + self.num_qa_labels = num_labels + + return new_qa_logit_layer + + def _resize_qa_labels(self, num_labels): + cur_qa_logit_layer = self.get_qa_logit_layer() + new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels) + self._set_qa_logit_layer(new_qa_logit_layer) + return self.get_qa_logit_layer() + + def get_qa_logit_layer(self) -> nn.Module: + """ + Returns the the linear layer that produces question answering logits + + Returns: + :obj:`nn.Module`: A torch module mapping the question answering prediction hidden states. :obj:`None`: A + NoneType object if Lxmert does not have the visual answering head. + """ + + if hasattr(self, "answer_head"): + return self.answer_head.logit_fc[-1] + + def _set_qa_logit_layer(self, qa_logit_layer): + self.answer_head.logit_fc[-1] = qa_logit_layer + + def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels): + + if num_labels is None: + return cur_qa_logit_layer + + cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size() + if cur_qa_labels == num_labels: + return cur_qa_logit_layer + + # Build new linear output + if getattr(cur_qa_logit_layer, "bias", None) is not None: + new_qa_logit_layer = nn.Linear(hidden_dim, num_labels) + else: + new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False) + + new_qa_logit_layer.to(cur_qa_logit_layer.weight.device) + + # initialize all new labels + self._init_weights(new_qa_logit_layer) + + # Copy labels from the previous weights + num_labels_to_copy = min(cur_qa_labels, num_labels) + new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :] + if getattr(cur_qa_logit_layer, "bias", None) is not None: + new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy] + + return new_qa_logit_layer + + @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="unc-nlp/lxmert-base-uncased", + output_type=LxmertForQuestionAnsweringOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + visual_feats=None, + visual_pos=None, + attention_mask=None, + visual_attention_mask=None, + token_type_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels: (``Torch.Tensor`` of shape ``(batch_size)``, `optional`): + A one-hot representation of the correct answer + + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + lxmert_output = self.lxmert( + input_ids=input_ids, + visual_feats=visual_feats, + visual_pos=visual_pos, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + visual_attention_mask=visual_attention_mask, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + pooled_output = lxmert_output[2] + answer_score = self.answer_head(pooled_output) + loss = None + if labels is not None: + loss = self.loss(answer_score.view(-1, self.num_qa_labels), labels.view(-1)) + + if not return_dict: + output = (answer_score,) + lxmert_output[3:] + return (loss,) + output if loss is not None else output + + self.vis_shape = lxmert_output.vision_output.shape + + return LxmertForQuestionAnsweringOutput( + loss=loss, + question_answering_score=answer_score, + language_hidden_states=lxmert_output.language_hidden_states, + vision_hidden_states=lxmert_output.vision_hidden_states, + language_attentions=lxmert_output.language_attentions, + vision_attentions=lxmert_output.vision_attentions, + cross_encoder_attentions=lxmert_output.cross_encoder_attentions, + ) + + def relprop(self, cam, **kwargs): + cam_lang = self.answer_head.relprop(cam, **kwargs) + cam_vis = torch.zeros(self.vis_shape).to(cam_lang.device) + cam_lang, cam_vis = self.lxmert.relprop((cam_lang, cam_vis), **kwargs) + return cam_lang, cam_vis \ No newline at end of file diff --git a/lxmert/src/lxrt/__init__.py b/lxmert/src/lxrt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lxmert/src/lxrt/entry.py b/lxmert/src/lxrt/entry.py new file mode 100644 index 0000000000000000000000000000000000000000..daf7cdf314c90222a6eae393fcbcd8a989d802ab --- /dev/null +++ b/lxmert/src/lxrt/entry.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2019 project LXRT. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +import torch.nn as nn + +from ..lxrt.tokenization import BertTokenizer +from ..lxrt.modeling import LXRTFeatureExtraction as VisualBertForLXRFeature, VISUAL_CONFIG + + +class InputFeatures(object): + """A single set of features of data.""" + + def __init__(self, input_ids, input_mask, segment_ids): + self.input_ids = input_ids + self.input_mask = input_mask + self.segment_ids = segment_ids + + +def convert_sents_to_features(sents, max_seq_length, tokenizer): + """Loads a data file into a list of `InputBatch`s.""" + + features = [] + for (i, sent) in enumerate(sents): + tokens_a = tokenizer.tokenize(sent.strip()) + + # Account for [CLS] and [SEP] with "- 2" + if len(tokens_a) > max_seq_length - 2: + tokens_a = tokens_a[:(max_seq_length - 2)] + + # Keep segment id which allows loading BERT-weights. + tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + segment_ids = [0] * len(tokens) + + input_ids = tokenizer.convert_tokens_to_ids(tokens) + + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1] * len(input_ids) + + # Zero-pad up to the sequence length. + padding = [0] * (max_seq_length - len(input_ids)) + input_ids += padding + input_mask += padding + segment_ids += padding + + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + + features.append( + InputFeatures(input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids)) + return features + + +def set_visual_config(args): + VISUAL_CONFIG.l_layers = args.llayers + VISUAL_CONFIG.x_layers = args.xlayers + VISUAL_CONFIG.r_layers = args.rlayers + + +class LXRTEncoder(nn.Module): + def __init__(self, args, max_seq_length, mode='x'): + super().__init__() + self.max_seq_length = max_seq_length + set_visual_config(args) + + # Using the bert tokenizer + self.tokenizer = BertTokenizer.from_pretrained( + "bert-base-uncased", + do_lower_case=True + ) + + # Build LXRT Model + self.model = VisualBertForLXRFeature.from_pretrained( + "bert-base-uncased", + mode=mode + ) + + if args.from_scratch: + print("initializing all the weights") + self.model.apply(self.model.init_bert_weights) + + def multi_gpu(self): + self.model = nn.DataParallel(self.model) + + @property + def dim(self): + return 768 + + def forward(self, sents, feats, visual_attention_mask=None): + train_features = convert_sents_to_features( + sents, self.max_seq_length, self.tokenizer) + + input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda() + input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda() + segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda() + + output = self.model(input_ids, segment_ids, input_mask, + visual_feats=feats, + visual_attention_mask=visual_attention_mask) + return output + + def save(self, path): + torch.save(self.model.state_dict(), + os.path.join("%s_LXRT.pth" % path)) + + def load(self, path): + # Load state_dict from snapshot file + print("Load lxmert pre-trained model from %s" % path) + state_dict = torch.load("%s_LXRT.pth" % path) + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("module."): + new_state_dict[key[len("module."):]] = value + else: + new_state_dict[key] = value + state_dict = new_state_dict + + # Print out the differences of pre-trained and model weights. + load_keys = set(state_dict.keys()) + model_keys = set(self.model.state_dict().keys()) + print() + print("Weights in loaded but not in model:") + for key in sorted(load_keys.difference(model_keys)): + print(key) + print() + print("Weights in model but not in loaded:") + for key in sorted(model_keys.difference(load_keys)): + print(key) + print() + + # Load weights to model + self.model.load_state_dict(state_dict, strict=False) + + + + diff --git a/lxmert/src/lxrt/file_utils.py b/lxmert/src/lxrt/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..14919cb322b78302049c465046bb1ad437ec1d4e --- /dev/null +++ b/lxmert/src/lxrt/file_utils.py @@ -0,0 +1,247 @@ +""" +Utilities for working with the local dataset cache. +This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" +import json +import logging +import os +import shutil +import tempfile +from functools import wraps +from hashlib import sha256 +import sys +from io import open + +import boto3 +import requests +from botocore.exceptions import ClientError +from tqdm import tqdm + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + +try: + from pathlib import Path + PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', + Path.home() / '.pytorch_pretrained_bert')) +except (AttributeError, ImportError): + PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', + os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the url's, delimited + by a period. + """ + url_bytes = url.encode('utf-8') + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode('utf-8') + etag_hash = sha256(etag_bytes) + filename += '.' + etag_hash.hexdigest() + + return filename + + +def filename_to_url(filename, cache_dir=None): + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError("file {} not found".format(cache_path)) + + meta_path = cache_path + '.json' + if not os.path.exists(meta_path): + raise EnvironmentError("file {} not found".format(meta_path)) + + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata['url'] + etag = metadata['etag'] + + return url, etag + + +def cached_path(url_or_filename, cache_dir=None): + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ('http', 'https', 's3'): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == '': + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad s3 path {}".format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith("/"): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + + @wraps(func) + def wrapper(url, *args, **kwargs): + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response["Error"]["Code"]) == 404: + raise EnvironmentError("file {} not found".format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url): + """Check ETag on S3 object.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file): + """Pull a file directly from S3.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def http_get(url, temp_file): + req = requests.get(url, stream=True) + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url, cache_dir=None): + """ + Given a URL, look for the corresponding dataset in the local cache. + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url) + else: + response = requests.head(url, allow_redirects=True) + if response.status_code != 200: + raise IOError("HEAD request failed for url {} with status code {}" + .format(url, response.status_code)) + etag = response.headers.get("ETag") + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + + # GET file object + if url.startswith("s3://"): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before closing it, so flush to avoid truncation + temp_file.flush() + # shutil.copyfileobj() starts at the current position, so go to the start + temp_file.seek(0) + + logger.info("copying %s to cache at %s", temp_file.name, cache_path) + with open(cache_path, 'wb') as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info("creating metadata file for %s", cache_path) + meta = {'url': url, 'etag': etag} + meta_path = cache_path + '.json' + with open(meta_path, 'w', encoding="utf-8") as meta_file: + json.dump(meta, meta_file) + + logger.info("removing temp file %s", temp_file.name) + + return cache_path + + +def read_set_from_file(filename): + ''' + Extract a de-duped collection (set) of text from a file. + Expected file format is one item per line. + ''' + collection = set() + with open(filename, 'r', encoding='utf-8') as file_: + for line in file_: + collection.add(line.rstrip()) + return collection + + +def get_file_extension(path, dot=True, lower=True): + ext = os.path.splitext(path)[1] + ext = ext if dot else ext[1:] + return ext.lower() if lower else ext diff --git a/lxmert/src/lxrt/modeling.py b/lxmert/src/lxrt/modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..666fd2ee43e0381cb6271324f82d85a88bd27782 --- /dev/null +++ b/lxmert/src/lxrt/modeling.py @@ -0,0 +1,1018 @@ +# coding=utf-8 +# Copyright 2019 project LXRT. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LXRT model.""" + +import copy +import json +import logging +import math +import os +import shutil +import tarfile +import tempfile +import sys +from io import open + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss, SmoothL1Loss + +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_MODEL_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", +} +CONFIG_NAME = 'bert_config.json' +WEIGHTS_NAME = 'pytorch_model.bin' +TF_WEIGHTS_NAME = 'model.ckpt' + +def load_tf_weights_in_bert(model, tf_checkpoint_path): + """ Load tf checkpoints in a pytorch model + """ + try: + import re + import numpy as np + import tensorflow as tf + except Importtokenization: + print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions.") + raise + tf_path = os.path.abspath(tf_checkpoint_path) + print("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + print("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split('/') + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in ["adam_v", "adam_m"] for n in name): + print("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + l = re.split(r'_(\d+)', m_name) + else: + l = [m_name] + if l[0] == 'kernel' or l[0] == 'gamma': + pointer = getattr(pointer, 'weight') + elif l[0] == 'output_bias' or l[0] == 'beta': + pointer = getattr(pointer, 'bias') + elif l[0] == 'output_weights': + pointer = getattr(pointer, 'weight') + else: + pointer = getattr(pointer, l[0]) + if len(l) >= 2: + num = int(l[1]) + pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + print("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +class GeLU(nn.Module): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + def __init__(self): + super().__init__() + + def forward(self, x): + return gelu(x) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class VisualConfig(object): + VISUAL_LOSSES = ['obj', 'attr', 'feat'] + def __init__(self, + l_layers=12, + x_layers=5, + r_layers=0): + self.l_layers = l_layers + self.x_layers = x_layers + self.r_layers = r_layers + + self.visual_feat_dim = 2048 + self.visual_pos_dim = 4 + + self.obj_id_num = 1600 + self.attr_id_num = 400 + + self.visual_losses = self.VISUAL_LOSSES + self.visual_loss_config = { + 'obj': (self.obj_id_num, 'ce', (-1,), 1/0.15), + 'attr': (self.attr_id_num, 'ce', (-1,), 1/0.15), + 'feat': (2048, 'l2', (-1, 2048), 1/0.15), + } + + def set_visual_dims(self, feat_dim, pos_dim): + self.visual_feat_dim = feat_dim + self.visual_pos_dim = pos_dim + + +VISUAL_CONFIG = VisualConfig() + + +class BertConfig(object): + """Configuration class to store the configuration of a `BertModel`. + """ + def __init__(self, + vocab_size_or_config_json_file, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02): + """Constructs BertConfig. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 + and isinstance(vocab_size_or_config_json_file, unicode)): + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + else: + raise ValueError("First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)") + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = BertConfig(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, "r", encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + +BertLayerNorm = torch.nn.LayerNorm + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertAttention(nn.Module): + def __init__(self, config, ctx_dim=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + # visual_dim = 2048 + if ctx_dim is None: + ctx_dim =config.hidden_size + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(ctx_dim, self.all_head_size) + self.value = nn.Linear(ctx_dim, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, context, attention_mask=None): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(context) + mixed_value_layer = self.value(context) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertAttOutput(nn.Module): + def __init__(self, config): + super(BertAttOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertCrossattLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.att = BertAttention(config) + self.output = BertAttOutput(config) + + def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None): + output = self.att(input_tensor, ctx_tensor, ctx_att_mask) + attention_output = self.output(output, input_tensor) + return attention_output + + +class BertSelfattLayer(nn.Module): + def __init__(self, config): + super(BertSelfattLayer, self).__init__() + self.self = BertAttention(config) + self.output = BertAttOutput(config) + + def forward(self, input_tensor, attention_mask): + # Self attention attends to itself, thus keys and querys are the same (input_tensor). + self_output = self.self(input_tensor, input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertSelfattLayer(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +""" +--------------------------------------------------------------------------------------- + Above modules are copied from BERT (pytorch-transformer) with modifications. +--------------------------------------------------------------------------------------- +""" + + +class LXRTXLayer(nn.Module): + def __init__(self, config): + super().__init__() + # The cross-attention Layer + self.visual_attention = BertCrossattLayer(config) + + # Self-attention Layers + self.lang_self_att = BertSelfattLayer(config) + self.visn_self_att = BertSelfattLayer(config) + + # Intermediate and Output Layers (FFNs) + self.lang_inter = BertIntermediate(config) + self.lang_output = BertOutput(config) + self.visn_inter = BertIntermediate(config) + self.visn_output = BertOutput(config) + + def cross_att(self, lang_input, lang_attention_mask, visn_input, visn_attention_mask): + # Cross Attention + lang_att_output = self.visual_attention(lang_input, visn_input, ctx_att_mask=visn_attention_mask) + visn_att_output = self.visual_attention(visn_input, lang_input, ctx_att_mask=lang_attention_mask) + return lang_att_output, visn_att_output + + def self_att(self, lang_input, lang_attention_mask, visn_input, visn_attention_mask): + # Self Attention + lang_att_output = self.lang_self_att(lang_input, lang_attention_mask) + visn_att_output = self.visn_self_att(visn_input, visn_attention_mask) + return lang_att_output, visn_att_output + + def output_fc(self, lang_input, visn_input): + # FC layers + lang_inter_output = self.lang_inter(lang_input) + visn_inter_output = self.visn_inter(visn_input) + + # Layer output + lang_output = self.lang_output(lang_inter_output, lang_input) + visn_output = self.visn_output(visn_inter_output, visn_input) + return lang_output, visn_output + + def forward(self, lang_feats, lang_attention_mask, + visn_feats, visn_attention_mask): + lang_att_output = lang_feats + visn_att_output = visn_feats + + lang_att_output, visn_att_output = self.cross_att(lang_att_output, lang_attention_mask, + visn_att_output, visn_attention_mask) + lang_att_output, visn_att_output = self.self_att(lang_att_output, lang_attention_mask, + visn_att_output, visn_attention_mask) + lang_output, visn_output = self.output_fc(lang_att_output, visn_att_output) + + return lang_output, visn_output + + +class VisualFeatEncoder(nn.Module): + def __init__(self, config): + super().__init__() + feat_dim = VISUAL_CONFIG.visual_feat_dim + pos_dim = VISUAL_CONFIG.visual_pos_dim + + # Object feature encoding + self.visn_fc = nn.Linear(feat_dim, config.hidden_size) + self.visn_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12) + + # Box position encoding + self.box_fc = nn.Linear(pos_dim, config.hidden_size) + self.box_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12) + + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, visn_input): + feats, boxes = visn_input + + x = self.visn_fc(feats) + x = self.visn_layer_norm(x) + y = self.box_fc(boxes) + y = self.box_layer_norm(y) + output = (x + y) / 2 + + output = self.dropout(output) + return output + + +class LXRTEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + # Obj-level image embedding layer + self.visn_fc = VisualFeatEncoder(config) + + # Number of layers + self.num_l_layers = VISUAL_CONFIG.l_layers + self.num_x_layers = VISUAL_CONFIG.x_layers + self.num_r_layers = VISUAL_CONFIG.r_layers + print("LXRT encoder with %d l_layers, %d x_layers, and %d r_layers." % + (self.num_l_layers, self.num_x_layers, self.num_r_layers)) + + # Layers + # Using self.layer instead of self.l_layer to support loading BERT weights. + self.layer = nn.ModuleList( + [BertLayer(config) for _ in range(self.num_l_layers)] + ) + self.x_layers = nn.ModuleList( + [LXRTXLayer(config) for _ in range(self.num_x_layers)] + ) + self.r_layers = nn.ModuleList( + [BertLayer(config) for _ in range(self.num_r_layers)] + ) + + def forward(self, lang_feats, lang_attention_mask, + visn_feats, visn_attention_mask=None): + # Run visual embedding layer + # Note: Word embedding layer was executed outside this module. + # Keep this design to allow loading BERT weights. + visn_feats = self.visn_fc(visn_feats) + + # Run language layers + for layer_module in self.layer: + lang_feats = layer_module(lang_feats, lang_attention_mask) + + # Run relational layers + for layer_module in self.r_layers: + visn_feats = layer_module(visn_feats, visn_attention_mask) + + # Run cross-modality layers + for layer_module in self.x_layers: + lang_feats, visn_feats = layer_module(lang_feats, lang_attention_mask, + visn_feats, visn_attention_mask) + + return lang_feats, visn_feats + + +class BertPooler(nn.Module): + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(bert_model_embedding_weights.size(1), + bert_model_embedding_weights.size(0), + bias=False) + self.decoder.weight = bert_model_embedding_weights + self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + return hidden_states + + +class BertVisualAnswerHead(nn.Module): + def __init__(self, config, num_answers): + super().__init__() + hid_dim = config.hidden_size + self.logit_fc = nn.Sequential( + nn.Linear(hid_dim, hid_dim * 2), + GeLU(), + BertLayerNorm(hid_dim * 2, eps=1e-12), + nn.Linear(hid_dim * 2, num_answers) + ) + + def forward(self, hidden_states): + return self.logit_fc(hidden_states) + + +class BertVisualObjHead(nn.Module): + def __init__(self, config, visual_losses): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # Decide the use of visual losses + visual_losses = visual_losses.split(",") + for loss in visual_losses: + assert loss in VISUAL_CONFIG.VISUAL_LOSSES + self.visual_losses = visual_losses + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder_dict = nn.ModuleDict({ + key: nn.Linear(config.hidden_size, VISUAL_CONFIG.visual_loss_config[key][0]) + for key in self.visual_losses + }) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + output = {} + for key in self.visual_losses: + output[key] = self.decoder_dict[key](hidden_states) + return output + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertPreTrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + def __init__(self, config, *inputs, **kwargs): + super(BertPreTrainedModel, self).__init__() + if not isinstance(config, BertConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + )) + self.config = config + + def init_bert_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, + from_tf=False, *inputs, **kwargs): + """ + Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + + Params: + pretrained_model_name_or_path: either: + - a str with the name of a pre-trained model to load selected in the list of: + . `bert-base-uncased` + . `bert-large-uncased` + . `bert-base-cased` + . `bert-large-cased` + . `bert-base-multilingual-uncased` + . `bert-base-multilingual-cased` + . `bert-base-chinese` + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `model.chkpt` a TensorFlow checkpoint + from_tf: should we load the weights from a locally saved TensorFlow checkpoint + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) + """ + if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: + archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] + else: + archive_file = pretrained_model_name_or_path + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + except EnvironmentError: + if pretrained_model_name_or_path == 'bert-base-uncased': + try: + print("The BERT-weight-downloading query to AWS was time-out;" + "trying to download from UNC servers") + archive_file = "https://nlp.cs.unc.edu/data/bert/bert-base-uncased.tar.gz" + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + except EnvironmentError: + print("The weight-downloading still crashed with link: %s, " + "please check your network connection" % archive_file) + return None + else: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), + archive_file)) + if resolved_archive_file == archive_file: + logger.info("loading archive file {}".format(archive_file)) + else: + logger.info("loading archive file {} from cache at {}".format( + archive_file, resolved_archive_file)) + tempdir = None + if os.path.isdir(resolved_archive_file) or from_tf: + serialization_dir = resolved_archive_file + else: + # Extract archive to temp dir + tempdir = tempfile.mkdtemp() + logger.info("extracting archive file {} to temp dir {}".format( + resolved_archive_file, tempdir)) + with tarfile.open(resolved_archive_file, 'r:gz') as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + # Load config + config_file = os.path.join(serialization_dir, CONFIG_NAME) + config = BertConfig.from_json_file(config_file) + logger.info("Model config {}".format(config)) + # Instantiate model. + model = cls(config, *inputs, **kwargs) + if state_dict is None and not from_tf: + weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None) + if tempdir: + # Clean up temp dir + shutil.rmtree(tempdir) + if from_tf: + # Directly load from a TensorFlow checkpoint + weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) + return load_tf_weights_in_bert(model, weights_path) + # Load from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + start_prefix = '' + if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()): + start_prefix = 'bert.' + load(model, prefix=start_prefix) + # if len(missing_keys) > 0: + # logger.info("Weights of {} not initialized from pretrained model: {}".format( + # model.__class__.__name__, missing_keys)) + # if len(unexpected_keys) > 0: + # logger.info("Weights from pretrained model not used in {}: {}".format( + # model.__class__.__name__, unexpected_keys)) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs))) + return model + + +class LXRTModel(BertPreTrainedModel): + """LXRT Model.""" + + def __init__(self, config): + super().__init__(config) + self.embeddings = BertEmbeddings(config) + self.encoder = LXRTEncoder(config) + self.pooler = BertPooler(config) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, + visual_feats=None, visual_attention_mask=None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # Process the visual attention mask + if visual_attention_mask is not None: + extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2) + extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0 + else: + extended_visual_attention_mask = None + + # Positional Word Embeddings + embedding_output = self.embeddings(input_ids, token_type_ids) + + # Run LXRT backbone + lang_feats, visn_feats = self.encoder( + embedding_output, + extended_attention_mask, + visn_feats=visual_feats, + visn_attention_mask=extended_visual_attention_mask) + pooled_output = self.pooler(lang_feats) + + return (lang_feats, visn_feats), pooled_output + + +class LXRTPretraining(BertPreTrainedModel): + def __init__(self, + config, + task_mask_lm=True, + task_matched=True, + task_obj_predict=True, + visual_losses='', + task_qa=True, + num_answers=2): + super().__init__(config) + # Configuration + self.config = config + self.num_answers = num_answers + + # Use of pre-training tasks + self.task_mask_lm = task_mask_lm + self.task_obj_predict = task_obj_predict + self.task_matched = task_matched + self.task_qa = task_qa + + # LXRT backbone + self.bert = LXRTModel(config) + + # Pre-training heads + self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) + if self.task_obj_predict: + self.obj_predict_head = BertVisualObjHead(config, visual_losses) + if self.task_qa: + self.answer_head = BertVisualAnswerHead(config, self.num_answers) + + # Weight initialization + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, + visual_feats=None, pos=None, obj_labels=None, matched_label=None, ans=None): + (lang_output, visn_output), pooled_output = self.bert( + input_ids, token_type_ids, attention_mask, + visual_feats=(visual_feats, pos), + ) + + lang_prediction_scores, cross_relationship_score = self.cls(lang_output, pooled_output) + if self.task_qa: + answer_score = self.answer_head(pooled_output) + else: + # This answer_score would not be used anywhere, + # just to keep a constant return function signature. + answer_score = pooled_output[0][0] + + total_loss = 0. + loss_fct = CrossEntropyLoss(ignore_index=-1) + losses = () + if masked_lm_labels is not None and self.task_mask_lm: + masked_lm_loss = loss_fct( + lang_prediction_scores.view(-1, self.config.vocab_size), + masked_lm_labels.view(-1) + ) + total_loss += masked_lm_loss + losses += (masked_lm_loss.detach(),) + if matched_label is not None and self.task_matched: + matched_loss = loss_fct( + cross_relationship_score.view(-1, 2), + matched_label.view(-1) + ) + total_loss += matched_loss + losses += (matched_loss.detach(),) + if obj_labels is not None and self.task_obj_predict: + loss_fcts = { + 'l2': SmoothL1Loss(reduction='none'), + 'ce': CrossEntropyLoss(ignore_index=-1, reduction='none') + } + total_visn_loss = 0. + visn_prediction_scores_dict = self.obj_predict_head(visn_output) + for key in VISUAL_CONFIG.visual_losses: + label, mask_conf = obj_labels[key] + output_dim, loss_fct_name, label_shape, weight = VISUAL_CONFIG.visual_loss_config[key] + visn_loss_fct = loss_fcts[loss_fct_name] + visn_prediction_scores = visn_prediction_scores_dict[key] + visn_loss = visn_loss_fct( + visn_prediction_scores.view(-1, output_dim), + label.view(*label_shape), + ) + if visn_loss.dim() > 1: # Regression Losses + visn_loss = visn_loss.mean(1) + visn_loss = (visn_loss * mask_conf.view(-1)).mean() * weight + total_visn_loss += visn_loss + losses += (visn_loss.detach(),) + total_loss += total_visn_loss + if ans is not None and self.task_qa: + answer_loss = loss_fct( + answer_score.view(-1, self.num_answers), + ans.view(-1) + ) + # Since this Github version pre-trains with QA loss from the beginning, + # I exclude "*2" here to match the effect of QA losses. + # Previous: (loss *0) for 6 epochs, (loss *2) for 6 epochs. (Used 10 instead of 6 in EMNLP paper) + # Now : (loss *1) for 12 epochs + # + # * 2 # Multiply by 2 because > half of the data will not have label + total_loss += answer_loss + losses += (answer_loss.detach(),) + return total_loss, torch.stack(losses).unsqueeze(0), answer_score.detach() + + +class LXRTFeatureExtraction(BertPreTrainedModel): + """ + BERT model for classification. + """ + def __init__(self, config, mode='lxr'): + """ + + :param config: + :param mode: Number of visual layers + """ + super().__init__(config) + self.bert = LXRTModel(config) + self.mode = mode + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, visual_feats=None, + visual_attention_mask=None): + feat_seq, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, + visual_feats=visual_feats, + visual_attention_mask=visual_attention_mask) + if 'x' == self.mode: + return pooled_output + elif 'x' in self.mode and ('l' in self.mode or 'r' in self.mode): + return feat_seq, pooled_output + elif 'l' in self.mode or 'r' in self.mode: + return feat_seq + diff --git a/lxmert/src/lxrt/optimization.py b/lxmert/src/lxrt/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..99533ef33e497603fa9537b00da965fee91cbf8b --- /dev/null +++ b/lxmert/src/lxrt/optimization.py @@ -0,0 +1,180 @@ +# coding=utf-8 +# Copyright 2019 project LXRT +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch optimization for BERT model.""" + +import math +import torch +from torch.optim import Optimizer +from torch.optim.optimizer import required +import logging + +logger = logging.getLogger(__name__) + +def warmup_cosine(x, warmup=0.002): + if x < warmup: + return x/warmup + return 0.5 * (1.0 + torch.cos(math.pi * x)) + +def warmup_constant(x, warmup=0.002): + """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. + Learning rate is 1. afterwards. """ + if x < warmup: + return x/warmup + return 1.0 + +def warmup_linear(x, warmup=0.002): + """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. + After `t_total`-th training step, learning rate is zero. """ + if x < warmup: + return x/warmup + return max((x-1.)/(warmup-1.), 0) + +SCHEDULES = { + 'warmup_cosine': warmup_cosine, + 'warmup_constant': warmup_constant, + 'warmup_linear': warmup_linear, +} + + +class BertAdam(Optimizer): + """Implements BERT version of Adam algorithm with weight decay fix. + Params: + lr: learning rate + warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 + t_total: total number of training steps for the learning + rate schedule, -1 means constant learning rate. Default: -1 + schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' + b1: Adams b1. Default: 0.9 + b2: Adams b2. Default: 0.999 + e: Adams epsilon. Default: 1e-6 + weight_decay: Weight decay. Default: 0.01 + max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 + """ + def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', + b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, + max_grad_norm=1.0): + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) + if schedule not in SCHEDULES: + raise ValueError("Invalid schedule parameter: {}".format(schedule)) + if not 0.0 <= warmup < 1.0 and not warmup == -1: + raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) + if not 0.0 <= b1 < 1.0: + raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) + if not 0.0 <= b2 < 1.0: + raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) + if not e >= 0.0: + raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) + defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, + b1=b1, b2=b2, e=e, weight_decay=weight_decay, + max_grad_norm=max_grad_norm) + super(BertAdam, self).__init__(params, defaults) + + def get_lr(self): + lr = [] + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + if len(state) == 0: + return [0] + if group['t_total'] != -1: + schedule_fct = SCHEDULES[group['schedule']] + lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) + else: + lr_scheduled = group['lr'] + lr.append(lr_scheduled) + return lr + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + warned_for_t_total = False + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['next_m'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['next_v'] = torch.zeros_like(p.data) + + next_m, next_v = state['next_m'], state['next_v'] + beta1, beta2 = group['b1'], group['b2'] + + # LXRT: grad is clipped outside. + # Add grad clipping + # if group['max_grad_norm'] > 0: + # clip_grad_norm_(p, group['max_grad_norm']) + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + next_m.mul_(beta1).add_(1 - beta1, grad) + next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) + update = next_m / (next_v.sqrt() + group['e']) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + if group['weight_decay'] > 0.0: + update += group['weight_decay'] * p.data + + if group['t_total'] != -1: + schedule_fct = SCHEDULES[group['schedule']] + progress = state['step']/group['t_total'] + lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) + # warning for exceeding t_total (only active with warmup_linear + if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total: + logger.warning( + "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. " + "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__)) + warned_for_t_total = True + # end warning + else: + lr_scheduled = group['lr'] + + update_with_lr = lr_scheduled * update + p.data.add_(-update_with_lr) + + state['step'] += 1 + + # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 + # No bias correction + # bias_correction1 = 1 - beta1 ** state['step'] + # bias_correction2 = 1 - beta2 ** state['step'] + + return loss diff --git a/lxmert/src/lxrt/tokenization.py b/lxmert/src/lxrt/tokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..9d986e597e76d493d6139abb4ae8d7ca2986460f --- /dev/null +++ b/lxmert/src/lxrt/tokenization.py @@ -0,0 +1,388 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + +import collections +import logging +import os +import unicodedata +from io import open + +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", +} +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'bert-base-uncased': 512, + 'bert-large-uncased': 512, + 'bert-base-cased': 512, + 'bert-large-cased': 512, + 'bert-base-multilingual-uncased': 512, + 'bert-base-multilingual-cased': 512, + 'bert-base-chinese': 512, +} +VOCAB_NAME = 'vocab.txt' + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r", encoding="utf-8") as reader: + while True: + token = reader.readline() + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BertTokenizer(object): + """Runs end-to-end tokenization: punctuation splitting + wordpiece""" + + def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BertTokenizer. + + Args: + vocab_file: Path to a one-wordpiece-per-line vocabulary file + do_lower_case: Whether to lower case the input + Only has an effect when do_wordpiece_only=False + do_basic_tokenize: Whether to do basic tokenization before wordpiece. + max_len: An artificial maximum length to truncate tokenized sequences to; + Effective maximum length is always the minimum of this + value (if specified) and the underlying BERT model's + sequence length. + never_split: List of tokens which will never be split during tokenization. + Only has an effect when do_wordpiece_only=False + """ + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=never_split) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.max_len = max_len if max_len is not None else int(1e12) + + def tokenize(self, text): + if self.do_basic_tokenize: + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + if len(ids) > self.max_len: + logger.warning( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this BERT model ({} > {}). Running this" + " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in wordpiece tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + else: + vocab_file = pretrained_model_name_or_path + if os.path.isdir(vocab_file): + vocab_file = os.path.join(vocab_file, VOCAB_NAME) + # redirect to the cache, if necessary + try: + resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + vocab_file)) + return None + if resolved_vocab_file == vocab_file: + logger.info("loading vocabulary file {}".format(vocab_file)) + else: + logger.info("loading vocabulary file {} from cache at {}".format( + vocab_file, resolved_vocab_file)) + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) + return tokenizer + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, + do_lower_case=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = self._clean_text(text) + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case and token not in self.never_split: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/lxmert/src/modeling_frcnn.py b/lxmert/src/modeling_frcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..332350cca0f3b10017cf59ee9a8e2449918a9e47 --- /dev/null +++ b/lxmert/src/modeling_frcnn.py @@ -0,0 +1,1922 @@ +""" + coding=utf-8 + Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal + Adapted From Facebook Inc, Detectron2 && Huggingface Co. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License.import copy + """ +import itertools +import math +import os +from abc import ABCMeta, abstractmethod +from collections import OrderedDict, namedtuple +from typing import Dict, List, Tuple + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.batchnorm import BatchNorm2d +from torchvision.ops import RoIPool +from torchvision.ops.boxes import batched_nms, nms + +from lxmert.lxmert.src.vqa_utils import WEIGHTS_NAME, Config, cached_path, hf_bucket_url, is_remote_url, load_checkpoint + + +# other: +def norm_box(boxes, raw_sizes): + if not isinstance(boxes, torch.Tensor): + normalized_boxes = boxes.copy() + else: + normalized_boxes = boxes.clone() + normalized_boxes[:, :, (0, 2)] /= raw_sizes[:, 1] + normalized_boxes[:, :, (1, 3)] /= raw_sizes[:, 0] + return normalized_boxes + + +def pad_list_tensors( + list_tensors, + preds_per_image, + max_detections=None, + return_tensors=None, + padding=None, + pad_value=0, + location=None, +): + """ + location will always be cpu for np tensors + """ + if location is None: + location = "cpu" + assert return_tensors in {"pt", "np", None} + assert padding in {"max_detections", "max_batch", None} + new = [] + if padding is None: + if return_tensors is None: + return list_tensors + elif return_tensors == "pt": + if not isinstance(list_tensors, torch.Tensor): + return torch.stack(list_tensors).to(location) + else: + return list_tensors.to(location) + else: + if not isinstance(list_tensors, list): + return np.array(list_tensors.to(location)) + else: + return list_tensors.to(location) + if padding == "max_detections": + assert max_detections is not None, "specify max number of detections per batch" + elif padding == "max_batch": + max_detections = max(preds_per_image) + for i in range(len(list_tensors)): + too_small = False + tensor_i = list_tensors.pop(0) + if tensor_i.ndim < 2: + too_small = True + tensor_i = tensor_i.unsqueeze(-1) + assert isinstance(tensor_i, torch.Tensor) + tensor_i = F.pad( + input=tensor_i, + pad=(0, 0, 0, max_detections - preds_per_image[i]), + mode="constant", + value=pad_value, + ) + if too_small: + tensor_i = tensor_i.squeeze(-1) + if return_tensors is None: + if location == "cpu": + tensor_i = tensor_i.cpu() + tensor_i = tensor_i.tolist() + if return_tensors == "np": + if location == "cpu": + tensor_i = tensor_i.cpu() + tensor_i = tensor_i.numpy() + else: + if location == "cpu": + tensor_i = tensor_i.cpu() + new.append(tensor_i) + if return_tensors == "np": + return np.stack(new, axis=0) + elif return_tensors == "pt" and not isinstance(new, torch.Tensor): + return torch.stack(new, dim=0) + else: + return list_tensors + + +def do_nms(boxes, scores, image_shape, score_thresh, nms_thresh, mind, maxd): + scores = scores[:, :-1] + num_bbox_reg_classes = boxes.shape[1] // 4 + # Convert to Boxes to use the `clip` function ... + boxes = boxes.reshape(-1, 4) + _clip_box(boxes, image_shape) + boxes = boxes.view(-1, num_bbox_reg_classes, 4) # R x C x 4 + + # Select max scores + max_scores, max_classes = scores.max(1) # R x C --> R + num_objs = boxes.size(0) + boxes = boxes.view(-1, 4) + idxs = torch.arange(num_objs).to(boxes.device) * num_bbox_reg_classes + max_classes + max_boxes = boxes[idxs] # Select max boxes according to the max scores. + + # Apply NMS + keep = nms(max_boxes, max_scores, nms_thresh) + keep = keep[:maxd] + if keep.shape[-1] >= mind and keep.shape[-1] <= maxd: + max_boxes, max_scores = max_boxes[keep], max_scores[keep] + classes = max_classes[keep] + return max_boxes, max_scores, classes, keep + else: + return None + + +# Helper Functions +def _clip_box(tensor, box_size: Tuple[int, int]): + assert torch.isfinite(tensor).all(), "Box tensor contains infinite or NaN!" + h, w = box_size + tensor[:, 0].clamp_(min=0, max=w) + tensor[:, 1].clamp_(min=0, max=h) + tensor[:, 2].clamp_(min=0, max=w) + tensor[:, 3].clamp_(min=0, max=h) + + +def _nonempty_boxes(box, threshold: float = 0.0) -> torch.Tensor: + widths = box[:, 2] - box[:, 0] + heights = box[:, 3] - box[:, 1] + keep = (widths > threshold) & (heights > threshold) + return keep + + +def get_norm(norm, out_channels): + if isinstance(norm, str): + if len(norm) == 0: + return None + norm = { + "BN": BatchNorm2d, + "GN": lambda channels: nn.GroupNorm(32, channels), + "nnSyncBN": nn.SyncBatchNorm, # keep for debugging + "": lambda x: x, + }[norm] + return norm(out_channels) + + +def _create_grid_offsets(size: List[int], stride: int, offset: float, device): + + grid_height, grid_width = size + shifts_x = torch.arange( + offset * stride, + grid_width * stride, + step=stride, + dtype=torch.float32, + device=device, + ) + shifts_y = torch.arange( + offset * stride, + grid_height * stride, + step=stride, + dtype=torch.float32, + device=device, + ) + + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + return shift_x, shift_y + + +def build_backbone(cfg): + input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN)) + norm = cfg.RESNETS.NORM + stem = BasicStem( + in_channels=input_shape.channels, + out_channels=cfg.RESNETS.STEM_OUT_CHANNELS, + norm=norm, + caffe_maxpool=cfg.MODEL.MAX_POOL, + ) + freeze_at = cfg.BACKBONE.FREEZE_AT + + if freeze_at >= 1: + for p in stem.parameters(): + p.requires_grad = False + + out_features = cfg.RESNETS.OUT_FEATURES + depth = cfg.RESNETS.DEPTH + num_groups = cfg.RESNETS.NUM_GROUPS + width_per_group = cfg.RESNETS.WIDTH_PER_GROUP + bottleneck_channels = num_groups * width_per_group + in_channels = cfg.RESNETS.STEM_OUT_CHANNELS + out_channels = cfg.RESNETS.RES2_OUT_CHANNELS + stride_in_1x1 = cfg.RESNETS.STRIDE_IN_1X1 + res5_dilation = cfg.RESNETS.RES5_DILATION + assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation) + + num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth] + + stages = [] + out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features] + max_stage_idx = max(out_stage_idx) + for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): + dilation = res5_dilation if stage_idx == 5 else 1 + first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 + stage_kargs = { + "num_blocks": num_blocks_per_stage[idx], + "first_stride": first_stride, + "in_channels": in_channels, + "bottleneck_channels": bottleneck_channels, + "out_channels": out_channels, + "num_groups": num_groups, + "norm": norm, + "stride_in_1x1": stride_in_1x1, + "dilation": dilation, + } + + stage_kargs["block_class"] = BottleneckBlock + blocks = ResNet.make_stage(**stage_kargs) + in_channels = out_channels + out_channels *= 2 + bottleneck_channels *= 2 + + if freeze_at >= stage_idx: + for block in blocks: + block.freeze() + stages.append(blocks) + + return ResNet(stem, stages, out_features=out_features) + + +def find_top_rpn_proposals( + proposals, + pred_objectness_logits, + images, + image_sizes, + nms_thresh, + pre_nms_topk, + post_nms_topk, + min_box_side_len, + training, +): + """Args: + proposals (list[Tensor]): (L, N, Hi*Wi*A, 4). + pred_objectness_logits: tensors of length L. + nms_thresh (float): IoU threshold to use for NMS + pre_nms_topk (int): before nms + post_nms_topk (int): after nms + min_box_side_len (float): minimum proposal box side + training (bool): True if proposals are to be used in training, + Returns: + results (List[Dict]): stores post_nms_topk object proposals for image i. + """ + num_images = len(images) + device = proposals[0].device + + # 1. Select top-k anchor for every level and every image + topk_scores = [] # #lvl Tensor, each of shape N x topk + topk_proposals = [] + level_ids = [] # #lvl Tensor, each of shape (topk,) + batch_idx = torch.arange(num_images, device=device) + for level_id, proposals_i, logits_i in zip(itertools.count(), proposals, pred_objectness_logits): + Hi_Wi_A = logits_i.shape[1] + num_proposals_i = min(pre_nms_topk, Hi_Wi_A) + + # sort is faster than topk (https://github.com/pytorch/pytorch/issues/22812) + # topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1) + logits_i, idx = logits_i.sort(descending=True, dim=1) + topk_scores_i = logits_i[batch_idx, :num_proposals_i] + topk_idx = idx[batch_idx, :num_proposals_i] + + # each is N x topk + topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx] # N x topk x 4 + + topk_proposals.append(topk_proposals_i) + topk_scores.append(topk_scores_i) + level_ids.append(torch.full((num_proposals_i,), level_id, dtype=torch.int64, device=device)) + + # 2. Concat all levels together + topk_scores = torch.cat(topk_scores, dim=1) + topk_proposals = torch.cat(topk_proposals, dim=1) + level_ids = torch.cat(level_ids, dim=0) + + # if I change to batched_nms, I wonder if this will make a difference + # 3. For each image, run a per-level NMS, and choose topk results. + results = [] + for n, image_size in enumerate(image_sizes): + boxes = topk_proposals[n] + scores_per_img = topk_scores[n] + # I will have to take a look at the boxes clip method + _clip_box(boxes, image_size) + # filter empty boxes + keep = _nonempty_boxes(boxes, threshold=min_box_side_len) + lvl = level_ids + if keep.sum().item() != len(boxes): + boxes, scores_per_img, lvl = ( + boxes[keep], + scores_per_img[keep], + level_ids[keep], + ) + + keep = batched_nms(boxes, scores_per_img, lvl, nms_thresh) + keep = keep[:post_nms_topk] + + res = (boxes[keep], scores_per_img[keep]) + results.append(res) + + # I wonder if it would be possible for me to pad all these things. + return results + + +def subsample_labels(labels, num_samples, positive_fraction, bg_label): + """ + Returns: + pos_idx, neg_idx (Tensor): + 1D vector of indices. The total length of both is `num_samples` or fewer. + """ + positive = torch.nonzero((labels != -1) & (labels != bg_label)).squeeze(1) + negative = torch.nonzero(labels == bg_label).squeeze(1) + + num_pos = int(num_samples * positive_fraction) + # protect against not enough positive examples + num_pos = min(positive.numel(), num_pos) + num_neg = num_samples - num_pos + # protect against not enough negative examples + num_neg = min(negative.numel(), num_neg) + + # randomly select positive and negative examples + perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] + perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] + + pos_idx = positive[perm1] + neg_idx = negative[perm2] + return pos_idx, neg_idx + + +def add_ground_truth_to_proposals(gt_boxes, proposals): + raise NotImplementedError() + + +def add_ground_truth_to_proposals_single_image(gt_boxes, proposals): + raise NotImplementedError() + + +def _fmt_box_list(box_tensor, batch_index: int): + repeated_index = torch.full( + (len(box_tensor), 1), + batch_index, + dtype=box_tensor.dtype, + device=box_tensor.device, + ) + return torch.cat((repeated_index, box_tensor), dim=1) + + +def convert_boxes_to_pooler_format(box_lists: List[torch.Tensor]): + pooler_fmt_boxes = torch.cat( + [_fmt_box_list(box_list, i) for i, box_list in enumerate(box_lists)], + dim=0, + ) + return pooler_fmt_boxes + + +def assign_boxes_to_levels( + box_lists: List[torch.Tensor], + min_level: int, + max_level: int, + canonical_box_size: int, + canonical_level: int, +): + + box_sizes = torch.sqrt(torch.cat([boxes.area() for boxes in box_lists])) + # Eqn.(1) in FPN paper + level_assignments = torch.floor(canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8)) + # clamp level to (min, max), in case the box size is too large or too small + # for the available feature maps + level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level) + return level_assignments.to(torch.int64) - min_level + + +# Helper Classes +class _NewEmptyTensorOp(torch.autograd.Function): + @staticmethod + def forward(ctx, x, new_shape): + ctx.shape = x.shape + return x.new_empty(new_shape) + + @staticmethod + def backward(ctx, grad): + shape = ctx.shape + return _NewEmptyTensorOp.apply(grad, shape), None + + +class ShapeSpec(namedtuple("_ShapeSpec", ["channels", "height", "width", "stride"])): + def __new__(cls, *, channels=None, height=None, width=None, stride=None): + return super().__new__(cls, channels, height, width, stride) + + +class Box2BoxTransform(object): + """ + This R-CNN transformation scales the box's width and height + by exp(dw), exp(dh) and shifts a box's center by the offset + (dx * width, dy * height). + """ + + def __init__(self, weights: Tuple[float, float, float, float], scale_clamp: float = None): + """ + Args: + weights (4-element tuple): Scaling factors that are applied to the + (dx, dy, dw, dh) deltas. In Fast R-CNN, these were originally set + such that the deltas have unit variance; now they are treated as + hyperparameters of the system. + scale_clamp (float): When predicting deltas, the predicted box scaling + factors (dw and dh) are clamped such that they are <= scale_clamp. + """ + self.weights = weights + if scale_clamp is not None: + self.scale_clamp = scale_clamp + else: + """ + Value for clamping large dw and dh predictions. + The heuristic is that we clamp such that dw and dh are no larger + than what would transform a 16px box into a 1000px box + (based on a small anchor, 16px, and a typical image size, 1000px). + """ + self.scale_clamp = math.log(1000.0 / 16) + + def get_deltas(self, src_boxes, target_boxes): + """ + Get box regression transformation deltas (dx, dy, dw, dh) that can be used + to transform the `src_boxes` into the `target_boxes`. That is, the relation + ``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true (unless + any delta is too large and is clamped). + Args: + src_boxes (Tensor): source boxes, e.g., object proposals + target_boxes (Tensor): target of the transformation, e.g., ground-truth + boxes. + """ + assert isinstance(src_boxes, torch.Tensor), type(src_boxes) + assert isinstance(target_boxes, torch.Tensor), type(target_boxes) + + src_widths = src_boxes[:, 2] - src_boxes[:, 0] + src_heights = src_boxes[:, 3] - src_boxes[:, 1] + src_ctr_x = src_boxes[:, 0] + 0.5 * src_widths + src_ctr_y = src_boxes[:, 1] + 0.5 * src_heights + + target_widths = target_boxes[:, 2] - target_boxes[:, 0] + target_heights = target_boxes[:, 3] - target_boxes[:, 1] + target_ctr_x = target_boxes[:, 0] + 0.5 * target_widths + target_ctr_y = target_boxes[:, 1] + 0.5 * target_heights + + wx, wy, ww, wh = self.weights + dx = wx * (target_ctr_x - src_ctr_x) / src_widths + dy = wy * (target_ctr_y - src_ctr_y) / src_heights + dw = ww * torch.log(target_widths / src_widths) + dh = wh * torch.log(target_heights / src_heights) + + deltas = torch.stack((dx, dy, dw, dh), dim=1) + assert (src_widths > 0).all().item(), "Input boxes to Box2BoxTransform are not valid!" + return deltas + + def apply_deltas(self, deltas, boxes): + """ + Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`. + Args: + deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1. + deltas[i] represents k potentially different class-specific + box transformations for the single box boxes[i]. + boxes (Tensor): boxes to transform, of shape (N, 4) + """ + boxes = boxes.to(deltas.dtype) + + widths = boxes[:, 2] - boxes[:, 0] + heights = boxes[:, 3] - boxes[:, 1] + ctr_x = boxes[:, 0] + 0.5 * widths + ctr_y = boxes[:, 1] + 0.5 * heights + + wx, wy, ww, wh = self.weights + dx = deltas[:, 0::4] / wx + dy = deltas[:, 1::4] / wy + dw = deltas[:, 2::4] / ww + dh = deltas[:, 3::4] / wh + + # Prevent sending too large values into torch.exp() + dw = torch.clamp(dw, max=self.scale_clamp) + dh = torch.clamp(dh, max=self.scale_clamp) + + pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] + pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] + pred_w = torch.exp(dw) * widths[:, None] + pred_h = torch.exp(dh) * heights[:, None] + + pred_boxes = torch.zeros_like(deltas) + pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1 + pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1 + pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2 + pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2 + return pred_boxes + + +class Matcher(object): + """ + This class assigns to each predicted "element" (e.g., a box) a ground-truth + element. Each predicted element will have exactly zero or one matches; each + ground-truth element may be matched to zero or more predicted elements. + The matching is determined by the MxN match_quality_matrix, that characterizes + how well each (ground-truth, prediction)-pair match each other. For example, + if the elements are boxes, this matrix may contain box intersection-over-union + overlap values. + The matcher returns (a) a vector of length N containing the index of the + ground-truth element m in [0, M) that matches to prediction n in [0, N). + (b) a vector of length N containing the labels for each prediction. + """ + + def __init__( + self, + thresholds: List[float], + labels: List[int], + allow_low_quality_matches: bool = False, + ): + """ + Args: + thresholds (list): a list of thresholds used to stratify predictions + into levels. + labels (list): a list of values to label predictions belonging at + each level. A label can be one of {-1, 0, 1} signifying + {ignore, negative class, positive class}, respectively. + allow_low_quality_matches (bool): if True, produce additional matches or predictions with maximum match quality lower than high_threshold. + For example, thresholds = [0.3, 0.5] labels = [0, -1, 1] All predictions with iou < 0.3 will be marked with 0 and + thus will be considered as false positives while training. All predictions with 0.3 <= iou < 0.5 will be marked with -1 and + thus will be ignored. All predictions with 0.5 <= iou will be marked with 1 and thus will be considered as true positives. + """ + thresholds = thresholds[:] + assert thresholds[0] > 0 + thresholds.insert(0, -float("inf")) + thresholds.append(float("inf")) + assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])]) + assert all([label_i in [-1, 0, 1] for label_i in labels]) + assert len(labels) == len(thresholds) - 1 + self.thresholds = thresholds + self.labels = labels + self.allow_low_quality_matches = allow_low_quality_matches + + def __call__(self, match_quality_matrix): + """ + Args: + match_quality_matrix (Tensor[float]): an MxN tensor, containing the pairwise quality between M ground-truth elements and N predicted + elements. All elements must be >= 0 (due to the us of `torch.nonzero` for selecting indices in :meth:`set_low_quality_matches_`). + Returns: + matches (Tensor[int64]): a vector of length N, where matches[i] is a matched ground-truth index in [0, M) + match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates true or false positive or ignored + """ + assert match_quality_matrix.dim() == 2 + if match_quality_matrix.numel() == 0: + default_matches = match_quality_matrix.new_full((match_quality_matrix.size(1),), 0, dtype=torch.int64) + # When no gt boxes exist, we define IOU = 0 and therefore set labels + # to `self.labels[0]`, which usually defaults to background class 0 + # To choose to ignore instead, + # can make labels=[-1,0,-1,1] + set appropriate thresholds + default_match_labels = match_quality_matrix.new_full( + (match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8 + ) + return default_matches, default_match_labels + + assert torch.all(match_quality_matrix >= 0) + + # match_quality_matrix is M (gt) x N (predicted) + # Max over gt elements (dim 0) to find best gt candidate for each prediction + matched_vals, matches = match_quality_matrix.max(dim=0) + + match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8) + + for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]): + low_high = (matched_vals >= low) & (matched_vals < high) + match_labels[low_high] = l + + if self.allow_low_quality_matches: + self.set_low_quality_matches_(match_labels, match_quality_matrix) + + return matches, match_labels + + def set_low_quality_matches_(self, match_labels, match_quality_matrix): + """ + Produce additional matches for predictions that have only low-quality matches. + Specifically, for each ground-truth G find the set of predictions that have + maximum overlap with it (including ties); for each prediction in that set, if + it is unmatched, then match it to the ground-truth G. + This function implements the RPN assignment case (i) + in Sec. 3.1.2 of Faster R-CNN. + """ + # For each gt, find the prediction with which it has highest quality + highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) + # Find the highest quality match available, even if it is low, including ties. + # Note that the matches qualities must be positive due to the use of + # `torch.nonzero`. + of_quality_inds = match_quality_matrix == highest_quality_foreach_gt[:, None] + if of_quality_inds.dim() == 0: + (_, pred_inds_with_highest_quality) = of_quality_inds.unsqueeze(0).nonzero().unbind(1) + else: + (_, pred_inds_with_highest_quality) = of_quality_inds.nonzero().unbind(1) + match_labels[pred_inds_with_highest_quality] = 1 + + +class RPNOutputs(object): + def __init__( + self, + box2box_transform, + anchor_matcher, + batch_size_per_image, + positive_fraction, + images, + pred_objectness_logits, + pred_anchor_deltas, + anchors, + boundary_threshold=0, + gt_boxes=None, + smooth_l1_beta=0.0, + ): + """ + Args: + box2box_transform (Box2BoxTransform): :class:`Box2BoxTransform` instance for anchor-proposal transformations. + anchor_matcher (Matcher): :class:`Matcher` instance for matching anchors to ground-truth boxes; used to determine training labels. + batch_size_per_image (int): number of proposals to sample when training + positive_fraction (float): target fraction of sampled proposals that should be positive + images (ImageList): :class:`ImageList` instance representing N input images + pred_objectness_logits (list[Tensor]): A list of L elements. Element i is a tensor of shape (N, A, Hi, W) + pred_anchor_deltas (list[Tensor]): A list of L elements. Element i is a tensor of shape (N, A*4, Hi, Wi) + anchors (list[torch.Tensor]): nested list of boxes. anchors[i][j] at (n, l) stores anchor array for feature map l + boundary_threshold (int): if >= 0, then anchors that extend beyond the image boundary by more than boundary_thresh are not used in training. + gt_boxes (list[Boxes], optional): A list of N elements. + smooth_l1_beta (float): The transition point between L1 and L2 lossn. When set to 0, the loss becomes L1. When +inf, it is ignored + """ + self.box2box_transform = box2box_transform + self.anchor_matcher = anchor_matcher + self.batch_size_per_image = batch_size_per_image + self.positive_fraction = positive_fraction + self.pred_objectness_logits = pred_objectness_logits + self.pred_anchor_deltas = pred_anchor_deltas + + self.anchors = anchors + self.gt_boxes = gt_boxes + self.num_feature_maps = len(pred_objectness_logits) + self.num_images = len(images) + self.boundary_threshold = boundary_threshold + self.smooth_l1_beta = smooth_l1_beta + + def _get_ground_truth(self): + raise NotImplementedError() + + def predict_proposals(self): + # pred_anchor_deltas: (L, N, ? Hi, Wi) + # anchors:(N, L, -1, B) + # here we loop over specific feature map, NOT images + proposals = [] + anchors = self.anchors.transpose(0, 1) + for anchors_i, pred_anchor_deltas_i in zip(anchors, self.pred_anchor_deltas): + B = anchors_i.size(-1) + N, _, Hi, Wi = pred_anchor_deltas_i.shape + anchors_i = anchors_i.flatten(start_dim=0, end_dim=1) + pred_anchor_deltas_i = pred_anchor_deltas_i.view(N, -1, B, Hi, Wi).permute(0, 3, 4, 1, 2).reshape(-1, B) + proposals_i = self.box2box_transform.apply_deltas(pred_anchor_deltas_i, anchors_i) + # Append feature map proposals with shape (N, Hi*Wi*A, B) + proposals.append(proposals_i.view(N, -1, B)) + proposals = torch.stack(proposals) + return proposals + + def predict_objectness_logits(self): + """ + Returns: + pred_objectness_logits (list[Tensor]) -> (N, Hi*Wi*A). + """ + pred_objectness_logits = [ + # Reshape: (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N, Hi*Wi*A) + score.permute(0, 2, 3, 1).reshape(self.num_images, -1) + for score in self.pred_objectness_logits + ] + return pred_objectness_logits + + +# Main Classes +class Conv2d(torch.nn.Conv2d): + def __init__(self, *args, **kwargs): + norm = kwargs.pop("norm", None) + activation = kwargs.pop("activation", None) + super().__init__(*args, **kwargs) + + self.norm = norm + self.activation = activation + + def forward(self, x): + if x.numel() == 0 and self.training: + assert not isinstance(self.norm, torch.nn.SyncBatchNorm) + if x.numel() == 0: + assert not isinstance(self.norm, torch.nn.GroupNorm) + output_shape = [ + (i + 2 * p - (di * (k - 1) + 1)) // s + 1 + for i, p, di, k, s in zip( + x.shape[-2:], + self.padding, + self.dilation, + self.kernel_size, + self.stride, + ) + ] + output_shape = [x.shape[0], self.weight.shape[0]] + output_shape + empty = _NewEmptyTensorOp.apply(x, output_shape) + if self.training: + _dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0 + return empty + _dummy + else: + return empty + + x = super().forward(x) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + + +class LastLevelMaxPool(nn.Module): + """ + This module is used in the original FPN to generate a downsampled P6 feature from P5. + """ + + def __init__(self): + super().__init__() + self.num_levels = 1 + self.in_feature = "p5" + + def forward(self, x): + return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)] + + +class LastLevelP6P7(nn.Module): + """ + This module is used in RetinaNet to generate extra layers, P6 and P7 from C5 feature. + """ + + def __init__(self, in_channels, out_channels): + super().__init__() + self.num_levels = 2 + self.in_feature = "res5" + self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) + self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) + + def forward(self, c5): + p6 = self.p6(c5) + p7 = self.p7(F.relu(p6)) + return [p6, p7] + + +class BasicStem(nn.Module): + def __init__(self, in_channels=3, out_channels=64, norm="BN", caffe_maxpool=False): + super().__init__() + self.conv1 = Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False, + norm=get_norm(norm, out_channels), + ) + self.caffe_maxpool = caffe_maxpool + # use pad 1 instead of pad zero + + def forward(self, x): + x = self.conv1(x) + x = F.relu_(x) + if self.caffe_maxpool: + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True) + else: + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) + return x + + @property + def out_channels(self): + return self.conv1.out_channels + + @property + def stride(self): + return 4 # = stride 2 conv -> stride 2 max pool + + +class ResNetBlockBase(nn.Module): + def __init__(self, in_channels, out_channels, stride): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + return self + + +class BottleneckBlock(ResNetBlockBase): + def __init__( + self, + in_channels, + out_channels, + bottleneck_channels, + stride=1, + num_groups=1, + norm="BN", + stride_in_1x1=False, + dilation=1, + ): + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + # The original MSRA ResNet models have stride in the first 1x1 conv + # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have + # stride in the 3x3 conv + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, bottleneck_channels), + ) + + self.conv2 = Conv2d( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + norm=get_norm(norm, bottleneck_channels), + ) + + self.conv3 = Conv2d( + bottleneck_channels, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + out = self.conv2(out) + out = F.relu_(out) + + out = self.conv3(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class Backbone(nn.Module, metaclass=ABCMeta): + def __init__(self): + super().__init__() + + @abstractmethod + def forward(self): + pass + + @property + def size_divisibility(self): + """ + Some backbones require the input height and width to be divisible by a specific integer. This is + typically true for encoder / decoder type networks with lateral connection (e.g., FPN) for which feature maps need to match + dimension in the "bottom up" and "top down" paths. Set to 0 if no specific input size divisibility is required. + """ + return 0 + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], + stride=self._out_feature_strides[name], + ) + for name in self._out_features + } + + @property + def out_features(self): + """deprecated""" + return self._out_features + + @property + def out_feature_strides(self): + """deprecated""" + return {f: self._out_feature_strides[f] for f in self._out_features} + + @property + def out_feature_channels(self): + """deprecated""" + return {f: self._out_feature_channels[f] for f in self._out_features} + + +class ResNet(Backbone): + def __init__(self, stem, stages, num_classes=None, out_features=None): + """ + Args: + stem (nn.Module): a stem module + stages (list[list[ResNetBlock]]): several (typically 4) stages, each contains multiple :class:`ResNetBlockBase`. + num_classes (None or int): if None, will not perform classification. + out_features (list[str]): name of the layers whose outputs should be returned in forward. Can be anything in: + "stem", "linear", or "res2" ... If None, will return the output of the last layer. + """ + super(ResNet, self).__init__() + self.stem = stem + self.num_classes = num_classes + + current_stride = self.stem.stride + self._out_feature_strides = {"stem": current_stride} + self._out_feature_channels = {"stem": self.stem.out_channels} + + self.stages_and_names = [] + for i, blocks in enumerate(stages): + for block in blocks: + assert isinstance(block, ResNetBlockBase), block + curr_channels = block.out_channels + stage = nn.Sequential(*blocks) + name = "res" + str(i + 2) + self.add_module(name, stage) + self.stages_and_names.append((stage, name)) + self._out_feature_strides[name] = current_stride = int( + current_stride * np.prod([k.stride for k in blocks]) + ) + self._out_feature_channels[name] = blocks[-1].out_channels + + if num_classes is not None: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.linear = nn.Linear(curr_channels, num_classes) + + # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "The 1000-way fully-connected layer is initialized by + # drawing weights from a zero-mean Gaussian with std of 0.01." + nn.init.normal_(self.linear.weight, stddev=0.01) + name = "linear" + + if out_features is None: + out_features = [name] + self._out_features = out_features + assert len(self._out_features) + children = [x[0] for x in self.named_children()] + for out_feature in self._out_features: + assert out_feature in children, "Available children: {}".format(", ".join(children)) + + def forward(self, x): + outputs = {} + x = self.stem(x) + if "stem" in self._out_features: + outputs["stem"] = x + for stage, name in self.stages_and_names: + x = stage(x) + if name in self._out_features: + outputs[name] = x + if self.num_classes is not None: + x = self.avgpool(x) + x = self.linear(x) + if "linear" in self._out_features: + outputs["linear"] = x + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], + stride=self._out_feature_strides[name], + ) + for name in self._out_features + } + + @staticmethod + def make_stage( + block_class, + num_blocks, + first_stride=None, + *, + in_channels, + out_channels, + **kwargs, + ): + """ + Usually, layers that produce the same feature map spatial size + are defined as one "stage". + Under such definition, stride_per_block[1:] should all be 1. + """ + if first_stride is not None: + assert "stride" not in kwargs and "stride_per_block" not in kwargs + kwargs["stride_per_block"] = [first_stride] + [1] * (num_blocks - 1) + blocks = [] + for i in range(num_blocks): + curr_kwargs = {} + for k, v in kwargs.items(): + if k.endswith("_per_block"): + assert len(v) == num_blocks, ( + f"Argument '{k}' of make_stage should have the " f"same length as num_blocks={num_blocks}." + ) + newk = k[: -len("_per_block")] + assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!" + curr_kwargs[newk] = v[i] + else: + curr_kwargs[k] = v + + blocks.append(block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)) + in_channels = out_channels + + return blocks + + +class ROIPooler(nn.Module): + """ + Region of interest feature map pooler that supports pooling from one or more + feature maps. + """ + + def __init__( + self, + output_size, + scales, + sampling_ratio, + canonical_box_size=224, + canonical_level=4, + ): + super().__init__() + # assumption that stride is a power of 2. + min_level = -math.log2(scales[0]) + max_level = -math.log2(scales[-1]) + + # a bunch of testing + assert math.isclose(min_level, int(min_level)) and math.isclose(max_level, int(max_level)) + assert len(scales) == max_level - min_level + 1, "not pyramid" + assert 0 < min_level and min_level <= max_level + if isinstance(output_size, int): + output_size = (output_size, output_size) + assert len(output_size) == 2 and isinstance(output_size[0], int) and isinstance(output_size[1], int) + if len(scales) > 1: + assert min_level <= canonical_level and canonical_level <= max_level + assert canonical_box_size > 0 + + self.output_size = output_size + self.min_level = int(min_level) + self.max_level = int(max_level) + self.level_poolers = nn.ModuleList(RoIPool(output_size, spatial_scale=scale) for scale in scales) + self.canonical_level = canonical_level + self.canonical_box_size = canonical_box_size + + def forward(self, feature_maps, boxes): + """ + Args: + feature_maps: List[torch.Tensor(N,C,W,H)] + box_lists: list[torch.Tensor]) + Returns: + A tensor of shape(N*B, Channels, output_size, output_size) + """ + x = [v for v in feature_maps.values()] + num_level_assignments = len(self.level_poolers) + assert len(x) == num_level_assignments and len(boxes) == x[0].size(0) + + pooler_fmt_boxes = convert_boxes_to_pooler_format(boxes) + + if num_level_assignments == 1: + return self.level_poolers[0](x[0], pooler_fmt_boxes) + + level_assignments = assign_boxes_to_levels( + boxes, + self.min_level, + self.max_level, + self.canonical_box_size, + self.canonical_level, + ) + + num_boxes = len(pooler_fmt_boxes) + num_channels = x[0].shape[1] + output_size = self.output_size[0] + + dtype, device = x[0].dtype, x[0].device + output = torch.zeros( + (num_boxes, num_channels, output_size, output_size), + dtype=dtype, + device=device, + ) + + for level, (x_level, pooler) in enumerate(zip(x, self.level_poolers)): + inds = torch.nonzero(level_assignments == level).squeeze(1) + pooler_fmt_boxes_level = pooler_fmt_boxes[inds] + output[inds] = pooler(x_level, pooler_fmt_boxes_level) + + return output + + +class ROIOutputs(object): + def __init__(self, cfg, training=False): + self.smooth_l1_beta = cfg.ROI_BOX_HEAD.SMOOTH_L1_BETA + self.box2box_transform = Box2BoxTransform(weights=cfg.ROI_BOX_HEAD.BBOX_REG_WEIGHTS) + self.training = training + self.score_thresh = cfg.ROI_HEADS.SCORE_THRESH_TEST + self.min_detections = cfg.MIN_DETECTIONS + self.max_detections = cfg.MAX_DETECTIONS + + nms_thresh = cfg.ROI_HEADS.NMS_THRESH_TEST + if not isinstance(nms_thresh, list): + nms_thresh = [nms_thresh] + self.nms_thresh = nms_thresh + + def _predict_boxes(self, proposals, box_deltas, preds_per_image): + num_pred = box_deltas.size(0) + B = proposals[0].size(-1) + K = box_deltas.size(-1) // B + box_deltas = box_deltas.view(num_pred * K, B) + proposals = torch.cat(proposals, dim=0).unsqueeze(-2).expand(num_pred, K, B) + proposals = proposals.reshape(-1, B) + boxes = self.box2box_transform.apply_deltas(box_deltas, proposals) + return boxes.view(num_pred, K * B).split(preds_per_image, dim=0) + + def _predict_objs(self, obj_logits, preds_per_image): + probs = F.softmax(obj_logits, dim=-1) + probs = probs.split(preds_per_image, dim=0) + return probs + + def _predict_attrs(self, attr_logits, preds_per_image): + attr_logits = attr_logits[..., :-1].softmax(-1) + attr_probs, attrs = attr_logits.max(-1) + return attr_probs.split(preds_per_image, dim=0), attrs.split(preds_per_image, dim=0) + + @torch.no_grad() + def inference( + self, + obj_logits, + attr_logits, + box_deltas, + pred_boxes, + features, + sizes, + scales=None, + ): + # only the pred boxes is the + preds_per_image = [p.size(0) for p in pred_boxes] + boxes_all = self._predict_boxes(pred_boxes, box_deltas, preds_per_image) + obj_scores_all = self._predict_objs(obj_logits, preds_per_image) # list of length N + attr_probs_all, attrs_all = self._predict_attrs(attr_logits, preds_per_image) + features = features.split(preds_per_image, dim=0) + + # fun for each image too, also I can experiment and do multiple images + final_results = [] + zipped = zip(boxes_all, obj_scores_all, attr_probs_all, attrs_all, sizes) + for i, (boxes, obj_scores, attr_probs, attrs, size) in enumerate(zipped): + for nms_t in self.nms_thresh: + outputs = do_nms( + boxes, + obj_scores, + size, + self.score_thresh, + nms_t, + self.min_detections, + self.max_detections, + ) + if outputs is not None: + max_boxes, max_scores, classes, ids = outputs + break + + if scales is not None: + scale_yx = scales[i] + max_boxes[:, 0::2] *= scale_yx[1] + max_boxes[:, 1::2] *= scale_yx[0] + + final_results.append( + ( + max_boxes, + classes, + max_scores, + attrs[ids], + attr_probs[ids], + features[i][ids], + ) + ) + boxes, classes, class_probs, attrs, attr_probs, roi_features = map(list, zip(*final_results)) + return boxes, classes, class_probs, attrs, attr_probs, roi_features + + def training(self, obj_logits, attr_logits, box_deltas, pred_boxes, features, sizes): + pass + + def __call__( + self, + obj_logits, + attr_logits, + box_deltas, + pred_boxes, + features, + sizes, + scales=None, + ): + if self.training: + raise NotImplementedError() + return self.inference( + obj_logits, + attr_logits, + box_deltas, + pred_boxes, + features, + sizes, + scales=scales, + ) + + +class Res5ROIHeads(nn.Module): + """ + ROIHeads perform all per-region computation in an R-CNN. + It contains logic of cropping the regions, extract per-region features + (by the res-5 block in this case), and make per-region predictions. + """ + + def __init__(self, cfg, input_shape): + super().__init__() + self.batch_size_per_image = cfg.RPN.BATCH_SIZE_PER_IMAGE + self.positive_sample_fraction = cfg.ROI_HEADS.POSITIVE_FRACTION + self.in_features = cfg.ROI_HEADS.IN_FEATURES + self.num_classes = cfg.ROI_HEADS.NUM_CLASSES + self.proposal_append_gt = cfg.ROI_HEADS.PROPOSAL_APPEND_GT + self.feature_strides = {k: v.stride for k, v in input_shape.items()} + self.feature_channels = {k: v.channels for k, v in input_shape.items()} + self.cls_agnostic_bbox_reg = cfg.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG + self.stage_channel_factor = 2 ** 3 # res5 is 8x res2 + self.out_channels = cfg.RESNETS.RES2_OUT_CHANNELS * self.stage_channel_factor + + # self.proposal_matcher = Matcher( + # cfg.ROI_HEADS.IOU_THRESHOLDS, + # cfg.ROI_HEADS.IOU_LABELS, + # allow_low_quality_matches=False, + # ) + + pooler_resolution = cfg.ROI_BOX_HEAD.POOLER_RESOLUTION + pooler_scales = (1.0 / self.feature_strides[self.in_features[0]],) + sampling_ratio = cfg.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + res5_halve = cfg.ROI_BOX_HEAD.RES5HALVE + use_attr = cfg.ROI_BOX_HEAD.ATTR + num_attrs = cfg.ROI_BOX_HEAD.NUM_ATTRS + + self.pooler = ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + ) + + self.res5 = self._build_res5_block(cfg) + if not res5_halve: + """ + Modifications for VG in RoI heads: + 1. Change the stride of conv1 and shortcut in Res5.Block1 from 2 to 1 + 2. Modifying all conv2 with (padding: 1 --> 2) and (dilation: 1 --> 2) + """ + self.res5[0].conv1.stride = (1, 1) + self.res5[0].shortcut.stride = (1, 1) + for i in range(3): + self.res5[i].conv2.padding = (2, 2) + self.res5[i].conv2.dilation = (2, 2) + + self.box_predictor = FastRCNNOutputLayers( + self.out_channels, + self.num_classes, + self.cls_agnostic_bbox_reg, + use_attr=use_attr, + num_attrs=num_attrs, + ) + + def _build_res5_block(self, cfg): + stage_channel_factor = self.stage_channel_factor # res5 is 8x res2 + num_groups = cfg.RESNETS.NUM_GROUPS + width_per_group = cfg.RESNETS.WIDTH_PER_GROUP + bottleneck_channels = num_groups * width_per_group * stage_channel_factor + out_channels = self.out_channels + stride_in_1x1 = cfg.RESNETS.STRIDE_IN_1X1 + norm = cfg.RESNETS.NORM + + blocks = ResNet.make_stage( + BottleneckBlock, + 3, + first_stride=2, + in_channels=out_channels // 2, + bottleneck_channels=bottleneck_channels, + out_channels=out_channels, + num_groups=num_groups, + norm=norm, + stride_in_1x1=stride_in_1x1, + ) + return nn.Sequential(*blocks) + + def _shared_roi_transform(self, features, boxes): + x = self.pooler(features, boxes) + return self.res5(x) + + def forward(self, features, proposal_boxes, gt_boxes=None): + if self.training: + """ + see https://github.com/airsplay/py-bottom-up-attention/\ + blob/master/detectron2/modeling/roi_heads/roi_heads.py + """ + raise NotImplementedError() + + assert not proposal_boxes[0].requires_grad + box_features = self._shared_roi_transform(features, proposal_boxes) + feature_pooled = box_features.mean(dim=[2, 3]) # pooled to 1x1 + obj_logits, attr_logits, pred_proposal_deltas = self.box_predictor(feature_pooled) + return obj_logits, attr_logits, pred_proposal_deltas, feature_pooled + + +class AnchorGenerator(nn.Module): + """ + For a set of image sizes and feature maps, computes a set of anchors. + """ + + def __init__(self, cfg, input_shape: List[ShapeSpec]): + super().__init__() + sizes = cfg.ANCHOR_GENERATOR.SIZES + aspect_ratios = cfg.ANCHOR_GENERATOR.ASPECT_RATIOS + self.strides = [x.stride for x in input_shape] + self.offset = cfg.ANCHOR_GENERATOR.OFFSET + assert 0.0 <= self.offset < 1.0, self.offset + + """ + sizes (list[list[int]]): sizes[i] is the list of anchor sizes for feat map i + 1. given in absolute lengths in units of the input image; + 2. they do not dynamically scale if the input image size changes. + aspect_ratios (list[list[float]]) + strides (list[int]): stride of each input feature. + """ + + self.num_features = len(self.strides) + self.cell_anchors = nn.ParameterList(self._calculate_anchors(sizes, aspect_ratios)) + self._spacial_feat_dim = 4 + + def _calculate_anchors(self, sizes, aspect_ratios): + # If one size (or aspect ratio) is specified and there are multiple feature + # maps, then we "broadcast" anchors of that single size (or aspect ratio) + if len(sizes) == 1: + sizes *= self.num_features + if len(aspect_ratios) == 1: + aspect_ratios *= self.num_features + assert self.num_features == len(sizes) + assert self.num_features == len(aspect_ratios) + + cell_anchors = [self.generate_cell_anchors(s, a).float() for s, a in zip(sizes, aspect_ratios)] + + return cell_anchors + + @property + def box_dim(self): + return self._spacial_feat_dim + + @property + def num_cell_anchors(self): + """ + Returns: + list[int]: Each int is the number of anchors at every pixel location, on that feature map. + """ + return [len(cell_anchors) for cell_anchors in self.cell_anchors] + + def grid_anchors(self, grid_sizes): + anchors = [] + for (size, stride, base_anchors) in zip(grid_sizes, self.strides, self.cell_anchors): + shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors.device) + shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) + + anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)) + + return anchors + + def generate_cell_anchors(self, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)): + """ + anchors are continuous geometric rectangles + centered on one feature map point sample. + We can later build the set of anchors + for the entire feature map by tiling these tensors + """ + + anchors = [] + for size in sizes: + area = size ** 2.0 + for aspect_ratio in aspect_ratios: + w = math.sqrt(area / aspect_ratio) + h = aspect_ratio * w + x0, y0, x1, y1 = -w / 2.0, -h / 2.0, w / 2.0, h / 2.0 + anchors.append([x0, y0, x1, y1]) + return nn.Parameter(torch.Tensor(anchors)) + + def forward(self, features): + """ + Args: + features List[torch.Tensor]: list of feature maps on which to generate anchors. + Returns: + torch.Tensor: a list of #image elements. + """ + num_images = features[0].size(0) + grid_sizes = [feature_map.shape[-2:] for feature_map in features] + anchors_over_all_feature_maps = self.grid_anchors(grid_sizes) + anchors_over_all_feature_maps = torch.stack(anchors_over_all_feature_maps) + return anchors_over_all_feature_maps.unsqueeze(0).repeat_interleave(num_images, dim=0) + + +class RPNHead(nn.Module): + """ + RPN classification and regression heads. Uses a 3x3 conv to produce a shared + hidden state from which one 1x1 conv predicts objectness logits for each anchor + and a second 1x1 conv predicts bounding-box deltas specifying how to deform + each anchor into an object proposal. + """ + + def __init__(self, cfg, input_shape: List[ShapeSpec]): + super().__init__() + + # Standard RPN is shared across levels: + in_channels = [s.channels for s in input_shape] + assert len(set(in_channels)) == 1, "Each level must have the same channel!" + in_channels = in_channels[0] + + anchor_generator = AnchorGenerator(cfg, input_shape) + num_cell_anchors = anchor_generator.num_cell_anchors + box_dim = anchor_generator.box_dim + assert len(set(num_cell_anchors)) == 1, "Each level must have the same number of cell anchors" + num_cell_anchors = num_cell_anchors[0] + + if cfg.PROPOSAL_GENERATOR.HIDDEN_CHANNELS == -1: + hid_channels = in_channels + else: + hid_channels = cfg.PROPOSAL_GENERATOR.HIDDEN_CHANNELS + # Modifications for VG in RPN (modeling/proposal_generator/rpn.py) + # Use hidden dim instead fo the same dim as Res4 (in_channels) + + # 3x3 conv for the hidden representation + self.conv = nn.Conv2d(in_channels, hid_channels, kernel_size=3, stride=1, padding=1) + # 1x1 conv for predicting objectness logits + self.objectness_logits = nn.Conv2d(hid_channels, num_cell_anchors, kernel_size=1, stride=1) + # 1x1 conv for predicting box2box transform deltas + self.anchor_deltas = nn.Conv2d(hid_channels, num_cell_anchors * box_dim, kernel_size=1, stride=1) + + for layer in [self.conv, self.objectness_logits, self.anchor_deltas]: + nn.init.normal_(layer.weight, std=0.01) + nn.init.constant_(layer.bias, 0) + + def forward(self, features): + """ + Args: + features (list[Tensor]): list of feature maps + """ + pred_objectness_logits = [] + pred_anchor_deltas = [] + for x in features: + t = F.relu(self.conv(x)) + pred_objectness_logits.append(self.objectness_logits(t)) + pred_anchor_deltas.append(self.anchor_deltas(t)) + return pred_objectness_logits, pred_anchor_deltas + + +class RPN(nn.Module): + """ + Region Proposal Network, introduced by the Faster R-CNN paper. + """ + + def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): + super().__init__() + + self.min_box_side_len = cfg.PROPOSAL_GENERATOR.MIN_SIZE + self.in_features = cfg.RPN.IN_FEATURES + self.nms_thresh = cfg.RPN.NMS_THRESH + self.batch_size_per_image = cfg.RPN.BATCH_SIZE_PER_IMAGE + self.positive_fraction = cfg.RPN.POSITIVE_FRACTION + self.smooth_l1_beta = cfg.RPN.SMOOTH_L1_BETA + self.loss_weight = cfg.RPN.LOSS_WEIGHT + + self.pre_nms_topk = { + True: cfg.RPN.PRE_NMS_TOPK_TRAIN, + False: cfg.RPN.PRE_NMS_TOPK_TEST, + } + self.post_nms_topk = { + True: cfg.RPN.POST_NMS_TOPK_TRAIN, + False: cfg.RPN.POST_NMS_TOPK_TEST, + } + self.boundary_threshold = cfg.RPN.BOUNDARY_THRESH + + self.anchor_generator = AnchorGenerator(cfg, [input_shape[f] for f in self.in_features]) + self.box2box_transform = Box2BoxTransform(weights=cfg.RPN.BBOX_REG_WEIGHTS) + self.anchor_matcher = Matcher( + cfg.RPN.IOU_THRESHOLDS, + cfg.RPN.IOU_LABELS, + allow_low_quality_matches=True, + ) + self.rpn_head = RPNHead(cfg, [input_shape[f] for f in self.in_features]) + + def training(self, images, image_shapes, features, gt_boxes): + pass + + def inference(self, outputs, images, image_shapes, features, gt_boxes=None): + outputs = find_top_rpn_proposals( + outputs.predict_proposals(), + outputs.predict_objectness_logits(), + images, + image_shapes, + self.nms_thresh, + self.pre_nms_topk[self.training], + self.post_nms_topk[self.training], + self.min_box_side_len, + self.training, + ) + + results = [] + for img in outputs: + im_boxes, img_box_logits = img + img_box_logits, inds = img_box_logits.sort(descending=True) + im_boxes = im_boxes[inds] + results.append((im_boxes, img_box_logits)) + + (proposal_boxes, logits) = tuple(map(list, zip(*results))) + return proposal_boxes, logits + + def forward(self, images, image_shapes, features, gt_boxes=None): + """ + Args: + images (torch.Tensor): input images of length `N` + features (dict[str: Tensor]) + gt_instances + """ + # features is dict, key = block level, v = feature_map + features = [features[f] for f in self.in_features] + pred_objectness_logits, pred_anchor_deltas = self.rpn_head(features) + anchors = self.anchor_generator(features) + outputs = RPNOutputs( + self.box2box_transform, + self.anchor_matcher, + self.batch_size_per_image, + self.positive_fraction, + images, + pred_objectness_logits, + pred_anchor_deltas, + anchors, + self.boundary_threshold, + gt_boxes, + self.smooth_l1_beta, + ) + # For RPN-only models, the proposals are the final output + + if self.training: + raise NotImplementedError() + return self.training(outputs, images, image_shapes, features, gt_boxes) + else: + return self.inference(outputs, images, image_shapes, features, gt_boxes) + + +class FastRCNNOutputLayers(nn.Module): + """ + Two linear layers for predicting Fast R-CNN outputs: + (1) proposal-to-detection box regression deltas + (2) classification scores + """ + + def __init__( + self, + input_size, + num_classes, + cls_agnostic_bbox_reg, + box_dim=4, + use_attr=False, + num_attrs=-1, + ): + """ + Args: + input_size (int): channels, or (channels, height, width) + num_classes (int) + cls_agnostic_bbox_reg (bool) + box_dim (int) + """ + super().__init__() + + if not isinstance(input_size, int): + input_size = np.prod(input_size) + + # (do + 1 for background class) + self.cls_score = nn.Linear(input_size, num_classes + 1) + num_bbox_reg_classes = 1 if cls_agnostic_bbox_reg else num_classes + self.bbox_pred = nn.Linear(input_size, num_bbox_reg_classes * box_dim) + + self.use_attr = use_attr + if use_attr: + """ + Modifications for VG in RoI heads + Embedding: {num_classes + 1} --> {input_size // 8} + Linear: {input_size + input_size // 8} --> {input_size // 4} + Linear: {input_size // 4} --> {num_attrs + 1} + """ + self.cls_embedding = nn.Embedding(num_classes + 1, input_size // 8) + self.fc_attr = nn.Linear(input_size + input_size // 8, input_size // 4) + self.attr_score = nn.Linear(input_size // 4, num_attrs + 1) + + nn.init.normal_(self.cls_score.weight, std=0.01) + nn.init.normal_(self.bbox_pred.weight, std=0.001) + for item in [self.cls_score, self.bbox_pred]: + nn.init.constant_(item.bias, 0) + + def forward(self, roi_features): + if roi_features.dim() > 2: + roi_features = torch.flatten(roi_features, start_dim=1) + scores = self.cls_score(roi_features) + proposal_deltas = self.bbox_pred(roi_features) + if self.use_attr: + _, max_class = scores.max(-1) # [b, c] --> [b] + cls_emb = self.cls_embedding(max_class) # [b] --> [b, 256] + roi_features = torch.cat([roi_features, cls_emb], -1) # [b, 2048] + [b, 256] --> [b, 2304] + roi_features = self.fc_attr(roi_features) + roi_features = F.relu(roi_features) + attr_scores = self.attr_score(roi_features) + return scores, attr_scores, proposal_deltas + else: + return scores, proposal_deltas + + +class GeneralizedRCNN(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.device = torch.device(cfg.MODEL.DEVICE) + self.backbone = build_backbone(cfg) + self.proposal_generator = RPN(cfg, self.backbone.output_shape()) + self.roi_heads = Res5ROIHeads(cfg, self.backbone.output_shape()) + self.roi_outputs = ROIOutputs(cfg) + self.to(self.device) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + config = kwargs.pop("config", None) + state_dict = kwargs.pop("state_dict", None) + cache_dir = kwargs.pop("cache_dir", None) + from_tf = kwargs.pop("from_tf", False) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_cdn = kwargs.pop("use_cdn", True) + + # Load config if we don't provide a configuration + if not isinstance(config, Config): + config_path = config if config is not None else pretrained_model_name_or_path + # try: + config = Config.from_pretrained( + config_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + ) + + # Load model + if pretrained_model_name_or_path is not None: + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + # Load from a PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + else: + raise EnvironmentError( + "Error no file named {} found in directory {} ".format( + WEIGHTS_NAME, + pretrained_model_name_or_path, + ) + ) + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): + archive_file = pretrained_model_name_or_path + elif os.path.isfile(pretrained_model_name_or_path + ".index"): + assert ( + from_tf + ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format( + pretrained_model_name_or_path + ".index" + ) + archive_file = pretrained_model_name_or_path + ".index" + else: + archive_file = hf_bucket_url( + pretrained_model_name_or_path, + filename=WEIGHTS_NAME, + use_cdn=use_cdn, + ) + + try: + # Load from URL or cache if already cached + resolved_archive_file = cached_path( + archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + ) + if resolved_archive_file is None: + raise EnvironmentError + except EnvironmentError: + msg = f"Can't load weights for '{pretrained_model_name_or_path}'." + raise EnvironmentError(msg) + + if resolved_archive_file == archive_file: + print("loading weights file {}".format(archive_file)) + else: + print("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file)) + else: + resolved_archive_file = None + + # Instantiate model. + model = cls(config) + + if state_dict is None: + try: + try: + state_dict = torch.load(resolved_archive_file, map_location="cpu") + except Exception: + state_dict = load_checkpoint(resolved_archive_file) + + except Exception: + raise OSError( + "Unable to load weights from pytorch checkpoint file. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " + ) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + + # Convert old format to new format if needed from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + model_to_load = model + model_to_load.load_state_dict(state_dict) + + if model.__class__.__name__ != model_to_load.__class__.__name__: + base_model_state_dict = model_to_load.state_dict().keys() + head_model_state_dict_without_base_prefix = [ + key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys() + ] + missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) + + if len(unexpected_keys) > 0: + print( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " + f"initializing {model.__class__.__name__}: {unexpected_keys}\n" + f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" + f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " + f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + print(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + print( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " + f"and are newly initialized: {missing_keys}\n" + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + else: + print( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" + f"If your task is similar to the task the model of the checkpoint was trained on, " + f"you can already use {model.__class__.__name__} for predictions without further training." + ) + if len(error_msgs) > 0: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + model.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + return model + + def forward( + self, + images, + image_shapes, + gt_boxes=None, + proposals=None, + scales_yx=None, + **kwargs, + ): + """ + kwargs: + max_detections (int), return_tensors {"np", "pt", None}, padding {None, + "max_detections"}, pad_value (int), location = {"cuda", "cpu"} + """ + if self.training: + raise NotImplementedError() + return self.inference( + images=images, + image_shapes=image_shapes, + gt_boxes=gt_boxes, + proposals=proposals, + scales_yx=scales_yx, + **kwargs, + ) + + @torch.no_grad() + def inference( + self, + images, + image_shapes, + gt_boxes=None, + proposals=None, + scales_yx=None, + **kwargs, + ): + # run images through backbone + original_sizes = image_shapes * scales_yx + features = self.backbone(images) + + # generate proposals if none are available + if proposals is None: + proposal_boxes, _ = self.proposal_generator(images, image_shapes, features, gt_boxes) + else: + assert proposals is not None + + # pool object features from either gt_boxes, or from proposals + obj_logits, attr_logits, box_deltas, feature_pooled = self.roi_heads(features, proposal_boxes, gt_boxes) + + # prepare FRCNN Outputs and select top proposals + boxes, classes, class_probs, attrs, attr_probs, roi_features = self.roi_outputs( + obj_logits=obj_logits, + attr_logits=attr_logits, + box_deltas=box_deltas, + pred_boxes=proposal_boxes, + features=feature_pooled, + sizes=image_shapes, + scales=scales_yx, + ) + + # will we pad??? + subset_kwargs = { + "max_detections": kwargs.get("max_detections", None), + "return_tensors": kwargs.get("return_tensors", None), + "pad_value": kwargs.get("pad_value", 0), + "padding": kwargs.get("padding", None), + } + preds_per_image = torch.tensor([p.size(0) for p in boxes]) + boxes = pad_list_tensors(boxes, preds_per_image, **subset_kwargs) + classes = pad_list_tensors(classes, preds_per_image, **subset_kwargs) + class_probs = pad_list_tensors(class_probs, preds_per_image, **subset_kwargs) + attrs = pad_list_tensors(attrs, preds_per_image, **subset_kwargs) + attr_probs = pad_list_tensors(attr_probs, preds_per_image, **subset_kwargs) + roi_features = pad_list_tensors(roi_features, preds_per_image, **subset_kwargs) + subset_kwargs["padding"] = None + preds_per_image = pad_list_tensors(preds_per_image, None, **subset_kwargs) + sizes = pad_list_tensors(image_shapes, None, **subset_kwargs) + normalized_boxes = norm_box(boxes, original_sizes) + return OrderedDict( + { + "obj_ids": classes, + "obj_probs": class_probs, + "attr_ids": attrs, + "attr_probs": attr_probs, + "boxes": boxes, + "sizes": sizes, + "preds_per_image": preds_per_image, + "roi_features": roi_features, + "normalized_boxes": normalized_boxes, + } + ) diff --git a/lxmert/src/param.py b/lxmert/src/param.py new file mode 100644 index 0000000000000000000000000000000000000000..c2666e7b55632fa06865d825ea057657b4349c05 --- /dev/null +++ b/lxmert/src/param.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyleft 2019 project LXRT. + +import argparse +import random + +import numpy as np +import torch + + +def get_optimizer(optim): + # Bind the optimizer + if optim == 'rms': + print("Optimizer: Using RMSProp") + optimizer = torch.optim.RMSprop + elif optim == 'adam': + print("Optimizer: Using Adam") + optimizer = torch.optim.Adam + elif optim == 'adamax': + print("Optimizer: Using Adamax") + optimizer = torch.optim.Adamax + elif optim == 'sgd': + print("Optimizer: sgd") + optimizer = torch.optim.SGD + elif 'bert' in optim: + optimizer = 'bert' # The bert optimizer will be bind later. + else: + assert False, "Please add your optimizer %s in the list." % optim + + return optimizer + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Data Splits + parser.add_argument("--train", default='train') + parser.add_argument("--valid", default='valid') + parser.add_argument("--test", default=None) + + # Training Hyper-parameters + parser.add_argument('--batchSize', dest='batch_size', type=int, default=256) + parser.add_argument('--optim', default='bert') + parser.add_argument('--lr', type=float, default=1e-4) + parser.add_argument('--epochs', type=int, default=10) + parser.add_argument('--dropout', type=float, default=0.1) + parser.add_argument('--seed', type=int, default=9595, help='random seed') + + # Debugging + parser.add_argument('--output', type=str, default='snap/test') + parser.add_argument("--fast", action='store_const', default=False, const=True) + parser.add_argument("--tiny", action='store_const', default=False, const=True) + parser.add_argument("--tqdm", action='store_const', default=False, const=True) + + # Model Loading + parser.add_argument('--load', type=str, default=None, + help='Load the model (usually the fine-tuned model).') + parser.add_argument('--loadLXMERT', dest='load_lxmert', type=str, default=None, + help='Load the pre-trained lxmert model.') + parser.add_argument('--loadLXMERTQA', dest='load_lxmert_qa', type=str, default=None, + help='Load the pre-trained lxmert model with QA answer head.') + parser.add_argument("--fromScratch", dest='from_scratch', action='store_const', default=False, const=True, + help='If none of the --load, --loadLXMERT, --loadLXMERTQA is set, ' + 'the model would be trained from scratch. If --fromScratch is' + ' not specified, the model would load BERT-pre-trained weights by' + ' default. ') + + # Optimization + parser.add_argument("--mceLoss", dest='mce_loss', action='store_const', default=False, const=True) + + # LXRT Model Config + # Note: LXRT = L, X, R (three encoders), Transformer + parser.add_argument("--llayers", default=9, type=int, help='Number of Language layers') + parser.add_argument("--xlayers", default=5, type=int, help='Number of CROSS-modality layers.') + parser.add_argument("--rlayers", default=5, type=int, help='Number of object Relationship layers.') + + # lxmert Pre-training Config + parser.add_argument("--taskMatched", dest='task_matched', action='store_const', default=False, const=True) + parser.add_argument("--taskMaskLM", dest='task_mask_lm', action='store_const', default=False, const=True) + parser.add_argument("--taskObjPredict", dest='task_obj_predict', action='store_const', default=False, const=True) + parser.add_argument("--taskQA", dest='task_qa', action='store_const', default=False, const=True) + parser.add_argument("--visualLosses", dest='visual_losses', default='obj,attr,feat', type=str) + parser.add_argument("--qaSets", dest='qa_sets', default=None, type=str) + parser.add_argument("--wordMaskRate", dest='word_mask_rate', default=0.15, type=float) + parser.add_argument("--objMaskRate", dest='obj_mask_rate', default=0.15, type=float) + + # Training configuration + parser.add_argument("--multiGPU", action='store_const', default=False, const=True) + parser.add_argument("--numWorkers", dest='num_workers', default=0) + + + # perturbation configuration + parser.add_argument('--method', type=str, + default='ours_no_lrp', + choices=['ours_with_lrp', 'rollout', 'partial_lrp', 'transformer_att', + 'raw_attn', 'attn_gradcam', 'ours_with_lrp_no_normalization', 'ours_no_lrp', + 'ours_no_lrp_no_norm', 'ablation_no_aggregation', 'ablation_no_self_in_10'], + help='') + parser.add_argument('--num-samples', type=int, + default=10000, + help='') + parser.add_argument('--is-positive-pert', type=bool, + default=False, + help='') + parser.add_argument('--is-text-pert', type=bool, + default=False, + help='') + parser.add_argument('--COCO_path', type=str, + default='', + help='path to COCO 2014 validation set') + + # Parse the arguments. + args = parser.parse_args() + + # Bind optimizer class. + args.optimizer = get_optimizer(args.optim) + + # Set seeds + torch.manual_seed(args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + + return args + + +args = parse_args() diff --git a/lxmert/src/pretrain/__init__.py b/lxmert/src/pretrain/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lxmert/src/pretrain/lxmert_data.py b/lxmert/src/pretrain/lxmert_data.py new file mode 100644 index 0000000000000000000000000000000000000000..42d6621a194f66246fbebfb6b20e3b8839b0268d --- /dev/null +++ b/lxmert/src/pretrain/lxmert_data.py @@ -0,0 +1,255 @@ +# coding=utf-8 +# Copyleft 2019 project LXRT. + +from collections import defaultdict +import json +import random + +import numpy as np +from torch.utils.data import Dataset + +from param import args +from pretrain.qa_answer_table import AnswerTable +from utils import load_obj_tsv + +TINY_IMG_NUM = 500 +FAST_IMG_NUM = 5000 + +Split2ImgFeatPath = { + 'mscoco_train': 'data/mscoco_imgfeat/train2014_obj36.tsv', + 'mscoco_minival': 'data/mscoco_imgfeat/val2014_obj36.tsv', + 'mscoco_nominival': 'data/mscoco_imgfeat/val2014_obj36.tsv', + 'vgnococo': 'data/vg_gqa_imgfeat/vg_gqa_obj36.tsv', +} + + +class InputExample(object): + """A single training/test example for the language model.""" + def __init__(self, uid, sent, visual_feats=None, + obj_labels=None, attr_labels=None, + is_matched=None, label=None): + self.uid = uid + self.sent = sent + self.visual_feats = visual_feats + self.obj_labels = obj_labels + self.attr_labels = attr_labels + self.is_matched = is_matched # whether the visual and obj matched + self.label = label + + +class LXMERTDataset: + def __init__(self, splits: str, qa_sets=None): + """ + :param splits: The data sources to be loaded + :param qa_sets: if None, no action + o.w., only takes the answers appearing in these dsets + and remove all unlabeled data (MSCOCO captions) + """ + self.name = splits + self.sources = splits.split(',') + + # Loading datasets to data + self.data = [] + for source in self.sources: + self.data.extend(json.load(open("data/lxmert/%s.json" % source))) + print("Load %d data from %s" % (len(self.data), self.name)) + + # Create answer table according to the qa_sets + self.answer_table = AnswerTable(qa_sets) + print("Load an answer table of size %d." % (len(self.answer_table.ans2id_map()))) + + # Modify the answers + for datum in self.data: + labelf = datum['labelf'] + for cat, labels in labelf.items(): + for label in labels: + for ans in list(label.keys()): + new_ans = self.answer_table.convert_ans(ans) + if self.answer_table.used(new_ans): + if ans != new_ans: + label[new_ans] = label.pop(ans) + else: + label.pop(ans) + + def __len__(self): + return len(self.data) + + +def make_uid(img_id, dset, sent_idx): + return "%s_%s_%03d" % (img_id, dset, sent_idx), + + +""" +Example in obj tsv: +FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", + "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] +""" +class LXMERTTorchDataset(Dataset): + def __init__(self, dataset: LXMERTDataset, topk=-1): + super().__init__() + self.raw_dataset = dataset + self.task_matched = args.task_matched + + if args.tiny: + topk = TINY_IMG_NUM + elif args.fast: + topk = FAST_IMG_NUM + + # Load the dataset + img_data = [] + for source in self.raw_dataset.sources: + img_data.extend(load_obj_tsv(Split2ImgFeatPath[source], topk)) + + self.imgid2img = {} + for img_datum in img_data: + self.imgid2img[img_datum['img_id']] = img_datum + + # Filter out the dataset + used_data = [] + for datum in self.raw_dataset.data: + if datum['img_id'] in self.imgid2img: + used_data.append(datum) + + # Flatten the dataset (into one sent + one image entries) + self.data = [] + for datum in used_data: + sentf = datum['sentf'] + for sents_cat, sents in sentf.items(): + if sents_cat in datum['labelf']: + labels = datum['labelf'][sents_cat] + else: + labels = None + for sent_idx, sent in enumerate(sents): + new_datum = { + 'uid': make_uid(datum['img_id'], sents_cat, sent_idx), + 'img_id': datum['img_id'], + 'sent': sent + } + if labels is not None: + new_datum['label'] = labels[sent_idx] + self.data.append(new_datum) + print("Use %d data in torch dataset" % (len(self.data))) + + def __len__(self): + return len(self.data) + + def random_feat(self): + """Get a random obj feat from the dataset.""" + datum = self.data[random.randint(0, len(self.data)-1)] + img_id = datum['img_id'] + img_info = self.imgid2img[img_id] + feat = img_info['features'][random.randint(0, 35)] + return feat + + def __getitem__(self, item: int): + datum = self.data[item] + + uid = datum['uid'] + img_id = datum['img_id'] + + # Get image info + img_info = self.imgid2img[img_id] + obj_num = img_info['num_boxes'] + feats = img_info['features'].copy() + boxes = img_info['boxes'].copy() + obj_labels = img_info['objects_id'].copy() + obj_confs = img_info['objects_conf'].copy() + attr_labels = img_info['attrs_id'].copy() + attr_confs = img_info['attrs_conf'].copy() + assert obj_num == len(boxes) == len(feats) + + # Normalize the boxes (to 0 ~ 1) + img_h, img_w = img_info['img_h'], img_info['img_w'] + boxes = boxes.copy() + boxes[:, (0, 2)] /= img_w + boxes[:, (1, 3)] /= img_h + np.testing.assert_array_less(boxes, 1+1e-5) + np.testing.assert_array_less(-boxes, 0+1e-5) + + # If calculating the matched loss, replace the sentence with an sentence + # corresponding to other image. + is_matched = 1 + sent = datum['sent'] + if self.task_matched: + if random.random() < 0.5: + is_matched = 0 + other_datum = self.data[random.randint(0, len(self.data)-1)] + while other_datum['img_id'] == img_id: + other_datum = self.data[random.randint(0, len(self.data)-1)] + sent = other_datum['sent'] + + # Label, convert answer to id + if 'label' in datum: + label = datum['label'].copy() + for ans in list(label.keys()): + label[self.raw_dataset.answer_table.ans2id(ans)] = label.pop(ans) + else: + label = None + + # Create target + example = InputExample( + uid, sent, (feats, boxes), + (obj_labels, obj_confs), (attr_labels, attr_confs), + is_matched, label + ) + return example + + +class LXMERTEvaluator: + def __init__(self, dataset: LXMERTDataset): + self.raw_dataset = dataset + + # Create QA Eval Data + self.data = [] + for datum in self.raw_dataset.data: + sentf = datum['sentf'] + for sents_cat, sents in sentf.items(): + if sents_cat in datum['labelf']: # A labeled dataset + labels = datum['labelf'][sents_cat] + for sent_idx, sent in enumerate(sents): + new_datum = { + 'uid': make_uid(datum['img_id'], sents_cat, sent_idx), + 'img_id': datum['img_id'], + 'sent': sent, + 'dset': sents_cat, + 'label': labels[sent_idx] + } + self.data.append(new_datum) + + # uid2datum + self.uid2datum = {} + for datum in self.data: + self.uid2datum[datum['uid']] = datum + + def evaluate(self, uid2ans: dict, pprint=False): + score = 0. + cnt = 0 + dset2score = defaultdict(lambda: 0.) + dset2cnt = defaultdict(lambda: 0) + for uid, ans in uid2ans.items(): + if uid not in self.uid2datum: # Not a labeled data + continue + datum = self.uid2datum[uid] + label = datum['label'] + dset = datum['dset'] + if ans in label: + score += label[ans] + dset2score[dset] += label[ans] + cnt += 1 + dset2cnt[dset] += 1 + accu = score / cnt + dset2accu = {} + for dset in dset2cnt: + dset2accu[dset] = dset2score[dset] / dset2cnt[dset] + + if pprint: + accu_str = "Overall Accu %0.4f, " % (accu) + sorted_keys = sorted(dset2accu.keys()) + for key in sorted_keys: + accu_str += "%s Accu %0.4f, " % (key, dset2accu[key]) + print(accu_str) + + return accu, dset2accu + + def dump_result(self, uid2ans: dict, path): + raise NotImplemented diff --git a/lxmert/src/pretrain/lxmert_pretrain.py b/lxmert/src/pretrain/lxmert_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..d88c445d04bf7e680e01eac13a6c4b9256c65d52 --- /dev/null +++ b/lxmert/src/pretrain/lxmert_pretrain.py @@ -0,0 +1,435 @@ +# coding=utf-8 +# Copyleft 2019 project LXRT. + +import collections +import os +import random + +from tqdm import tqdm +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from param import args +from pretrain.lxmert_data import InputExample, LXMERTDataset, LXMERTTorchDataset, LXMERTEvaluator +from lxrt.entry import set_visual_config +from lxrt.tokenization import BertTokenizer +from lxrt.modeling import LXRTPretraining + +DataTuple = collections.namedtuple("DataTuple", 'dataset torchdset loader evaluator') + + +def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1) -> DataTuple: + # Decide which QA datasets would be used in pre-training. + # Options: vqa, gqa, visual7w + # Note: visual7w is a part of vgqa, we take the name here. + qa_sets = args.qa_sets + if qa_sets is not None: + qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(",")) + + # Build dataset, data loader, and evaluator. + dset = LXMERTDataset(splits, qa_sets=qa_sets) + tset = LXMERTTorchDataset(dset, topk) + data_loader = DataLoader( + tset, batch_size=bs, + shuffle=shuffle, num_workers=args.num_workers, + collate_fn=lambda x: x, + drop_last=drop_last, pin_memory=True + ) + evaluator = LXMERTEvaluator(dset) + print() + + return DataTuple(dataset=dset, torchdset=tset, loader=data_loader, evaluator=evaluator) + + +train_tuple = get_tuple(args.train, args.batch_size, shuffle=True, drop_last=True) +valid_batch_size = 2048 if args.multiGPU else 512 +valid_tuple = get_tuple(args.valid, valid_batch_size, shuffle=False, drop_last=False, topk=5000) + + +class InputFeatures(object): + """A single set of features of data.""" + + def __init__(self, + input_ids, input_mask, segment_ids, lm_label_ids, + visual_feats, obj_labels, + is_matched, ans): + self.input_ids = input_ids + self.input_mask = input_mask + self.segment_ids = segment_ids + self.lm_label_ids = lm_label_ids + + self.visual_feats = visual_feats + self.obj_labels = obj_labels + + self.is_matched = is_matched + + self.ans = ans + + +def random_word(tokens, tokenizer): + """ + Masking some random tokens for Language Model task with probabilities as in the original BERT paper. + :param tokens: list of str, tokenized sentence. + :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here) + :return: (list of str, list of int), masked tokens and related labels for LM prediction + """ + output_label = [] + + for i, token in enumerate(tokens): + prob = random.random() + # mask token with probability + ratio = args.word_mask_rate + if prob < ratio: + prob /= ratio + + # 80% randomly change token to mask token + if prob < 0.8: + tokens[i] = "[MASK]" + + # 10% randomly change token to random token + elif prob < 0.9: + tokens[i] = random.choice(list(tokenizer.vocab.items()))[0] + + # -> rest 10% randomly keep current token + + # append current token to output (we will predict these later) + try: + output_label.append(tokenizer.vocab[token]) + except KeyError: + # For unknown words (should not occur with BPE vocab) + output_label.append(tokenizer.vocab["[UNK]"]) + else: + # no masking token (will be ignored by loss function later) + output_label.append(-1) + + return tokens, output_label + + +def random_feat(feats): + mask_feats = feats.copy() + feat_mask = np.zeros(len(feats), dtype=np.float32) + for i in range(len(feats)): + prob = random.random() + # mask token with probability + if prob < args.obj_mask_rate: + prob /= args.obj_mask_rate + + # 80% randomly change token to zero feat + if prob < 0.8: + mask_feats[i, :] = 0. + + # 10% randomly change token to random feat + elif prob < 0.9: + mask_feats[i, :] = train_tuple.torchdset.random_feat() + # -> rest 10% randomly keep current feat + + # Need to predict this feat + feat_mask[i] = 1. + + return mask_feats, feat_mask + + +def convert_example_to_features(example: InputExample, max_seq_length, tokenizer)->InputFeatures: + """ + Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with + IDs, LM labels, input_mask, CLS and SEP tokens etc. + :param example: InputExample, containing sentence input as strings and is_next label + :param max_seq_length: int, maximum length of sequence. + :param tokenizer: Tokenizer + :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training) + """ + tokens = tokenizer.tokenize(example.sent.strip()) + + # Account for [CLS] and [SEP] with "- 2" + if len(tokens) > max_seq_length - 2: + tokens = tokens[:(max_seq_length - 2)] + + # Ge random words + masked_tokens, masked_label = random_word(tokens, tokenizer) + + # concatenate lm labels and account for CLS, SEP, SEP + masked_tokens = ['[CLS]'] + masked_tokens + ['[SEP]'] + input_ids = tokenizer.convert_tokens_to_ids(masked_tokens) + + # Mask & Segment Word + lm_label_ids = ([-1] + masked_label + [-1]) + input_mask = [1] * len(input_ids) + segment_ids = [0] * len(input_ids) + + # Zero-pad up to the sequence length. + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + lm_label_ids.append(-1) + + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + assert len(lm_label_ids) == max_seq_length + + feat, boxes = example.visual_feats + obj_labels, obj_confs = example.obj_labels + attr_labels, attr_confs = example.attr_labels + + # Mask Image Features: + masked_feat, feat_mask = random_feat(feat) + + # QA answer label + if example.label is None or len(example.label) == 0 or example.is_matched != 1: + # 1. No label 2. Label is pruned 3. unmatched visual + language pair + ans = -1 + else: + keys, values = zip(*example.label.items()) + if len(keys) == 1: + ans = keys[0] + else: + value_sum = sum(values) + prob = [value / value_sum for value in values] + choice = np.random.multinomial(1, prob).argmax() + ans = keys[choice] + + features = InputFeatures( + input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids, + lm_label_ids=lm_label_ids, + visual_feats=(masked_feat, boxes), + obj_labels={ + 'obj': (obj_labels, obj_confs), + 'attr': (attr_labels, attr_confs), + 'feat': (feat, feat_mask), + }, + is_matched=example.is_matched, + ans=ans, + ) + return features + + +LOSSES_NAME = ('Mask_LM', 'Matched', 'Obj', 'Attr', 'Feat', 'QA') + + +class LXMERT: + def __init__(self, max_seq_length): + super().__init__() + self.max_seq_length = max_seq_length + + self.tokenizer = BertTokenizer.from_pretrained( + "bert-base-uncased", + do_lower_case=True + ) + + # Build model + set_visual_config(args) + self.model = LXRTPretraining.from_pretrained( + "bert-base-uncased", + task_mask_lm=args.task_mask_lm, + task_obj_predict=args.task_obj_predict, + task_matched=args.task_matched, + task_qa=args.task_qa, + visual_losses=args.visual_losses, + num_answers=train_tuple.dataset.answer_table.num_answers + ) + + # Weight initialization and loading + if args.from_scratch: + print("Train from Scratch: re-initialize all BERT weights.") + self.model.apply(self.model.init_bert_weights) + if args.load is not None: + self.load(args.load) + if args.load_lxmert is not None: + # Load lxmert would not load the answer head. + self.load_lxmert(args.load_lxmert) + + # GPU Options + self.model = self.model.cuda() + if args.multiGPU: + self.model = nn.DataParallel(self.model) + + def forward(self, examples): + train_features = [convert_example_to_features(example, self.max_seq_length, self.tokenizer) + for example in examples] + + # language Inputs + input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda() + input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda() + segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda() + + # Visual Inputs + feats = torch.from_numpy(np.stack([f.visual_feats[0] for f in train_features])).cuda() + pos = torch.from_numpy(np.stack([f.visual_feats[1] for f in train_features])).cuda() + + # Language Prediction + lm_labels = torch.tensor([f.lm_label_ids for f in train_features], dtype=torch.long).cuda() + + # Visual Prediction + obj_labels = {} + for key in ('obj', 'attr', 'feat'): + visn_labels = torch.from_numpy(np.stack([f.obj_labels[key][0] for f in train_features])).cuda() + visn_mask = torch.from_numpy(np.stack([f.obj_labels[key][1] for f in train_features])).cuda() + assert visn_labels.size(0) == visn_mask.size(0) and visn_labels.size(1) == visn_mask.size(1) + obj_labels[key] = (visn_labels, visn_mask) + + # Joint Prediction + matched_labels = torch.tensor([f.is_matched for f in train_features], dtype=torch.long).cuda() + ans = torch.from_numpy(np.stack([f.ans for f in train_features])).cuda() + + """ + forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, + visual_feats=None, pos=None, obj_labels=None, matched_label=None, ans=None): + """ + loss, losses, ans_logit = self.model( + input_ids, segment_ids, input_mask, lm_labels, + feats, pos, obj_labels, matched_labels, ans + ) + return loss, losses.detach().cpu(), ans_logit + + def train_batch(self, optim, batch): + optim.zero_grad() + loss, losses, ans_logit = self.forward(batch) + if args.multiGPU: + loss = loss.mean() + losses = losses.mean(0) + loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), 1.) + optim.step() + + return loss.item(), losses.cpu().numpy(), ans_logit + + def valid_batch(self, batch): + with torch.no_grad(): + loss, losses, ans_logit = self.forward(batch) + if args.multiGPU: + loss = loss.mean() + losses = losses.mean(0) + return loss.item(), losses.cpu().numpy(), ans_logit + + def train(self, train_tuple: DataTuple, eval_tuple: DataTuple): + train_ld = train_tuple.loader + + # Optimizer + from lxrt.optimization import BertAdam + batch_per_epoch = len(train_ld) + t_total = int(batch_per_epoch * args.epochs) + warmup_ratio = 0.05 + warmup_iters = int(t_total * warmup_ratio) + print("Batch per epoch: %d" % batch_per_epoch) + print("Total Iters: %d" % t_total) + print("Warm up Iters: %d" % warmup_iters) + optim = BertAdam(self.model.parameters(), lr=args.lr, warmup=warmup_ratio, t_total=t_total) + + # Train + best_eval_loss = 9595. + for epoch in range(args.epochs): + # Train + self.model.train() + total_loss = 0. + total_losses = 0. + uid2ans = {} + for batch in tqdm(train_ld, total=len(train_ld)): + loss, losses, logit = self.train_batch(optim, batch) + total_loss += loss + total_losses += losses + + if args.task_qa: + score, label = logit.max(1) + for datum, l in zip(batch, label.cpu().numpy()): + uid = datum.uid + ans = train_tuple.dataset.answer_table.id2ans(l) + uid2ans[uid] = ans + + print("The training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch)) + losses_str = "The losses are " + for name, loss in zip(LOSSES_NAME, total_losses): + losses_str += "%s: %0.4f " % (name, loss / batch_per_epoch) + print(losses_str) + if args.task_qa: + train_tuple.evaluator.evaluate(uid2ans, pprint=True) + + # Eval + avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1) + + # Save + if avg_eval_loss < best_eval_loss: + best_eval_loss = avg_eval_loss + self.save("BEST_EVAL_LOSS") + self.save("Epoch%02d" % (epoch+1)) + + def evaluate_epoch(self, eval_tuple: DataTuple, iters: int=-1): + self.model.eval() + eval_ld = eval_tuple.loader + total_loss = 0. + total_losses = 0. + uid2ans = {} + for i, batch in enumerate(eval_ld): + loss, losses, logit = self.valid_batch(batch) + total_loss += loss + total_losses += losses + if args.task_qa: + score, label = logit.max(1) + for datum, l in zip(batch, label.cpu().numpy()): + uid = datum.uid + ans = train_tuple.dataset.answer_table.id2ans(l) + uid2ans[uid] = ans + if i == iters: + break + + print("The valid loss is %0.4f" % (total_loss / len(eval_ld))) + losses_str = "The losses are " + for name, loss in zip(LOSSES_NAME, total_losses / len(eval_ld)): + losses_str += "%s: %0.4f " % (name, loss) + print(losses_str) + + if args.task_qa: + eval_tuple.evaluator.evaluate(uid2ans, pprint=True) + + return total_loss / len(eval_ld) + + def save(self, name): + torch.save(self.model.state_dict(), + os.path.join(args.output, "%s_LXRT.pth" % name)) + + def load(self, path): + print("Load BERT extractor from %s" % path) + state_dict = torch.load("%s_LXRT.pth" % path) + self.model.load_state_dict(state_dict) + + def load_lxmert(self, path): + print("Load lxmert model from %s" % path) + state_dict = torch.load("%s_LXRT.pth" % path) + + # Do not load any answer head + for key in list(state_dict.keys()): + if 'answer' in key: + state_dict.pop(key) + + # Change Multi GPU to single GPU + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("module."): + new_state_dict[key[len("module."):]] = value + state_dict = new_state_dict + + load_keys = set(state_dict.keys()) + model_keys = set(self.model.state_dict().keys()) + print() + print("Keys in loaded but not in model:") + for key in sorted(load_keys.difference(model_keys)): + print(key) + print() + print("Keys in model but not in loaded:") + for key in sorted(model_keys.difference(load_keys)): + print(key) + print() + + self.model.load_state_dict(state_dict, strict=False) + + +if __name__ == "__main__": + + lxmert = LXMERT(max_seq_length=20) + + + lxmert.train(train_tuple, valid_tuple) diff --git a/lxmert/src/pretrain/qa_answer_table.py b/lxmert/src/pretrain/qa_answer_table.py new file mode 100644 index 0000000000000000000000000000000000000000..d7fd92e55f95f94ced4bd8f149a9bf504f26e556 --- /dev/null +++ b/lxmert/src/pretrain/qa_answer_table.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyleft 2019 project LXRT. + +import json +import torch + + +class AnswerTable: + ANS_CONVERT = { + "a man": "man", + "the man": "man", + "a woman": "woman", + "the woman": "woman", + 'one': '1', + 'two': '2', + 'three': '3', + 'four': '4', + 'five': '5', + 'six': '6', + 'seven': '7', + 'eight': '8', + 'nine': '9', + 'ten': '10', + 'grey': 'gray', + } + + def __init__(self, dsets=None): + self.all_ans = json.load(open("data/lxmert/all_ans.json")) + if dsets is not None: + dsets = set(dsets) + # If the answer is used in the dsets + self.anss = [ans['ans'] for ans in self.all_ans if + len(set(ans['dsets']) & dsets) > 0] + else: + self.anss = [ans['ans'] for ans in self.all_ans] + self.ans_set = set(self.anss) + + self._id2ans_map = self.anss + self._ans2id_map = {ans: ans_id for ans_id, ans in enumerate(self.anss)} + + assert len(self._id2ans_map) == len(self._ans2id_map) + for ans_id, ans in enumerate(self._id2ans_map): + assert self._ans2id_map[ans] == ans_id + + def convert_ans(self, ans): + if len(ans) == 0: + return "" + ans = ans.lower() + if ans[-1] == '.': + ans = ans[:-1].strip() + if ans.startswith("a "): + ans = ans[2:].strip() + if ans.startswith("an "): + ans = ans[3:].strip() + if ans.startswith("the "): + ans = ans[4:].strip() + if ans in self.ANS_CONVERT: + ans = self.ANS_CONVERT[ans] + return ans + + def ans2id(self, ans): + return self._ans2id_map[ans] + + def id2ans(self, ans_id): + return self._id2ans_map[ans_id] + + def ans2id_map(self): + return self._ans2id_map.copy() + + def id2ans_map(self): + return self._id2ans_map.copy() + + def used(self, ans): + return ans in self.ans_set + + def all_answers(self): + return self.anss.copy() + + @property + def num_answers(self): + return len(self.anss) + + +def load_lxmert_qa(path, model, label2ans): + """ + Load model weights from lxmert pre-training. + The answers in the fine-tuned QA task (indicated by label2ans) + would also be properly initialized with lxmert pre-trained + QA heads. + + :param path: Path to lxmert snapshot. + :param model: LXRT model instance. + :param label2ans: The label2ans dict of fine-tuned QA datasets, like + {0: 'cat', 1: 'dog', ...} + :return: + """ + print("Load QA pre-trained lxmert from %s " % path) + loaded_state_dict = torch.load("%s_LXRT.pth" % path) + model_state_dict = model.state_dict() + + # Handle Multi-GPU pre-training --> Single GPU fine-tuning + for key in list(loaded_state_dict.keys()): + loaded_state_dict[key.replace("module.", '')] = loaded_state_dict.pop(key) + + # Isolate bert model + bert_state_dict = {} + for key, value in loaded_state_dict.items(): + if key.startswith('bert.'): + bert_state_dict[key] = value + + # Isolate answer head + answer_state_dict = {} + for key, value in loaded_state_dict.items(): + if key.startswith("answer_head."): + answer_state_dict[key.replace('answer_head.', '')] = value + + # Do surgery on answer state dict + ans_weight = answer_state_dict['logit_fc.3.weight'] + ans_bias = answer_state_dict['logit_fc.3.bias'] + import copy + new_answer_weight = copy.deepcopy(model_state_dict['logit_fc.3.weight']) + new_answer_bias = copy.deepcopy(model_state_dict['logit_fc.3.bias']) + answer_table = AnswerTable() + loaded = 0 + unload = 0 + if type(label2ans) is list: + label2ans = {label: ans for label, ans in enumerate(label2ans)} + for label, ans in label2ans.items(): + new_ans = answer_table.convert_ans(ans) + if answer_table.used(new_ans): + ans_id_9500 = answer_table.ans2id(new_ans) + new_answer_weight[label] = ans_weight[ans_id_9500] + new_answer_bias[label] = ans_bias[ans_id_9500] + loaded += 1 + else: + new_answer_weight[label] = 0. + new_answer_bias[label] = 0. + unload += 1 + print("Loaded %d answers from LXRTQA pre-training and %d not" % (loaded, unload)) + print() + answer_state_dict['logit_fc.3.weight'] = new_answer_weight + answer_state_dict['logit_fc.3.bias'] = new_answer_bias + + # Load Bert Weights + bert_model_keys = set(model.lxrt_encoder.model.state_dict().keys()) + bert_loaded_keys = set(bert_state_dict.keys()) + assert len(bert_model_keys - bert_loaded_keys) == 0 + model.lxrt_encoder.model.load_state_dict(bert_state_dict, strict=False) + + # Load Answer Logic FC Weights + model_keys = set(model.state_dict().keys()) + ans_loaded_keys = set(answer_state_dict.keys()) + assert len(ans_loaded_keys - model_keys) == 0 + + model.load_state_dict(answer_state_dict, strict=False) + + + diff --git a/lxmert/src/processing_image.py b/lxmert/src/processing_image.py new file mode 100644 index 0000000000000000000000000000000000000000..2df5bdbabe158183d696fdc2c972a9668d72bc97 --- /dev/null +++ b/lxmert/src/processing_image.py @@ -0,0 +1,147 @@ +""" + coding=utf-8 + Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal + Adapted From Facebook Inc, Detectron2 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License.import copy + """ +import sys +from typing import Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + +from lxmert.lxmert.src.vqa_utils import img_tensorize + + +class ResizeShortestEdge: + def __init__(self, short_edge_length, max_size=sys.maxsize): + """ + Args: + short_edge_length (list[min, max]) + max_size (int): maximum allowed longest edge length. + """ + self.interp_method = "bilinear" + self.max_size = max_size + self.short_edge_length = short_edge_length + + def __call__(self, imgs): + img_augs = [] + for img in imgs: + h, w = img.shape[:2] + # later: provide list and randomly choose index for resize + size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1) + if size == 0: + return img + scale = size * 1.0 / min(h, w) + if h < w: + newh, neww = size, scale * w + else: + newh, neww = scale * h, size + if max(newh, neww) > self.max_size: + scale = self.max_size * 1.0 / max(newh, neww) + newh = newh * scale + neww = neww * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + + if img.dtype == np.uint8: + pil_image = Image.fromarray(img) + pil_image = pil_image.resize((neww, newh), Image.BILINEAR) + img = np.asarray(pil_image) + else: + img = img.permute(2, 0, 1).unsqueeze(0) # 3, 0, 1) # hw(c) -> nchw + img = F.interpolate(img, (newh, neww), mode=self.interp_method, align_corners=False).squeeze(0) + img_augs.append(img) + + return img_augs + + +class Preprocess: + def __init__(self, cfg): + self.aug = ResizeShortestEdge([cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST) + self.input_format = cfg.INPUT.FORMAT + self.size_divisibility = cfg.SIZE_DIVISIBILITY + self.pad_value = cfg.PAD_VALUE + self.max_image_size = cfg.INPUT.MAX_SIZE_TEST + self.device = cfg.MODEL.DEVICE + self.pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1) + self.pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1) + self.normalizer = lambda x: (x - self.pixel_mean) / self.pixel_std + + def pad(self, images): + max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) + image_sizes = [im.shape[-2:] for im in images] + images = [ + F.pad( + im, + [0, max_size[-1] - size[1], 0, max_size[-2] - size[0]], + value=self.pad_value, + ) + for size, im in zip(image_sizes, images) + ] + + return torch.stack(images), torch.tensor(image_sizes) + + def __call__(self, images, single_image=False): + with torch.no_grad(): + if not isinstance(images, list): + images = [images] + if single_image: + assert len(images) == 1 + for i in range(len(images)): + if isinstance(images[i], torch.Tensor): + images.insert(i, images.pop(i).to(self.device).float()) + elif not isinstance(images[i], torch.Tensor): + images.insert( + i, + torch.as_tensor(img_tensorize(images.pop(i), input_format=self.input_format)) + .to(self.device) + .float(), + ) + # resize smallest edge + raw_sizes = torch.tensor([im.shape[:2] for im in images]) + images = self.aug(images) + # transpose images and convert to torch tensors + # images = [torch.as_tensor(i.astype("float32")).permute(2, 0, 1).to(self.device) for i in images] + # now normalize before pad to avoid useless arithmetic + images = [self.normalizer(x) for x in images] + # now pad them to do the following operations + images, sizes = self.pad(images) + # Normalize + + if self.size_divisibility > 0: + raise NotImplementedError() + # pad + scales_yx = torch.true_divide(raw_sizes, sizes) + if single_image: + return images[0], sizes[0], scales_yx[0] + else: + return images, sizes, scales_yx + + +def _scale_box(boxes, scale_yx): + boxes[:, 0::2] *= scale_yx[:, 1] + boxes[:, 1::2] *= scale_yx[:, 0] + return boxes + + +def _clip_box(tensor, box_size: Tuple[int, int]): + assert torch.isfinite(tensor).all(), "Box tensor contains infinite or NaN!" + h, w = box_size + tensor[:, 0].clamp_(min=0, max=w) + tensor[:, 1].clamp_(min=0, max=h) + tensor[:, 2].clamp_(min=0, max=w) + tensor[:, 3].clamp_(min=0, max=h) diff --git a/lxmert/src/tasks/__init__.py b/lxmert/src/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lxmert/src/tasks/gqa.py b/lxmert/src/tasks/gqa.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd9959e2699051f3c8119c6fae9be08ff6d6cb1 --- /dev/null +++ b/lxmert/src/tasks/gqa.py @@ -0,0 +1,210 @@ +# coding=utf-8 +# Copyleft 2019 project LXRT. + +import os +import collections + +import torch +from tqdm import tqdm +import torch.nn as nn +from torch.utils.data.dataloader import DataLoader + +from param import args +from pretrain.qa_answer_table import load_lxmert_qa +from tasks.gqa_model import GQAModel +from tasks.gqa_data import GQADataset, GQATorchDataset, GQAEvaluator + + +DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator') + + +def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple: + dset = GQADataset(splits) + tset = GQATorchDataset(dset) + evaluator = GQAEvaluator(dset) + data_loader = DataLoader( + tset, batch_size=bs, + shuffle=shuffle, num_workers=args.num_workers, + drop_last=drop_last, pin_memory=True + ) + + return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator) + + +class GQA: + def __init__(self): + self.train_tuple = get_tuple( + args.train, bs=args.batch_size, shuffle=True, drop_last=True + ) + if args.valid != "": + valid_bsize = 2048 if args.multiGPU else 512 + self.valid_tuple = get_tuple( + args.valid, bs=valid_bsize, + shuffle=False, drop_last=False + ) + else: + self.valid_tuple = None + + self.model = GQAModel(self.train_tuple.dataset.num_answers) + + # Load pre-trained weights + if args.load_lxmert is not None: + self.model.lxrt_encoder.load(args.load_lxmert) + if args.load_lxmert_qa is not None: + load_lxmert_qa(args.load_lxmert_qa, self.model, + label2ans=self.train_tuple.dataset.label2ans) + + # GPU options + self.model = self.model.cuda() + if args.multiGPU: + self.model.lxrt_encoder.multi_gpu() + + # Losses and optimizer + self.bce_loss = nn.BCEWithLogitsLoss() + self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1) + if 'bert' in args.optim: + batch_per_epoch = len(self.train_tuple.loader) + t_total = int(batch_per_epoch * args.epochs) + print("Total Iters: %d" % t_total) + from lxrt.optimization import BertAdam + self.optim = BertAdam(list(self.model.parameters()), + lr=args.lr, + warmup=0.1, + t_total=t_total) + else: + self.optim = args.optimizer(list(self.model.parameters()), args.lr) + + self.output = args.output + os.makedirs(self.output, exist_ok=True) + + def train(self, train_tuple, eval_tuple): + dset, loader, evaluator = train_tuple + iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x) + + best_valid = 0. + for epoch in range(args.epochs): + quesid2ans = {} + for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)): + + self.model.train() + self.optim.zero_grad() + + feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda() + logit = self.model(feats, boxes, sent) + assert logit.dim() == target.dim() == 2 + if args.mce_loss: + max_value, target = target.max(1) + loss = self.mce_loss(logit, target) * logit.size(1) + else: + loss = self.bce_loss(logit, target) + loss = loss * logit.size(1) + + loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), 5.) + self.optim.step() + + score, label = logit.max(1) + for qid, l in zip(ques_id, label.cpu().numpy()): + ans = dset.label2ans[l] + quesid2ans[qid] = ans + + log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.) + + if self.valid_tuple is not None: # Do Validation + valid_score = self.evaluate(eval_tuple) + if valid_score > best_valid: + best_valid = valid_score + self.save("BEST") + + log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ + "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) + + print(log_str, end='') + + with open(self.output + "/log.log", 'a') as f: + f.write(log_str) + f.flush() + + self.save("LAST") + + def predict(self, eval_tuple: DataTuple, dump=None): + self.model.eval() + dset, loader, evaluator = eval_tuple + quesid2ans = {} + for i, datum_tuple in enumerate(loader): + ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target + with torch.no_grad(): + feats, boxes = feats.cuda(), boxes.cuda() + logit = self.model(feats, boxes, sent) + score, label = logit.max(1) + for qid, l in zip(ques_id, label.cpu().numpy()): + ans = dset.label2ans[l] + quesid2ans[qid] = ans + if dump is not None: + evaluator.dump_result(quesid2ans, dump) + return quesid2ans + + def evaluate(self, eval_tuple: DataTuple, dump=None): + dset, loader, evaluator = eval_tuple + quesid2ans = self.predict(eval_tuple, dump) + return evaluator.evaluate(quesid2ans) + + @staticmethod + def oracle_score(data_tuple): + dset, loader, evaluator = data_tuple + quesid2ans = {} + for i, (ques_id, feats, boxes, sent, target) in enumerate(loader): + _, label = target.max(1) + for qid, l in zip(ques_id, label.cpu().numpy()): + ans = dset.label2ans[l] + quesid2ans[qid] = ans + return evaluator.evaluate(quesid2ans) + + def save(self, name): + torch.save(self.model.state_dict(), + os.path.join(self.output, "%s.pth" % name)) + + def load(self, path): + print("Load model from %s" % path) + state_dict = torch.load("%s.pth" % path) + for key in list(state_dict.keys()): + if '.module' in key: + state_dict[key.replace('.module', '')] = state_dict.pop(key) + self.model.load_state_dict(state_dict, strict=False) + + +if __name__ == "__main__": + # Build Class + gqa = GQA() + + # Load Model + if args.load is not None: + gqa.load(args.load) + + # Test or Train + if args.test is not None: + args.fast = args.tiny = False # Always loading all data in test + if 'submit' in args.test: + gqa.predict( + get_tuple(args.test, bs=args.batch_size, + shuffle=False, drop_last=False), + dump=os.path.join(args.output, 'submit_predict.json') + ) + if 'testdev' in args.test: + result = gqa.evaluate( + get_tuple('testdev', bs=args.batch_size, + shuffle=False, drop_last=False), + dump=os.path.join(args.output, 'testdev_predict.json') + ) + print(result) + else: + # print("Train Oracle: %0.2f" % (gqa.oracle_score(gqa.train_tuple) * 100)) + print('Splits in Train data:', gqa.train_tuple.dataset.splits) + if gqa.valid_tuple is not None: + print('Splits in Valid data:', gqa.valid_tuple.dataset.splits) + print("Valid Oracle: %0.2f" % (gqa.oracle_score(gqa.valid_tuple) * 100)) + else: + print("DO NOT USE VALIDATION") + gqa.train(gqa.train_tuple, gqa.valid_tuple) + + diff --git a/lxmert/src/tasks/gqa_data.py b/lxmert/src/tasks/gqa_data.py new file mode 100644 index 0000000000000000000000000000000000000000..bfcc6a449a66bff01ed84e5b0e833bf1d1cb8ddf --- /dev/null +++ b/lxmert/src/tasks/gqa_data.py @@ -0,0 +1,194 @@ +# coding=utf-8 +# Copyleft 2019 project LXRT. + +import json + +import numpy as np +import torch +from torch.utils.data import Dataset + +from param import args +from utils import load_obj_tsv + +# Load part of the dataset for fast checking. +# Notice that here is the number of images instead of the number of data, +# which means all related data to the images would be used. +TINY_IMG_NUM = 512 +FAST_IMG_NUM = 5000 + + +class GQADataset: + """ + A GQA data example in json file: + { + "img_id": "2375429", + "label": { + "pipe": 1.0 + }, + "question_id": "07333408", + "sent": "What is on the white wall?" + } + """ + def __init__(self, splits: str): + self.name = splits + self.splits = splits.split(',') + + # Loading datasets to data + self.data = [] + for split in self.splits: + self.data.extend(json.load(open("data/gqa/%s.json" % split))) + print("Load %d data from split(s) %s." % (len(self.data), self.name)) + + # List to dict (for evaluation and others) + self.id2datum = { + datum['question_id']: datum + for datum in self.data + } + + # Answers + self.ans2label = json.load(open("data/gqa/trainval_ans2label.json")) + self.label2ans = json.load(open("data/gqa/trainval_label2ans.json")) + assert len(self.ans2label) == len(self.label2ans) + for ans, label in self.ans2label.items(): + assert self.label2ans[label] == ans + + @property + def num_answers(self): + return len(self.ans2label) + + def __len__(self): + return len(self.data) + + +class GQABufferLoader(): + def __init__(self): + self.key2data = {} + + def load_data(self, name, number): + if name == 'testdev': + path = "data/vg_gqa_imgfeat/gqa_testdev_obj36.tsv" + else: + path = "data/vg_gqa_imgfeat/vg_gqa_obj36.tsv" + key = "%s_%d" % (path, number) + if key not in self.key2data: + self.key2data[key] = load_obj_tsv( + path, + topk=number + ) + return self.key2data[key] + + +gqa_buffer_loader = GQABufferLoader() + + +""" +Example in obj tsv: +FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", + "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] +""" +class GQATorchDataset(Dataset): + def __init__(self, dataset: GQADataset): + super().__init__() + self.raw_dataset = dataset + + if args.tiny: + topk = TINY_IMG_NUM + elif args.fast: + topk = FAST_IMG_NUM + else: + topk = -1 + + # Loading detection features to img_data + # Since images in train and valid both come from Visual Genome, + # buffer the image loading to save memory. + img_data = [] + if 'testdev' in dataset.splits or 'testdev_all' in dataset.splits: # Always loading all the data in testdev + img_data.extend(gqa_buffer_loader.load_data('testdev', -1)) + else: + img_data.extend(gqa_buffer_loader.load_data('train', topk)) + self.imgid2img = {} + for img_datum in img_data: + self.imgid2img[img_datum['img_id']] = img_datum + + # Only kept the data with loaded image features + self.data = [] + for datum in self.raw_dataset.data: + if datum['img_id'] in self.imgid2img: + self.data.append(datum) + print("Use %d data in torch dataset" % (len(self.data))) + print() + + def __len__(self): + return len(self.data) + + def __getitem__(self, item: int): + datum = self.data[item] + + img_id = datum['img_id'] + ques_id = datum['question_id'] + ques = datum['sent'] + + # Get image info + img_info = self.imgid2img[img_id] + obj_num = img_info['num_boxes'] + boxes = img_info['boxes'].copy() + feats = img_info['features'].copy() + assert len(boxes) == len(feats) == obj_num + + # Normalize the boxes (to 0 ~ 1) + img_h, img_w = img_info['img_h'], img_info['img_w'] + boxes = boxes.copy() + boxes[:, (0, 2)] /= img_w + boxes[:, (1, 3)] /= img_h + np.testing.assert_array_less(boxes, 1+1e-5) + np.testing.assert_array_less(-boxes, 0+1e-5) + + # Create target + if 'label' in datum: + label = datum['label'] + target = torch.zeros(self.raw_dataset.num_answers) + for ans, score in label.items(): + if ans in self.raw_dataset.ans2label: + target[self.raw_dataset.ans2label[ans]] = score + return ques_id, feats, boxes, ques, target + else: + return ques_id, feats, boxes, ques + + +class GQAEvaluator: + def __init__(self, dataset: GQADataset): + self.dataset = dataset + + def evaluate(self, quesid2ans: dict): + score = 0. + for quesid, ans in quesid2ans.items(): + datum = self.dataset.id2datum[quesid] + label = datum['label'] + if ans in label: + score += label[ans] + return score / len(quesid2ans) + + def dump_result(self, quesid2ans: dict, path): + """ + Dump the result to a GQA-challenge submittable json file. + GQA json file submission requirement: + results = [result] + result = { + "questionId": str, # Note: it's a actually an int number but the server requires an str. + "prediction": str + } + + :param quesid2ans: A dict mapping question id to its predicted answer. + :param path: The file path to save the json file. + :return: + """ + with open(path, 'w') as f: + result = [] + for ques_id, ans in quesid2ans.items(): + result.append({ + 'questionId': ques_id, + 'prediction': ans + }) + json.dump(result, f, indent=4, sort_keys=True) + + diff --git a/lxmert/src/tasks/gqa_model.py b/lxmert/src/tasks/gqa_model.py new file mode 100644 index 0000000000000000000000000000000000000000..03726afa0deaf28b1670943b30190f14aac7ec17 --- /dev/null +++ b/lxmert/src/tasks/gqa_model.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# Copyleft 2019 project LXRT. + +import torch.nn as nn + +from param import args +from lxrt.entry import LXRTEncoder +from lxrt.modeling import BertLayerNorm, GeLU + +# Max length including and +MAX_GQA_LENGTH = 20 + + +class GQAModel(nn.Module): + def __init__(self, num_answers): + super().__init__() + self.lxrt_encoder = LXRTEncoder( + args, + max_seq_length=MAX_GQA_LENGTH + ) + hid_dim = self.lxrt_encoder.dim + self.logit_fc = nn.Sequential( + nn.Linear(hid_dim, hid_dim * 2), + GeLU(), + BertLayerNorm(hid_dim * 2, eps=1e-12), + nn.Linear(hid_dim * 2, num_answers) + ) + self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights) + + def forward(self, feat, pos, sent): + """ + b -- batch_size, o -- object_number, f -- visual_feature_size + + :param feat: (b, o, f) + :param pos: (b, o, 4) + :param sent: (b,) Type -- list of string + :param leng: (b,) Type -- int numpy array + :return: (b, num_answer) The logit of each answers. + """ + x = self.lxrt_encoder(sent, (feat, pos)) + logit = self.logit_fc(x) + + return logit + + diff --git a/lxmert/src/tasks/nlvr2.py b/lxmert/src/tasks/nlvr2.py new file mode 100644 index 0000000000000000000000000000000000000000..74d654d286bce86820e651c489963dac1f23b10d --- /dev/null +++ b/lxmert/src/tasks/nlvr2.py @@ -0,0 +1,182 @@ +# coding=utf-8 +# Copyleft 2019 project LXRT. + +import os +import collections + +from tqdm import tqdm +import torch +import torch.nn as nn +from torch.utils.data.dataloader import DataLoader + +from param import args +from tasks.nlvr2_model import NLVR2Model +from tasks.nlvr2_data import NLVR2Dataset, NLVR2TorchDataset, NLVR2Evaluator + +DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator') + + +def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple: + dset = NLVR2Dataset(splits) + tset = NLVR2TorchDataset(dset) + evaluator = NLVR2Evaluator(dset) + data_loader = DataLoader( + tset, batch_size=bs, + shuffle=shuffle, num_workers=args.num_workers, + drop_last=drop_last, pin_memory=True + ) + + return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator) + + +class NLVR2: + def __init__(self): + self.train_tuple = get_tuple( + args.train, bs=args.batch_size, shuffle=True, drop_last=True + ) + if args.valid != "": + valid_bsize = 2048 if args.multiGPU else 512 + self.valid_tuple = get_tuple( + args.valid, bs=valid_bsize, + shuffle=False, drop_last=False + ) + else: + self.valid_tuple = None + + self.model = NLVR2Model() + + # Load pre-trained weights + if args.load_lxmert is not None: + self.model.lxrt_encoder.load(args.load_lxmert) + + # GPU options + if args.multiGPU: + self.model.lxrt_encoder.multi_gpu() + self.model = self.model.cuda() + + # Losses and optimizer + self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1) + if 'bert' in args.optim: + batch_per_epoch = len(self.train_tuple.loader) + t_total = int(batch_per_epoch * args.epochs) + print("Total Iters: %d" % t_total) + from lxrt.optimization import BertAdam + self.optim = BertAdam(list(self.model.parameters()), + lr=args.lr, + warmup=0.1, + t_total=t_total) + else: + self.optim = args.optimizer(list(self.model.parameters()), args.lr) + + self.output = args.output + os.makedirs(self.output, exist_ok=True) + + def train(self, train_tuple, eval_tuple): + dset, loader, evaluator = train_tuple + iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x) + + best_valid = 0. + for epoch in range(args.epochs): + quesid2ans = {} + for i, (ques_id, feats, boxes, sent, label) in iter_wrapper(enumerate(loader)): + self.model.train() + + self.optim.zero_grad() + feats, boxes, label = feats.cuda(), boxes.cuda(), label.cuda() + logit = self.model(feats, boxes, sent) + + loss = self.mce_loss(logit, label) + + loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), 5.) + self.optim.step() + + score, predict = logit.max(1) + for qid, l in zip(ques_id, predict.cpu().numpy()): + quesid2ans[qid] = l + + log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.) + + if self.valid_tuple is not None: # Do Validation + valid_score = self.evaluate(eval_tuple) + if valid_score > best_valid: + best_valid = valid_score + self.save("BEST") + + log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ + "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) + + print(log_str, end='') + + with open(self.output + "/log.log", 'a') as f: + f.write(log_str) + f.flush() + + self.save("LAST") + + def predict(self, eval_tuple: DataTuple, dump=None): + self.model.eval() + dset, loader, evaluator = eval_tuple + quesid2ans = {} + for i, datum_tuple in enumerate(loader): + ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target + with torch.no_grad(): + feats, boxes = feats.cuda(), boxes.cuda() + logit = self.model(feats, boxes, sent) + score, predict = logit.max(1) + for qid, l in zip(ques_id, predict.cpu().numpy()): + quesid2ans[qid] = l + if dump is not None: + evaluator.dump_result(quesid2ans, dump) + return quesid2ans + + def evaluate(self, eval_tuple: DataTuple, dump=None): + dset, loader, evaluator = eval_tuple + quesid2ans = self.predict(eval_tuple, dump) + return evaluator.evaluate(quesid2ans) + + def save(self, name): + torch.save(self.model.state_dict(), + os.path.join(self.output, "%s.pth" % name)) + + def load(self, path): + print("Load model from %s" % path) + state_dict = torch.load("%s.pth" % path) + self.model.load_state_dict(state_dict) + + +if __name__ == "__main__": + # Build Class + nlvr2 = NLVR2() + + # Load Model + if args.load is not None: + nlvr2.load(args.load) + + # Test or Train + if args.test is not None: + args.fast = args.tiny = False # Always loading all data in test + if 'hidden' in args.test: + nlvr2.predict( + get_tuple(args.test, bs=args.batch_size, + shuffle=False, drop_last=False), + dump=os.path.join(args.output, 'hidden_predict.csv') + ) + elif 'test' in args.test or 'valid' in args.test: + result = nlvr2.evaluate( + get_tuple(args.test, bs=args.batch_size, + shuffle=False, drop_last=False), + dump=os.path.join(args.output, '%s_predict.csv' % args.test) + ) + print(result) + else: + assert False, "No such test option for %s" % args.test + else: + print('Splits in Train data:', nlvr2.train_tuple.dataset.splits) + if nlvr2.valid_tuple is not None: + print('Splits in Valid data:', nlvr2.valid_tuple.dataset.splits) + else: + print("DO NOT USE VALIDATION") + nlvr2.train(nlvr2.train_tuple, nlvr2.valid_tuple) + + diff --git a/lxmert/src/tasks/nlvr2_data.py b/lxmert/src/tasks/nlvr2_data.py new file mode 100644 index 0000000000000000000000000000000000000000..bca3e604a051b04a090a23a151a2a131b190bab7 --- /dev/null +++ b/lxmert/src/tasks/nlvr2_data.py @@ -0,0 +1,157 @@ +# coding=utf-8 +# Copyleft 2019 project LXRT. + +import json + +import numpy as np +from torch.utils.data import Dataset + +from param import args +from utils import load_obj_tsv + +# Load part of the dataset for fast checking. +# Notice that here is the number of images instead of the number of data, +# which means all related data to the images would be used. +TINY_IMG_NUM = 512 +FAST_IMG_NUM = 5000 + + +class NLVR2Dataset: + """ + An NLVR2 data example in json file: + { + "identifier": "train-10171-0-0", + "img0": "train-10171-0-img0", + "img1": "train-10171-0-img1", + "label": 0, + "sent": "An image shows one leather pencil case, displayed open with writing implements tucked inside. + ", + "uid": "nlvr2_train_0" + } + """ + def __init__(self, splits: str): + self.name = splits + self.splits = splits.split(',') + + # Loading datasets to data + self.data = [] + for split in self.splits: + self.data.extend(json.load(open("data/nlvr2/%s.json" % split))) + print("Load %d data from split(s) %s." % (len(self.data), self.name)) + + # List to dict (for evaluation and others) + self.id2datum = { + datum['uid']: datum + for datum in self.data + } + + def __len__(self): + return len(self.data) + + +""" +An example in obj36 tsv: +FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", + "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] +FIELDNAMES would be keys in the dict returned by load_obj_tsv. +""" +class NLVR2TorchDataset(Dataset): + def __init__(self, dataset: NLVR2Dataset): + super().__init__() + self.raw_dataset = dataset + + if args.tiny: + topk = TINY_IMG_NUM + elif args.fast: + topk = FAST_IMG_NUM + else: + topk = -1 + + # Loading detection features to img_data + img_data = [] + if 'train' in dataset.splits: + img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/train_obj36.tsv', topk=topk)) + if 'valid' in dataset.splits: + img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/valid_obj36.tsv', topk=topk)) + if 'test' in dataset.name: + img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/test_obj36.tsv', topk=topk)) + self.imgid2img = {} + for img_datum in img_data: + self.imgid2img[img_datum['img_id']] = img_datum + + # Filter out the dataset + self.data = [] + for datum in self.raw_dataset.data: + if datum['img0'] in self.imgid2img and datum['img1'] in self.imgid2img: + self.data.append(datum) + print("Use %d data in torch dataset" % (len(self.data))) + print() + + def __len__(self): + return len(self.data) + + def __getitem__(self, item: int): + datum = self.data[item] + + ques_id = datum['uid'] + ques = datum['sent'] + + # Get image info + boxes2 = [] + feats2 = [] + for key in ['img0', 'img1']: + img_id = datum[key] + img_info = self.imgid2img[img_id] + boxes = img_info['boxes'].copy() + feats = img_info['features'].copy() + assert len(boxes) == len(feats) + + # Normalize the boxes (to 0 ~ 1) + img_h, img_w = img_info['img_h'], img_info['img_w'] + boxes[..., (0, 2)] /= img_w + boxes[..., (1, 3)] /= img_h + np.testing.assert_array_less(boxes, 1+1e-5) + np.testing.assert_array_less(-boxes, 0+1e-5) + + boxes2.append(boxes) + feats2.append(feats) + feats = np.stack(feats2) + boxes = np.stack(boxes2) + + # Create target + if 'label' in datum: + label = datum['label'] + return ques_id, feats, boxes, ques, label + else: + return ques_id, feats, boxes, ques + + +class NLVR2Evaluator: + def __init__(self, dataset: NLVR2Dataset): + self.dataset = dataset + + def evaluate(self, quesid2ans: dict): + score = 0. + for quesid, ans in quesid2ans.items(): + datum = self.dataset.id2datum[quesid] + label = datum['label'] + if ans == label: + score += 1 + return score / len(quesid2ans) + + def dump_result(self, quesid2ans: dict, path): + """ + Dump result to a CSV file, which is compatible with NLVR2 evaluation system. + NLVR2 CSV file requirement: + Each line contains: identifier, answer + + :param quesid2ans: nlvr2 uid to ans (either "True" or "False") + :param path: The desired path of saved file. + :return: + """ + with open(path, 'w') as f: + for uid, ans in quesid2ans.items(): + idt = self.dataset.id2datum[uid]["identifier"] + ans = 'True' if ans == 1 else 'False' + f.write("%s,%s\n" % (idt, ans)) + diff --git a/lxmert/src/tasks/nlvr2_model.py b/lxmert/src/tasks/nlvr2_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ef93474403461f18461d1da85fb8877b6f6b5364 --- /dev/null +++ b/lxmert/src/tasks/nlvr2_model.py @@ -0,0 +1,55 @@ +# coding=utf-8 +# Copyleft 2019 project LXRT. + +import torch.nn as nn +from lxrt.modeling import GeLU, BertLayerNorm +from lxrt.entry import LXRTEncoder +from param import args + + +class NLVR2Model(nn.Module): + def __init__(self): + super().__init__() + self.lxrt_encoder = LXRTEncoder( + args, + max_seq_length=20 + ) + self.hid_dim = hid_dim = self.lxrt_encoder.dim + self.logit_fc = nn.Sequential( + nn.Linear(hid_dim * 2, hid_dim * 2), + GeLU(), + BertLayerNorm(hid_dim * 2, eps=1e-12), + nn.Linear(hid_dim * 2, 2) + ) + self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights) + + def forward(self, feat, pos, sent): + """ + :param feat: b, 2, o, f + :param pos: b, 2, o, 4 + :param sent: b, (string) + :param leng: b, (numpy, int) + :return: + """ + # Pairing images and sentences: + # The input of NLVR2 is two images and one sentence. In batch level, they are saved as + # [ [img0_0, img0_1], [img1_0, img1_1], ...] and [sent0, sent1, ...] + # Here, we flat them to + # feat/pos = [ img0_0, img0_1, img1_0, img1_1, ...] + # sent = [ sent0, sent0, sent1, sent1, ...] + sent = sum(zip(sent, sent), ()) + batch_size, img_num, obj_num, feat_size = feat.size() + assert img_num == 2 and obj_num == 36 and feat_size == 2048 + feat = feat.view(batch_size * 2, obj_num, feat_size) + pos = pos.view(batch_size * 2, obj_num, 4) + + # Extract feature --> Concat + x = self.lxrt_encoder(sent, (feat, pos)) + x = x.view(-1, self.hid_dim*2) + + # Compute logit of answers + logit = self.logit_fc(x) + + return logit + + diff --git a/lxmert/src/tasks/vqa.py b/lxmert/src/tasks/vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..bd69fe6984383aa285a3bce3e32f3329da6d00e7 --- /dev/null +++ b/lxmert/src/tasks/vqa.py @@ -0,0 +1,214 @@ +# coding=utf-8 +# Copyleft 2019 project LXRT. + +import os +import collections + +import torch +import torch.nn as nn +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from ..param import args +from ..pretrain.qa_answer_table import load_lxmert_qa +from .vqa_model import VQAModel +from .vqa_data import VQADataset, VQATorchDataset, VQAEvaluator + +DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator') + + +def get_data_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple: + dset = VQADataset(splits) + tset = VQATorchDataset(dset) + evaluator = VQAEvaluator(dset) + data_loader = DataLoader( + tset, batch_size=bs, + shuffle=shuffle, num_workers=args.num_workers, + drop_last=drop_last, pin_memory=True + ) + + return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator) + + +class VQA: + def __init__(self): + # Datasets + self.train_tuple = get_data_tuple( + args.train, bs=args.batch_size, shuffle=True, drop_last=True + ) + if args.valid != "": + self.valid_tuple = get_data_tuple( + args.valid, bs=1024, + shuffle=False, drop_last=False + ) + else: + self.valid_tuple = None + + # Model + self.model = VQAModel(self.train_tuple.dataset.num_answers) + + # Load pre-trained weights + if args.load_lxmert is not None: + self.model.lxrt_encoder.load(args.load_lxmert) + if args.load_lxmert_qa is not None: + load_lxmert_qa(args.load_lxmert_qa, self.model, + label2ans=self.train_tuple.dataset.label2ans) + + # GPU options + self.model = self.model.cuda() + if args.multiGPU: + self.model.lxrt_encoder.multi_gpu() + + # Loss and Optimizer + self.bce_loss = nn.BCEWithLogitsLoss() + if 'bert' in args.optim: + batch_per_epoch = len(self.train_tuple.loader) + t_total = int(batch_per_epoch * args.epochs) + print("BertAdam Total Iters: %d" % t_total) + from ..lxrt.optimization import BertAdam + self.optim = BertAdam(list(self.model.parameters()), + lr=args.lr, + warmup=0.1, + t_total=t_total) + else: + self.optim = args.optimizer(self.model.parameters(), args.lr) + + # Output Directory + self.output = args.output + os.makedirs(self.output, exist_ok=True) + + def train(self, train_tuple, eval_tuple): + dset, loader, evaluator = train_tuple + iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x) + + best_valid = 0. + for epoch in range(args.epochs): + quesid2ans = {} + for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)): + + self.model.train() + self.optim.zero_grad() + + feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda() + logit = self.model(feats, boxes, sent) + assert logit.dim() == target.dim() == 2 + loss = self.bce_loss(logit, target) + loss = loss * logit.size(1) + + loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), 5.) + self.optim.step() + + score, label = logit.max(1) + for qid, l in zip(ques_id, label.cpu().numpy()): + ans = dset.label2ans[l] + quesid2ans[qid.item()] = ans + + log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.) + + if self.valid_tuple is not None: # Do Validation + valid_score = self.evaluate(eval_tuple) + if valid_score > best_valid: + best_valid = valid_score + self.save("BEST") + + log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ + "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) + + print(log_str, end='') + + with open(self.output + "/log.log", 'a') as f: + f.write(log_str) + f.flush() + + self.save("LAST") + + def predict(self, eval_tuple: DataTuple, dump=None): + """ + Predict the answers to questions in a data split. + + :param eval_tuple: The data tuple to be evaluated. + :param dump: The path of saved file to dump results. + :return: A dict of question_id to answer. + """ + self.model.eval() + dset, loader, evaluator = eval_tuple + quesid2ans = {} + for i, datum_tuple in enumerate(loader): + ques_id, feats, boxes, sent = datum_tuple[:4] # Avoid seeing ground truth + with torch.no_grad(): + feats, boxes = feats.cuda(), boxes.cuda() + logit = self.model(feats, boxes, sent) + score, label = logit.max(1) + for qid, l in zip(ques_id, label.cpu().numpy()): + ans = dset.label2ans[l] + quesid2ans[qid.item()] = ans + if dump is not None: + evaluator.dump_result(quesid2ans, dump) + return quesid2ans + + def evaluate(self, eval_tuple: DataTuple, dump=None): + """Evaluate all data in data_tuple.""" + quesid2ans = self.predict(eval_tuple, dump) + return eval_tuple.evaluator.evaluate(quesid2ans) + + @staticmethod + def oracle_score(data_tuple): + dset, loader, evaluator = data_tuple + quesid2ans = {} + for i, (ques_id, feats, boxes, sent, target) in enumerate(loader): + _, label = target.max(1) + for qid, l in zip(ques_id, label.cpu().numpy()): + ans = dset.label2ans[l] + quesid2ans[qid.item()] = ans + return evaluator.evaluate(quesid2ans) + + def save(self, name): + torch.save(self.model.state_dict(), + os.path.join(self.output, "%s.pth" % name)) + + def load(self, path): + print("Load model from %s" % path) + state_dict = torch.load("%s.pth" % path) + self.model.load_state_dict(state_dict) + + +if __name__ == "__main__": + # Build Class + vqa = VQA() + + # Load VQA model weights + # Note: It is different from loading lxmert pre-trained weights. + if args.load is not None: + vqa.load(args.load) + + # Test or Train + if args.test is not None: + args.fast = args.tiny = False # Always loading all data in test + if 'test' in args.test: + vqa.predict( + get_data_tuple(args.test, bs=950, + shuffle=False, drop_last=False), + dump=os.path.join(args.output, 'test_predict.json') + ) + elif 'val' in args.test: + # Since part of valididation data are used in pre-training/fine-tuning, + # only validate on the minival set. + result = vqa.evaluate( + get_data_tuple('minival', bs=950, + shuffle=False, drop_last=False), + dump=os.path.join(args.output, 'minival_predict.json') + ) + print(result) + else: + assert False, "No such test option for %s" % args.test + else: + print('Splits in Train data:', vqa.train_tuple.dataset.splits) + if vqa.valid_tuple is not None: + print('Splits in Valid data:', vqa.valid_tuple.dataset.splits) + print("Valid Oracle: %0.2f" % (vqa.oracle_score(vqa.valid_tuple) * 100)) + else: + print("DO NOT USE VALIDATION") + vqa.train(vqa.train_tuple, vqa.valid_tuple) + + diff --git a/lxmert/src/tasks/vqa_data.py b/lxmert/src/tasks/vqa_data.py new file mode 100644 index 0000000000000000000000000000000000000000..831c095c3c9bb2cffde2c4230fc278aba7ed0fac --- /dev/null +++ b/lxmert/src/tasks/vqa_data.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# Copyleft 2019 project LXRT. + +import json +import os +import pickle + +import numpy as np +import torch +from torch.utils.data import Dataset + +from ..param import args +from ..utils import load_obj_tsv + +# Load part of the dataset for fast checking. +# Notice that here is the number of images instead of the number of data, +# which means all related data to the images would be used. +TINY_IMG_NUM = 512 +FAST_IMG_NUM = 5000 + +# The path to data and image features. +VQA_DATA_ROOT = 'data/vqa/' +MSCOCO_IMGFEAT_ROOT = 'data/mscoco_imgfeat/' +SPLIT2NAME = { + 'train': 'train2014', + 'valid': 'val2014', + 'minival': 'val2014', + 'nominival': 'val2014', + 'test': 'test2015', +} + + +class VQADataset: + """ + A VQA data example in json file: + { + "answer_type": "other", + "img_id": "COCO_train2014_000000458752", + "label": { + "net": 1 + }, + "question_id": 458752000, + "question_type": "what is this", + "sent": "What is this photo taken looking through?" + } + """ + def __init__(self, splits: str): + self.name = splits + self.splits = splits.split(',') + + # Loading datasets + self.data = [] + for split in self.splits: + self.data.extend(json.load(open("data/vqa/%s.json" % split))) + print("Load %d data from split(s) %s." % (len(self.data), self.name)) + + # Convert list to dict (for evaluation) + self.id2datum = { + datum['question_id']: datum + for datum in self.data + } + + # Answers + self.ans2label = json.load(open("data/vqa/trainval_ans2label.json")) + self.label2ans = json.load(open("data/vqa/trainval_label2ans.json")) + assert len(self.ans2label) == len(self.label2ans) + + @property + def num_answers(self): + return len(self.ans2label) + + def __len__(self): + return len(self.data) + + +""" +An example in obj36 tsv: +FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", + "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] +FIELDNAMES would be keys in the dict returned by load_obj_tsv. +""" +class VQATorchDataset(Dataset): + def __init__(self, dataset: VQADataset): + super().__init__() + self.raw_dataset = dataset + + if args.tiny: + topk = TINY_IMG_NUM + elif args.fast: + topk = FAST_IMG_NUM + else: + topk = None + + # Loading detection features to img_data + img_data = [] + for split in dataset.splits: + # Minival is 5K images in MS COCO, which is used in evaluating VQA/lxmert-pre-training. + # It is saved as the top 5K features in val2014_***.tsv + load_topk = 5000 if (split == 'minival' and topk is None) else topk + img_data.extend(load_obj_tsv( + os.path.join(MSCOCO_IMGFEAT_ROOT, '%s_obj36.tsv' % (SPLIT2NAME[split])), + topk=load_topk)) + + # Convert img list to dict + self.imgid2img = {} + for img_datum in img_data: + self.imgid2img[img_datum['img_id']] = img_datum + + # Only kept the data with loaded image features + self.data = [] + for datum in self.raw_dataset.data: + if datum['img_id'] in self.imgid2img: + self.data.append(datum) + print("Use %d data in torch dataset" % (len(self.data))) + print() + + def __len__(self): + return len(self.data) + + def __getitem__(self, item: int): + datum = self.data[item] + + img_id = datum['img_id'] + ques_id = datum['question_id'] + ques = datum['sent'] + + # Get image info + img_info = self.imgid2img[img_id] + obj_num = img_info['num_boxes'] + feats = img_info['features'].copy() + boxes = img_info['boxes'].copy() + assert obj_num == len(boxes) == len(feats) + + # Normalize the boxes (to 0 ~ 1) + img_h, img_w = img_info['img_h'], img_info['img_w'] + boxes = boxes.copy() + boxes[:, (0, 2)] /= img_w + boxes[:, (1, 3)] /= img_h + np.testing.assert_array_less(boxes, 1+1e-5) + np.testing.assert_array_less(-boxes, 0+1e-5) + + # Provide label (target) + if 'label' in datum: + label = datum['label'] + target = torch.zeros(self.raw_dataset.num_answers) + for ans, score in label.items(): + target[self.raw_dataset.ans2label[ans]] = score + return ques_id, feats, boxes, ques, target + else: + return ques_id, feats, boxes, ques + + +class VQAEvaluator: + def __init__(self, dataset: VQADataset): + self.dataset = dataset + + def evaluate(self, quesid2ans: dict): + score = 0. + for quesid, ans in quesid2ans.items(): + datum = self.dataset.id2datum[quesid] + label = datum['label'] + if ans in label: + score += label[ans] + return score / len(quesid2ans) + + def dump_result(self, quesid2ans: dict, path): + """ + Dump results to a json file, which could be submitted to the VQA online evaluation. + VQA json file submission requirement: + results = [result] + result = { + "question_id": int, + "answer": str + } + + :param quesid2ans: dict of quesid --> ans + :param path: The desired path of saved file. + """ + with open(path, 'w') as f: + result = [] + for ques_id, ans in quesid2ans.items(): + result.append({ + 'question_id': ques_id, + 'answer': ans + }) + json.dump(result, f, indent=4, sort_keys=True) + + diff --git a/lxmert/src/tasks/vqa_model.py b/lxmert/src/tasks/vqa_model.py new file mode 100644 index 0000000000000000000000000000000000000000..792e07344b9c2371edb51db86ef08b1e57711497 --- /dev/null +++ b/lxmert/src/tasks/vqa_model.py @@ -0,0 +1,50 @@ +# coding=utf-8 +# Copyleft 2019 project LXRT. + +import torch.nn as nn + +from ..param import args +from ..lxrt.entry import LXRTEncoder +from ..lxrt.modeling import BertLayerNorm, GeLU +from transformers import AutoTokenizer, AutoModelForQuestionAnswering + +# Max length including and +MAX_VQA_LENGTH = 20 + + +class VQAModel(nn.Module): + def __init__(self, num_answers): + super().__init__() + + # # Build LXRT encoder + # self.lxrt_encoder = LXRTEncoder( + # args, + # max_seq_length=MAX_VQA_LENGTH + # ) + # hid_dim = self.lxrt_encoder.dim + # + # # VQA Answer heads + # self.logit_fc = nn.Sequential( + # nn.Linear(hid_dim, hid_dim * 2), + # GeLU(), + # BertLayerNorm(hid_dim * 2, eps=1e-12), + # nn.Linear(hid_dim * 2, num_answers) + # ) + # self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights) + + self.tokenizer = AutoTokenizer.from_pretrained("unc-nlp/lxmert-vqa-uncased") + self.model = AutoModelForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased") + + def forward(self, feat, pos, sent): + """ + b -- batch_size, o -- object_number, f -- visual_feature_size + + :param feat: (b, o, f) + :param pos: (b, o, 4) + :param sent: (b,) Type -- list of string + :param leng: (b,) Type -- int numpy array + :return: (b, num_answer) The logit of each answers. + """ + return self.model(sent, feat, pos) + + diff --git a/lxmert/src/utils.py b/lxmert/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2346dddd78bcb67e5cd06ea8c5ca128eaf9b9faf --- /dev/null +++ b/lxmert/src/utils.py @@ -0,0 +1,55 @@ +# coding=utf-8 +# Copyleft 2019 Project LXRT + +import sys +import csv +import base64 +import time + +import numpy as np + +csv.field_size_limit(sys.maxsize) +FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf", + "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"] + + +def load_obj_tsv(fname, topk=None): + """Load object features from tsv file. + + :param fname: The path to the tsv file. + :param topk: Only load features for top K images (lines) in the tsv file. + Will load all the features if topk is either -1 or None. + :return: A list of image object features where each feature is a dict. + See FILENAMES above for the keys in the feature dict. + """ + data = [] + start_time = time.time() + print("Start to load Faster-RCNN detected objects from %s" % fname) + with open(fname) as f: + reader = csv.DictReader(f, FIELDNAMES, delimiter="\t") + for i, item in enumerate(reader): + + for key in ['img_h', 'img_w', 'num_boxes']: + item[key] = int(item[key]) + + boxes = item['num_boxes'] + decode_config = [ + ('objects_id', (boxes, ), np.int64), + ('objects_conf', (boxes, ), np.float32), + ('attrs_id', (boxes, ), np.int64), + ('attrs_conf', (boxes, ), np.float32), + ('boxes', (boxes, 4), np.float32), + ('features', (boxes, -1), np.float32), + ] + for key, shape, dtype in decode_config: + item[key] = np.frombuffer(base64.b64decode(item[key]), dtype=dtype) + item[key] = item[key].reshape(shape) + item[key].setflags(write=False) + + data.append(item) + if topk is not None and len(data) == topk: + break + elapsed_time = time.time() - start_time + print("Loaded %d images in file %s in %d seconds." % (len(data), fname, elapsed_time)) + return data + diff --git a/lxmert/src/visualizing_image.py b/lxmert/src/visualizing_image.py new file mode 100644 index 0000000000000000000000000000000000000000..b7508249c2db032f25a64cb811cd9d8696275570 --- /dev/null +++ b/lxmert/src/visualizing_image.py @@ -0,0 +1,499 @@ +""" + coding=utf-8 + Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal + Adapted From Facebook Inc, Detectron2 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License.import copy + """ +import colorsys +import io + +import matplotlib as mpl +import matplotlib.colors as mplc +import matplotlib.figure as mplfigure +import numpy as np +import torch +from matplotlib.backends.backend_agg import FigureCanvasAgg + +import cv2 +from src.utils import img_tensorize + + +_SMALL_OBJ = 1000 + + +class SingleImageViz: + def __init__( + self, + img, + scale=1.2, + edgecolor="g", + alpha=0.5, + linestyle="-", + saveas="test_out.jpg", + rgb=True, + pynb=False, + id2obj=None, + id2attr=None, + pad=0.7, + ): + """ + img: an RGB image of shape (H, W, 3). + """ + if isinstance(img, torch.Tensor): + img = img.numpy().astype("np.uint8") + if isinstance(img, str): + img = img_tensorize(img) + assert isinstance(img, np.ndarray) + + width, height = img.shape[1], img.shape[0] + fig = mplfigure.Figure(frameon=False) + dpi = fig.get_dpi() + width_in = (width * scale + 1e-2) / dpi + height_in = (height * scale + 1e-2) / dpi + fig.set_size_inches(width_in, height_in) + ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) + ax.axis("off") + ax.set_xlim(0.0, width) + ax.set_ylim(height) + + self.saveas = saveas + self.rgb = rgb + self.pynb = pynb + self.img = img + self.edgecolor = edgecolor + self.alpha = 0.5 + self.linestyle = linestyle + self.font_size = int(np.sqrt(min(height, width)) * scale // 3) + self.width = width + self.height = height + self.scale = scale + self.fig = fig + self.ax = ax + self.pad = pad + self.id2obj = id2obj + self.id2attr = id2attr + self.canvas = FigureCanvasAgg(fig) + + def add_box(self, box, color=None): + if color is None: + color = self.edgecolor + (x0, y0, x1, y1) = box + width = x1 - x0 + height = y1 - y0 + self.ax.add_patch( + mpl.patches.Rectangle( + (x0, y0), + width, + height, + fill=False, + edgecolor=color, + linewidth=self.font_size // 3, + alpha=self.alpha, + linestyle=self.linestyle, + ) + ) + + def draw_boxes(self, boxes, obj_ids=None, obj_scores=None, attr_ids=None, attr_scores=None): + if len(boxes.shape) > 2: + boxes = boxes[0] + if len(obj_ids.shape) > 1: + obj_ids = obj_ids[0] + if len(obj_scores.shape) > 1: + obj_scores = obj_scores[0] + if len(attr_ids.shape) > 1: + attr_ids = attr_ids[0] + if len(attr_scores.shape) > 1: + attr_scores = attr_scores[0] + if isinstance(boxes, torch.Tensor): + boxes = boxes.numpy() + if isinstance(boxes, list): + boxes = np.array(boxes) + assert isinstance(boxes, np.ndarray) + areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1) + sorted_idxs = np.argsort(-areas).tolist() + boxes = boxes[sorted_idxs] if boxes is not None else None + obj_ids = obj_ids[sorted_idxs] if obj_ids is not None else None + obj_scores = obj_scores[sorted_idxs] if obj_scores is not None else None + attr_ids = attr_ids[sorted_idxs] if attr_ids is not None else None + attr_scores = attr_scores[sorted_idxs] if attr_scores is not None else None + + assigned_colors = [self._random_color(maximum=1) for _ in range(len(boxes))] + assigned_colors = [assigned_colors[idx] for idx in sorted_idxs] + if obj_ids is not None: + labels = self._create_text_labels_attr(obj_ids, obj_scores, attr_ids, attr_scores) + for i in range(len(boxes)): + color = assigned_colors[i] + self.add_box(boxes[i], color) + self.draw_labels(labels[i], boxes[i], color) + + def draw_labels(self, label, box, color): + x0, y0, x1, y1 = box + text_pos = (x0, y0) + instance_area = (y1 - y0) * (x1 - x0) + small = _SMALL_OBJ * self.scale + if instance_area < small or y1 - y0 < 40 * self.scale: + if y1 >= self.height - 5: + text_pos = (x1, y0) + else: + text_pos = (x0, y1) + + height_ratio = (y1 - y0) / np.sqrt(self.height * self.width) + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + font_size = np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) + font_size *= 0.75 * self.font_size + + self.draw_text( + text=label, + position=text_pos, + color=lighter_color, + ) + + def draw_text( + self, + text, + position, + color="g", + ha="left", + ): + rotation = 0 + font_size = self.font_size + color = np.maximum(list(mplc.to_rgb(color)), 0.2) + color[np.argmax(color)] = max(0.8, np.max(color)) + bbox = { + "facecolor": "black", + "alpha": self.alpha, + "pad": self.pad, + "edgecolor": "none", + } + x, y = position + self.ax.text( + x, + y, + text, + size=font_size * self.scale, + family="sans-serif", + bbox=bbox, + verticalalignment="top", + horizontalalignment=ha, + color=color, + zorder=10, + rotation=rotation, + ) + + def save(self, saveas=None): + if saveas is None: + saveas = self.saveas + if saveas.lower().endswith(".jpg") or saveas.lower().endswith(".png"): + cv2.imwrite( + saveas, + self._get_buffer()[:, :, ::-1], + ) + else: + self.fig.savefig(saveas) + + def _create_text_labels_attr(self, classes, scores, attr_classes, attr_scores): + labels = [self.id2obj[i] for i in classes] + attr_labels = [self.id2attr[i] for i in attr_classes] + labels = [ + f"{label} {score:.2f} {attr} {attr_score:.2f}" + for label, score, attr, attr_score in zip(labels, scores, attr_labels, attr_scores) + ] + return labels + + def _create_text_labels(self, classes, scores): + labels = [self.id2obj[i] for i in classes] + if scores is not None: + if labels is None: + labels = ["{:.0f}%".format(s * 100) for s in scores] + else: + labels = ["{} {:.0f}%".format(li, s * 100) for li, s in zip(labels, scores)] + return labels + + def _random_color(self, maximum=255): + idx = np.random.randint(0, len(_COLORS)) + ret = _COLORS[idx] * maximum + if not self.rgb: + ret = ret[::-1] + return ret + + def _get_buffer(self): + if not self.pynb: + s, (width, height) = self.canvas.print_to_buffer() + if (width, height) != (self.width, self.height): + img = cv2.resize(self.img, (width, height)) + else: + img = self.img + else: + buf = io.BytesIO() # works for cairo backend + self.canvas.print_rgba(buf) + width, height = self.width, self.height + s = buf.getvalue() + img = self.img + + buffer = np.frombuffer(s, dtype="uint8") + img_rgba = buffer.reshape(height, width, 4) + rgb, alpha = np.split(img_rgba, [3], axis=2) + + try: + import numexpr as ne # fuse them with numexpr + + visualized_image = ne.evaluate("img * (1 - alpha / 255.0) + rgb * (alpha / 255.0)") + except ImportError: + alpha = alpha.astype("float32") / 255.0 + visualized_image = img * (1 - alpha) + rgb * alpha + + return visualized_image.astype("uint8") + + def _change_color_brightness(self, color, brightness_factor): + assert brightness_factor >= -1.0 and brightness_factor <= 1.0 + color = mplc.to_rgb(color) + polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color)) + modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1]) + modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness + modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness + modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2]) + return modified_color + + +# Color map +_COLORS = ( + np.array( + [ + 0.000, + 0.447, + 0.741, + 0.850, + 0.325, + 0.098, + 0.929, + 0.694, + 0.125, + 0.494, + 0.184, + 0.556, + 0.466, + 0.674, + 0.188, + 0.301, + 0.745, + 0.933, + 0.635, + 0.078, + 0.184, + 0.300, + 0.300, + 0.300, + 0.600, + 0.600, + 0.600, + 1.000, + 0.000, + 0.000, + 1.000, + 0.500, + 0.000, + 0.749, + 0.749, + 0.000, + 0.000, + 1.000, + 0.000, + 0.000, + 0.000, + 1.000, + 0.667, + 0.000, + 1.000, + 0.333, + 0.333, + 0.000, + 0.333, + 0.667, + 0.000, + 0.333, + 1.000, + 0.000, + 0.667, + 0.333, + 0.000, + 0.667, + 0.667, + 0.000, + 0.667, + 1.000, + 0.000, + 1.000, + 0.333, + 0.000, + 1.000, + 0.667, + 0.000, + 1.000, + 1.000, + 0.000, + 0.000, + 0.333, + 0.500, + 0.000, + 0.667, + 0.500, + 0.000, + 1.000, + 0.500, + 0.333, + 0.000, + 0.500, + 0.333, + 0.333, + 0.500, + 0.333, + 0.667, + 0.500, + 0.333, + 1.000, + 0.500, + 0.667, + 0.000, + 0.500, + 0.667, + 0.333, + 0.500, + 0.667, + 0.667, + 0.500, + 0.667, + 1.000, + 0.500, + 1.000, + 0.000, + 0.500, + 1.000, + 0.333, + 0.500, + 1.000, + 0.667, + 0.500, + 1.000, + 1.000, + 0.500, + 0.000, + 0.333, + 1.000, + 0.000, + 0.667, + 1.000, + 0.000, + 1.000, + 1.000, + 0.333, + 0.000, + 1.000, + 0.333, + 0.333, + 1.000, + 0.333, + 0.667, + 1.000, + 0.333, + 1.000, + 1.000, + 0.667, + 0.000, + 1.000, + 0.667, + 0.333, + 1.000, + 0.667, + 0.667, + 1.000, + 0.667, + 1.000, + 1.000, + 1.000, + 0.000, + 1.000, + 1.000, + 0.333, + 1.000, + 1.000, + 0.667, + 1.000, + 0.333, + 0.000, + 0.000, + 0.500, + 0.000, + 0.000, + 0.667, + 0.000, + 0.000, + 0.833, + 0.000, + 0.000, + 1.000, + 0.000, + 0.000, + 0.000, + 0.167, + 0.000, + 0.000, + 0.333, + 0.000, + 0.000, + 0.500, + 0.000, + 0.000, + 0.667, + 0.000, + 0.000, + 0.833, + 0.000, + 0.000, + 1.000, + 0.000, + 0.000, + 0.000, + 0.167, + 0.000, + 0.000, + 0.333, + 0.000, + 0.000, + 0.500, + 0.000, + 0.000, + 0.667, + 0.000, + 0.000, + 0.833, + 0.000, + 0.000, + 1.000, + 0.000, + 0.000, + 0.000, + 0.143, + 0.143, + 0.143, + 0.857, + 0.857, + 0.857, + 1.000, + 1.000, + 1.000, + ] + ) + .astype(np.float32) + .reshape(-1, 3) +) \ No newline at end of file diff --git a/lxmert/src/vqa_utils.py b/lxmert/src/vqa_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1faf9feffa1d4b7dceeabc6c2758fc22463bb1b1 --- /dev/null +++ b/lxmert/src/vqa_utils.py @@ -0,0 +1,559 @@ +""" + coding=utf-8 + Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal, Huggingface team :) + Adapted From Facebook Inc, Detectron2 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License.import copy + """ + +import copy +import fnmatch +import json +import os +import pickle as pkl +import shutil +import sys +import tarfile +import tempfile +from collections import OrderedDict +from contextlib import contextmanager +from functools import partial +from hashlib import sha256 +from io import BytesIO +from pathlib import Path +from urllib.parse import urlparse +from zipfile import ZipFile, is_zipfile + +import numpy as np +from PIL import Image +from tqdm.auto import tqdm + +import cv2 +import requests +import wget +from filelock import FileLock +from yaml import Loader, dump, load + + +try: + import torch + + _torch_available = True +except ImportError: + _torch_available = False + + +try: + from torch.hub import _get_torch_home + + torch_cache_home = _get_torch_home() +except ImportError: + torch_cache_home = os.path.expanduser( + os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) + ) + +default_cache_path = os.path.join(torch_cache_home, "transformers") + +CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co" +S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" +PATH = "/".join(str(Path(__file__).resolve()).split("/")[:-1]) +CONFIG = os.path.join(PATH, "config.yaml") +ATTRIBUTES = os.path.join(PATH, "attributes.txt") +OBJECTS = os.path.join(PATH, "objects.txt") +PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) +PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) +TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) +WEIGHTS_NAME = "pytorch_model.bin" +CONFIG_NAME = "config.yaml" + + +def load_labels(objs=OBJECTS, attrs=ATTRIBUTES): + vg_classes = [] + with open(objs) as f: + for object in f.readlines(): + vg_classes.append(object.split(",")[0].lower().strip()) + + vg_attrs = [] + with open(attrs) as f: + for object in f.readlines(): + vg_attrs.append(object.split(",")[0].lower().strip()) + return vg_classes, vg_attrs + + +def load_checkpoint(ckp): + r = OrderedDict() + with open(ckp, "rb") as f: + ckp = pkl.load(f)["model"] + for k in copy.deepcopy(list(ckp.keys())): + v = ckp.pop(k) + if isinstance(v, np.ndarray): + v = torch.tensor(v) + else: + assert isinstance(v, torch.tensor), type(v) + r[k] = v + return r + + +class Config: + _pointer = {} + + def __init__(self, dictionary: dict, name: str = "root", level=0): + self._name = name + self._level = level + d = {} + for k, v in dictionary.items(): + if v is None: + raise ValueError() + k = copy.deepcopy(k) + v = copy.deepcopy(v) + if isinstance(v, dict): + v = Config(v, name=k, level=level + 1) + d[k] = v + setattr(self, k, v) + + self._pointer = d + + def __repr__(self): + return str(list((self._pointer.keys()))) + + def __setattr__(self, key, val): + self.__dict__[key] = val + self.__dict__[key.upper()] = val + levels = key.split(".") + last_level = len(levels) - 1 + pointer = self._pointer + if len(levels) > 1: + for i, l in enumerate(levels): + if hasattr(self, l) and isinstance(getattr(self, l), Config): + setattr(getattr(self, l), ".".join(levels[i:]), val) + if l == last_level: + pointer[l] = val + else: + pointer = pointer[l] + + def to_dict(self): + return self._pointer + + def dump_yaml(self, data, file_name): + with open(f"{file_name}", "w") as stream: + dump(data, stream) + + def dump_json(self, data, file_name): + with open(f"{file_name}", "w") as stream: + json.dump(data, stream) + + @staticmethod + def load_yaml(config): + with open(config) as stream: + data = load(stream, Loader=Loader) + return data + + def __str__(self): + t = " " + if self._name != "root": + r = f"{t * (self._level-1)}{self._name}:\n" + else: + r = "" + level = self._level + for i, (k, v) in enumerate(self._pointer.items()): + if isinstance(v, Config): + r += f"{t * (self._level)}{v}\n" + self._level += 1 + else: + r += f"{t * (self._level)}{k}: {v} ({type(v).__name__})\n" + self._level = level + return r[:-1] + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + return cls(config_dict) + + @classmethod + def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs): + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + + if os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + else: + config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False) + + try: + # Load from URL or cache if already cached + resolved_config_file = cached_path( + config_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + ) + # Load config dict + if resolved_config_file is None: + raise EnvironmentError + + config_file = Config.load_yaml(resolved_config_file) + + except EnvironmentError: + msg = "Can't load config for" + raise EnvironmentError(msg) + + if resolved_config_file == config_file: + print("loading configuration file from path") + else: + print("loading configuration file cache") + + return Config.load_yaml(resolved_config_file), kwargs + + +# quick compare tensors +def compare(in_tensor): + + out_tensor = torch.load("dump.pt", map_location=in_tensor.device) + n1 = in_tensor.numpy() + n2 = out_tensor.numpy()[0] + print(n1.shape, n1[0, 0, :5]) + print(n2.shape, n2[0, 0, :5]) + assert np.allclose( + n1, n2, rtol=0.01, atol=0.1 + ), f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} % element-wise mismatch" + raise Exception("tensors are all good") + + # Hugging face functions below + + +def is_remote_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + + +def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str: + endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX + legacy_format = "/" not in model_id + if legacy_format: + return f"{endpoint}/{model_id}-{filename}" + else: + return f"{endpoint}/{model_id}/{filename}" + + +def http_get( + url, + temp_file, + proxies=None, + resume_size=0, + user_agent=None, +): + ua = "python/{}".format(sys.version.split()[0]) + if _torch_available: + ua += "; torch/{}".format(torch.__version__) + if isinstance(user_agent, dict): + ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) + elif isinstance(user_agent, str): + ua += "; " + user_agent + headers = {"user-agent": ua} + if resume_size > 0: + headers["Range"] = "bytes=%d-" % (resume_size,) + response = requests.get(url, stream=True, proxies=proxies, headers=headers) + if response.status_code == 416: # Range not satisfiable + return + content_length = response.headers.get("Content-Length") + total = resume_size + int(content_length) if content_length is not None else None + progress = tqdm( + unit="B", + unit_scale=True, + total=total, + initial=resume_size, + desc="Downloading", + ) + for chunk in response.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache( + url, + cache_dir=None, + force_download=False, + proxies=None, + etag_timeout=10, + resume_download=False, + user_agent=None, + local_files_only=False, +): + + if cache_dir is None: + cache_dir = TRANSFORMERS_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + os.makedirs(cache_dir, exist_ok=True) + + etag = None + if not local_files_only: + try: + response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) + if response.status_code == 200: + etag = response.headers.get("ETag") + except (EnvironmentError, requests.exceptions.Timeout): + # etag is already None + pass + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible. + # try to get the last downloaded one + if etag is None: + if os.path.exists(cache_path): + return cache_path + else: + matching_files = [ + file + for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") + if not file.endswith(".json") and not file.endswith(".lock") + ] + if len(matching_files) > 0: + return os.path.join(cache_dir, matching_files[-1]) + else: + # If files cannot be found and local_files_only=True, + # the models might've been found if local_files_only=False + # Notify the user about that + if local_files_only: + raise ValueError( + "Cannot find the requested files in the cached path and outgoing traffic has been" + " disabled. To enable model look-ups and downloads online, set 'local_files_only'" + " to False." + ) + return None + + # From now on, etag is not None. + if os.path.exists(cache_path) and not force_download: + return cache_path + + # Prevent parallel downloads of the same file with a lock. + lock_path = cache_path + ".lock" + with FileLock(lock_path): + + # If the download just completed while the lock was activated. + if os.path.exists(cache_path) and not force_download: + # Even if returning early like here, the lock will be released. + return cache_path + + if resume_download: + incomplete_path = cache_path + ".incomplete" + + @contextmanager + def _resumable_file_manager(): + with open(incomplete_path, "a+b") as f: + yield f + + temp_file_manager = _resumable_file_manager + if os.path.exists(incomplete_path): + resume_size = os.stat(incomplete_path).st_size + else: + resume_size = 0 + else: + temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) + resume_size = 0 + + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with temp_file_manager() as temp_file: + print( + "%s not found in cache or force_download set to True, downloading to %s", + url, + temp_file.name, + ) + + http_get( + url, + temp_file, + proxies=proxies, + resume_size=resume_size, + user_agent=user_agent, + ) + + os.replace(temp_file.name, cache_path) + + meta = {"url": url, "etag": etag} + meta_path = cache_path + ".json" + with open(meta_path, "w") as meta_file: + json.dump(meta, meta_file) + + return cache_path + + +def url_to_filename(url, etag=None): + + url_bytes = url.encode("utf-8") + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode("utf-8") + etag_hash = sha256(etag_bytes) + filename += "." + etag_hash.hexdigest() + + if url.endswith(".h5"): + filename += ".h5" + + return filename + + +def cached_path( + url_or_filename, + cache_dir=None, + force_download=False, + proxies=None, + resume_download=False, + user_agent=None, + extract_compressed_file=False, + force_extract=False, + local_files_only=False, +): + if cache_dir is None: + cache_dir = TRANSFORMERS_CACHE + if isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if is_remote_url(url_or_filename): + # URL, so get it from the cache (downloading if necessary) + output_path = get_from_cache( + url_or_filename, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + user_agent=user_agent, + local_files_only=local_files_only, + ) + elif os.path.exists(url_or_filename): + # File, and it exists. + output_path = url_or_filename + elif urlparse(url_or_filename).scheme == "": + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) + + if extract_compressed_file: + if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): + return output_path + + # Path where we extract compressed archives + # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/" + output_dir, output_file = os.path.split(output_path) + output_extract_dir_name = output_file.replace(".", "-") + "-extracted" + output_path_extracted = os.path.join(output_dir, output_extract_dir_name) + + if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: + return output_path_extracted + + # Prevent parallel extractions + lock_path = output_path + ".lock" + with FileLock(lock_path): + shutil.rmtree(output_path_extracted, ignore_errors=True) + os.makedirs(output_path_extracted) + if is_zipfile(output_path): + with ZipFile(output_path, "r") as zip_file: + zip_file.extractall(output_path_extracted) + zip_file.close() + elif tarfile.is_tarfile(output_path): + tar_file = tarfile.open(output_path) + tar_file.extractall(output_path_extracted) + tar_file.close() + else: + raise EnvironmentError("Archive format of {} could not be identified".format(output_path)) + + return output_path_extracted + + return output_path + + +def get_data(query, delim=","): + assert isinstance(query, str) + if os.path.isfile(query): + with open(query) as f: + data = eval(f.read()) + else: + req = requests.get(query) + try: + data = requests.json() + except Exception: + data = req.content.decode() + assert data is not None, "could not connect" + try: + data = eval(data) + except Exception: + data = data.split("\n") + req.close() + return data + + +def get_image_from_url(url): + response = requests.get(url) + img = np.array(Image.open(BytesIO(response.content))) + return img + + +# to load legacy frcnn checkpoint from detectron +def load_frcnn_pkl_from_url(url): + fn = url.split("/")[-1] + if fn not in os.listdir(os.getcwd()): + wget.download(url) + with open(fn, "rb") as stream: + weights = pkl.load(stream) + model = weights.pop("model") + new = {} + for k, v in model.items(): + new[k] = torch.from_numpy(v) + if "running_var" in k: + zero = torch.Tensor([0]) + k2 = k.replace("running_var", "num_batches_tracked") + new[k2] = zero + return new + + +def get_demo_path(): + print(f"{os.path.abspath(os.path.join(PATH, os.pardir))}/demo.ipynb") + + +def img_tensorize(im, input_format="RGB"): + assert isinstance(im, str) + if os.path.isfile(im): + img = cv2.imread(im) + else: + img = get_image_from_url(im) + assert img is not None, f"could not connect to: {im}" + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if input_format == "RGB": + img = img[:, :, ::-1] + return img + + +def chunk(images, batch=1): + return (images[i : i + batch] for i in range(0, len(images), batch))