Ruchin commited on
Commit
3addf82
1 Parent(s): fd1c40b

added handling of multi dimension

Browse files
Files changed (1) hide show
  1. jaccard_similarity.py +18 -12
jaccard_similarity.py CHANGED
@@ -77,17 +77,10 @@ class JaccardSimilarity(evaluate.Metric):
77
  description=_DESCRIPTION,
78
  citation=_CITATION,
79
  inputs_description=_KWARGS_DESCRIPTION,
80
- features=datasets.Features(
81
- {
82
- "predictions": datasets.Sequence(datasets.Value("int32")),
83
- "references": datasets.Sequence(datasets.Value("int32")),
84
- }
85
- if self.config_name == "multilabel"
86
- else {
87
- "predictions": datasets.Value("int32"),
88
- "references": datasets.Value("int32"),
89
- }
90
- ),
91
  reference_urls=[
92
  "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html",
93
  "https://en.wikipedia.org/wiki/Jaccard_index"
@@ -95,7 +88,20 @@ class JaccardSimilarity(evaluate.Metric):
95
  )
96
 
97
  def _compute(self, predictions, references, labels=None, pos_label=1, average='binary', sample_weight=None, zero_division='warn'):
98
- """Returns the Jaccard similarity score using scikit-learn"""
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  return {
100
  "jaccard_similarity": jaccard_score(
101
  references,
 
77
  description=_DESCRIPTION,
78
  citation=_CITATION,
79
  inputs_description=_KWARGS_DESCRIPTION,
80
+ features=datasets.Features({
81
+ "predictions": datasets.Sequence(datasets.Value("int32")),
82
+ "references": datasets.Sequence(datasets.Value("int32")),
83
+ }),
 
 
 
 
 
 
 
84
  reference_urls=[
85
  "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html",
86
  "https://en.wikipedia.org/wiki/Jaccard_index"
 
88
  )
89
 
90
  def _compute(self, predictions, references, labels=None, pos_label=1, average='binary', sample_weight=None, zero_division='warn'):
91
+ predictions = np.array(predictions)
92
+ references = np.array(references)
93
+
94
+ # Handle different input shapes
95
+ if predictions.ndim == 1 and references.ndim == 1:
96
+ # Binary or multiclass case
97
+ pass
98
+ elif predictions.ndim == 2 and references.ndim == 2:
99
+ # Multilabel case
100
+ if average == 'binary':
101
+ average = 'micro' # 'binary' doesn't make sense for multilabel
102
+ else:
103
+ raise ValueError("Predictions and references should have the same shape")
104
+
105
  return {
106
  "jaccard_similarity": jaccard_score(
107
  references,