File size: 3,685 Bytes
b7304c4
c08680c
 
a210442
2dd44a8
 
e18ee0c
 
c08680c
421d392
e18ee0c
b7304c4
2dd44a8
 
 
 
c08680c
2dd44a8
e18ee0c
 
2dd44a8
 
 
 
e18ee0c
 
 
2dd44a8
 
 
e18ee0c
 
2dd44a8
e18ee0c
 
 
 
2dd44a8
e18ee0c
2dd44a8
e18ee0c
2dd44a8
e18ee0c
 
 
2dd44a8
e18ee0c
2dd44a8
e18ee0c
 
2dd44a8
 
e18ee0c
 
2dd44a8
 
e18ee0c
421d392
e18ee0c
 
 
 
 
 
 
b7304c4
e18ee0c
 
 
 
 
 
b7304c4
e18ee0c
 
 
 
 
 
b7304c4
2dd44a8
 
e18ee0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python3



import os, json, time, random, threading, logging
from datetime import datetime, timezone
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
import gradio as gr

from transformers import AutoTokenizer, AutoModelForCausalLM
from gradio.themes import Dark

MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
PROMPTS_PATH = "full_prompts.json"
STATE_PATH = "current_state.json"
DATA_PATH = "data.json"

TOKENS_PER_PROMPT = 2048
SECS_PER_TOKEN = 15
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192

logging.basicConfig(level=logging.INFO)
log = logging.getLogger()

# read or write json

def _rj(p, d):
    try: return json.load(open(p, encoding="utf-8"))
    except: return d

def _aw(p, o):
    t = p + ".tmp"; open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)); os.replace(t, p)

# load prompts
tmp = _rj(PROMPTS_PATH, [])
if not tmp: raise Exception("no prompts")
prompts = tmp

# load model
tok = os.environ.get("HF_READ_TOKEN")
log.info("loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=False, token=tok)
model.to("cpu"); model.eval()
log.info("model up")

lock = threading.Lock()

def _init():
    s = _rj(STATE_PATH, {})
    if not s or s.get("finished"):
        i = random.randrange(len(prompts))
        s = {"i": i, "p": prompts[i], "g": "", "c": 0, "t": time.time(), "finished": False}
        _aw(STATE_PATH, s)
    return s

# elapsed time

def _es(st):
    d = int(time.time() - st); h, r = divmod(d, 3600); m, s = divmod(r, 60)
    return f"{h}h {m}m {s}s"

# oracle loop

def _loop():
    while True:
        with lock: s = _init()
        if s["finished"]: time.sleep(SECS_PER_TOKEN); continue
        c = s["p"] + s["g"]
        ids = tokenizer(c, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids
        with torch.no_grad(): out = model.generate(ids, max_new_tokens=1, do_sample=True, temperature=TEMP, top_p=TOP_P)
        nt = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
        with lock:
            s["g"] += nt; s["c"] += 1
            if s["c"] >= TOKENS_PER_PROMPT: s["finished"] = True
            _aw(STATE_PATH, s)
        time.sleep(SECS_PER_TOKEN)

threading.Thread(target=_loop, daemon=True).start()

# ui

def _fetch():
    s = _rj(STATE_PATH, {})
    if not s: return "...", "", "0h 0m 0s"
    return s["p"], s["g"], _es(s["t"])

def _sg(f, i):
    f1, f2 = f.strip(), i.strip()
    if not (f1 or f2): return gr.update(value="eh?"), gr.update(), gr.update()
    p, g, e = _fetch(); guess = f1 or f2; gt = "full" if f1 else "idea"
    r = {"ts": datetime.now(timezone.utc).isoformat(), "p": p, "time": e, "resp": g, "guess": guess, "type": gt}
    with lock: open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(r, ensure_ascii=False) + "\n")
    return gr.update(value="ok logged"), gr.update(value=""), gr.update(value="")

with gr.Blocks(theme=Dark()) as demo:
    gr.Markdown("# What Comes Next")
    prm = gr.Markdown(); txt = gr.Textbox(lines=10, interactive=False, label="oracle"); tme = gr.Textbox(interactive=False, label="time")
    rbtn = gr.Button("refresh"); full = gr.Textbox(label="full"); idea = gr.Textbox(label="idea"); send = gr.Button("send"); st = gr.Textbox(interactive=False, label="st")
    demo.load(_fetch, outputs=[prm, txt, tme]); rbtn.click(_fetch, outputs=[prm, txt, tme]); send.click(_sg, inputs=[full, idea], outputs=[st, full, idea])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)