SWCK / model.py
neuralworm's picture
V5
1722634
raw
history blame
21.6 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import hashlib
# --- Future Entropy Predictor (FEP) ---
# (No changes from V4)
class FutureEntropyPredictor(nn.Module):
def __init__(self, input_dim=2, hidden_dim=16, output_dim=1, name=""):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.name = name
self.debug_prints_enabled = False
def forward(self, current_block_entropy, current_static_target_diff):
if not torch.is_tensor(current_block_entropy):
current_block_entropy = torch.tensor([current_block_entropy], device=self.fc1.weight.device, dtype=torch.float32)
if not torch.is_tensor(current_static_target_diff):
current_static_target_diff = torch.tensor([current_static_target_diff], device=self.fc1.weight.device, dtype=torch.float32)
current_block_entropy = current_block_entropy.view(-1, 1)
current_static_target_diff = current_static_target_diff.view(-1, 1)
x_in = torch.cat((current_block_entropy, current_static_target_diff), dim=1)
h = F.relu(self.fc1(x_in))
predicted_delta_factor_raw = self.fc2(h)
return predicted_delta_factor_raw.squeeze(-1)
# --- Helper: Entropy Estimator ---
# (No changes from V4)
class EntropyEstimator(nn.Module):
def __init__(self, d_model, hidden_dim=32, name=""):
super().__init__()
self.fc1 = nn.Linear(d_model, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, 1)
self.name = name
self.debug_prints_enabled = False
def forward(self, x, active_mask=None):
if x.numel() == 0: return torch.tensor(0.0, device=x.device)
if active_mask is not None:
if active_mask.dtype != torch.bool: active_mask = active_mask.bool()
if x.dim() == 3 and active_mask.dim() == 2 and x.shape[:2] == active_mask.shape: x_masked = x[active_mask]
elif x.dim() == 2 and active_mask.dim() == 1 and x.shape[0] == active_mask.shape[0]: x_masked = x[active_mask]
else: x_masked = x.reshape(-1, x.size(-1))
else: x_masked = x.reshape(-1, x.size(-1))
if x_masked.numel() == 0: return torch.tensor(0.0, device=x.device)
h = F.relu(self.fc1(x_masked)); return torch.sigmoid(self.fc2(h)).mean()
# --- Helper: Seed Parser ---
# (No changes from V4)
class SeedParser:
def __init__(self, seed_phrase, seed_number_str, d_model, num_adaptive_blocks, num_sub_modules_per_block):
self.seed_phrase = seed_phrase; self.seed_number_str = seed_number_str; self.d_model = d_model
self.num_adaptive_blocks = num_adaptive_blocks; self.num_sub_modules_per_block = num_sub_modules_per_block
self.debug_prints_enabled = True
if self.debug_prints_enabled: print(f"--- SeedParser Initialization ---\n Seed Phrase (start): '{self.seed_phrase[:50]}...'\n Seed Number: {self.seed_number_str}")
phrase_hash = hashlib.sha256(seed_phrase.encode()).hexdigest(); self.phrase_base_val = int(phrase_hash[:16], 16)
if self.debug_prints_enabled: print(f" Phrase Base Value (from hash): {self.phrase_base_val}")
self.num_sequence = [int(d) for d in seed_number_str if d.isdigit()]
if not self.num_sequence: self.num_sequence = [sum(bytearray(seed_number_str.encode())) % 10]
if self.debug_prints_enabled: print(f" Numerical Sequence (from seed number): {self.num_sequence}")
self.init_map = self._generate_init_map()
if self.debug_prints_enabled:
print(f" SeedParser: Generated InitMap:")
for i, block_config in enumerate(self.init_map["block_configs"]):
gate_inits_str = [f'{g:.3f}' for g in block_config['initial_gate_proportions']]
raw_gate_scores_str = [f'{g:.3f}' for g in block_config['raw_gate_scores_for_param_init']]
print(f" Block {i}: Target Entropy: {block_config['target_entropy']:.4f}, RawGateScores: {raw_gate_scores_str}, InitialGateProps (softmax): {gate_inits_str}")
if self.debug_prints_enabled: print(f"--- SeedParser Initialized ---")
def _get_deterministic_value(self, key_name, min_val, max_val, sequence_idx_offset=0): # ... (same as V4)
key_specific_hash = int(hashlib.sha256(key_name.encode() + self.seed_phrase.encode()).hexdigest()[:8], 16); num_seq_val = 0
if self.num_sequence:
for i, digit in enumerate(self.num_sequence): num_seq_val = (num_seq_val * 10 + digit) % 1000003
combined_seed_val = self.phrase_base_val + key_specific_hash + num_seq_val + sequence_idx_offset
if max_val == min_val: return min_val
val_range = max_val - min_val + 1
return min_val + int(abs(math.sin(float(combined_seed_val)) * 1e5)) % int(val_range)
def _get_deterministic_float(self, key_name, min_val=0.0, max_val=1.0, sequence_idx_offset=0): # ... (same as V4)
key_specific_hash = int(hashlib.sha256(key_name.encode() + self.seed_phrase.encode()).hexdigest()[:8], 16); num_seq_val = 0
if self.num_sequence:
for i, digit in enumerate(self.num_sequence): num_seq_val = (num_seq_val * 10 + digit) % 1000003
combined_seed_val = self.phrase_base_val + key_specific_hash + num_seq_val + sequence_idx_offset
norm_float = (math.sin(float(combined_seed_val) * 0.1) + 1.0) / 2.0
return min_val + norm_float * (max_val - min_val)
def _generate_init_map(self): # ... (same as V4, but remember initial_gate_proportions are softmax based)
init_map = {"block_configs": []}
for i in range(self.num_adaptive_blocks):
gate_raw_scores = [self._get_deterministic_float(f"block_{i}_gate_{j}_raw_score", -1.5, 1.5, sequence_idx_offset=i*10 + j) for j in range(self.num_sub_modules_per_block)]
gate_initial_proportions = F.softmax(torch.tensor(gate_raw_scores), dim=0).tolist() if self.num_sub_modules_per_block > 0 else []
target_entropy = self._get_deterministic_float(f"block_{i}_target_entropy", 0.15, 0.45, sequence_idx_offset=i)
init_map["block_configs"].append({"initial_gate_proportions": gate_initial_proportions, "raw_gate_scores_for_param_init": gate_raw_scores, "target_entropy": target_entropy})
return init_map
def get_block_config(self, block_idx): # ... (same as V4)
if 0 <= block_idx < len(self.init_map["block_configs"]): return self.init_map["block_configs"][block_idx]
return None
# --- Adaptive Block (V5 changes) ---
class AdaptiveBlock(nn.Module):
MAX_DYNAMIC_ENTROPY_ADJUSTMENT_RANGE = 0.05
INITIAL_HEURISTIC_STRENGTH = 0.025 # V5: Start strength for heuristic
FINAL_HEURISTIC_STRENGTH = 0.005 # V5: End strength for heuristic
def __init__(self, d_model, n_heads, d_ff, dropout, seed_parser_config_for_block, block_idx, num_sub_modules=3):
super().__init__()
self.d_model = d_model; self.block_idx = block_idx; self.num_sub_modules = num_sub_modules
self.config_from_seed = seed_parser_config_for_block; self.debug_prints_enabled = True
raw_gate_param_inits_list = self.config_from_seed.get("raw_gate_scores_for_param_init", [0.0] * self.num_sub_modules)
if len(raw_gate_param_inits_list) != self.num_sub_modules:
raw_gate_param_inits_list = [0.0] * self.num_sub_modules
self.gates_params = nn.Parameter(torch.tensor(raw_gate_param_inits_list, dtype=torch.float32))
# V5: Store initial raw scores as a buffer for alignment loss
self.register_buffer('initial_raw_gate_scores_buffer', torch.tensor(raw_gate_param_inits_list, dtype=torch.float32))
if self.debug_prints_enabled:
raw_gate_scores_str = [f'{g:.3f}' for g in raw_gate_param_inits_list]
print(f" Initializing AdaptiveBlock {self.block_idx} with seed config: StaticSeedTgtEnt={self.config_from_seed['target_entropy']:.3f}, InitialRawGateScores={raw_gate_scores_str}")
self.sub_module_0 = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.sub_module_1 = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model))
self.sub_module_2 = nn.Sequential(nn.Linear(d_model, d_model), nn.GELU(), nn.Dropout(dropout))
self.sub_modules = nn.ModuleList([self.sub_module_0, self.sub_module_1, self.sub_module_2])
if self.num_sub_modules > len(self.sub_modules): self.num_sub_modules = len(self.sub_modules)
elif self.num_sub_modules <= 0: raise ValueError(f"AdaptiveBlock {self.block_idx} must have at least one sub_module.")
self.norm1 = nn.LayerNorm(d_model); self.norm2 = nn.LayerNorm(d_model)
self.dropout_layer = nn.Dropout(dropout) # V5 Renamed from self.dropout to avoid conflict
self.output_entropy_estimator = EntropyEstimator(d_model, name=f"Block{block_idx}_OutEntropy")
self.fep = FutureEntropyPredictor(input_dim=2, hidden_dim=16, output_dim=1, name=f"Block{block_idx}_FEP")
self.wiring_phase_active = False
self.static_seed_target_entropy = self.config_from_seed.get("target_entropy", 0.25)
self.current_epoch_in_wiring = 0 # V5
self.total_wiring_epochs = 1 # V5: Default to 1 to prevent division by zero if not set
# V5: set_wiring_phase now takes epoch info for decaying strength
def set_wiring_phase(self, active, current_epoch_num=0, total_wiring_epochs=1):
self.wiring_phase_active = active
if active:
self.current_epoch_in_wiring = current_epoch_num
self.total_wiring_epochs = total_wiring_epochs if total_wiring_epochs > 0 else 1
def _get_current_heuristic_strength(self):
if not self.wiring_phase_active or self.total_wiring_epochs <= 1:
return self.INITIAL_HEURISTIC_STRENGTH # Or some default if not wiring
# Linear decay from INITIAL to FINAL strength over total_wiring_epochs
progress = min(self.current_epoch_in_wiring / (self.total_wiring_epochs -1 ), 1.0) if self.total_wiring_epochs >1 else 1.0
decayed_strength = self.INITIAL_HEURISTIC_STRENGTH - progress * (self.INITIAL_HEURISTIC_STRENGTH - self.FINAL_HEURISTIC_STRENGTH)
return decayed_strength
def forward(self, x, key_padding_mask=None, attn_mask=None):
# V5: Sigmoid activations
current_gates_activations = torch.sigmoid(self.gates_params)
if self.debug_prints_enabled and self.wiring_phase_active:
print(f" AdaptiveBlock {self.block_idx} (Wiring ON, Epoch {self.current_epoch_in_wiring+1}/{self.total_wiring_epochs}) Input x: {x.shape}, RawG: {[f'{g.item():.3f}' for g in self.gates_params.data]}, SigmoidG: {[f'{s.item():.3f}' for s in current_gates_activations.data]}")
x_norm_submodules = self.norm1(x)
outputs = []
for i, module_instance in enumerate(self.sub_modules):
if i >= self.num_sub_modules: break
if i == 0: module_out, _ = module_instance(x_norm_submodules, x_norm_submodules, x_norm_submodules, key_padding_mask=key_padding_mask, attn_mask=attn_mask, need_weights=False)
else: module_out = module_instance(x_norm_submodules)
outputs.append(module_out * current_gates_activations[i]) # V5: Apply sigmoid activation here
if not outputs: final_out_unnorm = x
else:
# V5: Summing activated outputs (no further multiplication by gates needed here as it's done above)
weighted_sum = torch.sum(torch.stack(outputs, dim=0), dim=0)
final_out_unnorm = x + self.dropout_layer(weighted_sum)
final_out_norm = self.norm2(final_out_unnorm)
current_output_entropy = self.output_entropy_estimator(final_out_norm, active_mask=~key_padding_mask if key_padding_mask is not None else None)
current_static_target_diff = current_output_entropy - self.static_seed_target_entropy
dynamic_target_entropy_for_heuristic = self.static_seed_target_entropy
predicted_delta_factor_for_report = torch.tensor(0.0, device=x.device)
if self.wiring_phase_active and self.training:
predicted_delta_factor_raw = self.fep(current_output_entropy.detach(), current_static_target_diff.detach())
predicted_delta_factor_tanh = torch.tanh(predicted_delta_factor_raw)
dynamic_adjustment = predicted_delta_factor_tanh * self.MAX_DYNAMIC_ENTROPY_ADJUSTMENT_RANGE
dynamic_target_entropy_for_heuristic = self.static_seed_target_entropy + dynamic_adjustment.item()
dynamic_target_entropy_for_heuristic = max(0.01, min(0.99, dynamic_target_entropy_for_heuristic))
predicted_delta_factor_for_report = predicted_delta_factor_tanh
with torch.no_grad():
entropy_diff_for_heuristic = current_output_entropy - dynamic_target_entropy_for_heuristic
# V5: Decaying heuristic strength
base_adjustment_strength = self._get_current_heuristic_strength()
adaptive_strength_factor = min(max(abs(entropy_diff_for_heuristic.item()) * 7.0, 0.3), 2.5)
adjustment_strength = base_adjustment_strength * adaptive_strength_factor
if self.debug_prints_enabled:
print(f" AdaptiveBlock {self.block_idx} WIRING PRE-ADJUST: RawG={[f'{g.item():.3f}' for g in self.gates_params.data]}, SigmoidG={[f'{s.item():.3f}' for s in current_gates_activations.data]}")
print(f" OutEnt={current_output_entropy.item():.4f}, StaticTgtEnt={self.static_seed_target_entropy:.4f}, FEPΔFactor={predicted_delta_factor_tanh.item():.4f}, DynTgtEnt={dynamic_target_entropy_for_heuristic:.4f}, ED_Dyn={entropy_diff_for_heuristic.item():.4f}, BaseHeurStr={base_adjustment_strength:.4f} AdjStr={adjustment_strength:.4f}")
if entropy_diff_for_heuristic.item() > 1e-4:
self.gates_params.data[0] -= adjustment_strength
self.gates_params.data[1] += adjustment_strength * 0.6
if self.num_sub_modules > 2: self.gates_params.data[2] += adjustment_strength * 0.4
elif entropy_diff_for_heuristic.item() < -1e-4:
self.gates_params.data[0] += adjustment_strength
self.gates_params.data[1] -= adjustment_strength * 0.6
if self.num_sub_modules > 2: self.gates_params.data[2] -= adjustment_strength * 0.4
self.gates_params.data.clamp_(-3.5, 3.5)
if self.debug_prints_enabled:
print(f" AdaptiveBlock {self.block_idx} WIRING POST-ADJUST: RawG={[f'{g.item():.3f}' for g in self.gates_params.data]}, SigmoidG={[f'{s.item():.3f}' for s in torch.sigmoid(self.gates_params.data)]}")
# V5: Return sigmoid activations
return final_out_norm, current_output_entropy, current_gates_activations, self.gates_params.data.clone(), predicted_delta_factor_for_report, torch.tensor(dynamic_target_entropy_for_heuristic, device=x.device)
# --- Positional Encoding ---
# (No changes from V4)
class PositionalEncoding(nn.Module): # ... (same as V4)
def __init__(self,d_model,dropout=0.1,max_len=512): super().__init__(); self.dropout=nn.Dropout(p=dropout); pe=torch.zeros(max_len,d_model); pos=torch.arange(0,max_len,dtype=torch.float).unsqueeze(1); div=torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model)); pe[:,0::2]=torch.sin(pos*div); pe[:,1::2]=torch.cos(pos*div); self.register_buffer('pe',pe.unsqueeze(0))
def forward(self,x): x=x+self.pe[:,:x.size(1),:]; return self.dropout(x)
# --- Main SWCK Model (V5 changes) ---
class SWCKModel(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, d_ff, num_adaptive_blocks,
dropout, seed_phrase, seed_number_str, num_sub_modules_per_block=3):
super().__init__()
self.d_model = d_model; self.seed_phrase = seed_phrase; self.seed_number_str = seed_number_str
self.debug_prints_enabled = True
if self.debug_prints_enabled: print(f"--- Initializing SWCKModel (V5) ---")
self.seed_parser = SeedParser(seed_phrase, seed_number_str, d_model, num_adaptive_blocks, num_sub_modules_per_block)
self.seed_parser.debug_prints_enabled = self.debug_prints_enabled
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout)
self.adaptive_blocks = nn.ModuleList()
for i in range(num_adaptive_blocks):
block_config = self.seed_parser.get_block_config(i)
if block_config is None: raise ValueError(f"SWCKModel Error: Could not get seed config for block {i}")
new_block = AdaptiveBlock(d_model, n_heads, d_ff, dropout, block_config, block_idx=i, num_sub_modules=num_sub_modules_per_block)
new_block.debug_prints_enabled = self.debug_prints_enabled
self.adaptive_blocks.append(new_block)
if self.debug_prints_enabled: print(f" SWCKModel: Added AdaptiveBlock {i} (V5 with Sigmoid Gates, Decaying Heuristic)")
self.fc_out = nn.Linear(d_model, vocab_size)
self.overall_output_entropy_estimator = EntropyEstimator(d_model, name="OverallOutEntropy")
self.overall_output_entropy_estimator.debug_prints_enabled = False
self._init_weights()
if self.debug_prints_enabled: print(f"--- SWCKModel V5 Initialized (Vocab: {vocab_size}, d_model: {d_model}, Blocks: {num_adaptive_blocks}x{num_sub_modules_per_block}sub) ---")
def _init_weights(self): # ... (same as V4)
initrange = 0.1; self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc_out.bias.data.zero_(); self.fc_out.weight.data.uniform_(-initrange, initrange)
# V5: set_wiring_phase now takes epoch info
def set_wiring_phase(self, active, current_epoch_num=0, total_wiring_epochs=1):
if self.debug_prints_enabled:
print(f"SWCKModel: Setting wiring phase to {active} for all blocks (Epoch {current_epoch_num+1}/{total_wiring_epochs} of wiring if active).")
for block in self.adaptive_blocks:
block.set_wiring_phase(active, current_epoch_num, total_wiring_epochs)
def forward(self, src_tokens, src_key_padding_mask=None):
if self.debug_prints_enabled:
print(f"\n--- SWCKModel Forward Pass (Training: {self.training}) ---")
print(f" Input src_tokens: {src_tokens.shape}")
if src_key_padding_mask is not None: print(f" Input src_key_padding_mask: {src_key_padding_mask.shape} (True means pad)")
x = self.embedding(src_tokens) * math.sqrt(self.d_model)
x = self.pos_encoder(x)
if self.debug_prints_enabled: print(f" After Embedding & PosEnc, x: {x.shape}")
block_output_entropies = []
current_block_gate_activations = [] # V5: Changed from softmaxes
current_block_gate_raw_params = []
fep_predicted_delta_factors = []
dynamic_target_entropies_used = []
for i, block in enumerate(self.adaptive_blocks):
if self.debug_prints_enabled: print(f" Processing AdaptiveBlock {i}...")
# V5 AdaptiveBlock returns sigmoid activations
x, block_entropy, current_gate_acts, raw_gate_params, fep_delta, dyn_target_ent = block(x, key_padding_mask=src_key_padding_mask, attn_mask=None)
block_output_entropies.append(block_entropy)
current_block_gate_activations.append(current_gate_acts) # V5
current_block_gate_raw_params.append(raw_gate_params)
fep_predicted_delta_factors.append(fep_delta)
dynamic_target_entropies_used.append(dyn_target_ent)
if self.debug_prints_enabled:
acts_str = [f'{act.item():.3f}' for act in current_gate_acts] # V5
raw_str = [f'{rp.item():.3f}' for rp in raw_gate_params]
fep_delta_str = f"{fep_delta.item():.3f}" if torch.is_tensor(fep_delta) else "N/A"
dyn_target_str = f"{dyn_target_ent.item():.3f}" if torch.is_tensor(dyn_target_ent) else "N/A"
print(f" Output x from Block {i}: {x.shape}, MeasEnt: {block_entropy.item():.4f}, FEPΔFactor: {fep_delta_str}, DynTgtUsed: {dyn_target_str}, SigmoidG: {acts_str}, RawG: {raw_str}") # V5
logits = self.fc_out(x)
if self.debug_prints_enabled: print(f" Output logits: {logits.shape}")
final_active_mask = ~src_key_padding_mask if src_key_padding_mask is not None else None
overall_entropy = self.overall_output_entropy_estimator(x, active_mask=final_active_mask)
if self.debug_prints_enabled: print(f" Overall Final Representation Entropy: {overall_entropy.item():.4f}")
entropy_report = {
"block_output_entropies": block_output_entropies,
"overall_output_entropy": overall_entropy,
"current_block_gate_activations": current_block_gate_activations, # V5
"current_block_gate_params": current_block_gate_raw_params,
# "initial_block_gate_targets" (softmax based) is removed from report as it's less relevant with sigmoid gates
# The alignment loss will use the initial_raw_gate_scores_buffer directly from the block.
"fep_predicted_delta_factors": fep_predicted_delta_factors,
"dynamic_target_entropies_used": dynamic_target_entropies_used
}
if self.debug_prints_enabled: print(f"--- SWCKModel Forward Pass Complete ---")
return logits, entropy_report