Dileep7729 commited on
Commit
0320eb2
·
verified ·
1 Parent(s): 03231a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -1,6 +1,6 @@
1
- import gradio as gr
2
- from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor
3
  from PIL import Image
 
 
4
  import torch
5
 
6
  # Load the fine-tuned model and processor
@@ -13,26 +13,29 @@ id2label = {0: "company", 1: "date", 2: "address", 3: "total", 4: "other"}
13
 
14
  # Define prediction function
15
  def predict_receipt(image):
16
- # Preprocess the image
17
- encoding = processor(image, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
18
- input_ids = encoding["input_ids"]
19
- attention_mask = encoding["attention_mask"]
20
- bbox = encoding["bbox"]
21
- pixel_values = encoding["pixel_values"]
 
22
 
23
- # Get model predictions
24
- outputs = model(input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, pixel_values=pixel_values)
25
- predictions = outputs.logits.argmax(-1).squeeze().tolist()
26
 
27
- # Map predictions to labels
28
- labeled_output = {id2label[pred]: idx for idx, pred in enumerate(predictions) if pred != 4}
29
 
30
- return labeled_output
 
 
31
 
32
  # Create Gradio Interface
33
  interface = gr.Interface(
34
  fn=predict_receipt,
35
- inputs=gr.Image(type="pil"), # Use `gr.Image` instead of `gr.inputs.Image`
36
  outputs="json",
37
  title="Receipt Information Analyzer",
38
  description="Upload a scanned receipt image to extract information like company name, date, address, and total."
@@ -48,3 +51,4 @@ if __name__ == "__main__":
48
 
49
 
50
 
 
 
 
 
1
  from PIL import Image
2
+ from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor
3
+ import gradio as gr
4
  import torch
5
 
6
  # Load the fine-tuned model and processor
 
13
 
14
  # Define prediction function
15
  def predict_receipt(image):
16
+ try:
17
+ # Preprocess the image
18
+ encoding = processor(image, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
19
+ input_ids = encoding["input_ids"]
20
+ attention_mask = encoding["attention_mask"]
21
+ bbox = encoding["bbox"]
22
+ pixel_values = encoding["pixel_values"]
23
 
24
+ # Get model predictions
25
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, pixel_values=pixel_values)
26
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
27
 
28
+ # Map predictions to labels
29
+ labeled_output = {id2label[pred]: idx for idx, pred in enumerate(predictions) if pred != 4}
30
 
31
+ return labeled_output
32
+ except Exception as e:
33
+ return {"error": str(e)}
34
 
35
  # Create Gradio Interface
36
  interface = gr.Interface(
37
  fn=predict_receipt,
38
+ inputs=gr.Image(type="pil"),
39
  outputs="json",
40
  title="Receipt Information Analyzer",
41
  description="Upload a scanned receipt image to extract information like company name, date, address, and total."
 
51
 
52
 
53
 
54
+