File size: 1,665 Bytes
5b07194
0320eb2
 
bd746ed
a9eefb1
5b07194
 
 
 
b874912
5b07194
 
bd746ed
5b07194
 
0320eb2
 
 
 
 
 
 
bd746ed
0320eb2
 
 
bd746ed
0320eb2
 
bd746ed
0320eb2
 
 
bd746ed
5b07194
bd746ed
5b07194
0320eb2
5b07194
 
 
bd746ed
 
5b07194
bd746ed
5b07194
 
bd746ed
 
 
 
b443de4
03231a0
0320eb2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from PIL import Image
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor
import gradio as gr
import torch

# Load the fine-tuned model and processor
model_path = "quadranttechnologies/Receipt_Image_Analyzer"  
model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
processor = LayoutLMv3Processor.from_pretrained(model_path)

# Define label mapping
id2label = {0: "company", 1: "date", 2: "address", 3: "total", 4: "other"}

# Define prediction function
def predict_receipt(image):
    try:
        # Preprocess the image
        encoding = processor(image, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
        input_ids = encoding["input_ids"]
        attention_mask = encoding["attention_mask"]
        bbox = encoding["bbox"]
        pixel_values = encoding["pixel_values"]

        # Get model predictions
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, pixel_values=pixel_values)
        predictions = outputs.logits.argmax(-1).squeeze().tolist()

        # Map predictions to labels
        labeled_output = {id2label[pred]: idx for idx, pred in enumerate(predictions) if pred != 4}

        return labeled_output
    except Exception as e:
        return {"error": str(e)}

# Create Gradio Interface
interface = gr.Interface(
    fn=predict_receipt,
    inputs=gr.Image(type="pil"),
    outputs="json",
    title="Receipt Information Analyzer",
    description="Upload a scanned receipt image to extract information like company name, date, address, and total."
)

# Launch the interface
if __name__ == "__main__":
    interface.launch()