Spaces:
Runtime error
Runtime error
update
Browse files- loss_metric.py +7 -4
loss_metric.py
CHANGED
@@ -15,8 +15,10 @@
|
|
15 |
|
16 |
import evaluate
|
17 |
import datasets
|
|
|
|
|
18 |
|
19 |
-
from torch import nn, Tensor
|
20 |
|
21 |
_CITATION = """\
|
22 |
@InProceedings{huggingface:module,
|
@@ -75,11 +77,12 @@ class LossMetric(evaluate.Metric):
|
|
75 |
|
76 |
def _compute(self, predictions, references):
|
77 |
"""Returns the scores"""
|
78 |
-
|
|
|
79 |
loss_func = nn.CrossEntropyLoss()
|
80 |
-
loss = loss_func(
|
81 |
|
82 |
-
mean_loss = loss.item() /
|
83 |
return {
|
84 |
"loss": mean_loss,
|
85 |
}
|
|
|
15 |
|
16 |
import evaluate
|
17 |
import datasets
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
|
21 |
+
from torch import nn, Tensor, tensor
|
22 |
|
23 |
_CITATION = """\
|
24 |
@InProceedings{huggingface:module,
|
|
|
77 |
|
78 |
def _compute(self, predictions, references):
|
79 |
"""Returns the scores"""
|
80 |
+
pred = tensor(np.array(predictions), dtype=torch.float16)
|
81 |
+
label = tensor(np.array(references), dtype=torch.float16)
|
82 |
loss_func = nn.CrossEntropyLoss()
|
83 |
+
loss = loss_func(pred, label)
|
84 |
|
85 |
+
mean_loss = loss.item() / label.shape[0]
|
86 |
return {
|
87 |
"loss": mean_loss,
|
88 |
}
|