OneStarDao commited on
Commit
ef37700
·
verified ·
1 Parent(s): e672caa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -32
app.py CHANGED
@@ -1,50 +1,68 @@
1
- """
2
- HF Space · WFGY 1-click Variance Gate (貼上就能部署)
3
- """
4
 
5
- import io, numpy as np, gradio as gr, matplotlib.pyplot as plt
 
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from wfgy_sdk import get_engine
8
  from wfgy_sdk.evaluator import compare_logits, plot_histogram
9
 
10
- MODEL = "sshleifer/tiny-gpt2"
11
- tok = AutoTokenizer.from_pretrained(MODEL)
12
- mdl = AutoModelForCausalLM.from_pretrained(MODEL)
13
- ENG = get_engine()
14
 
15
- def run(prompt:str):
16
- if not prompt.strip():
17
- return "-", "-", "-", None
18
 
19
- inp = tok(prompt, return_tensors="pt")
20
- rawL = mdl(**inp).logits[0, -1].detach().cpu().numpy()
21
- I, G = np.random.randn(2, 256).astype(np.float32)
22
- modL = ENG.run(I, G, rawL)
23
 
24
- mets = compare_logits(rawL, modL)
25
- head = f"▼ Var {mets['var_drop']*100:.1f}% | KL {mets['kl']:.2f}"
26
 
27
- # ── 圖表轉成 PNG buffer ──
28
- fig = plot_histogram(rawL, modL)
29
- buf = io.BytesIO(); fig.savefig(buf, format="png"); buf.seek(0)
30
 
31
- raw_txt = prompt + tok.decode(int(rawL.argmax()))
32
- mod_txt = prompt + tok.decode(int(modL.argmax()))
33
- return raw_txt, mod_txt, head, buf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- with gr.Blocks(title="WFGY 1-Click Variance Gate") as demo:
36
- gr.Markdown("# 🧠 WFGY 模擬實驗\n*輸入任意 Prompt,立刻觀看 Logit 直方圖*")
37
  prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
38
- run_b = gr.Button("🚀 Run")
39
 
40
  with gr.Row():
41
- raw = gr.Textbox(label="Raw GPT-2")
42
- mod = gr.Textbox(label="After WFGY")
 
 
 
43
 
44
- head = gr.Markdown()
45
- img = gr.Image(label="Logit Histogram")
46
 
47
- run_b.click(run, prompt, [raw, mod, head, img])
 
 
 
 
48
 
49
  if __name__ == "__main__":
50
- demo.queue(default_concurrency_limit=2).launch()
 
1
+ # HF Space · WFGY demo (all-English, no comments in other languages)
 
 
2
 
3
+ import io
4
+ import numpy as np
5
+ import gradio as gr
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from wfgy_sdk import get_engine
8
  from wfgy_sdk.evaluator import compare_logits, plot_histogram
9
 
10
+ MODEL_ID = "sshleifer/tiny-gpt2"
 
 
 
11
 
12
+ tok = AutoTokenizer.from_pretrained(MODEL_ID)
13
+ mdl = AutoModelForCausalLM.from_pretrained(MODEL_ID)
14
+ eng = get_engine()
15
 
16
+ def run(prompt: str):
17
+ prompt = prompt.strip()
18
+ if not prompt:
19
+ return "", "", "No prompt — nothing to show", None
20
 
21
+ ids = tok(prompt, return_tensors="pt").input_ids
22
+ logits_raw = mdl(ids).logits[0, -1].detach().cpu().numpy()
23
 
24
+ # toy fingerprints just for the demo
25
+ G = np.random.randn(256).astype(np.float32)
26
+ I = G + np.random.normal(scale=0.05, size=256).astype(np.float32)
27
 
28
+ logits_mod = eng.run(I, G, logits_raw)
29
+ m = compare_logits(logits_raw, logits_mod)
30
+
31
+ headline = f"▼ var {m['var_drop']*100:.1f} % | KL {m['kl']:.3f}"
32
+
33
+ fig = plot_histogram(logits_raw, logits_mod)
34
+ buf = io.BytesIO()
35
+ fig.savefig(buf, format="png")
36
+ buf.seek(0)
37
+
38
+ raw_txt = prompt + tok.decode(int(logits_raw.argmax()))
39
+ mod_txt = prompt + tok.decode(int(logits_mod.argmax()))
40
+ return raw_txt, mod_txt, headline, buf
41
+
42
+
43
+ with gr.Blocks(title="WFGY Variance Gate") as demo:
44
+ gr.Markdown(
45
+ "# 🧠 WFGY simulation demo\n"
46
+ "Type any prompt – watch variance shrink in real time."
47
+ )
48
 
 
 
49
  prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
50
+ btn = gr.Button("🚀 Run")
51
 
52
  with gr.Row():
53
+ raw_box = gr.Textbox(label="Raw GPT-2")
54
+ mod_box = gr.Textbox(label="After WFGY")
55
+
56
+ headline = gr.Markdown()
57
+ img = gr.Image(label="Logit histogram")
58
 
59
+ btn.click(run, prompt, [raw_box, mod_box, headline, img])
 
60
 
61
+ gr.Markdown(
62
+ "---\n"
63
+ "### ⭐ Help unlock **WFGY 2.0**\n"
64
+ "10 000 stars on GitHub by **2025-08-01** → next-gen release."
65
+ )
66
 
67
  if __name__ == "__main__":
68
+ demo.queue(concurrency_count=2).launch()