OneStarDao commited on
Commit
e12152f
Β·
verified Β·
1 Parent(s): 26f293f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -64
app.py CHANGED
@@ -8,35 +8,28 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
8
  from wfgy_sdk import get_engine
9
  from wfgy_sdk.evaluator import compare_logits, plot_histogram
10
 
11
- # ────────────────────────────
12
- # tiny model + engine
13
- # ────────────────────────────
14
  tok = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
15
  mdl = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
16
  eng = get_engine()
17
 
18
- # ────────────────────────────
19
- # runtime history (dummy zero so line is never empty)
20
- # ────────────────────────────
21
  history = {"step": [0], "var": [0.0], "kl": [0.0]}
22
 
23
- # ────────────────────────────
24
- # paper benchmark table
25
- # ────────────────────────────
26
  paper_df = pd.DataFrame({
27
  "Benchmark": ["MMLU","GSM8K","BBH","MathBench","TruthfulQA",
28
  "XNLI","MLQA","LongBench","VQAv2","OK-VQA"],
29
  "Baseline": [61.0,78.0,79.3,72.2,62.4,59.5,78.1,51.4,69.1,65.7],
30
  "WFGY": [89.8,98.7,100.7,87.4,90.4,77.3,106.6,69.6,86.6,86.8]
31
  })
32
- paper_df["Abs_gain"] = (paper_df["WFGY"] - paper_df["Baseline"]).round(1)
33
- paper_df["Rel_gain%"] = ((paper_df["Abs_gain"] / paper_df["Baseline"]) * 100).round(0)
34
 
35
  styled_df = (
36
  paper_df.style
37
- .background_gradient(subset=["Abs_gain"], cmap="Greens")
38
- .background_gradient(subset=["Rel_gain%"], cmap="Greens")
39
- .format({"Abs_gain": "{:.1f}", "Rel_gain%": "{:.0f}"})
40
  )
41
 
42
  paper_bar = px.bar(
@@ -45,68 +38,47 @@ paper_bar = px.bar(
45
  color_continuous_scale="Greens", height=300
46
  )
47
 
48
- # ────────────────────────────
49
  # helpers
50
- # ────────────────────────────
51
- def top5_tokens(logits: np.ndarray):
52
- """return list of (token, prob) sorted desc"""
53
- probs = torch.softmax(torch.tensor(logits), dim=0).numpy()
54
- idx = probs.argsort()[-5:][::-1]
55
- items = []
56
- for i in idx:
57
- token = tok.decode(int(i)).replace("\n", "\\n")
58
- prob = probs[i]
59
- items.append(f"{token!r}: {prob:.3f}")
60
- return "\n".join(items)
61
-
62
- def plot_history():
63
  df = pd.DataFrame(history)
64
- return px.line(df, x="step", y=["var", "kl"],
65
  labels={"value":"metric","step":"call"},
66
  title="history (var% ↓ & KL)").update_layout(height=260)
67
 
68
- def clear_history():
69
  history["step"][:] = [0]; history["var"][:]=[0.0]; history["kl"][:]=[0.0]
70
- return plot_history()
71
 
72
- # ────────────────────────────
73
- # main run
74
- # ────────────────────────────
75
  def run(prompt: str):
76
  p = prompt.strip()
77
  if not p:
78
- return "", "", "", "", None, plot_history()
79
 
80
- ids = tok(p, return_tensors="pt").input_ids
81
- rawL = mdl(ids).logits[0, -1].detach().cpu().numpy()
82
  G = np.random.randn(256).astype(np.float32)
83
- I = G + np.random.normal(scale=0.05, size=256).astype(np.float32)
84
- modL = eng.run(I, G, rawL)
85
-
86
- m = compare_logits(rawL, modL)
87
- step = len(history["step"])
88
- history["step"].append(step)
89
- history["var"].append(m["var_drop"] * 100)
90
- history["kl"].append(m["kl"])
91
-
92
- fig = plot_histogram(rawL, modL)
93
- buf = io.BytesIO(); fig.savefig(buf, format="png"); buf.seek(0)
94
- img = Image.open(buf)
95
 
96
- headline = f"β–Ό var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f} | top-1 {'kept' if m['top1'] else 'changed'}"
 
 
97
 
98
- raw_top5 = top5_tokens(rawL)
99
- mod_top5 = top5_tokens(modL)
100
 
101
- return raw_top5, mod_top5, headline, img, plot_history()
 
102
 
103
- # ────────────────────────────
104
- # UI layout
105
- # ────────────────────────────
106
  with gr.Blocks(title="WFGY variance gate demo") as demo:
107
  gr.Markdown("# 🧠 WFGY simulation demo")
108
  prompt = gr.Textbox(label="Prompt", value="Explain SchrΓΆdinger's cat")
109
- run_btn = gr.Button("πŸš€ Run")
110
 
111
  with gr.Row():
112
  raw_box = gr.Textbox(label="Raw top-5 tokens", lines=6)
@@ -114,18 +86,17 @@ with gr.Blocks(title="WFGY variance gate demo") as demo:
114
 
115
  headline = gr.Markdown()
116
  hist_img = gr.Image(type="pil", label="Logit histogram")
117
- hist_plot = gr.Plot()
118
- clr_btn = gr.Button("Clear history")
119
 
120
  with gr.Accordion("Paper benchmarks", open=False):
121
- gr.DataFrame(value=styled_df, interactive=False, wrap=True)
122
  gr.Plot(paper_bar)
123
 
124
- gr.Markdown("---\n⭐ **10 000 GitHub stars before 2025-08-01 unlock WFGY 2.0**")
125
 
126
- run_btn.click(run, prompt,
127
- [raw_box, mod_box, headline, hist_img, hist_plot])
128
- clr_btn.click(clear_history, None, hist_plot)
129
 
130
  if __name__ == "__main__":
131
  demo.queue().launch()
 
8
  from wfgy_sdk import get_engine
9
  from wfgy_sdk.evaluator import compare_logits, plot_histogram
10
 
11
+ # tiny model (CPU-friendly demo)
 
 
12
  tok = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
13
  mdl = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
14
  eng = get_engine()
15
 
16
+ # history buffer
 
 
17
  history = {"step": [0], "var": [0.0], "kl": [0.0]}
18
 
19
+ # paper table
 
 
20
  paper_df = pd.DataFrame({
21
  "Benchmark": ["MMLU","GSM8K","BBH","MathBench","TruthfulQA",
22
  "XNLI","MLQA","LongBench","VQAv2","OK-VQA"],
23
  "Baseline": [61.0,78.0,79.3,72.2,62.4,59.5,78.1,51.4,69.1,65.7],
24
  "WFGY": [89.8,98.7,100.7,87.4,90.4,77.3,106.6,69.6,86.6,86.8]
25
  })
26
+ paper_df["Abs_gain"] = (paper_df["WFGY"]-paper_df["Baseline"]).round(1)
27
+ paper_df["Rel_gain%"] = ((paper_df["Abs_gain"]/paper_df["Baseline"])*100).round(0)
28
 
29
  styled_df = (
30
  paper_df.style
31
+ .background_gradient(cmap="Greens", subset=["Abs_gain","Rel_gain%"])
32
+ .format({"Abs_gain":"{:.1f}","Rel_gain%":"{:.0f}"})
 
33
  )
34
 
35
  paper_bar = px.bar(
 
38
  color_continuous_scale="Greens", height=300
39
  )
40
 
 
41
  # helpers
42
+ def top5(logits: np.ndarray):
43
+ p = torch.softmax(torch.tensor(logits), dim=0).numpy()
44
+ idx = p.argsort()[-5:][::-1]
45
+ return "\n".join([f"{tok.decode(int(i))!r}: {p[i]:.2e}" for i in idx])
46
+
47
+ def hist_plot():
 
 
 
 
 
 
 
48
  df = pd.DataFrame(history)
49
+ return px.line(df, x="step", y=["var","kl"],
50
  labels={"value":"metric","step":"call"},
51
  title="history (var% ↓ & KL)").update_layout(height=260)
52
 
53
+ def clear_hist():
54
  history["step"][:] = [0]; history["var"][:]=[0.0]; history["kl"][:]=[0.0]
55
+ return hist_plot()
56
 
 
 
 
57
  def run(prompt: str):
58
  p = prompt.strip()
59
  if not p:
60
+ return "", "", "", "", None, hist_plot()
61
 
62
+ ids = tok(p, return_tensors="pt").input_ids
63
+ rawL = mdl(ids).logits[0,-1].detach().cpu().numpy()
64
  G = np.random.randn(256).astype(np.float32)
65
+ I = G + np.random.normal(scale=0.05,size=256).astype(np.float32)
66
+ modL = eng.run(I,G,rawL)
 
 
 
 
 
 
 
 
 
 
67
 
68
+ m = compare_logits(rawL,modL)
69
+ n = len(history["step"])
70
+ history["step"].append(n); history["var"].append(m["var_drop"]*100); history["kl"].append(m["kl"])
71
 
72
+ fig = plot_histogram(rawL,modL); buf=io.BytesIO(); fig.savefig(buf,format="png"); buf.seek(0)
 
73
 
74
+ head = f"β–Ό var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f} | top-1 {'kept' if m['top1'] else 'changed'}"
75
+ return top5(rawL), top5(modL), head, Image.open(buf), hist_plot()
76
 
77
+ # UI
 
 
78
  with gr.Blocks(title="WFGY variance gate demo") as demo:
79
  gr.Markdown("# 🧠 WFGY simulation demo")
80
  prompt = gr.Textbox(label="Prompt", value="Explain SchrΓΆdinger's cat")
81
+ run_b = gr.Button("πŸš€ Run")
82
 
83
  with gr.Row():
84
  raw_box = gr.Textbox(label="Raw top-5 tokens", lines=6)
 
86
 
87
  headline = gr.Markdown()
88
  hist_img = gr.Image(type="pil", label="Logit histogram")
89
+ hist_p = gr.Plot()
90
+ clr_b = gr.Button("Clear history")
91
 
92
  with gr.Accordion("Paper benchmarks", open=False):
93
+ gr.DataFrame(styled_df, interactive=False, wrap=True)
94
  gr.Plot(paper_bar)
95
 
96
+ gr.Markdown("---\n⭐ **10 k GitHub stars before 2025-08-01 unlock WFGY 2.0**")
97
 
98
+ run_b.click(run, prompt, [raw_box,mod_box,headline,hist_img,hist_p])
99
+ clr_b.click(clear_hist, None, hist_p)
 
100
 
101
  if __name__ == "__main__":
102
  demo.queue().launch()