ProCreations commited on
Commit
e18ee0c
·
verified ·
1 Parent(s): 77ac99e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -58
app.py CHANGED
@@ -1,92 +1,100 @@
1
  #!/usr/bin/env python3
2
-
3
 
4
  import os, json, time, random, threading, logging
5
  from datetime import datetime, timezone
6
- import torch, gradio as gr
7
-
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
9
 
10
  MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
11
  PROMPTS_PATH = "full_prompts.json"
12
  STATE_PATH = "current_state.json"
13
  DATA_PATH = "data.json"
14
  TOKENS_PER_PROMPT = 2048
15
- SECS_BETWEEN_TOKENS = 15
16
- TEMPERATURE = 0.9
17
- TOP_P = 0.95
18
- MAX_CONTEXT_TOKENS = 8192
19
 
20
  logging.basicConfig(level=logging.INFO)
21
  log = logging.getLogger()
22
 
23
- def _read_json(p, d):
 
 
24
  try: return json.load(open(p, encoding="utf-8"))
25
  except: return d
26
 
27
- def _atomic_write(p, o):
28
- t = p + ".tmp"; open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)); os.replace(t,p)
29
 
30
- def load_prompts():
31
- l = _read_json(PROMPTS_PATH, [])
32
- if not l: raise FileNotFoundError
33
- return l
34
 
35
- # load model (uses HF_READ_TOKEN)
36
  tok = os.environ.get("HF_READ_TOKEN")
 
37
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
38
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=True, token=tok)
39
- model.to(torch.device("cpu")); model.eval()
 
40
 
41
- prompts = load_prompts(); lock = threading.Lock()
42
 
43
- # main loop: oracle gen
44
- def _init_state():
45
- s = _read_json(STATE_PATH, {})
46
  if not s or s.get("finished"):
47
  i = random.randrange(len(prompts))
48
- s = {"prompt_idx":i, "prompt":prompts[i], "generated":"", "tokens_done":0, "start_time":time.time(), "finished":False}
49
- _atomic_write(STATE_PATH, s)
50
  return s
51
 
52
- def _elapsed_str(st):
53
- d=int(time.time()-st);h,r=divmod(d,3600);m,s=divmod(r,60);return f"{h}h {m}m {s}s"
54
 
55
- def oracle_loop():
 
 
 
 
 
 
56
  while True:
57
- with lock: s=_init_state()
58
- if s["finished"]: time.sleep(SECS_BETWEEN_TOKENS); continue
59
- c=s["prompt"]+s["generated"]
60
- ids=tokenizer(c, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids
61
- with torch.no_grad(): out=model.generate(ids, max_new_tokens=1, do_sample=True, temperature=TEMPERATURE, top_p=TOP_P)
62
- nt=tokenizer.decode(out[0,-1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
63
  with lock:
64
- s["generated"]+=nt; s["tokens_done"]+=1
65
- if s["tokens_done"]>=TOKENS_PER_PROMPT: s["finished"]=True
66
- _atomic_write(STATE_PATH, s)
67
- time.sleep(SECS_BETWEEN_TOKENS)
68
- threading.Thread(target=oracle_loop, daemon=True).start()
 
69
 
70
  # ui
71
 
72
- def fetch_state():
73
- s=_read_json(STATE_PATH,{})
74
- if not s: return "Loading...","","0h 0m 0s"
75
- return s["prompt"], s["generated"], _elapsed_str(s["start_time"])
76
-
77
- def submit_guess(full, idea):
78
- f=full.strip(); i=idea.strip()
79
- if not (f or i): return gr.update(value="enter guess!"),gr.update(),gr.update()
80
- p,g,e=fetch_state(); guess=f or i; gt="full" if f else "idea"
81
- r={"timestamp":datetime.now(timezone.utc).isoformat(),"prompt":p,"point-in-time":e,"response-point":g,"user-guess":guess,"guess-type":gt}
82
- with lock: open(DATA_PATH,"a",encoding="utf-8").write(json.dumps(r,ensure_ascii=False)+"\n")
83
- return gr.update(value="logged!"),gr.update(value=""),gr.update(value="")
84
-
85
- with gr.Blocks(title="What Comes Next") as demo:
86
- gr.Markdown("# What Comes Next - sloppy")
87
- prm=gr.Markdown(); txt=gr.Textbox(lines=10,interactive=False,label="oracle"); elt=gr.Textbox(interactive=False,label="time")
88
- r=gr.Button("refresh"); f=gr.Textbox(label="full guess"); i=gr.Textbox(label="idea"); sbtn=gr.Button("send"); st=gr.Textbox(interactive=False,label="st")
89
- demo.load(fetch_state,outputs=[prm,txt,elt])
90
- r.click(fetch_state,outputs=[prm,txt,elt]); sbtn.click(submit_guess,inputs=[f,i],outputs=[st,f,i])
91
-
92
- if __name__=="__main__": demo.launch(server_name="0.0.0.0",server_port=7860)
 
1
  #!/usr/bin/env python3
2
+ # sloppy header
3
 
4
  import os, json, time, random, threading, logging
5
  from datetime import datetime, timezone
6
+ import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
7
+ import gradio as gr
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from gradio.themes import Dark
10
 
11
  MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
12
  PROMPTS_PATH = "full_prompts.json"
13
  STATE_PATH = "current_state.json"
14
  DATA_PATH = "data.json"
15
  TOKENS_PER_PROMPT = 2048
16
+ SECS_PER_TOKEN = 15
17
+ TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
 
 
18
 
19
  logging.basicConfig(level=logging.INFO)
20
  log = logging.getLogger()
21
 
22
+ # read or write json
23
+
24
+ def _rj(p, d):
25
  try: return json.load(open(p, encoding="utf-8"))
26
  except: return d
27
 
28
+ def _aw(p, o):
29
+ t = p + ".tmp"; open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)); os.replace(t, p)
30
 
31
+ # load prompts
32
+ tmp = _rj(PROMPTS_PATH, [])
33
+ if not tmp: raise Exception("no prompts")
34
+ prompts = tmp
35
 
36
+ # load model
37
  tok = os.environ.get("HF_READ_TOKEN")
38
+ log.info("loading model...")
39
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
40
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=False, token=tok)
41
+ model.to("cpu"); model.eval()
42
+ log.info("model up")
43
 
44
+ lock = threading.Lock()
45
 
46
+ def _init():
47
+ s = _rj(STATE_PATH, {})
 
48
  if not s or s.get("finished"):
49
  i = random.randrange(len(prompts))
50
+ s = {"i": i, "p": prompts[i], "g": "", "c": 0, "t": time.time(), "finished": False}
51
+ _aw(STATE_PATH, s)
52
  return s
53
 
54
+ # elapsed time
 
55
 
56
+ def _es(st):
57
+ d = int(time.time() - st); h, r = divmod(d, 3600); m, s = divmod(r, 60)
58
+ return f"{h}h {m}m {s}s"
59
+
60
+ # oracle loop
61
+
62
+ def _loop():
63
  while True:
64
+ with lock: s = _init()
65
+ if s["finished"]: time.sleep(SECS_PER_TOKEN); continue
66
+ c = s["p"] + s["g"]
67
+ ids = tokenizer(c, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids
68
+ with torch.no_grad(): out = model.generate(ids, max_new_tokens=1, do_sample=True, temperature=TEMP, top_p=TOP_P)
69
+ nt = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
70
  with lock:
71
+ s["g"] += nt; s["c"] += 1
72
+ if s["c"] >= TOKENS_PER_PROMPT: s["finished"] = True
73
+ _aw(STATE_PATH, s)
74
+ time.sleep(SECS_PER_TOKEN)
75
+
76
+ threading.Thread(target=_loop, daemon=True).start()
77
 
78
  # ui
79
 
80
+ def _fetch():
81
+ s = _rj(STATE_PATH, {})
82
+ if not s: return "...", "", "0h 0m 0s"
83
+ return s["p"], s["g"], _es(s["t"])
84
+
85
+ def _sg(f, i):
86
+ f1, f2 = f.strip(), i.strip()
87
+ if not (f1 or f2): return gr.update(value="eh?"), gr.update(), gr.update()
88
+ p, g, e = _fetch(); guess = f1 or f2; gt = "full" if f1 else "idea"
89
+ r = {"ts": datetime.now(timezone.utc).isoformat(), "p": p, "time": e, "resp": g, "guess": guess, "type": gt}
90
+ with lock: open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(r, ensure_ascii=False) + "\n")
91
+ return gr.update(value="ok logged"), gr.update(value=""), gr.update(value="")
92
+
93
+ with gr.Blocks(theme=Dark()) as demo:
94
+ gr.Markdown("# What Comes Next")
95
+ prm = gr.Markdown(); txt = gr.Textbox(lines=10, interactive=False, label="oracle"); tme = gr.Textbox(interactive=False, label="time")
96
+ rbtn = gr.Button("refresh"); full = gr.Textbox(label="full"); idea = gr.Textbox(label="idea"); send = gr.Button("send"); st = gr.Textbox(interactive=False, label="st")
97
+ demo.load(_fetch, outputs=[prm, txt, tme]); rbtn.click(_fetch, outputs=[prm, txt, tme]); send.click(_sg, inputs=[full, idea], outputs=[st, full, idea])
98
+
99
+ if __name__ == "__main__":
100
+ demo.launch(server_name="0.0.0.0", server_port=7860)