File size: 1,177 Bytes
ba3c963
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8542393
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
        
class Translator:    
    src_lang, tgt_lang = "", ""

    def set_lang_codes(self, direction):
        if direction == "English -> Chinese":
            self.src_lang = "eng_Latn"
            self.tgt_lang = "zho_Hant"
        elif direction == "Chinese -> English":
            self.src_lang = "zho_Hant"
            self.tgt_lang = "eng_Latn"
        else:
            raise ValueError("Unsupported translation direction")
           
    def __init__(self, model_name='6yuru99/medical-nllb-200-en2zh_hant'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=self.src_lang)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    def translate(self, text, direction):
        self.set_lang_codes(direction)
        inputs = self.tokenizer(text, return_tensors="pt")
        translated_tokens = self.model.generate(**inputs, forced_bos_token_id=self.tokenizer.convert_tokens_to_ids(self.tgt_lang), max_length=1024)
        outputs = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
        return outputs

translator = Translator()