Spaces:
Runtime error
Runtime error
File size: 4,564 Bytes
8e0bbdf 6914f78 8e0bbdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
import torch.nn as nn
from mega.fairseq.modules.mega_layer import MegaEncoderLayer
class PS4_Conv(torch.nn.Module):
def __init__(self):
super(PS4_Conv, self).__init__()
# This is only called "elmo_feature_extractor" for historic reason
# CNN weights are trained on ProtT5 embeddings
self.elmo_feature_extractor = torch.nn.Sequential(
torch.nn.Conv2d(1024, 512, kernel_size=(7, 1), padding=(3, 0)), # 7x512
torch.nn.ReLU(),
torch.nn.Dropout(0.1),
torch.nn.Conv2d(512, 256, kernel_size=(7, 1), padding=(3, 0)), # 7x256
torch.nn.ReLU(),
torch.nn.Dropout(0.1),
torch.nn.Conv2d(256, 128, kernel_size=(7, 1), padding=(3, 0)), # 7x128
torch.nn.ReLU(),
torch.nn.Dropout(0.1),
torch.nn.Conv2d(128, 32, kernel_size=(7, 1), padding=(3, 0)), # 7x32
torch.nn.ReLU(),
torch.nn.Dropout(0.1)
)
n_final_in = 32
self.dssp8_classifier = torch.nn.Sequential(
torch.nn.Conv2d(n_final_in, 8, kernel_size=(7, 1), padding=(3, 0))
)
def forward(self, x):
# IN: X = (B x L x F); OUT: (B x F x L, 1)
x = x.permute(0, 2, 1).unsqueeze(dim=-1)
x = self.elmo_feature_extractor(x) # OUT: (B x 32 x L x 1)
d8_yhat = self.dssp8_classifier(x).squeeze(dim=-1).permute(0, 2, 1) # OUT: (B x L x 8)
return d8_yhat
class PS4_Mega(nn.Module):
def __init__(self, nb_layers=11, l_aux_dim=1024, model_parallel=False,
h_dim=1024, batch_size=1, seq_len=1, dropout=0.0):
super(PS4_Mega, self).__init__()
self.nb_layers = nb_layers
self.h_dim = h_dim
self.batch_size = batch_size
self.seq_len = seq_len
self.dropout = dropout
self.aux_emb_size_l = l_aux_dim
self.input_size = l_aux_dim
self.args = ArgHolder(emb_dim=self.input_size, dropout=dropout, hdim=h_dim)
self.nb_tags = 8
self.model_parallel = model_parallel
# build actual NN
self.__build_model()
def __build_model(self):
# design Sequence processing module
megas = []
for i in range(self.nb_layers):
mega = MegaEncoderLayer(self.args)
megas.append(mega)
self.seq_unit = MegaSequence(*megas)
self.dropout_i = nn.Dropout(max(0.0, self.dropout - 0.2))
# output layer which projects back to tag space
out_dim = self.input_size
self.hidden_to_tag = nn.Linear(out_dim, self.nb_tags, bias=False)
def init_hidden(self):
# the weights are of the form (nb_layers, batch_size, nb_rnn_units)
hidden_a = torch.randn(self.nb_rnn_layers, self.batch_size, self.aux_emb_size_l)
if torch.cuda.is_available():
hidden_a = hidden_a.cuda()
return hidden_a
def forward(self, r):
self.seq_len = r.shape[1]
# residue encoding
R = r.view(self.seq_len, self.batch_size, self.aux_emb_size_l)
X = self.dropout_i(R)
# Run through MEGA
X = self.seq_unit(X, encoder_padding_mask=None)
X = X.view(self.batch_size, self.seq_len, self.input_size)
# run through linear layer
X = self.hidden_to_tag(X)
Y_hat = X
return Y_hat
class MegaSequence(nn.Sequential):
def forward(self, input, **kwargs):
for module in self:
options = kwargs if isinstance(module, MegaEncoderLayer) else {}
input = module(input, **options)
return input
class ArgHolder(object):
def __init__(self, hdim=512, dropout=0.1, emb_dim=1024):
super(object, self).__init__()
self.encoder_embed_dim = emb_dim
self.encoder_hidden_dim = hdim
self.dropout = dropout
self.encoder_ffn_embed_dim = 1024
self.ffn_hidden_dim: int = 1024
self.encoder_z_dim: int = 128
self.encoder_n_dim: int = 16
self.activation_fn: str = 'silu'
self.attention_activation_fn: str = 'softmax'
self.attention_dropout: float = 0.0
self.activation_dropout: float = 0.0
self.hidden_dropout: float = 0.0
self.encoder_chunk_size: int = -1
self.truncation_length: int = None
self.rel_pos_bias: str = 'simple'
self.max_source_positions: int = 2048
self.normalization_type: str = 'layernorm'
self.normalize_before: bool = False
self.feature_dropout: bool = False
|