Prome-LLM-II / model.py
Neu256's picture
Upload 3 files
35db907
raw
history blame
12.2 kB
import torch
import torch.nn as nn
from torch.nn import functional as F
from utils import DEVICE
class PromeLayerNorm(nn.Module):
def __init__(self, epsilon=1e-5):
super().__init__()
self.epsilon = epsilon
def forward(self, x):
g = torch.nn.Parameter(torch.ones(x.shape[-1])).to(x.device)
b = torch.nn.Parameter(torch.zeros(x.shape[-1])).to(x.device)
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) * torch.rsqrt(s + self.epsilon)
x = x * g + b
return x
class PromeStand(nn.Module):
def __init__(self, epsilon=1e-5):
super().__init__()
self.epsilon = epsilon
def forward(self, x):
"""
x: Input tensor
"""
mean = x.mean() + self.epsilon
std = x.std() + self.epsilon
x = x - mean
x = x / std
return x
class PromeEmbedding(nn.Module):
"""
This class implements a Prome embedding layer.
Args:
vocab_size (int): The size of the vocabulary.
embedding_dim (int): The dimension of the embedding.
padding_idx (int, optional): The padding index. If this is not None, then the padding index will be masked out when calculating the embedding.
Returns:
torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim).
"""
def __init__(self, vocab_size, embedding_dim, padding_idx = None):
super().__init__()
self.embedding_dim = embedding_dim
self.weight = torch.nn.Parameter(torch.randn(vocab_size, embedding_dim))
self.padding_idx = padding_idx
self.context_matrix = torch.nn.Parameter(torch.randn(vocab_size, embedding_dim))
def forward(self, input_ids):
"""
Calculates the embedding for the given input IDs.
Args:
input_ids (torch.Tensor): A tensor of shape (batch_size, sequence_length).
Returns:
torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim).
"""
input_ids = input_ids.long()
if self.padding_idx is not None:
input_ids = input_ids.masked_fill(input_ids == self.padding_idx, 0)
# get symbol vector
embeddings = self.weight[input_ids]
# Dynamically update context vector based on input embeddings
context_vectors = self.context_matrix[input_ids]
# Modify embeddings using context vector
output = embeddings + context_vectors
return output
class AttentionHead(nn.Module):
"""
One head of the self-attention layer
"""
def __init__(self, head_size, num_embed, block_size, dropout):
super().__init__()
self.key = nn.Linear(num_embed, head_size, bias=False)
self.query = nn.Linear(num_embed, head_size, bias=False)
self.value = nn.Linear(num_embed, head_size, bias=False)
# tril is a lower triangular matrix. it is not a parameter
# of the model, so we assign it to the module using register_buffer
self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
# layer norm
self.norm = PromeStand()
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
key = self.key(x)
query = self.query(x)
# compute attention scores
# (B, T, C) @ (B, C, T) -> (B, T, T)
wei = (query @ key.transpose(-2, -1)) * C ** -0.5
# Tril matrix (lower triagular matrix) is used to mask
# future positions (setting them to -inf) so that the
# decoder "learns" to predict next words
wei = wei.masked_fill(self.tril[:T, :T] == 0, -float("inf")) # (B,T,T)
wei = F.silu(F.softmax(wei, dim=-1))
# scale
# multiplicative attention
score = -1 / (C ** -0.5)
wei.mul_(score)
# weighted aggregation of the values
value = self.value(x)
out = wei @ value # (B,T,T) @ (B,T,C) ---> (B,T,C)
return out
class MultiHeadAttention(nn.Module):
"""
Multiple Heads of self-attention in parallel
"""
def __init__(self, num_heads, head_size, num_embed, block_size, dropout):
super().__init__()
self.heads = nn.ModuleList(
[
AttentionHead(
head_size=head_size,
num_embed=num_embed,
block_size=block_size,
dropout=dropout
)
for _ in range(num_heads)
]
)
self.proj = nn.Linear(num_embed, num_embed)
self.dropout = nn.Dropout(dropout)
self.norm = PromeStand()
def forward(self, x):
# output of the self-attention
out = torch.concat([h(x) for h in self.heads], dim=-1)
# standartization
out = self.norm(out + x)
# apply the linear projection layer
out = self.dropout(self.proj(out))
return out
class MLP(nn.Module):
def __init__(self, num_embed, hidden_dim, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.fc1 = nn.Linear(num_embed, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, num_embed)
def forward(self, x):
x = self.fc1(x)
x = F.silu(x)
x = self.fc2(x)
x = self.dropout(x)
x = F.silu(x)
x = self.fc3(x)
return x
class TransformerBlock(nn.Module):
"""
This calss will group together MultiHead Attention and
FeedForward NN, so that we can copy it in Transformer
"""
def __init__(self, num_heads, block_size, num_embed, hidden_dim, dropout):
super().__init__()
head_size = num_embed // num_heads
self.mha = MultiHeadAttention(
num_heads=num_heads,
head_size=head_size,
num_embed=num_embed,
block_size=block_size,
dropout=dropout
)
self.mlp = MLP(num_embed=num_embed, hidden_dim = hidden_dim, dropout=dropout)
# add the layer normalization
self.ln = PromeStand(num_embed)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Decodes the input sequence.
Args:
x (torch.Tensor): A tensor of shape (batch_size, sequence_length, embedding_dim).
memory (torch.Tensor): A tensor of shape (batch_size, memory_length, embedding_dim).
Returns:
torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim).
"""
y = x
x = self.ln(x)
x = self.mha(x)
x = self.dropout(x)
x += y
y = x
x = self.ln(x)
x = self.mlp(x)
x = self.mha(x)
x += y
x = self.dropout(x)
return x
class TransformerDecoder(nn.Module):
"""
This class implements a Transformer decoder.
Args:
num_heads (int): The number of attention heads.
block_size (int): The size of the input sequence.
num_embed (int): The dimension of the embedding.
num_layers (int): The number of decoder blocks.
dropout (float): The dropout rate.
Returns:
torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim).
"""
def __init__(self, num_heads, block_size, num_embed, hidden_dim, num_layers, dropout):
super().__init__()
# Create the embedding layer.
self.pemb = PromeEmbedding(block_size, num_embed)
# Create a sequential block of Transformer blocks.
self.blocks = nn.Sequential(
*[
TransformerBlock(
num_heads=num_heads,
block_size=block_size,
num_embed=num_embed,
hidden_dim = hidden_dim,
dropout=dropout
)
for _ in range(num_layers)
]
)
# Create a softmax layer.
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
Decodes the input sequence.
Args:
x (torch.Tensor): A tensor of shape (batch_size, sequence_length).
Returns:
torch.Tensor: A tensor of shape (batch_size, sequence_length, embedding_dim).
"""
# Add positional encodings to the input sequence.
x = x + self.pemb(torch.arange(x.size(1)))
x = self.blocks(x)
# Apply a softmax layer to the output of the last Transformer block.
x = self.softmax(x)
return x
class Transformer(nn.Module):
def __init__(self, **kwargs):
super().__init__()
# a simple lookup table that stores embeddings of a fixed dictionary and size
# each token directly reads off the logits for the next token from a lookup table
# see more: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
self.vocab_size = kwargs.get("vocab_size", 100)
self.num_embed = kwargs.get("num_embed", 32)
self.block_size = kwargs.get("block_size", 8)
self.num_heads = kwargs.get("num_heads", 4)
self.num_layers = kwargs.get("num_layers", 4)
self.hidden_dim = kwargs.get("hidden_dim", 768)
self.dropout = kwargs.get("dropout", 0.2)
# each token reads the logits for the next token from a lookup table
self.token_embedding_table = PromeEmbedding(self.vocab_size, self.num_embed)
# each position from 0 to block_size-1 will get its embedding
self.position_embedding_table = PromeEmbedding(self.block_size, self.num_embed)
self.decoder = TransformerDecoder(self.num_heads, self.block_size, self.num_embed, self.hidden_dim, self.num_layers, self.dropout)
# we add the layer norm before the Linear layer
self.dropout = nn.Dropout(self.dropout)
self.ln_f = PromeLayerNorm(self.num_embed)
self.lm_head = nn.Linear(self.num_embed, self.vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
# idx and targets are (B,T) tensor of integers
# the token_emb is (B, T, C), C = NUM_EMBED
token_emb = self.token_embedding_table(idx)
# (T, C)
posit_emb = self.position_embedding_table(torch.arange(T, device=DEVICE))
x = token_emb + posit_emb
# apply dropout
x = self.dropout(x)
# apply one head of self-attention
x = self.decoder(x)
# apply normalization
x = self.ln_f(x)
# (B, T, vocab_size)
logits = self.lm_head(x)
# Compute the loss
if targets != None:
# cross_entropy accepts inputs in a (batch_size, num_classes)
# so we need to reformat our logits dimensions to
# (batch_size * time, dim_vocabulary), time = block_size
B, T, C = logits.shape
logits = torch.reshape(logits, (B * T, C))
targets = torch.reshape(targets, (B * T, ))
loss = F.cross_entropy(logits, targets)
else:
loss = None
return logits, loss
def generate(self, idx: torch.Tensor, max_new_tokens: int, block_size: int):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# crop the context too the last block_size tokens
# because tokens don't communicate between blocks
idx_crop = idx[:, -block_size:]
# get the predictions
logits, loss = self.forward(idx_crop)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution with probabilities probs
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx