Spaces:
Running
on
L40S
Running
on
L40S
from typing import Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from cube3d.model.transformers.cache import Cache | |
from cube3d.model.transformers.norm import LayerNorm, RMSNorm | |
from cube3d.model.transformers.rope import scaled_dot_product_attention_with_rotary_emb | |
class SwiGLUMLP(nn.Module): | |
def __init__(self, embed_dim, hidden_dim, bias=True, **kwargs): | |
""" | |
A PyTorch implementation of the SwiGLU (Swish-Gated Linear Unit) MLP layer. | |
This module consists of three linear projections: `gate_proj`, `up_proj`, and `down_proj`. | |
It applies the SwiGLU activation function, which combines the Swish activation with a gating mechanism, | |
followed by a projection back to the original embedding dimension. | |
Args: | |
embed_dim (int): The dimensionality of the input embeddings. | |
hidden_dim (int): The dimensionality of the hidden layer. | |
bias (bool, optional): Whether to include bias terms in the linear layers. Defaults to True. | |
**kwargs: Additional keyword arguments (currently unused). | |
""" | |
super().__init__() | |
self.gate_proj = nn.Linear(embed_dim, hidden_dim, bias=bias) | |
self.up_proj = nn.Linear(embed_dim, hidden_dim, bias=bias) | |
self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=bias) | |
# Ignore copy | |
def forward(self, x): | |
""" | |
Applies a forward pass. | |
Args: | |
x (torch.Tensor): The input tensor. | |
Returns: | |
torch.Tensor: The output tensor after applying the forward pass. | |
""" | |
down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | |
return down_proj | |
class SelfAttentionWithRotaryEmbedding(nn.Module): | |
def __init__( | |
self, | |
embed_dim: int, | |
num_heads: int, | |
bias: bool = True, | |
eps: float = 1e-6, | |
): | |
""" | |
A PyTorch module implementing self-attention with rotary embeddings. | |
Args: | |
embed_dim (int): The dimensionality of the input embeddings. | |
num_heads (int): The number of attention heads. | |
bias (bool, optional): Whether to include bias terms in the linear projections. Defaults to True. | |
eps (float, optional): A small value added for numerical stability in normalization. Defaults to 1e-6. | |
""" | |
super().__init__() | |
assert embed_dim % num_heads == 0 | |
self.num_heads = num_heads | |
# key, query, value projections for all heads, but in a batch | |
self.c_qk = nn.Linear(embed_dim, 2 * embed_dim, bias=False) | |
self.c_v = nn.Linear(embed_dim, embed_dim, bias=bias) | |
# output projection | |
self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | |
head_dim = embed_dim // num_heads | |
self.q_norm = RMSNorm(head_dim) | |
self.k_norm = RMSNorm(head_dim) | |
def forward( | |
self, | |
x, | |
freqs_cis: torch.Tensor, | |
attn_mask=None, | |
is_causal: bool = False, | |
kv_cache: Optional[Cache] = None, | |
curr_pos_id: Optional[torch.Tensor] = None, | |
decode: bool = False, | |
): | |
""" | |
Forward pass for the SelfAttentionWithRotaryEmbedding instance. | |
Args: | |
x (torch.Tensor): Input tensor. | |
freqs_cis (torch.Tensor): Precomputed rotary positional embeddings. | |
attn_mask (Optional[torch.Tensor], optional): Attention mask to apply during self-attention. Defaults to None. | |
is_causal (bool, optional): Whether to apply causal masking for autoregressive decoding. Defaults to False. | |
kv_cache (Optional[Cache], optional): Cache object for storing key and value states for decoding. Defaults to None. | |
curr_pos_id (Optional[torch.Tensor], optional): Current position indices for decoding. Required if `decode` is True. Defaults to None. | |
decode (bool, optional): Whether the model is in decoding mode. Defaults to False. | |
Returns: | |
torch.Tensor: Output tensor after applying self-attention and projection. | |
""" | |
# batch size, sequence length, embedding dim | |
b, l, d = x.shape | |
# compute q, k, v and then split per q, k, v | |
q, k = self.c_qk(x).chunk(2, dim=-1) | |
v = self.c_v(x) | |
# split per head | |
q = q.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs) | |
k = k.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs) | |
v = v.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs) | |
q = self.q_norm(q) | |
k = self.k_norm(k) | |
if kv_cache is not None: | |
if not decode: | |
kv_cache.key_states[:, :, : k.shape[2], :].copy_(k) | |
kv_cache.value_states[:, :, : k.shape[2], :].copy_(v) | |
else: | |
assert curr_pos_id is not None | |
kv_cache.key_states.index_copy_(2, curr_pos_id, k) | |
kv_cache.value_states.index_copy_(2, curr_pos_id, v) | |
k = kv_cache.key_states | |
v = kv_cache.value_states | |
# self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) | |
# efficient attention using Flash Attention CUDA kernels | |
y = scaled_dot_product_attention_with_rotary_emb( | |
q, | |
k, | |
v, | |
freqs_cis=freqs_cis, | |
attn_mask=attn_mask, | |
curr_pos_id=curr_pos_id if decode else None, | |
is_causal=is_causal, | |
) | |
y = ( | |
y.transpose(1, 2).contiguous().view(b, l, d) | |
) # re-assemble all head outputs side by side | |
# output projection | |
y = self.c_proj(y) | |
return y | |
class DecoderLayerWithRotaryEmbedding(nn.Module): | |
def __init__( | |
self, | |
embed_dim: int, | |
num_heads: int, | |
bias: bool = True, | |
eps: float = 1e-6, | |
) -> None: | |
""" | |
Initializes the transformer model with rotary embeddings. | |
Args: | |
embed_dim (int): The dimensionality of the embedding space. | |
num_heads (int): The number of attention heads. | |
bias (bool, optional): Whether to include bias terms in the layers. Defaults to True. | |
eps (float, optional): A small value added for numerical stability in layer normalization. Defaults to 1e-6. | |
""" | |
super().__init__() | |
self.ln_1 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps) | |
self.attn = SelfAttentionWithRotaryEmbedding( | |
embed_dim, num_heads=num_heads, bias=bias, eps=eps | |
) | |
self.ln_2 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps) | |
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias) | |
def from_config(cls, cfg): | |
""" | |
Create an instance of the class using the provided configuration. | |
Args: | |
cfg: A configuration object containing the following attributes: | |
- n_embd (int): The size of the embedding dimension. | |
- n_head (int): The number of attention heads. | |
- bias (bool): Whether to include a bias term. | |
- eps (float): A small value added for numerical stability. | |
Returns: | |
An instance of the class initialized with the specified configuration. | |
""" | |
return cls( | |
cfg.n_embd, | |
num_heads=cfg.n_head, | |
bias=cfg.bias, | |
eps=cfg.eps, | |
) | |
def forward( | |
self, | |
x, | |
freqs_cis: torch.Tensor, | |
attn_mask=None, | |
is_causal: bool = True, | |
kv_cache: Optional[Cache] = None, | |
curr_pos_id: Optional[torch.Tensor] = None, | |
decode: bool = False, | |
): | |
""" | |
Forward pass for the transformer model. | |
Args: | |
x (torch.Tensor): Input tensor. | |
freqs_cis (torch.Tensor): Precomputed sinusoidal positional encodings. | |
attn_mask (Optional[torch.Tensor], optional): Attention mask to apply during self-attention. | |
Defaults to None. | |
is_causal (bool, optional): Whether to apply causal masking for autoregressive decoding. | |
Defaults to True. | |
kv_cache (Optional[Cache], optional): Key-value cache for efficient decoding. | |
Defaults to None. | |
curr_pos_id (Optional[torch.Tensor], optional): Current position IDs for decoding. | |
Defaults to None. | |
decode (bool, optional): Whether the model is in decoding mode. | |
Defaults to False. | |
Returns: | |
torch.Tensor: Output tensor. | |
""" | |
out = self.attn( | |
self.ln_1(x), | |
freqs_cis=freqs_cis, | |
attn_mask=attn_mask, | |
is_causal=is_causal, | |
kv_cache=kv_cache, | |
curr_pos_id=curr_pos_id, | |
decode=decode, | |
) | |
x = x + out | |
x = x + self.mlp(self.ln_2(x)) | |
return x | |