Spaces:
Build error
Build error
# -*- coding:utf-8 -*- | |
""" | |
@Author : Bao | |
@Date : 2021/9/7 | |
@Desc : | |
@Last modified by : Bao | |
@Last modified date : 2021/9/7 | |
""" | |
import json | |
import numpy as np | |
import argparse | |
from collections import defaultdict | |
from sklearn.metrics import precision_recall_fscore_support | |
# ref --> label 1, nei & sup --> label 0 | |
idx2label = {0: 1, 1: 0, 2: 0} | |
def read_json_lines(filename, mode='r', encoding='utf-8', skip=0): | |
with open(filename, mode, encoding=encoding) as fin: | |
for line in fin: | |
if skip > 0: | |
skip -= 1 | |
continue | |
yield json.loads(line) | |
def process(filein): | |
id2info = defaultdict(dict) | |
for line in read_json_lines('eval.human.ref.merged.json'): | |
labels = [0] * len(line['questions']) | |
for cul in line['culprit']: | |
labels[cul] = 1 | |
id2info[line['id']].update({'id': line['id'], 'labels': labels}) | |
for line in read_json_lines(filein): | |
if line['id'] not in id2info: continue | |
predicted = [idx2label[_] for _ in np.argmax(line['z_prob'], axis=-1)] | |
id2info[line['id']]['predicted'] = predicted | |
ps, rs, fs = [], [], [] | |
for info in id2info.values(): | |
p, r, f, _ = precision_recall_fscore_support(info['labels'], info['predicted'], average='binary') | |
ps.append(p) | |
rs.append(r) | |
fs.append(f) | |
print(filein) | |
print('Precision: {}'.format(sum(ps) / len(ps))) | |
print('Recall: {}'.format(sum(rs) / len(rs))) | |
print('F1: {}'.format(sum(fs) / len(fs))) | |
return sum(ps) / len(ps), sum(rs) / len(rs), sum(fs) / len(fs) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-i', type=str, help='predicted jsonl file with phrasal veracity predictions.') | |
args = parser.parse_args() | |
process(args.i) | |