gsar78 commited on
Commit
07f7aae
·
verified ·
1 Parent(s): 1d7a81c

Create custom_model.py

Browse files
custom_model_package/custom_model.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, XLMRobertaForSequenceClassification
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ class CustomConfig(PretrainedConfig):
6
+ model_type = "custom_model"
7
+
8
+ def __init__(self, num_emotion_labels=18, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self.num_emotion_labels = num_emotion_labels
11
+
12
+ class CustomModel(XLMRobertaForSequenceClassification):
13
+ config_class = CustomConfig
14
+
15
+ def __init__(self, config):
16
+ super(CustomModel, self).__init__(config)
17
+ self.num_emotion_labels = config.num_emotion_labels
18
+ self.dropout_emotion = nn.Dropout(config.hidden_dropout_prob)
19
+ self.emotion_classifier = nn.Sequential(
20
+ nn.Linear(config.hidden_size, 512),
21
+ nn.Mish(),
22
+ nn.Dropout(0.3),
23
+ nn.Linear(512, self.num_emotion_labels)
24
+ )
25
+ self._init_weights(self.emotion_classifier[0])
26
+ self._init_weights(self.emotion_classifier[3])
27
+
28
+ def _init_weights(self, module):
29
+ if isinstance(module, nn.Linear):
30
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
31
+ if module.bias is not None:
32
+ module.bias.data.zero_()
33
+
34
+ def forward(self, input_ids=None, attention_mask=None, sentiment=None, labels=None):
35
+ outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
36
+ sequence_output = outputs[0]
37
+ if len(sequence_output.shape) != 3:
38
+ raise ValueError(f"Expected sequence_output to have 3 dimensions, got {sequence_output.shape}")
39
+ cls_hidden_states = sequence_output[:, 0, :]
40
+ cls_hidden_states = self.dropout_emotion(cls_hidden_states)
41
+ emotion_logits = self.emotion_classifier(cls_hidden_states)
42
+
43
+ # Sentiment logits from the original classifier
44
+ sentiment_logits = self.classifier(cls_hidden_states)
45
+
46
+ # Concatenate the sentiment and emotion logits
47
+ logits = torch.cat([sentiment_logits, emotion_logits], dim=-1)
48
+
49
+ if labels is not None:
50
+ class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device)
51
+ loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
52
+ loss = loss_fct(emotion_logits, labels)
53
+ return {"loss": loss, "logits": logits}
54
+ return {"logits": logits}