Spaces:
Build error
Build error
File size: 4,508 Bytes
26dff99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
__author__ = "Yifan Zhang ([email protected])"
__copyright__ = "Copyright (C) 2021, Qatar Computing Research Institute, HBKU, Doha"
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn.functional import sigmoid
from transformers import BertPreTrainedModel, BertModel
from transformers.file_utils import ModelOutput
TOKEN_TAGS = (
"<PAD>", "O",
"Name_Calling,Labeling", "Repetition", "Slogans", "Appeal_to_fear-prejudice", "Doubt",
"Exaggeration,Minimisation", "Flag-Waving", "Loaded_Language",
"Reductio_ad_hitlerum", "Bandwagon",
"Causal_Oversimplification", "Obfuscation,Intentional_Vagueness,Confusion", "Appeal_to_Authority", "Black-and-White_Fallacy",
"Thought-terminating_Cliches", "Red_Herring", "Straw_Men", "Whataboutism"
)
SEQUENCE_TAGS = ("Non-prop", "Prop")
@dataclass
class TokenAndSequenceJointClassifierOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
token_logits: torch.FloatTensor = None
sequence_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class BertForTokenAndSequenceJointClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_token_labels = 20
self.num_sequence_labels = 2
self.token_tags = TOKEN_TAGS
self.sequence_tags = SEQUENCE_TAGS
self.alpha = 0.9
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.ModuleList([
nn.Linear(config.hidden_size, self.num_token_labels),
nn.Linear(config.hidden_size, self.num_sequence_labels),
])
self.masking_gate = nn.Linear(2, 1)
self.init_weights()
self.merge_classifier_1 = nn.Linear(self.num_token_labels + self.num_sequence_labels, self.num_token_labels)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=True,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
pooler_output = outputs[1]
sequence_output = self.dropout(sequence_output)
token_logits = self.classifier[0](sequence_output)
pooler_output = self.dropout(pooler_output)
sequence_logits = self.classifier[1](pooler_output)
gate = torch.sigmoid(self.masking_gate(sequence_logits))
gates = gate.unsqueeze(1).repeat(1, token_logits.size()[1], token_logits.size()[2])
weighted_token_logits = torch.mul(gates, token_logits)
logits = [weighted_token_logits, sequence_logits]
loss = None
if labels is not None:
criterion = nn.CrossEntropyLoss(ignore_index=0)
binary_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([3932/14263]).cuda())
loss_fct = CrossEntropyLoss()
weighted_token_logits = weighted_token_logits.view(-1, weighted_token_logits.shape[-1])
sequence_logits = sequence_logits.view(-1, sequence_logits.shape[-1])
token_loss = criterion(weighted_token_logits, labels)
sequence_label = torch.LongTensor([1] if any([label > 0 for label in labels]) else [0])
sequence_loss = binary_criterion(sequence_logits, sequence_label)
loss = self.alpha*loss[0] + (1-self.alpha)*loss[1]
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenAndSequenceJointClassifierOutput(
loss=loss,
token_logits=weighted_token_logits,
sequence_logits=sequence_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|