gsar78 commited on
Commit
528fc07
·
verified ·
1 Parent(s): 3588c34

Update custom_model_package/custom_model.py

Browse files
custom_model_package/custom_model.py CHANGED
@@ -42,13 +42,13 @@ class CustomModel(XLMRobertaForSequenceClassification):
42
  with torch.no_grad():
43
  cls_token_state = sequence_output[:, 0, :].unsqueeze(1)
44
  sentiment_logits = self.classifier(cls_token_state).squeeze(1)
45
- logits = torch.cat([sentiment_logits, emotion_logits], dim=-1)
46
  if labels is not None:
47
  class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device)
48
  loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
49
  loss = loss_fct(emotion_logits, labels)
50
- return {"loss": loss, "logits": logits}
51
- return {"logits": logits}
52
 
53
  @classmethod
54
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
 
42
  with torch.no_grad():
43
  cls_token_state = sequence_output[:, 0, :].unsqueeze(1)
44
  sentiment_logits = self.classifier(cls_token_state).squeeze(1)
45
+ #logits = torch.cat([sentiment_logits, emotion_logits], dim=-1)
46
  if labels is not None:
47
  class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device)
48
  loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
49
  loss = loss_fct(emotion_logits, labels)
50
+ return {"loss": loss, "emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
51
+ return {"emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
52
 
53
  @classmethod
54
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):