ProCreations commited on
Commit
7f977c5
·
verified ·
1 Parent(s): e4b0b00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -58
app.py CHANGED
@@ -1,8 +1,5 @@
1
  #!/usr/bin/env python3
2
 
3
-
4
-
5
-
6
  import os, json, time, random, threading, logging
7
  from datetime import datetime, timezone
8
  import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
@@ -18,77 +15,139 @@ TOKENS_PER_PROMPT = 2048
18
  SECS_PER_TOKEN = 15
19
  TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
20
 
21
-
22
  logging.basicConfig(level=logging.INFO)
23
  log = logging.getLogger()
24
 
25
- def _rj(p,d):
26
-
27
- try: return json.load(open(p,encoding="utf-8"))
28
- except: return d
29
-
30
 
31
- def _aw(p,o):
32
- t=p+".tmp"; open(t,"w",encoding="utf-8").write(json.dumps(o,ensure_ascii=False,indent=2)); os.replace(t,p)
33
 
34
- prompts=_rj(PROMPTS_PATH,[])
35
- if not prompts: raise Exception("no prompts")
 
 
36
 
37
- tok=os.environ.get("HF_READ_TOKEN")
38
- log.info("loading model...")
39
- tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME,token=tok)
40
- model=AutoModelForCausalLM.from_pretrained(MODEL_NAME,torch_dtype=torch.float32,low_cpu_mem_usage=False,token=tok)
41
- model.to("cpu");model.eval()
42
- log.info("model up")
43
 
44
- lock=threading.Lock()
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def _init():
47
- s=_rj(STATE_PATH,{})
48
- if not s or s.get("finished"):
49
- i=random.randrange(len(prompts))
50
- s={"i":i,"p":prompts[i],"g":"","c":0,"t":time.time(),"finished":False}
51
- _aw(STATE_PATH,s)
52
- return s
53
-
54
- def _es(st):
55
- d=int(time.time()-st);h,r=divmod(d,3600);m,s=divmod(r,60)
 
 
56
  return f"{h}h {m}m {s}s"
57
 
58
  def _loop():
59
  while True:
60
- with lock: s=_init()
61
- if s["finished"]: time.sleep(SECS_PER_TOKEN); continue
62
- c=s["p"]+s["g"]
63
- ids=tokenizer(c,return_tensors="pt",truncation=True,max_length=MAX_CTX).input_ids
64
- with torch.no_grad(): out=model.generate(ids,max_new_tokens=1,do_sample=True,temperature=TEMP,top_p=TOP_P)
65
- nt=tokenizer.decode(out[0,-1],skip_special_tokens=True,clean_up_tokenization_spaces=False)
66
  with lock:
67
- s["g"]+=nt; s["c"]+=1
68
- if s["c"]>=TOKENS_PER_PROMPT: s["finished"]=True
69
- _aw(STATE_PATH,s)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  time.sleep(SECS_PER_TOKEN)
71
- threading.Thread(target=_loop,daemon=True).start()
 
72
 
73
  def _fetch():
74
- s=_rj(STATE_PATH,{})
75
- if not s: return "...","","0h 0m 0s"
76
- return s["p"],s["g"],_es(s["t"])
77
-
78
- def _sg(f,i):
79
- f1,i1=f.strip(),i.strip()
80
- if not(f1 or i1): return gr.update(value="eh?"),gr.update(),gr.update()
81
- p,g,e=_fetch();guess=f1 or i1;gt="full" if f1 else "idea"
82
- r={"ts":datetime.now(timezone.utc).isoformat(),"prompt":p,"time":e,"resp":g,"guess":guess,"type":gt}
83
- with lock: open(DATA_PATH,"a",encoding="utf-8").write(json.dumps(r,ensure_ascii=False)+"\n")
84
- return gr.update(value="ok logged"),gr.update(value=""),gr.update(value="")
 
 
 
 
 
 
 
 
 
 
85
 
86
  with gr.Blocks(theme="darkdefault") as demo:
87
- gr.Markdown("# What Comes Next")
88
- prm=gr.Markdown();txt=gr.Textbox(lines=10,interactive=False,label="oracle");tme=gr.Textbox(interactive=False,label="time")
89
- rbtn=gr.Button("refresh");full=gr.Textbox(label="full");idea=gr.Textbox(label="idea");send=gr.Button("send");st=gr.Textbox(interactive=False,label="status")
90
- demo.load(_fetch,outputs=[prm,txt,tme])
91
- rbtn.click(_fetch,outputs=[prm,txt,tme])
92
- send.click(_sg,inputs=[full,idea],outputs=[st,full,idea])
93
-
94
- if __name__=="__main__": demo.launch(server_name="0.0.0.0",server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
 
 
 
 
3
  import os, json, time, random, threading, logging
4
  from datetime import datetime, timezone
5
  import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
 
15
  SECS_PER_TOKEN = 15
16
  TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
17
 
 
18
  logging.basicConfig(level=logging.INFO)
19
  log = logging.getLogger()
20
 
21
+ def _rj(p, d):
22
+ try:
23
+ return json.load(open(p, encoding="utf-8"))
24
+ except:
25
+ return d
26
 
 
 
27
 
28
+ def _aw(p, o):
29
+ t = p + ".tmp"
30
+ open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2))
31
+ os.replace(t, p)
32
 
33
+ prompts = _rj(PROMPTS_PATH, [])
34
+ if not prompts:
35
+ raise Exception("No prompts found in full_prompts.json")
 
 
 
36
 
37
+ tok = os.environ.get("HF_READ_TOKEN")
38
+ log.info("Loading model...")
39
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ MODEL_NAME,
42
+ torch_dtype=torch.float32,
43
+ low_cpu_mem_usage=False,
44
+ token=tok
45
+ )
46
+ model.to("cpu"); model.eval()
47
+ log.info("Model is ready.")
48
+
49
+ lock = threading.Lock()
50
 
51
  def _init():
52
+ state = _rj(STATE_PATH, {})
53
+ if not state or state.get("finished"):
54
+ idx = random.randrange(len(prompts))
55
+ state = {"i": idx, "p": prompts[idx], "g": "", "c": 0, "t": time.time(), "finished": False}
56
+ _aw(STATE_PATH, state)
57
+ return state
58
+
59
+ def _es(start_time):
60
+ elapsed = int(time.time() - start_time)
61
+ h, rem = divmod(elapsed, 3600)
62
+ m, s = divmod(rem, 60)
63
  return f"{h}h {m}m {s}s"
64
 
65
  def _loop():
66
  while True:
 
 
 
 
 
 
67
  with lock:
68
+ st = _init()
69
+ if st["finished"]:
70
+ time.sleep(SECS_PER_TOKEN)
71
+ continue
72
+ context = st["p"] + st["g"]
73
+ ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids
74
+ with torch.no_grad():
75
+ out = model.generate(
76
+ ids,
77
+ max_new_tokens=1,
78
+ do_sample=True,
79
+ temperature=TEMP,
80
+ top_p=TOP_P
81
+ )
82
+ next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
83
+ with lock:
84
+ st["g"] += next_token
85
+ st["c"] += 1
86
+ if st["c"] >= TOKENS_PER_PROMPT:
87
+ st["finished"] = True
88
+ _aw(STATE_PATH, st)
89
  time.sleep(SECS_PER_TOKEN)
90
+
91
+ threading.Thread(target=_loop, daemon=True).start()
92
 
93
  def _fetch():
94
+ state = _rj(STATE_PATH, {})
95
+ if not state:
96
+ return "...", "", "0h 0m 0s"
97
+ return state["p"], state["g"], _es(state["t"])
98
+
99
+ def _submit_prediction(detailed, summary):
100
+ det = detailed.strip()
101
+ if not det:
102
+ return gr.update(value="Please enter at least a detailed prediction."), gr.update(value=""), gr.update(value="")
103
+ prompt_text, oracle_resp, elapsed = _fetch()
104
+ record = {
105
+ "ts": datetime.now(timezone.utc).isoformat(),
106
+ "prompt": prompt_text,
107
+ "time": elapsed,
108
+ "resp": oracle_resp,
109
+ "prediction": det,
110
+ "summary": summary.strip()
111
+ }
112
+ with lock:
113
+ open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(record, ensure_ascii=False) + "\n")
114
+ return gr.update(value="Prediction logged!"), gr.update(value=""), gr.update(value="")
115
 
116
  with gr.Blocks(theme="darkdefault") as demo:
117
+ gr.Markdown(
118
+ "# What Comes Next\n"
119
+ "Enter what you think will come next in the text.\n"
120
+ "Provide a detailed continuation and optionally a brief summary for context."
121
+ )
122
+ prompt_md = gr.Markdown()
123
+ oracle_output = gr.Textbox(lines=10, interactive=False, label="Oracle Response")
124
+ time_info = gr.Textbox(interactive=False, label="Elapsed Time")
125
+
126
+ with gr.Row():
127
+ prompt_md, oracle_output, time_info
128
+
129
+ detailed = gr.Textbox(
130
+ label="Your Detailed Prediction",
131
+ placeholder="Enter the full text continuation you expect...",
132
+ lines=3
133
+ )
134
+ summary = gr.Textbox(
135
+ label="Prediction Summary (Optional)",
136
+ placeholder="Optionally, summarize your prediction in a few words...",
137
+ lines=2
138
+ )
139
+ status = gr.Textbox(interactive=False, label="Status")
140
+ submit_btn = gr.Button("Submit Prediction")
141
+ refresh_btn = gr.Button("Refresh Oracle")
142
+
143
+ demo.load(_fetch, outputs=[prompt_md, oracle_output, time_info])
144
+ refresh_btn.click(_fetch, outputs=[prompt_md, oracle_output, time_info])
145
+ submit_btn.click(
146
+ _submit_prediction,
147
+ inputs=[detailed, summary],
148
+ outputs=[status, detailed, summary]
149
+ )
150
+
151
+ if __name__ == "__main__":
152
+ demo.launch(server_name="0.0.0.0", server_port=7860)
153
+