Spaces:
Running
Running
File size: 3,690 Bytes
e41ca05 98c2996 e41ca05 35e30e2 e41ca05 97095a3 e41ca05 97095a3 e41ca05 97095a3 e41ca05 f83550f e41ca05 4064ff6 209d7d1 e41ca05 ff7a14a e41ca05 4064ff6 e41ca05 98c2996 a5a236f bf30ca5 e41ca05 cbee1cf a5a236f e41ca05 fece6d1 e41ca05 fece6d1 e41ca05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import torch
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from utils import OCR, unnormalize_box
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-sroie")
id2label = model.config.id2label
label2id = model.config.label2id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
def blur(image, boxes):
image = np.array(image)
for box in boxes:
blur_x = int(box[0])
blur_y = int(box[1])
blur_width = int(box[2]-box[0])
blur_height = int(box[3]-box[1])
roi = image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width]
blur_image = cv2.GaussianBlur(roi, (201, 201), 0)
image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width] = blur_image
return Image.fromarray(image, 'RGB')
def prediction(image):
boxes, words = OCR(image)
encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
offset_mapping = encoding.pop('offset_mapping')
for k, v in encoding.items():
encoding[k] = v.to(device)
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
token_boxes = encoding.bbox.squeeze().tolist()
probabilities = torch.softmax(outputs.logits, dim=-1)
confidence_scores = probabilities.max(-1).values.squeeze().tolist()
inp_ids = encoding.input_ids.squeeze().tolist()
inp_words = [tokenizer.decode(i) for i in inp_ids]
width, height = image.size
is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
true_words = []
for id, i in enumerate(inp_words):
if not is_subword[id]:
true_words.append(i)
else:
true_words[-1] = true_words[-1]+i
true_predictions = true_predictions[1:-1]
true_boxes = true_boxes[1:-1]
true_words = true_words[1:-1]
true_confidence_scores = true_confidence_scores[1:-1]
#for i, j in enumerate(true_confidence_scores):
# if j < 0.8: #####################################
# true_predictions[i] = "O"
d = {}
for id, i in enumerate(true_predictions):
#rimuovo i prefissi
if i != "O":
i = i[2:]
if i not in d.keys():
d[i] = true_words[id]
else:
d[i] = d[i] + ", " + true_words[id]
d = {k: v.strip() for (k, v) in d.items()}
if "O" in d: d.pop("O")
if "S-TOTAL" in d: d.pop("S-TOTAL")
blur_boxes = []
for prediction, box in zip(true_predictions, true_boxes):
if prediction != 'O' and prediction != 'S-TOTAL':
blur_boxes.append(box)
image = (blur(image, blur_boxes))
#draw = ImageDraw.Draw(image, "RGBA")
#font = ImageFont.load_default()
#for prediction, box in zip(true_predictions, true_boxes):
# draw.rectangle(box)
# draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black", font_size="8")
return d, image
|