Spaces:
Sleeping
Sleeping
import torch | |
from transformers import RobertaTokenizerFast | |
from utils import RobertaConfig | |
from model import RobertaForQuestionAnswering | |
from safetensors.torch import load_file | |
from datasets import load_dataset | |
from pprintpp import pprint | |
# import os | |
# # Đặt biến môi trường HF_HOME | |
# os.environ["HF_HOME"] = "/tmp/hf_cache" | |
class InferenceModel: | |
""" | |
Quick inference function that works with the models we have trained! | |
""" | |
def __init__(self, path_to_weights, huggingface_model=True): | |
self.config = { | |
"hf_model_name": "deepset/roberta-base-squad2" # Ví dụ model của bạn | |
} | |
### Init Config with either Huggingface Backbone or our own ### | |
self.config = RobertaConfig(pretrained_backbone="pretrained_huggingface" if huggingface_model else "random") | |
### Load Tokenizer ### | |
self.tokenizer = RobertaTokenizerFast.from_pretrained(self.config.hf_model_name) | |
### Load Model ### | |
self.model = RobertaForQuestionAnswering(self.config) | |
weights = load_file(path_to_weights) | |
self.model.load_state_dict(weights) | |
self.model.eval() | |
def inference_model(self, | |
question, | |
context): | |
### Tokenize Text | |
inputs = self.tokenizer(text=question, | |
text_pair=context, | |
max_length=self.config.context_length, | |
truncation="only_second", | |
return_tensors="pt") | |
pass | |
### Pass through Model #### | |
with torch.no_grad(): | |
start_token_logits, end_token_logits = self.model(**inputs) | |
### Grab Start and End Token Idx ### | |
start_token_idx = start_token_logits.squeeze().argmax().item() | |
end_token_idx = end_token_logits.squeeze().argmax().item() | |
### Slice Tokens and then Decode with Tokenizer (+1 because slice is not right inclusive) ### | |
tokens = inputs["input_ids"].squeeze()[start_token_idx:end_token_idx + 1] | |
answer = self.tokenizer.decode(tokens, skip_special_tokens=True).strip() | |
prediction = {"start_token_idx": start_token_idx, | |
"end_token_idx": end_token_idx, | |
"answer": answer} | |
return prediction | |
# test model |