Spaces:
Sleeping
Sleeping
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""TODO: Add a description here.""" | |
import copy | |
import re | |
from typing import List, Dict, Union,Callable | |
import numpy as np | |
import datasets | |
import evaluate | |
from rouge_chinese import Rouge | |
from scipy.optimize import linear_sum_assignment | |
# TODO: Add BibTeX citation | |
_CITATION = """\ | |
@InProceedings{huggingface:module, | |
title = {A great new module}, | |
authors={huggingface, Inc.}, | |
year={2020} | |
} | |
""" | |
# TODO: Add description of the module here | |
_DESCRIPTION = """\ | |
evaluate sentiment quadruples. | |
评估生成模型的情感四元组 | |
""" | |
# TODO: Add description of the arguments of the module here | |
_KWARGS_DESCRIPTION = """ | |
Calculates how good are predictions given some references, using certain scores | |
Args: | |
predictions: list of predictions to score. Each predictions | |
should be a string with tokens separated by spaces. | |
references: list of reference for each prediction. Each | |
reference should be a string with tokens separated by spaces. | |
Returns: | |
score: sentiment quadruple match score | |
Examples: | |
Examples should be written in doctest format, and should illustrate how | |
to use the function. | |
>>> my_new_module = evaluate.load("my_new_module") | |
>>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1]) | |
>>> print(results) | |
{'accuracy': 1.0} | |
""" | |
def compute_quadruple_f1(y_pred: List[str], y_true: Union[List[str], List[List[str]]], | |
return_rp=False, **kwargs) -> Dict[str, float]: | |
assert len(y_pred) == len(y_true) | |
correct, pred_num, true_num = 0, 0, 0 | |
for pred, true in zip(y_pred, y_true): | |
pred = CommentUnitsSim.from_str(pred, **kwargs) | |
# 如果true是list,说明有多个正确答案 | |
if isinstance(true, str): | |
true = CommentUnitsSim.from_str(true, **kwargs) | |
else: | |
true = [CommentUnitsSim.from_str(t,**kwargs) for t in true] | |
# 如果true是list,说明有多个正确答案,取最高分 | |
if isinstance(true, list): | |
correct_list = [pred.compare_same(t) for t in true] | |
correct += max(correct_list) # 获取得分最高的值 | |
correct_index = correct_list.index(max(correct_list)) # 获取得分最高的索引 | |
pred_num += pred.num | |
true_num += true[correct_index].num | |
else: | |
correct += pred.compare_same(true) | |
pred_num += pred.num | |
true_num += true.num | |
# 以下结果保留4位小数 | |
precision = round(correct / pred_num, 4) + 1e-8 | |
recall = round(correct / true_num, 4) + 1e-8 | |
f1 = round(2 * precision * recall / (precision + recall), 4) | |
if return_rp: | |
return {"precision": precision, "recall": recall, "f1": f1} | |
else: | |
return f1 | |
# 计算rougel的f1值 | |
def get_rougel_f1(text_pred_list: List[str], text_true_list: List[str]) -> float: | |
assert len(text_pred_list) == len(text_true_list), "文本数量不一致" | |
#如果text_pred_list[0]为空字符串或空格,则返回0 | |
if not text_pred_list[0].strip(): | |
return 0 | |
rouge = Rouge() | |
# 判断text_true[0]是否有中文,有中文则要用空格分割 | |
if re.search(u"[\u4e00-\u9fa5]+", text_pred_list[0]): | |
text_pred_list = [' '.join(list(text_pred)) for text_pred in text_pred_list] | |
text_true_list = [' '.join(list(text_true)) for text_true in text_true_list] | |
rouge_l_f1 = rouge.get_scores(text_pred_list, text_true_list, avg=True)['rouge-l']['f'] | |
return rouge_l_f1 | |
# 记录四元组的函数 | |
class CommentUnitsSim: | |
def __init__(self, data: List[Dict[str, str]],data_source:any=None,abnormal=False,language=None): | |
self.data_source=data_source | |
self.abnormal=abnormal | |
data=copy.deepcopy(data) | |
# 如果字典有target,则改名为target_text | |
for quad_dict in data: | |
if 'target' in quad_dict: | |
quad_dict['target_text'] = quad_dict['target'] | |
del quad_dict['target'] | |
if 'opinion' in quad_dict: | |
quad_dict['opinion_text'] = quad_dict['opinion'] | |
del quad_dict['opinion'] | |
self.data = data | |
self.polarity_en2zh = {'positive': '积极', 'negative': '消极', 'neutral': '中性','pos':'积极','neg':'消极','neu':'中性','积极':'积极','消极':'消极','中性':'中性'} | |
self.polarity_zh2en={'积极':'pos','消极':'neg','中性':'neu','pos':'pos','neg':'neg','neu':'neu','positive':'pos','negative':'neg','neutral':'neu'} | |
self.language=language if language is not None else 'zh' if self.check_zh() else 'en' | |
self.none_sign='null' | |
def num(self): | |
return len(self.data) | |
#检查四元组中是否有中文 | |
def check_zh(self): | |
for quad_dict in self.data: | |
if re.search('[\u4e00-\u9fa5]',quad_dict['target_text']) or re.search('[\u4e00-\u9fa5]',quad_dict['opinion_text']): | |
return True | |
return False | |
# 检测极性是否正确 | |
def check_polarity(self): | |
#若有某个四元组的极性不是positive、negative、neutral,则返回False | |
for quad_dict in self.data: | |
if quad_dict['polarity'] not in ['positive', 'negative', 'neutral','pos','neg','neu','积极','消极','中性']: | |
self.abnormal=True | |
return False | |
#将极性由英文转为中文 | |
def convert_polarity_en2zh(self): | |
for quad_dict in self.data: | |
quad_dict['polarity']=self.polarity_en2zh[quad_dict['polarity']] | |
return self | |
#将极性由中文转为英文 | |
def convert_polarity_zh2en(self): | |
for quad_dict in self.data: | |
quad_dict['polarity']=self.polarity_zh2en[quad_dict['polarity']] | |
return self | |
#检查是否有重复的四元组,若有则删除重复的 | |
def del_duplicate(self): | |
new_data=[] | |
for quad_dict in self.data: | |
if quad_dict not in new_data: | |
new_data.append(quad_dict) | |
self.data=new_data | |
return self | |
#检查是否有target和opinion都为null的四元组,若有则返回True | |
def check_target_opinion_null(self): | |
for quad_dict in self.data: | |
if quad_dict['target_text']=='null' and quad_dict['opinion_text']=='null': | |
return True | |
return False | |
#检查是否有target或opinion为null的四元组,若有则返回True | |
def check_any_null(self): | |
for quad_dict in self.data: | |
if quad_dict['target_text']=='null' or quad_dict['opinion_text']=='null': | |
return True | |
return False | |
def from_str(cls, quadruple_str: str, tuple_len:Union[int,list,str]=4, format_code=0, sep_token1=' & ', sep_token2=' | '): | |
data = [] | |
abnormal=False | |
#确保分隔符后面一定是空格 | |
for i in range(len(quadruple_str)-1): | |
if (quadruple_str[i] == sep_token1.strip() or quadruple_str[i] == sep_token2.strip()) and quadruple_str[i + 1] != ' ': | |
quadruple_str = quadruple_str[:i + 1] + ' ' + quadruple_str[i + 1:] | |
# 选择几元组,即创建列表索引,从四元组中抽出n元 | |
if isinstance(tuple_len, int): | |
tuple_index = list(range(tuple_len)) | |
elif isinstance(tuple_len, list): | |
tuple_index = tuple_len | |
elif isinstance(tuple_len, str): | |
# 例如将‘012’转换为[0,1,2] | |
tuple_index = [int(i) for i in tuple_len] | |
else: | |
raise Exception('tuple_len参数错误') | |
for quadruple in quadruple_str.split(sep_token1): | |
if format_code == 0: | |
# quadruple可能是target|opinion|aspect|polarity,也可能是target|opinion|aspect,也可能是target|opinion,若没有则为“None” | |
quadruple_split=[unit.strip() for unit in quadruple.split(sep_token2)] | |
if len(quadruple_split)>len(tuple_index): | |
print('quadruple格式错误,过多元素', quadruple_str) | |
abnormal=True | |
quadruple_split=quadruple_split[0:len(tuple_index)] #过长则截断 | |
elif len(quadruple_split)<len(tuple_index): | |
print('quadruple格式错误,过少元素', quadruple_str) | |
abnormal=True | |
quadruple_split=["None"]*(len(tuple_index)-len(quadruple_split))+quadruple_split #过短则补'None' | |
quadruple_keys=[["target_text","opinion_text","aspect","polarity"][i] for i in tuple_index] | |
quadruple_dict=dict(zip(quadruple_keys,quadruple_split)) | |
q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'} | |
q.update(quadruple_dict) | |
#检查极性是否合法 | |
if q['polarity'] not in ['pos','neg','neu','None','积极','消极','中性']: | |
print('quadruple格式错误,极性格式不对', quadruple_str) | |
else: | |
raise Exception('answer_format参数错误') | |
data.append(q) | |
return CommentUnitsSim(data,quadruple_str,abnormal) | |
def from_list(cls, quadruple_list: List[List[str]],**kwargs): | |
data = [] | |
for quadruple in quadruple_list: | |
# #format_code='013'代表list只有四元组的第0、1、3个元素,需要扩充为4元组,空缺位置补上None | |
# if format_code=='013': | |
# quadruple.insert(2,None) | |
data.append( | |
{"target_text": quadruple[0], "opinion_text": quadruple[1], "aspect": quadruple[2], | |
"polarity": quadruple[3]}) | |
return CommentUnitsSim(data,quadruple_list,**kwargs) | |
def from_list_dict(cls, quadruple_list: List[dict],**kwargs): | |
for quad_dict in quadruple_list: | |
if 'target' in quad_dict: | |
quad_dict['target_text'] = quad_dict['target'] | |
del quad_dict['target'] | |
if 'opinion' in quad_dict: | |
quad_dict['opinion_text'] = quad_dict['opinion'] | |
del quad_dict['opinion'] | |
data = [] | |
for quadruple in quadruple_list: | |
#如果quadruple缺少某个key,则补上None | |
q={"target_text":'None',"opinion_text":'None',"aspect":'None',"polarity":'None'} | |
q.update(quadruple) | |
data.append(q) | |
return CommentUnitsSim(data,quadruple_list,**kwargs) | |
#转化为list,即只保留字典的value | |
def to_list(self): | |
data = [] | |
for quad_dict in self.data: | |
data.append([quad_dict['target_text'],quad_dict['opinion_text'],quad_dict['aspect'],quad_dict['polarity']]) | |
return data | |
# 将data转换为n元组字符串 | |
def get_quadruple_str(self, format_code=0, tuple_len:Union[int,list,str]=4,sep_token1=' & ',sep_token2=' | '): | |
new_text_list = [] | |
# 选择几元组,即创建列表索引,从四元组中抽出n元 | |
if isinstance(tuple_len, int): | |
tuple_index = list(range(tuple_len)) | |
elif isinstance(tuple_len, list): | |
tuple_index = tuple_len | |
elif isinstance(tuple_len, str): | |
# 例如将‘012’转换为[0,1,2] | |
tuple_index = [int(i) for i in tuple_len] | |
else: | |
raise Exception('tuple_len参数错误') | |
try: | |
#若语言为中文,则使用中文极性 | |
if self.language=='zh': | |
self.convert_polarity_en2zh() | |
else: | |
self.convert_polarity_zh2en() | |
except: | |
print('语言参数错误',self.data) | |
print(self.language) | |
raise Exception('语言参数错误') | |
#若tuple_index==[3],则返回综合情感极性 | |
if tuple_index==[3]: | |
return self.merge_polarity() | |
for quad_dict in self.data: | |
# 提取target_text,如果空列表则为'',如果列表长度大于1则为','.join(list) | |
target_text = quad_dict['target_text'] | |
# 提取opinion_text,如果空列表则为'',如果列表长度大于1则为','.join(list) | |
opinion_text = quad_dict['opinion_text'] | |
# 提取aspect | |
aspect = quad_dict['aspect'] | |
# 提取polarity | |
polarity = quad_dict['polarity'] | |
# 拼接,‘|’分割 | |
if format_code == 0: | |
# 根据tuple_len拼接 | |
new_text = sep_token2.join([[target_text, opinion_text, aspect, polarity][i] for i in tuple_index]) | |
else: | |
raise Exception('answer_format参数错误') | |
new_text_list.append(new_text) | |
#如果tuple_index为[2,3],则需要去除new_text_list中重复的元素,不要改变顺序。因为可能有重复的方面 | |
if tuple_index==[2,3]: | |
res = [] | |
for t in new_text_list: | |
if t not in res: | |
res.append(t) | |
new_text_list=res | |
#如果tuple_index为[3],则只保留new_text_list的第一个元素。因为只有一个情感极性 | |
elif tuple_index==[3]: | |
new_text_list=new_text_list[:1] | |
if format_code == 0: | |
# 根据tuple_len拼接 | |
return sep_token1.join(new_text_list) | |
# 与另一个CommentUnits对象对比,检测有几个相同的四元组 | |
def compare_same(self, other)->int: | |
count = 0 | |
for quad_dict in self.data: | |
if quad_dict in other.data: | |
count += 1 | |
return count | |
# 检查自身数据的四元组中target是否有重复 | |
def check_target_repeat(self): | |
target_list = [] | |
for quad_dict in self.data: | |
target_list.append(quad_dict['target_text']) | |
return len(target_list) != len(set(target_list)) | |
# 检查自身数据的四元组中opinion是否有重复 | |
def check_opinion_repeat(self): | |
opinion_list = [] | |
for quad_dict in self.data: | |
opinion_list.append(quad_dict['opinion_text']) | |
return len(opinion_list) != len(set(opinion_list)) | |
# 检查自身数据的四元组中aspect是否有重复 | |
def check_aspect_repeat(self): | |
aspect_list = [] | |
for quad_dict in self.data: | |
aspect_list.append(quad_dict['aspect']) | |
return len(aspect_list) != len(set(aspect_list)) | |
# 输出所有aspect的列表 | |
def get_aspect_list(self): | |
aspect_list = [] | |
for quad_dict in self.data: | |
aspect_list.append(quad_dict['aspect']) | |
return aspect_list | |
# 输出所有target的列表 | |
def get_target_list(self): | |
target_list = [] | |
for quad_dict in self.data: | |
target_list.append(quad_dict['target_text']) | |
return target_list | |
# 输出所有opinion的列表 | |
def get_opinion_list(self): | |
opinion_list = [] | |
for quad_dict in self.data: | |
opinion_list.append(quad_dict['opinion_text']) | |
return opinion_list | |
# 输出所有polarity的列表 | |
def get_polarity_list(self): | |
polarity_list = [] | |
for quad_dict in self.data: | |
polarity_list.append(quad_dict['polarity']) | |
return polarity_list | |
#对所有polarity进行综合 | |
def merge_polarity(self): | |
polarity_list = self.get_polarity_list() | |
#判断是英文还是中文 | |
if self.language == 'en': | |
if 'pos' in polarity_list and 'neg' in polarity_list: | |
return 'neu' | |
elif 'pos' in polarity_list: | |
return 'pos' | |
elif 'neg' in polarity_list: | |
return 'neg' | |
else: | |
return 'neu' | |
else: | |
if '积极' in polarity_list and '消极' in polarity_list: | |
return '中性' | |
elif '积极' in polarity_list: | |
return '积极' | |
elif '消极' in polarity_list: | |
return '消极' | |
else: | |
return '中性' | |
#检测是否有不合法opinion | |
def check_opinion_in_comment(self, comment_text): | |
for quad_dict in self.data: | |
if quad_dict['opinion_text'] !='*' and (not quad_dict['opinion_text'] in comment_text): | |
return False | |
return True | |
#检测是否有不合法target | |
def check_target_in_comment(self,comment_text): | |
for quad_dict in self.data: | |
if quad_dict['target_text'] !='*' and (not quad_dict['target_text'] in comment_text): | |
return False | |
return True | |
#计算两个四元组的相似度 | |
def get_similarity(units1, units2: 'CommentUnitsSim'): | |
pass | |
#对自身数据进行操作 | |
def apply(self,func:Callable,field:str): | |
for quad_dict in self.data: | |
quad_dict[field] = func(quad_dict[field]) | |
return self | |
#四元组匹配函数 | |
class CommentUnitsMatch: | |
def __init__(self,target_weight=0.5,opinion_weight=0.5,aspect_weight=0.5,polarity_weight=0.5): | |
#归一化权重 | |
weight_sum = target_weight+opinion_weight+aspect_weight+polarity_weight | |
self.target_weight = target_weight/weight_sum | |
self.opinion_weight = opinion_weight/weight_sum | |
self.aspect_weight = aspect_weight/weight_sum | |
self.polarity_weight = polarity_weight/weight_sum | |
#特定feature置零 | |
def set_zero(self,feature:str='polarity'): | |
if feature == 'polarity': | |
self.polarity_weight = 0 | |
elif feature == 'aspect': | |
self.aspect_weight = 0 | |
elif 'opinion' in feature: | |
self.opinion_weight = 0 | |
elif 'target' in feature: | |
self.target_weight = 0 | |
else: | |
raise Exception('feature参数错误') | |
def re_normalize(self): | |
weight_sum = self.target_weight+self.opinion_weight+self.aspect_weight+self.polarity_weight | |
self.target_weight = self.target_weight/weight_sum | |
self.opinion_weight = self.opinion_weight/weight_sum | |
self.aspect_weight = self.aspect_weight/weight_sum | |
self.polarity_weight = self.polarity_weight/weight_sum | |
#计算cost矩阵 | |
def get_cost_matrix(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',feature:str='polarity'): | |
pass | |
#检查此feature是否存在,不存在则返回全0矩阵 | |
if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None\ | |
or units1.data[0].get(feature)=='None' or units2.data[0].get(feature)=='None': | |
cost_matrix = np.zeros((len(units1.data),len(units2.data))) | |
#对应feature的weight也为0 | |
self.set_zero(feature) | |
# 并再次归一化 | |
self.re_normalize() | |
return cost_matrix | |
#检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。不相同则cost为1,相同则cost为0 | |
cost_matrix = [] | |
for quad_dict1 in units1.data: | |
cost_list = [] | |
for quad_dict2 in units2.data: | |
if quad_dict1[feature] == quad_dict2[feature]: | |
cost_list.append(0) | |
else: | |
cost_list.append(1) | |
cost_matrix.append(cost_list) | |
#cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data)) | |
cost_matrix = np.array(cost_matrix) | |
return cost_matrix | |
#计算cost矩阵,使用rouge指标 | |
def get_cost_matrix_rouge(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',feature:str='target_text'): | |
#检查此feature是否存在,不存在则返回全0矩阵 | |
if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None\ | |
or units1.data[0].get(feature)=='None' or units2.data[0].get(feature)=='None': | |
cost_matrix = np.zeros((len(units1.data),len(units2.data))) | |
#对应feature的weight也为0 | |
self.set_zero(feature) | |
# 并再次归一化 | |
self.re_normalize() | |
return cost_matrix | |
#检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。相同则cost为0,不相同则cost为1-rougel | |
cost_matrix = [] | |
for quad_dict1 in units1.data: | |
cost_list = [] | |
for quad_dict2 in units2.data: | |
if quad_dict1[feature] == quad_dict2[feature]: | |
cost_list.append(0) | |
else: | |
cost_list.append(1-get_rougel_f1([quad_dict1[feature]],[quad_dict2[feature]])) | |
cost_matrix.append(cost_list) | |
#cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data)) | |
cost_matrix = np.array(cost_matrix) | |
return cost_matrix | |
def match_units(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',one_match=True)->tuple: | |
#计算极性的cost矩阵,矩阵元素在0-1之间 | |
cost_matrix_polarity = self.get_cost_matrix(units1, units2,feature='polarity') | |
#计算aspect的cost矩阵 | |
cost_matrix_aspect = self.get_cost_matrix(units1, units2,feature='aspect') | |
#计算target的cost矩阵 | |
cost_matrix_target = self.get_cost_matrix_rouge(units1, units2,feature='target_text') | |
#计算opinion的cost矩阵 | |
cost_matrix_opinion = self.get_cost_matrix_rouge(units1, units2,feature='opinion_text') | |
#计算总的cost矩阵,矩阵元素在0-1之间。矩阵的行数为units1即pred的数量,列数为units2即true的数量 | |
cost_matrix = self.target_weight*cost_matrix_target + self.opinion_weight*cost_matrix_opinion + \ | |
self.aspect_weight*cost_matrix_aspect + self.polarity_weight*cost_matrix_polarity | |
score_matrix = 1-cost_matrix | |
#使用匈牙利算法进行匹配 | |
if one_match: | |
row_ind, col_ind = linear_sum_assignment(cost_matrix) | |
else: | |
#允许一对多的匹配 | |
row_ind = np.argmin(cost_matrix, axis=0) | |
col_ind = np.arange(len(units2.data)) | |
max_units_num=max(units1.num,units2.num) | |
#计算这种匹配的cost | |
cost = 0 | |
for i in range(len(row_ind)): | |
cost += cost_matrix[row_ind[i]][col_ind[i]] | |
#计算这种匹配下的TP\FP\FN | |
TP = 0 | |
for i in range(len(row_ind)): | |
TP += score_matrix[row_ind[i]][col_ind[i]] | |
#len(row_ind)为pred的数量,TP为匹配上的数量 | |
FP = units1.num-TP | |
FN = units2.num-TP | |
#匹配不上的四元组,cost为1 | |
cost += (max_units_num-len(row_ind)) | |
cost_per_quadruple=cost/max_units_num | |
if cost_per_quadruple>1 or cost_per_quadruple <0: | |
print('cost错误',cost_per_quadruple,'pred:',units1.data,'true:',units2.data) | |
print(self.target_weight,self.opinion_weight,self.aspect_weight,self.polarity_weight) | |
#返回的cost在0-1之间 | |
return cost_per_quadruple,TP,FP,FN | |
class QuadMatch(evaluate.Metric): | |
"""TODO: Short description of my evaluation module.""" | |
def _info(self): | |
# TODO: Specifies the evaluate.EvaluationModuleInfo object | |
return evaluate.MetricInfo( | |
# This is the description that will appear on the modules page. | |
module_type="metric", | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
# This defines the format of each prediction and reference | |
features=[ | |
datasets.Features( | |
{ | |
"predictions": datasets.Value("string", id="sequence"), | |
"references": datasets.Sequence(datasets.Value("string", id="sequence")), | |
} | |
), | |
datasets.Features( | |
{ | |
"predictions": datasets.Value("string", id="sequence"), | |
"references": datasets.Value("string", id="sequence"), | |
} | |
), | |
], | |
# Homepage of the module for documentation | |
homepage="http://module.homepage", | |
# Additional links to the codebase or references | |
codebase_urls=["http://github.com/path/to/codebase/of/new_module"], | |
reference_urls=["http://path.to.reference.url/new_module"] | |
) | |
def _download_and_prepare(self, dl_manager): | |
"""Optional: download external resources useful to compute the scores""" | |
# TODO: Download external resources if needed | |
pass | |
def _compute(self, | |
predictions:List[str], | |
references: Union[List[str],List[List[str]]], | |
quad_weights:tuple=(1,1,1,1), | |
**kwargs) -> dict: | |
''' | |
:param predictions: list of predictions of sentiment quads | |
:param references: list of references of sentiment quads | |
:param quad_weights: weight of target,opinion,aspect,polarity for cost compute | |
:param kwargs: | |
:param tuple_len: indicate the format of the quad, see the following mapping | |
:param sep_token1: the token to seperate quads | |
:param sep_token2: the token to seperate units of one quad | |
:return:average matching score | |
#mapping | |
id2prompt={'0123':"quadruples (target | opinion | aspect | polarity)", | |
'':"quadruples (target | opinion | aspect | polarity)", | |
'01':'pairs (target | opinion)', | |
'012':'triples (target | opinion | aspect)', | |
'013':'triples (target | opinion | polarity)', | |
'023':'triples (target | aspect | polarity)', | |
'23':'pairs (aspect | polarity)', | |
'03':'pairs (target | polarity)', | |
'13':'pairs (opinion | polarity)', | |
'3':'single (polarity)'} | |
#中文版映射 | |
id2prompt_zh={'0123': "四元组(对象 | 观点 | 方面 | 极性)", | |
'':"四元组(对象 | 观点 | 方面 | 极性)", | |
'01':'二元组(对象 | 观点)', | |
'012':'三元组(对象 | 观点 | 方面)', | |
'013':'三元组(对象 | 观点 | 极性)', | |
'023':'三元组(对象 | 方面 | 极性)', | |
'23':'二元组(方面 | 极性)', | |
'03':'二元组(对象 | 极性)', | |
'13':'二元组(观点 | 极性)', | |
'3':'单元素(极性)'} | |
''' | |
assert len(predictions) == len(references) | |
if isinstance(predictions,str): | |
predictions=[predictions] | |
references=[references] | |
cost=0 | |
TP,FP,FN=0,0,0 | |
matcher = CommentUnitsMatch(*quad_weights) | |
for pred, true in zip(predictions, references): | |
pred = CommentUnitsSim.from_str(pred,**kwargs) | |
# 如果true是list,说明有多个正确答案 | |
if isinstance(true, str): | |
true = CommentUnitsSim.from_str(true, **kwargs) | |
elif isinstance(true, list): | |
true=[CommentUnitsSim.from_str(t, **kwargs) for t in true] | |
else: | |
print("true的类型不对",true) | |
continue | |
#如果true是list,说明有多个正确答案,取最高分 | |
if isinstance(true, list): | |
cost_list=[matcher.match_units(pred,t,one_match=True) for t in true] | |
# 获取得分最高的值的索引,按元组中第一个元素大小排序 | |
cost_,TP_,FP_,FN_ = cost_list[np.argmax([c[0] for c in cost_list])] | |
cost += cost_ | |
TP+=TP_ | |
FP+=FP_ | |
FN+=FN_ | |
else: | |
cost_,TP_,FP_,FN_ = matcher.match_units(pred,true,one_match=True) | |
cost += cost_ | |
TP+=TP_ | |
FP+=FP_ | |
FN+=FN_ | |
#平均cost | |
cost=cost/len(predictions) | |
#由TP\FP\FN计算最优匹配F1 | |
precision_match=TP/(TP+FP) | |
recall_match=TP/(TP+FN) | |
f1_match=2*precision_match*recall_match/(precision_match+recall_match) | |
f1=compute_quadruple_f1(y_pred=predictions,y_true=references, **kwargs) | |
#取1-cost为得分 | |
return {'ave match score of weight '+str(quad_weights):1-cost, | |
'f1 score of optimal match of weight '+str(quad_weights): f1_match, | |
'f1 score of exact match':f1} |