File size: 5,571 Bytes
1b870f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import cv2
import json
import numpy as np

class_id_to_name = {
            0: "Articles",
            1: "Advertisement",
            2: "Headlines",
            3: "Sub-headlines",
            4: "Graphics",
            5: "Images",
            6: "Tables",
            7: "Text Block",
            8: "Header"
        }


def findIntersection(box1, box2):
    """
    args:
        box1: [x, y, w, h]
        box2: [x, y, w, h]
    
    returns:
        iou: float
    """
    x1, y1, w1, h1 = box1[0], box1[1], box1[2] - box1[0], box1[3] - box1[1]
    x2, y2, w2, h2 = box2[0], box2[1], box2[2] - box2[0], box2[3] - box2[1]
    
    xA = max(x1, x2)
    yA = max(y1, y2)
    
    xB = min(x1 + w1, x2 + w2)
    yB = min(y1 + h1, y2 + h2)
    
    # Calculate the intersection area
    interArea = max(0, xB - xA) * max(0, yB - yA)
    
    # divide by obj1 area instead of union
    
    if w1 * h1 == 0:
        return 0
    else:
        iou = interArea / (w1 * h1)
    
    return iou
                
def get_hierarchy(labels):
    class_list = [[] for _ in range(len(class_id_to_name))]
    object_dict = {}
    article_list = []
    
    for itr, label in enumerate(labels):
        x1,y1,x2,y2 = label[0][0], label[0][1], label[0][2], label[0][3]
        conf = label[1]
        class_id = label[2]
        class_list[int(class_id)].append([x1, y1, x2, y2])
        obj_key = int(class_id) * 1000000 + len(class_list[int(class_id)]) - 1
        object_dict[obj_key] = [0, [x1, y1, x2, y2], int(class_id)]
            
    # For each article, find all the objects that belong to it
    cou = 0
    for article in class_list[0]:
        article_dict = {'Articles': [], 'Headlines': [], 'Sub-headlines': [], 'Graphics': [], 'Images': [], 'Text Block': [], "Advertisement": [], "Tables": []}
        for class_id in range(9):
            IoUThresh = 0.70; max_article = None
            for obj in class_list[class_id]:
                IoU = findIntersection(obj, article)
                if IoU > IoUThresh:
                    key = class_id_to_name[class_id]
                    
                    article_dict[key].append(obj)
                    obj_key = class_id * 1000000 + class_list[class_id].index(obj)
                    if obj_key in object_dict:
                        cou += 1
                        object_dict[obj_key][0] = 1
                
 
        article_list.append(article_dict)
         
    return article_list, object_dict
                
def textblock_ordering(textblocks, img):
    coords = []
    ori_coords = {}
    for obj in textblocks:
        x1, y1, x2, y2 = map(float, obj)
        ori_coords[tuple([x1, y1, x2, y2])] = [x1, y1, x2, y2]
        coords.append([x1, y1, x2, y2])
    
    # sort the textblocks by x1 
    coords.sort(key = lambda x: x[0])
    # create vertical buckets of horizontal pixelsize of 15% of the image width
    buckets = []
    bucket_size = int(0.15 * img.shape[1])
    # put the textblocks in the buckets
    for coord in coords:
        if len(buckets) == 0:
            buckets.append([coord])
        else:
            for bucket in buckets:
                if abs(bucket[0][0] - coord[0]) < bucket_size:
                    bucket.append(coord)
                    break
            else:
                buckets.append([coord])
                
    # sort each bucket by y1
    for bucket in buckets:
        bucket.sort(key = lambda x: x[1])
                 
    # visualize the buckets one by one each with a different color
    # for bucket in buckets:
    #     color = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))
    #     for coord in bucket:
    #         img = cv2.rectangle(img, (coord[0], coord[1]), (coord[2], coord[3]), color, 5)
    #         cv2.imshow('img', img)
            
    # change bucket coords to original coords
    for bucket in buckets:
        for i in range(len(bucket)):
            bucket[i] = ori_coords[tuple(bucket[i])]
            
    # merge all the buckets into one list
    buckets = [item for sublist in buckets for item in sublist]
    return buckets

def get_ordered_data(labels, img):
    article_list, object_dict = get_hierarchy(labels)
    for article in article_list:
        sorted_buckets = textblock_ordering(article['Text Block'], img)
        article['Text Block'] = sorted_buckets

    # Dump the results in a json file
    # Data structure:
    # {Article1: {Headlines: [obj1, obj2, ...], Sub-headlines: [obj1, obj2, ...], ...}, Article2: {...}, ...}
    json_dict = {}
    json_dict['Articles'] = article_list
    json_dict['Extra'] = []
    # Add remaining objects to the json
    for key in object_dict:
        if object_dict[key][0] == 0:
            print("Extra: ", key)
            json_dict['Extra'].append({class_id_to_name[object_dict[key][2]]: [object_dict[key][1]]})
    
    return json_dict
        
# if __name__ == '__main__':
#     label_path = '/Users/deveshpant/Work/WadhwaniAI/IDSP/eNewspaperPDFs/Language_wise/Results/pred/Hindi2/labels/_Jansatta-Delhi  15-11_5.txt'  
#     img_path = '/Users/deveshpant/Work/WadhwaniAI/IDSP/eNewspaperPDFs/Language_wise/Language_wise_imgs/Hindi/_Jansatta-Delhi  15-11_5.png'
#     json_dict = get_ordered_data(label_path, img_path)
    
#     # dump the json
#     with open('json_dict.json', 'w') as f:
#         json.dump(json_dict, f)
        
#     # read the json
#     with open('json_dict.json', 'r') as f:
#         json_dict = json.load(f)
        
    # visualize the results
    img = cv2.imread(img_path)