Spaces:
Running
Running
Commit
·
b8efd7e
1
Parent(s):
d82b2bb
overhaul by Gemini
Browse files
app.py
CHANGED
@@ -68,7 +68,7 @@ BLOCK_TARGET_ENTROPY_LOSS_WEIGHT_APP = 0.02
|
|
68 |
OVERALL_OUTPUT_ENTROPY_REG_WEIGHT_APP = 0.01
|
69 |
GATE_SPARSITY_LOSS_WEIGHT_APP = 0.001
|
70 |
GATE_ALIGNMENT_LOSS_WEIGHT_APP = 0.005 # For ObserverTime Sync during wiring phase
|
71 |
-
WIRING_PHASE_EPOCHS_APP = 5
|
72 |
|
73 |
def set_model_debug_prints(model, seed_parser_debug, block_debug, model_debug):
|
74 |
if model:
|
@@ -228,7 +228,6 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
|
|
228 |
swck_model_global.set_wiring_phase(epoch < WIRING_PHASE_EPOCHS_APP)
|
229 |
epoch_loss = 0.0; print(f"\n>>> EPOCH {epoch+1} <<<")
|
230 |
for batch_idx, (src_batch, tgt_batch) in enumerate(app_dataloader):
|
231 |
-
# print(f"\n--- Training Batch {batch_idx+1}/{len(app_dataloader)} (Epoch {epoch+1}) ---") # Verbose
|
232 |
src_batch, tgt_batch = src_batch.to(device_global), tgt_batch.to(device_global)
|
233 |
src_key_padding_mask = (src_batch == PAD_TOKEN)
|
234 |
optimizer_global.zero_grad()
|
@@ -248,11 +247,11 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
|
|
248 |
gate_sparsity_loss = torch.tensor(0.0, device=device_global)
|
249 |
if entropy_report["current_block_gate_softmaxes"]:
|
250 |
num_valid_gates_sparsity = 0
|
251 |
-
for gates_tensor in entropy_report["current_block_gate_softmaxes"]:
|
252 |
if torch.is_tensor(gates_tensor) and gates_tensor.numel() > 0:
|
253 |
-
gate_sparsity_loss += torch.mean(gates_tensor * torch.log(gates_tensor + 1e-9))
|
254 |
num_valid_gates_sparsity +=1
|
255 |
-
if num_valid_gates_sparsity > 0 : gate_sparsity_loss = -(gate_sparsity_loss / num_valid_gates_sparsity)
|
256 |
|
257 |
gate_alignment_loss = torch.tensor(0.0, device=device_global)
|
258 |
if entropy_report["current_block_gate_softmaxes"] and entropy_report["initial_block_gate_targets"]:
|
@@ -265,7 +264,8 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
|
|
265 |
num_valid_align_gates +=1
|
266 |
if num_valid_align_gates > 0: gate_alignment_loss /= num_valid_align_gates
|
267 |
|
268 |
-
|
|
|
269 |
|
270 |
combined_loss = (MAIN_LOSS_WEIGHT_APP * main_loss + BLOCK_TARGET_ENTROPY_LOSS_WEIGHT_APP * block_entropy_loss +
|
271 |
OVERALL_OUTPUT_ENTROPY_REG_WEIGHT_APP * overall_entropy_loss + GATE_SPARSITY_LOSS_WEIGHT_APP * gate_sparsity_loss +
|
@@ -285,7 +285,7 @@ def run_short_training_session(num_epochs_app, batch_size_app, learning_rate_app
|
|
285 |
'num_adaptive_blocks': len(swck_model_global.adaptive_blocks), 'dropout': current_dropout,
|
286 |
'seed_phrase': seed_phrase_ui, 'seed_number_str': seed_number_ui,
|
287 |
'num_sub_modules_per_block': swck_model_global.adaptive_blocks[0].num_sub_modules if swck_model_global.adaptive_blocks else current_num_sub_modules_pb,
|
288 |
-
'seq_len_trained_on': SEQ_LEN_APP
|
289 |
}
|
290 |
torch.save({'model_state_dict': swck_model_global.state_dict(), 'optimizer_state_dict': optimizer_global.state_dict(),
|
291 |
'word_to_idx': word_to_idx_global, 'idx_to_word': idx_to_word_global, 'model_hyperparameters': hyperparams
|
@@ -312,9 +312,7 @@ def generate_text_for_app(current_interaction_text, max_len_gen, temperature_gen
|
|
312 |
newly_generated_tokens_list = []
|
313 |
with torch.no_grad():
|
314 |
for i in range(int(max_len_gen)):
|
315 |
-
# print(f"\n--- Gen Step {i+1}/{max_len_gen} ---") # Verbose
|
316 |
context_for_model = generated_ids_app[-SEQ_LEN_APP:]
|
317 |
-
# print(f" Context for model (len {len(context_for_model)}): {[idx_to_word_global.get(t, UNK_TOKEN_STR) for t in context_for_model[-20:]]}...") # Verbose
|
318 |
if not context_for_model: print("Warning: Empty context_for_model!"); break
|
319 |
input_tensor = torch.tensor([context_for_model], dtype=torch.long).to(device_global)
|
320 |
padding_mask = (input_tensor == PAD_TOKEN)
|
@@ -344,13 +342,12 @@ def generate_text_for_app(current_interaction_text, max_len_gen, temperature_gen
|
|
344 |
generated_ids_app.append(next_token_id)
|
345 |
current_word = idx_to_word_global.get(next_token_id, UNK_TOKEN_STR)
|
346 |
newly_generated_tokens_list.append(current_word)
|
347 |
-
# print(f" ==> Generated token {i+1}: '{current_word}' (ID: {next_token_id})") # Verbose
|
348 |
if i < 10:
|
349 |
overall_ent = entropy_report_infer['overall_output_entropy'].item() if torch.is_tensor(entropy_report_infer['overall_output_entropy']) else 0.0
|
350 |
b0_ent_str, b0_gates_str = "N/A", "N/A"
|
351 |
if entropy_report_infer['block_output_entropies'] and len(entropy_report_infer['block_output_entropies']) > 0 and torch.is_tensor(entropy_report_infer['block_output_entropies'][0]):
|
352 |
b0_ent_str = f"{entropy_report_infer['block_output_entropies'][0].item():.3f}"
|
353 |
-
if entropy_report_infer['current_block_gate_softmaxes'] and len(entropy_report_infer['current_block_gate_softmaxes']) > 0 and torch.is_tensor(entropy_report_infer['current_block_gate_softmaxes'][0]):
|
354 |
b0_gates_str = ", ".join([f"{g.item():.2f}" for g in entropy_report_infer['current_block_gate_softmaxes'][0]])
|
355 |
debug_info_lines.append(f"Gen {i+1}: '{current_word}', OvrlEnt={overall_ent:.3f}, B0Ent={b0_ent_str}, B0Gates=[{b0_gates_str}]")
|
356 |
|
@@ -382,7 +379,7 @@ def prepare_model_for_download():
|
|
382 |
'num_adaptive_blocks': len(swck_model_global.adaptive_blocks), 'dropout': current_dropout,
|
383 |
'seed_phrase': swck_model_global.seed_parser.seed_phrase, 'seed_number_str': swck_model_global.seed_parser.seed_number_str,
|
384 |
'num_sub_modules_per_block': swck_model_global.adaptive_blocks[0].num_sub_modules if swck_model_global.adaptive_blocks else current_num_sub_modules_pb,
|
385 |
-
'seq_len_trained_on': SEQ_LEN_APP
|
386 |
}
|
387 |
torch.save({'model_state_dict': swck_model_global.state_dict(), 'optimizer_state_dict': optimizer_global.state_dict(),
|
388 |
'word_to_idx': word_to_idx_global, 'idx_to_word': idx_to_word_global, 'model_hyperparameters': hyperparams
|
|
|
68 |
OVERALL_OUTPUT_ENTROPY_REG_WEIGHT_APP = 0.01
|
69 |
GATE_SPARSITY_LOSS_WEIGHT_APP = 0.001
|
70 |
GATE_ALIGNMENT_LOSS_WEIGHT_APP = 0.005 # For ObserverTime Sync during wiring phase
|
71 |
+
WIRING_PHASE_EPOCHS_APP = 5
|
72 |
|
73 |
def set_model_debug_prints(model, seed_parser_debug, block_debug, model_debug):
|
74 |
if model:
|
|
|
228 |
swck_model_global.set_wiring_phase(epoch < WIRING_PHASE_EPOCHS_APP)
|
229 |
epoch_loss = 0.0; print(f"\n>>> EPOCH {epoch+1} <<<")
|
230 |
for batch_idx, (src_batch, tgt_batch) in enumerate(app_dataloader):
|
|
|
231 |
src_batch, tgt_batch = src_batch.to(device_global), tgt_batch.to(device_global)
|
232 |
src_key_padding_mask = (src_batch == PAD_TOKEN)
|
233 |
optimizer_global.zero_grad()
|
|
|
247 |
gate_sparsity_loss = torch.tensor(0.0, device=device_global)
|
248 |
if entropy_report["current_block_gate_softmaxes"]:
|
249 |
num_valid_gates_sparsity = 0
|
250 |
+
for gates_tensor in entropy_report["current_block_gate_softmaxes"]:
|
251 |
if torch.is_tensor(gates_tensor) and gates_tensor.numel() > 0:
|
252 |
+
gate_sparsity_loss += torch.mean(gates_tensor * torch.log(gates_tensor + 1e-9))
|
253 |
num_valid_gates_sparsity +=1
|
254 |
+
if num_valid_gates_sparsity > 0 : gate_sparsity_loss = -(gate_sparsity_loss / num_valid_gates_sparsity)
|
255 |
|
256 |
gate_alignment_loss = torch.tensor(0.0, device=device_global)
|
257 |
if entropy_report["current_block_gate_softmaxes"] and entropy_report["initial_block_gate_targets"]:
|
|
|
264 |
num_valid_align_gates +=1
|
265 |
if num_valid_align_gates > 0: gate_alignment_loss /= num_valid_align_gates
|
266 |
|
267 |
+
# CORRECTED VARIABLE NAME HERE
|
268 |
+
current_gate_alignment_weight = GATE_ALIGNMENT_LOSS_WEIGHT_APP if epoch < WIRING_PHASE_EPOCHS_APP else GATE_ALIGNMENT_LOSS_WEIGHT_APP * 0.1
|
269 |
|
270 |
combined_loss = (MAIN_LOSS_WEIGHT_APP * main_loss + BLOCK_TARGET_ENTROPY_LOSS_WEIGHT_APP * block_entropy_loss +
|
271 |
OVERALL_OUTPUT_ENTROPY_REG_WEIGHT_APP * overall_entropy_loss + GATE_SPARSITY_LOSS_WEIGHT_APP * gate_sparsity_loss +
|
|
|
285 |
'num_adaptive_blocks': len(swck_model_global.adaptive_blocks), 'dropout': current_dropout,
|
286 |
'seed_phrase': seed_phrase_ui, 'seed_number_str': seed_number_ui,
|
287 |
'num_sub_modules_per_block': swck_model_global.adaptive_blocks[0].num_sub_modules if swck_model_global.adaptive_blocks else current_num_sub_modules_pb,
|
288 |
+
'seq_len_trained_on': SEQ_LEN_APP
|
289 |
}
|
290 |
torch.save({'model_state_dict': swck_model_global.state_dict(), 'optimizer_state_dict': optimizer_global.state_dict(),
|
291 |
'word_to_idx': word_to_idx_global, 'idx_to_word': idx_to_word_global, 'model_hyperparameters': hyperparams
|
|
|
312 |
newly_generated_tokens_list = []
|
313 |
with torch.no_grad():
|
314 |
for i in range(int(max_len_gen)):
|
|
|
315 |
context_for_model = generated_ids_app[-SEQ_LEN_APP:]
|
|
|
316 |
if not context_for_model: print("Warning: Empty context_for_model!"); break
|
317 |
input_tensor = torch.tensor([context_for_model], dtype=torch.long).to(device_global)
|
318 |
padding_mask = (input_tensor == PAD_TOKEN)
|
|
|
342 |
generated_ids_app.append(next_token_id)
|
343 |
current_word = idx_to_word_global.get(next_token_id, UNK_TOKEN_STR)
|
344 |
newly_generated_tokens_list.append(current_word)
|
|
|
345 |
if i < 10:
|
346 |
overall_ent = entropy_report_infer['overall_output_entropy'].item() if torch.is_tensor(entropy_report_infer['overall_output_entropy']) else 0.0
|
347 |
b0_ent_str, b0_gates_str = "N/A", "N/A"
|
348 |
if entropy_report_infer['block_output_entropies'] and len(entropy_report_infer['block_output_entropies']) > 0 and torch.is_tensor(entropy_report_infer['block_output_entropies'][0]):
|
349 |
b0_ent_str = f"{entropy_report_infer['block_output_entropies'][0].item():.3f}"
|
350 |
+
if entropy_report_infer['current_block_gate_softmaxes'] and len(entropy_report_infer['current_block_gate_softmaxes']) > 0 and torch.is_tensor(entropy_report_infer['current_block_gate_softmaxes'][0]):
|
351 |
b0_gates_str = ", ".join([f"{g.item():.2f}" for g in entropy_report_infer['current_block_gate_softmaxes'][0]])
|
352 |
debug_info_lines.append(f"Gen {i+1}: '{current_word}', OvrlEnt={overall_ent:.3f}, B0Ent={b0_ent_str}, B0Gates=[{b0_gates_str}]")
|
353 |
|
|
|
379 |
'num_adaptive_blocks': len(swck_model_global.adaptive_blocks), 'dropout': current_dropout,
|
380 |
'seed_phrase': swck_model_global.seed_parser.seed_phrase, 'seed_number_str': swck_model_global.seed_parser.seed_number_str,
|
381 |
'num_sub_modules_per_block': swck_model_global.adaptive_blocks[0].num_sub_modules if swck_model_global.adaptive_blocks else current_num_sub_modules_pb,
|
382 |
+
'seq_len_trained_on': SEQ_LEN_APP
|
383 |
}
|
384 |
torch.save({'model_state_dict': swck_model_global.state_dict(), 'optimizer_state_dict': optimizer_global.state_dict(),
|
385 |
'word_to_idx': word_to_idx_global, 'idx_to_word': idx_to_word_global, 'model_hyperparameters': hyperparams
|