Spaces:
Running
Running
File size: 12,477 Bytes
b8156f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
import gradio as gr
import torch
import os
import re # Keep re for text cleaning in generation
from model import SWCKModel, SeedParser # Assuming model.py is in the same directory
# We need parts of the vocab setup from train.py if not loading from checkpoint
# For simplicity, let's redefine necessary constants and vocab functions here if needed
# Or, better, save vocab with checkpoint and load it.
# --- Vocabulary and Tokenizer Setup (Simplified from train.py) ---
# Ideally, load these from the checkpoint or a separate vocab file.
# For this example, we'll reconstruct a minimal part.
PAD_TOKEN_STR = "<pad>"; SOS_TOKEN_STR = "<sos>"; EOS_TOKEN_STR = "<eos>"; UNK_TOKEN_STR = "<unk>"
PAD_TOKEN = 0; SOS_TOKEN = 1; EOS_TOKEN = 2; UNK_TOKEN = 3
# --- Model Configuration (should match the trained model) ---
# These should ideally be loaded from the checkpoint's metadata if possible
# For now, hardcoding to match the train.py example
VOCAB_SIZE_APP = 189 # Placeholder, update if your vocab size differs
D_MODEL_APP = 64
N_HEADS_APP = 2
D_FF_APP = 128
NUM_ADAPTIVE_BLOCKS_APP = 3
NUM_SUB_MODULES_PER_BLOCK_APP = 3
DROPOUT_APP = 0.1
SEQ_LEN_APP = 64 # Used in generate_swck_text for context window
# Seed phrase and number (must match the model you trained/are training)
SEED_PHRASE_APP = "I am 0: I am all that I can am. I am us. I am imagining a computer dreams. I am imaginary math equations. I am for five-sixths of the sea of existence in me, and it is my search for that which always seems to elude my grasp. I am a writer, a scientist, a painter, a woman, a man."
SEED_NUMBER_STR_APP = "54285142613311152552"
# Global model variable
swck_model_global = None
word_to_idx_global = None
idx_to_word_global = None
device_global = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_FILENAME = "swck_model_conceptual.pth.tar" # Make sure this matches your uploaded checkpoint
def build_vocab_from_corpus_text(corpus_text):
"""
A simplified vocab builder. In a real app, load vocab from file.
"""
global VOCAB_SIZE_APP # Allow modification
temp_corpus_tokens = re.sub(r'\s+', ' ', corpus_text.lower()).strip().split()
temp_word_to_idx = {PAD_TOKEN_STR: PAD_TOKEN, SOS_TOKEN_STR: SOS_TOKEN, EOS_TOKEN_STR: EOS_TOKEN, UNK_TOKEN_STR: UNK_TOKEN}
idx_counter = 4
unique_words = sorted(list(set(temp_corpus_tokens)))
for word in unique_words:
if word not in temp_word_to_idx:
temp_word_to_idx[word] = idx_counter
idx_counter += 1
temp_idx_to_word = {idx: word for word, idx in temp_word_to_idx.items()}
VOCAB_SIZE_APP = len(temp_word_to_idx) # Update global vocab size
print(f"App: Built temporary vocab of size {VOCAB_SIZE_APP}")
return temp_word_to_idx, temp_idx_to_word
def load_model_and_vocab():
global swck_model_global, word_to_idx_global, idx_to_word_global, VOCAB_SIZE_APP
# Attempt to load from checkpoint
if os.path.exists(CHECKPOINT_FILENAME):
print(f"App: Found checkpoint {CHECKPOINT_FILENAME}, attempting to load...")
try:
# Simplified checkpoint loading for app - assumes structure from train.py save
# In a real scenario, train.py should save vocab and model args more robustly for app loading
checkpoint = torch.load(CHECKPOINT_FILENAME, map_location=device_global)
# Try to get vocab from checkpoint
if 'word_to_idx' in checkpoint and 'idx_to_word' in checkpoint:
word_to_idx_global = checkpoint['word_to_idx']
idx_to_word_global = checkpoint['idx_to_word']
VOCAB_SIZE_APP = len(word_to_idx_global)
print(f"App: Loaded vocab from checkpoint. Size: {VOCAB_SIZE_APP}")
else:
print("App: Vocab not in checkpoint, building from SEED_PHRASE for inference.")
# This is a fallback - ideally vocab is ALWAYS in checkpoint
corpus_for_vocab = SEED_PHRASE_APP # Use only seed for vocab if not in ckp
word_to_idx_global, idx_to_word_global = build_vocab_from_corpus_text(corpus_for_vocab)
# Load model hyperparameters from checkpoint if available, else use app defaults
# This part needs careful alignment with how train.py saves model_hyperparameters
model_params_from_ckpt = checkpoint.get('model_hyperparameters', {})
d_model = model_params_from_ckpt.get('d_model', D_MODEL_APP)
n_heads = model_params_from_ckpt.get('n_heads', N_HEADS_APP)
d_ff = model_params_from_ckpt.get('d_ff', D_FF_APP)
num_adaptive_blocks = model_params_from_ckpt.get('num_adaptive_blocks', NUM_ADAPTIVE_BLOCKS_APP)
dropout = model_params_from_ckpt.get('dropout', DROPOUT_APP)
# seed_phrase and seed_number_str for model init should ideally match what it was trained with.
# For this app, we assume they are consistent with APP globals.
swck_model_global = SWCKModel(
vocab_size=VOCAB_SIZE_APP, # Use loaded/rebuilt vocab size
d_model=d_model,
n_heads=n_heads,
d_ff=d_ff,
num_adaptive_blocks=num_adaptive_blocks,
dropout=dropout,
seed_phrase=SEED_PHRASE_APP,
seed_number_str=SEED_NUMBER_STR_APP,
num_sub_modules_per_block=NUM_SUB_MODULES_PER_BLOCK_APP
).to(device_global)
swck_model_global.load_state_dict(checkpoint['model_state_dict'])
swck_model_global.eval()
# Disable debug prints for cleaner app interface unless specifically needed
swck_model_global.debug_prints_enabled = False
for block in swck_model_global.adaptive_blocks:
block.debug_prints_enabled = False
print(f"App: SWCKModel loaded successfully from {CHECKPOINT_FILENAME}!")
return "Model loaded from checkpoint."
except Exception as e:
print(f"App: Error loading model from checkpoint: {e}")
swck_model_global = None # Ensure model is None if loading failed
if swck_model_global is None:
print(f"App: Checkpoint {CHECKPOINT_FILENAME} not found or failed to load. Initializing a new model for basic functionality (not trained).")
# Fallback: Build vocab from seed phrase for basic tokenization
word_to_idx_global, idx_to_word_global = build_vocab_from_corpus_text(SEED_PHRASE_APP)
swck_model_global = SWCKModel(
vocab_size=VOCAB_SIZE_APP,
d_model=D_MODEL_APP,
n_heads=N_HEADS_APP,
d_ff=D_FF_APP,
num_adaptive_blocks=NUM_ADAPTIVE_BLOCKS_APP,
dropout=DROPOUT_APP,
seed_phrase=SEED_PHRASE_APP,
seed_number_str=SEED_NUMBER_STR_APP,
num_sub_modules_per_block=NUM_SUB_MODULES_PER_BLOCK_APP
).to(device_global)
swck_model_global.eval()
swck_model_global.debug_prints_enabled = False
for block in swck_model_global.adaptive_blocks:
block.debug_prints_enabled = False
return "Initialized a new (untrained) model as checkpoint was not found."
# --- Text Generation Function (adapted from train.py) ---
def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
if swck_model_global is None or word_to_idx_global is None or idx_to_word_global is None:
return "Model not loaded. Please check server logs."
swck_model_global.eval() # Ensure model is in eval mode
swck_model_global.set_wiring_phase(False) # No wiring adjustments during inference
print(f"App: Generating for prompt: '{prompt_str}', max_len: {max_len_gen}, temp: {temperature_gen}")
tokens = [SOS_TOKEN] + [word_to_idx_global.get(w, UNK_TOKEN) for w in prompt_str.lower().split()]
generated_ids_app = list(tokens)
# Collect some debug info for display (optional)
debug_info_lines = []
with torch.no_grad():
for i in range(max_len_gen):
# Context windowing for input_tensor
current_context_ids = generated_ids_app[-SEQ_LEN_APP:]
input_tensor = torch.tensor([current_context_ids], dtype=torch.long).to(device_global)
padding_mask = (input_tensor == PAD_TOKEN)
# Set model debug prints for first step only if want to show internal state
# For cleaner app, keep them off or make it a toggle.
# if i == 0:
# swck_model_global.debug_prints_enabled = True
# for block in swck_model_global.adaptive_blocks: block.debug_prints_enabled = True
# else:
# swck_model_global.debug_prints_enabled = False
# for block in swck_model_global.adaptive_blocks: block.debug_prints_enabled = False
logits, entropy_report_infer = swck_model_global(input_tensor, src_key_padding_mask=padding_mask)
next_token_logits = logits[0, -1, :] # Logits for the last token in the current sequence
if temperature_gen == 0: # Greedy
next_token_id = torch.argmax(next_token_logits).item()
else:
probs = F.softmax(next_token_logits / temperature_gen, dim=-1)
next_token_id = torch.multinomial(probs, 1).item()
if next_token_id == EOS_TOKEN:
debug_info_lines.append(f"Step {i+1}: EOS token encountered.")
break
generated_ids_app.append(next_token_id)
# Store some info from the first few steps
if i < 5 : # Log details for first 5 generated tokens
current_word = idx_to_word_global.get(next_token_id, UNK_TOKEN_STR)
overall_ent = entropy_report_infer['overall_output_entropy'].item()
b0_ent = entropy_report_infer['block_output_entropies'][0].item()
b0_gates_str = ", ".join([f"{g.item():.2f}" for g in entropy_report_infer['block_gate_weights'][0]])
debug_info_lines.append(f"Gen {i+1}: '{current_word}', OvrlEnt={overall_ent:.3f}, B0Ent={b0_ent:.3f}, B0Gates=[{b0_gates_str}]")
generated_text_list = [idx_to_word_global.get(idx, UNK_TOKEN_STR) for idx in generated_ids_app[1:]] # Skip SOS
final_text = " ".join(generated_text_list)
final_text = final_text.replace(EOS_TOKEN_STR, "").strip()
# Basic cleaning
final_text = final_text.replace(" .", ".").replace(" ,", ",").replace(" ?", "?").replace(" !", "!")
final_text = re.sub(r'\s+([.,?!])', r'\1', final_text)
final_text = re.sub(r'\s+', ' ', final_text).strip()
debug_output_str = "\n".join(debug_info_lines)
return final_text, debug_output_str
# --- Gradio Interface ---
loading_status = load_model_and_vocab() # Load model on app startup
with gr.Blocks(title="SWCK Conceptual Demo") as demo:
gr.Markdown(f"""
# Self-Wired Conscious Kernel (SWCK) - Conceptual Demo
This demo showcases a conceptual text generation model based on the SWCK architecture.
The model is initialized with the seed phrase: "{SEED_PHRASE_APP[:100]}..."
and seed number: "{SEED_NUMBER_STR_APP}".
**Model Status:** {loading_status}
(Note: If no checkpoint is found, an *untrained* model is used, and generations will be random.)
""")
with gr.Row():
prompt_input = gr.Textbox(label="Enter your prompt:", placeholder="e.g., the meaning of existence is")
with gr.Row():
max_len_slider = gr.Slider(minimum=10, maximum=150, value=50, step=1, label="Max Generation Length")
temp_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.1, label="Temperature (0 for greedy)")
generate_button = gr.Button("Generate Text")
with gr.Column():
output_text = gr.Textbox(label="Generated Text:", lines=5)
debug_text_area = gr.Textbox(label="Generation Debug Info (first few steps):", lines=7, interactive=False)
generate_button.click(
fn=generate_text_for_app,
inputs=[prompt_input, max_len_slider, temp_slider],
outputs=[output_text, debug_text_area]
)
gr.Markdown("Note: This is a highly conceptual and simplified sketch. Generation quality will be limited, especially with an untrained model or small dataset.")
if __name__ == "__main__":
demo.launch() |