Spaces:
Running
Running
#!/usr/bin/env python3 | |
import os, json, time, random, threading, logging | |
from datetime import datetime, timezone | |
import torch, gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" | |
PROMPTS_PATH = "full_prompts.json" | |
STATE_PATH = "current_state.json" | |
DATA_PATH = "data.json" | |
TOKENS_PER_PROMPT = 2048 | |
SECS_BETWEEN_TOKENS = 15 | |
TEMPERATURE = 0.9 | |
TOP_P = 0.95 | |
MAX_CONTEXT_TOKENS = 8192 | |
logging.basicConfig(level=logging.INFO) | |
log = logging.getLogger() | |
def _read_json(p, d): | |
try: return json.load(open(p, encoding="utf-8")) | |
except: return d | |
def _atomic_write(p, o): | |
t = p + ".tmp"; open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)); os.replace(t,p) | |
def load_prompts(): | |
l = _read_json(PROMPTS_PATH, []) | |
if not l: raise FileNotFoundError | |
return l | |
# load model (uses HF_READ_TOKEN) | |
tok = os.environ.get("HF_READ_TOKEN") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=True, token=tok) | |
model.to(torch.device("cpu")); model.eval() | |
prompts = load_prompts(); lock = threading.Lock() | |
# main loop: oracle gen | |
def _init_state(): | |
s = _read_json(STATE_PATH, {}) | |
if not s or s.get("finished"): | |
i = random.randrange(len(prompts)) | |
s = {"prompt_idx":i, "prompt":prompts[i], "generated":"", "tokens_done":0, "start_time":time.time(), "finished":False} | |
_atomic_write(STATE_PATH, s) | |
return s | |
def _elapsed_str(st): | |
d=int(time.time()-st);h,r=divmod(d,3600);m,s=divmod(r,60);return f"{h}h {m}m {s}s" | |
def oracle_loop(): | |
while True: | |
with lock: s=_init_state() | |
if s["finished"]: time.sleep(SECS_BETWEEN_TOKENS); continue | |
c=s["prompt"]+s["generated"] | |
ids=tokenizer(c, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids | |
with torch.no_grad(): out=model.generate(ids, max_new_tokens=1, do_sample=True, temperature=TEMPERATURE, top_p=TOP_P) | |
nt=tokenizer.decode(out[0,-1], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
with lock: | |
s["generated"]+=nt; s["tokens_done"]+=1 | |
if s["tokens_done"]>=TOKENS_PER_PROMPT: s["finished"]=True | |
_atomic_write(STATE_PATH, s) | |
time.sleep(SECS_BETWEEN_TOKENS) | |
threading.Thread(target=oracle_loop, daemon=True).start() | |
# ui | |
def fetch_state(): | |
s=_read_json(STATE_PATH,{}) | |
if not s: return "Loading...","","0h 0m 0s" | |
return s["prompt"], s["generated"], _elapsed_str(s["start_time"]) | |
def submit_guess(full, idea): | |
f=full.strip(); i=idea.strip() | |
if not (f or i): return gr.update(value="enter guess!"),gr.update(),gr.update() | |
p,g,e=fetch_state(); guess=f or i; gt="full" if f else "idea" | |
r={"timestamp":datetime.now(timezone.utc).isoformat(),"prompt":p,"point-in-time":e,"response-point":g,"user-guess":guess,"guess-type":gt} | |
with lock: open(DATA_PATH,"a",encoding="utf-8").write(json.dumps(r,ensure_ascii=False)+"\n") | |
return gr.update(value="logged!"),gr.update(value=""),gr.update(value="") | |
with gr.Blocks(title="What Comes Next") as demo: | |
gr.Markdown("# What Comes Next - sloppy") | |
prm=gr.Markdown(); txt=gr.Textbox(lines=10,interactive=False,label="oracle"); elt=gr.Textbox(interactive=False,label="time") | |
r=gr.Button("refresh"); f=gr.Textbox(label="full guess"); i=gr.Textbox(label="idea"); sbtn=gr.Button("send"); st=gr.Textbox(interactive=False,label="st") | |
demo.load(fetch_state,outputs=[prm,txt,elt]) | |
r.click(fetch_state,outputs=[prm,txt,elt]); sbtn.click(submit_guess,inputs=[f,i],outputs=[st,f,i]) | |
if __name__=="__main__": demo.launch(server_name="0.0.0.0",server_port=7860) | |