Spaces:
Build error
Build error
# -*- coding: utf-8 -*- | |
''' | |
@Author : Jiangjie Chen | |
@Time : 2020/9/21 16:13 | |
@Contact : [email protected] | |
@Description: | |
''' | |
import cjjpy as cjj | |
import os | |
# from .document_retrieval import DocRetrieval | |
from .doc_retrieval_by_api import DocRetrieval | |
from .sentence_selection import SentSelector | |
arg_values = { | |
'batch_size': 32, | |
'dropout': 0.6, | |
'use_cuda': True, | |
'bert_hidden_dim': 768, | |
'layer': 1, | |
'num_labels': 3, | |
'evi_num': 5, | |
'threshold': 0.0, | |
'max_len': 120, | |
} | |
args = cjj.AttrDict(arg_values) | |
class EvidenceRetrieval: | |
def __init__(self, er_model_dir=cjj.AbsParentDir(__file__, '...', 'models/evidence_retrieval/')): | |
# self.doc_retriever = DocRetrieval(cjj.AbsParentDir(__file__, '...', 'data/fever.db'), | |
# add_claim=True, k_wiki_results=7) | |
self.doc_retrieval = DocRetrieval(link_type='tagme') | |
self.sent_selector = SentSelector(os.path.join(er_model_dir, 'bert_base/'), | |
os.path.join(er_model_dir, 'retrieval_model/model.best.pt'), | |
args) | |
def retrieve(self, claim): | |
# noun_phrases, wiki_results, predicted_pages = self.doc_retriever.exact_match(claim) | |
# evidence = [] | |
# for page in predicted_pages: | |
# evidence.extend(self.doc_retriever.db.get_doc_lines(page)) | |
evidence = self.doc_retrieval.retrieve_docs(claim) | |
evidence = self.rank_sentences(claim, evidence) | |
return evidence | |
def rank_sentences(self, claim, sentences, id=None): | |
''' | |
:param claim: str | |
:param sentences: [(ent, num, sent) * N] | |
:param id: | |
:return: [(ent, num, sent) * k] | |
''' | |
if id is None: | |
id = len(claim) | |
result = self.sent_selector.rank_sentences([{'claim': claim, | |
'evidence': sentences, | |
'id': id}]) | |
evidence = result.get(id, []) | |
return evidence |