from transformers import PreTrainedModel #from genomics_research.biobrain_p2.huggingface.modeling_enformer import Enformer from genomics_research.biobrain_p2.huggingface.modeling_esm import NTForMaskedLM, MultiHeadAttention from genomics_research.biobrain_p2.huggingface.isoformer_config import IsoformerConfig #from genomics_research.biobrain_p2.huggingface.enformer_config import EnformerConfig from genomics_research.biobrain_p2.huggingface.esm_config import NTConfig from genomics_research.biobrain_p2.huggingface.modeling_esm_original import EsmForMaskedLM from transformers.models.esm.configuration_esm import EsmConfig from enformer_pytorch import Enformer, str_to_one_hot, EnformerConfig import torch from torch import nn class Isoformer(PreTrainedModel): config_class = IsoformerConfig def __init__(self, config): super().__init__(config) self.esm_config = EsmConfig( vocab_size=config.esm_vocab_size, mask_token_id=config.esm_mask_token_id, pad_token_id=config.esm_pad_token_id, hidden_size=config.esm_hidden_size, num_hidden_layers=config.esm_num_hidden_layers, num_attention_heads=config.esm_num_attention_heads, intermediate_size=config.esm_intermediate_size, max_position_embeddings=config.esm_max_position_embeddings, token_dropout=config.esm_token_dropout, emb_layer_norm_before=config.esm_emb_layer_norm_before, attention_probs_dropout_prob=0.0, hidden_dropout_prob=0.0, use_cache=False, add_bias_fnn=config.esm_add_bias_fnn, position_embedding_type="rotary", tie_word_embeddings=False, ) self.nt_config = NTConfig( vocab_size=config.nt_vocab_size, mask_token_id=config.nt_mask_token_id, pad_token_id=config.nt_pad_token_id, hidden_size=config.nt_hidden_size, num_hidden_layers=config.nt_num_hidden_layers, num_attention_heads=config.nt_num_attention_heads, intermediate_size=config.nt_intermediate_size, max_position_embeddings=config.nt_max_position_embeddings, token_dropout=config.nt_token_dropout, emb_layer_norm_before=config.nt_emb_layer_norm_before, attention_probs_dropout_prob=0.0, hidden_dropout_prob=0.0, use_cache=False, add_bias_fnn=config.nt_add_bias_fnn, position_embedding_type="rotary", tie_word_embeddings=False, ) self.config = config # self.enformer_config = EnformerConfig( # dim=config.enformer_dim, # depth=config.enformer_depth, # heads=config.enformer_heads, # output_heads=dict( # human=1, # mouse=1 # TODO CHANGE # ), # target_length=config.enformer_target_length, # 896, # attn_dim_key=config.enformer_attn_dim_key, # dropout_rate=0.4, # attn_dropout=0.05, # pos_dropout=0.01, # use_checkpointing=config.enformer_use_checkpointing, # use_convnext=config.enformer_use_convnext, # num_downsamples=config.enformer_num_downsamples, # # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution # dim_divisible_by=config.enformer_dim_divisible_by, # use_tf_gamma=False, # ) self.esm_model = EsmForMaskedLM(self.esm_config) # protein encoder self.nt_model = NTForMaskedLM(self.nt_config) # rna encoder #self.enformer_model = Enformer(self.enformer_config) # dna encoder self.enformer_model = Enformer.from_pretrained("EleutherAI/enformer-official-rough") self.cross_attention_layer_rna = MultiHeadAttention( config=EsmConfig( num_attention_heads=config.num_heads_omics_cross_attention, attention_head_size=3072 // config.num_heads_omics_cross_attention, hidden_size=3072, attention_probs_dropout_prob=0, max_position_embeddings=0 ), omics_of_interest_size=3072, other_omic_size=768 ) self.cross_attention_layer_protein = MultiHeadAttention( config=EsmConfig( num_attention_heads=config.num_heads_omics_cross_attention, attention_head_size=3072 // config.num_heads_omics_cross_attention, hidden_size=3072, attention_probs_dropout_prob=0, max_position_embeddings=0 ), omics_of_interest_size=3072, other_omic_size=640 ) self.head_layer_1 = nn.Linear(3072, 2 * 3072) self.head_layer_2 = nn.Linear(2 * 3072, 30) def forward( self, tensor_dna, tensor_rna, tensor_protein, attention_mask_rna, attention_mask_protein ): tensor_dna = tensor_dna[:, 1:] # remove CLS dna_embedding = self.enformer_model( tensor_dna, return_only_embeddings=True # attention_mask=attention_mask_dna, # encoder_attention_mask=attention_mask_dna, # output_hidden_states=True ) protein_embedding = self.esm_model( tensor_protein, attention_mask=attention_mask_protein, encoder_attention_mask=attention_mask_protein, output_hidden_states=True ) rna_embedding = self.nt_model( tensor_rna, attention_mask=attention_mask_rna, encoder_attention_mask=attention_mask_rna, output_hidden_states=True ) encoder_attention_mask = torch.unsqueeze(torch.unsqueeze(tensor_rna != 1, 0),0).repeat(1,1,dna_embedding.shape[1],1) rna_to_dna = self.cross_attention_layer_rna.forward( hidden_states=dna_embedding, encoder_hidden_states=rna_embedding["hidden_states"][-1], encoder_attention_mask=encoder_attention_mask ) final_dna_embeddings = self.cross_attention_layer_protein.forward( hidden_states=rna_to_dna["embeddings"], encoder_hidden_states=protein_embedding["hidden_states"][-1], )["embeddings"] sequence_mask = torch.zeros(final_dna_embeddings.shape[1]) sequence_mask[self.config.pool_window_start:self.config.pool_window_end] = 1 x = torch.sum(torch.einsum('ijk,j->ijk', final_dna_embeddings, sequence_mask),axis=1)/torch.sum(sequence_mask) x = self.head_layer_1(x) x = torch.nn.functional.softplus(x) x = self.head_layer_2(x) return { "gene_expression_predictions": x, "final_dna_embeddings": final_dna_embeddings, }