omarperacha's picture
fix unneeded import
6914f78
import torch
import torch.nn as nn
from mega.fairseq.modules.mega_layer import MegaEncoderLayer
class PS4_Conv(torch.nn.Module):
def __init__(self):
super(PS4_Conv, self).__init__()
# This is only called "elmo_feature_extractor" for historic reason
# CNN weights are trained on ProtT5 embeddings
self.elmo_feature_extractor = torch.nn.Sequential(
torch.nn.Conv2d(1024, 512, kernel_size=(7, 1), padding=(3, 0)), # 7x512
torch.nn.ReLU(),
torch.nn.Dropout(0.1),
torch.nn.Conv2d(512, 256, kernel_size=(7, 1), padding=(3, 0)), # 7x256
torch.nn.ReLU(),
torch.nn.Dropout(0.1),
torch.nn.Conv2d(256, 128, kernel_size=(7, 1), padding=(3, 0)), # 7x128
torch.nn.ReLU(),
torch.nn.Dropout(0.1),
torch.nn.Conv2d(128, 32, kernel_size=(7, 1), padding=(3, 0)), # 7x32
torch.nn.ReLU(),
torch.nn.Dropout(0.1)
)
n_final_in = 32
self.dssp8_classifier = torch.nn.Sequential(
torch.nn.Conv2d(n_final_in, 8, kernel_size=(7, 1), padding=(3, 0))
)
def forward(self, x):
# IN: X = (B x L x F); OUT: (B x F x L, 1)
x = x.permute(0, 2, 1).unsqueeze(dim=-1)
x = self.elmo_feature_extractor(x) # OUT: (B x 32 x L x 1)
d8_yhat = self.dssp8_classifier(x).squeeze(dim=-1).permute(0, 2, 1) # OUT: (B x L x 8)
return d8_yhat
class PS4_Mega(nn.Module):
def __init__(self, nb_layers=11, l_aux_dim=1024, model_parallel=False,
h_dim=1024, batch_size=1, seq_len=1, dropout=0.0):
super(PS4_Mega, self).__init__()
self.nb_layers = nb_layers
self.h_dim = h_dim
self.batch_size = batch_size
self.seq_len = seq_len
self.dropout = dropout
self.aux_emb_size_l = l_aux_dim
self.input_size = l_aux_dim
self.args = ArgHolder(emb_dim=self.input_size, dropout=dropout, hdim=h_dim)
self.nb_tags = 8
self.model_parallel = model_parallel
# build actual NN
self.__build_model()
def __build_model(self):
# design Sequence processing module
megas = []
for i in range(self.nb_layers):
mega = MegaEncoderLayer(self.args)
megas.append(mega)
self.seq_unit = MegaSequence(*megas)
self.dropout_i = nn.Dropout(max(0.0, self.dropout - 0.2))
# output layer which projects back to tag space
out_dim = self.input_size
self.hidden_to_tag = nn.Linear(out_dim, self.nb_tags, bias=False)
def init_hidden(self):
# the weights are of the form (nb_layers, batch_size, nb_rnn_units)
hidden_a = torch.randn(self.nb_rnn_layers, self.batch_size, self.aux_emb_size_l)
if torch.cuda.is_available():
hidden_a = hidden_a.cuda()
return hidden_a
def forward(self, r):
self.seq_len = r.shape[1]
# residue encoding
R = r.view(self.seq_len, self.batch_size, self.aux_emb_size_l)
X = self.dropout_i(R)
# Run through MEGA
X = self.seq_unit(X, encoder_padding_mask=None)
X = X.view(self.batch_size, self.seq_len, self.input_size)
# run through linear layer
X = self.hidden_to_tag(X)
Y_hat = X
return Y_hat
class MegaSequence(nn.Sequential):
def forward(self, input, **kwargs):
for module in self:
options = kwargs if isinstance(module, MegaEncoderLayer) else {}
input = module(input, **options)
return input
class ArgHolder(object):
def __init__(self, hdim=512, dropout=0.1, emb_dim=1024):
super(object, self).__init__()
self.encoder_embed_dim = emb_dim
self.encoder_hidden_dim = hdim
self.dropout = dropout
self.encoder_ffn_embed_dim = 1024
self.ffn_hidden_dim: int = 1024
self.encoder_z_dim: int = 128
self.encoder_n_dim: int = 16
self.activation_fn: str = 'silu'
self.attention_activation_fn: str = 'softmax'
self.attention_dropout: float = 0.0
self.activation_dropout: float = 0.0
self.hidden_dropout: float = 0.0
self.encoder_chunk_size: int = -1
self.truncation_length: int = None
self.rel_pos_bias: str = 'simple'
self.max_source_positions: int = 2048
self.normalization_type: str = 'layernorm'
self.normalize_before: bool = False
self.feature_dropout: bool = False