wanadzhar913 commited on
Commit
d751d99
·
verified ·
1 Parent(s): 3d03da6

Add `classifier.py`

Browse files
Files changed (1) hide show
  1. classifier.py +118 -0
classifier.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers.modeling_outputs import SequenceClassifierOutput
6
+ from transformers import AutoTokenizer, DebertaV2Model, \
7
+ DebertaV2ForSequenceClassification
8
+
9
+ class ClassificationHead(nn.Module):
10
+ """Head for sentence-level classification tasks."""
11
+
12
+ def __init__(self, config, num_extra_dims):
13
+ super().__init__()
14
+ total_dims = config.hidden_size+num_extra_dims
15
+ self.dense = nn.Linear(total_dims, total_dims)
16
+ classifier_dropout = config.hidden_dropout_prob
17
+ self.dropout = nn.Dropout(classifier_dropout)
18
+ self.out_proj = nn.Linear(total_dims, config.num_labels)
19
+
20
+ def forward(self, features, **kwargs):
21
+ x = self.dropout(features)
22
+ x = self.dense(x)
23
+ x = torch.tanh(x)
24
+ x = self.dropout(x)
25
+ x = self.out_proj(x)
26
+ return x
27
+
28
+ class CustomSequenceClassification(DebertaV2ForSequenceClassification):
29
+
30
+ def __init__(self, config, num_extra_dims):
31
+ super().__init__(config)
32
+ self.num_labels = config.num_labels
33
+ self.config = config
34
+
35
+ # might need to rename this depending on the model
36
+ self.deberta = DebertaV2Model(config)
37
+ self.classifier = ClassificationHead(config, num_extra_dims)
38
+
39
+ # Initialize weights and apply final processing
40
+ self.post_init()
41
+
42
+
43
+ def forward(
44
+ self,
45
+ input_ids: Optional[torch.LongTensor] = None,
46
+ attention_mask: Optional[torch.FloatTensor] = None,
47
+ extra_data: Optional[torch.FloatTensor] = None,
48
+ token_type_ids: Optional[torch.LongTensor] = None,
49
+ position_ids: Optional[torch.LongTensor] = None,
50
+ head_mask: Optional[torch.FloatTensor] = None,
51
+ inputs_embeds: Optional[torch.FloatTensor] = None,
52
+ labels: Optional[torch.LongTensor] = None,
53
+ output_attentions: Optional[bool] = None,
54
+ output_hidden_states: Optional[bool] = None,
55
+ return_dict: Optional[bool] = None,
56
+ ) -> Union[Tuple, SequenceClassifierOutput]:
57
+ r"""
58
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
59
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
60
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
61
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
62
+ """
63
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
64
+
65
+ outputs = self.deberta(
66
+ input_ids,
67
+ attention_mask=attention_mask,
68
+ token_type_ids=token_type_ids,
69
+ position_ids=position_ids,
70
+ inputs_embeds=inputs_embeds,
71
+ output_attentions=output_attentions,
72
+ output_hidden_states=output_hidden_states,
73
+ return_dict=return_dict,
74
+ )
75
+
76
+ # sequence_output will be (batch_size, seq_length, hidden_size)
77
+ sequence_output = outputs[0]
78
+
79
+ # additional data should be (batch_size, num_extra_dims)
80
+ cls_embedding = sequence_output[:, 0, :]
81
+
82
+ output = torch.cat((cls_embedding, extra_data), dim=-1)
83
+
84
+ logits = self.classifier(output)
85
+
86
+ loss = None
87
+ if labels is not None:
88
+ if self.config.problem_type is None:
89
+ if self.num_labels == 1:
90
+ self.config.problem_type = "regression"
91
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
92
+ self.config.problem_type = "single_label_classification"
93
+ else:
94
+ self.config.problem_type = "multi_label_classification"
95
+
96
+ if self.config.problem_type == "regression":
97
+ loss_fct = nn.MSELoss()
98
+ if self.num_labels == 1:
99
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
100
+ else:
101
+ loss = loss_fct(logits, labels)
102
+ elif self.config.problem_type == "single_label_classification":
103
+ loss_fct = nn.CrossEntropyLoss()
104
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
105
+ elif self.config.problem_type == "multi_label_classification":
106
+ loss_fct = nn.BCEWithLogitsLoss()
107
+ loss = loss_fct(logits, labels)
108
+
109
+ if not return_dict:
110
+ output = (logits,) + outputs[2:]
111
+ return ((loss,) + output) if loss is not None else output
112
+
113
+ return SequenceClassifierOutput(
114
+ loss=loss,
115
+ logits=logits,
116
+ hidden_states=outputs.hidden_states,
117
+ attentions=outputs.attentions,
118
+ )