BulkRNABert / bulkrnabert.py
mgelard's picture
Upload BulkRNABert
b3c76bb verified
import logging
from typing import Any, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from transformers import PretrainedConfig, PreTrainedModel
class MultiHeadAttention(nn.Module):
def __init__(
self,
num_heads: int,
key_size: int,
add_bias_kv: bool = False,
value_size: Optional[int] = None,
model_size: Optional[int] = None,
name: Optional[str] = None,
):
super().__init__()
if not model_size:
model_size = key_size
if not value_size:
value_size = key_size
self.model_size = model_size
self.key_size = key_size
self.value_size = value_size
self.add_bias_kv = add_bias_kv
self.name = name
self.num_heads = num_heads
self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size)
self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size)
self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size)
self.output = nn.Linear(self.num_heads * self.value_size, self.model_size)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_weight_bias: Optional[torch.Tensor] = None,
) -> dict[str, torch.Tensor]:
"""
Returns:
dictionary containing attention weights
and outputs.
"""
key_heads = self.w_k(key).reshape(
(*key.shape[:-1], self.num_heads, self.key_size)
)
query_heads = self.w_q(query).reshape(
(*query.shape[:-1], self.num_heads, self.key_size)
)
value_heads = self.w_v(value).reshape(
(*value.shape[:-1], self.num_heads, self.value_size)
)
attention_weights = torch.einsum(
"...thd, ...Thd -> ...htT", query_heads, key_heads
)
sqrt_key_size = np.sqrt(self.key_size)
attention_weights = attention_weights / sqrt_key_size
if attention_mask is not None:
attention_weights = torch.where(attention_mask, attention_weights, -1e30)
if attention_weight_bias:
attention_weights = F.softmax(
attention_weights + attention_weight_bias, dim=-1
)
else:
attention_weights = F.softmax(attention_weights, dim=-1)
value_out = torch.einsum(
"...htT, ...Thd->...thd", attention_weights, value_heads
)
value_out = value_out.reshape((*value_out.shape[:-2], -1))
embeddings = self.output(value_out)
return {"attention_weights": attention_weights, "embeddings": embeddings}
class SelfAttentionBlock(nn.Module):
def __init__(
self,
num_heads: int,
embed_dim: int,
ffn_embed_dim: int,
key_size: Optional[int] = None,
add_bias_kv: bool = False,
add_bias_fnn: bool = True,
ffn_activation_name: str = "gelu-no-approx",
use_glu_in_ffn: bool = False,
layer_norm_eps: float = 1e-5, # this is the default haiku value
pre_layer_norm: bool = True,
name: Optional[str] = None,
):
super().__init__()
if key_size is None:
if embed_dim % num_heads != 0:
raise ValueError(
f"The embedding dimension should be divisible by the number of "
f"heads, however provided embedding dimension is {embed_dim} and "
f"the number of heads is {num_heads}."
)
else:
key_size = embed_dim // num_heads
# Get ffn activation function
self._pre_layer_norm = pre_layer_norm
self._use_glu_in_fnn = use_glu_in_ffn
# Define layers
if use_glu_in_ffn:
# user should multiply ffn_embed_dim by 2/3 when using GLU
# to keep total number of parameters equal
# see https://arxiv.org/pdf/2002.05202.pdf. for more details
# we multiply by 2 here as the output will be split in 2 for GLU
self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn)
else:
self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn)
self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn)
self.layer_norm_self_attention = nn.LayerNorm(
embed_dim,
)
self.layer_norm_mlp = nn.LayerNorm(embed_dim)
if ffn_activation_name == "swish":
self._ffn_activation_fn = nn.SiLU()
elif ffn_activation_name == "gelu-no-approx":
self._ffn_activation_fn = lambda x: F.gelu(x, approximate="none")
else:
self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name)
self.mha = MultiHeadAttention(
num_heads=num_heads,
key_size=key_size,
add_bias_kv=add_bias_kv,
model_size=embed_dim,
name="self_attention",
)
def mlp(self, embed: torch.Tensor) -> torch.Tensor:
if self._pre_layer_norm:
x = self.layer_norm_mlp(embed)
else:
x = embed
if self._use_glu_in_fnn:
x = self.fc1(x)
x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1)
x = self._ffn_activation_fn(x1) * x2
else:
x = self._ffn_activation_fn(self.fc1(x))
x = self.fc2(x)
if not self._pre_layer_norm:
x = self.layer_norm_mlp(x + embed)
return x
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_weight_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
res = x
if self._pre_layer_norm:
x = self.layer_norm_self_attention(x)
output = self.mha(
x,
x,
x,
attention_mask=attention_mask,
attention_weight_bias=attention_weight_bias,
)
if not self._pre_layer_norm:
output["embeddings"] = self.layer_norm_self_attention(
output["embeddings"] + res
)
x = output["embeddings"]
else:
x = output["embeddings"]
x = res + x
# MLP
if not self._pre_layer_norm:
x = self.mlp(x)
else:
x = x + self.mlp(x)
output["embeddings"] = x
return output
class BulkRNABertConfig(PretrainedConfig):
model_type = "BulkRNABert"
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.n_genes = kwargs.get("n_genes", 19_062)
self.n_expressions_bins = kwargs.get("n_expressions_bins", 64)
self.embed_dim = kwargs.get("embed_dim", 256)
self.init_gene_embed_dim = kwargs.get("init_gene_embed_dim", 200)
self.use_gene_embedding = kwargs.get("use_gene_embedding", True)
self.project_gene_embedding = kwargs.get("project_gene_embedding", True)
self.num_attention_heads = kwargs.get("num_attention_heads", 8)
self.key_size = kwargs.get("key_size", None)
self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 512)
self.num_layers = kwargs.get("num_layers", 4)
# return
self.embeddings_layers_to_save: tuple[int, ...] = kwargs.get(
"embeddings_layers_to_save", ()
)
self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get(
"attention_maps_to_save", []
)
self.__post_init__()
def __post_init__(self):
# Validate attention key size
key_size = self.key_size
if key_size is None:
embed_dim = self.embed_dim
num_attention_heads = self.num_attention_heads
if not embed_dim % num_attention_heads == 0:
raise ValueError(
f"When no key size is provided, the embedding dimension should be "
f"divisible by the number of heads, however provided embedding "
f"dimension is {embed_dim} and the number of heads is "
f"{num_attention_heads}."
)
self.key_size = embed_dim // num_attention_heads
# Validate gene embedding projection
use_gene_embedding = self.use_gene_embedding
if use_gene_embedding:
init_gene_embed_dim = self.init_gene_embed_dim
embed_dim = self.embed_dim
if init_gene_embed_dim != embed_dim:
project_gene_embedding = self.project_gene_embedding
if not project_gene_embedding:
logging.warning(
f"Init gene embedding dimension ({init_gene_embed_dim})"
f"different than embedding dimension ({embed_dim})."
f"Setting `project_gene_embedding` to True"
)
self.project_gene_embedding = True
class BulkRNABert(PreTrainedModel):
config_class = BulkRNABertConfig
def __init__(self, config: BulkRNABertConfig):
super().__init__(config=config)
self.expression_embedding_layer = nn.Embedding(
config.n_expressions_bins, config.embed_dim
)
self.gene_embedding_layer = nn.Embedding(
config.n_genes,
config.init_gene_embed_dim,
)
self.fc_gene_embedding = nn.Linear(config.init_gene_embed_dim, config.embed_dim)
attention_maps_to_save = config.attention_maps_to_save
self._attention_layers_to_save = list({t[0] for t in attention_maps_to_save})
self._attention_maps_per_layer_to_save = {
layer: [t[1] for t in attention_maps_to_save if t[0] == layer]
for layer in self._attention_layers_to_save
}
max_layer = max(self._attention_layers_to_save + [0])
if max_layer > config.num_layers:
raise ValueError(
f"You are requiring attention maps for layer {max_layer}, "
f"while the model has {config.num_layers} layers only."
)
self.transformer_layers = nn.ModuleList(
[
SelfAttentionBlock(
num_heads=config.num_attention_heads,
embed_dim=config.embed_dim,
key_size=config.key_size,
ffn_embed_dim=config.ffn_embed_dim,
name=f"attention_layer_{layer_idx}",
)
for layer_idx in range(config.num_layers)
]
)
self.lm_head = nn.Linear(config.embed_dim, config.n_expressions_bins)
def forward(
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> dict[str, torch.Tensor]:
outs = {}
x = self.expression_embedding_layer(input_ids)
if self.config.use_gene_embedding:
gene_indices = torch.arange(self.config.n_genes, device=x.device)
gene_embedding = self.gene_embedding_layer(gene_indices)
if self.config.project_gene_embedding:
gene_embedding = self.fc_gene_embedding(gene_embedding)
x = x + gene_embedding
if attention_mask is None:
batch_size, seq_length = input_ids.shape
attention_mask = torch.ones( # noqa
(batch_size, 1, seq_length, seq_length),
device=input_ids.device,
dtype=bool,
)
for layer_idx, transformer in enumerate(self.transformer_layers):
output = transformer(x, attention_mask=attention_mask)
x = output["embeddings"]
if (layer_idx + 1) in self.config.embeddings_layers_to_save:
outs[f"embeddings_{(layer_idx + 1)}"] = output["embeddings"]
if (layer_idx + 1) in self._attention_layers_to_save:
for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]:
dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}"
outs[dkey] = output["attention_weights"][:, map_number + 1]
outs["logits"] = self.lm_head(x)
return outs