import torch.nn as nn from transformers import AlbertModel, AutoConfig from transformers import PreTrainedModel, AutoModel, AutoConfig from transformers.models.albert.configuration_albert import AlbertConfig class ProtAlBertModel(PreTrainedModel): config_class = AlbertConfig def __init__(self, num_labels: int=10, hidden_size: int=24, model_name: str="Rostlab/prot_albert", *args, **kwargs): """ Initialise the model. :param hidden_size: size of the hidden layer after the CLS token. :param num_labels: the number of labels. :param model_name: the name of the model. """ self.config = AutoConfig.from_pretrained( model_name, trust_remote_code=True, ) super().__init__(self.config) self.protbert = AlbertModel.from_pretrained( model_name, trust_remote_code=True, config=self.config ) num_labels = 10 if type(num_labels) is not int else num_labels self.last_layer = nn.Sequential( nn.Dropout(0.1), nn.LayerNorm(self.config.hidden_size), nn.Linear(self.config.hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, num_labels), ) def forward(self, x): # Take the last embedding of the [CLS] token z = self.protbert(**x).last_hidden_state[:, 0, :] output = self.last_layer(z) return {"logits": output}