SmolLM-2-135M / model.py
tranquilkd's picture
added requirements.txt
87347e3
import torch
import torch.nn.functional as F
from typing import Optional
from torch import nn
class RMSNorm(nn.Module):
def __init__(self, dim, eps):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# Root Mean Square Layer Normalization
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return x * rms * self.weight
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=2048, theta=10000):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("freqs", freqs)
t = torch.arange(max_seq_len, dtype=self.freqs.dtype)
freqs = torch.outer(t, self.freqs)
cos = freqs.cos()
sin = freqs.sin()
self.register_buffer('cos', cos)
self.register_buffer('sin', sin)
def rotate_half(self, x):
rot_dim = x.shape[-1]
x1 = x[..., :rot_dim // 2]
x2 = x[..., rot_dim // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_emb(self, t, x):
rot_dim = self.freqs.shape[-1]
cos = self.cos[t, :rot_dim]
sin = self.sin[t, :rot_dim]
rotated_x = (x[..., :rot_dim] * cos) + (self.rotate_half(x[..., :rot_dim]) * sin)
if x.shape[-1] > rot_dim:
rotated_x = torch.cat((rotated_x, x[..., rot_dim:]), dim=-1)
return rotated_x
def forward(self, x, seq_dim=-2):
seq_len = x.shape[seq_dim]
t = torch.arange(seq_len, device=x.device)
return self.apply_rotary_emb(t, x)
class Attention(nn.Module):
def __init__(self, args):
super().__init__()
self.dim = args.dim
self.num_heads = args.n_heads
self.kv_heads = args.n_kv_heads
self.head_dim = args.dim // args.n_heads
self.kv_head_dim = args.dim // args.n_kv_heads
assert self.head_dim * args.n_heads == args.dim, "args.dim must be divisible by args.n_heads"
assert self.kv_head_dim * args.n_kv_heads == args.dim, "args.dim must be divisible by args.n_kv_heads"
self.query_proj = nn.Linear(args.dim, args.dim, bias=False)
self.key_proj = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
self.value_proj = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
self.rope = RotaryEmbedding(self.head_dim)
self.out_proj = nn.Linear(args.dim, args.dim, bias=False)
self.dropout = nn.Dropout(args.dropout)
# # Caching storage (keys and values)
cached_keys = None
cached_values = None
self.register_buffer('cached_keys', cached_keys)
self.register_buffer('cached_values', cached_values)
def forward(self, x, mask=None, use_cache=False):
# # batch_size = x.size(0)
batch_size, seq_len, C = x.size()
query = self.query_proj(x)
key = self.key_proj(x)
value = self.value_proj(x)
# Reshape for attention computation
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim)
key = key.view(batch_size, seq_len, self.kv_heads, self.head_dim)
value = value.view(batch_size, seq_len, self.kv_heads, self.head_dim)
# Transpose for attention computation
query = query.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
key = key.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
value = value.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim]
query = self.rope(query)
key = self.rope(key)
# # If kv_heads are less than num_heads, repeat them
# if self.kv_heads < self.num_heads:
# key = key.repeat_interleave(self.num_heads // self.kv_heads, dim=1)
# value = value.repeat_interleave(self.num_heads // self.kv_heads, dim=1)
# # Compute attention
# attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
# if mask is not None:
# attn_weights = attn_weights + mask
# attn_weights = F.softmax(attn_weights, dim=-1)
# # Compute output
# output = torch.matmul(attn_weights, value)
# Flash-attn
output = F.scaled_dot_product_attention(query, key, value, is_causal=True, dropout_p=self.dropout.p, enable_gqa=True)
# Update cache only if using cache
if use_cache:
self.cached_keys = key
self.cached_values = value
else:
# Reset cached values during training (to prevent unwanted accumulation)
self.cached_keys = None
self.cached_values = None
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) # [batch, seq_len, num_heads * head_dim]
return self.out_proj(output)
class FeedForward(nn.Module):
def __init__(self, args):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer. # 2304
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
Attributes:
w1 (nn.Linear): Linear transformation for the first layer.
w2 (nn.Linear): Linear transformation for the second layer.
w3 (nn.Linear): Linear transformation for the third layer.
"""
super().__init__()
self.w1 = nn.Linear(args.dim, args.intermediate_dim, bias=False)
self.w2 = nn.Linear(args.intermediate_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.intermediate_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args):
"""
Initialize a TransformerBlock.
Args:
layer_id (int): Identifier for the layer.
args (ModelArgs): Model configuration parameters.
Attributes:
n_heads (int): Number of attention heads.
dim (int): Dimension size of the model.
head_dim (int): Dimension size of each attention head.
attention (Attention): Attention module.
feed_forward (FeedForward): FeedForward module.
layer_id (int): Identifier for the layer.
attention_norm (RMSNorm): Layer normalization for attention output.
ffn_norm (RMSNorm): Layer normalization for feedforward output.
"""
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(args)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor],
use_cache: bool
):
"""
Perform a forward pass through the TransformerBlock.
Args:
x (torch.Tensor): Input tensor.
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
use_cache (bool): whether to use kv_cache
Returns:
torch.Tensor: Output tensor after applying attention and feedforward layers.
"""
h = x + self.attention(self.attention_norm(x), mask=mask, use_cache=use_cache)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, args):
"""
Initialize a Transformer model.
Args:
args (ModelArgs): Model configuration parameters.
Attributes:
args (ModelArgs): Model configuration parameters.
vocab_size (int): Vocabulary size.
n_layers (int): Number of layers in the model.
tok_embeddings (nn.Embedding): Token embeddings.
layers (torch.nn.ModuleList): List of Transformer blocks.
norm (RMSNorm): Layer normalization for the model output.
output (nn.Linear): Linear layer for final output.
"""
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(TransformerBlock(layer_id, args))
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
# self.output = nn.Linear(
# args.dim, args.vocab_size, bias=False
# )
# # weight sharing
# self.output.weight = self.tok_embeddings.weight
# weight initialization
self.apply(self._init_weights)
def _init_weights(self, module):
std = self.args.init_scale
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
# if module.bias is not None:
# module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
def forward(self, tokens: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = False):
"""
Perform a forward pass through the Transformer model.
Args:
tokens (torch.Tensor): Input token indices.
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
use_cache (bool): whether to use kv_cache
Returns:
torch.Tensor: Output logits after applying the Transformer model.
"""
_, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
if mask is None:
mask = torch.triu(torch.ones((seqlen, seqlen),
dtype=torch.bool,
device=tokens.device),
diagonal=1)
mask = mask.unsqueeze(0).unsqueeze(0)
mask = mask * -1e4
for layer in self.layers:
h = layer(h, mask, use_cache)
h = self.norm(h)
# output = self.output(h).float()
output = F.linear(h, self.tok_embeddings.weight)
return output
def generate(self,
input_ids,
max_length,
min_length=None,
num_return_sequences=1,
pad_token_id=None,
do_sample=True,
temperature=0.8,
top_k=50,
top_p=0.95
):
self.eval()
# batch_size = input_ids.shape[0]
min_length = min_length if min_length is not None else input_ids.shape[1]
with torch.no_grad():
for ret_seq in range(num_return_sequences):
print(f"Sequence #{ret_seq + 1}:")
for _ in range(max_length - input_ids.shape[1]):
outputs = self(input_ids, use_cache=True)
next_token_logits = outputs[:, -1, :]
# Apply temperature
next_token_logits = next_token_logits / temperature
# Apply top-k filtering
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
# Apply top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
# Sample from the filtered distribution
if do_sample:
probs = torch.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_logits, dim=-1)
input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
# Stop if all sequences have hit the pad token
if pad_token_id is not None and (next_tokens == pad_token_id).all():
break
# Stop if we've reached min_length
if input_ids.shape[1] < min_length:
continue
return input_ids