Spaces:
Running
Running
File size: 3,685 Bytes
b7304c4 c08680c a210442 2dd44a8 e18ee0c c08680c 421d392 e18ee0c b7304c4 2dd44a8 c08680c 2dd44a8 e18ee0c 2dd44a8 e18ee0c 2dd44a8 e18ee0c 2dd44a8 e18ee0c 2dd44a8 e18ee0c 2dd44a8 e18ee0c 2dd44a8 e18ee0c 2dd44a8 e18ee0c 2dd44a8 e18ee0c 2dd44a8 e18ee0c 2dd44a8 e18ee0c 421d392 e18ee0c b7304c4 e18ee0c b7304c4 e18ee0c b7304c4 2dd44a8 e18ee0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
#!/usr/bin/env python3
import os, json, time, random, threading, logging
from datetime import datetime, timezone
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from gradio.themes import Dark
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
PROMPTS_PATH = "full_prompts.json"
STATE_PATH = "current_state.json"
DATA_PATH = "data.json"
TOKENS_PER_PROMPT = 2048
SECS_PER_TOKEN = 15
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
logging.basicConfig(level=logging.INFO)
log = logging.getLogger()
# read or write json
def _rj(p, d):
try: return json.load(open(p, encoding="utf-8"))
except: return d
def _aw(p, o):
t = p + ".tmp"; open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)); os.replace(t, p)
# load prompts
tmp = _rj(PROMPTS_PATH, [])
if not tmp: raise Exception("no prompts")
prompts = tmp
# load model
tok = os.environ.get("HF_READ_TOKEN")
log.info("loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=False, token=tok)
model.to("cpu"); model.eval()
log.info("model up")
lock = threading.Lock()
def _init():
s = _rj(STATE_PATH, {})
if not s or s.get("finished"):
i = random.randrange(len(prompts))
s = {"i": i, "p": prompts[i], "g": "", "c": 0, "t": time.time(), "finished": False}
_aw(STATE_PATH, s)
return s
# elapsed time
def _es(st):
d = int(time.time() - st); h, r = divmod(d, 3600); m, s = divmod(r, 60)
return f"{h}h {m}m {s}s"
# oracle loop
def _loop():
while True:
with lock: s = _init()
if s["finished"]: time.sleep(SECS_PER_TOKEN); continue
c = s["p"] + s["g"]
ids = tokenizer(c, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids
with torch.no_grad(): out = model.generate(ids, max_new_tokens=1, do_sample=True, temperature=TEMP, top_p=TOP_P)
nt = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
with lock:
s["g"] += nt; s["c"] += 1
if s["c"] >= TOKENS_PER_PROMPT: s["finished"] = True
_aw(STATE_PATH, s)
time.sleep(SECS_PER_TOKEN)
threading.Thread(target=_loop, daemon=True).start()
# ui
def _fetch():
s = _rj(STATE_PATH, {})
if not s: return "...", "", "0h 0m 0s"
return s["p"], s["g"], _es(s["t"])
def _sg(f, i):
f1, f2 = f.strip(), i.strip()
if not (f1 or f2): return gr.update(value="eh?"), gr.update(), gr.update()
p, g, e = _fetch(); guess = f1 or f2; gt = "full" if f1 else "idea"
r = {"ts": datetime.now(timezone.utc).isoformat(), "p": p, "time": e, "resp": g, "guess": guess, "type": gt}
with lock: open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(r, ensure_ascii=False) + "\n")
return gr.update(value="ok logged"), gr.update(value=""), gr.update(value="")
with gr.Blocks(theme=Dark()) as demo:
gr.Markdown("# What Comes Next")
prm = gr.Markdown(); txt = gr.Textbox(lines=10, interactive=False, label="oracle"); tme = gr.Textbox(interactive=False, label="time")
rbtn = gr.Button("refresh"); full = gr.Textbox(label="full"); idea = gr.Textbox(label="idea"); send = gr.Button("send"); st = gr.Textbox(interactive=False, label="st")
demo.load(_fetch, outputs=[prm, txt, tme]); rbtn.click(_fetch, outputs=[prm, txt, tme]); send.click(_sg, inputs=[full, idea], outputs=[st, full, idea])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|