ikeno-ada commited on
Commit
c45ae9b
·
verified ·
1 Parent(s): 51b86e0

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +32 -0
handler.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CT2_USE_EXPERIMENTAL_PACKED_GEMM"] = "1"
3
+ import ctranslate2
4
+ from typing import Dict, List, Any
5
+ from transformers import T5TokenizerFast
6
+ from huggingface_hub import snapshot_download
7
+ from sentencepiece import SentencePieceProcessor
8
+
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ # load the optimized model
13
+ model_path = snapshot_download('ikeno-ada/madlad400-3b-mt-8bit-ct2')
14
+ self.translator = ctranslate2.Translator(model_path)
15
+ self.tokenizer = T5TokenizerFast.from_pretrained('google/madlad400-3b-mt')
16
+
17
+ def __call__(self, data: Any) -> Dict[str, str]:
18
+ """
19
+ Args:
20
+ data (:obj:):
21
+ includes the input data and the parameters for the inference.
22
+ """
23
+ text = data.get("inputs").get("text")
24
+ langId = data.get("inputs").get("langId")
25
+
26
+ input_text = f"<2{langId}>{text}"
27
+ input_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(input_text))
28
+ results = translator.translate_batch([input_tokens],batch_type="tokens")
29
+
30
+ output_tokens = results[0].hypotheses[0]
31
+ output_text = tokenizer.decode(tokenizer.convert_tokens_to_ids(output_tokens))
32
+ return {"translated": output_text}