Dileep7729 commited on
Commit
7e5af81
·
verified ·
1 Parent(s): 235c3b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -1,26 +1,40 @@
1
  import gradio as gr
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
 
3
 
4
  # Load your fine-tuned model and tokenizer
5
  model_name = "quadranttechnologies/Receipt_Image_Analyzer"
6
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
 
9
- # Define a prediction function
10
- def analyze_receipt(receipt_text):
11
- inputs = tokenizer(receipt_text, return_tensors="pt", truncation=True, padding=True)
 
 
 
 
 
 
12
  outputs = model(**inputs)
13
  logits = outputs.logits
14
  predicted_class = logits.argmax(-1).item()
15
- return f"Predicted Class: {predicted_class}"
 
 
 
 
 
 
16
 
17
  # Create a Gradio interface
18
  interface = gr.Interface(
19
  fn=analyze_receipt,
20
- inputs="text",
21
- outputs="text",
22
  title="Receipt Image Analyzer",
23
- description="Analyze receipts for relevant information using a fine-tuned LLM model.",
24
  )
25
 
26
  # Launch the Gradio app
 
1
  import gradio as gr
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ from PIL import Image
4
+ import pytesseract # Install using `pip install pytesseract` and ensure Tesseract is installed
5
 
6
  # Load your fine-tuned model and tokenizer
7
  model_name = "quadranttechnologies/Receipt_Image_Analyzer"
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
 
11
+ # Define a function to preprocess the image and predict
12
+ def analyze_receipt(image):
13
+ # Perform OCR to extract text from the image
14
+ extracted_text = pytesseract.image_to_string(image)
15
+
16
+ # Tokenize the extracted text
17
+ inputs = tokenizer(extracted_text, return_tensors="pt", truncation=True, padding=True)
18
+
19
+ # Get model predictions
20
  outputs = model(**inputs)
21
  logits = outputs.logits
22
  predicted_class = logits.argmax(-1).item()
23
+
24
+ # Optionally return extracted text and prediction as JSON
25
+ result = {
26
+ "extracted_text": extracted_text,
27
+ "predicted_class": predicted_class
28
+ }
29
+ return result
30
 
31
  # Create a Gradio interface
32
  interface = gr.Interface(
33
  fn=analyze_receipt,
34
+ inputs=gr.inputs.Image(type="pil"), # Accept image input
35
+ outputs="json", # Return JSON output
36
  title="Receipt Image Analyzer",
37
+ description="Upload a receipt image to analyze and classify its contents.",
38
  )
39
 
40
  # Launch the Gradio app