quad_match_score / quad_match_score.py
yuyijiong's picture
修复多个refer时f1不正常的bug
120df80
raw
history blame
31.5 kB
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TODO: Add a description here."""
import copy
import re
from typing import List, Dict, Union, Callable
import numpy as np
import datasets
import evaluate
from rouge_chinese import Rouge
from scipy.optimize import linear_sum_assignment
# TODO: Add BibTeX citation
_CITATION = """\
@InProceedings{huggingface:module,
title = {quad match score},
authors={huggingface, Inc.},
year={2020}
}
"""
# TODO: Add description of the module here
_DESCRIPTION = """\
evaluate sentiment quadruples.
评估生成模型的情感四元组
"""
# TODO: Add description of the arguments of the module here
_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using certain scores
Args:
predictions: list of predictions to score. Each predictions
should be a string with tokens separated by spaces.
references: list of reference for each prediction. Each
reference should be a string with tokens separated by spaces.
Returns:
score: sentiment quadruple match score
Examples:
Examples should be written in doctest format, and should illustrate how
to use the function.
>>> import evaluate
>>> module = evaluate.load("yuyijiong/quad_match_score")
>>> predictions=["food | good | food#taste | pos"]
>>> references=["food | good | food#taste | pos & service | bad | service#general | neg"]
>>> result=module.compute(predictions=predictions, references=references)
>>> print(result)
result={'ave match score of weight (1, 1, 1, 1)': 0.375,
'f1 score of exact match': 0.0,
'f1 score of optimal match of weight (1, 1, 1, 1)': 0.5}
"""
# 计算rougel的f1值
def get_rougel_f1(text_pred_list: List[str], text_true_list: List[str]) -> float:
assert len(text_pred_list) == len(text_true_list), "文本数量不一致"
# 如果text_pred_list[0]为空字符串或空格,则返回0
if not text_pred_list[0].strip():
return 0
rouge = Rouge()
# 判断text_true[0]是否有中文,有中文则要用空格分割
if re.search(u"[\u4e00-\u9fa5]+", text_pred_list[0]):
text_pred_list = [' '.join(list(text_pred)) for text_pred in text_pred_list]
text_true_list = [' '.join(list(text_true)) for text_true in text_true_list]
rouge_l_f1 = rouge.get_scores(text_pred_list, text_true_list, avg=True)['rouge-l']['f']
return rouge_l_f1
# 记录四元组的函数
class CommentUnitsSim:
def __init__(self, data: List[Dict[str, str]], data_source: any = None, abnormal=False, language=None):
self.data_source = data_source
self.abnormal = abnormal
data = copy.deepcopy(data)
# 如果字典有target,则改名为target_text
for quad_dict in data:
if 'target' in quad_dict:
quad_dict['target_text'] = quad_dict['target']
del quad_dict['target']
if 'opinion' in quad_dict:
quad_dict['opinion_text'] = quad_dict['opinion']
del quad_dict['opinion']
self.data = data
self.polarity_en2zh = {'positive': '积极', 'negative': '消极', 'neutral': '中性', 'pos': '积极', 'neg': '消极',
'neu': '中性', '积极': '积极', '消极': '消极', '中性': '中性'}
self.polarity_zh2en = {'积极': 'pos', '消极': 'neg', '中性': 'neu', 'pos': 'pos', 'neg': 'neg', 'neu': 'neu',
'positive': 'pos', 'negative': 'neg', 'neutral': 'neu'}
self.language = language if language is not None else 'zh' if self.check_zh() else 'en'
self.none_sign = 'null'
@property
def num(self):
return len(self.data)
# 检查四元组中是否有中文
def check_zh(self):
for quad_dict in self.data:
if re.search('[\u4e00-\u9fa5]', quad_dict['target_text']) or re.search('[\u4e00-\u9fa5]',
quad_dict['opinion_text']):
return True
return False
# 检测极性是否正确
def check_polarity(self):
# 若有某个四元组的极性不是positive、negative、neutral,则返回False
for quad_dict in self.data:
if quad_dict['polarity'] not in ['positive', 'negative', 'neutral', 'pos', 'neg', 'neu', '积极', '消极',
'中性']:
self.abnormal = True
return False
# 将极性由英文转为中文
def convert_polarity_en2zh(self):
for quad_dict in self.data:
quad_dict['polarity'] = self.polarity_en2zh[quad_dict['polarity']]
return self
# 将极性由中文转为英文
def convert_polarity_zh2en(self):
for quad_dict in self.data:
quad_dict['polarity'] = self.polarity_zh2en[quad_dict['polarity']]
return self
# 检查是否有重复的四元组,若有则删除重复的
def del_duplicate(self):
new_data = []
for quad_dict in self.data:
if quad_dict not in new_data:
new_data.append(quad_dict)
self.data = new_data
return self
# 检查是否有target和opinion都为null的四元组,若有则返回True
def check_target_opinion_null(self):
for quad_dict in self.data:
if quad_dict['target_text'] == 'null' and quad_dict['opinion_text'] == 'null':
return True
return False
# 检查是否有target或opinion为null的四元组,若有则返回True
def check_any_null(self):
for quad_dict in self.data:
if quad_dict['target_text'] == 'null' or quad_dict['opinion_text'] == 'null':
return True
return False
@classmethod
def from_str(cls, quadruple_str: str, tuple_len: Union[int, list, str] = 4, format_code=0, sep_token1=' & ',
sep_token2=' | '):
data = []
abnormal = False
# 确保分隔符后面一定是空格
for i in range(len(quadruple_str) - 1):
if (quadruple_str[i] == sep_token1.strip() or quadruple_str[i] == sep_token2.strip()) and quadruple_str[
i + 1] != ' ':
quadruple_str = quadruple_str[:i + 1] + ' ' + quadruple_str[i + 1:]
# 选择几元组,即创建列表索引,从四元组中抽出n元
if isinstance(tuple_len, int):
tuple_index = list(range(tuple_len))
elif isinstance(tuple_len, list):
tuple_index = tuple_len
elif isinstance(tuple_len, str):
# 例如将‘012’转换为[0,1,2]
tuple_index = [int(i) for i in tuple_len]
else:
raise Exception('tuple_len参数错误')
for quadruple in quadruple_str.split(sep_token1):
if format_code == 0:
# quadruple可能是target|opinion|aspect|polarity,也可能是target|opinion|aspect,也可能是target|opinion,若没有则为“None”
quadruple_split = [unit.strip() for unit in quadruple.split(sep_token2)]
if len(quadruple_split) > len(tuple_index):
print('quadruple格式错误,过多元素', quadruple_str)
abnormal = True
quadruple_split = quadruple_split[0:len(tuple_index)] # 过长则截断
elif len(quadruple_split) < len(tuple_index):
print('quadruple格式错误,过少元素', quadruple_str)
abnormal = True
quadruple_split = ["None"] * (
len(tuple_index) - len(quadruple_split)) + quadruple_split # 过短则补'None'
quadruple_keys = [["target_text", "opinion_text", "aspect", "polarity"][i] for i in tuple_index]
quadruple_dict = dict(zip(quadruple_keys, quadruple_split))
q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'}
q.update(quadruple_dict)
# 检查极性是否合法
if q['polarity'] not in ['pos', 'neg', 'neu', 'None', '积极', '消极', '中性']:
print('quadruple格式错误,极性格式不对', quadruple_str)
else:
raise Exception('answer_format参数错误')
data.append(q)
return CommentUnitsSim(data, quadruple_str, abnormal)
@classmethod
def from_list(cls, quadruple_list: List[List[str]], **kwargs):
data = []
for quadruple in quadruple_list:
# #format_code='013'代表list只有四元组的第0、1、3个元素,需要扩充为4元组,空缺位置补上None
# if format_code=='013':
# quadruple.insert(2,None)
data.append(
{"target_text": quadruple[0], "opinion_text": quadruple[1], "aspect": quadruple[2],
"polarity": quadruple[3]})
return CommentUnitsSim(data, quadruple_list, **kwargs)
@classmethod
def from_list_dict(cls, quadruple_list: List[dict], **kwargs):
for quad_dict in quadruple_list:
if 'target' in quad_dict:
quad_dict['target_text'] = quad_dict['target']
del quad_dict['target']
if 'opinion' in quad_dict:
quad_dict['opinion_text'] = quad_dict['opinion']
del quad_dict['opinion']
data = []
for quadruple in quadruple_list:
# 如果quadruple缺少某个key,则补上None
q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'}
q.update(quadruple)
data.append(q)
return CommentUnitsSim(data, quadruple_list, **kwargs)
# 转化为list,即只保留字典的value
def to_list(self):
data = []
for quad_dict in self.data:
data.append(
[quad_dict['target_text'], quad_dict['opinion_text'], quad_dict['aspect'], quad_dict['polarity']])
return data
# 将data转换为n元组字符串
def get_quadruple_str(self, format_code=0, tuple_len: Union[int, list, str] = 4, sep_token1=' & ',
sep_token2=' | '):
new_text_list = []
# 选择几元组,即创建列表索引,从四元组中抽出n元
if isinstance(tuple_len, int):
tuple_index = list(range(tuple_len))
elif isinstance(tuple_len, list):
tuple_index = tuple_len
elif isinstance(tuple_len, str):
# 例如将‘012’转换为[0,1,2]
tuple_index = [int(i) for i in tuple_len]
else:
raise Exception('tuple_len参数错误')
try:
# 若语言为中文,则使用中文极性
if self.language == 'zh':
self.convert_polarity_en2zh()
else:
self.convert_polarity_zh2en()
except:
print('语言参数错误', self.data)
print(self.language)
raise Exception('语言参数错误')
# 若tuple_index==[3],则返回综合情感极性
if tuple_index == [3]:
return self.merge_polarity()
for quad_dict in self.data:
# 提取target_text,如果空列表则为'',如果列表长度大于1则为','.join(list)
target_text = quad_dict['target_text']
# 提取opinion_text,如果空列表则为'',如果列表长度大于1则为','.join(list)
opinion_text = quad_dict['opinion_text']
# 提取aspect
aspect = quad_dict['aspect']
# 提取polarity
polarity = quad_dict['polarity']
# 拼接,‘|’分割
if format_code == 0:
# 根据tuple_len拼接
new_text = sep_token2.join([[target_text, opinion_text, aspect, polarity][i] for i in tuple_index])
else:
raise Exception('answer_format参数错误')
new_text_list.append(new_text)
# 如果tuple_index为[2,3],则需要去除new_text_list中重复的元素,不要改变顺序。因为可能有重复的方面
if tuple_index == [2, 3]:
res = []
for t in new_text_list:
if t not in res:
res.append(t)
new_text_list = res
# 如果tuple_index为[3],则只保留new_text_list的第一个元素。因为只有一个情感极性
elif tuple_index == [3]:
new_text_list = new_text_list[:1]
if format_code == 0:
# 根据tuple_len拼接
return sep_token1.join(new_text_list)
# 与另一个CommentUnits对象对比,检测有几个相同的四元组
def compare_same(self, other) -> int:
count = 0
for quad_dict in self.data:
if quad_dict in other.data:
count += 1
return count
# 检查自身数据的四元组中target是否有重复
def check_target_repeat(self):
target_list = []
for quad_dict in self.data:
target_list.append(quad_dict['target_text'])
return len(target_list) != len(set(target_list))
# 检查自身数据的四元组中opinion是否有重复
def check_opinion_repeat(self):
opinion_list = []
for quad_dict in self.data:
opinion_list.append(quad_dict['opinion_text'])
return len(opinion_list) != len(set(opinion_list))
# 检查自身数据的四元组中aspect是否有重复
def check_aspect_repeat(self):
aspect_list = []
for quad_dict in self.data:
aspect_list.append(quad_dict['aspect'])
return len(aspect_list) != len(set(aspect_list))
# 输出所有aspect的列表
def get_aspect_list(self):
aspect_list = []
for quad_dict in self.data:
aspect_list.append(quad_dict['aspect'])
return aspect_list
# 输出所有target的列表
def get_target_list(self):
target_list = []
for quad_dict in self.data:
target_list.append(quad_dict['target_text'])
return target_list
# 输出所有opinion的列表
def get_opinion_list(self):
opinion_list = []
for quad_dict in self.data:
opinion_list.append(quad_dict['opinion_text'])
return opinion_list
# 输出所有polarity的列表
def get_polarity_list(self):
polarity_list = []
for quad_dict in self.data:
polarity_list.append(quad_dict['polarity'])
return polarity_list
# 对所有polarity进行综合
def merge_polarity(self):
polarity_list = self.get_polarity_list()
# 判断是英文还是中文
if self.language == 'en':
if 'pos' in polarity_list and 'neg' in polarity_list:
return 'neu'
elif 'pos' in polarity_list:
return 'pos'
elif 'neg' in polarity_list:
return 'neg'
else:
return 'neu'
else:
if '积极' in polarity_list and '消极' in polarity_list:
return '中性'
elif '积极' in polarity_list:
return '积极'
elif '消极' in polarity_list:
return '消极'
else:
return '中性'
# 检测是否有不合法opinion
def check_opinion_in_comment(self, comment_text):
for quad_dict in self.data:
if quad_dict['opinion_text'] != '*' and (not quad_dict['opinion_text'] in comment_text):
return False
return True
# 检测是否有不合法target
def check_target_in_comment(self, comment_text):
for quad_dict in self.data:
if quad_dict['target_text'] != '*' and (not quad_dict['target_text'] in comment_text):
return False
return True
# 计算两个四元组的相似度
@staticmethod
def get_similarity(units1, units2: 'CommentUnitsSim'):
pass
# 对自身数据进行操作
def apply(self, func: Callable, field: str):
for quad_dict in self.data:
quad_dict[field] = func(quad_dict[field])
return self
# 四元组匹配函数
class CommentUnitsMatch:
def __init__(self, target_weight=0.5, opinion_weight=0.5, aspect_weight=0.5, polarity_weight=0.5, one_match=True):
# 归一化权重
weight_sum = target_weight + opinion_weight + aspect_weight + polarity_weight
self.target_weight = target_weight / weight_sum
self.opinion_weight = opinion_weight / weight_sum
self.aspect_weight = aspect_weight / weight_sum
self.polarity_weight = polarity_weight / weight_sum
# 是否一对一匹配
self.one_match = one_match
# 特定feature置零
def set_zero(self, feature: str = 'polarity'):
if feature == 'polarity':
self.polarity_weight = 0
elif feature == 'aspect':
self.aspect_weight = 0
elif 'opinion' in feature:
self.opinion_weight = 0
elif 'target' in feature:
self.target_weight = 0
else:
raise Exception('feature参数错误')
def re_normalize(self):
weight_sum = self.target_weight + self.opinion_weight + self.aspect_weight + self.polarity_weight
self.target_weight = self.target_weight / weight_sum
self.opinion_weight = self.opinion_weight / weight_sum
self.aspect_weight = self.aspect_weight / weight_sum
self.polarity_weight = self.polarity_weight / weight_sum
# 计算cost矩阵,完全匹配为0,不匹配为1
def get_cost_matrix(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim', feature: str = 'polarity'):
pass
# 检查此feature是否存在,不存在则返回全0矩阵
if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None \
or units1.data[0].get(feature) == 'None' or units2.data[0].get(feature) == 'None':
cost_matrix = np.zeros((len(units1.data), len(units2.data)))
# 对应feature的weight也为0
self.set_zero(feature)
# 并再次归一化
self.re_normalize()
return cost_matrix
# 检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。不相同则cost为1,相同则cost为0
cost_matrix = []
for quad_dict1 in units1.data:
cost_list = []
for quad_dict2 in units2.data:
if quad_dict1[feature] == quad_dict2[feature]:
cost_list.append(0)
else:
cost_list.append(1)
cost_matrix.append(cost_list)
# cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
cost_matrix = np.array(cost_matrix)
return cost_matrix
# 计算cost矩阵,使用rougel指标
def get_cost_matrix_rouge(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim', feature: str = 'target_text'):
# 检查此feature是否存在,不存在则返回全0矩阵
if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None \
or units1.data[0].get(feature) == 'None' or units2.data[0].get(feature) == 'None':
cost_matrix = np.zeros((len(units1.data), len(units2.data)))
# 对应feature的weight也为0
self.set_zero(feature)
# 并再次归一化
self.re_normalize()
return cost_matrix
# 检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。相同则cost为0,不相同则cost为1-rougel
cost_matrix = []
for quad_dict1 in units1.data:
cost_list = []
for quad_dict2 in units2.data:
if quad_dict1[feature] == quad_dict2[feature]:
cost_list.append(0)
else:
cost_list.append(1 - get_rougel_f1([quad_dict1[feature]], [quad_dict2[feature]]))
cost_matrix.append(cost_list)
# cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
cost_matrix = np.array(cost_matrix)
return cost_matrix
# 匹配四元组并计算cost
def match_units(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim') -> tuple:
# 计算极性的cost矩阵,矩阵元素在0-1之间
cost_matrix_polarity = self.get_cost_matrix(units1, units2, feature='polarity')
# 计算aspect的cost矩阵
cost_matrix_aspect = self.get_cost_matrix(units1, units2, feature='aspect')
# 计算target的cost矩阵
cost_matrix_target = self.get_cost_matrix_rouge(units1, units2, feature='target_text')
# 计算opinion的cost矩阵
cost_matrix_opinion = self.get_cost_matrix_rouge(units1, units2, feature='opinion_text')
# 计算总的cost矩阵,矩阵元素在0-1之间。矩阵的行数为units1即pred的数量,列数为units2即true的数量
cost_matrix = self.target_weight * cost_matrix_target + self.opinion_weight * cost_matrix_opinion + \
self.aspect_weight * cost_matrix_aspect + self.polarity_weight * cost_matrix_polarity
score_matrix = 1 - cost_matrix
cost = 0
# 使用匈牙利算法进行匹配
if self.one_match:
# 只允许一对一的匹配,这种情况下row_ind和col_ind的长度一定相等且等于units1和units2的数量中的较小值
row_ind, col_ind = linear_sum_assignment(cost_matrix)
else:
# 允许一对多的匹配。这种情况下每个四元组都一定匹配上,这种情况下row_ind和col_ind的长度一定相等且等于units1和units2的数量中的较大值
if units1.num > units2.num:
row_ind = np.arange(units1.num)
col_ind = np.argmin(cost_matrix, axis=1)
else:
row_ind = np.argmin(cost_matrix, axis=0)
col_ind = np.arange(units2.num)
# 计算这种匹配的cost
for i in range(len(row_ind)):
cost += cost_matrix[row_ind[i]][col_ind[i]]
# 计算这种匹配下的TP\FP\FN
TP = 0
for i in range(len(row_ind)):
TP += score_matrix[row_ind[i]][col_ind[i]]
# len(row_ind)为pred的数量,TP为匹配上的数量
FP = units1.num - TP
FN = units2.num - TP
# 如果一对一匹配,会有匹配不上的四元组,这些四元组cost为1
max_units_num = max(units1.num, units2.num)
if self.one_match:
cost += (max_units_num - len(row_ind))
# 对cost进行归一化,使其在0-1之间
cost_per_quadruple = cost / max_units_num
if cost_per_quadruple > 1 or cost_per_quadruple < 0:
print('cost错误', cost_per_quadruple, 'pred:', units1.data, 'true:', units2.data)
print(self.target_weight, self.opinion_weight, self.aspect_weight, self.polarity_weight)
# 返回的cost在0-1之间
return cost_per_quadruple, TP, FP, FN
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class QuadMatch(evaluate.Metric):
"""TODO: Short description of my evaluation module."""
def _info(self):
# TODO: Specifies the evaluate.EvaluationModuleInfo object
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=[
datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Sequence(datasets.Value("string", id="sequence")),
}
),
datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Value("string", id="sequence"),
}
),
],
# Homepage of the module for documentation
homepage="http://module.homepage",
# Additional links to the codebase or references
codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
reference_urls=["http://path.to.reference.url/new_module"]
)
def _download_and_prepare(self, dl_manager):
"""Optional: download external resources useful to compute the scores"""
# TODO: Download external resources if needed
pass
def _compute(self,
predictions: List[str],
references: Union[List[str], List[List[str]]],
quad_weights: tuple = (1, 1, 1, 1),
**kwargs) -> dict:
'''
:param predictions: list of predictions of sentiment quads
:param references: list of references of sentiment quads
:param quad_weights: weight of target,opinion,aspect,polarity for cost compute
:param kwargs:
:param tuple_len: indicate the format of the quad, see the following mapping
:param sep_token1: the token to seperate quads
:param sep_token2: the token to seperate units of one quad
:return:average matching score
#mapping
id2prompt={'0123':"quadruples (target | opinion | aspect | polarity)",
'':"quadruples (target | opinion | aspect | polarity)",
'01':'pairs (target | opinion)',
'012':'triples (target | opinion | aspect)',
'013':'triples (target | opinion | polarity)',
'023':'triples (target | aspect | polarity)',
'23':'pairs (aspect | polarity)',
'03':'pairs (target | polarity)',
'13':'pairs (opinion | polarity)',
'3':'single (polarity)'}
#中文版映射
id2prompt_zh={'0123': "四元组(对象 | 观点 | 方面 | 极性)",
'':"四元组(对象 | 观点 | 方面 | 极性)",
'01':'二元组(对象 | 观点)',
'012':'三元组(对象 | 观点 | 方面)',
'013':'三元组(对象 | 观点 | 极性)',
'023':'三元组(对象 | 方面 | 极性)',
'23':'二元组(方面 | 极性)',
'03':'二元组(对象 | 极性)',
'13':'二元组(观点 | 极性)',
'3':'单元素(极性)'}
'''
f1_of_optimal_match, score_of_optimal_match = self.quad_f1_of_optimal_match(predictions, references,
quad_weights, **kwargs)
f1 = self.quad_f1_of_exact_match(predictions=predictions, references=references, **kwargs)
# 取1-cost为得分
return {'score of optimal match of weight ' + str(quad_weights): score_of_optimal_match,
'f1 of optimal match of weight ' + str(quad_weights): f1_of_optimal_match,
'f1 of exact match': f1}
@staticmethod
def quad_f1_of_exact_match(predictions: List[str], references: Union[List[str], List[List[str]]],
return_dict=False, **kwargs) -> Union[Dict[str, float], float]:
assert len(predictions) == len(references), "文本数量不一致"
correct, pred_num, true_num = 0, 0, 0
for pred, refer in zip(predictions, references):
pred = CommentUnitsSim.from_str(pred, **kwargs)
# refer转换为list
if isinstance(refer, str):
refer =[refer]
# refer转换为CommentUnitsSim
refer = [CommentUnitsSim.from_str(t, **kwargs) for t in refer]
# 如果refer是list,说明有多个正确答案,取最高分的那个
#计算每个refer的TP的个数
correct_list = [pred.compare_same(t) for t in refer]
#计算每个refer的f1
f1_list=[2 * correct_list[i] / (pred.num + refer[i].num) for i in range(len(refer))]
# 获取f1得分最高的索引
best_index = f1_list.index(max(f1_list))
pred_num += pred.num
true_num += refer[best_index].num
correct += correct_list[best_index]
# 以下结果保留4位小数
precision = round(correct / pred_num, 4) + 1e-8
recall = round(correct / true_num, 4) + 1e-8
f1 = round(2 * precision * recall / (precision + recall), 4)
if return_dict:
return {"precision": precision, "recall": recall, "f1": f1}
else:
return f1
# 计算最优匹配f1
@staticmethod
def quad_f1_of_optimal_match(
predictions: List[str],
references: Union[List[str], List[List[str]]],
quad_weights: tuple = (1, 1, 1, 1),
one_match=True,
**kwargs):
assert len(predictions) == len(references)
if isinstance(predictions, str):
predictions = [predictions]
references = [references]
cost = 0
TP, FP, FN = 0, 0, 0
matcher = CommentUnitsMatch(*quad_weights, one_match=one_match)
for pred, refer in zip(predictions, references):
pred = CommentUnitsSim.from_str(pred, **kwargs)
# 将refer转换为list形式
if isinstance(refer, str):
refer = [refer]
# 将refer中的每个元素转换为CommentUnitsSim
refer = [CommentUnitsSim.from_str(t, **kwargs) for t in refer]
# 如果true是多个正确答案,取最高分
cost_list = [matcher.match_units(pred, t) for t in refer]
# 获取cost最小的值的索引,按元组中第一个元素大小排序
# 计算每一对样本的cost,TP,FP,FN
cost_, TP_, FP_, FN_ = cost_list[np.argmin([c[0] for c in cost_list])]
cost += cost_
TP += TP_
FP += FP_
FN += FN_
# 平均cost
cost = cost / len(predictions)
# 由TP\FP\FN计算最优匹配F1
precision_match = TP / (TP + FP)
recall_match = TP / (TP + FN)
f1_match = 2 * precision_match * recall_match / (precision_match + recall_match)
return f1_match, 1 - cost