Training in progress, epoch 1
Browse files- pytorch_model.bin +1 -1
- training_args.bin +1 -1
- utils.py +9 -2
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 274752173
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8d0b025faa8c33ff9e377a0ecb3a8066849218e66d79dabdb9c9e1289d3a014
|
3 |
size 274752173
|
training_args.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4027
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b65d405be15213e29d93bcec6b150b9a9f69c52594a891222bfb4fa877f3ab74
|
3 |
size 4027
|
utils.py
CHANGED
@@ -57,11 +57,14 @@ class ConsistentSentenceRegressor(DistilBertForSequenceClassification):
|
|
57 |
# Replace the classifier with a single-neuron linear layer for regression
|
58 |
self.classifier = torch.nn.Linear(config.dim, config.num_labels)
|
59 |
|
|
|
|
|
60 |
if not freeze_bert:
|
61 |
return
|
62 |
|
63 |
for param in self.distilbert.parameters():
|
64 |
param.requires_grad = False
|
|
|
65 |
|
66 |
def forward(
|
67 |
self,
|
@@ -86,9 +89,13 @@ class ConsistentSentenceRegressor(DistilBertForSequenceClassification):
|
|
86 |
)
|
87 |
|
88 |
logits = outputs.logits.squeeze(-1) # Remove the last dimension to match target tensor shape
|
89 |
-
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
return
|
92 |
|
93 |
|
94 |
# Set up evaluation metridef get_metrics():
|
|
|
57 |
# Replace the classifier with a single-neuron linear layer for regression
|
58 |
self.classifier = torch.nn.Linear(config.dim, config.num_labels)
|
59 |
|
60 |
+
self.loss_fn = torch.nn.MSELoss()
|
61 |
+
|
62 |
if not freeze_bert:
|
63 |
return
|
64 |
|
65 |
for param in self.distilbert.parameters():
|
66 |
param.requires_grad = False
|
67 |
+
|
68 |
|
69 |
def forward(
|
70 |
self,
|
|
|
89 |
)
|
90 |
|
91 |
logits = outputs.logits.squeeze(-1) # Remove the last dimension to match target tensor shape
|
92 |
+
outputs.logits = logits
|
93 |
+
if labels is not None:
|
94 |
+
# Compute custom loss
|
95 |
+
loss = self.loss_fn(logits, labels)
|
96 |
+
outputs.loss = loss
|
97 |
|
98 |
+
return outputs
|
99 |
|
100 |
|
101 |
# Set up evaluation metridef get_metrics():
|