Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,5 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
import os, json, time, random, threading, logging
|
7 |
from datetime import datetime, timezone
|
8 |
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
|
@@ -18,77 +15,139 @@ TOKENS_PER_PROMPT = 2048
|
|
18 |
SECS_PER_TOKEN = 15
|
19 |
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
|
20 |
|
21 |
-
|
22 |
logging.basicConfig(level=logging.INFO)
|
23 |
log = logging.getLogger()
|
24 |
|
25 |
-
def _rj(p,d):
|
26 |
-
|
27 |
-
|
28 |
-
except:
|
29 |
-
|
30 |
|
31 |
-
def _aw(p,o):
|
32 |
-
t=p+".tmp"; open(t,"w",encoding="utf-8").write(json.dumps(o,ensure_ascii=False,indent=2)); os.replace(t,p)
|
33 |
|
34 |
-
|
35 |
-
|
|
|
|
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
model=AutoModelForCausalLM.from_pretrained(MODEL_NAME,torch_dtype=torch.float32,low_cpu_mem_usage=False,token=tok)
|
41 |
-
model.to("cpu");model.eval()
|
42 |
-
log.info("model up")
|
43 |
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
def _init():
|
47 |
-
|
48 |
-
if not
|
49 |
-
|
50 |
-
|
51 |
-
_aw(STATE_PATH,
|
52 |
-
return
|
53 |
-
|
54 |
-
def _es(
|
55 |
-
|
|
|
|
|
56 |
return f"{h}h {m}m {s}s"
|
57 |
|
58 |
def _loop():
|
59 |
while True:
|
60 |
-
with lock: s=_init()
|
61 |
-
if s["finished"]: time.sleep(SECS_PER_TOKEN); continue
|
62 |
-
c=s["p"]+s["g"]
|
63 |
-
ids=tokenizer(c,return_tensors="pt",truncation=True,max_length=MAX_CTX).input_ids
|
64 |
-
with torch.no_grad(): out=model.generate(ids,max_new_tokens=1,do_sample=True,temperature=TEMP,top_p=TOP_P)
|
65 |
-
nt=tokenizer.decode(out[0,-1],skip_special_tokens=True,clean_up_tokenization_spaces=False)
|
66 |
with lock:
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
time.sleep(SECS_PER_TOKEN)
|
71 |
-
|
|
|
72 |
|
73 |
def _fetch():
|
74 |
-
|
75 |
-
if not
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
with gr.Blocks(theme="darkdefault") as demo:
|
87 |
-
gr.Markdown(
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
|
|
|
|
|
|
|
3 |
import os, json, time, random, threading, logging
|
4 |
from datetime import datetime, timezone
|
5 |
import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count())
|
|
|
15 |
SECS_PER_TOKEN = 15
|
16 |
TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192
|
17 |
|
|
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
log = logging.getLogger()
|
20 |
|
21 |
+
def _rj(p, d):
|
22 |
+
try:
|
23 |
+
return json.load(open(p, encoding="utf-8"))
|
24 |
+
except:
|
25 |
+
return d
|
26 |
|
|
|
|
|
27 |
|
28 |
+
def _aw(p, o):
|
29 |
+
t = p + ".tmp"
|
30 |
+
open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2))
|
31 |
+
os.replace(t, p)
|
32 |
|
33 |
+
prompts = _rj(PROMPTS_PATH, [])
|
34 |
+
if not prompts:
|
35 |
+
raise Exception("No prompts found in full_prompts.json")
|
|
|
|
|
|
|
36 |
|
37 |
+
tok = os.environ.get("HF_READ_TOKEN")
|
38 |
+
log.info("Loading model...")
|
39 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
|
40 |
+
model = AutoModelForCausalLM.from_pretrained(
|
41 |
+
MODEL_NAME,
|
42 |
+
torch_dtype=torch.float32,
|
43 |
+
low_cpu_mem_usage=False,
|
44 |
+
token=tok
|
45 |
+
)
|
46 |
+
model.to("cpu"); model.eval()
|
47 |
+
log.info("Model is ready.")
|
48 |
+
|
49 |
+
lock = threading.Lock()
|
50 |
|
51 |
def _init():
|
52 |
+
state = _rj(STATE_PATH, {})
|
53 |
+
if not state or state.get("finished"):
|
54 |
+
idx = random.randrange(len(prompts))
|
55 |
+
state = {"i": idx, "p": prompts[idx], "g": "", "c": 0, "t": time.time(), "finished": False}
|
56 |
+
_aw(STATE_PATH, state)
|
57 |
+
return state
|
58 |
+
|
59 |
+
def _es(start_time):
|
60 |
+
elapsed = int(time.time() - start_time)
|
61 |
+
h, rem = divmod(elapsed, 3600)
|
62 |
+
m, s = divmod(rem, 60)
|
63 |
return f"{h}h {m}m {s}s"
|
64 |
|
65 |
def _loop():
|
66 |
while True:
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
with lock:
|
68 |
+
st = _init()
|
69 |
+
if st["finished"]:
|
70 |
+
time.sleep(SECS_PER_TOKEN)
|
71 |
+
continue
|
72 |
+
context = st["p"] + st["g"]
|
73 |
+
ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids
|
74 |
+
with torch.no_grad():
|
75 |
+
out = model.generate(
|
76 |
+
ids,
|
77 |
+
max_new_tokens=1,
|
78 |
+
do_sample=True,
|
79 |
+
temperature=TEMP,
|
80 |
+
top_p=TOP_P
|
81 |
+
)
|
82 |
+
next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
83 |
+
with lock:
|
84 |
+
st["g"] += next_token
|
85 |
+
st["c"] += 1
|
86 |
+
if st["c"] >= TOKENS_PER_PROMPT:
|
87 |
+
st["finished"] = True
|
88 |
+
_aw(STATE_PATH, st)
|
89 |
time.sleep(SECS_PER_TOKEN)
|
90 |
+
|
91 |
+
threading.Thread(target=_loop, daemon=True).start()
|
92 |
|
93 |
def _fetch():
|
94 |
+
state = _rj(STATE_PATH, {})
|
95 |
+
if not state:
|
96 |
+
return "...", "", "0h 0m 0s"
|
97 |
+
return state["p"], state["g"], _es(state["t"])
|
98 |
+
|
99 |
+
def _submit_prediction(detailed, summary):
|
100 |
+
det = detailed.strip()
|
101 |
+
if not det:
|
102 |
+
return gr.update(value="Please enter at least a detailed prediction."), gr.update(value=""), gr.update(value="")
|
103 |
+
prompt_text, oracle_resp, elapsed = _fetch()
|
104 |
+
record = {
|
105 |
+
"ts": datetime.now(timezone.utc).isoformat(),
|
106 |
+
"prompt": prompt_text,
|
107 |
+
"time": elapsed,
|
108 |
+
"resp": oracle_resp,
|
109 |
+
"prediction": det,
|
110 |
+
"summary": summary.strip()
|
111 |
+
}
|
112 |
+
with lock:
|
113 |
+
open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(record, ensure_ascii=False) + "\n")
|
114 |
+
return gr.update(value="Prediction logged!"), gr.update(value=""), gr.update(value="")
|
115 |
|
116 |
with gr.Blocks(theme="darkdefault") as demo:
|
117 |
+
gr.Markdown(
|
118 |
+
"# What Comes Next\n"
|
119 |
+
"Enter what you think will come next in the text.\n"
|
120 |
+
"Provide a detailed continuation and optionally a brief summary for context."
|
121 |
+
)
|
122 |
+
prompt_md = gr.Markdown()
|
123 |
+
oracle_output = gr.Textbox(lines=10, interactive=False, label="Oracle Response")
|
124 |
+
time_info = gr.Textbox(interactive=False, label="Elapsed Time")
|
125 |
+
|
126 |
+
with gr.Row():
|
127 |
+
prompt_md, oracle_output, time_info
|
128 |
+
|
129 |
+
detailed = gr.Textbox(
|
130 |
+
label="Your Detailed Prediction",
|
131 |
+
placeholder="Enter the full text continuation you expect...",
|
132 |
+
lines=3
|
133 |
+
)
|
134 |
+
summary = gr.Textbox(
|
135 |
+
label="Prediction Summary (Optional)",
|
136 |
+
placeholder="Optionally, summarize your prediction in a few words...",
|
137 |
+
lines=2
|
138 |
+
)
|
139 |
+
status = gr.Textbox(interactive=False, label="Status")
|
140 |
+
submit_btn = gr.Button("Submit Prediction")
|
141 |
+
refresh_btn = gr.Button("Refresh Oracle")
|
142 |
+
|
143 |
+
demo.load(_fetch, outputs=[prompt_md, oracle_output, time_info])
|
144 |
+
refresh_btn.click(_fetch, outputs=[prompt_md, oracle_output, time_info])
|
145 |
+
submit_btn.click(
|
146 |
+
_submit_prediction,
|
147 |
+
inputs=[detailed, summary],
|
148 |
+
outputs=[status, detailed, summary]
|
149 |
+
)
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
153 |
+
|