OpenOCR-Demo / openrec /metrics /rec_metric.py
topdu's picture
openocr demo
29f689c
raw
history blame
10.3 kB
import string
import numpy as np
from rapidfuzz.distance import Levenshtein
def match_ss(ss1, ss2):
s1_len = len(ss1)
for c_i in range(s1_len):
if ss1[c_i:] == ss2[:s1_len - c_i]:
return ss2[s1_len - c_i:]
return ss2
def stream_match(text):
bs = len(text)
s_list = []
conf_list = []
for s_conf in text:
s_list.append(s_conf[0])
conf_list.append(s_conf[1])
s_n = bs
s_start = s_list[0][:-1]
s_new = s_start
for s_i in range(1, s_n):
s_start = match_ss(
s_start, s_list[s_i][1:-1] if s_i < s_n - 1 else s_list[s_i][1:])
s_new += s_start
return s_new, sum(conf_list) / bs
class RecMetric(object):
def __init__(self,
main_indicator='acc',
is_filter=False,
is_lower=True,
ignore_space=True,
stream=False,
with_ratio=False,
max_len=25,
max_ratio=4,
**kwargs):
self.main_indicator = main_indicator
self.is_filter = is_filter
self.is_lower = is_lower
self.ignore_space = ignore_space
self.stream = stream
self.eps = 1e-5
self.with_ratio = with_ratio
self.max_len = max_len
self.max_ratio = max_ratio
self.reset()
def _normalize_text(self, text):
text = ''.join(
filter(lambda x: x in (string.digits + string.ascii_letters),
text))
return text
def __call__(self,
pred_label,
batch=None,
training=False,
*args,
**kwargs):
if self.with_ratio and not training:
return self.eval_all_metric(pred_label, batch)
else:
return self.eval_metric(pred_label)
def eval_metric(self, pred_label, *args, **kwargs):
preds, labels = pred_label
correct_num = 0
all_num = 0
norm_edit_dis = 0.0
for (pred, pred_conf), (target, _) in zip(preds, labels):
if self.stream:
assert len(labels) == 1
pred, _ = stream_match(preds)
if self.ignore_space:
pred = pred.replace(' ', '')
target = target.replace(' ', '')
if self.is_filter:
pred = self._normalize_text(pred)
target = self._normalize_text(target)
if self.is_lower:
pred = pred.lower()
target = target.lower()
norm_edit_dis += Levenshtein.normalized_distance(pred, target)
if pred == target:
correct_num += 1
all_num += 1
self.correct_num += correct_num
self.all_num += all_num
self.norm_edit_dis += norm_edit_dis
return {
'acc': correct_num / (all_num + self.eps),
'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps),
}
def eval_all_metric(self, pred_label, batch=None, *args, **kwargs):
if self.with_ratio:
ratio = batch[-1]
preds, labels = pred_label
correct_num = 0
correct_num_real = 0
correct_num_lower = 0
correct_num_ignore_space = 0
correct_num_ignore_space_lower = 0
correct_num_ignore_space_symbol = 0
all_num = 0
norm_edit_dis = 0.0
each_len_num = [0 for _ in range(self.max_len)]
each_len_correct_num = [0 for _ in range(self.max_len)]
each_len_norm_edit_dis = [0 for _ in range(self.max_len)]
each_ratio_num = [0 for _ in range(self.max_ratio)]
each_ratio_correct_num = [0 for _ in range(self.max_ratio)]
each_ratio_norm_edit_dis = [0 for _ in range(self.max_ratio)]
for (pred, pred_conf), (target, _) in zip(preds, labels):
if self.stream:
assert len(labels) == 1
pred, _ = stream_match(preds)
if pred == target:
correct_num_real += 1
if pred.lower() == target.lower():
correct_num_lower += 1
if self.ignore_space:
pred = pred.replace(' ', '')
target = target.replace(' ', '')
if pred == target:
correct_num_ignore_space += 1
if pred.lower() == target.lower():
correct_num_ignore_space_lower += 1
if self.is_filter:
pred = self._normalize_text(pred)
target = self._normalize_text(target)
if pred == target:
correct_num_ignore_space_symbol += 1
if self.is_lower:
pred = pred.lower()
target = target.lower()
dis = Levenshtein.normalized_distance(pred, target)
norm_edit_dis += dis
ratio_i = ratio[all_num] - 1 if ratio[
all_num] < self.max_ratio else self.max_ratio - 1
len_i = max(0, min(self.max_len, len(target)) - 1)
if pred == target:
correct_num += 1
each_len_correct_num[len_i] += 1
each_ratio_correct_num[ratio_i] += 1
each_len_num[len_i] += 1
each_len_norm_edit_dis[len_i] += dis
each_ratio_num[ratio_i] += 1
each_ratio_norm_edit_dis[ratio_i] += dis
all_num += 1
self.correct_num += correct_num
self.correct_num_real += correct_num_real
self.correct_num_lower += correct_num_lower
self.correct_num_ignore_space += correct_num_ignore_space
self.correct_num_ignore_space_lower += correct_num_ignore_space_lower
self.correct_num_ignore_space_symbol += correct_num_ignore_space_symbol
self.all_num += all_num
self.norm_edit_dis += norm_edit_dis
self.each_len_num = self.each_len_num + np.array(each_len_num)
self.each_len_correct_num = self.each_len_correct_num + np.array(
each_len_correct_num)
self.each_len_norm_edit_dis = self.each_len_norm_edit_dis + np.array(
each_len_norm_edit_dis)
self.each_ratio_num = self.each_ratio_num + np.array(each_ratio_num)
self.each_ratio_correct_num = self.each_ratio_correct_num + np.array(
each_ratio_correct_num)
self.each_ratio_norm_edit_dis = self.each_ratio_norm_edit_dis + np.array(
each_ratio_norm_edit_dis)
return {
'acc': correct_num / (all_num + self.eps),
'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps),
}
def get_metric(self, training=False):
"""
return metrics {
'acc': 0,
'norm_edit_dis': 0,
}
"""
if self.with_ratio and not training:
return self.get_all_metric()
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
num_samples = self.all_num
self.reset()
return {
'acc': acc,
'norm_edit_dis': norm_edit_dis,
'num_samples': num_samples
}
def get_all_metric(self):
"""
return metrics {
'acc': 0,
'norm_edit_dis': 0,
}
"""
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
acc_real = 1.0 * self.correct_num_real / (self.all_num + self.eps)
acc_lower = 1.0 * self.correct_num_lower / (self.all_num + self.eps)
acc_ignore_space = 1.0 * self.correct_num_ignore_space / (
self.all_num + self.eps)
acc_ignore_space_lower = 1.0 * self.correct_num_ignore_space_lower / (
self.all_num + self.eps)
acc_ignore_space_symbol = 1.0 * self.correct_num_ignore_space_symbol / (
self.all_num + self.eps)
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
num_samples = self.all_num
each_len_acc = (self.each_len_correct_num /
(self.each_len_num + self.eps)).tolist()
each_len_norm_edit_dis = (1 -
((self.each_len_norm_edit_dis) /
((self.each_len_num) + self.eps))).tolist()
each_len_num = self.each_len_num.tolist()
each_ratio_acc = (self.each_ratio_correct_num /
(self.each_ratio_num + self.eps)).tolist()
each_ratio_norm_edit_dis = (1 - ((self.each_ratio_norm_edit_dis) / (
(self.each_ratio_num) + self.eps))).tolist()
each_ratio_num = self.each_ratio_num.tolist()
self.reset()
return {
'acc': acc,
'acc_real': acc_real,
'acc_lower': acc_lower,
'acc_ignore_space': acc_ignore_space,
'acc_ignore_space_lower': acc_ignore_space_lower,
'acc_ignore_space_symbol': acc_ignore_space_symbol,
'acc_ignore_space_lower_symbol': acc,
'each_len_num': each_len_num,
'each_len_acc': each_len_acc,
'each_len_norm_edit_dis': each_len_norm_edit_dis,
'each_ratio_num': each_ratio_num,
'each_ratio_acc': each_ratio_acc,
'each_ratio_norm_edit_dis': each_ratio_norm_edit_dis,
'norm_edit_dis': norm_edit_dis,
'num_samples': num_samples
}
def reset(self):
self.correct_num = 0
self.all_num = 0
self.norm_edit_dis = 0
self.correct_num_real = 0
self.correct_num_lower = 0
self.correct_num_ignore_space = 0
self.correct_num_ignore_space_lower = 0
self.correct_num_ignore_space_symbol = 0
self.each_len_num = np.array([0 for _ in range(self.max_len)])
self.each_len_correct_num = np.array([0 for _ in range(self.max_len)])
self.each_len_norm_edit_dis = np.array(
[0. for _ in range(self.max_len)])
self.each_ratio_num = np.array([0 for _ in range(self.max_ratio)])
self.each_ratio_correct_num = np.array(
[0 for _ in range(self.max_ratio)])
self.each_ratio_norm_edit_dis = np.array(
[0. for _ in range(self.max_ratio)])