|
from PIL import Image |
|
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor |
|
import gradio as gr |
|
import torch |
|
import pytesseract |
|
|
|
|
|
pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract" |
|
|
|
|
|
|
|
|
|
model_path = "./" |
|
model = LayoutLMv3ForTokenClassification.from_pretrained(model_path) |
|
processor = LayoutLMv3Processor.from_pretrained(model_path, apply_ocr=True) |
|
|
|
|
|
|
|
id2label = {0: "company", 1: "date", 2: "address", 3: "total", 4: "other"} |
|
|
|
|
|
def predict_receipt(image): |
|
try: |
|
|
|
encoding = processor(image, return_tensors="pt", truncation=True, padding="max_length", max_length=512) |
|
input_ids = encoding["input_ids"] |
|
attention_mask = encoding["attention_mask"] |
|
bbox = encoding["bbox"] |
|
pixel_values = encoding["pixel_values"] |
|
|
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, pixel_values=pixel_values) |
|
predictions = outputs.logits.argmax(-1).squeeze().tolist() |
|
|
|
|
|
labeled_output = {id2label[pred]: idx for idx, pred in enumerate(predictions) if pred != 4} |
|
|
|
return labeled_output |
|
except Exception as e: |
|
return {"error": str(e)} |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_receipt, |
|
inputs=gr.Image(type="pil"), |
|
outputs="json", |
|
title="Receipt Information Analyzer", |
|
description="Upload a scanned receipt image to extract information like company name, date, address, and total." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|