# -*- coding: utf-8 -*- """ @Author : Bao @Date : 2020/8/12 @Desc : @Last modified by : Bao @Last modified date : 2020/8/20 """ import os import sys import logging import torch import numpy as np from tqdm import tqdm import tensorflow as tf import ujson as json import argparse import cjjpy as cjj from itertools import repeat from torch.utils.data import DataLoader, SequentialSampler from transformers import ( BertConfig, BertTokenizer, AutoTokenizer, RobertaConfig, RobertaTokenizer, ) try: from .modules.data_processor import DataProcessor from .plm_checkers import BertChecker, RobertaChecker from .utils import read_json_lines, compute_metrics, set_seed from .train import do_evaluate from ..eval_client.fever_scorer import FeverScorer except: sys.path.append(cjj.AbsParentDir(__file__, '.')) sys.path.append(cjj.AbsParentDir(__file__, '..')) from eval_client.fever_scorer import FeverScorer from modules.data_processor import DataProcessor from plm_checkers import BertChecker, RobertaChecker from utils import read_json_lines, compute_metrics, set_seed from train import do_evaluate MODEL_MAPPING = { 'bert': (BertConfig, BertTokenizer, BertChecker), 'roberta': (RobertaConfig, RobertaTokenizer, RobertaChecker), } logger = logging.getLogger(__name__) label2id = {"SUPPORTS": 2, "REFUTES": 0, 'NOT ENOUGH INFO': 1} id2label = {v: k for k, v in label2id.items()} class FactChecker: def __init__(self, args, fc_ckpt_dir=None, mask_rate=0.): self.data_processor = None self.tokenizer = None self.model = None self.args = args self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.ckpt = args.fc_dir if fc_ckpt_dir is None else fc_ckpt_dir self.mask_rate = mask_rate logger.info('Initializing fact checker.') self._prepare_ckpt(self.args.model_name_or_path, self.ckpt) self.load_model() def _prepare_ckpt(self, model_name_or_path, ckpt_dir): tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) tokenizer.save_pretrained(ckpt_dir) def load_model(self): if self.model is None: self.data_processor = DataProcessor( self.args.model_name_or_path, self.args.max_seq1_length, self.args.max_seq2_length, self.args.max_num_questions, self.args.cand_k, mask_rate=self.mask_rate ) _, tokenizer_class, model_class = MODEL_MAPPING[self.args.model_type] self.tokenizer = tokenizer_class.from_pretrained( self.ckpt, do_lower_case=self.args.do_lower_case ) self.model = model_class.from_pretrained( self.ckpt, from_tf=bool(".ckpt" in self.ckpt), logic_lambda=self.args.logic_lambda, prior=self.args.prior, ) self.model = torch.nn.DataParallel(self.model) def _check(self, inputs: list, batch_size=32, verbose=True): dataset = self.data_processor.convert_inputs_to_dataset(inputs, self.tokenizer, verbose=verbose) sampler = SequentialSampler(dataset) dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size) with torch.no_grad(): self.model.to(self.args.device) self.model.eval() iter = tqdm(dataloader, desc="Fact Checking") if verbose else dataloader _, y_predicted, z_predicted, m_attn, mask = \ do_evaluate(iter, self.model, self.args, during_training=False, with_label=False) return y_predicted, z_predicted, m_attn, mask def check_from_file(self, in_filename, out_filename, batch_size, verbose=False): if 'test' in in_filename: raw_inp = f'{os.environ["PJ_HOME"]}/data/fever/shared_task_test.jsonl' else: raw_inp = None tf.io.gfile.makedirs(os.path.dirname(out_filename)) inputs = list(read_json_lines(in_filename)) y_predicted, z_predicted, m_attn, mask = self._check(inputs, batch_size) z_predicted = repeat(None) if z_predicted is None else z_predicted m_attn = repeat(None) if m_attn is None else m_attn ordered_results = {} with_label = inputs[0].get('label') is not None if with_label: label_truth = [label2id[x['label']] for x in inputs] _, acc_results = compute_metrics(label_truth, y_predicted, z_predicted, mask) else: acc_results = {} for i, (inp, y, z, attn, _mask) in \ enumerate(zip(inputs, y_predicted, z_predicted, m_attn, mask)): result = {'id': inp['id'], 'predicted_label': id2label[y], 'predicted_evidence': inp.get('predicted_evidence', [])} if verbose: if i < 5: print("{}\t{}\t{}".format(inp.get("id", i), inp["claim"], y)) if z is not None and attn is not None: result.update({ 'z_prob': z[:torch.tensor(_mask).sum()], 'm_attn': attn[:torch.tensor(_mask).sum()], }) ordered_results[inp['id']] = result with tf.io.gfile.GFile(out_filename, 'w') as fout: if raw_inp: with tf.io.gfile.GFile(raw_inp) as f: for line in f: raw_js = json.loads(line) fout.write(json.dumps(ordered_results[raw_js['id']]) + '\n') else: for k in ordered_results: fout.write(json.dumps(ordered_results[k]) + '\n') if ('dev' in in_filename or 'val' in in_filename) and with_label: scorer = FeverScorer() fever_results = scorer.get_scores(out_filename) fever_results.update(acc_results) print(fever_results) return fever_results def check_from_batch(self, inputs: list, verbose=False): y_predicted, z_predicted, m_attn, mask = self._check(inputs, len(inputs), verbose) return y_predicted, z_predicted, m_attn if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input', '-i', required=True, type=str, choices=['val', 'eval', 'test', 'demo']) parser.add_argument('--output', '-o', default='none', type=str) parser.add_argument('--ckpt', '-c', required=True, type=str) parser.add_argument('--model_type', default='roberta', type=str, choices=['roberta', 'bert']) parser.add_argument('--model_name_or_path', default='roberta-large', type=str) parser.add_argument('--verbose', '-v', action='store_true', default=False, help='whether output phrasal veracity or not') parser.add_argument('--logic_lambda', '-l', required=True, type=float) parser.add_argument('--prior', default='random', type=str, choices=['nli', 'uniform', 'logic', 'random'], help='type of prior distribution') parser.add_argument('--mask_rate', '-m', default=0., type=float) parser.add_argument('--cand_k', '-k', default=3, type=int) parser.add_argument('--max_seq1_length', default=256, type=int) parser.add_argument('--max_seq2_length', default=128, type=int) parser.add_argument('--max_num_questions', default=8, type=int) parser.add_argument('--do_lower_case', action='store_true', default=False) parser.add_argument('--batch_size', '-b', default=64, type=int) parser.add_argument('--seed', default=42) parser.add_argument('--n_gpu', default=4) args = parser.parse_args() set_seed(args) if args.output == 'none': args.ckpt = args.ckpt[:-1] if args.ckpt.endswith('/') else args.ckpt base_name = os.path.basename(args.ckpt) args.output = f'{os.environ["PJ_HOME"]}/results/fact_checking/AAAI22/{args.input}.{args.model_name_or_path}_m{args.mask_rate}_l{args.logic_lambda}_{base_name}_{args.prior}.predictions.jsonl' assert args.output.endswith('predictions.jsonl'), \ f"{args.output} must end with predictions.jsonl" args.input = f'{os.environ["PJ_HOME"]}/data/fact_checking/v5/{args.input}.json' checker = FactChecker(args, args.ckpt, args.mask_rate) fever_results = checker.check_from_file(args.input, args.output, args.batch_size, args.verbose) cjj.lark(f"{args.output}: {fever_results}")