Spaces:
Running
Running
File size: 21,524 Bytes
71934cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import hashlib # For generating deterministic values from seed
# --- Helper: Entropy Estimator ---
class EntropyEstimator(nn.Module):
def __init__(self, d_model, hidden_dim=32, name=""): # Smaller hidden_dim for simplicity
super().__init__()
self.fc1 = nn.Linear(d_model, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, 1)
self.name = name
def forward(self, x, active_mask=None): # x: (batch, seq_len, d_model)
if active_mask is not None and x.shape[:-1] != active_mask.shape:
print(f"Warning [{self.name}]: x shape {x.shape[:-1]} and active_mask shape {active_mask.shape} mismatch. Entropy might be inaccurate.")
# Fallback if mask is problematic, or process only unmasked if shapes allow
if x.numel() == 0: return torch.tensor(0.0, device=x.device) # Handle empty tensor case
if active_mask.sum() == 0: return torch.tensor(0.0, device=x.device) # Handle all masked case
# Try to apply mask if possible, otherwise average all. This part can be tricky.
# For now, if shapes mismatch significantly, we might average all as a robust fallback.
# A more robust solution would ensure masks are always correct upstream.
if x.dim() == active_mask.dim() + 1 and x.shape[:-1] == active_mask.shape : # (B,S,D) and (B,S)
x_masked = x[active_mask]
if x_masked.numel() == 0: return torch.tensor(0.0, device=x.device)
h = F.relu(self.fc1(x_masked))
return torch.sigmoid(self.fc2(h)).mean() # Mean entropy over active elements
else: # Fallback if mask application is uncertain
h = F.relu(self.fc1(x.reshape(-1, x.size(-1))))
return torch.sigmoid(self.fc2(h)).mean()
elif active_mask is None and x.numel() > 0:
h = F.relu(self.fc1(x.reshape(-1, x.size(-1))))
return torch.sigmoid(self.fc2(h)).mean()
elif x.numel() == 0:
return torch.tensor(0.0, device=x.device) # Handle empty tensor
# Default if active_mask is present and correct
x_masked = x[active_mask]
if x_masked.numel() == 0: return torch.tensor(0.0, device=x.device)
h = F.relu(self.fc1(x_masked))
return torch.sigmoid(self.fc2(h)).mean() # Mean entropy over active elements
# --- Helper: Seed Parser ---
class SeedParser:
def __init__(self, seed_phrase, seed_number_str, d_model, num_adaptive_blocks, num_sub_modules_per_block):
self.seed_phrase = seed_phrase
self.seed_number_str = seed_number_str
self.d_model = d_model
self.num_adaptive_blocks = num_adaptive_blocks
self.num_sub_modules_per_block = num_sub_modules_per_block
self.debug_prints_enabled = True
print(f"--- SeedParser Initialization ---")
print(f" Seed Phrase: '{self.seed_phrase}'")
print(f" Seed Number: {self.seed_number_str}")
# 1. Process Seed Phrase (e.g., to get a base vector)
# For simplicity, hash it to get a deterministic starting point for numerical derivation
phrase_hash = hashlib.sha256(seed_phrase.encode()).hexdigest()
self.phrase_base_val = int(phrase_hash[:8], 16) # Use first 8 hex chars
if self.debug_prints_enabled: print(f" Phrase Base Value (from hash): {self.phrase_base_val}")
# 2. Process Seed Number (more direct influence on structure)
self.num_sequence = [int(d) for d in seed_number_str if d.isdigit()]
if not self.num_sequence: self.num_sequence = [0] # Fallback
if self.debug_prints_enabled: print(f" Numerical Sequence (from seed number): {self.num_sequence}")
self.init_map = self._generate_init_map()
if self.debug_prints_enabled:
print(f" Generated InitMap:")
for i, block_config in enumerate(self.init_map["block_configs"]):
print(f" Block {i}: Active Module Index: {block_config['active_module_idx']}, Target Entropy: {block_config['target_entropy']:.4f}, Gate Inits: {[f'{g:.2f}' for g in block_config['gate_inits']]}")
print(f"--- SeedParser Initialized ---")
def _get_deterministic_value(self, key_name, min_val, max_val, sequence_idx_offset=0):
# Combine phrase base and numerical sequence for more variation
combined_seed_val = self.phrase_base_val
for i, num in enumerate(self.num_sequence):
combined_seed_val += num * (10**(i + sequence_idx_offset))
# Hash the key_name to make it specific to the parameter
key_hash = int(hashlib.sha256(key_name.encode()).hexdigest()[:8], 16)
final_seed = combined_seed_val + key_hash
# Simple mapping to range (not cryptographically strong, but deterministic)
if max_val == min_val: return min_val # Avoid division by zero if range is 1
val = min_val + (final_seed % (max_val - min_val + 1))
return val
def _get_deterministic_float(self, key_name, min_val=0.0, max_val=1.0, sequence_idx_offset=0):
combined_seed_val = self.phrase_base_val
for i, num in enumerate(self.num_sequence):
combined_seed_val += num * (10**(i + sequence_idx_offset))
key_hash = int(hashlib.sha256(key_name.encode()).hexdigest()[:8], 16)
final_seed = combined_seed_val + key_hash
# Map to [0,1] float then scale
float_val = (final_seed % 1000001) / 1000000.0 # Ensure it's never exactly 0 for some ops
scaled_val = min_val + float_val * (max_val - min_val)
return scaled_val
def _generate_init_map(self):
init_map = {"block_configs": []}
for i in range(self.num_adaptive_blocks):
# Determine which sub-module is initially "more" active
active_module_idx = self._get_deterministic_value(
f"block_{i}_active_module", 0, self.num_sub_modules_per_block - 1, sequence_idx_offset=i
)
# Determine initial gating values (summing to 1 for softmax-like behavior later)
gate_inits_raw = [
self._get_deterministic_float(f"block_{i}_gate_{j}_init_raw", 0.1, 1.0, sequence_idx_offset=i*10 + j)
for j in range(self.num_sub_modules_per_block)
]
# Make one gate stronger based on active_module_idx, then normalize slightly
if self.num_sub_modules_per_block > 0 :
gate_inits_raw[active_module_idx] *= 2.0 # Boost the 'active' one
sum_raw = sum(gate_inits_raw)
gate_inits_normalized = [g / sum_raw for g in gate_inits_raw] if sum_raw > 0 else [1.0/self.num_sub_modules_per_block]*self.num_sub_modules_per_block
else:
gate_inits_normalized = []
# Determine a target entropy for this block's output
target_entropy = self._get_deterministic_float(
f"block_{i}_target_entropy", 0.05, 0.3, sequence_idx_offset=i # Target a moderate, non-zero entropy
)
init_map["block_configs"].append({
"active_module_idx": active_module_idx, # For initial bias
"gate_inits": gate_inits_normalized, # Initial values for learnable gates
"target_entropy": target_entropy
})
return init_map
def get_block_config(self, block_idx):
if 0 <= block_idx < len(self.init_map["block_configs"]):
return self.init_map["block_configs"][block_idx]
return None
# --- Adaptive Block ---
class AdaptiveBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout, seed_parser_config, block_idx, num_sub_modules=3):
super().__init__()
self.d_model = d_model
self.block_idx = block_idx
self.num_sub_modules = num_sub_modules
self.config_from_seed = seed_parser_config # dict for this block
self.debug_prints_enabled = True
if self.debug_prints_enabled:
print(f" Initializing AdaptiveBlock {self.block_idx} with seed config: {self.config_from_seed}")
# Define potential sub-modules
self.sub_module_0 = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.sub_module_1 = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model)
)
# Sub-module 2: A simpler FFN or even a near identity (residual + small transform)
self.sub_module_2 = nn.Sequential(
nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model // 2, d_model)
)
# Add more diverse sub-modules if needed for `num_sub_modules_per_block`
self.sub_modules = nn.ModuleList([self.sub_module_0, self.sub_module_1, self.sub_module_2])
if self.num_sub_modules > len(self.sub_modules):
print(f"Warning: block {self.block_idx} requested {self.num_sub_modules} sub_modules, but only {len(self.sub_modules)} are defined. Using defined ones.")
self.num_sub_modules = len(self.sub_modules)
# Learnable gates for combining/selecting sub-modules
# Initialize gates based on seed_parser_config
gate_initial_values = self.config_from_seed.get("gate_inits", [1.0/self.num_sub_modules]*self.num_sub_modules if self.num_sub_modules > 0 else [])
if len(gate_initial_values) != self.num_sub_modules: # Fallback if seed parser gave wrong number
print(f"Warning: Block {self.block_idx} gate_inits length mismatch. Re-initializing uniformly.")
gate_initial_values = [1.0/self.num_sub_modules]*self.num_sub_modules if self.num_sub_modules > 0 else []
self.gates = nn.Parameter(torch.tensor(gate_initial_values, dtype=torch.float32))
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model) # For output of block
self.dropout = nn.Dropout(dropout)
self.output_entropy_estimator = EntropyEstimator(d_model, name=f"Block{block_idx}_OutEntropy")
self.wiring_phase_active = False # To be set by the main model
def set_wiring_phase(self, active):
self.wiring_phase_active = active
if self.debug_prints_enabled and active:
print(f" AdaptiveBlock {self.block_idx}: WIRING PHASE ACTIVATED")
elif self.debug_prints_enabled and not active:
print(f" AdaptiveBlock {self.block_idx}: WIRING PHASE DEACTIVATED")
def forward(self, x, key_padding_mask=None, attn_mask=None): # attn_mask is for MHA, key_padding_mask for MHA keys
if self.debug_prints_enabled:
current_gates_softmax = F.softmax(self.gates, dim=0)
print(f" AdaptiveBlock {self.block_idx} Input x: {x.shape}, Gates (softmax): {[f'{g.item():.3f}' for g in current_gates_softmax]}")
x_norm = self.norm1(x)
outputs = []
active_module_found = False
for i, module in enumerate(self.sub_modules):
if i >= self.num_sub_modules: break # Only use configured number
if i == 0: # MHA
# MHA expects key_padding_mask (N, S) bool: True if padded.
# attn_mask (L,S) or (N*H,L,S) float/bool: True if masked / -inf.
# For self-attention, L=S. If attn_mask is causal (L,L), it's fine.
# If key_padding_mask is (N,S), it's fine.
module_out, _ = module(x_norm, x_norm, x_norm,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
need_weights=False) # Don't need weights for this sim
active_module_found = True
elif hasattr(module, 'fc1') or isinstance(module, nn.Sequential): # FFN-like
module_out = module(x_norm)
active_module_found = True
else: # Fallback for undefined module types in this simple sketch
module_out = x_norm # Pass through
outputs.append(module_out)
if not active_module_found or not outputs: # Should not happen if num_sub_modules > 0
print(f" AdaptiveBlock {self.block_idx}: No active sub_modules processed. Passing input through.")
final_out_unnorm = x # pass through
else:
# Gated combination
gate_weights = F.softmax(self.gates, dim=0) # Ensure they sum to 1
# Weighted sum of module outputs
# Ensure outputs are stackable (they should be if all modules output (B,S,D))
if outputs:
stacked_outputs = torch.stack(outputs, dim=0) # (num_sub_modules, B, S, D)
# gate_weights (num_sub_modules) -> (num_sub_modules, 1, 1, 1) for broadcasting
weighted_sum = torch.sum(stacked_outputs * gate_weights.view(-1, 1, 1, 1), dim=0)
final_out_unnorm = x + self.dropout(weighted_sum) # Residual connection
else: # Fallback if somehow no outputs
final_out_unnorm = x
final_out_norm = self.norm2(final_out_unnorm)
# During wiring phase, we might adjust gates based on local entropy vs target
# This is a very simplified "self-wiring" heuristic
current_output_entropy = self.output_entropy_estimator(final_out_norm, active_mask=~key_padding_mask if key_padding_mask is not None else None)
target_entropy_for_block = self.config_from_seed.get("target_entropy", 0.1) # Default target
if self.wiring_phase_active and self.training : # Only adjust gates during wiring AND training
with torch.no_grad(): # Don't track gradients for this heuristic adjustment
entropy_diff = current_output_entropy - target_entropy_for_block
# If current entropy is too high, slightly boost gates of modules that might reduce it (heuristic)
# If too low, slightly boost gates of modules that might increase it (heuristic)
# This is extremely heuristic. A true self-wiring mechanism would be more complex.
# For this sketch, let's say MHA (module 0) might increase complexity/entropy if it was low,
# and FFNs (module 1, 2) might refine/stabilize if entropy was high.
adjustment_strength = 0.01 # Small adjustment
if entropy_diff > 0.05: # Current entropy significantly higher than target
self.gates.data[1] += adjustment_strength
self.gates.data[2] += adjustment_strength
self.gates.data[0] -= adjustment_strength * 0.5 # Slightly decrease MHA
elif entropy_diff < -0.05: # Current entropy significantly lower
self.gates.data[0] += adjustment_strength
self.gates.data[1] -= adjustment_strength * 0.5
self.gates.data[2] -= adjustment_strength * 0.5
# Clamp gates to avoid extreme values before softmax (optional)
self.gates.data.clamp_(-2.0, 2.0)
if self.debug_prints_enabled:
print(f" AdaptiveBlock {self.block_idx} WIRING: OutEnt={current_output_entropy.item():.4f}, TgtEnt={target_entropy_for_block:.4f}, Δ={entropy_diff.item():.4f} -> New Gates (raw): {[f'{g.item():.3f}' for g in self.gates.data]}")
elif self.debug_prints_enabled:
print(f" AdaptiveBlock {self.block_idx} EXEC: OutEnt={current_output_entropy.item():.4f}, TgtEnt={target_entropy_for_block:.4f}")
# Return the block's output and its current estimated output entropy
return final_out_norm, current_output_entropy, gate_weights
# --- Positional Encoding ---
class PositionalEncoding(nn.Module):
def __init__(self,d_model,dropout=0.1,max_len=512): # Reduced max_len for this sketch
super().__init__()
self.dropout=nn.Dropout(p=dropout)
pe=torch.zeros(max_len,d_model)
pos=torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)
div=torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))
pe[:,0::2]=torch.sin(pos*div)
pe[:,1::2]=torch.cos(pos*div)
self.register_buffer('pe',pe.unsqueeze(0)) # (1, max_len, d_model)
def forward(self,x): # x: (batch, seq_len, d_model)
x=x+self.pe[:,:x.size(1),:]
return self.dropout(x)
# --- Main SWCK Model ---
class SWCKModel(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, d_ff, num_adaptive_blocks,
dropout, seed_phrase, seed_number_str, num_sub_modules_per_block=3):
super().__init__()
self.d_model = d_model
self.seed_phrase = seed_phrase
self.seed_number_str = seed_number_str
self.debug_prints_enabled = True
print(f"--- Initializing SWCKModel ---")
self.seed_parser = SeedParser(seed_phrase, seed_number_str, d_model, num_adaptive_blocks, num_sub_modules_per_block)
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout)
self.adaptive_blocks = nn.ModuleList()
for i in range(num_adaptive_blocks):
block_config = self.seed_parser.get_block_config(i)
if block_config is None:
raise ValueError(f"Could not get seed config for block {i}")
self.adaptive_blocks.append(
AdaptiveBlock(d_model, n_heads, d_ff, dropout, block_config, block_idx=i, num_sub_modules=num_sub_modules_per_block)
)
if self.debug_prints_enabled:
print(f" SWCKModel: Added AdaptiveBlock {i}")
self.fc_out = nn.Linear(d_model, vocab_size)
self.overall_output_entropy_estimator = EntropyEstimator(d_model, name="OverallOutEntropy")
self._init_weights()
print(f"--- SWCKModel Initialized ---")
def _init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc_out.bias.data.zero_()
self.fc_out.weight.data.uniform_(-initrange, initrange)
def set_wiring_phase(self, active):
if self.debug_prints_enabled:
print(f"SWCKModel: Setting wiring phase to {active} for all blocks.")
for block in self.adaptive_blocks:
block.set_wiring_phase(active)
def forward(self, src_tokens, src_key_padding_mask=None):
# src_tokens: (batch, seq_len)
# src_key_padding_mask: (batch, seq_len), True for padded positions
if self.debug_prints_enabled:
print(f"\n--- SWCKModel Forward Pass ---")
print(f" Input src_tokens: {src_tokens.shape}")
if src_key_padding_mask is not None: print(f" Input src_key_padding_mask: {src_key_padding_mask.shape}")
x = self.embedding(src_tokens) * math.sqrt(self.d_model)
x = self.pos_encoder(x)
if self.debug_prints_enabled: print(f" After Embedding & PosEnc, x: {x.shape}")
block_output_entropies = []
block_gate_weights = []
# For self-attention within blocks, a causal mask might be needed if it's a decoder-style model
# For this general "processing core" sketch, let's assume full self-attention unless specified.
# If this were a decoder, a causal mask would be passed or generated here.
# For now, no explicit top-level causal mask is made, relying on block's internal MHA params.
# A more standard transformer would create a causal mask for decoder self-attention.
# We'll pass src_key_padding_mask to MHA if it's self-attention on source.
for i, block in enumerate(self.adaptive_blocks):
if self.debug_prints_enabled: print(f" Processing AdaptiveBlock {i}...")
# For self-attention in blocks, key_padding_mask applies to keys/values.
# No separate attention mask for now unless it's a decoder block.
x, block_entropy, gates = block(x, key_padding_mask=src_key_padding_mask, attn_mask=None)
block_output_entropies.append(block_entropy)
block_gate_weights.append(gates)
if self.debug_prints_enabled: print(f" Output x from AdaptiveBlock {i}: {x.shape}, Entropy: {block_entropy.item():.4f}")
logits = self.fc_out(x)
if self.debug_prints_enabled: print(f" Output logits: {logits.shape}")
# Overall output entropy (of the final representation before fc_out)
# Masking for entropy calculation
final_active_mask = ~src_key_padding_mask if src_key_padding_mask is not None else None
overall_entropy = self.overall_output_entropy_estimator(x, active_mask=final_active_mask)
if self.debug_prints_enabled: print(f" Overall Final Representation Entropy: {overall_entropy.item():.4f}")
# Entropies from each block, overall output entropy, and gate weights for regularization/logging
entropy_report = {
"block_output_entropies": block_output_entropies, # List of tensors
"overall_output_entropy": overall_entropy, # Tensor
"block_gate_weights": block_gate_weights # List of tensors
}
return logits, entropy_report |