Dileep7729 commited on
Commit
5b07194
·
verified ·
1 Parent(s): 3bfc64a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -67
app.py CHANGED
@@ -1,83 +1,47 @@
1
  import gradio as gr
 
 
2
  import torch
3
- from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
4
- import pytesseract
5
- import os
6
 
7
- # Explicitly set the Tesseract path for Hugging Face Spaces
8
- pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
 
 
9
 
10
- # Debugging: Print Tesseract version and PATH details
11
- try:
12
- tesseract_version = pytesseract.get_tesseract_version()
13
- print("Tesseract Version:", tesseract_version)
14
- print("Tesseract Path:", pytesseract.pytesseract.tesseract_cmd)
15
- print("Environment PATH:", os.environ["PATH"])
16
- except Exception as e:
17
- print("Tesseract Debugging Error:", e)
18
 
19
- # For local development on Windows
20
- # Uncomment the line below if running locally on Windows
21
- # pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
 
 
 
 
 
22
 
23
- # Load the model and processor
24
- processor = LayoutLMv3Processor.from_pretrained("quadranttechnologies/Receipt_Image_Analyzer")
25
- model = LayoutLMv3ForTokenClassification.from_pretrained("quadranttechnologies/Receipt_Image_Analyzer")
26
- model.eval()
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- model.to(device)
29
 
30
- def process_image(image):
31
- try:
32
- # Preprocess the image using the processor
33
- encoding = processor(image, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
34
 
35
- # Move inputs to the same device as the model
36
- encoding = {key: val.to(device) for key, val in encoding.items()}
37
 
38
- # Perform inference
39
- with torch.no_grad():
40
- outputs = model(**encoding)
41
- predictions = torch.argmax(outputs.logits, dim=-1)
42
-
43
- # Extract input IDs, bounding boxes, and predicted labels
44
- words = encoding["input_ids"]
45
- bboxes = encoding["bbox"]
46
- labels = predictions.squeeze().tolist()
47
-
48
- # Format output as JSON
49
- structured_output = []
50
- for word_id, bbox, label in zip(words.squeeze().tolist(), bboxes.squeeze().tolist(), labels):
51
- # Decode the word ID to text
52
- word = processor.tokenizer.decode([word_id]).strip()
53
- if word: # Avoid adding empty words
54
- structured_output.append({
55
- "word": word,
56
- "bounding_box": bbox,
57
- "label": model.config.id2label[label] # Convert label ID to label name
58
- })
59
-
60
- return structured_output
61
-
62
- except Exception as e:
63
- # Debugging: Log any errors encountered during processing
64
- print("Error during processing:", str(e))
65
- return {"error": str(e)}
66
-
67
- # Define the Gradio interface
68
  interface = gr.Interface(
69
- fn=process_image,
70
- inputs=gr.Image(type="pil"), # Accepts image input
71
- outputs="json", # Outputs JSON structure
72
- title="Receipt_Image_Analyzer",
73
- description="Upload an image (e.g., receipt or document) to extract structured information in JSON format."
74
  )
75
 
76
- # Launch the app
77
  if __name__ == "__main__":
78
- # Debugging: Check if the app is starting correctly
79
- print("Starting Table OCR App...")
80
- interface.launch(share=True)
81
 
82
 
83
 
 
1
  import gradio as gr
2
+ from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor
3
+ from PIL import Image
4
  import torch
 
 
 
5
 
6
+ # Load the fine-tuned model and processor
7
+ model_path = "quadranttechnologies/Receipt_Image_Analyzer"
8
+ model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
9
+ processor = LayoutLMv3Processor.from_pretrained(model_path)
10
 
11
+ # Define label mapping
12
+ id2label = {0: "company", 1: "date", 2: "address", 3: "total", 4: "other"}
 
 
 
 
 
 
13
 
14
+ # Define prediction function
15
+ def predict_receipt(image):
16
+ # Preprocess the image
17
+ encoding = processor(image, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
18
+ input_ids = encoding["input_ids"]
19
+ attention_mask = encoding["attention_mask"]
20
+ bbox = encoding["bbox"]
21
+ pixel_values = encoding["pixel_values"]
22
 
23
+ # Get model predictions
24
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, bbox=bbox, pixel_values=pixel_values)
25
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
 
 
 
26
 
27
+ # Map predictions to labels
28
+ labeled_output = {id2label[pred]: idx for idx, pred in enumerate(predictions) if pred != 4}
 
 
29
 
30
+ return labeled_output
 
31
 
32
+ # Create Gradio Interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  interface = gr.Interface(
34
+ fn=predict_receipt,
35
+ inputs=gr.inputs.Image(type="pil"),
36
+ outputs="json",
37
+ title="Receipt Information Analyzer",
38
+ description="Upload a scanned receipt image to extract information like company name, date, address, and total."
39
  )
40
 
41
+ # Launch the interface
42
  if __name__ == "__main__":
43
+ interface.launch()
44
+
 
45
 
46
 
47