ProCreations commited on
Commit
421d392
·
verified ·
1 Parent(s): 6760dbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -148
app.py CHANGED
@@ -1,23 +1,21 @@
1
  #!/usr/bin/env python3
2
  """
3
- what_comes_next.py – HuggingFace Space implementation of **What Comes Next**
4
- A slow, contemplative global guessing game.
5
-
6
- 🔮 HOW IT WORKS 🔮
7
- • A single Llama‑3.1‑8B‑Instruct model (FP32 on CPU) is generating one very long completion
8
- for a chosen mystical prompt. It runs continuously in the background for everyone.
9
- Any visitor sees the same prompt and the Oracle’s current partial response.
10
- Players may submit *one* of two kinds of guesses:
11
- 1. 🧠 **Exact Completion** the full sentence/paragraph they think the Oracle will
12
- eventually write.
13
- 2. 💡 **General Idea** – a short summary of the direction or theme they expect.
14
- Each guess is recorded immediately (with timestamp, Oracle progress, etc.) to
15
- `data.json` (JSON‑Lines). When the Oracle finally finishes, offline evaluation can
16
- score the guesses against the final text.
17
-
18
- The game then moves on to the next prompt and the cycle repeats.
19
  """
20
 
 
 
21
  import os
22
  import json
23
  import time
@@ -25,47 +23,45 @@ import random
25
  import threading
26
  import logging
27
  from datetime import datetime, timezone
28
- from pathlib import Path
29
  from typing import Dict, Any
30
 
31
  import torch
32
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
33
  import gradio as gr
 
34
 
35
  ###############################################################################
36
- # Settings #
37
  ###############################################################################
38
- MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" # FP32, CPU‑only
39
- PROMPTS_PATH = "oracle_prompts.json" # 100 unfinished lines
40
- STATE_PATH = "current_state.json" # persistent Oracle state
41
- DATA_PATH = "data.json" # JSONL of user guesses
42
- TOKENS_PER_PROMPT = 2048 # stop after N tokens
43
- SECS_BETWEEN_TOKENS = 15 # pacing (≈10h / prompt)
44
- TEMPERATURE = 0.8
45
- TOP_P = 0.95
46
- MAX_CONTEXT_TOKENS = 8192
 
47
  ###############################################################################
48
 
49
  logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO)
50
- log = logging.getLogger("whatcomesnext")
51
-
52
- lock = threading.Lock() # global file/variable lock
53
 
54
- # --------------------------------------------------------------------------- #
55
- # Helper functions #
56
- # --------------------------------------------------------------------------- #
57
 
58
  def _read_json(path: str, default: Any):
59
  try:
60
- with open(path, "r", encoding="utf8") as f:
61
  return json.load(f)
62
  except FileNotFoundError:
63
  return default
64
 
65
 
66
- def _write_json(path: str, obj: Any):
67
  tmp = f"{path}.tmp"
68
- with open(tmp, "w", encoding="utf8") as f:
69
  json.dump(obj, f, ensure_ascii=False, indent=2)
70
  os.replace(tmp, path)
71
 
@@ -73,170 +69,156 @@ def _write_json(path: str, obj: Any):
73
  def load_prompts() -> list[str]:
74
  if not os.path.exists(PROMPTS_PATH):
75
  raise FileNotFoundError(f"Missing {PROMPTS_PATH}. Please add 100 prompts.")
76
- with open(PROMPTS_PATH, "r", encoding="utf8") as f:
77
- return json.load(f)
 
 
 
78
 
79
-
80
- prompts = load_prompts()
81
-
82
- # --------------------------------------------------------------------------- #
83
- # Model loading (FP32 ‑ CPU) #
84
- # --------------------------------------------------------------------------- #
85
- log.info("Loading Llama‑3.1‑8B‑Instruct in FP32 on CPU (this is *slow*) …")
86
 
87
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
88
  model = AutoModelForCausalLM.from_pretrained(
89
  MODEL_NAME,
90
  torch_dtype=torch.float32,
91
- device_map={"": "cpu"}, # force CPU
92
  )
93
  model.eval()
94
- log.info("Model loaded.")
95
 
96
- # --------------------------------------------------------------------------- #
 
 
 
 
 
 
97
  # Oracle generation thread #
98
- # --------------------------------------------------------------------------- #
99
 
100
- def init_state() -> Dict[str, Any]:
101
- """Return existing state or create a new one."""
102
  state = _read_json(STATE_PATH, {})
103
- if state.get("finished", False):
104
- state = {} # finished, start new prompt
105
- if not state:
106
  prompt_idx = random.randrange(len(prompts))
107
- prompt = prompts[prompt_idx]
108
  state = {
109
  "prompt_idx": prompt_idx,
110
- "prompt": prompt,
111
- "generated": "", # Oracle’s text so far (string)
 
112
  "start_time": time.time(),
113
- "finished": False,
114
- "tokens_done": 0
115
  }
116
- _write_json(STATE_PATH, state)
117
- log.info(f"Starting new Oracle prompt #{prompt_idx}: {prompt[:60]}…")
118
  return state
119
 
120
 
 
 
 
 
 
 
 
121
  def oracle_loop():
122
- """Continuously extend the Oracle’s text by one token every SECS_BETWEEN_TOKENS."""
123
  while True:
124
  with lock:
125
- state = init_state()
126
- if state["finished"]:
127
- # Should not happen, but guard anyway
128
- time.sleep(SECS_BETWEEN_TOKENS)
129
- continue
130
- prompt_text = state["prompt"]
131
- generated_text = state["generated"]
132
- tokens_done = state["tokens_done"]
133
-
134
- # Build input_ids (prompt + generated so far)
135
- full_input = prompt_text + generated_text
136
- input_ids = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids
137
-
138
- # Generate ONE token
139
  with torch.no_grad():
140
- outputs = model.generate(
141
  input_ids,
142
  max_new_tokens=1,
143
  do_sample=True,
144
  temperature=TEMPERATURE,
145
  top_p=TOP_P,
146
  )
147
- next_token_id = outputs[0, -1].unsqueeze(0)
148
- next_token_text = tokenizer.decode(next_token_id, skip_special_tokens=True, clean_up_tokenization_spaces=False)
149
 
150
  with lock:
151
- # Update state
152
- state["generated"] += next_token_text
153
  state["tokens_done"] += 1
154
  if state["tokens_done"] >= TOKENS_PER_PROMPT:
155
  state["finished"] = True
156
- log.info("Prompt complete. Oracle will pick a new one next cycle.")
157
- _write_json(STATE_PATH, state)
158
- time.sleep(SECS_BETWEEN_TOKENS) # pacing
159
-
160
 
161
  threading.Thread(target=oracle_loop, daemon=True).start()
162
 
163
- # --------------------------------------------------------------------------- #
164
- # Gradio Interface #
165
- # --------------------------------------------------------------------------- #
166
 
167
- def human_readable_elapsed(start: float) -> str:
168
- delta = int(time.time() - start)
169
- h, rem = divmod(delta, 3600)
170
- m, s = divmod(rem, 60)
171
- return f"{h}h {m}m {s}s"
172
 
173
 
174
- def get_current_state() -> Dict[str, Any]:
175
- with lock:
176
- state = _read_json(STATE_PATH, {})
177
- if not state:
178
- return {"prompt": "…loading…", "generated": "", "elapsed": "0h 0m 0s"}
179
- return {
180
- "prompt": state["prompt"],
181
- "generated": state["generated"],
182
- "elapsed": human_readable_elapsed(state["start_time"])
183
- }
184
 
 
 
 
185
 
186
- def record_guess(full_guess: str, idea_guess: str):
187
- state = get_current_state()
188
- guess_text = full_guess.strip() or idea_guess.strip()
189
- if not guess_text:
190
- return gr.update(value="⚠️ Please enter a guess in one of the boxes …"), gr.update()
191
- guess_type = "full" if full_guess.strip() else "idea"
192
  record = {
193
  "timestamp": datetime.now(timezone.utc).isoformat(),
194
- "prompt": state["prompt"],
195
- "pointintime": state["elapsed"],
196
- "responsepoint": state["generated"],
197
- "userguess": guess_text,
198
- "guesstype": guess_type
199
  }
200
- # Append to JSONL (data.json)
201
  with lock:
202
- with open(DATA_PATH, "a", encoding="utf8") as f:
203
  f.write(json.dumps(record, ensure_ascii=False) + "\n")
204
- log.info(f"Recorded {guess_type} guess ({len(guess_text)} chars).")
205
- return gr.update(value="✅ Guess recorded – check back when the Oracle finishes!"), gr.update(value="")
206
 
207
 
208
- with gr.Blocks(title="What Comes Next", theme="gradio/soft") as demo:
209
- gr.Markdown("""# ✨ What Comes Next
210
- A global, slow‑burn guessing game. The Oracle is continuously writing its story.
211
- Read the prompt, see the Oracle’s progress, and predict **what comes next**!
212
- *(FP32 CPU inference deliberately unhurried.)*""")
213
 
214
- ### Live Oracle view
215
- prompt_box = gr.Markdown(label="🔮 Current Oracle Prompt")
216
- oracle_box = gr.Textbox(label="📜 Oracle’s current text", lines=10, interactive=False)
217
- elapsed_box = gr.Textbox(label="⏱️ Elapsed", interactive=False)
218
-
219
- ### Guess inputs
220
- gr.Markdown("**Make your prediction:** Fill **either** the exact continuation *or* a general idea.")
221
- with gr.Row():
222
- full_guess = gr.Textbox(label="🧠 Exact continuation (full)")
223
- idea_guess = gr.Textbox(label="💡 General idea")
224
- submit_btn = gr.Button("Submit Guess")
225
- status_msg = gr.Textbox(label="Status", interactive=False)
226
 
227
- ### Refresh button
228
- refresh_btn = gr.Button("🔄 Refresh Oracle progress")
229
 
230
- def refresh():
231
- st = get_current_state()
232
- return st["prompt"], st["generated"], st["elapsed"]
233
-
234
- refresh_btn.click(refresh, outputs=[prompt_box, oracle_box, elapsed_box])
235
- demo.load(refresh, outputs=[prompt_box, oracle_box, elapsed_box]) # auto‑load on launch
236
-
237
- submit_btn.click(record_guess,
238
- inputs=[full_guess, idea_guess],
239
- outputs=[status_msg, full_guess]) # clear full_guess box on success
 
 
 
240
 
241
  if __name__ == "__main__":
242
  demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
1
  #!/usr/bin/env python3
2
  """
3
+ what_comes_next.py – Hugging Face Space implementation of **What Comes Next**
4
+ A global, slow-burn guessing game powered by Llama-3.1-8B-Instruct (FP32, CPU-only).
5
+
6
+ HOW IT WORKS
7
+ ============
8
+ One shared model generates a single, very long completion (≈2 k tokens) for a chosen
9
+ prompt in *full precision* on CPU. One token is sampled every ~15 s, so a prompt
10
+ unfolds for roughly 10 hours. All visitors see the same progress in real-time.
11
+ Players read the partial output and may submit **either**
12
+ 🧠 Exact continuation (full guess) **or** 💡 General idea (summary guess).
13
+ Each guess is appended to `data.json` with prompt, Oracle progress, timestamp & type.
14
+ Offline scoring (not included here) can later measure similarity vs the final text.
 
 
 
 
15
  """
16
 
17
+ from __future__ import annotations
18
+
19
  import os
20
  import json
21
  import time
 
23
  import threading
24
  import logging
25
  from datetime import datetime, timezone
 
26
  from typing import Dict, Any
27
 
28
  import torch
 
29
  import gradio as gr
30
+ from transformers import AutoTokenizer, AutoModelForCausalLM
31
 
32
  ###############################################################################
33
+ # Configuration #
34
  ###############################################################################
35
+ MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" # full-precision model
36
+ PROMPTS_PATH = "full_prompts.json" # 100 full prompts
37
+ STATE_PATH = "current_state.json" # persistent Oracle state
38
+ DATA_PATH = "data.json" # JSONL log of guesses
39
+
40
+ TOKENS_PER_PROMPT = 2048 # stop after N generated tokens
41
+ SECS_BETWEEN_TOKENS = 15 # ~10 h per prompt
42
+ TEMPERATURE = 0.9 # higher creativity, as requested
43
+ TOP_P = 0.95 # nucleus sampling
44
+ MAX_CONTEXT_TOKENS = 8192 # safety cap
45
  ###############################################################################
46
 
47
  logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO)
48
+ log = logging.getLogger("what-comes-next")
 
 
49
 
50
+ ###############################################################################
51
+ # Utility helpers #
52
+ ###############################################################################
53
 
54
  def _read_json(path: str, default: Any):
55
  try:
56
+ with open(path, "r", encoding="utf-8") as f:
57
  return json.load(f)
58
  except FileNotFoundError:
59
  return default
60
 
61
 
62
+ def _atomic_write(path: str, obj: Any):
63
  tmp = f"{path}.tmp"
64
+ with open(tmp, "w", encoding="utf-8") as f:
65
  json.dump(obj, f, ensure_ascii=False, indent=2)
66
  os.replace(tmp, path)
67
 
 
69
  def load_prompts() -> list[str]:
70
  if not os.path.exists(PROMPTS_PATH):
71
  raise FileNotFoundError(f"Missing {PROMPTS_PATH}. Please add 100 prompts.")
72
+ with open(PROMPTS_PATH, "r", encoding="utf-8") as f:
73
+ prompts = json.load(f)
74
+ if not isinstance(prompts, list) or not prompts:
75
+ raise ValueError("full_prompts.json must be a non-empty JSON array of strings")
76
+ return prompts
77
 
78
+ ###############################################################################
79
+ # Model loading #
80
+ ###############################################################################
81
+ log.info("Loading Llama-3.1-8B-Instruct (FP32 CPU-only)… this can take a while.")
 
 
 
82
 
83
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
84
  model = AutoModelForCausalLM.from_pretrained(
85
  MODEL_NAME,
86
  torch_dtype=torch.float32,
87
+ device_map={"": "cpu"}, # force CPU placement
88
  )
89
  model.eval()
90
+ log.info("Model ready – Oracle awakened.")
91
 
92
+ ###############################################################################
93
+ # Global state #
94
+ ###############################################################################
95
+ lock = threading.Lock() # guard state + files
96
+ prompts = load_prompts() # list of 100 strings
97
+
98
+ ###############################################################################
99
  # Oracle generation thread #
100
+ ###############################################################################
101
 
102
+ def _init_state() -> Dict[str, Any]:
103
+ """Return existing state or create a fresh one if none/finished."""
104
  state = _read_json(STATE_PATH, {})
105
+ if not state or state.get("finished"):
 
 
106
  prompt_idx = random.randrange(len(prompts))
 
107
  state = {
108
  "prompt_idx": prompt_idx,
109
+ "prompt": prompts[prompt_idx],
110
+ "generated": "", # text so far
111
+ "tokens_done": 0,
112
  "start_time": time.time(),
113
+ "finished": False
 
114
  }
115
+ _atomic_write(STATE_PATH, state)
116
+ log.info(f"New Oracle prompt #{prompt_idx}: {state['prompt'][:80]}…")
117
  return state
118
 
119
 
120
+ def _elapsed_str(start: float) -> str:
121
+ d = int(time.time() - start)
122
+ h, r = divmod(d, 3600)
123
+ m, s = divmod(r, 60)
124
+ return f"{h}h {m}m {s}s"
125
+
126
+
127
  def oracle_loop():
 
128
  while True:
129
  with lock:
130
+ state = _init_state()
131
+ if state["finished"]:
132
+ time.sleep(SECS_BETWEEN_TOKENS)
133
+ continue
134
+
135
+ # Build context: prompt + generated so far
136
+ context = state["prompt"] + state["generated"]
137
+ input_ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids
138
+
139
+ # Sample one token
 
 
 
 
140
  with torch.no_grad():
141
+ out = model.generate(
142
  input_ids,
143
  max_new_tokens=1,
144
  do_sample=True,
145
  temperature=TEMPERATURE,
146
  top_p=TOP_P,
147
  )
148
+ next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False)
 
149
 
150
  with lock:
151
+ state["generated"] += next_token
 
152
  state["tokens_done"] += 1
153
  if state["tokens_done"] >= TOKENS_PER_PROMPT:
154
  state["finished"] = True
155
+ log.info("Prompt completed Oracle will select a new one shortly.")
156
+ _atomic_write(STATE_PATH, state)
157
+ time.sleep(SECS_BETWEEN_TOKENS)
 
158
 
159
  threading.Thread(target=oracle_loop, daemon=True).start()
160
 
161
+ ###############################################################################
162
+ # Gradio interface #
163
+ ###############################################################################
164
 
165
+ def fetch_state() -> tuple[str, str, str]:
166
+ state = _read_json(STATE_PATH, {})
167
+ if not state:
168
+ return "Loading…", "", "0h 0m 0s"
169
+ return state["prompt"], state["generated"], _elapsed_str(state["start_time"])
170
 
171
 
172
+ def submit_guess(full: str, idea: str):
173
+ full = full.strip()
174
+ idea = idea.strip()
175
+ if not full and not idea:
176
+ return gr.update(value="⚠️ Enter a guess in one of the fields."), gr.update(), gr.update()
 
 
 
 
 
177
 
178
+ prompt, generated, elapsed = fetch_state()
179
+ guess_text = full or idea
180
+ guess_type = "full" if full else "idea"
181
 
 
 
 
 
 
 
182
  record = {
183
  "timestamp": datetime.now(timezone.utc).isoformat(),
184
+ "prompt": prompt,
185
+ "point-in-time": elapsed,
186
+ "response-point": generated,
187
+ "user-guess": guess_text,
188
+ "guess-type": guess_type
189
  }
 
190
  with lock:
191
+ with open(DATA_PATH, "a", encoding="utf-8") as f:
192
  f.write(json.dumps(record, ensure_ascii=False) + "\n")
193
+ log.info(f"Logged {guess_type} guess ({len(guess_text)} chars).")
194
+ return gr.update(value="✅ Guess recorded – thanks!"), gr.update(value=""), gr.update(value="")
195
 
196
 
197
+ with gr.Blocks(title="What Comes Next", theme="gradio/soft") as demo:
198
+ gr.Markdown("""# 🌌 What Comes Next
199
+ Watch the Oracle craft an extended response **one token at a time**. Predict its
200
+ next words or general direction and see how close you were when the tale concludes.
201
+ (All inputs are stored in `data.json` for research.)""")
202
 
203
+ prompt_md = gr.Markdown()
204
+ oracle_box = gr.Textbox(lines=10, interactive=False, label="📜 Oracle text so far")
205
+ elapsed_tb = gr.Textbox(interactive=False, label=" Elapsed time")
 
 
 
 
 
 
 
 
 
206
 
207
+ refresh_btn = gr.Button("🔄 Refresh")
 
208
 
209
+ with gr.Row():
210
+ exact_tb = gr.Textbox(label="🧠 Exact continuation (full)")
211
+ idea_tb = gr.Textbox(label="💡 General idea")
212
+ submit_btn = gr.Button("Submit Guess")
213
+ status_tb = gr.Textbox(interactive=False, label="Status")
214
+
215
+ # Actions
216
+ refresh_btn.click(fetch_state, outputs=[prompt_md, oracle_box, elapsed_tb])
217
+ demo.load(fetch_state, outputs=[prompt_md, oracle_box, elapsed_tb])
218
+
219
+ submit_btn.click(submit_guess,
220
+ inputs=[exact_tb, idea_tb],
221
+ outputs=[status_tb, exact_tb, idea_tb])
222
 
223
  if __name__ == "__main__":
224
  demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)