File size: 2,363 Bytes
8516514
 
 
 
 
 
 
 
85c308f
8516514
85c308f
 
8516514
 
 
 
 
 
 
cc370e5
 
 
 
8516514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from transformers import RobertaTokenizerFast
from utils import RobertaConfig
from model import RobertaForQuestionAnswering
from safetensors.torch import load_file
from datasets import load_dataset
from pprintpp import pprint

# import os

# # Đặt biến môi trường HF_HOME
# os.environ["HF_HOME"] = "/tmp/hf_cache" 

class InferenceModel:
    """
        Quick inference function that works with the models we have trained!
    """

    def __init__(self, path_to_weights, huggingface_model=True):

        self.config = {
            "hf_model_name": "deepset/roberta-base-squad2" # Ví dụ model của bạn
        }
        ### Init Config with either Huggingface Backbone or our own ###
        self.config = RobertaConfig(pretrained_backbone="pretrained_huggingface" if huggingface_model else "random")

        ### Load Tokenizer ###
        self.tokenizer = RobertaTokenizerFast.from_pretrained(self.config.hf_model_name)

        ### Load Model ###
        self.model = RobertaForQuestionAnswering(self.config)

        weights = load_file(path_to_weights)
        self.model.load_state_dict(weights)

        self.model.eval()

    def inference_model(self,
                        question,
                        context):
        ### Tokenize Text
        inputs = self.tokenizer(text=question,
                                text_pair=context,
                                max_length=self.config.context_length,
                                truncation="only_second",
                                return_tensors="pt")
        pass
        ### Pass through Model ####
        with torch.no_grad():
            start_token_logits, end_token_logits = self.model(**inputs)

        ### Grab Start and End Token Idx ###
        start_token_idx = start_token_logits.squeeze().argmax().item()
        end_token_idx = end_token_logits.squeeze().argmax().item()


        ### Slice Tokens and then Decode with Tokenizer (+1 because slice is not right inclusive) ###
        tokens = inputs["input_ids"].squeeze()[start_token_idx:end_token_idx + 1]
        answer = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()

        prediction = {"start_token_idx": start_token_idx,
                      "end_token_idx": end_token_idx,
                      "answer": answer}

        return prediction

    # test model