query2osm / hydra.py
ellenhp's picture
Update to use new and improved bert model
b4df2a0
from transformers import BertConfig, BertModel
import torch.nn as nn
import torch
from typing import Optional, Union, Tuple, List
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.nn import CrossEntropyLoss
class HydraConfig(BertConfig):
model_type = "hydra"
label_groups = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
def num_labels(self):
return sum([len(group) for group in self.label_groups])
def distilbert_config(self):
return BertConfig(**self.__dict__)
class HydraSequenceClassifierOutput(SequenceClassifierOutput):
classifications: List[dict]
def __init__(self, classifications=None, **kwargs):
super().__init__(**kwargs)
self.classifications = classifications
class Hydra(BertModel):
config_class = HydraConfig
def __init__(self, config: HydraConfig):
super().__init__(config)
self.config = config
self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
self.classifier = nn.Linear(config.hidden_size, sum(
[len(group) for group in config.label_groups]))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.embeddings.requires_grad_(False)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
distilbert_output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, num_labels)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + distilbert_output[1:]
return ((loss,) + output) if loss is not None else output
classifications = []
if logits.shape[0] == 1:
offset = 0
for group in self.config.label_groups:
inverted = {group[pair]: pair for pair in group}
softmax = nn.Softmax(dim=1)
output = softmax(logits[:, offset:offset + len(group)])
classification = []
for i, val in enumerate(output[0]):
classification.append((inverted[i], val.item()))
classification.sort(key=lambda x: x[1], reverse=True)
classifications.append(classification)
offset += len(group)
return HydraSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=distilbert_output.hidden_states,
attentions=distilbert_output.attentions,
classifications=classifications
)
def to(self, device):
super().to(device)
self.pre_classifier.to(device)
self.classifier.to(device)
self.dropout.to(device)
return self