File size: 1,665 Bytes
5b07194 0320eb2 bd746ed a9eefb1 5b07194 b874912 5b07194 bd746ed 5b07194 0320eb2 bd746ed 0320eb2 bd746ed 0320eb2 bd746ed 0320eb2 bd746ed 5b07194 bd746ed 5b07194 0320eb2 5b07194 bd746ed 5b07194 bd746ed 5b07194 bd746ed b443de4 03231a0 0320eb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from PIL import Image
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor
import gradio as gr
import torch
# Load the fine-tuned model and processor
model_path = "quadranttechnologies/Receipt_Image_Analyzer"
model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
processor = LayoutLMv3Processor.from_pretrained(model_path)
# Define label mapping
id2label = {0: "company", 1: "date", 2: "address", 3: "total", 4: "other"}
# Define prediction function
def predict_receipt(image):
try:
# Preprocess the image
encoding = processor(image, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
input_ids = encoding["input_ids"]
attention_mask = encoding["attention_mask"]
bbox = encoding["bbox"]
pixel_values = encoding["pixel_values"]
# Get model predictions
outputs = model(input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, pixel_values=pixel_values)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
# Map predictions to labels
labeled_output = {id2label[pred]: idx for idx, pred in enumerate(predictions) if pred != 4}
return labeled_output
except Exception as e:
return {"error": str(e)}
# Create Gradio Interface
interface = gr.Interface(
fn=predict_receipt,
inputs=gr.Image(type="pil"),
outputs="json",
title="Receipt Information Analyzer",
description="Upload a scanned receipt image to extract information like company name, date, address, and total."
)
# Launch the interface
if __name__ == "__main__":
interface.launch()
|