Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import argparse | |
import math | |
import os.path as osp | |
import mmengine | |
from mmocr.utils import dump_ocr_data | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description='Generate training and validation set of ArT ') | |
parser.add_argument('root_path', help='Root dir path of ArT') | |
parser.add_argument( | |
'--val-ratio', help='Split ratio for val set', default=0.0, type=float) | |
parser.add_argument( | |
'--nproc', default=1, type=int, help='Number of processes') | |
args = parser.parse_args() | |
return args | |
def convert_art(root_path, split, ratio): | |
"""Collect the annotation information and crop the images. | |
The annotation format is as the following: | |
{ | |
"gt_2836_0": [ | |
{ | |
"transcription": "URDER", | |
"points": [ | |
[25, 51], | |
[0, 2], | |
[21, 0], | |
[42, 43] | |
], | |
"language": "Latin", | |
"illegibility": false | |
} | |
], ... | |
} | |
Args: | |
root_path (str): The root path of the dataset | |
split (str): The split of dataset. Namely: training or val | |
ratio (float): Split ratio for val set | |
Returns: | |
img_info (dict): The dict of the img and annotation information | |
""" | |
annotation_path = osp.join(root_path, | |
'annotations/train_task2_labels.json') | |
if not osp.exists(annotation_path): | |
raise Exception( | |
f'{annotation_path} not exists, please check and try again.') | |
annotation = mmengine.load(annotation_path) | |
img_prefixes = annotation.keys() | |
trn_files, val_files = [], [] | |
if ratio > 0: | |
for i, file in enumerate(img_prefixes): | |
if i % math.floor(1 / ratio): | |
trn_files.append(file) | |
else: | |
val_files.append(file) | |
else: | |
trn_files, val_files = img_prefixes, [] | |
print(f'training #{len(trn_files)}, val #{len(val_files)}') | |
if split == 'train': | |
img_prefixes = trn_files | |
elif split == 'val': | |
img_prefixes = val_files | |
else: | |
raise NotImplementedError | |
img_info = [] | |
for prefix in img_prefixes: | |
text_label = annotation[prefix][0]['transcription'] | |
dst_img_name = prefix + '.jpg' | |
img_info.append({ | |
'file_name': dst_img_name, | |
'anno_info': [{ | |
'text': text_label | |
}] | |
}) | |
ensure_ascii = dict(ensure_ascii=False) | |
dump_ocr_data(img_info, osp.join(root_path, f'{split.lower()}_label.json'), | |
'textrecog', **ensure_ascii) | |
def main(): | |
args = parse_args() | |
root_path = args.root_path | |
print('Processing training set...') | |
convert_art(root_path=root_path, split='train', ratio=args.val_ratio) | |
if args.val_ratio > 0: | |
print('Processing validation set...') | |
convert_art(root_path=root_path, split='val', ratio=args.val_ratio) | |
print('Finish') | |
if __name__ == '__main__': | |
main() | |