File size: 5,531 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Dict, List

import cv2
import lmdb
import mmengine
import numpy as np

from mmocr.registry import DATA_DUMPERS
from .base import BaseDumper


@DATA_DUMPERS.register_module()
class TextRecogLMDBDumper(BaseDumper):
    """Text recognition LMDB format dataset dumper.

    Args:
        task (str): Task type. Options are 'textdet', 'textrecog',
            'textspotter', and 'kie'. It is usually set automatically and users
             do not need to set it manually in config file in most cases.
        split (str): It' s the partition of the datasets. Options are 'train',
            'val' or 'test'. It is usually set automatically and users do not
            need to set it manually in config file in most cases. Defaults to
            None.
        data_root (str): The root directory of the image and
            annotation. It is usually set automatically and users do not need
            to set it manually in config file in most cases. Defaults to None.
        batch_size (int): Number of files written to the cache each time.
            Defaults to 1000.
        encoding (str): Label encoding method. Defaults to 'utf-8'.
        lmdb_map_size (int): Maximum size database may grow to. Defaults to
            1099511627776.
        verify (bool): Whether to check the validity of every image. Defaults
            to True.
    """

    def __init__(self,
                 task: str,
                 split: str,
                 data_root: str,
                 batch_size: int = 1000,
                 encoding: str = 'utf-8',
                 lmdb_map_size: int = 1099511627776,
                 verify: bool = True) -> None:
        assert task == 'textrecog', \
            f'TextRecogLMDBDumper only works with textrecog, but got {task}'
        super().__init__(task=task, split=split, data_root=data_root)
        self.batch_size = batch_size
        self.encoding = encoding
        self.lmdb_map_size = lmdb_map_size
        self.verify = verify

    def check_image_is_valid(self, imageBin):
        if imageBin is None:
            return False
        imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
        img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
        imgH, imgW = img.shape[0], img.shape[1]
        if imgH * imgW == 0:
            return False
        return True

    def write_cache(self, env, cache):
        with env.begin(write=True) as txn:
            cursor = txn.cursor()
            cursor.putmulti(cache, dupdata=False, overwrite=True)

    def parser_pack_instance(self, instance: Dict):
        """parser an packed MMOCR format textrecog instance.
        Args:
            instance (Dict): An packed MMOCR format textrecog instance.
                For example,
                {
                    "instance": [
                        {
                            "text": "Hello"
                        }
                    ],
                    "img_path": "img1.jpg"
                }
        """
        assert isinstance(instance,
                          Dict), 'Element of data_list must be a dict'
        assert 'img_path' in instance and 'instances' in instance, \
            'Element of data_list must have the following keys: ' \
            f'img_path and instances, but got {instance.keys()}'
        assert isinstance(instance['instances'], List) and len(
            instance['instances']) == 1
        assert 'text' in instance['instances'][0]

        img_path = instance['img_path']
        text = instance['instances'][0]['text']
        return img_path, text

    def dump(self, data: Dict) -> None:
        """Dump data to LMDB format."""

        # create lmdb env
        output_dirname = f'{self.task}_{self.split}.lmdb'
        output = osp.join(self.data_root, output_dirname)
        mmengine.mkdir_or_exist(output)
        env = lmdb.open(output, map_size=self.lmdb_map_size)
        # load data
        if 'data_list' not in data:
            raise ValueError('Dump data must have data_list key')
        data_list = data['data_list']
        cache = []
        # index start from 1
        cnt = 1
        n_samples = len(data_list)
        for d in data_list:
            # convert both images and labels to lmdb
            label_key = 'label-%09d'.encode(self.encoding) % cnt
            img_name, text = self.parser_pack_instance(d)
            img_path = osp.join(self.data_root, img_name)
            if not osp.exists(img_path):
                warnings.warn('%s does not exist' % img_path)
                continue
            with open(img_path, 'rb') as f:
                image_bin = f.read()
            if self.verify:
                if not self.check_image_is_valid(image_bin):
                    warnings.warn('%s is not a valid image' % img_path)
                    continue
            image_key = 'image-%09d'.encode(self.encoding) % cnt
            cache.append((image_key, image_bin))
            cache.append((label_key, text.encode(self.encoding)))

            if cnt % self.batch_size == 0:
                self.write_cache(env, cache)
                cache = []
                print('Written %d / %d' % (cnt, n_samples))
            cnt += 1
        n_samples = cnt - 1
        cache.append(('num-samples'.encode(self.encoding),
                      str(n_samples).encode(self.encoding)))
        self.write_cache(env, cache)
        print('Created lmdb dataset with %d samples' % n_samples)