mtzig commited on
Commit
9b45d91
1 Parent(s): 61bf5a6

updates format

Browse files
Files changed (1) hide show
  1. cross_entropy_loss.py +2 -2
cross_entropy_loss.py CHANGED
@@ -63,8 +63,8 @@ class cross_entropy_loss(evaluate.Metric):
63
  def _compute(self, prediction_scores, references):
64
  """Returns the scores"""
65
 
66
- loss = F.cross_entropy(input=torch.from_numpy(prediction_scores).flatten(start_dim=0, end_dim=1),
67
- target=torch.from_numpy(references).flatten(),
68
  ignore_index=-100).item()
69
  return {
70
  "cross_entropy_loss": loss
 
63
  def _compute(self, prediction_scores, references):
64
  """Returns the scores"""
65
 
66
+ loss = F.cross_entropy(input=torch.from_numpy(prediction_scores),
67
+ target=torch.from_numpy(references),
68
  ignore_index=-100).item()
69
  return {
70
  "cross_entropy_loss": loss