what-comes-next / app.py
ProCreations's picture
Update app.py
2dd44a8 verified
raw
history blame
3.81 kB
#!/usr/bin/env python3
# what comes next sloppy version
import os, json, time, random, threading, logging
from datetime import datetime, timezone
import torch, gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
PROMPTS_PATH = "full_prompts.json"
STATE_PATH = "current_state.json"
DATA_PATH = "data.json"
TOKENS_PER_PROMPT = 2048
SECS_BETWEEN_TOKENS = 15
TEMPERATURE = 0.9
TOP_P = 0.95
MAX_CONTEXT_TOKENS = 8192
logging.basicConfig(level=logging.INFO)
log = logging.getLogger()
def _read_json(p, d):
try: return json.load(open(p, encoding="utf-8"))
except: return d
def _atomic_write(p, o):
t = p + ".tmp"; open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)); os.replace(t,p)
def load_prompts():
l = _read_json(PROMPTS_PATH, [])
if not l: raise FileNotFoundError
return l
# load model (uses HF_READ_TOKEN)
tok = os.environ.get("HF_READ_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=True, token=tok)
model.to(torch.device("cpu")); model.eval()
prompts = load_prompts(); lock = threading.Lock()
# main loop: oracle gen
def _init_state():
s = _read_json(STATE_PATH, {})
if not s or s.get("finished"):
i = random.randrange(len(prompts))
s = {"prompt_idx":i, "prompt":prompts[i], "generated":"", "tokens_done":0, "start_time":time.time(), "finished":False}
_atomic_write(STATE_PATH, s)
return s
def _elapsed_str(st):
d=int(time.time()-st);h,r=divmod(d,3600);m,s=divmod(r,60);return f"{h}h {m}m {s}s"
def oracle_loop():
while True:
with lock: s=_init_state()
if s["finished"]: time.sleep(SECS_BETWEEN_TOKENS); continue
c=s["prompt"]+s["generated"]
ids=tokenizer(c, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids
with torch.no_grad(): out=model.generate(ids, max_new_tokens=1, do_sample=True, temperature=TEMPERATURE, top_p=TOP_P)
nt=tokenizer.decode(out[0,-1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
with lock:
s["generated"]+=nt; s["tokens_done"]+=1
if s["tokens_done"]>=TOKENS_PER_PROMPT: s["finished"]=True
_atomic_write(STATE_PATH, s)
time.sleep(SECS_BETWEEN_TOKENS)
threading.Thread(target=oracle_loop, daemon=True).start()
# ui
def fetch_state():
s=_read_json(STATE_PATH,{})
if not s: return "Loading...","","0h 0m 0s"
return s["prompt"], s["generated"], _elapsed_str(s["start_time"])
def submit_guess(full, idea):
f=full.strip(); i=idea.strip()
if not (f or i): return gr.update(value="enter guess!"),gr.update(),gr.update()
p,g,e=fetch_state(); guess=f or i; gt="full" if f else "idea"
r={"timestamp":datetime.now(timezone.utc).isoformat(),"prompt":p,"point-in-time":e,"response-point":g,"user-guess":guess,"guess-type":gt}
with lock: open(DATA_PATH,"a",encoding="utf-8").write(json.dumps(r,ensure_ascii=False)+"\n")
return gr.update(value="logged!"),gr.update(value=""),gr.update(value="")
with gr.Blocks(title="What Comes Next") as demo:
gr.Markdown("# What Comes Next - sloppy")
prm=gr.Markdown(); txt=gr.Textbox(lines=10,interactive=False,label="oracle"); elt=gr.Textbox(interactive=False,label="time")
r=gr.Button("refresh"); f=gr.Textbox(label="full guess"); i=gr.Textbox(label="idea"); sbtn=gr.Button("send"); st=gr.Textbox(interactive=False,label="st")
demo.load(fetch_state,outputs=[prm,txt,elt])
r.click(fetch_state,outputs=[prm,txt,elt]); sbtn.click(submit_guess,inputs=[f,i],outputs=[st,f,i])
if __name__=="__main__": demo.launch(server_name="0.0.0.0",server_port=7860)