Mountchicken's picture
Upload 704 files
9bf4bd7
raw
history blame
5.74 kB
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import math
import os.path as osp
from functools import partial
import mmcv
import mmengine
from mmocr.utils import dump_ocr_data
def parse_args():
parser = argparse.ArgumentParser(
description='Generate training and validation set of LSVT ')
parser.add_argument('root_path', help='Root dir path of LSVT')
parser.add_argument(
'--val-ratio', help='Split ratio for val set', default=0.0, type=float)
parser.add_argument(
'--nproc', default=1, type=int, help='Number of processes')
parser.add_argument(
'--preserve-vertical',
help='Preserve samples containing vertical texts',
action='store_true')
args = parser.parse_args()
return args
def process_img(args, dst_image_root, ignore_image_root, preserve_vertical,
split):
# Dirty hack for multi-processing
img_idx, img_info, anns = args
src_img = mmcv.imread(img_info['file_name'])
img_info = []
for ann_idx, ann in enumerate(anns):
segmentation = []
for x, y in ann['points']:
segmentation.append(max(0, x))
segmentation.append(max(0, y))
xs, ys = segmentation[::2], segmentation[1::2]
x, y = min(xs), min(ys)
w, h = max(xs) - x, max(ys) - y
text_label = ann['transcription']
dst_img = src_img[y:y + h, x:x + w]
dst_img_name = f'img_{img_idx}_{ann_idx}.jpg'
if not preserve_vertical and h / w > 2 and split == 'train':
dst_img_path = osp.join(ignore_image_root, dst_img_name)
mmcv.imwrite(dst_img, dst_img_path)
continue
dst_img_path = osp.join(dst_image_root, dst_img_name)
mmcv.imwrite(dst_img, dst_img_path)
img_info.append({
'file_name': dst_img_name,
'anno_info': [{
'text': text_label
}]
})
return img_info
def convert_lsvt(root_path,
split,
ratio,
preserve_vertical,
nproc,
img_start_idx=0):
"""Collect the annotation information and crop the images.
The annotation format is as the following:
[
{'gt_1234': # 'gt_1234' is file name
[
{
'transcription': '一站式购物中心',
'points': [[45, 272], [215, 273], [212, 296], [45, 290]]
'illegibility': False
}, ...
]
}
]
Args:
root_path (str): The root path of the dataset
split (str): The split of dataset. Namely: training or val
ratio (float): Split ratio for val set
preserve_vertical (bool): Whether to preserve vertical texts
nproc (int): The number of process to collect annotations
img_start_idx (int): Index of start image
Returns:
img_info (dict): The dict of the img and annotation information
"""
annotation_path = osp.join(root_path, 'annotations/train_full_labels.json')
if not osp.exists(annotation_path):
raise Exception(
f'{annotation_path} not exists, please check and try again.')
annotation = mmengine.load(annotation_path)
# outputs
dst_label_file = osp.join(root_path, f'{split}_label.json')
dst_image_root = osp.join(root_path, 'crops', split)
ignore_image_root = osp.join(root_path, 'ignores', split)
src_image_root = osp.join(root_path, 'imgs')
mmengine.mkdir_or_exist(dst_image_root)
mmengine.mkdir_or_exist(ignore_image_root)
process_img_with_path = partial(
process_img,
dst_image_root=dst_image_root,
ignore_image_root=ignore_image_root,
preserve_vertical=preserve_vertical,
split=split)
img_prefixes = annotation.keys()
trn_files, val_files = [], []
if ratio > 0:
for i, file in enumerate(img_prefixes):
if i % math.floor(1 / ratio):
trn_files.append(file)
else:
val_files.append(file)
else:
trn_files, val_files = img_prefixes, []
print(f'training #{len(trn_files)}, val #{len(val_files)}')
if split == 'train':
img_prefixes = trn_files
elif split == 'val':
img_prefixes = val_files
else:
raise NotImplementedError
tasks = []
idx = 0
for img_idx, prefix in enumerate(img_prefixes):
img_file = osp.join(src_image_root, prefix + '.jpg')
img_info = {'file_name': img_file}
# Skip not exist images
if not osp.exists(img_file):
continue
tasks.append((img_idx + img_start_idx, img_info, annotation[prefix]))
idx = idx + 1
labels_list = mmengine.track_parallel_progress(
process_img_with_path, tasks, keep_order=True, nproc=nproc)
final_labels = []
for label_list in labels_list:
final_labels += label_list
dump_ocr_data(final_labels, dst_label_file, 'textrecog')
return idx
def main():
args = parse_args()
root_path = args.root_path
print('Processing training set...')
num_train_imgs = convert_lsvt(
root_path=root_path,
split='train',
ratio=args.val_ratio,
preserve_vertical=args.preserve_vertical,
nproc=args.nproc)
if args.val_ratio > 0:
print('Processing validation set...')
convert_lsvt(
root_path=root_path,
split='val',
ratio=args.val_ratio,
preserve_vertical=args.preserve_vertical,
nproc=args.nproc,
img_start_idx=num_train_imgs)
print('Finish')
if __name__ == '__main__':
main()