Spaces:
Runtime error
Runtime error
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import re | |
import sys | |
import shapely | |
from shapely.geometry import Polygon | |
import numpy as np | |
from collections import defaultdict | |
import operator | |
import editdistance | |
def strQ2B(ustring): | |
rstring = "" | |
for uchar in ustring: | |
inside_code = ord(uchar) | |
if inside_code == 12288: | |
inside_code = 32 | |
elif (inside_code >= 65281 and inside_code <= 65374): | |
inside_code -= 65248 | |
rstring += chr(inside_code) | |
return rstring | |
def polygon_from_str(polygon_points): | |
""" | |
Create a shapely polygon object from gt or dt line. | |
""" | |
polygon_points = np.array(polygon_points).reshape(4, 2) | |
polygon = Polygon(polygon_points).convex_hull | |
return polygon | |
def polygon_iou(poly1, poly2): | |
""" | |
Intersection over union between two shapely polygons. | |
""" | |
if not poly1.intersects( | |
poly2): # this test is fast and can accelerate calculation | |
iou = 0 | |
else: | |
try: | |
inter_area = poly1.intersection(poly2).area | |
union_area = poly1.area + poly2.area - inter_area | |
iou = float(inter_area) / union_area | |
except shapely.geos.TopologicalError: | |
# except Exception as e: | |
# print(e) | |
print('shapely.geos.TopologicalError occurred, iou set to 0') | |
iou = 0 | |
return iou | |
def ed(str1, str2): | |
return editdistance.eval(str1, str2) | |
def e2e_eval(gt_dir, res_dir, ignore_blank=False): | |
print('start testing...') | |
iou_thresh = 0.5 | |
val_names = os.listdir(gt_dir) | |
num_gt_chars = 0 | |
gt_count = 0 | |
dt_count = 0 | |
hit = 0 | |
ed_sum = 0 | |
for i, val_name in enumerate(val_names): | |
with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f: | |
gt_lines = [o.strip() for o in f.readlines()] | |
gts = [] | |
ignore_masks = [] | |
for line in gt_lines: | |
parts = line.strip().split('\t') | |
# ignore illegal data | |
if len(parts) < 9: | |
continue | |
assert (len(parts) < 11) | |
if len(parts) == 9: | |
gts.append(parts[:8] + ['']) | |
else: | |
gts.append(parts[:8] + [parts[-1]]) | |
ignore_masks.append(parts[8]) | |
val_path = os.path.join(res_dir, val_name) | |
if not os.path.exists(val_path): | |
dt_lines = [] | |
else: | |
with open(val_path, encoding='utf-8') as f: | |
dt_lines = [o.strip() for o in f.readlines()] | |
dts = [] | |
for line in dt_lines: | |
# print(line) | |
parts = line.strip().split("\t") | |
assert (len(parts) < 10), "line error: {}".format(line) | |
if len(parts) == 8: | |
dts.append(parts + ['']) | |
else: | |
dts.append(parts) | |
dt_match = [False] * len(dts) | |
gt_match = [False] * len(gts) | |
all_ious = defaultdict(tuple) | |
for index_gt, gt in enumerate(gts): | |
gt_coors = [float(gt_coor) for gt_coor in gt[0:8]] | |
gt_poly = polygon_from_str(gt_coors) | |
for index_dt, dt in enumerate(dts): | |
dt_coors = [float(dt_coor) for dt_coor in dt[0:8]] | |
dt_poly = polygon_from_str(dt_coors) | |
iou = polygon_iou(dt_poly, gt_poly) | |
if iou >= iou_thresh: | |
all_ious[(index_gt, index_dt)] = iou | |
sorted_ious = sorted( | |
all_ious.items(), key=operator.itemgetter(1), reverse=True) | |
sorted_gt_dt_pairs = [item[0] for item in sorted_ious] | |
# matched gt and dt | |
for gt_dt_pair in sorted_gt_dt_pairs: | |
index_gt, index_dt = gt_dt_pair | |
if gt_match[index_gt] == False and dt_match[index_dt] == False: | |
gt_match[index_gt] = True | |
dt_match[index_dt] = True | |
if ignore_blank: | |
gt_str = strQ2B(gts[index_gt][8]).replace(" ", "") | |
dt_str = strQ2B(dts[index_dt][8]).replace(" ", "") | |
else: | |
gt_str = strQ2B(gts[index_gt][8]) | |
dt_str = strQ2B(dts[index_dt][8]) | |
if ignore_masks[index_gt] == '0': | |
ed_sum += ed(gt_str, dt_str) | |
num_gt_chars += len(gt_str) | |
if gt_str == dt_str: | |
hit += 1 | |
gt_count += 1 | |
dt_count += 1 | |
# unmatched dt | |
for tindex, dt_match_flag in enumerate(dt_match): | |
if dt_match_flag == False: | |
dt_str = dts[tindex][8] | |
gt_str = '' | |
ed_sum += ed(dt_str, gt_str) | |
dt_count += 1 | |
# unmatched gt | |
for tindex, gt_match_flag in enumerate(gt_match): | |
if gt_match_flag == False and ignore_masks[tindex] == '0': | |
dt_str = '' | |
gt_str = gts[tindex][8] | |
ed_sum += ed(gt_str, dt_str) | |
num_gt_chars += len(gt_str) | |
gt_count += 1 | |
eps = 1e-9 | |
print('hit, dt_count, gt_count', hit, dt_count, gt_count) | |
precision = hit / (dt_count + eps) | |
recall = hit / (gt_count + eps) | |
fmeasure = 2.0 * precision * recall / (precision + recall + eps) | |
avg_edit_dist_img = ed_sum / len(val_names) | |
avg_edit_dist_field = ed_sum / (gt_count + eps) | |
character_acc = 1 - ed_sum / (num_gt_chars + eps) | |
print('character_acc: %.2f' % (character_acc * 100) + "%") | |
print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field)) | |
print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img)) | |
print('precision: %.2f' % (precision * 100) + "%") | |
print('recall: %.2f' % (recall * 100) + "%") | |
print('fmeasure: %.2f' % (fmeasure * 100) + "%") | |
if __name__ == '__main__': | |
# if len(sys.argv) != 3: | |
# print("python3 ocr_e2e_eval.py gt_dir res_dir") | |
# exit(-1) | |
# gt_folder = sys.argv[1] | |
# pred_folder = sys.argv[2] | |
gt_folder = sys.argv[1] | |
pred_folder = sys.argv[2] | |
e2e_eval(gt_folder, pred_folder) | |