from transformers import PretrainedConfig, PreTrainedModel, BertModel, BertConfig | |
from torch import nn | |
class SimBertModel(PreTrainedModel): | |
""" SimBert Model | |
""" | |
config_class = BertConfig | |
def __init__( | |
self, | |
config: PretrainedConfig | |
) -> None: | |
super().__init__(config) | |
self.bert = BertModel(config=config, add_pooling_layer=True) | |
self.fc = nn.Linear(config.hidden_size, 2) | |
# self.loss_fct = nn.CrossEntropyLoss() | |
self.loss_fct = nn.MSELoss() | |
self.softmax = nn.Softmax(dim=1) | |
def forward( | |
self, | |
input_ids, | |
token_type_ids, | |
attention_mask, | |
labels=None | |
): | |
outputs = self.bert( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids | |
) | |
pooled_output = outputs.pooler_output | |
logits = self.fc(pooled_output) | |
logits = self.softmax(logits)[:,1] | |
if labels is not None: | |
loss = self.loss_fct(logits.view(-1), labels.view(-1)) | |
return loss, logits | |
return None, logits |