Spaces:
Build error
Build error
# -*- coding: utf-8 -*- | |
''' | |
@Author : Jiangjie Chen | |
@Time : 2020/9/17 15:55 | |
@Contact : [email protected] | |
@Description: | |
''' | |
import os | |
import sys | |
import json | |
import logging | |
import cjjpy as cjj | |
try: | |
from .qg_client.question_generator import QuestionGenerator | |
from .mrc_client.answer_generator import AnswerGenerator, chunks, assemble_answers_to_one | |
from .parsing_client.sentence_parser import SentenceParser, deal_bracket | |
from .check_client.fact_checker import FactChecker, id2label | |
from .er_client import EvidenceRetrieval | |
except: | |
sys.path.append(cjj.AbsParentDir(__file__, '.')) | |
from qg_client.question_generator import QuestionGenerator | |
from mrc_client.answer_generator import AnswerGenerator, chunks, assemble_answers_to_one | |
from parsing_client.sentence_parser import SentenceParser, deal_bracket | |
from check_client.fact_checker import FactChecker, id2label | |
from er_client import EvidenceRetrieval | |
def load_config(config): | |
if isinstance(config, str): | |
with open(config) as f: | |
config = json.load(f) | |
cfg = cjj.AttrDict(config) | |
return cfg | |
class Loren: | |
def __init__(self, config_file, verbose=True): | |
self.verbose = verbose | |
self.args = load_config(config_file) | |
self.sent_client = SentenceParser() | |
self.qg_client = QuestionGenerator('t5', verbose=False) | |
self.ag_client = AnswerGenerator(self.args.mrc_dir) | |
self.fc_client = FactChecker(self.args, self.args.fc_dir) | |
self.er_client = EvidenceRetrieval(self.args.er_dir) | |
self.logger = cjj.init_logger(f'{os.environ["PJ_HOME"]}/results/loren_dev.log', | |
log_file_level=logging.INFO if self.verbose else logging.WARNING) | |
self.logger.info('*** Loren initialized. ***') | |
def check(self, claim, evidence=None): | |
self.logger.info('*** Verifying "%s"... ***' % claim) | |
js = self.prep(claim, evidence) | |
js['id'] = 0 | |
y_predicted, z_predicted, m_attn = self.fc_client.check_from_batch([js], verbose=self.verbose) | |
label = id2label[y_predicted[0]] | |
# Update js | |
js['local_premises'] = assemble_answers_to_one(js, k=3) | |
js['evidence'] = [self.fc_client.tokenizer.clean_up_tokenization(e[2]) for e in js['evidence']] | |
js['questions'] = [self.fc_client.tokenizer.clean_up_tokenization(q) for q in js['questions']] | |
js['claim_phrases'] = [self.fc_client.tokenizer.clean_up_tokenization(a[0]) for a in js['answers']] | |
js['local_premises'] = [[self.fc_client.tokenizer.clean_up_tokenization(a) for a in aa] | |
for aa in js['local_premises']] | |
# js['m_attn'] = m_attn[0][:len(js['claim_phrases'])] | |
js['phrase_veracity'] = z_predicted[0][:len(js['claim_phrases'])] | |
js['claim_veracity'] = label | |
self.logger.info(" * Intermediary: %s *" % str(js)) | |
self.logger.info('*** Verification completed: "%s" ***' % label) | |
return js | |
def prep(self, claim, evidence=None): | |
''' | |
:param evidence: 'aaa||bbb||ccc' / [entity, num, evidence, (prob)] if not None | |
''' | |
evidence = self._prep_evidence(claim, evidence) | |
self.logger.info(' * Evidence prepared. *') | |
assert isinstance(evidence, list) | |
js = {'claim': claim, 'evidence': evidence} | |
js = self._prep_claim_phrases(js) | |
self.logger.info(' * Claim phrases prepared. *') | |
js = self._prep_questions(js) | |
self.logger.info(' * Probing questions prepared. *') | |
js = self._prep_evidential_phrases(js) | |
self.logger.info(' * Evidential phrases prepared. *') | |
return js | |
def _prep_claim_phrases(self, js): | |
results = self.sent_client.identify_NPs(deal_bracket(js['claim'], True), | |
candidate_NPs=[x[0] for x in js['evidence']]) | |
NPs = results['NPs'] | |
claim = results['text'] | |
verbs = results['verbs'] | |
adjs = results['adjs'] | |
_cache = {'claim': claim, | |
'evidence': js['evidence'], | |
'answers': NPs + verbs + adjs, | |
'answer_roles': ['noun'] * len(NPs) + ['verb'] * len(verbs) + ['adj'] * len(adjs)} | |
if len(_cache['answers']) == 0: | |
_cache['answers'] = js['claim'].split()[0] | |
_cache['answer_roles'] = ['noun'] | |
return _cache | |
def _prep_questions(self, js): | |
_cache = [] | |
for answer in js['answers']: | |
_cache.append((js['claim'], [answer])) | |
qa_pairs = self.qg_client.generate([(x, y) for x, y in _cache]) | |
for q, clz_q, a in qa_pairs: | |
if 'questions' in js: | |
js['regular_qs'].append(q) | |
js['cloze_qs'].append(clz_q) | |
js['questions'].append(self.qg_client.assemble_question(q, clz_q)) | |
else: | |
js['regular_qs'] = [q] | |
js['cloze_qs'] = [clz_q] | |
js['questions'] = [self.qg_client.assemble_question(q, clz_q)] | |
return js | |
def _prep_evidential_phrases(self, js): | |
examples = [] | |
for q in js['questions']: | |
ex = self.ag_client.assemble(q, " ".join([x[2] for x in js['evidence']])) | |
examples.append(ex) | |
predicted = self.ag_client.generate(examples, num_beams=self.args['cand_k'], | |
num_return_sequences=self.args['cand_k'], | |
batch_size=2, verbose=False) | |
for answers in predicted: | |
if 'evidential' in js: | |
js['evidential'].append(answers) | |
else: | |
js['evidential'] = [answers] | |
return js | |
def _prep_evidence(self, claim, evidence=None): | |
''' | |
:param evidence: 'aaa||bbb||ccc' / [entity, num, evidence, (prob)] | |
:return: [entity, num, evidence, (prob)] | |
''' | |
if evidence in [None, '', 'null', 'NULL', 'Null']: | |
evidence = self.er_client.retrieve(claim) | |
evidence = [(ev[0], ev[1], deal_bracket(ev[2], True, ev[0])) for ev in evidence] | |
else: | |
if isinstance(evidence, str): | |
# TODO: magic sentence number | |
evidence = [("None", i, ev.strip()) for i, ev in enumerate(evidence.split('||')[:5])] | |
return evidence | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--config', '-c', type=str, required=True, | |
default='available_models/aaai22_roberta.json', | |
help='Config json file with hyper-parameters') | |
args = parser.parse_args() | |
loren = Loren(args.config) | |
while True: | |
claim = input('> ') | |
label, js = loren.check(claim) | |
print(label) | |
print(js) | |