question-answering-api / inference.py
TangSan003's picture
Update inference.py
85c308f verified
import torch
from transformers import RobertaTokenizerFast
from utils import RobertaConfig
from model import RobertaForQuestionAnswering
from safetensors.torch import load_file
from datasets import load_dataset
from pprintpp import pprint
# import os
# # Đặt biến môi trường HF_HOME
# os.environ["HF_HOME"] = "/tmp/hf_cache"
class InferenceModel:
"""
Quick inference function that works with the models we have trained!
"""
def __init__(self, path_to_weights, huggingface_model=True):
self.config = {
"hf_model_name": "deepset/roberta-base-squad2" # Ví dụ model của bạn
}
### Init Config with either Huggingface Backbone or our own ###
self.config = RobertaConfig(pretrained_backbone="pretrained_huggingface" if huggingface_model else "random")
### Load Tokenizer ###
self.tokenizer = RobertaTokenizerFast.from_pretrained(self.config.hf_model_name)
### Load Model ###
self.model = RobertaForQuestionAnswering(self.config)
weights = load_file(path_to_weights)
self.model.load_state_dict(weights)
self.model.eval()
def inference_model(self,
question,
context):
### Tokenize Text
inputs = self.tokenizer(text=question,
text_pair=context,
max_length=self.config.context_length,
truncation="only_second",
return_tensors="pt")
pass
### Pass through Model ####
with torch.no_grad():
start_token_logits, end_token_logits = self.model(**inputs)
### Grab Start and End Token Idx ###
start_token_idx = start_token_logits.squeeze().argmax().item()
end_token_idx = end_token_logits.squeeze().argmax().item()
### Slice Tokens and then Decode with Tokenizer (+1 because slice is not right inclusive) ###
tokens = inputs["input_ids"].squeeze()[start_token_idx:end_token_idx + 1]
answer = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
prediction = {"start_token_idx": start_token_idx,
"end_token_idx": end_token_idx,
"answer": answer}
return prediction
# test model