OneStarDao commited on
Commit
048a5a0
·
verified ·
1 Parent(s): c1eda44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -27
app.py CHANGED
@@ -1,58 +1,66 @@
1
  import io, traceback, numpy as np, gradio as gr, matplotlib
2
  matplotlib.use("Agg")
3
-
4
  from PIL import Image
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from wfgy_sdk import get_engine
7
  from wfgy_sdk.evaluator import compare_logits, plot_histogram
 
8
 
9
  MODEL = "sshleifer/tiny-gpt2"
10
- tok = AutoTokenizer.from_pretrained(MODEL)
11
- mdl = AutoModelForCausalLM.from_pretrained(MODEL)
12
- eng = get_engine()
13
-
14
 
15
  def run(prompt: str):
16
  prompt = prompt.strip()
17
  if not prompt:
18
- return "", "", "no prompt – nothing to show", None
19
  try:
20
- ids = tok(prompt, return_tensors="pt").input_ids
21
- raw = mdl(ids).logits[0, -1].detach().cpu().numpy()
22
- G = np.random.randn(256).astype(np.float32)
23
- I = G + np.random.normal(scale=0.05, size=256).astype(np.float32)
24
- mod = eng.run(I, G, raw)
25
- m = compare_logits(raw, mod)
26
- headline = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f}"
27
-
28
- fig = plot_histogram(raw, mod)
29
- buf = io.BytesIO()
30
- fig.savefig(buf, format="png")
31
- buf.seek(0)
 
 
 
 
 
 
32
  img = Image.open(buf)
33
 
34
- raw_txt = prompt + tok.decode(int(raw.argmax()))
35
- mod_txt = prompt + tok.decode(int(mod.argmax()))
36
- return raw_txt, mod_txt, headline, img
37
  except Exception:
38
  tb = traceback.format_exc()
39
- return "runtime error", tb, "runtime error", None
40
-
41
 
42
  with gr.Blocks(title="WFGY variance gate") as demo:
43
  gr.Markdown("# 🧠 WFGY simulation demo")
44
  prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
45
- btn = gr.Button("🚀 Run")
46
 
47
  with gr.Row():
48
  raw_box = gr.Textbox(label="Raw GPT-2")
49
  mod_box = gr.Textbox(label="After WFGY")
 
50
  headline = gr.Markdown()
51
- img = gr.Image(label="Logit histogram", type="pil")
 
52
 
53
- btn.click(run, prompt, [raw_box, mod_box, headline, img])
 
54
 
55
- gr.Markdown("---\n### ⭐ 10 000 stars → unlock **WFGY 2.0** by 2025-08-01")
 
56
 
57
  if __name__ == "__main__":
58
  demo.queue().launch()
 
1
  import io, traceback, numpy as np, gradio as gr, matplotlib
2
  matplotlib.use("Agg")
 
3
  from PIL import Image
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from wfgy_sdk import get_engine
6
  from wfgy_sdk.evaluator import compare_logits, plot_histogram
7
+ from tabulate import tabulate
8
 
9
  MODEL = "sshleifer/tiny-gpt2"
10
+ tok = AutoTokenizer.from_pretrained(MODEL)
11
+ mdl = AutoModelForCausalLM.from_pretrained(MODEL)
12
+ eng = get_engine()
 
13
 
14
  def run(prompt: str):
15
  prompt = prompt.strip()
16
  if not prompt:
17
+ return "", "", "", None, None
18
  try:
19
+ ids = tok(prompt, return_tensors="pt").input_ids
20
+ rawL = mdl(ids).logits[0, -1].detach().cpu().numpy()
21
+ G = np.random.randn(256).astype(np.float32)
22
+ I = G + np.random.normal(scale=0.05, size=256).astype(np.float32)
23
+ modL = eng.run(I, G, rawL)
24
+
25
+ m = compare_logits(rawL, modL)
26
+ tbl = tabulate(
27
+ [[f"{m['std_ratio']:.3f}",
28
+ f"{m['var_drop']*100:4.1f} %",
29
+ f"{m['kl']:.3f}",
30
+ "✔" if m['top1'] else "✘"]],
31
+ headers=["std_ratio", "▼ var", "KL", "top-1"],
32
+ tablefmt="github")
33
+ headline = f"▼ var {m['var_drop']*100:4.1f} % | KL {m['kl']:.3f}"
34
+
35
+ fig = plot_histogram(rawL, modL)
36
+ buf = io.BytesIO(); fig.savefig(buf, format="png"); buf.seek(0)
37
  img = Image.open(buf)
38
 
39
+ raw_txt = prompt + tok.decode(int(rawL.argmax()))
40
+ mod_txt = prompt + tok.decode(int(modL.argmax()))
41
+ return raw_txt, mod_txt, headline, tbl, img
42
  except Exception:
43
  tb = traceback.format_exc()
44
+ return "runtime error", tb, "runtime error", "", None
 
45
 
46
  with gr.Blocks(title="WFGY variance gate") as demo:
47
  gr.Markdown("# 🧠 WFGY simulation demo")
48
  prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
49
+ btn = gr.Button("🚀 Run")
50
 
51
  with gr.Row():
52
  raw_box = gr.Textbox(label="Raw GPT-2")
53
  mod_box = gr.Textbox(label="After WFGY")
54
+
55
  headline = gr.Markdown()
56
+ metrics = gr.Markdown() # ← 新增數值表
57
+ img = gr.Image(label="Logit histogram", type="pil")
58
 
59
+ btn.click(run, prompt,
60
+ [raw_box, mod_box, headline, metrics, img])
61
 
62
+ gr.Markdown("---\n"
63
+ "### ⭐ 10 000 stars → unlock **WFGY 2.0** by 2025-08-01")
64
 
65
  if __name__ == "__main__":
66
  demo.queue().launch()