|
import logging |
|
from typing import Any, Optional |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__( |
|
self, |
|
num_heads: int, |
|
key_size: int, |
|
add_bias_kv: bool = False, |
|
value_size: Optional[int] = None, |
|
model_size: Optional[int] = None, |
|
name: Optional[str] = None, |
|
): |
|
super().__init__() |
|
if not model_size: |
|
model_size = key_size |
|
if not value_size: |
|
value_size = key_size |
|
self.model_size = model_size |
|
self.key_size = key_size |
|
self.value_size = value_size |
|
self.add_bias_kv = add_bias_kv |
|
self.name = name |
|
self.num_heads = num_heads |
|
|
|
self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size) |
|
self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size) |
|
self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size) |
|
self.output = nn.Linear(self.num_heads * self.value_size, self.model_size) |
|
|
|
def forward( |
|
self, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
attention_weight_bias: Optional[torch.Tensor] = None, |
|
) -> dict[str, torch.Tensor]: |
|
""" |
|
Returns: |
|
dictionary containing attention weights |
|
and outputs. |
|
""" |
|
key_heads = self.w_k(key).reshape( |
|
(*key.shape[:-1], self.num_heads, self.key_size) |
|
) |
|
query_heads = self.w_q(query).reshape( |
|
(*query.shape[:-1], self.num_heads, self.key_size) |
|
) |
|
value_heads = self.w_v(value).reshape( |
|
(*value.shape[:-1], self.num_heads, self.value_size) |
|
) |
|
attention_weights = torch.einsum( |
|
"...thd, ...Thd -> ...htT", query_heads, key_heads |
|
) |
|
sqrt_key_size = np.sqrt(self.key_size) |
|
attention_weights = attention_weights / sqrt_key_size |
|
if attention_mask is not None: |
|
attention_weights = torch.where(attention_mask, attention_weights, -1e30) |
|
if attention_weight_bias: |
|
attention_weights = F.softmax( |
|
attention_weights + attention_weight_bias, dim=-1 |
|
) |
|
else: |
|
attention_weights = F.softmax(attention_weights, dim=-1) |
|
value_out = torch.einsum( |
|
"...htT, ...Thd->...thd", attention_weights, value_heads |
|
) |
|
value_out = value_out.reshape((*value_out.shape[:-2], -1)) |
|
embeddings = self.output(value_out) |
|
|
|
return {"attention_weights": attention_weights, "embeddings": embeddings} |
|
|
|
|
|
class SelfAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
num_heads: int, |
|
embed_dim: int, |
|
ffn_embed_dim: int, |
|
key_size: Optional[int] = None, |
|
add_bias_kv: bool = False, |
|
add_bias_fnn: bool = True, |
|
ffn_activation_name: str = "gelu-no-approx", |
|
use_glu_in_ffn: bool = False, |
|
layer_norm_eps: float = 1e-5, |
|
pre_layer_norm: bool = True, |
|
name: Optional[str] = None, |
|
): |
|
super().__init__() |
|
if key_size is None: |
|
if embed_dim % num_heads != 0: |
|
raise ValueError( |
|
f"The embedding dimension should be divisible by the number of " |
|
f"heads, however provided embedding dimension is {embed_dim} and " |
|
f"the number of heads is {num_heads}." |
|
) |
|
else: |
|
key_size = embed_dim // num_heads |
|
|
|
|
|
self._pre_layer_norm = pre_layer_norm |
|
self._use_glu_in_fnn = use_glu_in_ffn |
|
|
|
if use_glu_in_ffn: |
|
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn) |
|
else: |
|
self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn) |
|
|
|
self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn) |
|
|
|
self.layer_norm_self_attention = nn.LayerNorm( |
|
embed_dim, |
|
) |
|
self.layer_norm_mlp = nn.LayerNorm(embed_dim) |
|
if ffn_activation_name == "swish": |
|
self._ffn_activation_fn = nn.SiLU() |
|
elif ffn_activation_name == "gelu-no-approx": |
|
self._ffn_activation_fn = lambda x: F.gelu(x, approximate="none") |
|
else: |
|
self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name) |
|
|
|
self.mha = MultiHeadAttention( |
|
num_heads=num_heads, |
|
key_size=key_size, |
|
add_bias_kv=add_bias_kv, |
|
model_size=embed_dim, |
|
name="self_attention", |
|
) |
|
|
|
def mlp(self, embed: torch.Tensor) -> torch.Tensor: |
|
|
|
if self._pre_layer_norm: |
|
x = self.layer_norm_mlp(embed) |
|
else: |
|
x = embed |
|
|
|
if self._use_glu_in_fnn: |
|
x = self.fc1(x) |
|
x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1) |
|
x = self._ffn_activation_fn(x1) * x2 |
|
else: |
|
x = self._ffn_activation_fn(self.fc1(x)) |
|
x = self.fc2(x) |
|
|
|
if not self._pre_layer_norm: |
|
x = self.layer_norm_mlp(x + embed) |
|
return x |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
attention_weight_bias: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
|
|
res = x |
|
if self._pre_layer_norm: |
|
x = self.layer_norm_self_attention(x) |
|
|
|
output = self.mha( |
|
x, |
|
x, |
|
x, |
|
attention_mask=attention_mask, |
|
attention_weight_bias=attention_weight_bias, |
|
) |
|
|
|
if not self._pre_layer_norm: |
|
output["embeddings"] = self.layer_norm_self_attention( |
|
output["embeddings"] + res |
|
) |
|
|
|
x = output["embeddings"] |
|
else: |
|
x = output["embeddings"] |
|
x = res + x |
|
|
|
|
|
if not self._pre_layer_norm: |
|
x = self.mlp(x) |
|
else: |
|
x = x + self.mlp(x) |
|
|
|
output["embeddings"] = x |
|
return output |
|
|
|
|
|
class BulkRNABertConfig(PretrainedConfig): |
|
model_type = "BulkRNABert" |
|
|
|
def __init__(self, **kwargs: Any) -> None: |
|
super().__init__(**kwargs) |
|
self.n_genes = kwargs.get("n_genes", 19_062) |
|
self.n_expressions_bins = kwargs.get("n_expressions_bins", 64) |
|
self.embed_dim = kwargs.get("embed_dim", 256) |
|
self.init_gene_embed_dim = kwargs.get("init_gene_embed_dim", 200) |
|
self.use_gene_embedding = kwargs.get("use_gene_embedding", True) |
|
self.project_gene_embedding = kwargs.get("project_gene_embedding", True) |
|
self.num_attention_heads = kwargs.get("num_attention_heads", 8) |
|
self.key_size = kwargs.get("key_size", None) |
|
self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 512) |
|
self.num_layers = kwargs.get("num_layers", 4) |
|
|
|
|
|
self.embeddings_layers_to_save: tuple[int, ...] = kwargs.get( |
|
"embeddings_layers_to_save", () |
|
) |
|
self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get( |
|
"attention_maps_to_save", [] |
|
) |
|
|
|
self.__post_init__() |
|
|
|
def __post_init__(self): |
|
|
|
key_size = self.key_size |
|
if key_size is None: |
|
embed_dim = self.embed_dim |
|
num_attention_heads = self.num_attention_heads |
|
if not embed_dim % num_attention_heads == 0: |
|
raise ValueError( |
|
f"When no key size is provided, the embedding dimension should be " |
|
f"divisible by the number of heads, however provided embedding " |
|
f"dimension is {embed_dim} and the number of heads is " |
|
f"{num_attention_heads}." |
|
) |
|
self.key_size = embed_dim // num_attention_heads |
|
|
|
|
|
use_gene_embedding = self.use_gene_embedding |
|
if use_gene_embedding: |
|
init_gene_embed_dim = self.init_gene_embed_dim |
|
embed_dim = self.embed_dim |
|
if init_gene_embed_dim != embed_dim: |
|
project_gene_embedding = self.project_gene_embedding |
|
if not project_gene_embedding: |
|
logging.warning( |
|
f"Init gene embedding dimension ({init_gene_embed_dim})" |
|
f"different than embedding dimension ({embed_dim})." |
|
f"Setting `project_gene_embedding` to True" |
|
) |
|
self.project_gene_embedding = True |
|
|
|
|
|
class BulkRNABert(PreTrainedModel): |
|
config_class = BulkRNABertConfig |
|
|
|
def __init__(self, config: BulkRNABertConfig): |
|
super().__init__(config=config) |
|
|
|
self.expression_embedding_layer = nn.Embedding( |
|
config.n_expressions_bins, config.embed_dim |
|
) |
|
self.gene_embedding_layer = nn.Embedding( |
|
config.n_genes, |
|
config.init_gene_embed_dim, |
|
) |
|
self.fc_gene_embedding = nn.Linear(config.init_gene_embed_dim, config.embed_dim) |
|
|
|
attention_maps_to_save = config.attention_maps_to_save |
|
self._attention_layers_to_save = list({t[0] for t in attention_maps_to_save}) |
|
|
|
self._attention_maps_per_layer_to_save = { |
|
layer: [t[1] for t in attention_maps_to_save if t[0] == layer] |
|
for layer in self._attention_layers_to_save |
|
} |
|
max_layer = max(self._attention_layers_to_save + [0]) |
|
if max_layer > config.num_layers: |
|
raise ValueError( |
|
f"You are requiring attention maps for layer {max_layer}, " |
|
f"while the model has {config.num_layers} layers only." |
|
) |
|
self.transformer_layers = nn.ModuleList( |
|
[ |
|
SelfAttentionBlock( |
|
num_heads=config.num_attention_heads, |
|
embed_dim=config.embed_dim, |
|
key_size=config.key_size, |
|
ffn_embed_dim=config.ffn_embed_dim, |
|
name=f"attention_layer_{layer_idx}", |
|
) |
|
for layer_idx in range(config.num_layers) |
|
] |
|
) |
|
|
|
self.lm_head = nn.Linear(config.embed_dim, config.n_expressions_bins) |
|
|
|
def forward( |
|
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None |
|
) -> dict[str, torch.Tensor]: |
|
outs = {} |
|
x = self.expression_embedding_layer(input_ids) |
|
|
|
if self.config.use_gene_embedding: |
|
gene_indices = torch.arange(self.config.n_genes, device=x.device) |
|
gene_embedding = self.gene_embedding_layer(gene_indices) |
|
if self.config.project_gene_embedding: |
|
gene_embedding = self.fc_gene_embedding(gene_embedding) |
|
x = x + gene_embedding |
|
|
|
if attention_mask is None: |
|
batch_size, seq_length = input_ids.shape |
|
attention_mask = torch.ones( |
|
(batch_size, 1, seq_length, seq_length), |
|
device=input_ids.device, |
|
dtype=bool, |
|
) |
|
|
|
for layer_idx, transformer in enumerate(self.transformer_layers): |
|
output = transformer(x, attention_mask=attention_mask) |
|
x = output["embeddings"] |
|
if (layer_idx + 1) in self.config.embeddings_layers_to_save: |
|
outs[f"embeddings_{(layer_idx + 1)}"] = output["embeddings"] |
|
if (layer_idx + 1) in self._attention_layers_to_save: |
|
for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]: |
|
dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}" |
|
outs[dkey] = output["attention_weights"][:, map_number + 1] |
|
|
|
outs["logits"] = self.lm_head(x) |
|
return outs |
|
|