Update app.py
Browse files
app.py
CHANGED
@@ -104,10 +104,12 @@ def pred_resume(pdf_path) -> dict:
|
|
104 |
if text.replace(" ","") != "":
|
105 |
bboxes.append(normalize_bbox([xmin, ymin, xmax, ymax], image.size))
|
106 |
words.append(decontracted(text))
|
|
|
107 |
fake_label = ["O"] * len(words)
|
108 |
encoding = processor(image, words, boxes=bboxes, word_labels=fake_label, truncation=True, stride=256,
|
109 |
padding="max_length", max_length=512, return_overflowing_tokens=True, return_offsets_mapping=True)
|
110 |
labels = encoding["labels"]
|
|
|
111 |
offset_mapping = encoding.pop('offset_mapping')
|
112 |
overflow_to_sample_mapping = encoding.pop('overflow_to_sample_mapping')
|
113 |
encoding = {k: torch.tensor(v) for k,v in encoding.items() if k != "labels"}
|
@@ -128,12 +130,16 @@ def pred_resume(pdf_path) -> dict:
|
|
128 |
if i>0:
|
129 |
labels[i] = labels[i][256:]
|
130 |
predictions[i] = predictions[i][256:]
|
|
|
131 |
predictions = [j for i in predictions for j in i]
|
|
|
132 |
labels = [j for i in labels for j in i]
|
133 |
true_predictions = [id2label[pred] for pred, label in zip(predictions, labels) if label != -100]
|
134 |
-
for
|
|
|
135 |
if pred in key_list:
|
136 |
-
result[pred].append(
|
|
|
137 |
return str(result)
|
138 |
def norm(result: str) -> str:
|
139 |
result = ast.literal_eval(result)
|
|
|
104 |
if text.replace(" ","") != "":
|
105 |
bboxes.append(normalize_bbox([xmin, ymin, xmax, ymax], image.size))
|
106 |
words.append(decontracted(text))
|
107 |
+
text_reverse = {str(bboxes[i]): words[i] for i,_ in enumerate(words)}
|
108 |
fake_label = ["O"] * len(words)
|
109 |
encoding = processor(image, words, boxes=bboxes, word_labels=fake_label, truncation=True, stride=256,
|
110 |
padding="max_length", max_length=512, return_overflowing_tokens=True, return_offsets_mapping=True)
|
111 |
labels = encoding["labels"]
|
112 |
+
key_box = encoding["bbox"]
|
113 |
offset_mapping = encoding.pop('offset_mapping')
|
114 |
overflow_to_sample_mapping = encoding.pop('overflow_to_sample_mapping')
|
115 |
encoding = {k: torch.tensor(v) for k,v in encoding.items() if k != "labels"}
|
|
|
130 |
if i>0:
|
131 |
labels[i] = labels[i][256:]
|
132 |
predictions[i] = predictions[i][256:]
|
133 |
+
key_box[i] = key_box[i][256:]
|
134 |
predictions = [j for i in predictions for j in i]
|
135 |
+
key_box = [j for i in key_box for j in i]
|
136 |
labels = [j for i in labels for j in i]
|
137 |
true_predictions = [id2label[pred] for pred, label in zip(predictions, labels) if label != -100]
|
138 |
+
key_box = [box for box, label in zip(key_box, labels) if label != -100]
|
139 |
+
for box, pred in zip(key_box, true_predictions):
|
140 |
if pred in key_list:
|
141 |
+
result[pred].append(text_reverse[str(box)])
|
142 |
+
result = {k: list(set(v)) for k, v in result.items()}
|
143 |
return str(result)
|
144 |
def norm(result: str) -> str:
|
145 |
result = ast.literal_eval(result)
|