File size: 4,066 Bytes
7a57069
 
 
0229607
 
7a57069
 
0229607
 
 
 
 
7a57069
 
8d76254
7a57069
 
 
 
 
 
 
 
0229607
340599d
0229607
 
7a57069
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d76254
 
 
 
 
 
 
 
1d624af
8d76254
 
 
 
 
 
 
1d624af
8d76254
 
1d624af
8d76254
 
1d624af
8d76254
 
 
 
 
1d624af
8d76254
 
 
1d624af
8d76254
 
1d624af
8d76254
 
 
1d624af
8d76254
1d624af
8d76254
1d624af
8d76254
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from typing import Dict, List, Any
from tokenizers import Tokenizer
import sys
import os
import warnings

# Add the current directory to the system path to locate the model module
sys.path.append(os.path.dirname(__file__))

from model import build_transformer

warnings.simplefilter("ignore", category=FutureWarning)


class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize the handler. Load the model and tokenizer.
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device

        # Paths for weights and tokenizers
        self.model_weights_path = os.path.join(path, "EN-IT.pt")
        self.tokenizer_src_path = os.path.join(path, "tokenizer_en.json")
        self.tokenizer_tgt_path = os.path.join(path, "tokenizer_it.json")

        # Load tokenizers
        self.tokenizer_src = Tokenizer.from_file(self.tokenizer_src_path)
        self.tokenizer_tgt = Tokenizer.from_file(self.tokenizer_tgt_path)

        # Build the transformer model
        self.model = build_transformer(
            src_vocab_size=self.tokenizer_src.get_vocab_size(),
            tgt_vocab_size=self.tokenizer_tgt.get_vocab_size(),
            src_seq_len=350,  # Match the trained model's sequence length
            tgt_seq_len=350,  # Match the trained model's sequence length
            d_model=512,
            num_layers=6,
            num_heads=8,
            dropout=0.1,
            d_ff=2048
        ).to(self.device)

        # Load the pretrained weights
        print(f"Loading weights from: {self.model_weights_path}")
        checkpoint = torch.load(self.model_weights_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process the incoming request and return the translation.
        """
        try:
            inputs = data.get("inputs", "")
            if not inputs:
                return [{"error": "No 'inputs' provided in request"}]

            source = self.tokenizer_src.encode(inputs)
            source = torch.cat([
                torch.tensor([self.tokenizer_src.token_to_id("[SOS]")], dtype=torch.int64),
                torch.tensor(source.ids, dtype=torch.int64),
                torch.tensor([self.tokenizer_src.token_to_id("[EOS]")], dtype=torch.int64),
                torch.tensor([self.tokenizer_src.token_to_id("[PAD]")] * (350 - len(source.ids) - 2), dtype=torch.int64)
            ], dim=0).to(self.device)

            source_mask = (source != self.tokenizer_src.token_to_id("[PAD]")).unsqueeze(0).unsqueeze(1).int().to(self.device)
            encoder_output = self.model.encode(source, source_mask)

            decoder_input = torch.empty(1, 1).fill_(self.tokenizer_tgt.token_to_id("[SOS]")).type_as(source).to(self.device)
            predicted_words = []

            while decoder_input.size(1) < 350:
                decoder_mask = torch.triu(
                    torch.ones((1, decoder_input.size(1), decoder_input.size(1))),
                    diagonal=1
                ).type(torch.int).type_as(source_mask).to(self.device)

                out = self.model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
                prob = self.model.project(out[:, -1])
                _, next_word = torch.max(prob, dim=1)

                decoder_input = torch.cat(
                    [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(self.device)], dim=1)

                decoded_word = self.tokenizer_tgt.decode([next_word.item()])
                if next_word == self.tokenizer_tgt.token_to_id("[EOS]"):
                    break

                predicted_words.append(decoded_word)

            predicted_translation = " ".join(predicted_words).replace("[EOS]", "").strip()

            return [{"translation": predicted_translation}]
        except Exception as e:
            return [{"error": str(e)}]