neuralworm commited on
Commit
b41e522
·
verified ·
1 Parent(s): 30f4d64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -47
app.py CHANGED
@@ -93,10 +93,10 @@ def initialize_or_load_model_app():
93
  }
94
 
95
  swck_model_global = SWCKModel(**model_args).to(device_global)
96
- swck_model_global.debug_prints_enabled = True # Top-level model debug
97
  if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = True
98
  for i,block in enumerate(swck_model_global.adaptive_blocks):
99
- block.debug_prints_enabled = True # Block-level debug
100
  # print(f"App: Debug prints explicitly enabled for AdaptiveBlock {i}")
101
 
102
 
@@ -112,18 +112,16 @@ def initialize_or_load_model_app():
112
 
113
  if 'word_to_idx' in checkpoint:
114
  loaded_w2i = checkpoint['word_to_idx']
115
- # Basic check, could be more robust
116
  if isinstance(loaded_w2i, dict) and len(loaded_w2i) > 4:
117
  word_to_idx_global = loaded_w2i
118
  idx_to_word_global = {v: k for k,v in loaded_w2i.items()}
119
- VOCAB_SIZE_APP = len(word_to_idx_global) # Ensure vocab size reflects loaded
120
  print(f"App: Overwrote vocab with checkpoint's vocab. New size: {VOCAB_SIZE_APP}")
121
  else:
122
  print("App: Checkpoint vocab seems invalid, using app's rebuilt vocab.")
123
  else:
124
  print("App: word_to_idx not in checkpoint, using app's rebuilt vocab.")
125
 
126
-
127
  model_load_status_global = f"Model loaded successfully from {CHECKPOINT_FILENAME}."
128
  print(model_load_status_global)
129
  except Exception as e:
@@ -148,11 +146,9 @@ class AppSWCKDataset(Dataset):
148
  self.seq_len = seq_len
149
  self.sos_id, self.eos_id, self.pad_id = sos_id, eos_id, pad_id
150
  self.samples = []
151
- # Create overlapping sequences for language modeling
152
- # Ensure target is seq_len for consistency with input to model.
153
- for i in range(len(token_ids) - seq_len -1): # -1 to ensure target has full seq_len
154
- input_seq = [self.sos_id] + token_ids[i : i + seq_len] # length seq_len + 1
155
- target_seq = token_ids[i + 1 : i + seq_len + 1] + [self.eos_id] # length seq_len + 1
156
  self.samples.append((input_seq, target_seq))
157
  print(f"AppSWCKDataset: Created {len(self.samples)} training samples for in-app training.")
158
 
@@ -198,39 +194,35 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
198
  swck_model_global.set_wiring_phase(epoch < WIRING_PHASE_EPOCHS_APP)
199
  epoch_loss = 0.0
200
 
201
- # Enable debug for first batch of first epoch
202
  first_batch_debug = (epoch == 0)
203
 
204
  for batch_idx, (src_batch, tgt_batch) in enumerate(app_dataloader):
205
  if first_batch_debug and batch_idx == 0:
206
  swck_model_global.debug_prints_enabled = True
 
207
  for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = True
208
- elif not (first_batch_debug and batch_idx == 0) : # Disable after first batch for speed
209
  swck_model_global.debug_prints_enabled = False
 
210
  for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
211
 
212
 
213
  src_batch, tgt_batch = src_batch.to(device_global), tgt_batch.to(device_global)
214
- decoder_input_tokens = src_batch[:, :-1] # Remove EOS from input
215
- gold_standard_for_loss = tgt_batch[:, 1:] # Remove SOS from target
216
 
217
  src_key_padding_mask = (decoder_input_tokens == PAD_TOKEN)
218
 
219
  optimizer_global.zero_grad()
220
  logits, entropy_report = swck_model_global(decoder_input_tokens, src_key_padding_mask=src_key_padding_mask)
221
 
222
- # Ensure logits and gold_standard_for_loss are aligned for CrossEntropyLoss
223
- # Logits: (B, S_len_in, VocabSize)
224
- # Gold: (B, S_len_target)
225
- # If S_len_in == S_len_target, it's fine.
226
  if logits.size(1) != gold_standard_for_loss.size(1):
227
- # This can happen if seq len handling differs slightly, adjust shorter one
228
  min_len = min(logits.size(1), gold_standard_for_loss.size(1))
229
- logits_for_loss = logits[:, :min_len, :].contiguous()
230
  gold_for_loss_aligned = gold_standard_for_loss[:, :min_len].contiguous()
231
  else:
232
- logits_for_loss = logits
233
- gold_for_loss_aligned = gold_standard_for_loss
234
 
235
  main_loss = criterion_main_app(logits_for_loss.view(-1, logits_for_loss.size(-1)), gold_for_loss_aligned.view(-1))
236
 
@@ -239,7 +231,7 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
239
  for i, block_entropy_tensor in enumerate(entropy_report["block_output_entropies"]):
240
  target_entropy_val = swck_model_global.seed_parser.get_block_config(i)["target_entropy"]
241
  block_entropy_loss += F.mse_loss(block_entropy_tensor, torch.tensor(target_entropy_val, device=device_global))
242
- if entropy_report["block_output_entropies"]: # Avoid division by zero
243
  block_entropy_loss = block_entropy_loss / len(entropy_report["block_output_entropies"])
244
 
245
  overall_entropy_loss = entropy_report["overall_output_entropy"]
@@ -247,7 +239,7 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
247
  if entropy_report["block_gate_weights"]:
248
  for gates_softmax_tensor in entropy_report["block_gate_weights"]:
249
  gate_sparsity_loss += torch.mean(gates_softmax_tensor * torch.log(gates_softmax_tensor + 1e-9))
250
- if entropy_report["block_gate_weights"]: # Avoid division by zero
251
  gate_sparsity_loss = - (gate_sparsity_loss / len(entropy_report["block_gate_weights"]))
252
 
253
  combined_loss = (MAIN_LOSS_WEIGHT_APP * main_loss +
@@ -261,12 +253,12 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
261
  epoch_loss += combined_loss.item()
262
 
263
  log_line = f" Epoch {epoch+1}, Batch {batch_idx+1}/{len(app_dataloader)}, Loss: {combined_loss.item():.4f}"
264
- if batch_idx % max(1, len(app_dataloader)//2) == 0 or batch_idx == len(app_dataloader)-1 : # Log less frequently to UI
265
  print(log_line)
266
  training_log_output += log_line + "\n"
267
 
268
- # Disable debug prints after the very first batch of the first epoch
269
  swck_model_global.debug_prints_enabled = False
 
270
  for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
271
 
272
 
@@ -275,8 +267,8 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
275
  print(epoch_summary)
276
  training_log_output += epoch_summary
277
 
278
- # Ensure debug prints are off after training session
279
  swck_model_global.debug_prints_enabled = False
 
280
  for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
281
  swck_model_global.eval()
282
 
@@ -310,8 +302,10 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
310
 
311
  swck_model_global.eval()
312
  swck_model_global.set_wiring_phase(False)
313
- # Temporarily enable debug for generation if needed, then disable
314
- # swck_model_global.debug_prints_enabled = True # For generation debug
 
 
315
  # for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = True
316
 
317
  print(f"App: Generating for prompt: '{prompt_str}', max_len: {max_len_gen}, temp: {temperature_gen}")
@@ -321,9 +315,7 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
321
  debug_info_lines = [f"Prompt tokens: {generated_ids_app}"]
322
 
323
  with torch.no_grad():
324
- for i in range(int(max_len_gen)): # Ensure max_len_gen is int
325
- # Context windowing for input_tensor
326
- # Take up to SEQ_LEN_APP tokens from the end of generated_ids_app
327
  context_start_idx = max(0, len(generated_ids_app) - SEQ_LEN_APP)
328
  current_context_ids = generated_ids_app[context_start_idx:]
329
 
@@ -360,7 +352,6 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
360
  else:
361
  debug_info_lines.append(f"Gen {i+1}: '{current_word}', OvrlEnt={overall_ent:.3f}, No block entropy/gate report.")
362
 
363
-
364
  generated_text_list = [idx_to_word_global.get(idx, UNK_TOKEN_STR) for idx in generated_ids_app[1:]]
365
  final_text = " ".join(generated_text_list)
366
  final_text = final_text.replace(EOS_TOKEN_STR, "").strip()
@@ -370,20 +361,21 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
370
 
371
  debug_output_str = "\n".join(debug_info_lines)
372
 
373
- # Disable debug prints after generation
374
- # swck_model_global.debug_prints_enabled = False
375
  # for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
376
  return final_text, debug_output_str
377
 
378
- # --- Gradio Interface ---
379
- initial_load_status = initialize_or_load_model_app() # Load model on app startup
380
 
381
  with gr.Blocks(title="SWCK Conceptual Demo") as demo:
 
 
 
382
  gr.Markdown(f"""
383
  # Self-Wired Conscious Kernel (SWCK) - Conceptual Demo
384
  This demo showcases a conceptual text generation model.
385
  Seed Phrase: "{SEED_PHRASE_APP[:100]}..." | Seed Number: "{SEED_NUMBER_STR_APP}".
386
- **Model Status:** <span id="model_status_display">{initial_load_status}</span>
387
  (Note: If checkpoint is not found or fails to load, an *untrained* model is used.)
388
  """)
389
 
@@ -404,16 +396,13 @@ with gr.Blocks(title="SWCK Conceptual Demo") as demo:
404
  with gr.Row():
405
  train_epochs_slider = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Training Epochs")
406
  train_batch_size_slider = gr.Slider(minimum=1, maximum=8, value=2, step=1, label="Training Batch Size")
407
- # REMOVED format="%.1e"
408
  train_lr_slider = gr.Slider(minimum=1e-5, maximum=1e-3, value=5e-4, step=1e-5, label="Learning Rate")
409
 
410
  start_training_button = gr.Button("Start Short Training Session")
411
  training_status_output = gr.Textbox(label="Training Log / Status:", lines=10, interactive=False,show_label=True )
412
 
413
-
414
- model_status_md = gr.Markdown(value=f"**Model Status:** {model_load_status_global}")
415
-
416
- def update_status_text(): # Helper to refresh status after training
417
  return f"**Model Status:** {model_load_status_global}"
418
 
419
  generate_button.click(
@@ -426,11 +415,8 @@ with gr.Blocks(title="SWCK Conceptual Demo") as demo:
426
  fn=run_short_training_session,
427
  inputs=[train_epochs_slider, train_batch_size_slider, train_lr_slider],
428
  outputs=[training_status_output]
429
- ).then(fn=update_status_text, inputs=None, outputs=model_status_md)
430
 
431
 
432
  if __name__ == "__main__":
433
- # The Gradio app launch options (like debug=True) are for local execution.
434
- # On Hugging Face Spaces, these are typically controlled by the environment.
435
- # The `print()` statements will go to the Space's console logs.
436
  demo.launch(debug=True)
 
93
  }
94
 
95
  swck_model_global = SWCKModel(**model_args).to(device_global)
96
+ swck_model_global.debug_prints_enabled = True
97
  if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = True
98
  for i,block in enumerate(swck_model_global.adaptive_blocks):
99
+ block.debug_prints_enabled = True
100
  # print(f"App: Debug prints explicitly enabled for AdaptiveBlock {i}")
101
 
102
 
 
112
 
113
  if 'word_to_idx' in checkpoint:
114
  loaded_w2i = checkpoint['word_to_idx']
 
115
  if isinstance(loaded_w2i, dict) and len(loaded_w2i) > 4:
116
  word_to_idx_global = loaded_w2i
117
  idx_to_word_global = {v: k for k,v in loaded_w2i.items()}
118
+ VOCAB_SIZE_APP = len(word_to_idx_global)
119
  print(f"App: Overwrote vocab with checkpoint's vocab. New size: {VOCAB_SIZE_APP}")
120
  else:
121
  print("App: Checkpoint vocab seems invalid, using app's rebuilt vocab.")
122
  else:
123
  print("App: word_to_idx not in checkpoint, using app's rebuilt vocab.")
124
 
 
125
  model_load_status_global = f"Model loaded successfully from {CHECKPOINT_FILENAME}."
126
  print(model_load_status_global)
127
  except Exception as e:
 
146
  self.seq_len = seq_len
147
  self.sos_id, self.eos_id, self.pad_id = sos_id, eos_id, pad_id
148
  self.samples = []
149
+ for i in range(len(token_ids) - seq_len -1):
150
+ input_seq = [self.sos_id] + token_ids[i : i + seq_len]
151
+ target_seq = token_ids[i + 1 : i + seq_len + 1] + [self.eos_id]
 
 
152
  self.samples.append((input_seq, target_seq))
153
  print(f"AppSWCKDataset: Created {len(self.samples)} training samples for in-app training.")
154
 
 
194
  swck_model_global.set_wiring_phase(epoch < WIRING_PHASE_EPOCHS_APP)
195
  epoch_loss = 0.0
196
 
 
197
  first_batch_debug = (epoch == 0)
198
 
199
  for batch_idx, (src_batch, tgt_batch) in enumerate(app_dataloader):
200
  if first_batch_debug and batch_idx == 0:
201
  swck_model_global.debug_prints_enabled = True
202
+ if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = True
203
  for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = True
204
+ elif not (first_batch_debug and batch_idx == 0) :
205
  swck_model_global.debug_prints_enabled = False
206
+ if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = False
207
  for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
208
 
209
 
210
  src_batch, tgt_batch = src_batch.to(device_global), tgt_batch.to(device_global)
211
+ decoder_input_tokens = src_batch[:, :-1]
212
+ gold_standard_for_loss = tgt_batch[:, 1:]
213
 
214
  src_key_padding_mask = (decoder_input_tokens == PAD_TOKEN)
215
 
216
  optimizer_global.zero_grad()
217
  logits, entropy_report = swck_model_global(decoder_input_tokens, src_key_padding_mask=src_key_padding_mask)
218
 
 
 
 
 
219
  if logits.size(1) != gold_standard_for_loss.size(1):
 
220
  min_len = min(logits.size(1), gold_standard_for_loss.size(1))
221
+ logits_for_loss = logits[:, :min_len, :].contiguous() # ADDED .contiguous()
222
  gold_for_loss_aligned = gold_standard_for_loss[:, :min_len].contiguous()
223
  else:
224
+ logits_for_loss = logits.contiguous() # ADDED .contiguous()
225
+ gold_for_loss_aligned = gold_standard_for_loss.contiguous()
226
 
227
  main_loss = criterion_main_app(logits_for_loss.view(-1, logits_for_loss.size(-1)), gold_for_loss_aligned.view(-1))
228
 
 
231
  for i, block_entropy_tensor in enumerate(entropy_report["block_output_entropies"]):
232
  target_entropy_val = swck_model_global.seed_parser.get_block_config(i)["target_entropy"]
233
  block_entropy_loss += F.mse_loss(block_entropy_tensor, torch.tensor(target_entropy_val, device=device_global))
234
+ if entropy_report["block_output_entropies"]:
235
  block_entropy_loss = block_entropy_loss / len(entropy_report["block_output_entropies"])
236
 
237
  overall_entropy_loss = entropy_report["overall_output_entropy"]
 
239
  if entropy_report["block_gate_weights"]:
240
  for gates_softmax_tensor in entropy_report["block_gate_weights"]:
241
  gate_sparsity_loss += torch.mean(gates_softmax_tensor * torch.log(gates_softmax_tensor + 1e-9))
242
+ if entropy_report["block_gate_weights"]:
243
  gate_sparsity_loss = - (gate_sparsity_loss / len(entropy_report["block_gate_weights"]))
244
 
245
  combined_loss = (MAIN_LOSS_WEIGHT_APP * main_loss +
 
253
  epoch_loss += combined_loss.item()
254
 
255
  log_line = f" Epoch {epoch+1}, Batch {batch_idx+1}/{len(app_dataloader)}, Loss: {combined_loss.item():.4f}"
256
+ if batch_idx % max(1, len(app_dataloader)//5) == 0 or batch_idx == len(app_dataloader)-1 :
257
  print(log_line)
258
  training_log_output += log_line + "\n"
259
 
 
260
  swck_model_global.debug_prints_enabled = False
261
+ if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = False
262
  for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
263
 
264
 
 
267
  print(epoch_summary)
268
  training_log_output += epoch_summary
269
 
 
270
  swck_model_global.debug_prints_enabled = False
271
+ if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = False
272
  for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
273
  swck_model_global.eval()
274
 
 
302
 
303
  swck_model_global.eval()
304
  swck_model_global.set_wiring_phase(False)
305
+
306
+ # Temporarily re-enable debug for generation if you want to inspect Space logs
307
+ # swck_model_global.debug_prints_enabled = True
308
+ # if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = True
309
  # for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = True
310
 
311
  print(f"App: Generating for prompt: '{prompt_str}', max_len: {max_len_gen}, temp: {temperature_gen}")
 
315
  debug_info_lines = [f"Prompt tokens: {generated_ids_app}"]
316
 
317
  with torch.no_grad():
318
+ for i in range(int(max_len_gen)):
 
 
319
  context_start_idx = max(0, len(generated_ids_app) - SEQ_LEN_APP)
320
  current_context_ids = generated_ids_app[context_start_idx:]
321
 
 
352
  else:
353
  debug_info_lines.append(f"Gen {i+1}: '{current_word}', OvrlEnt={overall_ent:.3f}, No block entropy/gate report.")
354
 
 
355
  generated_text_list = [idx_to_word_global.get(idx, UNK_TOKEN_STR) for idx in generated_ids_app[1:]]
356
  final_text = " ".join(generated_text_list)
357
  final_text = final_text.replace(EOS_TOKEN_STR, "").strip()
 
361
 
362
  debug_output_str = "\n".join(debug_info_lines)
363
 
364
+ # swck_model_global.debug_prints_enabled = False # Disable after generation
365
+ # if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = False
366
  # for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
367
  return final_text, debug_output_str
368
 
369
+ initial_load_status = initialize_or_load_model_app()
 
370
 
371
  with gr.Blocks(title="SWCK Conceptual Demo") as demo:
372
+ # Using a unique elem_id for the status Markdown
373
+ model_status_md = gr.Markdown(value=f"**Model Status:** {initial_load_status}", elem_id="model_status_md_123")
374
+
375
  gr.Markdown(f"""
376
  # Self-Wired Conscious Kernel (SWCK) - Conceptual Demo
377
  This demo showcases a conceptual text generation model.
378
  Seed Phrase: "{SEED_PHRASE_APP[:100]}..." | Seed Number: "{SEED_NUMBER_STR_APP}".
 
379
  (Note: If checkpoint is not found or fails to load, an *untrained* model is used.)
380
  """)
381
 
 
396
  with gr.Row():
397
  train_epochs_slider = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Training Epochs")
398
  train_batch_size_slider = gr.Slider(minimum=1, maximum=8, value=2, step=1, label="Training Batch Size")
 
399
  train_lr_slider = gr.Slider(minimum=1e-5, maximum=1e-3, value=5e-4, step=1e-5, label="Learning Rate")
400
 
401
  start_training_button = gr.Button("Start Short Training Session")
402
  training_status_output = gr.Textbox(label="Training Log / Status:", lines=10, interactive=False,show_label=True )
403
 
404
+ def update_status_text_for_ui():
405
+ # This function will be called by .then() to get the new status string
 
 
406
  return f"**Model Status:** {model_load_status_global}"
407
 
408
  generate_button.click(
 
415
  fn=run_short_training_session,
416
  inputs=[train_epochs_slider, train_batch_size_slider, train_lr_slider],
417
  outputs=[training_status_output]
418
+ ).then(fn=update_status_text_for_ui, inputs=None, outputs=model_status_md) # Update the Markdown component
419
 
420
 
421
  if __name__ == "__main__":
 
 
 
422
  demo.launch(debug=True)