Spaces:
Running
on
L40S
Running
on
L40S
from dataclasses import dataclass | |
from typing import Optional | |
import torch | |
from torch import nn | |
from cube3d.model.transformers.cache import Cache | |
from cube3d.model.transformers.dual_stream_attention import ( | |
DualStreamDecoderLayerWithRotaryEmbedding, | |
) | |
from cube3d.model.transformers.norm import LayerNorm | |
from cube3d.model.transformers.roformer import DecoderLayerWithRotaryEmbedding | |
from cube3d.model.transformers.rope import precompute_freqs_cis | |
class DualStreamRoformer(nn.Module): | |
class Config: | |
checkpoint_path: str = "" | |
n_layer: int = 12 | |
n_single_layer: int = 0 | |
rope_theta: float = 1000 | |
n_head: int = 16 | |
n_embd: int = 2048 | |
bias: bool = False # bias in Linears and LayerNorms | |
eps: float = 1e-6 # Norm eps | |
shape_model_vocab_size: int = 4096 | |
shape_model_embed_dim: int = 16 | |
text_model_embed_dim: int = 512 | |
use_pooled_text_embed: bool = False | |
encoder_with_cls_token: bool = True | |
def __init__(self, cfg: Config) -> None: | |
""" | |
Initializes the DualStreamRoFormer model. | |
Args: | |
cfg (Config): Configuration object containing model parameters. | |
Attributes: | |
cfg (Config): Stores the configuration object. | |
text_proj (nn.Linear): Linear layer to project text model embeddings to the desired embedding dimension. | |
shape_proj (nn.Linear, optional): Linear layer to project shape model embeddings to the desired embedding | |
dimension | |
vocab_size (int): Vocabulary size for the shape model, including special tokens. | |
shape_bos_id (int): Token ID for the beginning-of-sequence (BOS) token for the shape model. | |
shape_eos_id (int): Token ID for the end-of-sequence (EOS) token for the shape model. | |
padding_id (int): Token ID for the padding token. | |
transformer (nn.ModuleDict): Dictionary containing the following components: | |
- wte (nn.Embedding): Embedding layer for the vocabulary. | |
- dual_blocks (nn.ModuleList): List of dual-stream decoder layers with rotary embeddings. | |
- single_blocks (nn.ModuleList): List of single-stream decoder layers with rotary embeddings. | |
- ln_f (LayerNorm): Layer normalization applied to the final output. | |
lm_head (nn.Linear): Linear layer mapping the final embeddings to the vocabulary size for language modeling. | |
""" | |
super().__init__() | |
self.cfg = cfg | |
self.text_proj = nn.Linear( | |
in_features=self.cfg.text_model_embed_dim, | |
out_features=self.cfg.n_embd, | |
bias=self.cfg.bias, | |
) | |
self.shape_proj = nn.Linear(self.cfg.shape_model_embed_dim, self.cfg.n_embd) | |
self.vocab_size = self.cfg.shape_model_vocab_size | |
def add_special_token(): | |
token_id = self.vocab_size | |
self.vocab_size += 1 | |
return token_id | |
self.shape_bos_id = add_special_token() | |
self.shape_eos_id = add_special_token() | |
self.padding_id = add_special_token() | |
self.transformer = nn.ModuleDict( | |
dict( | |
wte=nn.Embedding( | |
self.vocab_size, | |
self.cfg.n_embd, | |
padding_idx=self.padding_id, | |
), | |
dual_blocks=nn.ModuleList( | |
[ | |
DualStreamDecoderLayerWithRotaryEmbedding.from_config( | |
self.cfg, cond_pre_only=(i == self.cfg.n_layer - 1) | |
) | |
for i in range(self.cfg.n_layer) | |
] | |
), | |
single_blocks=nn.ModuleList( | |
[ | |
DecoderLayerWithRotaryEmbedding.from_config(self.cfg) | |
for _ in range(self.cfg.n_single_layer) | |
] | |
), | |
ln_f=LayerNorm( | |
self.cfg.n_embd, elementwise_affine=False, eps=self.cfg.eps | |
), | |
) | |
) | |
self.lm_head = nn.Linear(self.cfg.n_embd, self.vocab_size, bias=False) | |
def encode_text(self, text_embed): | |
""" | |
Encodes the given text embeddings by projecting them through a linear transformation. | |
Args: | |
text_embed (torch.Tensor): A tensor representing the text embeddings to be encoded. | |
Returns: | |
torch.Tensor: The projected text embeddings after applying the linear transformation. | |
""" | |
return self.text_proj(text_embed) | |
def encode_token(self, tokens): | |
""" | |
Encodes the input tokens using the word token embedding layer of the transformer model. | |
Args: | |
tokens (torch.Tensor): A tensor containing the input tokens to be encoded. | |
Returns: | |
torch.Tensor: A tensor containing the encoded token embeddings. | |
""" | |
return self.transformer.wte(tokens) | |
def init_kv_cache( | |
self, | |
batch_size: int, | |
cond_len: int, | |
max_shape_tokens: int, | |
dtype: torch.dtype, | |
device: torch.device, | |
) -> list[Cache]: | |
""" | |
Initializes the key-value cache for the transformer model. | |
This method creates a list of `Cache` objects to store the key and value | |
states for both dual-stream and single-stream transformer blocks. The | |
cache is pre-allocated with zeros and is used to optimize the computation | |
of attention mechanisms during model inference. | |
Args: | |
batch_size (int): The batch size for the input data. | |
cond_len (int): The length of the conditioning sequence. | |
max_shape_tokens (int): The maximum number of tokens in the shape sequence. | |
dtype (torch.dtype): The data type for the tensors (e.g., torch.float32). | |
device (torch.device): The device on which the tensors will be allocated | |
(e.g., torch.device('cuda') or torch.device('cpu')). | |
Returns: | |
list[Cache]: A list of `Cache` objects containing pre-allocated key and | |
value states for each transformer block. | |
""" | |
num_heads = self.cfg.n_head | |
max_all_tokens = cond_len + max_shape_tokens | |
per_head_dim = self.cfg.n_embd // num_heads | |
kv_cache = [ | |
Cache( | |
key_states=torch.zeros( | |
(batch_size, num_heads, max_all_tokens, per_head_dim), | |
dtype=dtype, | |
device=device, | |
), | |
value_states=torch.zeros( | |
(batch_size, num_heads, max_all_tokens, per_head_dim), | |
dtype=dtype, | |
device=device, | |
), | |
) | |
for _ in range(len(self.transformer.dual_blocks)) | |
] | |
kv_cache += [ | |
Cache( | |
key_states=torch.zeros( | |
(batch_size, num_heads, max_shape_tokens, per_head_dim), | |
dtype=dtype, | |
device=device, | |
), | |
value_states=torch.zeros( | |
(batch_size, num_heads, max_shape_tokens, per_head_dim), | |
dtype=dtype, | |
device=device, | |
), | |
) | |
for _ in range(len(self.transformer.single_blocks)) | |
] | |
return kv_cache | |
def forward( | |
self, | |
embed: torch.Tensor, | |
cond: torch.Tensor, | |
kv_cache: Optional[list[Cache]] = None, | |
curr_pos_id: Optional[torch.Tensor] = None, | |
decode: bool = False, | |
): | |
""" | |
Forward pass for the dual-stream RoFormer model. | |
Args: | |
embed (torch.Tensor): The input embedding tensor. | |
cond (torch.Tensor): The conditioning tensor. | |
kv_cache (Optional[list[Cache]]): A list of key-value caches for each layer, used for decoding. Default is None. | |
curr_pos_id (Optional[torch.Tensor]): The current position ID tensor of shape (batch_size,). Required if `decode` is True. Default is None. | |
decode (bool): Whether the model is in decoding mode. Default is False. | |
Returns: | |
torch.Tensor: The output logits tensor. | |
""" | |
b, l = embed.shape[:2] | |
s = cond.shape[1] | |
device = embed.device | |
attn_mask = torch.tril( | |
torch.ones(s + l, s + l, dtype=torch.bool, device=device) | |
) | |
position_ids = torch.arange(l, dtype=torch.long, device=device) # shape (t) | |
position_ids = position_ids.unsqueeze_(0).expand(b, -1) | |
s_freqs_cis = precompute_freqs_cis( | |
dim=self.cfg.n_embd // self.cfg.n_head, | |
t=position_ids, | |
theta=self.cfg.rope_theta, | |
) | |
position_ids = torch.cat( | |
[ | |
torch.zeros([b, s], dtype=torch.long, device=position_ids.device), | |
position_ids, | |
], | |
dim=1, | |
) | |
d_freqs_cis = precompute_freqs_cis( | |
dim=self.cfg.n_embd // self.cfg.n_head, | |
t=position_ids, | |
theta=self.cfg.rope_theta, | |
) | |
if kv_cache is not None and decode: | |
assert curr_pos_id is not None | |
embed = embed[:, curr_pos_id, :] | |
h = embed | |
c = cond | |
layer_idx = 0 | |
for block in self.transformer.dual_blocks: | |
h, c = block( | |
h, | |
c=c, | |
freqs_cis=d_freqs_cis, | |
attn_mask=attn_mask, | |
is_causal=True, | |
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None, | |
curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None, | |
decode=decode, | |
) | |
layer_idx += 1 | |
for block in self.transformer.single_blocks: | |
h = block( | |
h, | |
freqs_cis=s_freqs_cis, | |
attn_mask=None, | |
is_causal=True, | |
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None, | |
curr_pos_id=curr_pos_id, | |
decode=decode, | |
) | |
layer_idx += 1 | |
# Normalization | |
h = self.transformer.ln_f(h) | |
logits = self.lm_head(h) | |
return logits | |