#!/usr/bin/env python3 import os, json, time, random, threading, logging from datetime import datetime, timezone import torch; torch.set_num_threads(os.cpu_count()); torch.set_num_interop_threads(os.cpu_count()) import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct" PROMPTS_PATH = "full_prompts.json" STATE_PATH = "current_state.json" DATA_PATH = "data.json" TOKENS_PER_PROMPT = 2048 SECS_PER_TOKEN = 15 TEMP = 0.9; TOP_P = 0.95; MAX_CTX = 8192 logging.basicConfig(level=logging.INFO) log = logging.getLogger() def _rj(p, d): try: return json.load(open(p, encoding="utf-8")) except: return d def _aw(p, o): t = p + ".tmp" open(t, "w", encoding="utf-8").write(json.dumps(o, ensure_ascii=False, indent=2)) os.replace(t, p) prompts = _rj(PROMPTS_PATH, []) if not prompts: raise Exception("No prompts found in full_prompts.json") tok = os.environ.get("HF_READ_TOKEN") log.info("Loading model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=tok) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=False, token=tok ) model.to("cpu"); model.eval() log.info("Model is ready.") lock = threading.Lock() def _init(): state = _rj(STATE_PATH, {}) if not state or state.get("finished"): idx = random.randrange(len(prompts)) state = {"i": idx, "p": prompts[idx], "g": "", "c": 0, "t": time.time(), "finished": False} _aw(STATE_PATH, state) return state def _es(start_time): elapsed = int(time.time() - start_time) h, rem = divmod(elapsed, 3600) m, s = divmod(rem, 60) return f"{h}h {m}m {s}s" def _loop(): while True: with lock: st = _init() if st["finished"]: time.sleep(SECS_PER_TOKEN) continue context = st["p"] + st["g"] ids = tokenizer(context, return_tensors="pt", truncation=True, max_length=MAX_CTX).input_ids with torch.no_grad(): out = model.generate( ids, max_new_tokens=1, do_sample=True, temperature=TEMP, top_p=TOP_P ) next_token = tokenizer.decode(out[0, -1], skip_special_tokens=True, clean_up_tokenization_spaces=False) with lock: st["g"] += next_token st["c"] += 1 if st["c"] >= TOKENS_PER_PROMPT: st["finished"] = True _aw(STATE_PATH, st) time.sleep(SECS_PER_TOKEN) threading.Thread(target=_loop, daemon=True).start() def _fetch(): state = _rj(STATE_PATH, {}) if not state: return "...", "", "0h 0m 0s" return state["p"], state["g"], _es(state["t"]) def _submit_prediction(detailed, summary): det = detailed.strip() if not det: return gr.update(value="Please enter at least a detailed prediction."), gr.update(value=""), gr.update(value="") prompt_text, oracle_resp, elapsed = _fetch() record = { "ts": datetime.now(timezone.utc).isoformat(), "prompt": prompt_text, "time": elapsed, "resp": oracle_resp, "prediction": det, "summary": summary.strip() } with lock: open(DATA_PATH, "a", encoding="utf-8").write(json.dumps(record, ensure_ascii=False) + "\n") return gr.update(value="Prediction logged!"), gr.update(value=""), gr.update(value="") with gr.Blocks(theme="darkdefault") as demo: gr.Markdown( "# What Comes Next\n" "Enter what you think will come next in the text.\n" "Provide a detailed continuation and optionally a brief summary for context." ) prompt_md = gr.Markdown() oracle_output = gr.Textbox(lines=10, interactive=False, label="Oracle Response") time_info = gr.Textbox(interactive=False, label="Elapsed Time") with gr.Row(): prompt_md, oracle_output, time_info detailed = gr.Textbox( label="Your Detailed Prediction", placeholder="Enter the full text continuation you expect...", lines=3 ) summary = gr.Textbox( label="Prediction Summary (Optional)", placeholder="Optionally, summarize your prediction in a few words...", lines=2 ) status = gr.Textbox(interactive=False, label="Status") submit_btn = gr.Button("Submit Prediction") refresh_btn = gr.Button("Refresh Oracle") demo.load(_fetch, outputs=[prompt_md, oracle_output, time_info]) refresh_btn.click(_fetch, outputs=[prompt_md, oracle_output, time_info]) submit_btn.click( _submit_prediction, inputs=[detailed, summary], outputs=[status, detailed, summary] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)