from dataclasses import dataclass from typing import Optional import torch from torch import nn from cube3d.model.transformers.cache import Cache from cube3d.model.transformers.dual_stream_attention import ( DualStreamDecoderLayerWithRotaryEmbedding, ) from cube3d.model.transformers.norm import LayerNorm from cube3d.model.transformers.roformer import DecoderLayerWithRotaryEmbedding from cube3d.model.transformers.rope import precompute_freqs_cis class DualStreamRoformer(nn.Module): @dataclass class Config: checkpoint_path: str = "" n_layer: int = 12 n_single_layer: int = 0 rope_theta: float = 1000 n_head: int = 16 n_embd: int = 2048 bias: bool = False # bias in Linears and LayerNorms eps: float = 1e-6 # Norm eps shape_model_vocab_size: int = 4096 shape_model_embed_dim: int = 16 text_model_embed_dim: int = 512 use_pooled_text_embed: bool = False encoder_with_cls_token: bool = True def __init__(self, cfg: Config) -> None: """ Initializes the DualStreamRoFormer model. Args: cfg (Config): Configuration object containing model parameters. Attributes: cfg (Config): Stores the configuration object. text_proj (nn.Linear): Linear layer to project text model embeddings to the desired embedding dimension. shape_proj (nn.Linear, optional): Linear layer to project shape model embeddings to the desired embedding dimension vocab_size (int): Vocabulary size for the shape model, including special tokens. shape_bos_id (int): Token ID for the beginning-of-sequence (BOS) token for the shape model. shape_eos_id (int): Token ID for the end-of-sequence (EOS) token for the shape model. padding_id (int): Token ID for the padding token. transformer (nn.ModuleDict): Dictionary containing the following components: - wte (nn.Embedding): Embedding layer for the vocabulary. - dual_blocks (nn.ModuleList): List of dual-stream decoder layers with rotary embeddings. - single_blocks (nn.ModuleList): List of single-stream decoder layers with rotary embeddings. - ln_f (LayerNorm): Layer normalization applied to the final output. lm_head (nn.Linear): Linear layer mapping the final embeddings to the vocabulary size for language modeling. """ super().__init__() self.cfg = cfg self.text_proj = nn.Linear( in_features=self.cfg.text_model_embed_dim, out_features=self.cfg.n_embd, bias=self.cfg.bias, ) self.shape_proj = nn.Linear(self.cfg.shape_model_embed_dim, self.cfg.n_embd) self.vocab_size = self.cfg.shape_model_vocab_size def add_special_token(): token_id = self.vocab_size self.vocab_size += 1 return token_id self.shape_bos_id = add_special_token() self.shape_eos_id = add_special_token() self.padding_id = add_special_token() self.transformer = nn.ModuleDict( dict( wte=nn.Embedding( self.vocab_size, self.cfg.n_embd, padding_idx=self.padding_id, ), dual_blocks=nn.ModuleList( [ DualStreamDecoderLayerWithRotaryEmbedding.from_config( self.cfg, cond_pre_only=(i == self.cfg.n_layer - 1) ) for i in range(self.cfg.n_layer) ] ), single_blocks=nn.ModuleList( [ DecoderLayerWithRotaryEmbedding.from_config(self.cfg) for _ in range(self.cfg.n_single_layer) ] ), ln_f=LayerNorm( self.cfg.n_embd, elementwise_affine=False, eps=self.cfg.eps ), ) ) self.lm_head = nn.Linear(self.cfg.n_embd, self.vocab_size, bias=False) def encode_text(self, text_embed): """ Encodes the given text embeddings by projecting them through a linear transformation. Args: text_embed (torch.Tensor): A tensor representing the text embeddings to be encoded. Returns: torch.Tensor: The projected text embeddings after applying the linear transformation. """ return self.text_proj(text_embed) def encode_token(self, tokens): """ Encodes the input tokens using the word token embedding layer of the transformer model. Args: tokens (torch.Tensor): A tensor containing the input tokens to be encoded. Returns: torch.Tensor: A tensor containing the encoded token embeddings. """ return self.transformer.wte(tokens) def init_kv_cache( self, batch_size: int, cond_len: int, max_shape_tokens: int, dtype: torch.dtype, device: torch.device, ) -> list[Cache]: """ Initializes the key-value cache for the transformer model. This method creates a list of `Cache` objects to store the key and value states for both dual-stream and single-stream transformer blocks. The cache is pre-allocated with zeros and is used to optimize the computation of attention mechanisms during model inference. Args: batch_size (int): The batch size for the input data. cond_len (int): The length of the conditioning sequence. max_shape_tokens (int): The maximum number of tokens in the shape sequence. dtype (torch.dtype): The data type for the tensors (e.g., torch.float32). device (torch.device): The device on which the tensors will be allocated (e.g., torch.device('cuda') or torch.device('cpu')). Returns: list[Cache]: A list of `Cache` objects containing pre-allocated key and value states for each transformer block. """ num_heads = self.cfg.n_head max_all_tokens = cond_len + max_shape_tokens per_head_dim = self.cfg.n_embd // num_heads kv_cache = [ Cache( key_states=torch.zeros( (batch_size, num_heads, max_all_tokens, per_head_dim), dtype=dtype, device=device, ), value_states=torch.zeros( (batch_size, num_heads, max_all_tokens, per_head_dim), dtype=dtype, device=device, ), ) for _ in range(len(self.transformer.dual_blocks)) ] kv_cache += [ Cache( key_states=torch.zeros( (batch_size, num_heads, max_shape_tokens, per_head_dim), dtype=dtype, device=device, ), value_states=torch.zeros( (batch_size, num_heads, max_shape_tokens, per_head_dim), dtype=dtype, device=device, ), ) for _ in range(len(self.transformer.single_blocks)) ] return kv_cache def forward( self, embed: torch.Tensor, cond: torch.Tensor, kv_cache: Optional[list[Cache]] = None, curr_pos_id: Optional[torch.Tensor] = None, decode: bool = False, ): """ Forward pass for the dual-stream RoFormer model. Args: embed (torch.Tensor): The input embedding tensor. cond (torch.Tensor): The conditioning tensor. kv_cache (Optional[list[Cache]]): A list of key-value caches for each layer, used for decoding. Default is None. curr_pos_id (Optional[torch.Tensor]): The current position ID tensor of shape (batch_size,). Required if `decode` is True. Default is None. decode (bool): Whether the model is in decoding mode. Default is False. Returns: torch.Tensor: The output logits tensor. """ b, l = embed.shape[:2] s = cond.shape[1] device = embed.device attn_mask = torch.tril( torch.ones(s + l, s + l, dtype=torch.bool, device=device) ) position_ids = torch.arange(l, dtype=torch.long, device=device) # shape (t) position_ids = position_ids.unsqueeze_(0).expand(b, -1) s_freqs_cis = precompute_freqs_cis( dim=self.cfg.n_embd // self.cfg.n_head, t=position_ids, theta=self.cfg.rope_theta, ) position_ids = torch.cat( [ torch.zeros([b, s], dtype=torch.long, device=position_ids.device), position_ids, ], dim=1, ) d_freqs_cis = precompute_freqs_cis( dim=self.cfg.n_embd // self.cfg.n_head, t=position_ids, theta=self.cfg.rope_theta, ) if kv_cache is not None and decode: assert curr_pos_id is not None embed = embed[:, curr_pos_id, :] h = embed c = cond layer_idx = 0 for block in self.transformer.dual_blocks: h, c = block( h, c=c, freqs_cis=d_freqs_cis, attn_mask=attn_mask, is_causal=True, kv_cache=kv_cache[layer_idx] if kv_cache is not None else None, curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None, decode=decode, ) layer_idx += 1 for block in self.transformer.single_blocks: h = block( h, freqs_cis=s_freqs_cis, attn_mask=None, is_causal=True, kv_cache=kv_cache[layer_idx] if kv_cache is not None else None, curr_pos_id=curr_pos_id, decode=decode, ) layer_idx += 1 # Normalization h = self.transformer.ln_f(h) logits = self.lm_head(h) return logits