Spaces:
Runtime error
Runtime error
DarrenChensformer
commited on
Commit
Β·
749b801
1
Parent(s):
ba6a59b
Fix nothing return
Browse files- relation_extraction.py +9 -7
relation_extraction.py
CHANGED
@@ -87,7 +87,7 @@ class relation_extraction(evaluate.Metric):
|
|
87 |
# TODO: Download external resources if needed
|
88 |
pass
|
89 |
|
90 |
-
def _compute(self,
|
91 |
"""Returns the scores"""
|
92 |
# TODO: Compute the different scores of the module
|
93 |
|
@@ -95,7 +95,7 @@ class relation_extraction(evaluate.Metric):
|
|
95 |
|
96 |
# construct relation_types from ground truth if not given
|
97 |
if len(relation_types) == 0:
|
98 |
-
for triplets in
|
99 |
for triplet in triplets:
|
100 |
relation = triplet["type"]
|
101 |
if relation not in relation_types:
|
@@ -104,12 +104,12 @@ class relation_extraction(evaluate.Metric):
|
|
104 |
scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in relation_types + ["ALL"]}
|
105 |
|
106 |
# Count GT relations and Predicted relations
|
107 |
-
n_sents = len(
|
108 |
-
n_rels = sum([len([rel for rel in sent]) for sent in
|
109 |
-
n_found = sum([len([rel for rel in sent]) for sent in
|
110 |
|
111 |
# Count TP, FP and FN per type
|
112 |
-
for pred_sent, gt_sent in zip(
|
113 |
for rel_type in relation_types:
|
114 |
# strict mode takes argument types into account
|
115 |
if mode == "strict":
|
@@ -164,4 +164,6 @@ class relation_extraction(evaluate.Metric):
|
|
164 |
# Compute Macro F1 Scores
|
165 |
scores["ALL"]["Macro_f1"] = np.mean([scores[ent_type]["f1"] for ent_type in relation_types])
|
166 |
scores["ALL"]["Macro_p"] = np.mean([scores[ent_type]["p"] for ent_type in relation_types])
|
167 |
-
scores["ALL"]["Macro_r"] = np.mean([scores[ent_type]["r"] for ent_type in relation_types])
|
|
|
|
|
|
87 |
# TODO: Download external resources if needed
|
88 |
pass
|
89 |
|
90 |
+
def _compute(self, predictions, references, mode="strict", relation_types=[]):
|
91 |
"""Returns the scores"""
|
92 |
# TODO: Compute the different scores of the module
|
93 |
|
|
|
95 |
|
96 |
# construct relation_types from ground truth if not given
|
97 |
if len(relation_types) == 0:
|
98 |
+
for triplets in references:
|
99 |
for triplet in triplets:
|
100 |
relation = triplet["type"]
|
101 |
if relation not in relation_types:
|
|
|
104 |
scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in relation_types + ["ALL"]}
|
105 |
|
106 |
# Count GT relations and Predicted relations
|
107 |
+
n_sents = len(references)
|
108 |
+
n_rels = sum([len([rel for rel in sent]) for sent in references])
|
109 |
+
n_found = sum([len([rel for rel in sent]) for sent in predictions])
|
110 |
|
111 |
# Count TP, FP and FN per type
|
112 |
+
for pred_sent, gt_sent in zip(predictions, references):
|
113 |
for rel_type in relation_types:
|
114 |
# strict mode takes argument types into account
|
115 |
if mode == "strict":
|
|
|
164 |
# Compute Macro F1 Scores
|
165 |
scores["ALL"]["Macro_f1"] = np.mean([scores[ent_type]["f1"] for ent_type in relation_types])
|
166 |
scores["ALL"]["Macro_p"] = np.mean([scores[ent_type]["p"] for ent_type in relation_types])
|
167 |
+
scores["ALL"]["Macro_r"] = np.mean([scores[ent_type]["r"] for ent_type in relation_types])
|
168 |
+
|
169 |
+
return scores
|