File size: 6,868 Bytes
7f7285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# -*- coding: utf-8 -*-

'''
@Author     : Jiangjie Chen
@Time       : 2020/9/17 15:55
@Contact    : [email protected]
@Description: 
'''

import os
import sys
import json
import logging
import cjjpy as cjj

try:
    from .qg_client.question_generator import QuestionGenerator
    from .mrc_client.answer_generator import AnswerGenerator, chunks, assemble_answers_to_one
    from .parsing_client.sentence_parser import SentenceParser, deal_bracket
    from .check_client.fact_checker import FactChecker, id2label
    from .er_client import EvidenceRetrieval
except:
    sys.path.append(cjj.AbsParentDir(__file__, '.'))
    from qg_client.question_generator import QuestionGenerator
    from mrc_client.answer_generator import AnswerGenerator, chunks, assemble_answers_to_one
    from parsing_client.sentence_parser import SentenceParser, deal_bracket
    from check_client.fact_checker import FactChecker, id2label
    from er_client import EvidenceRetrieval


def load_config(config):
    if isinstance(config, str):
        with open(config) as f:
            config = json.load(f)
    cfg = cjj.AttrDict(config)
    return cfg


class Loren:
    def __init__(self, config_file, verbose=True):
        self.verbose = verbose
        self.args = load_config(config_file)
        self.sent_client = SentenceParser()
        self.qg_client = QuestionGenerator('t5', verbose=False)
        self.ag_client = AnswerGenerator(self.args.mrc_dir)
        self.fc_client = FactChecker(self.args, self.args.fc_dir)
        self.er_client = EvidenceRetrieval(self.args.er_dir)
        self.logger = cjj.init_logger(f'{os.environ["PJ_HOME"]}/results/loren_dev.log',
                                      log_file_level=logging.INFO if self.verbose else logging.WARNING)
        self.logger.info('*** Loren initialized. ***')

    def check(self, claim, evidence=None):
        self.logger.info('*** Verifying "%s"... ***' % claim)
        js = self.prep(claim, evidence)
        js['id'] = 0
        y_predicted, z_predicted, m_attn = self.fc_client.check_from_batch([js], verbose=self.verbose)
        label = id2label[y_predicted[0]]

        # Update js
        js['local_premises'] = assemble_answers_to_one(js, k=3)
        js['evidence'] = [self.fc_client.tokenizer.clean_up_tokenization(e[2]) for e in js['evidence']]
        js['questions'] = [self.fc_client.tokenizer.clean_up_tokenization(q) for q in js['questions']]
        js['claim_phrases'] = [self.fc_client.tokenizer.clean_up_tokenization(a[0]) for a in js['answers']]
        js['local_premises'] = [[self.fc_client.tokenizer.clean_up_tokenization(a) for a in aa]
                                for aa in js['local_premises']]
        # js['m_attn'] = m_attn[0][:len(js['claim_phrases'])]
        js['phrase_veracity'] = z_predicted[0][:len(js['claim_phrases'])]
        js['claim_veracity'] = label

        self.logger.info("  * Intermediary: %s *" % str(js))
        self.logger.info('*** Verification completed: "%s" ***' % label)
        return js

    def prep(self, claim, evidence=None):
        '''
        :param evidence: 'aaa||bbb||ccc' / [entity, num, evidence, (prob)] if not None
        '''
        evidence = self._prep_evidence(claim, evidence)
        self.logger.info('  * Evidence prepared. *')
        assert isinstance(evidence, list)

        js = {'claim': claim, 'evidence': evidence}
        js = self._prep_claim_phrases(js)
        self.logger.info('  * Claim phrases prepared. *')
        js = self._prep_questions(js)
        self.logger.info('  * Probing questions prepared. *')
        js = self._prep_evidential_phrases(js)
        self.logger.info('  * Evidential phrases prepared. *')
        return js

    def _prep_claim_phrases(self, js):
        results = self.sent_client.identify_NPs(deal_bracket(js['claim'], True),
                                                candidate_NPs=[x[0] for x in js['evidence']])
        NPs = results['NPs']
        claim = results['text']
        verbs = results['verbs']
        adjs = results['adjs']
        _cache = {'claim': claim,
                  'evidence': js['evidence'],
                  'answers': NPs + verbs + adjs,
                  'answer_roles': ['noun'] * len(NPs) + ['verb'] * len(verbs) + ['adj'] * len(adjs)}
        if len(_cache['answers']) == 0:
            _cache['answers'] = js['claim'].split()[0]
            _cache['answer_roles'] = ['noun']
        return _cache

    def _prep_questions(self, js):
        _cache = []
        for answer in js['answers']:
            _cache.append((js['claim'], [answer]))
        qa_pairs = self.qg_client.generate([(x, y) for x, y in _cache])
        for q, clz_q, a in qa_pairs:
            if 'questions' in js:
                js['regular_qs'].append(q)
                js['cloze_qs'].append(clz_q)
                js['questions'].append(self.qg_client.assemble_question(q, clz_q))
            else:
                js['regular_qs'] = [q]
                js['cloze_qs'] = [clz_q]
                js['questions'] = [self.qg_client.assemble_question(q, clz_q)]
        return js

    def _prep_evidential_phrases(self, js):
        examples = []
        for q in js['questions']:
            ex = self.ag_client.assemble(q, " ".join([x[2] for x in js['evidence']]))
            examples.append(ex)
        predicted = self.ag_client.generate(examples, num_beams=self.args['cand_k'],
                                            num_return_sequences=self.args['cand_k'],
                                            batch_size=2, verbose=False)
        for answers in predicted:
            if 'evidential' in js:
                js['evidential'].append(answers)
            else:
                js['evidential'] = [answers]
        return js

    def _prep_evidence(self, claim, evidence=None):
        '''
        :param evidence: 'aaa||bbb||ccc' / [entity, num, evidence, (prob)]
        :return: [entity, num, evidence, (prob)]
        '''
        if evidence in [None, '', 'null', 'NULL', 'Null']:
            evidence = self.er_client.retrieve(claim)
            evidence = [(ev[0], ev[1], deal_bracket(ev[2], True, ev[0])) for ev in evidence]
        else:
            if isinstance(evidence, str):
                # TODO: magic sentence number
                evidence = [("None", i, ev.strip()) for i, ev in enumerate(evidence.split('||')[:5])]
        return evidence


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str, required=True,
                        default='available_models/aaai22_roberta.json',
                        help='Config json file with hyper-parameters')
    args = parser.parse_args()

    loren = Loren(args.config)
    while True:
        claim = input('> ')
        label, js = loren.check(claim)
        print(label)
        print(js)