ikeno-ada's picture
Update handler.py
09f6aea verified
import os
os.environ["CT2_USE_EXPERIMENTAL_PACKED_GEMM"] = "1"
import ctranslate2
from typing import Dict, List, Any
from transformers import T5TokenizerFast
from huggingface_hub import snapshot_download
class EndpointHandler():
def __init__(self, path=""):
# load the optimized model
model_path = snapshot_download('ikeno-ada/madlad400-3b-mt-8bit-ct2')
self.translator = ctranslate2.Translator(model_path)
self.tokenizer = T5TokenizerFast.from_pretrained('google/madlad400-3b-mt')
def __call__(self, data: Any) -> Dict[str, str]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
"""
text = data.get("inputs").get("text")
langId = data.get("inputs").get("langId")
input_text = f"<2{langId}>{text}"
input_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(input_text))
results = translator.translate_batch([input_tokens],batch_type="tokens")
output_tokens = results[0].hypotheses[0]
output_text = tokenizer.decode(tokenizer.convert_tokens_to_ids(output_tokens))
return {"translated": output_text}