Spaces:
Runtime error
Runtime error
import os | |
import cv2 | |
import json | |
import numpy as np | |
class_id_to_name = { | |
0: "Articles", | |
1: "Advertisement", | |
2: "Headlines", | |
3: "Sub-headlines", | |
4: "Graphics", | |
5: "Images", | |
6: "Tables", | |
7: "Text Block", | |
8: "Header" | |
} | |
def findIntersection(box1, box2): | |
""" | |
args: | |
box1: [x, y, w, h] | |
box2: [x, y, w, h] | |
returns: | |
iou: float | |
""" | |
x1, y1, w1, h1 = box1[0], box1[1], box1[2] - box1[0], box1[3] - box1[1] | |
x2, y2, w2, h2 = box2[0], box2[1], box2[2] - box2[0], box2[3] - box2[1] | |
xA = max(x1, x2) | |
yA = max(y1, y2) | |
xB = min(x1 + w1, x2 + w2) | |
yB = min(y1 + h1, y2 + h2) | |
# Calculate the intersection area | |
interArea = max(0, xB - xA) * max(0, yB - yA) | |
# divide by obj1 area instead of union | |
if w1 * h1 == 0: | |
return 0 | |
else: | |
iou = interArea / (w1 * h1) | |
return iou | |
def get_hierarchy(labels): | |
class_list = [[] for _ in range(len(class_id_to_name))] | |
object_dict = {} | |
article_list = [] | |
for itr, label in enumerate(labels): | |
x1,y1,x2,y2 = label[0][0], label[0][1], label[0][2], label[0][3] | |
conf = label[1] | |
class_id = label[2] | |
class_list[int(class_id)].append([x1, y1, x2, y2]) | |
obj_key = int(class_id) * 1000000 + len(class_list[int(class_id)]) - 1 | |
object_dict[obj_key] = [0, [x1, y1, x2, y2], int(class_id)] | |
# For each article, find all the objects that belong to it | |
cou = 0 | |
for article in class_list[0]: | |
article_dict = {'Articles': [], 'Headlines': [], 'Sub-headlines': [], 'Graphics': [], 'Images': [], 'Text Block': [], "Advertisement": [], "Tables": []} | |
for class_id in range(9): | |
IoUThresh = 0.70; max_article = None | |
for obj in class_list[class_id]: | |
IoU = findIntersection(obj, article) | |
if IoU > IoUThresh: | |
key = class_id_to_name[class_id] | |
article_dict[key].append(obj) | |
obj_key = class_id * 1000000 + class_list[class_id].index(obj) | |
if obj_key in object_dict: | |
cou += 1 | |
object_dict[obj_key][0] = 1 | |
article_list.append(article_dict) | |
return article_list, object_dict | |
def textblock_ordering(textblocks, img): | |
coords = [] | |
ori_coords = {} | |
avg_width = 0 | |
for obj in textblocks: | |
x1, y1, x2, y2 = map(float, obj) | |
ori_coords[tuple([x1, y1, x2, y2])] = [x1, y1, x2, y2] | |
coords.append([x1, y1, x2, y2]) | |
avg_width += x2 - x1 | |
avg_width /= len(coords) | |
# sort the textblocks by left x coordinate | |
coords.sort(key = lambda x: x[0]) | |
# create vertical buckets based on avg. width of text blocks | |
buckets = [] | |
bucket_size = int(0.55 * avg_width) | |
# put the textblocks in the buckets | |
for coord in coords: | |
if len(buckets) == 0: | |
buckets.append([coord]) | |
else: | |
for bucket in buckets: | |
if abs(bucket[0][0] - coord[0]) < bucket_size: | |
bucket.append(coord) | |
break | |
else: | |
buckets.append([coord]) | |
# sort each bucket by y1 | |
for bucket in buckets: | |
bucket.sort(key = lambda x: x[1]) | |
# visualize the buckets one by one each with a different color | |
# for bucket in buckets: | |
# color = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255)) | |
# for coord in bucket: | |
# img = cv2.rectangle(img, (coord[0], coord[1]), (coord[2], coord[3]), color, 5) | |
# cv2.imshow('img', img) | |
# change bucket coords to original coords | |
for bucket in buckets: | |
for i in range(len(bucket)): | |
bucket[i] = ori_coords[tuple(bucket[i])] | |
# merge all the buckets into one list | |
buckets = [item for sublist in buckets for item in sublist] | |
return buckets | |
def get_ordered_data(labels, img): | |
article_list, object_dict = get_hierarchy(labels) | |
for article in article_list: | |
sorted_buckets = textblock_ordering(article['Text Block'], img) | |
article['Text Block'] = sorted_buckets | |
# Dump the results in a json file | |
# Data structure: | |
# {Article1: {Headlines: [obj1, obj2, ...], Sub-headlines: [obj1, obj2, ...], ...}, Article2: {...}, ...} | |
json_dict = {} | |
json_dict['Articles'] = article_list | |
json_dict['Extra'] = [] | |
# Add remaining objects to the json | |
for key in object_dict: | |
if object_dict[key][0] == 0: | |
print("Extra: ", key) | |
json_dict['Extra'].append({class_id_to_name[object_dict[key][2]]: [object_dict[key][1]]}) | |
return json_dict | |
# if __name__ == '__main__': | |
# label_path = '/Users/deveshpant/Work/WadhwaniAI/IDSP/eNewspaperPDFs/Language_wise/Results/pred/Hindi2/labels/_Jansatta-Delhi 15-11_5.txt' | |
# img_path = '/Users/deveshpant/Work/WadhwaniAI/IDSP/eNewspaperPDFs/Language_wise/Language_wise_imgs/Hindi/_Jansatta-Delhi 15-11_5.png' | |
# json_dict = get_ordered_data(label_path, img_path) | |
# # dump the json | |
# with open('json_dict.json', 'w') as f: | |
# json.dump(json_dict, f) | |
# # read the json | |
# with open('json_dict.json', 'r') as f: | |
# json_dict = json.load(f) | |
# visualize the results | |
img = cv2.imread(img_path) | |