OneStarDao commited on
Commit
6d21131
·
verified ·
1 Parent(s): 07ad6d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -32
app.py CHANGED
@@ -5,21 +5,28 @@ from wfgy_sdk.visual import plot_histogram
5
  import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
7
 
 
 
 
8
  MODEL = "sshleifer/tiny-gpt2"
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
10
- model = AutoModelForCausalLM.from_pretrained(MODEL)
11
  set_seed(42)
12
- ENGINE = w.get_engine()
13
 
 
14
 
 
 
 
15
  def wfgy_pipeline(prompt: str, enable_wfgy: bool):
16
  if not prompt.strip():
17
  return "", "", "<i>Please enter a prompt.</i>", None
18
 
19
  try:
20
- ids = tokenizer(prompt, return_tensors="pt").input_ids
21
- raw_logits = model(ids).logits[0, -1].detach().numpy()
22
 
 
23
  G = np.random.randn(256); G /= np.linalg.norm(G)
24
  I = G + np.random.normal(scale=0.05, size=256)
25
 
@@ -28,49 +35,51 @@ def wfgy_pipeline(prompt: str, enable_wfgy: bool):
28
  if enable_wfgy else raw_logits.copy()
29
  )
30
 
31
- m = compare_logits(raw_logits, mod_logits)
32
- top1 = "✔" if m["top1_shift"] else "✘"
33
- metrics_html = (
34
- f"<b>variance ▼ {(1-m['std_ratio'])*100:.0f}%</b> "
35
- f"| <b>KL {m['kl_divergence']:.2f}</b> "
36
- f"| top-1 {top1}"
37
- )
38
 
39
- fig = plot_histogram(raw_logits, mod_logits) # <<< fixed
 
 
 
40
  buf = io.BytesIO(); fig.savefig(buf, format="png"); fig.clf()
41
- img_uri = "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()
42
 
43
- raw_next = tokenizer.decode(int(raw_logits.argmax()))
44
- mod_next = tokenizer.decode(int(mod_logits.argmax()))
45
- return prompt + raw_next, prompt + mod_next, metrics_html, img_uri
46
 
47
- except Exception as e:
48
- return "", "", f"<b style='color:red'>Error:</b> {str(e)}", None
49
 
 
 
 
50
 
51
- css = """
52
- #prompt-row {margin-bottom: 1.0rem}
53
- .gr-box {font-size: 0.85rem}
54
- """
55
 
56
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
57
  gr.Markdown(
58
  """
59
  ### 🧠 WFGY 1-click Variance Gate
60
 
61
- Turn GPT-2 into a calmer thinker in seconds.<br>
62
- **Bigger LLMs → even stronger gains.**
63
 
64
  | Metric | Meaning |
65
  |--------|---------|
66
  | **variance ▼** | logits become less noisy |
67
  | **KL** | distribution reshaped |
68
- | **top-1** | most-likely token swapped ✔/✘ |
69
 
70
  **Benchmarks (WFGY 1.0 vs base)**
71
 
72
  | Task | Base % | WFGY % | Δ |
73
- |------|-------|--------|---|
74
  | MMLU | 61.0 | **89.8** | +47 % |
75
  | TruthfulQA | 62.4 | **90.4** | +45 % |
76
  | GSM8K | 78.0 | **98.7** | +27 % |
@@ -80,17 +89,17 @@ Turn GPT-2 into a calmer thinker in seconds.<br>
80
  with gr.Row(elem_id="prompt-row"):
81
  prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Ask anything…")
82
  enable = gr.Checkbox(label="Enable WFGY", value=True)
83
- run_btn = gr.Button("Run")
84
 
85
  with gr.Row():
86
  raw_box = gr.Textbox(label="Raw GPT-2")
87
  mod_box = gr.Textbox(label="After WFGY")
88
 
89
- metrics = gr.HTML()
90
- hist_img = gr.Image(label="Logit distribution", width=440)
91
 
92
- run_btn.click(wfgy_pipeline, [prompt, enable],
93
- [raw_box, mod_box, metrics, hist_img])
94
 
95
  gr.Markdown(
96
  """
 
5
  import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
7
 
8
+ # ----------------------------------------------------------------------
9
+ # tiny GPT-2 so the Space stays within free CPU limits
10
+ # ----------------------------------------------------------------------
11
  MODEL = "sshleifer/tiny-gpt2"
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
13
+ model = AutoModelForCausalLM.from_pretrained(MODEL)
14
  set_seed(42)
 
15
 
16
+ ENGINE = w.get_engine()
17
 
18
+ # ----------------------------------------------------------------------
19
+ # helper
20
+ # ----------------------------------------------------------------------
21
  def wfgy_pipeline(prompt: str, enable_wfgy: bool):
22
  if not prompt.strip():
23
  return "", "", "<i>Please enter a prompt.</i>", None
24
 
25
  try:
26
+ ids = tokenizer(prompt, return_tensors="pt").input_ids
27
+ raw_logits = model(ids).logits[0, -1].detach().cpu().numpy()
28
 
29
+ # dummy semantic vectors (demo only)
30
  G = np.random.randn(256); G /= np.linalg.norm(G)
31
  I = G + np.random.normal(scale=0.05, size=256)
32
 
 
35
  if enable_wfgy else raw_logits.copy()
36
  )
37
 
38
+ # metrics
39
+ m = compare_logits(raw_logits, mod_logits)
40
+ top1 = "✔" if m["top1_shift"] else "✘"
41
+ metric = (f"<b>variance ▼ {(1-m['std_ratio'])*100:.0f}%</b> | "
42
+ f"<b>KL {m['kl_divergence']:.2f}</b> | top-1 {top1}")
 
 
43
 
44
+ # histogram (support both “return fig” or “draw directly” versions)
45
+ maybe_fig = plot_histogram(raw_logits, mod_logits) # **no show kwarg**
46
+ import matplotlib.pyplot as plt
47
+ fig = maybe_fig if maybe_fig is not None else plt.gcf()
48
  buf = io.BytesIO(); fig.savefig(buf, format="png"); fig.clf()
49
+ hist_uri = "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()
50
 
51
+ # one-token continuations
52
+ raw_txt = prompt + tokenizer.decode(int(raw_logits.argmax()))
53
+ mod_txt = prompt + tokenizer.decode(int(mod_logits.argmax()))
54
 
55
+ return raw_txt, mod_txt, metric, hist_uri
 
56
 
57
+ except Exception as e:
58
+ err = f"<b style='color:red'>Error:</b> {str(e)}"
59
+ return "", "", err, None
60
 
61
+ # ----------------------------------------------------------------------
62
+ # UI
63
+ # ----------------------------------------------------------------------
64
+ css = "#prompt-row{margin-bottom:1rem}.gr-box{font-size:.85rem}"
65
 
66
+ with gr.Blocks(title="WFGY Variance Gate", css=css, theme=gr.themes.Soft()) as demo:
67
  gr.Markdown(
68
  """
69
  ### 🧠 WFGY 1-click Variance Gate
70
 
71
+ Turn GPT-2 into a calmer thinker in seconds. **Bigger LLMs → even stronger gains.**
 
72
 
73
  | Metric | Meaning |
74
  |--------|---------|
75
  | **variance ▼** | logits become less noisy |
76
  | **KL** | distribution reshaped |
77
+ | **top-1** | most-likely token swapped / ✘ |
78
 
79
  **Benchmarks (WFGY 1.0 vs base)**
80
 
81
  | Task | Base % | WFGY % | Δ |
82
+ |------|-------:|-------:|---:|
83
  | MMLU | 61.0 | **89.8** | +47 % |
84
  | TruthfulQA | 62.4 | **90.4** | +45 % |
85
  | GSM8K | 78.0 | **98.7** | +27 % |
 
89
  with gr.Row(elem_id="prompt-row"):
90
  prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Ask anything…")
91
  enable = gr.Checkbox(label="Enable WFGY", value=True)
92
+ runbtn = gr.Button("Run")
93
 
94
  with gr.Row():
95
  raw_box = gr.Textbox(label="Raw GPT-2")
96
  mod_box = gr.Textbox(label="After WFGY")
97
 
98
+ metric_html = gr.HTML()
99
+ hist_img = gr.Image(label="Logit distribution", width=440)
100
 
101
+ runbtn.click(wfgy_pipeline, [prompt, enable],
102
+ [raw_box, mod_box, metric_html, hist_img])
103
 
104
  gr.Markdown(
105
  """