Mountchicken's picture
Upload 704 files
9bf4bd7
raw
history blame
5.53 kB
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Dict, List
import cv2
import lmdb
import mmengine
import numpy as np
from mmocr.registry import DATA_DUMPERS
from .base import BaseDumper
@DATA_DUMPERS.register_module()
class TextRecogLMDBDumper(BaseDumper):
"""Text recognition LMDB format dataset dumper.
Args:
task (str): Task type. Options are 'textdet', 'textrecog',
'textspotter', and 'kie'. It is usually set automatically and users
do not need to set it manually in config file in most cases.
split (str): It' s the partition of the datasets. Options are 'train',
'val' or 'test'. It is usually set automatically and users do not
need to set it manually in config file in most cases. Defaults to
None.
data_root (str): The root directory of the image and
annotation. It is usually set automatically and users do not need
to set it manually in config file in most cases. Defaults to None.
batch_size (int): Number of files written to the cache each time.
Defaults to 1000.
encoding (str): Label encoding method. Defaults to 'utf-8'.
lmdb_map_size (int): Maximum size database may grow to. Defaults to
1099511627776.
verify (bool): Whether to check the validity of every image. Defaults
to True.
"""
def __init__(self,
task: str,
split: str,
data_root: str,
batch_size: int = 1000,
encoding: str = 'utf-8',
lmdb_map_size: int = 1099511627776,
verify: bool = True) -> None:
assert task == 'textrecog', \
f'TextRecogLMDBDumper only works with textrecog, but got {task}'
super().__init__(task=task, split=split, data_root=data_root)
self.batch_size = batch_size
self.encoding = encoding
self.lmdb_map_size = lmdb_map_size
self.verify = verify
def check_image_is_valid(self, imageBin):
if imageBin is None:
return False
imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
return False
return True
def write_cache(self, env, cache):
with env.begin(write=True) as txn:
cursor = txn.cursor()
cursor.putmulti(cache, dupdata=False, overwrite=True)
def parser_pack_instance(self, instance: Dict):
"""parser an packed MMOCR format textrecog instance.
Args:
instance (Dict): An packed MMOCR format textrecog instance.
For example,
{
"instance": [
{
"text": "Hello"
}
],
"img_path": "img1.jpg"
}
"""
assert isinstance(instance,
Dict), 'Element of data_list must be a dict'
assert 'img_path' in instance and 'instances' in instance, \
'Element of data_list must have the following keys: ' \
f'img_path and instances, but got {instance.keys()}'
assert isinstance(instance['instances'], List) and len(
instance['instances']) == 1
assert 'text' in instance['instances'][0]
img_path = instance['img_path']
text = instance['instances'][0]['text']
return img_path, text
def dump(self, data: Dict) -> None:
"""Dump data to LMDB format."""
# create lmdb env
output_dirname = f'{self.task}_{self.split}.lmdb'
output = osp.join(self.data_root, output_dirname)
mmengine.mkdir_or_exist(output)
env = lmdb.open(output, map_size=self.lmdb_map_size)
# load data
if 'data_list' not in data:
raise ValueError('Dump data must have data_list key')
data_list = data['data_list']
cache = []
# index start from 1
cnt = 1
n_samples = len(data_list)
for d in data_list:
# convert both images and labels to lmdb
label_key = 'label-%09d'.encode(self.encoding) % cnt
img_name, text = self.parser_pack_instance(d)
img_path = osp.join(self.data_root, img_name)
if not osp.exists(img_path):
warnings.warn('%s does not exist' % img_path)
continue
with open(img_path, 'rb') as f:
image_bin = f.read()
if self.verify:
if not self.check_image_is_valid(image_bin):
warnings.warn('%s is not a valid image' % img_path)
continue
image_key = 'image-%09d'.encode(self.encoding) % cnt
cache.append((image_key, image_bin))
cache.append((label_key, text.encode(self.encoding)))
if cnt % self.batch_size == 0:
self.write_cache(env, cache)
cache = []
print('Written %d / %d' % (cnt, n_samples))
cnt += 1
n_samples = cnt - 1
cache.append(('num-samples'.encode(self.encoding),
str(n_samples).encode(self.encoding)))
self.write_cache(env, cache)
print('Created lmdb dataset with %d samples' % n_samples)