Commit
·
819fc2b
1
Parent(s):
c35bb2c
Update files/functions.py
Browse files- files/functions.py +74 -35
files/functions.py
CHANGED
@@ -70,36 +70,36 @@ label2color = {
|
|
70 |
|
71 |
# bounding boxes start and end of a sequence
|
72 |
cls_box = [0, 0, 0, 0]
|
|
|
|
|
73 |
sep_box_lilt = cls_box
|
|
|
|
|
74 |
sep_box_layoutxlm = [1000, 1000, 1000, 1000]
|
|
|
75 |
|
76 |
# models
|
77 |
model_id_lilt = "pierreguillou/lilt-xlm-roberta-base-finetuned-with-DocLayNet-base-at-paragraphlevel-ml512"
|
|
|
78 |
model_id_layoutxlm = "pierreguillou/layout-xlm-base-finetuned-with-DocLayNet-base-at-paragraphlevel-ml512"
|
|
|
79 |
|
80 |
# tokenizer for LayoutXLM
|
81 |
tokenizer_id_layoutxlm = "xlm-roberta-base"
|
82 |
|
83 |
# (tokenization) The maximum length of a feature (sequence)
|
84 |
-
if str(384) in model_id_lilt:
|
85 |
-
|
86 |
-
elif str(512) in model_id_lilt:
|
87 |
-
|
88 |
-
else:
|
89 |
-
print("Error with max_length_lilt of chunks!")
|
90 |
-
|
91 |
-
if str(384) in model_id_layoutxlm:
|
92 |
-
max_length_layoutxlm = 384
|
93 |
-
elif str(512) in model_id_layoutxlm:
|
94 |
-
max_length_layoutxlm = 512
|
95 |
else:
|
96 |
-
print("Error with
|
97 |
|
98 |
# (tokenization) overlap
|
99 |
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.
|
100 |
|
101 |
# max PDF page images that will be displayed
|
102 |
-
max_imgboxes =
|
103 |
|
104 |
# get files
|
105 |
examples_dir = 'files/'
|
@@ -159,6 +159,9 @@ tokenizer_lilt = AutoTokenizer.from_pretrained(model_id_lilt)
|
|
159 |
model_lilt = AutoModelForTokenClassification.from_pretrained(model_id_lilt);
|
160 |
model_lilt.to(device);
|
161 |
|
|
|
|
|
|
|
162 |
## model LayoutXLM
|
163 |
from transformers import LayoutLMv2ForTokenClassification # LayoutXLMTokenizerFast,
|
164 |
model_layoutxlm = LayoutLMv2ForTokenClassification.from_pretrained(model_id_layoutxlm);
|
@@ -172,14 +175,8 @@ feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
|
172 |
from transformers import AutoTokenizer
|
173 |
tokenizer_layoutxlm = AutoTokenizer.from_pretrained(tokenizer_id_layoutxlm)
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
label2id_lilt = model_lilt.config.label2id
|
178 |
-
num_labels_lilt = len(id2label_lilt)
|
179 |
-
|
180 |
-
id2label_layoutxlm = model_layoutxlm.config.id2label
|
181 |
-
label2id_layoutxlm = model_layoutxlm.config.label2id
|
182 |
-
num_labels_layoutxlm = len(id2label_layoutxlm)
|
183 |
|
184 |
|
185 |
# General
|
@@ -519,14 +516,10 @@ def extraction_data_from_image(images):
|
|
519 |
from datasets import Dataset
|
520 |
dataset = Dataset.from_dict({"images_ids": images_ids_list, "images": images_list, "images_pixels": images_pixels_list, "page_no": page_no_list, "num_pages": num_pages_list, "texts_line": texts_lines_list, "texts_par": texts_pars_list, "texts_lines_par": texts_lines_par_list, "bboxes_par": par_boxes_list, "bboxes_lines_par":lines_par_boxes_list})
|
521 |
|
522 |
-
|
523 |
# print(f"The text data was successfully extracted by the OCR!")
|
524 |
|
525 |
return dataset, texts_lines, texts_pars, texts_lines_par, row_indexes, par_boxes, line_boxes, lines_par_boxes
|
526 |
|
527 |
-
|
528 |
-
# Inference
|
529 |
-
|
530 |
def prepare_inference_features_paragraph(example, tokenizer, max_length, cls_box, sep_box):
|
531 |
|
532 |
images_ids_list, chunks_ids_list, input_ids_list, attention_mask_list, bb_list, images_pixels_list = list(), list(), list(), list(), list(), list()
|
@@ -711,8 +704,8 @@ def predictions_token_level(images, custom_encoded_dataset, model_id, model):
|
|
711 |
|
712 |
from functools import reduce
|
713 |
|
714 |
-
# Get predictions (
|
715 |
-
def
|
716 |
|
717 |
ten_probs_dict, ten_input_ids_dict, ten_bboxes_dict = dict(), dict(), dict()
|
718 |
bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df = dict(), dict(), dict(), dict()
|
@@ -788,24 +781,69 @@ def predictions_paragraph_level(max_length, tokenizer, id2label, dataset, output
|
|
788 |
prob_label = reduce(lambda x, y: x*y, probs_list)
|
789 |
prob_label = prob_label**(1./(len(probs_list))) # normalization
|
790 |
probs_label.append(prob_label)
|
791 |
-
max_value = max(probs_label)
|
792 |
-
max_index = probs_label.index(max_value)
|
793 |
-
probs_bbox[str(bbox)] = max_index
|
|
|
794 |
|
795 |
bboxes_list_dict[image_id] = bboxes_list
|
796 |
input_ids_dict_dict[image_id] = input_ids_dict
|
797 |
probs_dict_dict[image_id] = probs_bbox
|
798 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
799 |
df[image_id] = pd.DataFrame()
|
800 |
-
df[image_id]["bboxes"] =
|
801 |
-
df[image_id]["texts"] = [
|
802 |
-
df[image_id]["labels"] = [id2label[probs_bbox[str(bbox)]] for bbox in
|
803 |
|
804 |
-
return
|
805 |
|
806 |
else:
|
807 |
print("An error occurred while getting predictions!")
|
808 |
|
|
|
809 |
# Get labeled images with lines bounding boxes
|
810 |
def get_labeled_images(id2label, dataset, images_ids_list, bboxes_list_dict, probs_dict_dict):
|
811 |
|
@@ -925,4 +963,5 @@ def display_chunk_lines_inference(dataset, encoded_dataset, index_chunk=None):
|
|
925 |
print("\n>> Dataframe of annotated lines\n")
|
926 |
cols = ["texts", "bboxes"]
|
927 |
df = df[cols]
|
928 |
-
display(df)
|
|
|
|
70 |
|
71 |
# bounding boxes start and end of a sequence
|
72 |
cls_box = [0, 0, 0, 0]
|
73 |
+
cls_box1, cls_box2 = cls_box, cls_box
|
74 |
+
|
75 |
sep_box_lilt = cls_box
|
76 |
+
sep_box1 = sep_box_lilt
|
77 |
+
|
78 |
sep_box_layoutxlm = [1000, 1000, 1000, 1000]
|
79 |
+
sep_box2 = sep_box_layoutxlm
|
80 |
|
81 |
# models
|
82 |
model_id_lilt = "pierreguillou/lilt-xlm-roberta-base-finetuned-with-DocLayNet-base-at-paragraphlevel-ml512"
|
83 |
+
model_id1 = model_id_lilt
|
84 |
model_id_layoutxlm = "pierreguillou/layout-xlm-base-finetuned-with-DocLayNet-base-at-paragraphlevel-ml512"
|
85 |
+
model_id2 = model_id_layoutxlm
|
86 |
|
87 |
# tokenizer for LayoutXLM
|
88 |
tokenizer_id_layoutxlm = "xlm-roberta-base"
|
89 |
|
90 |
# (tokenization) The maximum length of a feature (sequence)
|
91 |
+
if (str(384) in model_id_lilt) and (str(384) in model_id_layoutxlm):
|
92 |
+
max_length = 384
|
93 |
+
elif (str(512) in model_id_lilt) and (str(512) in model_id_layoutxlm):
|
94 |
+
max_length = 512
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
else:
|
96 |
+
print("Error with max_length of chunks!")
|
97 |
|
98 |
# (tokenization) overlap
|
99 |
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.
|
100 |
|
101 |
# max PDF page images that will be displayed
|
102 |
+
max_imgboxes = 2
|
103 |
|
104 |
# get files
|
105 |
examples_dir = 'files/'
|
|
|
159 |
model_lilt = AutoModelForTokenClassification.from_pretrained(model_id_lilt);
|
160 |
model_lilt.to(device);
|
161 |
|
162 |
+
tokenizer1 = tokenizer_lilt
|
163 |
+
model1 = model_lilt
|
164 |
+
|
165 |
## model LayoutXLM
|
166 |
from transformers import LayoutLMv2ForTokenClassification # LayoutXLMTokenizerFast,
|
167 |
model_layoutxlm = LayoutLMv2ForTokenClassification.from_pretrained(model_id_layoutxlm);
|
|
|
175 |
from transformers import AutoTokenizer
|
176 |
tokenizer_layoutxlm = AutoTokenizer.from_pretrained(tokenizer_id_layoutxlm)
|
177 |
|
178 |
+
tokenizer2 = tokenizer_layoutxlm
|
179 |
+
model2 = model_layoutxlm
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
|
182 |
# General
|
|
|
516 |
from datasets import Dataset
|
517 |
dataset = Dataset.from_dict({"images_ids": images_ids_list, "images": images_list, "images_pixels": images_pixels_list, "page_no": page_no_list, "num_pages": num_pages_list, "texts_line": texts_lines_list, "texts_par": texts_pars_list, "texts_lines_par": texts_lines_par_list, "bboxes_par": par_boxes_list, "bboxes_lines_par":lines_par_boxes_list})
|
518 |
|
|
|
519 |
# print(f"The text data was successfully extracted by the OCR!")
|
520 |
|
521 |
return dataset, texts_lines, texts_pars, texts_lines_par, row_indexes, par_boxes, line_boxes, lines_par_boxes
|
522 |
|
|
|
|
|
|
|
523 |
def prepare_inference_features_paragraph(example, tokenizer, max_length, cls_box, sep_box):
|
524 |
|
525 |
images_ids_list, chunks_ids_list, input_ids_list, attention_mask_list, bb_list, images_pixels_list = list(), list(), list(), list(), list(), list()
|
|
|
704 |
|
705 |
from functools import reduce
|
706 |
|
707 |
+
# Get predictions (paragraph level)
|
708 |
+
def predictions_probs_paragraph_level(max_length, tokenizer, id2label, dataset, outputs, images_ids_list, chunk_ids, input_ids, bboxes, cls_box, sep_box):
|
709 |
|
710 |
ten_probs_dict, ten_input_ids_dict, ten_bboxes_dict = dict(), dict(), dict()
|
711 |
bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df = dict(), dict(), dict(), dict()
|
|
|
781 |
prob_label = reduce(lambda x, y: x*y, probs_list)
|
782 |
prob_label = prob_label**(1./(len(probs_list))) # normalization
|
783 |
probs_label.append(prob_label)
|
784 |
+
# max_value = max(probs_label)
|
785 |
+
# max_index = probs_label.index(max_value)
|
786 |
+
# probs_bbox[str(bbox)] = max_index
|
787 |
+
probs_bbox[str(bbox)] = probs_label
|
788 |
|
789 |
bboxes_list_dict[image_id] = bboxes_list
|
790 |
input_ids_dict_dict[image_id] = input_ids_dict
|
791 |
probs_dict_dict[image_id] = probs_bbox
|
792 |
|
793 |
+
# df[image_id] = pd.DataFrame()
|
794 |
+
# df[image_id]["bboxes"] = bboxes_list
|
795 |
+
# df[image_id]["texts"] = [tokenizer.decode(input_ids_dict[str(bbox)]) for bbox in bboxes_list]
|
796 |
+
# df[image_id]["labels"] = [id2label[probs_bbox[str(bbox)]] for bbox in bboxes_list]
|
797 |
+
|
798 |
+
return probs_bbox, bboxes_list_dict, input_ids_dict_dict, probs_dict_dict #, df
|
799 |
+
|
800 |
+
else:
|
801 |
+
print("An error occurred while getting predictions!")
|
802 |
+
|
803 |
+
from functools import reduce
|
804 |
+
|
805 |
+
# Get predictions (paragraph level)
|
806 |
+
def predictions_paragraph_level(max_length, tokenizer1, id2label, dataset, outputs1, images_ids_list1, chunk_ids1, input_ids1, bboxes1, cls_box1, sep_box1, tokenizer2, outputs2, images_ids_list2, chunk_ids2, input_ids2, bboxes2, cls_box2, sep_box2):
|
807 |
+
|
808 |
+
bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df = dict(), dict(), dict(), dict()
|
809 |
+
|
810 |
+
probs_bbox1, bboxes_list_dict1, input_ids_dict_dict1, probs_dict_dict1 = predictions_probs_paragraph_level(max_length, tokenizer1, id2label, dataset, outputs1, images_ids_list1, chunk_ids1, input_ids1, bboxes1, cls_box1, sep_box1)
|
811 |
+
probs_bbox2, bboxes_list_dict2, input_ids_dict_dict2, probs_dict_dict2 = predictions_probs_paragraph_level(max_length, tokenizer2, id2label, dataset, outputs2, images_ids_list2, chunk_ids2, input_ids2, bboxes2, cls_box2, sep_box2)
|
812 |
+
|
813 |
+
if len(images_ids_list1) > 0:
|
814 |
+
|
815 |
+
for i, image_id in enumerate(images_ids_list1):
|
816 |
+
|
817 |
+
bboxes_list1 = bboxes_list_dict1[image_id]
|
818 |
+
input_ids_dict1 = input_ids_dict_dict1[image_id]
|
819 |
+
probs_bbox1 = probs_dict_dict1[image_id]
|
820 |
+
|
821 |
+
bboxes_list2 = bboxes_list_dict2[image_id]
|
822 |
+
input_ids_dict2 = input_ids_dict_dict2[image_id]
|
823 |
+
probs_bbox2 = probs_dict_dict2[image_id]
|
824 |
+
|
825 |
+
probs_bbox = dict()
|
826 |
+
for bbox in bboxes_list1:
|
827 |
+
prob_bbox = [(p1+p2)/2 for p1,p2 in zip(probs_bbox1[str(bbox)], probs_bbox2[str(bbox)])]
|
828 |
+
max_value = max(prob_bbox)
|
829 |
+
max_index = prob_bbox.index(max_value)
|
830 |
+
probs_bbox[str(bbox)] = max_index
|
831 |
+
|
832 |
+
bboxes_list_dict[image_id] = bboxes_list1
|
833 |
+
input_ids_dict_dict[image_id] = input_ids_dict1
|
834 |
+
probs_dict_dict[image_id] = probs_bbox
|
835 |
+
|
836 |
df[image_id] = pd.DataFrame()
|
837 |
+
df[image_id]["bboxes"] = bboxes_list1
|
838 |
+
df[image_id]["texts"] = [tokenizer1.decode(input_ids_dict1[str(bbox)]) for bbox in bboxes_list1]
|
839 |
+
df[image_id]["labels"] = [id2label[probs_bbox[str(bbox)]] for bbox in bboxes_list1]
|
840 |
|
841 |
+
return bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df
|
842 |
|
843 |
else:
|
844 |
print("An error occurred while getting predictions!")
|
845 |
|
846 |
+
|
847 |
# Get labeled images with lines bounding boxes
|
848 |
def get_labeled_images(id2label, dataset, images_ids_list, bboxes_list_dict, probs_dict_dict):
|
849 |
|
|
|
963 |
print("\n>> Dataframe of annotated lines\n")
|
964 |
cols = ["texts", "bboxes"]
|
965 |
df = df[cols]
|
966 |
+
display(df)
|
967 |
+
|