6yuru99 commited on
Commit
ba3c963
1 Parent(s): 9a950fb

Update translation_model.py

Browse files
Files changed (1) hide show
  1. translation_model.py +26 -28
translation_model.py CHANGED
@@ -1,29 +1,27 @@
1
- # translation_model.py
2
-
3
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
-
5
- class Translator:
6
- src_lang, tgt_lang = "", ""
7
-
8
- def set_lang_codes(self, direction):
9
- if direction == "English -> Chinese":
10
- self.src_lang = "eng_Latn"
11
- self.tgt_lang = "zho_Hant"
12
- elif direction == "Chinese -> English":
13
- self.src_lang = "zho_Hant"
14
- self.tgt_lang = "eng_Latn"
15
- else:
16
- raise ValueError("Unsupported translation direction")
17
-
18
- def __init__(self, model_name='6yuru99/medical-nllb-200-en2zh_hant'):
19
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=self.src_lang)
20
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
-
22
- def translate(self, text, direction):
23
- self.set_lang_codes(direction)
24
- inputs = self.tokenizer(text, return_tensors="pt")
25
- translated_tokens = self.model.generate(**inputs, forced_bos_token_id=self.tokenizer.convert_tokens_to_ids(self.tgt_lang), max_length=1024)
26
- outputs = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
27
- return outputs
28
-
29
  translator = Translator()
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
+
3
+ class Translator:
4
+ src_lang, tgt_lang = "", ""
5
+
6
+ def set_lang_codes(self, direction):
7
+ if direction == "English -> Chinese":
8
+ self.src_lang = "eng_Latn"
9
+ self.tgt_lang = "zho_Hant"
10
+ elif direction == "Chinese -> English":
11
+ self.src_lang = "zho_Hant"
12
+ self.tgt_lang = "eng_Latn"
13
+ else:
14
+ raise ValueError("Unsupported translation direction")
15
+
16
+ def __init__(self, model_name='6yuru99/medical-nllb-200-en2zh_hant'):
17
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=self.src_lang)
18
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
19
+
20
+ def translate(self, text, direction):
21
+ self.set_lang_codes(direction)
22
+ inputs = self.tokenizer(text, return_tensors="pt")
23
+ translated_tokens = self.model.generate(**inputs, forced_bos_token_id=self.tokenizer.convert_tokens_to_ids(self.tgt_lang), max_length=1024)
24
+ outputs = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
25
+ return outputs
26
+
 
 
27
  translator = Translator()