File size: 4,013 Bytes
bc83430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4df2a0
bc83430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4df2a0
bc83430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4df2a0
bc83430
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from transformers import BertConfig, BertModel
import torch.nn as nn
import torch
from typing import Optional, Union, Tuple, List
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.nn import CrossEntropyLoss


class HydraConfig(BertConfig):
    model_type = "hydra"
    label_groups = None

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def num_labels(self):
        return sum([len(group) for group in self.label_groups])

    def distilbert_config(self):
        return BertConfig(**self.__dict__)


class HydraSequenceClassifierOutput(SequenceClassifierOutput):
    classifications: List[dict]

    def __init__(self, classifications=None, **kwargs):
        super().__init__(**kwargs)
        self.classifications = classifications


class Hydra(BertModel):
    config_class = HydraConfig

    def __init__(self, config: HydraConfig):
        super().__init__(config)
        self.config = config
        self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
        self.classifier = nn.Linear(config.hidden_size, sum(
            [len(group) for group in config.label_groups]))
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        self.embeddings.requires_grad_(False)

        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        distilbert_output = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )
        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
        pooled_output = self.dropout(pooled_output)  # (bs, dim)
        logits = self.classifier(pooled_output)  # (bs, num_labels)

        loss = None
        if labels is not None:

            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits, labels)

            if not return_dict:
                output = (logits,) + distilbert_output[1:]
                return ((loss,) + output) if loss is not None else output

        classifications = []
        if logits.shape[0] == 1:
            offset = 0
            for group in self.config.label_groups:
                inverted = {group[pair]: pair for pair in group}
                softmax = nn.Softmax(dim=1)
                output = softmax(logits[:, offset:offset + len(group)])
                classification = []
                for i, val in enumerate(output[0]):
                    classification.append((inverted[i], val.item()))
                classification.sort(key=lambda x: x[1], reverse=True)
                classifications.append(classification)
                offset += len(group)

        return HydraSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions,
            classifications=classifications
        )

    def to(self, device):
        super().to(device)
        self.pre_classifier.to(device)
        self.classifier.to(device)
        self.dropout.to(device)
        return self