Newspapers-OCR-Demo / order_text_blocks.py
pantatwiai's picture
Update order_text_blocks.py
ce1d008
import os
import cv2
import json
import numpy as np
class_id_to_name = {
0: "Articles",
1: "Advertisement",
2: "Headlines",
3: "Sub-headlines",
4: "Graphics",
5: "Images",
6: "Tables",
7: "Text Block",
8: "Header"
}
def findIntersection(box1, box2):
"""
args:
box1: [x, y, w, h]
box2: [x, y, w, h]
returns:
iou: float
"""
x1, y1, w1, h1 = box1[0], box1[1], box1[2] - box1[0], box1[3] - box1[1]
x2, y2, w2, h2 = box2[0], box2[1], box2[2] - box2[0], box2[3] - box2[1]
xA = max(x1, x2)
yA = max(y1, y2)
xB = min(x1 + w1, x2 + w2)
yB = min(y1 + h1, y2 + h2)
# Calculate the intersection area
interArea = max(0, xB - xA) * max(0, yB - yA)
# divide by obj1 area instead of union
if w1 * h1 == 0:
return 0
else:
iou = interArea / (w1 * h1)
return iou
def get_hierarchy(labels):
class_list = [[] for _ in range(len(class_id_to_name))]
object_dict = {}
article_list = []
for itr, label in enumerate(labels):
x1,y1,x2,y2 = label[0][0], label[0][1], label[0][2], label[0][3]
conf = label[1]
class_id = label[2]
class_list[int(class_id)].append([x1, y1, x2, y2])
obj_key = int(class_id) * 1000000 + len(class_list[int(class_id)]) - 1
object_dict[obj_key] = [0, [x1, y1, x2, y2], int(class_id)]
# For each article, find all the objects that belong to it
cou = 0
for article in class_list[0]:
article_dict = {'Articles': [], 'Headlines': [], 'Sub-headlines': [], 'Graphics': [], 'Images': [], 'Text Block': [], "Advertisement": [], "Tables": []}
for class_id in range(9):
IoUThresh = 0.70; max_article = None
for obj in class_list[class_id]:
IoU = findIntersection(obj, article)
if IoU > IoUThresh:
key = class_id_to_name[class_id]
article_dict[key].append(obj)
obj_key = class_id * 1000000 + class_list[class_id].index(obj)
if obj_key in object_dict:
cou += 1
object_dict[obj_key][0] = 1
article_list.append(article_dict)
return article_list, object_dict
def textblock_ordering(textblocks, img):
coords = []
ori_coords = {}
avg_width = 0
for obj in textblocks:
x1, y1, x2, y2 = map(float, obj)
ori_coords[tuple([x1, y1, x2, y2])] = [x1, y1, x2, y2]
coords.append([x1, y1, x2, y2])
avg_width += x2 - x1
avg_width /= len(coords)
# sort the textblocks by left x coordinate
coords.sort(key = lambda x: x[0])
# create vertical buckets based on avg. width of text blocks
buckets = []
bucket_size = int(0.55 * avg_width)
# put the textblocks in the buckets
for coord in coords:
if len(buckets) == 0:
buckets.append([coord])
else:
for bucket in buckets:
if abs(bucket[0][0] - coord[0]) < bucket_size:
bucket.append(coord)
break
else:
buckets.append([coord])
# sort each bucket by y1
for bucket in buckets:
bucket.sort(key = lambda x: x[1])
# visualize the buckets one by one each with a different color
# for bucket in buckets:
# color = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))
# for coord in bucket:
# img = cv2.rectangle(img, (coord[0], coord[1]), (coord[2], coord[3]), color, 5)
# cv2.imshow('img', img)
# change bucket coords to original coords
for bucket in buckets:
for i in range(len(bucket)):
bucket[i] = ori_coords[tuple(bucket[i])]
# merge all the buckets into one list
buckets = [item for sublist in buckets for item in sublist]
return buckets
def get_ordered_data(labels, img):
article_list, object_dict = get_hierarchy(labels)
for article in article_list:
sorted_buckets = textblock_ordering(article['Text Block'], img)
article['Text Block'] = sorted_buckets
# Dump the results in a json file
# Data structure:
# {Article1: {Headlines: [obj1, obj2, ...], Sub-headlines: [obj1, obj2, ...], ...}, Article2: {...}, ...}
json_dict = {}
json_dict['Articles'] = article_list
json_dict['Extra'] = []
# Add remaining objects to the json
for key in object_dict:
if object_dict[key][0] == 0:
print("Extra: ", key)
json_dict['Extra'].append({class_id_to_name[object_dict[key][2]]: [object_dict[key][1]]})
return json_dict
# if __name__ == '__main__':
# label_path = '/Users/deveshpant/Work/WadhwaniAI/IDSP/eNewspaperPDFs/Language_wise/Results/pred/Hindi2/labels/_Jansatta-Delhi 15-11_5.txt'
# img_path = '/Users/deveshpant/Work/WadhwaniAI/IDSP/eNewspaperPDFs/Language_wise/Language_wise_imgs/Hindi/_Jansatta-Delhi 15-11_5.png'
# json_dict = get_ordered_data(label_path, img_path)
# # dump the json
# with open('json_dict.json', 'w') as f:
# json.dump(json_dict, f)
# # read the json
# with open('json_dict.json', 'r') as f:
# json_dict = json.load(f)
# visualize the results
img = cv2.imread(img_path)