File size: 4,508 Bytes
26dff99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
__author__ = "Yifan Zhang ([email protected])"
__copyright__ = "Copyright (C) 2021, Qatar Computing Research Institute, HBKU, Doha"


from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn.functional import sigmoid
from transformers import BertPreTrainedModel, BertModel
from transformers.file_utils import ModelOutput


TOKEN_TAGS = (
    "<PAD>", "O", 
    "Name_Calling,Labeling", "Repetition", "Slogans", "Appeal_to_fear-prejudice", "Doubt",
    "Exaggeration,Minimisation", "Flag-Waving", "Loaded_Language",
    "Reductio_ad_hitlerum", "Bandwagon",
    "Causal_Oversimplification", "Obfuscation,Intentional_Vagueness,Confusion", "Appeal_to_Authority", "Black-and-White_Fallacy",
    "Thought-terminating_Cliches", "Red_Herring", "Straw_Men", "Whataboutism"
)


SEQUENCE_TAGS = ("Non-prop", "Prop")


@dataclass
class TokenAndSequenceJointClassifierOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    token_logits: torch.FloatTensor = None
    sequence_logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


class BertForTokenAndSequenceJointClassification(BertPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.num_token_labels = 20
        self.num_sequence_labels = 2

        self.token_tags = TOKEN_TAGS
        self.sequence_tags = SEQUENCE_TAGS 

        self.alpha = 0.9

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.ModuleList([
            nn.Linear(config.hidden_size, self.num_token_labels),
            nn.Linear(config.hidden_size, self.num_sequence_labels),
        ])
        self.masking_gate = nn.Linear(2, 1)

        self.init_weights()
        self.merge_classifier_1 = nn.Linear(self.num_token_labels + self.num_sequence_labels, self.num_token_labels)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=True,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        sequence_output = outputs[0]
        pooler_output = outputs[1]

        sequence_output = self.dropout(sequence_output)
        token_logits = self.classifier[0](sequence_output)

        pooler_output = self.dropout(pooler_output)
        sequence_logits = self.classifier[1](pooler_output)

        gate = torch.sigmoid(self.masking_gate(sequence_logits))

        gates = gate.unsqueeze(1).repeat(1, token_logits.size()[1], token_logits.size()[2])

        weighted_token_logits = torch.mul(gates, token_logits)

        logits = [weighted_token_logits, sequence_logits]

        loss = None
        if labels is not None:
            criterion = nn.CrossEntropyLoss(ignore_index=0)
            binary_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([3932/14263]).cuda())
            loss_fct = CrossEntropyLoss()
            weighted_token_logits = weighted_token_logits.view(-1, weighted_token_logits.shape[-1])
            sequence_logits = sequence_logits.view(-1, sequence_logits.shape[-1])

            token_loss = criterion(weighted_token_logits, labels)
            sequence_label = torch.LongTensor([1] if any([label > 0 for label in labels]) else [0])
            sequence_loss = binary_criterion(sequence_logits, sequence_label)

            loss = self.alpha*loss[0] + (1-self.alpha)*loss[1]

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenAndSequenceJointClassifierOutput(
            loss=loss,
            token_logits=weighted_token_logits,
            sequence_logits=sequence_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )