Spaces:
Sleeping
Sleeping
File size: 5,132 Bytes
14c9181 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import json
import os.path as osp
import numpy as np
from shapely.geometry import Polygon
from mmocr.utils import dump_ocr_data
def collect_level_info(annotation):
"""Collect information from any level in HierText.
Args:
annotation (dict): dict at each level
Return:
anno (dict): dict containing annotations
"""
iscrowd = 0 if annotation['legible'] else 1
vertices = np.array(annotation['vertices'])
polygon = Polygon(vertices)
area = polygon.area
min_x, min_y, max_x, max_y = polygon.bounds
bbox = [min_x, min_y, max_x - min_x, max_y - min_y]
segmentation = [i for j in vertices for i in j]
anno = dict(
iscrowd=iscrowd,
category_id=1,
bbox=bbox,
area=area,
segmentation=[segmentation])
return anno
def collect_hiertext_info(root_path, level, split, print_every=1000):
"""Collect the annotation information.
The annotation format is as the following:
{
"info": {
"date": "release date",
"version": "current version"
},
"annotations": [ // List of dictionaries, one for each image.
{
"image_id": "the filename of corresponding image.",
"image_width": image_width, // (int) The image width.
"image_height": image_height, // (int) The image height.
"paragraphs": [ // List of paragraphs.
{
"vertices": [[x1, y1], [x2, y2],...,[xn, yn]]
"legible": true
"lines": [
{
"vertices": [[x1, y1], [x2, y2],...,[x4, y4]]
"text": L
"legible": true,
"handwritten": false
"vertical": false,
"words": [
{
"vertices": [[x1, y1], [x2, y2],...,[xm, ym]]
"text": "the text content of this word",
"legible": true
"handwritten": false,
"vertical": false,
}, ...
]
}, ...
]
}, ...
]
}, ...
]
}
Args:
root_path (str): Root path to the dataset
level (str): Level of annotations, which should be 'word', 'line',
or 'paragraphs'
split (str): Dataset split, which should be 'train' or 'validation'
print_every (int): Print log information per iter
Returns:
img_info (dict): The dict of the img and annotation information
"""
annotation_path = osp.join(root_path, 'annotations/' + split + '.jsonl')
if not osp.exists(annotation_path):
raise Exception(
f'{annotation_path} not exists, please check and try again.')
annotation = json.load(open(annotation_path))['annotations']
img_infos = []
for i, img_annos in enumerate(annotation):
if i > 0 and i % print_every == 0:
print(f'{i}/{len(annotation)}')
img_info = {}
img_info['file_name'] = img_annos['image_id'] + '.jpg'
img_info['height'] = img_annos['image_height']
img_info['width'] = img_annos['image_width']
img_info['segm_file'] = annotation_path
anno_info = []
for paragraph in img_annos['paragraphs']:
if level == 'paragraph':
anno = collect_level_info(paragraph)
anno_info.append(anno)
elif level == 'line':
for line in paragraph['lines']:
anno = collect_level_info(line)
anno_info.append(anno)
elif level == 'word':
for line in paragraph['lines']:
for word in line['words']:
anno = collect_level_info(line)
anno_info.append(anno)
img_info.update(anno_info=anno_info)
img_infos.append(img_info)
return img_infos
def parse_args():
parser = argparse.ArgumentParser(
description='Generate training and validation set of HierText ')
parser.add_argument('root_path', help='Root dir path of HierText')
parser.add_argument(
'--level',
default='word',
help='HierText provides three levels of annotation',
choices=['word', 'line', 'paragraph'])
args = parser.parse_args()
return args
def main():
args = parse_args()
root_path = args.root_path
print('Processing training set...')
training_infos = collect_hiertext_info(root_path, args.level, 'train')
dump_ocr_data(training_infos,
osp.join(root_path, 'instances_training.json'), 'textdet')
print('Processing validation set...')
val_infos = collect_hiertext_info(root_path, args.level, 'val')
dump_ocr_data(val_infos, osp.join(root_path, 'instances_val.json'),
'textdet')
print('Finish')
if __name__ == '__main__':
main()
|