D_Nikud / handler.py
NadavShaked's picture
Upload 7 files
91da6cc verified
from typing import Dict, List, Any
from transformers import AutoConfig, AutoTokenizer
from src.models import DNikudModel, ModelConfig
from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
from src.utiles_data import Nikud, NikudDataset
from src.models_utils import predict_single, predict
import torch
import os
from tqdm import tqdm
class EndpointHandler:
def __init__(self, path=""):
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained("tau/tavbert-he")
dir_model_config = os.path.join("models", "config.yml")
self.config = ModelConfig.load_from_file(dir_model_config)
self.model = DNikudModel(
self.config,
len(Nikud.label_2_id["nikud"]),
len(Nikud.label_2_id["dagesh"]),
len(Nikud.label_2_id["sin"]),
device=self.DEVICE,
).to(self.DEVICE)
state_dict_model = self.model.state_dict()
state_dict_model.update(torch.load("./models/Dnikud_best_model.pth"))
self.model.load_state_dict(state_dict_model)
self.max_length = MAX_LENGTH_SEN
def back_2_text(self, labels, text):
nikud = Nikud()
new_line = ""
for indx_char, c in enumerate(text):
new_line += (
c
+ nikud.id_2_char(labels[indx_char][1][1], "dagesh")
+ nikud.id_2_char(labels[indx_char][1][2], "sin")
+ nikud.id_2_char(labels[indx_char][1][0], "nikud")
)
print(indx_char, c)
print(labels)
return new_line
def prepare_data(self, data, name="train"):
print("Data = ", data)
dataset = []
for index, (sentence, label) in tqdm(
enumerate(data), desc=f"Prepare data {name}"
):
encoded_sequence = self.tokenizer.encode_plus(
sentence,
add_special_tokens=True,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
label_lists = [
[letter.nikud, letter.dagesh, letter.sin] for letter in label
]
label = torch.tensor(
[
[
Nikud.PAD_OR_IRRELEVANT,
Nikud.PAD_OR_IRRELEVANT,
Nikud.PAD_OR_IRRELEVANT,
]
]
+ label_lists[: (self.max_length - 1)]
+ [
[
Nikud.PAD_OR_IRRELEVANT,
Nikud.PAD_OR_IRRELEVANT,
Nikud.PAD_OR_IRRELEVANT,
]
for i in range(self.max_length - len(label) - 1)
]
)
dataset.append(
(
encoded_sequence["input_ids"][0],
encoded_sequence["attention_mask"][0],
label,
)
)
self.prepered_data = dataset
def predict_single_text(
self,
text,
):
dataset = NikudDataset(tokenizer=self.tokenizer, max_length=MAX_LENGTH_SEN)
data, orig_data = dataset.read_single_text(text)
print("data", data, len(data))
dataset.prepare_data(name="inference")
mtb_prediction_dl = torch.utils.data.DataLoader(
dataset.prepered_data, batch_size=BATCH_SIZE
)
# print("dataset", dataset, len(dataset))
# data = self.tokenizer(text, return_tensors="pt")
all_labels = predict(self.model, mtb_prediction_dl, self.DEVICE)
text_data_with_labels = dataset.back_2_text(labels=all_labels)
# all_labels = predict_single(self.model, dataset, self.DEVICE)
return text_data_with_labels
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
"""
# get inputs
inputs = data.pop("text", data)
# run normal prediction
prediction = self.predict_single_text(inputs)
# result = []
# for pred in prediction:
# result.append(self.back_2_text(pred, inputs))
# result = self.back_2_text(prediction, inputs)
return prediction