ProCreations commited on
Commit
2dd44a8
·
verified ·
1 Parent(s): e03a150

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -217
app.py CHANGED
@@ -1,228 +1,91 @@
1
  #!/usr/bin/env python3
2
- """
3
- what_comes_next.py – Hugging Face Space implementation of **What Comes Next**
4
- A global, slow-burn guessing game powered by Llama-3.1-8B-Instruct (FP32, CPU-only).
5
-
6
- HOW IT WORKS
7
- ============
8
- • One shared model generates a single, very long completion (≈2 k tokens) for a chosen
9
- prompt in *full precision* on CPU. One token is sampled every ~15 s, so a prompt
10
- unfolds for roughly 10 hours. All visitors see the same progress in real-time.
11
- • Players read the partial output and may submit **either**
12
- 🧠 Exact continuation (full guess) **or** 💡 General idea (summary guess).
13
- • Each guess is appended to `data.json` with prompt, Oracle progress, timestamp & type.
14
- • Offline scoring (not included here) can later measure similarity vs the final text.
15
- """
16
-
17
- from __future__ import annotations
18
-
19
- import os
20
- import json
21
- import time
22
- import random
23
- import threading
24
- import logging
25
- from datetime import datetime, timezone
26
- from typing import Dict, Any
27
 
28
- import torch
29
- import gradio as gr
 
30
  from transformers import AutoTokenizer, AutoModelForCausalLM
31
 
32
- ###############################################################################
33
- # Configuration #
34
- ###############################################################################
35
- MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" # gated, requires HF_READ_TOKEN
36
- PROMPTS_PATH = "full_prompts.json" # 100 full prompts
37
- STATE_PATH = "current_state.json" # persistent Oracle state
38
- DATA_PATH = "data.json" # JSONL log of guesses
39
-
40
- TOKENS_PER_PROMPT = 2048 # stop after N generated tokens
41
- SECS_BETWEEN_TOKENS = 15 # ~10 h per prompt
42
- TEMPERATURE = 0.9 # higher creativity, as requested
43
- TOP_P = 0.95 # nucleus sampling
44
- MAX_CONTEXT_TOKENS = 8192 # safety cap
45
- ###############################################################################
46
-
47
- logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO)
48
- log = logging.getLogger("what-comes-next")
49
-
50
- ###############################################################################
51
- # Utility helpers #
52
- ###############################################################################
53
-
54
- def _read_json(path: str, default: Any):
55
- try:
56
- with open(path, "r", encoding="utf-8") as f:
57
- return json.load(f)
58
- except FileNotFoundError:
59
- return default
60
-
61
-
62
- def _atomic_write(path: str, obj: Any):
63
- tmp = f"{path}.tmp"
64
- with open(tmp, "w", encoding="utf-8") as f:
65
- json.dump(obj, f, ensure_ascii=False, indent=2)
66
- os.replace(tmp, path)
67
-
68
-
69
- def load_prompts() -> list[str]:
70
- if not os.path.exists(PROMPTS_PATH):
71
- raise FileNotFoundError(f"Missing {PROMPTS_PATH}. Please add 100 prompts.")
72
- with open(PROMPTS_PATH, "r", encoding="utf-8") as f:
73
- prompts = json.load(f)
74
- if not isinstance(prompts, list) or not prompts:
75
- raise ValueError("full_prompts.json must be a non-empty JSON array of strings")
76
- return prompts
77
-
78
- ###############################################################################
79
- # Model loading #
80
- ###############################################################################
81
- log.info("Loading Llama-3.1-8B-Instruct (FP32 CPU-only) using secret token…")
82
-
83
- tokenizer = AutoTokenizer.from_pretrained(
84
- MODEL_NAME,
85
- use_auth_token=os.environ.get("HF_READ_TOKEN")
86
- )
87
- model = AutoModelForCausalLM.from_pretrained(
88
- MODEL_NAME,
89
- torch_dtype=torch.float32,
90
- device_map={"": "cpu"},
91
- use_auth_token=os.environ.get("HF_READ_TOKEN")
92
- )
93
- model.eval()
94
- log.info("Model ready – Oracle awakened.")
95
-
96
- ###############################################################################
97
- # Global state #
98
- ###############################################################################
99
- lock = threading.Lock() # guard state + files
100
- prompts = load_prompts() # list of 100 strings
101
-
102
- ###############################################################################
103
- # Oracle generation thread #
104
- ###############################################################################
105
-
106
- def _init_state() -> Dict[str, Any]:
107
- """Return existing state or create a fresh one if none/finished."""
108
- state = _read_json(STATE_PATH, {})
109
- if not state or state.get("finished"):
110
- prompt_idx = random.randrange(len(prompts))
111
- state = {
112
- "prompt_idx": prompt_idx,
113
- "prompt": prompts[prompt_idx],
114
- "generated": "", # text so far
115
- "tokens_done": 0,
116
- "start_time": time.time(),
117
- "finished": False
118
- }
119
- _atomic_write(STATE_PATH, state)
120
- log.info(f"New Oracle prompt #{prompt_idx}: {state['prompt'][:80]}…")
121
- return state
122
-
123
-
124
- def _elapsed_str(start: float) -> str:
125
- d = int(time.time() - start)
126
- h, r = divmod(d, 3600)
127
- m, s = divmod(r, 60)
128
- return f"{h}h {m}m {s}s"
129
-
130
 
131
  def oracle_loop():
132
  while True:
 
 
 
 
 
 
133
  with lock:
134
- state = _init_state()
135
- if state["finished"]:
136
- time.sleep(SECS_BETWEEN_TOKENS)
137
- continue
138
-
139
- # Build context: prompt + generated so far
140
- context = state["prompt"] + state["generated"]
141
- input_ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids
142
-
143
- # Sample one token
144
- with torch.no_grad():
145
- out = model.generate(
146
- input_ids,
147
- max_new_tokens=1,
148
- do_sample=True,
149
- temperature=TEMPERATURE,
150
- top_p=TOP_P,
151
- )
152
- next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
153
-
154
- with lock:
155
- state["generated"] += next_token
156
- state["tokens_done"] += 1
157
- if state["tokens_done"] >= TOKENS_PER_PROMPT:
158
- state["finished"] = True
159
- log.info("Prompt completed – Oracle will select a new one shortly.")
160
- _atomic_write(STATE_PATH, state)
161
  time.sleep(SECS_BETWEEN_TOKENS)
162
-
163
  threading.Thread(target=oracle_loop, daemon=True).start()
164
 
165
- ###############################################################################
166
- # Gradio interface #
167
- ###############################################################################
168
-
169
- def fetch_state() -> tuple[str, str, str]:
170
- state = _read_json(STATE_PATH, {})
171
- if not state:
172
- return "Loading…", "", "0h 0m 0s"
173
- return state["prompt"], state["generated"], _elapsed_str(state["start_time"])
174
-
175
-
176
- def submit_guess(full: str, idea: str):
177
- full = full.strip()
178
- idea = idea.strip()
179
- if not full and not idea:
180
- return gr.update(value="⚠️ Enter a guess in one of the fields."), gr.update(), gr.update()
181
-
182
- prompt, generated, elapsed = fetch_state()
183
- guess_text = full or idea
184
- guess_type = "full" if full else "idea"
185
-
186
- record = {
187
- "timestamp": datetime.now(timezone.utc).isoformat(),
188
- "prompt": prompt,
189
- "point-in-time": elapsed,
190
- "response-point": generated,
191
- "user-guess": guess_text,
192
- "guess-type": guess_type
193
- }
194
- with lock:
195
- with open(DATA_PATH, "a", encoding="utf-8") as f:
196
- f.write(json.dumps(record, ensure_ascii=False) + "\n")
197
- log.info(f"Logged {guess_type} guess ({len(guess_text)} chars).")
198
- return gr.update(value="✅ Guess recorded – thanks!"), gr.update(value=""), gr.update(value="")
199
-
200
-
201
- with gr.Blocks(title="What Comes Next", theme="gradio/soft") as demo:
202
- gr.Markdown("""# 🌌 What Comes Next
203
- Watch the Oracle craft an extended response – **one token at a time**. Predict its
204
- next words or general direction and see how close you were when the tale concludes.
205
- (All inputs are stored in `data.json` for research.)""")
206
-
207
- prompt_md = gr.Markdown()
208
- oracle_box = gr.Textbox(lines=10, interactive=False, label="📜 Oracle text so far")
209
- elapsed_tb = gr.Textbox(interactive=False, label="⏱ Elapsed time")
210
-
211
- refresh_btn = gr.Button("🔄 Refresh")
212
-
213
- with gr.Row():
214
- exact_tb = gr.Textbox(label="🧠 Exact continuation (full)")
215
- idea_tb = gr.Textbox(label="💡 General idea")
216
- submit_btn = gr.Button("Submit Guess")
217
- status_tb = gr.Textbox(interactive=False, label="Status")
218
-
219
- # Actions
220
- refresh_btn.click(fetch_state, outputs=[prompt_md, oracle_box, elapsed_tb])
221
- demo.load(fetch_state, outputs=[prompt_md, oracle_box, elapsed_tb])
222
-
223
- submit_btn.click(submit_guess,
224
- inputs=[exact_tb, idea_tb],
225
- outputs=[status_tb, exact_tb, idea_tb])
226
-
227
- if __name__ == "__main__":
228
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
1
  #!/usr/bin/env python3
2
+ # what comes next sloppy version
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ import os, json, time, random, threading, logging
5
+ from datetime import datetime, timezone
6
+ import torch, gradio as gr
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
 
9
+ MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
10
+ PROMPTS_PATH = "full_prompts.json"
11
+ STATE_PATH = "current_state.json"
12
+ DATA_PATH = "data.json"
13
+ TOKENS_PER_PROMPT = 2048
14
+ SECS_BETWEEN_TOKENS = 15
15
+ TEMPERATURE = 0.9
16
+ TOP_P = 0.95
17
+ MAX_CONTEXT_TOKENS = 8192
18
+
19
+ logging.basicConfig(level=logging.INFO)
20
+ log = logging.getLogger()
21
+
22
+ def _read_json(p, d):
23
+ try: return json.load(open(p, encoding="utf-8"))
24
+ except: return d
25
+
26
+ def _atomic_write(p, o):
27
+ t = p + ".tmp"; open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)); os.replace(t,p)
28
+
29
+ def load_prompts():
30
+ l = _read_json(PROMPTS_PATH, [])
31
+ if not l: raise FileNotFoundError
32
+ return l
33
+
34
+ # load model (uses HF_READ_TOKEN)
35
+ tok = os.environ.get("HF_READ_TOKEN")
36
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok)
37
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=True, token=tok)
38
+ model.to(torch.device("cpu")); model.eval()
39
+
40
+ prompts = load_prompts(); lock = threading.Lock()
41
+
42
+ # main loop: oracle gen
43
+ def _init_state():
44
+ s = _read_json(STATE_PATH, {})
45
+ if not s or s.get("finished"):
46
+ i = random.randrange(len(prompts))
47
+ s = {"prompt_idx":i, "prompt":prompts[i], "generated":"", "tokens_done":0, "start_time":time.time(), "finished":False}
48
+ _atomic_write(STATE_PATH, s)
49
+ return s
50
+
51
+ def _elapsed_str(st):
52
+ d=int(time.time()-st);h,r=divmod(d,3600);m,s=divmod(r,60);return f"{h}h {m}m {s}s"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  def oracle_loop():
55
  while True:
56
+ with lock: s=_init_state()
57
+ if s["finished"]: time.sleep(SECS_BETWEEN_TOKENS); continue
58
+ c=s["prompt"]+s["generated"]
59
+ ids=tokenizer(c, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids
60
+ with torch.no_grad(): out=model.generate(ids, max_new_tokens=1, do_sample=True, temperature=TEMPERATURE, top_p=TOP_P)
61
+ nt=tokenizer.decode(out[0,-1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
62
  with lock:
63
+ s["generated"]+=nt; s["tokens_done"]+=1
64
+ if s["tokens_done"]>=TOKENS_PER_PROMPT: s["finished"]=True
65
+ _atomic_write(STATE_PATH, s)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  time.sleep(SECS_BETWEEN_TOKENS)
 
67
  threading.Thread(target=oracle_loop, daemon=True).start()
68
 
69
+ # ui
70
+
71
+ def fetch_state():
72
+ s=_read_json(STATE_PATH,{})
73
+ if not s: return "Loading...","","0h 0m 0s"
74
+ return s["prompt"], s["generated"], _elapsed_str(s["start_time"])
75
+
76
+ def submit_guess(full, idea):
77
+ f=full.strip(); i=idea.strip()
78
+ if not (f or i): return gr.update(value="enter guess!"),gr.update(),gr.update()
79
+ p,g,e=fetch_state(); guess=f or i; gt="full" if f else "idea"
80
+ r={"timestamp":datetime.now(timezone.utc).isoformat(),"prompt":p,"point-in-time":e,"response-point":g,"user-guess":guess,"guess-type":gt}
81
+ with lock: open(DATA_PATH,"a",encoding="utf-8").write(json.dumps(r,ensure_ascii=False)+"\n")
82
+ return gr.update(value="logged!"),gr.update(value=""),gr.update(value="")
83
+
84
+ with gr.Blocks(title="What Comes Next") as demo:
85
+ gr.Markdown("# What Comes Next - sloppy")
86
+ prm=gr.Markdown(); txt=gr.Textbox(lines=10,interactive=False,label="oracle"); elt=gr.Textbox(interactive=False,label="time")
87
+ r=gr.Button("refresh"); f=gr.Textbox(label="full guess"); i=gr.Textbox(label="idea"); sbtn=gr.Button("send"); st=gr.Textbox(interactive=False,label="st")
88
+ demo.load(fetch_state,outputs=[prm,txt,elt])
89
+ r.click(fetch_state,outputs=[prm,txt,elt]); sbtn.click(submit_guess,inputs=[f,i],outputs=[st,f,i])
90
+
91
+ if __name__=="__main__": demo.launch(server_name="0.0.0.0",server_port=7860)