import logging import math from dataclasses import dataclass from typing import Any, Optional, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 from transformers import PretrainedConfig, PreTrainedModel @dataclass class RotaryEmbeddingConfig: """ Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows to adapt the rotary embeddings to larger lengths than what was used for training. One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa Args:b """ rescaling_factor: Optional[float] class RotaryEmbedding(torch.nn.Module): """ Rotary position embeddings based on those in [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation matrices which depend on their relative positions. """ def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfig): super().__init__() # Extract argument from the config self.rescaling_factor = rotary_embedding_config.rescaling_factor self.upper_freq = 10000 self.dim = dim self._seq_len_cached = None self._cos_cached = None self._sin_cached = None def _apply_rotary_pos_emb( self, heads: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: """ """ x_first, x_second = ( heads[..., : heads.shape[-1] // 2], heads[..., heads.shape[-1] // 2 :], ) first_part = x_first * cos - x_second * sin second_part = x_second * cos + x_first * sin return torch.cat((first_part, second_part), dim=-1) def _compute_cos_sin_tables( self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2 ) -> tuple[torch.Tensor, torch.Tensor]: seq_len = x.shape[seq_dimension] # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) self._seq_len_cached = seq_len t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq) freqs = torch.einsum("i, j -> ij", t, inv_freq) self._cos_cached = torch.cos(freqs)[None, :, None, :] self._sin_cached = torch.sin(freqs)[None, :, None, :] return self._cos_cached, self._sin_cached def forward( self, q: torch.Tensor, k: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: if self.rescaling_factor is None: inv_freq = 1.0 / ( self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim) ) else: updated_base = self.upper_freq * ( self.rescaling_factor ** (self.dim / (self.dim - 2)) ) inv_freq = 1.0 / ( updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim) ) self._cos_cached, self._sin_cached = self._compute_cos_sin_tables( q, inv_freq, seq_dimension=-3, ) return ( self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), ) class ResidualConvBlock(nn.Module): """ Conv Block with Residual connection. """ def __init__( self, dim_in: int, dim_out: int, layer_norm_shape: int, kernel_size: int = 1 ): super().__init__() self.conv_block = ConvBlock( dim_in=dim_in, dim_out=dim_out, layer_norm_shape=layer_norm_shape, kernel_size=kernel_size, ) def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.conv_block(x) return x.reshape(y.shape) + y class ConvBlock(nn.Module): """ Conv Block. """ def __init__( self, dim_in: int, dim_out: int, layer_norm_shape: int, kernel_size: int = 1 ): super().__init__() self.conv = nn.Conv1d( in_channels=dim_in, out_channels=dim_out, kernel_size=kernel_size, padding="same", ) self.layer_norm = nn.LayerNorm(layer_norm_shape, eps=1e-5) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 1) x = self.layer_norm(x) x = x.permute(0, 2, 1) x = self.conv(x) x = F.gelu(x, approximate="tanh") return x class ConvTowerBlock(nn.Module): def __init__( self, dim_in: int, dim_out: int, conv_layer_norm_shape: int, resconv_layer_norm_shape, kernel_size: int, ) -> None: super().__init__() self.conv_layer = ConvBlock( dim_in=dim_in, dim_out=dim_out, layer_norm_shape=conv_layer_norm_shape, kernel_size=kernel_size, ) self.res_conv = ResidualConvBlock( dim_in=dim_out, dim_out=dim_out, layer_norm_shape=resconv_layer_norm_shape, kernel_size=1, ) self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: residual = x x = self.conv_layer(x) x = self.res_conv(x) x = self.avg_pool(x) return x, residual class ResidualDeConvBlock(nn.Module): """ Conv Block with Residual connection. """ def __init__( self, dim_in: int, dim_out: int, layer_norm_shape: int, kernel_size: int = 1, stride: int = 1, ): super().__init__() self.deconv_block = DeConvBlock( dim_in=dim_in, dim_out=dim_out, layer_norm_shape=layer_norm_shape, kernel_size=kernel_size, stride=stride, ) def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.deconv_block(x) return x.reshape(y.shape) + y class DeConvBlock(nn.Module): """ DeConv Block. """ def __init__( self, dim_in: int, dim_out: int, layer_norm_shape: int, kernel_size: int = 1, stride: int = 1, ): super().__init__() self.deconv = nn.ConvTranspose1d( in_channels=dim_in, out_channels=dim_out, kernel_size=kernel_size, stride=stride, padding=0, ) self.layer_norm = nn.LayerNorm(layer_norm_shape) self.kernel_size = kernel_size def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 1) x = self.layer_norm(x) x = x.permute(0, 2, 1) x = self.deconv(x) if self.kernel_size == 5: # handle the special case where haiku # deconv removes padding automatically x = x[:, :, 1:-2] x = F.gelu(x, approximate="tanh") return x class DeConvTowerBlock(nn.Module): def __init__( self, dim_in: int, dim_out: int, kernel_size: int, conv_layer_norm_shape: int, resconv_layer_norm_shape: int, stride: int = 2, ): super().__init__() self.deconv_block = DeConvBlock( dim_in=dim_in, dim_out=dim_out, layer_norm_shape=conv_layer_norm_shape, kernel_size=kernel_size, stride=stride, ) self.res_deconv_block = ResidualDeConvBlock( dim_in=dim_out, dim_out=dim_out, layer_norm_shape=resconv_layer_norm_shape, kernel_size=1, ) def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor: x = self.deconv_block(x) x = self.res_deconv_block(x) x = x + res return x class MultiHeadAttention(nn.Module): def __init__( self, num_heads: int, key_size: int, rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None, add_bias_kv: bool = False, value_size: Optional[int] = None, model_size: Optional[int] = None, name: Optional[str] = None, ): super().__init__() if not model_size: model_size = key_size if not value_size: value_size = key_size self.model_size = model_size self.key_size = key_size self.value_size = value_size self.add_bias_kv = add_bias_kv self.name = name self.num_heads = num_heads self._rotary_embedding_config = rotary_embedding_config self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size) self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size) self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size) self.output = nn.Linear(self.num_heads * self.value_size, self.model_size) if self._rotary_embedding_config: self._rotary_embedding = RotaryEmbedding( self.key_size, self._rotary_embedding_config ) def apply_rotary_embeddings( self, query: torch.Tensor, key: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ """ query, key = self._rotary_embedding(query, key) return query, key def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, attention_weight_bias: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: """ Returns: dictionary containing attention weights and outputs. """ key_heads = self.w_k(key).reshape( (*key.shape[:-1], self.num_heads, self.key_size) ) query_heads = self.w_q(query).reshape( (*query.shape[:-1], self.num_heads, self.key_size) ) value_heads = self.w_v(value).reshape( (*value.shape[:-1], self.num_heads, self.value_size) ) if self._rotary_embedding_config: query_heads, key_heads = self.apply_rotary_embeddings( query_heads, key_heads ) attention_weights = torch.einsum( "...thd, ...Thd -> ...htT", query_heads, key_heads ) sqrt_key_size = np.sqrt(self.key_size) attention_weights = attention_weights / sqrt_key_size if attention_mask: attention_weights = torch.where(attention_mask, attention_weights, -1e30) if attention_weight_bias: attention_weights = F.softmax( attention_weights + attention_weight_bias, dim=-1 ) else: attention_weights = F.softmax(attention_weights, dim=-1) value_out = torch.einsum( "...htT, ...Thd->...thd", attention_weights, value_heads ) value_out = value_out.reshape((*value_out.shape[:-2], -1)) embeddings = self.output(value_out) return {"attention_weights": attention_weights, "embeddings": embeddings} class SelfAttentionBlock(nn.Module): def __init__( self, num_heads: int, embed_dim: int, ffn_embed_dim: int, key_size: Optional[int] = None, add_bias_kv: bool = False, add_bias_fnn: bool = True, ffn_activation_name: str = "gelu-no-approx", use_glu_in_ffn: bool = False, layer_norm_eps: float = 1e-5, # this is the default haiku value pre_layer_norm: bool = True, name: Optional[str] = None, rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None, ): super().__init__() if key_size is None: if embed_dim % num_heads != 0: raise ValueError( f"The embedding dimension should be divisible by the number of " f"heads, however provided embedding dimension is {embed_dim} and " f"the number of heads is {num_heads}." ) else: key_size = embed_dim // num_heads # Get ffn activation function self._pre_layer_norm = pre_layer_norm self._use_glu_in_fnn = use_glu_in_ffn # Define layers if use_glu_in_ffn: # user should multiply ffn_embed_dim by 2/3 when using GLU # to keep total number of parameters equal # see https://arxiv.org/pdf/2002.05202.pdf. for more details # we multiply by 2 here as the output will be split in 2 for GLU self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn) else: self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn) self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn) self.layer_norm_self_attention = nn.LayerNorm( embed_dim, ) self.layer_norm_mlp = nn.LayerNorm(embed_dim) if ffn_activation_name == "swish": self._ffn_activation_fn = nn.SiLU() elif ffn_activation_name == "gelu-no-approx": self._ffn_activation_fn = lambda x: F.gelu(x, approximate="none") else: self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name) self.mha = MultiHeadAttention( num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv, model_size=embed_dim, name="self_attention", rotary_embedding_config=rotary_embedding_config, ) def mlp(self, embed: torch.Tensor) -> torch.Tensor: if self._pre_layer_norm: x = self.layer_norm_mlp(embed) else: x = embed if self._use_glu_in_fnn: x = self.fc1(x) x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1) x = self._ffn_activation_fn(x1) * x2 else: x = self._ffn_activation_fn(self.fc1(x)) x = self.fc2(x) if not self._pre_layer_norm: x = self.layer_norm_mlp(x + embed) return x def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, attention_weight_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: res = x if self._pre_layer_norm: x = self.layer_norm_self_attention(x) output = self.mha( x, x, x, attention_mask=attention_mask, attention_weight_bias=attention_weight_bias, ) if not self._pre_layer_norm: output["embeddings"] = self.layer_norm_self_attention( output["embeddings"] + res ) x = output["embeddings"] else: x = output["embeddings"] x = res + x # MLP if not self._pre_layer_norm: x = self.mlp(x) else: x = x + self.mlp(x) output["embeddings"] = x return output class LMHead(nn.Module): def __init__( self, dim_in: int, embed_dim: int, dim_out: int, num_hidden_layers: int ) -> None: """ """ super().__init__() self.num_hidden_layers = num_hidden_layers self.linear_layers = nn.ModuleList([nn.Linear(dim_in, embed_dim)]) self.linear_layers.extend( nn.ModuleList( [nn.Linear(embed_dim, embed_dim)] # noqa for _ in range(num_hidden_layers - 1) ) ) self.linear_out = nn.Linear(embed_dim, dim_out) def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.gelu(x, approximate="tanh") for layer in self.linear_layers: x = layer(x) x = F.gelu(x, approximate="tanh") out = self.linear_out(x) return out class MOJOConfig(PretrainedConfig): # noqa: N801 model_type = "MOJO" def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.alphabet_size = kwargs.get( "alphabet_size", {"rnaseq": 66, "methylation": 66} ) self.token_embed_dim = kwargs.get("token_embed_dim", 256) self.init_gene_embed_dim = kwargs.get("init_gene_embed_dim", 200) self.use_gene_embedding = kwargs.get("use_gene_embedding", True) self.project_gene_embedding = kwargs.get("project_gene_embedding", True) self.sequence_length = kwargs.get("sequence_length", 17_116) # n_genes self.fixed_sequence_length = kwargs.get("fixed_sequence_length", None) self.num_downsamples = kwargs.get("num_downsamples", 8) self.conv_init_embed_dim = kwargs.get("conv_init_embed_dim", 512) self.stem_kernel_shape = kwargs.get("stem_kernel_shape", 15) self.embed_dim = kwargs.get("embed_dim", 512) self.filter_list = kwargs.get("filter_list", []) self.num_attention_heads = kwargs.get("num_attention_heads", 16) self.key_size = kwargs.get("key_size", None) self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 1_024) self.num_layers = kwargs.get("num_layers", 8) self.num_hidden_layers_head = kwargs.get("num_hidden_layers_head", 1) # return self.embeddings_layers_to_save: tuple[int, ...] = kwargs.get( "embeddings_layers_to_save", () ) self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get( "attention_maps_to_save", [] ) self.__post_init__() def __post_init__(self): # Validate attention key size key_size = self.key_size if key_size is None: embed_dim = self.embed_dim num_attention_heads = self.num_attention_heads if not embed_dim % num_attention_heads == 0: raise ValueError( f"When no key size is provided, the embedding dimension should be " f"divisible by the number of heads, however provided embedding " f"dimension is {embed_dim} and the number of heads is " f"{num_attention_heads}." ) self.key_size = embed_dim // num_attention_heads # Validate gene embedding projection use_gene_embedding = self.use_gene_embedding if use_gene_embedding: init_gene_embed_dim = self.init_gene_embed_dim token_embed_dim = self.token_embed_dim if init_gene_embed_dim != token_embed_dim: project_gene_embedding = self.project_gene_embedding if not project_gene_embedding: logging.warning( f"Init gene embedding dimension ({init_gene_embed_dim})" f"different than token embedding dimension ({token_embed_dim})." f"Setting `project_gene_embedding` to True" ) self.project_gene_embedding = True # Compute fixed_sequence_length num_downsamples = self.num_downsamples sequence_length = self.sequence_length downsample_factor = 2**num_downsamples fixed_sequence_length = ( math.ceil(sequence_length / downsample_factor) * downsample_factor ) self.fixed_sequence_length = fixed_sequence_length # Create filters list num_downsamples = self.num_downsamples filter_list = ( np.linspace( self.conv_init_embed_dim, self.embed_dim, num_downsamples + 1, ) .astype(int) .tolist() ) self.filter_list = filter_list # noqa class MOJO(PreTrainedModel): # noqa: N801 config_class = MOJOConfig def __init__(self, config: MOJOConfig): super().__init__(config=config) # Embeddings self.embedding_layers = nn.ModuleDict( { omic: nn.Embedding(config.alphabet_size[omic], config.token_embed_dim) for omic in config.alphabet_size } ) self.gene_embedding_layer = nn.Embedding( config.fixed_sequence_length, config.init_gene_embed_dim, ) self.fc_gene_embedding = nn.Linear( config.init_gene_embed_dim, config.token_embed_dim ) # Convolutions self.stem_conv = nn.Sequential( nn.Conv1d( in_channels=config.token_embed_dim, out_channels=config.conv_init_embed_dim, kernel_size=config.stem_kernel_shape, padding="same", ), nn.GELU(approximate="tanh"), ) self.conv_tower = nn.ModuleList( [ ConvTowerBlock( dim_in=config.filter_list[i], dim_out=config.filter_list[i + 1], kernel_size=5, conv_layer_norm_shape=config.filter_list[i], resconv_layer_norm_shape=config.filter_list[i + 1], ) for i in range(len(config.filter_list) - 1) ] ) # Transformer attention_maps_to_save = config.attention_maps_to_save self._attention_layers_to_save = list({t[0] for t in attention_maps_to_save}) self._attention_maps_per_layer_to_save = { layer: [t[1] for t in attention_maps_to_save if t[0] == layer] for layer in self._attention_layers_to_save } max_layer = max(self._attention_layers_to_save + [0]) if max_layer > config.num_layers: raise ValueError( f"You are requiring attention maps for layer {max_layer}, " f"while the model has {config.num_layers} layers only." ) self._rotary_embedding_config = RotaryEmbeddingConfig(rescaling_factor=None) self.transformer_layers = nn.ModuleList( [ SelfAttentionBlock( num_heads=config.num_attention_heads, embed_dim=config.embed_dim, ffn_embed_dim=config.ffn_embed_dim, key_size=config.key_size, add_bias_kv=False, add_bias_fnn=False, ffn_activation_name="swish", use_glu_in_ffn=True, layer_norm_eps=1e-5, # this is the default haiku value pre_layer_norm=True, name=f"attention_layer_{layer_idx}", rotary_embedding_config=self._rotary_embedding_config, ) for layer_idx in range(config.num_layers) ] ) # Deconvolutions self.deconv_tower = nn.ModuleList( [ DeConvTowerBlock( dim_in=config.filter_list[-1 - i], dim_out=config.filter_list[-1 - i - 1], kernel_size=5, stride=2, conv_layer_norm_shape=config.filter_list[-1 - i], resconv_layer_norm_shape=config.filter_list[-1 - i - 1], ) for i in range(len(config.filter_list) - 1) ] ) # Language Modeling heads self.omic_lm_heads = nn.ModuleDict( { omic: LMHead( dim_in=config.conv_init_embed_dim, embed_dim=config.embed_dim, dim_out=config.alphabet_size[omic], num_hidden_layers=config.num_hidden_layers_head, ) for omic in self.config.alphabet_size } ) def get_embeddings( self, input_ids: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: omic_embeddings = {} for omic, omic_tokens in input_ids.items(): omic_embeddings[omic] = self.embedding_layers[omic](omic_tokens) return omic_embeddings def forward(self, input_ids: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: outs = {} embeddings = self.get_embeddings(input_ids) outs["omic_embeddings"] = embeddings x = torch.stack(list(embeddings.values()), dim=0).sum(dim=0) # [B, T, C] outs["embeddings"] = x if self.config.use_gene_embedding: gene_indices = torch.arange( self.config.fixed_sequence_length, device=x.device ) gene_embedding = self.gene_embedding_layer(gene_indices) if self.config.project_gene_embedding: gene_embedding = self.fc_gene_embedding(gene_embedding) x = x + gene_embedding outs["embeddings_with_gene_embedding"] = x x = x.permute(0, 2, 1) x = self.stem_conv(x) outs["stem"] = x residuals = [] for conv_block in self.conv_tower: x, res = conv_block(x) residuals.append(res) x = x.permute(0, 2, 1) outs["conv_tower"] = x outs["conv_tower_residuals"] = residuals # type: ignore residuals = residuals[::-1] for layer_idx, transformer in enumerate(self.transformer_layers): output = transformer(x) x = output["embeddings"] if (layer_idx + 1) in self.config.embeddings_layers_to_save: outs[f"embeddings_{(layer_idx + 1)}"] = output["embeddings"] if (layer_idx + 1) in self._attention_layers_to_save: for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]: dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}" outs[dkey] = output["attention_weights"][:, map_number + 1] outs["after_transformer_embedding"] = x x = x.permute(0, 2, 1) for deconv_block, res in zip(self.deconv_tower, residuals): x = deconv_block(x, res) x = x.permute(0, 2, 1) outs["deconv_tower"] = x outs["logits"] = { omic: self.omic_lm_heads[omic](x) for omic in self.config.alphabet_size } return outs