Spaces:
Running
Running
import string | |
import numpy as np | |
from rapidfuzz.distance import Levenshtein | |
from .rec_metric import stream_match | |
# f_pred = open('pred_focal_subs_rand1_h2_bi_first.txt', 'w') | |
class RecMetricLong(object): | |
def __init__(self, | |
main_indicator='acc', | |
is_filter=False, | |
ignore_space=True, | |
stream=False, | |
**kwargs): | |
self.main_indicator = main_indicator | |
self.is_filter = is_filter | |
self.ignore_space = ignore_space | |
self.stream = stream | |
self.eps = 1e-5 | |
self.max_len = 201 | |
self.reset() | |
def _normalize_text(self, text): | |
text = ''.join( | |
filter(lambda x: x in (string.digits + string.ascii_letters), | |
text)) | |
return text.lower() | |
def __call__(self, pred_label, *args, **kwargs): | |
preds, labels = pred_label | |
correct_num = 0 | |
correct_num_slice = 0 | |
f_l_acc = 0 | |
all_num = 0 | |
norm_edit_dis = 0.0 | |
len_acc = 0 | |
each_len_num = [0 for _ in range(self.max_len)] | |
each_len_correct_num = [0 for _ in range(self.max_len)] | |
each_len_norm_edit_dis = [0 for _ in range(self.max_len)] | |
for (pred, pred_conf), (target, _) in zip(preds, labels): | |
if self.stream: | |
assert len(labels) == 1 | |
pred, _ = stream_match(preds) | |
if self.ignore_space: | |
pred = pred.replace(' ', '') | |
target = target.replace(' ', '') | |
if self.is_filter: | |
pred = self._normalize_text(pred) | |
target = self._normalize_text(target) | |
dis = Levenshtein.normalized_distance(pred, target) | |
norm_edit_dis += dis | |
# print(pred, target) | |
if pred == target: | |
correct_num += 1 | |
each_len_correct_num[len(target)] += 1 | |
each_len_num[len(target)] += 1 | |
each_len_norm_edit_dis[len(target)] += dis | |
# f_pred.write(pred+'\t'+target+'\t1'+'\n') | |
# print(pred, target, 1) | |
# else: | |
# f_pred.write(pred+'\t'+target+'\t0'+'\n') | |
# print(pred, target, 0) | |
if len(pred) >= 1 and len(target) >= 1: | |
if pred[0] == target[0] and pred[-1] == target[-1]: | |
f_l_acc += 1 | |
if len(pred) == len(target): | |
len_acc += 1 | |
if pred == target[:len(pred)]: | |
# if pred == target[-len(pred):]: | |
correct_num_slice += 1 | |
all_num += 1 | |
self.correct_num += correct_num | |
self.correct_num_slice += correct_num_slice | |
self.f_l_acc += f_l_acc | |
self.all_num += all_num | |
self.len_acc += len_acc | |
self.each_len_num = self.each_len_num + np.array(each_len_num) | |
self.each_len_correct_num = self.each_len_correct_num + np.array( | |
each_len_correct_num) | |
self.each_len_norm_edit_dis = self.each_len_norm_edit_dis + np.array( | |
each_len_norm_edit_dis) | |
self.norm_edit_dis += norm_edit_dis | |
return { | |
'acc': correct_num / (all_num + self.eps), | |
'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps), | |
} | |
def get_metric(self): | |
""" | |
return metrics { | |
'acc': 0, | |
'norm_edit_dis': 0, | |
} | |
""" | |
acc = 1.0 * self.correct_num / (self.all_num + self.eps) | |
acc_slice = 1.0 * self.correct_num_slice / (self.all_num + self.eps) | |
f_l_acc = 1.0 * self.f_l_acc / (self.all_num + self.eps) | |
len_acc = 1.0 * self.len_acc / (self.all_num + self.eps) | |
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps) | |
each_len_acc = (self.each_len_correct_num / | |
(self.each_len_num + self.eps)).tolist() | |
# each_len_acc_25 = each_len_acc[:26] | |
# each_len_acc_26 = each_len_acc[26:] | |
each_len_norm_edit_dis = (1 - | |
((self.each_len_norm_edit_dis) / | |
((self.each_len_num) + self.eps))).tolist() | |
# each_len_norm_edit_dis_25 = each_len_norm_edit_dis[:26] | |
# each_len_norm_edit_dis_26 = each_len_norm_edit_dis[26:] | |
each_len_num = self.each_len_num.tolist() | |
all_num = self.all_num | |
self.reset() | |
return { | |
'acc': acc, | |
'norm_edit_dis': norm_edit_dis, | |
'acc_slice': acc_slice, | |
'f_l_acc': f_l_acc, | |
'len_acc': len_acc, | |
'each_len_num': each_len_num, | |
'each_len_acc': each_len_acc, | |
# "each_len_acc_25": each_len_acc_25, | |
# "each_len_acc_26": each_len_acc_26, | |
'each_len_norm_edit_dis': each_len_norm_edit_dis, | |
# "each_len_norm_edit_dis_25":each_len_norm_edit_dis_25, | |
# "each_len_norm_edit_dis_26":each_len_norm_edit_dis_26, | |
'all_num': all_num | |
} | |
def reset(self): | |
self.correct_num = 0 | |
self.all_num = 0 | |
self.norm_edit_dis = 0 | |
self.correct_num_slice = 0 | |
self.each_len_num = np.array([0 for _ in range(self.max_len)]) | |
self.each_len_correct_num = np.array([0 for _ in range(self.max_len)]) | |
self.each_len_norm_edit_dis = np.array( | |
[0. for _ in range(self.max_len)]) | |
self.f_l_acc = 0 | |
self.len_acc = 0 | |