neuralworm commited on
Commit
afb3e05
·
verified ·
1 Parent(s): d2d8270

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -40
app.py CHANGED
@@ -6,7 +6,7 @@ from torch.utils.data import Dataset, DataLoader
6
  import os
7
  import re
8
  import time
9
- import torch.nn.functional as F # <<<<<<<<<<<< ADDED THIS IMPORT
10
  from model import SWCKModel, SeedParser, EntropyEstimator
11
 
12
  # --- Vocabulary and Tokenizer Setup ---
@@ -57,6 +57,17 @@ OVERALL_OUTPUT_ENTROPY_REG_WEIGHT_APP = 0.01
57
  GATE_SPARSITY_LOSS_WEIGHT_APP = 0.001
58
  WIRING_PHASE_EPOCHS_APP = 1
59
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def build_vocab_from_corpus_text_app(corpus_text):
62
  global VOCAB_SIZE_APP
@@ -74,7 +85,7 @@ def build_vocab_from_corpus_text_app(corpus_text):
74
  print(f"App: Built vocab of size {VOCAB_SIZE_APP}")
75
  return temp_word_to_idx, temp_idx_to_word
76
 
77
- def initialize_or_load_model_app():
78
  global swck_model_global, optimizer_global, word_to_idx_global, idx_to_word_global, \
79
  VOCAB_SIZE_APP, model_load_status_global
80
 
@@ -92,13 +103,16 @@ def initialize_or_load_model_app():
92
  'seed_number_str': SEED_NUMBER_STR_APP,
93
  'num_sub_modules_per_block': NUM_SUB_MODULES_PER_BLOCK_APP
94
  }
95
-
 
 
 
96
  swck_model_global = SWCKModel(**model_args).to(device_global)
97
- swck_model_global.debug_prints_enabled = True
98
- if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = True
99
- for i,block in enumerate(swck_model_global.adaptive_blocks):
100
- block.debug_prints_enabled = True
101
- # print(f"App: Debug prints explicitly enabled for AdaptiveBlock {i}")
102
 
103
 
104
  if os.path.exists(CHECKPOINT_FILENAME):
@@ -126,8 +140,13 @@ def initialize_or_load_model_app():
126
  model_load_status_global = f"Model loaded successfully from {CHECKPOINT_FILENAME}."
127
  print(model_load_status_global)
128
  except Exception as e:
129
- print(f"App: Error loading model from checkpoint: {e}. Initializing new model.")
130
- swck_model_global = SWCKModel(**model_args).to(device_global) # Re-init
 
 
 
 
 
131
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
132
  model_load_status_global = "Error loading checkpoint. Using new (untrained) model."
133
  else:
@@ -135,7 +154,13 @@ def initialize_or_load_model_app():
135
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
136
  model_load_status_global = "Initialized a new (untrained) model."
137
 
138
- swck_model_global.eval()
 
 
 
 
 
 
139
  return model_load_status_global
140
 
141
 
@@ -195,17 +220,16 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
195
  swck_model_global.set_wiring_phase(epoch < WIRING_PHASE_EPOCHS_APP)
196
  epoch_loss = 0.0
197
 
198
- first_batch_debug = (epoch == 0)
 
 
199
 
200
  for batch_idx, (src_batch, tgt_batch) in enumerate(app_dataloader):
201
- if first_batch_debug and batch_idx == 0:
202
- swck_model_global.debug_prints_enabled = True
203
- if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = True
204
- for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = True
205
- elif not (first_batch_debug and batch_idx == 0) :
206
- swck_model_global.debug_prints_enabled = False
207
- if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = False
208
- for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
209
 
210
 
211
  src_batch, tgt_batch = src_batch.to(device_global), tgt_batch.to(device_global)
@@ -231,7 +255,7 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
231
  if entropy_report["block_output_entropies"]:
232
  for i, block_entropy_tensor in enumerate(entropy_report["block_output_entropies"]):
233
  target_entropy_val = swck_model_global.seed_parser.get_block_config(i)["target_entropy"]
234
- block_entropy_loss += F.mse_loss(block_entropy_tensor, torch.tensor(target_entropy_val, device=device_global)) # Used F here
235
  if entropy_report["block_output_entropies"]:
236
  block_entropy_loss = block_entropy_loss / len(entropy_report["block_output_entropies"])
237
 
@@ -254,23 +278,19 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
254
  epoch_loss += combined_loss.item()
255
 
256
  log_line = f" Epoch {epoch+1}, Batch {batch_idx+1}/{len(app_dataloader)}, Loss: {combined_loss.item():.4f}"
257
- if batch_idx % max(1, len(app_dataloader)//5) == 0 or batch_idx == len(app_dataloader)-1 :
258
  print(log_line)
259
  training_log_output += log_line + "\n"
260
 
261
- swck_model_global.debug_prints_enabled = False
262
- if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = False
263
- for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
264
-
265
 
266
  avg_epoch_loss = epoch_loss / len(app_dataloader) if len(app_dataloader) > 0 else epoch_loss
267
  epoch_summary = f"Epoch {epoch+1}/{num_epochs_app} - Avg Loss: {avg_epoch_loss:.4f}\n"
268
  print(epoch_summary)
269
  training_log_output += epoch_summary
270
 
271
- swck_model_global.debug_prints_enabled = False
272
- if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = False
273
- for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
274
  swck_model_global.eval()
275
 
276
  try:
@@ -296,7 +316,7 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
296
 
297
  return training_log_output
298
 
299
- def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
300
  global model_load_status_global
301
  if swck_model_global is None or word_to_idx_global is None or idx_to_word_global is None:
302
  return "Model not loaded. Please check server logs or try training.", "Model not available."
@@ -304,7 +324,10 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
304
  swck_model_global.eval()
305
  swck_model_global.set_wiring_phase(False)
306
 
307
- print(f"App: Generating for prompt: '{prompt_str}', max_len: {max_len_gen}, temp: {temperature_gen}")
 
 
 
308
 
309
  tokens = [SOS_TOKEN] + [word_to_idx_global.get(w, UNK_TOKEN) for w in prompt_str.lower().split()]
310
  generated_ids_app = list(tokens)
@@ -324,7 +347,7 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
324
  if temperature_gen == 0:
325
  next_token_id = torch.argmax(next_token_logits).item()
326
  else:
327
- probs = F.softmax(next_token_logits / temperature_gen, dim=-1) # Used F here
328
  if probs.isnan().any() or probs.isinf().any() or torch.sum(probs).item() < 1e-9 :
329
  print(f"Warning: Invalid probabilities at step {i}. Using uniform.")
330
  probs = torch.ones_like(next_token_logits) / next_token_logits.size(-1)
@@ -335,7 +358,7 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
335
  break
336
  generated_ids_app.append(next_token_id)
337
 
338
- if i < 10 :
339
  current_word = idx_to_word_global.get(next_token_id, UNK_TOKEN_STR)
340
  overall_ent = entropy_report_infer['overall_output_entropy'].item()
341
  if entropy_report_infer['block_output_entropies'] and len(entropy_report_infer['block_output_entropies']) > 0:
@@ -357,9 +380,12 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
357
 
358
  debug_output_str = "\n".join(debug_info_lines)
359
 
 
 
360
  return final_text, debug_output_str
361
 
362
- initial_load_status = initialize_or_load_model_app()
 
363
 
364
  with gr.Blocks(title="SWCK Conceptual Demo") as demo:
365
  model_status_md = gr.Markdown(value=f"**Model Status:** {initial_load_status}", elem_id="model_status_md_123")
@@ -375,30 +401,32 @@ with gr.Blocks(title="SWCK Conceptual Demo") as demo:
375
  with gr.TabItem("Generate Text"):
376
  with gr.Row():
377
  prompt_input = gr.Textbox(label="Enter your prompt:", placeholder="e.g., the meaning of existence is", scale=3)
 
 
378
  generate_button = gr.Button("Generate", scale=1)
379
  with gr.Row():
380
  max_len_slider = gr.Slider(minimum=10, maximum=150, value=50, step=1, label="Max Generation Length")
381
  temp_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.1, label="Temperature (0 for greedy)")
382
 
383
  output_text = gr.Textbox(label="Generated Text:", lines=6, interactive=False)
384
- debug_text_area = gr.Textbox(label="Generation Debug Info (first few steps):", lines=8, interactive=False)
385
 
386
  with gr.TabItem("In-App Training (Conceptual Test)"):
387
- gr.Markdown("WARNING: In-app training is EXTREMELY slow and only for basic conceptual testing on Spaces free tier. Uses a small internal corpus. Model state persists only for this session unless saved manually via code modification.")
388
  with gr.Row():
389
- train_epochs_slider = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Training Epochs")
390
- train_batch_size_slider = gr.Slider(minimum=1, maximum=8, value=2, step=1, label="Training Batch Size")
391
  train_lr_slider = gr.Slider(minimum=1e-5, maximum=1e-3, value=5e-4, step=1e-5, label="Learning Rate")
392
 
393
  start_training_button = gr.Button("Start Short Training Session")
394
- training_status_output = gr.Textbox(label="Training Log / Status:", lines=10, interactive=False,show_label=True )
395
 
396
  def update_status_text_for_ui():
397
  return f"**Model Status:** {model_load_status_global}"
398
 
399
  generate_button.click(
400
  fn=generate_text_for_app,
401
- inputs=[prompt_input, max_len_slider, temp_slider],
402
  outputs=[output_text, debug_text_area]
403
  )
404
 
@@ -410,4 +438,6 @@ with gr.Blocks(title="SWCK Conceptual Demo") as demo:
410
 
411
 
412
  if __name__ == "__main__":
 
 
413
  demo.launch(debug=True)
 
6
  import os
7
  import re
8
  import time
9
+ import torch.nn.functional as F
10
  from model import SWCKModel, SeedParser, EntropyEstimator
11
 
12
  # --- Vocabulary and Tokenizer Setup ---
 
57
  GATE_SPARSITY_LOSS_WEIGHT_APP = 0.001
58
  WIRING_PHASE_EPOCHS_APP = 1
59
 
60
+ # --- Helper to toggle all debug prints in the model ---
61
+ def set_model_debug_prints(model, seed_parser_debug, block_debug, model_debug):
62
+ if model:
63
+ model.debug_prints_enabled = model_debug
64
+ if hasattr(model, 'seed_parser'):
65
+ model.seed_parser.debug_prints_enabled = seed_parser_debug
66
+ if hasattr(model, 'adaptive_blocks'):
67
+ for block in model.adaptive_blocks:
68
+ block.debug_prints_enabled = block_debug
69
+ print(f"App: Model debug prints set - SeedParser: {seed_parser_debug}, Blocks: {block_debug}, SWCKModel: {model_debug}")
70
+
71
 
72
  def build_vocab_from_corpus_text_app(corpus_text):
73
  global VOCAB_SIZE_APP
 
85
  print(f"App: Built vocab of size {VOCAB_SIZE_APP}")
86
  return temp_word_to_idx, temp_idx_to_word
87
 
88
+ def initialize_or_load_model_app(enable_initial_debug=True): # Control initial debug prints
89
  global swck_model_global, optimizer_global, word_to_idx_global, idx_to_word_global, \
90
  VOCAB_SIZE_APP, model_load_status_global
91
 
 
103
  'seed_number_str': SEED_NUMBER_STR_APP,
104
  'num_sub_modules_per_block': NUM_SUB_MODULES_PER_BLOCK_APP
105
  }
106
+
107
+ # Temporarily disable debug during model init to avoid clutter if enable_initial_debug is False
108
+ # The SeedParser within SWCKModel will print if its own flag is True
109
+
110
  swck_model_global = SWCKModel(**model_args).to(device_global)
111
+ # Set debug prints AFTER full model initialization
112
+ set_model_debug_prints(swck_model_global,
113
+ seed_parser_debug=enable_initial_debug,
114
+ block_debug=enable_initial_debug,
115
+ model_debug=enable_initial_debug)
116
 
117
 
118
  if os.path.exists(CHECKPOINT_FILENAME):
 
140
  model_load_status_global = f"Model loaded successfully from {CHECKPOINT_FILENAME}."
141
  print(model_load_status_global)
142
  except Exception as e:
143
+ print(f"App: Error loading model from checkpoint: {e}. Re-initializing new model.")
144
+ # Re-initialize if loading failed, ensuring debug flags are set again
145
+ swck_model_global = SWCKModel(**model_args).to(device_global)
146
+ set_model_debug_prints(swck_model_global,
147
+ seed_parser_debug=enable_initial_debug,
148
+ block_debug=enable_initial_debug,
149
+ model_debug=enable_initial_debug)
150
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
151
  model_load_status_global = "Error loading checkpoint. Using new (untrained) model."
152
  else:
 
154
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
155
  model_load_status_global = "Initialized a new (untrained) model."
156
 
157
+ swck_model_global.eval() # Default to eval mode
158
+ # After loading or initializing, ensure debug prints are set based on desire for startup logs
159
+ # If enable_initial_debug was False, they are off. If True, they were on during init.
160
+ # For operations like training/generation, we'll toggle them explicitly.
161
+ if not enable_initial_debug: # Turn them off if they weren't meant to be on for init
162
+ set_model_debug_prints(swck_model_global, False, False, False)
163
+
164
  return model_load_status_global
165
 
166
 
 
220
  swck_model_global.set_wiring_phase(epoch < WIRING_PHASE_EPOCHS_APP)
221
  epoch_loss = 0.0
222
 
223
+ # Enable full debug for the first batch of the first "wiring" epoch
224
+ # This will give detailed insight into the "self-wiring roll" on the first piece of data
225
+ is_first_wiring_batch = (epoch < WIRING_PHASE_EPOCHS_APP and epoch == 0)
226
 
227
  for batch_idx, (src_batch, tgt_batch) in enumerate(app_dataloader):
228
+ if is_first_wiring_batch and batch_idx == 0:
229
+ print(">>> Enabling FULL DEBUG for first wiring batch <<<")
230
+ set_model_debug_prints(swck_model_global, True, True, True)
231
+ else: # Otherwise, keep debug prints minimal or off for speed
232
+ set_model_debug_prints(swck_model_global, False, False, False)
 
 
 
233
 
234
 
235
  src_batch, tgt_batch = src_batch.to(device_global), tgt_batch.to(device_global)
 
255
  if entropy_report["block_output_entropies"]:
256
  for i, block_entropy_tensor in enumerate(entropy_report["block_output_entropies"]):
257
  target_entropy_val = swck_model_global.seed_parser.get_block_config(i)["target_entropy"]
258
+ block_entropy_loss += F.mse_loss(block_entropy_tensor, torch.tensor(target_entropy_val, device=device_global))
259
  if entropy_report["block_output_entropies"]:
260
  block_entropy_loss = block_entropy_loss / len(entropy_report["block_output_entropies"])
261
 
 
278
  epoch_loss += combined_loss.item()
279
 
280
  log_line = f" Epoch {epoch+1}, Batch {batch_idx+1}/{len(app_dataloader)}, Loss: {combined_loss.item():.4f}"
281
+ if batch_idx % max(1, len(app_dataloader)//2) == 0 or batch_idx == len(app_dataloader)-1 :
282
  print(log_line)
283
  training_log_output += log_line + "\n"
284
 
285
+ # Ensure debug is off after the first special batch
286
+ set_model_debug_prints(swck_model_global, False, False, False)
 
 
287
 
288
  avg_epoch_loss = epoch_loss / len(app_dataloader) if len(app_dataloader) > 0 else epoch_loss
289
  epoch_summary = f"Epoch {epoch+1}/{num_epochs_app} - Avg Loss: {avg_epoch_loss:.4f}\n"
290
  print(epoch_summary)
291
  training_log_output += epoch_summary
292
 
293
+ set_model_debug_prints(swck_model_global, False, False, False) # Ensure off after all training
 
 
294
  swck_model_global.eval()
295
 
296
  try:
 
316
 
317
  return training_log_output
318
 
319
+ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen, enable_gen_debug: bool): # Add debug toggle
320
  global model_load_status_global
321
  if swck_model_global is None or word_to_idx_global is None or idx_to_word_global is None:
322
  return "Model not loaded. Please check server logs or try training.", "Model not available."
 
324
  swck_model_global.eval()
325
  swck_model_global.set_wiring_phase(False)
326
 
327
+ # Set debug prints based on UI toggle for this generation call
328
+ set_model_debug_prints(swck_model_global, enable_gen_debug, enable_gen_debug, enable_gen_debug)
329
+
330
+ print(f"App: Generating for prompt: '{prompt_str}', max_len: {max_len_gen}, temp: {temperature_gen}, Debug: {enable_gen_debug}")
331
 
332
  tokens = [SOS_TOKEN] + [word_to_idx_global.get(w, UNK_TOKEN) for w in prompt_str.lower().split()]
333
  generated_ids_app = list(tokens)
 
347
  if temperature_gen == 0:
348
  next_token_id = torch.argmax(next_token_logits).item()
349
  else:
350
+ probs = F.softmax(next_token_logits / temperature_gen, dim=-1)
351
  if probs.isnan().any() or probs.isinf().any() or torch.sum(probs).item() < 1e-9 :
352
  print(f"Warning: Invalid probabilities at step {i}. Using uniform.")
353
  probs = torch.ones_like(next_token_logits) / next_token_logits.size(-1)
 
358
  break
359
  generated_ids_app.append(next_token_id)
360
 
361
+ if i < 10 : # UI debug info is still limited to first 10 new tokens for brevity
362
  current_word = idx_to_word_global.get(next_token_id, UNK_TOKEN_STR)
363
  overall_ent = entropy_report_infer['overall_output_entropy'].item()
364
  if entropy_report_infer['block_output_entropies'] and len(entropy_report_infer['block_output_entropies']) > 0:
 
380
 
381
  debug_output_str = "\n".join(debug_info_lines)
382
 
383
+ # Important: Turn off debug prints after generation if they were turned on
384
+ set_model_debug_prints(swck_model_global, False, False, False)
385
  return final_text, debug_output_str
386
 
387
+ # Load model once on app startup. Set enable_initial_debug=False for cleaner startup logs.
388
+ initial_load_status = initialize_or_load_model_app(enable_initial_debug=False)
389
 
390
  with gr.Blocks(title="SWCK Conceptual Demo") as demo:
391
  model_status_md = gr.Markdown(value=f"**Model Status:** {initial_load_status}", elem_id="model_status_md_123")
 
401
  with gr.TabItem("Generate Text"):
402
  with gr.Row():
403
  prompt_input = gr.Textbox(label="Enter your prompt:", placeholder="e.g., the meaning of existence is", scale=3)
404
+ enable_generation_debug_checkbox = gr.Checkbox(label="Enable Full Kernel Debug (to Console Logs)", value=False)
405
+ with gr.Row():
406
  generate_button = gr.Button("Generate", scale=1)
407
  with gr.Row():
408
  max_len_slider = gr.Slider(minimum=10, maximum=150, value=50, step=1, label="Max Generation Length")
409
  temp_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.1, label="Temperature (0 for greedy)")
410
 
411
  output_text = gr.Textbox(label="Generated Text:", lines=6, interactive=False)
412
+ debug_text_area = gr.Textbox(label="Generation Debug Info (first few steps to UI):", lines=8, interactive=False)
413
 
414
  with gr.TabItem("In-App Training (Conceptual Test)"):
415
+ gr.Markdown("WARNING: In-app training is EXTREMELY slow and only for basic conceptual testing on Spaces free tier. Uses a small internal corpus. Model state persists only for this session unless saved manually via code modification. Full Kernel Debug will be printed to console for the FIRST BATCH of the FIRST WIRING EPOCH ONLY.")
416
  with gr.Row():
417
+ train_epochs_slider = gr.Slider(minimum=1, maximum=3, value=1, step=1, label="Number of Training Epochs (1-3 for demo)") # Reduced max
418
+ train_batch_size_slider = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Training Batch Size (1-4 for demo)") # Reduced max
419
  train_lr_slider = gr.Slider(minimum=1e-5, maximum=1e-3, value=5e-4, step=1e-5, label="Learning Rate")
420
 
421
  start_training_button = gr.Button("Start Short Training Session")
422
+ training_status_output = gr.Textbox(label="Training Log / Status (summary):", lines=10, interactive=False,show_label=True )
423
 
424
  def update_status_text_for_ui():
425
  return f"**Model Status:** {model_load_status_global}"
426
 
427
  generate_button.click(
428
  fn=generate_text_for_app,
429
+ inputs=[prompt_input, max_len_slider, temp_slider, enable_generation_debug_checkbox], # Added checkbox
430
  outputs=[output_text, debug_text_area]
431
  )
432
 
 
438
 
439
 
440
  if __name__ == "__main__":
441
+ # For local testing, you can launch with debug=True for Gradio's server debug.
442
+ # The model's internal debug prints are controlled by set_model_debug_prints().
443
  demo.launch(debug=True)