Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import argparse | |
import math | |
import os.path as osp | |
from functools import partial | |
import mmcv | |
import mmengine | |
from mmocr.utils import dump_ocr_data | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description='Generate training and validation set of LSVT ') | |
parser.add_argument('root_path', help='Root dir path of LSVT') | |
parser.add_argument( | |
'--val-ratio', help='Split ratio for val set', default=0.0, type=float) | |
parser.add_argument( | |
'--nproc', default=1, type=int, help='Number of processes') | |
parser.add_argument( | |
'--preserve-vertical', | |
help='Preserve samples containing vertical texts', | |
action='store_true') | |
args = parser.parse_args() | |
return args | |
def process_img(args, dst_image_root, ignore_image_root, preserve_vertical, | |
split): | |
# Dirty hack for multi-processing | |
img_idx, img_info, anns = args | |
src_img = mmcv.imread(img_info['file_name']) | |
img_info = [] | |
for ann_idx, ann in enumerate(anns): | |
segmentation = [] | |
for x, y in ann['points']: | |
segmentation.append(max(0, x)) | |
segmentation.append(max(0, y)) | |
xs, ys = segmentation[::2], segmentation[1::2] | |
x, y = min(xs), min(ys) | |
w, h = max(xs) - x, max(ys) - y | |
text_label = ann['transcription'] | |
dst_img = src_img[y:y + h, x:x + w] | |
dst_img_name = f'img_{img_idx}_{ann_idx}.jpg' | |
if not preserve_vertical and h / w > 2 and split == 'train': | |
dst_img_path = osp.join(ignore_image_root, dst_img_name) | |
mmcv.imwrite(dst_img, dst_img_path) | |
continue | |
dst_img_path = osp.join(dst_image_root, dst_img_name) | |
mmcv.imwrite(dst_img, dst_img_path) | |
img_info.append({ | |
'file_name': dst_img_name, | |
'anno_info': [{ | |
'text': text_label | |
}] | |
}) | |
return img_info | |
def convert_lsvt(root_path, | |
split, | |
ratio, | |
preserve_vertical, | |
nproc, | |
img_start_idx=0): | |
"""Collect the annotation information and crop the images. | |
The annotation format is as the following: | |
[ | |
{'gt_1234': # 'gt_1234' is file name | |
[ | |
{ | |
'transcription': '一站式购物中心', | |
'points': [[45, 272], [215, 273], [212, 296], [45, 290]] | |
'illegibility': False | |
}, ... | |
] | |
} | |
] | |
Args: | |
root_path (str): The root path of the dataset | |
split (str): The split of dataset. Namely: training or val | |
ratio (float): Split ratio for val set | |
preserve_vertical (bool): Whether to preserve vertical texts | |
nproc (int): The number of process to collect annotations | |
img_start_idx (int): Index of start image | |
Returns: | |
img_info (dict): The dict of the img and annotation information | |
""" | |
annotation_path = osp.join(root_path, 'annotations/train_full_labels.json') | |
if not osp.exists(annotation_path): | |
raise Exception( | |
f'{annotation_path} not exists, please check and try again.') | |
annotation = mmengine.load(annotation_path) | |
# outputs | |
dst_label_file = osp.join(root_path, f'{split}_label.json') | |
dst_image_root = osp.join(root_path, 'crops', split) | |
ignore_image_root = osp.join(root_path, 'ignores', split) | |
src_image_root = osp.join(root_path, 'imgs') | |
mmengine.mkdir_or_exist(dst_image_root) | |
mmengine.mkdir_or_exist(ignore_image_root) | |
process_img_with_path = partial( | |
process_img, | |
dst_image_root=dst_image_root, | |
ignore_image_root=ignore_image_root, | |
preserve_vertical=preserve_vertical, | |
split=split) | |
img_prefixes = annotation.keys() | |
trn_files, val_files = [], [] | |
if ratio > 0: | |
for i, file in enumerate(img_prefixes): | |
if i % math.floor(1 / ratio): | |
trn_files.append(file) | |
else: | |
val_files.append(file) | |
else: | |
trn_files, val_files = img_prefixes, [] | |
print(f'training #{len(trn_files)}, val #{len(val_files)}') | |
if split == 'train': | |
img_prefixes = trn_files | |
elif split == 'val': | |
img_prefixes = val_files | |
else: | |
raise NotImplementedError | |
tasks = [] | |
idx = 0 | |
for img_idx, prefix in enumerate(img_prefixes): | |
img_file = osp.join(src_image_root, prefix + '.jpg') | |
img_info = {'file_name': img_file} | |
# Skip not exist images | |
if not osp.exists(img_file): | |
continue | |
tasks.append((img_idx + img_start_idx, img_info, annotation[prefix])) | |
idx = idx + 1 | |
labels_list = mmengine.track_parallel_progress( | |
process_img_with_path, tasks, keep_order=True, nproc=nproc) | |
final_labels = [] | |
for label_list in labels_list: | |
final_labels += label_list | |
dump_ocr_data(final_labels, dst_label_file, 'textrecog') | |
return idx | |
def main(): | |
args = parse_args() | |
root_path = args.root_path | |
print('Processing training set...') | |
num_train_imgs = convert_lsvt( | |
root_path=root_path, | |
split='train', | |
ratio=args.val_ratio, | |
preserve_vertical=args.preserve_vertical, | |
nproc=args.nproc) | |
if args.val_ratio > 0: | |
print('Processing validation set...') | |
convert_lsvt( | |
root_path=root_path, | |
split='val', | |
ratio=args.val_ratio, | |
preserve_vertical=args.preserve_vertical, | |
nproc=args.nproc, | |
img_start_idx=num_train_imgs) | |
print('Finish') | |
if __name__ == '__main__': | |
main() | |