ua-thesis-absa / models /BERT_LSTM_CRF.py
gundruke's picture
added app file
584d3dc
#from transformers import BertPreTrainedModel, BertForSequenceClassification, BertModel
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_outputs import TokenClassifierOutput
from torch import nn
from torch.nn import CrossEntropyLoss
import torch
from .layers import CRF
from itertools import islice
NUM_PER_LAYER = 16
class BERTLstmCRF(PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
print(config)
self.num_labels = config.num_labels
self.bert = AutoModel.from_pretrained(config._name_or_path, config=config, add_pooling_layer=False)
classifier_dropout = (config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.bilstm = nn.LSTM(config.hidden_size, (config.hidden_size) // 2, batch_first=True, bidirectional=True)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.crf = CRF(num_tags=config.num_labels, batch_first=True)
if self.config.freeze == True:
self.manage_freezing()
#self.bert.init_weights() # load pretrained weights
def manage_freezing(self):
for _, param in self.bert.embeddings.named_parameters():
param.requires_grad = False
num_encoders_to_freeze = self.config.num_frozen_encoder
if num_encoders_to_freeze > 0:
for _, param in islice(self.bert.encoder.named_parameters(), num_encoders_to_freeze*NUM_PER_LAYER):
param.requires_grad = False
def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None
):
# Default `model.config.use_return_dict´ is `True´
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
lstm_output, hc = self.bilstm(sequence_output)
logits = self.classifier(lstm_output)
loss = None
if labels is not None:
# During train/test as we don't pass labels during inference
loss = -1 * self.crf(logits, labels)
tags = torch.Tensor(self.crf.decode(logits))
return loss, tags