AbstractPhil commited on
Commit
98ec4ab
·
verified ·
1 Parent(s): d08afa7

Update beeper_model.py

Browse files
Files changed (1) hide show
  1. beeper_model.py +126 -222
beeper_model.py CHANGED
@@ -1,38 +1,27 @@
1
  # beeper.py
2
- # --------------------------------------------------------------------------------------------------
3
- # Beeper Full Penta Controller — Rose-based tiny GPT (inference module with runtime pentachora influence)
4
- # - Decoder-only GPT with SDPA (FlashAttention path on Ampere/Hopper)
5
- # - Runtime "vertex pull" uses config["runtime_pentachora"] to bias hidden states toward
6
- # pentachora vertices (coarse/topic/mood) exactly like training-time behavior, but non-destructive
7
- # and fully toggleable.
8
- # --------------------------------------------------------------------------------------------------
9
  from __future__ import annotations
10
 
11
- import math
12
- import re
13
- import inspect
14
  from contextlib import nullcontext
15
- from typing import Optional, Tuple, Dict, Any
16
 
17
  import torch
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
 
21
- # --- Prefer high-throughput matmul where possible (Ampere/Hopper) ---
22
  torch.set_float32_matmul_precision("high")
23
  torch.backends.cuda.matmul.allow_tf32 = True
24
  torch.backends.cudnn.allow_tf32 = True
25
 
26
- # ---- Version-safe SDPA (FlashAttention) selection -------------------------------------------------
27
  try:
28
- # PyTorch 2.3+ modern API
29
  from torch.nn.attention import sdpa_kernel as _sdpa_kernel_modern
30
  from torch.nn.attention import SDPBackend as _SDPBackend
31
  _SDPA_SIG = inspect.signature(_sdpa_kernel_modern)
32
  _sdpa_kernel = _sdpa_kernel_modern
33
  except Exception:
34
  try:
35
- # Legacy API
36
  from torch.backends.cuda import sdp_kernel as _sdpa_kernel_legacy
37
  _SDPA_SIG = inspect.signature(_sdpa_kernel_legacy)
38
  _SDPBackend = None
@@ -42,39 +31,31 @@ except Exception:
42
  _SDPBackend = None
43
  _sdpa_kernel = None
44
 
45
-
46
  def sdpa_ctx_prefer_flash():
47
- """Bias SDPA toward FlashAttention where possible; otherwise no-op."""
48
  if _sdpa_kernel is None or _SDPA_SIG is None:
49
  return nullcontext()
50
  params = {p.name for p in _SDPA_SIG.parameters.values()}
51
  try:
52
  if "backends" in params and _SDPBackend is not None:
53
- return _sdpa_kernel(backends=[
54
- _SDPBackend.FLASH_ATTENTION,
55
- _SDPBackend.EFFICIENT_ATTENTION,
56
- _SDPBackend.MATH
57
- ])
58
  if "backend" in params and _SDPBackend is not None:
59
  return _sdpa_kernel(backend=_SDPBackend.FLASH_ATTENTION)
60
- if {"enable_flash", "enable_math", "enable_mem_efficient"} <= params:
61
  return _sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True)
62
- if {"use_flash", "use_math", "use_mem_efficient"} <= params:
63
  return _sdpa_kernel(use_flash=True, use_math=False, use_mem_efficient=True)
64
  except Exception:
65
  pass
66
  return nullcontext()
67
 
68
-
69
- # --------------------------------- Core blocks ------------------------------------------------------
70
  class CausalSelfAttention(nn.Module):
71
- """Multi-head causal self-attention using PyTorch SDPA."""
72
  def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0):
73
  super().__init__()
74
- assert dim % n_heads == 0, "dim must be divisible by n_heads"
75
- self.nh = int(n_heads)
76
- self.hd = dim // self.nh
77
- self.qkv = nn.Linear(dim, 3 * dim, bias=False)
78
  self.proj = nn.Linear(dim, dim, bias=False)
79
  self.attn_dropout = float(attn_dropout)
80
 
@@ -82,78 +63,59 @@ class CausalSelfAttention(nn.Module):
82
  B, T, C = x.shape
83
  qkv = self.qkv(x)
84
  q, k, v = qkv.chunk(3, dim=-1)
85
- q = q.view(B, T, self.nh, self.hd).transpose(1, 2) # [B,H,T,D]
86
  k = k.view(B, T, self.nh, self.hd).transpose(1, 2)
87
  v = v.view(B, T, self.nh, self.hd).transpose(1, 2)
88
-
89
  if x.is_cuda:
90
  with sdpa_ctx_prefer_flash():
91
- y = F.scaled_dot_product_attention(
92
- q, k, v,
93
- is_causal=True,
94
- dropout_p=self.attn_dropout if self.training else 0.0,
95
- )
96
  else:
97
  scale = 1.0 / math.sqrt(self.hd)
98
  att = (q @ k.transpose(-2, -1)) * scale
99
- mask = torch.full((1, 1, T, T), float("-inf"), device=x.device)
100
- mask = torch.triu(mask, diagonal=1)
101
- att = (att + mask).softmax(dim=-1)
102
- y = att @ v
103
-
104
  y = y.transpose(1, 2).contiguous().view(B, T, C)
105
  return self.proj(y)
106
 
107
-
108
  class MLP(nn.Module):
109
- """GELU MLP with dropout, sized by mlp_ratio."""
110
  def __init__(self, dim: int, mlp_ratio: float = 4.0, dropout: float = 0.1):
111
  super().__init__()
112
- hidden = int(dim * mlp_ratio)
113
- self.fc1 = nn.Linear(dim, hidden)
114
- self.fc2 = nn.Linear(hidden, dim)
115
  self.drop = nn.Dropout(dropout)
116
-
117
- def forward(self, x: torch.Tensor) -> torch.Tensor:
118
- x = self.fc1(x)
119
- x = F.gelu(x, approximate="tanh")
120
  x = self.drop(x)
121
  x = self.fc2(x)
122
  x = self.drop(x)
123
  return x
124
 
125
-
126
- # --------------------------------- Beeper Model -----------------------------------------------------
127
  class BeeperRoseGPT(nn.Module):
128
  """
129
- Decoder-only GPT used by Beeper during training and inference.
130
-
131
- Config keys used:
132
- - vocab_size, dim, context, n_heads, n_layers, mlp_ratio
133
- - resid_dropout, dropout, grad_checkpoint
134
- - runtime_pentachora: {
135
- "enable": bool,
136
- "pool": "mean" | "last",
137
- "temp": float, # similarity temperature (default: 0.10)
138
- "coarse_alpha": float, # hidden blend strength for coarse bank
139
- "topic_alpha": float, # hidden blend strength for topic bank
140
- "mood_alpha": float # hidden blend strength for mood bank
141
- }
142
- Notes:
143
- - Shares token embedding with LM head (tied weights).
144
- - Includes Rose anchors and pentachora banks; at runtime we can apply a *non-destructive*
145
- vertex pull to hidden states before the LM head using the above config.
146
  """
147
  def __init__(self, cfg: dict):
148
  super().__init__()
149
  V, D, Ctx = cfg["vocab_size"], cfg["dim"], cfg["context"]
150
- H, L, MR = cfg["n_heads"], cfg["n_layers"], cfg["mlp_ratio"]
151
- RD, AD = cfg.get("resid_dropout", 0.1), cfg.get("dropout", 0.0)
152
  self.grad_checkpoint = bool(cfg.get("grad_checkpoint", False))
153
  self.runtime_cfg: Dict[str, Any] = dict(cfg.get("runtime_pentachora", {}) or {})
154
 
155
  self.vocab_size, self.context = int(V), int(Ctx)
156
-
157
  self.token_emb = nn.Embedding(V, D)
158
  self.pos_emb = nn.Parameter(torch.zeros(1, Ctx, D))
159
  self.drop = nn.Dropout(RD)
@@ -164,142 +126,123 @@ class BeeperRoseGPT(nn.Module):
164
  "attn": CausalSelfAttention(D, H, attn_dropout=AD),
165
  "norm2": nn.LayerNorm(D),
166
  "mlp": MLP(D, mlp_ratio=MR, dropout=RD),
167
- })
168
- for _ in range(L)
169
  ])
170
-
171
- self.norm = nn.LayerNorm(D)
172
  self.lm_head = nn.Linear(D, V, bias=False)
173
- self.lm_head.weight = self.token_emb.weight # weight tying
174
 
175
- # Rose projection + anchors (present in checkpoints)
176
  self.rose_proj = nn.Linear(D, D, bias=False)
177
- self.rose_anchors = nn.Parameter(torch.randn(3, D) / (D ** 0.5))
178
 
179
- # Pentachora banks (created lazily to match state dict)
180
  self.register_buffer("pent_inited", torch.tensor(0, dtype=torch.uint8), persistent=False)
181
  self.penta_coarse: Optional[nn.Parameter] = None # [C,5,D]
182
  self.penta_medium: Optional[nn.Parameter] = None # [T,5,D]
183
  self.penta_fine: Optional[nn.Parameter] = None # [M,5,D]
184
 
185
- self.apply(self._init_weights)
186
 
187
  @staticmethod
188
- def _init_weights(m: nn.Module):
189
  if isinstance(m, nn.Linear):
190
  nn.init.normal_(m.weight, mean=0.0, std=0.02)
191
- if m.bias is not None:
192
- nn.init.zeros_(m.bias)
193
  elif isinstance(m, nn.Embedding):
194
  nn.init.normal_(m.weight, mean=0.0, std=0.02)
195
 
196
- # ---- Pentachora creation (must match sizes in checkpoint before strict load) -------------------
197
  def ensure_pentachora(self, coarse_C: int, medium_C: int, fine_C: int, dim: int, device: torch.device):
198
- """Initialize pentachora banks if not already present."""
199
  if self.pent_inited.item() == 1:
200
  return
201
-
202
  def bank(C: int) -> nn.Parameter:
203
- if C <= 0:
204
- return nn.Parameter(torch.zeros((0, 5, dim), device=device))
205
  pts = torch.randn(C, 5, dim, device=device)
206
  pts = F.normalize(pts - pts.mean(dim=1, keepdim=True), dim=-1)
207
  return nn.Parameter(pts)
208
-
209
  self.penta_coarse = bank(int(coarse_C))
210
  self.penta_medium = bank(int(medium_C))
211
  self.penta_fine = bank(int(fine_C))
212
  self.pent_inited.fill_(1)
213
 
214
- # ---- Runtime configuration helpers -------------------------------------------------------------
215
  def set_runtime_pentachora(self, cfg: Dict[str, Any]) -> None:
216
- """Update runtime pentachora behavior (enable/alphas/temp/pool)."""
217
  self.runtime_cfg.update(cfg or {})
218
 
219
  def _pool_hidden(self, h: torch.Tensor, mode: str) -> torch.Tensor:
220
  return h.mean(dim=1) if mode == "mean" else h[:, -1, :]
221
 
 
 
 
 
 
 
 
 
222
  @staticmethod
223
  def _weighted_nearest_vertex_target(
224
- pooled: torch.Tensor, # [B,D]
225
- bank: torch.Tensor, # [C,5,D]
226
- temp: float
 
227
  ) -> torch.Tensor:
228
  """
229
- For each class (simplex) pick its nearest vertex to the pooled latent,
230
- then compute a softmax over classes of -min_dists/temp and take the
231
- weighted average of those nearest vertices => [B,D] target.
232
  """
233
  B, D = pooled.shape
234
- C = bank.size(0)
235
- if C == 0:
236
  return pooled
237
-
238
- # distances to each vertex
239
- diffs = pooled[:, None, None, :] - bank[None, :, :, :] # [B,C,5,D]
240
- dists = torch.norm(diffs, dim=-1) # [B,C,5]
241
-
242
- min_dists, min_idx = dists.min(dim=2) # [B,C], [B,C]
243
- sims = -min_dists / max(1e-8, float(temp)) # [B,C]
244
- weights = F.softmax(sims, dim=-1) # [B,C]
245
-
246
- # gather nearest vertex vectors: [B,C,D]
247
- bank_exp = bank.unsqueeze(0).expand(B, -1, -1, -1) # [B,C,5,D]
248
- gather_idx = min_idx.unsqueeze(-1).unsqueeze(-1).expand(B, C, 1, D)
249
- nearest = torch.gather(bank_exp, 2, gather_idx).squeeze(2) # [B,C,D]
250
-
251
- target = (weights.unsqueeze(-1) * nearest).sum(dim=1) # [B,D]
252
  return target
253
 
254
- def _apply_runtime_vertex_pull(
255
- self,
256
- h: torch.Tensor, # [B,T,D]
257
- runtime_cfg: Dict[str, Any]
258
- ) -> torch.Tensor:
259
- """
260
- Apply non-destructive vertex pull to hidden states using banks selected by runtime_cfg.
261
- We compute a pooled latent, a per-bank target vector, form a delta, and blend it back into h.
262
- """
263
  if not runtime_cfg or not runtime_cfg.get("enable", False):
264
  return h
265
-
266
  pool_mode = str(runtime_cfg.get("pool", "mean"))
267
  temp = float(runtime_cfg.get("temp", 0.10))
268
-
269
- # Strengths per bank
270
- alpha_coarse = float(runtime_cfg.get("coarse_alpha", 0.0))
271
- alpha_topic = float(runtime_cfg.get("topic_alpha", 0.0))
272
- alpha_mood = float(runtime_cfg.get("mood_alpha", 0.0))
273
-
274
- if (alpha_coarse <= 0 and alpha_topic <= 0 and alpha_mood <= 0):
275
  return h
276
 
277
- pooled = self._pool_hidden(h, pool_mode) # [B,D]
278
-
279
- total_delta = None
280
- if alpha_coarse > 0 and getattr(self, "penta_coarse", None) is not None:
281
- tgt = self._weighted_nearest_vertex_target(pooled, self.penta_coarse, temp)
282
- delta = tgt - pooled
283
- total_delta = (alpha_coarse * delta) if total_delta is None else total_delta + alpha_coarse * delta
284
-
285
- if alpha_topic > 0 and getattr(self, "penta_medium", None) is not None:
286
- tgt = self._weighted_nearest_vertex_target(pooled, self.penta_medium, temp)
287
- delta = tgt - pooled
288
- total_delta = delta * alpha_topic if total_delta is None else total_delta + alpha_topic * delta
289
-
290
- if alpha_mood > 0 and getattr(self, "penta_fine", None) is not None:
291
- tgt = self._weighted_nearest_vertex_target(pooled, self.penta_fine, temp)
292
- delta = tgt - pooled
293
- total_delta = delta * alpha_mood if total_delta is None else total_delta + alpha_mood * delta
294
-
295
- if total_delta is None:
 
 
 
 
 
 
296
  return h
 
297
 
298
- # Broadcast same delta to all time steps (global conditioning shift)
299
- h = h + total_delta.unsqueeze(1) # [B,T,D]
300
- return h
301
-
302
- # ---- Backbone / forward -----------------------------------------------------------------------
303
  def _block_forward(self, blk: nn.ModuleDict, x: torch.Tensor) -> torch.Tensor:
304
  x = x + blk["attn"](blk["norm1"](x))
305
  x = x + blk["mlp"](blk["norm2"](x))
@@ -312,87 +255,53 @@ class BeeperRoseGPT(nn.Module):
312
  if self.grad_checkpoint and self.training:
313
  from torch.utils.checkpoint import checkpoint
314
  for blk in self.blocks:
315
- x = checkpoint(lambda _x: self._block_forward(blk, _x), x) # type: ignore[arg-type]
316
  else:
317
  for blk in self.blocks:
318
  x = self._block_forward(blk, x)
319
  return self.norm(x)
320
 
321
  def forward(self, idx: torch.Tensor, runtime_cfg: Optional[Dict[str, Any]] = None) -> torch.Tensor:
322
- """
323
- Forward pass with optional runtime pentachora influence.
324
- If runtime_cfg is None, falls back to self.runtime_cfg set at init or via set_runtime_pentachora().
325
- """
326
  h = self.backbone(idx)
327
  cfg = self.runtime_cfg if runtime_cfg is None else {**self.runtime_cfg, **(runtime_cfg or {})}
328
  h = self._apply_runtime_vertex_pull(h, cfg)
329
  return self.lm_head(h)
330
 
331
- # ---- Utilities ---------------------------------------------------------------------------------
332
  def hidden_states(self, idx: torch.Tensor) -> torch.Tensor:
333
- """Return final hidden states (pre-LM head)."""
334
  return self.backbone(idx)
335
-
336
  def rose_hidden_pool(self, h: torch.Tensor, mode: str = "mean") -> torch.Tensor:
337
- """Pool hidden states for Rose-related terms."""
338
- return h.mean(dim=1) if mode == "mean" else h[:, -1, :]
339
 
340
-
341
- # --------------------------------- Loader helpers ---------------------------------------------------
342
- def prepare_model_for_state_dict(
343
- model: BeeperRoseGPT,
344
- state_dict: "dict[str, torch.Tensor]",
345
- device: Optional[torch.device] = None,
346
- ) -> None:
347
- """
348
- Ensure model has pentachora parameters sized to match the incoming state_dict,
349
- so we can load with strict=True. No-op if checkpoint lacks penta_* keys.
350
- """
351
  device = device or next(model.parameters()).device
352
- need = all(k in state_dict for k in ("penta_coarse", "penta_medium", "penta_fine"))
353
- if not need:
354
- return
355
-
356
- pc, pt, pm = state_dict["penta_coarse"], state_dict["penta_medium"], state_dict["penta_fine"]
357
-
358
- def dims_ok(t: torch.Tensor, D: int) -> bool:
359
- return t.ndim == 3 and t.size(1) == 5 and t.size(2) == D
360
-
361
  D = model.token_emb.embedding_dim
362
- if not (dims_ok(pc, D) and dims_ok(pt, D) and dims_ok(pm, D)):
363
- return
364
-
365
  model.ensure_pentachora(pc.size(0), pt.size(0), pm.size(0), dim=D, device=device)
366
 
367
-
368
- # --------------------------------- Generation -------------------------------------------------------
369
  def _detok(text: str) -> str:
370
  text = re.sub(r"\s+([,.;:!?%])", r"\1", text)
371
  text = re.sub(r"\s+([\)\]\}])", r"\1", text)
372
  text = re.sub(r"([\(\[\{])\s+", r"\1", text)
373
  return text
374
 
375
-
376
  @torch.no_grad()
377
- def generate(
378
- model: BeeperRoseGPT,
379
- tok, # Hugging Face Tokenizers `Tokenizer`
380
- cfg: dict,
381
- prompt: str,
382
- max_new_tokens: int = 120,
383
- temperature: Optional[float] = None,
384
- top_k: Optional[int] = None,
385
- top_p: Optional[float] = None,
386
- repetition_penalty: Optional[float] = None,
387
- presence_penalty: Optional[float] = None,
388
- frequency_penalty: Optional[float] = None,
389
- device: Optional[torch.device] = None,
390
- detokenize: bool = True,
391
- runtime_cfg: Optional[Dict[str, Any]] = None, # <— NEW: pass-through to forward()
392
- ) -> str:
393
- """
394
- Penalized nucleus sampling with optional runtime pentachora influence.
395
- """
396
  temperature = cfg.get("temperature", 0.9) if temperature is None else float(temperature)
397
  top_k = cfg.get("top_k", 40) if top_k is None else int(top_k)
398
  top_p = cfg.get("top_p", 0.9) if top_p is None else float(top_p)
@@ -408,14 +317,12 @@ def generate(
408
  V = int(cfg["vocab_size"])
409
  counts = torch.zeros(V, dtype=torch.int32, device=device)
410
  for t in ids:
411
- if 0 <= t < V:
412
- counts[t] += 1
413
 
414
  for _ in range(int(max_new_tokens)):
415
  logits = model(x[:, -cfg["context"]:], runtime_cfg=runtime_cfg)
416
  logits = logits[:, -1, :]
417
 
418
- # Repetition penalty
419
  if repetition_penalty and repetition_penalty != 1.0:
420
  mask = counts > 0
421
  if mask.any():
@@ -423,9 +330,8 @@ def generate(
423
  logits[:, mask][pos] /= repetition_penalty
424
  logits[:, mask][~pos] *= repetition_penalty
425
 
426
- # Presence/frequency penalties
427
  if presence_penalty or frequency_penalty:
428
- pen = counts.float() * (frequency_penalty or 0.0) + (counts > 0).float() * (presence_penalty or 0.0)
429
  logits = logits - pen.unsqueeze(0)
430
 
431
  logits = logits / max(1e-8, temperature)
@@ -433,8 +339,7 @@ def generate(
433
  if top_k and top_k > 0:
434
  k = min(top_k, logits.size(-1))
435
  v, ix = torch.topk(logits, k, dim=-1)
436
- filt = torch.full_like(logits, float("-inf"))
437
- logits = filt.scatter_(-1, ix, v)
438
 
439
  if top_p and top_p < 1.0:
440
  sl, si = torch.sort(logits, descending=True)
@@ -449,8 +354,7 @@ def generate(
449
  next_id = torch.multinomial(probs, num_samples=1)
450
  x = torch.cat([x, next_id], dim=1)
451
  nid = next_id.item()
452
- if 0 <= nid < V:
453
- counts[nid] += 1
454
 
455
  out = tok.decode(x[0].tolist())
456
  return _detok(out) if detokenize else out
 
1
  # beeper.py
2
+ # Beeper — Rose-based tiny GPT (inference, with runtime pentachora influence + class/topic/mood selection)
 
 
 
 
 
 
3
  from __future__ import annotations
4
 
5
+ import math, re, inspect
 
 
6
  from contextlib import nullcontext
7
+ from typing import Optional, Dict, Any, Iterable
8
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
 
 
13
  torch.set_float32_matmul_precision("high")
14
  torch.backends.cuda.matmul.allow_tf32 = True
15
  torch.backends.cudnn.allow_tf32 = True
16
 
17
+ # ---- SDPA (FlashAttention) selection ----
18
  try:
 
19
  from torch.nn.attention import sdpa_kernel as _sdpa_kernel_modern
20
  from torch.nn.attention import SDPBackend as _SDPBackend
21
  _SDPA_SIG = inspect.signature(_sdpa_kernel_modern)
22
  _sdpa_kernel = _sdpa_kernel_modern
23
  except Exception:
24
  try:
 
25
  from torch.backends.cuda import sdp_kernel as _sdpa_kernel_legacy
26
  _SDPA_SIG = inspect.signature(_sdpa_kernel_legacy)
27
  _SDPBackend = None
 
31
  _SDPBackend = None
32
  _sdpa_kernel = None
33
 
 
34
  def sdpa_ctx_prefer_flash():
 
35
  if _sdpa_kernel is None or _SDPA_SIG is None:
36
  return nullcontext()
37
  params = {p.name for p in _SDPA_SIG.parameters.values()}
38
  try:
39
  if "backends" in params and _SDPBackend is not None:
40
+ return _sdpa_kernel(backends=[_SDPBackend.FLASH_ATTENTION, _SDPBackend.EFFICIENT_ATTENTION, _SDPBackend.MATH])
 
 
 
 
41
  if "backend" in params and _SDPBackend is not None:
42
  return _sdpa_kernel(backend=_SDPBackend.FLASH_ATTENTION)
43
+ if {"enable_flash","enable_math","enable_mem_efficient"} <= params:
44
  return _sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True)
45
+ if {"use_flash","use_math","use_mem_efficient"} <= params:
46
  return _sdpa_kernel(use_flash=True, use_math=False, use_mem_efficient=True)
47
  except Exception:
48
  pass
49
  return nullcontext()
50
 
51
+ # ---------------- Blocks ----------------
 
52
  class CausalSelfAttention(nn.Module):
 
53
  def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0):
54
  super().__init__()
55
+ assert dim % n_heads == 0
56
+ self.nh = n_heads
57
+ self.hd = dim // n_heads
58
+ self.qkv = nn.Linear(dim, 3*dim, bias=False)
59
  self.proj = nn.Linear(dim, dim, bias=False)
60
  self.attn_dropout = float(attn_dropout)
61
 
 
63
  B, T, C = x.shape
64
  qkv = self.qkv(x)
65
  q, k, v = qkv.chunk(3, dim=-1)
66
+ q = q.view(B, T, self.nh, self.hd).transpose(1, 2)
67
  k = k.view(B, T, self.nh, self.hd).transpose(1, 2)
68
  v = v.view(B, T, self.nh, self.hd).transpose(1, 2)
 
69
  if x.is_cuda:
70
  with sdpa_ctx_prefer_flash():
71
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True,
72
+ dropout_p=self.attn_dropout if self.training else 0.0)
 
 
 
73
  else:
74
  scale = 1.0 / math.sqrt(self.hd)
75
  att = (q @ k.transpose(-2, -1)) * scale
76
+ mask = torch.triu(torch.full((1,1,T,T), float("-inf"), device=x.device), diagonal=1)
77
+ y = (att + mask).softmax(dim=-1) @ v
 
 
 
78
  y = y.transpose(1, 2).contiguous().view(B, T, C)
79
  return self.proj(y)
80
 
 
81
  class MLP(nn.Module):
 
82
  def __init__(self, dim: int, mlp_ratio: float = 4.0, dropout: float = 0.1):
83
  super().__init__()
84
+ h = int(dim*mlp_ratio)
85
+ self.fc1 = nn.Linear(dim, h)
86
+ self.fc2 = nn.Linear(h, dim)
87
  self.drop = nn.Dropout(dropout)
88
+ def forward(self, x):
89
+ x = F.gelu(self.fc1(x), approximate="tanh")
 
 
90
  x = self.drop(x)
91
  x = self.fc2(x)
92
  x = self.drop(x)
93
  return x
94
 
95
+ # --------------- Model ---------------
 
96
  class BeeperRoseGPT(nn.Module):
97
  """
98
+ Runtime pentachora control via self.runtime_cfg:
99
+ {
100
+ "enable": bool,
101
+ "pool": "mean"|"last",
102
+ "temp": 0.10,
103
+ "coarse_alpha": float, "topic_alpha": float, "mood_alpha": float,
104
+ # NEW: selection masks (ints or lists of ints)
105
+ "coarse_select": Optional[Iterable[int]],
106
+ "topic_select": Optional[Iterable[int]],
107
+ "mood_select": Optional[Iterable[int]],
108
+ }
 
 
 
 
 
 
109
  """
110
  def __init__(self, cfg: dict):
111
  super().__init__()
112
  V, D, Ctx = cfg["vocab_size"], cfg["dim"], cfg["context"]
113
+ H, L, MR = cfg["n_heads"], cfg["n_layers"], cfg["mlp_ratio"]
114
+ RD, AD = cfg.get("resid_dropout", 0.1), cfg.get("dropout", 0.0)
115
  self.grad_checkpoint = bool(cfg.get("grad_checkpoint", False))
116
  self.runtime_cfg: Dict[str, Any] = dict(cfg.get("runtime_pentachora", {}) or {})
117
 
118
  self.vocab_size, self.context = int(V), int(Ctx)
 
119
  self.token_emb = nn.Embedding(V, D)
120
  self.pos_emb = nn.Parameter(torch.zeros(1, Ctx, D))
121
  self.drop = nn.Dropout(RD)
 
126
  "attn": CausalSelfAttention(D, H, attn_dropout=AD),
127
  "norm2": nn.LayerNorm(D),
128
  "mlp": MLP(D, mlp_ratio=MR, dropout=RD),
129
+ }) for _ in range(L)
 
130
  ])
131
+ self.norm = nn.LayerNorm(D)
 
132
  self.lm_head = nn.Linear(D, V, bias=False)
133
+ self.lm_head.weight = self.token_emb.weight
134
 
135
+ # Rose anchors (kept for compatibility)
136
  self.rose_proj = nn.Linear(D, D, bias=False)
137
+ self.rose_anchors = nn.Parameter(torch.randn(3, D) / (D**0.5))
138
 
139
+ # Pentachora banks
140
  self.register_buffer("pent_inited", torch.tensor(0, dtype=torch.uint8), persistent=False)
141
  self.penta_coarse: Optional[nn.Parameter] = None # [C,5,D]
142
  self.penta_medium: Optional[nn.Parameter] = None # [T,5,D]
143
  self.penta_fine: Optional[nn.Parameter] = None # [M,5,D]
144
 
145
+ self.apply(self._init)
146
 
147
  @staticmethod
148
+ def _init(m):
149
  if isinstance(m, nn.Linear):
150
  nn.init.normal_(m.weight, mean=0.0, std=0.02)
151
+ if m.bias is not None: nn.init.zeros_(m.bias)
 
152
  elif isinstance(m, nn.Embedding):
153
  nn.init.normal_(m.weight, mean=0.0, std=0.02)
154
 
 
155
  def ensure_pentachora(self, coarse_C: int, medium_C: int, fine_C: int, dim: int, device: torch.device):
 
156
  if self.pent_inited.item() == 1:
157
  return
 
158
  def bank(C: int) -> nn.Parameter:
159
+ if C <= 0: return nn.Parameter(torch.zeros((0,5,dim), device=device))
 
160
  pts = torch.randn(C, 5, dim, device=device)
161
  pts = F.normalize(pts - pts.mean(dim=1, keepdim=True), dim=-1)
162
  return nn.Parameter(pts)
 
163
  self.penta_coarse = bank(int(coarse_C))
164
  self.penta_medium = bank(int(medium_C))
165
  self.penta_fine = bank(int(fine_C))
166
  self.pent_inited.fill_(1)
167
 
 
168
  def set_runtime_pentachora(self, cfg: Dict[str, Any]) -> None:
 
169
  self.runtime_cfg.update(cfg or {})
170
 
171
  def _pool_hidden(self, h: torch.Tensor, mode: str) -> torch.Tensor:
172
  return h.mean(dim=1) if mode == "mean" else h[:, -1, :]
173
 
174
+ @staticmethod
175
+ def _normalize_indices(sel: Optional[Iterable[int]], C: int) -> Optional[torch.Tensor]:
176
+ if sel is None: return None
177
+ if isinstance(sel, int): sel = [sel]
178
+ sel = [int(x) for x in sel if 0 <= int(x) < C]
179
+ if not sel: return None
180
+ return torch.as_tensor(sel, dtype=torch.long)
181
+
182
  @staticmethod
183
  def _weighted_nearest_vertex_target(
184
+ pooled: torch.Tensor, # [B,D]
185
+ bank: torch.Tensor, # [C,5,D]
186
+ temp: float,
187
+ restrict_idx: Optional[torch.Tensor] = None # [K] or None
188
  ) -> torch.Tensor:
189
  """
190
+ If restrict_idx is given, compute target within the selected classes only.
 
 
191
  """
192
  B, D = pooled.shape
193
+ if bank.size(0) == 0:
 
194
  return pooled
195
+ if restrict_idx is not None:
196
+ bank = bank.index_select(0, restrict_idx.to(bank.device)) # [K,5,D]
197
+ diffs = pooled[:, None, None, :] - bank[None, :, :, :] # [B,C|K,5,D]
198
+ dists = torch.norm(diffs, dim=-1) # [B,C|K,5]
199
+ min_dists = dists.min(dim=2).values # [B,C|K]
200
+ sims = -min_dists / max(1e-8, float(temp)) # [B,C|K]
201
+ weights = F.softmax(sims, dim=-1) # [B,C|K]
202
+ nearest = bank.unsqueeze(0).gather(2, dists.argmin(dim=2)[...,None,None].expand(B, weights.size(1), 1, D)).squeeze(2) # [B,C|K,D]
203
+ target = (weights.unsqueeze(-1) * nearest).sum(dim=1) # [B,D]
 
 
 
 
 
 
204
  return target
205
 
206
+ def _apply_runtime_vertex_pull(self, h: torch.Tensor, runtime_cfg: Dict[str, Any]) -> torch.Tensor:
 
 
 
 
 
 
 
 
207
  if not runtime_cfg or not runtime_cfg.get("enable", False):
208
  return h
 
209
  pool_mode = str(runtime_cfg.get("pool", "mean"))
210
  temp = float(runtime_cfg.get("temp", 0.10))
211
+ a_coarse = float(runtime_cfg.get("coarse_alpha", 0.0))
212
+ a_topic = float(runtime_cfg.get("topic_alpha", 0.0))
213
+ a_mood = float(runtime_cfg.get("mood_alpha", 0.0))
214
+ if a_coarse<=0 and a_topic<=0 and a_mood<=0:
 
 
 
215
  return h
216
 
217
+ pooled = self._pool_hidden(h, pool_mode) # [B,D]
218
+ delta = None
219
+
220
+ if a_coarse>0 and getattr(self, "penta_coarse", None) is not None:
221
+ C = self.penta_coarse.size(0)
222
+ r = self._normalize_indices(runtime_cfg.get("coarse_select"), C)
223
+ tgt = self._weighted_nearest_vertex_target(pooled, self.penta_coarse, temp, r)
224
+ d = tgt - pooled
225
+ delta = a_coarse * d if delta is None else delta + a_coarse * d
226
+
227
+ if a_topic>0 and getattr(self, "penta_medium", None) is not None:
228
+ C = self.penta_medium.size(0)
229
+ r = self._normalize_indices(runtime_cfg.get("topic_select"), C)
230
+ tgt = self._weighted_nearest_vertex_target(pooled, self.penta_medium, temp, r)
231
+ d = tgt - pooled
232
+ delta = a_topic * d if delta is None else delta + a_topic * d
233
+
234
+ if a_mood>0 and getattr(self, "penta_fine", None) is not None:
235
+ C = self.penta_fine.size(0)
236
+ r = self._normalize_indices(runtime_cfg.get("mood_select"), C)
237
+ tgt = self._weighted_nearest_vertex_target(pooled, self.penta_fine, temp, r)
238
+ d = tgt - pooled
239
+ delta = a_mood * d if delta is None else delta + a_mood * d
240
+
241
+ if delta is None:
242
  return h
243
+ return h + delta.unsqueeze(1) # broadcast across time
244
 
245
+ # ---- forward ----
 
 
 
 
246
  def _block_forward(self, blk: nn.ModuleDict, x: torch.Tensor) -> torch.Tensor:
247
  x = x + blk["attn"](blk["norm1"](x))
248
  x = x + blk["mlp"](blk["norm2"](x))
 
255
  if self.grad_checkpoint and self.training:
256
  from torch.utils.checkpoint import checkpoint
257
  for blk in self.blocks:
258
+ x = checkpoint(lambda _x: self._block_forward(blk, _x), x) # type: ignore
259
  else:
260
  for blk in self.blocks:
261
  x = self._block_forward(blk, x)
262
  return self.norm(x)
263
 
264
  def forward(self, idx: torch.Tensor, runtime_cfg: Optional[Dict[str, Any]] = None) -> torch.Tensor:
 
 
 
 
265
  h = self.backbone(idx)
266
  cfg = self.runtime_cfg if runtime_cfg is None else {**self.runtime_cfg, **(runtime_cfg or {})}
267
  h = self._apply_runtime_vertex_pull(h, cfg)
268
  return self.lm_head(h)
269
 
270
+ # Utilities
271
  def hidden_states(self, idx: torch.Tensor) -> torch.Tensor:
 
272
  return self.backbone(idx)
 
273
  def rose_hidden_pool(self, h: torch.Tensor, mode: str = "mean") -> torch.Tensor:
274
+ return h.mean(dim=1) if mode=="mean" else h[:, -1, :]
 
275
 
276
+ # ---- Loader helper ----
277
+ def prepare_model_for_state_dict(model: BeeperRoseGPT, state_dict: Dict[str, torch.Tensor], device: Optional[torch.device] = None) -> None:
 
 
 
 
 
 
 
 
 
278
  device = device or next(model.parameters()).device
279
+ need = all(k in state_dict for k in ("penta_coarse","penta_medium","penta_fine"))
280
+ if not need: return
 
 
 
 
 
 
 
281
  D = model.token_emb.embedding_dim
282
+ pc, pt, pm = state_dict["penta_coarse"], state_dict["penta_medium"], state_dict["penta_fine"]
283
+ ok = lambda t: (t.ndim==3 and t.size(1)==5 and t.size(2)==D)
284
+ if not (ok(pc) and ok(pt) and ok(pm)): return
285
  model.ensure_pentachora(pc.size(0), pt.size(0), pm.size(0), dim=D, device=device)
286
 
287
+ # ---- Generation ----
 
288
  def _detok(text: str) -> str:
289
  text = re.sub(r"\s+([,.;:!?%])", r"\1", text)
290
  text = re.sub(r"\s+([\)\]\}])", r"\1", text)
291
  text = re.sub(r"([\(\[\{])\s+", r"\1", text)
292
  return text
293
 
 
294
  @torch.no_grad()
295
+ def generate(model: BeeperRoseGPT, tok, cfg: dict, prompt: str,
296
+ max_new_tokens: int = 120, temperature: float | None = None,
297
+ top_k: int | None = None, top_p: float | None = None,
298
+ repetition_penalty: float | None = None,
299
+ presence_penalty: float | None = None,
300
+ frequency_penalty: float | None = None,
301
+ device: Optional[torch.device] = None,
302
+ detokenize: bool = True,
303
+ runtime_cfg: Optional[Dict[str, Any]] = None) -> str:
304
+
 
 
 
 
 
 
 
 
 
305
  temperature = cfg.get("temperature", 0.9) if temperature is None else float(temperature)
306
  top_k = cfg.get("top_k", 40) if top_k is None else int(top_k)
307
  top_p = cfg.get("top_p", 0.9) if top_p is None else float(top_p)
 
317
  V = int(cfg["vocab_size"])
318
  counts = torch.zeros(V, dtype=torch.int32, device=device)
319
  for t in ids:
320
+ if 0 <= t < V: counts[t] += 1
 
321
 
322
  for _ in range(int(max_new_tokens)):
323
  logits = model(x[:, -cfg["context"]:], runtime_cfg=runtime_cfg)
324
  logits = logits[:, -1, :]
325
 
 
326
  if repetition_penalty and repetition_penalty != 1.0:
327
  mask = counts > 0
328
  if mask.any():
 
330
  logits[:, mask][pos] /= repetition_penalty
331
  logits[:, mask][~pos] *= repetition_penalty
332
 
 
333
  if presence_penalty or frequency_penalty:
334
+ pen = counts.float() * (frequency_penalty or 0.0) + (counts>0).float() * (presence_penalty or 0.0)
335
  logits = logits - pen.unsqueeze(0)
336
 
337
  logits = logits / max(1e-8, temperature)
 
339
  if top_k and top_k > 0:
340
  k = min(top_k, logits.size(-1))
341
  v, ix = torch.topk(logits, k, dim=-1)
342
+ logits = torch.full_like(logits, float("-inf")).scatter(-1, ix, v)
 
343
 
344
  if top_p and top_p < 1.0:
345
  sl, si = torch.sort(logits, descending=True)
 
354
  next_id = torch.multinomial(probs, num_samples=1)
355
  x = torch.cat([x, next_id], dim=1)
356
  nid = next_id.item()
357
+ if 0 <= nid < V: counts[nid] += 1
 
358
 
359
  out = tok.decode(x[0].tolist())
360
  return _detok(out) if detokenize else out