Spaces:
Sleeping
Sleeping
File size: 2,955 Bytes
52c57e0 9b31f19 52c57e0 1655071 52c57e0 b248a87 52c57e0 b248a87 52c57e0 b248a87 52c57e0 a8ea53b f4b0061 358c7b8 0581571 52c57e0 80ff823 f4b0061 52c57e0 2a1d11c 52c57e0 2a1d11c 52c57e0 e3e9c8c 52c57e0 70bfa79 52c57e0 80ff823 52c57e0 ad0f1c1 52c57e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
import numpy as np
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from PIL import Image, ImageDraw, ImageFont
from utils import OCR, unnormalize_box
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-base-cord", apply_ocr=False)
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-base-cord", apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-base-cord")
id2label = model.config.id2label
label2id = model.config.label2id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
def prediction(image):
boxes, words = OCR(image)
encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
offset_mapping = encoding.pop('offset_mapping')
for k, v in encoding.items():
encoding[k] = v.to(device)
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
token_boxes = encoding.bbox.squeeze().tolist()
probabilities = torch.softmax(outputs.logits, dim=-1)
confidence_scores = probabilities.max(-1).values.squeeze().tolist()
inp_ids = encoding.input_ids.squeeze().tolist()
inp_words = [tokenizer.decode(i) for i in inp_ids]
width, height = image.size
is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
true_words = []
for id, i in enumerate(inp_words):
if not is_subword[id]:
true_words.append(i)
else:
true_words[-1] = true_words[-1]+i
true_predictions = true_predictions[1:-1]
true_boxes = true_boxes[1:-1]
true_words = true_words[1:-1]
true_confidence_scores = true_confidence_scores[1:-1]
for i, conf in enumerate(true_confidence_scores):
if conf < 0.5 :
true_predictions[i] = "O"
d = {}
for id, i in enumerate(true_predictions):
if i != "O":
i = i[2:]
if i not in d.keys():
d[i] = true_words[id]
else:
d[i] = d[i] + ", " + true_words[id]
d = {k: v.strip() for (k, v) in d.items()}
if "O" in d: d.pop("O")
# TODO:process the json
draw = ImageDraw.Draw(image, "RGBA")
font = ImageFont.load_default()
for prediction, box, confidence in zip(true_predictions, true_boxes, true_confidence_scores):
draw.rectangle(box)
draw.text((box[0]+10, box[1]-10), text=prediction+ ", "+ str(confidence), font=font, fill="black", font_size="15")
return d, image
|