Aye10032 commited on
Commit
c81d7c7
·
1 Parent(s): 629cd1c
Files changed (1) hide show
  1. loss_metric.py +7 -4
loss_metric.py CHANGED
@@ -15,8 +15,10 @@
15
 
16
  import evaluate
17
  import datasets
 
 
18
 
19
- from torch import nn, Tensor
20
 
21
  _CITATION = """\
22
  @InProceedings{huggingface:module,
@@ -75,11 +77,12 @@ class LossMetric(evaluate.Metric):
75
 
76
  def _compute(self, predictions, references):
77
  """Returns the scores"""
78
- print(predictions)
 
79
  loss_func = nn.CrossEntropyLoss()
80
- loss = loss_func(predictions, references)
81
 
82
- mean_loss = loss.item() / references.shape[0]
83
  return {
84
  "loss": mean_loss,
85
  }
 
15
 
16
  import evaluate
17
  import datasets
18
+ import numpy as np
19
+ import torch
20
 
21
+ from torch import nn, Tensor, tensor
22
 
23
  _CITATION = """\
24
  @InProceedings{huggingface:module,
 
77
 
78
  def _compute(self, predictions, references):
79
  """Returns the scores"""
80
+ pred = tensor(np.array(predictions), dtype=torch.float16)
81
+ label = tensor(np.array(references), dtype=torch.float16)
82
  loss_func = nn.CrossEntropyLoss()
83
+ loss = loss_func(pred, label)
84
 
85
+ mean_loss = loss.item() / label.shape[0]
86
  return {
87
  "loss": mean_loss,
88
  }