import io import math import random import re import unicodedata import cv2 import lmdb import numpy as np from PIL import Image from torch.utils.data import Dataset from torchvision import transforms as T from torchvision.transforms import functional as F from openrec.preprocess import create_operators, transform class CharsetAdapter: """Transforms labels according to the target charset.""" def __init__(self, target_charset) -> None: super().__init__() self.lowercase_only = target_charset == target_charset.lower() self.uppercase_only = target_charset == target_charset.upper() self.unsupported = re.compile(f'[^{re.escape(target_charset)}]') def __call__(self, label): if self.lowercase_only: label = label.lower() elif self.uppercase_only: label = label.upper() # Remove unsupported characters label = self.unsupported.sub('', label) return label class RatioDataSetTVResizeTest(Dataset): def __init__(self, config, mode, logger, seed=None, epoch=1): super(RatioDataSetTVResizeTest, self).__init__() self.ds_width = config[mode]['dataset'].get('ds_width', True) global_config = config['Global'] dataset_config = config[mode]['dataset'] loader_config = config[mode]['loader'] max_ratio = loader_config.get('max_ratio', 10) min_ratio = loader_config.get('min_ratio', 1) data_dir_list = dataset_config['data_dir_list'] self.do_shuffle = loader_config['shuffle'] self.seed = epoch self.max_text_length = global_config['max_text_length'] data_source_num = len(data_dir_list) ratio_list = dataset_config.get('ratio_list', 1.0) if isinstance(ratio_list, (float, int)): ratio_list = [float(ratio_list)] * int(data_source_num) assert len( ratio_list ) == data_source_num, 'The length of ratio_list should be the same as the file_list.' self.lmdb_sets = self.load_hierarchical_lmdb_dataset( data_dir_list, ratio_list) for data_dir in data_dir_list: logger.info('Initialize indexs of datasets:%s' % data_dir) self.logger = logger data_idx_order_list = self.dataset_traversal() character_dict_path = global_config.get('character_dict_path', None) use_space_char = global_config.get('use_space_char', False) if character_dict_path is None: char_test = '0123456789abcdefghijklmnopqrstuvwxyz' else: char_test = '' with open(character_dict_path, 'rb') as fin: lines = fin.readlines() for line in lines: line = line.decode('utf-8').strip('\n').strip('\r\n') char_test += line if use_space_char: char_test += ' ' wh_ratio, data_idx_order_list = self.get_wh_ratio( data_idx_order_list, char_test) self.data_idx_order_list = np.array(data_idx_order_list) wh_ratio = np.around(np.array(wh_ratio)) self.wh_ratio = np.clip(wh_ratio, a_min=min_ratio, a_max=max_ratio) for i in range(max_ratio + 1): logger.info((1 * (self.wh_ratio == i)).sum()) self.wh_ratio_sort = np.argsort(self.wh_ratio) self.ops = create_operators(dataset_config['transforms'], global_config) self.need_reset = True in [x < 1 for x in ratio_list] self.error = 0 self.base_shape = dataset_config.get( 'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]]) self.base_h = dataset_config.get('base_h', 32) self.interpolation = T.InterpolationMode.BICUBIC transforms = [] transforms.extend([ T.ToTensor(), T.Normalize(0.5, 0.5), ]) self.transforms = T.Compose(transforms) def get_wh_ratio(self, data_idx_order_list, char_test): wh_ratio = [] wh_ratio_len = [[0 for _ in range(26)] for _ in range(11)] data_idx_order_list_filter = [] charset_adapter = CharsetAdapter(char_test) for idx in range(data_idx_order_list.shape[0]): lmdb_idx, file_idx = data_idx_order_list[idx] lmdb_idx = int(lmdb_idx) file_idx = int(file_idx) wh_key = 'wh-%09d'.encode() % file_idx wh = self.lmdb_sets[lmdb_idx]['txn'].get(wh_key) if wh is None: img_key = f'image-{file_idx:09d}'.encode() img = self.lmdb_sets[lmdb_idx]['txn'].get(img_key) buf = io.BytesIO(img) w, h = Image.open(buf).size else: wh = wh.decode('utf-8') w, h = wh.split('_') label_key = 'label-%09d'.encode() % file_idx label = self.lmdb_sets[lmdb_idx]['txn'].get(label_key) if label is not None: # return None label = label.decode('utf-8') # if remove_whitespace: label = ''.join(label.split()) # Normalize unicode composites (if any) and convert to compatible ASCII characters # if normalize_unicode: label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode() # Filter by length before removing unsupported characters. The original label might be too long. if len(label) > self.max_text_length: continue label = charset_adapter(label) if not label: continue wh_ratio.append(float(w) / float(h)) wh_ratio_len[int(float(w) / float(h)) if int(float(w) / float(h)) <= 10 else 10][len(label) if len(label) <= 25 else 25] += 1 data_idx_order_list_filter.append([lmdb_idx, file_idx]) self.logger.info(wh_ratio_len) return wh_ratio, data_idx_order_list_filter def load_hierarchical_lmdb_dataset(self, data_dir_list, ratio_list): lmdb_sets = {} dataset_idx = 0 for dirpath, ratio in zip(data_dir_list, ratio_list): env = lmdb.open(dirpath, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) txn = env.begin(write=False) num_samples = int(txn.get('num-samples'.encode())) lmdb_sets[dataset_idx] = { 'dirpath': dirpath, 'env': env, 'txn': txn, 'num_samples': num_samples, 'ratio_num_samples': int(ratio * num_samples), } dataset_idx += 1 return lmdb_sets def dataset_traversal(self): lmdb_num = len(self.lmdb_sets) total_sample_num = 0 for lno in range(lmdb_num): total_sample_num += self.lmdb_sets[lno]['ratio_num_samples'] data_idx_order_list = np.zeros((total_sample_num, 2)) beg_idx = 0 for lno in range(lmdb_num): tmp_sample_num = self.lmdb_sets[lno]['ratio_num_samples'] end_idx = beg_idx + tmp_sample_num data_idx_order_list[beg_idx:end_idx, 0] = lno data_idx_order_list[beg_idx:end_idx, 1] = list( random.sample(range(1, self.lmdb_sets[lno]['num_samples'] + 1), self.lmdb_sets[lno]['ratio_num_samples'])) beg_idx = beg_idx + tmp_sample_num return data_idx_order_list def get_img_data(self, value): """get_img_data.""" if not value: return None imgdata = np.frombuffer(value, dtype='uint8') if imgdata is None: return None imgori = cv2.imdecode(imgdata, 1) if imgori is None: return None return imgori def resize_norm_img(self, data, gen_ratio, padding=True): img = data['image'] w, h = img.size imgW, imgH = self.base_shape[gen_ratio - 1] if gen_ratio <= 4 else [ self.base_h * gen_ratio, self.base_h ] use_ratio = imgW // imgH if use_ratio >= (w // h) + 2: self.error += 1 return None if not padding: resized_w = imgW else: ratio = w / float(h) if math.ceil(imgH * ratio) > imgW: resized_w = imgW else: resized_w = int( math.ceil(imgH * ratio * (random.random() + 0.5))) resized_w = min(imgW, resized_w) resized_image = F.resize(img, (imgH, resized_w), interpolation=self.interpolation) img = self.transforms(resized_image) if resized_w < imgW and padding: img = F.pad(img, [0, 0, imgW - resized_w, 0], fill=0.) valid_ratio = min(1.0, float(resized_w / imgW)) data['image'] = img data['valid_ratio'] = valid_ratio data['gen_ratio'] = imgW // imgH r = float(w) / float(h) data['real_ratio'] = max(1, round(r)) return data def get_lmdb_sample_info(self, txn, index): label_key = 'label-%09d'.encode() % index label = txn.get(label_key) if label is None: return None label = label.decode('utf-8') img_key = 'image-%09d'.encode() % index imgbuf = txn.get(img_key) return imgbuf, label def __getitem__(self, properties): img_width = properties[0] img_height = properties[1] idx = properties[2] ratio = properties[3] lmdb_idx, file_idx = self.data_idx_order_list[idx] lmdb_idx = int(lmdb_idx) file_idx = int(file_idx) sample_info = self.get_lmdb_sample_info( self.lmdb_sets[lmdb_idx]['txn'], file_idx) if sample_info is None: ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist() ids = random.sample(ratio_ids, 1) return self.__getitem__([img_width, img_height, ids[0], ratio]) img, label = sample_info data = {'image': img, 'label': label} outs = transform(data, self.ops[:-1]) if outs is not None: outs = self.resize_norm_img(outs, ratio, padding=False) if outs is None: ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist() ids = random.sample(ratio_ids, 1) return self.__getitem__([img_width, img_height, ids[0], ratio]) outs = transform(outs, self.ops[-1:]) if outs is None: ratio_ids = np.where(self.wh_ratio == ratio)[0].tolist() ids = random.sample(ratio_ids, 1) return self.__getitem__([img_width, img_height, ids[0], ratio]) return outs def __len__(self): return self.data_idx_order_list.shape[0]