OneStarDao commited on
Commit
c1eda44
·
verified ·
1 Parent(s): 14903f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import io, traceback, numpy as np, gradio as gr, matplotlib
2
- matplotlib.use("Agg") # headless backend
3
 
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from wfgy_sdk import get_engine
6
  from wfgy_sdk.evaluator import compare_logits, plot_histogram
@@ -10,6 +11,7 @@ tok = AutoTokenizer.from_pretrained(MODEL)
10
  mdl = AutoModelForCausalLM.from_pretrained(MODEL)
11
  eng = get_engine()
12
 
 
13
  def run(prompt: str):
14
  prompt = prompt.strip()
15
  if not prompt:
@@ -22,15 +24,21 @@ def run(prompt: str):
22
  mod = eng.run(I, G, raw)
23
  m = compare_logits(raw, mod)
24
  headline = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f}"
 
25
  fig = plot_histogram(raw, mod)
26
- buf = io.BytesIO(); fig.savefig(buf, format="png"); buf.seek(0)
 
 
 
 
27
  raw_txt = prompt + tok.decode(int(raw.argmax()))
28
  mod_txt = prompt + tok.decode(int(mod.argmax()))
29
- return raw_txt, mod_txt, headline, buf
30
- except Exception as e:
31
  tb = traceback.format_exc()
32
  return "runtime error", tb, "runtime error", None
33
 
 
34
  with gr.Blocks(title="WFGY variance gate") as demo:
35
  gr.Markdown("# 🧠 WFGY simulation demo")
36
  prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
@@ -40,7 +48,7 @@ with gr.Blocks(title="WFGY variance gate") as demo:
40
  raw_box = gr.Textbox(label="Raw GPT-2")
41
  mod_box = gr.Textbox(label="After WFGY")
42
  headline = gr.Markdown()
43
- img = gr.Image(label="Logit histogram")
44
 
45
  btn.click(run, prompt, [raw_box, mod_box, headline, img])
46
 
 
1
  import io, traceback, numpy as np, gradio as gr, matplotlib
2
+ matplotlib.use("Agg")
3
 
4
+ from PIL import Image
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from wfgy_sdk import get_engine
7
  from wfgy_sdk.evaluator import compare_logits, plot_histogram
 
11
  mdl = AutoModelForCausalLM.from_pretrained(MODEL)
12
  eng = get_engine()
13
 
14
+
15
  def run(prompt: str):
16
  prompt = prompt.strip()
17
  if not prompt:
 
24
  mod = eng.run(I, G, raw)
25
  m = compare_logits(raw, mod)
26
  headline = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f}"
27
+
28
  fig = plot_histogram(raw, mod)
29
+ buf = io.BytesIO()
30
+ fig.savefig(buf, format="png")
31
+ buf.seek(0)
32
+ img = Image.open(buf)
33
+
34
  raw_txt = prompt + tok.decode(int(raw.argmax()))
35
  mod_txt = prompt + tok.decode(int(mod.argmax()))
36
+ return raw_txt, mod_txt, headline, img
37
+ except Exception:
38
  tb = traceback.format_exc()
39
  return "runtime error", tb, "runtime error", None
40
 
41
+
42
  with gr.Blocks(title="WFGY variance gate") as demo:
43
  gr.Markdown("# 🧠 WFGY simulation demo")
44
  prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
 
48
  raw_box = gr.Textbox(label="Raw GPT-2")
49
  mod_box = gr.Textbox(label="After WFGY")
50
  headline = gr.Markdown()
51
+ img = gr.Image(label="Logit histogram", type="pil")
52
 
53
  btn.click(run, prompt, [raw_box, mod_box, headline, img])
54