gsar78 commited on
Commit
6388076
·
verified ·
1 Parent(s): 7f88877

Update modeling_custom.py

Browse files
Files changed (1) hide show
  1. modeling_custom.py +15 -19
modeling_custom.py CHANGED
@@ -1,26 +1,27 @@
1
- from transformers import XLMRobertaForSequenceClassification
2
  import torch.nn as nn
3
  import torch
4
 
5
- class CustomModel(XLMRobertaForSequenceClassification):
6
- def __init__(self, config, num_emotion_labels):
7
- super(CustomModel, self).__init__(config)
 
 
8
  self.num_emotion_labels = num_emotion_labels
9
 
10
- # Freeze sentiment classifier parameters
11
- for param in self.classifier.parameters():
12
- param.requires_grad = False
13
-
14
- # Define emotion classifier
 
15
  self.dropout_emotion = nn.Dropout(config.hidden_dropout_prob)
16
  self.emotion_classifier = nn.Sequential(
17
  nn.Linear(config.hidden_size, 512),
18
  nn.Mish(),
19
  nn.Dropout(0.3),
20
- nn.Linear(512, num_emotion_labels)
21
  )
22
-
23
- # Initialize the weights of the new layers
24
  self._init_weights(self.emotion_classifier[0])
25
  self._init_weights(self.emotion_classifier[3])
26
 
@@ -33,22 +34,17 @@ class CustomModel(XLMRobertaForSequenceClassification):
33
  def forward(self, input_ids=None, attention_mask=None, sentiment=None, labels=None):
34
  outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
35
  sequence_output = outputs[0]
36
-
37
- # Select the CLS token for emotion classification
38
  cls_hidden_states = sequence_output[:, 0, :]
39
  cls_hidden_states = self.dropout_emotion(cls_hidden_states)
40
-
41
  emotion_logits = self.emotion_classifier(cls_hidden_states)
42
-
43
- # Sentiment logits from the frozen classifier
44
  with torch.no_grad():
45
  cls_token_state = sequence_output[:, 0, :].unsqueeze(1)
46
  sentiment_logits = self.classifier(cls_token_state).squeeze(1)
47
-
48
  if labels is not None:
49
  class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device)
50
  loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
51
  loss = loss_fct(emotion_logits, labels)
52
  return {"loss": loss, "emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
53
-
54
  return {"emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
 
1
+ from transformers import PretrainedConfig, XLMRobertaForSequenceClassification
2
  import torch.nn as nn
3
  import torch
4
 
5
+ class CustomConfig(PretrainedConfig):
6
+ model_type = "custom_model"
7
+
8
+ def __init__(self, num_emotion_labels=18, **kwargs):
9
+ super().__init__(**kwargs)
10
  self.num_emotion_labels = num_emotion_labels
11
 
12
+ class CustomModel(XLMRobertaForSequenceClassification):
13
+ config_class = CustomConfig
14
+
15
+ def __init__(self, config):
16
+ super(CustomModel, self).__init__(config)
17
+ self.num_emotion_labels = config.num_emotion_labels
18
  self.dropout_emotion = nn.Dropout(config.hidden_dropout_prob)
19
  self.emotion_classifier = nn.Sequential(
20
  nn.Linear(config.hidden_size, 512),
21
  nn.Mish(),
22
  nn.Dropout(0.3),
23
+ nn.Linear(512, self.num_emotion_labels)
24
  )
 
 
25
  self._init_weights(self.emotion_classifier[0])
26
  self._init_weights(self.emotion_classifier[3])
27
 
 
34
  def forward(self, input_ids=None, attention_mask=None, sentiment=None, labels=None):
35
  outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
36
  sequence_output = outputs[0]
37
+ if len(sequence_output.shape) != 3:
38
+ raise ValueError(f"Expected sequence_output to have 3 dimensions, got {sequence_output.shape}")
39
  cls_hidden_states = sequence_output[:, 0, :]
40
  cls_hidden_states = self.dropout_emotion(cls_hidden_states)
 
41
  emotion_logits = self.emotion_classifier(cls_hidden_states)
 
 
42
  with torch.no_grad():
43
  cls_token_state = sequence_output[:, 0, :].unsqueeze(1)
44
  sentiment_logits = self.classifier(cls_token_state).squeeze(1)
 
45
  if labels is not None:
46
  class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device)
47
  loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
48
  loss = loss_fct(emotion_logits, labels)
49
  return {"loss": loss, "emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
 
50
  return {"emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}