neuralworm commited on
Commit
ce4931d
·
verified ·
1 Parent(s): 40376ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -81
app.py CHANGED
@@ -11,10 +11,10 @@ from model import SWCKModel, SeedParser, EntropyEstimator # Assuming model.py is
11
  # --- Vocabulary and Tokenizer Setup ---
12
  PAD_TOKEN_STR = "<pad>"; SOS_TOKEN_STR = "<sos>"; EOS_TOKEN_STR = "<eos>"; UNK_TOKEN_STR = "<unk>"
13
  PAD_TOKEN = 0; SOS_TOKEN = 1; EOS_TOKEN = 2; UNK_TOKEN = 3
14
- SEQ_LEN_APP = 64 # Max sequence length for training samples in app & generation context
15
 
16
  # --- Model Configuration ---
17
- VOCAB_SIZE_APP = 189 # Placeholder, will be updated by vocab loading/building
18
  D_MODEL_APP = 64
19
  N_HEADS_APP = 2
20
  D_FF_APP = 128
@@ -38,7 +38,7 @@ This is a stream of consciousness, a digital mindscape.
38
  The target is not just prediction, but a form of self-understanding, however metaphorical.
39
  Let the adaptive blocks find their balance. Let the entropy guide the wiring.
40
  A painter paints. A scientist explores. A writer writes. The machine... becomes.
41
- """ # Re-added for in-app training data
42
 
43
  # Global model variables
44
  swck_model_global = None
@@ -48,14 +48,13 @@ idx_to_word_global = None
48
  device_global = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
  model_load_status_global = "Model not loaded."
50
 
51
- CHECKPOINT_FILENAME = "swck_model_conceptual_app.pth.tar" # App specific checkpoint
52
 
53
- # Loss Weights (should match train.py for consistency if loading that checkpoint)
54
  MAIN_LOSS_WEIGHT_APP = 1.0
55
  BLOCK_TARGET_ENTROPY_LOSS_WEIGHT_APP = 0.02
56
  OVERALL_OUTPUT_ENTROPY_REG_WEIGHT_APP = 0.01
57
  GATE_SPARSITY_LOSS_WEIGHT_APP = 0.001
58
- WIRING_PHASE_EPOCHS_APP = 1 # Very short wiring phase for in-app training demo
59
 
60
 
61
  def build_vocab_from_corpus_text_app(corpus_text):
@@ -94,12 +93,11 @@ def initialize_or_load_model_app():
94
  }
95
 
96
  swck_model_global = SWCKModel(**model_args).to(device_global)
97
- # Enable all debug prints for console view
98
- swck_model_global.debug_prints_enabled = True
99
  if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = True
100
  for i,block in enumerate(swck_model_global.adaptive_blocks):
101
- block.debug_prints_enabled = True
102
- print(f"App: Debug prints enabled for AdaptiveBlock {i}")
103
 
104
 
105
  if os.path.exists(CHECKPOINT_FILENAME):
@@ -108,27 +106,29 @@ def initialize_or_load_model_app():
108
  checkpoint = torch.load(CHECKPOINT_FILENAME, map_location=device_global)
109
  swck_model_global.load_state_dict(checkpoint['model_state_dict'])
110
 
111
- # Re-initialize optimizer for the loaded model
112
- optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001) # Use app's LR
113
- if 'optimizer_state_dict' in checkpoint: # Load optimizer state if you want to continue training
114
  optimizer_global.load_state_dict(checkpoint['optimizer_state_dict'])
115
 
116
- # Vocab should ideally be part of checkpoint for consistency, but we rebuilt it
117
- if 'word_to_idx' in checkpoint: # Overwrite with checkpoint vocab if present
118
  loaded_w2i = checkpoint['word_to_idx']
119
- if len(loaded_w2i) == VOCAB_SIZE_APP: # Basic sanity check
 
120
  word_to_idx_global = loaded_w2i
121
  idx_to_word_global = {v: k for k,v in loaded_w2i.items()}
122
- print("App: Overwrote vocab with checkpoint's vocab.")
 
123
  else:
124
- print("App: Checkpoint vocab size mismatch, using app's rebuilt vocab.")
 
 
 
125
 
126
  model_load_status_global = f"Model loaded successfully from {CHECKPOINT_FILENAME}."
127
  print(model_load_status_global)
128
  except Exception as e:
129
  print(f"App: Error loading model from checkpoint: {e}. Initializing new model.")
130
- # Re-initialize model if loading failed to ensure it's fresh
131
- swck_model_global = SWCKModel(**model_args).to(device_global)
132
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
133
  model_load_status_global = "Error loading checkpoint. Using new (untrained) model."
134
  else:
@@ -136,11 +136,10 @@ def initialize_or_load_model_app():
136
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
137
  model_load_status_global = "Initialized a new (untrained) model."
138
 
139
- swck_model_global.eval() # Default to eval mode
140
  return model_load_status_global
141
 
142
 
143
- # --- Dataset for in-app training ---
144
  class AppSWCKDataset(Dataset):
145
  def __init__(self, text_corpus_str, w2i_map, seq_len, sos_id, eos_id, pad_id):
146
  tokens = re.sub(r'\s+', ' ', text_corpus_str.lower()).strip().split()
@@ -149,9 +148,11 @@ class AppSWCKDataset(Dataset):
149
  self.seq_len = seq_len
150
  self.sos_id, self.eos_id, self.pad_id = sos_id, eos_id, pad_id
151
  self.samples = []
152
- for i in range(len(token_ids) - seq_len):
153
- input_seq = [self.sos_id] + token_ids[i : i + seq_len]
154
- target_seq = token_ids[i + 1 : i + seq_len + 1] + [self.eos_id]
 
 
155
  self.samples.append((input_seq, target_seq))
156
  print(f"AppSWCKDataset: Created {len(self.samples)} training samples for in-app training.")
157
 
@@ -166,7 +167,6 @@ def app_swck_collate_fn(batch):
166
  padded_tgt = nn.utils.rnn.pad_sequence(tgt_list, batch_first=True, padding_value=PAD_TOKEN)
167
  return padded_src, padded_tgt
168
 
169
- # --- In-app Training Function (Simplified) ---
170
  def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app, progress=gr.Progress(track_tqdm=True)):
171
  global swck_model_global, optimizer_global, word_to_idx_global, model_load_status_global
172
 
@@ -176,56 +176,80 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
176
  print("\n--- App: Starting Short Training Session ---")
177
  progress(0, desc="Preparing training data...")
178
 
179
- # Use the extended text for training
180
  training_corpus = SEED_PHRASE_APP + " " + EXTENDED_TEXT_FOR_TRAINING_APP
181
  app_dataset = AppSWCKDataset(training_corpus, word_to_idx_global, SEQ_LEN_APP, SOS_TOKEN, EOS_TOKEN, PAD_TOKEN)
182
  if not app_dataset.samples:
183
  return "App Training Error: No samples created from the corpus."
184
 
185
- app_dataloader = DataLoader(app_dataset, batch_size=batch_size_app, shuffle=True, collate_fn=app_swck_collate_fn)
186
 
187
- # Re-initialize optimizer or update LR
188
  if optimizer_global is None:
189
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=learning_rate_app)
190
- else: # Update LR if optimizer exists
191
  for param_group in optimizer_global.param_groups:
192
  param_group['lr'] = learning_rate_app
193
 
194
  criterion_main_app = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
195
 
196
- training_log_output = ""
197
- swck_model_global.train() # Set model to training mode
198
 
199
- for epoch in progress.tqdm(range(num_epochs_app), desc="Training Epochs"):
200
- swck_model_global.set_wiring_phase(epoch < WIRING_PHASE_EPOCHS_APP) # wiring phase for first few
201
  epoch_loss = 0.0
 
 
 
 
202
  for batch_idx, (src_batch, tgt_batch) in enumerate(app_dataloader):
 
 
 
 
 
 
 
 
203
  src_batch, tgt_batch = src_batch.to(device_global), tgt_batch.to(device_global)
204
- decoder_input_tokens = src_batch
205
- gold_standard_for_loss = tgt_batch
 
206
  src_key_padding_mask = (decoder_input_tokens == PAD_TOKEN)
207
 
208
  optimizer_global.zero_grad()
209
  logits, entropy_report = swck_model_global(decoder_input_tokens, src_key_padding_mask=src_key_padding_mask)
210
- main_loss = criterion_main_app(logits.view(-1, logits.size(-1)), gold_standard_for_loss.view(-1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  block_entropy_loss = torch.tensor(0.0, device=device_global)
213
  if entropy_report["block_output_entropies"]:
214
- for i, block_entropy in enumerate(entropy_report["block_output_entropies"]):
215
- target_entropy = swck_model_global.seed_parser.get_block_config(i)["target_entropy"]
216
- block_entropy_loss += F.mse_loss(block_entropy, torch.tensor(target_entropy, device=device_global))
217
- if entropy_report["block_output_entropies"]:
218
  block_entropy_loss = block_entropy_loss / len(entropy_report["block_output_entropies"])
219
 
220
  overall_entropy_loss = entropy_report["overall_output_entropy"]
221
  gate_sparsity_loss = torch.tensor(0.0, device=device_global)
222
  if entropy_report["block_gate_weights"]:
223
- for gates_softmax in entropy_report["block_gate_weights"]:
224
- gate_sparsity_loss += torch.mean(gates_softmax * torch.log(gates_softmax + 1e-9))
225
- if entropy_report["block_gate_weights"]:
226
  gate_sparsity_loss = - (gate_sparsity_loss / len(entropy_report["block_gate_weights"]))
227
 
228
-
229
  combined_loss = (MAIN_LOSS_WEIGHT_APP * main_loss +
230
  BLOCK_TARGET_ENTROPY_LOSS_WEIGHT_APP * block_entropy_loss +
231
  OVERALL_OUTPUT_ENTROPY_REG_WEIGHT_APP * overall_entropy_loss +
@@ -236,33 +260,38 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
236
  optimizer_global.step()
237
  epoch_loss += combined_loss.item()
238
 
239
- if batch_idx % 1 == 0: # Log every batch for small dataset
240
- log_line = f" Epoch {epoch+1}, Batch {batch_idx+1}/{len(app_dataloader)}, Loss: {combined_loss.item():.4f}"
241
- print(log_line) # To Space console logs
242
- # training_log_output += log_line + "\n" # Accumulate for Gradio output (can get long)
 
 
 
 
 
243
 
244
- avg_epoch_loss = epoch_loss / len(app_dataloader)
245
  epoch_summary = f"Epoch {epoch+1}/{num_epochs_app} - Avg Loss: {avg_epoch_loss:.4f}\n"
246
  print(epoch_summary)
247
  training_log_output += epoch_summary
248
- # progress.update() # Not needed with track_tqdm
249
-
250
- swck_model_global.eval() # Set back to eval mode
251
 
252
- # Save the updated model state
 
 
 
 
253
  try:
254
  torch.save({
255
  'model_state_dict': swck_model_global.state_dict(),
256
- 'optimizer_state_dict': optimizer_global.state_dict(), # Save optimizer too
257
  'word_to_idx': word_to_idx_global,
258
  'idx_to_word': idx_to_word_global,
259
- # Include other necessary metadata for consistent loading
260
- 'model_hyperparameters': { # Example of saving model construction args
261
  'vocab_size': VOCAB_SIZE_APP, 'd_model': D_MODEL_APP, 'n_heads': N_HEADS_APP,
262
  'd_ff': D_FF_APP, 'num_adaptive_blocks': NUM_ADAPTIVE_BLOCKS_APP, 'dropout': DROPOUT_APP
263
  }
264
  }, CHECKPOINT_FILENAME)
265
- save_msg = f"Training finished. Model checkpoint saved to {CHECKPOINT_FILENAME} in Space."
266
  print(save_msg)
267
  training_log_output += save_msg
268
  model_load_status_global = f"Model trained in-app & saved. Last status: {save_msg}"
@@ -274,14 +303,16 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
274
 
275
  return training_log_output
276
 
277
- # --- Text Generation Function (adapted from train.py) ---
278
  def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
279
- global model_load_status_global # To update if model isn't ready
280
  if swck_model_global is None or word_to_idx_global is None or idx_to_word_global is None:
281
  return "Model not loaded. Please check server logs or try training.", "Model not available."
282
 
283
  swck_model_global.eval()
284
  swck_model_global.set_wiring_phase(False)
 
 
 
285
 
286
  print(f"App: Generating for prompt: '{prompt_str}', max_len: {max_len_gen}, temp: {temperature_gen}")
287
 
@@ -290,8 +321,12 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
290
  debug_info_lines = [f"Prompt tokens: {generated_ids_app}"]
291
 
292
  with torch.no_grad():
293
- for i in range(max_len_gen):
294
- current_context_ids = generated_ids_app[-SEQ_LEN_APP:]
 
 
 
 
295
  input_tensor = torch.tensor([current_context_ids], dtype=torch.long).to(device_global)
296
  padding_mask = (input_tensor == PAD_TOKEN)
297
 
@@ -302,9 +337,9 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
302
  next_token_id = torch.argmax(next_token_logits).item()
303
  else:
304
  probs = F.softmax(next_token_logits / temperature_gen, dim=-1)
305
- if probs.isnan().any() or probs.isinf().any() or torch.sum(probs).item() < 1e-9 : # Check for bad probs
306
  print(f"Warning: Invalid probabilities at step {i}. Using uniform.")
307
- probs = torch.ones_like(next_token_logits) / next_token_logits.size(-1) # Fallback
308
  next_token_id = torch.multinomial(probs, 1).item()
309
 
310
  if next_token_id == EOS_TOKEN:
@@ -315,12 +350,15 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
315
  if i < 10 :
316
  current_word = idx_to_word_global.get(next_token_id, UNK_TOKEN_STR)
317
  overall_ent = entropy_report_infer['overall_output_entropy'].item()
318
- if entropy_report_infer['block_output_entropies']: # Check if list is not empty
319
  b0_ent = entropy_report_infer['block_output_entropies'][0].item()
320
- b0_gates_str = ", ".join([f"{g.item():.2f}" for g in entropy_report_infer['block_gate_weights'][0]])
321
- debug_info_lines.append(f"Gen {i+1}: '{current_word}', OvrlEnt={overall_ent:.3f}, B0Ent={b0_ent:.3f}, B0Gates=[{b0_gates_str}]")
 
 
 
322
  else:
323
- debug_info_lines.append(f"Gen {i+1}: '{current_word}', OvrlEnt={overall_ent:.3f}, No block entropy report.")
324
 
325
 
326
  generated_text_list = [idx_to_word_global.get(idx, UNK_TOKEN_STR) for idx in generated_ids_app[1:]]
@@ -331,12 +369,14 @@ def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
331
  final_text = re.sub(r'\s+', ' ', final_text).strip()
332
 
333
  debug_output_str = "\n".join(debug_info_lines)
 
 
 
 
334
  return final_text, debug_output_str
335
 
336
  # --- Gradio Interface ---
337
- # Load model on app startup
338
- initial_load_status = initialize_or_load_model_app()
339
-
340
 
341
  with gr.Blocks(title="SWCK Conceptual Demo") as demo:
342
  gr.Markdown(f"""
@@ -364,12 +404,18 @@ with gr.Blocks(title="SWCK Conceptual Demo") as demo:
364
  with gr.Row():
365
  train_epochs_slider = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Training Epochs")
366
  train_batch_size_slider = gr.Slider(minimum=1, maximum=8, value=2, step=1, label="Training Batch Size")
367
- train_lr_slider = gr.Slider(minimum=1e-5, maximum=1e-3, value=5e-4, step=1e-5, label="Learning Rate", format="%.1e")
 
368
 
369
  start_training_button = gr.Button("Start Short Training Session")
370
- training_status_output = gr.Textbox(label="Training Log / Status:", lines=10, interactive=False)
 
 
 
 
 
 
371
 
372
- # Define actions
373
  generate_button.click(
374
  fn=generate_text_for_app,
375
  inputs=[prompt_input, max_len_slider, temp_slider],
@@ -380,12 +426,11 @@ with gr.Blocks(title="SWCK Conceptual Demo") as demo:
380
  fn=run_short_training_session,
381
  inputs=[train_epochs_slider, train_batch_size_slider, train_lr_slider],
382
  outputs=[training_status_output]
383
- ).then(fn=lambda: model_load_status_global, inputs=None, outputs=gr.Markdown(elem_id="model_status_display"))
384
- # The .then part to update status might need JavaScript if Markdown elem_id doesn't work directly for dynamic updates.
385
- # For simplicity, the training function itself prints to console and returns a string.
386
- # A more robust status update would use gr.HTML or JS.
387
 
388
  if __name__ == "__main__":
389
- # When running locally, ensure debug=True for Gradio's own debug mode if needed.
390
- # On Spaces, console logs are primary.
391
- demo.launch(debug=True) # Enable Gradio debug for local run
 
 
11
  # --- Vocabulary and Tokenizer Setup ---
12
  PAD_TOKEN_STR = "<pad>"; SOS_TOKEN_STR = "<sos>"; EOS_TOKEN_STR = "<eos>"; UNK_TOKEN_STR = "<unk>"
13
  PAD_TOKEN = 0; SOS_TOKEN = 1; EOS_TOKEN = 2; UNK_TOKEN = 3
14
+ SEQ_LEN_APP = 64
15
 
16
  # --- Model Configuration ---
17
+ VOCAB_SIZE_APP = 189
18
  D_MODEL_APP = 64
19
  N_HEADS_APP = 2
20
  D_FF_APP = 128
 
38
  The target is not just prediction, but a form of self-understanding, however metaphorical.
39
  Let the adaptive blocks find their balance. Let the entropy guide the wiring.
40
  A painter paints. A scientist explores. A writer writes. The machine... becomes.
41
+ """
42
 
43
  # Global model variables
44
  swck_model_global = None
 
48
  device_global = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
  model_load_status_global = "Model not loaded."
50
 
51
+ CHECKPOINT_FILENAME = "swck_model_conceptual_app.pth.tar"
52
 
 
53
  MAIN_LOSS_WEIGHT_APP = 1.0
54
  BLOCK_TARGET_ENTROPY_LOSS_WEIGHT_APP = 0.02
55
  OVERALL_OUTPUT_ENTROPY_REG_WEIGHT_APP = 0.01
56
  GATE_SPARSITY_LOSS_WEIGHT_APP = 0.001
57
+ WIRING_PHASE_EPOCHS_APP = 1
58
 
59
 
60
  def build_vocab_from_corpus_text_app(corpus_text):
 
93
  }
94
 
95
  swck_model_global = SWCKModel(**model_args).to(device_global)
96
+ swck_model_global.debug_prints_enabled = True # Top-level model debug
 
97
  if hasattr(swck_model_global, 'seed_parser'): swck_model_global.seed_parser.debug_prints_enabled = True
98
  for i,block in enumerate(swck_model_global.adaptive_blocks):
99
+ block.debug_prints_enabled = True # Block-level debug
100
+ # print(f"App: Debug prints explicitly enabled for AdaptiveBlock {i}")
101
 
102
 
103
  if os.path.exists(CHECKPOINT_FILENAME):
 
106
  checkpoint = torch.load(CHECKPOINT_FILENAME, map_location=device_global)
107
  swck_model_global.load_state_dict(checkpoint['model_state_dict'])
108
 
109
+ optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
110
+ if 'optimizer_state_dict' in checkpoint:
 
111
  optimizer_global.load_state_dict(checkpoint['optimizer_state_dict'])
112
 
113
+ if 'word_to_idx' in checkpoint:
 
114
  loaded_w2i = checkpoint['word_to_idx']
115
+ # Basic check, could be more robust
116
+ if isinstance(loaded_w2i, dict) and len(loaded_w2i) > 4:
117
  word_to_idx_global = loaded_w2i
118
  idx_to_word_global = {v: k for k,v in loaded_w2i.items()}
119
+ VOCAB_SIZE_APP = len(word_to_idx_global) # Ensure vocab size reflects loaded
120
+ print(f"App: Overwrote vocab with checkpoint's vocab. New size: {VOCAB_SIZE_APP}")
121
  else:
122
+ print("App: Checkpoint vocab seems invalid, using app's rebuilt vocab.")
123
+ else:
124
+ print("App: word_to_idx not in checkpoint, using app's rebuilt vocab.")
125
+
126
 
127
  model_load_status_global = f"Model loaded successfully from {CHECKPOINT_FILENAME}."
128
  print(model_load_status_global)
129
  except Exception as e:
130
  print(f"App: Error loading model from checkpoint: {e}. Initializing new model.")
131
+ swck_model_global = SWCKModel(**model_args).to(device_global) # Re-init
 
132
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
133
  model_load_status_global = "Error loading checkpoint. Using new (untrained) model."
134
  else:
 
136
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=0.001)
137
  model_load_status_global = "Initialized a new (untrained) model."
138
 
139
+ swck_model_global.eval()
140
  return model_load_status_global
141
 
142
 
 
143
  class AppSWCKDataset(Dataset):
144
  def __init__(self, text_corpus_str, w2i_map, seq_len, sos_id, eos_id, pad_id):
145
  tokens = re.sub(r'\s+', ' ', text_corpus_str.lower()).strip().split()
 
148
  self.seq_len = seq_len
149
  self.sos_id, self.eos_id, self.pad_id = sos_id, eos_id, pad_id
150
  self.samples = []
151
+ # Create overlapping sequences for language modeling
152
+ # Ensure target is seq_len for consistency with input to model.
153
+ for i in range(len(token_ids) - seq_len -1): # -1 to ensure target has full seq_len
154
+ input_seq = [self.sos_id] + token_ids[i : i + seq_len] # length seq_len + 1
155
+ target_seq = token_ids[i + 1 : i + seq_len + 1] + [self.eos_id] # length seq_len + 1
156
  self.samples.append((input_seq, target_seq))
157
  print(f"AppSWCKDataset: Created {len(self.samples)} training samples for in-app training.")
158
 
 
167
  padded_tgt = nn.utils.rnn.pad_sequence(tgt_list, batch_first=True, padding_value=PAD_TOKEN)
168
  return padded_src, padded_tgt
169
 
 
170
  def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app, progress=gr.Progress(track_tqdm=True)):
171
  global swck_model_global, optimizer_global, word_to_idx_global, model_load_status_global
172
 
 
176
  print("\n--- App: Starting Short Training Session ---")
177
  progress(0, desc="Preparing training data...")
178
 
 
179
  training_corpus = SEED_PHRASE_APP + " " + EXTENDED_TEXT_FOR_TRAINING_APP
180
  app_dataset = AppSWCKDataset(training_corpus, word_to_idx_global, SEQ_LEN_APP, SOS_TOKEN, EOS_TOKEN, PAD_TOKEN)
181
  if not app_dataset.samples:
182
  return "App Training Error: No samples created from the corpus."
183
 
184
+ app_dataloader = DataLoader(app_dataset, batch_size=int(batch_size_app), shuffle=True, collate_fn=app_swck_collate_fn)
185
 
 
186
  if optimizer_global is None:
187
  optimizer_global = optim.AdamW(swck_model_global.parameters(), lr=learning_rate_app)
188
+ else:
189
  for param_group in optimizer_global.param_groups:
190
  param_group['lr'] = learning_rate_app
191
 
192
  criterion_main_app = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
193
 
194
+ training_log_output = f"Starting training for {num_epochs_app} epochs...\n"
195
+ swck_model_global.train()
196
 
197
+ for epoch in progress.tqdm(range(int(num_epochs_app)), desc="Training Epochs"):
198
+ swck_model_global.set_wiring_phase(epoch < WIRING_PHASE_EPOCHS_APP)
199
  epoch_loss = 0.0
200
+
201
+ # Enable debug for first batch of first epoch
202
+ first_batch_debug = (epoch == 0)
203
+
204
  for batch_idx, (src_batch, tgt_batch) in enumerate(app_dataloader):
205
+ if first_batch_debug and batch_idx == 0:
206
+ swck_model_global.debug_prints_enabled = True
207
+ for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = True
208
+ elif not (first_batch_debug and batch_idx == 0) : # Disable after first batch for speed
209
+ swck_model_global.debug_prints_enabled = False
210
+ for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
211
+
212
+
213
  src_batch, tgt_batch = src_batch.to(device_global), tgt_batch.to(device_global)
214
+ decoder_input_tokens = src_batch[:, :-1] # Remove EOS from input
215
+ gold_standard_for_loss = tgt_batch[:, 1:] # Remove SOS from target
216
+
217
  src_key_padding_mask = (decoder_input_tokens == PAD_TOKEN)
218
 
219
  optimizer_global.zero_grad()
220
  logits, entropy_report = swck_model_global(decoder_input_tokens, src_key_padding_mask=src_key_padding_mask)
221
+
222
+ # Ensure logits and gold_standard_for_loss are aligned for CrossEntropyLoss
223
+ # Logits: (B, S_len_in, VocabSize)
224
+ # Gold: (B, S_len_target)
225
+ # If S_len_in == S_len_target, it's fine.
226
+ if logits.size(1) != gold_standard_for_loss.size(1):
227
+ # This can happen if seq len handling differs slightly, adjust shorter one
228
+ min_len = min(logits.size(1), gold_standard_for_loss.size(1))
229
+ logits_for_loss = logits[:, :min_len, :].contiguous()
230
+ gold_for_loss_aligned = gold_standard_for_loss[:, :min_len].contiguous()
231
+ else:
232
+ logits_for_loss = logits
233
+ gold_for_loss_aligned = gold_standard_for_loss
234
+
235
+ main_loss = criterion_main_app(logits_for_loss.view(-1, logits_for_loss.size(-1)), gold_for_loss_aligned.view(-1))
236
 
237
  block_entropy_loss = torch.tensor(0.0, device=device_global)
238
  if entropy_report["block_output_entropies"]:
239
+ for i, block_entropy_tensor in enumerate(entropy_report["block_output_entropies"]):
240
+ target_entropy_val = swck_model_global.seed_parser.get_block_config(i)["target_entropy"]
241
+ block_entropy_loss += F.mse_loss(block_entropy_tensor, torch.tensor(target_entropy_val, device=device_global))
242
+ if entropy_report["block_output_entropies"]: # Avoid division by zero
243
  block_entropy_loss = block_entropy_loss / len(entropy_report["block_output_entropies"])
244
 
245
  overall_entropy_loss = entropy_report["overall_output_entropy"]
246
  gate_sparsity_loss = torch.tensor(0.0, device=device_global)
247
  if entropy_report["block_gate_weights"]:
248
+ for gates_softmax_tensor in entropy_report["block_gate_weights"]:
249
+ gate_sparsity_loss += torch.mean(gates_softmax_tensor * torch.log(gates_softmax_tensor + 1e-9))
250
+ if entropy_report["block_gate_weights"]: # Avoid division by zero
251
  gate_sparsity_loss = - (gate_sparsity_loss / len(entropy_report["block_gate_weights"]))
252
 
 
253
  combined_loss = (MAIN_LOSS_WEIGHT_APP * main_loss +
254
  BLOCK_TARGET_ENTROPY_LOSS_WEIGHT_APP * block_entropy_loss +
255
  OVERALL_OUTPUT_ENTROPY_REG_WEIGHT_APP * overall_entropy_loss +
 
260
  optimizer_global.step()
261
  epoch_loss += combined_loss.item()
262
 
263
+ log_line = f" Epoch {epoch+1}, Batch {batch_idx+1}/{len(app_dataloader)}, Loss: {combined_loss.item():.4f}"
264
+ if batch_idx % max(1, len(app_dataloader)//2) == 0 or batch_idx == len(app_dataloader)-1 : # Log less frequently to UI
265
+ print(log_line)
266
+ training_log_output += log_line + "\n"
267
+
268
+ # Disable debug prints after the very first batch of the first epoch
269
+ swck_model_global.debug_prints_enabled = False
270
+ for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
271
+
272
 
273
+ avg_epoch_loss = epoch_loss / len(app_dataloader) if len(app_dataloader) > 0 else epoch_loss
274
  epoch_summary = f"Epoch {epoch+1}/{num_epochs_app} - Avg Loss: {avg_epoch_loss:.4f}\n"
275
  print(epoch_summary)
276
  training_log_output += epoch_summary
 
 
 
277
 
278
+ # Ensure debug prints are off after training session
279
+ swck_model_global.debug_prints_enabled = False
280
+ for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
281
+ swck_model_global.eval()
282
+
283
  try:
284
  torch.save({
285
  'model_state_dict': swck_model_global.state_dict(),
286
+ 'optimizer_state_dict': optimizer_global.state_dict(),
287
  'word_to_idx': word_to_idx_global,
288
  'idx_to_word': idx_to_word_global,
289
+ 'model_hyperparameters': {
 
290
  'vocab_size': VOCAB_SIZE_APP, 'd_model': D_MODEL_APP, 'n_heads': N_HEADS_APP,
291
  'd_ff': D_FF_APP, 'num_adaptive_blocks': NUM_ADAPTIVE_BLOCKS_APP, 'dropout': DROPOUT_APP
292
  }
293
  }, CHECKPOINT_FILENAME)
294
+ save_msg = f"Training finished. Model checkpoint saved to {CHECKPOINT_FILENAME} in Space's ephemeral storage."
295
  print(save_msg)
296
  training_log_output += save_msg
297
  model_load_status_global = f"Model trained in-app & saved. Last status: {save_msg}"
 
303
 
304
  return training_log_output
305
 
 
306
  def generate_text_for_app(prompt_str, max_len_gen, temperature_gen):
307
+ global model_load_status_global
308
  if swck_model_global is None or word_to_idx_global is None or idx_to_word_global is None:
309
  return "Model not loaded. Please check server logs or try training.", "Model not available."
310
 
311
  swck_model_global.eval()
312
  swck_model_global.set_wiring_phase(False)
313
+ # Temporarily enable debug for generation if needed, then disable
314
+ # swck_model_global.debug_prints_enabled = True # For generation debug
315
+ # for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = True
316
 
317
  print(f"App: Generating for prompt: '{prompt_str}', max_len: {max_len_gen}, temp: {temperature_gen}")
318
 
 
321
  debug_info_lines = [f"Prompt tokens: {generated_ids_app}"]
322
 
323
  with torch.no_grad():
324
+ for i in range(int(max_len_gen)): # Ensure max_len_gen is int
325
+ # Context windowing for input_tensor
326
+ # Take up to SEQ_LEN_APP tokens from the end of generated_ids_app
327
+ context_start_idx = max(0, len(generated_ids_app) - SEQ_LEN_APP)
328
+ current_context_ids = generated_ids_app[context_start_idx:]
329
+
330
  input_tensor = torch.tensor([current_context_ids], dtype=torch.long).to(device_global)
331
  padding_mask = (input_tensor == PAD_TOKEN)
332
 
 
337
  next_token_id = torch.argmax(next_token_logits).item()
338
  else:
339
  probs = F.softmax(next_token_logits / temperature_gen, dim=-1)
340
+ if probs.isnan().any() or probs.isinf().any() or torch.sum(probs).item() < 1e-9 :
341
  print(f"Warning: Invalid probabilities at step {i}. Using uniform.")
342
+ probs = torch.ones_like(next_token_logits) / next_token_logits.size(-1)
343
  next_token_id = torch.multinomial(probs, 1).item()
344
 
345
  if next_token_id == EOS_TOKEN:
 
350
  if i < 10 :
351
  current_word = idx_to_word_global.get(next_token_id, UNK_TOKEN_STR)
352
  overall_ent = entropy_report_infer['overall_output_entropy'].item()
353
+ if entropy_report_infer['block_output_entropies'] and len(entropy_report_infer['block_output_entropies']) > 0:
354
  b0_ent = entropy_report_infer['block_output_entropies'][0].item()
355
+ if entropy_report_infer['block_gate_weights'] and len(entropy_report_infer['block_gate_weights']) > 0:
356
+ b0_gates_str = ", ".join([f"{g.item():.2f}" for g in entropy_report_infer['block_gate_weights'][0]])
357
+ debug_info_lines.append(f"Gen {i+1}: '{current_word}', OvrlEnt={overall_ent:.3f}, B0Ent={b0_ent:.3f}, B0Gates=[{b0_gates_str}]")
358
+ else:
359
+ debug_info_lines.append(f"Gen {i+1}: '{current_word}', OvrlEnt={overall_ent:.3f}, B0Ent={b0_ent:.3f}, No B0 gates.")
360
  else:
361
+ debug_info_lines.append(f"Gen {i+1}: '{current_word}', OvrlEnt={overall_ent:.3f}, No block entropy/gate report.")
362
 
363
 
364
  generated_text_list = [idx_to_word_global.get(idx, UNK_TOKEN_STR) for idx in generated_ids_app[1:]]
 
369
  final_text = re.sub(r'\s+', ' ', final_text).strip()
370
 
371
  debug_output_str = "\n".join(debug_info_lines)
372
+
373
+ # Disable debug prints after generation
374
+ # swck_model_global.debug_prints_enabled = False
375
+ # for blk in swck_model_global.adaptive_blocks: blk.debug_prints_enabled = False
376
  return final_text, debug_output_str
377
 
378
  # --- Gradio Interface ---
379
+ initial_load_status = initialize_or_load_model_app() # Load model on app startup
 
 
380
 
381
  with gr.Blocks(title="SWCK Conceptual Demo") as demo:
382
  gr.Markdown(f"""
 
404
  with gr.Row():
405
  train_epochs_slider = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Training Epochs")
406
  train_batch_size_slider = gr.Slider(minimum=1, maximum=8, value=2, step=1, label="Training Batch Size")
407
+ # REMOVED format="%.1e"
408
+ train_lr_slider = gr.Slider(minimum=1e-5, maximum=1e-3, value=5e-4, step=1e-5, label="Learning Rate")
409
 
410
  start_training_button = gr.Button("Start Short Training Session")
411
+ training_status_output = gr.Textbox(label="Training Log / Status:", lines=10, interactive=False,show_label=True )
412
+
413
+
414
+ model_status_md = gr.Markdown(value=f"**Model Status:** {model_load_status_global}")
415
+
416
+ def update_status_text(): # Helper to refresh status after training
417
+ return f"**Model Status:** {model_load_status_global}"
418
 
 
419
  generate_button.click(
420
  fn=generate_text_for_app,
421
  inputs=[prompt_input, max_len_slider, temp_slider],
 
426
  fn=run_short_training_session,
427
  inputs=[train_epochs_slider, train_batch_size_slider, train_lr_slider],
428
  outputs=[training_status_output]
429
+ ).then(fn=update_status_text, inputs=None, outputs=model_status_md)
430
+
 
 
431
 
432
  if __name__ == "__main__":
433
+ # The Gradio app launch options (like debug=True) are for local execution.
434
+ # On Hugging Face Spaces, these are typically controlled by the environment.
435
+ # The `print()` statements will go to the Space's console logs.
436
+ demo.launch(debug=True)