Spaces:
Running
Running
File size: 5,476 Bytes
29f689c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import string
import numpy as np
from rapidfuzz.distance import Levenshtein
from .rec_metric import stream_match
# f_pred = open('pred_focal_subs_rand1_h2_bi_first.txt', 'w')
class RecMetricLong(object):
def __init__(self,
main_indicator='acc',
is_filter=False,
ignore_space=True,
stream=False,
**kwargs):
self.main_indicator = main_indicator
self.is_filter = is_filter
self.ignore_space = ignore_space
self.stream = stream
self.eps = 1e-5
self.max_len = 201
self.reset()
def _normalize_text(self, text):
text = ''.join(
filter(lambda x: x in (string.digits + string.ascii_letters),
text))
return text.lower()
def __call__(self, pred_label, *args, **kwargs):
preds, labels = pred_label
correct_num = 0
correct_num_slice = 0
f_l_acc = 0
all_num = 0
norm_edit_dis = 0.0
len_acc = 0
each_len_num = [0 for _ in range(self.max_len)]
each_len_correct_num = [0 for _ in range(self.max_len)]
each_len_norm_edit_dis = [0 for _ in range(self.max_len)]
for (pred, pred_conf), (target, _) in zip(preds, labels):
if self.stream:
assert len(labels) == 1
pred, _ = stream_match(preds)
if self.ignore_space:
pred = pred.replace(' ', '')
target = target.replace(' ', '')
if self.is_filter:
pred = self._normalize_text(pred)
target = self._normalize_text(target)
dis = Levenshtein.normalized_distance(pred, target)
norm_edit_dis += dis
# print(pred, target)
if pred == target:
correct_num += 1
each_len_correct_num[len(target)] += 1
each_len_num[len(target)] += 1
each_len_norm_edit_dis[len(target)] += dis
# f_pred.write(pred+'\t'+target+'\t1'+'\n')
# print(pred, target, 1)
# else:
# f_pred.write(pred+'\t'+target+'\t0'+'\n')
# print(pred, target, 0)
if len(pred) >= 1 and len(target) >= 1:
if pred[0] == target[0] and pred[-1] == target[-1]:
f_l_acc += 1
if len(pred) == len(target):
len_acc += 1
if pred == target[:len(pred)]:
# if pred == target[-len(pred):]:
correct_num_slice += 1
all_num += 1
self.correct_num += correct_num
self.correct_num_slice += correct_num_slice
self.f_l_acc += f_l_acc
self.all_num += all_num
self.len_acc += len_acc
self.each_len_num = self.each_len_num + np.array(each_len_num)
self.each_len_correct_num = self.each_len_correct_num + np.array(
each_len_correct_num)
self.each_len_norm_edit_dis = self.each_len_norm_edit_dis + np.array(
each_len_norm_edit_dis)
self.norm_edit_dis += norm_edit_dis
return {
'acc': correct_num / (all_num + self.eps),
'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps),
}
def get_metric(self):
"""
return metrics {
'acc': 0,
'norm_edit_dis': 0,
}
"""
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
acc_slice = 1.0 * self.correct_num_slice / (self.all_num + self.eps)
f_l_acc = 1.0 * self.f_l_acc / (self.all_num + self.eps)
len_acc = 1.0 * self.len_acc / (self.all_num + self.eps)
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
each_len_acc = (self.each_len_correct_num /
(self.each_len_num + self.eps)).tolist()
# each_len_acc_25 = each_len_acc[:26]
# each_len_acc_26 = each_len_acc[26:]
each_len_norm_edit_dis = (1 -
((self.each_len_norm_edit_dis) /
((self.each_len_num) + self.eps))).tolist()
# each_len_norm_edit_dis_25 = each_len_norm_edit_dis[:26]
# each_len_norm_edit_dis_26 = each_len_norm_edit_dis[26:]
each_len_num = self.each_len_num.tolist()
all_num = self.all_num
self.reset()
return {
'acc': acc,
'norm_edit_dis': norm_edit_dis,
'acc_slice': acc_slice,
'f_l_acc': f_l_acc,
'len_acc': len_acc,
'each_len_num': each_len_num,
'each_len_acc': each_len_acc,
# "each_len_acc_25": each_len_acc_25,
# "each_len_acc_26": each_len_acc_26,
'each_len_norm_edit_dis': each_len_norm_edit_dis,
# "each_len_norm_edit_dis_25":each_len_norm_edit_dis_25,
# "each_len_norm_edit_dis_26":each_len_norm_edit_dis_26,
'all_num': all_num
}
def reset(self):
self.correct_num = 0
self.all_num = 0
self.norm_edit_dis = 0
self.correct_num_slice = 0
self.each_len_num = np.array([0 for _ in range(self.max_len)])
self.each_len_correct_num = np.array([0 for _ in range(self.max_len)])
self.each_len_norm_edit_dis = np.array(
[0. for _ in range(self.max_len)])
self.f_l_acc = 0
self.len_acc = 0
|