Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,92 +1,100 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
-
|
3 |
|
4 |
import os, json, time, random, threading, logging
|
5 |
from datetime import datetime, timezone
|
6 |
-
import torch
|
7 |
-
|
8 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
9 |
|
10 |
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
|
11 |
PROMPTS_PATH = "full_prompts.json"
|
12 |
STATE_PATH = "current_state.json"
|
13 |
DATA_PATH = "data.json"
|
14 |
TOKENS_PER_PROMPT = 2048
|
15 |
-
|
16 |
-
|
17 |
-
TOP_P = 0.95
|
18 |
-
MAX_CONTEXT_TOKENS = 8192
|
19 |
|
20 |
logging.basicConfig(level=logging.INFO)
|
21 |
log = logging.getLogger()
|
22 |
|
23 |
-
|
|
|
|
|
24 |
try: return json.load(open(p, encoding="utf-8"))
|
25 |
except: return d
|
26 |
|
27 |
-
def
|
28 |
-
t = p + ".tmp"; open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)); os.replace(t,p)
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
|
35 |
-
# load model
|
36 |
tok = os.environ.get("HF_READ_TOKEN")
|
|
|
37 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
|
38 |
-
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=
|
39 |
-
model.to(
|
|
|
40 |
|
41 |
-
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
s = _read_json(STATE_PATH, {})
|
46 |
if not s or s.get("finished"):
|
47 |
i = random.randrange(len(prompts))
|
48 |
-
s = {"
|
49 |
-
|
50 |
return s
|
51 |
|
52 |
-
|
53 |
-
d=int(time.time()-st);h,r=divmod(d,3600);m,s=divmod(r,60);return f"{h}h {m}m {s}s"
|
54 |
|
55 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
while True:
|
57 |
-
with lock: s=
|
58 |
-
if s["finished"]: time.sleep(
|
59 |
-
c=s["
|
60 |
-
ids=tokenizer(c, return_tensors="pt", truncation=True, max_length=
|
61 |
-
with torch.no_grad(): out=model.generate(ids, max_new_tokens=1, do_sample=True, temperature=
|
62 |
-
nt=tokenizer.decode(out[0
|
63 |
with lock:
|
64 |
-
s["
|
65 |
-
if s["
|
66 |
-
|
67 |
-
time.sleep(
|
68 |
-
|
|
|
69 |
|
70 |
# ui
|
71 |
|
72 |
-
def
|
73 |
-
s=
|
74 |
-
if not s: return "
|
75 |
-
return s["
|
76 |
-
|
77 |
-
def
|
78 |
-
f
|
79 |
-
if not (
|
80 |
-
p,g,e=
|
81 |
-
r={"
|
82 |
-
with lock: open(DATA_PATH,"a",encoding="utf-8").write(json.dumps(r,ensure_ascii=False)+"\n")
|
83 |
-
return gr.update(value="logged
|
84 |
-
|
85 |
-
with gr.Blocks(
|
86 |
-
gr.Markdown("# What Comes Next
|
87 |
-
prm=gr.Markdown(); txt=gr.Textbox(lines=10,interactive=False,label="oracle");
|
88 |
-
|
89 |
-
demo.load(
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
+
# sloppy header
|
3 |
|
4 |
import os, json, time, random, threading, logging
|
5 |
from datetime import datetime, timezone
|
6 |
+
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
|
7 |
+
import gradio as gr
|
8 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
9 |
+
from gradio.themes import Dark
|
10 |
|
11 |
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
|
12 |
PROMPTS_PATH = "full_prompts.json"
|
13 |
STATE_PATH = "current_state.json"
|
14 |
DATA_PATH = "data.json"
|
15 |
TOKENS_PER_PROMPT = 2048
|
16 |
+
SECS_PER_TOKEN = 15
|
17 |
+
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
|
|
|
|
|
18 |
|
19 |
logging.basicConfig(level=logging.INFO)
|
20 |
log = logging.getLogger()
|
21 |
|
22 |
+
# read or write json
|
23 |
+
|
24 |
+
def _rj(p, d):
|
25 |
try: return json.load(open(p, encoding="utf-8"))
|
26 |
except: return d
|
27 |
|
28 |
+
def _aw(p, o):
|
29 |
+
t = p + ".tmp"; open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)); os.replace(t, p)
|
30 |
|
31 |
+
# load prompts
|
32 |
+
tmp = _rj(PROMPTS_PATH, [])
|
33 |
+
if not tmp: raise Exception("no prompts")
|
34 |
+
prompts = tmp
|
35 |
|
36 |
+
# load model
|
37 |
tok = os.environ.get("HF_READ_TOKEN")
|
38 |
+
log.info("loading model...")
|
39 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
|
40 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=False, token=tok)
|
41 |
+
model.to("cpu"); model.eval()
|
42 |
+
log.info("model up")
|
43 |
|
44 |
+
lock = threading.Lock()
|
45 |
|
46 |
+
def _init():
|
47 |
+
s = _rj(STATE_PATH, {})
|
|
|
48 |
if not s or s.get("finished"):
|
49 |
i = random.randrange(len(prompts))
|
50 |
+
s = {"i": i, "p": prompts[i], "g": "", "c": 0, "t": time.time(), "finished": False}
|
51 |
+
_aw(STATE_PATH, s)
|
52 |
return s
|
53 |
|
54 |
+
# elapsed time
|
|
|
55 |
|
56 |
+
def _es(st):
|
57 |
+
d = int(time.time() - st); h, r = divmod(d, 3600); m, s = divmod(r, 60)
|
58 |
+
return f"{h}h {m}m {s}s"
|
59 |
+
|
60 |
+
# oracle loop
|
61 |
+
|
62 |
+
def _loop():
|
63 |
while True:
|
64 |
+
with lock: s = _init()
|
65 |
+
if s["finished"]: time.sleep(SECS_PER_TOKEN); continue
|
66 |
+
c = s["p"] + s["g"]
|
67 |
+
ids = tokenizer(c, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids
|
68 |
+
with torch.no_grad(): out = model.generate(ids, max_new_tokens=1, do_sample=True, temperature=TEMP, top_p=TOP_P)
|
69 |
+
nt = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
70 |
with lock:
|
71 |
+
s["g"] += nt; s["c"] += 1
|
72 |
+
if s["c"] >= TOKENS_PER_PROMPT: s["finished"] = True
|
73 |
+
_aw(STATE_PATH, s)
|
74 |
+
time.sleep(SECS_PER_TOKEN)
|
75 |
+
|
76 |
+
threading.Thread(target=_loop, daemon=True).start()
|
77 |
|
78 |
# ui
|
79 |
|
80 |
+
def _fetch():
|
81 |
+
s = _rj(STATE_PATH, {})
|
82 |
+
if not s: return "...", "", "0h 0m 0s"
|
83 |
+
return s["p"], s["g"], _es(s["t"])
|
84 |
+
|
85 |
+
def _sg(f, i):
|
86 |
+
f1, f2 = f.strip(), i.strip()
|
87 |
+
if not (f1 or f2): return gr.update(value="eh?"), gr.update(), gr.update()
|
88 |
+
p, g, e = _fetch(); guess = f1 or f2; gt = "full" if f1 else "idea"
|
89 |
+
r = {"ts": datetime.now(timezone.utc).isoformat(), "p": p, "time": e, "resp": g, "guess": guess, "type": gt}
|
90 |
+
with lock: open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(r, ensure_ascii=False) + "\n")
|
91 |
+
return gr.update(value="ok logged"), gr.update(value=""), gr.update(value="")
|
92 |
+
|
93 |
+
with gr.Blocks(theme=Dark()) as demo:
|
94 |
+
gr.Markdown("# What Comes Next")
|
95 |
+
prm = gr.Markdown(); txt = gr.Textbox(lines=10, interactive=False, label="oracle"); tme = gr.Textbox(interactive=False, label="time")
|
96 |
+
rbtn = gr.Button("refresh"); full = gr.Textbox(label="full"); idea = gr.Textbox(label="idea"); send = gr.Button("send"); st = gr.Textbox(interactive=False, label="st")
|
97 |
+
demo.load(_fetch, outputs=[prm, txt, tme]); rbtn.click(_fetch, outputs=[prm, txt, tme]); send.click(_sg, inputs=[full, idea], outputs=[st, full, idea])
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|