kleervoyans commited on
Commit
5ccd1db
·
verified ·
1 Parent(s): 08c5787

Create model_manager.py

Browse files
Files changed (1) hide show
  1. models/model_manager.py +85 -0
models/model_manager.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Union, List
3
+ from langdetect import detect, LangDetectException
4
+
5
+ from models.model_loader import ModelLoader
6
+ from models.model_selector import ModelSelector
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class ModelManager:
11
+ """
12
+ Orchestrates model selection, loading, and auto-language detection.
13
+ Exposes:
14
+ - translate(text, src_lang=None, tgt_lang=None)
15
+ - get_info()
16
+ """
17
+ def __init__(
18
+ self,
19
+ candidates: List[str] = None,
20
+ quantize: bool = True,
21
+ default_tgt: str = None,
22
+ ):
23
+ self.selector = ModelSelector(candidates, quantize)
24
+ self.loader = ModelLoader(quantize)
25
+ self.tokenizer = None
26
+ self.pipeline = None
27
+ self.lang_codes = []
28
+ self.default_tgt = default_tgt # e.g. "tur_Latn"
29
+ self._load_best_model()
30
+
31
+ def _load_best_model(self):
32
+ model_name = self.selector.select()
33
+ tok, pipe = self.loader.load(model_name)
34
+ self.tokenizer = tok
35
+ self.pipeline = pipe
36
+ self.lang_codes = list(tok.lang_code_to_id.keys())
37
+
38
+ # Pick a Turkish code if not explicitly set
39
+ if not self.default_tgt:
40
+ tur = [c for c in self.lang_codes if c.lower().startswith("tr")]
41
+ if not tur:
42
+ raise ValueError(f"No Turkish code found in {model_name}")
43
+ self.default_tgt = tur[0]
44
+ logger.info(f"Default target language: {self.default_tgt}")
45
+
46
+ def translate(
47
+ self,
48
+ text: Union[str, List[str]],
49
+ src_lang: str = None,
50
+ tgt_lang: str = None,
51
+ ):
52
+ tgt = tgt_lang or self.default_tgt
53
+
54
+ # Auto-detect source if missing
55
+ if not src_lang:
56
+ sample = text[0] if isinstance(text, list) else text
57
+ try:
58
+ iso = detect(sample).lower()
59
+ candidates = [c for c in self.lang_codes if c.lower().startswith(iso)]
60
+ if not candidates:
61
+ raise LangDetectException(f"No mapping for ISO '{iso}'")
62
+ exact = [c for c in candidates if c.lower() == iso]
63
+ src = exact[0] if exact else candidates[0]
64
+ logger.info(f"Detected src_lang={src}")
65
+ except Exception as e:
66
+ logger.warning(f"Auto-detect failed ({e}); defaulting to English")
67
+ eng = [c for c in self.lang_codes if c.lower().startswith("en")]
68
+ src = eng[0] if eng else self.lang_codes[0]
69
+ else:
70
+ src = src_lang
71
+
72
+ return self.pipeline(text, src_lang=src, tgt_lang=tgt)
73
+
74
+ def get_info(self):
75
+ """
76
+ Returns a dict for your sidebar:
77
+ { model_name, quantized, device, default_tgt }
78
+ """
79
+ mdl = getattr(self.pipeline, "model", None)
80
+ return {
81
+ "model": getattr(mdl, "name_or_path", None),
82
+ "quantized": getattr(mdl, "is_loaded_in_8bit", False),
83
+ "device": str(getattr(mdl, "device", "auto")),
84
+ "default_tgt": self.default_tgt,
85
+ }