PromeMobile / model.py
Neu256's picture
Update model.py
488f6f3 verified
raw
history blame
8.17 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis, x):
batch_size, num_heads, seq_len, head_size = x.shape
freqs_cis = freqs_cis[:seq_len]
shape = [1, 1, seq_len, head_size // 2]
return freqs_cis.view(*shape)
def apply_rope(x, position, freqs_cis):
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, x)
x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
return x_out.type_as(x)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class Attention(nn.Module):
"""
Multi-head Self-Attention with RoPE
"""
def __init__(self, num_heads, head_size, num_embed):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.wq = nn.Linear(num_embed, num_heads * head_size, bias = False)
self.wk = nn.Linear(num_embed, num_heads * head_size, bias = False)
self.wv = nn.Linear(num_embed, num_heads * head_size, bias = False)
self.wo = nn.Linear(num_heads * head_size, num_embed, bias = False)
def forward(self, x, freqs_cis):
B, T, C = x.shape
mask = torch.triu(torch.full((T, T), float("-inf"), device=x.device), diagonal=1)
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(B, T, self.num_heads, self.head_size)
xk = xk.view(B, T, self.num_heads, self.head_size)
xv = xv.view(B, T, self.num_heads, self.head_size)
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
xq = apply_rope(xq, T, freqs_cis)
xk = apply_rope(xk, T, freqs_cis)
attn_weights = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_size)
attn_weights += mask
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq)
output = torch.matmul(attn_weights, xv)
output = output.transpose(1, 2).contiguous().view(B, T, C)
return self.wo(output)
class MLP(nn.Module):
def __init__(self, num_embed, dropout):
super().__init__()
self.num_embed = num_embed
hidden_dim = 3 * int(num_embed * 2 / 3)
self.linear1 = nn.Linear(num_embed, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, num_embed)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.linear1(x)
x = F.silu(x)
x = self.linear2(x)
x = self.dropout(x)
return x
class TransformerBlock(nn.Module):
"""
This calss will group together MultiHead Attention and
FeedForward NN, so that we can copy it in Transformer
"""
def __init__(self, num_heads, num_embed, dropout):
super().__init__()
self.num_heads = num_heads
self.num_embed = num_embed
head_size = num_embed // num_heads
self.sa = Attention(
num_heads=num_heads,
head_size=head_size,
num_embed=num_embed
)
self.ffwd = MLP(num_embed=num_embed, dropout=dropout)
# add the layer normalization
self.ln1 = RMSNorm(num_embed)
self.ln2 = RMSNorm(num_embed)
def forward(self, x, freqs_cis):
# "x +" is the skip (or residual) connection
# it helps with optimization
# also we apply layer normalization before self-attention
# and feed-forward (a reshufle from original paper)
x = x + self.sa(self.ln1(x), freqs_cis)
x = x + self.ffwd(self.ln2(x))
return x
class Transformer(nn.Module):
def __init__(self, **kwargs):
super().__init__()
# a simple lookup table that stores embeddings of a fixed dictionary and size
# each token directly reads off the logits for the next token from a lookup table
# see more: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
self.vocab_size = kwargs.get("vocab_size", 100)
self.num_embed = kwargs.get("num_embed", 32)
self.num_heads = kwargs.get("num_heads", 4)
self.num_layers = kwargs.get("num_layers", 4)
self.max_seq_len = kwargs.get("max_seq_len", 1024)
self.dropout = kwargs.get("dropout", 0.2)
# each token reads the logits for the next token from a lookup table
self.token_embedding_table = nn.Embedding(self.vocab_size, self.num_embed)
# each position from 0 to block_size-1 will get its embedding
#self.position_embedding_table = nn.Embedding(self.block_size, self.num_embed)
self.blocks = nn.ModuleList([
TransformerBlock(
num_heads=self.num_heads,
num_embed=self.num_embed,
dropout=self.dropout
)
for _ in range(self.num_layers)
])
# we add the layer norm before the Linear layer
self.lm_head = nn.Linear(self.num_embed, self.vocab_size)
self.norm = RMSNorm(self.num_embed)
self.freqs_cis = precompute_freqs_cis(
self.num_embed//self.num_heads,
self.max_seq_len * 2,
500000,
)
def forward(self, idx, targets=None):
B, T = idx.shape
# idx and targets are (B,T) tensor of integers
# the token_emb is (B, T, C), C = NUM_EMBED
x = self.token_embedding_table(idx)
freq = self.freqs_cis[:self.max_seq_len]
# apply one head of self-attention
for block in self.blocks:
x = block(x, freq)
x = self.norm(x)
# (B, T, vocab_size)
logits = self.lm_head(x)
# compute the loss
if targets != None:
# cross_entropy accepts inputs in a (batch_size, num_classes)
# so we need to reformat our logits dimensions to
# (batch_size * time, dim_vocabulary), time = block_size
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
loss = None
return logits, loss
def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 0.7, top_p: float = 0.9):
for _ in range(max_new_tokens):
idx_crop = idx[:, -self.max_seq_len:]
freq = self.freqs_cis[:self.max_seq_len]
logits, loss = self.forward(idx_crop)
logits = logits[:, -1, :]
if temperature > 0:
probs = F.softmax(logits / temperature, dim=-1)
idx_next = self.sample_top_p(probs, top_p)
else:
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx[0]
def sample_top_p(self, probs: torch.Tensor, top_p: float) -> torch.Tensor:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Create a mask for top-p filtering
top_p_mask = cumulative_probs <= top_p
top_p_mask[..., 1:] = top_p_mask[..., :-1].clone()
top_p_mask[..., 0] = 1
filtered_probs = sorted_probs * top_p_mask
filtered_probs /= filtered_probs.sum(dim=-1, keepdim=True) # Normalize filtered probabilities
next_token = torch.multinomial(filtered_probs, num_samples=1)
return torch.gather(sorted_indices, -1, next_token)