Spaces:
Build error
Build error
File size: 3,735 Bytes
7f7285f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import logging
import argparse
import os
import json
import torch
from tqdm import tqdm
from transformers import BertTokenizer
from .models import inference_model
from .data_loader import DataLoaderTest
from .bert_model import BertForSequenceEncoder
logger = logging.getLogger(__name__)
def save_to_file(all_predict, outpath, evi_num):
with open(outpath, "w") as out:
for key, values in all_predict.items():
sorted_values = sorted(values, key=lambda x:x[-1], reverse=True)
data = json.dumps({"id": key, "evidence": sorted_values[:evi_num]})
out.write(data + "\n")
def eval_model(model, validset_reader):
model.eval()
all_predict = dict()
for inp_tensor, msk_tensor, seg_tensor, ids, evi_list in tqdm(validset_reader):
probs = model(inp_tensor, msk_tensor, seg_tensor)
probs = probs.tolist()
assert len(probs) == len(evi_list)
for i in range(len(probs)):
if ids[i] not in all_predict:
all_predict[ids[i]] = []
#if probs[i][1] >= probs[i][0]:
all_predict[ids[i]].append(evi_list[i] + [probs[i]])
return all_predict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--test_path', help='train path')
parser.add_argument('--name', help='train path')
parser.add_argument("--batch_size", default=32, type=int, help="Total batch size for training.")
parser.add_argument('--outdir', required=True, help='path to output directory')
parser.add_argument('--bert_pretrain', required=True)
parser.add_argument('--checkpoint', required=True)
parser.add_argument('--dropout', type=float, default=0.6, help='Dropout.')
parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.')
parser.add_argument("--bert_hidden_dim", default=768, type=int, help="Total batch size for training.")
parser.add_argument("--layer", type=int, default=1, help='Graph Layer.')
parser.add_argument("--num_labels", type=int, default=3)
parser.add_argument("--evi_num", type=int, default=5, help='Evidence num.')
parser.add_argument("--threshold", type=float, default=0.0, help='Evidence num.')
parser.add_argument("--max_len", default=120, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.")
args = parser.parse_args()
if not os.path.exists(args.outdir):
os.mkdir(args.outdir)
args.cuda = not args.no_cuda and torch.cuda.is_available()
handlers = [logging.FileHandler(os.path.abspath(args.outdir) + '/train_log.txt'), logging.StreamHandler()]
logging.basicConfig(format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.DEBUG,
datefmt='%d-%m-%Y %H:%M:%S', handlers=handlers)
logger.info(args)
logger.info('Start training!')
tokenizer = BertTokenizer.from_pretrained(args.bert_pretrain, do_lower_case=False)
logger.info("loading training set")
validset_reader = DataLoaderTest(args.test_path, tokenizer, args, batch_size=args.batch_size)
logger.info('initializing estimator model')
bert_model = BertForSequenceEncoder.from_pretrained(args.bert_pretrain)
bert_model = bert_model.cuda()
model = inference_model(bert_model, args)
model.load_state_dict(torch.load(args.checkpoint)['model'])
model = model.cuda()
logger.info('Start eval!')
save_path = args.outdir + "/" + args.name
predict_dict = eval_model(model, validset_reader)
save_to_file(predict_dict, save_path, args.evi_num) |