Spaces:
Sleeping
Sleeping
File size: 5,743 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import math
import os.path as osp
from functools import partial
import mmcv
import mmengine
from mmocr.utils import dump_ocr_data
def parse_args():
parser = argparse.ArgumentParser(
description='Generate training and validation set of LSVT ')
parser.add_argument('root_path', help='Root dir path of LSVT')
parser.add_argument(
'--val-ratio', help='Split ratio for val set', default=0.0, type=float)
parser.add_argument(
'--nproc', default=1, type=int, help='Number of processes')
parser.add_argument(
'--preserve-vertical',
help='Preserve samples containing vertical texts',
action='store_true')
args = parser.parse_args()
return args
def process_img(args, dst_image_root, ignore_image_root, preserve_vertical,
split):
# Dirty hack for multi-processing
img_idx, img_info, anns = args
src_img = mmcv.imread(img_info['file_name'])
img_info = []
for ann_idx, ann in enumerate(anns):
segmentation = []
for x, y in ann['points']:
segmentation.append(max(0, x))
segmentation.append(max(0, y))
xs, ys = segmentation[::2], segmentation[1::2]
x, y = min(xs), min(ys)
w, h = max(xs) - x, max(ys) - y
text_label = ann['transcription']
dst_img = src_img[y:y + h, x:x + w]
dst_img_name = f'img_{img_idx}_{ann_idx}.jpg'
if not preserve_vertical and h / w > 2 and split == 'train':
dst_img_path = osp.join(ignore_image_root, dst_img_name)
mmcv.imwrite(dst_img, dst_img_path)
continue
dst_img_path = osp.join(dst_image_root, dst_img_name)
mmcv.imwrite(dst_img, dst_img_path)
img_info.append({
'file_name': dst_img_name,
'anno_info': [{
'text': text_label
}]
})
return img_info
def convert_lsvt(root_path,
split,
ratio,
preserve_vertical,
nproc,
img_start_idx=0):
"""Collect the annotation information and crop the images.
The annotation format is as the following:
[
{'gt_1234': # 'gt_1234' is file name
[
{
'transcription': '一站式购物中心',
'points': [[45, 272], [215, 273], [212, 296], [45, 290]]
'illegibility': False
}, ...
]
}
]
Args:
root_path (str): The root path of the dataset
split (str): The split of dataset. Namely: training or val
ratio (float): Split ratio for val set
preserve_vertical (bool): Whether to preserve vertical texts
nproc (int): The number of process to collect annotations
img_start_idx (int): Index of start image
Returns:
img_info (dict): The dict of the img and annotation information
"""
annotation_path = osp.join(root_path, 'annotations/train_full_labels.json')
if not osp.exists(annotation_path):
raise Exception(
f'{annotation_path} not exists, please check and try again.')
annotation = mmengine.load(annotation_path)
# outputs
dst_label_file = osp.join(root_path, f'{split}_label.json')
dst_image_root = osp.join(root_path, 'crops', split)
ignore_image_root = osp.join(root_path, 'ignores', split)
src_image_root = osp.join(root_path, 'imgs')
mmengine.mkdir_or_exist(dst_image_root)
mmengine.mkdir_or_exist(ignore_image_root)
process_img_with_path = partial(
process_img,
dst_image_root=dst_image_root,
ignore_image_root=ignore_image_root,
preserve_vertical=preserve_vertical,
split=split)
img_prefixes = annotation.keys()
trn_files, val_files = [], []
if ratio > 0:
for i, file in enumerate(img_prefixes):
if i % math.floor(1 / ratio):
trn_files.append(file)
else:
val_files.append(file)
else:
trn_files, val_files = img_prefixes, []
print(f'training #{len(trn_files)}, val #{len(val_files)}')
if split == 'train':
img_prefixes = trn_files
elif split == 'val':
img_prefixes = val_files
else:
raise NotImplementedError
tasks = []
idx = 0
for img_idx, prefix in enumerate(img_prefixes):
img_file = osp.join(src_image_root, prefix + '.jpg')
img_info = {'file_name': img_file}
# Skip not exist images
if not osp.exists(img_file):
continue
tasks.append((img_idx + img_start_idx, img_info, annotation[prefix]))
idx = idx + 1
labels_list = mmengine.track_parallel_progress(
process_img_with_path, tasks, keep_order=True, nproc=nproc)
final_labels = []
for label_list in labels_list:
final_labels += label_list
dump_ocr_data(final_labels, dst_label_file, 'textrecog')
return idx
def main():
args = parse_args()
root_path = args.root_path
print('Processing training set...')
num_train_imgs = convert_lsvt(
root_path=root_path,
split='train',
ratio=args.val_ratio,
preserve_vertical=args.preserve_vertical,
nproc=args.nproc)
if args.val_ratio > 0:
print('Processing validation set...')
convert_lsvt(
root_path=root_path,
split='val',
ratio=args.val_ratio,
preserve_vertical=args.preserve_vertical,
nproc=args.nproc,
img_start_idx=num_train_imgs)
print('Finish')
if __name__ == '__main__':
main()
|