neuralworm commited on
Commit
026247e
·
verified ·
1 Parent(s): 144d8b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -37
app.py CHANGED
@@ -57,6 +57,17 @@ OVERALL_OUTPUT_ENTROPY_REG_WEIGHT_APP = 0.01
57
  GATE_SPARSITY_LOSS_WEIGHT_APP = 0.001
58
  WIRING_PHASE_EPOCHS_APP = 1
59
 
 
 
 
 
 
 
 
 
 
 
 
60
  def build_vocab_from_corpus_text_app(corpus_text):
61
  global VOCAB_SIZE_APP
62
  print("App: Building vocabulary...")
@@ -73,7 +84,8 @@ def build_vocab_from_corpus_text_app(corpus_text):
73
  print(f"App: Built vocab of size {VOCAB_SIZE_APP}")
74
  return temp_word_to_idx, temp_idx_to_word
75
 
76
- def initialize_or_load_model_app():
 
77
  global swck_model_global, optimizer_global, word_to_idx_global, idx_to_word_global, \
78
  VOCAB_SIZE_APP, model_load_status_global
79
 
@@ -92,19 +104,14 @@ def initialize_or_load_model_app():
92
  'num_sub_modules_per_block': NUM_SUB_MODULES_PER_BLOCK_APP
93
  }
94
 
95
- print("App: Initializing SWCKModel. Debug prints are ON by default in model components.")
 
96
 
97
  swck_model_global = SWCKModel(**model_args).to(device_global)
98
- # Ensure debug flags are True on all components after initialization
99
- # (assuming model.py might have them False by default, this makes them True)
100
- if swck_model_global:
101
- swck_model_global.debug_prints_enabled = True
102
- if hasattr(swck_model_global, 'seed_parser'):
103
- swck_model_global.seed_parser.debug_prints_enabled = True
104
- if hasattr(swck_model_global, 'adaptive_blocks'):
105
- for block in swck_model_global.adaptive_blocks:
106
- block.debug_prints_enabled = True
107
- print("App: Confirmed debug prints ON for SWCKModel and its components.")
108
 
109
 
110
  if os.path.exists(CHECKPOINT_FILENAME):
@@ -129,30 +136,26 @@ def initialize_or_load_model_app():
129
  else:
130
  print("App: word_to_idx not in checkpoint, using app's rebuilt vocab.")
131
 
132
- # After loading, ensure debug flags are still True
133
- if swck_model_global:
134
- swck_model_global.debug_prints_enabled = True
135
- if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = True
136
- for block in swck_model_global.adaptive_blocks: block.debug_prints_enabled = True
137
- print("App: Re-confirmed debug prints ON after loading checkpoint.")
138
-
139
 
140
  model_load_status_global = f"Model loaded successfully from {CHECKPOINT_FILENAME}."
141
  print(model_load_status_global)
142
  except Exception as e:
143
  print(f"App: Error loading model from checkpoint: {e}. Re-initializing new model.")
144
  swck_model_global = SWCKModel(**model_args).to(device_global)
145
- if swck_model_global: # Ensure debug is on for the new instance too
146
- swck_model_global.debug_prints_enabled = True
147
- if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = True
148
- for block in swck_model_global.adaptive_blocks: block.debug_prints_enabled = True
149
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
150
- model_load_status_global = "Error loading checkpoint. Using new (untrained) model."
151
  else:
152
- print(f"App: Checkpoint {CHECKPOINT_FILENAME} not found. Initializing new model.")
153
- # Debug flags already set for a new model instance above
154
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
155
- model_load_status_global = "Initialized a new (untrained) model."
156
 
157
  swck_model_global.eval()
158
  return model_load_status_global
@@ -192,11 +195,13 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
192
  print("\n--- App: Starting Short Training Session (Full Debug ON for ALL batches/epochs by default) ---")
193
  progress(0, desc="Preparing training data...")
194
 
195
- # Model debug flags are assumed to be already ON from initialize_or_load_model_app()
 
196
 
197
  training_corpus = SEED_PHRASE_APP + " " + EXTENDED_TEXT_FOR_TRAINING_APP
198
  app_dataset = AppSWCKDataset(training_corpus, word_to_idx_global, SEQ_LEN_APP, SOS_TOKEN, EOS_TOKEN, PAD_TOKEN)
199
  if not app_dataset.samples:
 
200
  return "App Training Error: No samples created from the corpus."
201
 
202
  app_dataloader = DataLoader(app_dataset, batch_size=int(batch_size_app), shuffle=True, collate_fn=app_swck_collate_fn)
@@ -215,10 +220,9 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
215
  for epoch in progress.tqdm(range(int(num_epochs_app)), desc="Training Epochs"):
216
  swck_model_global.set_wiring_phase(epoch < WIRING_PHASE_EPOCHS_APP)
217
  epoch_loss = 0.0
218
- # No need to toggle debug here; it's globally on for the model instance
219
 
220
  for batch_idx, (src_batch, tgt_batch) in enumerate(app_dataloader):
221
- # Print statements within model.py's forward methods will now trigger automatically
222
  print(f"\n--- Training Batch {batch_idx+1}/{len(app_dataloader)} (Epoch {epoch+1}) ---")
223
 
224
  src_batch, tgt_batch = src_batch.to(device_global), tgt_batch.to(device_global)
@@ -267,17 +271,18 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
267
  epoch_loss += combined_loss.item()
268
 
269
  log_line = f" Epoch {epoch+1}, Batch {batch_idx+1}/{len(app_dataloader)}, Loss: {combined_loss.item():.4f}"
270
- print(log_line) # This will go to console
271
  if batch_idx % max(1, len(app_dataloader)//2) == 0 or batch_idx == len(app_dataloader)-1 :
272
- training_log_output += log_line + "\n" # Summary to UI
273
 
274
  avg_epoch_loss = epoch_loss / len(app_dataloader) if len(app_dataloader) > 0 else epoch_loss
275
  epoch_summary = f"Epoch {epoch+1}/{num_epochs_app} - Avg Loss: {avg_epoch_loss:.4f}\n"
276
  print(epoch_summary)
277
  training_log_output += epoch_summary
278
 
 
 
279
  print("--- App: Training Session Finished. Debug prints remain ON for the model instance. ---")
280
- # No need to turn off debugs here if they are meant to be globally on for the app session
281
  swck_model_global.eval()
282
 
283
  try:
@@ -311,7 +316,7 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
311
  swck_model_global.eval()
312
  swck_model_global.set_wiring_phase(False)
313
 
314
- # Model debug flags are assumed to be already ON globally from initialize_or_load_model_app()
315
  print("\n--- App: Generating Text (Full Debug ON by default) ---")
316
  print(f"App: Generating for prompt: '{prompt_str}', max_len: {max_len_gen}, temp: {temperature_gen}")
317
 
@@ -321,7 +326,6 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
321
 
322
  with torch.no_grad():
323
  for i in range(int(max_len_gen)):
324
- # Print statements inside SWCKModel's forward and AdaptiveBlock's forward will trigger
325
  print(f"\n--- Generation Step {i+1} ---")
326
  context_start_idx = max(0, len(generated_ids_app) - SEQ_LEN_APP)
327
  current_context_ids = generated_ids_app[context_start_idx:]
@@ -371,11 +375,11 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
371
 
372
  debug_output_str = "\n".join(debug_info_lines)
373
 
374
- # Debug flags remain ON for the model instance for subsequent calls unless changed elsewhere
375
  print("--- App: Generation Finished. Debug prints remain ON for the model instance. ---")
 
376
  return final_text, debug_output_str
377
 
378
- # Initialize model with debug ON by default for the whole app session
379
  initial_load_status = initialize_or_load_model_app(enable_initial_debug=True)
380
 
381
  with gr.Blocks(title="SWCK Conceptual Demo") as demo:
 
57
  GATE_SPARSITY_LOSS_WEIGHT_APP = 0.001
58
  WIRING_PHASE_EPOCHS_APP = 1
59
 
60
+ def set_model_debug_prints(model, seed_parser_debug, block_debug, model_debug):
61
+ if model:
62
+ model.debug_prints_enabled = model_debug
63
+ if hasattr(model, 'seed_parser'):
64
+ model.seed_parser.debug_prints_enabled = seed_parser_debug
65
+ if hasattr(model, 'adaptive_blocks'):
66
+ for block_component in model.adaptive_blocks: # Renamed to avoid conflict
67
+ block_component.debug_prints_enabled = block_debug
68
+ print(f"App: Model debug prints set - SeedParser: {seed_parser_debug}, Blocks: {block_debug}, SWCKModel: {model_debug}")
69
+
70
+
71
  def build_vocab_from_corpus_text_app(corpus_text):
72
  global VOCAB_SIZE_APP
73
  print("App: Building vocabulary...")
 
84
  print(f"App: Built vocab of size {VOCAB_SIZE_APP}")
85
  return temp_word_to_idx, temp_idx_to_word
86
 
87
+ # CORRECTED FUNCTION DEFINITION: Added enable_initial_debug parameter
88
+ def initialize_or_load_model_app(enable_initial_debug=True):
89
  global swck_model_global, optimizer_global, word_to_idx_global, idx_to_word_global, \
90
  VOCAB_SIZE_APP, model_load_status_global
91
 
 
104
  'num_sub_modules_per_block': NUM_SUB_MODULES_PER_BLOCK_APP
105
  }
106
 
107
+ if enable_initial_debug: # This print will now work correctly
108
+ print("App: Initializing SWCKModel with FULL DEBUG ON by default for init...")
109
 
110
  swck_model_global = SWCKModel(**model_args).to(device_global)
111
+ set_model_debug_prints(swck_model_global,
112
+ seed_parser_debug=enable_initial_debug,
113
+ block_debug=enable_initial_debug,
114
+ model_debug=enable_initial_debug)
 
 
 
 
 
 
115
 
116
 
117
  if os.path.exists(CHECKPOINT_FILENAME):
 
136
  else:
137
  print("App: word_to_idx not in checkpoint, using app's rebuilt vocab.")
138
 
139
+ set_model_debug_prints(swck_model_global,
140
+ seed_parser_debug=enable_initial_debug,
141
+ block_debug=enable_initial_debug,
142
+ model_debug=enable_initial_debug)
 
 
 
143
 
144
  model_load_status_global = f"Model loaded successfully from {CHECKPOINT_FILENAME}."
145
  print(model_load_status_global)
146
  except Exception as e:
147
  print(f"App: Error loading model from checkpoint: {e}. Re-initializing new model.")
148
  swck_model_global = SWCKModel(**model_args).to(device_global)
149
+ set_model_debug_prints(swck_model_global,
150
+ seed_parser_debug=enable_initial_debug,
151
+ block_debug=enable_initial_debug,
152
+ model_debug=enable_initial_debug)
153
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
154
+ model_load_status_global = f"Error loading checkpoint. Using new (untrained) model with debug: {enable_initial_debug}."
155
  else:
156
+ print(f"App: Checkpoint {CHECKPOINT_FILENAME} not found. Initializing new model with debug state: {enable_initial_debug}.")
 
157
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
158
+ model_load_status_global = f"Initialized a new (untrained) model with debug: {enable_initial_debug}."
159
 
160
  swck_model_global.eval()
161
  return model_load_status_global
 
195
  print("\n--- App: Starting Short Training Session (Full Debug ON for ALL batches/epochs by default) ---")
196
  progress(0, desc="Preparing training data...")
197
 
198
+ # Ensure debug prints are ON for the entire training session
199
+ set_model_debug_prints(swck_model_global, True, True, True)
200
 
201
  training_corpus = SEED_PHRASE_APP + " " + EXTENDED_TEXT_FOR_TRAINING_APP
202
  app_dataset = AppSWCKDataset(training_corpus, word_to_idx_global, SEQ_LEN_APP, SOS_TOKEN, EOS_TOKEN, PAD_TOKEN)
203
  if not app_dataset.samples:
204
+ set_model_debug_prints(swck_model_global, False, False, False) # Turn off if error before training starts
205
  return "App Training Error: No samples created from the corpus."
206
 
207
  app_dataloader = DataLoader(app_dataset, batch_size=int(batch_size_app), shuffle=True, collate_fn=app_swck_collate_fn)
 
220
  for epoch in progress.tqdm(range(int(num_epochs_app)), desc="Training Epochs"):
221
  swck_model_global.set_wiring_phase(epoch < WIRING_PHASE_EPOCHS_APP)
222
  epoch_loss = 0.0
223
+ print(f"\n>>> EPOCH {epoch+1} - Starting with Full Debug for all batches <<<")
224
 
225
  for batch_idx, (src_batch, tgt_batch) in enumerate(app_dataloader):
 
226
  print(f"\n--- Training Batch {batch_idx+1}/{len(app_dataloader)} (Epoch {epoch+1}) ---")
227
 
228
  src_batch, tgt_batch = src_batch.to(device_global), tgt_batch.to(device_global)
 
271
  epoch_loss += combined_loss.item()
272
 
273
  log_line = f" Epoch {epoch+1}, Batch {batch_idx+1}/{len(app_dataloader)}, Loss: {combined_loss.item():.4f}"
274
+ print(log_line)
275
  if batch_idx % max(1, len(app_dataloader)//2) == 0 or batch_idx == len(app_dataloader)-1 :
276
+ training_log_output += log_line + "\n"
277
 
278
  avg_epoch_loss = epoch_loss / len(app_dataloader) if len(app_dataloader) > 0 else epoch_loss
279
  epoch_summary = f"Epoch {epoch+1}/{num_epochs_app} - Avg Loss: {avg_epoch_loss:.4f}\n"
280
  print(epoch_summary)
281
  training_log_output += epoch_summary
282
 
283
+ # After training, leave debug ON as per request for "default ON" for the app instance.
284
+ # If you wanted it off after training, you'd call set_model_debug_prints(..., False, False, False)
285
  print("--- App: Training Session Finished. Debug prints remain ON for the model instance. ---")
 
286
  swck_model_global.eval()
287
 
288
  try:
 
316
  swck_model_global.eval()
317
  swck_model_global.set_wiring_phase(False)
318
 
319
+ # Debug is assumed to be ON from initialization for the model instance
320
  print("\n--- App: Generating Text (Full Debug ON by default) ---")
321
  print(f"App: Generating for prompt: '{prompt_str}', max_len: {max_len_gen}, temp: {temperature_gen}")
322
 
 
326
 
327
  with torch.no_grad():
328
  for i in range(int(max_len_gen)):
 
329
  print(f"\n--- Generation Step {i+1} ---")
330
  context_start_idx = max(0, len(generated_ids_app) - SEQ_LEN_APP)
331
  current_context_ids = generated_ids_app[context_start_idx:]
 
375
 
376
  debug_output_str = "\n".join(debug_info_lines)
377
 
 
378
  print("--- App: Generation Finished. Debug prints remain ON for the model instance. ---")
379
+ # No need to turn off debugs if they are globally ON for the app session
380
  return final_text, debug_output_str
381
 
382
+ # Initialize model with debug ON by default for the entire app session
383
  initial_load_status = initialize_or_load_model_app(enable_initial_debug=True)
384
 
385
  with gr.Blocks(title="SWCK Conceptual Demo") as demo: