File size: 1,096 Bytes
5bc9a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformers import MarianMTModel, MarianTokenizer
from typing import Any, List, Dict

class EndpointHandler:
    def __init__(self, path=""):
        # Load the model and tokenizer
        self.model = MarianMTModel.from_pretrained(path)
        self.tokenizer = MarianTokenizer.from_pretrained(path)

    def __call__(self, data: Any) -> List[Dict[str, str]]:
        """
        Args:
            data (dict): The request payload with an "inputs" key containing the text to translate.
        Returns:
            List[Dict]: A list containing the translated text.
        """
        # Get the input text from the request
        text = data.get("inputs", "")

        # Tokenize the input text
        inputs = self.tokenizer(text, return_tensors="pt", padding=True)

        # Perform the translation
        translated = self.model.generate(**inputs)

        # Decode the translated text
        translated_text = self.tokenizer.decode(translated[0], skip_special_tokens=True)

        # Return the translated text as a response
        return [{"translation_text": translated_text}]