WwYc's picture
Upload 61 files
08d7644 verified
raw
history blame
7.79 kB
# coding=utf-8
# Copyleft 2019 project LXRT.
import os
import collections
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from ..param import args
from ..pretrain.qa_answer_table import load_lxmert_qa
from .vqa_model import VQAModel
from .vqa_data import VQADataset, VQATorchDataset, VQAEvaluator
DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
def get_data_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
dset = VQADataset(splits)
tset = VQATorchDataset(dset)
evaluator = VQAEvaluator(dset)
data_loader = DataLoader(
tset, batch_size=bs,
shuffle=shuffle, num_workers=args.num_workers,
drop_last=drop_last, pin_memory=True
)
return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
class VQA:
def __init__(self):
# Datasets
self.train_tuple = get_data_tuple(
args.train, bs=args.batch_size, shuffle=True, drop_last=True
)
if args.valid != "":
self.valid_tuple = get_data_tuple(
args.valid, bs=1024,
shuffle=False, drop_last=False
)
else:
self.valid_tuple = None
# Model
self.model = VQAModel(self.train_tuple.dataset.num_answers)
# Load pre-trained weights
if args.load_lxmert is not None:
self.model.lxrt_encoder.load(args.load_lxmert)
if args.load_lxmert_qa is not None:
load_lxmert_qa(args.load_lxmert_qa, self.model,
label2ans=self.train_tuple.dataset.label2ans)
# GPU options
self.model = self.model.cuda()
if args.multiGPU:
self.model.lxrt_encoder.multi_gpu()
# Loss and Optimizer
self.bce_loss = nn.BCEWithLogitsLoss()
if 'bert' in args.optim:
batch_per_epoch = len(self.train_tuple.loader)
t_total = int(batch_per_epoch * args.epochs)
print("BertAdam Total Iters: %d" % t_total)
from ..lxrt.optimization import BertAdam
self.optim = BertAdam(list(self.model.parameters()),
lr=args.lr,
warmup=0.1,
t_total=t_total)
else:
self.optim = args.optimizer(self.model.parameters(), args.lr)
# Output Directory
self.output = args.output
os.makedirs(self.output, exist_ok=True)
def train(self, train_tuple, eval_tuple):
dset, loader, evaluator = train_tuple
iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
best_valid = 0.
for epoch in range(args.epochs):
quesid2ans = {}
for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)):
self.model.train()
self.optim.zero_grad()
feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
logit = self.model(feats, boxes, sent)
assert logit.dim() == target.dim() == 2
loss = self.bce_loss(logit, target)
loss = loss * logit.size(1)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
self.optim.step()
score, label = logit.max(1)
for qid, l in zip(ques_id, label.cpu().numpy()):
ans = dset.label2ans[l]
quesid2ans[qid.item()] = ans
log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
if self.valid_tuple is not None: # Do Validation
valid_score = self.evaluate(eval_tuple)
if valid_score > best_valid:
best_valid = valid_score
self.save("BEST")
log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
"Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
print(log_str, end='')
with open(self.output + "/log.log", 'a') as f:
f.write(log_str)
f.flush()
self.save("LAST")
def predict(self, eval_tuple: DataTuple, dump=None):
"""
Predict the answers to questions in a data split.
:param eval_tuple: The data tuple to be evaluated.
:param dump: The path of saved file to dump results.
:return: A dict of question_id to answer.
"""
self.model.eval()
dset, loader, evaluator = eval_tuple
quesid2ans = {}
for i, datum_tuple in enumerate(loader):
ques_id, feats, boxes, sent = datum_tuple[:4] # Avoid seeing ground truth
with torch.no_grad():
feats, boxes = feats.cuda(), boxes.cuda()
logit = self.model(feats, boxes, sent)
score, label = logit.max(1)
for qid, l in zip(ques_id, label.cpu().numpy()):
ans = dset.label2ans[l]
quesid2ans[qid.item()] = ans
if dump is not None:
evaluator.dump_result(quesid2ans, dump)
return quesid2ans
def evaluate(self, eval_tuple: DataTuple, dump=None):
"""Evaluate all data in data_tuple."""
quesid2ans = self.predict(eval_tuple, dump)
return eval_tuple.evaluator.evaluate(quesid2ans)
@staticmethod
def oracle_score(data_tuple):
dset, loader, evaluator = data_tuple
quesid2ans = {}
for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
_, label = target.max(1)
for qid, l in zip(ques_id, label.cpu().numpy()):
ans = dset.label2ans[l]
quesid2ans[qid.item()] = ans
return evaluator.evaluate(quesid2ans)
def save(self, name):
torch.save(self.model.state_dict(),
os.path.join(self.output, "%s.pth" % name))
def load(self, path):
print("Load model from %s" % path)
state_dict = torch.load("%s.pth" % path)
self.model.load_state_dict(state_dict)
if __name__ == "__main__":
# Build Class
vqa = VQA()
# Load VQA model weights
# Note: It is different from loading lxmert pre-trained weights.
if args.load is not None:
vqa.load(args.load)
# Test or Train
if args.test is not None:
args.fast = args.tiny = False # Always loading all data in test
if 'test' in args.test:
vqa.predict(
get_data_tuple(args.test, bs=950,
shuffle=False, drop_last=False),
dump=os.path.join(args.output, 'test_predict.json')
)
elif 'val' in args.test:
# Since part of valididation data are used in pre-training/fine-tuning,
# only validate on the minival set.
result = vqa.evaluate(
get_data_tuple('minival', bs=950,
shuffle=False, drop_last=False),
dump=os.path.join(args.output, 'minival_predict.json')
)
print(result)
else:
assert False, "No such test option for %s" % args.test
else:
print('Splits in Train data:', vqa.train_tuple.dataset.splits)
if vqa.valid_tuple is not None:
print('Splits in Valid data:', vqa.valid_tuple.dataset.splits)
print("Valid Oracle: %0.2f" % (vqa.oracle_score(vqa.valid_tuple) * 100))
else:
print("DO NOT USE VALIDATION")
vqa.train(vqa.train_tuple, vqa.valid_tuple)