import gradio as gr from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor from PIL import Image import torch # Load the fine-tuned model and processor model_path = "quadranttechnologies/Receipt_Image_Analyzer" model = LayoutLMv3ForTokenClassification.from_pretrained(model_path) processor = LayoutLMv3Processor.from_pretrained(model_path) # Define label mapping id2label = {0: "company", 1: "date", 2: "address", 3: "total", 4: "other"} # Define prediction function def predict_receipt(image): # Preprocess the image encoding = processor(image, return_tensors="pt", truncation=True, padding="max_length", max_length=512) input_ids = encoding["input_ids"] attention_mask = encoding["attention_mask"] bbox = encoding["bbox"] pixel_values = encoding["pixel_values"] # Get model predictions outputs = model(input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, pixel_values=pixel_values) predictions = outputs.logits.argmax(-1).squeeze().tolist() # Map predictions to labels labeled_output = {id2label[pred]: idx for idx, pred in enumerate(predictions) if pred != 4} return labeled_output # Create Gradio Interface interface = gr.Interface( fn=predict_receipt, inputs=gr.inputs.Image(type="pil"), outputs="json", title="Receipt Information Analyzer", description="Upload a scanned receipt image to extract information like company name, date, address, and total." ) # Launch the interface if __name__ == "__main__": interface.launch()