Spaces:
Running
Running
import string | |
import numpy as np | |
from rapidfuzz.distance import Levenshtein | |
def match_ss(ss1, ss2): | |
s1_len = len(ss1) | |
for c_i in range(s1_len): | |
if ss1[c_i:] == ss2[:s1_len - c_i]: | |
return ss2[s1_len - c_i:] | |
return ss2 | |
def stream_match(text): | |
bs = len(text) | |
s_list = [] | |
conf_list = [] | |
for s_conf in text: | |
s_list.append(s_conf[0]) | |
conf_list.append(s_conf[1]) | |
s_n = bs | |
s_start = s_list[0][:-1] | |
s_new = s_start | |
for s_i in range(1, s_n): | |
s_start = match_ss( | |
s_start, s_list[s_i][1:-1] if s_i < s_n - 1 else s_list[s_i][1:]) | |
s_new += s_start | |
return s_new, sum(conf_list) / bs | |
class RecMetric(object): | |
def __init__(self, | |
main_indicator='acc', | |
is_filter=False, | |
is_lower=True, | |
ignore_space=True, | |
stream=False, | |
with_ratio=False, | |
max_len=25, | |
max_ratio=4, | |
**kwargs): | |
self.main_indicator = main_indicator | |
self.is_filter = is_filter | |
self.is_lower = is_lower | |
self.ignore_space = ignore_space | |
self.stream = stream | |
self.eps = 1e-5 | |
self.with_ratio = with_ratio | |
self.max_len = max_len | |
self.max_ratio = max_ratio | |
self.reset() | |
def _normalize_text(self, text): | |
text = ''.join( | |
filter(lambda x: x in (string.digits + string.ascii_letters), | |
text)) | |
return text | |
def __call__(self, | |
pred_label, | |
batch=None, | |
training=False, | |
*args, | |
**kwargs): | |
if self.with_ratio and not training: | |
return self.eval_all_metric(pred_label, batch) | |
else: | |
return self.eval_metric(pred_label) | |
def eval_metric(self, pred_label, *args, **kwargs): | |
preds, labels = pred_label | |
correct_num = 0 | |
all_num = 0 | |
norm_edit_dis = 0.0 | |
for (pred, pred_conf), (target, _) in zip(preds, labels): | |
if self.stream: | |
assert len(labels) == 1 | |
pred, _ = stream_match(preds) | |
if self.ignore_space: | |
pred = pred.replace(' ', '') | |
target = target.replace(' ', '') | |
if self.is_filter: | |
pred = self._normalize_text(pred) | |
target = self._normalize_text(target) | |
if self.is_lower: | |
pred = pred.lower() | |
target = target.lower() | |
norm_edit_dis += Levenshtein.normalized_distance(pred, target) | |
if pred == target: | |
correct_num += 1 | |
all_num += 1 | |
self.correct_num += correct_num | |
self.all_num += all_num | |
self.norm_edit_dis += norm_edit_dis | |
return { | |
'acc': correct_num / (all_num + self.eps), | |
'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps), | |
} | |
def eval_all_metric(self, pred_label, batch=None, *args, **kwargs): | |
if self.with_ratio: | |
ratio = batch[-1] | |
preds, labels = pred_label | |
correct_num = 0 | |
correct_num_real = 0 | |
correct_num_lower = 0 | |
correct_num_ignore_space = 0 | |
correct_num_ignore_space_lower = 0 | |
correct_num_ignore_space_symbol = 0 | |
all_num = 0 | |
norm_edit_dis = 0.0 | |
each_len_num = [0 for _ in range(self.max_len)] | |
each_len_correct_num = [0 for _ in range(self.max_len)] | |
each_len_norm_edit_dis = [0 for _ in range(self.max_len)] | |
each_ratio_num = [0 for _ in range(self.max_ratio)] | |
each_ratio_correct_num = [0 for _ in range(self.max_ratio)] | |
each_ratio_norm_edit_dis = [0 for _ in range(self.max_ratio)] | |
for (pred, pred_conf), (target, _) in zip(preds, labels): | |
if self.stream: | |
assert len(labels) == 1 | |
pred, _ = stream_match(preds) | |
if pred == target: | |
correct_num_real += 1 | |
if pred.lower() == target.lower(): | |
correct_num_lower += 1 | |
if self.ignore_space: | |
pred = pred.replace(' ', '') | |
target = target.replace(' ', '') | |
if pred == target: | |
correct_num_ignore_space += 1 | |
if pred.lower() == target.lower(): | |
correct_num_ignore_space_lower += 1 | |
if self.is_filter: | |
pred = self._normalize_text(pred) | |
target = self._normalize_text(target) | |
if pred == target: | |
correct_num_ignore_space_symbol += 1 | |
if self.is_lower: | |
pred = pred.lower() | |
target = target.lower() | |
dis = Levenshtein.normalized_distance(pred, target) | |
norm_edit_dis += dis | |
ratio_i = ratio[all_num] - 1 if ratio[ | |
all_num] < self.max_ratio else self.max_ratio - 1 | |
len_i = max(0, min(self.max_len, len(target)) - 1) | |
if pred == target: | |
correct_num += 1 | |
each_len_correct_num[len_i] += 1 | |
each_ratio_correct_num[ratio_i] += 1 | |
each_len_num[len_i] += 1 | |
each_len_norm_edit_dis[len_i] += dis | |
each_ratio_num[ratio_i] += 1 | |
each_ratio_norm_edit_dis[ratio_i] += dis | |
all_num += 1 | |
self.correct_num += correct_num | |
self.correct_num_real += correct_num_real | |
self.correct_num_lower += correct_num_lower | |
self.correct_num_ignore_space += correct_num_ignore_space | |
self.correct_num_ignore_space_lower += correct_num_ignore_space_lower | |
self.correct_num_ignore_space_symbol += correct_num_ignore_space_symbol | |
self.all_num += all_num | |
self.norm_edit_dis += norm_edit_dis | |
self.each_len_num = self.each_len_num + np.array(each_len_num) | |
self.each_len_correct_num = self.each_len_correct_num + np.array( | |
each_len_correct_num) | |
self.each_len_norm_edit_dis = self.each_len_norm_edit_dis + np.array( | |
each_len_norm_edit_dis) | |
self.each_ratio_num = self.each_ratio_num + np.array(each_ratio_num) | |
self.each_ratio_correct_num = self.each_ratio_correct_num + np.array( | |
each_ratio_correct_num) | |
self.each_ratio_norm_edit_dis = self.each_ratio_norm_edit_dis + np.array( | |
each_ratio_norm_edit_dis) | |
return { | |
'acc': correct_num / (all_num + self.eps), | |
'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps), | |
} | |
def get_metric(self, training=False): | |
""" | |
return metrics { | |
'acc': 0, | |
'norm_edit_dis': 0, | |
} | |
""" | |
if self.with_ratio and not training: | |
return self.get_all_metric() | |
acc = 1.0 * self.correct_num / (self.all_num + self.eps) | |
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps) | |
num_samples = self.all_num | |
self.reset() | |
return { | |
'acc': acc, | |
'norm_edit_dis': norm_edit_dis, | |
'num_samples': num_samples | |
} | |
def get_all_metric(self): | |
""" | |
return metrics { | |
'acc': 0, | |
'norm_edit_dis': 0, | |
} | |
""" | |
acc = 1.0 * self.correct_num / (self.all_num + self.eps) | |
acc_real = 1.0 * self.correct_num_real / (self.all_num + self.eps) | |
acc_lower = 1.0 * self.correct_num_lower / (self.all_num + self.eps) | |
acc_ignore_space = 1.0 * self.correct_num_ignore_space / ( | |
self.all_num + self.eps) | |
acc_ignore_space_lower = 1.0 * self.correct_num_ignore_space_lower / ( | |
self.all_num + self.eps) | |
acc_ignore_space_symbol = 1.0 * self.correct_num_ignore_space_symbol / ( | |
self.all_num + self.eps) | |
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps) | |
num_samples = self.all_num | |
each_len_acc = (self.each_len_correct_num / | |
(self.each_len_num + self.eps)).tolist() | |
each_len_norm_edit_dis = (1 - | |
((self.each_len_norm_edit_dis) / | |
((self.each_len_num) + self.eps))).tolist() | |
each_len_num = self.each_len_num.tolist() | |
each_ratio_acc = (self.each_ratio_correct_num / | |
(self.each_ratio_num + self.eps)).tolist() | |
each_ratio_norm_edit_dis = (1 - ((self.each_ratio_norm_edit_dis) / ( | |
(self.each_ratio_num) + self.eps))).tolist() | |
each_ratio_num = self.each_ratio_num.tolist() | |
self.reset() | |
return { | |
'acc': acc, | |
'acc_real': acc_real, | |
'acc_lower': acc_lower, | |
'acc_ignore_space': acc_ignore_space, | |
'acc_ignore_space_lower': acc_ignore_space_lower, | |
'acc_ignore_space_symbol': acc_ignore_space_symbol, | |
'acc_ignore_space_lower_symbol': acc, | |
'each_len_num': each_len_num, | |
'each_len_acc': each_len_acc, | |
'each_len_norm_edit_dis': each_len_norm_edit_dis, | |
'each_ratio_num': each_ratio_num, | |
'each_ratio_acc': each_ratio_acc, | |
'each_ratio_norm_edit_dis': each_ratio_norm_edit_dis, | |
'norm_edit_dis': norm_edit_dis, | |
'num_samples': num_samples | |
} | |
def reset(self): | |
self.correct_num = 0 | |
self.all_num = 0 | |
self.norm_edit_dis = 0 | |
self.correct_num_real = 0 | |
self.correct_num_lower = 0 | |
self.correct_num_ignore_space = 0 | |
self.correct_num_ignore_space_lower = 0 | |
self.correct_num_ignore_space_symbol = 0 | |
self.each_len_num = np.array([0 for _ in range(self.max_len)]) | |
self.each_len_correct_num = np.array([0 for _ in range(self.max_len)]) | |
self.each_len_norm_edit_dis = np.array( | |
[0. for _ in range(self.max_len)]) | |
self.each_ratio_num = np.array([0 for _ in range(self.max_ratio)]) | |
self.each_ratio_correct_num = np.array( | |
[0 for _ in range(self.max_ratio)]) | |
self.each_ratio_norm_edit_dis = np.array( | |
[0. for _ in range(self.max_ratio)]) | |