File size: 4,564 Bytes
8e0bbdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6914f78
8e0bbdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import torch.nn as nn
from mega.fairseq.modules.mega_layer import MegaEncoderLayer


class PS4_Conv(torch.nn.Module):
    def __init__(self):
        super(PS4_Conv, self).__init__()
        # This is only called "elmo_feature_extractor" for historic reason
        # CNN weights are trained on ProtT5 embeddings
        self.elmo_feature_extractor = torch.nn.Sequential(
            torch.nn.Conv2d(1024, 512, kernel_size=(7, 1), padding=(3, 0)),  # 7x512
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Conv2d(512, 256, kernel_size=(7, 1), padding=(3, 0)),  # 7x256
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Conv2d(256, 128, kernel_size=(7, 1), padding=(3, 0)),  # 7x128
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Conv2d(128, 32, kernel_size=(7, 1), padding=(3, 0)),  # 7x32
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1)
        )
        n_final_in = 32

        self.dssp8_classifier = torch.nn.Sequential(
            torch.nn.Conv2d(n_final_in, 8, kernel_size=(7, 1), padding=(3, 0))
        )

    def forward(self, x):
        # IN: X = (B x L x F); OUT: (B x F x L, 1)
        x = x.permute(0, 2, 1).unsqueeze(dim=-1)
        x = self.elmo_feature_extractor(x)  # OUT: (B x 32 x L x 1)
        d8_yhat = self.dssp8_classifier(x).squeeze(dim=-1).permute(0, 2, 1)  # OUT: (B x L x 8)
        
        return d8_yhat


class PS4_Mega(nn.Module):
    def __init__(self, nb_layers=11, l_aux_dim=1024, model_parallel=False,
                 h_dim=1024, batch_size=1, seq_len=1, dropout=0.0):
        super(PS4_Mega, self).__init__()

        self.nb_layers = nb_layers
        self.h_dim = h_dim
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.dropout = dropout
        self.aux_emb_size_l = l_aux_dim
        self.input_size = l_aux_dim

        self.args = ArgHolder(emb_dim=self.input_size, dropout=dropout, hdim=h_dim)

        self.nb_tags = 8

        self.model_parallel = model_parallel

        # build actual NN
        self.__build_model()

    def __build_model(self):

        # design Sequence processing module

        megas = []
        for i in range(self.nb_layers):
            mega = MegaEncoderLayer(self.args)
            megas.append(mega)

        self.seq_unit = MegaSequence(*megas)

        self.dropout_i = nn.Dropout(max(0.0, self.dropout - 0.2))

        # output layer which projects back to tag space
        out_dim = self.input_size

        self.hidden_to_tag = nn.Linear(out_dim, self.nb_tags, bias=False)

    def init_hidden(self):
        # the weights are of the form (nb_layers, batch_size, nb_rnn_units)
        hidden_a = torch.randn(self.nb_rnn_layers, self.batch_size, self.aux_emb_size_l)

        if torch.cuda.is_available():
            hidden_a = hidden_a.cuda()

        return hidden_a

    def forward(self, r):

        self.seq_len = r.shape[1]

        # residue encoding
        R = r.view(self.seq_len, self.batch_size,  self.aux_emb_size_l)

        X = self.dropout_i(R)

        # Run through MEGA
        X = self.seq_unit(X, encoder_padding_mask=None)
        X = X.view(self.batch_size, self.seq_len, self.input_size)

        # run through linear layer
        X = self.hidden_to_tag(X)

        Y_hat = X
        return Y_hat


class MegaSequence(nn.Sequential):
    def forward(self, input, **kwargs):
        for module in self:
            options = kwargs if isinstance(module, MegaEncoderLayer) else {}
            input = module(input, **options)
        return input


class ArgHolder(object):
    def __init__(self, hdim=512, dropout=0.1, emb_dim=1024):
        super(object, self).__init__()

        self.encoder_embed_dim = emb_dim
        self.encoder_hidden_dim = hdim
        self.dropout = dropout
        self.encoder_ffn_embed_dim = 1024
        self.ffn_hidden_dim: int = 1024
        self.encoder_z_dim: int = 128
        self.encoder_n_dim: int = 16
        self.activation_fn: str = 'silu'
        self.attention_activation_fn: str = 'softmax'
        self.attention_dropout: float = 0.0
        self.activation_dropout: float = 0.0
        self.hidden_dropout: float = 0.0
        self.encoder_chunk_size: int = -1
        self.truncation_length: int = None
        self.rel_pos_bias: str = 'simple'
        self.max_source_positions: int = 2048
        self.normalization_type: str = 'layernorm'
        self.normalize_before: bool = False
        self.feature_dropout: bool = False