jiangjiechen's picture
init loren for spaces
7f7285f
raw
history blame
6.87 kB
# -*- coding: utf-8 -*-
'''
@Author : Jiangjie Chen
@Time : 2020/9/17 15:55
@Contact : [email protected]
@Description:
'''
import os
import sys
import json
import logging
import cjjpy as cjj
try:
from .qg_client.question_generator import QuestionGenerator
from .mrc_client.answer_generator import AnswerGenerator, chunks, assemble_answers_to_one
from .parsing_client.sentence_parser import SentenceParser, deal_bracket
from .check_client.fact_checker import FactChecker, id2label
from .er_client import EvidenceRetrieval
except:
sys.path.append(cjj.AbsParentDir(__file__, '.'))
from qg_client.question_generator import QuestionGenerator
from mrc_client.answer_generator import AnswerGenerator, chunks, assemble_answers_to_one
from parsing_client.sentence_parser import SentenceParser, deal_bracket
from check_client.fact_checker import FactChecker, id2label
from er_client import EvidenceRetrieval
def load_config(config):
if isinstance(config, str):
with open(config) as f:
config = json.load(f)
cfg = cjj.AttrDict(config)
return cfg
class Loren:
def __init__(self, config_file, verbose=True):
self.verbose = verbose
self.args = load_config(config_file)
self.sent_client = SentenceParser()
self.qg_client = QuestionGenerator('t5', verbose=False)
self.ag_client = AnswerGenerator(self.args.mrc_dir)
self.fc_client = FactChecker(self.args, self.args.fc_dir)
self.er_client = EvidenceRetrieval(self.args.er_dir)
self.logger = cjj.init_logger(f'{os.environ["PJ_HOME"]}/results/loren_dev.log',
log_file_level=logging.INFO if self.verbose else logging.WARNING)
self.logger.info('*** Loren initialized. ***')
def check(self, claim, evidence=None):
self.logger.info('*** Verifying "%s"... ***' % claim)
js = self.prep(claim, evidence)
js['id'] = 0
y_predicted, z_predicted, m_attn = self.fc_client.check_from_batch([js], verbose=self.verbose)
label = id2label[y_predicted[0]]
# Update js
js['local_premises'] = assemble_answers_to_one(js, k=3)
js['evidence'] = [self.fc_client.tokenizer.clean_up_tokenization(e[2]) for e in js['evidence']]
js['questions'] = [self.fc_client.tokenizer.clean_up_tokenization(q) for q in js['questions']]
js['claim_phrases'] = [self.fc_client.tokenizer.clean_up_tokenization(a[0]) for a in js['answers']]
js['local_premises'] = [[self.fc_client.tokenizer.clean_up_tokenization(a) for a in aa]
for aa in js['local_premises']]
# js['m_attn'] = m_attn[0][:len(js['claim_phrases'])]
js['phrase_veracity'] = z_predicted[0][:len(js['claim_phrases'])]
js['claim_veracity'] = label
self.logger.info(" * Intermediary: %s *" % str(js))
self.logger.info('*** Verification completed: "%s" ***' % label)
return js
def prep(self, claim, evidence=None):
'''
:param evidence: 'aaa||bbb||ccc' / [entity, num, evidence, (prob)] if not None
'''
evidence = self._prep_evidence(claim, evidence)
self.logger.info(' * Evidence prepared. *')
assert isinstance(evidence, list)
js = {'claim': claim, 'evidence': evidence}
js = self._prep_claim_phrases(js)
self.logger.info(' * Claim phrases prepared. *')
js = self._prep_questions(js)
self.logger.info(' * Probing questions prepared. *')
js = self._prep_evidential_phrases(js)
self.logger.info(' * Evidential phrases prepared. *')
return js
def _prep_claim_phrases(self, js):
results = self.sent_client.identify_NPs(deal_bracket(js['claim'], True),
candidate_NPs=[x[0] for x in js['evidence']])
NPs = results['NPs']
claim = results['text']
verbs = results['verbs']
adjs = results['adjs']
_cache = {'claim': claim,
'evidence': js['evidence'],
'answers': NPs + verbs + adjs,
'answer_roles': ['noun'] * len(NPs) + ['verb'] * len(verbs) + ['adj'] * len(adjs)}
if len(_cache['answers']) == 0:
_cache['answers'] = js['claim'].split()[0]
_cache['answer_roles'] = ['noun']
return _cache
def _prep_questions(self, js):
_cache = []
for answer in js['answers']:
_cache.append((js['claim'], [answer]))
qa_pairs = self.qg_client.generate([(x, y) for x, y in _cache])
for q, clz_q, a in qa_pairs:
if 'questions' in js:
js['regular_qs'].append(q)
js['cloze_qs'].append(clz_q)
js['questions'].append(self.qg_client.assemble_question(q, clz_q))
else:
js['regular_qs'] = [q]
js['cloze_qs'] = [clz_q]
js['questions'] = [self.qg_client.assemble_question(q, clz_q)]
return js
def _prep_evidential_phrases(self, js):
examples = []
for q in js['questions']:
ex = self.ag_client.assemble(q, " ".join([x[2] for x in js['evidence']]))
examples.append(ex)
predicted = self.ag_client.generate(examples, num_beams=self.args['cand_k'],
num_return_sequences=self.args['cand_k'],
batch_size=2, verbose=False)
for answers in predicted:
if 'evidential' in js:
js['evidential'].append(answers)
else:
js['evidential'] = [answers]
return js
def _prep_evidence(self, claim, evidence=None):
'''
:param evidence: 'aaa||bbb||ccc' / [entity, num, evidence, (prob)]
:return: [entity, num, evidence, (prob)]
'''
if evidence in [None, '', 'null', 'NULL', 'Null']:
evidence = self.er_client.retrieve(claim)
evidence = [(ev[0], ev[1], deal_bracket(ev[2], True, ev[0])) for ev in evidence]
else:
if isinstance(evidence, str):
# TODO: magic sentence number
evidence = [("None", i, ev.strip()) for i, ev in enumerate(evidence.split('||')[:5])]
return evidence
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c', type=str, required=True,
default='available_models/aaai22_roberta.json',
help='Config json file with hyper-parameters')
args = parser.parse_args()
loren = Loren(args.config)
while True:
claim = input('> ')
label, js = loren.check(claim)
print(label)
print(js)