|
import os |
|
os.environ["CT2_USE_EXPERIMENTAL_PACKED_GEMM"] = "1" |
|
import ctranslate2 |
|
from typing import Dict, List, Any |
|
from transformers import T5TokenizerFast |
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
model_path = snapshot_download('ikeno-ada/madlad400-3b-mt-8bit-ct2') |
|
self.translator = ctranslate2.Translator(model_path) |
|
self.tokenizer = T5TokenizerFast.from_pretrained('google/madlad400-3b-mt') |
|
|
|
def __call__(self, data: Any) -> Dict[str, str]: |
|
""" |
|
Args: |
|
data (:obj:): |
|
includes the input data and the parameters for the inference. |
|
""" |
|
text = data.get("inputs").get("text") |
|
langId = data.get("inputs").get("langId") |
|
|
|
input_text = f"<2{langId}>{text}" |
|
input_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(input_text)) |
|
results = translator.translate_batch([input_tokens],batch_type="tokens") |
|
|
|
output_tokens = results[0].hypotheses[0] |
|
output_text = tokenizer.decode(tokenizer.convert_tokens_to_ids(output_tokens)) |
|
return {"translated": output_text} |