File size: 4,066 Bytes
7a57069 0229607 7a57069 0229607 7a57069 8d76254 7a57069 0229607 340599d 0229607 7a57069 8d76254 1d624af 8d76254 1d624af 8d76254 1d624af 8d76254 1d624af 8d76254 1d624af 8d76254 1d624af 8d76254 1d624af 8d76254 1d624af 8d76254 1d624af 8d76254 1d624af 8d76254 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
from typing import Dict, List, Any
from tokenizers import Tokenizer
import sys
import os
import warnings
# Add the current directory to the system path to locate the model module
sys.path.append(os.path.dirname(__file__))
from model import build_transformer
warnings.simplefilter("ignore", category=FutureWarning)
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the handler. Load the model and tokenizer.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
# Paths for weights and tokenizers
self.model_weights_path = os.path.join(path, "EN-IT.pt")
self.tokenizer_src_path = os.path.join(path, "tokenizer_en.json")
self.tokenizer_tgt_path = os.path.join(path, "tokenizer_it.json")
# Load tokenizers
self.tokenizer_src = Tokenizer.from_file(self.tokenizer_src_path)
self.tokenizer_tgt = Tokenizer.from_file(self.tokenizer_tgt_path)
# Build the transformer model
self.model = build_transformer(
src_vocab_size=self.tokenizer_src.get_vocab_size(),
tgt_vocab_size=self.tokenizer_tgt.get_vocab_size(),
src_seq_len=350, # Match the trained model's sequence length
tgt_seq_len=350, # Match the trained model's sequence length
d_model=512,
num_layers=6,
num_heads=8,
dropout=0.1,
d_ff=2048
).to(self.device)
# Load the pretrained weights
print(f"Loading weights from: {self.model_weights_path}")
checkpoint = torch.load(self.model_weights_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the incoming request and return the translation.
"""
try:
inputs = data.get("inputs", "")
if not inputs:
return [{"error": "No 'inputs' provided in request"}]
source = self.tokenizer_src.encode(inputs)
source = torch.cat([
torch.tensor([self.tokenizer_src.token_to_id("[SOS]")], dtype=torch.int64),
torch.tensor(source.ids, dtype=torch.int64),
torch.tensor([self.tokenizer_src.token_to_id("[EOS]")], dtype=torch.int64),
torch.tensor([self.tokenizer_src.token_to_id("[PAD]")] * (350 - len(source.ids) - 2), dtype=torch.int64)
], dim=0).to(self.device)
source_mask = (source != self.tokenizer_src.token_to_id("[PAD]")).unsqueeze(0).unsqueeze(1).int().to(self.device)
encoder_output = self.model.encode(source, source_mask)
decoder_input = torch.empty(1, 1).fill_(self.tokenizer_tgt.token_to_id("[SOS]")).type_as(source).to(self.device)
predicted_words = []
while decoder_input.size(1) < 350:
decoder_mask = torch.triu(
torch.ones((1, decoder_input.size(1), decoder_input.size(1))),
diagonal=1
).type(torch.int).type_as(source_mask).to(self.device)
out = self.model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
prob = self.model.project(out[:, -1])
_, next_word = torch.max(prob, dim=1)
decoder_input = torch.cat(
[decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(self.device)], dim=1)
decoded_word = self.tokenizer_tgt.decode([next_word.item()])
if next_word == self.tokenizer_tgt.token_to_id("[EOS]"):
break
predicted_words.append(decoded_word)
predicted_translation = " ".join(predicted_words).replace("[EOS]", "").strip()
return [{"translation": predicted_translation}]
except Exception as e:
return [{"error": str(e)}]
|