File size: 1,209 Bytes
c45ae9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
os.environ["CT2_USE_EXPERIMENTAL_PACKED_GEMM"] = "1"
import ctranslate2
from typing import  Dict, List, Any
from transformers import T5TokenizerFast
from huggingface_hub import snapshot_download


class EndpointHandler():
    def __init__(self, path=""):
        # load the optimized model
        model_path = snapshot_download('ikeno-ada/madlad400-3b-mt-8bit-ct2')
        self.translator = ctranslate2.Translator(model_path)
        self.tokenizer = T5TokenizerFast.from_pretrained('google/madlad400-3b-mt')

    def __call__(self, data: Any) -> Dict[str, str]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        """
        text = data.get("inputs").get("text")
        langId = data.get("inputs").get("langId")

        input_text = f"<2{langId}>{text}"
        input_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(input_text))
        results = translator.translate_batch([input_tokens],batch_type="tokens")
        
        output_tokens = results[0].hypotheses[0]
        output_text = tokenizer.decode(tokenizer.convert_tokens_to_ids(output_tokens))
        return {"translated": output_text}