Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from mega.fairseq.modules.mega_layer import MegaEncoderLayer | |
class PS4_Conv(torch.nn.Module): | |
def __init__(self): | |
super(PS4_Conv, self).__init__() | |
# This is only called "elmo_feature_extractor" for historic reason | |
# CNN weights are trained on ProtT5 embeddings | |
self.elmo_feature_extractor = torch.nn.Sequential( | |
torch.nn.Conv2d(1024, 512, kernel_size=(7, 1), padding=(3, 0)), # 7x512 | |
torch.nn.ReLU(), | |
torch.nn.Dropout(0.1), | |
torch.nn.Conv2d(512, 256, kernel_size=(7, 1), padding=(3, 0)), # 7x256 | |
torch.nn.ReLU(), | |
torch.nn.Dropout(0.1), | |
torch.nn.Conv2d(256, 128, kernel_size=(7, 1), padding=(3, 0)), # 7x128 | |
torch.nn.ReLU(), | |
torch.nn.Dropout(0.1), | |
torch.nn.Conv2d(128, 32, kernel_size=(7, 1), padding=(3, 0)), # 7x32 | |
torch.nn.ReLU(), | |
torch.nn.Dropout(0.1) | |
) | |
n_final_in = 32 | |
self.dssp8_classifier = torch.nn.Sequential( | |
torch.nn.Conv2d(n_final_in, 8, kernel_size=(7, 1), padding=(3, 0)) | |
) | |
def forward(self, x): | |
# IN: X = (B x L x F); OUT: (B x F x L, 1) | |
x = x.permute(0, 2, 1).unsqueeze(dim=-1) | |
x = self.elmo_feature_extractor(x) # OUT: (B x 32 x L x 1) | |
d8_yhat = self.dssp8_classifier(x).squeeze(dim=-1).permute(0, 2, 1) # OUT: (B x L x 8) | |
return d8_yhat | |
class PS4_Mega(nn.Module): | |
def __init__(self, nb_layers=11, l_aux_dim=1024, model_parallel=False, | |
h_dim=1024, batch_size=1, seq_len=1, dropout=0.0): | |
super(PS4_Mega, self).__init__() | |
self.nb_layers = nb_layers | |
self.h_dim = h_dim | |
self.batch_size = batch_size | |
self.seq_len = seq_len | |
self.dropout = dropout | |
self.aux_emb_size_l = l_aux_dim | |
self.input_size = l_aux_dim | |
self.args = ArgHolder(emb_dim=self.input_size, dropout=dropout, hdim=h_dim) | |
self.nb_tags = 8 | |
self.model_parallel = model_parallel | |
# build actual NN | |
self.__build_model() | |
def __build_model(self): | |
# design Sequence processing module | |
megas = [] | |
for i in range(self.nb_layers): | |
mega = MegaEncoderLayer(self.args) | |
megas.append(mega) | |
self.seq_unit = MegaSequence(*megas) | |
self.dropout_i = nn.Dropout(max(0.0, self.dropout - 0.2)) | |
# output layer which projects back to tag space | |
out_dim = self.input_size | |
self.hidden_to_tag = nn.Linear(out_dim, self.nb_tags, bias=False) | |
def init_hidden(self): | |
# the weights are of the form (nb_layers, batch_size, nb_rnn_units) | |
hidden_a = torch.randn(self.nb_rnn_layers, self.batch_size, self.aux_emb_size_l) | |
if torch.cuda.is_available(): | |
hidden_a = hidden_a.cuda() | |
return hidden_a | |
def forward(self, r): | |
self.seq_len = r.shape[1] | |
# residue encoding | |
R = r.view(self.seq_len, self.batch_size, self.aux_emb_size_l) | |
X = self.dropout_i(R) | |
# Run through MEGA | |
X = self.seq_unit(X, encoder_padding_mask=None) | |
X = X.view(self.batch_size, self.seq_len, self.input_size) | |
# run through linear layer | |
X = self.hidden_to_tag(X) | |
Y_hat = X | |
return Y_hat | |
class MegaSequence(nn.Sequential): | |
def forward(self, input, **kwargs): | |
for module in self: | |
options = kwargs if isinstance(module, MegaEncoderLayer) else {} | |
input = module(input, **options) | |
return input | |
class ArgHolder(object): | |
def __init__(self, hdim=512, dropout=0.1, emb_dim=1024): | |
super(object, self).__init__() | |
self.encoder_embed_dim = emb_dim | |
self.encoder_hidden_dim = hdim | |
self.dropout = dropout | |
self.encoder_ffn_embed_dim = 1024 | |
self.ffn_hidden_dim: int = 1024 | |
self.encoder_z_dim: int = 128 | |
self.encoder_n_dim: int = 16 | |
self.activation_fn: str = 'silu' | |
self.attention_activation_fn: str = 'softmax' | |
self.attention_dropout: float = 0.0 | |
self.activation_dropout: float = 0.0 | |
self.hidden_dropout: float = 0.0 | |
self.encoder_chunk_size: int = -1 | |
self.truncation_length: int = None | |
self.rel_pos_bias: str = 'simple' | |
self.max_source_positions: int = 2048 | |
self.normalization_type: str = 'layernorm' | |
self.normalize_before: bool = False | |
self.feature_dropout: bool = False | |