neuralworm commited on
Commit
2495f32
·
verified ·
1 Parent(s): e1c0f4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -43
app.py CHANGED
@@ -49,7 +49,7 @@ idx_to_word_global = None
49
  device_global = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
  model_load_status_global = "Model not loaded."
51
 
52
- CHECKPOINT_FILENAME = "swck_model_conceptual_app_fulldebug.pth.tar" # New checkpoint name
53
 
54
  MAIN_LOSS_WEIGHT_APP = 1.0
55
  BLOCK_TARGET_ENTROPY_LOSS_WEIGHT_APP = 0.02
@@ -84,7 +84,8 @@ def build_vocab_from_corpus_text_app(corpus_text):
84
  print(f"App: Built vocab of size {VOCAB_SIZE_APP}")
85
  return temp_word_to_idx, temp_idx_to_word
86
 
87
- def initialize_or_load_model_app():
 
88
  global swck_model_global, optimizer_global, word_to_idx_global, idx_to_word_global, \
89
  VOCAB_SIZE_APP, model_load_status_global
90
 
@@ -103,17 +104,19 @@ def initialize_or_load_model_app():
103
  'num_sub_modules_per_block': NUM_SUB_MODULES_PER_BLOCK_APP
104
  }
105
 
106
- print("App: Initializing SWCKModel with FULL DEBUG ON by default for init...")
 
 
 
 
 
 
107
  swck_model_global = SWCKModel(**model_args).to(device_global)
108
- # Debug is on by default in SWCKModel and sub-components as per their class __init__
109
- # We can use set_model_debug_prints to confirm or change it if needed later.
110
- # For now, rely on their internal defaults being True.
111
- # If SeedParser or AdaptiveBlock have their debug_prints_enabled=False by default in model.py,
112
- # you would explicitly set them here:
113
- if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = True
114
- for block in swck_model_global.adaptive_blocks: block.debug_prints_enabled = True
115
- swck_model_global.debug_prints_enabled = True
116
- print("App: All model component debugs are intended to be ON by default from their init.")
117
 
118
 
119
  if os.path.exists(CHECKPOINT_FILENAME):
@@ -137,21 +140,29 @@ def initialize_or_load_model_app():
137
  print("App: Checkpoint vocab seems invalid, using app's rebuilt vocab.")
138
  else:
139
  print("App: word_to_idx not in checkpoint, using app's rebuilt vocab.")
 
 
 
 
 
 
140
 
141
  model_load_status_global = f"Model loaded successfully from {CHECKPOINT_FILENAME}."
142
  print(model_load_status_global)
143
  except Exception as e:
144
- print(f"App: Error loading model from checkpoint: {e}. Re-initializing new model with debug ON.")
145
  swck_model_global = SWCKModel(**model_args).to(device_global)
146
- if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = True
147
- for block in swck_model_global.adaptive_blocks: block.debug_prints_enabled = True
148
- swck_model_global.debug_prints_enabled = True
 
149
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
150
- model_load_status_global = "Error loading checkpoint. Using new (untrained) model with debug ON."
151
  else:
152
- print(f"App: Checkpoint {CHECKPOINT_FILENAME} not found. Initializing new model with debug ON.")
 
153
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
154
- model_load_status_global = "Initialized a new (untrained) model with debug ON."
155
 
156
  swck_model_global.eval()
157
  return model_load_status_global
@@ -191,13 +202,12 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
191
  print("\n--- App: Starting Short Training Session (Full Debug ON for ALL batches/epochs) ---")
192
  progress(0, desc="Preparing training data...")
193
 
194
- # Ensure debug prints are ON for the entire training session
195
- set_model_debug_prints(swck_model_global, True, True, True)
196
 
197
  training_corpus = SEED_PHRASE_APP + " " + EXTENDED_TEXT_FOR_TRAINING_APP
198
  app_dataset = AppSWCKDataset(training_corpus, word_to_idx_global, SEQ_LEN_APP, SOS_TOKEN, EOS_TOKEN, PAD_TOKEN)
199
  if not app_dataset.samples:
200
- set_model_debug_prints(swck_model_global, False, False, False) # Turn off if error
201
  return "App Training Error: No samples created from the corpus."
202
 
203
  app_dataloader = DataLoader(app_dataset, batch_size=int(batch_size_app), shuffle=True, collate_fn=app_swck_collate_fn)
@@ -219,8 +229,7 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
219
  print(f"\n>>> EPOCH {epoch+1} - Starting with Full Debug for all batches <<<")
220
 
221
  for batch_idx, (src_batch, tgt_batch) in enumerate(app_dataloader):
222
- # Debug prints are already set for the whole session by set_model_debug_prints above
223
- print(f"\n--- Training Batch {batch_idx+1}/{len(app_dataloader)} ---")
224
 
225
  src_batch, tgt_batch = src_batch.to(device_global), tgt_batch.to(device_global)
226
  decoder_input_tokens = src_batch[:, :-1]
@@ -268,7 +277,6 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
268
  epoch_loss += combined_loss.item()
269
 
270
  log_line = f" Epoch {epoch+1}, Batch {batch_idx+1}/{len(app_dataloader)}, Loss: {combined_loss.item():.4f}"
271
- # Print every batch to console due to full debug, but maybe less often to UI
272
  print(log_line)
273
  if batch_idx % max(1, len(app_dataloader)//2) == 0 or batch_idx == len(app_dataloader)-1 :
274
  training_log_output += log_line + "\n"
@@ -278,8 +286,6 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
278
  print(epoch_summary)
279
  training_log_output += epoch_summary
280
 
281
- # Set debug prints OFF after the entire training session for subsequent operations (like generation)
282
- # unless generation itself re-enables them.
283
  print("--- App: Training Session Finished. Setting debug prints OFF by default. ---")
284
  set_model_debug_prints(swck_model_global, False, False, False)
285
  swck_model_global.eval()
@@ -307,7 +313,7 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
307
 
308
  return training_log_output
309
 
310
- def generate_text_for_app(prompt_str, max_len_gen, temperature_gen): # Removed debug toggle, always ON
311
  global model_load_status_global
312
  if swck_model_global is None or word_to_idx_global is None or idx_to_word_global is None:
313
  return "Model not loaded. Please check server logs or try training.", "Model not available."
@@ -315,19 +321,18 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen): # Removed d
315
  swck_model_global.eval()
316
  swck_model_global.set_wiring_phase(False)
317
 
318
- # FULL DEBUG ON for generation
319
  print("\n--- App: Generating Text (Full Debug ON) ---")
320
- set_model_debug_prints(swck_model_global, True, True, True)
321
 
322
  print(f"App: Generating for prompt: '{prompt_str}', max_len: {max_len_gen}, temp: {temperature_gen}")
323
 
324
  tokens = [SOS_TOKEN] + [word_to_idx_global.get(w, UNK_TOKEN) for w in prompt_str.lower().split()]
325
  generated_ids_app = list(tokens)
326
- debug_info_lines = [f"Prompt tokens: {generated_ids_app}"] # For UI
327
 
328
  with torch.no_grad():
329
  for i in range(int(max_len_gen)):
330
- print(f"\n--- Generation Step {i+1} ---") # Console log for each step
331
  context_start_idx = max(0, len(generated_ids_app) - SEQ_LEN_APP)
332
  current_context_ids = generated_ids_app[context_start_idx:]
333
 
@@ -353,9 +358,9 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen): # Removed d
353
  generated_ids_app.append(next_token_id)
354
 
355
  current_word = idx_to_word_global.get(next_token_id, UNK_TOKEN_STR)
356
- print(f" ==> Generated token {i+1}: '{current_word}' (ID: {next_token_id})") # Console log
357
 
358
- if i < 10 : # UI debug info is still limited
359
  overall_ent = entropy_report_infer['overall_output_entropy'].item()
360
  if entropy_report_infer['block_output_entropies'] and len(entropy_report_infer['block_output_entropies']) > 0:
361
  b0_ent = entropy_report_infer['block_output_entropies'][0].item()
@@ -377,12 +382,11 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen): # Removed d
377
  debug_output_str = "\n".join(debug_info_lines)
378
 
379
  print("--- App: Generation Finished. Setting debug prints OFF by default. ---")
380
- set_model_debug_prints(swck_model_global, False, False, False) # Turn off after this call
381
  return final_text, debug_output_str
382
 
383
- # Initialize model with debug OFF for initial startup to keep logs clean,
384
- # will be turned ON by training/generation functions.
385
- initial_load_status = initialize_or_load_model_app()
386
 
387
  with gr.Blocks(title="SWCK Conceptual Demo") as demo:
388
  model_status_md = gr.Markdown(value=f"**Model Status:** {initial_load_status}", elem_id="model_status_md_123")
@@ -398,9 +402,9 @@ with gr.Blocks(title="SWCK Conceptual Demo") as demo:
398
  with gr.TabItem("Generate Text"):
399
  with gr.Row():
400
  prompt_input = gr.Textbox(label="Enter your prompt:", placeholder="e.g., the meaning of existence is", scale=3)
401
- # Removed debug checkbox as it's on by default for console
402
  with gr.Row():
403
- generate_button = gr.Button("Generate", scale=1)
404
  with gr.Row():
405
  max_len_slider = gr.Slider(minimum=10, maximum=150, value=50, step=1, label="Max Generation Length")
406
  temp_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.1, label="Temperature (0 for greedy)")
@@ -422,8 +426,8 @@ with gr.Blocks(title="SWCK Conceptual Demo") as demo:
422
  return f"**Model Status:** {model_load_status_global}"
423
 
424
  generate_button.click(
425
- fn=generate_text_for_app, # Removed enable_gen_debug from inputs
426
- inputs=[prompt_input, max_len_slider, temp_slider],
427
  outputs=[output_text, debug_text_area]
428
  )
429
 
@@ -435,4 +439,4 @@ with gr.Blocks(title="SWCK Conceptual Demo") as demo:
435
 
436
 
437
  if __name__ == "__main__":
438
- demo.launch(debug=True) # Gradio server debug
 
49
  device_global = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
  model_load_status_global = "Model not loaded."
51
 
52
+ CHECKPOINT_FILENAME = "swck_model_conceptual_app_fulldebug.pth.tar"
53
 
54
  MAIN_LOSS_WEIGHT_APP = 1.0
55
  BLOCK_TARGET_ENTROPY_LOSS_WEIGHT_APP = 0.02
 
84
  print(f"App: Built vocab of size {VOCAB_SIZE_APP}")
85
  return temp_word_to_idx, temp_idx_to_word
86
 
87
+ # CORRECTED FUNCTION DEFINITION
88
+ def initialize_or_load_model_app(enable_initial_debug=True):
89
  global swck_model_global, optimizer_global, word_to_idx_global, idx_to_word_global, \
90
  VOCAB_SIZE_APP, model_load_status_global
91
 
 
104
  'num_sub_modules_per_block': NUM_SUB_MODULES_PER_BLOCK_APP
105
  }
106
 
107
+ if enable_initial_debug:
108
+ print("App: Initializing SWCKModel with FULL DEBUG ON by default for init...")
109
+
110
+ # Temporarily disable sub-component debug before SWCKModel init if enable_initial_debug is False,
111
+ # so SWCKModel's own init prints don't get mixed with sub-component init prints prematurely.
112
+ # SeedParser's internal debug_prints_enabled will control its own prints during its __init__.
113
+
114
  swck_model_global = SWCKModel(**model_args).to(device_global)
115
+ # Now set the debug states for all components based on enable_initial_debug
116
+ set_model_debug_prints(swck_model_global,
117
+ seed_parser_debug=enable_initial_debug,
118
+ block_debug=enable_initial_debug,
119
+ model_debug=enable_initial_debug)
 
 
 
 
120
 
121
 
122
  if os.path.exists(CHECKPOINT_FILENAME):
 
140
  print("App: Checkpoint vocab seems invalid, using app's rebuilt vocab.")
141
  else:
142
  print("App: word_to_idx not in checkpoint, using app's rebuilt vocab.")
143
+
144
+ # Ensure debug states are correctly set after loading
145
+ set_model_debug_prints(swck_model_global,
146
+ seed_parser_debug=enable_initial_debug,
147
+ block_debug=enable_initial_debug,
148
+ model_debug=enable_initial_debug)
149
 
150
  model_load_status_global = f"Model loaded successfully from {CHECKPOINT_FILENAME}."
151
  print(model_load_status_global)
152
  except Exception as e:
153
+ print(f"App: Error loading model from checkpoint: {e}. Re-initializing new model with debug state: {enable_initial_debug}.")
154
  swck_model_global = SWCKModel(**model_args).to(device_global)
155
+ set_model_debug_prints(swck_model_global,
156
+ seed_parser_debug=enable_initial_debug,
157
+ block_debug=enable_initial_debug,
158
+ model_debug=enable_initial_debug)
159
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
160
+ model_load_status_global = f"Error loading checkpoint. Using new (untrained) model with debug: {enable_initial_debug}."
161
  else:
162
+ print(f"App: Checkpoint {CHECKPOINT_FILENAME} not found. Initializing new model with debug state: {enable_initial_debug}.")
163
+ # set_model_debug_prints was already called for a new model above
164
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
165
+ model_load_status_global = f"Initialized a new (untrained) model with debug: {enable_initial_debug}."
166
 
167
  swck_model_global.eval()
168
  return model_load_status_global
 
202
  print("\n--- App: Starting Short Training Session (Full Debug ON for ALL batches/epochs) ---")
203
  progress(0, desc="Preparing training data...")
204
 
205
+ set_model_debug_prints(swck_model_global, True, True, True) # DEBUG ALWAYS ON FOR TRAINING
 
206
 
207
  training_corpus = SEED_PHRASE_APP + " " + EXTENDED_TEXT_FOR_TRAINING_APP
208
  app_dataset = AppSWCKDataset(training_corpus, word_to_idx_global, SEQ_LEN_APP, SOS_TOKEN, EOS_TOKEN, PAD_TOKEN)
209
  if not app_dataset.samples:
210
+ set_model_debug_prints(swck_model_global, False, False, False)
211
  return "App Training Error: No samples created from the corpus."
212
 
213
  app_dataloader = DataLoader(app_dataset, batch_size=int(batch_size_app), shuffle=True, collate_fn=app_swck_collate_fn)
 
229
  print(f"\n>>> EPOCH {epoch+1} - Starting with Full Debug for all batches <<<")
230
 
231
  for batch_idx, (src_batch, tgt_batch) in enumerate(app_dataloader):
232
+ print(f"\n--- Training Batch {batch_idx+1}/{len(app_dataloader)} ---") # Explicit batch print
 
233
 
234
  src_batch, tgt_batch = src_batch.to(device_global), tgt_batch.to(device_global)
235
  decoder_input_tokens = src_batch[:, :-1]
 
277
  epoch_loss += combined_loss.item()
278
 
279
  log_line = f" Epoch {epoch+1}, Batch {batch_idx+1}/{len(app_dataloader)}, Loss: {combined_loss.item():.4f}"
 
280
  print(log_line)
281
  if batch_idx % max(1, len(app_dataloader)//2) == 0 or batch_idx == len(app_dataloader)-1 :
282
  training_log_output += log_line + "\n"
 
286
  print(epoch_summary)
287
  training_log_output += epoch_summary
288
 
 
 
289
  print("--- App: Training Session Finished. Setting debug prints OFF by default. ---")
290
  set_model_debug_prints(swck_model_global, False, False, False)
291
  swck_model_global.eval()
 
313
 
314
  return training_log_output
315
 
316
+ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
317
  global model_load_status_global
318
  if swck_model_global is None or word_to_idx_global is None or idx_to_word_global is None:
319
  return "Model not loaded. Please check server logs or try training.", "Model not available."
 
321
  swck_model_global.eval()
322
  swck_model_global.set_wiring_phase(False)
323
 
 
324
  print("\n--- App: Generating Text (Full Debug ON) ---")
325
+ set_model_debug_prints(swck_model_global, True, True, True) # DEBUG ALWAYS ON FOR GENERATION
326
 
327
  print(f"App: Generating for prompt: '{prompt_str}', max_len: {max_len_gen}, temp: {temperature_gen}")
328
 
329
  tokens = [SOS_TOKEN] + [word_to_idx_global.get(w, UNK_TOKEN) for w in prompt_str.lower().split()]
330
  generated_ids_app = list(tokens)
331
+ debug_info_lines = [f"Prompt tokens: {generated_ids_app}"]
332
 
333
  with torch.no_grad():
334
  for i in range(int(max_len_gen)):
335
+ print(f"\n--- Generation Step {i+1} ---")
336
  context_start_idx = max(0, len(generated_ids_app) - SEQ_LEN_APP)
337
  current_context_ids = generated_ids_app[context_start_idx:]
338
 
 
358
  generated_ids_app.append(next_token_id)
359
 
360
  current_word = idx_to_word_global.get(next_token_id, UNK_TOKEN_STR)
361
+ print(f" ==> Generated token {i+1}: '{current_word}' (ID: {next_token_id})")
362
 
363
+ if i < 10 :
364
  overall_ent = entropy_report_infer['overall_output_entropy'].item()
365
  if entropy_report_infer['block_output_entropies'] and len(entropy_report_infer['block_output_entropies']) > 0:
366
  b0_ent = entropy_report_infer['block_output_entropies'][0].item()
 
382
  debug_output_str = "\n".join(debug_info_lines)
383
 
384
  print("--- App: Generation Finished. Setting debug prints OFF by default. ---")
385
+ set_model_debug_prints(swck_model_global, False, False, False)
386
  return final_text, debug_output_str
387
 
388
+ # Initialize model. Set enable_initial_debug=True for verbose init logs.
389
+ initial_load_status = initialize_or_load_model_app(enable_initial_debug=True)
 
390
 
391
  with gr.Blocks(title="SWCK Conceptual Demo") as demo:
392
  model_status_md = gr.Markdown(value=f"**Model Status:** {initial_load_status}", elem_id="model_status_md_123")
 
402
  with gr.TabItem("Generate Text"):
403
  with gr.Row():
404
  prompt_input = gr.Textbox(label="Enter your prompt:", placeholder="e.g., the meaning of existence is", scale=3)
405
+ # Removed debug checkbox from here
406
  with gr.Row():
407
+ generate_button = gr.Button("Generate (Full Debug to Console)", scale=1) # Updated button label
408
  with gr.Row():
409
  max_len_slider = gr.Slider(minimum=10, maximum=150, value=50, step=1, label="Max Generation Length")
410
  temp_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.1, label="Temperature (0 for greedy)")
 
426
  return f"**Model Status:** {model_load_status_global}"
427
 
428
  generate_button.click(
429
+ fn=generate_text_for_app,
430
+ inputs=[prompt_input, max_len_slider, temp_slider], # Removed checkbox from inputs
431
  outputs=[output_text, debug_text_area]
432
  )
433
 
 
439
 
440
 
441
  if __name__ == "__main__":
442
+ demo.launch(debug=True)