matiusX's picture
Upload llama.py
589d5e7 verified
import pandas as pd
import numpy as np
import os
import requests
import torch
from torch import nn
from torch.nn import functional as F
import sentencepiece as spm
import random
from collections import OrderedDict
from matplotlib import pyplot as plt
import time
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
VOCAB_SIZE = 130
BATCH_SIZE = 32
CONTEXT_WINDOW = 16
EPOCHS = 1000
DIM = 128
LOG_INTERVAL = 10
HEADS = 8
LAYERS = 4
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = requests.get(url)
if response.status_code == 200:
tinyshakespeare = response.text
else:
print(response.status_code)
tinyshakespeare_list = tinyshakespeare.split("\n")
tinyshakespeare_list = [i for i in tinyshakespeare_list if i != ""]
spm.SentencePieceTrainer.Train(
sentence_iterator = iter(tinyshakespeare_list),
model_prefix = "tinyshakespeare_model",
vocab_size = VOCAB_SIZE,
character_coverage = 1.0,
model_type = "bpe",
pad_id = 0,
unk_id = 1,
bos_id = 2,
eos_id = 3,
)
sp = spm.SentencePieceProcessor(model_file = "tinyshakespeare_model.model")
dataset_tensor = torch.tensor(sp.Encode(tinyshakespeare))
def get_batch_train(dataset, batch_size, context_window):
train_data = dataset[:int(.7 * len(dataset))]
ix = torch.randint(0, train_data.size(0) - context_window - 1, (batch_size,))
x = torch.stack([train_data[i:i+context_window] for i in ix]).long()
y = torch.stack([train_data[i+1:i+context_window+1] for i in ix]).long()
return x, y
def get_batch_val(dataset, batch_size, context_window):
val_data = dataset[int(.7 * len(dataset)): int(.85 * len(dataset))]
ix = torch.randint(0, val_data.size(0) - context_window - 1, (batch_size,))
x = torch.stack([val_data[i:i+context_window] for i in ix]).long()
y = torch.stack([val_data[i+1:i+context_window+1] for i in ix]).long()
return x, y
def get_batch_test(dataset, batch_size, context_window):
test_data = dataset[int(.85 * len(dataset)): len(dataset)]
ix = torch.randint(0, test_data.size(0) - context_window - 1, (batch_size,))
x = torch.stack([test_data[i:i+context_window] for i in ix]).long()
y = torch.stack([test_data[i+1:i+context_window+1] for i in ix]).long()
return x, y
@torch.no_grad()
def calculate_loss(model):
model.eval()
train_losses = []
val_losses = []
for i in range(EPOCHS):
#train evaluation
x_train, y_train = get_batch_train(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
_, train_loss = model(x_train, y_train)
train_losses.append(train_loss.item())
#val evaluation
x_val, y_val = get_batch_val(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
_, val_loss = model(x_val, y_val)
val_losses.append(val_loss.item())
losses_dict = {"train": np.mean(train_losses), "val": np.mean(val_losses)}
return losses_dict
@torch.no_grad()
def calculate_accuracy(model):
model.eval()
correct_predictions = 0
total_predictions = 0
for i in range(EPOCHS):
# Get a batch of validation data
x_val, y_val = get_batch_val(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
# Get model predictions
logits = model(x_val)
# Convert predictions to class labels
predicted_labels = torch.argmax(logits, dim=-1)
# Compare with true labels
correct_predictions += (predicted_labels == y_val).sum().item()
total_predictions += y_val.numel()
accuracy = correct_predictions / total_predictions
return accuracy
@torch.no_grad()
def calculate_perplexity(model):
model.eval()
val_losses = []
for i in range(EPOCHS):
# Get a batch of validation data
x_val, y_val = get_batch_val(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
# Get model predictions and loss
_, val_loss = model(x_val, y_val)
val_losses.append(val_loss.item())
# Calculate the mean validation loss
mean_val_loss = np.mean(val_losses)
# Perplexity is the exponential of the cross-entropy loss
perplexity = np.exp(mean_val_loss)
return perplexity
def train(model, optimizer, checkpoint_path="/checkpoints"):
losses = []
accs = []
perps = []
for epoch in range(EPOCHS):
optimizer.zero_grad()
x_train, y_train = get_batch_train(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
logits, loss = model(x_train, y_train)
loss.backward()
optimizer.step()
if epoch % LOG_INTERVAL == 0:
current_loss = calculate_loss(model)
current_accuracy = calculate_accuracy(model)
current_perplexity = calculate_perplexity(model)
losses.append(current_loss)
accs.append(current_accuracy)
perps.append(current_perplexity)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': current_loss,
'accuracy': current_accuracy,
'perplexity': current_perplexity
}, f"{checkpoint_path}/checkpoint_epoch_{epoch}.pth")
print(f"Epoch {epoch}: Loss - {current_loss['val']}, Accuracy - {current_accuracy}, Perplexity - {current_perplexity}")
print("validation Loss: ", losses[-1]['val'])
print("validation Accuracy: ", accs[-1])
print("validation Perplexity: ", perps[-1])
return pd.DataFrame(losses).plot()
class RMSNorm(torch.nn.Module):
def __init__(self, layer_shape, eps=1e-8, bias=False):
super(RMSNorm, self).__init__()
self.register_parameter("scale", torch.nn.Parameter(torch.ones(layer_shape)))
def forward(self, x):
return self.scale[:x.shape[1], :].unsqueeze(0) * ((torch.linalg.norm(x, dim=(1,2)) * x[0].numel() ** -.5).unsqueeze(-1).unsqueeze(-1))
def get_rotary_matrix(context_window, embedding_dim):
R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False)
for position in range(context_window):
for i in range(embedding_dim//2):
theta = 10000. ** (-2.*(i - 1) / embedding_dim)
m_theta = position * theta
R[position, 2*i,2*i] = np.cos(m_theta)
R[position, 2*i,2*i+1] = - np.sin(m_theta)
R[position, 2*i+1,2*i] = np.sin(m_theta)
R[position, 2*i+1,2*i+1] = np.cos(m_theta)
return R
class RoPEAttentionHead(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(DIM, DIM, bias=False)
self.w_k = nn.Linear(DIM, DIM, bias=False)
self.w_v = nn.Linear(DIM, DIM, bias=False)
self.R = get_rotary_matrix(CONTEXT_WINDOW, DIM)
def get_rotary_matrix(context_window, embedding_dim):
R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False)
for position in range(context_window):
for i in range(embedding_dim//2):
theta = 10000. ** (-2.*(i - 1) / embedding_dim)
m_theta = position * theta
R[position, 2*i,2*i] = np.cos(m_theta)
R[position, 2*i,2*i+1] = - np.sin(m_theta)
R[position, 2*i+1,2*i] = np.sin(m_theta)
R[position, 2*i+1,2*i+1] = np.cos(m_theta)
return R
def forward(self, x, return_attn_weights=False):
b,m,d = x.shape
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)
q_rotated = (torch.bmm(q.transpose(0,1), self.R[:m])).transpose(0,1)
k_rotated = (torch.bmm(k.transpose(0,1), self.R[:m])).transpose(0,1)
activations = F.scaled_dot_product_attention(
q_rotated,k_rotated,v,dropout_p =.1
)
if return_attn_weights:
attn_weights = torch.bmm(q_rotated, k_rotated.transpose(1,2)) / np.sqrt(d)
attn_weights = F.softmax(attn_weights, dim=-1)
return activations, attn_weights
return activations
class RoPEAttentionHead(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(DIM, DIM, bias=False)
self.w_k = nn.Linear(DIM, DIM, bias=False)
self.w_v = nn.Linear(DIM, DIM, bias=False)
self.R = get_rotary_matrix(CONTEXT_WINDOW, DIM)
def get_rotary_matrix(context_window, embedding_dim):
R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False)
for position in range(context_window):
for i in range(embedding_dim//2):
theta = 10000. ** (-2.*(i - 1) / embedding_dim)
m_theta = position * theta
R[position, 2*i,2*i] = np.cos(m_theta)
R[position, 2*i,2*i+1] = - np.sin(m_theta)
R[position, 2*i+1,2*i] = np.sin(m_theta)
R[position, 2*i+1,2*i+1] = np.cos(m_theta)
return R
def forward(self, x, return_attn_weights=False):
b,m,d = x.shape
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)
q_rotated = (torch.bmm(q.transpose(0,1), self.R[:m])).transpose(0,1)
k_rotated = (torch.bmm(k.transpose(0,1), self.R[:m])).transpose(0,1)
activations = F.scaled_dot_product_attention(
q_rotated,k_rotated,v,dropout_p =.1, is_causal=True
)
if return_attn_weights:
attn_mask = torch.tril(torch.ones((m,m)), diagonal=0)
attn_weights = torch.bmm(q_rotated, k_rotated.transpose(1,2)) / np.sqrt(d) + attn_mask
attn_weights = F.softmax(attn_weights, dim=-1)
return activations, attn_weights
return activations
class RoPEMultiheadAttention(nn.Module):
def __init__(self):
super().__init__()
self.heads = nn.ModuleList([
RoPEAttentionHead() for _ in range(HEADS)
])
self.linear = nn.Linear(HEADS * DIM, DIM)
self.dropout = nn.Dropout(.1)
def forward(self, x):
heads = [h(x) for h in self.heads]
x = torch.cat(heads, dim=-1)
x = self.linear(x)
x = self.dropout(x)
return x
class SwiGLU(nn.Module):
def __init__(self, size):
super().__init__()
self.linear_gate = nn.Linear(size, size)
self.linear = nn.Linear(size, size)
self.beta = torch.randn(1, requires_grad=True)
self.beta = nn.Parameter(torch.ones(1))
self.register_parameter("beta", self.beta)
def forward(self, x):
swish_gate = self.linear_gate(x) * torch.sigmoid(self.beta * self.linear_gate(x))
out = swish_gate * self.linear(x)
return out
class LlamaBlock(nn.Module):
def __init__(self):
super().__init__()
self.rms = RMSNorm((CONTEXT_WINDOW, DIM))
self.attention = RoPEMultiheadAttention()
self.feedforward = nn.Sequential(
nn.Linear(DIM, DIM),
SwiGLU(DIM),
)
def forward(self, x):
x = self.rms(x) #RMS NORMALIZATION
x = x + self.attention(x) #Self attention
x = self.rms(x) #RMS NORMALIZATION
x = x + self.feedforward(x) #Feed Foward: SwiGlu
return x
class Llama(nn.Module):
def __init__(self):
super().__init__()
self.embeddings = nn.Embedding(VOCAB_SIZE, DIM)
self.llama_blocks = nn.Sequential(
OrderedDict([(f"llama_{i}", LlamaBlock()) for i in range(LAYERS)])
)
self.ffn = nn.Sequential(
nn.Linear(DIM, DIM),
SwiGLU(DIM),
nn.Linear(DIM, VOCAB_SIZE),
)
print("model params:", sum([m.numel() for m in self.parameters()]))
def forward(self, idx, targets=None):
x = self.embeddings(idx)
x = self.llama_blocks(x)
logits = self.ffn(x)
if targets is None:
return logits
else:
loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), targets.view(-1))
return logits, loss
llama = Llama()
optimizer = torch.optim.Adam(llama.parameters())
train(llama, optimizer)