Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyleft 2019 project LXRT. | |
import os | |
import collections | |
import torch | |
from tqdm import tqdm | |
import torch.nn as nn | |
from torch.utils.data.dataloader import DataLoader | |
from param import args | |
from pretrain.qa_answer_table import load_lxmert_qa | |
from tasks.gqa_model import GQAModel | |
from tasks.gqa_data import GQADataset, GQATorchDataset, GQAEvaluator | |
DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator') | |
def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple: | |
dset = GQADataset(splits) | |
tset = GQATorchDataset(dset) | |
evaluator = GQAEvaluator(dset) | |
data_loader = DataLoader( | |
tset, batch_size=bs, | |
shuffle=shuffle, num_workers=args.num_workers, | |
drop_last=drop_last, pin_memory=True | |
) | |
return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator) | |
class GQA: | |
def __init__(self): | |
self.train_tuple = get_tuple( | |
args.train, bs=args.batch_size, shuffle=True, drop_last=True | |
) | |
if args.valid != "": | |
valid_bsize = 2048 if args.multiGPU else 512 | |
self.valid_tuple = get_tuple( | |
args.valid, bs=valid_bsize, | |
shuffle=False, drop_last=False | |
) | |
else: | |
self.valid_tuple = None | |
self.model = GQAModel(self.train_tuple.dataset.num_answers) | |
# Load pre-trained weights | |
if args.load_lxmert is not None: | |
self.model.lxrt_encoder.load(args.load_lxmert) | |
if args.load_lxmert_qa is not None: | |
load_lxmert_qa(args.load_lxmert_qa, self.model, | |
label2ans=self.train_tuple.dataset.label2ans) | |
# GPU options | |
self.model = self.model.cuda() | |
if args.multiGPU: | |
self.model.lxrt_encoder.multi_gpu() | |
# Losses and optimizer | |
self.bce_loss = nn.BCEWithLogitsLoss() | |
self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1) | |
if 'bert' in args.optim: | |
batch_per_epoch = len(self.train_tuple.loader) | |
t_total = int(batch_per_epoch * args.epochs) | |
print("Total Iters: %d" % t_total) | |
from lxrt.optimization import BertAdam | |
self.optim = BertAdam(list(self.model.parameters()), | |
lr=args.lr, | |
warmup=0.1, | |
t_total=t_total) | |
else: | |
self.optim = args.optimizer(list(self.model.parameters()), args.lr) | |
self.output = args.output | |
os.makedirs(self.output, exist_ok=True) | |
def train(self, train_tuple, eval_tuple): | |
dset, loader, evaluator = train_tuple | |
iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x) | |
best_valid = 0. | |
for epoch in range(args.epochs): | |
quesid2ans = {} | |
for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)): | |
self.model.train() | |
self.optim.zero_grad() | |
feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda() | |
logit = self.model(feats, boxes, sent) | |
assert logit.dim() == target.dim() == 2 | |
if args.mce_loss: | |
max_value, target = target.max(1) | |
loss = self.mce_loss(logit, target) * logit.size(1) | |
else: | |
loss = self.bce_loss(logit, target) | |
loss = loss * logit.size(1) | |
loss.backward() | |
nn.utils.clip_grad_norm_(self.model.parameters(), 5.) | |
self.optim.step() | |
score, label = logit.max(1) | |
for qid, l in zip(ques_id, label.cpu().numpy()): | |
ans = dset.label2ans[l] | |
quesid2ans[qid] = ans | |
log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.) | |
if self.valid_tuple is not None: # Do Validation | |
valid_score = self.evaluate(eval_tuple) | |
if valid_score > best_valid: | |
best_valid = valid_score | |
self.save("BEST") | |
log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ | |
"Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) | |
print(log_str, end='') | |
with open(self.output + "/log.log", 'a') as f: | |
f.write(log_str) | |
f.flush() | |
self.save("LAST") | |
def predict(self, eval_tuple: DataTuple, dump=None): | |
self.model.eval() | |
dset, loader, evaluator = eval_tuple | |
quesid2ans = {} | |
for i, datum_tuple in enumerate(loader): | |
ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target | |
with torch.no_grad(): | |
feats, boxes = feats.cuda(), boxes.cuda() | |
logit = self.model(feats, boxes, sent) | |
score, label = logit.max(1) | |
for qid, l in zip(ques_id, label.cpu().numpy()): | |
ans = dset.label2ans[l] | |
quesid2ans[qid] = ans | |
if dump is not None: | |
evaluator.dump_result(quesid2ans, dump) | |
return quesid2ans | |
def evaluate(self, eval_tuple: DataTuple, dump=None): | |
dset, loader, evaluator = eval_tuple | |
quesid2ans = self.predict(eval_tuple, dump) | |
return evaluator.evaluate(quesid2ans) | |
def oracle_score(data_tuple): | |
dset, loader, evaluator = data_tuple | |
quesid2ans = {} | |
for i, (ques_id, feats, boxes, sent, target) in enumerate(loader): | |
_, label = target.max(1) | |
for qid, l in zip(ques_id, label.cpu().numpy()): | |
ans = dset.label2ans[l] | |
quesid2ans[qid] = ans | |
return evaluator.evaluate(quesid2ans) | |
def save(self, name): | |
torch.save(self.model.state_dict(), | |
os.path.join(self.output, "%s.pth" % name)) | |
def load(self, path): | |
print("Load model from %s" % path) | |
state_dict = torch.load("%s.pth" % path) | |
for key in list(state_dict.keys()): | |
if '.module' in key: | |
state_dict[key.replace('.module', '')] = state_dict.pop(key) | |
self.model.load_state_dict(state_dict, strict=False) | |
if __name__ == "__main__": | |
# Build Class | |
gqa = GQA() | |
# Load Model | |
if args.load is not None: | |
gqa.load(args.load) | |
# Test or Train | |
if args.test is not None: | |
args.fast = args.tiny = False # Always loading all data in test | |
if 'submit' in args.test: | |
gqa.predict( | |
get_tuple(args.test, bs=args.batch_size, | |
shuffle=False, drop_last=False), | |
dump=os.path.join(args.output, 'submit_predict.json') | |
) | |
if 'testdev' in args.test: | |
result = gqa.evaluate( | |
get_tuple('testdev', bs=args.batch_size, | |
shuffle=False, drop_last=False), | |
dump=os.path.join(args.output, 'testdev_predict.json') | |
) | |
print(result) | |
else: | |
# print("Train Oracle: %0.2f" % (gqa.oracle_score(gqa.train_tuple) * 100)) | |
print('Splits in Train data:', gqa.train_tuple.dataset.splits) | |
if gqa.valid_tuple is not None: | |
print('Splits in Valid data:', gqa.valid_tuple.dataset.splits) | |
print("Valid Oracle: %0.2f" % (gqa.oracle_score(gqa.valid_tuple) * 100)) | |
else: | |
print("DO NOT USE VALIDATION") | |
gqa.train(gqa.train_tuple, gqa.valid_tuple) | |