ragavsachdeva commited on
Commit
f7499c0
1 Parent(s): 7c1657f

Upload model

Browse files
Files changed (6) hide show
  1. config.json +490 -0
  2. configuration_magiv2.py +38 -0
  3. modelling_magiv2.py +612 -0
  4. processing_magiv2.py +209 -0
  5. pytorch_model.bin +3 -0
  6. utils.py +411 -0
config.json ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/work/rs/logs/magiv2/to_release_polished",
3
+ "architectures": [
4
+ "Magiv2Model"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_magiv2.Magiv2Config",
8
+ "AutoModel": "modelling_magiv2.Magiv2Model"
9
+ },
10
+ "crop_embedding_image_preprocessing_config": {
11
+ "_processor_class": null,
12
+ "do_normalize": true,
13
+ "do_rescale": true,
14
+ "do_resize": true,
15
+ "image_mean": [
16
+ 0.485,
17
+ 0.456,
18
+ 0.406
19
+ ],
20
+ "image_processor_type": "ViTImageProcessor",
21
+ "image_std": [
22
+ 0.229,
23
+ 0.224,
24
+ 0.225
25
+ ],
26
+ "resample": 2,
27
+ "rescale_factor": 0.00392156862745098,
28
+ "size": {
29
+ "height": 224,
30
+ "width": 224
31
+ }
32
+ },
33
+ "crop_embedding_model_config": {
34
+ "_name_or_path": "facebook/vit-mae-base",
35
+ "add_cross_attention": false,
36
+ "architectures": [
37
+ "ViTMAEForPreTraining"
38
+ ],
39
+ "attention_probs_dropout_prob": 0.0,
40
+ "bad_words_ids": null,
41
+ "begin_suppress_tokens": null,
42
+ "bos_token_id": null,
43
+ "chunk_size_feed_forward": 0,
44
+ "cross_attention_hidden_size": null,
45
+ "decoder_hidden_size": 512,
46
+ "decoder_intermediate_size": 2048,
47
+ "decoder_num_attention_heads": 16,
48
+ "decoder_num_hidden_layers": 8,
49
+ "decoder_start_token_id": null,
50
+ "diversity_penalty": 0.0,
51
+ "do_sample": false,
52
+ "early_stopping": false,
53
+ "encoder_no_repeat_ngram_size": 0,
54
+ "eos_token_id": null,
55
+ "exponential_decay_length_penalty": null,
56
+ "finetuning_task": null,
57
+ "forced_bos_token_id": null,
58
+ "forced_eos_token_id": null,
59
+ "hidden_act": "gelu",
60
+ "hidden_dropout_prob": 0.0,
61
+ "hidden_size": 768,
62
+ "id2label": {
63
+ "0": "LABEL_0",
64
+ "1": "LABEL_1"
65
+ },
66
+ "image_size": 224,
67
+ "initializer_range": 0.02,
68
+ "intermediate_size": 3072,
69
+ "is_decoder": false,
70
+ "is_encoder_decoder": false,
71
+ "label2id": {
72
+ "LABEL_0": 0,
73
+ "LABEL_1": 1
74
+ },
75
+ "layer_norm_eps": 1e-12,
76
+ "length_penalty": 1.0,
77
+ "mask_ratio": 0.75,
78
+ "max_length": 20,
79
+ "min_length": 0,
80
+ "model_type": "",
81
+ "no_repeat_ngram_size": 0,
82
+ "norm_pix_loss": false,
83
+ "num_attention_heads": 12,
84
+ "num_beam_groups": 1,
85
+ "num_beams": 1,
86
+ "num_channels": 3,
87
+ "num_hidden_layers": 12,
88
+ "num_return_sequences": 1,
89
+ "output_attentions": false,
90
+ "output_hidden_states": false,
91
+ "output_scores": false,
92
+ "pad_token_id": null,
93
+ "patch_size": 16,
94
+ "prefix": null,
95
+ "problem_type": null,
96
+ "pruned_heads": {},
97
+ "qkv_bias": true,
98
+ "remove_invalid_values": false,
99
+ "repetition_penalty": 1.0,
100
+ "return_dict": true,
101
+ "return_dict_in_generate": false,
102
+ "sep_token_id": null,
103
+ "suppress_tokens": null,
104
+ "task_specific_params": null,
105
+ "temperature": 1.0,
106
+ "tf_legacy_loss": false,
107
+ "tie_encoder_decoder": false,
108
+ "tie_word_embeddings": true,
109
+ "tokenizer_class": null,
110
+ "top_k": 50,
111
+ "top_p": 1.0,
112
+ "torch_dtype": "float32",
113
+ "torchscript": false,
114
+ "typical_p": 1.0,
115
+ "use_bfloat16": false
116
+ },
117
+ "detection_image_preprocessing_config": {
118
+ "_processor_class": null,
119
+ "do_normalize": true,
120
+ "do_pad": true,
121
+ "do_rescale": true,
122
+ "do_resize": true,
123
+ "format": "coco_detection",
124
+ "image_mean": [
125
+ 0.485,
126
+ 0.456,
127
+ 0.406
128
+ ],
129
+ "image_processor_type": "ConditionalDetrImageProcessor",
130
+ "image_std": [
131
+ 0.229,
132
+ 0.224,
133
+ 0.225
134
+ ],
135
+ "resample": 2,
136
+ "rescale_factor": 0.00392156862745098,
137
+ "size": {
138
+ "longest_edge": 1333,
139
+ "shortest_edge": 800
140
+ }
141
+ },
142
+ "detection_model_config": {
143
+ "_name_or_path": "microsoft/conditional-detr-resnet-50",
144
+ "activation_dropout": 0.0,
145
+ "activation_function": "relu",
146
+ "add_cross_attention": false,
147
+ "architectures": [
148
+ "ConditionalDETRForObjectDetection"
149
+ ],
150
+ "attention_dropout": 0.0,
151
+ "auxiliary_loss": false,
152
+ "backbone": "resnet50",
153
+ "backbone_config": null,
154
+ "bad_words_ids": null,
155
+ "bbox_cost": 5,
156
+ "bbox_loss_coefficient": 5,
157
+ "begin_suppress_tokens": null,
158
+ "bos_token_id": null,
159
+ "chunk_size_feed_forward": 0,
160
+ "class_cost": 2,
161
+ "cls_loss_coefficient": 2,
162
+ "cross_attention_hidden_size": null,
163
+ "d_model": 256,
164
+ "decoder_attention_heads": 8,
165
+ "decoder_ffn_dim": 2048,
166
+ "decoder_layerdrop": 0.0,
167
+ "decoder_layers": 6,
168
+ "decoder_start_token_id": null,
169
+ "dice_loss_coefficient": 1,
170
+ "dilation": false,
171
+ "diversity_penalty": 0.0,
172
+ "do_sample": false,
173
+ "dropout": 0.1,
174
+ "early_stopping": false,
175
+ "encoder_attention_heads": 8,
176
+ "encoder_ffn_dim": 2048,
177
+ "encoder_layerdrop": 0.0,
178
+ "encoder_layers": 6,
179
+ "encoder_no_repeat_ngram_size": 0,
180
+ "eos_token_id": null,
181
+ "exponential_decay_length_penalty": null,
182
+ "finetuning_task": null,
183
+ "focal_alpha": 0.25,
184
+ "forced_bos_token_id": null,
185
+ "forced_eos_token_id": null,
186
+ "giou_cost": 2,
187
+ "giou_loss_coefficient": 2,
188
+ "id2label": {
189
+ "0": "LABEL_0",
190
+ "1": "LABEL_1",
191
+ "2": "LABEL_2",
192
+ "3": "LABEL_3"
193
+ },
194
+ "init_std": 0.02,
195
+ "init_xavier_std": 1.0,
196
+ "is_decoder": false,
197
+ "is_encoder_decoder": true,
198
+ "label2id": {
199
+ "LABEL_0": 0,
200
+ "LABEL_1": 1,
201
+ "LABEL_2": 2,
202
+ "LABEL_3": 3
203
+ },
204
+ "length_penalty": 1.0,
205
+ "mask_loss_coefficient": 1,
206
+ "max_length": 20,
207
+ "max_position_embeddings": 1024,
208
+ "min_length": 0,
209
+ "model_type": "",
210
+ "no_repeat_ngram_size": 0,
211
+ "num_beam_groups": 1,
212
+ "num_beams": 1,
213
+ "num_channels": 3,
214
+ "num_hidden_layers": 6,
215
+ "num_queries": 305,
216
+ "num_return_sequences": 1,
217
+ "output_attentions": false,
218
+ "output_hidden_states": false,
219
+ "output_scores": false,
220
+ "pad_token_id": null,
221
+ "position_embedding_type": "sine",
222
+ "prefix": null,
223
+ "problem_type": null,
224
+ "pruned_heads": {},
225
+ "remove_invalid_values": false,
226
+ "repetition_penalty": 1.0,
227
+ "return_dict": true,
228
+ "return_dict_in_generate": false,
229
+ "scale_embedding": false,
230
+ "sep_token_id": null,
231
+ "suppress_tokens": null,
232
+ "task_specific_params": null,
233
+ "temperature": 1.0,
234
+ "tf_legacy_loss": false,
235
+ "tie_encoder_decoder": false,
236
+ "tie_word_embeddings": true,
237
+ "tokenizer_class": null,
238
+ "top_k": 50,
239
+ "top_p": 1.0,
240
+ "torch_dtype": "float32",
241
+ "torchscript": false,
242
+ "typical_p": 1.0,
243
+ "use_bfloat16": false,
244
+ "use_pretrained_backbone": true,
245
+ "use_timm_backbone": true
246
+ },
247
+ "disable_crop_embeddings": false,
248
+ "disable_detections": false,
249
+ "disable_ocr": false,
250
+ "kwargs": {
251
+ "_commit_hash": null,
252
+ "_name_or_path": "ragavsachdeva/magiv2",
253
+ "architectures": [
254
+ "Magiv2Model"
255
+ ],
256
+ "auto_map": {
257
+ "AutoConfig": "configuration_magiv2.Magiv2Config",
258
+ "AutoModel": "modelling_magiv2.Magiv2Model"
259
+ },
260
+ "model_type": "magiv2",
261
+ "torch_dtype": "float32",
262
+ "transformers_version": "4.34.0.dev0"
263
+ },
264
+ "model_type": "magiv2",
265
+ "ocr_model_config": {
266
+ "_name_or_path": "microsoft/trocr-base-printed",
267
+ "add_cross_attention": false,
268
+ "architectures": [
269
+ "VisionEncoderDecoderModel"
270
+ ],
271
+ "bad_words_ids": null,
272
+ "begin_suppress_tokens": null,
273
+ "bos_token_id": null,
274
+ "chunk_size_feed_forward": 0,
275
+ "cross_attention_hidden_size": null,
276
+ "decoder": {
277
+ "_name_or_path": "",
278
+ "activation_dropout": 0.0,
279
+ "activation_function": "gelu",
280
+ "add_cross_attention": true,
281
+ "architectures": null,
282
+ "attention_dropout": 0.0,
283
+ "bad_words_ids": null,
284
+ "begin_suppress_tokens": null,
285
+ "bos_token_id": 0,
286
+ "chunk_size_feed_forward": 0,
287
+ "classifier_dropout": 0.0,
288
+ "cross_attention_hidden_size": 768,
289
+ "d_model": 1024,
290
+ "decoder_attention_heads": 16,
291
+ "decoder_ffn_dim": 4096,
292
+ "decoder_layerdrop": 0.0,
293
+ "decoder_layers": 12,
294
+ "decoder_start_token_id": 2,
295
+ "diversity_penalty": 0.0,
296
+ "do_sample": false,
297
+ "dropout": 0.1,
298
+ "early_stopping": false,
299
+ "encoder_no_repeat_ngram_size": 0,
300
+ "eos_token_id": 2,
301
+ "exponential_decay_length_penalty": null,
302
+ "finetuning_task": null,
303
+ "forced_bos_token_id": null,
304
+ "forced_eos_token_id": null,
305
+ "id2label": {
306
+ "0": "LABEL_0",
307
+ "1": "LABEL_1"
308
+ },
309
+ "init_std": 0.02,
310
+ "is_decoder": true,
311
+ "is_encoder_decoder": false,
312
+ "label2id": {
313
+ "LABEL_0": 0,
314
+ "LABEL_1": 1
315
+ },
316
+ "layernorm_embedding": true,
317
+ "length_penalty": 1.0,
318
+ "max_length": 20,
319
+ "max_position_embeddings": 512,
320
+ "min_length": 0,
321
+ "model_type": "trocr",
322
+ "no_repeat_ngram_size": 0,
323
+ "num_beam_groups": 1,
324
+ "num_beams": 1,
325
+ "num_return_sequences": 1,
326
+ "output_attentions": false,
327
+ "output_hidden_states": false,
328
+ "output_scores": false,
329
+ "pad_token_id": 1,
330
+ "prefix": null,
331
+ "problem_type": null,
332
+ "pruned_heads": {},
333
+ "remove_invalid_values": false,
334
+ "repetition_penalty": 1.0,
335
+ "return_dict": true,
336
+ "return_dict_in_generate": false,
337
+ "scale_embedding": false,
338
+ "sep_token_id": null,
339
+ "suppress_tokens": null,
340
+ "task_specific_params": null,
341
+ "temperature": 1.0,
342
+ "tf_legacy_loss": false,
343
+ "tie_encoder_decoder": false,
344
+ "tie_word_embeddings": true,
345
+ "tokenizer_class": null,
346
+ "top_k": 50,
347
+ "top_p": 1.0,
348
+ "torch_dtype": null,
349
+ "torchscript": false,
350
+ "typical_p": 1.0,
351
+ "use_bfloat16": false,
352
+ "use_cache": false,
353
+ "use_learned_position_embeddings": true,
354
+ "vocab_size": 50265
355
+ },
356
+ "decoder_start_token_id": null,
357
+ "diversity_penalty": 0.0,
358
+ "do_sample": false,
359
+ "early_stopping": false,
360
+ "encoder": {
361
+ "_name_or_path": "",
362
+ "add_cross_attention": false,
363
+ "architectures": null,
364
+ "attention_probs_dropout_prob": 0.0,
365
+ "bad_words_ids": null,
366
+ "begin_suppress_tokens": null,
367
+ "bos_token_id": null,
368
+ "chunk_size_feed_forward": 0,
369
+ "cross_attention_hidden_size": null,
370
+ "decoder_start_token_id": null,
371
+ "diversity_penalty": 0.0,
372
+ "do_sample": false,
373
+ "early_stopping": false,
374
+ "encoder_no_repeat_ngram_size": 0,
375
+ "encoder_stride": 16,
376
+ "eos_token_id": null,
377
+ "exponential_decay_length_penalty": null,
378
+ "finetuning_task": null,
379
+ "forced_bos_token_id": null,
380
+ "forced_eos_token_id": null,
381
+ "hidden_act": "gelu",
382
+ "hidden_dropout_prob": 0.0,
383
+ "hidden_size": 768,
384
+ "id2label": {
385
+ "0": "LABEL_0",
386
+ "1": "LABEL_1"
387
+ },
388
+ "image_size": 384,
389
+ "initializer_range": 0.02,
390
+ "intermediate_size": 3072,
391
+ "is_decoder": false,
392
+ "is_encoder_decoder": false,
393
+ "label2id": {
394
+ "LABEL_0": 0,
395
+ "LABEL_1": 1
396
+ },
397
+ "layer_norm_eps": 1e-12,
398
+ "length_penalty": 1.0,
399
+ "max_length": 20,
400
+ "min_length": 0,
401
+ "model_type": "vit",
402
+ "no_repeat_ngram_size": 0,
403
+ "num_attention_heads": 12,
404
+ "num_beam_groups": 1,
405
+ "num_beams": 1,
406
+ "num_channels": 3,
407
+ "num_hidden_layers": 12,
408
+ "num_return_sequences": 1,
409
+ "output_attentions": false,
410
+ "output_hidden_states": false,
411
+ "output_scores": false,
412
+ "pad_token_id": null,
413
+ "patch_size": 16,
414
+ "prefix": null,
415
+ "problem_type": null,
416
+ "pruned_heads": {},
417
+ "qkv_bias": false,
418
+ "remove_invalid_values": false,
419
+ "repetition_penalty": 1.0,
420
+ "return_dict": true,
421
+ "return_dict_in_generate": false,
422
+ "sep_token_id": null,
423
+ "suppress_tokens": null,
424
+ "task_specific_params": null,
425
+ "temperature": 1.0,
426
+ "tf_legacy_loss": false,
427
+ "tie_encoder_decoder": false,
428
+ "tie_word_embeddings": true,
429
+ "tokenizer_class": null,
430
+ "top_k": 50,
431
+ "top_p": 1.0,
432
+ "torch_dtype": null,
433
+ "torchscript": false,
434
+ "typical_p": 1.0,
435
+ "use_bfloat16": false
436
+ },
437
+ "encoder_no_repeat_ngram_size": 0,
438
+ "eos_token_id": null,
439
+ "exponential_decay_length_penalty": null,
440
+ "finetuning_task": null,
441
+ "forced_bos_token_id": null,
442
+ "forced_eos_token_id": null,
443
+ "id2label": {
444
+ "0": "LABEL_0",
445
+ "1": "LABEL_1"
446
+ },
447
+ "is_decoder": false,
448
+ "is_encoder_decoder": true,
449
+ "label2id": {
450
+ "LABEL_0": 0,
451
+ "LABEL_1": 1
452
+ },
453
+ "length_penalty": 1.0,
454
+ "max_length": 20,
455
+ "min_length": 0,
456
+ "model_type": "vision-encoder-decoder",
457
+ "no_repeat_ngram_size": 0,
458
+ "num_beam_groups": 1,
459
+ "num_beams": 1,
460
+ "num_return_sequences": 1,
461
+ "output_attentions": false,
462
+ "output_hidden_states": false,
463
+ "output_scores": false,
464
+ "pad_token_id": null,
465
+ "prefix": null,
466
+ "problem_type": null,
467
+ "pruned_heads": {},
468
+ "remove_invalid_values": false,
469
+ "repetition_penalty": 1.0,
470
+ "return_dict": true,
471
+ "return_dict_in_generate": false,
472
+ "sep_token_id": null,
473
+ "suppress_tokens": null,
474
+ "task_specific_params": null,
475
+ "temperature": 1.0,
476
+ "tf_legacy_loss": false,
477
+ "tie_encoder_decoder": false,
478
+ "tie_word_embeddings": false,
479
+ "tokenizer_class": null,
480
+ "top_k": 50,
481
+ "top_p": 1.0,
482
+ "torch_dtype": "float32",
483
+ "torchscript": false,
484
+ "typical_p": 1.0,
485
+ "use_bfloat16": false
486
+ },
487
+ "ocr_pretrained_processor_path": "microsoft/trocr-base-printed",
488
+ "torch_dtype": "float32",
489
+ "transformers_version": "4.34.0.dev0"
490
+ }
configuration_magiv2.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, VisionEncoderDecoderConfig
2
+ from typing import List
3
+
4
+
5
+ class Magiv2Config(PretrainedConfig):
6
+ model_type = "magiv2"
7
+
8
+ def __init__(
9
+ self,
10
+ disable_ocr: bool = False,
11
+ disable_crop_embeddings: bool = False,
12
+ disable_detections: bool = False,
13
+ detection_model_config: dict = None,
14
+ ocr_model_config: dict = None,
15
+ crop_embedding_model_config: dict = None,
16
+ detection_image_preprocessing_config: dict = None,
17
+ ocr_pretrained_processor_path: str = None,
18
+ crop_embedding_image_preprocessing_config: dict = None,
19
+ **kwargs,
20
+ ):
21
+ self.disable_ocr = disable_ocr
22
+ self.disable_crop_embeddings = disable_crop_embeddings
23
+ self.disable_detections = disable_detections
24
+ self.kwargs = kwargs
25
+ self.detection_model_config = None
26
+ self.ocr_model_config = None
27
+ self.crop_embedding_model_config = None
28
+ if detection_model_config is not None:
29
+ self.detection_model_config = PretrainedConfig.from_dict(detection_model_config)
30
+ if ocr_model_config is not None:
31
+ self.ocr_model_config = VisionEncoderDecoderConfig.from_dict(ocr_model_config)
32
+ if crop_embedding_model_config is not None:
33
+ self.crop_embedding_model_config = PretrainedConfig.from_dict(crop_embedding_model_config)
34
+
35
+ self.detection_image_preprocessing_config = detection_image_preprocessing_config
36
+ self.ocr_pretrained_processor_path = ocr_pretrained_processor_path
37
+ self.crop_embedding_image_preprocessing_config = crop_embedding_image_preprocessing_config
38
+ super().__init__(**kwargs)
modelling_magiv2.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel, ConditionalDetrModel
2
+ from transformers.models.conditional_detr.modeling_conditional_detr import (
3
+ ConditionalDetrMLPPredictionHead,
4
+ ConditionalDetrModelOutput,
5
+ ConditionalDetrHungarianMatcher,
6
+ inverse_sigmoid,
7
+ )
8
+ from .configuration_magiv2 import Magiv2Config
9
+ from .processing_magiv2 import Magiv2Processor
10
+ from torch import nn
11
+ from typing import Optional, List
12
+ import torch
13
+ from einops import rearrange, repeat
14
+ from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order
15
+ from transformers.image_transforms import center_to_corners_format
16
+ from .utils import UnionFind, sort_panels, sort_text_boxes_in_reading_order
17
+ import pulp
18
+ import scipy
19
+ import numpy as np
20
+
21
+ class Magiv2Model(PreTrainedModel):
22
+ config_class = Magiv2Config
23
+
24
+ def __init__(self, config):
25
+ super().__init__(config)
26
+ self.config = config
27
+ self.processor = Magiv2Processor(config)
28
+ if not config.disable_ocr:
29
+ self.ocr_model = VisionEncoderDecoderModel(config.ocr_model_config)
30
+ if not config.disable_crop_embeddings:
31
+ self.crop_embedding_model = ViTMAEModel(config.crop_embedding_model_config)
32
+ if not config.disable_detections:
33
+ self.num_non_obj_tokens = 5
34
+ self.detection_transformer = ConditionalDetrModel(config.detection_model_config)
35
+ self.bbox_predictor = ConditionalDetrMLPPredictionHead(
36
+ input_dim=config.detection_model_config.d_model,
37
+ hidden_dim=config.detection_model_config.d_model,
38
+ output_dim=4, num_layers=3
39
+ )
40
+ self.character_character_matching_head = ConditionalDetrMLPPredictionHead(
41
+ input_dim = 3 * config.detection_model_config.d_model + (2 * config.crop_embedding_model_config.hidden_size if not config.disable_crop_embeddings else 0),
42
+ hidden_dim=config.detection_model_config.d_model,
43
+ output_dim=1, num_layers=3
44
+ )
45
+ self.text_character_matching_head = ConditionalDetrMLPPredictionHead(
46
+ input_dim = 3 * config.detection_model_config.d_model,
47
+ hidden_dim=config.detection_model_config.d_model,
48
+ output_dim=1, num_layers=3
49
+ )
50
+ self.text_tail_matching_head = ConditionalDetrMLPPredictionHead(
51
+ input_dim = 2 * config.detection_model_config.d_model,
52
+ hidden_dim=config.detection_model_config.d_model,
53
+ output_dim=1, num_layers=3
54
+ )
55
+ self.class_labels_classifier = nn.Linear(
56
+ config.detection_model_config.d_model, config.detection_model_config.num_labels
57
+ )
58
+ self.is_this_text_a_dialogue = nn.Linear(
59
+ config.detection_model_config.d_model, 1
60
+ )
61
+ self.matcher = ConditionalDetrHungarianMatcher(
62
+ class_cost=config.detection_model_config.class_cost,
63
+ bbox_cost=config.detection_model_config.bbox_cost,
64
+ giou_cost=config.detection_model_config.giou_cost
65
+ )
66
+
67
+ def move_to_device(self, input):
68
+ return move_to_device(input, self.device)
69
+
70
+ @torch.no_grad()
71
+ def do_chapter_wide_prediction(self, pages_in_order, character_bank, eta=0.75, batch_size=8, use_tqdm=False, do_ocr=True):
72
+ texts = []
73
+ characters = []
74
+ character_clusters = []
75
+ if use_tqdm:
76
+ from tqdm import tqdm
77
+ iterator = tqdm(range(0, len(pages_in_order), batch_size))
78
+ else:
79
+ iterator = range(0, len(pages_in_order), batch_size)
80
+ per_page_results = []
81
+ for i in iterator:
82
+ pages = pages_in_order[i:i+batch_size]
83
+ results = self.predict_detections_and_associations(pages)
84
+ per_page_results.extend([result for result in results])
85
+
86
+ texts = [result["texts"] for result in per_page_results]
87
+ characters = [result["characters"] for result in per_page_results]
88
+ character_clusters = [result["character_cluster_labels"] for result in per_page_results]
89
+ assigned_character_names = self.assign_names_to_characters(pages_in_order, characters, character_bank, character_clusters, eta=eta)
90
+ if do_ocr:
91
+ ocr = self.predict_ocr(pages_in_order, texts, use_tqdm=use_tqdm)
92
+ offset_characters = 0
93
+ iteration_over = zip(per_page_results, ocr) if do_ocr else per_page_results
94
+ for iter in iteration_over:
95
+ if do_ocr:
96
+ result, ocr_for_page = iter
97
+ result["ocr"] = ocr_for_page
98
+ else:
99
+ result = iter
100
+ result["character_names"] = assigned_character_names[offset_characters:offset_characters + len(result["characters"])]
101
+ offset_characters += len(result["characters"])
102
+ return per_page_results
103
+
104
+
105
+ def assign_names_to_characters(self, images, character_bboxes, character_bank, character_clusters, eta=0.75):
106
+ chapter_wide_char_embeddings = self.predict_crop_embeddings(images, character_bboxes)
107
+ chapter_wide_char_embeddings = torch.cat(chapter_wide_char_embeddings, dim=0)
108
+ chapter_wide_char_embeddings = torch.nn.functional.normalize(chapter_wide_char_embeddings, p=2, dim=1).cpu().numpy()
109
+ # create must-link and cannot link constraints from character_clusters
110
+ must_link = []
111
+ cannot_link = []
112
+ offset = 0
113
+ for clusters_per_image in character_clusters:
114
+ for i in range(len(clusters_per_image)):
115
+ for j in range(i+1, len(clusters_per_image)):
116
+ if clusters_per_image[i] == clusters_per_image[j]:
117
+ must_link.append((offset + i, offset + j))
118
+ else:
119
+ cannot_link.append((offset + i, offset + j))
120
+ offset += len(clusters_per_image)
121
+ character_bank_for_this_chapter = self.predict_crop_embeddings(character_bank["images"], [[[0, 0, x.shape[1], x.shape[0]]] for x in character_bank["images"]])
122
+ character_bank_for_this_chapter = torch.cat(character_bank_for_this_chapter, dim=0)
123
+ character_bank_for_this_chapter = torch.nn.functional.normalize(character_bank_for_this_chapter, p=2, dim=1).cpu().numpy()
124
+ costs = scipy.spatial.distance.cdist(chapter_wide_char_embeddings, character_bank_for_this_chapter)
125
+ none_of_the_above = eta * np.ones((costs.shape[0],1))
126
+ costs = np.concatenate([costs, none_of_the_above], axis=1)
127
+ sense = pulp.LpMinimize
128
+ num_supply, num_demand = costs.shape
129
+ problem = pulp.LpProblem("Optimal_Transport_Problem", sense)
130
+ x = pulp.LpVariable.dicts("x", ((i, j) for i in range(num_supply) for j in range(num_demand)), cat='Binary')
131
+ # Objective Function to minimize
132
+ problem += pulp.lpSum([costs[i][j] * x[(i, j)] for i in range(num_supply) for j in range(num_demand)])
133
+ # each crop must be assigned to exactly one character
134
+ for i in range(num_supply):
135
+ problem += pulp.lpSum([x[(i, j)] for j in range(num_demand)]) == 1, f"Supply_{i}_Total_Assignment"
136
+ # cannot link constraints
137
+ for j in range(num_demand-1):
138
+ for (s1, s2) in cannot_link:
139
+ problem += x[(s1, j)] + x[(s2, j)] <= 1, f"Exclusion_{s1}_{s2}_Demand_{j}"
140
+ # must link constraints
141
+ for j in range(num_demand):
142
+ for (s1, s2) in must_link:
143
+ problem += x[(s1, j)] - x[(s2, j)] == 0, f"Inclusion_{s1}_{s2}_Demand_{j}"
144
+ problem.solve()
145
+ assignments = []
146
+ for v in problem.variables():
147
+ if v.varValue > 0:
148
+ index, assignment = v.name.split("(")[1].split(")")[0].split(",")
149
+ assignment = assignment[1:]
150
+ assignments.append((int(index), int(assignment)))
151
+
152
+ labels = np.zeros(num_supply)
153
+ for i, j in assignments:
154
+ labels[i] = j
155
+
156
+ return [character_bank["names"][int(i)] if i < len(character_bank["names"]) else "Other" for i in labels]
157
+
158
+
159
+ def predict_detections_and_associations(
160
+ self,
161
+ images,
162
+ move_to_device_fn=None,
163
+ character_detection_threshold=0.3,
164
+ panel_detection_threshold=0.2,
165
+ text_detection_threshold=0.3,
166
+ tail_detection_threshold=0.34,
167
+ character_character_matching_threshold=0.65,
168
+ text_character_matching_threshold=0.35,
169
+ text_tail_matching_threshold=0.3,
170
+ text_classification_threshold=0.5,
171
+ ):
172
+ assert not self.config.disable_detections
173
+ move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
174
+
175
+ inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images)
176
+ inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
177
+
178
+ detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
179
+ predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
180
+
181
+ original_image_sizes = torch.stack([torch.tensor(img.shape[:2]) for img in images], dim=0).to(predicted_bboxes.device)
182
+
183
+ batch_scores, batch_labels = predicted_class_scores.max(-1)
184
+ batch_scores = batch_scores.sigmoid()
185
+ batch_labels = batch_labels.long()
186
+ batch_bboxes = center_to_corners_format(predicted_bboxes)
187
+
188
+ # scale the bboxes back to the original image size
189
+ if isinstance(original_image_sizes, List):
190
+ img_h = torch.Tensor([i[0] for i in original_image_sizes])
191
+ img_w = torch.Tensor([i[1] for i in original_image_sizes])
192
+ else:
193
+ img_h, img_w = original_image_sizes.unbind(1)
194
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(batch_bboxes.device)
195
+ batch_bboxes = batch_bboxes * scale_fct[:, None, :]
196
+
197
+ batch_panel_indices = self.processor._get_indices_of_panels_to_keep(batch_scores, batch_labels, batch_bboxes, panel_detection_threshold)
198
+ batch_character_indices = self.processor._get_indices_of_characters_to_keep(batch_scores, batch_labels, batch_bboxes, character_detection_threshold)
199
+ batch_text_indices = self.processor._get_indices_of_texts_to_keep(batch_scores, batch_labels, batch_bboxes, text_detection_threshold)
200
+ batch_tail_indices = self.processor._get_indices_of_tails_to_keep(batch_scores, batch_labels, batch_bboxes, tail_detection_threshold)
201
+
202
+ predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
203
+ predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
204
+ predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
205
+
206
+ text_character_affinity_matrices = self._get_text_character_affinity_matrices(
207
+ character_obj_tokens_for_batch=[x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_character_indices)],
208
+ text_obj_tokens_for_this_batch=[x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_text_indices)],
209
+ t2c_tokens_for_batch=predicted_t2c_tokens_for_batch,
210
+ apply_sigmoid=True,
211
+ )
212
+
213
+ character_bboxes_in_batch = [batch_bboxes[i][j] for i, j in enumerate(batch_character_indices)]
214
+ character_character_affinity_matrices = self._get_character_character_affinity_matrices(
215
+ character_obj_tokens_for_batch=[x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_character_indices)],
216
+ crop_embeddings_for_batch=self.predict_crop_embeddings(images, character_bboxes_in_batch, move_to_device_fn),
217
+ c2c_tokens_for_batch=predicted_c2c_tokens_for_batch,
218
+ apply_sigmoid=True,
219
+ )
220
+
221
+ text_tail_affinity_matrices = self._get_text_tail_affinity_matrices(
222
+ text_obj_tokens_for_this_batch=[x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_text_indices)],
223
+ tail_obj_tokens_for_batch=[x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_tail_indices)],
224
+ apply_sigmoid=True,
225
+ )
226
+
227
+ is_this_text_a_dialogue = self._get_text_classification([x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_text_indices)])
228
+
229
+ results = []
230
+ for batch_index in range(len(batch_scores)):
231
+ panel_indices = batch_panel_indices[batch_index]
232
+ character_indices = batch_character_indices[batch_index]
233
+ text_indices = batch_text_indices[batch_index]
234
+ tail_indices = batch_tail_indices[batch_index]
235
+
236
+ character_bboxes = batch_bboxes[batch_index][character_indices]
237
+ panel_bboxes = batch_bboxes[batch_index][panel_indices]
238
+ text_bboxes = batch_bboxes[batch_index][text_indices]
239
+ tail_bboxes = batch_bboxes[batch_index][tail_indices]
240
+
241
+ local_sorted_panel_indices = sort_panels(panel_bboxes)
242
+ panel_bboxes = panel_bboxes[local_sorted_panel_indices]
243
+ local_sorted_text_indices = sort_text_boxes_in_reading_order(text_bboxes, panel_bboxes)
244
+ text_bboxes = text_bboxes[local_sorted_text_indices]
245
+
246
+ character_character_matching_scores = character_character_affinity_matrices[batch_index]
247
+ text_character_matching_scores = text_character_affinity_matrices[batch_index][local_sorted_text_indices]
248
+ text_tail_matching_scores = text_tail_affinity_matrices[batch_index][local_sorted_text_indices]
249
+
250
+ is_essential_text = is_this_text_a_dialogue[batch_index][local_sorted_text_indices] > text_classification_threshold
251
+ character_cluster_labels = UnionFind.from_adj_matrix(
252
+ character_character_matching_scores > character_character_matching_threshold
253
+ ).get_labels_for_connected_components()
254
+
255
+ if 0 in text_character_matching_scores.shape:
256
+ text_character_associations = torch.zeros((0, 2), dtype=torch.long)
257
+ else:
258
+ most_likely_speaker_for_each_text = torch.argmax(text_character_matching_scores, dim=1)
259
+ text_indices = torch.arange(len(text_bboxes)).type_as(most_likely_speaker_for_each_text)
260
+ text_character_associations = torch.stack([text_indices, most_likely_speaker_for_each_text], dim=1)
261
+ to_keep = text_character_matching_scores.max(dim=1).values > text_character_matching_threshold
262
+ text_character_associations = text_character_associations[to_keep]
263
+
264
+ if 0 in text_tail_matching_scores.shape:
265
+ text_tail_associations = torch.zeros((0, 2), dtype=torch.long)
266
+ else:
267
+ most_likely_tail_for_each_text = torch.argmax(text_tail_matching_scores, dim=1)
268
+ text_indices = torch.arange(len(text_bboxes)).type_as(most_likely_tail_for_each_text)
269
+ text_tail_associations = torch.stack([text_indices, most_likely_tail_for_each_text], dim=1)
270
+ to_keep = text_tail_matching_scores.max(dim=1).values > text_tail_matching_threshold
271
+ text_tail_associations = text_tail_associations[to_keep]
272
+
273
+ results.append({
274
+ "panels": panel_bboxes.tolist(),
275
+ "texts": text_bboxes.tolist(),
276
+ "characters": character_bboxes.tolist(),
277
+ "tails": tail_bboxes.tolist(),
278
+ "text_character_associations": text_character_associations.tolist(),
279
+ "text_tail_associations": text_tail_associations.tolist(),
280
+ "character_cluster_labels": character_cluster_labels,
281
+ "is_essential_text": is_essential_text.tolist(),
282
+ })
283
+
284
+ return results
285
+
286
+ def get_affinity_matrices_given_annotations(
287
+ self, images, annotations, move_to_device_fn=None, apply_sigmoid=True
288
+ ):
289
+ assert not self.config.disable_detections
290
+ move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
291
+
292
+ character_bboxes_in_batch = [[bbox for bbox, label in zip(a["bboxes_as_x1y1x2y2"], a["labels"]) if label == 0] for a in annotations]
293
+ crop_embeddings_for_batch = self.predict_crop_embeddings(images, character_bboxes_in_batch, move_to_device_fn)
294
+
295
+ inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
296
+ inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
297
+ processed_targets = inputs_to_detection_transformer.pop("labels")
298
+
299
+ detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
300
+ predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
301
+ predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
302
+ predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
303
+
304
+ predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
305
+ matching_dict = {
306
+ "logits": predicted_class_scores,
307
+ "pred_boxes": predicted_bboxes,
308
+ }
309
+ indices = self.matcher(matching_dict, processed_targets)
310
+
311
+ matched_char_obj_tokens_for_batch = []
312
+ matched_text_obj_tokens_for_batch = []
313
+ matched_tail_obj_tokens_for_batch = []
314
+ t2c_tokens_for_batch = []
315
+ c2c_tokens_for_batch = []
316
+
317
+ for j, (pred_idx, tgt_idx) in enumerate(indices):
318
+ target_idx_to_pred_idx = {tgt.item(): pred.item() for pred, tgt in zip(pred_idx, tgt_idx)}
319
+ targets_for_this_image = processed_targets[j]
320
+ indices_of_text_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 1]
321
+ indices_of_char_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 0]
322
+ indices_of_tail_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 3]
323
+ predicted_text_indices = [target_idx_to_pred_idx[i] for i in indices_of_text_boxes_in_annotation]
324
+ predicted_char_indices = [target_idx_to_pred_idx[i] for i in indices_of_char_boxes_in_annotation]
325
+ predicted_tail_indices = [target_idx_to_pred_idx[i] for i in indices_of_tail_boxes_in_annotation]
326
+ matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
327
+ matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
328
+ matched_tail_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_tail_indices])
329
+ t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
330
+ c2c_tokens_for_batch.append(predicted_c2c_tokens_for_batch[j])
331
+
332
+ text_character_affinity_matrices = self._get_text_character_affinity_matrices(
333
+ character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
334
+ text_obj_tokens_for_this_batch=matched_text_obj_tokens_for_batch,
335
+ t2c_tokens_for_batch=t2c_tokens_for_batch,
336
+ apply_sigmoid=apply_sigmoid,
337
+ )
338
+
339
+ character_character_affinity_matrices = self._get_character_character_affinity_matrices(
340
+ character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
341
+ crop_embeddings_for_batch=crop_embeddings_for_batch,
342
+ c2c_tokens_for_batch=c2c_tokens_for_batch,
343
+ apply_sigmoid=apply_sigmoid,
344
+ )
345
+
346
+ character_character_affinity_matrices_crop_only = self._get_character_character_affinity_matrices(
347
+ character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
348
+ crop_embeddings_for_batch=crop_embeddings_for_batch,
349
+ c2c_tokens_for_batch=c2c_tokens_for_batch,
350
+ crop_only=True,
351
+ apply_sigmoid=apply_sigmoid,
352
+ )
353
+
354
+ text_tail_affinity_matrices = self._get_text_tail_affinity_matrices(
355
+ text_obj_tokens_for_this_batch=matched_text_obj_tokens_for_batch,
356
+ tail_obj_tokens_for_batch=matched_tail_obj_tokens_for_batch,
357
+ apply_sigmoid=apply_sigmoid,
358
+ )
359
+
360
+ is_this_text_a_dialogue = self._get_text_classification(matched_text_obj_tokens_for_batch, apply_sigmoid=apply_sigmoid)
361
+
362
+ return {
363
+ "text_character_affinity_matrices": text_character_affinity_matrices,
364
+ "character_character_affinity_matrices": character_character_affinity_matrices,
365
+ "character_character_affinity_matrices_crop_only": character_character_affinity_matrices_crop_only,
366
+ "text_tail_affinity_matrices": text_tail_affinity_matrices,
367
+ "is_this_text_a_dialogue": is_this_text_a_dialogue,
368
+ }
369
+
370
+
371
+ def predict_crop_embeddings(self, images, crop_bboxes, move_to_device_fn=None, mask_ratio=0.0, batch_size=256):
372
+ if self.config.disable_crop_embeddings:
373
+ return None
374
+
375
+ assert isinstance(crop_bboxes, List), "please provide a list of bboxes for each image to get embeddings for"
376
+
377
+ move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
378
+
379
+ # temporarily change the mask ratio from default to the one specified
380
+ old_mask_ratio = self.crop_embedding_model.embeddings.config.mask_ratio
381
+ self.crop_embedding_model.embeddings.config.mask_ratio = mask_ratio
382
+
383
+ crops_per_image = []
384
+ num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
385
+ for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
386
+ crops = self.processor.crop_image(image, bboxes)
387
+ assert len(crops) == num_crops
388
+ crops_per_image.extend(crops)
389
+
390
+ if len(crops_per_image) == 0:
391
+ return [move_to_device_fn(torch.zeros(0, self.config.crop_embedding_model_config.hidden_size)) for _ in crop_bboxes]
392
+
393
+ crops_per_image = self.processor.preprocess_inputs_for_crop_embeddings(crops_per_image)
394
+ crops_per_image = move_to_device_fn(crops_per_image)
395
+
396
+ # process the crops in batches to avoid OOM
397
+ embeddings = []
398
+ for i in range(0, len(crops_per_image), batch_size):
399
+ crops = crops_per_image[i:i+batch_size]
400
+ embeddings_per_batch = self.crop_embedding_model(crops).last_hidden_state[:, 0]
401
+ embeddings.append(embeddings_per_batch)
402
+ embeddings = torch.cat(embeddings, dim=0)
403
+
404
+ crop_embeddings_for_batch = []
405
+ for num_crops in num_crops_per_batch:
406
+ crop_embeddings_for_batch.append(embeddings[:num_crops])
407
+ embeddings = embeddings[num_crops:]
408
+
409
+ # restore the mask ratio to the default
410
+ self.crop_embedding_model.embeddings.config.mask_ratio = old_mask_ratio
411
+
412
+ return crop_embeddings_for_batch
413
+
414
+ def predict_ocr(self, images, crop_bboxes, move_to_device_fn=None, use_tqdm=False, batch_size=32):
415
+ assert not self.config.disable_ocr
416
+ move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
417
+
418
+ crops_per_image = []
419
+ num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
420
+ for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
421
+ crops = self.processor.crop_image(image, bboxes)
422
+ assert len(crops) == num_crops
423
+ crops_per_image.extend(crops)
424
+
425
+ if len(crops_per_image) == 0:
426
+ return [[] for _ in crop_bboxes]
427
+
428
+ crops_per_image = self.processor.preprocess_inputs_for_ocr(crops_per_image)
429
+ crops_per_image = move_to_device_fn(crops_per_image)
430
+
431
+ # process the crops in batches to avoid OOM
432
+ all_generated_texts = []
433
+ if use_tqdm:
434
+ from tqdm import tqdm
435
+ pbar = tqdm(range(0, len(crops_per_image), batch_size))
436
+ else:
437
+ pbar = range(0, len(crops_per_image), batch_size)
438
+ for i in pbar:
439
+ crops = crops_per_image[i:i+batch_size]
440
+ generated_ids = self.ocr_model.generate(crops)
441
+ generated_texts = self.processor.postprocess_ocr_tokens(generated_ids)
442
+ all_generated_texts.extend(generated_texts)
443
+
444
+ texts_for_images = []
445
+ for num_crops in num_crops_per_batch:
446
+ texts_for_images.append([x.replace("\n", "") for x in all_generated_texts[:num_crops]])
447
+ all_generated_texts = all_generated_texts[num_crops:]
448
+
449
+ return texts_for_images
450
+
451
+ def visualise_single_image_prediction(
452
+ self, image_as_np_array, predictions, filename=None
453
+ ):
454
+ return visualise_single_image_prediction(image_as_np_array, predictions, filename)
455
+
456
+
457
+ @torch.no_grad()
458
+ def _get_detection_transformer_output(
459
+ self,
460
+ pixel_values: torch.FloatTensor,
461
+ pixel_mask: Optional[torch.LongTensor] = None
462
+ ):
463
+ if self.config.disable_detections:
464
+ raise ValueError("Detection model is disabled. Set disable_detections=False in the config.")
465
+ return self.detection_transformer(
466
+ pixel_values=pixel_values,
467
+ pixel_mask=pixel_mask,
468
+ return_dict=True
469
+ )
470
+
471
+ def _get_predicted_obj_tokens(
472
+ self,
473
+ detection_transformer_output: ConditionalDetrModelOutput
474
+ ):
475
+ return detection_transformer_output.last_hidden_state[:, :-self.num_non_obj_tokens]
476
+
477
+ def _get_predicted_c2c_tokens(
478
+ self,
479
+ detection_transformer_output: ConditionalDetrModelOutput
480
+ ):
481
+ return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens]
482
+
483
+ def _get_predicted_t2c_tokens(
484
+ self,
485
+ detection_transformer_output: ConditionalDetrModelOutput
486
+ ):
487
+ return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens+1]
488
+
489
+ def _get_predicted_bboxes_and_classes(
490
+ self,
491
+ detection_transformer_output: ConditionalDetrModelOutput,
492
+ ):
493
+ if self.config.disable_detections:
494
+ raise ValueError("Detection model is disabled. Set disable_detections=False in the config.")
495
+
496
+ obj = self._get_predicted_obj_tokens(detection_transformer_output)
497
+
498
+ predicted_class_scores = self.class_labels_classifier(obj)
499
+ reference = detection_transformer_output.reference_points[:-self.num_non_obj_tokens]
500
+ reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
501
+ predicted_boxes = self.bbox_predictor(obj)
502
+ predicted_boxes[..., :2] += reference_before_sigmoid
503
+ predicted_boxes = predicted_boxes.sigmoid()
504
+
505
+ return predicted_class_scores, predicted_boxes
506
+
507
+ def _get_text_classification(
508
+ self,
509
+ text_obj_tokens_for_batch: List[torch.FloatTensor],
510
+ apply_sigmoid=False,
511
+ ):
512
+ assert not self.config.disable_detections
513
+ is_this_text_a_dialogue = []
514
+ for text_obj_tokens in text_obj_tokens_for_batch:
515
+ if text_obj_tokens.shape[0] == 0:
516
+ is_this_text_a_dialogue.append(torch.tensor([], dtype=torch.bool))
517
+ continue
518
+ classification = self.is_this_text_a_dialogue(text_obj_tokens).squeeze(-1)
519
+ if apply_sigmoid:
520
+ classification = classification.sigmoid()
521
+ is_this_text_a_dialogue.append(classification)
522
+ return is_this_text_a_dialogue
523
+
524
+ def _get_character_character_affinity_matrices(
525
+ self,
526
+ character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
527
+ crop_embeddings_for_batch: List[torch.FloatTensor] = None,
528
+ c2c_tokens_for_batch: List[torch.FloatTensor] = None,
529
+ crop_only=False,
530
+ apply_sigmoid=True,
531
+ ):
532
+ assert self.config.disable_detections or (character_obj_tokens_for_batch is not None and c2c_tokens_for_batch is not None)
533
+ assert self.config.disable_crop_embeddings or crop_embeddings_for_batch is not None
534
+ assert not self.config.disable_detections or not self.config.disable_crop_embeddings
535
+
536
+ if crop_only:
537
+ affinity_matrices = []
538
+ for crop_embeddings in crop_embeddings_for_batch:
539
+ crop_embeddings = crop_embeddings / crop_embeddings.norm(dim=-1, keepdim=True)
540
+ affinity_matrix = crop_embeddings @ crop_embeddings.T
541
+ affinity_matrices.append(affinity_matrix)
542
+ return affinity_matrices
543
+ affinity_matrices = []
544
+ for batch_index, (character_obj_tokens, c2c) in enumerate(zip(character_obj_tokens_for_batch, c2c_tokens_for_batch)):
545
+ if character_obj_tokens.shape[0] == 0:
546
+ affinity_matrices.append(torch.zeros(0, 0).type_as(character_obj_tokens))
547
+ continue
548
+ if not self.config.disable_crop_embeddings:
549
+ crop_embeddings = crop_embeddings_for_batch[batch_index]
550
+ assert character_obj_tokens.shape[0] == crop_embeddings.shape[0]
551
+ character_obj_tokens = torch.cat([character_obj_tokens, crop_embeddings], dim=-1)
552
+ char_i = repeat(character_obj_tokens, "i d -> i repeat d", repeat=character_obj_tokens.shape[0])
553
+ char_j = repeat(character_obj_tokens, "j d -> repeat j d", repeat=character_obj_tokens.shape[0])
554
+ char_ij = rearrange([char_i, char_j], "two i j d -> (i j) (two d)")
555
+ c2c = repeat(c2c, "d -> repeat d", repeat = char_ij.shape[0])
556
+ char_ij_c2c = torch.cat([char_ij, c2c], dim=-1)
557
+ character_character_affinities = self.character_character_matching_head(char_ij_c2c)
558
+ character_character_affinities = rearrange(character_character_affinities, "(i j) 1 -> i j", i=char_i.shape[0])
559
+ character_character_affinities = (character_character_affinities + character_character_affinities.T) / 2
560
+ if apply_sigmoid:
561
+ character_character_affinities = character_character_affinities.sigmoid()
562
+ affinity_matrices.append(character_character_affinities)
563
+ return affinity_matrices
564
+
565
+ def _get_text_character_affinity_matrices(
566
+ self,
567
+ character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
568
+ text_obj_tokens_for_this_batch: List[torch.FloatTensor] = None,
569
+ t2c_tokens_for_batch: List[torch.FloatTensor] = None,
570
+ apply_sigmoid=True,
571
+ ):
572
+ assert not self.config.disable_detections
573
+ assert character_obj_tokens_for_batch is not None and text_obj_tokens_for_this_batch is not None and t2c_tokens_for_batch is not None
574
+ affinity_matrices = []
575
+ for character_obj_tokens, text_obj_tokens, t2c in zip(character_obj_tokens_for_batch, text_obj_tokens_for_this_batch, t2c_tokens_for_batch):
576
+ if character_obj_tokens.shape[0] == 0 or text_obj_tokens.shape[0] == 0:
577
+ affinity_matrices.append(torch.zeros(text_obj_tokens.shape[0], character_obj_tokens.shape[0]).type_as(character_obj_tokens))
578
+ continue
579
+ text_i = repeat(text_obj_tokens, "i d -> i repeat d", repeat=character_obj_tokens.shape[0])
580
+ char_j = repeat(character_obj_tokens, "j d -> repeat j d", repeat=text_obj_tokens.shape[0])
581
+ text_char = rearrange([text_i, char_j], "two i j d -> (i j) (two d)")
582
+ t2c = repeat(t2c, "d -> repeat d", repeat = text_char.shape[0])
583
+ text_char_t2c = torch.cat([text_char, t2c], dim=-1)
584
+ text_character_affinities = self.text_character_matching_head(text_char_t2c)
585
+ text_character_affinities = rearrange(text_character_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
586
+ if apply_sigmoid:
587
+ text_character_affinities = text_character_affinities.sigmoid()
588
+ affinity_matrices.append(text_character_affinities)
589
+ return affinity_matrices
590
+
591
+ def _get_text_tail_affinity_matrices(
592
+ self,
593
+ text_obj_tokens_for_this_batch: List[torch.FloatTensor] = None,
594
+ tail_obj_tokens_for_batch: List[torch.FloatTensor] = None,
595
+ apply_sigmoid=True,
596
+ ):
597
+ assert not self.config.disable_detections
598
+ assert tail_obj_tokens_for_batch is not None and text_obj_tokens_for_this_batch is not None
599
+ affinity_matrices = []
600
+ for tail_obj_tokens, text_obj_tokens in zip(tail_obj_tokens_for_batch, text_obj_tokens_for_this_batch):
601
+ if tail_obj_tokens.shape[0] == 0 or text_obj_tokens.shape[0] == 0:
602
+ affinity_matrices.append(torch.zeros(text_obj_tokens.shape[0], tail_obj_tokens.shape[0]).type_as(tail_obj_tokens))
603
+ continue
604
+ text_i = repeat(text_obj_tokens, "i d -> i repeat d", repeat=tail_obj_tokens.shape[0])
605
+ tail_j = repeat(tail_obj_tokens, "j d -> repeat j d", repeat=text_obj_tokens.shape[0])
606
+ text_tail = rearrange([text_i, tail_j], "two i j d -> (i j) (two d)")
607
+ text_tail_affinities = self.text_tail_matching_head(text_tail)
608
+ text_tail_affinities = rearrange(text_tail_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
609
+ if apply_sigmoid:
610
+ text_tail_affinities = text_tail_affinities.sigmoid()
611
+ affinity_matrices.append(text_tail_affinities)
612
+ return affinity_matrices
processing_magiv2.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ConditionalDetrImageProcessor, TrOCRProcessor, ViTImageProcessor
2
+ import torch
3
+ from typing import List
4
+ from shapely.geometry import box
5
+ from .utils import x1y1x2y2_to_xywh
6
+ import numpy as np
7
+
8
+ class Magiv2Processor():
9
+ def __init__(self, config):
10
+ self.config = config
11
+ self.detection_image_preprocessor = None
12
+ self.ocr_preprocessor = None
13
+ self.crop_embedding_image_preprocessor = None
14
+ if not config.disable_detections:
15
+ assert config.detection_image_preprocessing_config is not None
16
+ self.detection_image_preprocessor = ConditionalDetrImageProcessor.from_dict(config.detection_image_preprocessing_config)
17
+ if not config.disable_ocr:
18
+ assert config.ocr_pretrained_processor_path is not None
19
+ self.ocr_preprocessor = TrOCRProcessor.from_pretrained(config.ocr_pretrained_processor_path)
20
+ if not config.disable_crop_embeddings:
21
+ assert config.crop_embedding_image_preprocessing_config is not None
22
+ self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict(config.crop_embedding_image_preprocessing_config)
23
+
24
+ def preprocess_inputs_for_detection(self, images, annotations=None):
25
+ images = list(images)
26
+ assert isinstance(images[0], np.ndarray)
27
+ annotations = self._convert_annotations_to_coco_format(annotations)
28
+ inputs = self.detection_image_preprocessor(images, annotations=annotations, return_tensors="pt")
29
+ return inputs
30
+
31
+ def preprocess_inputs_for_ocr(self, images):
32
+ images = list(images)
33
+ assert isinstance(images[0], np.ndarray)
34
+ return self.ocr_preprocessor(images, return_tensors="pt").pixel_values
35
+
36
+ def preprocess_inputs_for_crop_embeddings(self, images):
37
+ images = list(images)
38
+ assert isinstance(images[0], np.ndarray)
39
+ return self.crop_embedding_image_preprocessor(images, return_tensors="pt").pixel_values
40
+
41
+ def postprocess_ocr_tokens(self, generated_ids, skip_special_tokens=True):
42
+ return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens)
43
+
44
+ def crop_image(self, image, bboxes):
45
+ crops_for_image = []
46
+ for bbox in bboxes:
47
+ x1, y1, x2, y2 = bbox
48
+
49
+ # fix the bounding box in case it is out of bounds or too small
50
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
51
+ x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2) # just incase
52
+ x1, y1 = max(0, x1), max(0, y1)
53
+ x1, y1 = min(image.shape[1], x1), min(image.shape[0], y1)
54
+ x2, y2 = max(0, x2), max(0, y2)
55
+ x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2)
56
+ if x2 - x1 < 10:
57
+ if image.shape[1] - x1 > 10:
58
+ x2 = x1 + 10
59
+ else:
60
+ x1 = x2 - 10
61
+ if y2 - y1 < 10:
62
+ if image.shape[0] - y1 > 10:
63
+ y2 = y1 + 10
64
+ else:
65
+ y1 = y2 - 10
66
+
67
+ crop = image[y1:y2, x1:x2]
68
+ crops_for_image.append(crop)
69
+ return crops_for_image
70
+
71
+ def _get_indices_of_characters_to_keep(self, batch_scores, batch_labels, batch_bboxes, character_detection_threshold):
72
+ indices_of_characters_to_keep = []
73
+ for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes):
74
+ indices = torch.where((labels == 0) & (scores > character_detection_threshold))[0]
75
+ indices_of_characters_to_keep.append(indices)
76
+ return indices_of_characters_to_keep
77
+
78
+ def _get_indices_of_panels_to_keep(self, batch_scores, batch_labels, batch_bboxes, panel_detection_threshold):
79
+ indices_of_panels_to_keep = []
80
+ for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
81
+ indices = torch.where(labels == 2)[0]
82
+ bboxes = bboxes[indices]
83
+ scores = scores[indices]
84
+ labels = labels[indices]
85
+ if len(indices) == 0:
86
+ indices_of_panels_to_keep.append([])
87
+ continue
88
+ scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
89
+ panels_to_keep = []
90
+ union_of_panels_so_far = box(0, 0, 0, 0)
91
+ for ps, pb, pl, pi in zip(scores, bboxes, labels, indices):
92
+ panel_polygon = box(pb[0], pb[1], pb[2], pb[3])
93
+ if ps < panel_detection_threshold:
94
+ continue
95
+ if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5:
96
+ continue
97
+ panels_to_keep.append((ps, pl, pb, pi))
98
+ union_of_panels_so_far = union_of_panels_so_far.union(panel_polygon)
99
+ indices_of_panels_to_keep.append([p[3].item() for p in panels_to_keep])
100
+ return indices_of_panels_to_keep
101
+
102
+ def _get_indices_of_texts_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
103
+ indices_of_texts_to_keep = []
104
+ for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
105
+ indices = torch.where((labels == 1) & (scores > text_detection_threshold))[0]
106
+ bboxes = bboxes[indices]
107
+ scores = scores[indices]
108
+ labels = labels[indices]
109
+ if len(indices) == 0:
110
+ indices_of_texts_to_keep.append([])
111
+ continue
112
+ scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
113
+ texts_to_keep = []
114
+ texts_to_keep_as_shapely_objects = []
115
+ for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
116
+ text_polygon = box(tb[0], tb[1], tb[2], tb[3])
117
+ should_append = True
118
+ for t in texts_to_keep_as_shapely_objects:
119
+ if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5:
120
+ should_append = False
121
+ break
122
+ if should_append:
123
+ texts_to_keep.append((ts, tl, tb, ti))
124
+ texts_to_keep_as_shapely_objects.append(text_polygon)
125
+ indices_of_texts_to_keep.append([t[3].item() for t in texts_to_keep])
126
+ return indices_of_texts_to_keep
127
+
128
+ def _get_indices_of_tails_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
129
+ indices_of_texts_to_keep = []
130
+ for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
131
+ indices = torch.where((labels == 3) & (scores > text_detection_threshold))[0]
132
+ bboxes = bboxes[indices]
133
+ scores = scores[indices]
134
+ labels = labels[indices]
135
+ if len(indices) == 0:
136
+ indices_of_texts_to_keep.append([])
137
+ continue
138
+ scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
139
+ texts_to_keep = []
140
+ texts_to_keep_as_shapely_objects = []
141
+ for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
142
+ text_polygon = box(tb[0], tb[1], tb[2], tb[3])
143
+ should_append = True
144
+ for t in texts_to_keep_as_shapely_objects:
145
+ if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5:
146
+ should_append = False
147
+ break
148
+ if should_append:
149
+ texts_to_keep.append((ts, tl, tb, ti))
150
+ texts_to_keep_as_shapely_objects.append(text_polygon)
151
+ indices_of_texts_to_keep.append([t[3].item() for t in texts_to_keep])
152
+ return indices_of_texts_to_keep
153
+
154
+ def _convert_annotations_to_coco_format(self, annotations):
155
+ if annotations is None:
156
+ return None
157
+ self._verify_annotations_are_in_correct_format(annotations)
158
+ coco_annotations = []
159
+ for annotation in annotations:
160
+ coco_annotation = {
161
+ "image_id": annotation["image_id"],
162
+ "annotations": [],
163
+ }
164
+ for bbox, label in zip(annotation["bboxes_as_x1y1x2y2"], annotation["labels"]):
165
+ coco_annotation["annotations"].append({
166
+ "bbox": x1y1x2y2_to_xywh(bbox),
167
+ "category_id": label,
168
+ "area": (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]),
169
+ })
170
+ coco_annotations.append(coco_annotation)
171
+ return coco_annotations
172
+
173
+ def _verify_annotations_are_in_correct_format(self, annotations):
174
+ error_msg = """
175
+ Annotations must be in the following format:
176
+ [
177
+ {
178
+ "image_id": 0,
179
+ "bboxes_as_x1y1x2y2": [[0, 0, 10, 10], [10, 10, 20, 20], [20, 20, 30, 30]],
180
+ "labels": [0, 1, 2],
181
+ },
182
+ ...
183
+ ]
184
+ Labels: 0 for characters, 1 for text, 2 for panels.
185
+ """
186
+ if annotations is None:
187
+ return
188
+ if not isinstance(annotations, List) and not isinstance(annotations, tuple):
189
+ raise ValueError(
190
+ f"{error_msg} Expected a List/Tuple, found {type(annotations)}."
191
+ )
192
+ if len(annotations) == 0:
193
+ return
194
+ if not isinstance(annotations[0], dict):
195
+ raise ValueError(
196
+ f"{error_msg} Expected a List[Dicct], found {type(annotations[0])}."
197
+ )
198
+ if "image_id" not in annotations[0]:
199
+ raise ValueError(
200
+ f"{error_msg} Dict must contain 'image_id'."
201
+ )
202
+ if "bboxes_as_x1y1x2y2" not in annotations[0]:
203
+ raise ValueError(
204
+ f"{error_msg} Dict must contain 'bboxes_as_x1y1x2y2'."
205
+ )
206
+ if "labels" not in annotations[0]:
207
+ raise ValueError(
208
+ f"{error_msg} Dict must contain 'labels'."
209
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56392403204d3a4cca38694a3a260a6929d741869d802d6b14de35b4eab4c4b8
3
+ size 2063693064
utils.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.patches as patches
6
+ from shapely.geometry import Point, box
7
+ import networkx as nx
8
+ from copy import deepcopy
9
+ from itertools import groupby
10
+
11
+ def move_to_device(inputs, device):
12
+ if hasattr(inputs, "keys"):
13
+ return {k: move_to_device(v, device) for k, v in inputs.items()}
14
+ elif isinstance(inputs, list):
15
+ return [move_to_device(v, device) for v in inputs]
16
+ elif isinstance(inputs, tuple):
17
+ return tuple([move_to_device(v, device) for v in inputs])
18
+ elif isinstance(inputs, np.ndarray):
19
+ return torch.from_numpy(inputs).to(device)
20
+ else:
21
+ return inputs.to(device)
22
+
23
+ class UnionFind:
24
+ def __init__(self, n):
25
+ self.parent = list(range(n))
26
+ self.size = [1] * n
27
+ self.num_components = n
28
+
29
+ @classmethod
30
+ def from_adj_matrix(cls, adj_matrix):
31
+ ufds = cls(adj_matrix.shape[0])
32
+ for i in range(adj_matrix.shape[0]):
33
+ for j in range(adj_matrix.shape[1]):
34
+ if adj_matrix[i, j] > 0:
35
+ ufds.unite(i, j)
36
+ return ufds
37
+
38
+ @classmethod
39
+ def from_adj_list(cls, adj_list):
40
+ ufds = cls(len(adj_list))
41
+ for i in range(len(adj_list)):
42
+ for j in adj_list[i]:
43
+ ufds.unite(i, j)
44
+ return ufds
45
+
46
+ @classmethod
47
+ def from_edge_list(cls, edge_list, num_nodes):
48
+ ufds = cls(num_nodes)
49
+ for edge in edge_list:
50
+ ufds.unite(edge[0], edge[1])
51
+ return ufds
52
+
53
+ def find(self, x):
54
+ if self.parent[x] == x:
55
+ return x
56
+ self.parent[x] = self.find(self.parent[x])
57
+ return self.parent[x]
58
+
59
+ def unite(self, x, y):
60
+ x = self.find(x)
61
+ y = self.find(y)
62
+ if x != y:
63
+ if self.size[x] < self.size[y]:
64
+ x, y = y, x
65
+ self.parent[y] = x
66
+ self.size[x] += self.size[y]
67
+ self.num_components -= 1
68
+
69
+ def get_components_of(self, x):
70
+ x = self.find(x)
71
+ return [i for i in range(len(self.parent)) if self.find(i) == x]
72
+
73
+ def are_connected(self, x, y):
74
+ return self.find(x) == self.find(y)
75
+
76
+ def get_size(self, x):
77
+ return self.size[self.find(x)]
78
+
79
+ def get_num_components(self):
80
+ return self.num_components
81
+
82
+ def get_labels_for_connected_components(self):
83
+ map_parent_to_label = {}
84
+ labels = []
85
+ for i in range(len(self.parent)):
86
+ parent = self.find(i)
87
+ if parent not in map_parent_to_label:
88
+ map_parent_to_label[parent] = len(map_parent_to_label)
89
+ labels.append(map_parent_to_label[parent])
90
+ return labels
91
+
92
+ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
93
+ figure, subplot = plt.subplots(1, 1, figsize=(10, 10))
94
+ subplot.imshow(image_as_np_array)
95
+ plot_bboxes(subplot, predictions["panels"], color="green")
96
+ plot_bboxes(subplot, predictions["texts"], color="red", visibility=predictions["is_essential_text"])
97
+ plot_bboxes(subplot, predictions["characters"], color="blue")
98
+ plot_bboxes(subplot, predictions["tails"], color="purple")
99
+
100
+ for i, name in enumerate(predictions["character_names"]):
101
+ char_bbox = predictions["characters"][i]
102
+ x1, y1, x2, y2 = char_bbox
103
+ subplot.text(x1, y1 - 2, name,
104
+ verticalalignment='bottom', horizontalalignment='left',
105
+ bbox=dict(facecolor='blue', alpha=1, edgecolor='none'), # Background settings
106
+ color='white', fontsize=8)
107
+
108
+ COLOURS = [
109
+ "#b7ff51", # green
110
+ "#f50a8f", # pink
111
+ "#4b13b6", # purple
112
+ "#ddaa34", # orange
113
+ "#bea2a2", # brown
114
+ ]
115
+ colour_index = 0
116
+ character_cluster_labels = predictions["character_cluster_labels"]
117
+ unique_label_sorted_by_frequency = sorted(list(set(character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
118
+ for label in unique_label_sorted_by_frequency:
119
+ root = None
120
+ others = []
121
+ for i in range(len(predictions["characters"])):
122
+ if character_cluster_labels[i] == label:
123
+ if root is None:
124
+ root = i
125
+ else:
126
+ others.append(i)
127
+ if colour_index >= len(COLOURS):
128
+ random_colour = COLOURS[0]
129
+ while random_colour in COLOURS:
130
+ random_colour = "#" + "".join([random.choice("0123456789ABCDEF") for j in range(6)])
131
+ else:
132
+ random_colour = COLOURS[colour_index]
133
+ colour_index += 1
134
+ bbox_i = predictions["characters"][root]
135
+ x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
136
+ y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
137
+ subplot.plot([x1], [y1], color=random_colour, marker="o", markersize=5)
138
+ for j in others:
139
+ # draw line from centre of bbox i to centre of bbox j
140
+ bbox_j = predictions["characters"][j]
141
+ x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
142
+ y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
143
+ x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
144
+ y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
145
+ subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
146
+ subplot.plot([x2], [y2], color=random_colour, marker="o", markersize=5)
147
+
148
+ for (i, j) in predictions["text_character_associations"]:
149
+ bbox_i = predictions["texts"][i]
150
+ bbox_j = predictions["characters"][j]
151
+ if not predictions["is_essential_text"][i]:
152
+ continue
153
+ x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
154
+ y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
155
+ x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
156
+ y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
157
+ subplot.plot([x1, x2], [y1, y2], color="red", linewidth=2, linestyle="dashed")
158
+
159
+ for (i, j) in predictions["text_tail_associations"]:
160
+ bbox_i = predictions["texts"][i]
161
+ bbox_j = predictions["tails"][j]
162
+ x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
163
+ y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
164
+ x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
165
+ y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
166
+ subplot.plot([x1, x2], [y1, y2], color="purple", linewidth=2, linestyle="dashed")
167
+
168
+ subplot.axis("off")
169
+ if filename is not None:
170
+ plt.savefig(filename, bbox_inches="tight", pad_inches=0)
171
+
172
+ figure.canvas.draw()
173
+ image = np.array(figure.canvas.renderer._renderer)
174
+ plt.close()
175
+ return image
176
+
177
+ def plot_bboxes(subplot, bboxes, color="red", visibility=None):
178
+ if visibility is None:
179
+ visibility = [1] * len(bboxes)
180
+ for id, bbox in enumerate(bboxes):
181
+ if visibility[id] == 0:
182
+ continue
183
+ w = bbox[2] - bbox[0]
184
+ h = bbox[3] - bbox[1]
185
+ rect = patches.Rectangle(
186
+ bbox[:2], w, h, linewidth=1, edgecolor=color, facecolor="none", linestyle="solid"
187
+ )
188
+ subplot.add_patch(rect)
189
+
190
+ def sort_panels(rects):
191
+ before_rects = convert_to_list_of_lists(rects)
192
+ # slightly erode all rectangles initially to account for imperfect detections
193
+ rects = [erode_rectangle(rect, 0.05) for rect in before_rects]
194
+ G = nx.DiGraph()
195
+ G.add_nodes_from(range(len(rects)))
196
+ for i in range(len(rects)):
197
+ for j in range(len(rects)):
198
+ if i == j:
199
+ continue
200
+ if is_there_a_directed_edge(i, j, rects):
201
+ G.add_edge(i, j, weight=get_distance(rects[i], rects[j]))
202
+ else:
203
+ G.add_edge(j, i, weight=get_distance(rects[i], rects[j]))
204
+ while True:
205
+ cycles = sorted(nx.simple_cycles(G))
206
+ cycles = [cycle for cycle in cycles if len(cycle) > 1]
207
+ if len(cycles) == 0:
208
+ break
209
+ cycle = cycles[0]
210
+ edges = [e for e in zip(cycle, cycle[1:] + cycle[:1])]
211
+ max_cyclic_edge = max(edges, key=lambda x: G.edges[x]["weight"])
212
+ G.remove_edge(*max_cyclic_edge)
213
+ return list(nx.topological_sort(G))
214
+
215
+ def is_strictly_above(rectA, rectB):
216
+ x1A, y1A, x2A, y2A = rectA
217
+ x1B, y1B, x2B, y2B = rectB
218
+ return y2A < y1B
219
+
220
+ def is_strictly_below(rectA, rectB):
221
+ x1A, y1A, x2A, y2A = rectA
222
+ x1B, y1B, x2B, y2B = rectB
223
+ return y2B < y1A
224
+
225
+ def is_strictly_left_of(rectA, rectB):
226
+ x1A, y1A, x2A, y2A = rectA
227
+ x1B, y1B, x2B, y2B = rectB
228
+ return x2A < x1B
229
+
230
+ def is_strictly_right_of(rectA, rectB):
231
+ x1A, y1A, x2A, y2A = rectA
232
+ x1B, y1B, x2B, y2B = rectB
233
+ return x2B < x1A
234
+
235
+ def intersects(rectA, rectB):
236
+ return box(*rectA).intersects(box(*rectB))
237
+
238
+ def is_there_a_directed_edge(a, b, rects):
239
+ rectA = rects[a]
240
+ rectB = rects[b]
241
+ centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2, rectA[1] + (rectA[3] - rectA[1]) / 2]
242
+ centre_of_B = [rectB[0] + (rectB[2] - rectB[0]) / 2, rectB[1] + (rectB[3] - rectB[1]) / 2]
243
+ if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
244
+ return box(*rectA).area > (box(*rectB)).area
245
+ copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
246
+ copy_B = [rectB[0], rectB[1], rectB[2], rectB[3]]
247
+ while True:
248
+ if is_strictly_above(copy_A, copy_B) and not is_strictly_left_of(copy_A, copy_B):
249
+ return 1
250
+ if is_strictly_above(copy_B, copy_A) and not is_strictly_left_of(copy_B, copy_A):
251
+ return 0
252
+ if is_strictly_right_of(copy_A, copy_B) and not is_strictly_below(copy_A, copy_B):
253
+ return 1
254
+ if is_strictly_right_of(copy_B, copy_A) and not is_strictly_below(copy_B, copy_A):
255
+ return 0
256
+ if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
257
+ return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
258
+ if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
259
+ return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
260
+ # otherwise they intersect
261
+ copy_A = erode_rectangle(copy_A, 0.05)
262
+ copy_B = erode_rectangle(copy_B, 0.05)
263
+
264
+ def get_distance(rectA, rectB):
265
+ return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
266
+
267
+ def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
268
+ rects = deepcopy(rects)
269
+ while True:
270
+ xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
271
+ rect_index = [i for i in range(len(rects)) if intersects(rects[i], [xmin, ymin, xmax, ymax])]
272
+ rects_copy = [rect for rect in rects if intersects(rect, [xmin, ymin, xmax, ymax])]
273
+
274
+ # try to split the panels using a "horizontal" lines
275
+ overlapping_y_ranges = merge_overlapping_ranges([(y1, y2) for x1, y1, x2, y2 in rects_copy])
276
+ panel_index_to_split = {}
277
+ for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
278
+ for i, index in enumerate(rect_index):
279
+ if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
280
+ panel_index_to_split[index] = split_index
281
+
282
+ if panel_index_to_split[a] != panel_index_to_split[b]:
283
+ return panel_index_to_split[a] < panel_index_to_split[b]
284
+
285
+ # try to split the panels using a "vertical" lines
286
+ overlapping_x_ranges = merge_overlapping_ranges([(x1, x2) for x1, y1, x2, y2 in rects_copy])
287
+ panel_index_to_split = {}
288
+ for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
289
+ for i, index in enumerate(rect_index):
290
+ if x1 <= rects_copy[i][0] <= rects_copy[i][2] <= x2:
291
+ panel_index_to_split[index] = split_index
292
+ if panel_index_to_split[a] != panel_index_to_split[b]:
293
+ return panel_index_to_split[a] < panel_index_to_split[b]
294
+
295
+ # otherwise, erode the rectangles and try again
296
+ rects = [erode_rectangle(rect, 0.05) for rect in rects]
297
+
298
+ def erode_rectangle(bbox, erosion_factor):
299
+ x1, y1, x2, y2 = bbox
300
+ w, h = x2 - x1, y2 - y1
301
+ cx, cy = x1 + w / 2, y1 + h / 2
302
+ if w < h:
303
+ aspect_ratio = w / h
304
+ erosion_factor_width = erosion_factor * aspect_ratio
305
+ erosion_factor_height = erosion_factor
306
+ else:
307
+ aspect_ratio = h / w
308
+ erosion_factor_width = erosion_factor
309
+ erosion_factor_height = erosion_factor * aspect_ratio
310
+ w = w - w * erosion_factor_width
311
+ h = h - h * erosion_factor_height
312
+ x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
313
+ return [x1, y1, x2, y2]
314
+
315
+ def merge_overlapping_ranges(ranges):
316
+ """
317
+ ranges: list of tuples (x1, x2)
318
+ """
319
+ if len(ranges) == 0:
320
+ return []
321
+ ranges = sorted(ranges, key=lambda x: x[0])
322
+ merged_ranges = []
323
+ for i, r in enumerate(ranges):
324
+ if i == 0:
325
+ prev_x1, prev_x2 = r
326
+ continue
327
+ x1, x2 = r
328
+ if x1 > prev_x2:
329
+ merged_ranges.append((prev_x1, prev_x2))
330
+ prev_x1, prev_x2 = x1, x2
331
+ else:
332
+ prev_x2 = max(prev_x2, x2)
333
+ merged_ranges.append((prev_x1, prev_x2))
334
+ return merged_ranges
335
+
336
+ def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
337
+ text_bboxes = convert_to_list_of_lists(text_bboxes)
338
+ sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
339
+
340
+ if len(text_bboxes) == 0:
341
+ return []
342
+
343
+ def indices_of_same_elements(nums):
344
+ groups = groupby(range(len(nums)), key=lambda i: nums[i])
345
+ return [list(indices) for _, indices in groups]
346
+
347
+ panel_id_for_text = get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes)
348
+ indices_of_texts = list(range(len(text_bboxes)))
349
+ indices_of_texts, panel_id_for_text = zip(*sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
350
+ indices_of_texts = list(indices_of_texts)
351
+ grouped_indices = indices_of_same_elements(panel_id_for_text)
352
+ for group in grouped_indices:
353
+ subset_of_text_indices = [indices_of_texts[i] for i in group]
354
+ text_bboxes_of_subset = [text_bboxes[i] for i in subset_of_text_indices]
355
+ sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
356
+ indices_of_texts[group[0] : group[-1] + 1] = [subset_of_text_indices[i] for i in sorted_subset_indices]
357
+ return indices_of_texts
358
+
359
+ def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
360
+ text_to_panel_mapping = []
361
+ for text_bbox in text_bboxes:
362
+ shapely_text_polygon = box(*text_bbox)
363
+ all_intersections = []
364
+ all_distances = []
365
+ if len(sorted_panel_bboxes) == 0:
366
+ text_to_panel_mapping.append(-1)
367
+ continue
368
+ for j, annotation in enumerate(sorted_panel_bboxes):
369
+ shapely_annotation_polygon = box(*annotation)
370
+ if shapely_text_polygon.intersects(shapely_annotation_polygon):
371
+ all_intersections.append((shapely_text_polygon.intersection(shapely_annotation_polygon).area, j))
372
+ all_distances.append((shapely_text_polygon.distance(shapely_annotation_polygon), j))
373
+ if len(all_intersections) == 0:
374
+ text_to_panel_mapping.append(min(all_distances, key=lambda x: x[0])[1])
375
+ else:
376
+ text_to_panel_mapping.append(max(all_intersections, key=lambda x: x[0])[1])
377
+ return text_to_panel_mapping
378
+
379
+ def sort_texts_within_panel(rects):
380
+ smallest_y = float("inf")
381
+ greatest_x = float("-inf")
382
+ for i, rect in enumerate(rects):
383
+ x1, y1, x2, y2 = rect
384
+ smallest_y = min(smallest_y, y1)
385
+ greatest_x = max(greatest_x, x2)
386
+
387
+ reference_point = Point(greatest_x, smallest_y)
388
+
389
+ polygons_and_index = []
390
+ for i, rect in enumerate(rects):
391
+ x1, y1, x2, y2 = rect
392
+ polygons_and_index.append((box(x1,y1,x2,y2), i))
393
+ # sort points by closest to reference point
394
+ polygons_and_index = sorted(polygons_and_index, key=lambda x: reference_point.distance(x[0]))
395
+ indices = [x[1] for x in polygons_and_index]
396
+ return indices
397
+
398
+ def x1y1wh_to_x1y1x2y2(bbox):
399
+ x1, y1, w, h = bbox
400
+ return [x1, y1, x1 + w, y1 + h]
401
+
402
+ def x1y1x2y2_to_xywh(bbox):
403
+ x1, y1, x2, y2 = bbox
404
+ return [x1, y1, x2 - x1, y2 - y1]
405
+
406
+ def convert_to_list_of_lists(rects):
407
+ if isinstance(rects, torch.Tensor):
408
+ return rects.tolist()
409
+ if isinstance(rects, np.ndarray):
410
+ return rects.tolist()
411
+ return [[a, b, c, d] for a, b, c, d in rects]