File size: 2,313 Bytes
d8f642d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from typing import Dict, List, Any
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor
import torch
from subprocess import run
# install tesseract-ocr and pytesseract
run("apt install -y tesseract-ocr", shell=True, check=True)
run("pip install pytesseract", shell=True, check=True)
# helper function to unnormalize bboxes for drawing onto the image
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.model = LayoutLMForTokenClassification.from_pretrained(path).to(device)
self.processor = LayoutLMv2Processor.from_pretrained(path)
def __call__(self, data: Dict[str, bytes]) -> Dict[str, List[Any]]:
"""
Args:
data (:obj:):
includes the deserialized image file as PIL.Image
"""
# process input
image = data.pop("inputs", data)
# process image
encoding = self.processor(image, return_tensors="pt")
# run prediction
with torch.inference_mode():
outputs = self.model(
input_ids=encoding.input_ids.to(device),
bbox=encoding.bbox.to(device),
attention_mask=encoding.attention_mask.to(device),
token_type_ids=encoding.token_type_ids.to(device),
)
predictions = outputs.logits.softmax(-1)
# post process output
result = []
for item, inp_ids, bbox in zip(
predictions.squeeze(0).cpu(), encoding.input_ids.squeeze(0).cpu(), encoding.bbox.squeeze(0).cpu()
):
label = self.model.config.id2label[int(item.argmax().cpu())]
if label == "O":
continue
score = item.max().item()
text = self.processor.tokenizer.decode(inp_ids)
bbox = unnormalize_box(bbox.tolist(), image.width, image.height)
result.append({"label": label, "score": score, "text": text, "bbox": bbox})
return {"predictions": result}
|