File size: 3,777 Bytes
b7304c4
77ac99e
a210442
2dd44a8
 
 
77ac99e
421d392
b7304c4
2dd44a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421d392
b7304c4
 
2dd44a8
 
 
 
 
 
b7304c4
2dd44a8
 
 
421d392
b7304c4
 
2dd44a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env python3


import os, json, time, random, threading, logging
from datetime import datetime, timezone
import torch, gradio as gr

from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
PROMPTS_PATH = "full_prompts.json"
STATE_PATH = "current_state.json"
DATA_PATH = "data.json"
TOKENS_PER_PROMPT = 2048
SECS_BETWEEN_TOKENS = 15
TEMPERATURE = 0.9
TOP_P = 0.95
MAX_CONTEXT_TOKENS = 8192

logging.basicConfig(level=logging.INFO)
log = logging.getLogger()

def _read_json(p, d):
    try: return json.load(open(p, encoding="utf-8"))
    except: return d

def _atomic_write(p, o):
    t = p + ".tmp"; open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)); os.replace(t,p)

def load_prompts():
    l = _read_json(PROMPTS_PATH, [])
    if not l: raise FileNotFoundError
    return l

# load model (uses HF_READ_TOKEN)
tok = os.environ.get("HF_READ_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=True, token=tok)
model.to(torch.device("cpu")); model.eval()

prompts = load_prompts(); lock = threading.Lock()

# main loop: oracle gen
def _init_state():
    s = _read_json(STATE_PATH, {})
    if not s or s.get("finished"):
        i = random.randrange(len(prompts))
        s = {"prompt_idx":i, "prompt":prompts[i], "generated":"", "tokens_done":0, "start_time":time.time(), "finished":False}
        _atomic_write(STATE_PATH, s)
    return s

def _elapsed_str(st):
    d=int(time.time()-st);h,r=divmod(d,3600);m,s=divmod(r,60);return f"{h}h {m}m {s}s"

def oracle_loop():
    while True:
        with lock: s=_init_state()
        if s["finished"]: time.sleep(SECS_BETWEEN_TOKENS); continue
        c=s["prompt"]+s["generated"]
        ids=tokenizer(c, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids
        with torch.no_grad(): out=model.generate(ids, max_new_tokens=1, do_sample=True, temperature=TEMPERATURE, top_p=TOP_P)
        nt=tokenizer.decode(out[0,-1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
        with lock:
            s["generated"]+=nt; s["tokens_done"]+=1
            if s["tokens_done"]>=TOKENS_PER_PROMPT: s["finished"]=True
            _atomic_write(STATE_PATH, s)
        time.sleep(SECS_BETWEEN_TOKENS)
threading.Thread(target=oracle_loop, daemon=True).start()

# ui

def fetch_state():
    s=_read_json(STATE_PATH,{})
    if not s: return "Loading...","","0h 0m 0s"
    return s["prompt"], s["generated"], _elapsed_str(s["start_time"])

def submit_guess(full, idea):
    f=full.strip(); i=idea.strip()
    if not (f or i): return gr.update(value="enter guess!"),gr.update(),gr.update()
    p,g,e=fetch_state(); guess=f or i; gt="full" if f else "idea"
    r={"timestamp":datetime.now(timezone.utc).isoformat(),"prompt":p,"point-in-time":e,"response-point":g,"user-guess":guess,"guess-type":gt}
    with lock: open(DATA_PATH,"a",encoding="utf-8").write(json.dumps(r,ensure_ascii=False)+"\n")
    return gr.update(value="logged!"),gr.update(value=""),gr.update(value="")

with gr.Blocks(title="What Comes Next") as demo:
    gr.Markdown("# What Comes Next - sloppy")
    prm=gr.Markdown(); txt=gr.Textbox(lines=10,interactive=False,label="oracle"); elt=gr.Textbox(interactive=False,label="time")
    r=gr.Button("refresh"); f=gr.Textbox(label="full guess"); i=gr.Textbox(label="idea"); sbtn=gr.Button("send"); st=gr.Textbox(interactive=False,label="st")
    demo.load(fetch_state,outputs=[prm,txt,elt])
    r.click(fetch_state,outputs=[prm,txt,elt]); sbtn.click(submit_guess,inputs=[f,i],outputs=[st,f,i])

if __name__=="__main__": demo.launch(server_name="0.0.0.0",server_port=7860)