OneStarDao commited on
Commit
24bd8e1
·
verified ·
1 Parent(s): 4082288

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -102
app.py CHANGED
@@ -1,115 +1,68 @@
1
- import io, json, traceback, numpy as np, matplotlib
2
  matplotlib.use("Agg")
3
 
4
  from PIL import Image
5
- import gradio as gr
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from wfgy_sdk import get_engine
8
  from wfgy_sdk.evaluator import compare_logits, plot_histogram
9
 
10
- # model + engine
11
- MODEL_ID = "sshleifer/tiny-gpt2"
12
- tok = AutoTokenizer.from_pretrained(MODEL_ID)
13
- mdl = AutoModelForCausalLM.from_pretrained(MODEL_ID)
14
- eng = get_engine()
15
-
16
- # runtime history
17
- hist = {"step": [], "var": [], "kl": []}
18
-
19
- # paper benchmark numbers
20
- paper = {
21
- "benchmark": [
22
- "MMLU", "GSM8K", "BBH", "MathBench",
23
- "TruthfulQA", "XNLI", "MLQA", "LongBench",
24
- "VQAv2", "OK-VQA"
25
- ],
26
- "baseline": [61.0, 78.0, 79.3, 72.2, 62.4, 59.5, 78.1, 51.4, 69.1, 65.7],
27
- "wfgy": [89.8, 98.7, 100.7, 87.4, 90.4, 77.3, 106.6, 69.6, 86.6, 86.8]
28
- }
29
-
30
- def run(prompt: str):
31
- prompt = prompt.strip()
32
- if not prompt:
33
- return "", "", "", "", None, update_history()
34
- try:
35
- ids = tok(prompt, return_tensors="pt").input_ids
36
- raw = mdl(ids).logits[0, -1].detach().cpu().numpy()
37
- G = np.random.randn(256).astype(np.float32)
38
- I = G + np.random.normal(scale=0.05, size=256).astype(np.float32)
39
- mod = eng.run(I, G, raw)
40
- m = compare_logits(raw, mod)
41
-
42
- # update history
43
- step = len(hist["step"]) + 1
44
- hist["step"].append(step)
45
- hist["var"].append(m["var_drop"] * 100)
46
- hist["kl"].append(m["kl"])
47
-
48
- headline = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f}"
49
- metrics_md = (
50
- "| metric | value |\n"
51
- "|--------|-------|\n"
52
- f"| std_ratio | {m['std_ratio']:.3f} |\n"
53
- f"| var_drop | {m['var_drop']*100:.1f}% |\n"
54
- f"| KL | {m['kl']:.3f} |\n"
55
- f"| top-1 same| {'yes' if m['top1'] else 'no'} |"
56
- )
57
-
58
- fig = plot_histogram(raw, mod)
59
- buf = io.BytesIO(); fig.savefig(buf, format="png"); buf.seek(0)
60
- img = Image.open(buf)
61
-
62
- raw_txt = prompt + tok.decode(int(raw.argmax()))
63
- mod_txt = prompt + tok.decode(int(mod.argmax()))
64
- return raw_txt, mod_txt, headline, metrics_md, img, update_history()
65
- except Exception:
66
- tb = traceback.format_exc()
67
- return "runtime error", tb, "runtime error", "", None, update_history()
68
-
69
- def update_history():
70
- return {
71
- "step": hist["step"],
72
- "var": hist["var"],
73
- "kl": hist["kl"]
74
- }
75
-
76
- def clear_history():
77
- hist["step"].clear(); hist["var"].clear(); hist["kl"].clear()
78
- return update_history()
79
 
80
  with gr.Blocks(title="WFGY variance gate") as demo:
81
  gr.Markdown("# 🧠 WFGY simulation demo")
82
- prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
83
- btn = gr.Button("🚀 Run")
84
-
85
- with gr.Row():
86
- raw_box = gr.Textbox(label="Raw GPT-2")
87
- mod_box = gr.Textbox(label="After WFGY")
88
-
89
- headline = gr.Markdown()
90
- metrics = gr.Markdown()
91
- img = gr.Image(label="Logit histogram", type="pil")
92
-
93
- hist_plot = gr.LinePlot(
94
- label="History (var% ↓ & KL)",
95
- x="step", y=["var", "kl"],
96
- overlay=True, height=250
97
- )
98
- clear_btn = gr.Button("Clear history")
99
-
100
- with gr.Accordion("Paper benchmarks", open=False):
101
- bench_df = gr.DataFrame(paper, interactive=False)
102
- bench_bar = gr.BarPlot(paper, x="benchmark", y=["baseline", "wfgy"],
103
- overlay=False, height=300)
104
-
105
- gr.Markdown(
106
- "---\n"
107
- "### ⭐ 10 000 GitHub stars before **2025-08-01** unlock **WFGY 2.0**"
108
- )
109
-
110
- btn.click(run, prompt,
111
- [raw_box, mod_box, headline, metrics, img, hist_plot])
112
- clear_btn.click(fn=clear_history, inputs=None, outputs=hist_plot)
113
 
114
  if __name__ == "__main__":
115
  demo.queue().launch()
 
1
+ import io, numpy as np, matplotlib
2
  matplotlib.use("Agg")
3
 
4
  from PIL import Image
5
+ import pandas as pd, plotly.express as px, gradio as gr
6
+
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
  from wfgy_sdk import get_engine
9
  from wfgy_sdk.evaluator import compare_logits, plot_histogram
10
 
11
+ tok = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
12
+ mdl = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
13
+ eng = get_engine()
14
+
15
+ history = {"step": [], "var": [], "kl": []}
16
+
17
+ paper = pd.DataFrame({
18
+ "Benchmark": ["MMLU","GSM8K","BBH","MathBench","TruthfulQA",
19
+ "XNLI","MLQA","LongBench","VQAv2","OK-VQA"],
20
+ "Baseline": [61,78,79.3,72.2,62.4,59.5,78.1,51.4,69.1,65.7],
21
+ "WFGY": [89.8,98.7,100.7,87.4,90.4,77.3,106.6,69.6,86.6,86.8]
22
+ })
23
+
24
+ def run(prompt):
25
+ p = prompt.strip()
26
+ if not p:
27
+ return "", "", "", None, plot_history()
28
+ ids = tok(p, return_tensors="pt").input_ids
29
+ raw = mdl(ids).logits[0,-1].detach().cpu().numpy()
30
+ G = np.random.randn(256).astype(np.float32)
31
+ I = G + np.random.normal(scale=0.05,size=256).astype(np.float32)
32
+ mod = eng.run(I,G,raw)
33
+ m = compare_logits(raw,mod)
34
+ step = len(history["step"])+1
35
+ history["step"].append(step)
36
+ history["var"].append(m["var_drop"]*100)
37
+ history["kl"].append(m["kl"])
38
+ fig = plot_histogram(raw,mod)
39
+ buf = io.BytesIO(); fig.savefig(buf,format="png"); buf.seek(0)
40
+ img = Image.open(buf)
41
+ head = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f}"
42
+ raw_t = p + tok.decode(int(raw.argmax()))
43
+ mod_t = p + tok.decode(int(mod.argmax()))
44
+ return raw_t, mod_t, head, img, plot_history()
45
+
46
+ def plot_history():
47
+ if not history["step"]:
48
+ return px.line(title="history").update_layout(height=250)
49
+ df = pd.DataFrame(history)
50
+ return px.line(df, x="step", y=["var","kl"],
51
+ labels={"value":"metric","step":"call"}).update_layout(height=250)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  with gr.Blocks(title="WFGY variance gate") as demo:
54
  gr.Markdown("# 🧠 WFGY simulation demo")
55
+ inp = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
56
+ btn = gr.Button("🚀 Run")
57
+ raw = gr.Textbox(label="Raw GPT-2")
58
+ mod = gr.Textbox(label="After WFGY")
59
+ head= gr.Markdown()
60
+ img = gr.Image(type="pil")
61
+ line= gr.Plot()
62
+ btn.click(run, inp, [raw, mod, head, img, line])
63
+ with gr.Accordion("Paper benchmark", open=False):
64
+ gr.DataFrame(paper, interactive=False)
65
+ gr.Markdown("---\n⭐ 10 000 stars → unlock **WFGY 2.0** by 2025-08-01")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  if __name__ == "__main__":
68
  demo.queue().launch()