reposage-chatbot / assist /transformer_demo.py
Rohit Rajpoot
Deploy transformer demo to Space
d1a29d0
raw
history blame
1.84 kB
import torch
import torch.nn as nn
import numpy as np
from .chat import TOKEN2IDX, IDX2TOKEN # reuse your vocab maps
from .chat import WEIGHTS # reuse your embedding weights
class SingleTransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads=2):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.ReLU(),
nn.Linear(embed_dim * 4, embed_dim)
)
self.norm2 = nn.LayerNorm(embed_dim)
def forward(self, x):
# Self-attention
attn_out, _ = self.attn(x, x, x)
x = self.norm1(x + attn_out)
# Feed-forward
ff_out = self.ff(x)
x = self.norm2(x + ff_out)
return x
# Instantiate once
_EMB = torch.tensor(WEIGHTS, dtype=torch.float32) # V×D
_block = SingleTransformerBlock(embed_dim=_EMB.size(1), num_heads=2)
def transformer_next(prompt: str) -> str:
"""
Given a prompt, tokenize it, embed each token, run through one
transformer block, then use the last position’s output vector
to pick the nearest vocab token as the “next token.”
"""
tokens = prompt.lower().split()
idxs = [TOKEN2IDX[t] for t in tokens if t in TOKEN2IDX]
if not idxs:
return "🤔 No known tokens to predict from."
# Build batch: 1×seq_len×D
x = _EMB[idxs].unsqueeze(0)
# Forward pass
out = _block(x) # 1×seq_len×D
last = out[0, -1].unsqueeze(0) # 1×D
# Cosine similarity against all embeddings
sims = nn.functional.cosine_similarity(last, _EMB)
best = int(torch.argmax(sims))
return f"🔮 Next‐token prediction: **{IDX2TOKEN[best]}**"