Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- README.md +6 -6
- quad_match_score.py +261 -237
README.md
CHANGED
@@ -47,9 +47,9 @@ references=["food | good | food#taste | pos & service | bad | service#general |
|
|
47 |
result=module.compute(predictions=predictions, references=references)
|
48 |
print(result)
|
49 |
|
50 |
-
result={'
|
51 |
-
'f1
|
52 |
-
'
|
53 |
```
|
54 |
|
55 |
### Inputs
|
@@ -78,9 +78,9 @@ result={'ave match score of weight (1, 1, 1, 1)': 0.375,
|
|
78 |
|
79 |
*最优匹配 f1值、最优匹配样本平均得分、完全匹配 f1值(传统评估) 组成的dict,f1值均在[0,1]之间*
|
80 |
|
81 |
-
*例如:
|
82 |
-
'f1
|
83 |
-
'
|
84 |
|
85 |
|
86 |
## Limitations and Bias
|
|
|
47 |
result=module.compute(predictions=predictions, references=references)
|
48 |
print(result)
|
49 |
|
50 |
+
result={'f1 of exact match': 0.6667,
|
51 |
+
'f1 of optimal match of weight (1, 1, 1, 1)': 0.6666666666666666,
|
52 |
+
'score of optimal match of weight (1, 1, 1, 1)': 0.5}
|
53 |
```
|
54 |
|
55 |
### Inputs
|
|
|
78 |
|
79 |
*最优匹配 f1值、最优匹配样本平均得分、完全匹配 f1值(传统评估) 组成的dict,f1值均在[0,1]之间*
|
80 |
|
81 |
+
*例如:{'f1 of exact match': 0.6667,
|
82 |
+
'f1 of optimal match of weight (1, 1, 1, 1)': 0.6666666666666666,
|
83 |
+
'score of optimal match of weight (1, 1, 1, 1)': 0.5}*
|
84 |
|
85 |
|
86 |
## Limitations and Bias
|
quad_match_score.py
CHANGED
@@ -15,10 +15,9 @@
|
|
15 |
|
16 |
import copy
|
17 |
import re
|
18 |
-
from typing import List, Dict, Union,Callable
|
19 |
import numpy as np
|
20 |
|
21 |
-
|
22 |
import datasets
|
23 |
import evaluate
|
24 |
from rouge_chinese import Rouge
|
@@ -27,7 +26,7 @@ from scipy.optimize import linear_sum_assignment
|
|
27 |
# TODO: Add BibTeX citation
|
28 |
_CITATION = """\
|
29 |
@InProceedings{huggingface:module,
|
30 |
-
title = {
|
31 |
authors={huggingface, Inc.},
|
32 |
year={2020}
|
33 |
}
|
@@ -39,7 +38,6 @@ evaluate sentiment quadruples.
|
|
39 |
评估生成模型的情感四元组
|
40 |
"""
|
41 |
|
42 |
-
|
43 |
# TODO: Add description of the arguments of the module here
|
44 |
_KWARGS_DESCRIPTION = """
|
45 |
Calculates how good are predictions given some references, using certain scores
|
@@ -55,53 +53,22 @@ Examples:
|
|
55 |
Examples should be written in doctest format, and should illustrate how
|
56 |
to use the function.
|
57 |
|
58 |
-
>>>
|
59 |
-
>>>
|
60 |
-
>>>
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
62 |
"""
|
63 |
|
64 |
|
65 |
-
def compute_quadruple_f1(y_pred: List[str], y_true: Union[List[str], List[List[str]]],
|
66 |
-
return_rp=False, **kwargs) -> Dict[str, float]:
|
67 |
-
assert len(y_pred) == len(y_true)
|
68 |
-
correct, pred_num, true_num = 0, 0, 0
|
69 |
-
|
70 |
-
for pred, true in zip(y_pred, y_true):
|
71 |
-
|
72 |
-
pred = CommentUnitsSim.from_str(pred, **kwargs)
|
73 |
-
# 如果true是list,说明有多个正确答案
|
74 |
-
if isinstance(true, str):
|
75 |
-
true = CommentUnitsSim.from_str(true, **kwargs)
|
76 |
-
else:
|
77 |
-
true = [CommentUnitsSim.from_str(t,**kwargs) for t in true]
|
78 |
-
|
79 |
-
# 如果true是list,说明有多个正确答案,取最高分
|
80 |
-
if isinstance(true, list):
|
81 |
-
correct_list = [pred.compare_same(t) for t in true]
|
82 |
-
correct += max(correct_list) # 获取得分最高的值
|
83 |
-
correct_index = correct_list.index(max(correct_list)) # 获取得分最高的索引
|
84 |
-
pred_num += pred.num
|
85 |
-
true_num += true[correct_index].num
|
86 |
-
else:
|
87 |
-
correct += pred.compare_same(true)
|
88 |
-
pred_num += pred.num
|
89 |
-
true_num += true.num
|
90 |
-
|
91 |
-
# 以下结果保留4位小数
|
92 |
-
precision = round(correct / pred_num, 4) + 1e-8
|
93 |
-
recall = round(correct / true_num, 4) + 1e-8
|
94 |
-
f1 = round(2 * precision * recall / (precision + recall), 4)
|
95 |
-
|
96 |
-
if return_rp:
|
97 |
-
return {"precision": precision, "recall": recall, "f1": f1}
|
98 |
-
else:
|
99 |
-
return f1
|
100 |
-
|
101 |
# 计算rougel的f1值
|
102 |
def get_rougel_f1(text_pred_list: List[str], text_true_list: List[str]) -> float:
|
103 |
assert len(text_pred_list) == len(text_true_list), "文本数量不一致"
|
104 |
-
|
105 |
if not text_pred_list[0].strip():
|
106 |
return 0
|
107 |
|
@@ -115,12 +82,13 @@ def get_rougel_f1(text_pred_list: List[str], text_true_list: List[str]) -> float
|
|
115 |
|
116 |
return rouge_l_f1
|
117 |
|
|
|
118 |
# 记录四元组的函数
|
119 |
class CommentUnitsSim:
|
120 |
-
def __init__(self, data: List[Dict[str, str]],data_source:any=None,abnormal=False,language=None):
|
121 |
-
self.data_source=data_source
|
122 |
-
self.abnormal=abnormal
|
123 |
-
data=copy.deepcopy(data)
|
124 |
# 如果字典有target,则改名为target_text
|
125 |
for quad_dict in data:
|
126 |
if 'target' in quad_dict:
|
@@ -131,73 +99,79 @@ class CommentUnitsSim:
|
|
131 |
del quad_dict['opinion']
|
132 |
|
133 |
self.data = data
|
134 |
-
self.polarity_en2zh = {'positive': '积极', 'negative': '消极', 'neutral': '中性','pos':'积极','neg':'消极',
|
135 |
-
|
|
|
|
|
136 |
|
137 |
-
self.language=language if language is not None else 'zh' if self.check_zh() else 'en'
|
138 |
-
self.none_sign='null'
|
139 |
|
140 |
@property
|
141 |
def num(self):
|
142 |
return len(self.data)
|
143 |
|
144 |
-
|
145 |
def check_zh(self):
|
146 |
for quad_dict in self.data:
|
147 |
-
if re.search('[\u4e00-\u9fa5]',quad_dict['target_text']) or re.search('[\u4e00-\u9fa5]',
|
|
|
148 |
return True
|
149 |
return False
|
150 |
|
151 |
# 检测极性是否正确
|
152 |
def check_polarity(self):
|
153 |
-
|
154 |
for quad_dict in self.data:
|
155 |
-
if quad_dict['polarity'] not in ['positive', 'negative', 'neutral','pos','neg','neu','积极','消极',
|
156 |
-
|
|
|
157 |
return False
|
158 |
|
159 |
-
|
160 |
def convert_polarity_en2zh(self):
|
161 |
for quad_dict in self.data:
|
162 |
-
quad_dict['polarity']=self.polarity_en2zh[quad_dict['polarity']]
|
163 |
return self
|
164 |
|
165 |
-
|
166 |
def convert_polarity_zh2en(self):
|
167 |
for quad_dict in self.data:
|
168 |
-
quad_dict['polarity']=self.polarity_zh2en[quad_dict['polarity']]
|
169 |
return self
|
170 |
|
171 |
-
|
172 |
def del_duplicate(self):
|
173 |
-
new_data=[]
|
174 |
for quad_dict in self.data:
|
175 |
if quad_dict not in new_data:
|
176 |
new_data.append(quad_dict)
|
177 |
-
self.data=new_data
|
178 |
return self
|
179 |
|
180 |
-
|
181 |
def check_target_opinion_null(self):
|
182 |
for quad_dict in self.data:
|
183 |
-
if quad_dict['target_text']=='null' and quad_dict['opinion_text']=='null':
|
184 |
return True
|
185 |
return False
|
186 |
|
187 |
-
|
188 |
def check_any_null(self):
|
189 |
for quad_dict in self.data:
|
190 |
-
if quad_dict['target_text']=='null' or quad_dict['opinion_text']=='null':
|
191 |
return True
|
192 |
return False
|
193 |
|
194 |
@classmethod
|
195 |
-
def from_str(cls, quadruple_str: str, tuple_len:Union[int,list,str]=4, format_code=0, sep_token1=' & ',
|
|
|
196 |
data = []
|
197 |
-
abnormal=False
|
198 |
-
|
199 |
-
for i in range(len(quadruple_str)-1):
|
200 |
-
if (quadruple_str[i] == sep_token1.strip() or quadruple_str[i] == sep_token2.strip()) and quadruple_str[
|
|
|
201 |
quadruple_str = quadruple_str[:i + 1] + ' ' + quadruple_str[i + 1:]
|
202 |
|
203 |
# 选择几元组,即创建列表索引,从四元组中抽出n元
|
@@ -211,27 +185,27 @@ class CommentUnitsSim:
|
|
211 |
else:
|
212 |
raise Exception('tuple_len参数错误')
|
213 |
|
214 |
-
|
215 |
for quadruple in quadruple_str.split(sep_token1):
|
216 |
if format_code == 0:
|
217 |
# quadruple可能是target|opinion|aspect|polarity,也可能是target|opinion|aspect,也可能是target|opinion,若没有则为“None”
|
218 |
-
quadruple_split=[unit.strip() for unit in quadruple.split(sep_token2)]
|
219 |
-
if len(quadruple_split)>len(tuple_index):
|
220 |
print('quadruple格式错误,过多元素', quadruple_str)
|
221 |
-
abnormal=True
|
222 |
-
quadruple_split=quadruple_split[0:len(tuple_index)]
|
223 |
-
elif len(quadruple_split)<len(tuple_index):
|
224 |
print('quadruple格式错误,过少元素', quadruple_str)
|
225 |
-
abnormal=True
|
226 |
-
quadruple_split=["None"]*(
|
|
|
227 |
|
228 |
-
quadruple_keys=[["target_text","opinion_text","aspect","polarity"][i] for i in tuple_index]
|
229 |
-
quadruple_dict=dict(zip(quadruple_keys,quadruple_split))
|
230 |
|
231 |
q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'}
|
232 |
q.update(quadruple_dict)
|
233 |
-
|
234 |
-
if q['polarity'] not in ['pos','neg','neu','None','积极','消极','中性']:
|
235 |
print('quadruple格式错误,极性格式不对', quadruple_str)
|
236 |
|
237 |
else:
|
@@ -239,10 +213,10 @@ class CommentUnitsSim:
|
|
239 |
|
240 |
data.append(q)
|
241 |
|
242 |
-
return CommentUnitsSim(data,quadruple_str,abnormal)
|
243 |
|
244 |
@classmethod
|
245 |
-
def from_list(cls, quadruple_list: List[List[str]]
|
246 |
data = []
|
247 |
for quadruple in quadruple_list:
|
248 |
# #format_code='013'代表list只有四元组的第0、1、3个元素,需要扩充为4元组,空缺位置补上None
|
@@ -253,10 +227,10 @@ class CommentUnitsSim:
|
|
253 |
{"target_text": quadruple[0], "opinion_text": quadruple[1], "aspect": quadruple[2],
|
254 |
"polarity": quadruple[3]})
|
255 |
|
256 |
-
return CommentUnitsSim(data,quadruple_list
|
257 |
|
258 |
@classmethod
|
259 |
-
def from_list_dict(cls, quadruple_list: List[dict]
|
260 |
for quad_dict in quadruple_list:
|
261 |
if 'target' in quad_dict:
|
262 |
quad_dict['target_text'] = quad_dict['target']
|
@@ -267,22 +241,24 @@ class CommentUnitsSim:
|
|
267 |
|
268 |
data = []
|
269 |
for quadruple in quadruple_list:
|
270 |
-
|
271 |
-
q={"target_text":'None',"opinion_text":'None',"aspect":'None',"polarity":'None'}
|
272 |
q.update(quadruple)
|
273 |
data.append(q)
|
274 |
|
275 |
-
return CommentUnitsSim(data,quadruple_list
|
276 |
|
277 |
-
|
278 |
def to_list(self):
|
279 |
data = []
|
280 |
for quad_dict in self.data:
|
281 |
-
data.append(
|
|
|
282 |
return data
|
283 |
|
284 |
# 将data转换为n元组字符串
|
285 |
-
def get_quadruple_str(self, format_code=0, tuple_len:Union[int,list,str]=4,sep_token1=' & ',
|
|
|
286 |
new_text_list = []
|
287 |
# 选择几元组,即创建列表索引,从四元组中抽出n元
|
288 |
if isinstance(tuple_len, int):
|
@@ -296,18 +272,18 @@ class CommentUnitsSim:
|
|
296 |
raise Exception('tuple_len参数错误')
|
297 |
|
298 |
try:
|
299 |
-
|
300 |
-
if self.language=='zh':
|
301 |
self.convert_polarity_en2zh()
|
302 |
else:
|
303 |
self.convert_polarity_zh2en()
|
304 |
except:
|
305 |
-
print('
|
306 |
print(self.language)
|
307 |
raise Exception('语言参数错误')
|
308 |
|
309 |
-
|
310 |
-
if tuple_index==[3]:
|
311 |
return self.merge_polarity()
|
312 |
|
313 |
for quad_dict in self.data:
|
@@ -320,7 +296,6 @@ class CommentUnitsSim:
|
|
320 |
# 提取polarity
|
321 |
polarity = quad_dict['polarity']
|
322 |
|
323 |
-
|
324 |
# 拼接,‘|’分割
|
325 |
if format_code == 0:
|
326 |
# 根据tuple_len拼接
|
@@ -330,24 +305,24 @@ class CommentUnitsSim:
|
|
330 |
|
331 |
new_text_list.append(new_text)
|
332 |
|
333 |
-
|
334 |
-
if tuple_index==[2,3]:
|
335 |
res = []
|
336 |
for t in new_text_list:
|
337 |
if t not in res:
|
338 |
res.append(t)
|
339 |
-
new_text_list=res
|
340 |
|
341 |
-
|
342 |
-
elif tuple_index==[3]:
|
343 |
-
new_text_list=new_text_list[:1]
|
344 |
|
345 |
if format_code == 0:
|
346 |
# 根据tuple_len拼接
|
347 |
return sep_token1.join(new_text_list)
|
348 |
|
349 |
# 与另一个CommentUnits对象对比,检测有几个相同的四元组
|
350 |
-
def compare_same(self, other)->int:
|
351 |
count = 0
|
352 |
for quad_dict in self.data:
|
353 |
if quad_dict in other.data:
|
@@ -403,10 +378,10 @@ class CommentUnitsSim:
|
|
403 |
polarity_list.append(quad_dict['polarity'])
|
404 |
return polarity_list
|
405 |
|
406 |
-
|
407 |
def merge_polarity(self):
|
408 |
polarity_list = self.get_polarity_list()
|
409 |
-
|
410 |
if self.language == 'en':
|
411 |
if 'pos' in polarity_list and 'neg' in polarity_list:
|
412 |
return 'neu'
|
@@ -426,44 +401,47 @@ class CommentUnitsSim:
|
|
426 |
else:
|
427 |
return '中性'
|
428 |
|
429 |
-
|
430 |
def check_opinion_in_comment(self, comment_text):
|
431 |
for quad_dict in self.data:
|
432 |
-
if quad_dict['opinion_text'] !='*'
|
433 |
return False
|
434 |
return True
|
435 |
|
436 |
-
|
437 |
-
def check_target_in_comment(self,comment_text):
|
438 |
for quad_dict in self.data:
|
439 |
-
if quad_dict['target_text'] !='*'
|
440 |
return False
|
441 |
return True
|
442 |
|
443 |
-
|
444 |
@staticmethod
|
445 |
def get_similarity(units1, units2: 'CommentUnitsSim'):
|
446 |
pass
|
447 |
|
448 |
-
|
449 |
-
def apply(self,func:Callable,field:str):
|
450 |
for quad_dict in self.data:
|
451 |
quad_dict[field] = func(quad_dict[field])
|
452 |
return self
|
453 |
|
454 |
|
455 |
-
|
456 |
class CommentUnitsMatch:
|
457 |
-
def __init__(self,target_weight=0.5,opinion_weight=0.5,aspect_weight=0.5,polarity_weight=0.5):
|
458 |
-
|
459 |
-
weight_sum = target_weight+opinion_weight+aspect_weight+polarity_weight
|
460 |
-
self.target_weight = target_weight/weight_sum
|
461 |
-
self.opinion_weight = opinion_weight/weight_sum
|
462 |
-
self.aspect_weight = aspect_weight/weight_sum
|
463 |
-
self.polarity_weight = polarity_weight/weight_sum
|
464 |
-
|
465 |
-
|
466 |
-
|
|
|
|
|
|
|
467 |
if feature == 'polarity':
|
468 |
self.polarity_weight = 0
|
469 |
elif feature == 'aspect':
|
@@ -476,21 +454,20 @@ class CommentUnitsMatch:
|
|
476 |
raise Exception('feature参数错误')
|
477 |
|
478 |
def re_normalize(self):
|
479 |
-
weight_sum = self.target_weight+self.opinion_weight+self.aspect_weight+self.polarity_weight
|
480 |
-
self.target_weight = self.target_weight/weight_sum
|
481 |
-
self.opinion_weight = self.opinion_weight/weight_sum
|
482 |
-
self.aspect_weight = self.aspect_weight/weight_sum
|
483 |
-
self.polarity_weight = self.polarity_weight/weight_sum
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
def get_cost_matrix(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',feature:str='polarity'):
|
488 |
pass
|
489 |
-
|
490 |
-
if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None\
|
491 |
-
or units1.data[0].get(feature)=='None' or units2.data[0].get(feature)=='None':
|
492 |
-
cost_matrix = np.zeros((len(units1.data),len(units2.data)))
|
493 |
-
|
494 |
self.set_zero(feature)
|
495 |
|
496 |
# 并再次归一化
|
@@ -498,7 +475,7 @@ class CommentUnitsMatch:
|
|
498 |
|
499 |
return cost_matrix
|
500 |
|
501 |
-
|
502 |
cost_matrix = []
|
503 |
for quad_dict1 in units1.data:
|
504 |
cost_list = []
|
@@ -509,23 +486,23 @@ class CommentUnitsMatch:
|
|
509 |
cost_list.append(1)
|
510 |
cost_matrix.append(cost_list)
|
511 |
|
512 |
-
#cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
|
513 |
cost_matrix = np.array(cost_matrix)
|
514 |
return cost_matrix
|
515 |
|
516 |
-
|
517 |
-
def get_cost_matrix_rouge(self,units1: 'CommentUnitsSim', units2: 'CommentUnitsSim',feature:str='target_text'):
|
518 |
-
|
519 |
-
if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None\
|
520 |
-
or units1.data[0].get(feature)=='None' or units2.data[0].get(feature)=='None':
|
521 |
-
cost_matrix = np.zeros((len(units1.data),len(units2.data)))
|
522 |
-
|
523 |
self.set_zero(feature)
|
524 |
# 并再次归一化
|
525 |
self.re_normalize()
|
526 |
return cost_matrix
|
527 |
|
528 |
-
|
529 |
cost_matrix = []
|
530 |
for quad_dict1 in units1.data:
|
531 |
cost_list = []
|
@@ -533,63 +510,71 @@ class CommentUnitsMatch:
|
|
533 |
if quad_dict1[feature] == quad_dict2[feature]:
|
534 |
cost_list.append(0)
|
535 |
else:
|
536 |
-
cost_list.append(1-get_rougel_f1([quad_dict1[feature]],[quad_dict2[feature]]))
|
537 |
cost_matrix.append(cost_list)
|
538 |
|
539 |
-
#cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
|
540 |
cost_matrix = np.array(cost_matrix)
|
541 |
return cost_matrix
|
542 |
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
|
|
|
|
|
|
|
|
559 |
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
|
|
560 |
else:
|
561 |
-
|
562 |
-
|
563 |
-
|
|
|
564 |
|
565 |
-
|
|
|
|
|
566 |
|
567 |
-
|
568 |
-
cost = 0
|
569 |
for i in range(len(row_ind)):
|
570 |
cost += cost_matrix[row_ind[i]][col_ind[i]]
|
571 |
|
572 |
-
|
573 |
TP = 0
|
574 |
for i in range(len(row_ind)):
|
575 |
TP += score_matrix[row_ind[i]][col_ind[i]]
|
576 |
|
577 |
-
#len(row_ind)为pred的数量,TP为匹配上的数量
|
578 |
-
FP = units1.num-TP
|
579 |
-
FN = units2.num-TP
|
580 |
-
|
581 |
|
582 |
-
|
583 |
-
|
|
|
|
|
584 |
|
585 |
-
|
586 |
-
|
|
|
|
|
|
|
587 |
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
#返回的cost在0-1之间
|
592 |
-
return cost_per_quadruple,TP,FP,FN
|
593 |
|
594 |
|
595 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
@@ -632,9 +617,9 @@ class QuadMatch(evaluate.Metric):
|
|
632 |
pass
|
633 |
|
634 |
def _compute(self,
|
635 |
-
predictions:List[str],
|
636 |
-
references: Union[List[str],List[List[str]]],
|
637 |
-
quad_weights:tuple=(1,1,1,1),
|
638 |
**kwargs) -> dict:
|
639 |
'''
|
640 |
|
@@ -673,55 +658,94 @@ class QuadMatch(evaluate.Metric):
|
|
673 |
'13':'二元组(观点 | 极性)',
|
674 |
'3':'单元素(极性)'}
|
675 |
'''
|
|
|
|
|
|
|
676 |
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
|
688 |
-
|
|
|
689 |
# 如果true是list,说明有多个正确答案
|
690 |
if isinstance(true, str):
|
691 |
true = CommentUnitsSim.from_str(true, **kwargs)
|
692 |
-
elif isinstance(true, list):
|
693 |
-
true=[CommentUnitsSim.from_str(t, **kwargs) for t in true]
|
694 |
else:
|
695 |
-
|
696 |
-
continue
|
697 |
|
698 |
-
|
699 |
if isinstance(true, list):
|
700 |
-
|
701 |
-
#
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
FP+=FP_
|
706 |
-
FN+=FN_
|
707 |
-
|
708 |
else:
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
import copy
|
17 |
import re
|
18 |
+
from typing import List, Dict, Union, Callable
|
19 |
import numpy as np
|
20 |
|
|
|
21 |
import datasets
|
22 |
import evaluate
|
23 |
from rouge_chinese import Rouge
|
|
|
26 |
# TODO: Add BibTeX citation
|
27 |
_CITATION = """\
|
28 |
@InProceedings{huggingface:module,
|
29 |
+
title = {quad match score},
|
30 |
authors={huggingface, Inc.},
|
31 |
year={2020}
|
32 |
}
|
|
|
38 |
评估生成模型的情感四元组
|
39 |
"""
|
40 |
|
|
|
41 |
# TODO: Add description of the arguments of the module here
|
42 |
_KWARGS_DESCRIPTION = """
|
43 |
Calculates how good are predictions given some references, using certain scores
|
|
|
53 |
Examples should be written in doctest format, and should illustrate how
|
54 |
to use the function.
|
55 |
|
56 |
+
>>> import evaluate
|
57 |
+
>>> module = evaluate.load("yuyijiong/quad_match_score")
|
58 |
+
>>> predictions=["food | good | food#taste | pos"]
|
59 |
+
>>> references=["food | good | food#taste | pos & service | bad | service#general | neg"]
|
60 |
+
>>> result=module.compute(predictions=predictions, references=references)
|
61 |
+
>>> print(result)
|
62 |
+
result={'ave match score of weight (1, 1, 1, 1)': 0.375,
|
63 |
+
'f1 score of exact match': 0.0,
|
64 |
+
'f1 score of optimal match of weight (1, 1, 1, 1)': 0.5}
|
65 |
"""
|
66 |
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
# 计算rougel的f1值
|
69 |
def get_rougel_f1(text_pred_list: List[str], text_true_list: List[str]) -> float:
|
70 |
assert len(text_pred_list) == len(text_true_list), "文本数量不一致"
|
71 |
+
# 如果text_pred_list[0]为空字符串或空格,则返回0
|
72 |
if not text_pred_list[0].strip():
|
73 |
return 0
|
74 |
|
|
|
82 |
|
83 |
return rouge_l_f1
|
84 |
|
85 |
+
|
86 |
# 记录四元组的函数
|
87 |
class CommentUnitsSim:
|
88 |
+
def __init__(self, data: List[Dict[str, str]], data_source: any = None, abnormal=False, language=None):
|
89 |
+
self.data_source = data_source
|
90 |
+
self.abnormal = abnormal
|
91 |
+
data = copy.deepcopy(data)
|
92 |
# 如果字典有target,则改名为target_text
|
93 |
for quad_dict in data:
|
94 |
if 'target' in quad_dict:
|
|
|
99 |
del quad_dict['opinion']
|
100 |
|
101 |
self.data = data
|
102 |
+
self.polarity_en2zh = {'positive': '积极', 'negative': '消极', 'neutral': '中性', 'pos': '积极', 'neg': '消极',
|
103 |
+
'neu': '中性', '积极': '积极', '消极': '消极', '中性': '中性'}
|
104 |
+
self.polarity_zh2en = {'积极': 'pos', '消极': 'neg', '中性': 'neu', 'pos': 'pos', 'neg': 'neg', 'neu': 'neu',
|
105 |
+
'positive': 'pos', 'negative': 'neg', 'neutral': 'neu'}
|
106 |
|
107 |
+
self.language = language if language is not None else 'zh' if self.check_zh() else 'en'
|
108 |
+
self.none_sign = 'null'
|
109 |
|
110 |
@property
|
111 |
def num(self):
|
112 |
return len(self.data)
|
113 |
|
114 |
+
# 检查四元组中是否有中文
|
115 |
def check_zh(self):
|
116 |
for quad_dict in self.data:
|
117 |
+
if re.search('[\u4e00-\u9fa5]', quad_dict['target_text']) or re.search('[\u4e00-\u9fa5]',
|
118 |
+
quad_dict['opinion_text']):
|
119 |
return True
|
120 |
return False
|
121 |
|
122 |
# 检测极性是否正确
|
123 |
def check_polarity(self):
|
124 |
+
# 若有某个四元组的极性不是positive、negative、neutral,则返回False
|
125 |
for quad_dict in self.data:
|
126 |
+
if quad_dict['polarity'] not in ['positive', 'negative', 'neutral', 'pos', 'neg', 'neu', '积极', '消极',
|
127 |
+
'中性']:
|
128 |
+
self.abnormal = True
|
129 |
return False
|
130 |
|
131 |
+
# 将极性由英文转为中文
|
132 |
def convert_polarity_en2zh(self):
|
133 |
for quad_dict in self.data:
|
134 |
+
quad_dict['polarity'] = self.polarity_en2zh[quad_dict['polarity']]
|
135 |
return self
|
136 |
|
137 |
+
# 将极性由中文转为英文
|
138 |
def convert_polarity_zh2en(self):
|
139 |
for quad_dict in self.data:
|
140 |
+
quad_dict['polarity'] = self.polarity_zh2en[quad_dict['polarity']]
|
141 |
return self
|
142 |
|
143 |
+
# 检查是否有重复的四元组,若有则删除重复的
|
144 |
def del_duplicate(self):
|
145 |
+
new_data = []
|
146 |
for quad_dict in self.data:
|
147 |
if quad_dict not in new_data:
|
148 |
new_data.append(quad_dict)
|
149 |
+
self.data = new_data
|
150 |
return self
|
151 |
|
152 |
+
# 检查是否有target和opinion都为null的四元组,若有则返回True
|
153 |
def check_target_opinion_null(self):
|
154 |
for quad_dict in self.data:
|
155 |
+
if quad_dict['target_text'] == 'null' and quad_dict['opinion_text'] == 'null':
|
156 |
return True
|
157 |
return False
|
158 |
|
159 |
+
# 检查是否有target或opinion为null的四元组,若有则返回True
|
160 |
def check_any_null(self):
|
161 |
for quad_dict in self.data:
|
162 |
+
if quad_dict['target_text'] == 'null' or quad_dict['opinion_text'] == 'null':
|
163 |
return True
|
164 |
return False
|
165 |
|
166 |
@classmethod
|
167 |
+
def from_str(cls, quadruple_str: str, tuple_len: Union[int, list, str] = 4, format_code=0, sep_token1=' & ',
|
168 |
+
sep_token2=' | '):
|
169 |
data = []
|
170 |
+
abnormal = False
|
171 |
+
# 确保分隔符后面一定是空格
|
172 |
+
for i in range(len(quadruple_str) - 1):
|
173 |
+
if (quadruple_str[i] == sep_token1.strip() or quadruple_str[i] == sep_token2.strip()) and quadruple_str[
|
174 |
+
i + 1] != ' ':
|
175 |
quadruple_str = quadruple_str[:i + 1] + ' ' + quadruple_str[i + 1:]
|
176 |
|
177 |
# 选择几元组,即创建列表索引,从四元组中抽出n元
|
|
|
185 |
else:
|
186 |
raise Exception('tuple_len参数错误')
|
187 |
|
|
|
188 |
for quadruple in quadruple_str.split(sep_token1):
|
189 |
if format_code == 0:
|
190 |
# quadruple可能是target|opinion|aspect|polarity,也可能是target|opinion|aspect,也可能是target|opinion,若没有则为“None”
|
191 |
+
quadruple_split = [unit.strip() for unit in quadruple.split(sep_token2)]
|
192 |
+
if len(quadruple_split) > len(tuple_index):
|
193 |
print('quadruple格式错误,过多元素', quadruple_str)
|
194 |
+
abnormal = True
|
195 |
+
quadruple_split = quadruple_split[0:len(tuple_index)] # 过长则截断
|
196 |
+
elif len(quadruple_split) < len(tuple_index):
|
197 |
print('quadruple格式错误,过少元素', quadruple_str)
|
198 |
+
abnormal = True
|
199 |
+
quadruple_split = ["None"] * (
|
200 |
+
len(tuple_index) - len(quadruple_split)) + quadruple_split # 过短则补'None'
|
201 |
|
202 |
+
quadruple_keys = [["target_text", "opinion_text", "aspect", "polarity"][i] for i in tuple_index]
|
203 |
+
quadruple_dict = dict(zip(quadruple_keys, quadruple_split))
|
204 |
|
205 |
q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'}
|
206 |
q.update(quadruple_dict)
|
207 |
+
# 检查极性是否合法
|
208 |
+
if q['polarity'] not in ['pos', 'neg', 'neu', 'None', '积极', '消极', '中性']:
|
209 |
print('quadruple格式错误,极性格式不对', quadruple_str)
|
210 |
|
211 |
else:
|
|
|
213 |
|
214 |
data.append(q)
|
215 |
|
216 |
+
return CommentUnitsSim(data, quadruple_str, abnormal)
|
217 |
|
218 |
@classmethod
|
219 |
+
def from_list(cls, quadruple_list: List[List[str]], **kwargs):
|
220 |
data = []
|
221 |
for quadruple in quadruple_list:
|
222 |
# #format_code='013'代表list只有四元组的第0、1、3个元素,需要扩充为4元组,空缺位置补上None
|
|
|
227 |
{"target_text": quadruple[0], "opinion_text": quadruple[1], "aspect": quadruple[2],
|
228 |
"polarity": quadruple[3]})
|
229 |
|
230 |
+
return CommentUnitsSim(data, quadruple_list, **kwargs)
|
231 |
|
232 |
@classmethod
|
233 |
+
def from_list_dict(cls, quadruple_list: List[dict], **kwargs):
|
234 |
for quad_dict in quadruple_list:
|
235 |
if 'target' in quad_dict:
|
236 |
quad_dict['target_text'] = quad_dict['target']
|
|
|
241 |
|
242 |
data = []
|
243 |
for quadruple in quadruple_list:
|
244 |
+
# 如果quadruple缺少某个key,则补上None
|
245 |
+
q = {"target_text": 'None', "opinion_text": 'None', "aspect": 'None', "polarity": 'None'}
|
246 |
q.update(quadruple)
|
247 |
data.append(q)
|
248 |
|
249 |
+
return CommentUnitsSim(data, quadruple_list, **kwargs)
|
250 |
|
251 |
+
# 转化为list,即只保留字典的value
|
252 |
def to_list(self):
|
253 |
data = []
|
254 |
for quad_dict in self.data:
|
255 |
+
data.append(
|
256 |
+
[quad_dict['target_text'], quad_dict['opinion_text'], quad_dict['aspect'], quad_dict['polarity']])
|
257 |
return data
|
258 |
|
259 |
# 将data转换为n元组字符串
|
260 |
+
def get_quadruple_str(self, format_code=0, tuple_len: Union[int, list, str] = 4, sep_token1=' & ',
|
261 |
+
sep_token2=' | '):
|
262 |
new_text_list = []
|
263 |
# 选择几元组,即创建列表索引,从四元组中抽出n元
|
264 |
if isinstance(tuple_len, int):
|
|
|
272 |
raise Exception('tuple_len参数错误')
|
273 |
|
274 |
try:
|
275 |
+
# 若语言为中文,则使用中文极性
|
276 |
+
if self.language == 'zh':
|
277 |
self.convert_polarity_en2zh()
|
278 |
else:
|
279 |
self.convert_polarity_zh2en()
|
280 |
except:
|
281 |
+
print('语言参数错误', self.data)
|
282 |
print(self.language)
|
283 |
raise Exception('语言参数错误')
|
284 |
|
285 |
+
# 若tuple_index==[3],则返回综合情感极性
|
286 |
+
if tuple_index == [3]:
|
287 |
return self.merge_polarity()
|
288 |
|
289 |
for quad_dict in self.data:
|
|
|
296 |
# 提取polarity
|
297 |
polarity = quad_dict['polarity']
|
298 |
|
|
|
299 |
# 拼接,‘|’分割
|
300 |
if format_code == 0:
|
301 |
# 根据tuple_len拼接
|
|
|
305 |
|
306 |
new_text_list.append(new_text)
|
307 |
|
308 |
+
# 如果tuple_index为[2,3],则需要去除new_text_list中重复的元素,不要改变顺序。因为可能有重复的方面
|
309 |
+
if tuple_index == [2, 3]:
|
310 |
res = []
|
311 |
for t in new_text_list:
|
312 |
if t not in res:
|
313 |
res.append(t)
|
314 |
+
new_text_list = res
|
315 |
|
316 |
+
# 如果tuple_index为[3],则只保留new_text_list的第一个元素。因为只有一个情感极性
|
317 |
+
elif tuple_index == [3]:
|
318 |
+
new_text_list = new_text_list[:1]
|
319 |
|
320 |
if format_code == 0:
|
321 |
# 根据tuple_len拼接
|
322 |
return sep_token1.join(new_text_list)
|
323 |
|
324 |
# 与另一个CommentUnits对象对比,检测有几个相同的四元组
|
325 |
+
def compare_same(self, other) -> int:
|
326 |
count = 0
|
327 |
for quad_dict in self.data:
|
328 |
if quad_dict in other.data:
|
|
|
378 |
polarity_list.append(quad_dict['polarity'])
|
379 |
return polarity_list
|
380 |
|
381 |
+
# 对所有polarity进行综合
|
382 |
def merge_polarity(self):
|
383 |
polarity_list = self.get_polarity_list()
|
384 |
+
# 判断是英文还是中文
|
385 |
if self.language == 'en':
|
386 |
if 'pos' in polarity_list and 'neg' in polarity_list:
|
387 |
return 'neu'
|
|
|
401 |
else:
|
402 |
return '中性'
|
403 |
|
404 |
+
# 检测是否有不合法opinion
|
405 |
def check_opinion_in_comment(self, comment_text):
|
406 |
for quad_dict in self.data:
|
407 |
+
if quad_dict['opinion_text'] != '*' and (not quad_dict['opinion_text'] in comment_text):
|
408 |
return False
|
409 |
return True
|
410 |
|
411 |
+
# 检测是否有不合法target
|
412 |
+
def check_target_in_comment(self, comment_text):
|
413 |
for quad_dict in self.data:
|
414 |
+
if quad_dict['target_text'] != '*' and (not quad_dict['target_text'] in comment_text):
|
415 |
return False
|
416 |
return True
|
417 |
|
418 |
+
# 计算两个四元组的相似度
|
419 |
@staticmethod
|
420 |
def get_similarity(units1, units2: 'CommentUnitsSim'):
|
421 |
pass
|
422 |
|
423 |
+
# 对自身数据进行操作
|
424 |
+
def apply(self, func: Callable, field: str):
|
425 |
for quad_dict in self.data:
|
426 |
quad_dict[field] = func(quad_dict[field])
|
427 |
return self
|
428 |
|
429 |
|
430 |
+
# 四元组匹配函数
|
431 |
class CommentUnitsMatch:
|
432 |
+
def __init__(self, target_weight=0.5, opinion_weight=0.5, aspect_weight=0.5, polarity_weight=0.5, one_match=True):
|
433 |
+
# 归一化权重
|
434 |
+
weight_sum = target_weight + opinion_weight + aspect_weight + polarity_weight
|
435 |
+
self.target_weight = target_weight / weight_sum
|
436 |
+
self.opinion_weight = opinion_weight / weight_sum
|
437 |
+
self.aspect_weight = aspect_weight / weight_sum
|
438 |
+
self.polarity_weight = polarity_weight / weight_sum
|
439 |
+
|
440 |
+
# 是否一对一匹配
|
441 |
+
self.one_match = one_match
|
442 |
+
|
443 |
+
# 特定feature置零
|
444 |
+
def set_zero(self, feature: str = 'polarity'):
|
445 |
if feature == 'polarity':
|
446 |
self.polarity_weight = 0
|
447 |
elif feature == 'aspect':
|
|
|
454 |
raise Exception('feature参数错误')
|
455 |
|
456 |
def re_normalize(self):
|
457 |
+
weight_sum = self.target_weight + self.opinion_weight + self.aspect_weight + self.polarity_weight
|
458 |
+
self.target_weight = self.target_weight / weight_sum
|
459 |
+
self.opinion_weight = self.opinion_weight / weight_sum
|
460 |
+
self.aspect_weight = self.aspect_weight / weight_sum
|
461 |
+
self.polarity_weight = self.polarity_weight / weight_sum
|
462 |
+
|
463 |
+
# 计算cost矩阵,完全匹配为0,不匹配为1
|
464 |
+
def get_cost_matrix(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim', feature: str = 'polarity'):
|
|
|
465 |
pass
|
466 |
+
# 检查此feature是否存在,不存在则返回全0矩阵
|
467 |
+
if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None \
|
468 |
+
or units1.data[0].get(feature) == 'None' or units2.data[0].get(feature) == 'None':
|
469 |
+
cost_matrix = np.zeros((len(units1.data), len(units2.data)))
|
470 |
+
# 对应feature的weight也为0
|
471 |
self.set_zero(feature)
|
472 |
|
473 |
# 并再次归一化
|
|
|
475 |
|
476 |
return cost_matrix
|
477 |
|
478 |
+
# 检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。不相同则cost为1,相同则cost为0
|
479 |
cost_matrix = []
|
480 |
for quad_dict1 in units1.data:
|
481 |
cost_list = []
|
|
|
486 |
cost_list.append(1)
|
487 |
cost_matrix.append(cost_list)
|
488 |
|
489 |
+
# cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
|
490 |
cost_matrix = np.array(cost_matrix)
|
491 |
return cost_matrix
|
492 |
|
493 |
+
# 计算cost矩阵,使用rougel指标
|
494 |
+
def get_cost_matrix_rouge(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim', feature: str = 'target_text'):
|
495 |
+
# 检查此feature是否存在,不存在则返回全0矩阵
|
496 |
+
if units1.data[0].get(feature) is None or units2.data[0].get(feature) is None \
|
497 |
+
or units1.data[0].get(feature) == 'None' or units2.data[0].get(feature) == 'None':
|
498 |
+
cost_matrix = np.zeros((len(units1.data), len(units2.data)))
|
499 |
+
# 对应feature的weight也为0
|
500 |
self.set_zero(feature)
|
501 |
# 并再次归一化
|
502 |
self.re_normalize()
|
503 |
return cost_matrix
|
504 |
|
505 |
+
# 检查两个四元组的极性是否相同,生成cost矩阵,用于匈牙利算法。相同则cost为0,不相同则cost为1-rougel
|
506 |
cost_matrix = []
|
507 |
for quad_dict1 in units1.data:
|
508 |
cost_list = []
|
|
|
510 |
if quad_dict1[feature] == quad_dict2[feature]:
|
511 |
cost_list.append(0)
|
512 |
else:
|
513 |
+
cost_list.append(1 - get_rougel_f1([quad_dict1[feature]], [quad_dict2[feature]]))
|
514 |
cost_matrix.append(cost_list)
|
515 |
|
516 |
+
# cost矩阵转换为numpy数组,大小为(len(units1.data),len(units2.data))
|
517 |
cost_matrix = np.array(cost_matrix)
|
518 |
return cost_matrix
|
519 |
|
520 |
+
# 匹配四元组并计算cost
|
521 |
+
def match_units(self, units1: 'CommentUnitsSim', units2: 'CommentUnitsSim') -> tuple:
|
522 |
+
# 计算极性的cost矩阵,矩阵元素在0-1之间
|
523 |
+
cost_matrix_polarity = self.get_cost_matrix(units1, units2, feature='polarity')
|
524 |
+
# 计算aspect的cost矩阵
|
525 |
+
cost_matrix_aspect = self.get_cost_matrix(units1, units2, feature='aspect')
|
526 |
+
# 计算target的cost矩阵
|
527 |
+
cost_matrix_target = self.get_cost_matrix_rouge(units1, units2, feature='target_text')
|
528 |
+
# 计算opinion的cost矩阵
|
529 |
+
cost_matrix_opinion = self.get_cost_matrix_rouge(units1, units2, feature='opinion_text')
|
530 |
+
|
531 |
+
# 计算总的cost矩阵,矩阵元素在0-1之间。矩阵的行数为units1即pred的数量,列数为units2即true的数量
|
532 |
+
cost_matrix = self.target_weight * cost_matrix_target + self.opinion_weight * cost_matrix_opinion + \
|
533 |
+
self.aspect_weight * cost_matrix_aspect + self.polarity_weight * cost_matrix_polarity
|
534 |
+
score_matrix = 1 - cost_matrix
|
535 |
+
|
536 |
+
cost = 0
|
537 |
+
# 使用匈牙利算法进行匹配
|
538 |
+
if self.one_match:
|
539 |
+
# 只允许一对一的匹配,这种情况下row_ind和col_ind的长度一定相等且等于units1和units2的数量中的较小值
|
540 |
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
541 |
+
|
542 |
else:
|
543 |
+
# 允许一对多的匹配。这种情况下每个四元组都一定匹配上,这种情况下row_ind和col_ind的长度一定相等且等于units1和units2的数量中的较大值
|
544 |
+
if units1.num > units2.num:
|
545 |
+
row_ind = np.arange(units1.num)
|
546 |
+
col_ind = np.argmin(cost_matrix, axis=1)
|
547 |
|
548 |
+
else:
|
549 |
+
row_ind = np.argmin(cost_matrix, axis=0)
|
550 |
+
col_ind = np.arange(units2.num)
|
551 |
|
552 |
+
# 计算这种匹配的cost
|
|
|
553 |
for i in range(len(row_ind)):
|
554 |
cost += cost_matrix[row_ind[i]][col_ind[i]]
|
555 |
|
556 |
+
# 计算这种匹配下的TP\FP\FN
|
557 |
TP = 0
|
558 |
for i in range(len(row_ind)):
|
559 |
TP += score_matrix[row_ind[i]][col_ind[i]]
|
560 |
|
561 |
+
# len(row_ind)为pred的数量,TP为匹配上的数量
|
562 |
+
FP = units1.num - TP
|
563 |
+
FN = units2.num - TP
|
|
|
564 |
|
565 |
+
# 如果一对一匹配,会有匹配不上的四元组,这些四元组cost为1
|
566 |
+
max_units_num = max(units1.num, units2.num)
|
567 |
+
if self.one_match:
|
568 |
+
cost += (max_units_num - len(row_ind))
|
569 |
|
570 |
+
# 对cost进行归一化,使其在0-1之间
|
571 |
+
cost_per_quadruple = cost / max_units_num
|
572 |
+
if cost_per_quadruple > 1 or cost_per_quadruple < 0:
|
573 |
+
print('cost错误', cost_per_quadruple, 'pred:', units1.data, 'true:', units2.data)
|
574 |
+
print(self.target_weight, self.opinion_weight, self.aspect_weight, self.polarity_weight)
|
575 |
|
576 |
+
# 返回的cost在0-1之间
|
577 |
+
return cost_per_quadruple, TP, FP, FN
|
|
|
|
|
|
|
578 |
|
579 |
|
580 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
|
|
617 |
pass
|
618 |
|
619 |
def _compute(self,
|
620 |
+
predictions: List[str],
|
621 |
+
references: Union[List[str], List[List[str]]],
|
622 |
+
quad_weights: tuple = (1, 1, 1, 1),
|
623 |
**kwargs) -> dict:
|
624 |
'''
|
625 |
|
|
|
658 |
'13':'二元组(观点 | 极性)',
|
659 |
'3':'单元素(极性)'}
|
660 |
'''
|
661 |
+
f1_of_optimal_match, score_of_optimal_match = self.quad_f1_of_optimal_match(predictions, references,
|
662 |
+
quad_weights, **kwargs)
|
663 |
+
f1 = self.quad_f1_of_exact_match(y_pred=predictions, y_true=references, **kwargs)
|
664 |
|
665 |
+
# 取1-cost为得分
|
666 |
+
return {'score of optimal match of weight ' + str(quad_weights): score_of_optimal_match,
|
667 |
+
'f1 of optimal match of weight ' + str(quad_weights): f1_of_optimal_match,
|
668 |
+
'f1 of exact match': f1}
|
669 |
|
670 |
+
@staticmethod
|
671 |
+
def quad_f1_of_exact_match(y_pred: List[str], y_true: Union[List[str], List[List[str]]],
|
672 |
+
return_dict=False, **kwargs) -> Union[Dict[str, float], float]:
|
673 |
+
assert len(y_pred) == len(y_true), "文本数量不一致"
|
674 |
+
correct, pred_num, true_num = 0, 0, 0
|
675 |
|
676 |
+
for pred, true in zip(y_pred, y_true):
|
677 |
+
pred = CommentUnitsSim.from_str(pred, **kwargs)
|
678 |
# 如果true是list,说明有多个正确答案
|
679 |
if isinstance(true, str):
|
680 |
true = CommentUnitsSim.from_str(true, **kwargs)
|
|
|
|
|
681 |
else:
|
682 |
+
true = [CommentUnitsSim.from_str(t, **kwargs) for t in true]
|
|
|
683 |
|
684 |
+
# 如果true是list,说明有多个正确答案,取最高分
|
685 |
if isinstance(true, list):
|
686 |
+
correct_list = [pred.compare_same(t) for t in true]
|
687 |
+
correct += max(correct_list) # 获取得分最高的值
|
688 |
+
correct_index = correct_list.index(max(correct_list)) # 获取得分最高的索引
|
689 |
+
pred_num += pred.num
|
690 |
+
true_num += true[correct_index].num
|
|
|
|
|
|
|
691 |
else:
|
692 |
+
correct += pred.compare_same(true)
|
693 |
+
pred_num += pred.num
|
694 |
+
true_num += true.num
|
695 |
+
|
696 |
+
# 以下结果保留4位小数
|
697 |
+
precision = round(correct / pred_num, 4) + 1e-8
|
698 |
+
recall = round(correct / true_num, 4) + 1e-8
|
699 |
+
f1 = round(2 * precision * recall / (precision + recall), 4)
|
700 |
+
|
701 |
+
if return_dict:
|
702 |
+
return {"precision": precision, "recall": recall, "f1": f1}
|
703 |
+
else:
|
704 |
+
return f1
|
705 |
+
|
706 |
+
# 计算最优匹配f1
|
707 |
+
@staticmethod
|
708 |
+
def quad_f1_of_optimal_match(
|
709 |
+
predictions: List[str],
|
710 |
+
references: Union[List[str], List[List[str]]],
|
711 |
+
quad_weights: tuple = (1, 1, 1, 1),
|
712 |
+
one_match=True,
|
713 |
+
**kwargs):
|
714 |
+
|
715 |
+
assert len(predictions) == len(references)
|
716 |
+
if isinstance(predictions, str):
|
717 |
+
predictions = [predictions]
|
718 |
+
references = [references]
|
719 |
+
|
720 |
+
cost = 0
|
721 |
+
TP, FP, FN = 0, 0, 0
|
722 |
+
matcher = CommentUnitsMatch(*quad_weights, one_match=one_match)
|
723 |
+
|
724 |
+
for pred, refer in zip(predictions, references):
|
725 |
+
|
726 |
+
pred = CommentUnitsSim.from_str(pred, **kwargs)
|
727 |
+
# 将refer转换为list形式
|
728 |
+
if isinstance(refer, str):
|
729 |
+
refer = [refer]
|
730 |
+
|
731 |
+
# 将refer中的每个元素转换为CommentUnitsSim
|
732 |
+
refer = [CommentUnitsSim.from_str(t, **kwargs) for t in refer]
|
733 |
+
|
734 |
+
# 如果true是多个正确答案,取最高分
|
735 |
+
cost_list = [matcher.match_units(pred, t) for t in refer]
|
736 |
+
# 获取得分最高的值的索引,按元组中第一个元素大小排序
|
737 |
+
# 计算每一对样本的cost,TP,FP,FN
|
738 |
+
cost_, TP_, FP_, FN_ = cost_list[np.argmax([c[0] for c in cost_list])]
|
739 |
+
cost += cost_
|
740 |
+
TP += TP_
|
741 |
+
FP += FP_
|
742 |
+
FN += FN_
|
743 |
+
|
744 |
+
# 平均cost
|
745 |
+
cost = cost / len(predictions)
|
746 |
+
# 由TP\FP\FN计算最优匹配F1
|
747 |
+
precision_match = TP / (TP + FP)
|
748 |
+
recall_match = TP / (TP + FN)
|
749 |
+
f1_match = 2 * precision_match * recall_match / (precision_match + recall_match)
|
750 |
+
|
751 |
+
return f1_match, 1 - cost
|