Spaces:
Running
Running
Update translation_models.py
Browse files- translation_models.py +18 -2
translation_models.py
CHANGED
@@ -1,8 +1,10 @@
|
|
|
|
|
|
1 |
from functools import lru_cache
|
2 |
import torch
|
3 |
from transformers import pipeline
|
4 |
|
5 |
-
#
|
6 |
LANGUAGE_CODES = {
|
7 |
"English": "en",
|
8 |
"Spanish": "es",
|
@@ -15,4 +17,18 @@ LANGUAGE_CODES = {
|
|
15 |
"Russian": "ru",
|
16 |
"Portuguese": "pt"
|
17 |
}
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# translation_models.py
|
2 |
+
|
3 |
from functools import lru_cache
|
4 |
import torch
|
5 |
from transformers import pipeline
|
6 |
|
7 |
+
# Mapping for 10 languages:
|
8 |
LANGUAGE_CODES = {
|
9 |
"English": "en",
|
10 |
"Spanish": "es",
|
|
|
17 |
"Russian": "ru",
|
18 |
"Portuguese": "pt"
|
19 |
}
|
20 |
+
|
21 |
+
@lru_cache(maxsize=8)
|
22 |
+
def get_local_translator(src_lang: str, tgt_lang: str):
|
23 |
+
model_id = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
|
24 |
+
device = 0 if torch.cuda.is_available() else -1
|
25 |
+
return pipeline("translation", model=model_id, device=device)
|
26 |
+
|
27 |
+
def translate(text: str, tgt_lang_name: str, src_lang_name: str = "English") -> str:
|
28 |
+
if not text or tgt_lang_name == src_lang_name:
|
29 |
+
return text
|
30 |
+
src_code = LANGUAGE_CODES.get(src_lang_name, "en")
|
31 |
+
tgt_code = LANGUAGE_CODES.get(tgt_lang_name, "en")
|
32 |
+
translator = get_local_translator(src_code, tgt_code)
|
33 |
+
result = translator(text, max_length=512, truncation=True)
|
34 |
+
return result[0]["translation_text"]
|