WwYc's picture
Upload 61 files
08d7644 verified
# coding=utf-8
# Copyright 2019 project LXRT.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import torch.nn as nn
from ..lxrt.tokenization import BertTokenizer
from ..lxrt.modeling import LXRTFeatureExtraction as VisualBertForLXRFeature, VISUAL_CONFIG
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, input_mask, segment_ids):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
def convert_sents_to_features(sents, max_seq_length, tokenizer):
"""Loads a data file into a list of `InputBatch`s."""
features = []
for (i, sent) in enumerate(sents):
tokens_a = tokenizer.tokenize(sent.strip())
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[:(max_seq_length - 2)]
# Keep segment id which allows loading BERT-weights.
tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
segment_ids = [0] * len(tokens)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
input_mask += padding
segment_ids += padding
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
features.append(
InputFeatures(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids))
return features
def set_visual_config(args):
VISUAL_CONFIG.l_layers = args.llayers
VISUAL_CONFIG.x_layers = args.xlayers
VISUAL_CONFIG.r_layers = args.rlayers
class LXRTEncoder(nn.Module):
def __init__(self, args, max_seq_length, mode='x'):
super().__init__()
self.max_seq_length = max_seq_length
set_visual_config(args)
# Using the bert tokenizer
self.tokenizer = BertTokenizer.from_pretrained(
"bert-base-uncased",
do_lower_case=True
)
# Build LXRT Model
self.model = VisualBertForLXRFeature.from_pretrained(
"bert-base-uncased",
mode=mode
)
if args.from_scratch:
print("initializing all the weights")
self.model.apply(self.model.init_bert_weights)
def multi_gpu(self):
self.model = nn.DataParallel(self.model)
@property
def dim(self):
return 768
def forward(self, sents, feats, visual_attention_mask=None):
train_features = convert_sents_to_features(
sents, self.max_seq_length, self.tokenizer)
input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda()
input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda()
segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda()
output = self.model(input_ids, segment_ids, input_mask,
visual_feats=feats,
visual_attention_mask=visual_attention_mask)
return output
def save(self, path):
torch.save(self.model.state_dict(),
os.path.join("%s_LXRT.pth" % path))
def load(self, path):
# Load state_dict from snapshot file
print("Load lxmert pre-trained model from %s" % path)
state_dict = torch.load("%s_LXRT.pth" % path)
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith("module."):
new_state_dict[key[len("module."):]] = value
else:
new_state_dict[key] = value
state_dict = new_state_dict
# Print out the differences of pre-trained and model weights.
load_keys = set(state_dict.keys())
model_keys = set(self.model.state_dict().keys())
print()
print("Weights in loaded but not in model:")
for key in sorted(load_keys.difference(model_keys)):
print(key)
print()
print("Weights in model but not in loaded:")
for key in sorted(model_keys.difference(load_keys)):
print(key)
print()
# Load weights to model
self.model.load_state_dict(state_dict, strict=False)