Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import io, traceback, numpy as np, gradio as gr, matplotlib
|
2 |
-
matplotlib.use("Agg")
|
3 |
|
|
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
from wfgy_sdk import get_engine
|
6 |
from wfgy_sdk.evaluator import compare_logits, plot_histogram
|
@@ -10,6 +11,7 @@ tok = AutoTokenizer.from_pretrained(MODEL)
|
|
10 |
mdl = AutoModelForCausalLM.from_pretrained(MODEL)
|
11 |
eng = get_engine()
|
12 |
|
|
|
13 |
def run(prompt: str):
|
14 |
prompt = prompt.strip()
|
15 |
if not prompt:
|
@@ -22,15 +24,21 @@ def run(prompt: str):
|
|
22 |
mod = eng.run(I, G, raw)
|
23 |
m = compare_logits(raw, mod)
|
24 |
headline = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f}"
|
|
|
25 |
fig = plot_histogram(raw, mod)
|
26 |
-
buf = io.BytesIO()
|
|
|
|
|
|
|
|
|
27 |
raw_txt = prompt + tok.decode(int(raw.argmax()))
|
28 |
mod_txt = prompt + tok.decode(int(mod.argmax()))
|
29 |
-
return raw_txt, mod_txt, headline,
|
30 |
-
except Exception
|
31 |
tb = traceback.format_exc()
|
32 |
return "runtime error", tb, "runtime error", None
|
33 |
|
|
|
34 |
with gr.Blocks(title="WFGY variance gate") as demo:
|
35 |
gr.Markdown("# 🧠 WFGY simulation demo")
|
36 |
prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
|
@@ -40,7 +48,7 @@ with gr.Blocks(title="WFGY variance gate") as demo:
|
|
40 |
raw_box = gr.Textbox(label="Raw GPT-2")
|
41 |
mod_box = gr.Textbox(label="After WFGY")
|
42 |
headline = gr.Markdown()
|
43 |
-
img = gr.Image(label="Logit histogram")
|
44 |
|
45 |
btn.click(run, prompt, [raw_box, mod_box, headline, img])
|
46 |
|
|
|
1 |
import io, traceback, numpy as np, gradio as gr, matplotlib
|
2 |
+
matplotlib.use("Agg")
|
3 |
|
4 |
+
from PIL import Image
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
from wfgy_sdk import get_engine
|
7 |
from wfgy_sdk.evaluator import compare_logits, plot_histogram
|
|
|
11 |
mdl = AutoModelForCausalLM.from_pretrained(MODEL)
|
12 |
eng = get_engine()
|
13 |
|
14 |
+
|
15 |
def run(prompt: str):
|
16 |
prompt = prompt.strip()
|
17 |
if not prompt:
|
|
|
24 |
mod = eng.run(I, G, raw)
|
25 |
m = compare_logits(raw, mod)
|
26 |
headline = f"▼ var {m['var_drop']*100:4.1f}% | KL {m['kl']:.3f}"
|
27 |
+
|
28 |
fig = plot_histogram(raw, mod)
|
29 |
+
buf = io.BytesIO()
|
30 |
+
fig.savefig(buf, format="png")
|
31 |
+
buf.seek(0)
|
32 |
+
img = Image.open(buf)
|
33 |
+
|
34 |
raw_txt = prompt + tok.decode(int(raw.argmax()))
|
35 |
mod_txt = prompt + tok.decode(int(mod.argmax()))
|
36 |
+
return raw_txt, mod_txt, headline, img
|
37 |
+
except Exception:
|
38 |
tb = traceback.format_exc()
|
39 |
return "runtime error", tb, "runtime error", None
|
40 |
|
41 |
+
|
42 |
with gr.Blocks(title="WFGY variance gate") as demo:
|
43 |
gr.Markdown("# 🧠 WFGY simulation demo")
|
44 |
prompt = gr.Textbox(label="Prompt", value="Explain Schrödinger's cat")
|
|
|
48 |
raw_box = gr.Textbox(label="Raw GPT-2")
|
49 |
mod_box = gr.Textbox(label="After WFGY")
|
50 |
headline = gr.Markdown()
|
51 |
+
img = gr.Image(label="Logit histogram", type="pil")
|
52 |
|
53 |
btn.click(run, prompt, [raw_box, mod_box, headline, img])
|
54 |
|