WwYc's picture
Upload 61 files
08d7644 verified
raw
history blame
7.68 kB
# coding=utf-8
# Copyleft 2019 project LXRT.
import os
import collections
import torch
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from param import args
from pretrain.qa_answer_table import load_lxmert_qa
from tasks.gqa_model import GQAModel
from tasks.gqa_data import GQADataset, GQATorchDataset, GQAEvaluator
DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
dset = GQADataset(splits)
tset = GQATorchDataset(dset)
evaluator = GQAEvaluator(dset)
data_loader = DataLoader(
tset, batch_size=bs,
shuffle=shuffle, num_workers=args.num_workers,
drop_last=drop_last, pin_memory=True
)
return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
class GQA:
def __init__(self):
self.train_tuple = get_tuple(
args.train, bs=args.batch_size, shuffle=True, drop_last=True
)
if args.valid != "":
valid_bsize = 2048 if args.multiGPU else 512
self.valid_tuple = get_tuple(
args.valid, bs=valid_bsize,
shuffle=False, drop_last=False
)
else:
self.valid_tuple = None
self.model = GQAModel(self.train_tuple.dataset.num_answers)
# Load pre-trained weights
if args.load_lxmert is not None:
self.model.lxrt_encoder.load(args.load_lxmert)
if args.load_lxmert_qa is not None:
load_lxmert_qa(args.load_lxmert_qa, self.model,
label2ans=self.train_tuple.dataset.label2ans)
# GPU options
self.model = self.model.cuda()
if args.multiGPU:
self.model.lxrt_encoder.multi_gpu()
# Losses and optimizer
self.bce_loss = nn.BCEWithLogitsLoss()
self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
if 'bert' in args.optim:
batch_per_epoch = len(self.train_tuple.loader)
t_total = int(batch_per_epoch * args.epochs)
print("Total Iters: %d" % t_total)
from lxrt.optimization import BertAdam
self.optim = BertAdam(list(self.model.parameters()),
lr=args.lr,
warmup=0.1,
t_total=t_total)
else:
self.optim = args.optimizer(list(self.model.parameters()), args.lr)
self.output = args.output
os.makedirs(self.output, exist_ok=True)
def train(self, train_tuple, eval_tuple):
dset, loader, evaluator = train_tuple
iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
best_valid = 0.
for epoch in range(args.epochs):
quesid2ans = {}
for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)):
self.model.train()
self.optim.zero_grad()
feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
logit = self.model(feats, boxes, sent)
assert logit.dim() == target.dim() == 2
if args.mce_loss:
max_value, target = target.max(1)
loss = self.mce_loss(logit, target) * logit.size(1)
else:
loss = self.bce_loss(logit, target)
loss = loss * logit.size(1)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
self.optim.step()
score, label = logit.max(1)
for qid, l in zip(ques_id, label.cpu().numpy()):
ans = dset.label2ans[l]
quesid2ans[qid] = ans
log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
if self.valid_tuple is not None: # Do Validation
valid_score = self.evaluate(eval_tuple)
if valid_score > best_valid:
best_valid = valid_score
self.save("BEST")
log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
"Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
print(log_str, end='')
with open(self.output + "/log.log", 'a') as f:
f.write(log_str)
f.flush()
self.save("LAST")
def predict(self, eval_tuple: DataTuple, dump=None):
self.model.eval()
dset, loader, evaluator = eval_tuple
quesid2ans = {}
for i, datum_tuple in enumerate(loader):
ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target
with torch.no_grad():
feats, boxes = feats.cuda(), boxes.cuda()
logit = self.model(feats, boxes, sent)
score, label = logit.max(1)
for qid, l in zip(ques_id, label.cpu().numpy()):
ans = dset.label2ans[l]
quesid2ans[qid] = ans
if dump is not None:
evaluator.dump_result(quesid2ans, dump)
return quesid2ans
def evaluate(self, eval_tuple: DataTuple, dump=None):
dset, loader, evaluator = eval_tuple
quesid2ans = self.predict(eval_tuple, dump)
return evaluator.evaluate(quesid2ans)
@staticmethod
def oracle_score(data_tuple):
dset, loader, evaluator = data_tuple
quesid2ans = {}
for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
_, label = target.max(1)
for qid, l in zip(ques_id, label.cpu().numpy()):
ans = dset.label2ans[l]
quesid2ans[qid] = ans
return evaluator.evaluate(quesid2ans)
def save(self, name):
torch.save(self.model.state_dict(),
os.path.join(self.output, "%s.pth" % name))
def load(self, path):
print("Load model from %s" % path)
state_dict = torch.load("%s.pth" % path)
for key in list(state_dict.keys()):
if '.module' in key:
state_dict[key.replace('.module', '')] = state_dict.pop(key)
self.model.load_state_dict(state_dict, strict=False)
if __name__ == "__main__":
# Build Class
gqa = GQA()
# Load Model
if args.load is not None:
gqa.load(args.load)
# Test or Train
if args.test is not None:
args.fast = args.tiny = False # Always loading all data in test
if 'submit' in args.test:
gqa.predict(
get_tuple(args.test, bs=args.batch_size,
shuffle=False, drop_last=False),
dump=os.path.join(args.output, 'submit_predict.json')
)
if 'testdev' in args.test:
result = gqa.evaluate(
get_tuple('testdev', bs=args.batch_size,
shuffle=False, drop_last=False),
dump=os.path.join(args.output, 'testdev_predict.json')
)
print(result)
else:
# print("Train Oracle: %0.2f" % (gqa.oracle_score(gqa.train_tuple) * 100))
print('Splits in Train data:', gqa.train_tuple.dataset.splits)
if gqa.valid_tuple is not None:
print('Splits in Valid data:', gqa.valid_tuple.dataset.splits)
print("Valid Oracle: %0.2f" % (gqa.oracle_score(gqa.valid_tuple) * 100))
else:
print("DO NOT USE VALIDATION")
gqa.train(gqa.train_tuple, gqa.valid_tuple)