Spaces:
Running
on
L40S
Running
on
L40S
from typing import Optional, Tuple | |
import torch | |
import torch.nn as nn | |
from cube3d.model.transformers.cache import Cache | |
from cube3d.model.transformers.norm import LayerNorm, RMSNorm | |
from cube3d.model.transformers.roformer import SwiGLUMLP | |
from cube3d.model.transformers.rope import scaled_dot_product_attention_with_rotary_emb | |
class DismantledPreAttention(nn.Module): | |
def __init__( | |
self, | |
embed_dim: int, | |
num_heads: int, | |
query: bool = True, | |
bias: bool = True, | |
) -> None: | |
""" | |
Initializes the DismantledPreAttention module. | |
Args: | |
embed_dim (int): The dimensionality of the embedding space. | |
num_heads (int): The number of attention heads. | |
query (bool, optional): Whether to include query-key projection. Defaults to True. | |
bias (bool, optional): Whether to include bias in linear layers. Defaults to True. | |
Raises: | |
AssertionError: If `embed_dim` is not divisible by `num_heads`. | |
""" | |
super().__init__() | |
assert embed_dim % num_heads == 0 | |
self.query = query | |
head_dim = embed_dim // num_heads | |
# key, query, value projections for all heads, but in a batch | |
if query: | |
self.c_qk = nn.Linear(embed_dim, 2 * embed_dim, bias=False) | |
self.q_norm = RMSNorm(head_dim) | |
else: | |
self.c_k = nn.Linear(embed_dim, embed_dim, bias=bias) | |
self.k_norm = RMSNorm(head_dim) | |
self.c_v = nn.Linear(embed_dim, embed_dim, bias=bias) | |
# (B, T, C) -> (B, nh, T, hs) | |
self.to_mha = lambda x: x.view(*x.shape[:2], num_heads, -1).transpose(1, 2) | |
def forward(self, x): | |
""" | |
Forward pass for the dismantled pre-attention mechanism. | |
Args: | |
x (torch.Tensor): Input tensor of shape (..., input_dim). | |
Returns: | |
tuple: A tuple containing: | |
- q (torch.Tensor or None): Query tensor after normalization and transformation, | |
or None if `self.query` is False. | |
- k (torch.Tensor): Key tensor after normalization and transformation. | |
- v (torch.Tensor): Value tensor after transformation. | |
""" | |
if self.query: | |
q, k = self.c_qk(x).chunk(2, dim=-1) | |
q = self.q_norm(self.to_mha(q)) | |
else: | |
q = None | |
k = self.c_k(x) | |
k = self.k_norm(self.to_mha(k)) | |
v = self.to_mha(self.c_v(x)) | |
return (q, k, v) | |
class DismantledPostAttention(nn.Module): | |
def __init__( | |
self, | |
embed_dim, | |
bias: bool = True, | |
eps: float = 1e-6, | |
) -> None: | |
""" | |
Initializes the DismantledPostAttention module. | |
Args: | |
embed_dim (int): The dimensionality of the embedding space. | |
bias (bool, optional): Whether to include a bias term in the linear projection. Defaults to True. | |
eps (float, optional): A small value added to the denominator for numerical stability in layer normalization. Defaults to 1e-6. | |
""" | |
super().__init__() | |
self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | |
self.ln_3 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps) | |
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias) | |
def forward(self, x, a): | |
""" | |
Forward pass of the dual stream attention mechanism. | |
Args: | |
x (torch.Tensor): The input tensor to the model. | |
a (torch.Tensor): The attention tensor to be combined with the input. | |
Returns: | |
torch.Tensor: The output tensor after applying the projection, | |
layer normalization, and MLP transformations. | |
""" | |
x = x + self.c_proj(a) | |
x = x + self.mlp(self.ln_3(x)) | |
return x | |
class DualStreamAttentionWithRotaryEmbedding(nn.Module): | |
def __init__( | |
self, | |
embed_dim: int, | |
num_heads: int, | |
cond_pre_only: bool = False, | |
bias: bool = True, | |
): | |
""" | |
Initializes the DualStreamAttention module. | |
Args: | |
embed_dim (int): The dimensionality of the embedding space. | |
num_heads (int): The number of attention heads. | |
cond_pre_only (bool, optional): If True, the conditional pre-attention | |
will only process the key and value, not the query. Defaults to False. | |
bias (bool, optional): Whether to include a bias term in the attention layers. | |
Defaults to True. | |
""" | |
super().__init__() | |
self.cond_pre_only = cond_pre_only | |
self.pre_x = DismantledPreAttention( | |
embed_dim=embed_dim, num_heads=num_heads, query=True, bias=bias | |
) | |
self.pre_c = DismantledPreAttention( | |
embed_dim=embed_dim, num_heads=num_heads, query=not cond_pre_only, bias=bias | |
) | |
def forward( | |
self, | |
x, | |
c: Optional[torch.Tensor], | |
freqs_cis, | |
attn_mask: Optional[torch.Tensor] = None, | |
is_causal: bool = False, | |
kv_cache: Optional[Cache] = None, | |
curr_pos_id: Optional[torch.Tensor] = None, | |
decode: bool = False, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
""" | |
Forward pass for dual stream Multi-Head Attention. | |
Efficient single weight matrix multiplication with results split into query, key, value. | |
Parameters | |
---------- | |
x : torch.Tensor | |
Hidden states [B, L, D] | |
c : torch.Tensor | |
Condition [B, S, D] | |
freqs_cis: torch.Tensor | |
Precomputed RoPE matrix from precompute_freqs_cis [B, S+L, Hd] | |
attn_mask : torch.Tensor, optional | |
Attention mask [B, S+L, S+L], by default None | |
kv_cache: None | Tensor | |
key-value cache, but only if not None; if None - it means that it's disabled | |
contains cache for keys and value from all previous steps | |
kv_cache_cond: None | Tensor | |
key-value cache, but only if not None; if None - it means that it's disabled | |
contains cache for keys and value from all previous steps for the text conditioning. | |
Returns | |
------- | |
torch.Tensor | |
Hidden state output [B, L, D] | |
""" | |
if kv_cache is None or not decode: | |
# Either training or prefill | |
qkv_c = self.pre_c(c) | |
qkv_x = self.pre_x(x) | |
# prepend condition stream | |
# (B, nh, Tc, hs) + (B, nh, Tx, hs) -> (B, nh, Tc+Tx, hs) | |
if self.cond_pre_only: | |
q = qkv_x[0] | |
else: | |
q = torch.cat([qkv_c[0], qkv_x[0]], dim=2) | |
k = torch.cat([qkv_c[1], qkv_x[1]], dim=2) | |
v = torch.cat([qkv_c[2], qkv_x[2]], dim=2) | |
else: | |
# if using kv cache, query would only be the last token in the sequence, hence is_causal is False | |
assert x.shape[1] == 1 | |
is_causal = False | |
q, k, v = self.pre_x(x) | |
if kv_cache is not None: | |
if not decode: | |
kv_cache.key_states[:, :, : k.shape[2], :].copy_(k) | |
kv_cache.value_states[:, :, : k.shape[2], :].copy_(v) | |
else: | |
assert curr_pos_id is not None | |
kv_cache.key_states.index_copy_(2, curr_pos_id, k) | |
kv_cache.value_states.index_copy_(2, curr_pos_id, v) | |
k = kv_cache.key_states | |
v = kv_cache.value_states | |
if attn_mask is not None: | |
# trim attention mask to length | |
if decode: | |
assert curr_pos_id is not None | |
attn_mask = attn_mask[..., curr_pos_id, :] | |
else: | |
attn_mask = attn_mask[..., -q.shape[2] :, :] | |
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) | |
# efficient attention using Flash Attention CUDA kernels | |
y = scaled_dot_product_attention_with_rotary_emb( | |
q, | |
k, | |
v, | |
freqs_cis=freqs_cis, | |
attn_mask=attn_mask, | |
curr_pos_id=curr_pos_id if decode else None, | |
is_causal=is_causal, | |
) | |
# re-assemble all head outputs side by side | |
y = y.transpose(1, 2).contiguous().view(x.shape[0], -1, x.shape[2]) | |
if y.shape[1] == x.shape[1]: | |
y_c = None | |
y_x = y | |
else: | |
assert c is not None, "Conditioning is required for dual stream attention" | |
y_c, y_x = torch.split(y, [c.shape[1], x.shape[1]], dim=1) | |
return y_x, y_c | |
class DualStreamDecoderLayerWithRotaryEmbedding(nn.Module): | |
"""Nicely wrapped decoder layer block for dual stream GPT model""" | |
def __init__( | |
self, | |
embed_dim, | |
num_heads: int, | |
cond_pre_only: bool = False, | |
bias: bool = True, | |
eps: float = 1.0e-6, | |
) -> None: | |
""" | |
Initializes the DualStreamDecoderLayerWithRotaryEmbedding module with optional conditional pre-only mode. | |
Args: | |
embed_dim (int): The dimensionality of the embedding space. | |
num_heads (int): The number of attention heads. | |
cond_pre_only (bool, optional): If True, applies conditional processing only before attention. Defaults to False. | |
bias (bool, optional): If True, includes bias terms in the attention and post-attention layers. Defaults to True. | |
eps (float, optional): A small value added for numerical stability in layer normalization. Defaults to 1.0e-6. | |
""" | |
super().__init__() | |
self.ln_1 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps) | |
self.ln_2 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps) | |
self.attn = DualStreamAttentionWithRotaryEmbedding( | |
embed_dim=embed_dim, | |
num_heads=num_heads, | |
cond_pre_only=cond_pre_only, | |
bias=bias, | |
) | |
self.post_1 = DismantledPostAttention(embed_dim, bias=bias, eps=eps) | |
if not cond_pre_only: | |
self.post_2 = DismantledPostAttention(embed_dim, bias=bias, eps=eps) | |
def from_config(cls, cfg, cond_pre_only: bool = False): | |
""" | |
Create an instance of the class using the provided configuration. | |
Args: | |
cfg: A configuration object containing the necessary parameters: | |
- n_embd (int): The size of the embedding dimension. | |
- n_head (int): The number of attention heads. | |
- bias (bool): Whether to include a bias term. | |
- eps (float): A small value added for numerical stability. | |
cond_pre_only (bool, optional): If True, applies conditioning only in the pre-processing step. | |
Defaults to False. | |
Returns: | |
An instance of the class initialized with the specified configuration. | |
""" | |
return cls( | |
cfg.n_embd, | |
num_heads=cfg.n_head, | |
cond_pre_only=cond_pre_only, | |
bias=cfg.bias, | |
eps=cfg.eps, | |
) | |
def forward( | |
self, | |
x, | |
c, | |
freqs_cis: torch.Tensor, | |
attn_mask: Optional[torch.Tensor] = None, | |
is_causal: bool = True, | |
kv_cache: Optional[Cache] = None, | |
curr_pos_id: Optional[torch.Tensor] = None, | |
decode: bool = False, | |
): | |
""" | |
Forward pass for DualStreamDecoderLayerWithRotaryEmbedding. | |
Parameters | |
---------- | |
x : torch.Tensor | |
Hidden states [B, L, D] | |
c : torch.Tensor | |
Condition [B, S, D] | |
freqs_cis: torch.Tensor | |
Postional embedding from RoPE [B, S+L, hd] | |
attn_mask : torch.Tensor, optional | |
Attention mask [B, S+L, S+L], by default None | |
kv_vache : torch.Tensor, optional | |
kv_cache by default None | |
Returns | |
------- | |
torch.Tensor | |
Hidden state output [B, L, D] | |
torch.Tensor | |
kv_cache output [1, L, D] | |
""" | |
a_x, a_c = self.attn( | |
self.ln_1(x), | |
# NOTE condition could be none if using kv cache | |
self.ln_2(c) if c is not None else None, | |
freqs_cis=freqs_cis, | |
attn_mask=attn_mask, | |
is_causal=is_causal, | |
kv_cache=kv_cache, | |
curr_pos_id=curr_pos_id, | |
decode=decode, | |
) | |
x = self.post_1(x, a_x) | |
if a_c is not None: | |
c = self.post_2(c, a_c) | |
else: | |
c = None | |
return x, c | |