Spaces:
Build error
Build error
File size: 6,868 Bytes
7f7285f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# -*- coding: utf-8 -*-
'''
@Author : Jiangjie Chen
@Time : 2020/9/17 15:55
@Contact : [email protected]
@Description:
'''
import os
import sys
import json
import logging
import cjjpy as cjj
try:
from .qg_client.question_generator import QuestionGenerator
from .mrc_client.answer_generator import AnswerGenerator, chunks, assemble_answers_to_one
from .parsing_client.sentence_parser import SentenceParser, deal_bracket
from .check_client.fact_checker import FactChecker, id2label
from .er_client import EvidenceRetrieval
except:
sys.path.append(cjj.AbsParentDir(__file__, '.'))
from qg_client.question_generator import QuestionGenerator
from mrc_client.answer_generator import AnswerGenerator, chunks, assemble_answers_to_one
from parsing_client.sentence_parser import SentenceParser, deal_bracket
from check_client.fact_checker import FactChecker, id2label
from er_client import EvidenceRetrieval
def load_config(config):
if isinstance(config, str):
with open(config) as f:
config = json.load(f)
cfg = cjj.AttrDict(config)
return cfg
class Loren:
def __init__(self, config_file, verbose=True):
self.verbose = verbose
self.args = load_config(config_file)
self.sent_client = SentenceParser()
self.qg_client = QuestionGenerator('t5', verbose=False)
self.ag_client = AnswerGenerator(self.args.mrc_dir)
self.fc_client = FactChecker(self.args, self.args.fc_dir)
self.er_client = EvidenceRetrieval(self.args.er_dir)
self.logger = cjj.init_logger(f'{os.environ["PJ_HOME"]}/results/loren_dev.log',
log_file_level=logging.INFO if self.verbose else logging.WARNING)
self.logger.info('*** Loren initialized. ***')
def check(self, claim, evidence=None):
self.logger.info('*** Verifying "%s"... ***' % claim)
js = self.prep(claim, evidence)
js['id'] = 0
y_predicted, z_predicted, m_attn = self.fc_client.check_from_batch([js], verbose=self.verbose)
label = id2label[y_predicted[0]]
# Update js
js['local_premises'] = assemble_answers_to_one(js, k=3)
js['evidence'] = [self.fc_client.tokenizer.clean_up_tokenization(e[2]) for e in js['evidence']]
js['questions'] = [self.fc_client.tokenizer.clean_up_tokenization(q) for q in js['questions']]
js['claim_phrases'] = [self.fc_client.tokenizer.clean_up_tokenization(a[0]) for a in js['answers']]
js['local_premises'] = [[self.fc_client.tokenizer.clean_up_tokenization(a) for a in aa]
for aa in js['local_premises']]
# js['m_attn'] = m_attn[0][:len(js['claim_phrases'])]
js['phrase_veracity'] = z_predicted[0][:len(js['claim_phrases'])]
js['claim_veracity'] = label
self.logger.info(" * Intermediary: %s *" % str(js))
self.logger.info('*** Verification completed: "%s" ***' % label)
return js
def prep(self, claim, evidence=None):
'''
:param evidence: 'aaa||bbb||ccc' / [entity, num, evidence, (prob)] if not None
'''
evidence = self._prep_evidence(claim, evidence)
self.logger.info(' * Evidence prepared. *')
assert isinstance(evidence, list)
js = {'claim': claim, 'evidence': evidence}
js = self._prep_claim_phrases(js)
self.logger.info(' * Claim phrases prepared. *')
js = self._prep_questions(js)
self.logger.info(' * Probing questions prepared. *')
js = self._prep_evidential_phrases(js)
self.logger.info(' * Evidential phrases prepared. *')
return js
def _prep_claim_phrases(self, js):
results = self.sent_client.identify_NPs(deal_bracket(js['claim'], True),
candidate_NPs=[x[0] for x in js['evidence']])
NPs = results['NPs']
claim = results['text']
verbs = results['verbs']
adjs = results['adjs']
_cache = {'claim': claim,
'evidence': js['evidence'],
'answers': NPs + verbs + adjs,
'answer_roles': ['noun'] * len(NPs) + ['verb'] * len(verbs) + ['adj'] * len(adjs)}
if len(_cache['answers']) == 0:
_cache['answers'] = js['claim'].split()[0]
_cache['answer_roles'] = ['noun']
return _cache
def _prep_questions(self, js):
_cache = []
for answer in js['answers']:
_cache.append((js['claim'], [answer]))
qa_pairs = self.qg_client.generate([(x, y) for x, y in _cache])
for q, clz_q, a in qa_pairs:
if 'questions' in js:
js['regular_qs'].append(q)
js['cloze_qs'].append(clz_q)
js['questions'].append(self.qg_client.assemble_question(q, clz_q))
else:
js['regular_qs'] = [q]
js['cloze_qs'] = [clz_q]
js['questions'] = [self.qg_client.assemble_question(q, clz_q)]
return js
def _prep_evidential_phrases(self, js):
examples = []
for q in js['questions']:
ex = self.ag_client.assemble(q, " ".join([x[2] for x in js['evidence']]))
examples.append(ex)
predicted = self.ag_client.generate(examples, num_beams=self.args['cand_k'],
num_return_sequences=self.args['cand_k'],
batch_size=2, verbose=False)
for answers in predicted:
if 'evidential' in js:
js['evidential'].append(answers)
else:
js['evidential'] = [answers]
return js
def _prep_evidence(self, claim, evidence=None):
'''
:param evidence: 'aaa||bbb||ccc' / [entity, num, evidence, (prob)]
:return: [entity, num, evidence, (prob)]
'''
if evidence in [None, '', 'null', 'NULL', 'Null']:
evidence = self.er_client.retrieve(claim)
evidence = [(ev[0], ev[1], deal_bracket(ev[2], True, ev[0])) for ev in evidence]
else:
if isinstance(evidence, str):
# TODO: magic sentence number
evidence = [("None", i, ev.strip()) for i, ev in enumerate(evidence.split('||')[:5])]
return evidence
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c', type=str, required=True,
default='available_models/aaai22_roberta.json',
help='Config json file with hyper-parameters')
args = parser.parse_args()
loren = Loren(args.config)
while True:
claim = input('> ')
label, js = loren.check(claim)
print(label)
print(js)
|