bdsaglam commited on
Commit
e31d84c
·
1 Parent(s): 4097b95

update jer metric to add equality operation argument

Browse files
Files changed (1) hide show
  1. jer.py +22 -16
jer.py CHANGED
@@ -13,6 +13,7 @@
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
 
16
  from typing import Iterable
17
 
18
  import evaluate
@@ -43,17 +44,16 @@ Args:
43
  should be a string with tokens separated by spaces.
44
  references: list of reference for each prediction. Each
45
  reference should be a string with tokens separated by spaces.
 
46
  Returns:
47
- accuracy: description of the first score,
48
- another_score: description of the second score,
 
49
  Examples:
50
- Examples should be written in doctest format, and should illustrate how
51
- to use the function.
52
-
53
- >>> my_new_module = evaluate.load("my_new_module")
54
- >>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
55
  >>> print(results)
56
- {'accuracy': 1.0}
57
  """
58
 
59
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
@@ -85,7 +85,7 @@ class jer(evaluate.Metric):
85
  # TODO: Download external resources if needed
86
  pass
87
 
88
- def _compute(self, predictions, references):
89
  """Returns the scores"""
90
  score_dicts = [
91
  self._compute_single(prediction=prediction, reference=reference)
@@ -93,22 +93,28 @@ class jer(evaluate.Metric):
93
  ]
94
  return {('mean_' + key): np.mean([scores[key] for scores in score_dicts]) for key in score_dicts[0].keys()}
95
 
96
- def _compute_single(self, *, prediction: Iterable[str | tuple | int], reference: Iterable[str | tuple | int]):
97
  reference_set = set(reference)
98
  assert len(reference) == len(reference_set), f"Duplicates found in the reference list {reference}"
99
  prediction_set = set(prediction)
100
 
101
- TP = len(reference_set & prediction_set)
102
- FP = len(prediction_set - reference_set)
103
- FN = len(reference_set - prediction_set)
104
 
105
  # Calculate metrics
106
- precision = TP / (TP + FP) if TP + FP > 0 else 0
107
- recall = TP / (TP + FN) if TP + FN > 0 else 0
108
  f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
109
 
110
  return {
111
  'precision': precision,
112
  'recall': recall,
113
  'f1': f1_score
114
- }
 
 
 
 
 
 
 
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
16
+ from operator import eq
17
  from typing import Iterable
18
 
19
  import evaluate
 
44
  should be a string with tokens separated by spaces.
45
  references: list of reference for each prediction. Each
46
  reference should be a string with tokens separated by spaces.
47
+ eq_fn: function to compare two items. Defaults to the equality operator.
48
  Returns:
49
+ recall:
50
+ precision:
51
+ f1:
52
  Examples:
53
+ >>> jer = evaluate.load("jer")
54
+ >>> results = jer.compute(references=[["Baris | play | tennis", "Deniz | travel | London"]], predictions=[["Baris | play | tennis"]])
 
 
 
55
  >>> print(results)
56
+ {'recall': 0.5, 'precision': 1.0, 'f1': 0.6666666666666666}
57
  """
58
 
59
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
 
85
  # TODO: Download external resources if needed
86
  pass
87
 
88
+ def _compute(self, predictions, references, eq_fn=eq):
89
  """Returns the scores"""
90
  score_dicts = [
91
  self._compute_single(prediction=prediction, reference=reference)
 
93
  ]
94
  return {('mean_' + key): np.mean([scores[key] for scores in score_dicts]) for key in score_dicts[0].keys()}
95
 
96
+ def _compute_single(self, *, prediction: Iterable[str | tuple | int], reference: Iterable[str | tuple | int], eq_fn=eq):
97
  reference_set = set(reference)
98
  assert len(reference) == len(reference_set), f"Duplicates found in the reference list {reference}"
99
  prediction_set = set(prediction)
100
 
101
+ tp = sum(int(is_in(item, prediction, eq_fn=eq_fn)) for item in reference)
102
+ fp = len(prediction_set) - tp
103
+ fn = len(reference_set) - tp
104
 
105
  # Calculate metrics
106
+ precision = tp / (tp + fp) if tp + fp > 0 else 0
107
+ recall = tp / (tp + fn) if tp + fn > 0 else 0
108
  f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
109
 
110
  return {
111
  'precision': precision,
112
  'recall': recall,
113
  'f1': f1_score
114
+ }
115
+
116
+ def is_in(target, collection: Iterable, eq_fn=eq) -> bool:
117
+ for item in collection:
118
+ if eq_fn(item, target):
119
+ return True
120
+ return False