ProCreations commited on
Commit
b7304c4
·
verified ·
1 Parent(s): a210442

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -135
app.py CHANGED
@@ -1,151 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import json
 
3
  import random
4
  import threading
5
  import logging
6
- import sqlite3
7
- from datetime import datetime
 
8
 
9
- import gradio as gr
10
  import torch
11
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
- from sentence_transformers import SentenceTransformer, util
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Logging setup
15
- logging.basicConfig(level=logging.INFO)
16
- logger = logging.getLogger(__name__)
17
 
18
- # Load Oracle model (FP32, CPU-only)
19
- logger.info("Loading Oracle model...")
20
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
 
 
 
 
 
 
 
 
 
 
 
 
21
  model = AutoModelForCausalLM.from_pretrained(
22
- "meta-llama/Llama-3.1-8B-Instruct",
23
  torch_dtype=torch.float32,
24
- device_map="cpu"
25
  )
26
  model.eval()
27
-
28
- # Load SentenceTransformer for semantic similarity
29
- logger.info("Loading SentenceTransformer model...")
30
- st_model = SentenceTransformer('all-MiniLM-L6-v2')
31
-
32
- # Database setup (SQLite)
33
- DB_PATH = "game_data.db"
34
- conn = sqlite3.connect(DB_PATH, check_same_thread=False)
35
- c = conn.cursor()
36
- c.execute("""
37
- CREATE TABLE IF NOT EXISTS rounds (
38
- id INTEGER PRIMARY KEY AUTOINCREMENT,
39
- timestamp TEXT,
40
- prompt TEXT,
41
- full_guess TEXT,
42
- idea_guess TEXT,
43
- completion TEXT,
44
- score_full INTEGER,
45
- score_idea INTEGER,
46
- round_points INTEGER
47
- )
48
- """)
49
- conn.commit()
50
-
51
- # Load prompts from JSON
52
- PROMPTS_PATH = "oracle_prompts.json"
53
- with open(PROMPTS_PATH, 'r') as f:
54
- PROMPTS = json.load(f)
55
-
56
- # Helper functions
57
- def get_next_prompt(state):
58
- if not state["prompts"]:
59
- prompts = PROMPTS.copy()
60
- random.shuffle(prompts)
61
- state["prompts"] = prompts
62
- state["used"] = []
63
- prompt = state["prompts"].pop(0)
64
- state["used"].append(prompt)
65
- state["round"] += 1
66
- return prompt
67
-
68
-
69
- def compute_score(guess, completion):
70
- if not guess.strip():
71
- return 0
72
- emb_guess = st_model.encode(guess, convert_to_tensor=True)
73
- emb_comp = st_model.encode(completion, convert_to_tensor=True)
74
- cos_sim = util.pytorch_cos_sim(emb_guess, emb_comp).item()
75
- if cos_sim > 0.9:
76
- return 5
77
- elif cos_sim > 0.7:
78
- return 3
79
- elif cos_sim > 0.5:
80
- return 1
81
- else:
82
- return 0
83
-
84
-
85
- def log_round(prompt, full_guess, idea_guess, completion, score_full, score_idea, round_points):
86
- ts = datetime.utcnow().isoformat()
87
- c.execute(
88
- "INSERT INTO rounds (timestamp, prompt, full_guess, idea_guess, completion, score_full, score_idea, round_points) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
89
- (ts, prompt, full_guess, idea_guess, completion, score_full, score_idea, round_points)
90
- )
91
- conn.commit()
92
- logger.info(f"Round logged at {ts}")
93
-
94
-
95
- def play_round(full_guess, idea_guess, state):
96
- prompt = state.get("current_prompt", "")
97
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
98
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
99
- def generate():
100
- model.generate(
101
- input_ids=input_ids,
102
- max_new_tokens=200,
103
- do_sample=True,
104
- temperature=0.8,
105
- streamer=streamer
106
- )
107
- thread = threading.Thread(target=generate)
108
- thread.start()
109
- completion = ""
110
- for token in streamer:
111
- completion += token
112
- yield completion, "", ""
113
- score_full = compute_score(full_guess, completion)
114
- score_idea = compute_score(idea_guess, completion)
115
- round_points = score_full + score_idea
116
- state["score"] += round_points
117
- log_round(prompt, full_guess, idea_guess, completion, score_full, score_idea, round_points)
118
- score_text = f"Full Guess: {score_full} pts | Idea Guess: {score_idea} pts | Round Total: {round_points} pts"
119
- reflection = "🔮 The Oracle ponders your insights..."
120
- if state["round"] >= 5 and state["score"] >= 15:
121
- secret = random.choice([p for p in PROMPTS if p not in state["used"]])
122
- reflection += f"\n\n✨ **Secret Oracle Prompt:** {secret}"
123
- yield completion, score_text, reflection, state["score"]
124
-
125
-
126
- def next_round_fn(state):
127
- prompt = get_next_prompt(state)
128
- state["current_prompt"] = prompt
129
- return prompt, "", "", "", "", "", state["score"]
130
-
131
- # Gradio UI
132
- demo = gr.Blocks()
133
- with demo:
134
- state = gr.State({"prompts": [], "used": [], "round": 0, "score": 0, "current_prompt": ""})
135
- gr.Markdown("⚠️ **Your input and the Oracle’s response will be stored for AI training and research. By playing, you consent to this.**")
136
- prompt_display = gr.Markdown("", elem_id="prompt_display")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  with gr.Row():
138
- full_guess = gr.Textbox(label="🧠 Exact Full Completion Guess")
139
- idea_guess = gr.Textbox(label="💡 General Idea Guess")
140
- submit = gr.Button("Submit Guess")
141
- completion_box = gr.Textbox(label="Oracle's Completion", interactive=False)
142
- score_box = gr.Textbox(label="Score", interactive=False)
143
- reflection_box = gr.Textbox(label="Mystical Reflection", interactive=False)
144
- next_btn = gr.Button("Next Round")
145
- total_score_display = gr.Textbox(label="Total Score", interactive=False)
146
-
147
- next_btn.click(next_round_fn, inputs=state, outputs=[prompt_display, full_guess, idea_guess, completion_box, score_box, reflection_box, total_score_display])
148
- submit.click(play_round, inputs=[full_guess, idea_guess, state], outputs=[completion_box, score_box, reflection_box, total_score_display])
 
 
 
 
 
 
 
149
 
150
  if __name__ == "__main__":
151
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ what_comes_next.py – Hugging Face Space implementation of **What Comes Next**
4
+ A slow, contemplative global guessing game.
5
+
6
+ 🔮 HOW IT WORKS 🔮
7
+ • A single Llama‑3.1‑8B‑Instruct model (FP32 on CPU) is generating one very long completion
8
+ for a chosen mystical prompt. It runs continuously in the background for everyone.
9
+ • Any visitor sees the same prompt and the Oracle’s current partial response.
10
+ • Players may submit *one* of two kinds of guesses:
11
+ 1. 🧠 **Exact Completion** – the full sentence/paragraph they think the Oracle will
12
+ eventually write.
13
+ 2. 💡 **General Idea** – a short summary of the direction or theme they expect.
14
+ • Each guess is recorded immediately (with timestamp, Oracle progress, etc.) to
15
+ `data.json` (JSON‑Lines). When the Oracle finally finishes, offline evaluation can
16
+ score the guesses against the final text.
17
+
18
+ The game then moves on to the next prompt and the cycle repeats.
19
+ """
20
+
21
  import os
22
  import json
23
+ import time
24
  import random
25
  import threading
26
  import logging
27
+ from datetime import datetime, timezone
28
+ from pathlib import Path
29
+ from typing import Dict, Any
30
 
 
31
  import torch
32
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
33
+ import gradio as gr
34
+
35
+ ###############################################################################
36
+ # Settings #
37
+ ###############################################################################
38
+ MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" # FP32, CPU‑only
39
+ PROMPTS_PATH = "oracle_prompts.json" # 100 unfinished lines
40
+ STATE_PATH = "current_state.json" # persistent Oracle state
41
+ DATA_PATH = "data.json" # JSONL of user guesses
42
+ TOKENS_PER_PROMPT = 2048 # stop after N tokens
43
+ SECS_BETWEEN_TOKENS = 15 # pacing (≈10h / prompt)
44
+ TEMPERATURE = 0.8
45
+ TOP_P = 0.95
46
+ MAX_CONTEXT_TOKENS = 8192
47
+ ###############################################################################
48
+
49
+ logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO)
50
+ log = logging.getLogger("what‑comes‑next")
51
+
52
+ lock = threading.Lock() # global file/variable lock
53
+
54
+ # --------------------------------------------------------------------------- #
55
+ # Helper functions #
56
+ # --------------------------------------------------------------------------- #
57
+
58
+ def _read_json(path: str, default: Any):
59
+ try:
60
+ with open(path, "r", encoding="utf‑8") as f:
61
+ return json.load(f)
62
+ except FileNotFoundError:
63
+ return default
64
+
65
+
66
+ def _write_json(path: str, obj: Any):
67
+ tmp = f"{path}.tmp"
68
+ with open(tmp, "w", encoding="utf‑8") as f:
69
+ json.dump(obj, f, ensure_ascii=False, indent=2)
70
+ os.replace(tmp, path)
71
 
 
 
 
72
 
73
+ def load_prompts() -> list[str]:
74
+ if not os.path.exists(PROMPTS_PATH):
75
+ raise FileNotFoundError(f"Missing {PROMPTS_PATH}. Please add 100 prompts.")
76
+ with open(PROMPTS_PATH, "r", encoding="utf‑8") as f:
77
+ return json.load(f)
78
+
79
+
80
+ prompts = load_prompts()
81
+
82
+ # --------------------------------------------------------------------------- #
83
+ # Model loading (FP32 ‑ CPU) #
84
+ # --------------------------------------------------------------------------- #
85
+ log.info("Loading Llama‑3.1‑8B‑Instruct in FP32 on CPU (this is *slow*) …")
86
+
87
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
88
  model = AutoModelForCausalLM.from_pretrained(
89
+ MODEL_NAME,
90
  torch_dtype=torch.float32,
91
+ device_map={"": "cpu"}, # force CPU
92
  )
93
  model.eval()
94
+ log.info("Model loaded.")
95
+
96
+ # --------------------------------------------------------------------------- #
97
+ # Oracle generation thread #
98
+ # --------------------------------------------------------------------------- #
99
+
100
+ def init_state() -> Dict[str, Any]:
101
+ """Return existing state or create a new one."""
102
+ state = _read_json(STATE_PATH, {})
103
+ if state.get("finished", False):
104
+ state = {} # finished, start new prompt
105
+ if not state:
106
+ prompt_idx = random.randrange(len(prompts))
107
+ prompt = prompts[prompt_idx]
108
+ state = {
109
+ "prompt_idx": prompt_idx,
110
+ "prompt": prompt,
111
+ "generated": "", # Oracle’s text so far (string)
112
+ "start_time": time.time(),
113
+ "finished": False,
114
+ "tokens_done": 0
115
+ }
116
+ _write_json(STATE_PATH, state)
117
+ log.info(f"Starting new Oracle prompt #{prompt_idx}: {prompt[:60]}…")
118
+ return state
119
+
120
+
121
+ def oracle_loop():
122
+ """Continuously extend the Oracle’s text by one token every SECS_BETWEEN_TOKENS."""
123
+ while True:
124
+ with lock:
125
+ state = init_state()
126
+ if state["finished"]:
127
+ # Should not happen, but guard anyway
128
+ time.sleep(SECS_BETWEEN_TOKENS)
129
+ continue
130
+ prompt_text = state["prompt"]
131
+ generated_text = state["generated"]
132
+ tokens_done = state["tokens_done"]
133
+
134
+ # Build input_ids (prompt + generated so far)
135
+ full_input = prompt_text + generated_text
136
+ input_ids = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=MAX_CONTEXT_TOKENS).input_ids
137
+
138
+ # Generate ONE token
139
+ with torch.no_grad():
140
+ outputs = model.generate(
141
+ input_ids,
142
+ max_new_tokens=1,
143
+ do_sample=True,
144
+ temperature=TEMPERATURE,
145
+ top_p=TOP_P,
146
+ )
147
+ next_token_id = outputs[0, -1].unsqueeze(0)
148
+ next_token_text = tokenizer.decode(next_token_id, skip_special_tokens=True, clean_up_tokenization_spaces=False)
149
+
150
+ with lock:
151
+ # Update state
152
+ state["generated"] += next_token_text
153
+ state["tokens_done"] += 1
154
+ if state["tokens_done"] >= TOKENS_PER_PROMPT:
155
+ state["finished"] = True
156
+ log.info("Prompt complete. Oracle will pick a new one next cycle.")
157
+ _write_json(STATE_PATH, state)
158
+ time.sleep(SECS_BETWEEN_TOKENS) # pacing
159
+
160
+
161
+ threading.Thread(target=oracle_loop, daemon=True).start()
162
+
163
+ # --------------------------------------------------------------------------- #
164
+ # Gradio Interface #
165
+ # --------------------------------------------------------------------------- #
166
+
167
+ def human_readable_elapsed(start: float) -> str:
168
+ delta = int(time.time() - start)
169
+ h, rem = divmod(delta, 3600)
170
+ m, s = divmod(rem, 60)
171
+ return f"{h}h {m}m {s}s"
172
+
173
+
174
+ def get_current_state() -> Dict[str, Any]:
175
+ with lock:
176
+ state = _read_json(STATE_PATH, {})
177
+ if not state:
178
+ return {"prompt": "…loading…", "generated": "", "elapsed": "0h 0m 0s"}
179
+ return {
180
+ "prompt": state["prompt"],
181
+ "generated": state["generated"],
182
+ "elapsed": human_readable_elapsed(state["start_time"])
183
+ }
184
+
185
+
186
+ def record_guess(full_guess: str, idea_guess: str):
187
+ state = get_current_state()
188
+ guess_text = full_guess.strip() or idea_guess.strip()
189
+ if not guess_text:
190
+ return gr.update(value="⚠️ Please enter a guess in one of the boxes …"), gr.update()
191
+ guess_type = "full" if full_guess.strip() else "idea"
192
+ record = {
193
+ "timestamp": datetime.now(timezone.utc).isoformat(),
194
+ "prompt": state["prompt"],
195
+ "point‑in‑time": state["elapsed"],
196
+ "response‑point": state["generated"],
197
+ "user‑guess": guess_text,
198
+ "guess‑type": guess_type
199
+ }
200
+ # Append to JSONL (data.json)
201
+ with lock:
202
+ with open(DATA_PATH, "a", encoding="utf‑8") as f:
203
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
204
+ log.info(f"Recorded {guess_type} guess ({len(guess_text)} chars).")
205
+ return gr.update(value="✅ Guess recorded – check back when the Oracle finishes!"), gr.update(value="")
206
+
207
+
208
+ with gr.Blocks(title="What Comes Next", theme="gradio/soft") as demo:
209
+ gr.Markdown("""# ✨ What Comes Next
210
+ A global, slow‑burn guessing game. The Oracle is continuously writing its story.
211
+ Read the prompt, see the Oracle’s progress, and predict **what comes next**!
212
+ *(FP32 CPU inference – deliberately unhurried.)*""")
213
+
214
+ ### Live Oracle view
215
+ prompt_box = gr.Markdown(label="🔮 Current Oracle Prompt")
216
+ oracle_box = gr.Textbox(label="📜 Oracle’s current text", lines=10, interactive=False)
217
+ elapsed_box = gr.Textbox(label="⏱️ Elapsed", interactive=False)
218
+
219
+ ### Guess inputs
220
+ gr.Markdown("**Make your prediction:** Fill **either** the exact continuation *or* a general idea.")
221
  with gr.Row():
222
+ full_guess = gr.Textbox(label="🧠 Exact continuation (full)")
223
+ idea_guess = gr.Textbox(label="💡 General idea")
224
+ submit_btn = gr.Button("Submit Guess")
225
+ status_msg = gr.Textbox(label="Status", interactive=False)
226
+
227
+ ### Refresh button
228
+ refresh_btn = gr.Button("🔄 Refresh Oracle progress")
229
+
230
+ def refresh():
231
+ st = get_current_state()
232
+ return st["prompt"], st["generated"], st["elapsed"]
233
+
234
+ refresh_btn.click(refresh, outputs=[prompt_box, oracle_box, elapsed_box])
235
+ demo.load(refresh, outputs=[prompt_box, oracle_box, elapsed_box]) # auto‑load on launch
236
+
237
+ submit_btn.click(record_guess,
238
+ inputs=[full_guess, idea_guess],
239
+ outputs=[status_msg, full_guess]) # clear full_guess box on success
240
 
241
  if __name__ == "__main__":
242
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)