File size: 1,462 Bytes
13fc001
 
 
 
 
0204c8f
 
4498124
13fc001
 
0204c8f
c67a006
13fc001
 
 
 
 
 
 
 
 
 
6144e5d
13fc001
4498124
13fc001
4b8bf81
13fc001
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

import torch.nn as nn
from transformers import AlbertModel, AutoConfig
from transformers import PreTrainedModel, AutoModel, AutoConfig

from transformers.models.albert.configuration_albert import AlbertConfig



class ProtAlBertModel(PreTrainedModel):
    config_class = AlbertConfig
    def __init__(self, num_labels: int=10, hidden_size: int=24, model_name: str="Rostlab/prot_albert", *args, **kwargs):
        """
        Initialise the model.
        :param hidden_size: size of the hidden layer after the CLS token.
        :param num_labels: the number of labels.
        :param model_name: the name of the model.
        """
        self.config = AutoConfig.from_pretrained(
            model_name,
            trust_remote_code=True,
        )
        super().__init__(self.config)
        self.protbert = AlbertModel.from_pretrained(
            model_name, trust_remote_code=True, config=self.config
        )
        num_labels = 10 if type(num_labels) is not int else num_labels
        self.last_layer = nn.Sequential(
            nn.Dropout(0.1),
            nn.LayerNorm(self.config.hidden_size),
            nn.Linear(self.config.hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, num_labels),
        )

    def forward(self, x):
        # Take the last embedding of the [CLS] token
        z = self.protbert(**x).last_hidden_state[:, 0, :]
        output = self.last_layer(z)
        return {"logits": output}