kleervoyans commited on
Commit
7fc686c
Β·
verified Β·
1 Parent(s): 5ccd1db

Delete models/translation_loader.py

Browse files
Files changed (1) hide show
  1. models/translation_loader.py +0 -114
models/translation_loader.py DELETED
@@ -1,114 +0,0 @@
1
- # models/translation_loader.py
2
-
3
- import logging
4
- from typing import Union, List
5
- from langdetect import detect, LangDetectException
6
- from transformers import pipeline, AutoTokenizer, BitsAndBytesConfig
7
-
8
- class TranslationLoader:
9
- def __init__(
10
- self,
11
- model_name: str = "facebook/nllb-200-distilled-600M",
12
- quantize: bool = True,
13
- tgt_lang: str = None, # if None, we’ll pick the Turkish code automatically
14
- ):
15
- self.model_name = model_name
16
- self.quantize = quantize
17
- self.default_tgt = tgt_lang # may be None
18
-
19
- # ─── Load the translation pipeline ───────────────────────────────
20
- try:
21
- bnb_cfg = BitsAndBytesConfig(load_in_8bit=self.quantize)
22
- self.pipeline = pipeline(
23
- "translation",
24
- model=self.model_name,
25
- tokenizer=self.model_name,
26
- device_map="auto",
27
- quantization_config=bnb_cfg,
28
- )
29
- logging.info(f"Loaded `{self.model_name}` with 8-bit={self.quantize}")
30
- except Exception as e:
31
- logging.warning(f"8-bit load failed ({e}); falling back to full-precision")
32
- self.pipeline = pipeline(
33
- "translation",
34
- model=self.model_name,
35
- tokenizer=self.model_name,
36
- device_map="auto",
37
- )
38
- logging.info(f"Loaded `{self.model_name}` in full precision")
39
-
40
- # ─── Load tokenizer & grab the lang_code_to_id mapping ────────────
41
- try:
42
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
43
- logging.info(f"Tokenizer loaded for {self.model_name}")
44
- except Exception as e:
45
- logging.error(f"Cannot load tokenizer for {self.model_name}: {e}")
46
- raise ValueError(f"Failed to load tokenizer: {e}")
47
-
48
- if hasattr(self.tokenizer, "lang_code_to_id"):
49
- self.lang_code_to_id = self.tokenizer.lang_code_to_id
50
- logging.info("Using tokenizer.lang_code_to_id mapping")
51
- else:
52
- allowed = ", ".join(list(self.tokenizer.config.to_dict().keys())[:5])
53
- raise AttributeError(
54
- f"Model `{self.model_name}`’s tokenizer has no `lang_code_to_id`. "
55
- "Use a model like NLLB-200 or M2M100 that supports language codes. "
56
- f"(available config keys: {allowed}…)"
57
- )
58
-
59
- # ─── Auto-pick the Turkish target code if none was provided ───────
60
- if self.default_tgt is None:
61
- tur = [c for c in self.lang_code_to_id if c.lower().startswith("tr")]
62
- if not tur:
63
- raise ValueError(f"No Turkish code found in mapping for {self.model_name}")
64
- self.default_tgt = tur[0]
65
- logging.info(f"Default target set to `{self.default_tgt}`")
66
-
67
- def translate(
68
- self,
69
- text: Union[str, List[str]],
70
- src_lang: str = None,
71
- tgt_lang: str = None,
72
- ):
73
- """
74
- - Auto-detects src_lang via langdetect if not given
75
- - Uses default_tgt if tgt_lang is not passed
76
- - Returns pipeline output (list of dicts with 'translation_text')
77
- """
78
- tgt = tgt_lang or self.default_tgt
79
-
80
- # ─── Source-language auto-detection ─────────────────────────────
81
- if src_lang:
82
- src = src_lang
83
- else:
84
- sample = text[0] if isinstance(text, list) else text
85
- try:
86
- iso = detect(sample).lower()
87
- # find codes starting with that ISO (e.g. "en"β†’["en","eng_Latn",…])
88
- cand = [c for c in self.lang_code_to_id if c.lower().startswith(iso)]
89
- if not cand:
90
- raise LangDetectException(f"No mapping for ISO '{iso}'")
91
- # prefer exact match, else first
92
- exact = [c for c in cand if c.lower() == iso]
93
- src = exact[0] if exact else cand[0]
94
- logging.info(f"Detected src_lang={src} from ISO='{iso}'")
95
- except Exception as e:
96
- logging.warning(f"Language auto-detect failed ({e}); defaulting to English")
97
- eng = [c for c in self.lang_code_to_id if c.lower().startswith("en")]
98
- src = eng[0] if eng else list(self.lang_code_to_id)[0]
99
- logging.info(f"Fallback src_lang={src}")
100
-
101
- # ─── Perform translation call ────────────────────────────────────
102
- return self.pipeline(text, src_lang=src, tgt_lang=tgt)
103
-
104
- def get_info(self):
105
- """Return model metadata for display in your sidebar."""
106
- mdl = getattr(self.pipeline, "model", None)
107
- q = getattr(mdl, "is_loaded_in_8bit", False)
108
- device = getattr(mdl, "device", "auto")
109
- return {
110
- "model_name": self.model_name,
111
- "quantized": q,
112
- "device": str(device),
113
- "default_target": self.default_tgt,
114
- }