File size: 2,090 Bytes
7f7285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# -*- coding: utf-8 -*-

'''
@Author     : Jiangjie Chen
@Time       : 2020/9/21 16:13
@Contact    : [email protected]
@Description:
'''

import cjjpy as cjj
import os
# from .document_retrieval import DocRetrieval
from .doc_retrieval_by_api import DocRetrieval
from .sentence_selection import SentSelector


arg_values = {
    'batch_size': 32,
    'dropout': 0.6,
    'use_cuda': True,
    'bert_hidden_dim': 768,
    'layer': 1,
    'num_labels': 3,
    'evi_num': 5,
    'threshold': 0.0,
    'max_len': 120,
}

args = cjj.AttrDict(arg_values)

class EvidenceRetrieval:
    def __init__(self, er_model_dir=cjj.AbsParentDir(__file__, '...', 'models/evidence_retrieval/')):
        # self.doc_retriever = DocRetrieval(cjj.AbsParentDir(__file__, '...', 'data/fever.db'),
        #                                   add_claim=True, k_wiki_results=7)
        self.doc_retrieval = DocRetrieval(link_type='tagme')
        self.sent_selector = SentSelector(os.path.join(er_model_dir, 'bert_base/'),
                                          os.path.join(er_model_dir, 'retrieval_model/model.best.pt'),
                                          args)

    def retrieve(self, claim):
        # noun_phrases, wiki_results, predicted_pages = self.doc_retriever.exact_match(claim)
        # evidence = []
        # for page in predicted_pages:
        #     evidence.extend(self.doc_retriever.db.get_doc_lines(page))
        evidence = self.doc_retrieval.retrieve_docs(claim)
        evidence = self.rank_sentences(claim, evidence)
        return evidence

    def rank_sentences(self, claim, sentences, id=None):
        '''
        :param claim: str
        :param sentences: [(ent, num, sent) * N]
        :param id:
        :return: [(ent, num, sent) * k]
        '''
        if id is None:
            id = len(claim)

        result = self.sent_selector.rank_sentences([{'claim': claim,
                                                     'evidence': sentences,
                                                     'id': id}])
        evidence = result.get(id, [])
        return evidence