|
|
|
import torch.nn as nn |
|
from transformers import AlbertModel, AutoConfig |
|
from transformers import PreTrainedModel, AutoModel, AutoConfig |
|
|
|
from transformers.models.albert.configuration_albert import AlbertConfig |
|
|
|
|
|
|
|
class ProtAlBertModel(PreTrainedModel): |
|
config_class = AlbertConfig |
|
def __init__(self, num_labels: int=10, hidden_size: int=24, model_name: str="Rostlab/prot_albert", *args, **kwargs): |
|
""" |
|
Initialise the model. |
|
:param hidden_size: size of the hidden layer after the CLS token. |
|
:param num_labels: the number of labels. |
|
:param model_name: the name of the model. |
|
""" |
|
self.config = AutoConfig.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
) |
|
super().__init__(self.config) |
|
self.protbert = AlbertModel.from_pretrained( |
|
model_name, trust_remote_code=True, config=self.config |
|
) |
|
num_labels = 10 if type(num_labels) is not int else num_labels |
|
self.last_layer = nn.Sequential( |
|
nn.Dropout(0.1), |
|
nn.LayerNorm(self.config.hidden_size), |
|
nn.Linear(self.config.hidden_size, hidden_size), |
|
nn.GELU(), |
|
nn.Linear(hidden_size, num_labels), |
|
) |
|
|
|
def forward(self, x): |
|
|
|
z = self.protbert(**x).last_hidden_state[:, 0, :] |
|
output = self.last_layer(z) |
|
return {"logits": output} |