OpenOCR-Demo / tools /data /ratio_dataset_tvresize_test.py
topdu's picture
openocr demo
29f689c
import io
import math
import random
import re
import unicodedata
import cv2
import lmdb
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms as T
from torchvision.transforms import functional as F
from openrec.preprocess import create_operators, transform
class CharsetAdapter:
"""Transforms labels according to the target charset."""
def __init__(self, target_charset) -> None:
super().__init__()
self.lowercase_only = target_charset == target_charset.lower()
self.uppercase_only = target_charset == target_charset.upper()
self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
def __call__(self, label):
if self.lowercase_only:
label = label.lower()
elif self.uppercase_only:
label = label.upper()
# Remove unsupported characters
label = self.unsupported.sub('', label)
return label
class RatioDataSetTVResizeTest(Dataset):
def __init__(self, config, mode, logger, seed=None, epoch=1):
super(RatioDataSetTVResizeTest, self).__init__()
self.ds_width = config[mode]['dataset'].get('ds_width', True)
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
max_ratio = loader_config.get('max_ratio', 10)
min_ratio = loader_config.get('min_ratio', 1)
data_dir_list = dataset_config['data_dir_list']
self.do_shuffle = loader_config['shuffle']
self.seed = epoch
self.max_text_length = global_config['max_text_length']
data_source_num = len(data_dir_list)
ratio_list = dataset_config.get('ratio_list', 1.0)
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
assert len(
ratio_list
) == data_source_num, 'The length of ratio_list should be the same as the file_list.'
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(
data_dir_list, ratio_list)
for data_dir in data_dir_list:
logger.info('Initialize indexs of datasets:%s' % data_dir)
self.logger = logger
data_idx_order_list = self.dataset_traversal()
character_dict_path = global_config.get('character_dict_path', None)
use_space_char = global_config.get('use_space_char', False)
if character_dict_path is None:
char_test = '0123456789abcdefghijklmnopqrstuvwxyz'
else:
char_test = ''
with open(character_dict_path, 'rb') as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip('\n').strip('\r\n')
char_test += line
if use_space_char:
char_test += ' '
wh_ratio, data_idx_order_list = self.get_wh_ratio(
data_idx_order_list, char_test)
self.data_idx_order_list = np.array(data_idx_order_list)
wh_ratio = np.around(np.array(wh_ratio))
self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio)
for i in range(max_ratio + 1):
logger.info((1 * (self.wh_ratio == i)).sum())
self.wh_ratio_sort = np.argsort(self.wh_ratio)
self.ops = create_operators(dataset_config['transforms'],
global_config)
self.need_reset = True in [x < 1 for x in ratio_list]
self.error = 0
self.base_shape = dataset_config.get(
'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
self.base_h = dataset_config.get('base_h', 32)
self.interpolation = T.InterpolationMode.BICUBIC
transforms = []
transforms.extend([
T.ToTensor(),
T.Normalize(0.5, 0.5),
])
self.transforms = T.Compose(transforms)
def get_wh_ratio(self, data_idx_order_list, char_test):
wh_ratio = []
wh_ratio_len = [[0 for _ in range(26)] for _ in range(11)]
data_idx_order_list_filter = []
charset_adapter = CharsetAdapter(char_test)
for idx in range(data_idx_order_list.shape[0]):
lmdb_idx, file_idx = data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
wh_key = 'wh-%09d'.encode() % file_idx
wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key)
if wh is None:
img_key = f'image-{file_idx:09d}'.encode()
img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key)
buf = io.BytesIO(img)
w, h = Image.open(buf).size
else:
wh = wh.decode('utf-8')
w, h = wh.split('_')
label_key = 'label-%09d'.encode() % file_idx
label = self.lmdb_sets[lmdb_idx]['txn'].get(label_key)
if label is not None:
# return None
label = label.decode('utf-8')
# if remove_whitespace:
label = ''.join(label.split())
# Normalize unicode composites (if any) and convert to compatible ASCII characters
# if normalize_unicode:
label = unicodedata.normalize('NFKD',
label).encode('ascii',
'ignore').decode()
# Filter by length before removing unsupported characters. The original label might be too long.
if len(label) > self.max_text_length:
continue
label = charset_adapter(label)
if not label:
continue
wh_ratio.append(float(w) / float(h))
wh_ratio_len[int(float(w) /
float(h)) if int(float(w) /
float(h)) <= 10 else
10][len(label) if len(label) <= 25 else 25] += 1
data_idx_order_list_filter.append([lmdb_idx, file_idx])
self.logger.info(wh_ratio_len)
return wh_ratio, data_idx_order_list_filter
def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list):
lmdb_sets = {}
dataset_idx = 0
for dirpath, ratio in zip(data_dir_list, ratio_list):
env = lmdb.open(dirpath,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False)
txn = env.begin(write=False)
num_samples = int(txn.get('num-samples'.encode()))
lmdb_sets[dataset_idx] = {
'dirpath': dirpath,
'env': env,
'txn': txn,
'num_samples': num_samples,
'ratio_num_samples': int(ratio * num_samples),
}
dataset_idx += 1
return lmdb_sets
def dataset_traversal(self):
lmdb_num = len(self.lmdb_sets)
total_sample_num = 0
for lno in range(lmdb_num):
total_sample_num += self.lmdb_sets[lno]['ratio_num_samples']
data_idx_order_list = np.zeros((total_sample_num, 2))
beg_idx = 0
for lno in range(lmdb_num):
tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples']
end_idx = beg_idx + tmp_sample_num
data_idx_order_list[beg_idx:end_idx, 0] = lno
data_idx_order_list[beg_idx:end_idx, 1] = list(
random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1),
self.lmdb_sets[lno]['ratio_num_samples']))
beg_idx = beg_idx + tmp_sample_num
return data_idx_order_list
def get_img_data(self, value):
"""get_img_data."""
if not value:
return None
imgdata = np.frombuffer(value, dtype='uint8')
if imgdata is None:
return None
imgori = cv2.imdecode(imgdata, 1)
if imgori is None:
return None
return imgori
def resize_norm_img(self, data, gen_ratio, padding=True):
img = data['image']
w, h = img.size
imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [
self.base_h * gen_ratio, self.base_h
]
use_ratio = imgW // imgH
if use_ratio >= (w // h) + 2:
self.error += 1
return None
if not padding:
resized_w = imgW
else:
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(
math.ceil(imgH * ratio * (random.random() + 0.5)))
resized_w = min(imgW, resized_w)
resized_image = F.resize(img, (imgH, resized_w),
interpolation=self.interpolation)
img = self.transforms(resized_image)
if resized_w < imgW and padding:
img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.)
valid_ratio = min(1.0, float(resized_w / imgW))
data['image'] = img
data['valid_ratio'] = valid_ratio
data['gen_ratio'] = imgW // imgH
r = float(w) / float(h)
data['real_ratio'] = max(1, round(r))
return data
def get_lmdb_sample_info(self, txn, index):
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key)
if label is None:
return None
label = label.decode('utf-8')
img_key = 'image-%09d'.encode() % index
imgbuf = txn.get(img_key)
return imgbuf, label
def __getitem__(self, properties):
img_width = properties[0]
img_height = properties[1]
idx = properties[2]
ratio = properties[3]
lmdb_idx, file_idx = self.data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info(
self.lmdb_sets[lmdb_idx]['txn'], file_idx)
if sample_info is None:
ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
ids = random.sample(ratio_ids, 1)
return self.__getitem__([img_width, img_height, ids[0], ratio])
img, label = sample_info
data = {'image': img, 'label': label}
outs = transform(data, self.ops[:-1])
if outs is not None:
outs = self.resize_norm_img(outs, ratio, padding=False)
if outs is None:
ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
ids = random.sample(ratio_ids, 1)
return self.__getitem__([img_width, img_height, ids[0], ratio])
outs = transform(outs, self.ops[-1:])
if outs is None:
ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist()
ids = random.sample(ratio_ids, 1)
return self.__getitem__([img_width, img_height, ids[0], ratio])
return outs
def __len__(self):
return self.data_idx_order_list.shape[0]