ragavsachdeva commited on
Commit
b059ef6
·
verified ·
1 Parent(s): 82dec45

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +413 -0
utils.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.patches as patches
6
+ from shapely.geometry import Point, box
7
+ import networkx as nx
8
+ from copy import deepcopy
9
+ from itertools import groupby
10
+ from concurrent.futures import ThreadPoolExecutor, TimeoutError
11
+
12
+ def move_to_device(inputs, device):
13
+ if hasattr(inputs, "keys"):
14
+ return {k: move_to_device(v, device) for k, v in inputs.items()}
15
+ elif isinstance(inputs, list):
16
+ return [move_to_device(v, device) for v in inputs]
17
+ elif isinstance(inputs, tuple):
18
+ return tuple([move_to_device(v, device) for v in inputs])
19
+ elif isinstance(inputs, np.ndarray):
20
+ return torch.from_numpy(inputs).to(device)
21
+ else:
22
+ return inputs.to(device)
23
+
24
+ class UnionFind:
25
+ def __init__(self, n):
26
+ self.parent = list(range(n))
27
+ self.size = [1] * n
28
+ self.num_components = n
29
+
30
+ @classmethod
31
+ def from_adj_matrix(cls, adj_matrix):
32
+ ufds = cls(adj_matrix.shape[0])
33
+ for i in range(adj_matrix.shape[0]):
34
+ for j in range(adj_matrix.shape[1]):
35
+ if adj_matrix[i, j] > 0:
36
+ ufds.unite(i, j)
37
+ return ufds
38
+
39
+ @classmethod
40
+ def from_adj_list(cls, adj_list):
41
+ ufds = cls(len(adj_list))
42
+ for i in range(len(adj_list)):
43
+ for j in adj_list[i]:
44
+ ufds.unite(i, j)
45
+ return ufds
46
+
47
+ @classmethod
48
+ def from_edge_list(cls, edge_list, num_nodes):
49
+ ufds = cls(num_nodes)
50
+ for edge in edge_list:
51
+ ufds.unite(edge[0], edge[1])
52
+ return ufds
53
+
54
+ def find(self, x):
55
+ if self.parent[x] == x:
56
+ return x
57
+ self.parent[x] = self.find(self.parent[x])
58
+ return self.parent[x]
59
+
60
+ def unite(self, x, y):
61
+ x = self.find(x)
62
+ y = self.find(y)
63
+ if x != y:
64
+ if self.size[x] < self.size[y]:
65
+ x, y = y, x
66
+ self.parent[y] = x
67
+ self.size[x] += self.size[y]
68
+ self.num_components -= 1
69
+
70
+ def get_components_of(self, x):
71
+ x = self.find(x)
72
+ return [i for i in range(len(self.parent)) if self.find(i) == x]
73
+
74
+ def are_connected(self, x, y):
75
+ return self.find(x) == self.find(y)
76
+
77
+ def get_size(self, x):
78
+ return self.size[self.find(x)]
79
+
80
+ def get_num_components(self):
81
+ return self.num_components
82
+
83
+ def get_labels_for_connected_components(self):
84
+ map_parent_to_label = {}
85
+ labels = []
86
+ for i in range(len(self.parent)):
87
+ parent = self.find(i)
88
+ if parent not in map_parent_to_label:
89
+ map_parent_to_label[parent] = len(map_parent_to_label)
90
+ labels.append(map_parent_to_label[parent])
91
+ return labels
92
+
93
+ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
94
+ h, w = image_as_np_array.shape[:2]
95
+ if h > w:
96
+ figure, subplot = plt.subplots(1, 1, figsize=(10, 10 * h / w))
97
+ else:
98
+ figure, subplot = plt.subplots(1, 1, figsize=(10 * w / h, 10))
99
+ subplot.imshow(image_as_np_array)
100
+ plot_bboxes(subplot, predictions["panels"], color="green")
101
+ plot_bboxes(subplot, predictions["texts"], color="red", add_index=True)
102
+ plot_bboxes(subplot, predictions["characters"], color="blue")
103
+
104
+ COLOURS = [
105
+ "#b7ff51", # green
106
+ "#f50a8f", # pink
107
+ "#4b13b6", # purple
108
+ "#ddaa34", # orange
109
+ "#bea2a2", # brown
110
+ ]
111
+ colour_index = 0
112
+ character_cluster_labels = predictions["character_cluster_labels"]
113
+ unique_label_sorted_by_frequency = sorted(list(set(character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
114
+ for label in unique_label_sorted_by_frequency:
115
+ root = None
116
+ others = []
117
+ for i in range(len(predictions["characters"])):
118
+ if character_cluster_labels[i] == label:
119
+ if root is None:
120
+ root = i
121
+ else:
122
+ others.append(i)
123
+ if colour_index >= len(COLOURS):
124
+ random_colour = COLOURS[0]
125
+ while random_colour in COLOURS:
126
+ random_colour = "#" + "".join([random.choice("0123456789ABCDEF") for j in range(6)])
127
+ else:
128
+ random_colour = COLOURS[colour_index]
129
+ colour_index += 1
130
+ bbox_i = predictions["characters"][root]
131
+ x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
132
+ y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
133
+ subplot.plot([x1], [y1], color=random_colour, marker="o", markersize=5)
134
+ for j in others:
135
+ # draw line from centre of bbox i to centre of bbox j
136
+ bbox_j = predictions["characters"][j]
137
+ x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
138
+ y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
139
+ x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
140
+ y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
141
+ subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
142
+ subplot.plot([x2], [y2], color=random_colour, marker="o", markersize=5)
143
+
144
+ for (i, j) in predictions["text_character_associations"]:
145
+ score = predictions["dialog_confidences"][i]
146
+ bbox_i = predictions["texts"][i]
147
+ bbox_j = predictions["characters"][j]
148
+ x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
149
+ y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
150
+ x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
151
+ y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
152
+ subplot.plot([x1, x2], [y1, y2], color="red", linewidth=2, linestyle="dashed", alpha=score)
153
+
154
+ subplot.axis("off")
155
+ if filename is not None:
156
+ plt.savefig(filename, bbox_inches="tight", pad_inches=0)
157
+
158
+ figure.canvas.draw()
159
+ image = np.array(figure.canvas.renderer._renderer)
160
+ plt.close()
161
+ return image
162
+
163
+ def plot_bboxes(subplot, bboxes, color="red", add_index=False):
164
+ for id, bbox in enumerate(bboxes):
165
+ w = bbox[2] - bbox[0]
166
+ h = bbox[3] - bbox[1]
167
+ rect = patches.Rectangle(
168
+ bbox[:2], w, h, linewidth=1, edgecolor=color, facecolor="none", linestyle="solid"
169
+ )
170
+ subplot.add_patch(rect)
171
+ if add_index:
172
+ cx, cy = bbox[0] + w / 2, bbox[1] + h / 2
173
+ subplot.text(cx, cy, str(id), color=color, fontsize=10, ha="center", va="center")
174
+
175
+ def sort_panels(rects):
176
+ before_rects = convert_to_list_of_lists(rects)
177
+ # slightly erode all rectangles initially to account for imperfect detections
178
+ rects = [erode_rectangle(rect, 0.05) for rect in before_rects]
179
+ G = nx.DiGraph()
180
+ G.add_nodes_from(range(len(rects)))
181
+ for i in range(len(rects)):
182
+ for j in range(len(rects)):
183
+ if i == j:
184
+ continue
185
+ if is_there_a_directed_edge(i, j, rects):
186
+ G.add_edge(i, j, weight=get_distance(rects[i], rects[j]))
187
+ else:
188
+ G.add_edge(j, i, weight=get_distance(rects[i], rects[j]))
189
+ while True:
190
+ with ThreadPoolExecutor(max_workers=1) as executor:
191
+ future = executor.submit(list, nx.simple_cycles(G))
192
+ try:
193
+ cycles = future.result(timeout=60)
194
+ except TimeoutError:
195
+ print("Cycle finding timed out after 60 seconds")
196
+ return list(range(len(rects)))
197
+ cycles = [cycle for cycle in cycles if len(cycle) > 1]
198
+ if len(cycles) == 0:
199
+ break
200
+ cycle = cycles[0]
201
+ edges = [e for e in zip(cycle, cycle[1:] + cycle[:1])]
202
+ max_cyclic_edge = max(edges, key=lambda x: G.edges[x]["weight"])
203
+ G.remove_edge(*max_cyclic_edge)
204
+ return list(nx.topological_sort(G))
205
+
206
+ def is_strictly_above(rectA, rectB):
207
+ x1A, y1A, x2A, y2A = rectA
208
+ x1B, y1B, x2B, y2B = rectB
209
+ return y2A < y1B
210
+
211
+ def is_strictly_below(rectA, rectB):
212
+ x1A, y1A, x2A, y2A = rectA
213
+ x1B, y1B, x2B, y2B = rectB
214
+ return y2B < y1A
215
+
216
+ def is_strictly_left_of(rectA, rectB):
217
+ x1A, y1A, x2A, y2A = rectA
218
+ x1B, y1B, x2B, y2B = rectB
219
+ return x2A < x1B
220
+
221
+ def is_strictly_right_of(rectA, rectB):
222
+ x1A, y1A, x2A, y2A = rectA
223
+ x1B, y1B, x2B, y2B = rectB
224
+ return x2B < x1A
225
+
226
+ def intersects(rectA, rectB):
227
+ return box(*rectA).intersects(box(*rectB))
228
+
229
+ def is_there_a_directed_edge(a, b, rects):
230
+ rectA = rects[a]
231
+ rectB = rects[b]
232
+ centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2, rectA[1] + (rectA[3] - rectA[1]) / 2]
233
+ centre_of_B = [rectB[0] + (rectB[2] - rectB[0]) / 2, rectB[1] + (rectB[3] - rectB[1]) / 2]
234
+ if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
235
+ return box(*rectA).area > (box(*rectB)).area
236
+ copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
237
+ copy_B = [rectB[0], rectB[1], rectB[2], rectB[3]]
238
+ while True:
239
+ if is_strictly_above(copy_A, copy_B) and not is_strictly_left_of(copy_A, copy_B):
240
+ return 1
241
+ if is_strictly_above(copy_B, copy_A) and not is_strictly_left_of(copy_B, copy_A):
242
+ return 0
243
+ if is_strictly_right_of(copy_A, copy_B) and not is_strictly_below(copy_A, copy_B):
244
+ return 1
245
+ if is_strictly_right_of(copy_B, copy_A) and not is_strictly_below(copy_B, copy_A):
246
+ return 0
247
+ if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
248
+ return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
249
+ if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
250
+ return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
251
+ # otherwise they intersect
252
+ copy_A = erode_rectangle(copy_A, 0.05)
253
+ copy_B = erode_rectangle(copy_B, 0.05)
254
+
255
+ def get_distance(rectA, rectB):
256
+ return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
257
+
258
+ def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
259
+ rects = deepcopy(rects)
260
+ while True:
261
+ xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
262
+ rect_index = [i for i in range(len(rects)) if intersects(rects[i], [xmin, ymin, xmax, ymax])]
263
+ rects_copy = [rect for rect in rects if intersects(rect, [xmin, ymin, xmax, ymax])]
264
+
265
+ # try to split the panels using a "horizontal" lines
266
+ overlapping_y_ranges = merge_overlapping_ranges([(y1, y2) for x1, y1, x2, y2 in rects_copy])
267
+ panel_index_to_split = {}
268
+ for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
269
+ for i, index in enumerate(rect_index):
270
+ if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
271
+ panel_index_to_split[index] = split_index
272
+
273
+ if panel_index_to_split[a] != panel_index_to_split[b]:
274
+ return panel_index_to_split[a] < panel_index_to_split[b]
275
+
276
+ # try to split the panels using a "vertical" lines
277
+ overlapping_x_ranges = merge_overlapping_ranges([(x1, x2) for x1, y1, x2, y2 in rects_copy])
278
+ panel_index_to_split = {}
279
+ for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
280
+ for i, index in enumerate(rect_index):
281
+ if x1 <= rects_copy[i][0] <= rects_copy[i][2] <= x2:
282
+ panel_index_to_split[index] = split_index
283
+ if panel_index_to_split[a] != panel_index_to_split[b]:
284
+ return panel_index_to_split[a] < panel_index_to_split[b]
285
+
286
+ # otherwise, erode the rectangles and try again
287
+ rects = [erode_rectangle(rect, 0.05) for rect in rects]
288
+
289
+ def erode_rectangle(bbox, erosion_factor):
290
+ x1, y1, x2, y2 = bbox
291
+ w, h = x2 - x1, y2 - y1
292
+ cx, cy = x1 + w / 2, y1 + h / 2
293
+ if w < h:
294
+ aspect_ratio = w / h
295
+ erosion_factor_width = erosion_factor * aspect_ratio
296
+ erosion_factor_height = erosion_factor
297
+ else:
298
+ aspect_ratio = h / w
299
+ erosion_factor_width = erosion_factor
300
+ erosion_factor_height = erosion_factor * aspect_ratio
301
+ w = w - w * erosion_factor_width
302
+ h = h - h * erosion_factor_height
303
+ x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
304
+ return [x1, y1, x2, y2]
305
+
306
+ def merge_overlapping_ranges(ranges):
307
+ """
308
+ ranges: list of tuples (x1, x2)
309
+ """
310
+ if len(ranges) == 0:
311
+ return []
312
+ ranges = sorted(ranges, key=lambda x: x[0])
313
+ merged_ranges = []
314
+ for i, r in enumerate(ranges):
315
+ if i == 0:
316
+ prev_x1, prev_x2 = r
317
+ continue
318
+ x1, x2 = r
319
+ if x1 > prev_x2:
320
+ merged_ranges.append((prev_x1, prev_x2))
321
+ prev_x1, prev_x2 = x1, x2
322
+ else:
323
+ prev_x2 = max(prev_x2, x2)
324
+ merged_ranges.append((prev_x1, prev_x2))
325
+ return merged_ranges
326
+
327
+ def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
328
+ text_bboxes = convert_to_list_of_lists(text_bboxes)
329
+ sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
330
+
331
+ if len(text_bboxes) == 0:
332
+ return []
333
+
334
+ def indices_of_same_elements(nums):
335
+ groups = groupby(range(len(nums)), key=lambda i: nums[i])
336
+ return [list(indices) for _, indices in groups]
337
+
338
+ panel_id_for_text = get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes)
339
+ indices_of_texts = list(range(len(text_bboxes)))
340
+ indices_of_texts, panel_id_for_text = zip(*sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
341
+ indices_of_texts = list(indices_of_texts)
342
+ grouped_indices = indices_of_same_elements(panel_id_for_text)
343
+ for group in grouped_indices:
344
+ subset_of_text_indices = [indices_of_texts[i] for i in group]
345
+ text_bboxes_of_subset = [text_bboxes[i] for i in subset_of_text_indices]
346
+ sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
347
+ indices_of_texts[group[0] : group[-1] + 1] = [subset_of_text_indices[i] for i in sorted_subset_indices]
348
+ return indices_of_texts
349
+
350
+ def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
351
+ text_to_panel_mapping = []
352
+ for text_bbox in text_bboxes:
353
+ shapely_text_polygon = box(*text_bbox)
354
+ all_intersections = []
355
+ all_distances = []
356
+ if len(sorted_panel_bboxes) == 0:
357
+ text_to_panel_mapping.append(-1)
358
+ continue
359
+ for j, annotation in enumerate(sorted_panel_bboxes):
360
+ shapely_annotation_polygon = box(*annotation)
361
+ if shapely_text_polygon.intersects(shapely_annotation_polygon):
362
+ all_intersections.append((shapely_text_polygon.intersection(shapely_annotation_polygon).area, j))
363
+ all_distances.append((shapely_text_polygon.distance(shapely_annotation_polygon), j))
364
+ if len(all_intersections) == 0:
365
+ text_to_panel_mapping.append(min(all_distances, key=lambda x: x[0])[1])
366
+ else:
367
+ text_to_panel_mapping.append(max(all_intersections, key=lambda x: x[0])[1])
368
+ return text_to_panel_mapping
369
+
370
+ def sort_texts_within_panel(rects):
371
+ smallest_y = float("inf")
372
+ greatest_x = float("-inf")
373
+ for i, rect in enumerate(rects):
374
+ x1, y1, x2, y2 = rect
375
+ smallest_y = min(smallest_y, y1)
376
+ greatest_x = max(greatest_x, x2)
377
+
378
+ reference_point = Point(greatest_x, smallest_y)
379
+
380
+ polygons_and_index = []
381
+ for i, rect in enumerate(rects):
382
+ x1, y1, x2, y2 = rect
383
+ polygons_and_index.append((box(x1,y1,x2,y2), i))
384
+ # sort points by closest to reference point
385
+ polygons_and_index = sorted(polygons_and_index, key=lambda x: reference_point.distance(x[0]))
386
+ indices = [x[1] for x in polygons_and_index]
387
+ return indices
388
+
389
+ def force_to_be_valid_bboxes(bboxes):
390
+ if len(bboxes) == 0:
391
+ return bboxes
392
+ bboxes_as_xywh = [[x1, y1, x2-x1, y2-y1] for x1, y1, x2, y2 in bboxes]
393
+ bboxes_as_xywh = torch.tensor(bboxes_as_xywh)
394
+ bboxes_as_xywh[:, 2] = torch.clamp(bboxes_as_xywh[:, 2], min=1)
395
+ bboxes_as_xywh[:, 3] = torch.clamp(bboxes_as_xywh[:, 3], min=1)
396
+ bboxes_as_xywh = bboxes_as_xywh.tolist()
397
+ bboxes_as_xyxy = [[x1, y1, x1 + w, y1 + h] for x1, y1, w, h in bboxes_as_xywh]
398
+ return bboxes_as_xyxy
399
+
400
+ def x1y1wh_to_x1y1x2y2(bbox):
401
+ x1, y1, w, h = bbox
402
+ return [x1, y1, x1 + w, y1 + h]
403
+
404
+ def x1y1x2y2_to_xywh(bbox):
405
+ x1, y1, x2, y2 = bbox
406
+ return [x1, y1, x2 - x1, y2 - y1]
407
+
408
+ def convert_to_list_of_lists(rects):
409
+ if isinstance(rects, torch.Tensor):
410
+ return rects.tolist()
411
+ if isinstance(rects, np.ndarray):
412
+ return rects.tolist()
413
+ return [[a, b, c, d] for a, b, c, d in rects]