kleervoyans commited on
Commit
24c7801
Β·
verified Β·
1 Parent(s): b27cfa2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -157
app.py CHANGED
@@ -5,15 +5,16 @@ import logging
5
  import torch
6
  import pandas as pd
7
  import plotly.express as px
 
 
 
8
  from transformers import (
9
  AutoTokenizer,
10
  AutoModelForSeq2SeqLM,
11
  pipeline,
12
  BitsAndBytesConfig,
13
  )
14
- from langdetect import detect, LangDetectException
15
  import evaluate
16
- from typing import Union, List
17
 
18
  # ────────── Logging ──────────
19
  logging.basicConfig(
@@ -27,9 +28,9 @@ logger = logging.getLogger(__name__)
27
  # ────────── Model Manager ──────────
28
  class ModelManager:
29
  """
30
- Selects and loads a translation model (NLLB-200 or M2M100),
31
- using 8-bit quantization only if CUDA is available.
32
- Auto-detects source language and defaults target to Turkish.
33
  """
34
  def __init__(
35
  self,
@@ -37,81 +38,75 @@ class ModelManager:
37
  quantize: bool = True,
38
  default_tgt: str = None,
39
  ):
40
- # Disable 8-bit if CUDA isn't available
41
  if quantize and not torch.cuda.is_available():
42
  logger.warning("CUDA unavailable; disabling 8-bit quantization")
43
  quantize = False
44
  self.quantize = quantize
45
 
46
- self.candidates = candidates or [
47
  "facebook/nllb-200-distilled-600M",
48
  "facebook/m2m100_418M",
49
  ]
50
- self.default_tgt = default_tgt # will auto-pick if None
51
-
52
- self.model_name: str = None
53
- self.tokenizer = None
54
- self.model = None
55
- self.pipeline = None
56
- self.lang_codes: List[str] = []
57
 
58
  self._select_and_load()
59
 
60
  def _select_and_load(self):
61
  last_err = None
62
- for model_name in self.candidates:
63
  try:
64
- # Load tokenizer
65
- logger.info(f"Loading tokenizer for {model_name}")
66
- tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
67
  if not hasattr(tok, "lang_code_to_id"):
68
- raise AttributeError(f"Tokenizer for {model_name} missing lang_code_to_id")
69
 
70
- # Load model (with or without 8-bit)
71
- logger.info(f"Loading model {model_name} (8-bit={self.quantize})")
72
  if self.quantize:
73
  bnb_cfg = BitsAndBytesConfig(load_in_8bit=True)
74
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
75
- model_name,
76
  device_map="auto",
77
  quantization_config=bnb_cfg,
78
  )
79
  else:
80
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
81
- model_name,
82
  device_map="auto",
83
  )
84
- logger.info(f"Model {model_name} loaded successfully")
85
-
86
- # Wrap in a translation pipeline
87
- pipe = pipeline(
88
- "translation",
89
- model=mdl,
90
- tokenizer=tok,
91
- )
92
-
93
- # Store and break
94
- self.model_name = model_name
95
- self.tokenizer = tok
96
- self.model = mdl
97
- self.pipeline = pipe
98
  self.lang_codes = list(tok.lang_code_to_id.keys())
99
 
100
- # Auto-pick Turkish target code if none specified
101
  if not self.default_tgt:
102
- tur_codes = [c for c in self.lang_codes if c.lower().startswith("tr")]
103
- if not tur_codes:
104
- raise ValueError(f"No Turkish code found in {model_name}")
105
- self.default_tgt = tur_codes[0]
106
- logger.info(f"Default target language: {self.default_tgt}")
107
-
108
  return
109
 
110
  except Exception as e:
111
- logger.warning(f"Failed to load {model_name}: {e}")
112
  last_err = e
113
 
114
- raise RuntimeError(f"Could not load any model from {self.candidates}: {last_err}")
115
 
116
  def translate(
117
  self,
@@ -119,43 +114,39 @@ class ModelManager:
119
  src_lang: str = None,
120
  tgt_lang: str = None,
121
  ):
122
- """
123
- Translate `text` from src_lang β†’ tgt_lang.
124
- Auto-detects src_lang if not given.
125
- """
126
  tgt = tgt_lang or self.default_tgt
127
 
128
- # Auto-detect source language if missing
129
  if not src_lang:
130
  sample = text[0] if isinstance(text, list) else text
131
  try:
132
  iso = detect(sample).lower()
133
- candidates = [c for c in self.lang_codes if c.lower().startswith(iso)]
134
- if not candidates:
135
  raise LangDetectException(f"No code for ISO '{iso}'")
136
- exact = [c for c in candidates if c.lower() == iso]
137
- src = exact[0] if exact else candidates[0]
138
- logger.info(f"Auto-detected src_lang={src}")
139
- except Exception as e:
140
- logger.warning(f"langdetect failed ({e}); defaulting to English")
141
- eng_codes = [c for c in self.lang_codes if c.lower().startswith("en")]
142
- src = eng_codes[0] if eng_codes else self.lang_codes[0]
 
143
  else:
144
  src = src_lang
145
 
146
  return self.pipeline(text, src_lang=src, tgt_lang=tgt)
147
 
148
  def get_info(self):
149
- """Return metadata for the sidebar display."""
150
- quantized = getattr(self.model, "is_loaded_in_8bit", False)
151
- device = "cpu"
152
  if torch.cuda.is_available() and hasattr(self.model, "device"):
153
- dev = self.model.device
154
- device = str(dev) if isinstance(dev, torch.device) else f"cuda:{getattr(dev, 'index', '')}"
155
  return {
156
  "model": self.model_name,
157
- "quantized": quantized,
158
- "device": device,
159
  "default_tgt": self.default_tgt,
160
  }
161
 
@@ -163,17 +154,10 @@ class ModelManager:
163
  # ────────── Evaluator ──────────
164
  class TranslationEvaluator:
165
  def __init__(self):
166
- self.bleu = evaluate.load("bleu")
167
- try:
168
- self.bertscore = evaluate.load("bertscore")
169
- except Exception as e:
170
- logger.error("BERTScore load error: %s", e)
171
- self.bertscore = None
172
- try:
173
- self.comet = evaluate.load("comet", model_id="unbabel/comet-mqm-qe-da")
174
- except Exception as e:
175
- logger.error("COMET load error: %s", e)
176
- self.comet = None
177
 
178
  def evaluate(
179
  self,
@@ -184,69 +168,36 @@ class TranslationEvaluator:
184
  results = {}
185
 
186
  # BLEU
187
- try:
188
- bleu_res = self.bleu.compute(
189
- predictions=predictions,
190
- references=[[r] for r in references],
191
- )
192
- results["BLEU"] = float(bleu_res.get("bleu", 0.0))
193
- except Exception as e:
194
- logger.error("BLEU compute error: %s", e)
195
- results["BLEU"] = 0.0
196
-
197
- # BERTScore (general)
198
- if self.bertscore:
199
- try:
200
- bs = self.bertscore.compute(
201
- predictions=predictions, references=references, lang="xx"
202
- )
203
- f1 = bs.get("f1", [])
204
- results["BERTScore"] = float(sum(f1)) / max(len(f1), 1)
205
- except Exception as e:
206
- logger.error("BERTScore compute error: %s", e)
207
- results["BERTScore"] = 0.0
208
- else:
209
- results["BERTScore"] = 0.0
210
 
211
- # BERTurk (Turkish)
212
- if self.bertscore:
213
- try:
214
- bs_tr = self.bertscore.compute(
215
- predictions=predictions, references=references, lang="tr"
216
- )
217
- f1_tr = bs_tr.get("f1", [])
218
- results["BERTurk"] = float(sum(f1_tr)) / max(len(f1_tr), 1)
219
- except Exception as e:
220
- logger.error("BERTurk compute error: %s", e)
221
- results["BERTurk"] = 0.0
222
- else:
223
- results["BERTurk"] = 0.0
224
 
225
  # COMET
226
- if self.comet:
227
- try:
228
- cm = self.comet.compute(
229
- srcs=sources, hyps=predictions, refs=references
230
- )
231
- sc = cm.get("scores", None)
232
- if isinstance(sc, list):
233
- results["COMET"] = float(sc[0]) if sc else 0.0
234
- else:
235
- results["COMET"] = float(sc or 0.0)
236
- except Exception as e:
237
- logger.error("COMET compute error: %s", e)
238
- results["COMET"] = 0.0
239
  else:
240
- results["COMET"] = 0.0
241
 
242
  return results
243
 
244
 
245
  # ────────── Streamlit App ──────────
 
246
  @st.cache_resource
247
  def load_resources():
248
  mgr = ModelManager(quantize=True)
249
- ev = TranslationEvaluator()
250
  return mgr, ev
251
 
252
 
@@ -265,30 +216,38 @@ def process_text(
265
  ev: TranslationEvaluator,
266
  metrics: List[str],
267
  ):
268
- out = mgr.translate(src)
269
- hyp = out[0]["translation_text"] if isinstance(out, list) else out["translation_text"]
270
- scores = ev.evaluate([src], [ref or ""], [hyp])
271
- return {
272
- "source": src,
273
- "reference": ref,
 
 
274
  "hypothesis": hyp,
275
- **{m: scores.get(m, 0.0) for m in metrics},
276
  }
 
 
 
 
 
 
 
 
 
277
 
278
 
279
  def show_single_results(res: dict, metrics: List[str]):
280
  left, right = st.columns(2)
281
  with left:
282
- st.markdown("**Source:**")
283
- st.write(res["source"])
284
- st.markdown("**Hypothesis (TR):**")
285
- st.write(res["hypothesis"])
286
  if res["reference"]:
287
- st.markdown("**Reference (TR):**")
288
- st.write(res["reference"])
289
  with right:
290
  st.markdown("### Scores")
291
  df = pd.DataFrame([{m: res[m] for m in metrics}])
 
292
  st.table(df)
293
 
294
 
@@ -305,24 +264,35 @@ def process_file(
305
  prog = st.progress(0)
306
  results = []
307
  total = len(df)
 
308
  for i in range(0, total, batch_size):
309
  batch = df.iloc[i : i + batch_size]
310
- srcs = batch["src"].tolist()
311
- refs = batch["ref_tr"].tolist()
312
  outs = mgr.translate(srcs)
313
  hyps = [o["translation_text"] for o in outs]
 
314
  for s, r, h in zip(srcs, refs, hyps):
315
- sc = ev.evaluate([s], [r], [h])
316
  entry = {"src": s, "ref_tr": r, "hyp_tr": h}
317
- entry.update({m: sc.get(m, 0.0) for m in metrics})
 
 
 
 
 
 
318
  results.append(entry)
 
319
  prog.progress(min(i + batch_size, total) / total)
 
320
  return pd.DataFrame(results)
321
 
322
 
323
  def show_batch_viz(df: pd.DataFrame, metrics: List[str]):
324
  for m in metrics:
325
  st.markdown(f"#### {m} Distribution")
 
 
 
326
  fig = px.histogram(df, x=m)
327
  st.plotly_chart(fig, use_container_width=True)
328
 
@@ -330,20 +300,18 @@ def show_batch_viz(df: pd.DataFrame, metrics: List[str]):
330
  def main():
331
  st.set_page_config(page_title="πŸ”€ Translationβ†’Turkish Quality", layout="wide")
332
  st.title("πŸ”€ Translation β†’ TR Quality & COMET")
333
- st.markdown(
334
- "Translate any language into Turkish and evaluate with BLEU, BERTScore, BERTurk & COMET."
335
- )
336
 
337
  # Sidebar
338
  with st.sidebar:
339
  st.header("Settings")
340
- metrics = st.multiselect(
341
- "Select metrics",
342
  ["BLEU", "BERTScore", "BERTurk", "COMET"],
343
- default=["BLEU", "BERTScore", "COMET"],
344
  )
345
  batch_size = st.slider("Batch size", 1, 32, 8)
346
- mgr, ev = load_resources()
347
  display_model_info(mgr.get_info())
348
 
349
  # Tabs
@@ -358,16 +326,14 @@ def main():
358
  show_single_results(res, metrics)
359
 
360
  with tab2:
361
- uploaded = st.file_uploader(
362
- "Upload CSV with `src` & `ref_tr` columns", type=["csv"]
363
- )
364
  if uploaded:
365
  with st.spinner("Processing file…"):
366
  df_res = process_file(uploaded, mgr, ev, metrics, batch_size)
367
  st.markdown("### Batch Results")
368
  st.dataframe(df_res, use_container_width=True)
369
  show_batch_viz(df_res, metrics)
370
- st.download_button("Download CSV", df_res.to_csv(index=False), "results.csv")
371
 
372
 
373
  if __name__ == "__main__":
 
5
  import torch
6
  import pandas as pd
7
  import plotly.express as px
8
+ from typing import Union, List
9
+
10
+ from langdetect import detect, LangDetectException
11
  from transformers import (
12
  AutoTokenizer,
13
  AutoModelForSeq2SeqLM,
14
  pipeline,
15
  BitsAndBytesConfig,
16
  )
 
17
  import evaluate
 
18
 
19
  # ────────── Logging ──────────
20
  logging.basicConfig(
 
28
  # ────────── Model Manager ──────────
29
  class ModelManager:
30
  """
31
+ Select & load the best translation model from a candidate list,
32
+ using 8-bit quant if CUDA is available, else full-precision.
33
+ Auto-picks Turkish target code.
34
  """
35
  def __init__(
36
  self,
 
38
  quantize: bool = True,
39
  default_tgt: str = None,
40
  ):
41
+ # disable 8-bit if no GPU
42
  if quantize and not torch.cuda.is_available():
43
  logger.warning("CUDA unavailable; disabling 8-bit quantization")
44
  quantize = False
45
  self.quantize = quantize
46
 
47
+ self.candidates = candidates or [
48
  "facebook/nllb-200-distilled-600M",
49
  "facebook/m2m100_418M",
50
  ]
51
+ self.default_tgt = default_tgt # will auto-pick if None
52
+ self.model_name = None
53
+ self.tokenizer = None
54
+ self.model = None
55
+ self.pipeline = None
56
+ self.lang_codes = []
 
57
 
58
  self._select_and_load()
59
 
60
  def _select_and_load(self):
61
  last_err = None
62
+ for name in self.candidates:
63
  try:
64
+ # 1) tokenizer
65
+ logger.info(f"Loading tokenizer for {name}")
66
+ tok = AutoTokenizer.from_pretrained(name, use_fast=True)
67
  if not hasattr(tok, "lang_code_to_id"):
68
+ raise AttributeError("no lang_code_to_id on tokenizer")
69
 
70
+ # 2) model
71
+ logger.info(f"Loading model {name} (8-bit={self.quantize})")
72
  if self.quantize:
73
  bnb_cfg = BitsAndBytesConfig(load_in_8bit=True)
74
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
75
+ name,
76
  device_map="auto",
77
  quantization_config=bnb_cfg,
78
  )
79
  else:
80
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
81
+ name,
82
  device_map="auto",
83
  )
84
+ logger.info(f"Loaded {name}")
85
+
86
+ # 3) pipeline
87
+ pipe = pipeline("translation", model=mdl, tokenizer=tok)
88
+
89
+ # store
90
+ self.model_name = name
91
+ self.tokenizer = tok
92
+ self.model = mdl
93
+ self.pipeline = pipe
 
 
 
 
94
  self.lang_codes = list(tok.lang_code_to_id.keys())
95
 
96
+ # pick Turkish code if needed
97
  if not self.default_tgt:
98
+ tur = [c for c in self.lang_codes if c.lower().startswith("tr")]
99
+ if not tur:
100
+ raise ValueError("No Turkish code available")
101
+ self.default_tgt = tur[0]
102
+ logger.info(f"default_tgt = {self.default_tgt}")
 
103
  return
104
 
105
  except Exception as e:
106
+ logger.warning(f"failed to load {name}: {e}")
107
  last_err = e
108
 
109
+ raise RuntimeError(f"no model loaded: {last_err}")
110
 
111
  def translate(
112
  self,
 
114
  src_lang: str = None,
115
  tgt_lang: str = None,
116
  ):
 
 
 
 
117
  tgt = tgt_lang or self.default_tgt
118
 
119
+ # auto-detect source
120
  if not src_lang:
121
  sample = text[0] if isinstance(text, list) else text
122
  try:
123
  iso = detect(sample).lower()
124
+ cand = [c for c in self.lang_codes if c.lower().startswith(iso)]
125
+ if not cand:
126
  raise LangDetectException(f"No code for ISO '{iso}'")
127
+ # exact or first
128
+ exact = [c for c in cand if c.lower() == iso]
129
+ src = exact[0] if exact else cand[0]
130
+ logger.info(f"src_lang = {src}")
131
+ except Exception:
132
+ eng = [c for c in self.lang_codes if c.lower().startswith("en")]
133
+ src = eng[0] if eng else self.lang_codes[0]
134
+ logger.warning(f"defaulting src_lang = {src}")
135
  else:
136
  src = src_lang
137
 
138
  return self.pipeline(text, src_lang=src, tgt_lang=tgt)
139
 
140
  def get_info(self):
141
+ # figure out device for display
142
+ dev = "cpu"
 
143
  if torch.cuda.is_available() and hasattr(self.model, "device"):
144
+ d = self.model.device
145
+ dev = str(d) if isinstance(d, torch.device) else f"cuda:{getattr(d,'index','')}"
146
  return {
147
  "model": self.model_name,
148
+ "quantized": self.quantize,
149
+ "device": dev,
150
  "default_tgt": self.default_tgt,
151
  }
152
 
 
154
  # ────────── Evaluator ──────────
155
  class TranslationEvaluator:
156
  def __init__(self):
157
+ self.bleu = evaluate.load("bleu")
158
+ self.bertscore = evaluate.load("bertscore")
159
+ self.comet = evaluate.load("comet", model_id="unbabel/wmt22-comet-da")
160
+ logger.info("Loaded BLEU, BERTScore, COMET")
 
 
 
 
 
 
 
161
 
162
  def evaluate(
163
  self,
 
168
  results = {}
169
 
170
  # BLEU
171
+ bleu_r = self.bleu.compute(predictions=predictions, references=[[r] for r in references])
172
+ results["BLEU"] = float(bleu_r.get("bleu", 0.0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ # BERTScore (xx)
175
+ bs = self.bertscore.compute(predictions=predictions, references=references, lang="xx")
176
+ f1 = bs.get("f1", [])
177
+ results["BERTScore"] = float(sum(f1) / len(f1)) if f1 else 0.0
178
+
179
+ # BERTurk (tr)
180
+ bs_tr = self.bertscore.compute(predictions=predictions, references=references, lang="tr")
181
+ f1t = bs_tr.get("f1", [])
182
+ results["BERTurk"] = float(sum(f1t) / len(f1t)) if f1t else 0.0
 
 
 
 
183
 
184
  # COMET
185
+ cm = self.comet.compute(srcs=sources, hyps=predictions, refs=references)
186
+ sc = cm.get("scores", None)
187
+ if isinstance(sc, list):
188
+ results["COMET"] = float(sc[0]) if sc else 0.0
 
 
 
 
 
 
 
 
 
189
  else:
190
+ results["COMET"] = float(sc or 0.0)
191
 
192
  return results
193
 
194
 
195
  # ────────── Streamlit App ──────────
196
+
197
  @st.cache_resource
198
  def load_resources():
199
  mgr = ModelManager(quantize=True)
200
+ ev = TranslationEvaluator()
201
  return mgr, ev
202
 
203
 
 
216
  ev: TranslationEvaluator,
217
  metrics: List[str],
218
  ):
219
+ # 1) translate
220
+ out = mgr.translate(src) # list of dicts
221
+ hyp = out[0]["translation_text"]
222
+
223
+ # 2) if we have a non-blank reference β†’ compute metrics; else all Nones
224
+ result = {
225
+ "source": src,
226
+ "reference": ref or None,
227
  "hypothesis": hyp,
 
228
  }
229
+ if ref and ref.strip():
230
+ scores = ev.evaluate([src], [ref], [hyp])
231
+ for m in metrics:
232
+ result[m] = scores.get(m, 0.0)
233
+ else:
234
+ for m in metrics:
235
+ result[m] = None
236
+
237
+ return result
238
 
239
 
240
  def show_single_results(res: dict, metrics: List[str]):
241
  left, right = st.columns(2)
242
  with left:
243
+ st.markdown("**Source:**"); st.write(res["source"])
244
+ st.markdown("**Hypothesis (TR):**"); st.write(res["hypothesis"])
 
 
245
  if res["reference"]:
246
+ st.markdown("**Reference (TR):**"); st.write(res["reference"])
 
247
  with right:
248
  st.markdown("### Scores")
249
  df = pd.DataFrame([{m: res[m] for m in metrics}])
250
+ df = df.replace({None: "N/A"})
251
  st.table(df)
252
 
253
 
 
264
  prog = st.progress(0)
265
  results = []
266
  total = len(df)
267
+
268
  for i in range(0, total, batch_size):
269
  batch = df.iloc[i : i + batch_size]
270
+ srcs, refs = batch["src"].tolist(), batch["ref_tr"].tolist()
 
271
  outs = mgr.translate(srcs)
272
  hyps = [o["translation_text"] for o in outs]
273
+
274
  for s, r, h in zip(srcs, refs, hyps):
 
275
  entry = {"src": s, "ref_tr": r, "hyp_tr": h}
276
+ if r and str(r).strip():
277
+ sc = ev.evaluate([s], [r], [h])
278
+ for m in metrics:
279
+ entry[m] = sc.get(m, 0.0)
280
+ else:
281
+ for m in metrics:
282
+ entry[m] = None
283
  results.append(entry)
284
+
285
  prog.progress(min(i + batch_size, total) / total)
286
+
287
  return pd.DataFrame(results)
288
 
289
 
290
  def show_batch_viz(df: pd.DataFrame, metrics: List[str]):
291
  for m in metrics:
292
  st.markdown(f"#### {m} Distribution")
293
+ if df[m].dropna().empty:
294
+ st.write("No reference provided, so this metric is N/A.")
295
+ continue
296
  fig = px.histogram(df, x=m)
297
  st.plotly_chart(fig, use_container_width=True)
298
 
 
300
  def main():
301
  st.set_page_config(page_title="πŸ”€ Translationβ†’Turkish Quality", layout="wide")
302
  st.title("πŸ”€ Translation β†’ TR Quality & COMET")
303
+ st.markdown("Translate any language into Turkish and evaluate (optional) with BLEU, BERTScore, BERTurk & COMET.")
 
 
304
 
305
  # Sidebar
306
  with st.sidebar:
307
  st.header("Settings")
308
+ metrics = st.multiselect(
309
+ "Select metrics",
310
  ["BLEU", "BERTScore", "BERTurk", "COMET"],
311
+ default=["BLEU", "BERTScore", "COMET"]
312
  )
313
  batch_size = st.slider("Batch size", 1, 32, 8)
314
+ mgr, ev = load_resources()
315
  display_model_info(mgr.get_info())
316
 
317
  # Tabs
 
326
  show_single_results(res, metrics)
327
 
328
  with tab2:
329
+ uploaded = st.file_uploader("Upload CSV with `src` & `ref_tr` columns", type=["csv"])
 
 
330
  if uploaded:
331
  with st.spinner("Processing file…"):
332
  df_res = process_file(uploaded, mgr, ev, metrics, batch_size)
333
  st.markdown("### Batch Results")
334
  st.dataframe(df_res, use_container_width=True)
335
  show_batch_viz(df_res, metrics)
336
+ st.download_button("Download results as CSV", df_res.to_csv(index=False), "results.csv")
337
 
338
 
339
  if __name__ == "__main__":