OneStarDao commited on
Commit
4082288
·
verified ·
1 Parent(s): 65751f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -31
app.py CHANGED
@@ -1,47 +1,81 @@
1
- import io, traceback, numpy as np, gradio as gr, matplotlib
2
  matplotlib.use("Agg")
 
3
  from PIL import Image
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from wfgy_sdk import get_engine
6
  from wfgy_sdk.evaluator import compare_logits, plot_histogram
7
- from tabulate import tabulate
8
 
9
- MODEL = "sshleifer/tiny-gpt2"
10
- tok = AutoTokenizer.from_pretrained(MODEL)
11
- mdl = AutoModelForCausalLM.from_pretrained(MODEL)
12
- eng = get_engine()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def run(prompt: str):
15
  prompt = prompt.strip()
16
  if not prompt:
17
- return "", "", "", None, None
18
  try:
19
  ids = tok(prompt, return_tensors="pt").input_ids
20
- rawL = mdl(ids).logits[0, -1].detach().cpu().numpy()
21
  G = np.random.randn(256).astype(np.float32)
22
  I = G + np.random.normal(scale=0.05, size=256).astype(np.float32)
23
- modL = eng.run(I, G, rawL)
24
-
25
- m = compare_logits(rawL, modL)
26
- tbl = tabulate(
27
- [[f"{m['std_ratio']:.3f}",
28
- f"{m['var_drop']*100:4.1f} %",
29
- f"{m['kl']:.3f}",
30
- "" if m['top1'] else ""]],
31
- headers=["std_ratio", "▼ var", "KL", "top-1"],
32
- tablefmt="github")
33
- headline = f"▼ var {m['var_drop']*100:4.1f} % | KL {m['kl']:.3f}"
34
-
35
- fig = plot_histogram(rawL, modL)
 
 
 
 
 
 
 
36
  buf = io.BytesIO(); fig.savefig(buf, format="png"); buf.seek(0)
37
  img = Image.open(buf)
38
 
39
- raw_txt = prompt + tok.decode(int(rawL.argmax()))
40
- mod_txt = prompt + tok.decode(int(modL.argmax()))
41
- return raw_txt, mod_txt, headline, tbl, img
42
  except Exception:
43
  tb = traceback.format_exc()
44
- return "runtime error", tb, "runtime error", "", None
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  with gr.Blocks(title="WFGY variance gate") as demo:
47
  gr.Markdown("# 🧠 WFGY simulation demo")
@@ -53,15 +87,29 @@ with gr.Blocks(title="WFGY variance gate") as demo:
53
  mod_box = gr.Textbox(label="After WFGY")
54
 
55
  headline = gr.Markdown()
56
- metrics = gr.Markdown() # ← 新增數值表
57
  img = gr.Image(label="Logit histogram", type="pil")
58
 
59
- btn.click(run, prompt,
60
- [raw_box, mod_box, headline, metrics, img])
 
 
 
 
61
 
62
- gr.Markdown("---\n"
63
- "### ⭐ 10 000 stars → unlock **WFGY 2.0** by 2025-08-01")
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  if __name__ == "__main__":
66
  demo.queue().launch()
67
-
 
1
+ import io, json, traceback, numpy as np, matplotlib
2
  matplotlib.use("Agg")
3
+
4
  from PIL import Image
5
+ import gradio as gr
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from wfgy_sdk import get_engine
8
  from wfgy_sdk.evaluator import compare_logits, plot_histogram
 
9
 
10
+ # model + engine
11
+ MODEL_ID = "sshleifer/tiny-gpt2"
12
+ tok = AutoTokenizer.from_pretrained(MODEL_ID)
13
+ mdl = AutoModelForCausalLM.from_pretrained(MODEL_ID)
14
+ eng = get_engine()
15
+
16
+ # runtime history
17
+ hist = {"step": [], "var": [], "kl": []}
18
+
19
+ # paper benchmark numbers
20
+ paper = {
21
+ "benchmark": [
22
+ "MMLU", "GSM8K", "BBH", "MathBench",
23
+ "TruthfulQA", "XNLI", "MLQA", "LongBench",
24
+ "VQAv2", "OK-VQA"
25
+ ],
26
+ "baseline": [61.0, 78.0, 79.3, 72.2, 62.4, 59.5, 78.1, 51.4, 69.1, 65.7],
27
+ "wfgy": [89.8, 98.7, 100.7, 87.4, 90.4, 77.3, 106.6, 69.6, 86.6, 86.8]
28
+ }
29
 
30
  def run(prompt: str):
31
  prompt = prompt.strip()
32
  if not prompt:
33
+ return "", "", "", "", None, update_history()
34
  try:
35
  ids = tok(prompt, return_tensors="pt").input_ids
36
+ raw = mdl(ids).logits[0, -1].detach().cpu().numpy()
37
  G = np.random.randn(256).astype(np.float32)
38
  I = G + np.random.normal(scale=0.05, size=256).astype(np.float32)
39
+ mod = eng.run(I, G, raw)
40
+ m = compare_logits(raw, mod)
41
+
42
+ # update history
43
+ step = len(hist["step"]) + 1
44
+ hist["step"].append(step)
45
+ hist["var"].append(m["var_drop"] * 100)
46
+ hist["kl"].append(m["kl"])
47
+
48
+ headline = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f}"
49
+ metrics_md = (
50
+ "| metric | value |\n"
51
+ "|--------|-------|\n"
52
+ f"| std_ratio | {m['std_ratio']:.3f} |\n"
53
+ f"| var_drop | {m['var_drop']*100:.1f}% |\n"
54
+ f"| KL | {m['kl']:.3f} |\n"
55
+ f"| top-1 same| {'yes' if m['top1'] else 'no'} |"
56
+ )
57
+
58
+ fig = plot_histogram(raw, mod)
59
  buf = io.BytesIO(); fig.savefig(buf, format="png"); buf.seek(0)
60
  img = Image.open(buf)
61
 
62
+ raw_txt = prompt + tok.decode(int(raw.argmax()))
63
+ mod_txt = prompt + tok.decode(int(mod.argmax()))
64
+ return raw_txt, mod_txt, headline, metrics_md, img, update_history()
65
  except Exception:
66
  tb = traceback.format_exc()
67
+ return "runtime error", tb, "runtime error", "", None, update_history()
68
+
69
+ def update_history():
70
+ return {
71
+ "step": hist["step"],
72
+ "var": hist["var"],
73
+ "kl": hist["kl"]
74
+ }
75
+
76
+ def clear_history():
77
+ hist["step"].clear(); hist["var"].clear(); hist["kl"].clear()
78
+ return update_history()
79
 
80
  with gr.Blocks(title="WFGY variance gate") as demo:
81
  gr.Markdown("# 🧠 WFGY simulation demo")
 
87
  mod_box = gr.Textbox(label="After WFGY")
88
 
89
  headline = gr.Markdown()
90
+ metrics = gr.Markdown()
91
  img = gr.Image(label="Logit histogram", type="pil")
92
 
93
+ hist_plot = gr.LinePlot(
94
+ label="History (var% & KL)",
95
+ x="step", y=["var", "kl"],
96
+ overlay=True, height=250
97
+ )
98
+ clear_btn = gr.Button("Clear history")
99
 
100
+ with gr.Accordion("Paper benchmarks", open=False):
101
+ bench_df = gr.DataFrame(paper, interactive=False)
102
+ bench_bar = gr.BarPlot(paper, x="benchmark", y=["baseline", "wfgy"],
103
+ overlay=False, height=300)
104
+
105
+ gr.Markdown(
106
+ "---\n"
107
+ "### ⭐ 10 000 GitHub stars before **2025-08-01** unlock **WFGY 2.0**"
108
+ )
109
+
110
+ btn.click(run, prompt,
111
+ [raw_box, mod_box, headline, metrics, img, hist_plot])
112
+ clear_btn.click(fn=clear_history, inputs=None, outputs=hist_plot)
113
 
114
  if __name__ == "__main__":
115
  demo.queue().launch()