Dileep7729's picture
Update app.py
545ca3f verified
raw
history blame
1.92 kB
from PIL import Image
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor
import gradio as gr
import torch
import pytesseract
import os
import os
# Set the Linux path for Tesseract
pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
print("Tesseract version:", os.popen("tesseract --version").read())
# Load the fine-tuned model and processor from local files
model_path = "./" # Path to the directory containing the uploaded model files
model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
processor = LayoutLMv3Processor.from_pretrained(model_path, apply_ocr=True)
# Define label mapping
id2label = {0: "company", 1: "date", 2: "address", 3: "total", 4: "other"}
# Define prediction function
def predict_receipt(image):
try:
# Preprocess the image
encoding = processor(image, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
input_ids = encoding["input_ids"]
attention_mask = encoding["attention_mask"]
bbox = encoding["bbox"]
pixel_values = encoding["pixel_values"]
# Get model predictions
outputs = model(input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, pixel_values=pixel_values)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
# Map predictions to labels
labeled_output = {id2label[pred]: idx for idx, pred in enumerate(predictions) if pred != 4}
return labeled_output
except Exception as e:
return {"error": str(e)}
# Create Gradio Interface
interface = gr.Interface(
fn=predict_receipt,
inputs=gr.Image(type="pil"),
outputs="json",
title="Receipt Information Analyzer",
description="Upload a scanned receipt image to extract information like company name, date, address, and total."
)
# Launch the interface
if __name__ == "__main__":
interface.launch()