File size: 2,019 Bytes
7f7285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# -*- coding: utf-8 -*-

'''
@Author     : Jiangjie Chen
@Time       : 2020/9/20 11:42
@Contact    : [email protected]
@Description: 
'''

import torch
from transformers import BertTokenizer
from .retrieval_model.bert_model import BertForSequenceEncoder
from .retrieval_model.models import inference_model
from .retrieval_model.data_loader import DataLoaderTest


class SentSelector:
    def __init__(self, pretrained_bert_path, select_model_path, args):
        self.args = args
        self.use_cuda = self.args.use_cuda and torch.cuda.is_available()

        self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
        self.bert_model = BertForSequenceEncoder.from_pretrained(pretrained_bert_path)

        self.rank_model = inference_model(self.bert_model, self.args)
        self.rank_model.load_state_dict(torch.load(select_model_path)['model'])

        if self.use_cuda:
            self.bert_model = self.bert_model.cuda()
            self.rank_model.cuda()

    def rank_sentences(self, js: list):
        '''
        :param js: [{'claim': xxx, 'id': xx, 'evidence': xxx}]
        :return: [(ent, num, sent, prob), (ent, num, sent, prob)]
        '''
        data_reader = DataLoaderTest(js, self.tokenizer, self.args, self.use_cuda)
        self.rank_model.eval()
        all_predict = dict()
        for inp_tensor, msk_tensor, seg_tensor, ids, evi_list in data_reader:
            probs = self.rank_model(inp_tensor, msk_tensor, seg_tensor)
            probs = probs.tolist()
            assert len(probs) == len(evi_list)
            for i in range(len(probs)):
                if ids[i] not in all_predict:
                    all_predict[ids[i]] = []
                # if probs[i][1] >= probs[i][0]:
                all_predict[ids[i]].append(tuple(evi_list[i]) + (probs[i],))

        results = {}
        for k, v in all_predict.items():
            sorted_v = sorted(v, key=lambda x: x[-1], reverse=True)
            results[k] = sorted_v[:self.args.evi_num]
        return results