jiangjiechen's picture
init loren for spaces
7f7285f
raw
history blame
3.74 kB
import logging
import argparse
import os
import json
import torch
from tqdm import tqdm
from transformers import BertTokenizer
from .models import inference_model
from .data_loader import DataLoaderTest
from .bert_model import BertForSequenceEncoder
logger = logging.getLogger(__name__)
def save_to_file(all_predict, outpath, evi_num):
with open(outpath, "w") as out:
for key, values in all_predict.items():
sorted_values = sorted(values, key=lambda x:x[-1], reverse=True)
data = json.dumps({"id": key, "evidence": sorted_values[:evi_num]})
out.write(data + "\n")
def eval_model(model, validset_reader):
model.eval()
all_predict = dict()
for inp_tensor, msk_tensor, seg_tensor, ids, evi_list in tqdm(validset_reader):
probs = model(inp_tensor, msk_tensor, seg_tensor)
probs = probs.tolist()
assert len(probs) == len(evi_list)
for i in range(len(probs)):
if ids[i] not in all_predict:
all_predict[ids[i]] = []
#if probs[i][1] >= probs[i][0]:
all_predict[ids[i]].append(evi_list[i] + [probs[i]])
return all_predict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--test_path', help='train path')
parser.add_argument('--name', help='train path')
parser.add_argument("--batch_size", default=32, type=int, help="Total batch size for training.")
parser.add_argument('--outdir', required=True, help='path to output directory')
parser.add_argument('--bert_pretrain', required=True)
parser.add_argument('--checkpoint', required=True)
parser.add_argument('--dropout', type=float, default=0.6, help='Dropout.')
parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.')
parser.add_argument("--bert_hidden_dim", default=768, type=int, help="Total batch size for training.")
parser.add_argument("--layer", type=int, default=1, help='Graph Layer.')
parser.add_argument("--num_labels", type=int, default=3)
parser.add_argument("--evi_num", type=int, default=5, help='Evidence num.')
parser.add_argument("--threshold", type=float, default=0.0, help='Evidence num.')
parser.add_argument("--max_len", default=120, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.")
args = parser.parse_args()
if not os.path.exists(args.outdir):
os.mkdir(args.outdir)
args.cuda = not args.no_cuda and torch.cuda.is_available()
handlers = [logging.FileHandler(os.path.abspath(args.outdir) + '/train_log.txt'), logging.StreamHandler()]
logging.basicConfig(format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.DEBUG,
datefmt='%d-%m-%Y %H:%M:%S', handlers=handlers)
logger.info(args)
logger.info('Start training!')
tokenizer = BertTokenizer.from_pretrained(args.bert_pretrain, do_lower_case=False)
logger.info("loading training set")
validset_reader = DataLoaderTest(args.test_path, tokenizer, args, batch_size=args.batch_size)
logger.info('initializing estimator model')
bert_model = BertForSequenceEncoder.from_pretrained(args.bert_pretrain)
bert_model = bert_model.cuda()
model = inference_model(bert_model, args)
model.load_state_dict(torch.load(args.checkpoint)['model'])
model = model.cuda()
logger.info('Start eval!')
save_path = args.outdir + "/" + args.name
predict_dict = eval_model(model, validset_reader)
save_to_file(predict_dict, save_path, args.evi_num)