kleervoyans commited on
Commit
3755b73
Β·
verified Β·
1 Parent(s): dc2f97b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -146
app.py CHANGED
@@ -4,12 +4,14 @@ import streamlit as st
4
  import streamlit.components.v1 as components
5
  import logging
6
  import torch
 
 
7
  import pandas as pd
8
  import plotly.express as px
9
  import time
10
  import difflib
11
- from typing import Union, List
12
 
 
13
  from langdetect import detect, LangDetectException
14
  from transformers import (
15
  AutoTokenizer,
@@ -18,73 +20,73 @@ from transformers import (
18
  BitsAndBytesConfig,
19
  )
20
  import evaluate
 
21
 
22
  # ────────── Global CSS ──────────
23
- st.markdown(
24
- """
25
- <style>
26
- /* Page */
27
- .main .block-container { max-width: 900px; padding: 1rem 2rem; }
28
- /* Buttons */
29
- .stButton>button { background-color: #4A90E2; color: white; border-radius: 4px; }
30
- .stButton>button:hover { background-color: #357ABD; }
31
- /* Text areas */
32
- textarea { border-radius: 4px; }
33
- /* Tables */
34
- .stTable table { border-radius: 4px; overflow: hidden; }
35
- </style>
36
- """,
37
- unsafe_allow_html=True
38
- )
39
 
40
  # ────────── Logging ──────────
41
  logging.basicConfig(
42
  format="%(asctime)s %(levelname)s %(name)s: %(message)s",
43
  datefmt="%Y-%m-%d %H:%M:%S",
44
- level=logging.INFO
45
  )
46
  logger = logging.getLogger(__name__)
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # ────────── Model Manager ──────────
50
  class ModelManager:
51
  """
52
- Selects & loads NLLB‐200 or M2M100 (8‐bit if GPU available).
53
- Exposes `translate()` with auto‐lang detection + dynamic tgt_lang.
54
  """
55
- def __init__(
56
- self,
57
- candidates: List[str] = None,
58
- quantize: bool = True,
59
- default_tgt: str = None,
60
- ):
61
  if quantize and not torch.cuda.is_available():
62
  logger.warning("CUDA unavailable; disabling 8-bit quantization")
63
  quantize = False
64
- self.quantize = quantize
65
- self.candidates = candidates or [
66
  "facebook/nllb-200-distilled-600M",
67
- "facebook/m2m100_418M"
68
  ]
69
  self.default_tgt = default_tgt
70
- self.model_name = None
71
- self.tokenizer = None
72
- self.model = None
73
- self.pipeline = None
74
- self.lang_codes = []
75
  self._load_best()
76
 
77
  def _load_best(self):
78
  last_err = None
79
  for name in self.candidates:
80
  try:
81
- # 1) Tokenizer
82
- logger.info(f"Loading tokenizer for {name}")
83
  tok = AutoTokenizer.from_pretrained(name, use_fast=True)
84
  if not hasattr(tok, "lang_code_to_id"):
85
  raise AttributeError("no lang_code_to_id")
86
- # 2) Model (8-bit if configured)
87
- logger.info(f"Loading model {name} (8-bit={self.quantize})")
88
  if self.quantize:
89
  bnb = BitsAndBytesConfig(load_in_8bit=True)
90
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
@@ -94,21 +96,19 @@ class ModelManager:
94
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
95
  name, device_map="auto"
96
  )
97
- # 3) Pipeline
98
  pipe = pipeline("translation", model=mdl, tokenizer=tok)
99
- # Store
100
  self.model_name = name
101
  self.tokenizer = tok
102
  self.model = mdl
103
  self.pipeline = pipe
104
  self.lang_codes = list(tok.lang_code_to_id.keys())
105
- # Auto‐pick Turkish if needed
106
  if not self.default_tgt:
107
  tur = [c for c in self.lang_codes if c.lower().startswith("tr")]
108
  if not tur:
109
  raise ValueError("No Turkish code found")
110
  self.default_tgt = tur[0]
111
- logger.info(f"Default target = {self.default_tgt}")
112
  return
113
  except Exception as e:
114
  logger.warning(f"Failed to load {name}: {e}")
@@ -116,30 +116,26 @@ class ModelManager:
116
  raise RuntimeError(f"No model loaded: {last_err}")
117
 
118
  def translate(
119
- self,
120
- text: Union[str, List[str]],
121
- src_lang: str = None,
122
- tgt_lang: str = None,
123
  ):
124
  tgt = tgt_lang or self.default_tgt
125
- # auto‐detect source if missing
126
  if not src_lang:
127
  sample = text[0] if isinstance(text, list) else text
128
  try:
129
  iso = detect(sample).lower()
130
- cands = [c for c in self.lang_codes if c.lower().startswith(iso)]
131
- if not cands: raise LangDetectException()
132
- exact = [c for c in cands if c.lower() == iso]
133
- src = exact[0] if exact else cands[0]
134
  logger.info(f"Detected src_lang={src}")
135
  except Exception:
136
- # fallback to English
137
  eng = [c for c in self.lang_codes if c.lower().startswith("en")]
138
  src = eng[0] if eng else self.lang_codes[0]
139
  logger.warning(f"Falling back src_lang={src}")
140
  else:
141
  src = src_lang
142
-
143
  return self.pipeline(text, src_lang=src, tgt_lang=tgt)
144
 
145
  def get_info(self):
@@ -152,79 +148,185 @@ class ModelManager:
152
  "quantized": self.quantize,
153
  "device": dev,
154
  "default_tgt": self.default_tgt,
 
155
  }
156
 
157
-
158
  # ────────── Evaluator ──────────
159
  class TranslationEvaluator:
 
 
 
160
  def __init__(self):
161
- self.bleu = evaluate.load("bleu")
 
 
 
 
 
 
162
  self.bertscore = evaluate.load("bertscore")
163
- self.comet = evaluate.load("comet", model_id="unbabel/wmt22-comet-da")
164
- logger.info("Loaded BLEU, BERTScore, COMET")
 
 
 
165
 
166
- def evaluate(
167
  self,
168
- srcs: List[str],
169
- refs: List[str],
170
- hyps: List[str],
171
- ):
 
 
172
  out = {}
173
- # BLEU
174
- b = self.bleu.compute(predictions=hyps, references=[[r] for r in refs])
175
- out["BLEU"] = float(b.get("bleu", 0.0))
176
- # BERTScore xx
177
- bs = self.bertscore.compute(predictions=hyps, references=refs, lang="xx")
178
- f1 = bs.get("f1", [])
179
- out["BERTScore"] = float(sum(f1)/len(f1)) if f1 else 0.0
180
- # BERTurk tr
181
- bt = self.bertscore.compute(predictions=hyps, references=refs, lang="tr")
182
- f2 = bt.get("f1", [])
183
- out["BERTurk"] = float(sum(f2)/len(f2)) if f2 else 0.0
184
- # COMET
185
- cm = self.comet.compute(srcs=srcs, hyps=hyps, refs=refs)
186
- sc = cm.get("scores")
187
- out["COMET"] = float(sc[0] if isinstance(sc, list) else sc or 0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  return out
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  # ────────── Streamlit App ──────────
192
  @st.cache_resource
193
  def load_resources():
194
  mgr = ModelManager(quantize=True)
195
  ev = TranslationEvaluator()
196
- return mgr, ev
197
-
 
198
 
199
  def display_model_info(info: dict):
200
  st.sidebar.markdown("### Model Info")
201
- st.sidebar.write(f"β€’ Model: **{info['model']}**")
202
- st.sidebar.write(f"β€’ Quantized: **{info['quantized']}**")
203
- st.sidebar.write(f"β€’ Device: **{info['device']}**")
204
-
205
 
206
- def process_and_stream(src, ref, tgt, mgr, ev, metrics):
207
- # 1) call pipeline
208
- out = mgr.translate(src, tgt_lang=tgt)
209
- hyp = out[0]["translation_text"]
210
-
211
- # 2) pseudo‐stream: reveal word by word
212
- placeholder = st.empty()
213
- text_acc = ""
214
- for w in hyp.split():
215
- text_acc += w + " "
216
- placeholder.markdown(f"**Hypothesis ({tgt}):** {text_acc}")
217
- time.sleep(0.05)
218
-
219
- # 3) metrics (only if ref given)
220
- scores = {}
221
- if ref and ref.strip():
222
- scores = ev.evaluate([src], [ref], [hyp])
223
- return hyp, scores
224
-
225
-
226
- def show_diff(ref, hyp):
227
- # side‐by‐side HTML diff
228
  differ = difflib.HtmlDiff(tabsize=4, wrapcolumn=60)
229
  html = differ.make_table(
230
  ref.split(), hyp.split(),
@@ -233,84 +335,115 @@ def show_diff(ref, hyp):
233
  )
234
  components.html(html, height=200, scrolling=True)
235
 
236
-
237
  def main():
238
- st.set_page_config(page_title="πŸ”€ Multi‐Lang ↑TR + Eval", layout="wide")
239
- st.title("🌐 Translate β†’ πŸ”  Turkish & Evaluate")
240
- st.write("Choose target, translate from any language, and (optionally) eval against a reference.")
241
 
242
- # Sidebar: load models & then dynamic tgt dropdown
243
  with st.sidebar:
244
  st.header("Settings")
245
- mgr, ev = load_resources()
246
  info = mgr.get_info()
247
  display_model_info(info)
248
 
249
  tgt = st.selectbox(
250
- "Target language code",
251
- options=mgr.lang_codes,
252
- index=mgr.lang_codes.index(info["default_tgt"])
253
- )
254
- metrics = st.multiselect(
255
- "Metrics",
256
- ["BLEU","BERTScore","BERTurk","COMET"],
257
- default=["BLEU","BERTScore","COMET"]
258
  )
 
 
 
 
 
 
 
259
  batch_size = st.slider("Batch size", 1, 32, 8)
260
 
261
- tab1, tab2 = st.tabs(["Single sentence","Batch CSV"])
262
 
 
263
  with tab1:
264
- src = st.text_area("Source sentence:", height=120)
265
- ref = st.text_area("Turkish reference (optional):", height=80)
266
  if st.button("Translate & Eval"):
267
- with st.spinner("Working…"):
268
- hyp, scores = process_and_stream(src, ref, tgt, mgr, ev, metrics)
269
- # show scores
270
- df = {m: (scores.get(m) if ref.strip() else None) for m in metrics}
 
 
 
 
 
 
 
 
 
 
 
 
271
  st.markdown("### Scores")
272
- st.table(pd.DataFrame([df]).replace({None:"N/A"}))
 
273
  # diff
274
  if ref.strip():
275
- st.markdown("### Diff view")
276
  show_diff(ref, hyp)
277
 
 
 
 
 
 
 
 
278
  with tab2:
279
  uploaded = st.file_uploader("Upload CSV with `src`,`ref_tr`", type=["csv"])
280
  if uploaded:
281
  df = pd.read_csv(uploaded)
282
- if not {"src","ref_tr"}.issubset(df):
283
- st.error("CSV needs `src` and `ref_tr` columns.")
284
  else:
285
- with st.spinner("Batch translating…"):
286
- out_rows = []
287
  prog = st.progress(0)
288
- for i in range(0, len(df), batch_size):
 
289
  batch = df.iloc[i : i+batch_size]
290
  srcs, refs = batch["src"].tolist(), batch["ref_tr"].tolist()
291
  outs = mgr.translate(srcs, tgt_lang=tgt)
292
  hyps = [o["translation_text"] for o in outs]
293
  for s, r, h in zip(srcs, refs, hyps):
294
- row = {"src":s, "ref_tr":r, "hyp_tr":h}
295
  if r.strip():
296
- sc = ev.evaluate([s],[r],[h])
297
- for m in metrics: row[m] = sc[m]
 
 
 
 
 
298
  else:
299
- for m in metrics: row[m] = None
300
- out_rows.append(row)
301
- prog.progress(min(i+batch_size,len(df))/len(df))
302
- res_df = pd.DataFrame(out_rows)
303
- st.markdown("### Batch Results")
 
 
304
  st.dataframe(res_df, use_container_width=True)
305
- # viz
 
306
  for m in metrics:
307
- st.markdown(f"#### {m} Histogram")
308
- col = res_df[m].dropna()
309
  if col.empty:
310
- st.write("No valid refs β†’ metric N/A.")
311
  else:
312
- fig = px.histogram(res_df, x=m)
313
  st.plotly_chart(fig, use_container_width=True)
 
314
  st.download_button("Download CSV", res_df.to_csv(index=False), "results.csv")
315
 
316
  if __name__=="__main__":
 
4
  import streamlit.components.v1 as components
5
  import logging
6
  import torch
7
+ import random
8
+ import numpy as np
9
  import pandas as pd
10
  import plotly.express as px
11
  import time
12
  import difflib
 
13
 
14
+ from typing import List, Union
15
  from langdetect import detect, LangDetectException
16
  from transformers import (
17
  AutoTokenizer,
 
20
  BitsAndBytesConfig,
21
  )
22
  import evaluate
23
+ from sacrebleu import corpus_bleu, sentence_bleu # Doc vs. segment BLEU
24
 
25
  # ────────── Global CSS ──────────
26
+ st.markdown("""
27
+ <style>
28
+ .main .block-container { max-width: 900px; padding: 1rem 2rem; }
29
+ .stButton>button { background-color: #4A90E2; color: white; border-radius: 4px; }
30
+ .stButton>button:hover { background-color: #357ABD; }
31
+ textarea { border-radius: 4px; }
32
+ .stTable table { border-radius: 4px; overflow: hidden; }
33
+ </style>
34
+ """, unsafe_allow_html=True)
 
 
 
 
 
 
 
35
 
36
  # ────────── Logging ──────────
37
  logging.basicConfig(
38
  format="%(asctime)s %(levelname)s %(name)s: %(message)s",
39
  datefmt="%Y-%m-%d %H:%M:%S",
40
+ level=logging.INFO,
41
  )
42
  logger = logging.getLogger(__name__)
43
 
44
+ # ────────── Utilities ──────────
45
+ def bootstrap(
46
+ fn, predictions: List[str], references: List[str], sources: List[str]=None,
47
+ n_resamples: int = 200, seed: int = 42
48
+ ) -> List[float]:
49
+ """Bootstrap metric fn over (predictions, references, [sources])."""
50
+ random.seed(seed)
51
+ scores = []
52
+ N = len(predictions)
53
+ for _ in range(n_resamples):
54
+ idxs = [random.randrange(N) for _ in range(N)]
55
+ ps = [predictions[i] for i in idxs]
56
+ rs = [references[i] for i in idxs]
57
+ if sources:
58
+ ss = [sources[i] for i in idxs]
59
+ scores.append(fn(ps, rs, ss))
60
+ else:
61
+ scores.append(fn(ps, rs))
62
+ return scores
63
 
64
  # ────────── Model Manager ──────────
65
  class ModelManager:
66
  """
67
+ Loads the best translation model (NLLB‐200 or M2M100),
68
+ 8-bit if GPU available; auto-detects src_lang; dynamic tgt_lang.
69
  """
70
+ def __init__(self, candidates=None, quantize=True, default_tgt=None):
 
 
 
 
 
71
  if quantize and not torch.cuda.is_available():
72
  logger.warning("CUDA unavailable; disabling 8-bit quantization")
73
  quantize = False
74
+ self.quantize = quantize
75
+ self.candidates = candidates or [
76
  "facebook/nllb-200-distilled-600M",
77
+ "facebook/m2m100_418M",
78
  ]
79
  self.default_tgt = default_tgt
 
 
 
 
 
80
  self._load_best()
81
 
82
  def _load_best(self):
83
  last_err = None
84
  for name in self.candidates:
85
  try:
 
 
86
  tok = AutoTokenizer.from_pretrained(name, use_fast=True)
87
  if not hasattr(tok, "lang_code_to_id"):
88
  raise AttributeError("no lang_code_to_id")
89
+ logger.info(f"Loading {name} (8-bit={self.quantize})")
 
90
  if self.quantize:
91
  bnb = BitsAndBytesConfig(load_in_8bit=True)
92
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
 
96
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
97
  name, device_map="auto"
98
  )
 
99
  pipe = pipeline("translation", model=mdl, tokenizer=tok)
 
100
  self.model_name = name
101
  self.tokenizer = tok
102
  self.model = mdl
103
  self.pipeline = pipe
104
  self.lang_codes = list(tok.lang_code_to_id.keys())
105
+ # pick default target if none
106
  if not self.default_tgt:
107
  tur = [c for c in self.lang_codes if c.lower().startswith("tr")]
108
  if not tur:
109
  raise ValueError("No Turkish code found")
110
  self.default_tgt = tur[0]
111
+ logger.info(f"default_tgt = {self.default_tgt}")
112
  return
113
  except Exception as e:
114
  logger.warning(f"Failed to load {name}: {e}")
 
116
  raise RuntimeError(f"No model loaded: {last_err}")
117
 
118
  def translate(
119
+ self, text: Union[str, List[str]],
120
+ src_lang: str = None, tgt_lang: str = None
 
 
121
  ):
122
  tgt = tgt_lang or self.default_tgt
123
+ # auto-detect src
124
  if not src_lang:
125
  sample = text[0] if isinstance(text, list) else text
126
  try:
127
  iso = detect(sample).lower()
128
+ cand = [c for c in self.lang_codes if c.lower().startswith(iso)]
129
+ if not cand: raise LangDetectException()
130
+ exact = [c for c in cand if c.lower()==iso]
131
+ src = exact[0] if exact else cand[0]
132
  logger.info(f"Detected src_lang={src}")
133
  except Exception:
 
134
  eng = [c for c in self.lang_codes if c.lower().startswith("en")]
135
  src = eng[0] if eng else self.lang_codes[0]
136
  logger.warning(f"Falling back src_lang={src}")
137
  else:
138
  src = src_lang
 
139
  return self.pipeline(text, src_lang=src, tgt_lang=tgt)
140
 
141
  def get_info(self):
 
148
  "quantized": self.quantize,
149
  "device": dev,
150
  "default_tgt": self.default_tgt,
151
+ "langs": self.lang_codes,
152
  }
153
 
 
154
  # ────────── Evaluator ──────────
155
  class TranslationEvaluator:
156
+ """
157
+ Wraps BLEU (corpus), ChrF, TER, BERTScore, COMET (ref & ref-free), and provides CIs.
158
+ """
159
  def __init__(self):
160
+ # BLEU (corpus)
161
+ self.bleu = evaluate.load("bleu")
162
+ # ChrF :contentReference[oaicite:0]{index=0}
163
+ self.chrf = evaluate.load("chrf")
164
+ # TER :contentReference[oaicite:1]{index=1}
165
+ self.ter = evaluate.load("ter")
166
+ # BERTScore
167
  self.bertscore = evaluate.load("bertscore")
168
+ # COMET (ref-based)
169
+ self.comet_ref = evaluate.load("comet", model_id="unbabel/comet-mqm-qe-da")
170
+ # COMET QE (ref-free) :contentReference[oaicite:2]{index=2}
171
+ self.comet_qe = evaluate.load("comet", model_id="unbabel/wmt20-comet-qe-da")
172
+ logger.info("Loaded BLEU, ChrF, TER, BERTScore, COMET (ref & QE)")
173
 
174
+ def compute_metrics(
175
  self,
176
+ sources: List[str],
177
+ references: List[str],
178
+ predictions: List[str],
179
+ metrics: List[str],
180
+ ci: bool = True
181
+ ) -> dict:
182
  out = {}
183
+
184
+ # -- BLEU (document-level)
185
+ if "BLEU_doc" in metrics:
186
+ doc_bleu = self.bleu.compute(
187
+ predictions=predictions,
188
+ references=[[r] for r in references]
189
+ )["bleu"]
190
+ out["BLEU_doc"] = float(doc_bleu)
191
+
192
+ # -- BLEU (segment-level avg)
193
+ if "BLEU_seg" in metrics:
194
+ seg_scores = [
195
+ sentence_bleu([r], p).score
196
+ for p, r in zip(predictions, references)
197
+ ]
198
+ out["BLEU_seg"] = float(sum(seg_scores) / len(seg_scores))
199
+
200
+ # -- ChrF
201
+ if "ChrF" in metrics:
202
+ cf = self.chrf.compute(
203
+ predictions=predictions,
204
+ references=[[r] for r in references]
205
+ )["score"]
206
+ out["ChrF"] = float(cf)
207
+
208
+ # -- TER
209
+ if "TER" in metrics:
210
+ tr = self.ter.compute(
211
+ predictions=predictions,
212
+ references=[[r] for r in references],
213
+ normalized=True
214
+ )["score"]
215
+ out["TER"] = float(tr)
216
+
217
+ # -- BERTScore
218
+ if "BERTScore" in metrics:
219
+ bs = self.bertscore.compute(
220
+ predictions=predictions,
221
+ references=references,
222
+ lang="xx"
223
+ )["f1"]
224
+ out["BERTScore"] = float(sum(bs) / len(bs)) if bs else 0.0
225
+
226
+ # -- BERTurk
227
+ if "BERTurk" in metrics:
228
+ bt = self.bertscore.compute(
229
+ predictions=predictions,
230
+ references=references,
231
+ lang="tr"
232
+ )["f1"]
233
+ out["BERTurk"] = float(sum(bt) / len(bt)) if bt else 0.0
234
+
235
+ # -- COMET (ref-based)
236
+ if "COMET" in metrics:
237
+ cr = self.comet_ref.compute(
238
+ srcs=sources, hyps=predictions, refs=references
239
+ ).get("scores", 0.0)
240
+ out["COMET"] = float(cr[0] if isinstance(cr, list) else cr)
241
+
242
+ # -- QE (ref-free)
243
+ if "QE" in metrics:
244
+ cq = self.comet_qe.compute(
245
+ srcs=sources, hyps=predictions
246
+ ).get("scores", 0.0)
247
+ out["QE"] = float(cq[0] if isinstance(cq, list) else cq)
248
+
249
+ # -- Bootstrap CIs
250
+ if ci:
251
+ # BLEU_doc CI
252
+ if "CI_BLEU_doc" in metrics:
253
+ bsamp = bootstrap(
254
+ lambda ps, rs: self.bleu.compute(
255
+ predictions=ps,
256
+ references=[[r] for r in rs]
257
+ )["bleu"],
258
+ predictions, references
259
+ )
260
+ out["CI_BLEU_doc"] = (
261
+ float(np.percentile(bsamp, 2.5)),
262
+ float(np.percentile(bsamp, 97.5))
263
+ )
264
+ # BERTScore CI
265
+ if "CI_BERTScore" in metrics:
266
+ bsamp = bootstrap(
267
+ lambda ps, rs: sum(
268
+ self.bertscore.compute(
269
+ predictions=ps, references=rs, lang="xx"
270
+ )["f1"]
271
+ ) / len(ps),
272
+ predictions, references
273
+ )
274
+ out["CI_BERTScore"] = (
275
+ float(np.percentile(bsamp, 2.5)),
276
+ float(np.percentile(bsamp, 97.5))
277
+ )
278
+ # COMET CI
279
+ if "CI_COMET" in metrics:
280
+ bsamp = bootstrap(
281
+ lambda ps, rs, ss: float(
282
+ self.comet_ref.compute(
283
+ srcs=ss, hyps=ps, refs=rs
284
+ ).get("scores", [0.0])[0]
285
+ ),
286
+ predictions, references, sources
287
+ )
288
+ out["CI_COMET"] = (
289
+ float(np.percentile(bsamp, 2.5)),
290
+ float(np.percentile(bsamp, 97.5))
291
+ )
292
+
293
  return out
294
 
295
+ # ────────── Error Categorizer ──────────
296
+ class ErrorCategorizer:
297
+ """
298
+ Optional: classify error types via a fine-tuned text-classification model.
299
+ Supply your own HF model name for real categories.
300
+ """
301
+ def __init__(self, model_name: str = None):
302
+ if model_name:
303
+ self.pipe = pipeline("text-classification", model=model_name, device=0 if torch.cuda.is_available() else -1)
304
+ else:
305
+ self.pipe = None
306
+
307
+ def categorize(self, src: str, hyp: str):
308
+ if not self.pipe:
309
+ return []
310
+ inp = f"SRC: {src}\nHYP: {hyp}\nError types (pick from taxonomy):"
311
+ return self.pipe(inp, top_k=None)
312
 
313
  # ────────── Streamlit App ──────────
314
  @st.cache_resource
315
  def load_resources():
316
  mgr = ModelManager(quantize=True)
317
  ev = TranslationEvaluator()
318
+ # set your error-classifier HF model here, or None to disable
319
+ err = ErrorCategorizer(model_name="your-org/translation-error-categorizer")
320
+ return mgr, ev, err
321
 
322
  def display_model_info(info: dict):
323
  st.sidebar.markdown("### Model Info")
324
+ st.sidebar.write(f"β€’ **Model:** {info['model']}")
325
+ st.sidebar.write(f"β€’ **Quantized:** {info['quantized']}")
326
+ st.sidebar.write(f"β€’ **Device:** {info['device']}")
327
+ st.sidebar.write(f"β€’ **Default tgt:** {info['default_tgt']}")
328
 
329
+ def show_diff(ref: str, hyp: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  differ = difflib.HtmlDiff(tabsize=4, wrapcolumn=60)
331
  html = differ.make_table(
332
  ref.split(), hyp.split(),
 
335
  )
336
  components.html(html, height=200, scrolling=True)
337
 
 
338
  def main():
339
+ st.set_page_config(page_title="πŸ”€ Translateβ†’Eval+", layout="wide")
340
+ st.title("🌐 Translate β†’ πŸ”  Evaluate & Analyze")
341
+ st.write("Translate from any language, choose target, eval with advanced metrics, and inspect errors.")
342
 
343
+ # Sidebar
344
  with st.sidebar:
345
  st.header("Settings")
346
+ mgr, ev, err = load_resources()
347
  info = mgr.get_info()
348
  display_model_info(info)
349
 
350
  tgt = st.selectbox(
351
+ "Target language", info["langs"],
352
+ index=info["langs"].index(info["default_tgt"])
 
 
 
 
 
 
353
  )
354
+
355
+ metric_opts = [
356
+ "BLEU_doc","BLEU_seg","ChrF","TER",
357
+ "BERTScore","BERTurk","COMET","QE",
358
+ "CI_BLEU_doc","CI_BERTScore","CI_COMET"
359
+ ]
360
+ metrics = st.multiselect("Metrics & CIs", metric_opts, default=["BLEU_doc","BERTScore","COMET"])
361
  batch_size = st.slider("Batch size", 1, 32, 8)
362
 
363
+ tab1, tab2 = st.tabs(["Single","Batch CSV"])
364
 
365
+ # ────────── Single Sentence ──────────
366
  with tab1:
367
+ src = st.text_area("Source text:", height=120)
368
+ ref = st.text_area("Gold reference (optional):", height=80)
369
  if st.button("Translate & Eval"):
370
+ with st.spinner("⏳ Translating…"):
371
+ out = mgr.translate(src, tgt_lang=tgt)
372
+ hyp = out[0]["translation_text"]
373
+ st.markdown(f"**Hypothesis ({tgt}):** {hyp}")
374
+
375
+ # metrics
376
+ scores = ev.compute_metrics([src],[ref or ""],[hyp], metrics)
377
+ # display
378
+ sd = {}
379
+ for m in metrics:
380
+ v = scores.get(m)
381
+ if m.startswith("CI_"):
382
+ low, high = v
383
+ sd[m] = f"{low:.3f} – {high:.3f}"
384
+ else:
385
+ sd[m] = f"{v:.4f}" if v is not None else "N/A"
386
  st.markdown("### Scores")
387
+ st.table(pd.DataFrame([sd]))
388
+
389
  # diff
390
  if ref.strip():
391
+ st.markdown("### Diff View")
392
  show_diff(ref, hyp)
393
 
394
+ # error categories
395
+ cats = err.categorize(src, hyp)
396
+ if cats:
397
+ st.markdown("### Error Categories")
398
+ st.json(cats)
399
+
400
+ # ────────── Batch CSV ──────────
401
  with tab2:
402
  uploaded = st.file_uploader("Upload CSV with `src`,`ref_tr`", type=["csv"])
403
  if uploaded:
404
  df = pd.read_csv(uploaded)
405
+ if not {"src","ref_tr"}.issubset(df.columns):
406
+ st.error("CSV must have `src` and `ref_tr` columns.")
407
  else:
408
+ with st.spinner("⏳ Batch processing…"):
409
+ all_rows = []
410
  prog = st.progress(0)
411
+ N = len(df)
412
+ for i in range(0, N, batch_size):
413
  batch = df.iloc[i : i+batch_size]
414
  srcs, refs = batch["src"].tolist(), batch["ref_tr"].tolist()
415
  outs = mgr.translate(srcs, tgt_lang=tgt)
416
  hyps = [o["translation_text"] for o in outs]
417
  for s, r, h in zip(srcs, refs, hyps):
418
+ base = {"src":s, "ref_tr":r, "hyp_tr":h}
419
  if r.strip():
420
+ sc = ev.compute_metrics([s],[r],[h], metrics)
421
+ for m in metrics:
422
+ if m.startswith("CI_"):
423
+ low, high = sc[m]
424
+ base[m] = f"{low:.3f}–{high:.3f}"
425
+ else:
426
+ base[m] = sc[m]
427
  else:
428
+ for m in metrics:
429
+ base[m] = None
430
+ all_rows.append(base)
431
+ prog.progress(min(i+batch_size, N)/N)
432
+ res_df = pd.DataFrame(all_rows)
433
+
434
+ st.markdown("### Results")
435
  st.dataframe(res_df, use_container_width=True)
436
+
437
+ # histograms
438
  for m in metrics:
439
+ st.markdown(f"#### {m} Distribution")
440
+ col = pd.to_numeric(res_df[m], errors="coerce").dropna()
441
  if col.empty:
442
+ st.write("No valid data for this metric.")
443
  else:
444
+ fig = px.histogram(col, x=col)
445
  st.plotly_chart(fig, use_container_width=True)
446
+
447
  st.download_button("Download CSV", res_df.to_csv(index=False), "results.csv")
448
 
449
  if __name__=="__main__":