File size: 1,837 Bytes
d1a29d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
import numpy as np
from .chat import TOKEN2IDX, IDX2TOKEN  # reuse your vocab maps
from .chat import WEIGHTS              # reuse your embedding weights

class SingleTransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=2):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.ff    = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # Self-attention
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + attn_out)
        # Feed-forward
        ff_out = self.ff(x)
        x = self.norm2(x + ff_out)
        return x

# Instantiate once
_EMB = torch.tensor(WEIGHTS, dtype=torch.float32)              # V×D
_block = SingleTransformerBlock(embed_dim=_EMB.size(1), num_heads=2)

def transformer_next(prompt: str) -> str:
    """
    Given a prompt, tokenize it, embed each token, run through one
    transformer block, then use the last position’s output vector
    to pick the nearest vocab token as the “next token.”
    """
    tokens = prompt.lower().split()
    idxs = [TOKEN2IDX[t] for t in tokens if t in TOKEN2IDX]
    if not idxs:
        return "🤔 No known tokens to predict from."
    # Build batch: 1×seq_len×D
    x = _EMB[idxs].unsqueeze(0)
    # Forward pass
    out = _block(x)              # 1×seq_len×D
    last = out[0, -1].unsqueeze(0)  # 1×D
    # Cosine similarity against all embeddings
    sims = nn.functional.cosine_similarity(last, _EMB)
    best = int(torch.argmax(sims))
    return f"🔮 Next‐token prediction: **{IDX2TOKEN[best]}**"