# -*- coding: utf-8 -*- ''' @Author : Jiangjie Chen @Time : 2020/9/17 15:55 @Contact : jjchen19@fudan.edu.cn @Description: ''' import os import sys import json import logging import cjjpy as cjj try: from .qg_client.question_generator import QuestionGenerator from .mrc_client.answer_generator import AnswerGenerator, chunks, assemble_answers_to_one from .parsing_client.sentence_parser import SentenceParser, deal_bracket from .check_client.fact_checker import FactChecker, id2label from .er_client import EvidenceRetrieval except: sys.path.append(cjj.AbsParentDir(__file__, '.')) from qg_client.question_generator import QuestionGenerator from mrc_client.answer_generator import AnswerGenerator, chunks, assemble_answers_to_one from parsing_client.sentence_parser import SentenceParser, deal_bracket from check_client.fact_checker import FactChecker, id2label from er_client import EvidenceRetrieval def load_config(config): if isinstance(config, str): with open(config) as f: config = json.load(f) cfg = cjj.AttrDict(config) return cfg class Loren: def __init__(self, config_file, verbose=True): self.verbose = verbose self.args = load_config(config_file) self.sent_client = SentenceParser() self.qg_client = QuestionGenerator('t5', verbose=False) self.ag_client = AnswerGenerator(self.args.mrc_dir) self.fc_client = FactChecker(self.args, self.args.fc_dir) self.er_client = EvidenceRetrieval(self.args.er_dir) self.logger = cjj.init_logger(f'{os.environ["PJ_HOME"]}/results/loren_dev.log', log_file_level=logging.INFO if self.verbose else logging.WARNING) self.logger.info('*** Loren initialized. ***') def check(self, claim, evidence=None): self.logger.info('*** Verifying "%s"... ***' % claim) js = self.prep(claim, evidence) js['id'] = 0 y_predicted, z_predicted, m_attn = self.fc_client.check_from_batch([js], verbose=self.verbose) label = id2label[y_predicted[0]] # Update js js['local_premises'] = assemble_answers_to_one(js, k=3) js['evidence'] = [self.fc_client.tokenizer.clean_up_tokenization(e[2]) for e in js['evidence']] js['questions'] = [self.fc_client.tokenizer.clean_up_tokenization(q) for q in js['questions']] js['claim_phrases'] = [self.fc_client.tokenizer.clean_up_tokenization(a[0]) for a in js['answers']] js['local_premises'] = [[self.fc_client.tokenizer.clean_up_tokenization(a) for a in aa] for aa in js['local_premises']] # js['m_attn'] = m_attn[0][:len(js['claim_phrases'])] js['phrase_veracity'] = z_predicted[0][:len(js['claim_phrases'])] js['claim_veracity'] = label self.logger.info(" * Intermediary: %s *" % str(js)) self.logger.info('*** Verification completed: "%s" ***' % label) return js def prep(self, claim, evidence=None): ''' :param evidence: 'aaa||bbb||ccc' / [entity, num, evidence, (prob)] if not None ''' evidence = self._prep_evidence(claim, evidence) self.logger.info(' * Evidence prepared. *') assert isinstance(evidence, list) js = {'claim': claim, 'evidence': evidence} js = self._prep_claim_phrases(js) self.logger.info(' * Claim phrases prepared. *') js = self._prep_questions(js) self.logger.info(' * Probing questions prepared. *') js = self._prep_evidential_phrases(js) self.logger.info(' * Evidential phrases prepared. *') return js def _prep_claim_phrases(self, js): results = self.sent_client.identify_NPs(deal_bracket(js['claim'], True), candidate_NPs=[x[0] for x in js['evidence']]) NPs = results['NPs'] claim = results['text'] verbs = results['verbs'] adjs = results['adjs'] _cache = {'claim': claim, 'evidence': js['evidence'], 'answers': NPs + verbs + adjs, 'answer_roles': ['noun'] * len(NPs) + ['verb'] * len(verbs) + ['adj'] * len(adjs)} if len(_cache['answers']) == 0: _cache['answers'] = js['claim'].split()[0] _cache['answer_roles'] = ['noun'] return _cache def _prep_questions(self, js): _cache = [] for answer in js['answers']: _cache.append((js['claim'], [answer])) qa_pairs = self.qg_client.generate([(x, y) for x, y in _cache]) for q, clz_q, a in qa_pairs: if 'questions' in js: js['regular_qs'].append(q) js['cloze_qs'].append(clz_q) js['questions'].append(self.qg_client.assemble_question(q, clz_q)) else: js['regular_qs'] = [q] js['cloze_qs'] = [clz_q] js['questions'] = [self.qg_client.assemble_question(q, clz_q)] return js def _prep_evidential_phrases(self, js): examples = [] for q in js['questions']: ex = self.ag_client.assemble(q, " ".join([x[2] for x in js['evidence']])) examples.append(ex) predicted = self.ag_client.generate(examples, num_beams=self.args['cand_k'], num_return_sequences=self.args['cand_k'], batch_size=2, verbose=False) for answers in predicted: if 'evidential' in js: js['evidential'].append(answers) else: js['evidential'] = [answers] return js def _prep_evidence(self, claim, evidence=None): ''' :param evidence: 'aaa||bbb||ccc' / [entity, num, evidence, (prob)] :return: [entity, num, evidence, (prob)] ''' if evidence in [None, '', 'null', 'NULL', 'Null']: evidence = self.er_client.retrieve(claim) evidence = [(ev[0], ev[1], deal_bracket(ev[2], True, ev[0])) for ev in evidence] else: if isinstance(evidence, str): # TODO: magic sentence number evidence = [("None", i, ev.strip()) for i, ev in enumerate(evidence.split('||')[:5])] return evidence if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--config', '-c', type=str, required=True, default='available_models/aaai22_roberta.json', help='Config json file with hyper-parameters') args = parser.parse_args() loren = Loren(args.config) while True: claim = input('> ') label, js = loren.check(claim) print(label) print(js)