Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from typing import Optional | |
from torch import nn | |
class RMSNorm(nn.Module): | |
def __init__(self, dim, eps): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(dim)) | |
def forward(self, x): | |
# Root Mean Square Layer Normalization | |
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
return x * rms * self.weight | |
class RotaryEmbedding(nn.Module): | |
def __init__(self, dim, max_seq_len=2048, theta=10000): | |
super().__init__() | |
self.dim = dim | |
self.max_seq_len = max_seq_len | |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) | |
self.register_buffer("freqs", freqs) | |
t = torch.arange(max_seq_len, dtype=self.freqs.dtype) | |
freqs = torch.outer(t, self.freqs) | |
cos = freqs.cos() | |
sin = freqs.sin() | |
self.register_buffer('cos', cos) | |
self.register_buffer('sin', sin) | |
def rotate_half(self, x): | |
rot_dim = x.shape[-1] | |
x1 = x[..., :rot_dim // 2] | |
x2 = x[..., rot_dim // 2:] | |
return torch.cat((-x2, x1), dim=-1) | |
def apply_rotary_emb(self, t, x): | |
rot_dim = self.freqs.shape[-1] | |
cos = self.cos[t, :rot_dim] | |
sin = self.sin[t, :rot_dim] | |
rotated_x = (x[..., :rot_dim] * cos) + (self.rotate_half(x[..., :rot_dim]) * sin) | |
if x.shape[-1] > rot_dim: | |
rotated_x = torch.cat((rotated_x, x[..., rot_dim:]), dim=-1) | |
return rotated_x | |
def forward(self, x, seq_dim=-2): | |
seq_len = x.shape[seq_dim] | |
t = torch.arange(seq_len, device=x.device) | |
return self.apply_rotary_emb(t, x) | |
class Attention(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.dim = args.dim | |
self.num_heads = args.n_heads | |
self.kv_heads = args.n_kv_heads | |
self.head_dim = args.dim // args.n_heads | |
self.kv_head_dim = args.dim // args.n_kv_heads | |
assert self.head_dim * args.n_heads == args.dim, "args.dim must be divisible by args.n_heads" | |
assert self.kv_head_dim * args.n_kv_heads == args.dim, "args.dim must be divisible by args.n_kv_heads" | |
self.query_proj = nn.Linear(args.dim, args.dim, bias=False) | |
self.key_proj = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) | |
self.value_proj = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) | |
self.rope = RotaryEmbedding(self.head_dim) | |
self.out_proj = nn.Linear(args.dim, args.dim, bias=False) | |
self.dropout = nn.Dropout(args.dropout) | |
# # Caching storage (keys and values) | |
cached_keys = None | |
cached_values = None | |
self.register_buffer('cached_keys', cached_keys) | |
self.register_buffer('cached_values', cached_values) | |
def forward(self, x, mask=None, use_cache=False): | |
# # batch_size = x.size(0) | |
batch_size, seq_len, C = x.size() | |
query = self.query_proj(x) | |
key = self.key_proj(x) | |
value = self.value_proj(x) | |
# Reshape for attention computation | |
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim) | |
key = key.view(batch_size, seq_len, self.kv_heads, self.head_dim) | |
value = value.view(batch_size, seq_len, self.kv_heads, self.head_dim) | |
# Transpose for attention computation | |
query = query.transpose(1, 2) # [batch, num_heads, seq_len, head_dim] | |
key = key.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim] | |
value = value.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim] | |
query = self.rope(query) | |
key = self.rope(key) | |
# # If kv_heads are less than num_heads, repeat them | |
# if self.kv_heads < self.num_heads: | |
# key = key.repeat_interleave(self.num_heads // self.kv_heads, dim=1) | |
# value = value.repeat_interleave(self.num_heads // self.kv_heads, dim=1) | |
# # Compute attention | |
# attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
# if mask is not None: | |
# attn_weights = attn_weights + mask | |
# attn_weights = F.softmax(attn_weights, dim=-1) | |
# # Compute output | |
# output = torch.matmul(attn_weights, value) | |
# Flash-attn | |
output = F.scaled_dot_product_attention(query, key, value, is_causal=True, dropout_p=self.dropout.p, enable_gqa=True) | |
# Update cache only if using cache | |
if use_cache: | |
self.cached_keys = key | |
self.cached_values = value | |
else: | |
# Reset cached values during training (to prevent unwanted accumulation) | |
self.cached_keys = None | |
self.cached_values = None | |
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) # [batch, seq_len, num_heads * head_dim] | |
return self.out_proj(output) | |
class FeedForward(nn.Module): | |
def __init__(self, args): | |
""" | |
Initialize the FeedForward module. | |
Args: | |
dim (int): Input dimension. | |
hidden_dim (int): Hidden dimension of the feedforward layer. # 2304 | |
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. | |
Attributes: | |
w1 (nn.Linear): Linear transformation for the first layer. | |
w2 (nn.Linear): Linear transformation for the second layer. | |
w3 (nn.Linear): Linear transformation for the third layer. | |
""" | |
super().__init__() | |
self.w1 = nn.Linear(args.dim, args.intermediate_dim, bias=False) | |
self.w2 = nn.Linear(args.intermediate_dim, args.dim, bias=False) | |
self.w3 = nn.Linear(args.dim, args.intermediate_dim, bias=False) | |
def forward(self, x): | |
return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
class TransformerBlock(nn.Module): | |
def __init__(self, layer_id: int, args): | |
""" | |
Initialize a TransformerBlock. | |
Args: | |
layer_id (int): Identifier for the layer. | |
args (ModelArgs): Model configuration parameters. | |
Attributes: | |
n_heads (int): Number of attention heads. | |
dim (int): Dimension size of the model. | |
head_dim (int): Dimension size of each attention head. | |
attention (Attention): Attention module. | |
feed_forward (FeedForward): FeedForward module. | |
layer_id (int): Identifier for the layer. | |
attention_norm (RMSNorm): Layer normalization for attention output. | |
ffn_norm (RMSNorm): Layer normalization for feedforward output. | |
""" | |
super().__init__() | |
self.n_heads = args.n_heads | |
self.dim = args.dim | |
self.head_dim = args.dim // args.n_heads | |
self.attention = Attention(args) | |
self.feed_forward = FeedForward(args) | |
self.layer_id = layer_id | |
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) | |
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) | |
def forward( | |
self, | |
x: torch.Tensor, | |
mask: Optional[torch.Tensor], | |
use_cache: bool | |
): | |
""" | |
Perform a forward pass through the TransformerBlock. | |
Args: | |
x (torch.Tensor): Input tensor. | |
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None. | |
use_cache (bool): whether to use kv_cache | |
Returns: | |
torch.Tensor: Output tensor after applying attention and feedforward layers. | |
""" | |
h = x + self.attention(self.attention_norm(x), mask=mask, use_cache=use_cache) | |
out = h + self.feed_forward(self.ffn_norm(h)) | |
return out | |
class Transformer(nn.Module): | |
def __init__(self, args): | |
""" | |
Initialize a Transformer model. | |
Args: | |
args (ModelArgs): Model configuration parameters. | |
Attributes: | |
args (ModelArgs): Model configuration parameters. | |
vocab_size (int): Vocabulary size. | |
n_layers (int): Number of layers in the model. | |
tok_embeddings (nn.Embedding): Token embeddings. | |
layers (torch.nn.ModuleList): List of Transformer blocks. | |
norm (RMSNorm): Layer normalization for the model output. | |
output (nn.Linear): Linear layer for final output. | |
""" | |
super().__init__() | |
self.args = args | |
self.vocab_size = args.vocab_size | |
self.n_layers = args.n_layers | |
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) | |
self.layers = torch.nn.ModuleList() | |
for layer_id in range(args.n_layers): | |
self.layers.append(TransformerBlock(layer_id, args)) | |
self.norm = RMSNorm(args.dim, eps=args.norm_eps) | |
# self.output = nn.Linear( | |
# args.dim, args.vocab_size, bias=False | |
# ) | |
# # weight sharing | |
# self.output.weight = self.tok_embeddings.weight | |
# weight initialization | |
self.apply(self._init_weights) | |
def _init_weights(self, module): | |
std = self.args.init_scale | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=std) | |
# if module.bias is not None: | |
# module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=std) | |
def forward(self, tokens: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = False): | |
""" | |
Perform a forward pass through the Transformer model. | |
Args: | |
tokens (torch.Tensor): Input token indices. | |
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None. | |
use_cache (bool): whether to use kv_cache | |
Returns: | |
torch.Tensor: Output logits after applying the Transformer model. | |
""" | |
_, seqlen = tokens.shape | |
h = self.tok_embeddings(tokens) | |
if mask is None: | |
mask = torch.triu(torch.ones((seqlen, seqlen), | |
dtype=torch.bool, | |
device=tokens.device), | |
diagonal=1) | |
mask = mask.unsqueeze(0).unsqueeze(0) | |
mask = mask * -1e4 | |
for layer in self.layers: | |
h = layer(h, mask, use_cache) | |
h = self.norm(h) | |
# output = self.output(h).float() | |
output = F.linear(h, self.tok_embeddings.weight) | |
return output | |
def generate(self, | |
input_ids, | |
max_length, | |
min_length=None, | |
num_return_sequences=1, | |
pad_token_id=None, | |
do_sample=True, | |
temperature=0.8, | |
top_k=50, | |
top_p=0.95 | |
): | |
self.eval() | |
# batch_size = input_ids.shape[0] | |
min_length = min_length if min_length is not None else input_ids.shape[1] | |
with torch.no_grad(): | |
for ret_seq in range(num_return_sequences): | |
print(f"Sequence #{ret_seq + 1}:") | |
for _ in range(max_length - input_ids.shape[1]): | |
outputs = self(input_ids, use_cache=True) | |
next_token_logits = outputs[:, -1, :] | |
# Apply temperature | |
next_token_logits = next_token_logits / temperature | |
# Apply top-k filtering | |
if top_k > 0: | |
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] | |
next_token_logits[indices_to_remove] = float('-inf') | |
# Apply top-p (nucleus) filtering | |
if top_p < 1.0: | |
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) | |
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) | |
sorted_indices_to_remove = cumulative_probs > top_p | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
next_token_logits[indices_to_remove] = float('-inf') | |
# Sample from the filtered distribution | |
if do_sample: | |
probs = torch.softmax(next_token_logits, dim=-1) | |
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
else: | |
next_tokens = torch.argmax(next_token_logits, dim=-1) | |
input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1) | |
# Stop if all sequences have hit the pad token | |
if pad_token_id is not None and (next_tokens == pad_token_id).all(): | |
break | |
# Stop if we've reached min_length | |
if input_ids.shape[1] < min_length: | |
continue | |
return input_ids | |