Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import argparse | |
import os | |
import os.path as osp | |
import re | |
from functools import partial | |
import mmcv | |
import numpy as np | |
from mmocr.utils.fileio import list_to_file | |
from PIL import Image | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Generate training set of LSVT ' 'by cropping box image.') | |
parser.add_argument('root_path', help='Root dir path of LSVT') | |
parser.add_argument('n_proc', default=1, type=int, help='Number of processes to run') | |
args = parser.parse_args() | |
return args | |
def process_img(args, src_image_root, dst_image_root): | |
# Dirty hack for multiprocessing | |
img_idx, img_info, anns = args | |
try: | |
src_img = Image.open(osp.join(src_image_root, 'train_full_images_0/{}.jpg'.format(img_info))) | |
except IOError: | |
src_img = Image.open(osp.join(src_image_root, 'train_full_images_1/{}.jpg'.format(img_info))) | |
blacklist = ['LOFTINESS*'] | |
whitelist = ['#Find YOUR Fun#', 'Story #', '*0#'] | |
labels = [] | |
for ann_idx, ann in enumerate(anns): | |
text_label = ann['transcription'] | |
# Ignore illegible or words with non-Latin characters | |
if ( | |
ann['illegibility'] | |
or re.findall(r'[\u4e00-\u9fff]+', text_label) | |
or text_label in blacklist | |
or ('#' in text_label and text_label not in whitelist) | |
): | |
continue | |
points = np.asarray(ann['points']) | |
x1, y1 = points.min(axis=0) | |
x2, y2 = points.max(axis=0) | |
dst_img = src_img.crop((x1, y1, x2, y2)) | |
dst_img_name = f'img_{img_idx}_{ann_idx}.jpg' | |
dst_img_path = osp.join(dst_image_root, dst_img_name) | |
# Preserve JPEG quality | |
dst_img.save(dst_img_path, qtables=src_img.quantization) | |
labels.append(f'{osp.basename(dst_image_root)}/{dst_img_name}' f' {text_label}') | |
src_img.close() | |
return labels | |
def convert_lsvt(root_path, dst_image_path, dst_label_filename, annotation_filename, img_start_idx=0, nproc=1): | |
annotation_path = osp.join(root_path, annotation_filename) | |
if not osp.exists(annotation_path): | |
raise Exception(f'{annotation_path} not exists, please check and try again.') | |
src_image_root = root_path | |
# outputs | |
dst_label_file = osp.join(root_path, dst_label_filename) | |
dst_image_root = osp.join(root_path, dst_image_path) | |
os.makedirs(dst_image_root, exist_ok=True) | |
annotation = mmcv.load(annotation_path) | |
process_img_with_path = partial(process_img, src_image_root=src_image_root, dst_image_root=dst_image_root) | |
tasks = [] | |
for img_idx, (img_info, anns) in enumerate(annotation.items()): | |
tasks.append((img_idx + img_start_idx, img_info, anns)) | |
labels_list = mmcv.track_parallel_progress(process_img_with_path, tasks, keep_order=True, nproc=nproc) | |
final_labels = [] | |
for label_list in labels_list: | |
final_labels += label_list | |
list_to_file(dst_label_file, final_labels) | |
return len(annotation) | |
def main(): | |
args = parse_args() | |
root_path = args.root_path | |
print('Processing training set...') | |
convert_lsvt( | |
root_path=root_path, | |
dst_image_path='image_train', | |
dst_label_filename='train_label.txt', | |
annotation_filename='train_full_labels.json', | |
nproc=args.n_proc, | |
) | |
print('Finish') | |
if __name__ == '__main__': | |
main() | |