Spaces:
Runtime error
Runtime error
davebulaval
commited on
Commit
Β·
46c63a3
1
Parent(s):
008aa62
add fix for matching elements
Browse files- meaningbert.py +6 -0
meaningbert.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14 |
""" MeaningBERT metric. """
|
15 |
|
16 |
from contextlib import contextmanager
|
|
|
17 |
from typing import List, Dict
|
18 |
|
19 |
import datasets
|
@@ -118,6 +119,7 @@ class MeaningBERT(evaluate.Metric):
|
|
118 |
), "The number of document is different of the number of simplifications."
|
119 |
hashcode = _HASH
|
120 |
|
|
|
121 |
matching_index = [
|
122 |
i for i, item in enumerate(documents) if item in simplifications
|
123 |
]
|
@@ -146,6 +148,10 @@ class MeaningBERT(evaluate.Metric):
|
|
146 |
|
147 |
scores = scores.logits.tolist()
|
148 |
|
|
|
|
|
|
|
|
|
149 |
if len(matching_index) > 0:
|
150 |
for matching_element_index in matching_index:
|
151 |
scores[matching_element_index] = 100
|
|
|
14 |
""" MeaningBERT metric. """
|
15 |
|
16 |
from contextlib import contextmanager
|
17 |
+
from itertools import chain
|
18 |
from typing import List, Dict
|
19 |
|
20 |
import datasets
|
|
|
119 |
), "The number of document is different of the number of simplifications."
|
120 |
hashcode = _HASH
|
121 |
|
122 |
+
# Index of sentence with perfect match between two sentences
|
123 |
matching_index = [
|
124 |
i for i, item in enumerate(documents) if item in simplifications
|
125 |
]
|
|
|
148 |
|
149 |
scores = scores.logits.tolist()
|
150 |
|
151 |
+
# Flatten the list of list of logits
|
152 |
+
scores = list(chain(*scores))
|
153 |
+
|
154 |
+
# Handle case of perfect match
|
155 |
if len(matching_index) > 0:
|
156 |
for matching_element_index in matching_index:
|
157 |
scores[matching_element_index] = 100
|