SWCK / model.py
neuralworm's picture
overhaul by Gemini
d82b2bb
raw
history blame
17.5 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import hashlib # For generating deterministic values from seed
# --- Helper: Entropy Estimator ---
class EntropyEstimator(nn.Module):
def __init__(self, d_model, hidden_dim=32, name=""):
super().__init__()
self.fc1 = nn.Linear(d_model, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, 1)
self.name = name
self.debug_prints_enabled = True # Default to True for this module if needed
def forward(self, x, active_mask=None): # x: (batch, seq_len, d_model)
# Simplified masking logic for robustness
if x.numel() == 0:
return torch.tensor(0.0, device=x.device)
if active_mask is not None:
# Ensure active_mask is boolean and compatible shape for broadcasting/indexing
if active_mask.dtype != torch.bool:
active_mask = active_mask.bool()
if x.dim() == 3 and active_mask.dim() == 2 and x.shape[:2] == active_mask.shape:
# typical case: x is (B,S,D), active_mask is (B,S)
x_masked = x[active_mask] # This flattens to (N_active, D)
elif x.dim() == 2 and active_mask.dim() == 1 and x.shape[0] == active_mask.shape[0]:
# x is (S,D) or (B,D) - less common here, but handle
x_masked = x[active_mask]
else: # Fallback if mask shapes are unexpected, process all elements
# if self.debug_prints_enabled:
# print(f"Warning [{self.name}]: Mask shape mismatch (x: {x.shape}, mask: {active_mask.shape}). Processing all elements.")
x_masked = x.reshape(-1, x.size(-1))
else:
x_masked = x.reshape(-1, x.size(-1))
if x_masked.numel() == 0:
return torch.tensor(0.0, device=x.device)
h = F.relu(self.fc1(x_masked))
# Sigmoid output, then mean. Represents average "activity" or "confidence" as a proxy for entropy.
estimated_entropy = torch.sigmoid(self.fc2(h)).mean()
return estimated_entropy
# --- Helper: Seed Parser ---
class SeedParser:
def __init__(self, seed_phrase, seed_number_str, d_model, num_adaptive_blocks, num_sub_modules_per_block):
self.seed_phrase = seed_phrase
self.seed_number_str = seed_number_str
self.d_model = d_model
self.num_adaptive_blocks = num_adaptive_blocks
self.num_sub_modules_per_block = num_sub_modules_per_block
self.debug_prints_enabled = True
if self.debug_prints_enabled:
print(f"--- SeedParser Initialization ---")
print(f" Seed Phrase (start): '{self.seed_phrase[:50]}...'")
print(f" Seed Number: {self.seed_number_str}")
phrase_hash = hashlib.sha256(seed_phrase.encode()).hexdigest()
self.phrase_base_val = int(phrase_hash[:16], 16)
if self.debug_prints_enabled: print(f" Phrase Base Value (from hash): {self.phrase_base_val}")
self.num_sequence = [int(d) for d in seed_number_str if d.isdigit()]
if not self.num_sequence: self.num_sequence = [sum(bytearray(seed_number_str.encode())) % 10]
if self.debug_prints_enabled: print(f" Numerical Sequence (from seed number): {self.num_sequence}")
self.init_map = self._generate_init_map()
if self.debug_prints_enabled:
print(f" SeedParser: Generated InitMap:")
for i, block_config in enumerate(self.init_map["block_configs"]):
gate_inits_str = [f'{g:.3f}' for g in block_config['initial_gate_proportions']]
print(f" Block {i}: Target Entropy: {block_config['target_entropy']:.4f}, Initial Gate Proportions: {gate_inits_str}")
if self.debug_prints_enabled: print(f"--- SeedParser Initialized ---")
def _get_deterministic_value(self, key_name, min_val, max_val, sequence_idx_offset=0):
key_specific_hash = int(hashlib.sha256(key_name.encode() + self.seed_phrase.encode()).hexdigest()[:8], 16)
num_seq_val = 0
if self.num_sequence:
for i, digit in enumerate(self.num_sequence):
num_seq_val = (num_seq_val * 10 + digit) % 1000003
combined_seed_val = self.phrase_base_val + key_specific_hash + num_seq_val + sequence_idx_offset
if max_val == min_val: return min_val
val_range = max_val - min_val + 1
return min_val + int(abs(math.sin(float(combined_seed_val)) * 1e5)) % val_range
def _get_deterministic_float(self, key_name, min_val=0.0, max_val=1.0, sequence_idx_offset=0):
key_specific_hash = int(hashlib.sha256(key_name.encode() + self.seed_phrase.encode()).hexdigest()[:8], 16)
num_seq_val = 0
if self.num_sequence:
for i, digit in enumerate(self.num_sequence):
num_seq_val = (num_seq_val * 10 + digit) % 1000003
combined_seed_val = self.phrase_base_val + key_specific_hash + num_seq_val + sequence_idx_offset
norm_float = (math.sin(float(combined_seed_val) * 0.1) + 1.0) / 2.0
scaled_val = min_val + norm_float * (max_val - min_val)
return scaled_val
def _generate_init_map(self):
init_map = {"block_configs": []}
for i in range(self.num_adaptive_blocks):
gate_raw_scores = [
self._get_deterministic_float(f"block_{i}_gate_{j}_raw_score", -1.0, 1.0, sequence_idx_offset=i*10 + j)
for j in range(self.num_sub_modules_per_block)
]
if self.num_sub_modules_per_block > 0:
gate_initial_proportions = F.softmax(torch.tensor(gate_raw_scores), dim=0).tolist()
else:
gate_initial_proportions = []
target_entropy = self._get_deterministic_float(
f"block_{i}_target_entropy", 0.05, 0.35, sequence_idx_offset=i
)
init_map["block_configs"].append({
"initial_gate_proportions": gate_initial_proportions,
"raw_gate_scores_for_param_init": gate_raw_scores,
"target_entropy": target_entropy
})
return init_map
def get_block_config(self, block_idx):
if 0 <= block_idx < len(self.init_map["block_configs"]):
return self.init_map["block_configs"][block_idx]
return None
# --- Adaptive Block ---
class AdaptiveBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout, seed_parser_config_for_block, block_idx, num_sub_modules=3):
super().__init__()
self.d_model = d_model
self.block_idx = block_idx
self.num_sub_modules = num_sub_modules
self.config_from_seed = seed_parser_config_for_block
self.debug_prints_enabled = True
if self.debug_prints_enabled:
print(f" Initializing AdaptiveBlock {self.block_idx} with seed config: TargetEntropy={self.config_from_seed['target_entropy']:.3f}, InitialGateProportions={[f'{g:.3f}' for g in self.config_from_seed['initial_gate_proportions']]}")
self.sub_module_0 = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.sub_module_1 = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model))
self.sub_module_2 = nn.Sequential(nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model // 2, d_model))
self.sub_modules = nn.ModuleList([self.sub_module_0, self.sub_module_1, self.sub_module_2])
if self.num_sub_modules > len(self.sub_modules):
print(f"Warning: block {self.block_idx} requested {self.num_sub_modules} sub_modules, but only {len(self.sub_modules)} defined. Using defined count.")
self.num_sub_modules = len(self.sub_modules)
raw_gate_param_inits = self.config_from_seed.get("raw_gate_scores_for_param_init", [0.0] * self.num_sub_modules if self.num_sub_modules > 0 else [])
if len(raw_gate_param_inits) != self.num_sub_modules:
print(f"Warning: Block {self.block_idx} raw_gate_scores length mismatch. Re-initializing to zeros.")
raw_gate_param_inits = [0.0] * self.num_sub_modules if self.num_sub_modules > 0 else []
self.gates_params = nn.Parameter(torch.tensor(raw_gate_param_inits, dtype=torch.float32))
self.initial_gate_proportions_tensor = torch.tensor(self.config_from_seed['initial_gate_proportions'], dtype=torch.float32)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.output_entropy_estimator = EntropyEstimator(d_model, name=f"Block{block_idx}_OutEntropy")
self.wiring_phase_active = False
def set_wiring_phase(self, active):
self.wiring_phase_active = active
# if self.debug_prints_enabled:
# phase_status = "ACTIVATED" if active else "DEACTIVATED"
# print(f" AdaptiveBlock {self.block_idx}: WIRING PHASE {phase_status}") # Made less verbose
def forward(self, x, key_padding_mask=None, attn_mask=None):
current_gates_softmax = F.softmax(self.gates_params, dim=0)
# if self.debug_prints_enabled: # Made less verbose
# print(f" AdaptiveBlock {self.block_idx} Input x: {x.shape}, Current Gates (softmax): {[f'{g.item():.3f}' for g in current_gates_softmax]}")
x_norm = self.norm1(x)
outputs = []
for i, module in enumerate(self.sub_modules):
if i >= self.num_sub_modules: break
if i == 0:
module_out, _ = module(x_norm, x_norm, x_norm, key_padding_mask=key_padding_mask, attn_mask=attn_mask, need_weights=False)
else:
module_out = module(x_norm)
outputs.append(module_out)
if not outputs:
if self.debug_prints_enabled: print(f" AdaptiveBlock {self.block_idx}: No sub_modules processed. Passing input through.")
final_out_unnorm = x
else:
stacked_outputs = torch.stack(outputs, dim=0)
weighted_sum = torch.sum(stacked_outputs * current_gates_softmax.view(-1, 1, 1, 1), dim=0)
final_out_unnorm = x + self.dropout(weighted_sum)
final_out_norm = self.norm2(final_out_unnorm)
current_output_entropy = self.output_entropy_estimator(final_out_norm, active_mask=~key_padding_mask if key_padding_mask is not None else None)
target_entropy_for_block = self.config_from_seed.get("target_entropy", 0.1)
if self.wiring_phase_active and self.training:
with torch.no_grad():
entropy_diff = current_output_entropy - target_entropy_for_block
adjustment_strength = 0.01
if entropy_diff > 0.05:
self.gates_params.data[1] += adjustment_strength
if self.num_sub_modules > 2: self.gates_params.data[2] += adjustment_strength
self.gates_params.data[0] -= adjustment_strength * 0.5
elif entropy_diff < -0.05:
self.gates_params.data[0] += adjustment_strength
self.gates_params.data[1] -= adjustment_strength * 0.5
if self.num_sub_modules > 2: self.gates_params.data[2] -= adjustment_strength * 0.5
self.gates_params.data.clamp_(-2.5, 2.5)
if self.debug_prints_enabled:
print(f" AdaptiveBlock {self.block_idx} WIRING: OutEnt={current_output_entropy.item():.4f}, TgtEnt={target_entropy_for_block:.4f}, Δ={entropy_diff.item():.4f} -> New Gate Params (raw): {[f'{g.item():.3f}' for g in self.gates_params.data]}")
initial_gate_targets_on_device = self.initial_gate_proportions_tensor.to(self.gates_params.device)
return final_out_norm, current_output_entropy, current_gates_softmax, self.gates_params, initial_gate_targets_on_device
# --- Positional Encoding ---
class PositionalEncoding(nn.Module):
def __init__(self,d_model,dropout=0.1,max_len=512): # Default max_len is good
super().__init__()
self.dropout=nn.Dropout(p=dropout)
pe=torch.zeros(max_len,d_model)
pos=torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)
div=torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))
pe[:,0::2]=torch.sin(pos*div)
pe[:,1::2]=torch.cos(pos*div)
self.register_buffer('pe',pe.unsqueeze(0))
def forward(self,x):
# x: (batch, seq_len, d_model)
# self.pe: (1, max_len, d_model)
# We need to select the part of pe corresponding to x's seq_len
x=x+self.pe[:,:x.size(1),:]
return self.dropout(x)
# --- Main SWCK Model ---
class SWCKModel(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, d_ff, num_adaptive_blocks,
dropout, seed_phrase, seed_number_str, num_sub_modules_per_block=3):
super().__init__()
self.d_model = d_model
self.seed_phrase = seed_phrase
self.seed_number_str = seed_number_str
self.debug_prints_enabled = True
if self.debug_prints_enabled: print(f"--- Initializing SWCKModel ---")
self.seed_parser = SeedParser(seed_phrase, seed_number_str, d_model, num_adaptive_blocks, num_sub_modules_per_block)
self.seed_parser.debug_prints_enabled = self.debug_prints_enabled
self.embedding = nn.Embedding(vocab_size, d_model)
# Corrected: PositionalEncoding uses its own default max_len or a hardcoded one.
# It does not depend on SEQ_LEN_APP from app.py.
self.pos_encoder = PositionalEncoding(d_model, dropout)
self.adaptive_blocks = nn.ModuleList()
for i in range(num_adaptive_blocks):
block_config = self.seed_parser.get_block_config(i)
if block_config is None:
raise ValueError(f"Could not get seed config for block {i}")
new_block = AdaptiveBlock(d_model, n_heads, d_ff, dropout, block_config, block_idx=i, num_sub_modules=num_sub_modules_per_block)
new_block.debug_prints_enabled = self.debug_prints_enabled
self.adaptive_blocks.append(new_block)
if self.debug_prints_enabled: print(f" SWCKModel: Added AdaptiveBlock {i}")
self.fc_out = nn.Linear(d_model, vocab_size)
self.overall_output_entropy_estimator = EntropyEstimator(d_model, name="OverallOutEntropy")
self.overall_output_entropy_estimator.debug_prints_enabled = self.debug_prints_enabled
self._init_weights()
if self.debug_prints_enabled: print(f"--- SWCKModel Initialized (Vocab: {vocab_size}, d_model: {d_model}) ---")
def _init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc_out.bias.data.zero_()
self.fc_out.weight.data.uniform_(-initrange, initrange)
def set_wiring_phase(self, active):
if self.debug_prints_enabled:
# print(f"SWCKModel: Setting wiring phase to {active} for all blocks.") # Made less verbose
pass
for block in self.adaptive_blocks:
block.set_wiring_phase(active)
def forward(self, src_tokens, src_key_padding_mask=None):
# if self.debug_prints_enabled: # Made less verbose
# print(f"\n--- SWCKModel Forward Pass ---")
# print(f" Input src_tokens: {src_tokens.shape}")
# if src_key_padding_mask is not None: print(f" Input src_key_padding_mask: {src_key_padding_mask.shape} (True means pad)")
x = self.embedding(src_tokens) * math.sqrt(self.d_model)
x = self.pos_encoder(x)
# if self.debug_prints_enabled: print(f" After Embedding & PosEnc, x: {x.shape}") # Made less verbose
block_output_entropies = []
current_block_gate_softmaxes = []
current_block_gate_params = []
initial_block_gate_targets = []
for i, block in enumerate(self.adaptive_blocks):
# if self.debug_prints_enabled: print(f" Processing AdaptiveBlock {i}...") # Made less verbose
x, block_entropy, current_gate_softmax, current_gate_param, initial_gate_target = block(x, key_padding_mask=src_key_padding_mask, attn_mask=None)
block_output_entropies.append(block_entropy)
current_block_gate_softmaxes.append(current_gate_softmax)
current_block_gate_params.append(current_gate_param)
initial_block_gate_targets.append(initial_gate_target)
# if self.debug_prints_enabled: print(f" Output x from AdaptiveBlock {i}: {x.shape}, Entropy: {block_entropy.item():.4f}") # Made less verbose
logits = self.fc_out(x)
# if self.debug_prints_enabled: print(f" Output logits: {logits.shape}") # Made less verbose
final_active_mask = ~src_key_padding_mask if src_key_padding_mask is not None else None
overall_entropy = self.overall_output_entropy_estimator(x, active_mask=final_active_mask)
# if self.debug_prints_enabled: print(f" Overall Final Representation Entropy: {overall_entropy.item():.4f}") # Made less verbose
entropy_report = {
"block_output_entropies": block_output_entropies,
"overall_output_entropy": overall_entropy,
"current_block_gate_softmaxes": current_block_gate_softmaxes,
"current_block_gate_params": current_block_gate_params,
"initial_block_gate_targets": initial_block_gate_targets
}
return logits, entropy_report