mgbam commited on
Commit
a73c69c
·
verified ·
1 Parent(s): de42686

Update translation_models.py

Browse files
Files changed (1) hide show
  1. translation_models.py +18 -2
translation_models.py CHANGED
@@ -1,8 +1,10 @@
 
 
1
  from functools import lru_cache
2
  import torch
3
  from transformers import pipeline
4
 
5
- # translation_models.py
6
  LANGUAGE_CODES = {
7
  "English": "en",
8
  "Spanish": "es",
@@ -15,4 +17,18 @@ LANGUAGE_CODES = {
15
  "Russian": "ru",
16
  "Portuguese": "pt"
17
  }
18
- # ... along with the translate() function ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # translation_models.py
2
+
3
  from functools import lru_cache
4
  import torch
5
  from transformers import pipeline
6
 
7
+ # Mapping for 10 languages:
8
  LANGUAGE_CODES = {
9
  "English": "en",
10
  "Spanish": "es",
 
17
  "Russian": "ru",
18
  "Portuguese": "pt"
19
  }
20
+
21
+ @lru_cache(maxsize=8)
22
+ def get_local_translator(src_lang: str, tgt_lang: str):
23
+ model_id = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
24
+ device = 0 if torch.cuda.is_available() else -1
25
+ return pipeline("translation", model=model_id, device=device)
26
+
27
+ def translate(text: str, tgt_lang_name: str, src_lang_name: str = "English") -> str:
28
+ if not text or tgt_lang_name == src_lang_name:
29
+ return text
30
+ src_code = LANGUAGE_CODES.get(src_lang_name, "en")
31
+ tgt_code = LANGUAGE_CODES.get(tgt_lang_name, "en")
32
+ translator = get_local_translator(src_code, tgt_code)
33
+ result = translator(text, max_length=512, truncation=True)
34
+ return result[0]["translation_text"]