Spaces:
Running
Running
import io | |
import math | |
import random | |
import re | |
import unicodedata | |
import cv2 | |
import lmdb | |
import numpy as np | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from torchvision import transforms as T | |
from torchvision.transforms import functional as F | |
from openrec.preprocess import create_operators, transform | |
class CharsetAdapter: | |
"""Transforms labels according to the target charset.""" | |
def __init__(self, target_charset) -> None: | |
super().__init__() | |
self.lowercase_only = target_charset == target_charset.lower() | |
self.uppercase_only = target_charset == target_charset.upper() | |
self.unsupported = re.compile(f'[^{re.escape(target_charset)}]') | |
def __call__(self, label): | |
if self.lowercase_only: | |
label = label.lower() | |
elif self.uppercase_only: | |
label = label.upper() | |
# Remove unsupported characters | |
label = self.unsupported.sub('', label) | |
return label | |
class RatioDataSetTVResizeTest(Dataset): | |
def __init__(self, config, mode, logger, seed=None, epoch=1): | |
super(RatioDataSetTVResizeTest, self).__init__() | |
self.ds_width = config[mode]['dataset'].get('ds_width', True) | |
global_config = config['Global'] | |
dataset_config = config[mode]['dataset'] | |
loader_config = config[mode]['loader'] | |
max_ratio = loader_config.get('max_ratio', 10) | |
min_ratio = loader_config.get('min_ratio', 1) | |
data_dir_list = dataset_config['data_dir_list'] | |
self.do_shuffle = loader_config['shuffle'] | |
self.seed = epoch | |
self.max_text_length = global_config['max_text_length'] | |
data_source_num = len(data_dir_list) | |
ratio_list = dataset_config.get('ratio_list', 1.0) | |
if isinstance(ratio_list, (float, int)): | |
ratio_list = [float(ratio_list)] * int(data_source_num) | |
assert len( | |
ratio_list | |
) == data_source_num, 'The length of ratio_list should be the same as the file_list.' | |
self.lmdb_sets = self.load_hierarchical_lmdb_dataset( | |
data_dir_list, ratio_list) | |
for data_dir in data_dir_list: | |
logger.info('Initialize indexs of datasets:%s' % data_dir) | |
self.logger = logger | |
data_idx_order_list = self.dataset_traversal() | |
character_dict_path = global_config.get('character_dict_path', None) | |
use_space_char = global_config.get('use_space_char', False) | |
if character_dict_path is None: | |
char_test = '0123456789abcdefghijklmnopqrstuvwxyz' | |
else: | |
char_test = '' | |
with open(character_dict_path, 'rb') as fin: | |
lines = fin.readlines() | |
for line in lines: | |
line = line.decode('utf-8').strip('\n').strip('\r\n') | |
char_test += line | |
if use_space_char: | |
char_test += ' ' | |
wh_ratio, data_idx_order_list = self.get_wh_ratio( | |
data_idx_order_list, char_test) | |
self.data_idx_order_list = np.array(data_idx_order_list) | |
wh_ratio = np.around(np.array(wh_ratio)) | |
self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio) | |
for i in range(max_ratio + 1): | |
logger.info((1 * (self.wh_ratio == i)).sum()) | |
self.wh_ratio_sort = np.argsort(self.wh_ratio) | |
self.ops = create_operators(dataset_config['transforms'], | |
global_config) | |
self.need_reset = True in [x < 1 for x in ratio_list] | |
self.error = 0 | |
self.base_shape = dataset_config.get( | |
'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]]) | |
self.base_h = dataset_config.get('base_h', 32) | |
self.interpolation = T.InterpolationMode.BICUBIC | |
transforms = [] | |
transforms.extend([ | |
T.ToTensor(), | |
T.Normalize(0.5, 0.5), | |
]) | |
self.transforms = T.Compose(transforms) | |
def get_wh_ratio(self, data_idx_order_list, char_test): | |
wh_ratio = [] | |
wh_ratio_len = [[0 for _ in range(26)] for _ in range(11)] | |
data_idx_order_list_filter = [] | |
charset_adapter = CharsetAdapter(char_test) | |
for idx in range(data_idx_order_list.shape[0]): | |
lmdb_idx, file_idx = data_idx_order_list[idx] | |
lmdb_idx = int(lmdb_idx) | |
file_idx = int(file_idx) | |
wh_key = 'wh-%09d'.encode() % file_idx | |
wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key) | |
if wh is None: | |
img_key = f'image-{file_idx:09d}'.encode() | |
img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key) | |
buf = io.BytesIO(img) | |
w, h = Image.open(buf).size | |
else: | |
wh = wh.decode('utf-8') | |
w, h = wh.split('_') | |
label_key = 'label-%09d'.encode() % file_idx | |
label = self.lmdb_sets[lmdb_idx]['txn'].get(label_key) | |
if label is not None: | |
# return None | |
label = label.decode('utf-8') | |
# if remove_whitespace: | |
label = ''.join(label.split()) | |
# Normalize unicode composites (if any) and convert to compatible ASCII characters | |
# if normalize_unicode: | |
label = unicodedata.normalize('NFKD', | |
label).encode('ascii', | |
'ignore').decode() | |
# Filter by length before removing unsupported characters. The original label might be too long. | |
if len(label) > self.max_text_length: | |
continue | |
label = charset_adapter(label) | |
if not label: | |
continue | |
wh_ratio.append(float(w) / float(h)) | |
wh_ratio_len[int(float(w) / | |
float(h)) if int(float(w) / | |
float(h)) <= 10 else | |
10][len(label) if len(label) <= 25 else 25] += 1 | |
data_idx_order_list_filter.append([lmdb_idx, file_idx]) | |
self.logger.info(wh_ratio_len) | |
return wh_ratio, data_idx_order_list_filter | |
def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list): | |
lmdb_sets = {} | |
dataset_idx = 0 | |
for dirpath, ratio in zip(data_dir_list, ratio_list): | |
env = lmdb.open(dirpath, | |
max_readers=32, | |
readonly=True, | |
lock=False, | |
readahead=False, | |
meminit=False) | |
txn = env.begin(write=False) | |
num_samples = int(txn.get('num-samples'.encode())) | |
lmdb_sets[dataset_idx] = { | |
'dirpath': dirpath, | |
'env': env, | |
'txn': txn, | |
'num_samples': num_samples, | |
'ratio_num_samples': int(ratio * num_samples), | |
} | |
dataset_idx += 1 | |
return lmdb_sets | |
def dataset_traversal(self): | |
lmdb_num = len(self.lmdb_sets) | |
total_sample_num = 0 | |
for lno in range(lmdb_num): | |
total_sample_num += self.lmdb_sets[lno]['ratio_num_samples'] | |
data_idx_order_list = np.zeros((total_sample_num, 2)) | |
beg_idx = 0 | |
for lno in range(lmdb_num): | |
tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples'] | |
end_idx = beg_idx + tmp_sample_num | |
data_idx_order_list[beg_idx:end_idx, 0] = lno | |
data_idx_order_list[beg_idx:end_idx, 1] = list( | |
random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1), | |
self.lmdb_sets[lno]['ratio_num_samples'])) | |
beg_idx = beg_idx + tmp_sample_num | |
return data_idx_order_list | |
def get_img_data(self, value): | |
"""get_img_data.""" | |
if not value: | |
return None | |
imgdata = np.frombuffer(value, dtype='uint8') | |
if imgdata is None: | |
return None | |
imgori = cv2.imdecode(imgdata, 1) | |
if imgori is None: | |
return None | |
return imgori | |
def resize_norm_img(self, data, gen_ratio, padding=True): | |
img = data['image'] | |
w, h = img.size | |
imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [ | |
self.base_h * gen_ratio, self.base_h | |
] | |
use_ratio = imgW // imgH | |
if use_ratio >= (w // h) + 2: | |
self.error += 1 | |
return None | |
if not padding: | |
resized_w = imgW | |
else: | |
ratio = w / float(h) | |
if math.ceil(imgH * ratio) > imgW: | |
resized_w = imgW | |
else: | |
resized_w = int( | |
math.ceil(imgH * ratio * (random.random() + 0.5))) | |
resized_w = min(imgW, resized_w) | |
resized_image = F.resize(img, (imgH, resized_w), | |
interpolation=self.interpolation) | |
img = self.transforms(resized_image) | |
if resized_w < imgW and padding: | |
img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.) | |
valid_ratio = min(1.0, float(resized_w / imgW)) | |
data['image'] = img | |
data['valid_ratio'] = valid_ratio | |
data['gen_ratio'] = imgW // imgH | |
r = float(w) / float(h) | |
data['real_ratio'] = max(1, round(r)) | |
return data | |
def get_lmdb_sample_info(self, txn, index): | |
label_key = 'label-%09d'.encode() % index | |
label = txn.get(label_key) | |
if label is None: | |
return None | |
label = label.decode('utf-8') | |
img_key = 'image-%09d'.encode() % index | |
imgbuf = txn.get(img_key) | |
return imgbuf, label | |
def __getitem__(self, properties): | |
img_width = properties[0] | |
img_height = properties[1] | |
idx = properties[2] | |
ratio = properties[3] | |
lmdb_idx, file_idx = self.data_idx_order_list[idx] | |
lmdb_idx = int(lmdb_idx) | |
file_idx = int(file_idx) | |
sample_info = self.get_lmdb_sample_info( | |
self.lmdb_sets[lmdb_idx]['txn'], file_idx) | |
if sample_info is None: | |
ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist() | |
ids = random.sample(ratio_ids, 1) | |
return self.__getitem__([img_width, img_height, ids[0], ratio]) | |
img, label = sample_info | |
data = {'image': img, 'label': label} | |
outs = transform(data, self.ops[:-1]) | |
if outs is not None: | |
outs = self.resize_norm_img(outs, ratio, padding=False) | |
if outs is None: | |
ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist() | |
ids = random.sample(ratio_ids, 1) | |
return self.__getitem__([img_width, img_height, ids[0], ratio]) | |
outs = transform(outs, self.ops[-1:]) | |
if outs is None: | |
ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist() | |
ids = random.sample(ratio_ids, 1) | |
return self.__getitem__([img_width, img_height, ids[0], ratio]) | |
return outs | |
def __len__(self): | |
return self.data_idx_order_list.shape[0] | |