kleervoyans commited on
Commit
b27cfa2
Β·
verified Β·
1 Parent(s): 8ec855b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -65
app.py CHANGED
@@ -2,19 +2,18 @@
2
 
3
  import streamlit as st
4
  import logging
 
5
  import pandas as pd
6
  import plotly.express as px
7
- import torch
8
- from typing import Union, List
9
-
10
- from langdetect import detect, LangDetectException
11
  from transformers import (
12
  AutoTokenizer,
13
  AutoModelForSeq2SeqLM,
14
  pipeline,
15
  BitsAndBytesConfig,
16
  )
 
17
  import evaluate
 
18
 
19
  # ────────── Logging ──────────
20
  logging.basicConfig(
@@ -38,7 +37,7 @@ class ModelManager:
38
  quantize: bool = True,
39
  default_tgt: str = None,
40
  ):
41
- # If user requested quantization but CUDA isn't available, disable it
42
  if quantize and not torch.cuda.is_available():
43
  logger.warning("CUDA unavailable; disabling 8-bit quantization")
44
  quantize = False
@@ -50,7 +49,7 @@ class ModelManager:
50
  ]
51
  self.default_tgt = default_tgt # will auto-pick if None
52
 
53
- self.selected_model_name: str = None
54
  self.tokenizer = None
55
  self.model = None
56
  self.pipeline = None
@@ -66,14 +65,10 @@ class ModelManager:
66
  logger.info(f"Loading tokenizer for {model_name}")
67
  tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
68
  if not hasattr(tok, "lang_code_to_id"):
69
- raise AttributeError(
70
- f"Tokenizer for {model_name} missing lang_code_to_id"
71
- )
72
 
73
  # Load model (with or without 8-bit)
74
- logger.info(
75
- f"Loading model {model_name} (8-bit={self.quantize})"
76
- )
77
  if self.quantize:
78
  bnb_cfg = BitsAndBytesConfig(load_in_8bit=True)
79
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
@@ -96,7 +91,7 @@ class ModelManager:
96
  )
97
 
98
  # Store and break
99
- self.selected_model_name = model_name
100
  self.tokenizer = tok
101
  self.model = mdl
102
  self.pipeline = pipe
@@ -104,22 +99,19 @@ class ModelManager:
104
 
105
  # Auto-pick Turkish target code if none specified
106
  if not self.default_tgt:
107
- tur_codes = [
108
- c for c in self.lang_codes if c.lower().startswith("tr")
109
- ]
110
  if not tur_codes:
111
  raise ValueError(f"No Turkish code found in {model_name}")
112
  self.default_tgt = tur_codes[0]
113
  logger.info(f"Default target language: {self.default_tgt}")
114
 
115
  return
 
116
  except Exception as e:
117
  logger.warning(f"Failed to load {model_name}: {e}")
118
  last_err = e
119
 
120
- raise RuntimeError(
121
- f"Could not load any model from {self.candidates}: {last_err}"
122
- )
123
 
124
  def translate(
125
  self,
@@ -138,9 +130,7 @@ class ModelManager:
138
  sample = text[0] if isinstance(text, list) else text
139
  try:
140
  iso = detect(sample).lower()
141
- candidates = [
142
- c for c in self.lang_codes if c.lower().startswith(iso)
143
- ]
144
  if not candidates:
145
  raise LangDetectException(f"No code for ISO '{iso}'")
146
  exact = [c for c in candidates if c.lower() == iso]
@@ -148,9 +138,7 @@ class ModelManager:
148
  logger.info(f"Auto-detected src_lang={src}")
149
  except Exception as e:
150
  logger.warning(f"langdetect failed ({e}); defaulting to English")
151
- eng_codes = [
152
- c for c in self.lang_codes if c.lower().startswith("en")
153
- ]
154
  src = eng_codes[0] if eng_codes else self.lang_codes[0]
155
  else:
156
  src = src_lang
@@ -159,13 +147,14 @@ class ModelManager:
159
 
160
  def get_info(self):
161
  """Return metadata for the sidebar display."""
 
162
  device = "cpu"
163
  if torch.cuda.is_available() and hasattr(self.model, "device"):
164
- idx = self.model.device.index if hasattr(self.model.device, "index") else None
165
- device = f"cuda:{idx}" if idx is not None else "cuda"
166
  return {
167
- "model": self.selected_model_name,
168
- "quantized": self.quantize,
169
  "device": device,
170
  "default_tgt": self.default_tgt,
171
  }
@@ -175,9 +164,16 @@ class ModelManager:
175
  class TranslationEvaluator:
176
  def __init__(self):
177
  self.bleu = evaluate.load("bleu")
178
- self.bertscore = evaluate.load("bertscore")
179
- self.comet = evaluate.load("comet", model_id="unbabel/comet-mqm-qe-da")
180
- logger.info("Loaded BLEU, BERTScore, COMET metrics")
 
 
 
 
 
 
 
181
 
182
  def evaluate(
183
  self,
@@ -186,36 +182,71 @@ class TranslationEvaluator:
186
  predictions: List[str],
187
  ):
188
  results = {}
 
189
  # BLEU
190
- results["BLEU"] = self.bleu.compute(
191
- predictions=predictions,
192
- references=[[r] for r in references],
193
- )["bleu"]
 
 
 
 
 
 
194
  # BERTScore (general)
195
- bs = self.bertscore.compute(
196
- predictions=predictions, references=references, lang="xx"
197
- )
198
- results["BERTScore"] = sum(bs["f1"]) / len(bs["f1"]) if bs["f1"] else 0.0
 
 
 
 
 
 
 
 
 
199
  # BERTurk (Turkish)
200
- bs_tr = self.bertscore.compute(
201
- predictions=predictions, references=references, lang="tr"
202
- )
203
- results["BERTurk"] = sum(bs_tr["f1"]) / len(bs_tr["f1"]) if bs_tr["f1"] else 0.0
 
 
 
 
 
 
 
 
 
204
  # COMET
205
- cm = self.comet.compute(
206
- srcs=sources, hyps=predictions, refs=references
207
- )
208
- scores = cm.get("scores", None)
209
- results["COMET"] = float(scores[0] if isinstance(scores, list) else scores) or 0.0
 
 
 
 
 
 
 
 
 
 
 
210
  return results
211
 
212
 
213
  # ────────── Streamlit App ──────────
214
-
215
  @st.cache_resource
216
  def load_resources():
217
  mgr = ModelManager(quantize=True)
218
- ev = TranslationEvaluator()
219
  return mgr, ev
220
 
221
 
@@ -235,17 +266,17 @@ def process_text(
235
  metrics: List[str],
236
  ):
237
  out = mgr.translate(src)
238
- hyp = out[0]["translation_text"]
239
  scores = ev.evaluate([src], [ref or ""], [hyp])
240
  return {
241
  "source": src,
242
  "reference": ref,
243
  "hypothesis": hyp,
244
- **{m: scores[m] for m in metrics},
245
  }
246
 
247
 
248
- def _show_single_results(res: dict):
249
  left, right = st.columns(2)
250
  with left:
251
  st.markdown("**Source:**")
@@ -257,7 +288,7 @@ def _show_single_results(res: dict):
257
  st.write(res["reference"])
258
  with right:
259
  st.markdown("### Scores")
260
- df = pd.DataFrame([{k: v for k, v in res.items() if k in metrics}])
261
  st.table(df)
262
 
263
 
@@ -283,13 +314,13 @@ def process_file(
283
  for s, r, h in zip(srcs, refs, hyps):
284
  sc = ev.evaluate([s], [r], [h])
285
  entry = {"src": s, "ref_tr": r, "hyp_tr": h}
286
- entry.update({m: sc[m] for m in metrics})
287
  results.append(entry)
288
  prog.progress(min(i + batch_size, total) / total)
289
  return pd.DataFrame(results)
290
 
291
 
292
- def _show_batch_viz(df: pd.DataFrame, metrics: List[str]):
293
  for m in metrics:
294
  st.markdown(f"#### {m} Distribution")
295
  fig = px.histogram(df, x=m)
@@ -297,9 +328,7 @@ def _show_batch_viz(df: pd.DataFrame, metrics: List[str]):
297
 
298
 
299
  def main():
300
- st.set_page_config(
301
- page_title="πŸ”€ Translationβ†’Turkish Quality", layout="wide"
302
- )
303
  st.title("πŸ”€ Translation β†’ TR Quality & COMET")
304
  st.markdown(
305
  "Translate any language into Turkish and evaluate with BLEU, BERTScore, BERTurk & COMET."
@@ -326,7 +355,7 @@ def main():
326
  if st.button("Evaluate"):
327
  with st.spinner("Translating & evaluating…"):
328
  res = process_text(src, ref, mgr, ev, metrics)
329
- _show_single_results(res)
330
 
331
  with tab2:
332
  uploaded = st.file_uploader(
@@ -337,10 +366,8 @@ def main():
337
  df_res = process_file(uploaded, mgr, ev, metrics, batch_size)
338
  st.markdown("### Batch Results")
339
  st.dataframe(df_res, use_container_width=True)
340
- _show_batch_viz(df_res, metrics)
341
- st.download_button(
342
- "Download CSV", df_res.to_csv(index=False), "results.csv"
343
- )
344
 
345
 
346
  if __name__ == "__main__":
 
2
 
3
  import streamlit as st
4
  import logging
5
+ import torch
6
  import pandas as pd
7
  import plotly.express as px
 
 
 
 
8
  from transformers import (
9
  AutoTokenizer,
10
  AutoModelForSeq2SeqLM,
11
  pipeline,
12
  BitsAndBytesConfig,
13
  )
14
+ from langdetect import detect, LangDetectException
15
  import evaluate
16
+ from typing import Union, List
17
 
18
  # ────────── Logging ──────────
19
  logging.basicConfig(
 
37
  quantize: bool = True,
38
  default_tgt: str = None,
39
  ):
40
+ # Disable 8-bit if CUDA isn't available
41
  if quantize and not torch.cuda.is_available():
42
  logger.warning("CUDA unavailable; disabling 8-bit quantization")
43
  quantize = False
 
49
  ]
50
  self.default_tgt = default_tgt # will auto-pick if None
51
 
52
+ self.model_name: str = None
53
  self.tokenizer = None
54
  self.model = None
55
  self.pipeline = None
 
65
  logger.info(f"Loading tokenizer for {model_name}")
66
  tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
67
  if not hasattr(tok, "lang_code_to_id"):
68
+ raise AttributeError(f"Tokenizer for {model_name} missing lang_code_to_id")
 
 
69
 
70
  # Load model (with or without 8-bit)
71
+ logger.info(f"Loading model {model_name} (8-bit={self.quantize})")
 
 
72
  if self.quantize:
73
  bnb_cfg = BitsAndBytesConfig(load_in_8bit=True)
74
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
 
91
  )
92
 
93
  # Store and break
94
+ self.model_name = model_name
95
  self.tokenizer = tok
96
  self.model = mdl
97
  self.pipeline = pipe
 
99
 
100
  # Auto-pick Turkish target code if none specified
101
  if not self.default_tgt:
102
+ tur_codes = [c for c in self.lang_codes if c.lower().startswith("tr")]
 
 
103
  if not tur_codes:
104
  raise ValueError(f"No Turkish code found in {model_name}")
105
  self.default_tgt = tur_codes[0]
106
  logger.info(f"Default target language: {self.default_tgt}")
107
 
108
  return
109
+
110
  except Exception as e:
111
  logger.warning(f"Failed to load {model_name}: {e}")
112
  last_err = e
113
 
114
+ raise RuntimeError(f"Could not load any model from {self.candidates}: {last_err}")
 
 
115
 
116
  def translate(
117
  self,
 
130
  sample = text[0] if isinstance(text, list) else text
131
  try:
132
  iso = detect(sample).lower()
133
+ candidates = [c for c in self.lang_codes if c.lower().startswith(iso)]
 
 
134
  if not candidates:
135
  raise LangDetectException(f"No code for ISO '{iso}'")
136
  exact = [c for c in candidates if c.lower() == iso]
 
138
  logger.info(f"Auto-detected src_lang={src}")
139
  except Exception as e:
140
  logger.warning(f"langdetect failed ({e}); defaulting to English")
141
+ eng_codes = [c for c in self.lang_codes if c.lower().startswith("en")]
 
 
142
  src = eng_codes[0] if eng_codes else self.lang_codes[0]
143
  else:
144
  src = src_lang
 
147
 
148
  def get_info(self):
149
  """Return metadata for the sidebar display."""
150
+ quantized = getattr(self.model, "is_loaded_in_8bit", False)
151
  device = "cpu"
152
  if torch.cuda.is_available() and hasattr(self.model, "device"):
153
+ dev = self.model.device
154
+ device = str(dev) if isinstance(dev, torch.device) else f"cuda:{getattr(dev, 'index', '')}"
155
  return {
156
+ "model": self.model_name,
157
+ "quantized": quantized,
158
  "device": device,
159
  "default_tgt": self.default_tgt,
160
  }
 
164
  class TranslationEvaluator:
165
  def __init__(self):
166
  self.bleu = evaluate.load("bleu")
167
+ try:
168
+ self.bertscore = evaluate.load("bertscore")
169
+ except Exception as e:
170
+ logger.error("BERTScore load error: %s", e)
171
+ self.bertscore = None
172
+ try:
173
+ self.comet = evaluate.load("comet", model_id="unbabel/comet-mqm-qe-da")
174
+ except Exception as e:
175
+ logger.error("COMET load error: %s", e)
176
+ self.comet = None
177
 
178
  def evaluate(
179
  self,
 
182
  predictions: List[str],
183
  ):
184
  results = {}
185
+
186
  # BLEU
187
+ try:
188
+ bleu_res = self.bleu.compute(
189
+ predictions=predictions,
190
+ references=[[r] for r in references],
191
+ )
192
+ results["BLEU"] = float(bleu_res.get("bleu", 0.0))
193
+ except Exception as e:
194
+ logger.error("BLEU compute error: %s", e)
195
+ results["BLEU"] = 0.0
196
+
197
  # BERTScore (general)
198
+ if self.bertscore:
199
+ try:
200
+ bs = self.bertscore.compute(
201
+ predictions=predictions, references=references, lang="xx"
202
+ )
203
+ f1 = bs.get("f1", [])
204
+ results["BERTScore"] = float(sum(f1)) / max(len(f1), 1)
205
+ except Exception as e:
206
+ logger.error("BERTScore compute error: %s", e)
207
+ results["BERTScore"] = 0.0
208
+ else:
209
+ results["BERTScore"] = 0.0
210
+
211
  # BERTurk (Turkish)
212
+ if self.bertscore:
213
+ try:
214
+ bs_tr = self.bertscore.compute(
215
+ predictions=predictions, references=references, lang="tr"
216
+ )
217
+ f1_tr = bs_tr.get("f1", [])
218
+ results["BERTurk"] = float(sum(f1_tr)) / max(len(f1_tr), 1)
219
+ except Exception as e:
220
+ logger.error("BERTurk compute error: %s", e)
221
+ results["BERTurk"] = 0.0
222
+ else:
223
+ results["BERTurk"] = 0.0
224
+
225
  # COMET
226
+ if self.comet:
227
+ try:
228
+ cm = self.comet.compute(
229
+ srcs=sources, hyps=predictions, refs=references
230
+ )
231
+ sc = cm.get("scores", None)
232
+ if isinstance(sc, list):
233
+ results["COMET"] = float(sc[0]) if sc else 0.0
234
+ else:
235
+ results["COMET"] = float(sc or 0.0)
236
+ except Exception as e:
237
+ logger.error("COMET compute error: %s", e)
238
+ results["COMET"] = 0.0
239
+ else:
240
+ results["COMET"] = 0.0
241
+
242
  return results
243
 
244
 
245
  # ────────── Streamlit App ──────────
 
246
  @st.cache_resource
247
  def load_resources():
248
  mgr = ModelManager(quantize=True)
249
+ ev = TranslationEvaluator()
250
  return mgr, ev
251
 
252
 
 
266
  metrics: List[str],
267
  ):
268
  out = mgr.translate(src)
269
+ hyp = out[0]["translation_text"] if isinstance(out, list) else out["translation_text"]
270
  scores = ev.evaluate([src], [ref or ""], [hyp])
271
  return {
272
  "source": src,
273
  "reference": ref,
274
  "hypothesis": hyp,
275
+ **{m: scores.get(m, 0.0) for m in metrics},
276
  }
277
 
278
 
279
+ def show_single_results(res: dict, metrics: List[str]):
280
  left, right = st.columns(2)
281
  with left:
282
  st.markdown("**Source:**")
 
288
  st.write(res["reference"])
289
  with right:
290
  st.markdown("### Scores")
291
+ df = pd.DataFrame([{m: res[m] for m in metrics}])
292
  st.table(df)
293
 
294
 
 
314
  for s, r, h in zip(srcs, refs, hyps):
315
  sc = ev.evaluate([s], [r], [h])
316
  entry = {"src": s, "ref_tr": r, "hyp_tr": h}
317
+ entry.update({m: sc.get(m, 0.0) for m in metrics})
318
  results.append(entry)
319
  prog.progress(min(i + batch_size, total) / total)
320
  return pd.DataFrame(results)
321
 
322
 
323
+ def show_batch_viz(df: pd.DataFrame, metrics: List[str]):
324
  for m in metrics:
325
  st.markdown(f"#### {m} Distribution")
326
  fig = px.histogram(df, x=m)
 
328
 
329
 
330
  def main():
331
+ st.set_page_config(page_title="πŸ”€ Translationβ†’Turkish Quality", layout="wide")
 
 
332
  st.title("πŸ”€ Translation β†’ TR Quality & COMET")
333
  st.markdown(
334
  "Translate any language into Turkish and evaluate with BLEU, BERTScore, BERTurk & COMET."
 
355
  if st.button("Evaluate"):
356
  with st.spinner("Translating & evaluating…"):
357
  res = process_text(src, ref, mgr, ev, metrics)
358
+ show_single_results(res, metrics)
359
 
360
  with tab2:
361
  uploaded = st.file_uploader(
 
366
  df_res = process_file(uploaded, mgr, ev, metrics, batch_size)
367
  st.markdown("### Batch Results")
368
  st.dataframe(df_res, use_container_width=True)
369
+ show_batch_viz(df_res, metrics)
370
+ st.download_button("Download CSV", df_res.to_csv(index=False), "results.csv")
 
 
371
 
372
 
373
  if __name__ == "__main__":