File size: 21,524 Bytes
71934cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import hashlib # For generating deterministic values from seed

# --- Helper: Entropy Estimator ---
class EntropyEstimator(nn.Module):
    def __init__(self, d_model, hidden_dim=32, name=""): # Smaller hidden_dim for simplicity
        super().__init__()
        self.fc1 = nn.Linear(d_model, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.name = name

    def forward(self, x, active_mask=None): # x: (batch, seq_len, d_model)
        if active_mask is not None and x.shape[:-1] != active_mask.shape:
            print(f"Warning [{self.name}]: x shape {x.shape[:-1]} and active_mask shape {active_mask.shape} mismatch. Entropy might be inaccurate.")
            # Fallback if mask is problematic, or process only unmasked if shapes allow
            if x.numel() == 0: return torch.tensor(0.0, device=x.device) # Handle empty tensor case
            if active_mask.sum() == 0: return torch.tensor(0.0, device=x.device) # Handle all masked case
            # Try to apply mask if possible, otherwise average all. This part can be tricky.
            # For now, if shapes mismatch significantly, we might average all as a robust fallback.
            # A more robust solution would ensure masks are always correct upstream.
            if x.dim() == active_mask.dim() + 1 and x.shape[:-1] == active_mask.shape : # (B,S,D) and (B,S)
                 x_masked = x[active_mask]
                 if x_masked.numel() == 0: return torch.tensor(0.0, device=x.device)
                 h = F.relu(self.fc1(x_masked))
                 return torch.sigmoid(self.fc2(h)).mean() # Mean entropy over active elements
            else: # Fallback if mask application is uncertain
                 h = F.relu(self.fc1(x.reshape(-1, x.size(-1))))
                 return torch.sigmoid(self.fc2(h)).mean()

        elif active_mask is None and x.numel() > 0:
            h = F.relu(self.fc1(x.reshape(-1, x.size(-1))))
            return torch.sigmoid(self.fc2(h)).mean()
        elif x.numel() == 0:
            return torch.tensor(0.0, device=x.device) # Handle empty tensor

        # Default if active_mask is present and correct
        x_masked = x[active_mask]
        if x_masked.numel() == 0: return torch.tensor(0.0, device=x.device)
        h = F.relu(self.fc1(x_masked))
        return torch.sigmoid(self.fc2(h)).mean() # Mean entropy over active elements

# --- Helper: Seed Parser ---
class SeedParser:
    def __init__(self, seed_phrase, seed_number_str, d_model, num_adaptive_blocks, num_sub_modules_per_block):
        self.seed_phrase = seed_phrase
        self.seed_number_str = seed_number_str
        self.d_model = d_model
        self.num_adaptive_blocks = num_adaptive_blocks
        self.num_sub_modules_per_block = num_sub_modules_per_block
        self.debug_prints_enabled = True

        print(f"--- SeedParser Initialization ---")
        print(f"  Seed Phrase: '{self.seed_phrase}'")
        print(f"  Seed Number: {self.seed_number_str}")

        # 1. Process Seed Phrase (e.g., to get a base vector)
        # For simplicity, hash it to get a deterministic starting point for numerical derivation
        phrase_hash = hashlib.sha256(seed_phrase.encode()).hexdigest()
        self.phrase_base_val = int(phrase_hash[:8], 16) # Use first 8 hex chars
        if self.debug_prints_enabled: print(f"  Phrase Base Value (from hash): {self.phrase_base_val}")

        # 2. Process Seed Number (more direct influence on structure)
        self.num_sequence = [int(d) for d in seed_number_str if d.isdigit()]
        if not self.num_sequence: self.num_sequence = [0] # Fallback
        if self.debug_prints_enabled: print(f"  Numerical Sequence (from seed number): {self.num_sequence}")

        self.init_map = self._generate_init_map()
        if self.debug_prints_enabled:
            print(f"  Generated InitMap:")
            for i, block_config in enumerate(self.init_map["block_configs"]):
                print(f"    Block {i}: Active Module Index: {block_config['active_module_idx']}, Target Entropy: {block_config['target_entropy']:.4f}, Gate Inits: {[f'{g:.2f}' for g in block_config['gate_inits']]}")
        print(f"--- SeedParser Initialized ---")

    def _get_deterministic_value(self, key_name, min_val, max_val, sequence_idx_offset=0):
        # Combine phrase base and numerical sequence for more variation
        combined_seed_val = self.phrase_base_val
        for i, num in enumerate(self.num_sequence):
            combined_seed_val += num * (10**(i + sequence_idx_offset))
        
        # Hash the key_name to make it specific to the parameter
        key_hash = int(hashlib.sha256(key_name.encode()).hexdigest()[:8], 16)
        final_seed = combined_seed_val + key_hash
        
        # Simple mapping to range (not cryptographically strong, but deterministic)
        if max_val == min_val: return min_val # Avoid division by zero if range is 1
        val = min_val + (final_seed % (max_val - min_val + 1))
        return val

    def _get_deterministic_float(self, key_name, min_val=0.0, max_val=1.0, sequence_idx_offset=0):
        combined_seed_val = self.phrase_base_val
        for i, num in enumerate(self.num_sequence):
            combined_seed_val += num * (10**(i + sequence_idx_offset))
        
        key_hash = int(hashlib.sha256(key_name.encode()).hexdigest()[:8], 16)
        final_seed = combined_seed_val + key_hash
        
        # Map to [0,1] float then scale
        float_val = (final_seed % 1000001) / 1000000.0 # Ensure it's never exactly 0 for some ops
        scaled_val = min_val + float_val * (max_val - min_val)
        return scaled_val

    def _generate_init_map(self):
        init_map = {"block_configs": []}
        
        for i in range(self.num_adaptive_blocks):
            # Determine which sub-module is initially "more" active
            active_module_idx = self._get_deterministic_value(
                f"block_{i}_active_module", 0, self.num_sub_modules_per_block - 1, sequence_idx_offset=i
            )
            
            # Determine initial gating values (summing to 1 for softmax-like behavior later)
            gate_inits_raw = [
                self._get_deterministic_float(f"block_{i}_gate_{j}_init_raw", 0.1, 1.0, sequence_idx_offset=i*10 + j)
                for j in range(self.num_sub_modules_per_block)
            ]
            # Make one gate stronger based on active_module_idx, then normalize slightly
            if self.num_sub_modules_per_block > 0 :
                gate_inits_raw[active_module_idx] *= 2.0 # Boost the 'active' one
                sum_raw = sum(gate_inits_raw)
                gate_inits_normalized = [g / sum_raw for g in gate_inits_raw] if sum_raw > 0 else [1.0/self.num_sub_modules_per_block]*self.num_sub_modules_per_block
            else:
                gate_inits_normalized = []


            # Determine a target entropy for this block's output
            target_entropy = self._get_deterministic_float(
                f"block_{i}_target_entropy", 0.05, 0.3, sequence_idx_offset=i # Target a moderate, non-zero entropy
            )

            init_map["block_configs"].append({
                "active_module_idx": active_module_idx, # For initial bias
                "gate_inits": gate_inits_normalized,    # Initial values for learnable gates
                "target_entropy": target_entropy
            })
        return init_map

    def get_block_config(self, block_idx):
        if 0 <= block_idx < len(self.init_map["block_configs"]):
            return self.init_map["block_configs"][block_idx]
        return None

# --- Adaptive Block ---
class AdaptiveBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout, seed_parser_config, block_idx, num_sub_modules=3):
        super().__init__()
        self.d_model = d_model
        self.block_idx = block_idx
        self.num_sub_modules = num_sub_modules
        self.config_from_seed = seed_parser_config # dict for this block
        self.debug_prints_enabled = True

        if self.debug_prints_enabled:
            print(f"  Initializing AdaptiveBlock {self.block_idx} with seed config: {self.config_from_seed}")

        # Define potential sub-modules
        self.sub_module_0 = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.sub_module_1 = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model)
        )
        # Sub-module 2: A simpler FFN or even a near identity (residual + small transform)
        self.sub_module_2 = nn.Sequential(
            nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model // 2, d_model)
        ) 
        # Add more diverse sub-modules if needed for `num_sub_modules_per_block`

        self.sub_modules = nn.ModuleList([self.sub_module_0, self.sub_module_1, self.sub_module_2])
        
        if self.num_sub_modules > len(self.sub_modules):
            print(f"Warning: block {self.block_idx} requested {self.num_sub_modules} sub_modules, but only {len(self.sub_modules)} are defined. Using defined ones.")
            self.num_sub_modules = len(self.sub_modules)


        # Learnable gates for combining/selecting sub-modules
        # Initialize gates based on seed_parser_config
        gate_initial_values = self.config_from_seed.get("gate_inits", [1.0/self.num_sub_modules]*self.num_sub_modules if self.num_sub_modules > 0 else [])
        if len(gate_initial_values) != self.num_sub_modules: # Fallback if seed parser gave wrong number
            print(f"Warning: Block {self.block_idx} gate_inits length mismatch. Re-initializing uniformly.")
            gate_initial_values = [1.0/self.num_sub_modules]*self.num_sub_modules if self.num_sub_modules > 0 else []

        self.gates = nn.Parameter(torch.tensor(gate_initial_values, dtype=torch.float32))

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model) # For output of block
        self.dropout = nn.Dropout(dropout)
        self.output_entropy_estimator = EntropyEstimator(d_model, name=f"Block{block_idx}_OutEntropy")
        self.wiring_phase_active = False # To be set by the main model

    def set_wiring_phase(self, active):
        self.wiring_phase_active = active
        if self.debug_prints_enabled and active:
            print(f"    AdaptiveBlock {self.block_idx}: WIRING PHASE ACTIVATED")
        elif self.debug_prints_enabled and not active:
             print(f"    AdaptiveBlock {self.block_idx}: WIRING PHASE DEACTIVATED")


    def forward(self, x, key_padding_mask=None, attn_mask=None): # attn_mask is for MHA, key_padding_mask for MHA keys
        if self.debug_prints_enabled:
            current_gates_softmax = F.softmax(self.gates, dim=0)
            print(f"    AdaptiveBlock {self.block_idx} Input x: {x.shape}, Gates (softmax): {[f'{g.item():.3f}' for g in current_gates_softmax]}")

        x_norm = self.norm1(x)
        
        outputs = []
        active_module_found = False
        for i, module in enumerate(self.sub_modules):
            if i >= self.num_sub_modules: break # Only use configured number

            if i == 0: # MHA
                # MHA expects key_padding_mask (N, S) bool: True if padded.
                # attn_mask (L,S) or (N*H,L,S) float/bool: True if masked / -inf.
                # For self-attention, L=S. If attn_mask is causal (L,L), it's fine.
                # If key_padding_mask is (N,S), it's fine.
                module_out, _ = module(x_norm, x_norm, x_norm, 
                                       key_padding_mask=key_padding_mask, 
                                       attn_mask=attn_mask,
                                       need_weights=False) # Don't need weights for this sim
                active_module_found = True
            elif hasattr(module, 'fc1') or isinstance(module, nn.Sequential): # FFN-like
                module_out = module(x_norm)
                active_module_found = True
            else: # Fallback for undefined module types in this simple sketch
                module_out = x_norm # Pass through
            outputs.append(module_out)
        
        if not active_module_found or not outputs: # Should not happen if num_sub_modules > 0
            print(f"    AdaptiveBlock {self.block_idx}: No active sub_modules processed. Passing input through.")
            final_out_unnorm = x # pass through
        else:
            # Gated combination
            gate_weights = F.softmax(self.gates, dim=0) # Ensure they sum to 1
            
            # Weighted sum of module outputs
            # Ensure outputs are stackable (they should be if all modules output (B,S,D))
            if outputs:
                stacked_outputs = torch.stack(outputs, dim=0) # (num_sub_modules, B, S, D)
                # gate_weights (num_sub_modules) -> (num_sub_modules, 1, 1, 1) for broadcasting
                weighted_sum = torch.sum(stacked_outputs * gate_weights.view(-1, 1, 1, 1), dim=0)
                final_out_unnorm = x + self.dropout(weighted_sum) # Residual connection
            else: # Fallback if somehow no outputs
                final_out_unnorm = x


        final_out_norm = self.norm2(final_out_unnorm)
        
        # During wiring phase, we might adjust gates based on local entropy vs target
        # This is a very simplified "self-wiring" heuristic
        current_output_entropy = self.output_entropy_estimator(final_out_norm, active_mask=~key_padding_mask if key_padding_mask is not None else None)
        target_entropy_for_block = self.config_from_seed.get("target_entropy", 0.1) # Default target

        if self.wiring_phase_active and self.training : # Only adjust gates during wiring AND training
            with torch.no_grad(): # Don't track gradients for this heuristic adjustment
                entropy_diff = current_output_entropy - target_entropy_for_block
                # If current entropy is too high, slightly boost gates of modules that might reduce it (heuristic)
                # If too low, slightly boost gates of modules that might increase it (heuristic)
                # This is extremely heuristic. A true self-wiring mechanism would be more complex.
                # For this sketch, let's say MHA (module 0) might increase complexity/entropy if it was low,
                # and FFNs (module 1, 2) might refine/stabilize if entropy was high.
                adjustment_strength = 0.01 # Small adjustment
                if entropy_diff > 0.05: # Current entropy significantly higher than target
                    self.gates.data[1] += adjustment_strength 
                    self.gates.data[2] += adjustment_strength 
                    self.gates.data[0] -= adjustment_strength * 0.5 # Slightly decrease MHA
                elif entropy_diff < -0.05: # Current entropy significantly lower
                    self.gates.data[0] += adjustment_strength
                    self.gates.data[1] -= adjustment_strength * 0.5
                    self.gates.data[2] -= adjustment_strength * 0.5
                # Clamp gates to avoid extreme values before softmax (optional)
                self.gates.data.clamp_(-2.0, 2.0) 
            if self.debug_prints_enabled:
                print(f"    AdaptiveBlock {self.block_idx} WIRING: OutEnt={current_output_entropy.item():.4f}, TgtEnt={target_entropy_for_block:.4f}, Δ={entropy_diff.item():.4f} -> New Gates (raw): {[f'{g.item():.3f}' for g in self.gates.data]}")
        
        elif self.debug_prints_enabled:
             print(f"    AdaptiveBlock {self.block_idx} EXEC: OutEnt={current_output_entropy.item():.4f}, TgtEnt={target_entropy_for_block:.4f}")


        # Return the block's output and its current estimated output entropy
        return final_out_norm, current_output_entropy, gate_weights


# --- Positional Encoding ---
class PositionalEncoding(nn.Module):
    def __init__(self,d_model,dropout=0.1,max_len=512): # Reduced max_len for this sketch
        super().__init__()
        self.dropout=nn.Dropout(p=dropout)
        pe=torch.zeros(max_len,d_model)
        pos=torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)
        div=torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))
        pe[:,0::2]=torch.sin(pos*div)
        pe[:,1::2]=torch.cos(pos*div)
        self.register_buffer('pe',pe.unsqueeze(0)) # (1, max_len, d_model)
    def forward(self,x): # x: (batch, seq_len, d_model)
        x=x+self.pe[:,:x.size(1),:]
        return self.dropout(x)

# --- Main SWCK Model ---
class SWCKModel(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, d_ff, num_adaptive_blocks, 
                 dropout, seed_phrase, seed_number_str, num_sub_modules_per_block=3):
        super().__init__()
        self.d_model = d_model
        self.seed_phrase = seed_phrase
        self.seed_number_str = seed_number_str
        self.debug_prints_enabled = True
        
        print(f"--- Initializing SWCKModel ---")
        self.seed_parser = SeedParser(seed_phrase, seed_number_str, d_model, num_adaptive_blocks, num_sub_modules_per_block)
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        
        self.adaptive_blocks = nn.ModuleList()
        for i in range(num_adaptive_blocks):
            block_config = self.seed_parser.get_block_config(i)
            if block_config is None:
                raise ValueError(f"Could not get seed config for block {i}")
            self.adaptive_blocks.append(
                AdaptiveBlock(d_model, n_heads, d_ff, dropout, block_config, block_idx=i, num_sub_modules=num_sub_modules_per_block)
            )
            if self.debug_prints_enabled:
                print(f"  SWCKModel: Added AdaptiveBlock {i}")

        self.fc_out = nn.Linear(d_model, vocab_size)
        self.overall_output_entropy_estimator = EntropyEstimator(d_model, name="OverallOutEntropy")
        
        self._init_weights()
        print(f"--- SWCKModel Initialized ---")

    def _init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc_out.bias.data.zero_()
        self.fc_out.weight.data.uniform_(-initrange, initrange)

    def set_wiring_phase(self, active):
        if self.debug_prints_enabled:
            print(f"SWCKModel: Setting wiring phase to {active} for all blocks.")
        for block in self.adaptive_blocks:
            block.set_wiring_phase(active)

    def forward(self, src_tokens, src_key_padding_mask=None):
        # src_tokens: (batch, seq_len)
        # src_key_padding_mask: (batch, seq_len), True for padded positions
        if self.debug_prints_enabled:
            print(f"\n--- SWCKModel Forward Pass ---")
            print(f"  Input src_tokens: {src_tokens.shape}")
            if src_key_padding_mask is not None: print(f"  Input src_key_padding_mask: {src_key_padding_mask.shape}")

        x = self.embedding(src_tokens) * math.sqrt(self.d_model)
        x = self.pos_encoder(x)
        if self.debug_prints_enabled: print(f"  After Embedding & PosEnc, x: {x.shape}")

        block_output_entropies = []
        block_gate_weights = []

        # For self-attention within blocks, a causal mask might be needed if it's a decoder-style model
        # For this general "processing core" sketch, let's assume full self-attention unless specified.
        # If this were a decoder, a causal mask would be passed or generated here.
        # For now, no explicit top-level causal mask is made, relying on block's internal MHA params.
        # A more standard transformer would create a causal mask for decoder self-attention.
        # We'll pass src_key_padding_mask to MHA if it's self-attention on source.
        
        for i, block in enumerate(self.adaptive_blocks):
            if self.debug_prints_enabled: print(f"  Processing AdaptiveBlock {i}...")
            # For self-attention in blocks, key_padding_mask applies to keys/values.
            # No separate attention mask for now unless it's a decoder block.
            x, block_entropy, gates = block(x, key_padding_mask=src_key_padding_mask, attn_mask=None) 
            block_output_entropies.append(block_entropy)
            block_gate_weights.append(gates)
            if self.debug_prints_enabled: print(f"  Output x from AdaptiveBlock {i}: {x.shape}, Entropy: {block_entropy.item():.4f}")

        logits = self.fc_out(x)
        if self.debug_prints_enabled: print(f"  Output logits: {logits.shape}")

        # Overall output entropy (of the final representation before fc_out)
        # Masking for entropy calculation
        final_active_mask = ~src_key_padding_mask if src_key_padding_mask is not None else None
        overall_entropy = self.overall_output_entropy_estimator(x, active_mask=final_active_mask)
        if self.debug_prints_enabled: print(f"  Overall Final Representation Entropy: {overall_entropy.item():.4f}")
        
        # Entropies from each block, overall output entropy, and gate weights for regularization/logging
        entropy_report = {
            "block_output_entropies": block_output_entropies, # List of tensors
            "overall_output_entropy": overall_entropy,       # Tensor
            "block_gate_weights": block_gate_weights         # List of tensors
        }
        
        return logits, entropy_report