Seq-TransfoRNA / transforna /src /model /model_components.py
Yak-hbdx's picture
uploaded TransfoRNA repo
0b11a42 verified
import logging
import math
import random
from typing import Dict, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import DictConfig
from torch.nn.modules.normalization import LayerNorm
logger = logging.getLogger(__name__)
def circulant_mask(n: int, window: int) -> torch.Tensor:
"""Calculate the relative attention mask, calculated once when model instatiated, as a subset of this matrix
will be used for a input length less than max.
i,j represent relative token positions in this matrix and in the attention scores matrix,
this mask enables attention scores to be set to 0 if further than the specified window length
:param n: a fixed parameter set to be larger than largest max sequence length across batches
:param window: [window length],
:return relative attention mask
"""
circulant_t = torch.zeros(n, n)
# [0, 1, 2, ..., window, -1, -2, ..., window]
offsets = [0] + [i for i in range(window + 1)] + [-i for i in range(window + 1)]
if window >= n:
return torch.ones(n, n)
for offset in offsets:
# size of the 1-tensor depends on the length of the diagonal
circulant_t.diagonal(offset=offset).copy_(torch.ones(n - abs(offset)))
return circulant_t
class SelfAttention(nn.Module):
"""normal query, key, value based self attention but with relative attention functionality
and a learnable bias encoding relative token position which is added to the attention scores before the softmax"""
def __init__(self, config: DictConfig, relative_attention: int):
"""init self attention weight of each key, query, value and output projection layer.
:param config: model config
:type config: ConveRTModelConfig
"""
super().__init__()
self.config = config
self.query = nn.Linear(config.num_embed_hidden, config.num_attention_project)
self.key = nn.Linear(config.num_embed_hidden, config.num_attention_project)
self.value = nn.Linear(config.num_embed_hidden, config.num_attention_project)
self.softmax = nn.Softmax(dim=-1)
self.output_projection = nn.Linear(
config.num_attention_project, config.num_embed_hidden
)
self.bias = torch.nn.Parameter(torch.randn(config.n), requires_grad=True)
stdv = 1.0 / math.sqrt(self.bias.data.size(0))
self.bias.data.uniform_(-stdv, stdv)
self.relative_attention = relative_attention
self.n = self.config.n
self.half_n = self.n // 2
self.register_buffer(
"relative_mask",
circulant_mask(config.tokens_len, self.relative_attention),
)
def forward(
self, attn_input: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""calculate self-attention of query, key and weighted to value at the end.
self-attention input is projected by linear layer at the first time.
applying attention mask for ignore pad index attention weight. Relative attention mask
applied and a learnable bias added to the attention scores.
return value after apply output projection layer to value * attention
:param attn_input: [description]
:type attn_input: [type]
:param attention_mask: [description], defaults to None
:type attention_mask: [type], optional
:return: [description]
:rtype: [type]
"""
self.T = attn_input.size()[1]
# input is B x max seq len x n_emb
_query = self.query.forward(attn_input)
_key = self.key.forward(attn_input)
_value = self.value.forward(attn_input)
# scaled dot product
attention_scores = torch.matmul(_query, _key.transpose(1, 2))
attention_scores = attention_scores / math.sqrt(
self.config.num_attention_project
)
# Relative attention
# extended_attention_mask = attention_mask.to(attention_scores.device) # fp16 compatibility
extended_attention_mask = (1.0 - attention_mask.unsqueeze(-1)) * -10000.0
attention_scores = attention_scores + extended_attention_mask
# fix circulant_matrix to matrix of size 60 x60 (max token truncation_length,
# register as buffer, so not keep creating masks of different sizes.
attention_scores = attention_scores.masked_fill(
self.relative_mask.unsqueeze(0)[:, : self.T, : self.T] == 0, float("-inf")
)
# Learnable bias vector is used of max size,for each i, different subsets of it are added to the scores, where the permutations
# depend on the relative position (i-j). this way cleverly allows no loops. bias vector is 2*max truncation length+1
# so has a learnable parameter for each eg. (i-j) /in {-60,...60} .
ii, jj = torch.meshgrid(torch.arange(self.T), torch.arange(self.T))
B_matrix = self.bias[self.n // 2 - ii + jj]
attention_scores = attention_scores + B_matrix.unsqueeze(0)
attention_scores = self.softmax(attention_scores)
output = torch.matmul(attention_scores, _value)
output = self.output_projection(output)
return [output,attention_scores] # B x T x num embed hidden
class FeedForward1(nn.Module):
def __init__(
self, input_hidden: int, intermediate_hidden: int, dropout_rate: float = 0.0
):
# 512 2048
super().__init__()
self.linear_1 = nn.Linear(input_hidden, intermediate_hidden)
self.dropout = nn.Dropout(dropout_rate)
self.linear_2 = nn.Linear(intermediate_hidden, input_hidden)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.gelu(self.linear_1(x))
return self.linear_2(self.dropout(x))
class SharedInnerBlock(nn.Module):
def __init__(self, config: DictConfig, relative_attn: int):
super().__init__()
self.config = config
self.self_attention = SelfAttention(config, relative_attn)
self.norm1 = LayerNorm(config.num_embed_hidden) # 512
self.dropout = nn.Dropout(config.dropout)
self.ff1 = FeedForward1(
config.num_embed_hidden, config.feed_forward1_hidden, config.dropout
)
self.norm2 = LayerNorm(config.num_embed_hidden)
def forward(self, x: torch.Tensor, attention_mask: int) -> torch.Tensor:
new_values_x,attn_scores = self.self_attention(x, attention_mask=attention_mask)
x = x+new_values_x
x = self.norm1(x)
x = x + self.ff1(x)
return self.norm2(x),attn_scores
# pretty basic, just single head. but done many times, stack to have another dimension (4 with batches).# so get stacks of B x H of attention scores T x T..
# then matrix multiply these extra stacks with the v
# (B xnh)x T xT . (Bx nh xTx hs) gives (B Nh) T x hs stacks. now hs is set to be final dimension/ number of heads, so reorder the stacks (concatenating them)
# can have optional extra projection layer, but doing that later
class MultiheadAttention(nn.Module):
def __init__(self, config: DictConfig):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.num_attn_proj = config.num_embed_hidden * config.num_attention_heads
self.attention_head_size = int(self.num_attn_proj / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.num_embed_hidden, self.num_attn_proj)
self.key = nn.Linear(config.num_embed_hidden, self.num_attn_proj)
self.value = nn.Linear(config.num_embed_hidden, self.num_attn_proj)
self.dropout = nn.Dropout(config.dropout)
def forward(
self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
B, T, _ = hidden_states.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = (
self.key(hidden_states)
.view(B, T, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
) # (B, nh, T, hs)
q = (
self.query(hidden_states)
.view(B, T, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
) # (B, nh, T, hs)
v = (
self.value(hidden_states)
.view(B, T, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
) # (B, nh, T, hs)
attention_scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :]
attention_mask = (1.0 - attention_mask) * -10000.0
attention_scores = attention_scores + attention_mask
attention_scores = F.softmax(attention_scores, dim=-1)
attention_scores = self.dropout(attention_scores)
y = attention_scores @ v
y = y.transpose(1, 2).contiguous().view(B, T, self.num_attn_proj)
return y
class PositionalEncoding(nn.Module):
def __init__(self, model_config: DictConfig,):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=model_config.dropout)
self.num_embed_hidden = model_config.num_embed_hidden
pe = torch.zeros(model_config.tokens_len, self.num_embed_hidden)
position = torch.arange(
0, model_config.tokens_len, dtype=torch.float
).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.num_embed_hidden, 2).float()
* (-math.log(10000.0) / self.num_embed_hidden)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
class RNAFFrwd(
nn.Module
): # params are not shared for context and reply. so need two sets of weights
"""Fully-Connected 3-layer Linear Model"""
def __init__(self, model_config: DictConfig):
"""
:param input_hidden: first-hidden layer input embed-dim
:type input_hidden: int
:param intermediate_hidden: layer-(hidden)-layer middle point weight
:type intermediate_hidden: int
:param dropout_rate: dropout rate, defaults to None
:type dropout_rate: float, optional
"""
# paper specifies,skip connections,layer normalization, and orthogonal initialization
super().__init__()
# 3,679,744 x2 params
self.rna_ffwd_input_dim = (
model_config.num_embed_hidden * model_config.num_attention_heads
)
self.linear_1 = nn.Linear(self.rna_ffwd_input_dim, self.rna_ffwd_input_dim)
self.linear_2 = nn.Linear(self.rna_ffwd_input_dim, self.rna_ffwd_input_dim)
self.norm1 = LayerNorm(self.rna_ffwd_input_dim)
self.norm2 = LayerNorm(self.rna_ffwd_input_dim)
self.final = nn.Linear(self.rna_ffwd_input_dim, model_config.num_embed_hidden)
self.orthogonal_initialization() # torch implementation works perfectly out the box,
def orthogonal_initialization(self):
for l in [
self.linear_1,
self.linear_2,
]:
torch.nn.init.orthogonal_(l.weight)
def forward(self, x: torch.Tensor, attn_msk: torch.Tensor) -> torch.Tensor:
sentence_lengths = attn_msk.sum(1)
# adding square root reduction projection separately as not a shared.
# part of the diagram torch.Size([Batch, scent_len, embedd_dim])
# x has dims B x T x 2*d_emb
norms = 1 / torch.sqrt(sentence_lengths.double()).float() # 64
# TODO: Aggregation is done on all words including the masked ones
x = norms.unsqueeze(1) * torch.sum(x, dim=1) # 64 x1024
x = x + F.gelu(self.linear_1(self.norm1(x)))
x = x + F.gelu(self.linear_2(self.norm2(x)))
return F.normalize(self.final(x), dim=1, p=2) # 64 512
class RNATransformer(nn.Module):
def __init__(self, model_config: DictConfig):
super().__init__()
self.num_embedd_hidden = model_config.num_embed_hidden
self.encoder = nn.Embedding(
model_config.vocab_size, model_config.num_embed_hidden
)
self.model_input = model_config.model_input
if 'baseline' not in self.model_input:
# positional encoder
self.pos_encoder = PositionalEncoding(model_config)
self.transformer_layers = nn.ModuleList(
[
SharedInnerBlock(model_config, int(window/model_config.window))
for window in model_config.relative_attns[
: model_config.num_encoder_layers
]
]
)
self.MHA = MultiheadAttention(model_config)
# self.concatenate = FeedForward2(model_config)
self.rna_ffrwd = RNAFFrwd(model_config)
self.pad_id = 0
def forward(self, x:torch.Tensor) -> torch.Tensor:
if x.is_cuda:
long_tensor = torch.cuda.LongTensor
else:
long_tensor = torch.LongTensor
embedds = self.encoder(x)
if 'baseline' not in self.model_input:
output = self.pos_encoder(embedds)
attention_mask = (x != self.pad_id).int()
for l in self.transformer_layers:
output,attn_scores = l(output, attention_mask)
output = self.MHA(output)
output = self.rna_ffrwd(output, attention_mask)
return output,attn_scores
else:
embedds = torch.flatten(embedds,start_dim=1)
return embedds,None
class GeneEmbeddModel(nn.Module):
def __init__(
self, main_config: DictConfig,
):
super().__init__()
self.train_config = main_config["train_config"]
self.model_config = main_config["model_config"]
self.device = self.train_config.device
self.model_input = self.model_config["model_input"]
self.false_input_perc = self.model_config["false_input_perc"]
#adjust n (used to add rel bias on attn scores)
self.model_config.n = self.model_config.tokens_len*2+1
self.transformer_layers = RNATransformer(self.model_config)
#save tokens_len of sequences to be used to split ids between transformers
self.tokens_len = self.model_config.tokens_len
#reassign tokens_len and vocab_size to init a new transformer
#more clean solution -> RNATransformer and its children should
# have a flag input indicating which transformer
self.model_config.tokens_len = self.model_config.second_input_token_len
self.model_config.n = self.model_config.tokens_len*2+1
self.seq_vocab_size = self.model_config.vocab_size
#this differs between both models not the token_len/ss_token_len
self.model_config.vocab_size = self.model_config.second_input_vocab_size
self.second_input_model = RNATransformer(self.model_config)
#num_transformers refers to using either one model or two in parallel
self.num_transformers = 2
if self.model_input == 'seq':
self.num_transformers = 1
# could be moved to model
self.weight_decay = self.train_config.l2_weight_decay
if 'baseline' in self.model_input:
self.num_transformers = 1
num_nodes = self.model_config.num_embed_hidden*self.tokens_len
self.final_clf_1 = nn.Linear(num_nodes,self.model_config.num_classes)
else:
#setting classification layer
num_nodes = self.num_transformers*self.model_config.num_embed_hidden
if self.num_transformers == 1:
self.final_clf_1 = nn.Linear(num_nodes,self.model_config.num_classes)
else:
self.final_clf_1 = nn.Linear(num_nodes,num_nodes)
self.final_clf_2 = nn.Linear(num_nodes,self.model_config.num_classes)
self.relu = nn.ReLU()
self.BN = nn.BatchNorm1d(num_nodes)
self.dropout = nn.Dropout(0.6)
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def distort_input(self,x):
for sample_idx in range(x.shape[0]):
seq_length = x[sample_idx,-1]
num_tokens_flipped = int(self.false_input_perc*seq_length)
max_start_flip_idx = seq_length - num_tokens_flipped
random_feat_idx = random.randint(0,max_start_flip_idx-1)
x[sample_idx,random_feat_idx:random_feat_idx+num_tokens_flipped] = \
torch.tensor(np.random.choice(range(1,self.seq_vocab_size-1),size=num_tokens_flipped,replace=True))
x[sample_idx,random_feat_idx+self.tokens_len:random_feat_idx+self.tokens_len+num_tokens_flipped] = \
torch.tensor(np.random.choice(range(1,self.model_config.second_input_vocab_size-1),size=num_tokens_flipped,replace=True))
return x
def forward(self, x,train=False):
if self.device == 'cuda':
long_tensor = torch.cuda.LongTensor
float_tensor = torch.cuda.FloatTensor
else:
long_tensor = torch.LongTensor
float_tensor = torch.FloatTensor
if train:
if self.false_input_perc > 0:
x = self.distort_input(x)
gene_embedd,attn_scores_first = self.transformer_layers(
x[:, : self.tokens_len].type(long_tensor)
)
attn_scores_second = None
second_input_embedd,attn_scores_second = self.second_input_model(
x[:, self.tokens_len :-1].type(long_tensor)
)
#for tcga: if seq or baseline
if self.num_transformers == 1:
activations = self.final_clf_1(gene_embedd)
else:
out_clf_1 = self.final_clf_1(torch.cat((gene_embedd, second_input_embedd), 1))
out = self.BN(out_clf_1)
out = self.relu(out)
out = self.dropout(out)
activations = self.final_clf_2(out)
#create dummy attn scores for baseline
if 'baseline' in self.model_input:
attn_scores_first = torch.ones((1,2,2),device=x.device)
return [gene_embedd, second_input_embedd, activations,attn_scores_first,attn_scores_second]