OneStarDao commited on
Commit
54b95bc
·
verified ·
1 Parent(s): b42e326

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -34
app.py CHANGED
@@ -3,66 +3,87 @@ matplotlib.use("Agg")
3
 
4
  from PIL import Image
5
  import pandas as pd, plotly.express as px, gradio as gr
6
-
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
  from wfgy_sdk import get_engine
9
  from wfgy_sdk.evaluator import compare_logits, plot_histogram
10
 
 
11
  tok = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
12
  mdl = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
13
  eng = get_engine()
14
 
15
- history = {"step": [], "var": [], "kl": []}
 
16
 
17
- paper = pd.DataFrame({
 
18
  "Benchmark": ["MMLU","GSM8K","BBH","MathBench","TruthfulQA",
19
  "XNLI","MLQA","LongBench","VQAv2","OK-VQA"],
20
- "Baseline": [61,78,79.3,72.2,62.4,59.5,78.1,51.4,69.1,65.7],
21
  "WFGY": [89.8,98.7,100.7,87.4,90.4,77.3,106.6,69.6,86.6,86.8]
22
  })
 
 
23
 
24
- def run(prompt):
25
- p = prompt.strip()
26
- if not p:
27
  return "", "", "", None, plot_history()
28
- ids = tok(p, return_tensors="pt").input_ids
29
- raw = mdl(ids).logits[0,-1].detach().cpu().numpy()
30
- G = np.random.randn(256).astype(np.float32)
31
- I = G + np.random.normal(scale=0.05,size=256).astype(np.float32)
32
- mod = eng.run(I,G,raw)
33
- m = compare_logits(raw,mod)
34
- step = len(history["step"])+1
 
35
  history["step"].append(step)
36
  history["var"].append(m["var_drop"]*100)
37
  history["kl"].append(m["kl"])
38
- fig = plot_histogram(raw,mod)
39
- buf = io.BytesIO(); fig.savefig(buf,format="png"); buf.seek(0)
 
40
  img = Image.open(buf)
41
- head = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f}"
42
- raw_t = p + tok.decode(int(raw.argmax()))
43
- mod_t = p + tok.decode(int(mod.argmax()))
44
- return raw_t, mod_t, head, img, plot_history()
 
 
 
 
45
 
46
  def plot_history():
47
- if not history["step"]:
48
- return px.line(title="history").update_layout(height=250)
49
  df = pd.DataFrame(history)
50
  return px.line(df, x="step", y=["var","kl"],
51
- labels={"value":"metric","step":"call"}).update_layout(height=250)
 
 
 
 
 
 
 
52
 
53
  with gr.Blocks(title="WFGY variance gate") as demo:
54
  gr.Markdown("# 🧠 WFGY simulation demo")
55
- inp = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
56
- btn = gr.Button("🚀 Run")
57
- raw = gr.Textbox(label="Raw GPT-2")
58
- mod = gr.Textbox(label="After WFGY")
59
- head= gr.Markdown()
60
- img = gr.Image(type="pil")
61
- line= gr.Plot()
62
- btn.click(run, inp, [raw, mod, head, img, line])
63
- with gr.Accordion("Paper benchmark", open=False):
64
- gr.DataFrame(paper, interactive=False)
65
- gr.Markdown("---\n⭐ 10 000 stars → unlock **WFGY 2.0** by 2025-08-01")
 
 
 
 
 
 
66
 
67
  if __name__ == "__main__":
68
  demo.queue().launch()
 
3
 
4
  from PIL import Image
5
  import pandas as pd, plotly.express as px, gradio as gr
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from wfgy_sdk import get_engine
8
  from wfgy_sdk.evaluator import compare_logits, plot_histogram
9
 
10
+ # tiny model for demo
11
  tok = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
12
  mdl = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
13
  eng = get_engine()
14
 
15
+ # runtime history (start with a dummy zero so the plot is never empty)
16
+ history = {"step": [0], "var": [0.0], "kl": [0.0]}
17
 
18
+ # paper benchmark absolute numbers
19
+ paper_df = pd.DataFrame({
20
  "Benchmark": ["MMLU","GSM8K","BBH","MathBench","TruthfulQA",
21
  "XNLI","MLQA","LongBench","VQAv2","OK-VQA"],
22
+ "Baseline": [61.0,78.0,79.3,72.2,62.4,59.5,78.1,51.4,69.1,65.7],
23
  "WFGY": [89.8,98.7,100.7,87.4,90.4,77.3,106.6,69.6,86.6,86.8]
24
  })
25
+ paper_df["Abs_gain"] = (paper_df["WFGY"] - paper_df["Baseline"]).round(1)
26
+ paper_df["Rel_gain%"] = ((paper_df["Abs_gain"] / paper_df["Baseline"])*100).round(0)
27
 
28
+ def run(prompt: str):
29
+ prompt = prompt.strip()
30
+ if not prompt:
31
  return "", "", "", None, plot_history()
32
+ ids = tok(prompt, return_tensors="pt").input_ids
33
+ rawL = mdl(ids).logits[0,-1].detach().cpu().numpy()
34
+ G = np.random.randn(256).astype(np.float32)
35
+ I = G + np.random.normal(scale=0.05, size=256).astype(np.float32)
36
+ modL = eng.run(I, G, rawL)
37
+
38
+ m = compare_logits(rawL, modL)
39
+ step = len(history["step"])
40
  history["step"].append(step)
41
  history["var"].append(m["var_drop"]*100)
42
  history["kl"].append(m["kl"])
43
+
44
+ fig = plot_histogram(rawL, modL)
45
+ buf = io.BytesIO(); fig.savefig(buf, format="png"); buf.seek(0)
46
  img = Image.open(buf)
47
+
48
+ headline = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f}"
49
+ note = f"*top-1 token {'changed' if not m['top1'] else 'kept'}*"
50
+
51
+ raw_text = prompt + tok.decode(int(rawL.argmax()))
52
+ mod_text = prompt + tok.decode(int(modL.argmax()))
53
+
54
+ return raw_text, mod_text, headline + " " + note, img, plot_history()
55
 
56
  def plot_history():
 
 
57
  df = pd.DataFrame(history)
58
  return px.line(df, x="step", y=["var","kl"],
59
+ labels={"value":"metric","step":"call"},
60
+ title="history (var% ↓ & KL)").update_layout(height=260)
61
+
62
+ def clear_hist():
63
+ history["step"][:] = [0]
64
+ history["var"][:] = [0.0]
65
+ history["kl"][:] = [0.0]
66
+ return plot_history()
67
 
68
  with gr.Blocks(title="WFGY variance gate") as demo:
69
  gr.Markdown("# 🧠 WFGY simulation demo")
70
+ prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
71
+ run_btn = gr.Button("🚀 Run")
72
+ with gr.Row():
73
+ raw_box = gr.Textbox(label="Raw GPT-2")
74
+ mod_box = gr.Textbox(label="After WFGY")
75
+ headline = gr.Markdown()
76
+ hist_img = gr.Image(type="pil", label="Logit histogram")
77
+ hist_plot = gr.Plot(label="History")
78
+ clr_btn = gr.Button("Clear history")
79
+
80
+ with gr.Accordion("Paper benchmarks", open=False):
81
+ gr.DataFrame(paper_df, interactive=False, wrap=True)
82
+
83
+ gr.Markdown("---\n⭐ **10 000 GitHub stars before 2025-08-01 unlock WFGY 2.0**")
84
+
85
+ run_btn.click(run, prompt, [raw_box, mod_box, headline, hist_img, hist_plot])
86
+ clr_btn.click(clear_hist, None, hist_plot)
87
 
88
  if __name__ == "__main__":
89
  demo.queue().launch()