Spaces:
Sleeping
Sleeping
added handling of multi dimension
Browse files- jaccard_similarity.py +18 -12
jaccard_similarity.py
CHANGED
@@ -77,17 +77,10 @@ class JaccardSimilarity(evaluate.Metric):
|
|
77 |
description=_DESCRIPTION,
|
78 |
citation=_CITATION,
|
79 |
inputs_description=_KWARGS_DESCRIPTION,
|
80 |
-
features=datasets.Features(
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
}
|
85 |
-
if self.config_name == "multilabel"
|
86 |
-
else {
|
87 |
-
"predictions": datasets.Value("int32"),
|
88 |
-
"references": datasets.Value("int32"),
|
89 |
-
}
|
90 |
-
),
|
91 |
reference_urls=[
|
92 |
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html",
|
93 |
"https://en.wikipedia.org/wiki/Jaccard_index"
|
@@ -95,7 +88,20 @@ class JaccardSimilarity(evaluate.Metric):
|
|
95 |
)
|
96 |
|
97 |
def _compute(self, predictions, references, labels=None, pos_label=1, average='binary', sample_weight=None, zero_division='warn'):
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
return {
|
100 |
"jaccard_similarity": jaccard_score(
|
101 |
references,
|
|
|
77 |
description=_DESCRIPTION,
|
78 |
citation=_CITATION,
|
79 |
inputs_description=_KWARGS_DESCRIPTION,
|
80 |
+
features=datasets.Features({
|
81 |
+
"predictions": datasets.Sequence(datasets.Value("int32")),
|
82 |
+
"references": datasets.Sequence(datasets.Value("int32")),
|
83 |
+
}),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
reference_urls=[
|
85 |
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html",
|
86 |
"https://en.wikipedia.org/wiki/Jaccard_index"
|
|
|
88 |
)
|
89 |
|
90 |
def _compute(self, predictions, references, labels=None, pos_label=1, average='binary', sample_weight=None, zero_division='warn'):
|
91 |
+
predictions = np.array(predictions)
|
92 |
+
references = np.array(references)
|
93 |
+
|
94 |
+
# Handle different input shapes
|
95 |
+
if predictions.ndim == 1 and references.ndim == 1:
|
96 |
+
# Binary or multiclass case
|
97 |
+
pass
|
98 |
+
elif predictions.ndim == 2 and references.ndim == 2:
|
99 |
+
# Multilabel case
|
100 |
+
if average == 'binary':
|
101 |
+
average = 'micro' # 'binary' doesn't make sense for multilabel
|
102 |
+
else:
|
103 |
+
raise ValueError("Predictions and references should have the same shape")
|
104 |
+
|
105 |
return {
|
106 |
"jaccard_similarity": jaccard_score(
|
107 |
references,
|