# Copyright (c) OpenMMLab. All rights reserved. import argparse import json import math import os.path as osp from functools import partial import mmcv import mmengine import numpy as np from shapely.geometry import Polygon from mmocr.utils import dump_ocr_data def seg2bbox(seg): """Convert segmentation to bbox. Args: seg (list(int | float)): A set of coordinates """ if len(seg) == 4: min_x = min(seg[0], seg[2], seg[4], seg[6]) max_x = max(seg[0], seg[2], seg[4], seg[6]) min_y = min(seg[1], seg[3], seg[5], seg[7]) max_y = max(seg[1], seg[3], seg[5], seg[7]) else: seg = np.array(seg).reshape(-1, 2) polygon = Polygon(seg) min_x, min_y, max_x, max_y = polygon.bounds bbox = [min_x, min_y, max_x - min_x, max_y - min_y] return bbox def process_level( src_img, annotation, dst_image_root, ignore_image_root, preserve_vertical, split, para_idx, img_idx, line_idx, word_idx=None, ): vertices = annotation['vertices'] text_label = annotation['text'] segmentation = [i for j in vertices for i in j] x, y, w, h = seg2bbox(segmentation) x, y = max(0, math.floor(x)), max(0, math.floor(y)) w, h = math.ceil(w), math.ceil(h) dst_img = src_img[y:y + h, x:x + w] if word_idx: dst_img_name = f'img_{img_idx}_{para_idx}_{line_idx}_{word_idx}.jpg' else: dst_img_name = f'img_{img_idx}_{para_idx}_{line_idx}.jpg' if not preserve_vertical and h / w > 2 and split == 'train': dst_img_path = osp.join(ignore_image_root, dst_img_name) mmcv.imwrite(dst_img, dst_img_path) return None dst_img_path = osp.join(dst_image_root, dst_img_name) mmcv.imwrite(dst_img, dst_img_path) label = {'file_name': dst_img_name, 'anno_info': [{'text': text_label}]} return label def process_img(args, src_image_root, dst_image_root, ignore_image_root, level, preserve_vertical, split): # Dirty hack for multi-processing img_idx, img_annos = args src_img = mmcv.imread( osp.join(src_image_root, img_annos['image_id'] + '.jpg')) labels = [] for para_idx, paragraph in enumerate(img_annos['paragraphs']): for line_idx, line in enumerate(paragraph['lines']): if level == 'line': # Ignore illegible words if line['legible']: label = process_level(src_img, line, dst_image_root, ignore_image_root, preserve_vertical, split, para_idx, img_idx, line_idx) if label is not None: labels.append(label) elif level == 'word': for word_idx, word in enumerate(line['words']): if not word['legible']: continue label = process_level(src_img, word, dst_image_root, ignore_image_root, preserve_vertical, split, para_idx, img_idx, line_idx, word_idx) if label is not None: labels.append(label) return labels def convert_hiertext( root_path, split, level, preserve_vertical, nproc, ): """Collect the annotation information and crop the images. The annotation format is as the following: { "info": { "date": "release date", "version": "current version" }, "annotations": [ // List of dictionaries, one for each image. { "image_id": "the filename of corresponding image.", "image_width": image_width, // (int) The image width. "image_height": image_height, // (int) The image height. "paragraphs": [ // List of paragraphs. { "vertices": [[x1, y1], [x2, y2],...,[xn, yn]] "legible": true "lines": [ { "vertices": [[x1, y1], [x2, y2],...,[x4, y4]] "text": L "legible": true, "handwritten": false "vertical": false, "words": [ { "vertices": [[x1, y1], [x2, y2],...,[xm, ym]] "text": "the text content of this word", "legible": true "handwritten": false, "vertical": false, }, ... ] }, ... ] }, ... ] }, ... ] } Args: root_path (str): Root path to the dataset split (str): Dataset split, which should be 'train' or 'val' level (str): Crop word or line level instances preserve_vertical (bool): Whether to preserve vertical texts nproc (int): Number of processes Returns: img_info (dict): The dict of the img and annotation information """ annotation_path = osp.join(root_path, 'annotations/' + split + '.jsonl') if not osp.exists(annotation_path): raise Exception( f'{annotation_path} not exists, please check and try again.') annotation = json.load(open(annotation_path))['annotations'] # outputs dst_label_file = osp.join(root_path, f'{split}_label.json') dst_image_root = osp.join(root_path, 'crops', split) ignore_image_root = osp.join(root_path, 'ignores', split) src_image_root = osp.join(root_path, 'imgs', split) mmengine.mkdir_or_exist(dst_image_root) mmengine.mkdir_or_exist(ignore_image_root) process_img_with_path = partial( process_img, src_image_root=src_image_root, dst_image_root=dst_image_root, ignore_image_root=ignore_image_root, level=level, preserve_vertical=preserve_vertical, split=split) tasks = [] for img_idx, img_info in enumerate(annotation): tasks.append((img_idx, img_info)) labels_list = mmengine.track_parallel_progress( process_img_with_path, tasks, keep_order=True, nproc=nproc) final_labels = [] for label_list in labels_list: final_labels += label_list dump_ocr_data(final_labels, dst_label_file, 'textrecog') def parse_args(): parser = argparse.ArgumentParser( description='Generate training and validation set of HierText') parser.add_argument('root_path', help='Root dir path of HierText') parser.add_argument( '--nproc', default=1, type=int, help='Number of processes') parser.add_argument( '--preserve-vertical', help='Preserve samples containing vertical texts', action='store_true') parser.add_argument( '--level', default='word', help='Crop word or line level instance', choices=['word', 'line']) args = parser.parse_args() return args def main(): args = parse_args() root_path = args.root_path print('Processing training set...') convert_hiertext( root_path=root_path, split='train', level=args.level, preserve_vertical=args.preserve_vertical, nproc=args.nproc) print('Processing validation set...') convert_hiertext( root_path=root_path, split='val', level=args.level, preserve_vertical=args.preserve_vertical, nproc=args.nproc) print('Finish') if __name__ == '__main__': main()