ragavsachdeva
commited on
Commit
•
f7499c0
1
Parent(s):
7c1657f
Upload model
Browse files- config.json +490 -0
- configuration_magiv2.py +38 -0
- modelling_magiv2.py +612 -0
- processing_magiv2.py +209 -0
- pytorch_model.bin +3 -0
- utils.py +411 -0
config.json
ADDED
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/work/rs/logs/magiv2/to_release_polished",
|
3 |
+
"architectures": [
|
4 |
+
"Magiv2Model"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_magiv2.Magiv2Config",
|
8 |
+
"AutoModel": "modelling_magiv2.Magiv2Model"
|
9 |
+
},
|
10 |
+
"crop_embedding_image_preprocessing_config": {
|
11 |
+
"_processor_class": null,
|
12 |
+
"do_normalize": true,
|
13 |
+
"do_rescale": true,
|
14 |
+
"do_resize": true,
|
15 |
+
"image_mean": [
|
16 |
+
0.485,
|
17 |
+
0.456,
|
18 |
+
0.406
|
19 |
+
],
|
20 |
+
"image_processor_type": "ViTImageProcessor",
|
21 |
+
"image_std": [
|
22 |
+
0.229,
|
23 |
+
0.224,
|
24 |
+
0.225
|
25 |
+
],
|
26 |
+
"resample": 2,
|
27 |
+
"rescale_factor": 0.00392156862745098,
|
28 |
+
"size": {
|
29 |
+
"height": 224,
|
30 |
+
"width": 224
|
31 |
+
}
|
32 |
+
},
|
33 |
+
"crop_embedding_model_config": {
|
34 |
+
"_name_or_path": "facebook/vit-mae-base",
|
35 |
+
"add_cross_attention": false,
|
36 |
+
"architectures": [
|
37 |
+
"ViTMAEForPreTraining"
|
38 |
+
],
|
39 |
+
"attention_probs_dropout_prob": 0.0,
|
40 |
+
"bad_words_ids": null,
|
41 |
+
"begin_suppress_tokens": null,
|
42 |
+
"bos_token_id": null,
|
43 |
+
"chunk_size_feed_forward": 0,
|
44 |
+
"cross_attention_hidden_size": null,
|
45 |
+
"decoder_hidden_size": 512,
|
46 |
+
"decoder_intermediate_size": 2048,
|
47 |
+
"decoder_num_attention_heads": 16,
|
48 |
+
"decoder_num_hidden_layers": 8,
|
49 |
+
"decoder_start_token_id": null,
|
50 |
+
"diversity_penalty": 0.0,
|
51 |
+
"do_sample": false,
|
52 |
+
"early_stopping": false,
|
53 |
+
"encoder_no_repeat_ngram_size": 0,
|
54 |
+
"eos_token_id": null,
|
55 |
+
"exponential_decay_length_penalty": null,
|
56 |
+
"finetuning_task": null,
|
57 |
+
"forced_bos_token_id": null,
|
58 |
+
"forced_eos_token_id": null,
|
59 |
+
"hidden_act": "gelu",
|
60 |
+
"hidden_dropout_prob": 0.0,
|
61 |
+
"hidden_size": 768,
|
62 |
+
"id2label": {
|
63 |
+
"0": "LABEL_0",
|
64 |
+
"1": "LABEL_1"
|
65 |
+
},
|
66 |
+
"image_size": 224,
|
67 |
+
"initializer_range": 0.02,
|
68 |
+
"intermediate_size": 3072,
|
69 |
+
"is_decoder": false,
|
70 |
+
"is_encoder_decoder": false,
|
71 |
+
"label2id": {
|
72 |
+
"LABEL_0": 0,
|
73 |
+
"LABEL_1": 1
|
74 |
+
},
|
75 |
+
"layer_norm_eps": 1e-12,
|
76 |
+
"length_penalty": 1.0,
|
77 |
+
"mask_ratio": 0.75,
|
78 |
+
"max_length": 20,
|
79 |
+
"min_length": 0,
|
80 |
+
"model_type": "",
|
81 |
+
"no_repeat_ngram_size": 0,
|
82 |
+
"norm_pix_loss": false,
|
83 |
+
"num_attention_heads": 12,
|
84 |
+
"num_beam_groups": 1,
|
85 |
+
"num_beams": 1,
|
86 |
+
"num_channels": 3,
|
87 |
+
"num_hidden_layers": 12,
|
88 |
+
"num_return_sequences": 1,
|
89 |
+
"output_attentions": false,
|
90 |
+
"output_hidden_states": false,
|
91 |
+
"output_scores": false,
|
92 |
+
"pad_token_id": null,
|
93 |
+
"patch_size": 16,
|
94 |
+
"prefix": null,
|
95 |
+
"problem_type": null,
|
96 |
+
"pruned_heads": {},
|
97 |
+
"qkv_bias": true,
|
98 |
+
"remove_invalid_values": false,
|
99 |
+
"repetition_penalty": 1.0,
|
100 |
+
"return_dict": true,
|
101 |
+
"return_dict_in_generate": false,
|
102 |
+
"sep_token_id": null,
|
103 |
+
"suppress_tokens": null,
|
104 |
+
"task_specific_params": null,
|
105 |
+
"temperature": 1.0,
|
106 |
+
"tf_legacy_loss": false,
|
107 |
+
"tie_encoder_decoder": false,
|
108 |
+
"tie_word_embeddings": true,
|
109 |
+
"tokenizer_class": null,
|
110 |
+
"top_k": 50,
|
111 |
+
"top_p": 1.0,
|
112 |
+
"torch_dtype": "float32",
|
113 |
+
"torchscript": false,
|
114 |
+
"typical_p": 1.0,
|
115 |
+
"use_bfloat16": false
|
116 |
+
},
|
117 |
+
"detection_image_preprocessing_config": {
|
118 |
+
"_processor_class": null,
|
119 |
+
"do_normalize": true,
|
120 |
+
"do_pad": true,
|
121 |
+
"do_rescale": true,
|
122 |
+
"do_resize": true,
|
123 |
+
"format": "coco_detection",
|
124 |
+
"image_mean": [
|
125 |
+
0.485,
|
126 |
+
0.456,
|
127 |
+
0.406
|
128 |
+
],
|
129 |
+
"image_processor_type": "ConditionalDetrImageProcessor",
|
130 |
+
"image_std": [
|
131 |
+
0.229,
|
132 |
+
0.224,
|
133 |
+
0.225
|
134 |
+
],
|
135 |
+
"resample": 2,
|
136 |
+
"rescale_factor": 0.00392156862745098,
|
137 |
+
"size": {
|
138 |
+
"longest_edge": 1333,
|
139 |
+
"shortest_edge": 800
|
140 |
+
}
|
141 |
+
},
|
142 |
+
"detection_model_config": {
|
143 |
+
"_name_or_path": "microsoft/conditional-detr-resnet-50",
|
144 |
+
"activation_dropout": 0.0,
|
145 |
+
"activation_function": "relu",
|
146 |
+
"add_cross_attention": false,
|
147 |
+
"architectures": [
|
148 |
+
"ConditionalDETRForObjectDetection"
|
149 |
+
],
|
150 |
+
"attention_dropout": 0.0,
|
151 |
+
"auxiliary_loss": false,
|
152 |
+
"backbone": "resnet50",
|
153 |
+
"backbone_config": null,
|
154 |
+
"bad_words_ids": null,
|
155 |
+
"bbox_cost": 5,
|
156 |
+
"bbox_loss_coefficient": 5,
|
157 |
+
"begin_suppress_tokens": null,
|
158 |
+
"bos_token_id": null,
|
159 |
+
"chunk_size_feed_forward": 0,
|
160 |
+
"class_cost": 2,
|
161 |
+
"cls_loss_coefficient": 2,
|
162 |
+
"cross_attention_hidden_size": null,
|
163 |
+
"d_model": 256,
|
164 |
+
"decoder_attention_heads": 8,
|
165 |
+
"decoder_ffn_dim": 2048,
|
166 |
+
"decoder_layerdrop": 0.0,
|
167 |
+
"decoder_layers": 6,
|
168 |
+
"decoder_start_token_id": null,
|
169 |
+
"dice_loss_coefficient": 1,
|
170 |
+
"dilation": false,
|
171 |
+
"diversity_penalty": 0.0,
|
172 |
+
"do_sample": false,
|
173 |
+
"dropout": 0.1,
|
174 |
+
"early_stopping": false,
|
175 |
+
"encoder_attention_heads": 8,
|
176 |
+
"encoder_ffn_dim": 2048,
|
177 |
+
"encoder_layerdrop": 0.0,
|
178 |
+
"encoder_layers": 6,
|
179 |
+
"encoder_no_repeat_ngram_size": 0,
|
180 |
+
"eos_token_id": null,
|
181 |
+
"exponential_decay_length_penalty": null,
|
182 |
+
"finetuning_task": null,
|
183 |
+
"focal_alpha": 0.25,
|
184 |
+
"forced_bos_token_id": null,
|
185 |
+
"forced_eos_token_id": null,
|
186 |
+
"giou_cost": 2,
|
187 |
+
"giou_loss_coefficient": 2,
|
188 |
+
"id2label": {
|
189 |
+
"0": "LABEL_0",
|
190 |
+
"1": "LABEL_1",
|
191 |
+
"2": "LABEL_2",
|
192 |
+
"3": "LABEL_3"
|
193 |
+
},
|
194 |
+
"init_std": 0.02,
|
195 |
+
"init_xavier_std": 1.0,
|
196 |
+
"is_decoder": false,
|
197 |
+
"is_encoder_decoder": true,
|
198 |
+
"label2id": {
|
199 |
+
"LABEL_0": 0,
|
200 |
+
"LABEL_1": 1,
|
201 |
+
"LABEL_2": 2,
|
202 |
+
"LABEL_3": 3
|
203 |
+
},
|
204 |
+
"length_penalty": 1.0,
|
205 |
+
"mask_loss_coefficient": 1,
|
206 |
+
"max_length": 20,
|
207 |
+
"max_position_embeddings": 1024,
|
208 |
+
"min_length": 0,
|
209 |
+
"model_type": "",
|
210 |
+
"no_repeat_ngram_size": 0,
|
211 |
+
"num_beam_groups": 1,
|
212 |
+
"num_beams": 1,
|
213 |
+
"num_channels": 3,
|
214 |
+
"num_hidden_layers": 6,
|
215 |
+
"num_queries": 305,
|
216 |
+
"num_return_sequences": 1,
|
217 |
+
"output_attentions": false,
|
218 |
+
"output_hidden_states": false,
|
219 |
+
"output_scores": false,
|
220 |
+
"pad_token_id": null,
|
221 |
+
"position_embedding_type": "sine",
|
222 |
+
"prefix": null,
|
223 |
+
"problem_type": null,
|
224 |
+
"pruned_heads": {},
|
225 |
+
"remove_invalid_values": false,
|
226 |
+
"repetition_penalty": 1.0,
|
227 |
+
"return_dict": true,
|
228 |
+
"return_dict_in_generate": false,
|
229 |
+
"scale_embedding": false,
|
230 |
+
"sep_token_id": null,
|
231 |
+
"suppress_tokens": null,
|
232 |
+
"task_specific_params": null,
|
233 |
+
"temperature": 1.0,
|
234 |
+
"tf_legacy_loss": false,
|
235 |
+
"tie_encoder_decoder": false,
|
236 |
+
"tie_word_embeddings": true,
|
237 |
+
"tokenizer_class": null,
|
238 |
+
"top_k": 50,
|
239 |
+
"top_p": 1.0,
|
240 |
+
"torch_dtype": "float32",
|
241 |
+
"torchscript": false,
|
242 |
+
"typical_p": 1.0,
|
243 |
+
"use_bfloat16": false,
|
244 |
+
"use_pretrained_backbone": true,
|
245 |
+
"use_timm_backbone": true
|
246 |
+
},
|
247 |
+
"disable_crop_embeddings": false,
|
248 |
+
"disable_detections": false,
|
249 |
+
"disable_ocr": false,
|
250 |
+
"kwargs": {
|
251 |
+
"_commit_hash": null,
|
252 |
+
"_name_or_path": "ragavsachdeva/magiv2",
|
253 |
+
"architectures": [
|
254 |
+
"Magiv2Model"
|
255 |
+
],
|
256 |
+
"auto_map": {
|
257 |
+
"AutoConfig": "configuration_magiv2.Magiv2Config",
|
258 |
+
"AutoModel": "modelling_magiv2.Magiv2Model"
|
259 |
+
},
|
260 |
+
"model_type": "magiv2",
|
261 |
+
"torch_dtype": "float32",
|
262 |
+
"transformers_version": "4.34.0.dev0"
|
263 |
+
},
|
264 |
+
"model_type": "magiv2",
|
265 |
+
"ocr_model_config": {
|
266 |
+
"_name_or_path": "microsoft/trocr-base-printed",
|
267 |
+
"add_cross_attention": false,
|
268 |
+
"architectures": [
|
269 |
+
"VisionEncoderDecoderModel"
|
270 |
+
],
|
271 |
+
"bad_words_ids": null,
|
272 |
+
"begin_suppress_tokens": null,
|
273 |
+
"bos_token_id": null,
|
274 |
+
"chunk_size_feed_forward": 0,
|
275 |
+
"cross_attention_hidden_size": null,
|
276 |
+
"decoder": {
|
277 |
+
"_name_or_path": "",
|
278 |
+
"activation_dropout": 0.0,
|
279 |
+
"activation_function": "gelu",
|
280 |
+
"add_cross_attention": true,
|
281 |
+
"architectures": null,
|
282 |
+
"attention_dropout": 0.0,
|
283 |
+
"bad_words_ids": null,
|
284 |
+
"begin_suppress_tokens": null,
|
285 |
+
"bos_token_id": 0,
|
286 |
+
"chunk_size_feed_forward": 0,
|
287 |
+
"classifier_dropout": 0.0,
|
288 |
+
"cross_attention_hidden_size": 768,
|
289 |
+
"d_model": 1024,
|
290 |
+
"decoder_attention_heads": 16,
|
291 |
+
"decoder_ffn_dim": 4096,
|
292 |
+
"decoder_layerdrop": 0.0,
|
293 |
+
"decoder_layers": 12,
|
294 |
+
"decoder_start_token_id": 2,
|
295 |
+
"diversity_penalty": 0.0,
|
296 |
+
"do_sample": false,
|
297 |
+
"dropout": 0.1,
|
298 |
+
"early_stopping": false,
|
299 |
+
"encoder_no_repeat_ngram_size": 0,
|
300 |
+
"eos_token_id": 2,
|
301 |
+
"exponential_decay_length_penalty": null,
|
302 |
+
"finetuning_task": null,
|
303 |
+
"forced_bos_token_id": null,
|
304 |
+
"forced_eos_token_id": null,
|
305 |
+
"id2label": {
|
306 |
+
"0": "LABEL_0",
|
307 |
+
"1": "LABEL_1"
|
308 |
+
},
|
309 |
+
"init_std": 0.02,
|
310 |
+
"is_decoder": true,
|
311 |
+
"is_encoder_decoder": false,
|
312 |
+
"label2id": {
|
313 |
+
"LABEL_0": 0,
|
314 |
+
"LABEL_1": 1
|
315 |
+
},
|
316 |
+
"layernorm_embedding": true,
|
317 |
+
"length_penalty": 1.0,
|
318 |
+
"max_length": 20,
|
319 |
+
"max_position_embeddings": 512,
|
320 |
+
"min_length": 0,
|
321 |
+
"model_type": "trocr",
|
322 |
+
"no_repeat_ngram_size": 0,
|
323 |
+
"num_beam_groups": 1,
|
324 |
+
"num_beams": 1,
|
325 |
+
"num_return_sequences": 1,
|
326 |
+
"output_attentions": false,
|
327 |
+
"output_hidden_states": false,
|
328 |
+
"output_scores": false,
|
329 |
+
"pad_token_id": 1,
|
330 |
+
"prefix": null,
|
331 |
+
"problem_type": null,
|
332 |
+
"pruned_heads": {},
|
333 |
+
"remove_invalid_values": false,
|
334 |
+
"repetition_penalty": 1.0,
|
335 |
+
"return_dict": true,
|
336 |
+
"return_dict_in_generate": false,
|
337 |
+
"scale_embedding": false,
|
338 |
+
"sep_token_id": null,
|
339 |
+
"suppress_tokens": null,
|
340 |
+
"task_specific_params": null,
|
341 |
+
"temperature": 1.0,
|
342 |
+
"tf_legacy_loss": false,
|
343 |
+
"tie_encoder_decoder": false,
|
344 |
+
"tie_word_embeddings": true,
|
345 |
+
"tokenizer_class": null,
|
346 |
+
"top_k": 50,
|
347 |
+
"top_p": 1.0,
|
348 |
+
"torch_dtype": null,
|
349 |
+
"torchscript": false,
|
350 |
+
"typical_p": 1.0,
|
351 |
+
"use_bfloat16": false,
|
352 |
+
"use_cache": false,
|
353 |
+
"use_learned_position_embeddings": true,
|
354 |
+
"vocab_size": 50265
|
355 |
+
},
|
356 |
+
"decoder_start_token_id": null,
|
357 |
+
"diversity_penalty": 0.0,
|
358 |
+
"do_sample": false,
|
359 |
+
"early_stopping": false,
|
360 |
+
"encoder": {
|
361 |
+
"_name_or_path": "",
|
362 |
+
"add_cross_attention": false,
|
363 |
+
"architectures": null,
|
364 |
+
"attention_probs_dropout_prob": 0.0,
|
365 |
+
"bad_words_ids": null,
|
366 |
+
"begin_suppress_tokens": null,
|
367 |
+
"bos_token_id": null,
|
368 |
+
"chunk_size_feed_forward": 0,
|
369 |
+
"cross_attention_hidden_size": null,
|
370 |
+
"decoder_start_token_id": null,
|
371 |
+
"diversity_penalty": 0.0,
|
372 |
+
"do_sample": false,
|
373 |
+
"early_stopping": false,
|
374 |
+
"encoder_no_repeat_ngram_size": 0,
|
375 |
+
"encoder_stride": 16,
|
376 |
+
"eos_token_id": null,
|
377 |
+
"exponential_decay_length_penalty": null,
|
378 |
+
"finetuning_task": null,
|
379 |
+
"forced_bos_token_id": null,
|
380 |
+
"forced_eos_token_id": null,
|
381 |
+
"hidden_act": "gelu",
|
382 |
+
"hidden_dropout_prob": 0.0,
|
383 |
+
"hidden_size": 768,
|
384 |
+
"id2label": {
|
385 |
+
"0": "LABEL_0",
|
386 |
+
"1": "LABEL_1"
|
387 |
+
},
|
388 |
+
"image_size": 384,
|
389 |
+
"initializer_range": 0.02,
|
390 |
+
"intermediate_size": 3072,
|
391 |
+
"is_decoder": false,
|
392 |
+
"is_encoder_decoder": false,
|
393 |
+
"label2id": {
|
394 |
+
"LABEL_0": 0,
|
395 |
+
"LABEL_1": 1
|
396 |
+
},
|
397 |
+
"layer_norm_eps": 1e-12,
|
398 |
+
"length_penalty": 1.0,
|
399 |
+
"max_length": 20,
|
400 |
+
"min_length": 0,
|
401 |
+
"model_type": "vit",
|
402 |
+
"no_repeat_ngram_size": 0,
|
403 |
+
"num_attention_heads": 12,
|
404 |
+
"num_beam_groups": 1,
|
405 |
+
"num_beams": 1,
|
406 |
+
"num_channels": 3,
|
407 |
+
"num_hidden_layers": 12,
|
408 |
+
"num_return_sequences": 1,
|
409 |
+
"output_attentions": false,
|
410 |
+
"output_hidden_states": false,
|
411 |
+
"output_scores": false,
|
412 |
+
"pad_token_id": null,
|
413 |
+
"patch_size": 16,
|
414 |
+
"prefix": null,
|
415 |
+
"problem_type": null,
|
416 |
+
"pruned_heads": {},
|
417 |
+
"qkv_bias": false,
|
418 |
+
"remove_invalid_values": false,
|
419 |
+
"repetition_penalty": 1.0,
|
420 |
+
"return_dict": true,
|
421 |
+
"return_dict_in_generate": false,
|
422 |
+
"sep_token_id": null,
|
423 |
+
"suppress_tokens": null,
|
424 |
+
"task_specific_params": null,
|
425 |
+
"temperature": 1.0,
|
426 |
+
"tf_legacy_loss": false,
|
427 |
+
"tie_encoder_decoder": false,
|
428 |
+
"tie_word_embeddings": true,
|
429 |
+
"tokenizer_class": null,
|
430 |
+
"top_k": 50,
|
431 |
+
"top_p": 1.0,
|
432 |
+
"torch_dtype": null,
|
433 |
+
"torchscript": false,
|
434 |
+
"typical_p": 1.0,
|
435 |
+
"use_bfloat16": false
|
436 |
+
},
|
437 |
+
"encoder_no_repeat_ngram_size": 0,
|
438 |
+
"eos_token_id": null,
|
439 |
+
"exponential_decay_length_penalty": null,
|
440 |
+
"finetuning_task": null,
|
441 |
+
"forced_bos_token_id": null,
|
442 |
+
"forced_eos_token_id": null,
|
443 |
+
"id2label": {
|
444 |
+
"0": "LABEL_0",
|
445 |
+
"1": "LABEL_1"
|
446 |
+
},
|
447 |
+
"is_decoder": false,
|
448 |
+
"is_encoder_decoder": true,
|
449 |
+
"label2id": {
|
450 |
+
"LABEL_0": 0,
|
451 |
+
"LABEL_1": 1
|
452 |
+
},
|
453 |
+
"length_penalty": 1.0,
|
454 |
+
"max_length": 20,
|
455 |
+
"min_length": 0,
|
456 |
+
"model_type": "vision-encoder-decoder",
|
457 |
+
"no_repeat_ngram_size": 0,
|
458 |
+
"num_beam_groups": 1,
|
459 |
+
"num_beams": 1,
|
460 |
+
"num_return_sequences": 1,
|
461 |
+
"output_attentions": false,
|
462 |
+
"output_hidden_states": false,
|
463 |
+
"output_scores": false,
|
464 |
+
"pad_token_id": null,
|
465 |
+
"prefix": null,
|
466 |
+
"problem_type": null,
|
467 |
+
"pruned_heads": {},
|
468 |
+
"remove_invalid_values": false,
|
469 |
+
"repetition_penalty": 1.0,
|
470 |
+
"return_dict": true,
|
471 |
+
"return_dict_in_generate": false,
|
472 |
+
"sep_token_id": null,
|
473 |
+
"suppress_tokens": null,
|
474 |
+
"task_specific_params": null,
|
475 |
+
"temperature": 1.0,
|
476 |
+
"tf_legacy_loss": false,
|
477 |
+
"tie_encoder_decoder": false,
|
478 |
+
"tie_word_embeddings": false,
|
479 |
+
"tokenizer_class": null,
|
480 |
+
"top_k": 50,
|
481 |
+
"top_p": 1.0,
|
482 |
+
"torch_dtype": "float32",
|
483 |
+
"torchscript": false,
|
484 |
+
"typical_p": 1.0,
|
485 |
+
"use_bfloat16": false
|
486 |
+
},
|
487 |
+
"ocr_pretrained_processor_path": "microsoft/trocr-base-printed",
|
488 |
+
"torch_dtype": "float32",
|
489 |
+
"transformers_version": "4.34.0.dev0"
|
490 |
+
}
|
configuration_magiv2.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig, VisionEncoderDecoderConfig
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
class Magiv2Config(PretrainedConfig):
|
6 |
+
model_type = "magiv2"
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
disable_ocr: bool = False,
|
11 |
+
disable_crop_embeddings: bool = False,
|
12 |
+
disable_detections: bool = False,
|
13 |
+
detection_model_config: dict = None,
|
14 |
+
ocr_model_config: dict = None,
|
15 |
+
crop_embedding_model_config: dict = None,
|
16 |
+
detection_image_preprocessing_config: dict = None,
|
17 |
+
ocr_pretrained_processor_path: str = None,
|
18 |
+
crop_embedding_image_preprocessing_config: dict = None,
|
19 |
+
**kwargs,
|
20 |
+
):
|
21 |
+
self.disable_ocr = disable_ocr
|
22 |
+
self.disable_crop_embeddings = disable_crop_embeddings
|
23 |
+
self.disable_detections = disable_detections
|
24 |
+
self.kwargs = kwargs
|
25 |
+
self.detection_model_config = None
|
26 |
+
self.ocr_model_config = None
|
27 |
+
self.crop_embedding_model_config = None
|
28 |
+
if detection_model_config is not None:
|
29 |
+
self.detection_model_config = PretrainedConfig.from_dict(detection_model_config)
|
30 |
+
if ocr_model_config is not None:
|
31 |
+
self.ocr_model_config = VisionEncoderDecoderConfig.from_dict(ocr_model_config)
|
32 |
+
if crop_embedding_model_config is not None:
|
33 |
+
self.crop_embedding_model_config = PretrainedConfig.from_dict(crop_embedding_model_config)
|
34 |
+
|
35 |
+
self.detection_image_preprocessing_config = detection_image_preprocessing_config
|
36 |
+
self.ocr_pretrained_processor_path = ocr_pretrained_processor_path
|
37 |
+
self.crop_embedding_image_preprocessing_config = crop_embedding_image_preprocessing_config
|
38 |
+
super().__init__(**kwargs)
|
modelling_magiv2.py
ADDED
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel, ConditionalDetrModel
|
2 |
+
from transformers.models.conditional_detr.modeling_conditional_detr import (
|
3 |
+
ConditionalDetrMLPPredictionHead,
|
4 |
+
ConditionalDetrModelOutput,
|
5 |
+
ConditionalDetrHungarianMatcher,
|
6 |
+
inverse_sigmoid,
|
7 |
+
)
|
8 |
+
from .configuration_magiv2 import Magiv2Config
|
9 |
+
from .processing_magiv2 import Magiv2Processor
|
10 |
+
from torch import nn
|
11 |
+
from typing import Optional, List
|
12 |
+
import torch
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order
|
15 |
+
from transformers.image_transforms import center_to_corners_format
|
16 |
+
from .utils import UnionFind, sort_panels, sort_text_boxes_in_reading_order
|
17 |
+
import pulp
|
18 |
+
import scipy
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
class Magiv2Model(PreTrainedModel):
|
22 |
+
config_class = Magiv2Config
|
23 |
+
|
24 |
+
def __init__(self, config):
|
25 |
+
super().__init__(config)
|
26 |
+
self.config = config
|
27 |
+
self.processor = Magiv2Processor(config)
|
28 |
+
if not config.disable_ocr:
|
29 |
+
self.ocr_model = VisionEncoderDecoderModel(config.ocr_model_config)
|
30 |
+
if not config.disable_crop_embeddings:
|
31 |
+
self.crop_embedding_model = ViTMAEModel(config.crop_embedding_model_config)
|
32 |
+
if not config.disable_detections:
|
33 |
+
self.num_non_obj_tokens = 5
|
34 |
+
self.detection_transformer = ConditionalDetrModel(config.detection_model_config)
|
35 |
+
self.bbox_predictor = ConditionalDetrMLPPredictionHead(
|
36 |
+
input_dim=config.detection_model_config.d_model,
|
37 |
+
hidden_dim=config.detection_model_config.d_model,
|
38 |
+
output_dim=4, num_layers=3
|
39 |
+
)
|
40 |
+
self.character_character_matching_head = ConditionalDetrMLPPredictionHead(
|
41 |
+
input_dim = 3 * config.detection_model_config.d_model + (2 * config.crop_embedding_model_config.hidden_size if not config.disable_crop_embeddings else 0),
|
42 |
+
hidden_dim=config.detection_model_config.d_model,
|
43 |
+
output_dim=1, num_layers=3
|
44 |
+
)
|
45 |
+
self.text_character_matching_head = ConditionalDetrMLPPredictionHead(
|
46 |
+
input_dim = 3 * config.detection_model_config.d_model,
|
47 |
+
hidden_dim=config.detection_model_config.d_model,
|
48 |
+
output_dim=1, num_layers=3
|
49 |
+
)
|
50 |
+
self.text_tail_matching_head = ConditionalDetrMLPPredictionHead(
|
51 |
+
input_dim = 2 * config.detection_model_config.d_model,
|
52 |
+
hidden_dim=config.detection_model_config.d_model,
|
53 |
+
output_dim=1, num_layers=3
|
54 |
+
)
|
55 |
+
self.class_labels_classifier = nn.Linear(
|
56 |
+
config.detection_model_config.d_model, config.detection_model_config.num_labels
|
57 |
+
)
|
58 |
+
self.is_this_text_a_dialogue = nn.Linear(
|
59 |
+
config.detection_model_config.d_model, 1
|
60 |
+
)
|
61 |
+
self.matcher = ConditionalDetrHungarianMatcher(
|
62 |
+
class_cost=config.detection_model_config.class_cost,
|
63 |
+
bbox_cost=config.detection_model_config.bbox_cost,
|
64 |
+
giou_cost=config.detection_model_config.giou_cost
|
65 |
+
)
|
66 |
+
|
67 |
+
def move_to_device(self, input):
|
68 |
+
return move_to_device(input, self.device)
|
69 |
+
|
70 |
+
@torch.no_grad()
|
71 |
+
def do_chapter_wide_prediction(self, pages_in_order, character_bank, eta=0.75, batch_size=8, use_tqdm=False, do_ocr=True):
|
72 |
+
texts = []
|
73 |
+
characters = []
|
74 |
+
character_clusters = []
|
75 |
+
if use_tqdm:
|
76 |
+
from tqdm import tqdm
|
77 |
+
iterator = tqdm(range(0, len(pages_in_order), batch_size))
|
78 |
+
else:
|
79 |
+
iterator = range(0, len(pages_in_order), batch_size)
|
80 |
+
per_page_results = []
|
81 |
+
for i in iterator:
|
82 |
+
pages = pages_in_order[i:i+batch_size]
|
83 |
+
results = self.predict_detections_and_associations(pages)
|
84 |
+
per_page_results.extend([result for result in results])
|
85 |
+
|
86 |
+
texts = [result["texts"] for result in per_page_results]
|
87 |
+
characters = [result["characters"] for result in per_page_results]
|
88 |
+
character_clusters = [result["character_cluster_labels"] for result in per_page_results]
|
89 |
+
assigned_character_names = self.assign_names_to_characters(pages_in_order, characters, character_bank, character_clusters, eta=eta)
|
90 |
+
if do_ocr:
|
91 |
+
ocr = self.predict_ocr(pages_in_order, texts, use_tqdm=use_tqdm)
|
92 |
+
offset_characters = 0
|
93 |
+
iteration_over = zip(per_page_results, ocr) if do_ocr else per_page_results
|
94 |
+
for iter in iteration_over:
|
95 |
+
if do_ocr:
|
96 |
+
result, ocr_for_page = iter
|
97 |
+
result["ocr"] = ocr_for_page
|
98 |
+
else:
|
99 |
+
result = iter
|
100 |
+
result["character_names"] = assigned_character_names[offset_characters:offset_characters + len(result["characters"])]
|
101 |
+
offset_characters += len(result["characters"])
|
102 |
+
return per_page_results
|
103 |
+
|
104 |
+
|
105 |
+
def assign_names_to_characters(self, images, character_bboxes, character_bank, character_clusters, eta=0.75):
|
106 |
+
chapter_wide_char_embeddings = self.predict_crop_embeddings(images, character_bboxes)
|
107 |
+
chapter_wide_char_embeddings = torch.cat(chapter_wide_char_embeddings, dim=0)
|
108 |
+
chapter_wide_char_embeddings = torch.nn.functional.normalize(chapter_wide_char_embeddings, p=2, dim=1).cpu().numpy()
|
109 |
+
# create must-link and cannot link constraints from character_clusters
|
110 |
+
must_link = []
|
111 |
+
cannot_link = []
|
112 |
+
offset = 0
|
113 |
+
for clusters_per_image in character_clusters:
|
114 |
+
for i in range(len(clusters_per_image)):
|
115 |
+
for j in range(i+1, len(clusters_per_image)):
|
116 |
+
if clusters_per_image[i] == clusters_per_image[j]:
|
117 |
+
must_link.append((offset + i, offset + j))
|
118 |
+
else:
|
119 |
+
cannot_link.append((offset + i, offset + j))
|
120 |
+
offset += len(clusters_per_image)
|
121 |
+
character_bank_for_this_chapter = self.predict_crop_embeddings(character_bank["images"], [[[0, 0, x.shape[1], x.shape[0]]] for x in character_bank["images"]])
|
122 |
+
character_bank_for_this_chapter = torch.cat(character_bank_for_this_chapter, dim=0)
|
123 |
+
character_bank_for_this_chapter = torch.nn.functional.normalize(character_bank_for_this_chapter, p=2, dim=1).cpu().numpy()
|
124 |
+
costs = scipy.spatial.distance.cdist(chapter_wide_char_embeddings, character_bank_for_this_chapter)
|
125 |
+
none_of_the_above = eta * np.ones((costs.shape[0],1))
|
126 |
+
costs = np.concatenate([costs, none_of_the_above], axis=1)
|
127 |
+
sense = pulp.LpMinimize
|
128 |
+
num_supply, num_demand = costs.shape
|
129 |
+
problem = pulp.LpProblem("Optimal_Transport_Problem", sense)
|
130 |
+
x = pulp.LpVariable.dicts("x", ((i, j) for i in range(num_supply) for j in range(num_demand)), cat='Binary')
|
131 |
+
# Objective Function to minimize
|
132 |
+
problem += pulp.lpSum([costs[i][j] * x[(i, j)] for i in range(num_supply) for j in range(num_demand)])
|
133 |
+
# each crop must be assigned to exactly one character
|
134 |
+
for i in range(num_supply):
|
135 |
+
problem += pulp.lpSum([x[(i, j)] for j in range(num_demand)]) == 1, f"Supply_{i}_Total_Assignment"
|
136 |
+
# cannot link constraints
|
137 |
+
for j in range(num_demand-1):
|
138 |
+
for (s1, s2) in cannot_link:
|
139 |
+
problem += x[(s1, j)] + x[(s2, j)] <= 1, f"Exclusion_{s1}_{s2}_Demand_{j}"
|
140 |
+
# must link constraints
|
141 |
+
for j in range(num_demand):
|
142 |
+
for (s1, s2) in must_link:
|
143 |
+
problem += x[(s1, j)] - x[(s2, j)] == 0, f"Inclusion_{s1}_{s2}_Demand_{j}"
|
144 |
+
problem.solve()
|
145 |
+
assignments = []
|
146 |
+
for v in problem.variables():
|
147 |
+
if v.varValue > 0:
|
148 |
+
index, assignment = v.name.split("(")[1].split(")")[0].split(",")
|
149 |
+
assignment = assignment[1:]
|
150 |
+
assignments.append((int(index), int(assignment)))
|
151 |
+
|
152 |
+
labels = np.zeros(num_supply)
|
153 |
+
for i, j in assignments:
|
154 |
+
labels[i] = j
|
155 |
+
|
156 |
+
return [character_bank["names"][int(i)] if i < len(character_bank["names"]) else "Other" for i in labels]
|
157 |
+
|
158 |
+
|
159 |
+
def predict_detections_and_associations(
|
160 |
+
self,
|
161 |
+
images,
|
162 |
+
move_to_device_fn=None,
|
163 |
+
character_detection_threshold=0.3,
|
164 |
+
panel_detection_threshold=0.2,
|
165 |
+
text_detection_threshold=0.3,
|
166 |
+
tail_detection_threshold=0.34,
|
167 |
+
character_character_matching_threshold=0.65,
|
168 |
+
text_character_matching_threshold=0.35,
|
169 |
+
text_tail_matching_threshold=0.3,
|
170 |
+
text_classification_threshold=0.5,
|
171 |
+
):
|
172 |
+
assert not self.config.disable_detections
|
173 |
+
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
174 |
+
|
175 |
+
inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images)
|
176 |
+
inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
|
177 |
+
|
178 |
+
detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
|
179 |
+
predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
|
180 |
+
|
181 |
+
original_image_sizes = torch.stack([torch.tensor(img.shape[:2]) for img in images], dim=0).to(predicted_bboxes.device)
|
182 |
+
|
183 |
+
batch_scores, batch_labels = predicted_class_scores.max(-1)
|
184 |
+
batch_scores = batch_scores.sigmoid()
|
185 |
+
batch_labels = batch_labels.long()
|
186 |
+
batch_bboxes = center_to_corners_format(predicted_bboxes)
|
187 |
+
|
188 |
+
# scale the bboxes back to the original image size
|
189 |
+
if isinstance(original_image_sizes, List):
|
190 |
+
img_h = torch.Tensor([i[0] for i in original_image_sizes])
|
191 |
+
img_w = torch.Tensor([i[1] for i in original_image_sizes])
|
192 |
+
else:
|
193 |
+
img_h, img_w = original_image_sizes.unbind(1)
|
194 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(batch_bboxes.device)
|
195 |
+
batch_bboxes = batch_bboxes * scale_fct[:, None, :]
|
196 |
+
|
197 |
+
batch_panel_indices = self.processor._get_indices_of_panels_to_keep(batch_scores, batch_labels, batch_bboxes, panel_detection_threshold)
|
198 |
+
batch_character_indices = self.processor._get_indices_of_characters_to_keep(batch_scores, batch_labels, batch_bboxes, character_detection_threshold)
|
199 |
+
batch_text_indices = self.processor._get_indices_of_texts_to_keep(batch_scores, batch_labels, batch_bboxes, text_detection_threshold)
|
200 |
+
batch_tail_indices = self.processor._get_indices_of_tails_to_keep(batch_scores, batch_labels, batch_bboxes, tail_detection_threshold)
|
201 |
+
|
202 |
+
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
|
203 |
+
predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
|
204 |
+
predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
|
205 |
+
|
206 |
+
text_character_affinity_matrices = self._get_text_character_affinity_matrices(
|
207 |
+
character_obj_tokens_for_batch=[x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_character_indices)],
|
208 |
+
text_obj_tokens_for_this_batch=[x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_text_indices)],
|
209 |
+
t2c_tokens_for_batch=predicted_t2c_tokens_for_batch,
|
210 |
+
apply_sigmoid=True,
|
211 |
+
)
|
212 |
+
|
213 |
+
character_bboxes_in_batch = [batch_bboxes[i][j] for i, j in enumerate(batch_character_indices)]
|
214 |
+
character_character_affinity_matrices = self._get_character_character_affinity_matrices(
|
215 |
+
character_obj_tokens_for_batch=[x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_character_indices)],
|
216 |
+
crop_embeddings_for_batch=self.predict_crop_embeddings(images, character_bboxes_in_batch, move_to_device_fn),
|
217 |
+
c2c_tokens_for_batch=predicted_c2c_tokens_for_batch,
|
218 |
+
apply_sigmoid=True,
|
219 |
+
)
|
220 |
+
|
221 |
+
text_tail_affinity_matrices = self._get_text_tail_affinity_matrices(
|
222 |
+
text_obj_tokens_for_this_batch=[x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_text_indices)],
|
223 |
+
tail_obj_tokens_for_batch=[x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_tail_indices)],
|
224 |
+
apply_sigmoid=True,
|
225 |
+
)
|
226 |
+
|
227 |
+
is_this_text_a_dialogue = self._get_text_classification([x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_text_indices)])
|
228 |
+
|
229 |
+
results = []
|
230 |
+
for batch_index in range(len(batch_scores)):
|
231 |
+
panel_indices = batch_panel_indices[batch_index]
|
232 |
+
character_indices = batch_character_indices[batch_index]
|
233 |
+
text_indices = batch_text_indices[batch_index]
|
234 |
+
tail_indices = batch_tail_indices[batch_index]
|
235 |
+
|
236 |
+
character_bboxes = batch_bboxes[batch_index][character_indices]
|
237 |
+
panel_bboxes = batch_bboxes[batch_index][panel_indices]
|
238 |
+
text_bboxes = batch_bboxes[batch_index][text_indices]
|
239 |
+
tail_bboxes = batch_bboxes[batch_index][tail_indices]
|
240 |
+
|
241 |
+
local_sorted_panel_indices = sort_panels(panel_bboxes)
|
242 |
+
panel_bboxes = panel_bboxes[local_sorted_panel_indices]
|
243 |
+
local_sorted_text_indices = sort_text_boxes_in_reading_order(text_bboxes, panel_bboxes)
|
244 |
+
text_bboxes = text_bboxes[local_sorted_text_indices]
|
245 |
+
|
246 |
+
character_character_matching_scores = character_character_affinity_matrices[batch_index]
|
247 |
+
text_character_matching_scores = text_character_affinity_matrices[batch_index][local_sorted_text_indices]
|
248 |
+
text_tail_matching_scores = text_tail_affinity_matrices[batch_index][local_sorted_text_indices]
|
249 |
+
|
250 |
+
is_essential_text = is_this_text_a_dialogue[batch_index][local_sorted_text_indices] > text_classification_threshold
|
251 |
+
character_cluster_labels = UnionFind.from_adj_matrix(
|
252 |
+
character_character_matching_scores > character_character_matching_threshold
|
253 |
+
).get_labels_for_connected_components()
|
254 |
+
|
255 |
+
if 0 in text_character_matching_scores.shape:
|
256 |
+
text_character_associations = torch.zeros((0, 2), dtype=torch.long)
|
257 |
+
else:
|
258 |
+
most_likely_speaker_for_each_text = torch.argmax(text_character_matching_scores, dim=1)
|
259 |
+
text_indices = torch.arange(len(text_bboxes)).type_as(most_likely_speaker_for_each_text)
|
260 |
+
text_character_associations = torch.stack([text_indices, most_likely_speaker_for_each_text], dim=1)
|
261 |
+
to_keep = text_character_matching_scores.max(dim=1).values > text_character_matching_threshold
|
262 |
+
text_character_associations = text_character_associations[to_keep]
|
263 |
+
|
264 |
+
if 0 in text_tail_matching_scores.shape:
|
265 |
+
text_tail_associations = torch.zeros((0, 2), dtype=torch.long)
|
266 |
+
else:
|
267 |
+
most_likely_tail_for_each_text = torch.argmax(text_tail_matching_scores, dim=1)
|
268 |
+
text_indices = torch.arange(len(text_bboxes)).type_as(most_likely_tail_for_each_text)
|
269 |
+
text_tail_associations = torch.stack([text_indices, most_likely_tail_for_each_text], dim=1)
|
270 |
+
to_keep = text_tail_matching_scores.max(dim=1).values > text_tail_matching_threshold
|
271 |
+
text_tail_associations = text_tail_associations[to_keep]
|
272 |
+
|
273 |
+
results.append({
|
274 |
+
"panels": panel_bboxes.tolist(),
|
275 |
+
"texts": text_bboxes.tolist(),
|
276 |
+
"characters": character_bboxes.tolist(),
|
277 |
+
"tails": tail_bboxes.tolist(),
|
278 |
+
"text_character_associations": text_character_associations.tolist(),
|
279 |
+
"text_tail_associations": text_tail_associations.tolist(),
|
280 |
+
"character_cluster_labels": character_cluster_labels,
|
281 |
+
"is_essential_text": is_essential_text.tolist(),
|
282 |
+
})
|
283 |
+
|
284 |
+
return results
|
285 |
+
|
286 |
+
def get_affinity_matrices_given_annotations(
|
287 |
+
self, images, annotations, move_to_device_fn=None, apply_sigmoid=True
|
288 |
+
):
|
289 |
+
assert not self.config.disable_detections
|
290 |
+
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
291 |
+
|
292 |
+
character_bboxes_in_batch = [[bbox for bbox, label in zip(a["bboxes_as_x1y1x2y2"], a["labels"]) if label == 0] for a in annotations]
|
293 |
+
crop_embeddings_for_batch = self.predict_crop_embeddings(images, character_bboxes_in_batch, move_to_device_fn)
|
294 |
+
|
295 |
+
inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
|
296 |
+
inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
|
297 |
+
processed_targets = inputs_to_detection_transformer.pop("labels")
|
298 |
+
|
299 |
+
detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
|
300 |
+
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
|
301 |
+
predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
|
302 |
+
predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
|
303 |
+
|
304 |
+
predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
|
305 |
+
matching_dict = {
|
306 |
+
"logits": predicted_class_scores,
|
307 |
+
"pred_boxes": predicted_bboxes,
|
308 |
+
}
|
309 |
+
indices = self.matcher(matching_dict, processed_targets)
|
310 |
+
|
311 |
+
matched_char_obj_tokens_for_batch = []
|
312 |
+
matched_text_obj_tokens_for_batch = []
|
313 |
+
matched_tail_obj_tokens_for_batch = []
|
314 |
+
t2c_tokens_for_batch = []
|
315 |
+
c2c_tokens_for_batch = []
|
316 |
+
|
317 |
+
for j, (pred_idx, tgt_idx) in enumerate(indices):
|
318 |
+
target_idx_to_pred_idx = {tgt.item(): pred.item() for pred, tgt in zip(pred_idx, tgt_idx)}
|
319 |
+
targets_for_this_image = processed_targets[j]
|
320 |
+
indices_of_text_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 1]
|
321 |
+
indices_of_char_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 0]
|
322 |
+
indices_of_tail_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 3]
|
323 |
+
predicted_text_indices = [target_idx_to_pred_idx[i] for i in indices_of_text_boxes_in_annotation]
|
324 |
+
predicted_char_indices = [target_idx_to_pred_idx[i] for i in indices_of_char_boxes_in_annotation]
|
325 |
+
predicted_tail_indices = [target_idx_to_pred_idx[i] for i in indices_of_tail_boxes_in_annotation]
|
326 |
+
matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
|
327 |
+
matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
|
328 |
+
matched_tail_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_tail_indices])
|
329 |
+
t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
|
330 |
+
c2c_tokens_for_batch.append(predicted_c2c_tokens_for_batch[j])
|
331 |
+
|
332 |
+
text_character_affinity_matrices = self._get_text_character_affinity_matrices(
|
333 |
+
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
|
334 |
+
text_obj_tokens_for_this_batch=matched_text_obj_tokens_for_batch,
|
335 |
+
t2c_tokens_for_batch=t2c_tokens_for_batch,
|
336 |
+
apply_sigmoid=apply_sigmoid,
|
337 |
+
)
|
338 |
+
|
339 |
+
character_character_affinity_matrices = self._get_character_character_affinity_matrices(
|
340 |
+
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
|
341 |
+
crop_embeddings_for_batch=crop_embeddings_for_batch,
|
342 |
+
c2c_tokens_for_batch=c2c_tokens_for_batch,
|
343 |
+
apply_sigmoid=apply_sigmoid,
|
344 |
+
)
|
345 |
+
|
346 |
+
character_character_affinity_matrices_crop_only = self._get_character_character_affinity_matrices(
|
347 |
+
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
|
348 |
+
crop_embeddings_for_batch=crop_embeddings_for_batch,
|
349 |
+
c2c_tokens_for_batch=c2c_tokens_for_batch,
|
350 |
+
crop_only=True,
|
351 |
+
apply_sigmoid=apply_sigmoid,
|
352 |
+
)
|
353 |
+
|
354 |
+
text_tail_affinity_matrices = self._get_text_tail_affinity_matrices(
|
355 |
+
text_obj_tokens_for_this_batch=matched_text_obj_tokens_for_batch,
|
356 |
+
tail_obj_tokens_for_batch=matched_tail_obj_tokens_for_batch,
|
357 |
+
apply_sigmoid=apply_sigmoid,
|
358 |
+
)
|
359 |
+
|
360 |
+
is_this_text_a_dialogue = self._get_text_classification(matched_text_obj_tokens_for_batch, apply_sigmoid=apply_sigmoid)
|
361 |
+
|
362 |
+
return {
|
363 |
+
"text_character_affinity_matrices": text_character_affinity_matrices,
|
364 |
+
"character_character_affinity_matrices": character_character_affinity_matrices,
|
365 |
+
"character_character_affinity_matrices_crop_only": character_character_affinity_matrices_crop_only,
|
366 |
+
"text_tail_affinity_matrices": text_tail_affinity_matrices,
|
367 |
+
"is_this_text_a_dialogue": is_this_text_a_dialogue,
|
368 |
+
}
|
369 |
+
|
370 |
+
|
371 |
+
def predict_crop_embeddings(self, images, crop_bboxes, move_to_device_fn=None, mask_ratio=0.0, batch_size=256):
|
372 |
+
if self.config.disable_crop_embeddings:
|
373 |
+
return None
|
374 |
+
|
375 |
+
assert isinstance(crop_bboxes, List), "please provide a list of bboxes for each image to get embeddings for"
|
376 |
+
|
377 |
+
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
378 |
+
|
379 |
+
# temporarily change the mask ratio from default to the one specified
|
380 |
+
old_mask_ratio = self.crop_embedding_model.embeddings.config.mask_ratio
|
381 |
+
self.crop_embedding_model.embeddings.config.mask_ratio = mask_ratio
|
382 |
+
|
383 |
+
crops_per_image = []
|
384 |
+
num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
|
385 |
+
for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
|
386 |
+
crops = self.processor.crop_image(image, bboxes)
|
387 |
+
assert len(crops) == num_crops
|
388 |
+
crops_per_image.extend(crops)
|
389 |
+
|
390 |
+
if len(crops_per_image) == 0:
|
391 |
+
return [move_to_device_fn(torch.zeros(0, self.config.crop_embedding_model_config.hidden_size)) for _ in crop_bboxes]
|
392 |
+
|
393 |
+
crops_per_image = self.processor.preprocess_inputs_for_crop_embeddings(crops_per_image)
|
394 |
+
crops_per_image = move_to_device_fn(crops_per_image)
|
395 |
+
|
396 |
+
# process the crops in batches to avoid OOM
|
397 |
+
embeddings = []
|
398 |
+
for i in range(0, len(crops_per_image), batch_size):
|
399 |
+
crops = crops_per_image[i:i+batch_size]
|
400 |
+
embeddings_per_batch = self.crop_embedding_model(crops).last_hidden_state[:, 0]
|
401 |
+
embeddings.append(embeddings_per_batch)
|
402 |
+
embeddings = torch.cat(embeddings, dim=0)
|
403 |
+
|
404 |
+
crop_embeddings_for_batch = []
|
405 |
+
for num_crops in num_crops_per_batch:
|
406 |
+
crop_embeddings_for_batch.append(embeddings[:num_crops])
|
407 |
+
embeddings = embeddings[num_crops:]
|
408 |
+
|
409 |
+
# restore the mask ratio to the default
|
410 |
+
self.crop_embedding_model.embeddings.config.mask_ratio = old_mask_ratio
|
411 |
+
|
412 |
+
return crop_embeddings_for_batch
|
413 |
+
|
414 |
+
def predict_ocr(self, images, crop_bboxes, move_to_device_fn=None, use_tqdm=False, batch_size=32):
|
415 |
+
assert not self.config.disable_ocr
|
416 |
+
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
417 |
+
|
418 |
+
crops_per_image = []
|
419 |
+
num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
|
420 |
+
for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
|
421 |
+
crops = self.processor.crop_image(image, bboxes)
|
422 |
+
assert len(crops) == num_crops
|
423 |
+
crops_per_image.extend(crops)
|
424 |
+
|
425 |
+
if len(crops_per_image) == 0:
|
426 |
+
return [[] for _ in crop_bboxes]
|
427 |
+
|
428 |
+
crops_per_image = self.processor.preprocess_inputs_for_ocr(crops_per_image)
|
429 |
+
crops_per_image = move_to_device_fn(crops_per_image)
|
430 |
+
|
431 |
+
# process the crops in batches to avoid OOM
|
432 |
+
all_generated_texts = []
|
433 |
+
if use_tqdm:
|
434 |
+
from tqdm import tqdm
|
435 |
+
pbar = tqdm(range(0, len(crops_per_image), batch_size))
|
436 |
+
else:
|
437 |
+
pbar = range(0, len(crops_per_image), batch_size)
|
438 |
+
for i in pbar:
|
439 |
+
crops = crops_per_image[i:i+batch_size]
|
440 |
+
generated_ids = self.ocr_model.generate(crops)
|
441 |
+
generated_texts = self.processor.postprocess_ocr_tokens(generated_ids)
|
442 |
+
all_generated_texts.extend(generated_texts)
|
443 |
+
|
444 |
+
texts_for_images = []
|
445 |
+
for num_crops in num_crops_per_batch:
|
446 |
+
texts_for_images.append([x.replace("\n", "") for x in all_generated_texts[:num_crops]])
|
447 |
+
all_generated_texts = all_generated_texts[num_crops:]
|
448 |
+
|
449 |
+
return texts_for_images
|
450 |
+
|
451 |
+
def visualise_single_image_prediction(
|
452 |
+
self, image_as_np_array, predictions, filename=None
|
453 |
+
):
|
454 |
+
return visualise_single_image_prediction(image_as_np_array, predictions, filename)
|
455 |
+
|
456 |
+
|
457 |
+
@torch.no_grad()
|
458 |
+
def _get_detection_transformer_output(
|
459 |
+
self,
|
460 |
+
pixel_values: torch.FloatTensor,
|
461 |
+
pixel_mask: Optional[torch.LongTensor] = None
|
462 |
+
):
|
463 |
+
if self.config.disable_detections:
|
464 |
+
raise ValueError("Detection model is disabled. Set disable_detections=False in the config.")
|
465 |
+
return self.detection_transformer(
|
466 |
+
pixel_values=pixel_values,
|
467 |
+
pixel_mask=pixel_mask,
|
468 |
+
return_dict=True
|
469 |
+
)
|
470 |
+
|
471 |
+
def _get_predicted_obj_tokens(
|
472 |
+
self,
|
473 |
+
detection_transformer_output: ConditionalDetrModelOutput
|
474 |
+
):
|
475 |
+
return detection_transformer_output.last_hidden_state[:, :-self.num_non_obj_tokens]
|
476 |
+
|
477 |
+
def _get_predicted_c2c_tokens(
|
478 |
+
self,
|
479 |
+
detection_transformer_output: ConditionalDetrModelOutput
|
480 |
+
):
|
481 |
+
return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens]
|
482 |
+
|
483 |
+
def _get_predicted_t2c_tokens(
|
484 |
+
self,
|
485 |
+
detection_transformer_output: ConditionalDetrModelOutput
|
486 |
+
):
|
487 |
+
return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens+1]
|
488 |
+
|
489 |
+
def _get_predicted_bboxes_and_classes(
|
490 |
+
self,
|
491 |
+
detection_transformer_output: ConditionalDetrModelOutput,
|
492 |
+
):
|
493 |
+
if self.config.disable_detections:
|
494 |
+
raise ValueError("Detection model is disabled. Set disable_detections=False in the config.")
|
495 |
+
|
496 |
+
obj = self._get_predicted_obj_tokens(detection_transformer_output)
|
497 |
+
|
498 |
+
predicted_class_scores = self.class_labels_classifier(obj)
|
499 |
+
reference = detection_transformer_output.reference_points[:-self.num_non_obj_tokens]
|
500 |
+
reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
|
501 |
+
predicted_boxes = self.bbox_predictor(obj)
|
502 |
+
predicted_boxes[..., :2] += reference_before_sigmoid
|
503 |
+
predicted_boxes = predicted_boxes.sigmoid()
|
504 |
+
|
505 |
+
return predicted_class_scores, predicted_boxes
|
506 |
+
|
507 |
+
def _get_text_classification(
|
508 |
+
self,
|
509 |
+
text_obj_tokens_for_batch: List[torch.FloatTensor],
|
510 |
+
apply_sigmoid=False,
|
511 |
+
):
|
512 |
+
assert not self.config.disable_detections
|
513 |
+
is_this_text_a_dialogue = []
|
514 |
+
for text_obj_tokens in text_obj_tokens_for_batch:
|
515 |
+
if text_obj_tokens.shape[0] == 0:
|
516 |
+
is_this_text_a_dialogue.append(torch.tensor([], dtype=torch.bool))
|
517 |
+
continue
|
518 |
+
classification = self.is_this_text_a_dialogue(text_obj_tokens).squeeze(-1)
|
519 |
+
if apply_sigmoid:
|
520 |
+
classification = classification.sigmoid()
|
521 |
+
is_this_text_a_dialogue.append(classification)
|
522 |
+
return is_this_text_a_dialogue
|
523 |
+
|
524 |
+
def _get_character_character_affinity_matrices(
|
525 |
+
self,
|
526 |
+
character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
|
527 |
+
crop_embeddings_for_batch: List[torch.FloatTensor] = None,
|
528 |
+
c2c_tokens_for_batch: List[torch.FloatTensor] = None,
|
529 |
+
crop_only=False,
|
530 |
+
apply_sigmoid=True,
|
531 |
+
):
|
532 |
+
assert self.config.disable_detections or (character_obj_tokens_for_batch is not None and c2c_tokens_for_batch is not None)
|
533 |
+
assert self.config.disable_crop_embeddings or crop_embeddings_for_batch is not None
|
534 |
+
assert not self.config.disable_detections or not self.config.disable_crop_embeddings
|
535 |
+
|
536 |
+
if crop_only:
|
537 |
+
affinity_matrices = []
|
538 |
+
for crop_embeddings in crop_embeddings_for_batch:
|
539 |
+
crop_embeddings = crop_embeddings / crop_embeddings.norm(dim=-1, keepdim=True)
|
540 |
+
affinity_matrix = crop_embeddings @ crop_embeddings.T
|
541 |
+
affinity_matrices.append(affinity_matrix)
|
542 |
+
return affinity_matrices
|
543 |
+
affinity_matrices = []
|
544 |
+
for batch_index, (character_obj_tokens, c2c) in enumerate(zip(character_obj_tokens_for_batch, c2c_tokens_for_batch)):
|
545 |
+
if character_obj_tokens.shape[0] == 0:
|
546 |
+
affinity_matrices.append(torch.zeros(0, 0).type_as(character_obj_tokens))
|
547 |
+
continue
|
548 |
+
if not self.config.disable_crop_embeddings:
|
549 |
+
crop_embeddings = crop_embeddings_for_batch[batch_index]
|
550 |
+
assert character_obj_tokens.shape[0] == crop_embeddings.shape[0]
|
551 |
+
character_obj_tokens = torch.cat([character_obj_tokens, crop_embeddings], dim=-1)
|
552 |
+
char_i = repeat(character_obj_tokens, "i d -> i repeat d", repeat=character_obj_tokens.shape[0])
|
553 |
+
char_j = repeat(character_obj_tokens, "j d -> repeat j d", repeat=character_obj_tokens.shape[0])
|
554 |
+
char_ij = rearrange([char_i, char_j], "two i j d -> (i j) (two d)")
|
555 |
+
c2c = repeat(c2c, "d -> repeat d", repeat = char_ij.shape[0])
|
556 |
+
char_ij_c2c = torch.cat([char_ij, c2c], dim=-1)
|
557 |
+
character_character_affinities = self.character_character_matching_head(char_ij_c2c)
|
558 |
+
character_character_affinities = rearrange(character_character_affinities, "(i j) 1 -> i j", i=char_i.shape[0])
|
559 |
+
character_character_affinities = (character_character_affinities + character_character_affinities.T) / 2
|
560 |
+
if apply_sigmoid:
|
561 |
+
character_character_affinities = character_character_affinities.sigmoid()
|
562 |
+
affinity_matrices.append(character_character_affinities)
|
563 |
+
return affinity_matrices
|
564 |
+
|
565 |
+
def _get_text_character_affinity_matrices(
|
566 |
+
self,
|
567 |
+
character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
|
568 |
+
text_obj_tokens_for_this_batch: List[torch.FloatTensor] = None,
|
569 |
+
t2c_tokens_for_batch: List[torch.FloatTensor] = None,
|
570 |
+
apply_sigmoid=True,
|
571 |
+
):
|
572 |
+
assert not self.config.disable_detections
|
573 |
+
assert character_obj_tokens_for_batch is not None and text_obj_tokens_for_this_batch is not None and t2c_tokens_for_batch is not None
|
574 |
+
affinity_matrices = []
|
575 |
+
for character_obj_tokens, text_obj_tokens, t2c in zip(character_obj_tokens_for_batch, text_obj_tokens_for_this_batch, t2c_tokens_for_batch):
|
576 |
+
if character_obj_tokens.shape[0] == 0 or text_obj_tokens.shape[0] == 0:
|
577 |
+
affinity_matrices.append(torch.zeros(text_obj_tokens.shape[0], character_obj_tokens.shape[0]).type_as(character_obj_tokens))
|
578 |
+
continue
|
579 |
+
text_i = repeat(text_obj_tokens, "i d -> i repeat d", repeat=character_obj_tokens.shape[0])
|
580 |
+
char_j = repeat(character_obj_tokens, "j d -> repeat j d", repeat=text_obj_tokens.shape[0])
|
581 |
+
text_char = rearrange([text_i, char_j], "two i j d -> (i j) (two d)")
|
582 |
+
t2c = repeat(t2c, "d -> repeat d", repeat = text_char.shape[0])
|
583 |
+
text_char_t2c = torch.cat([text_char, t2c], dim=-1)
|
584 |
+
text_character_affinities = self.text_character_matching_head(text_char_t2c)
|
585 |
+
text_character_affinities = rearrange(text_character_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
|
586 |
+
if apply_sigmoid:
|
587 |
+
text_character_affinities = text_character_affinities.sigmoid()
|
588 |
+
affinity_matrices.append(text_character_affinities)
|
589 |
+
return affinity_matrices
|
590 |
+
|
591 |
+
def _get_text_tail_affinity_matrices(
|
592 |
+
self,
|
593 |
+
text_obj_tokens_for_this_batch: List[torch.FloatTensor] = None,
|
594 |
+
tail_obj_tokens_for_batch: List[torch.FloatTensor] = None,
|
595 |
+
apply_sigmoid=True,
|
596 |
+
):
|
597 |
+
assert not self.config.disable_detections
|
598 |
+
assert tail_obj_tokens_for_batch is not None and text_obj_tokens_for_this_batch is not None
|
599 |
+
affinity_matrices = []
|
600 |
+
for tail_obj_tokens, text_obj_tokens in zip(tail_obj_tokens_for_batch, text_obj_tokens_for_this_batch):
|
601 |
+
if tail_obj_tokens.shape[0] == 0 or text_obj_tokens.shape[0] == 0:
|
602 |
+
affinity_matrices.append(torch.zeros(text_obj_tokens.shape[0], tail_obj_tokens.shape[0]).type_as(tail_obj_tokens))
|
603 |
+
continue
|
604 |
+
text_i = repeat(text_obj_tokens, "i d -> i repeat d", repeat=tail_obj_tokens.shape[0])
|
605 |
+
tail_j = repeat(tail_obj_tokens, "j d -> repeat j d", repeat=text_obj_tokens.shape[0])
|
606 |
+
text_tail = rearrange([text_i, tail_j], "two i j d -> (i j) (two d)")
|
607 |
+
text_tail_affinities = self.text_tail_matching_head(text_tail)
|
608 |
+
text_tail_affinities = rearrange(text_tail_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
|
609 |
+
if apply_sigmoid:
|
610 |
+
text_tail_affinities = text_tail_affinities.sigmoid()
|
611 |
+
affinity_matrices.append(text_tail_affinities)
|
612 |
+
return affinity_matrices
|
processing_magiv2.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import ConditionalDetrImageProcessor, TrOCRProcessor, ViTImageProcessor
|
2 |
+
import torch
|
3 |
+
from typing import List
|
4 |
+
from shapely.geometry import box
|
5 |
+
from .utils import x1y1x2y2_to_xywh
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
class Magiv2Processor():
|
9 |
+
def __init__(self, config):
|
10 |
+
self.config = config
|
11 |
+
self.detection_image_preprocessor = None
|
12 |
+
self.ocr_preprocessor = None
|
13 |
+
self.crop_embedding_image_preprocessor = None
|
14 |
+
if not config.disable_detections:
|
15 |
+
assert config.detection_image_preprocessing_config is not None
|
16 |
+
self.detection_image_preprocessor = ConditionalDetrImageProcessor.from_dict(config.detection_image_preprocessing_config)
|
17 |
+
if not config.disable_ocr:
|
18 |
+
assert config.ocr_pretrained_processor_path is not None
|
19 |
+
self.ocr_preprocessor = TrOCRProcessor.from_pretrained(config.ocr_pretrained_processor_path)
|
20 |
+
if not config.disable_crop_embeddings:
|
21 |
+
assert config.crop_embedding_image_preprocessing_config is not None
|
22 |
+
self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict(config.crop_embedding_image_preprocessing_config)
|
23 |
+
|
24 |
+
def preprocess_inputs_for_detection(self, images, annotations=None):
|
25 |
+
images = list(images)
|
26 |
+
assert isinstance(images[0], np.ndarray)
|
27 |
+
annotations = self._convert_annotations_to_coco_format(annotations)
|
28 |
+
inputs = self.detection_image_preprocessor(images, annotations=annotations, return_tensors="pt")
|
29 |
+
return inputs
|
30 |
+
|
31 |
+
def preprocess_inputs_for_ocr(self, images):
|
32 |
+
images = list(images)
|
33 |
+
assert isinstance(images[0], np.ndarray)
|
34 |
+
return self.ocr_preprocessor(images, return_tensors="pt").pixel_values
|
35 |
+
|
36 |
+
def preprocess_inputs_for_crop_embeddings(self, images):
|
37 |
+
images = list(images)
|
38 |
+
assert isinstance(images[0], np.ndarray)
|
39 |
+
return self.crop_embedding_image_preprocessor(images, return_tensors="pt").pixel_values
|
40 |
+
|
41 |
+
def postprocess_ocr_tokens(self, generated_ids, skip_special_tokens=True):
|
42 |
+
return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens)
|
43 |
+
|
44 |
+
def crop_image(self, image, bboxes):
|
45 |
+
crops_for_image = []
|
46 |
+
for bbox in bboxes:
|
47 |
+
x1, y1, x2, y2 = bbox
|
48 |
+
|
49 |
+
# fix the bounding box in case it is out of bounds or too small
|
50 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
51 |
+
x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2) # just incase
|
52 |
+
x1, y1 = max(0, x1), max(0, y1)
|
53 |
+
x1, y1 = min(image.shape[1], x1), min(image.shape[0], y1)
|
54 |
+
x2, y2 = max(0, x2), max(0, y2)
|
55 |
+
x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2)
|
56 |
+
if x2 - x1 < 10:
|
57 |
+
if image.shape[1] - x1 > 10:
|
58 |
+
x2 = x1 + 10
|
59 |
+
else:
|
60 |
+
x1 = x2 - 10
|
61 |
+
if y2 - y1 < 10:
|
62 |
+
if image.shape[0] - y1 > 10:
|
63 |
+
y2 = y1 + 10
|
64 |
+
else:
|
65 |
+
y1 = y2 - 10
|
66 |
+
|
67 |
+
crop = image[y1:y2, x1:x2]
|
68 |
+
crops_for_image.append(crop)
|
69 |
+
return crops_for_image
|
70 |
+
|
71 |
+
def _get_indices_of_characters_to_keep(self, batch_scores, batch_labels, batch_bboxes, character_detection_threshold):
|
72 |
+
indices_of_characters_to_keep = []
|
73 |
+
for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes):
|
74 |
+
indices = torch.where((labels == 0) & (scores > character_detection_threshold))[0]
|
75 |
+
indices_of_characters_to_keep.append(indices)
|
76 |
+
return indices_of_characters_to_keep
|
77 |
+
|
78 |
+
def _get_indices_of_panels_to_keep(self, batch_scores, batch_labels, batch_bboxes, panel_detection_threshold):
|
79 |
+
indices_of_panels_to_keep = []
|
80 |
+
for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
|
81 |
+
indices = torch.where(labels == 2)[0]
|
82 |
+
bboxes = bboxes[indices]
|
83 |
+
scores = scores[indices]
|
84 |
+
labels = labels[indices]
|
85 |
+
if len(indices) == 0:
|
86 |
+
indices_of_panels_to_keep.append([])
|
87 |
+
continue
|
88 |
+
scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
|
89 |
+
panels_to_keep = []
|
90 |
+
union_of_panels_so_far = box(0, 0, 0, 0)
|
91 |
+
for ps, pb, pl, pi in zip(scores, bboxes, labels, indices):
|
92 |
+
panel_polygon = box(pb[0], pb[1], pb[2], pb[3])
|
93 |
+
if ps < panel_detection_threshold:
|
94 |
+
continue
|
95 |
+
if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5:
|
96 |
+
continue
|
97 |
+
panels_to_keep.append((ps, pl, pb, pi))
|
98 |
+
union_of_panels_so_far = union_of_panels_so_far.union(panel_polygon)
|
99 |
+
indices_of_panels_to_keep.append([p[3].item() for p in panels_to_keep])
|
100 |
+
return indices_of_panels_to_keep
|
101 |
+
|
102 |
+
def _get_indices_of_texts_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
|
103 |
+
indices_of_texts_to_keep = []
|
104 |
+
for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
|
105 |
+
indices = torch.where((labels == 1) & (scores > text_detection_threshold))[0]
|
106 |
+
bboxes = bboxes[indices]
|
107 |
+
scores = scores[indices]
|
108 |
+
labels = labels[indices]
|
109 |
+
if len(indices) == 0:
|
110 |
+
indices_of_texts_to_keep.append([])
|
111 |
+
continue
|
112 |
+
scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
|
113 |
+
texts_to_keep = []
|
114 |
+
texts_to_keep_as_shapely_objects = []
|
115 |
+
for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
|
116 |
+
text_polygon = box(tb[0], tb[1], tb[2], tb[3])
|
117 |
+
should_append = True
|
118 |
+
for t in texts_to_keep_as_shapely_objects:
|
119 |
+
if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5:
|
120 |
+
should_append = False
|
121 |
+
break
|
122 |
+
if should_append:
|
123 |
+
texts_to_keep.append((ts, tl, tb, ti))
|
124 |
+
texts_to_keep_as_shapely_objects.append(text_polygon)
|
125 |
+
indices_of_texts_to_keep.append([t[3].item() for t in texts_to_keep])
|
126 |
+
return indices_of_texts_to_keep
|
127 |
+
|
128 |
+
def _get_indices_of_tails_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
|
129 |
+
indices_of_texts_to_keep = []
|
130 |
+
for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
|
131 |
+
indices = torch.where((labels == 3) & (scores > text_detection_threshold))[0]
|
132 |
+
bboxes = bboxes[indices]
|
133 |
+
scores = scores[indices]
|
134 |
+
labels = labels[indices]
|
135 |
+
if len(indices) == 0:
|
136 |
+
indices_of_texts_to_keep.append([])
|
137 |
+
continue
|
138 |
+
scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
|
139 |
+
texts_to_keep = []
|
140 |
+
texts_to_keep_as_shapely_objects = []
|
141 |
+
for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
|
142 |
+
text_polygon = box(tb[0], tb[1], tb[2], tb[3])
|
143 |
+
should_append = True
|
144 |
+
for t in texts_to_keep_as_shapely_objects:
|
145 |
+
if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5:
|
146 |
+
should_append = False
|
147 |
+
break
|
148 |
+
if should_append:
|
149 |
+
texts_to_keep.append((ts, tl, tb, ti))
|
150 |
+
texts_to_keep_as_shapely_objects.append(text_polygon)
|
151 |
+
indices_of_texts_to_keep.append([t[3].item() for t in texts_to_keep])
|
152 |
+
return indices_of_texts_to_keep
|
153 |
+
|
154 |
+
def _convert_annotations_to_coco_format(self, annotations):
|
155 |
+
if annotations is None:
|
156 |
+
return None
|
157 |
+
self._verify_annotations_are_in_correct_format(annotations)
|
158 |
+
coco_annotations = []
|
159 |
+
for annotation in annotations:
|
160 |
+
coco_annotation = {
|
161 |
+
"image_id": annotation["image_id"],
|
162 |
+
"annotations": [],
|
163 |
+
}
|
164 |
+
for bbox, label in zip(annotation["bboxes_as_x1y1x2y2"], annotation["labels"]):
|
165 |
+
coco_annotation["annotations"].append({
|
166 |
+
"bbox": x1y1x2y2_to_xywh(bbox),
|
167 |
+
"category_id": label,
|
168 |
+
"area": (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]),
|
169 |
+
})
|
170 |
+
coco_annotations.append(coco_annotation)
|
171 |
+
return coco_annotations
|
172 |
+
|
173 |
+
def _verify_annotations_are_in_correct_format(self, annotations):
|
174 |
+
error_msg = """
|
175 |
+
Annotations must be in the following format:
|
176 |
+
[
|
177 |
+
{
|
178 |
+
"image_id": 0,
|
179 |
+
"bboxes_as_x1y1x2y2": [[0, 0, 10, 10], [10, 10, 20, 20], [20, 20, 30, 30]],
|
180 |
+
"labels": [0, 1, 2],
|
181 |
+
},
|
182 |
+
...
|
183 |
+
]
|
184 |
+
Labels: 0 for characters, 1 for text, 2 for panels.
|
185 |
+
"""
|
186 |
+
if annotations is None:
|
187 |
+
return
|
188 |
+
if not isinstance(annotations, List) and not isinstance(annotations, tuple):
|
189 |
+
raise ValueError(
|
190 |
+
f"{error_msg} Expected a List/Tuple, found {type(annotations)}."
|
191 |
+
)
|
192 |
+
if len(annotations) == 0:
|
193 |
+
return
|
194 |
+
if not isinstance(annotations[0], dict):
|
195 |
+
raise ValueError(
|
196 |
+
f"{error_msg} Expected a List[Dicct], found {type(annotations[0])}."
|
197 |
+
)
|
198 |
+
if "image_id" not in annotations[0]:
|
199 |
+
raise ValueError(
|
200 |
+
f"{error_msg} Dict must contain 'image_id'."
|
201 |
+
)
|
202 |
+
if "bboxes_as_x1y1x2y2" not in annotations[0]:
|
203 |
+
raise ValueError(
|
204 |
+
f"{error_msg} Dict must contain 'bboxes_as_x1y1x2y2'."
|
205 |
+
)
|
206 |
+
if "labels" not in annotations[0]:
|
207 |
+
raise ValueError(
|
208 |
+
f"{error_msg} Dict must contain 'labels'."
|
209 |
+
)
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:56392403204d3a4cca38694a3a260a6929d741869d802d6b14de35b4eab4c4b8
|
3 |
+
size 2063693064
|
utils.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import matplotlib.patches as patches
|
6 |
+
from shapely.geometry import Point, box
|
7 |
+
import networkx as nx
|
8 |
+
from copy import deepcopy
|
9 |
+
from itertools import groupby
|
10 |
+
|
11 |
+
def move_to_device(inputs, device):
|
12 |
+
if hasattr(inputs, "keys"):
|
13 |
+
return {k: move_to_device(v, device) for k, v in inputs.items()}
|
14 |
+
elif isinstance(inputs, list):
|
15 |
+
return [move_to_device(v, device) for v in inputs]
|
16 |
+
elif isinstance(inputs, tuple):
|
17 |
+
return tuple([move_to_device(v, device) for v in inputs])
|
18 |
+
elif isinstance(inputs, np.ndarray):
|
19 |
+
return torch.from_numpy(inputs).to(device)
|
20 |
+
else:
|
21 |
+
return inputs.to(device)
|
22 |
+
|
23 |
+
class UnionFind:
|
24 |
+
def __init__(self, n):
|
25 |
+
self.parent = list(range(n))
|
26 |
+
self.size = [1] * n
|
27 |
+
self.num_components = n
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def from_adj_matrix(cls, adj_matrix):
|
31 |
+
ufds = cls(adj_matrix.shape[0])
|
32 |
+
for i in range(adj_matrix.shape[0]):
|
33 |
+
for j in range(adj_matrix.shape[1]):
|
34 |
+
if adj_matrix[i, j] > 0:
|
35 |
+
ufds.unite(i, j)
|
36 |
+
return ufds
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def from_adj_list(cls, adj_list):
|
40 |
+
ufds = cls(len(adj_list))
|
41 |
+
for i in range(len(adj_list)):
|
42 |
+
for j in adj_list[i]:
|
43 |
+
ufds.unite(i, j)
|
44 |
+
return ufds
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def from_edge_list(cls, edge_list, num_nodes):
|
48 |
+
ufds = cls(num_nodes)
|
49 |
+
for edge in edge_list:
|
50 |
+
ufds.unite(edge[0], edge[1])
|
51 |
+
return ufds
|
52 |
+
|
53 |
+
def find(self, x):
|
54 |
+
if self.parent[x] == x:
|
55 |
+
return x
|
56 |
+
self.parent[x] = self.find(self.parent[x])
|
57 |
+
return self.parent[x]
|
58 |
+
|
59 |
+
def unite(self, x, y):
|
60 |
+
x = self.find(x)
|
61 |
+
y = self.find(y)
|
62 |
+
if x != y:
|
63 |
+
if self.size[x] < self.size[y]:
|
64 |
+
x, y = y, x
|
65 |
+
self.parent[y] = x
|
66 |
+
self.size[x] += self.size[y]
|
67 |
+
self.num_components -= 1
|
68 |
+
|
69 |
+
def get_components_of(self, x):
|
70 |
+
x = self.find(x)
|
71 |
+
return [i for i in range(len(self.parent)) if self.find(i) == x]
|
72 |
+
|
73 |
+
def are_connected(self, x, y):
|
74 |
+
return self.find(x) == self.find(y)
|
75 |
+
|
76 |
+
def get_size(self, x):
|
77 |
+
return self.size[self.find(x)]
|
78 |
+
|
79 |
+
def get_num_components(self):
|
80 |
+
return self.num_components
|
81 |
+
|
82 |
+
def get_labels_for_connected_components(self):
|
83 |
+
map_parent_to_label = {}
|
84 |
+
labels = []
|
85 |
+
for i in range(len(self.parent)):
|
86 |
+
parent = self.find(i)
|
87 |
+
if parent not in map_parent_to_label:
|
88 |
+
map_parent_to_label[parent] = len(map_parent_to_label)
|
89 |
+
labels.append(map_parent_to_label[parent])
|
90 |
+
return labels
|
91 |
+
|
92 |
+
def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
93 |
+
figure, subplot = plt.subplots(1, 1, figsize=(10, 10))
|
94 |
+
subplot.imshow(image_as_np_array)
|
95 |
+
plot_bboxes(subplot, predictions["panels"], color="green")
|
96 |
+
plot_bboxes(subplot, predictions["texts"], color="red", visibility=predictions["is_essential_text"])
|
97 |
+
plot_bboxes(subplot, predictions["characters"], color="blue")
|
98 |
+
plot_bboxes(subplot, predictions["tails"], color="purple")
|
99 |
+
|
100 |
+
for i, name in enumerate(predictions["character_names"]):
|
101 |
+
char_bbox = predictions["characters"][i]
|
102 |
+
x1, y1, x2, y2 = char_bbox
|
103 |
+
subplot.text(x1, y1 - 2, name,
|
104 |
+
verticalalignment='bottom', horizontalalignment='left',
|
105 |
+
bbox=dict(facecolor='blue', alpha=1, edgecolor='none'), # Background settings
|
106 |
+
color='white', fontsize=8)
|
107 |
+
|
108 |
+
COLOURS = [
|
109 |
+
"#b7ff51", # green
|
110 |
+
"#f50a8f", # pink
|
111 |
+
"#4b13b6", # purple
|
112 |
+
"#ddaa34", # orange
|
113 |
+
"#bea2a2", # brown
|
114 |
+
]
|
115 |
+
colour_index = 0
|
116 |
+
character_cluster_labels = predictions["character_cluster_labels"]
|
117 |
+
unique_label_sorted_by_frequency = sorted(list(set(character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
|
118 |
+
for label in unique_label_sorted_by_frequency:
|
119 |
+
root = None
|
120 |
+
others = []
|
121 |
+
for i in range(len(predictions["characters"])):
|
122 |
+
if character_cluster_labels[i] == label:
|
123 |
+
if root is None:
|
124 |
+
root = i
|
125 |
+
else:
|
126 |
+
others.append(i)
|
127 |
+
if colour_index >= len(COLOURS):
|
128 |
+
random_colour = COLOURS[0]
|
129 |
+
while random_colour in COLOURS:
|
130 |
+
random_colour = "#" + "".join([random.choice("0123456789ABCDEF") for j in range(6)])
|
131 |
+
else:
|
132 |
+
random_colour = COLOURS[colour_index]
|
133 |
+
colour_index += 1
|
134 |
+
bbox_i = predictions["characters"][root]
|
135 |
+
x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
|
136 |
+
y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
|
137 |
+
subplot.plot([x1], [y1], color=random_colour, marker="o", markersize=5)
|
138 |
+
for j in others:
|
139 |
+
# draw line from centre of bbox i to centre of bbox j
|
140 |
+
bbox_j = predictions["characters"][j]
|
141 |
+
x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
|
142 |
+
y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
|
143 |
+
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
144 |
+
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
145 |
+
subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
|
146 |
+
subplot.plot([x2], [y2], color=random_colour, marker="o", markersize=5)
|
147 |
+
|
148 |
+
for (i, j) in predictions["text_character_associations"]:
|
149 |
+
bbox_i = predictions["texts"][i]
|
150 |
+
bbox_j = predictions["characters"][j]
|
151 |
+
if not predictions["is_essential_text"][i]:
|
152 |
+
continue
|
153 |
+
x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
|
154 |
+
y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
|
155 |
+
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
156 |
+
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
157 |
+
subplot.plot([x1, x2], [y1, y2], color="red", linewidth=2, linestyle="dashed")
|
158 |
+
|
159 |
+
for (i, j) in predictions["text_tail_associations"]:
|
160 |
+
bbox_i = predictions["texts"][i]
|
161 |
+
bbox_j = predictions["tails"][j]
|
162 |
+
x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
|
163 |
+
y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
|
164 |
+
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
165 |
+
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
166 |
+
subplot.plot([x1, x2], [y1, y2], color="purple", linewidth=2, linestyle="dashed")
|
167 |
+
|
168 |
+
subplot.axis("off")
|
169 |
+
if filename is not None:
|
170 |
+
plt.savefig(filename, bbox_inches="tight", pad_inches=0)
|
171 |
+
|
172 |
+
figure.canvas.draw()
|
173 |
+
image = np.array(figure.canvas.renderer._renderer)
|
174 |
+
plt.close()
|
175 |
+
return image
|
176 |
+
|
177 |
+
def plot_bboxes(subplot, bboxes, color="red", visibility=None):
|
178 |
+
if visibility is None:
|
179 |
+
visibility = [1] * len(bboxes)
|
180 |
+
for id, bbox in enumerate(bboxes):
|
181 |
+
if visibility[id] == 0:
|
182 |
+
continue
|
183 |
+
w = bbox[2] - bbox[0]
|
184 |
+
h = bbox[3] - bbox[1]
|
185 |
+
rect = patches.Rectangle(
|
186 |
+
bbox[:2], w, h, linewidth=1, edgecolor=color, facecolor="none", linestyle="solid"
|
187 |
+
)
|
188 |
+
subplot.add_patch(rect)
|
189 |
+
|
190 |
+
def sort_panels(rects):
|
191 |
+
before_rects = convert_to_list_of_lists(rects)
|
192 |
+
# slightly erode all rectangles initially to account for imperfect detections
|
193 |
+
rects = [erode_rectangle(rect, 0.05) for rect in before_rects]
|
194 |
+
G = nx.DiGraph()
|
195 |
+
G.add_nodes_from(range(len(rects)))
|
196 |
+
for i in range(len(rects)):
|
197 |
+
for j in range(len(rects)):
|
198 |
+
if i == j:
|
199 |
+
continue
|
200 |
+
if is_there_a_directed_edge(i, j, rects):
|
201 |
+
G.add_edge(i, j, weight=get_distance(rects[i], rects[j]))
|
202 |
+
else:
|
203 |
+
G.add_edge(j, i, weight=get_distance(rects[i], rects[j]))
|
204 |
+
while True:
|
205 |
+
cycles = sorted(nx.simple_cycles(G))
|
206 |
+
cycles = [cycle for cycle in cycles if len(cycle) > 1]
|
207 |
+
if len(cycles) == 0:
|
208 |
+
break
|
209 |
+
cycle = cycles[0]
|
210 |
+
edges = [e for e in zip(cycle, cycle[1:] + cycle[:1])]
|
211 |
+
max_cyclic_edge = max(edges, key=lambda x: G.edges[x]["weight"])
|
212 |
+
G.remove_edge(*max_cyclic_edge)
|
213 |
+
return list(nx.topological_sort(G))
|
214 |
+
|
215 |
+
def is_strictly_above(rectA, rectB):
|
216 |
+
x1A, y1A, x2A, y2A = rectA
|
217 |
+
x1B, y1B, x2B, y2B = rectB
|
218 |
+
return y2A < y1B
|
219 |
+
|
220 |
+
def is_strictly_below(rectA, rectB):
|
221 |
+
x1A, y1A, x2A, y2A = rectA
|
222 |
+
x1B, y1B, x2B, y2B = rectB
|
223 |
+
return y2B < y1A
|
224 |
+
|
225 |
+
def is_strictly_left_of(rectA, rectB):
|
226 |
+
x1A, y1A, x2A, y2A = rectA
|
227 |
+
x1B, y1B, x2B, y2B = rectB
|
228 |
+
return x2A < x1B
|
229 |
+
|
230 |
+
def is_strictly_right_of(rectA, rectB):
|
231 |
+
x1A, y1A, x2A, y2A = rectA
|
232 |
+
x1B, y1B, x2B, y2B = rectB
|
233 |
+
return x2B < x1A
|
234 |
+
|
235 |
+
def intersects(rectA, rectB):
|
236 |
+
return box(*rectA).intersects(box(*rectB))
|
237 |
+
|
238 |
+
def is_there_a_directed_edge(a, b, rects):
|
239 |
+
rectA = rects[a]
|
240 |
+
rectB = rects[b]
|
241 |
+
centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2, rectA[1] + (rectA[3] - rectA[1]) / 2]
|
242 |
+
centre_of_B = [rectB[0] + (rectB[2] - rectB[0]) / 2, rectB[1] + (rectB[3] - rectB[1]) / 2]
|
243 |
+
if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
|
244 |
+
return box(*rectA).area > (box(*rectB)).area
|
245 |
+
copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
|
246 |
+
copy_B = [rectB[0], rectB[1], rectB[2], rectB[3]]
|
247 |
+
while True:
|
248 |
+
if is_strictly_above(copy_A, copy_B) and not is_strictly_left_of(copy_A, copy_B):
|
249 |
+
return 1
|
250 |
+
if is_strictly_above(copy_B, copy_A) and not is_strictly_left_of(copy_B, copy_A):
|
251 |
+
return 0
|
252 |
+
if is_strictly_right_of(copy_A, copy_B) and not is_strictly_below(copy_A, copy_B):
|
253 |
+
return 1
|
254 |
+
if is_strictly_right_of(copy_B, copy_A) and not is_strictly_below(copy_B, copy_A):
|
255 |
+
return 0
|
256 |
+
if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
|
257 |
+
return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
|
258 |
+
if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
|
259 |
+
return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
|
260 |
+
# otherwise they intersect
|
261 |
+
copy_A = erode_rectangle(copy_A, 0.05)
|
262 |
+
copy_B = erode_rectangle(copy_B, 0.05)
|
263 |
+
|
264 |
+
def get_distance(rectA, rectB):
|
265 |
+
return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
|
266 |
+
|
267 |
+
def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
|
268 |
+
rects = deepcopy(rects)
|
269 |
+
while True:
|
270 |
+
xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
|
271 |
+
rect_index = [i for i in range(len(rects)) if intersects(rects[i], [xmin, ymin, xmax, ymax])]
|
272 |
+
rects_copy = [rect for rect in rects if intersects(rect, [xmin, ymin, xmax, ymax])]
|
273 |
+
|
274 |
+
# try to split the panels using a "horizontal" lines
|
275 |
+
overlapping_y_ranges = merge_overlapping_ranges([(y1, y2) for x1, y1, x2, y2 in rects_copy])
|
276 |
+
panel_index_to_split = {}
|
277 |
+
for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
|
278 |
+
for i, index in enumerate(rect_index):
|
279 |
+
if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
|
280 |
+
panel_index_to_split[index] = split_index
|
281 |
+
|
282 |
+
if panel_index_to_split[a] != panel_index_to_split[b]:
|
283 |
+
return panel_index_to_split[a] < panel_index_to_split[b]
|
284 |
+
|
285 |
+
# try to split the panels using a "vertical" lines
|
286 |
+
overlapping_x_ranges = merge_overlapping_ranges([(x1, x2) for x1, y1, x2, y2 in rects_copy])
|
287 |
+
panel_index_to_split = {}
|
288 |
+
for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
|
289 |
+
for i, index in enumerate(rect_index):
|
290 |
+
if x1 <= rects_copy[i][0] <= rects_copy[i][2] <= x2:
|
291 |
+
panel_index_to_split[index] = split_index
|
292 |
+
if panel_index_to_split[a] != panel_index_to_split[b]:
|
293 |
+
return panel_index_to_split[a] < panel_index_to_split[b]
|
294 |
+
|
295 |
+
# otherwise, erode the rectangles and try again
|
296 |
+
rects = [erode_rectangle(rect, 0.05) for rect in rects]
|
297 |
+
|
298 |
+
def erode_rectangle(bbox, erosion_factor):
|
299 |
+
x1, y1, x2, y2 = bbox
|
300 |
+
w, h = x2 - x1, y2 - y1
|
301 |
+
cx, cy = x1 + w / 2, y1 + h / 2
|
302 |
+
if w < h:
|
303 |
+
aspect_ratio = w / h
|
304 |
+
erosion_factor_width = erosion_factor * aspect_ratio
|
305 |
+
erosion_factor_height = erosion_factor
|
306 |
+
else:
|
307 |
+
aspect_ratio = h / w
|
308 |
+
erosion_factor_width = erosion_factor
|
309 |
+
erosion_factor_height = erosion_factor * aspect_ratio
|
310 |
+
w = w - w * erosion_factor_width
|
311 |
+
h = h - h * erosion_factor_height
|
312 |
+
x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
|
313 |
+
return [x1, y1, x2, y2]
|
314 |
+
|
315 |
+
def merge_overlapping_ranges(ranges):
|
316 |
+
"""
|
317 |
+
ranges: list of tuples (x1, x2)
|
318 |
+
"""
|
319 |
+
if len(ranges) == 0:
|
320 |
+
return []
|
321 |
+
ranges = sorted(ranges, key=lambda x: x[0])
|
322 |
+
merged_ranges = []
|
323 |
+
for i, r in enumerate(ranges):
|
324 |
+
if i == 0:
|
325 |
+
prev_x1, prev_x2 = r
|
326 |
+
continue
|
327 |
+
x1, x2 = r
|
328 |
+
if x1 > prev_x2:
|
329 |
+
merged_ranges.append((prev_x1, prev_x2))
|
330 |
+
prev_x1, prev_x2 = x1, x2
|
331 |
+
else:
|
332 |
+
prev_x2 = max(prev_x2, x2)
|
333 |
+
merged_ranges.append((prev_x1, prev_x2))
|
334 |
+
return merged_ranges
|
335 |
+
|
336 |
+
def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
|
337 |
+
text_bboxes = convert_to_list_of_lists(text_bboxes)
|
338 |
+
sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
|
339 |
+
|
340 |
+
if len(text_bboxes) == 0:
|
341 |
+
return []
|
342 |
+
|
343 |
+
def indices_of_same_elements(nums):
|
344 |
+
groups = groupby(range(len(nums)), key=lambda i: nums[i])
|
345 |
+
return [list(indices) for _, indices in groups]
|
346 |
+
|
347 |
+
panel_id_for_text = get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes)
|
348 |
+
indices_of_texts = list(range(len(text_bboxes)))
|
349 |
+
indices_of_texts, panel_id_for_text = zip(*sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
|
350 |
+
indices_of_texts = list(indices_of_texts)
|
351 |
+
grouped_indices = indices_of_same_elements(panel_id_for_text)
|
352 |
+
for group in grouped_indices:
|
353 |
+
subset_of_text_indices = [indices_of_texts[i] for i in group]
|
354 |
+
text_bboxes_of_subset = [text_bboxes[i] for i in subset_of_text_indices]
|
355 |
+
sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
|
356 |
+
indices_of_texts[group[0] : group[-1] + 1] = [subset_of_text_indices[i] for i in sorted_subset_indices]
|
357 |
+
return indices_of_texts
|
358 |
+
|
359 |
+
def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
|
360 |
+
text_to_panel_mapping = []
|
361 |
+
for text_bbox in text_bboxes:
|
362 |
+
shapely_text_polygon = box(*text_bbox)
|
363 |
+
all_intersections = []
|
364 |
+
all_distances = []
|
365 |
+
if len(sorted_panel_bboxes) == 0:
|
366 |
+
text_to_panel_mapping.append(-1)
|
367 |
+
continue
|
368 |
+
for j, annotation in enumerate(sorted_panel_bboxes):
|
369 |
+
shapely_annotation_polygon = box(*annotation)
|
370 |
+
if shapely_text_polygon.intersects(shapely_annotation_polygon):
|
371 |
+
all_intersections.append((shapely_text_polygon.intersection(shapely_annotation_polygon).area, j))
|
372 |
+
all_distances.append((shapely_text_polygon.distance(shapely_annotation_polygon), j))
|
373 |
+
if len(all_intersections) == 0:
|
374 |
+
text_to_panel_mapping.append(min(all_distances, key=lambda x: x[0])[1])
|
375 |
+
else:
|
376 |
+
text_to_panel_mapping.append(max(all_intersections, key=lambda x: x[0])[1])
|
377 |
+
return text_to_panel_mapping
|
378 |
+
|
379 |
+
def sort_texts_within_panel(rects):
|
380 |
+
smallest_y = float("inf")
|
381 |
+
greatest_x = float("-inf")
|
382 |
+
for i, rect in enumerate(rects):
|
383 |
+
x1, y1, x2, y2 = rect
|
384 |
+
smallest_y = min(smallest_y, y1)
|
385 |
+
greatest_x = max(greatest_x, x2)
|
386 |
+
|
387 |
+
reference_point = Point(greatest_x, smallest_y)
|
388 |
+
|
389 |
+
polygons_and_index = []
|
390 |
+
for i, rect in enumerate(rects):
|
391 |
+
x1, y1, x2, y2 = rect
|
392 |
+
polygons_and_index.append((box(x1,y1,x2,y2), i))
|
393 |
+
# sort points by closest to reference point
|
394 |
+
polygons_and_index = sorted(polygons_and_index, key=lambda x: reference_point.distance(x[0]))
|
395 |
+
indices = [x[1] for x in polygons_and_index]
|
396 |
+
return indices
|
397 |
+
|
398 |
+
def x1y1wh_to_x1y1x2y2(bbox):
|
399 |
+
x1, y1, w, h = bbox
|
400 |
+
return [x1, y1, x1 + w, y1 + h]
|
401 |
+
|
402 |
+
def x1y1x2y2_to_xywh(bbox):
|
403 |
+
x1, y1, x2, y2 = bbox
|
404 |
+
return [x1, y1, x2 - x1, y2 - y1]
|
405 |
+
|
406 |
+
def convert_to_list_of_lists(rects):
|
407 |
+
if isinstance(rects, torch.Tensor):
|
408 |
+
return rects.tolist()
|
409 |
+
if isinstance(rects, np.ndarray):
|
410 |
+
return rects.tolist()
|
411 |
+
return [[a, b, c, d] for a, b, c, d in rects]
|