Spaces:
Runtime error
Runtime error
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import numpy as np | |
import paddle | |
__all__ = ['KIEMetric'] | |
class VQAReTokenMetric(object): | |
def __init__(self, main_indicator='hmean', **kwargs): | |
self.main_indicator = main_indicator | |
self.reset() | |
def __call__(self, preds, batch, **kwargs): | |
pred_relations, relations, entities = preds | |
self.pred_relations_list.extend(pred_relations) | |
self.relations_list.extend(relations) | |
self.entities_list.extend(entities) | |
def get_metric(self): | |
gt_relations = [] | |
for b in range(len(self.relations_list)): | |
rel_sent = [] | |
relation_list = self.relations_list[b] | |
entitie_list = self.entities_list[b] | |
head_len = relation_list[0, 0] | |
if head_len > 0: | |
entitie_start_list = entitie_list[1:entitie_list[0, 0] + 1, 0] | |
entitie_end_list = entitie_list[1:entitie_list[0, 1] + 1, 1] | |
entitie_label_list = entitie_list[1:entitie_list[0, 2] + 1, 2] | |
for head, tail in zip(relation_list[1:head_len + 1, 0], | |
relation_list[1:head_len + 1, 1]): | |
rel = {} | |
rel["head_id"] = head | |
rel["head"] = (entitie_start_list[head], | |
entitie_end_list[head]) | |
rel["head_type"] = entitie_label_list[head] | |
rel["tail_id"] = tail | |
rel["tail"] = (entitie_start_list[tail], | |
entitie_end_list[tail]) | |
rel["tail_type"] = entitie_label_list[tail] | |
rel["type"] = 1 | |
rel_sent.append(rel) | |
gt_relations.append(rel_sent) | |
re_metrics = self.re_score( | |
self.pred_relations_list, gt_relations, mode="boundaries") | |
metrics = { | |
"precision": re_metrics["ALL"]["p"], | |
"recall": re_metrics["ALL"]["r"], | |
"hmean": re_metrics["ALL"]["f1"], | |
} | |
self.reset() | |
return metrics | |
def reset(self): | |
self.pred_relations_list = [] | |
self.relations_list = [] | |
self.entities_list = [] | |
def re_score(self, pred_relations, gt_relations, mode="strict"): | |
"""Evaluate RE predictions | |
Args: | |
pred_relations (list) : list of list of predicted relations (several relations in each sentence) | |
gt_relations (list) : list of list of ground truth relations | |
rel = { "head": (start_idx (inclusive), end_idx (exclusive)), | |
"tail": (start_idx (inclusive), end_idx (exclusive)), | |
"head_type": ent_type, | |
"tail_type": ent_type, | |
"type": rel_type} | |
vocab (Vocab) : dataset vocabulary | |
mode (str) : in 'strict' or 'boundaries'""" | |
assert mode in ["strict", "boundaries"] | |
relation_types = [v for v in [0, 1] if not v == 0] | |
scores = { | |
rel: { | |
"tp": 0, | |
"fp": 0, | |
"fn": 0 | |
} | |
for rel in relation_types + ["ALL"] | |
} | |
# Count GT relations and Predicted relations | |
n_sents = len(gt_relations) | |
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations]) | |
n_found = sum([len([rel for rel in sent]) for sent in pred_relations]) | |
# Count TP, FP and FN per type | |
for pred_sent, gt_sent in zip(pred_relations, gt_relations): | |
for rel_type in relation_types: | |
# strict mode takes argument types into account | |
if mode == "strict": | |
pred_rels = {(rel["head"], rel["head_type"], rel["tail"], | |
rel["tail_type"]) | |
for rel in pred_sent | |
if rel["type"] == rel_type} | |
gt_rels = {(rel["head"], rel["head_type"], rel["tail"], | |
rel["tail_type"]) | |
for rel in gt_sent if rel["type"] == rel_type} | |
# boundaries mode only takes argument spans into account | |
elif mode == "boundaries": | |
pred_rels = {(rel["head"], rel["tail"]) | |
for rel in pred_sent | |
if rel["type"] == rel_type} | |
gt_rels = {(rel["head"], rel["tail"]) | |
for rel in gt_sent if rel["type"] == rel_type} | |
scores[rel_type]["tp"] += len(pred_rels & gt_rels) | |
scores[rel_type]["fp"] += len(pred_rels - gt_rels) | |
scores[rel_type]["fn"] += len(gt_rels - pred_rels) | |
# Compute per entity Precision / Recall / F1 | |
for rel_type in scores.keys(): | |
if scores[rel_type]["tp"]: | |
scores[rel_type]["p"] = scores[rel_type]["tp"] / ( | |
scores[rel_type]["fp"] + scores[rel_type]["tp"]) | |
scores[rel_type]["r"] = scores[rel_type]["tp"] / ( | |
scores[rel_type]["fn"] + scores[rel_type]["tp"]) | |
else: | |
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0 | |
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0: | |
scores[rel_type]["f1"] = ( | |
2 * scores[rel_type]["p"] * scores[rel_type]["r"] / | |
(scores[rel_type]["p"] + scores[rel_type]["r"])) | |
else: | |
scores[rel_type]["f1"] = 0 | |
# Compute micro F1 Scores | |
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types]) | |
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types]) | |
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types]) | |
if tp: | |
precision = tp / (tp + fp) | |
recall = tp / (tp + fn) | |
f1 = 2 * precision * recall / (precision + recall) | |
else: | |
precision, recall, f1 = 0, 0, 0 | |
scores["ALL"]["p"] = precision | |
scores["ALL"]["r"] = recall | |
scores["ALL"]["f1"] = f1 | |
scores["ALL"]["tp"] = tp | |
scores["ALL"]["fp"] = fp | |
scores["ALL"]["fn"] = fn | |
# Compute Macro F1 Scores | |
scores["ALL"]["Macro_f1"] = np.mean( | |
[scores[ent_type]["f1"] for ent_type in relation_types]) | |
scores["ALL"]["Macro_p"] = np.mean( | |
[scores[ent_type]["p"] for ent_type in relation_types]) | |
scores["ALL"]["Macro_r"] = np.mean( | |
[scores[ent_type]["r"] for ent_type in relation_types]) | |
return scores | |