Spaces:
Sleeping
Sleeping
Delete models/translation_loader.py
Browse files- models/translation_loader.py +0 -114
models/translation_loader.py
DELETED
@@ -1,114 +0,0 @@
|
|
1 |
-
# models/translation_loader.py
|
2 |
-
|
3 |
-
import logging
|
4 |
-
from typing import Union, List
|
5 |
-
from langdetect import detect, LangDetectException
|
6 |
-
from transformers import pipeline, AutoTokenizer, BitsAndBytesConfig
|
7 |
-
|
8 |
-
class TranslationLoader:
|
9 |
-
def __init__(
|
10 |
-
self,
|
11 |
-
model_name: str = "facebook/nllb-200-distilled-600M",
|
12 |
-
quantize: bool = True,
|
13 |
-
tgt_lang: str = None, # if None, weβll pick the Turkish code automatically
|
14 |
-
):
|
15 |
-
self.model_name = model_name
|
16 |
-
self.quantize = quantize
|
17 |
-
self.default_tgt = tgt_lang # may be None
|
18 |
-
|
19 |
-
# βββ Load the translation pipeline βββββββββββββββββββββββββββββββ
|
20 |
-
try:
|
21 |
-
bnb_cfg = BitsAndBytesConfig(load_in_8bit=self.quantize)
|
22 |
-
self.pipeline = pipeline(
|
23 |
-
"translation",
|
24 |
-
model=self.model_name,
|
25 |
-
tokenizer=self.model_name,
|
26 |
-
device_map="auto",
|
27 |
-
quantization_config=bnb_cfg,
|
28 |
-
)
|
29 |
-
logging.info(f"Loaded `{self.model_name}` with 8-bit={self.quantize}")
|
30 |
-
except Exception as e:
|
31 |
-
logging.warning(f"8-bit load failed ({e}); falling back to full-precision")
|
32 |
-
self.pipeline = pipeline(
|
33 |
-
"translation",
|
34 |
-
model=self.model_name,
|
35 |
-
tokenizer=self.model_name,
|
36 |
-
device_map="auto",
|
37 |
-
)
|
38 |
-
logging.info(f"Loaded `{self.model_name}` in full precision")
|
39 |
-
|
40 |
-
# βββ Load tokenizer & grab the lang_code_to_id mapping ββββββββββββ
|
41 |
-
try:
|
42 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
43 |
-
logging.info(f"Tokenizer loaded for {self.model_name}")
|
44 |
-
except Exception as e:
|
45 |
-
logging.error(f"Cannot load tokenizer for {self.model_name}: {e}")
|
46 |
-
raise ValueError(f"Failed to load tokenizer: {e}")
|
47 |
-
|
48 |
-
if hasattr(self.tokenizer, "lang_code_to_id"):
|
49 |
-
self.lang_code_to_id = self.tokenizer.lang_code_to_id
|
50 |
-
logging.info("Using tokenizer.lang_code_to_id mapping")
|
51 |
-
else:
|
52 |
-
allowed = ", ".join(list(self.tokenizer.config.to_dict().keys())[:5])
|
53 |
-
raise AttributeError(
|
54 |
-
f"Model `{self.model_name}`βs tokenizer has no `lang_code_to_id`. "
|
55 |
-
"Use a model like NLLB-200 or M2M100 that supports language codes. "
|
56 |
-
f"(available config keys: {allowed}β¦)"
|
57 |
-
)
|
58 |
-
|
59 |
-
# βββ Auto-pick the Turkish target code if none was provided βββββββ
|
60 |
-
if self.default_tgt is None:
|
61 |
-
tur = [c for c in self.lang_code_to_id if c.lower().startswith("tr")]
|
62 |
-
if not tur:
|
63 |
-
raise ValueError(f"No Turkish code found in mapping for {self.model_name}")
|
64 |
-
self.default_tgt = tur[0]
|
65 |
-
logging.info(f"Default target set to `{self.default_tgt}`")
|
66 |
-
|
67 |
-
def translate(
|
68 |
-
self,
|
69 |
-
text: Union[str, List[str]],
|
70 |
-
src_lang: str = None,
|
71 |
-
tgt_lang: str = None,
|
72 |
-
):
|
73 |
-
"""
|
74 |
-
- Auto-detects src_lang via langdetect if not given
|
75 |
-
- Uses default_tgt if tgt_lang is not passed
|
76 |
-
- Returns pipeline output (list of dicts with 'translation_text')
|
77 |
-
"""
|
78 |
-
tgt = tgt_lang or self.default_tgt
|
79 |
-
|
80 |
-
# βββ Source-language auto-detection βββββββββββββββββββββββββββββ
|
81 |
-
if src_lang:
|
82 |
-
src = src_lang
|
83 |
-
else:
|
84 |
-
sample = text[0] if isinstance(text, list) else text
|
85 |
-
try:
|
86 |
-
iso = detect(sample).lower()
|
87 |
-
# find codes starting with that ISO (e.g. "en"β["en","eng_Latn",β¦])
|
88 |
-
cand = [c for c in self.lang_code_to_id if c.lower().startswith(iso)]
|
89 |
-
if not cand:
|
90 |
-
raise LangDetectException(f"No mapping for ISO '{iso}'")
|
91 |
-
# prefer exact match, else first
|
92 |
-
exact = [c for c in cand if c.lower() == iso]
|
93 |
-
src = exact[0] if exact else cand[0]
|
94 |
-
logging.info(f"Detected src_lang={src} from ISO='{iso}'")
|
95 |
-
except Exception as e:
|
96 |
-
logging.warning(f"Language auto-detect failed ({e}); defaulting to English")
|
97 |
-
eng = [c for c in self.lang_code_to_id if c.lower().startswith("en")]
|
98 |
-
src = eng[0] if eng else list(self.lang_code_to_id)[0]
|
99 |
-
logging.info(f"Fallback src_lang={src}")
|
100 |
-
|
101 |
-
# βββ Perform translation call ββββββββββββββββββββββββββββββββββββ
|
102 |
-
return self.pipeline(text, src_lang=src, tgt_lang=tgt)
|
103 |
-
|
104 |
-
def get_info(self):
|
105 |
-
"""Return model metadata for display in your sidebar."""
|
106 |
-
mdl = getattr(self.pipeline, "model", None)
|
107 |
-
q = getattr(mdl, "is_loaded_in_8bit", False)
|
108 |
-
device = getattr(mdl, "device", "auto")
|
109 |
-
return {
|
110 |
-
"model_name": self.model_name,
|
111 |
-
"quantized": q,
|
112 |
-
"device": str(device),
|
113 |
-
"default_target": self.default_tgt,
|
114 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|