Spaces:
Build error
Build error
File size: 8,653 Bytes
7f7285f 20e9c0d 7f7285f 20e9c0d 7f7285f f18943d 7f7285f f18943d 7f7285f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
# -*- coding: utf-8 -*-
"""
@Author : Bao
@Date : 2020/8/12
@Desc :
@Last modified by : Bao
@Last modified date : 2020/8/20
"""
import os
import sys
import logging
import torch
import numpy as np
from tqdm import tqdm
import tensorflow as tf
import ujson as json
import argparse
import cjjpy as cjj
from itertools import repeat
from torch.utils.data import DataLoader, SequentialSampler
from transformers import (
BertConfig, BertTokenizer, AutoTokenizer,
RobertaConfig, RobertaTokenizer,
)
try:
from .modules.data_processor import DataProcessor
from .plm_checkers import BertChecker, RobertaChecker
from .utils import read_json_lines, compute_metrics, set_seed
from .train import do_evaluate
from ..eval_client.fever_scorer import FeverScorer
except:
sys.path.append(cjj.AbsParentDir(__file__, '.'))
sys.path.append(cjj.AbsParentDir(__file__, '..'))
from eval_client.fever_scorer import FeverScorer
from modules.data_processor import DataProcessor
from plm_checkers import BertChecker, RobertaChecker
from utils import read_json_lines, compute_metrics, set_seed
from train import do_evaluate
MODEL_MAPPING = {
'bert': (BertConfig, BertTokenizer, BertChecker),
'roberta': (RobertaConfig, RobertaTokenizer, RobertaChecker),
}
logger = logging.getLogger(__name__)
label2id = {"SUPPORTS": 2, "REFUTES": 0, 'NOT ENOUGH INFO': 1}
id2label = {v: k for k, v in label2id.items()}
class FactChecker:
def __init__(self, args, fc_ckpt_dir=None, mask_rate=0.):
self.data_processor = None
self.tokenizer = None
self.model = None
self.args = args
self.args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.ckpt = args.fc_dir if fc_ckpt_dir is None else fc_ckpt_dir
self.mask_rate = mask_rate
logger.info('Initializing fact checker.')
self._prepare_ckpt(self.args.model_name_or_path, self.ckpt)
self.load_model()
def _prepare_ckpt(self, model_name_or_path, ckpt_dir):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.save_pretrained(ckpt_dir)
def load_model(self):
if self.model is None:
self.data_processor = DataProcessor(
self.args.model_name_or_path,
self.args.max_seq1_length,
self.args.max_seq2_length,
self.args.max_num_questions,
self.args.cand_k,
mask_rate=self.mask_rate
)
_, tokenizer_class, model_class = MODEL_MAPPING[self.args.model_type]
self.tokenizer = tokenizer_class.from_pretrained(
self.ckpt,
do_lower_case=self.args.do_lower_case
)
self.model = model_class.from_pretrained(
self.ckpt,
from_tf=bool(".ckpt" in self.ckpt),
logic_lambda=self.args.logic_lambda,
prior=self.args.prior,
)
self.model = torch.nn.DataParallel(self.model)
def _check(self, inputs: list, batch_size=32, verbose=True):
dataset = self.data_processor.convert_inputs_to_dataset(inputs, self.tokenizer, verbose=verbose)
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
with torch.no_grad():
self.model.to(self.args.device)
self.model.eval()
iter = tqdm(dataloader, desc="Fact Checking") if verbose else dataloader
_, y_predicted, z_predicted, m_attn, mask = \
do_evaluate(iter, self.model, self.args, during_training=False, with_label=False)
return y_predicted, z_predicted, m_attn, mask
def check_from_file(self, in_filename, out_filename, batch_size, verbose=False):
if 'test' in in_filename:
raw_inp = f'{os.environ["PJ_HOME"]}/data/fever/shared_task_test.jsonl'
else:
raw_inp = None
tf.io.gfile.makedirs(os.path.dirname(out_filename))
inputs = list(read_json_lines(in_filename))
y_predicted, z_predicted, m_attn, mask = self._check(inputs, batch_size)
z_predicted = repeat(None) if z_predicted is None else z_predicted
m_attn = repeat(None) if m_attn is None else m_attn
ordered_results = {}
with_label = inputs[0].get('label') is not None
if with_label:
label_truth = [label2id[x['label']] for x in inputs]
_, acc_results = compute_metrics(label_truth, y_predicted, z_predicted, mask)
else:
acc_results = {}
for i, (inp, y, z, attn, _mask) in \
enumerate(zip(inputs, y_predicted, z_predicted, m_attn, mask)):
result = {'id': inp['id'],
'predicted_label': id2label[y],
'predicted_evidence': inp.get('predicted_evidence', [])}
if verbose:
if i < 5:
print("{}\t{}\t{}".format(inp.get("id", i), inp["claim"], y))
if z is not None and attn is not None:
result.update({
'z_prob': z[:torch.tensor(_mask).sum()],
'm_attn': attn[:torch.tensor(_mask).sum()],
})
ordered_results[inp['id']] = result
with tf.io.gfile.GFile(out_filename, 'w') as fout:
if raw_inp:
with tf.io.gfile.GFile(raw_inp) as f:
for line in f:
raw_js = json.loads(line)
fout.write(json.dumps(ordered_results[raw_js['id']]) + '\n')
else:
for k in ordered_results:
fout.write(json.dumps(ordered_results[k]) + '\n')
if ('dev' in in_filename or 'val' in in_filename) and with_label:
scorer = FeverScorer()
fever_results = scorer.get_scores(out_filename)
fever_results.update(acc_results)
print(fever_results)
return fever_results
def check_from_batch(self, inputs: list, verbose=False):
y_predicted, z_predicted, m_attn, mask = self._check(inputs, len(inputs), verbose)
return y_predicted, z_predicted, m_attn
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', '-i', required=True, type=str,
choices=['val', 'eval', 'test', 'demo'])
parser.add_argument('--output', '-o', default='none', type=str)
parser.add_argument('--ckpt', '-c', required=True, type=str)
parser.add_argument('--model_type', default='roberta', type=str,
choices=['roberta', 'bert'])
parser.add_argument('--model_name_or_path', default='roberta-large', type=str)
parser.add_argument('--verbose', '-v', action='store_true', default=False,
help='whether output phrasal veracity or not')
parser.add_argument('--logic_lambda', '-l', required=True, type=float)
parser.add_argument('--prior', default='random', type=str, choices=['nli', 'uniform', 'logic', 'random'],
help='type of prior distribution')
parser.add_argument('--mask_rate', '-m', default=0., type=float)
parser.add_argument('--cand_k', '-k', default=3, type=int)
parser.add_argument('--max_seq1_length', default=256, type=int)
parser.add_argument('--max_seq2_length', default=128, type=int)
parser.add_argument('--max_num_questions', default=8, type=int)
parser.add_argument('--do_lower_case', action='store_true', default=False)
parser.add_argument('--batch_size', '-b', default=64, type=int)
parser.add_argument('--seed', default=42)
parser.add_argument('--n_gpu', default=4)
args = parser.parse_args()
set_seed(args)
if args.output == 'none':
args.ckpt = args.ckpt[:-1] if args.ckpt.endswith('/') else args.ckpt
base_name = os.path.basename(args.ckpt)
args.output = f'{os.environ["PJ_HOME"]}/results/fact_checking/AAAI22/{args.input}.{args.model_name_or_path}_m{args.mask_rate}_l{args.logic_lambda}_{base_name}_{args.prior}.predictions.jsonl'
assert args.output.endswith('predictions.jsonl'), \
f"{args.output} must end with predictions.jsonl"
args.input = f'{os.environ["PJ_HOME"]}/data/fact_checking/v5/{args.input}.json'
checker = FactChecker(args, args.ckpt, args.mask_rate)
fever_results = checker.check_from_file(args.input, args.output, args.batch_size, args.verbose)
cjj.lark(f"{args.output}: {fever_results}")
|