Spaces:
Sleeping
Sleeping
# translation_model.py | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
class Translator: | |
src_lang, tgt_lang = "", "" | |
def set_lang_codes(self, direction): | |
if direction == "English -> Chinese": | |
self.src_lang = "eng_Latn" | |
self.tgt_lang = "zho_Hant" | |
elif direction == "Chinese -> English": | |
self.src_lang = "zho_Hant" | |
self.tgt_lang = "eng_Latn" | |
else: | |
raise ValueError("Unsupported translation direction") | |
def __init__(self, model_name='6yuru99/medical-nllb-200-en2zh_hant'): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=self.src_lang) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
def translate(self, text, direction): | |
self.set_lang_codes(direction) | |
inputs = self.tokenizer(text, return_tensors="pt") | |
translated_tokens = self.model.generate(**inputs, forced_bos_token_id=self.tokenizer.convert_tokens_to_ids(self.tgt_lang), max_length=1024) | |
outputs = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
return outputs | |
translator = Translator() |