Spaces:
Running
Running
#!/usr/bin/env python3 | |
import os, json, time, random, threading, logging | |
from datetime import datetime, timezone | |
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count()) | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from gradio.themes import Dark | |
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" | |
PROMPTS_PATH = "full_prompts.json" | |
STATE_PATH = "current_state.json" | |
DATA_PATH = "data.json" | |
TOKENS_PER_PROMPT = 2048 | |
SECS_PER_TOKEN = 15 | |
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192 | |
logging.basicConfig(level=logging.INFO) | |
log = logging.getLogger() | |
# read or write json | |
def _rj(p, d): | |
try: return json.load(open(p, encoding="utf-8")) | |
except: return d | |
def _aw(p, o): | |
t = p + ".tmp"; open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)); os.replace(t, p) | |
# load prompts | |
tmp = _rj(PROMPTS_PATH, []) | |
if not tmp: raise Exception("no prompts") | |
prompts = tmp | |
# load model | |
tok = os.environ.get("HF_READ_TOKEN") | |
log.info("loading model...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=False, token=tok) | |
model.to("cpu"); model.eval() | |
log.info("model up") | |
lock = threading.Lock() | |
def _init(): | |
s = _rj(STATE_PATH, {}) | |
if not s or s.get("finished"): | |
i = random.randrange(len(prompts)) | |
s = {"i": i, "p": prompts[i], "g": "", "c": 0, "t": time.time(), "finished": False} | |
_aw(STATE_PATH, s) | |
return s | |
# elapsed time | |
def _es(st): | |
d = int(time.time() - st); h, r = divmod(d, 3600); m, s = divmod(r, 60) | |
return f"{h}h {m}m {s}s" | |
# oracle loop | |
def _loop(): | |
while True: | |
with lock: s = _init() | |
if s["finished"]: time.sleep(SECS_PER_TOKEN); continue | |
c = s["p"] + s["g"] | |
ids = tokenizer(c, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids | |
with torch.no_grad(): out = model.generate(ids, max_new_tokens=1, do_sample=True, temperature=TEMP, top_p=TOP_P) | |
nt = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
with lock: | |
s["g"] += nt; s["c"] += 1 | |
if s["c"] >= TOKENS_PER_PROMPT: s["finished"] = True | |
_aw(STATE_PATH, s) | |
time.sleep(SECS_PER_TOKEN) | |
threading.Thread(target=_loop, daemon=True).start() | |
# ui | |
def _fetch(): | |
s = _rj(STATE_PATH, {}) | |
if not s: return "...", "", "0h 0m 0s" | |
return s["p"], s["g"], _es(s["t"]) | |
def _sg(f, i): | |
f1, f2 = f.strip(), i.strip() | |
if not (f1 or f2): return gr.update(value="eh?"), gr.update(), gr.update() | |
p, g, e = _fetch(); guess = f1 or f2; gt = "full" if f1 else "idea" | |
r = {"ts": datetime.now(timezone.utc).isoformat(), "p": p, "time": e, "resp": g, "guess": guess, "type": gt} | |
with lock: open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(r, ensure_ascii=False) + "\n") | |
return gr.update(value="ok logged"), gr.update(value=""), gr.update(value="") | |
with gr.Blocks(theme=Dark()) as demo: | |
gr.Markdown("# What Comes Next") | |
prm = gr.Markdown(); txt = gr.Textbox(lines=10, interactive=False, label="oracle"); tme = gr.Textbox(interactive=False, label="time") | |
rbtn = gr.Button("refresh"); full = gr.Textbox(label="full"); idea = gr.Textbox(label="idea"); send = gr.Button("send"); st = gr.Textbox(interactive=False, label="st") | |
demo.load(_fetch, outputs=[prm, txt, tme]); rbtn.click(_fetch, outputs=[prm, txt, tme]); send.click(_sg, inputs=[full, idea], outputs=[st, full, idea]) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |