import os import lmdb import cv2 from tqdm import tqdm import numpy as np import io from PIL import Image """ a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """ def get_datalist(data_dir, data_path, max_len): """ 获取训练和验证的数据list :param data_dir: 数据集根目录 :param data_path: 训练的dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’ :return: """ train_data = [] if isinstance(data_path, list): for p in data_path: train_data.extend(get_datalist(data_dir, p, max_len)) else: with open(data_path, 'r', encoding='utf-8') as f: for line in tqdm(f.readlines(), desc=f'load data from {data_path}'): line = (line.strip('\n').replace('.jpg ', '.jpg\t').replace( '.png ', '.png\t').split('\t')) if len(line) > 1: img_path = os.path.join(data_dir, line[0].strip(' ')) label = line[1] if len(label) > max_len: continue if os.path.exists( img_path) and os.path.getsize(img_path) > 0: train_data.append([str(img_path), label]) return train_data def checkImageIsValid(imageBin): if imageBin is None: return False imageBuf = np.frombuffer(imageBin, dtype=np.uint8) img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) imgH, imgW = img.shape[0], img.shape[1] if imgH * imgW == 0: return False return True def writeCache(env, cache): with env.begin(write=True) as txn: for k, v in cache.items(): txn.put(k, v) def createDataset(data_list, outputPath, checkValid=True): """ Create LMDB dataset for training and evaluation. ARGS: inputPath : input folder path where starts imagePath outputPath : LMDB output path gtFile : list of image path and label checkValid : if true, check the validity of every image """ os.makedirs(outputPath, exist_ok=True) env = lmdb.open(outputPath, map_size=1099511627776) cache = {} cnt = 1 for imagePath, label in tqdm(data_list, desc=f'make dataset, save to {outputPath}'): with open(imagePath, 'rb') as f: imageBin = f.read() buf = io.BytesIO(imageBin) w, h = Image.open(buf).size if checkValid: try: if not checkImageIsValid(imageBin): print('%s is not a valid image' % imagePath) continue except: continue imageKey = 'image-%09d'.encode() % cnt labelKey = 'label-%09d'.encode() % cnt whKey = 'wh-%09d'.encode() % cnt cache[imageKey] = imageBin cache[labelKey] = label.encode() cache[whKey] = (str(w) + '_' + str(h)).encode() if cnt % 1000 == 0: writeCache(env, cache) cache = {} cnt += 1 nSamples = cnt - 1 cache['num-samples'.encode()] = str(nSamples).encode() writeCache(env, cache) print('Created dataset with %d samples' % nSamples) if __name__ == '__main__': data_dir = './Union14M-L/' label_file_list = [ './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_challenging.jsonl.txt', './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_easy.jsonl.txt', './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_hard.jsonl.txt', './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_medium.jsonl.txt', './Union14M-L/train_annos/filter_jsonl_mmocr0.x/filter_train_normal.jsonl.txt' ] save_path_root = './Union14M-L-LMDB-Filtered/' for data_list in label_file_list: save_path = save_path_root + data_list.split('/')[-1].split( '.')[0] + '/' os.makedirs(save_path, exist_ok=True) print(save_path) train_data_list = get_datalist(data_dir, data_list, 800) createDataset(train_data_list, save_path)