query2osm / hydra.py
ellenhp's picture
Update to use new and improved bert model
b4df2a0
raw
history blame
4.01 kB
from transformers import BertConfig, BertModel
import torch.nn as nn
import torch
from typing import Optional, Union, Tuple, List
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.nn import CrossEntropyLoss
class HydraConfig(BertConfig):
model_type = "hydra"
label_groups = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
def num_labels(self):
return sum([len(group) for group in self.label_groups])
def distilbert_config(self):
return BertConfig(**self.__dict__)
class HydraSequenceClassifierOutput(SequenceClassifierOutput):
classifications: List[dict]
def __init__(self, classifications=None, **kwargs):
super().__init__(**kwargs)
self.classifications = classifications
class Hydra(BertModel):
config_class = HydraConfig
def __init__(self, config: HydraConfig):
super().__init__(config)
self.config = config
self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
self.classifier = nn.Linear(config.hidden_size, sum(
[len(group) for group in config.label_groups]))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.embeddings.requires_grad_(False)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
distilbert_output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, num_labels)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + distilbert_output[1:]
return ((loss,) + output) if loss is not None else output
classifications = []
if logits.shape[0] == 1:
offset = 0
for group in self.config.label_groups:
inverted = {group[pair]: pair for pair in group}
softmax = nn.Softmax(dim=1)
output = softmax(logits[:, offset:offset + len(group)])
classification = []
for i, val in enumerate(output[0]):
classification.append((inverted[i], val.item()))
classification.sort(key=lambda x: x[1], reverse=True)
classifications.append(classification)
offset += len(group)
return HydraSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=distilbert_output.hidden_states,
attentions=distilbert_output.attentions,
classifications=classifications
)
def to(self, device):
super().to(device)
self.pre_classifier.to(device)
self.classifier.to(device)
self.dropout.to(device)
return self