liwii commited on
Commit
9cc51df
·
verified ·
1 Parent(s): 21982ea

Training in progress, epoch 1

Browse files
Files changed (3) hide show
  1. pytorch_model.bin +1 -1
  2. training_args.bin +1 -1
  3. utils.py +9 -2
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d736705fafe165d048f591207c7e984739e7412080ec89b015a457b5084baba6
3
  size 274752173
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8d0b025faa8c33ff9e377a0ecb3a8066849218e66d79dabdb9c9e1289d3a014
3
  size 274752173
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cc65dbe5e03af7e214c614c19b0fda5a22554b261cb74cde5f58af6a97861002
3
  size 4027
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b65d405be15213e29d93bcec6b150b9a9f69c52594a891222bfb4fa877f3ab74
3
  size 4027
utils.py CHANGED
@@ -57,11 +57,14 @@ class ConsistentSentenceRegressor(DistilBertForSequenceClassification):
57
  # Replace the classifier with a single-neuron linear layer for regression
58
  self.classifier = torch.nn.Linear(config.dim, config.num_labels)
59
 
 
 
60
  if not freeze_bert:
61
  return
62
 
63
  for param in self.distilbert.parameters():
64
  param.requires_grad = False
 
65
 
66
  def forward(
67
  self,
@@ -86,9 +89,13 @@ class ConsistentSentenceRegressor(DistilBertForSequenceClassification):
86
  )
87
 
88
  logits = outputs.logits.squeeze(-1) # Remove the last dimension to match target tensor shape
89
-
 
 
 
 
90
 
91
- return logits
92
 
93
 
94
  # Set up evaluation metridef get_metrics():
 
57
  # Replace the classifier with a single-neuron linear layer for regression
58
  self.classifier = torch.nn.Linear(config.dim, config.num_labels)
59
 
60
+ self.loss_fn = torch.nn.MSELoss()
61
+
62
  if not freeze_bert:
63
  return
64
 
65
  for param in self.distilbert.parameters():
66
  param.requires_grad = False
67
+
68
 
69
  def forward(
70
  self,
 
89
  )
90
 
91
  logits = outputs.logits.squeeze(-1) # Remove the last dimension to match target tensor shape
92
+ outputs.logits = logits
93
+ if labels is not None:
94
+ # Compute custom loss
95
+ loss = self.loss_fn(logits, labels)
96
+ outputs.loss = loss
97
 
98
+ return outputs
99
 
100
 
101
  # Set up evaluation metridef get_metrics():