kleervoyans commited on
Commit
dc2f97b
Β·
verified Β·
1 Parent(s): 24c7801

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -196
app.py CHANGED
@@ -1,10 +1,13 @@
1
  # app.py
2
 
3
  import streamlit as st
 
4
  import logging
5
  import torch
6
  import pandas as pd
7
  import plotly.express as px
 
 
8
  from typing import Union, List
9
 
10
  from langdetect import detect, LangDetectException
@@ -16,11 +19,29 @@ from transformers import (
16
  )
17
  import evaluate
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # ────────── Logging ──────────
20
  logging.basicConfig(
21
  format="%(asctime)s %(levelname)s %(name)s: %(message)s",
22
  datefmt="%Y-%m-%d %H:%M:%S",
23
- level=logging.INFO,
24
  )
25
  logger = logging.getLogger(__name__)
26
 
@@ -28,9 +49,8 @@ logger = logging.getLogger(__name__)
28
  # ────────── Model Manager ──────────
29
  class ModelManager:
30
  """
31
- Select & load the best translation model from a candidate list,
32
- using 8-bit quant if CUDA is available, else full-precision.
33
- Auto-picks Turkish target code.
34
  """
35
  def __init__(
36
  self,
@@ -38,75 +58,62 @@ class ModelManager:
38
  quantize: bool = True,
39
  default_tgt: str = None,
40
  ):
41
- # disable 8-bit if no GPU
42
  if quantize and not torch.cuda.is_available():
43
  logger.warning("CUDA unavailable; disabling 8-bit quantization")
44
  quantize = False
45
- self.quantize = quantize
46
-
47
- self.candidates = candidates or [
48
  "facebook/nllb-200-distilled-600M",
49
- "facebook/m2m100_418M",
50
  ]
51
- self.default_tgt = default_tgt # will auto-pick if None
52
- self.model_name = None
53
- self.tokenizer = None
54
- self.model = None
55
- self.pipeline = None
56
- self.lang_codes = []
57
-
58
- self._select_and_load()
59
-
60
- def _select_and_load(self):
61
  last_err = None
62
  for name in self.candidates:
63
  try:
64
- # 1) tokenizer
65
  logger.info(f"Loading tokenizer for {name}")
66
  tok = AutoTokenizer.from_pretrained(name, use_fast=True)
67
  if not hasattr(tok, "lang_code_to_id"):
68
- raise AttributeError("no lang_code_to_id on tokenizer")
69
-
70
- # 2) model
71
  logger.info(f"Loading model {name} (8-bit={self.quantize})")
72
  if self.quantize:
73
- bnb_cfg = BitsAndBytesConfig(load_in_8bit=True)
74
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
75
- name,
76
- device_map="auto",
77
- quantization_config=bnb_cfg,
78
  )
79
  else:
80
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
81
- name,
82
- device_map="auto",
83
  )
84
- logger.info(f"Loaded {name}")
85
-
86
- # 3) pipeline
87
  pipe = pipeline("translation", model=mdl, tokenizer=tok)
88
-
89
- # store
90
  self.model_name = name
91
  self.tokenizer = tok
92
  self.model = mdl
93
  self.pipeline = pipe
94
  self.lang_codes = list(tok.lang_code_to_id.keys())
95
-
96
- # pick Turkish code if needed
97
  if not self.default_tgt:
98
  tur = [c for c in self.lang_codes if c.lower().startswith("tr")]
99
  if not tur:
100
- raise ValueError("No Turkish code available")
101
  self.default_tgt = tur[0]
102
- logger.info(f"default_tgt = {self.default_tgt}")
103
  return
104
-
105
  except Exception as e:
106
- logger.warning(f"failed to load {name}: {e}")
107
  last_err = e
108
-
109
- raise RuntimeError(f"no model loaded: {last_err}")
110
 
111
  def translate(
112
  self,
@@ -115,30 +122,27 @@ class ModelManager:
115
  tgt_lang: str = None,
116
  ):
117
  tgt = tgt_lang or self.default_tgt
118
-
119
- # auto-detect source
120
  if not src_lang:
121
  sample = text[0] if isinstance(text, list) else text
122
  try:
123
  iso = detect(sample).lower()
124
- cand = [c for c in self.lang_codes if c.lower().startswith(iso)]
125
- if not cand:
126
- raise LangDetectException(f"No code for ISO '{iso}'")
127
- # exact or first
128
- exact = [c for c in cand if c.lower() == iso]
129
- src = exact[0] if exact else cand[0]
130
- logger.info(f"src_lang = {src}")
131
  except Exception:
 
132
  eng = [c for c in self.lang_codes if c.lower().startswith("en")]
133
  src = eng[0] if eng else self.lang_codes[0]
134
- logger.warning(f"defaulting src_lang = {src}")
135
  else:
136
  src = src_lang
137
 
138
  return self.pipeline(text, src_lang=src, tgt_lang=tgt)
139
 
140
  def get_info(self):
141
- # figure out device for display
142
  dev = "cpu"
143
  if torch.cuda.is_available() and hasattr(self.model, "device"):
144
  d = self.model.device
@@ -161,39 +165,30 @@ class TranslationEvaluator:
161
 
162
  def evaluate(
163
  self,
164
- sources: List[str],
165
- references: List[str],
166
- predictions: List[str],
167
  ):
168
- results = {}
169
-
170
  # BLEU
171
- bleu_r = self.bleu.compute(predictions=predictions, references=[[r] for r in references])
172
- results["BLEU"] = float(bleu_r.get("bleu", 0.0))
173
-
174
- # BERTScore (xx)
175
- bs = self.bertscore.compute(predictions=predictions, references=references, lang="xx")
176
  f1 = bs.get("f1", [])
177
- results["BERTScore"] = float(sum(f1) / len(f1)) if f1 else 0.0
178
-
179
- # BERTurk (tr)
180
- bs_tr = self.bertscore.compute(predictions=predictions, references=references, lang="tr")
181
- f1t = bs_tr.get("f1", [])
182
- results["BERTurk"] = float(sum(f1t) / len(f1t)) if f1t else 0.0
183
-
184
  # COMET
185
- cm = self.comet.compute(srcs=sources, hyps=predictions, refs=references)
186
- sc = cm.get("scores", None)
187
- if isinstance(sc, list):
188
- results["COMET"] = float(sc[0]) if sc else 0.0
189
- else:
190
- results["COMET"] = float(sc or 0.0)
191
-
192
- return results
193
 
194
 
195
  # ────────── Streamlit App ──────────
196
-
197
  @st.cache_resource
198
  def load_resources():
199
  mgr = ModelManager(quantize=True)
@@ -203,142 +198,120 @@ def load_resources():
203
 
204
  def display_model_info(info: dict):
205
  st.sidebar.markdown("### Model Info")
206
- st.sidebar.write(f"**Model:** {info['model']}")
207
- st.sidebar.write(f"**8-bit Quantized:** {info['quantized']}")
208
- st.sidebar.write(f"**Device:** {info['device']}")
209
- st.sidebar.write(f"**Default target:** {info['default_tgt']}")
210
 
211
 
212
- def process_text(
213
- src: str,
214
- ref: str,
215
- mgr: ModelManager,
216
- ev: TranslationEvaluator,
217
- metrics: List[str],
218
- ):
219
- # 1) translate
220
- out = mgr.translate(src) # list of dicts
221
  hyp = out[0]["translation_text"]
222
 
223
- # 2) if we have a non-blank reference β†’ compute metrics; else all Nones
224
- result = {
225
- "source": src,
226
- "reference": ref or None,
227
- "hypothesis": hyp,
228
- }
 
 
 
 
229
  if ref and ref.strip():
230
  scores = ev.evaluate([src], [ref], [hyp])
231
- for m in metrics:
232
- result[m] = scores.get(m, 0.0)
233
- else:
234
- for m in metrics:
235
- result[m] = None
236
-
237
- return result
238
-
239
-
240
- def show_single_results(res: dict, metrics: List[str]):
241
- left, right = st.columns(2)
242
- with left:
243
- st.markdown("**Source:**"); st.write(res["source"])
244
- st.markdown("**Hypothesis (TR):**"); st.write(res["hypothesis"])
245
- if res["reference"]:
246
- st.markdown("**Reference (TR):**"); st.write(res["reference"])
247
- with right:
248
- st.markdown("### Scores")
249
- df = pd.DataFrame([{m: res[m] for m in metrics}])
250
- df = df.replace({None: "N/A"})
251
- st.table(df)
252
 
253
 
254
- def process_file(
255
- uploaded,
256
- mgr: ModelManager,
257
- ev: TranslationEvaluator,
258
- metrics: List[str],
259
- batch_size: int,
260
- ):
261
- df = pd.read_csv(uploaded)
262
- if not {"src", "ref_tr"}.issubset(df.columns):
263
- raise ValueError("CSV must have `src` and `ref_tr` columns")
264
- prog = st.progress(0)
265
- results = []
266
- total = len(df)
267
-
268
- for i in range(0, total, batch_size):
269
- batch = df.iloc[i : i + batch_size]
270
- srcs, refs = batch["src"].tolist(), batch["ref_tr"].tolist()
271
- outs = mgr.translate(srcs)
272
- hyps = [o["translation_text"] for o in outs]
273
-
274
- for s, r, h in zip(srcs, refs, hyps):
275
- entry = {"src": s, "ref_tr": r, "hyp_tr": h}
276
- if r and str(r).strip():
277
- sc = ev.evaluate([s], [r], [h])
278
- for m in metrics:
279
- entry[m] = sc.get(m, 0.0)
280
- else:
281
- for m in metrics:
282
- entry[m] = None
283
- results.append(entry)
284
-
285
- prog.progress(min(i + batch_size, total) / total)
286
-
287
- return pd.DataFrame(results)
288
-
289
-
290
- def show_batch_viz(df: pd.DataFrame, metrics: List[str]):
291
- for m in metrics:
292
- st.markdown(f"#### {m} Distribution")
293
- if df[m].dropna().empty:
294
- st.write("No reference provided, so this metric is N/A.")
295
- continue
296
- fig = px.histogram(df, x=m)
297
- st.plotly_chart(fig, use_container_width=True)
298
 
299
 
300
  def main():
301
- st.set_page_config(page_title="πŸ”€ Translationβ†’Turkish Quality", layout="wide")
302
- st.title("πŸ”€ Translation β†’ TR Quality & COMET")
303
- st.markdown("Translate any language into Turkish and evaluate (optional) with BLEU, BERTScore, BERTurk & COMET.")
304
 
305
- # Sidebar
306
  with st.sidebar:
307
  st.header("Settings")
308
- metrics = st.multiselect(
309
- "Select metrics",
310
- ["BLEU", "BERTScore", "BERTurk", "COMET"],
311
- default=["BLEU", "BERTScore", "COMET"]
 
 
 
 
 
 
 
 
 
312
  )
313
  batch_size = st.slider("Batch size", 1, 32, 8)
314
- mgr, ev = load_resources()
315
- display_model_info(mgr.get_info())
316
 
317
- # Tabs
318
- tab1, tab2 = st.tabs(["Single Sentence", "Batch CSV"])
319
 
320
  with tab1:
321
- src = st.text_area("Source sentence (any language):", height=150)
322
- ref = st.text_area("Turkish reference (optional):", height=100)
323
- if st.button("Evaluate"):
324
- with st.spinner("Translating & evaluating…"):
325
- res = process_text(src, ref, mgr, ev, metrics)
326
- show_single_results(res, metrics)
 
 
 
 
 
 
 
327
 
328
  with tab2:
329
- uploaded = st.file_uploader("Upload CSV with `src` & `ref_tr` columns", type=["csv"])
330
  if uploaded:
331
- with st.spinner("Processing file…"):
332
- df_res = process_file(uploaded, mgr, ev, metrics, batch_size)
333
- st.markdown("### Batch Results")
334
- st.dataframe(df_res, use_container_width=True)
335
- show_batch_viz(df_res, metrics)
336
- st.download_button("Download results as CSV", df_res.to_csv(index=False), "results.csv")
337
-
338
-
339
- if __name__ == "__main__":
340
- try:
341
- main()
342
- except Exception as e:
343
- st.error(f"Unexpected error: {e}")
344
- logger.exception("Unhandled exception")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # app.py
2
 
3
  import streamlit as st
4
+ import streamlit.components.v1 as components
5
  import logging
6
  import torch
7
  import pandas as pd
8
  import plotly.express as px
9
+ import time
10
+ import difflib
11
  from typing import Union, List
12
 
13
  from langdetect import detect, LangDetectException
 
19
  )
20
  import evaluate
21
 
22
+ # ────────── Global CSS ──────────
23
+ st.markdown(
24
+ """
25
+ <style>
26
+ /* Page */
27
+ .main .block-container { max-width: 900px; padding: 1rem 2rem; }
28
+ /* Buttons */
29
+ .stButton>button { background-color: #4A90E2; color: white; border-radius: 4px; }
30
+ .stButton>button:hover { background-color: #357ABD; }
31
+ /* Text areas */
32
+ textarea { border-radius: 4px; }
33
+ /* Tables */
34
+ .stTable table { border-radius: 4px; overflow: hidden; }
35
+ </style>
36
+ """,
37
+ unsafe_allow_html=True
38
+ )
39
+
40
  # ────────── Logging ──────────
41
  logging.basicConfig(
42
  format="%(asctime)s %(levelname)s %(name)s: %(message)s",
43
  datefmt="%Y-%m-%d %H:%M:%S",
44
+ level=logging.INFO
45
  )
46
  logger = logging.getLogger(__name__)
47
 
 
49
  # ────────── Model Manager ──────────
50
  class ModelManager:
51
  """
52
+ Selects & loads NLLB‐200 or M2M100 (8‐bit if GPU available).
53
+ Exposes `translate()` with auto‐lang detection + dynamic tgt_lang.
 
54
  """
55
  def __init__(
56
  self,
 
58
  quantize: bool = True,
59
  default_tgt: str = None,
60
  ):
 
61
  if quantize and not torch.cuda.is_available():
62
  logger.warning("CUDA unavailable; disabling 8-bit quantization")
63
  quantize = False
64
+ self.quantize = quantize
65
+ self.candidates = candidates or [
 
66
  "facebook/nllb-200-distilled-600M",
67
+ "facebook/m2m100_418M"
68
  ]
69
+ self.default_tgt = default_tgt
70
+ self.model_name = None
71
+ self.tokenizer = None
72
+ self.model = None
73
+ self.pipeline = None
74
+ self.lang_codes = []
75
+ self._load_best()
76
+
77
+ def _load_best(self):
 
78
  last_err = None
79
  for name in self.candidates:
80
  try:
81
+ # 1) Tokenizer
82
  logger.info(f"Loading tokenizer for {name}")
83
  tok = AutoTokenizer.from_pretrained(name, use_fast=True)
84
  if not hasattr(tok, "lang_code_to_id"):
85
+ raise AttributeError("no lang_code_to_id")
86
+ # 2) Model (8-bit if configured)
 
87
  logger.info(f"Loading model {name} (8-bit={self.quantize})")
88
  if self.quantize:
89
+ bnb = BitsAndBytesConfig(load_in_8bit=True)
90
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
91
+ name, device_map="auto", quantization_config=bnb
 
 
92
  )
93
  else:
94
  mdl = AutoModelForSeq2SeqLM.from_pretrained(
95
+ name, device_map="auto"
 
96
  )
97
+ # 3) Pipeline
 
 
98
  pipe = pipeline("translation", model=mdl, tokenizer=tok)
99
+ # Store
 
100
  self.model_name = name
101
  self.tokenizer = tok
102
  self.model = mdl
103
  self.pipeline = pipe
104
  self.lang_codes = list(tok.lang_code_to_id.keys())
105
+ # Auto‐pick Turkish if needed
 
106
  if not self.default_tgt:
107
  tur = [c for c in self.lang_codes if c.lower().startswith("tr")]
108
  if not tur:
109
+ raise ValueError("No Turkish code found")
110
  self.default_tgt = tur[0]
111
+ logger.info(f"Default target = {self.default_tgt}")
112
  return
 
113
  except Exception as e:
114
+ logger.warning(f"Failed to load {name}: {e}")
115
  last_err = e
116
+ raise RuntimeError(f"No model loaded: {last_err}")
 
117
 
118
  def translate(
119
  self,
 
122
  tgt_lang: str = None,
123
  ):
124
  tgt = tgt_lang or self.default_tgt
125
+ # auto‐detect source if missing
 
126
  if not src_lang:
127
  sample = text[0] if isinstance(text, list) else text
128
  try:
129
  iso = detect(sample).lower()
130
+ cands = [c for c in self.lang_codes if c.lower().startswith(iso)]
131
+ if not cands: raise LangDetectException()
132
+ exact = [c for c in cands if c.lower() == iso]
133
+ src = exact[0] if exact else cands[0]
134
+ logger.info(f"Detected src_lang={src}")
 
 
135
  except Exception:
136
+ # fallback to English
137
  eng = [c for c in self.lang_codes if c.lower().startswith("en")]
138
  src = eng[0] if eng else self.lang_codes[0]
139
+ logger.warning(f"Falling back src_lang={src}")
140
  else:
141
  src = src_lang
142
 
143
  return self.pipeline(text, src_lang=src, tgt_lang=tgt)
144
 
145
  def get_info(self):
 
146
  dev = "cpu"
147
  if torch.cuda.is_available() and hasattr(self.model, "device"):
148
  d = self.model.device
 
165
 
166
  def evaluate(
167
  self,
168
+ srcs: List[str],
169
+ refs: List[str],
170
+ hyps: List[str],
171
  ):
172
+ out = {}
 
173
  # BLEU
174
+ b = self.bleu.compute(predictions=hyps, references=[[r] for r in refs])
175
+ out["BLEU"] = float(b.get("bleu", 0.0))
176
+ # BERTScore xx
177
+ bs = self.bertscore.compute(predictions=hyps, references=refs, lang="xx")
 
178
  f1 = bs.get("f1", [])
179
+ out["BERTScore"] = float(sum(f1)/len(f1)) if f1 else 0.0
180
+ # BERTurk tr
181
+ bt = self.bertscore.compute(predictions=hyps, references=refs, lang="tr")
182
+ f2 = bt.get("f1", [])
183
+ out["BERTurk"] = float(sum(f2)/len(f2)) if f2 else 0.0
 
 
184
  # COMET
185
+ cm = self.comet.compute(srcs=srcs, hyps=hyps, refs=refs)
186
+ sc = cm.get("scores")
187
+ out["COMET"] = float(sc[0] if isinstance(sc, list) else sc or 0.0)
188
+ return out
 
 
 
 
189
 
190
 
191
  # ────────── Streamlit App ──────────
 
192
  @st.cache_resource
193
  def load_resources():
194
  mgr = ModelManager(quantize=True)
 
198
 
199
  def display_model_info(info: dict):
200
  st.sidebar.markdown("### Model Info")
201
+ st.sidebar.write(f"β€’ Model: **{info['model']}**")
202
+ st.sidebar.write(f"β€’ Quantized: **{info['quantized']}**")
203
+ st.sidebar.write(f"β€’ Device: **{info['device']}**")
 
204
 
205
 
206
+ def process_and_stream(src, ref, tgt, mgr, ev, metrics):
207
+ # 1) call pipeline
208
+ out = mgr.translate(src, tgt_lang=tgt)
 
 
 
 
 
 
209
  hyp = out[0]["translation_text"]
210
 
211
+ # 2) pseudo‐stream: reveal word by word
212
+ placeholder = st.empty()
213
+ text_acc = ""
214
+ for w in hyp.split():
215
+ text_acc += w + " "
216
+ placeholder.markdown(f"**Hypothesis ({tgt}):** {text_acc}")
217
+ time.sleep(0.05)
218
+
219
+ # 3) metrics (only if ref given)
220
+ scores = {}
221
  if ref and ref.strip():
222
  scores = ev.evaluate([src], [ref], [hyp])
223
+ return hyp, scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
 
226
+ def show_diff(ref, hyp):
227
+ # side‐by‐side HTML diff
228
+ differ = difflib.HtmlDiff(tabsize=4, wrapcolumn=60)
229
+ html = differ.make_table(
230
+ ref.split(), hyp.split(),
231
+ fromdesc="Reference", todesc="Hypothesis",
232
+ context=True, numlines=1
233
+ )
234
+ components.html(html, height=200, scrolling=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
 
237
  def main():
238
+ st.set_page_config(page_title="πŸ”€ Multi‐Lang ↑TR + Eval", layout="wide")
239
+ st.title("🌐 Translate β†’ πŸ”  Turkish & Evaluate")
240
+ st.write("Choose target, translate from any language, and (optionally) eval against a reference.")
241
 
242
+ # Sidebar: load models & then dynamic tgt dropdown
243
  with st.sidebar:
244
  st.header("Settings")
245
+ mgr, ev = load_resources()
246
+ info = mgr.get_info()
247
+ display_model_info(info)
248
+
249
+ tgt = st.selectbox(
250
+ "Target language code",
251
+ options=mgr.lang_codes,
252
+ index=mgr.lang_codes.index(info["default_tgt"])
253
+ )
254
+ metrics = st.multiselect(
255
+ "Metrics",
256
+ ["BLEU","BERTScore","BERTurk","COMET"],
257
+ default=["BLEU","BERTScore","COMET"]
258
  )
259
  batch_size = st.slider("Batch size", 1, 32, 8)
 
 
260
 
261
+ tab1, tab2 = st.tabs(["Single sentence","Batch CSV"])
 
262
 
263
  with tab1:
264
+ src = st.text_area("Source sentence:", height=120)
265
+ ref = st.text_area("Turkish reference (optional):", height=80)
266
+ if st.button("Translate & Eval"):
267
+ with st.spinner("Working…"):
268
+ hyp, scores = process_and_stream(src, ref, tgt, mgr, ev, metrics)
269
+ # show scores
270
+ df = {m: (scores.get(m) if ref.strip() else None) for m in metrics}
271
+ st.markdown("### Scores")
272
+ st.table(pd.DataFrame([df]).replace({None:"N/A"}))
273
+ # diff
274
+ if ref.strip():
275
+ st.markdown("### Diff view")
276
+ show_diff(ref, hyp)
277
 
278
  with tab2:
279
+ uploaded = st.file_uploader("Upload CSV with `src`,`ref_tr`", type=["csv"])
280
  if uploaded:
281
+ df = pd.read_csv(uploaded)
282
+ if not {"src","ref_tr"}.issubset(df):
283
+ st.error("CSV needs `src` and `ref_tr` columns.")
284
+ else:
285
+ with st.spinner("Batch translating…"):
286
+ out_rows = []
287
+ prog = st.progress(0)
288
+ for i in range(0, len(df), batch_size):
289
+ batch = df.iloc[i : i+batch_size]
290
+ srcs, refs = batch["src"].tolist(), batch["ref_tr"].tolist()
291
+ outs = mgr.translate(srcs, tgt_lang=tgt)
292
+ hyps = [o["translation_text"] for o in outs]
293
+ for s, r, h in zip(srcs, refs, hyps):
294
+ row = {"src":s, "ref_tr":r, "hyp_tr":h}
295
+ if r.strip():
296
+ sc = ev.evaluate([s],[r],[h])
297
+ for m in metrics: row[m] = sc[m]
298
+ else:
299
+ for m in metrics: row[m] = None
300
+ out_rows.append(row)
301
+ prog.progress(min(i+batch_size,len(df))/len(df))
302
+ res_df = pd.DataFrame(out_rows)
303
+ st.markdown("### Batch Results")
304
+ st.dataframe(res_df, use_container_width=True)
305
+ # viz
306
+ for m in metrics:
307
+ st.markdown(f"#### {m} Histogram")
308
+ col = res_df[m].dropna()
309
+ if col.empty:
310
+ st.write("No valid refs β†’ metric N/A.")
311
+ else:
312
+ fig = px.histogram(res_df, x=m)
313
+ st.plotly_chart(fig, use_container_width=True)
314
+ st.download_button("Download CSV", res_df.to_csv(index=False), "results.csv")
315
+
316
+ if __name__=="__main__":
317
+ main()