ProCreations commited on
Commit
a210442
·
verified ·
1 Parent(s): fdc7b1e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import threading
5
+ import logging
6
+ import sqlite3
7
+ from datetime import datetime
8
+
9
+ import gradio as gr
10
+ import torch
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
+ from sentence_transformers import SentenceTransformer, util
13
+
14
+ # Logging setup
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Load Oracle model (FP32, CPU-only)
19
+ logger.info("Loading Oracle model...")
20
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ "meta-llama/Llama-3.1-8B-Instruct",
23
+ torch_dtype=torch.float32,
24
+ device_map="cpu"
25
+ )
26
+ model.eval()
27
+
28
+ # Load SentenceTransformer for semantic similarity
29
+ logger.info("Loading SentenceTransformer model...")
30
+ st_model = SentenceTransformer('all-MiniLM-L6-v2')
31
+
32
+ # Database setup (SQLite)
33
+ DB_PATH = "game_data.db"
34
+ conn = sqlite3.connect(DB_PATH, check_same_thread=False)
35
+ c = conn.cursor()
36
+ c.execute("""
37
+ CREATE TABLE IF NOT EXISTS rounds (
38
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
39
+ timestamp TEXT,
40
+ prompt TEXT,
41
+ full_guess TEXT,
42
+ idea_guess TEXT,
43
+ completion TEXT,
44
+ score_full INTEGER,
45
+ score_idea INTEGER,
46
+ round_points INTEGER
47
+ )
48
+ """)
49
+ conn.commit()
50
+
51
+ # Load prompts from JSON
52
+ PROMPTS_PATH = "oracle_prompts.json"
53
+ with open(PROMPTS_PATH, 'r') as f:
54
+ PROMPTS = json.load(f)
55
+
56
+ # Helper functions
57
+ def get_next_prompt(state):
58
+ if not state["prompts"]:
59
+ prompts = PROMPTS.copy()
60
+ random.shuffle(prompts)
61
+ state["prompts"] = prompts
62
+ state["used"] = []
63
+ prompt = state["prompts"].pop(0)
64
+ state["used"].append(prompt)
65
+ state["round"] += 1
66
+ return prompt
67
+
68
+
69
+ def compute_score(guess, completion):
70
+ if not guess.strip():
71
+ return 0
72
+ emb_guess = st_model.encode(guess, convert_to_tensor=True)
73
+ emb_comp = st_model.encode(completion, convert_to_tensor=True)
74
+ cos_sim = util.pytorch_cos_sim(emb_guess, emb_comp).item()
75
+ if cos_sim > 0.9:
76
+ return 5
77
+ elif cos_sim > 0.7:
78
+ return 3
79
+ elif cos_sim > 0.5:
80
+ return 1
81
+ else:
82
+ return 0
83
+
84
+
85
+ def log_round(prompt, full_guess, idea_guess, completion, score_full, score_idea, round_points):
86
+ ts = datetime.utcnow().isoformat()
87
+ c.execute(
88
+ "INSERT INTO rounds (timestamp, prompt, full_guess, idea_guess, completion, score_full, score_idea, round_points) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
89
+ (ts, prompt, full_guess, idea_guess, completion, score_full, score_idea, round_points)
90
+ )
91
+ conn.commit()
92
+ logger.info(f"Round logged at {ts}")
93
+
94
+
95
+ def play_round(full_guess, idea_guess, state):
96
+ prompt = state.get("current_prompt", "")
97
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
98
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
99
+ def generate():
100
+ model.generate(
101
+ input_ids=input_ids,
102
+ max_new_tokens=200,
103
+ do_sample=True,
104
+ temperature=0.8,
105
+ streamer=streamer
106
+ )
107
+ thread = threading.Thread(target=generate)
108
+ thread.start()
109
+ completion = ""
110
+ for token in streamer:
111
+ completion += token
112
+ yield completion, "", ""
113
+ score_full = compute_score(full_guess, completion)
114
+ score_idea = compute_score(idea_guess, completion)
115
+ round_points = score_full + score_idea
116
+ state["score"] += round_points
117
+ log_round(prompt, full_guess, idea_guess, completion, score_full, score_idea, round_points)
118
+ score_text = f"Full Guess: {score_full} pts | Idea Guess: {score_idea} pts | Round Total: {round_points} pts"
119
+ reflection = "🔮 The Oracle ponders your insights..."
120
+ if state["round"] >= 5 and state["score"] >= 15:
121
+ secret = random.choice([p for p in PROMPTS if p not in state["used"]])
122
+ reflection += f"\n\n✨ **Secret Oracle Prompt:** {secret}"
123
+ yield completion, score_text, reflection, state["score"]
124
+
125
+
126
+ def next_round_fn(state):
127
+ prompt = get_next_prompt(state)
128
+ state["current_prompt"] = prompt
129
+ return prompt, "", "", "", "", "", state["score"]
130
+
131
+ # Gradio UI
132
+ demo = gr.Blocks()
133
+ with demo:
134
+ state = gr.State({"prompts": [], "used": [], "round": 0, "score": 0, "current_prompt": ""})
135
+ gr.Markdown("⚠️ **Your input and the Oracle’s response will be stored for AI training and research. By playing, you consent to this.**")
136
+ prompt_display = gr.Markdown("", elem_id="prompt_display")
137
+ with gr.Row():
138
+ full_guess = gr.Textbox(label="🧠 Exact Full Completion Guess")
139
+ idea_guess = gr.Textbox(label="💡 General Idea Guess")
140
+ submit = gr.Button("Submit Guess")
141
+ completion_box = gr.Textbox(label="Oracle's Completion", interactive=False)
142
+ score_box = gr.Textbox(label="Score", interactive=False)
143
+ reflection_box = gr.Textbox(label="Mystical Reflection", interactive=False)
144
+ next_btn = gr.Button("Next Round")
145
+ total_score_display = gr.Textbox(label="Total Score", interactive=False)
146
+
147
+ next_btn.click(next_round_fn, inputs=state, outputs=[prompt_display, full_guess, idea_guess, completion_box, score_box, reflection_box, total_score_display])
148
+ submit.click(play_round, inputs=[full_guess, idea_guess, state], outputs=[completion_box, score_box, reflection_box, total_score_display])
149
+
150
+ if __name__ == "__main__":
151
+ demo.launch(server_name="0.0.0.0", server_port=7860)