prot_albert_pfam / prot_albert_model.py
sayby's picture
Upload model
ce6350b verified
import torch.nn as nn
from transformers import AlbertModel, AutoConfig
from transformers import PreTrainedModel, AutoModel, AutoConfig
from transformers.models.albert.configuration_albert import AlbertConfig
class ProtAlBertModel(PreTrainedModel):
config_class = AlbertConfig
def __init__(self, num_labels: int=10, hidden_size: int=24, model_name: str="Rostlab/prot_albert", *args, **kwargs):
"""
Initialise the model.
:param hidden_size: size of the hidden layer after the CLS token.
:param num_labels: the number of labels.
:param model_name: the name of the model.
"""
self.config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=True,
)
super().__init__(self.config)
self.protbert = AlbertModel.from_pretrained(
model_name, trust_remote_code=True, config=self.config
)
num_labels = 10 if type(num_labels) is not int else num_labels
self.last_layer = nn.Sequential(
nn.Dropout(0.1),
nn.LayerNorm(self.config.hidden_size),
nn.Linear(self.config.hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, num_labels),
)
def forward(self, x):
# Take the last embedding of the [CLS] token
z = self.protbert(**x).last_hidden_state[:, 0, :]
output = self.last_layer(z)
return {"logits": output}