neuralworm commited on
Commit
b8efd7e
·
1 Parent(s): d82b2bb

overhaul by Gemini

Browse files
Files changed (1) hide show
  1. app.py +9 -12
app.py CHANGED
@@ -68,7 +68,7 @@ BLOCK_TARGET_ENTROPY_LOSS_WEIGHT_APP = 0.02
68
  OVERALL_OUTPUT_ENTROPY_REG_WEIGHT_APP = 0.01
69
  GATE_SPARSITY_LOSS_WEIGHT_APP = 0.001
70
  GATE_ALIGNMENT_LOSS_WEIGHT_APP = 0.005 # For ObserverTime Sync during wiring phase
71
- WIRING_PHASE_EPOCHS_APP = 5 # Slightly increased for gate alignment to take effect
72
 
73
  def set_model_debug_prints(model, seed_parser_debug, block_debug, model_debug):
74
  if model:
@@ -228,7 +228,6 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
228
  swck_model_global.set_wiring_phase(epoch < WIRING_PHASE_EPOCHS_APP)
229
  epoch_loss = 0.0; print(f"\n>>> EPOCH {epoch+1} <<<")
230
  for batch_idx, (src_batch, tgt_batch) in enumerate(app_dataloader):
231
- # print(f"\n--- Training Batch {batch_idx+1}/{len(app_dataloader)} (Epoch {epoch+1}) ---") # Verbose
232
  src_batch, tgt_batch = src_batch.to(device_global), tgt_batch.to(device_global)
233
  src_key_padding_mask = (src_batch == PAD_TOKEN)
234
  optimizer_global.zero_grad()
@@ -248,11 +247,11 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
248
  gate_sparsity_loss = torch.tensor(0.0, device=device_global)
249
  if entropy_report["current_block_gate_softmaxes"]:
250
  num_valid_gates_sparsity = 0
251
- for gates_tensor in entropy_report["current_block_gate_softmaxes"]: # These are already softmaxed
252
  if torch.is_tensor(gates_tensor) and gates_tensor.numel() > 0:
253
- gate_sparsity_loss += torch.mean(gates_tensor * torch.log(gates_tensor + 1e-9)) # Negative Entropy
254
  num_valid_gates_sparsity +=1
255
- if num_valid_gates_sparsity > 0 : gate_sparsity_loss = -(gate_sparsity_loss / num_valid_gates_sparsity) # Minimize entropy
256
 
257
  gate_alignment_loss = torch.tensor(0.0, device=device_global)
258
  if entropy_report["current_block_gate_softmaxes"] and entropy_report["initial_block_gate_targets"]:
@@ -265,7 +264,8 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
265
  num_valid_align_gates +=1
266
  if num_valid_align_gates > 0: gate_alignment_loss /= num_valid_align_gates
267
 
268
- current_gate_alignment_weight = GATE_ALIGNMENT_LOSS_WEIGHT if epoch < WIRING_PHASE_EPOCHS_APP else GATE_ALIGNMENT_LOSS_WEIGHT * 0.1
 
269
 
270
  combined_loss = (MAIN_LOSS_WEIGHT_APP * main_loss + BLOCK_TARGET_ENTROPY_LOSS_WEIGHT_APP * block_entropy_loss +
271
  OVERALL_OUTPUT_ENTROPY_REG_WEIGHT_APP * overall_entropy_loss + GATE_SPARSITY_LOSS_WEIGHT_APP * gate_sparsity_loss +
@@ -285,7 +285,7 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
285
  'num_adaptive_blocks': len(swck_model_global.adaptive_blocks), 'dropout': current_dropout,
286
  'seed_phrase': seed_phrase_ui, 'seed_number_str': seed_number_ui,
287
  'num_sub_modules_per_block': swck_model_global.adaptive_blocks[0].num_sub_modules if swck_model_global.adaptive_blocks else current_num_sub_modules_pb,
288
- 'seq_len_trained_on': SEQ_LEN_APP # Store the sequence length it was trained with
289
  }
290
  torch.save({'model_state_dict': swck_model_global.state_dict(), 'optimizer_state_dict': optimizer_global.state_dict(),
291
  'word_to_idx': word_to_idx_global, 'idx_to_word': idx_to_word_global, 'model_hyperparameters': hyperparams
@@ -312,9 +312,7 @@ def generate_text_for_app(current_interaction_text, max_len_gen, temperature_gen
312
  newly_generated_tokens_list = []
313
  with torch.no_grad():
314
  for i in range(int(max_len_gen)):
315
- # print(f"\n--- Gen Step {i+1}/{max_len_gen} ---") # Verbose
316
  context_for_model = generated_ids_app[-SEQ_LEN_APP:]
317
- # print(f" Context for model (len {len(context_for_model)}): {[idx_to_word_global.get(t, UNK_TOKEN_STR) for t in context_for_model[-20:]]}...") # Verbose
318
  if not context_for_model: print("Warning: Empty context_for_model!"); break
319
  input_tensor = torch.tensor([context_for_model], dtype=torch.long).to(device_global)
320
  padding_mask = (input_tensor == PAD_TOKEN)
@@ -344,13 +342,12 @@ def generate_text_for_app(current_interaction_text, max_len_gen, temperature_gen
344
  generated_ids_app.append(next_token_id)
345
  current_word = idx_to_word_global.get(next_token_id, UNK_TOKEN_STR)
346
  newly_generated_tokens_list.append(current_word)
347
- # print(f" ==> Generated token {i+1}: '{current_word}' (ID: {next_token_id})") # Verbose
348
  if i < 10:
349
  overall_ent = entropy_report_infer['overall_output_entropy'].item() if torch.is_tensor(entropy_report_infer['overall_output_entropy']) else 0.0
350
  b0_ent_str, b0_gates_str = "N/A", "N/A"
351
  if entropy_report_infer['block_output_entropies'] and len(entropy_report_infer['block_output_entropies']) > 0 and torch.is_tensor(entropy_report_infer['block_output_entropies'][0]):
352
  b0_ent_str = f"{entropy_report_infer['block_output_entropies'][0].item():.3f}"
353
- if entropy_report_infer['current_block_gate_softmaxes'] and len(entropy_report_infer['current_block_gate_softmaxes']) > 0 and torch.is_tensor(entropy_report_infer['current_block_gate_softmaxes'][0]): # Use softmaxes for debug
354
  b0_gates_str = ", ".join([f"{g.item():.2f}" for g in entropy_report_infer['current_block_gate_softmaxes'][0]])
355
  debug_info_lines.append(f"Gen {i+1}: '{current_word}', OvrlEnt={overall_ent:.3f}, B0Ent={b0_ent_str}, B0Gates=[{b0_gates_str}]")
356
 
@@ -382,7 +379,7 @@ def prepare_model_for_download():
382
  'num_adaptive_blocks': len(swck_model_global.adaptive_blocks), 'dropout': current_dropout,
383
  'seed_phrase': swck_model_global.seed_parser.seed_phrase, 'seed_number_str': swck_model_global.seed_parser.seed_number_str,
384
  'num_sub_modules_per_block': swck_model_global.adaptive_blocks[0].num_sub_modules if swck_model_global.adaptive_blocks else current_num_sub_modules_pb,
385
- 'seq_len_trained_on': SEQ_LEN_APP # Store SEQ_LEN_APP as it's used for dataset in-app
386
  }
387
  torch.save({'model_state_dict': swck_model_global.state_dict(), 'optimizer_state_dict': optimizer_global.state_dict(),
388
  'word_to_idx': word_to_idx_global, 'idx_to_word': idx_to_word_global, 'model_hyperparameters': hyperparams
 
68
  OVERALL_OUTPUT_ENTROPY_REG_WEIGHT_APP = 0.01
69
  GATE_SPARSITY_LOSS_WEIGHT_APP = 0.001
70
  GATE_ALIGNMENT_LOSS_WEIGHT_APP = 0.005 # For ObserverTime Sync during wiring phase
71
+ WIRING_PHASE_EPOCHS_APP = 5
72
 
73
  def set_model_debug_prints(model, seed_parser_debug, block_debug, model_debug):
74
  if model:
 
228
  swck_model_global.set_wiring_phase(epoch < WIRING_PHASE_EPOCHS_APP)
229
  epoch_loss = 0.0; print(f"\n>>> EPOCH {epoch+1} <<<")
230
  for batch_idx, (src_batch, tgt_batch) in enumerate(app_dataloader):
 
231
  src_batch, tgt_batch = src_batch.to(device_global), tgt_batch.to(device_global)
232
  src_key_padding_mask = (src_batch == PAD_TOKEN)
233
  optimizer_global.zero_grad()
 
247
  gate_sparsity_loss = torch.tensor(0.0, device=device_global)
248
  if entropy_report["current_block_gate_softmaxes"]:
249
  num_valid_gates_sparsity = 0
250
+ for gates_tensor in entropy_report["current_block_gate_softmaxes"]:
251
  if torch.is_tensor(gates_tensor) and gates_tensor.numel() > 0:
252
+ gate_sparsity_loss += torch.mean(gates_tensor * torch.log(gates_tensor + 1e-9))
253
  num_valid_gates_sparsity +=1
254
+ if num_valid_gates_sparsity > 0 : gate_sparsity_loss = -(gate_sparsity_loss / num_valid_gates_sparsity)
255
 
256
  gate_alignment_loss = torch.tensor(0.0, device=device_global)
257
  if entropy_report["current_block_gate_softmaxes"] and entropy_report["initial_block_gate_targets"]:
 
264
  num_valid_align_gates +=1
265
  if num_valid_align_gates > 0: gate_alignment_loss /= num_valid_align_gates
266
 
267
+ # CORRECTED VARIABLE NAME HERE
268
+ current_gate_alignment_weight = GATE_ALIGNMENT_LOSS_WEIGHT_APP if epoch < WIRING_PHASE_EPOCHS_APP else GATE_ALIGNMENT_LOSS_WEIGHT_APP * 0.1
269
 
270
  combined_loss = (MAIN_LOSS_WEIGHT_APP * main_loss + BLOCK_TARGET_ENTROPY_LOSS_WEIGHT_APP * block_entropy_loss +
271
  OVERALL_OUTPUT_ENTROPY_REG_WEIGHT_APP * overall_entropy_loss + GATE_SPARSITY_LOSS_WEIGHT_APP * gate_sparsity_loss +
 
285
  'num_adaptive_blocks': len(swck_model_global.adaptive_blocks), 'dropout': current_dropout,
286
  'seed_phrase': seed_phrase_ui, 'seed_number_str': seed_number_ui,
287
  'num_sub_modules_per_block': swck_model_global.adaptive_blocks[0].num_sub_modules if swck_model_global.adaptive_blocks else current_num_sub_modules_pb,
288
+ 'seq_len_trained_on': SEQ_LEN_APP
289
  }
290
  torch.save({'model_state_dict': swck_model_global.state_dict(), 'optimizer_state_dict': optimizer_global.state_dict(),
291
  'word_to_idx': word_to_idx_global, 'idx_to_word': idx_to_word_global, 'model_hyperparameters': hyperparams
 
312
  newly_generated_tokens_list = []
313
  with torch.no_grad():
314
  for i in range(int(max_len_gen)):
 
315
  context_for_model = generated_ids_app[-SEQ_LEN_APP:]
 
316
  if not context_for_model: print("Warning: Empty context_for_model!"); break
317
  input_tensor = torch.tensor([context_for_model], dtype=torch.long).to(device_global)
318
  padding_mask = (input_tensor == PAD_TOKEN)
 
342
  generated_ids_app.append(next_token_id)
343
  current_word = idx_to_word_global.get(next_token_id, UNK_TOKEN_STR)
344
  newly_generated_tokens_list.append(current_word)
 
345
  if i < 10:
346
  overall_ent = entropy_report_infer['overall_output_entropy'].item() if torch.is_tensor(entropy_report_infer['overall_output_entropy']) else 0.0
347
  b0_ent_str, b0_gates_str = "N/A", "N/A"
348
  if entropy_report_infer['block_output_entropies'] and len(entropy_report_infer['block_output_entropies']) > 0 and torch.is_tensor(entropy_report_infer['block_output_entropies'][0]):
349
  b0_ent_str = f"{entropy_report_infer['block_output_entropies'][0].item():.3f}"
350
+ if entropy_report_infer['current_block_gate_softmaxes'] and len(entropy_report_infer['current_block_gate_softmaxes']) > 0 and torch.is_tensor(entropy_report_infer['current_block_gate_softmaxes'][0]):
351
  b0_gates_str = ", ".join([f"{g.item():.2f}" for g in entropy_report_infer['current_block_gate_softmaxes'][0]])
352
  debug_info_lines.append(f"Gen {i+1}: '{current_word}', OvrlEnt={overall_ent:.3f}, B0Ent={b0_ent_str}, B0Gates=[{b0_gates_str}]")
353
 
 
379
  'num_adaptive_blocks': len(swck_model_global.adaptive_blocks), 'dropout': current_dropout,
380
  'seed_phrase': swck_model_global.seed_parser.seed_phrase, 'seed_number_str': swck_model_global.seed_parser.seed_number_str,
381
  'num_sub_modules_per_block': swck_model_global.adaptive_blocks[0].num_sub_modules if swck_model_global.adaptive_blocks else current_num_sub_modules_pb,
382
+ 'seq_len_trained_on': SEQ_LEN_APP
383
  }
384
  torch.save({'model_state_dict': swck_model_global.state_dict(), 'optimizer_state_dict': optimizer_global.state_dict(),
385
  'word_to_idx': word_to_idx_global, 'idx_to_word': idx_to_word_global, 'model_hyperparameters': hyperparams