import io import re import unicodedata import lmdb from PIL import Image from torch.utils.data import Dataset from openrec.preprocess import create_operators, transform class CharsetAdapter: """Transforms labels according to the target charset.""" def __init__(self, target_charset) -> None: super().__init__() self.lowercase_only = target_charset == target_charset.lower() self.uppercase_only = target_charset == target_charset.upper() self.unsupported = re.compile(f'[^{re.escape(target_charset)}]') def __call__(self, label): if self.lowercase_only: label = label.lower() elif self.uppercase_only: label = label.upper() # Remove unsupported characters label = self.unsupported.sub('', label) return label class LMDBDataSetTest(Dataset): """Dataset interface to an LMDB database. It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset. Labels are transformed according to the charset. """ def __init__(self, config, mode, logger, seed=None, epoch=1, gpu_i=0, max_label_len: int = 25, min_image_dim: int = 0, remove_whitespace: bool = True, normalize_unicode: bool = True, unlabelled: bool = False, transform=None): dataset_config = config[mode]['dataset'] global_config = config['Global'] max_label_len = global_config['max_text_length'] self.root = dataset_config['data_dir'] self._env = None self.unlabelled = unlabelled self.transform = transform self.labels = [] self.filtered_index_list = [] self.min_image_dim = min_image_dim self.filter_label = dataset_config.get('filter_label', True) #'data_dir']filter_label character_dict_path = global_config.get('character_dict_path', None) use_space_char = global_config.get('use_space_char', False) if character_dict_path is None: char_test = '0123456789abcdefghijklmnopqrstuvwxyz' else: char_test = '' with open(character_dict_path, 'rb') as fin: lines = fin.readlines() for line in lines: line = line.decode('utf-8').strip('\n').strip('\r\n') char_test += line if use_space_char: char_test += ' ' self.ops = create_operators(dataset_config['transforms'], global_config) self.num_samples = self._preprocess_labels(char_test, remove_whitespace, normalize_unicode, max_label_len, min_image_dim) def __del__(self): if self._env is not None: self._env.close() self._env = None def _create_env(self): return lmdb.open(self.root, max_readers=1, readonly=True, create=False, readahead=False, meminit=False, lock=False) @property def env(self): if self._env is None: self._env = self._create_env() return self._env def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim): charset_adapter = CharsetAdapter(charset) with self._create_env() as env, env.begin() as txn: num_samples = int(txn.get('num-samples'.encode())) if self.unlabelled: return num_samples for index in range(num_samples): index += 1 # lmdb starts with 1 label_key = f'label-{index:09d}'.encode() label = txn.get(label_key).decode() # Normally, whitespace is removed from the labels. if remove_whitespace: label = ''.join(label.split()) # Normalize unicode composites (if any) and convert to compatible ASCII characters if self.filter_label: # if normalize_unicode: label = unicodedata.normalize('NFKD', label).encode( 'ascii', 'ignore').decode() # Filter by length before removing unsupported characters. The original label might be too long. if len(label) > max_label_len: continue if self.filter_label: label = charset_adapter(label) # We filter out samples which don't contain any supported characters if not label: continue # Filter images that are too small. if min_image_dim > 0: img_key = f'image-{index:09d}'.encode() img = txn.get(img_key) data = {'image': img, 'label': label} outs = transform(data, self.ops) if outs is None: continue buf = io.BytesIO(img) w, h = Image.open(buf).size if w < self.min_image_dim or h < self.min_image_dim: continue self.labels.append(label) self.filtered_index_list.append(index) return len(self.labels) def __len__(self): return self.num_samples def __getitem__(self, index): if self.unlabelled: label = index else: label = self.labels[index] index = self.filtered_index_list[index] img_key = f'image-{index:09d}'.encode() with self.env.begin() as txn: img = txn.get(img_key) data = {'image': img, 'label': label} outs = transform(data, self.ops) return outs