Spaces:
Runtime error
Runtime error
update
Browse files- loss_metric.py +5 -8
loss_metric.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16 |
import evaluate
|
17 |
import datasets
|
18 |
|
19 |
-
from torch import nn
|
20 |
|
21 |
_CITATION = """\
|
22 |
@InProceedings{huggingface:module,
|
@@ -55,7 +55,6 @@ class loss_metric(evaluate.Metric):
|
|
55 |
"""Calculation of the cross-entropy loss function using the huggingface evaluate module."""
|
56 |
|
57 |
def _info(self):
|
58 |
-
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
59 |
return evaluate.MetricInfo(
|
60 |
# This is the description that will appear on the modules page.
|
61 |
module_type="metric",
|
@@ -74,14 +73,12 @@ class loss_metric(evaluate.Metric):
|
|
74 |
reference_urls=["http://path.to.reference.url/new_module"]
|
75 |
)
|
76 |
|
77 |
-
def
|
78 |
-
"""Optional: download external resources useful to compute the scores"""
|
79 |
-
pass
|
80 |
-
|
81 |
-
def _compute(self, predictions, references):
|
82 |
"""Returns the scores"""
|
83 |
loss_func = nn.CrossEntropyLoss()
|
84 |
loss = loss_func(predictions, references)
|
|
|
|
|
85 |
return {
|
86 |
-
"loss":
|
87 |
}
|
|
|
16 |
import evaluate
|
17 |
import datasets
|
18 |
|
19 |
+
from torch import nn, Tensor
|
20 |
|
21 |
_CITATION = """\
|
22 |
@InProceedings{huggingface:module,
|
|
|
55 |
"""Calculation of the cross-entropy loss function using the huggingface evaluate module."""
|
56 |
|
57 |
def _info(self):
|
|
|
58 |
return evaluate.MetricInfo(
|
59 |
# This is the description that will appear on the modules page.
|
60 |
module_type="metric",
|
|
|
73 |
reference_urls=["http://path.to.reference.url/new_module"]
|
74 |
)
|
75 |
|
76 |
+
def _compute(self, predictions: Tensor, references: Tensor):
|
|
|
|
|
|
|
|
|
77 |
"""Returns the scores"""
|
78 |
loss_func = nn.CrossEntropyLoss()
|
79 |
loss = loss_func(predictions, references)
|
80 |
+
|
81 |
+
mean_loss = loss.item() / references.shape[0]
|
82 |
return {
|
83 |
+
"loss": mean_loss,
|
84 |
}
|