import string import numpy as np from rapidfuzz.distance import Levenshtein def match_ss(ss1, ss2): s1_len = len(ss1) for c_i in range(s1_len): if ss1[c_i:] == ss2[:s1_len - c_i]: return ss2[s1_len - c_i:] return ss2 def stream_match(text): bs = len(text) s_list = [] conf_list = [] for s_conf in text: s_list.append(s_conf[0]) conf_list.append(s_conf[1]) s_n = bs s_start = s_list[0][:-1] s_new = s_start for s_i in range(1, s_n): s_start = match_ss( s_start, s_list[s_i][1:-1] if s_i < s_n - 1 else s_list[s_i][1:]) s_new += s_start return s_new, sum(conf_list) / bs class RecMetric(object): def __init__(self, main_indicator='acc', is_filter=False, is_lower=True, ignore_space=True, stream=False, with_ratio=False, max_len=25, max_ratio=4, **kwargs): self.main_indicator = main_indicator self.is_filter = is_filter self.is_lower = is_lower self.ignore_space = ignore_space self.stream = stream self.eps = 1e-5 self.with_ratio = with_ratio self.max_len = max_len self.max_ratio = max_ratio self.reset() def _normalize_text(self, text): text = ''.join( filter(lambda x: x in (string.digits + string.ascii_letters), text)) return text def __call__(self, pred_label, batch=None, training=False, *args, **kwargs): if self.with_ratio and not training: return self.eval_all_metric(pred_label, batch) else: return self.eval_metric(pred_label) def eval_metric(self, pred_label, *args, **kwargs): preds, labels = pred_label correct_num = 0 all_num = 0 norm_edit_dis = 0.0 for (pred, pred_conf), (target, _) in zip(preds, labels): if self.stream: assert len(labels) == 1 pred, _ = stream_match(preds) if self.ignore_space: pred = pred.replace(' ', '') target = target.replace(' ', '') if self.is_filter: pred = self._normalize_text(pred) target = self._normalize_text(target) if self.is_lower: pred = pred.lower() target = target.lower() norm_edit_dis += Levenshtein.normalized_distance(pred, target) if pred == target: correct_num += 1 all_num += 1 self.correct_num += correct_num self.all_num += all_num self.norm_edit_dis += norm_edit_dis return { 'acc': correct_num / (all_num + self.eps), 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps), } def eval_all_metric(self, pred_label, batch=None, *args, **kwargs): if self.with_ratio: ratio = batch[-1] preds, labels = pred_label correct_num = 0 correct_num_real = 0 correct_num_lower = 0 correct_num_ignore_space = 0 correct_num_ignore_space_lower = 0 correct_num_ignore_space_symbol = 0 all_num = 0 norm_edit_dis = 0.0 each_len_num = [0 for _ in range(self.max_len)] each_len_correct_num = [0 for _ in range(self.max_len)] each_len_norm_edit_dis = [0 for _ in range(self.max_len)] each_ratio_num = [0 for _ in range(self.max_ratio)] each_ratio_correct_num = [0 for _ in range(self.max_ratio)] each_ratio_norm_edit_dis = [0 for _ in range(self.max_ratio)] for (pred, pred_conf), (target, _) in zip(preds, labels): if self.stream: assert len(labels) == 1 pred, _ = stream_match(preds) if pred == target: correct_num_real += 1 if pred.lower() == target.lower(): correct_num_lower += 1 if self.ignore_space: pred = pred.replace(' ', '') target = target.replace(' ', '') if pred == target: correct_num_ignore_space += 1 if pred.lower() == target.lower(): correct_num_ignore_space_lower += 1 if self.is_filter: pred = self._normalize_text(pred) target = self._normalize_text(target) if pred == target: correct_num_ignore_space_symbol += 1 if self.is_lower: pred = pred.lower() target = target.lower() dis = Levenshtein.normalized_distance(pred, target) norm_edit_dis += dis ratio_i = ratio[all_num] - 1 if ratio[ all_num] < self.max_ratio else self.max_ratio - 1 len_i = max(0, min(self.max_len, len(target)) - 1) if pred == target: correct_num += 1 each_len_correct_num[len_i] += 1 each_ratio_correct_num[ratio_i] += 1 each_len_num[len_i] += 1 each_len_norm_edit_dis[len_i] += dis each_ratio_num[ratio_i] += 1 each_ratio_norm_edit_dis[ratio_i] += dis all_num += 1 self.correct_num += correct_num self.correct_num_real += correct_num_real self.correct_num_lower += correct_num_lower self.correct_num_ignore_space += correct_num_ignore_space self.correct_num_ignore_space_lower += correct_num_ignore_space_lower self.correct_num_ignore_space_symbol += correct_num_ignore_space_symbol self.all_num += all_num self.norm_edit_dis += norm_edit_dis self.each_len_num = self.each_len_num + np.array(each_len_num) self.each_len_correct_num = self.each_len_correct_num + np.array( each_len_correct_num) self.each_len_norm_edit_dis = self.each_len_norm_edit_dis + np.array( each_len_norm_edit_dis) self.each_ratio_num = self.each_ratio_num + np.array(each_ratio_num) self.each_ratio_correct_num = self.each_ratio_correct_num + np.array( each_ratio_correct_num) self.each_ratio_norm_edit_dis = self.each_ratio_norm_edit_dis + np.array( each_ratio_norm_edit_dis) return { 'acc': correct_num / (all_num + self.eps), 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps), } def get_metric(self, training=False): """ return metrics { 'acc': 0, 'norm_edit_dis': 0, } """ if self.with_ratio and not training: return self.get_all_metric() acc = 1.0 * self.correct_num / (self.all_num + self.eps) norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps) num_samples = self.all_num self.reset() return { 'acc': acc, 'norm_edit_dis': norm_edit_dis, 'num_samples': num_samples } def get_all_metric(self): """ return metrics { 'acc': 0, 'norm_edit_dis': 0, } """ acc = 1.0 * self.correct_num / (self.all_num + self.eps) acc_real = 1.0 * self.correct_num_real / (self.all_num + self.eps) acc_lower = 1.0 * self.correct_num_lower / (self.all_num + self.eps) acc_ignore_space = 1.0 * self.correct_num_ignore_space / ( self.all_num + self.eps) acc_ignore_space_lower = 1.0 * self.correct_num_ignore_space_lower / ( self.all_num + self.eps) acc_ignore_space_symbol = 1.0 * self.correct_num_ignore_space_symbol / ( self.all_num + self.eps) norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps) num_samples = self.all_num each_len_acc = (self.each_len_correct_num / (self.each_len_num + self.eps)).tolist() each_len_norm_edit_dis = (1 - ((self.each_len_norm_edit_dis) / ((self.each_len_num) + self.eps))).tolist() each_len_num = self.each_len_num.tolist() each_ratio_acc = (self.each_ratio_correct_num / (self.each_ratio_num + self.eps)).tolist() each_ratio_norm_edit_dis = (1 - ((self.each_ratio_norm_edit_dis) / ( (self.each_ratio_num) + self.eps))).tolist() each_ratio_num = self.each_ratio_num.tolist() self.reset() return { 'acc': acc, 'acc_real': acc_real, 'acc_lower': acc_lower, 'acc_ignore_space': acc_ignore_space, 'acc_ignore_space_lower': acc_ignore_space_lower, 'acc_ignore_space_symbol': acc_ignore_space_symbol, 'acc_ignore_space_lower_symbol': acc, 'each_len_num': each_len_num, 'each_len_acc': each_len_acc, 'each_len_norm_edit_dis': each_len_norm_edit_dis, 'each_ratio_num': each_ratio_num, 'each_ratio_acc': each_ratio_acc, 'each_ratio_norm_edit_dis': each_ratio_norm_edit_dis, 'norm_edit_dis': norm_edit_dis, 'num_samples': num_samples } def reset(self): self.correct_num = 0 self.all_num = 0 self.norm_edit_dis = 0 self.correct_num_real = 0 self.correct_num_lower = 0 self.correct_num_ignore_space = 0 self.correct_num_ignore_space_lower = 0 self.correct_num_ignore_space_symbol = 0 self.each_len_num = np.array([0 for _ in range(self.max_len)]) self.each_len_correct_num = np.array([0 for _ in range(self.max_len)]) self.each_len_norm_edit_dis = np.array( [0. for _ in range(self.max_len)]) self.each_ratio_num = np.array([0 for _ in range(self.max_ratio)]) self.each_ratio_correct_num = np.array( [0 for _ in range(self.max_ratio)]) self.each_ratio_norm_edit_dis = np.array( [0. for _ in range(self.max_ratio)])