# coding=utf-8 # Copyleft 2019 project LXRT. import os import collections import torch import torch.nn as nn from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from ..param import args from ..pretrain.qa_answer_table import load_lxmert_qa from .vqa_model import VQAModel from .vqa_data import VQADataset, VQATorchDataset, VQAEvaluator DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator') def get_data_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple: dset = VQADataset(splits) tset = VQATorchDataset(dset) evaluator = VQAEvaluator(dset) data_loader = DataLoader( tset, batch_size=bs, shuffle=shuffle, num_workers=args.num_workers, drop_last=drop_last, pin_memory=True ) return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator) class VQA: def __init__(self): # Datasets self.train_tuple = get_data_tuple( args.train, bs=args.batch_size, shuffle=True, drop_last=True ) if args.valid != "": self.valid_tuple = get_data_tuple( args.valid, bs=1024, shuffle=False, drop_last=False ) else: self.valid_tuple = None # Model self.model = VQAModel(self.train_tuple.dataset.num_answers) # Load pre-trained weights if args.load_lxmert is not None: self.model.lxrt_encoder.load(args.load_lxmert) if args.load_lxmert_qa is not None: load_lxmert_qa(args.load_lxmert_qa, self.model, label2ans=self.train_tuple.dataset.label2ans) # GPU options self.model = self.model.cuda() if args.multiGPU: self.model.lxrt_encoder.multi_gpu() # Loss and Optimizer self.bce_loss = nn.BCEWithLogitsLoss() if 'bert' in args.optim: batch_per_epoch = len(self.train_tuple.loader) t_total = int(batch_per_epoch * args.epochs) print("BertAdam Total Iters: %d" % t_total) from ..lxrt.optimization import BertAdam self.optim = BertAdam(list(self.model.parameters()), lr=args.lr, warmup=0.1, t_total=t_total) else: self.optim = args.optimizer(self.model.parameters(), args.lr) # Output Directory self.output = args.output os.makedirs(self.output, exist_ok=True) def train(self, train_tuple, eval_tuple): dset, loader, evaluator = train_tuple iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x) best_valid = 0. for epoch in range(args.epochs): quesid2ans = {} for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)): self.model.train() self.optim.zero_grad() feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda() logit = self.model(feats, boxes, sent) assert logit.dim() == target.dim() == 2 loss = self.bce_loss(logit, target) loss = loss * logit.size(1) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) self.optim.step() score, label = logit.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid.item()] = ans log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.) if self.valid_tuple is not None: # Do Validation valid_score = self.evaluate(eval_tuple) if valid_score > best_valid: best_valid = valid_score self.save("BEST") log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \ "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) print(log_str, end='') with open(self.output + "/log.log", 'a') as f: f.write(log_str) f.flush() self.save("LAST") def predict(self, eval_tuple: DataTuple, dump=None): """ Predict the answers to questions in a data split. :param eval_tuple: The data tuple to be evaluated. :param dump: The path of saved file to dump results. :return: A dict of question_id to answer. """ self.model.eval() dset, loader, evaluator = eval_tuple quesid2ans = {} for i, datum_tuple in enumerate(loader): ques_id, feats, boxes, sent = datum_tuple[:4] # Avoid seeing ground truth with torch.no_grad(): feats, boxes = feats.cuda(), boxes.cuda() logit = self.model(feats, boxes, sent) score, label = logit.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid.item()] = ans if dump is not None: evaluator.dump_result(quesid2ans, dump) return quesid2ans def evaluate(self, eval_tuple: DataTuple, dump=None): """Evaluate all data in data_tuple.""" quesid2ans = self.predict(eval_tuple, dump) return eval_tuple.evaluator.evaluate(quesid2ans) @staticmethod def oracle_score(data_tuple): dset, loader, evaluator = data_tuple quesid2ans = {} for i, (ques_id, feats, boxes, sent, target) in enumerate(loader): _, label = target.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = dset.label2ans[l] quesid2ans[qid.item()] = ans return evaluator.evaluate(quesid2ans) def save(self, name): torch.save(self.model.state_dict(), os.path.join(self.output, "%s.pth" % name)) def load(self, path): print("Load model from %s" % path) state_dict = torch.load("%s.pth" % path) self.model.load_state_dict(state_dict) if __name__ == "__main__": # Build Class vqa = VQA() # Load VQA model weights # Note: It is different from loading lxmert pre-trained weights. if args.load is not None: vqa.load(args.load) # Test or Train if args.test is not None: args.fast = args.tiny = False # Always loading all data in test if 'test' in args.test: vqa.predict( get_data_tuple(args.test, bs=950, shuffle=False, drop_last=False), dump=os.path.join(args.output, 'test_predict.json') ) elif 'val' in args.test: # Since part of valididation data are used in pre-training/fine-tuning, # only validate on the minival set. result = vqa.evaluate( get_data_tuple('minival', bs=950, shuffle=False, drop_last=False), dump=os.path.join(args.output, 'minival_predict.json') ) print(result) else: assert False, "No such test option for %s" % args.test else: print('Splits in Train data:', vqa.train_tuple.dataset.splits) if vqa.valid_tuple is not None: print('Splits in Valid data:', vqa.valid_tuple.dataset.splits) print("Valid Oracle: %0.2f" % (vqa.oracle_score(vqa.valid_tuple) * 100)) else: print("DO NOT USE VALIDATION") vqa.train(vqa.train_tuple, vqa.valid_tuple)