import logging import argparse import os import json import torch from tqdm import tqdm from transformers import BertTokenizer from .models import inference_model from .data_loader import DataLoaderTest from .bert_model import BertForSequenceEncoder logger = logging.getLogger(__name__) def save_to_file(all_predict, outpath, evi_num): with open(outpath, "w") as out: for key, values in all_predict.items(): sorted_values = sorted(values, key=lambda x:x[-1], reverse=True) data = json.dumps({"id": key, "evidence": sorted_values[:evi_num]}) out.write(data + "\n") def eval_model(model, validset_reader): model.eval() all_predict = dict() for inp_tensor, msk_tensor, seg_tensor, ids, evi_list in tqdm(validset_reader): probs = model(inp_tensor, msk_tensor, seg_tensor) probs = probs.tolist() assert len(probs) == len(evi_list) for i in range(len(probs)): if ids[i] not in all_predict: all_predict[ids[i]] = [] #if probs[i][1] >= probs[i][0]: all_predict[ids[i]].append(evi_list[i] + [probs[i]]) return all_predict if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--test_path', help='train path') parser.add_argument('--name', help='train path') parser.add_argument("--batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument('--outdir', required=True, help='path to output directory') parser.add_argument('--bert_pretrain', required=True) parser.add_argument('--checkpoint', required=True) parser.add_argument('--dropout', type=float, default=0.6, help='Dropout.') parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.') parser.add_argument("--bert_hidden_dim", default=768, type=int, help="Total batch size for training.") parser.add_argument("--layer", type=int, default=1, help='Graph Layer.') parser.add_argument("--num_labels", type=int, default=3) parser.add_argument("--evi_num", type=int, default=5, help='Evidence num.') parser.add_argument("--threshold", type=float, default=0.0, help='Evidence num.') parser.add_argument("--max_len", default=120, type=int, help="The maximum total input sequence length after WordPiece tokenization. Sequences " "longer than this will be truncated, and sequences shorter than this will be padded.") args = parser.parse_args() if not os.path.exists(args.outdir): os.mkdir(args.outdir) args.cuda = not args.no_cuda and torch.cuda.is_available() handlers = [logging.FileHandler(os.path.abspath(args.outdir) + '/train_log.txt'), logging.StreamHandler()] logging.basicConfig(format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.DEBUG, datefmt='%d-%m-%Y %H:%M:%S', handlers=handlers) logger.info(args) logger.info('Start training!') tokenizer = BertTokenizer.from_pretrained(args.bert_pretrain, do_lower_case=False) logger.info("loading training set") validset_reader = DataLoaderTest(args.test_path, tokenizer, args, batch_size=args.batch_size) logger.info('initializing estimator model') bert_model = BertForSequenceEncoder.from_pretrained(args.bert_pretrain) bert_model = bert_model.cuda() model = inference_model(bert_model, args) model.load_state_dict(torch.load(args.checkpoint)['model']) model = model.cuda() logger.info('Start eval!') save_path = args.outdir + "/" + args.name predict_dict = eval_model(model, validset_reader) save_to_file(predict_dict, save_path, args.evi_num)