File size: 1,942 Bytes
ccba2d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from transformers import BertTokenizer, BertForQuestionAnswering
import torch

class QA(object):
    def __init__(self, 
                 model_name = 'bert-large-uncased-whole-word-masking-finetuned-squad'):
      
        self.model_name = model_name
        
        self.__load_model_and_tokenizer()
    
    def __load_model_and_tokenizer(self):
        self.model = BertForQuestionAnswering.from_pretrained(self.model_name)
        self.tokenizer = BertTokenizer.from_pretrained(self.model_name)

    def __get_segment_ids(self, input_ids):
        # Search the input_ids for the first instance of the `[SEP]` token.
        sep_index = input_ids.index(self.tokenizer.sep_token_id)

        # The number of segment A tokens includes the [SEP] token istelf.
        num_seg_a = sep_index + 1

        # The remainder are segment B.
        num_seg_b = len(input_ids) - num_seg_a

        # Construct the list of 0s and 1s.
        segment_ids = [0]*num_seg_a + [1]*num_seg_b

        # There should be a segment_id for every input token.
        assert len(segment_ids) == len(input_ids)
        
        return segment_ids

    def answer_question(self, query, passage):
        input_ids = self.tokenizer.encode(query, passage)
        segment_ids = self.__get_segment_ids(input_ids)

        # Run our example through the model.
        outputs = self.model(torch.tensor([input_ids]), # The tokens representing our input text.
                            token_type_ids=torch.tensor([segment_ids]), # The segment IDs to differentiate question from answer_text
                            return_dict=True) 

        start_scores = outputs.start_logits
        end_scores = outputs.end_logits

        # Find the tokens with the highest `start` and `end` scores.
        answer_start = torch.argmax(start_scores)
        answer_end = torch.argmax(end_scores)

        return self.tokenizer.decode(input_ids[answer_start:answer_end+1])