Aye10032 commited on
Commit
1d6fdc1
·
1 Parent(s): bcb7c5c
Files changed (1) hide show
  1. loss_metric.py +5 -8
loss_metric.py CHANGED
@@ -16,7 +16,7 @@
16
  import evaluate
17
  import datasets
18
 
19
- from torch import nn
20
 
21
  _CITATION = """\
22
  @InProceedings{huggingface:module,
@@ -55,7 +55,6 @@ class loss_metric(evaluate.Metric):
55
  """Calculation of the cross-entropy loss function using the huggingface evaluate module."""
56
 
57
  def _info(self):
58
- # TODO: Specifies the evaluate.EvaluationModuleInfo object
59
  return evaluate.MetricInfo(
60
  # This is the description that will appear on the modules page.
61
  module_type="metric",
@@ -74,14 +73,12 @@ class loss_metric(evaluate.Metric):
74
  reference_urls=["http://path.to.reference.url/new_module"]
75
  )
76
 
77
- def _download_and_prepare(self, dl_manager):
78
- """Optional: download external resources useful to compute the scores"""
79
- pass
80
-
81
- def _compute(self, predictions, references):
82
  """Returns the scores"""
83
  loss_func = nn.CrossEntropyLoss()
84
  loss = loss_func(predictions, references)
 
 
85
  return {
86
- "loss": loss,
87
  }
 
16
  import evaluate
17
  import datasets
18
 
19
+ from torch import nn, Tensor
20
 
21
  _CITATION = """\
22
  @InProceedings{huggingface:module,
 
55
  """Calculation of the cross-entropy loss function using the huggingface evaluate module."""
56
 
57
  def _info(self):
 
58
  return evaluate.MetricInfo(
59
  # This is the description that will appear on the modules page.
60
  module_type="metric",
 
73
  reference_urls=["http://path.to.reference.url/new_module"]
74
  )
75
 
76
+ def _compute(self, predictions: Tensor, references: Tensor):
 
 
 
 
77
  """Returns the scores"""
78
  loss_func = nn.CrossEntropyLoss()
79
  loss = loss_func(predictions, references)
80
+
81
+ mean_loss = loss.item() / references.shape[0]
82
  return {
83
+ "loss": mean_loss,
84
  }