ikraamkb commited on
Commit
6c0ceb9
·
verified ·
1 Parent(s): 10e2a27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -10,7 +10,7 @@ from PIL import Image
10
  from transformers import pipeline
11
  import gradio as gr
12
  from fastapi.responses import RedirectResponse
13
- import numpy as np
14
  # Initialize FastAPI
15
  app = FastAPI()
16
 
@@ -62,14 +62,19 @@ def extract_text_from_excel(excel_file):
62
 
63
  # Function to perform object detection using Torchvision
64
  def extract_text_from_image(image_file):
65
- if isinstance(image_file, np.ndarray): # Check if input is a NumPy array
66
- image = Image.fromarray(image_file) # Convert NumPy array to PIL image
67
- else:
68
- image = Image.open(image_file).convert("RGB") # Handle file input
 
 
 
 
 
 
 
 
69
 
70
- reader = easyocr.Reader(["en"])
71
- result = reader.readtext(np.array(image)) # Convert PIL image back to NumPy array
72
- return " ".join([res[1] for res in result])
73
  # Function to answer questions based on document content
74
  def answer_question_from_document(file, question):
75
  file_ext = file.name.split(".")[-1].lower()
 
10
  from transformers import pipeline
11
  import gradio as gr
12
  from fastapi.responses import RedirectResponse
13
+
14
  # Initialize FastAPI
15
  app = FastAPI()
16
 
 
62
 
63
  # Function to perform object detection using Torchvision
64
  def extract_text_from_image(image_file):
65
+ image = Image.open(image_file).convert("RGB")
66
+ image_tensor = transform(image).unsqueeze(0)
67
+
68
+ with torch.no_grad():
69
+ predictions = model(image_tensor)
70
+
71
+ detected_objects = []
72
+ for label, score in zip(predictions[0]['labels'], predictions[0]['scores']):
73
+ if score > 0.7:
74
+ detected_objects.append(f"Object {label.item()} detected with confidence {score.item():.2f}")
75
+
76
+ return "\n".join(detected_objects) if detected_objects else "No objects detected."
77
 
 
 
 
78
  # Function to answer questions based on document content
79
  def answer_question_from_document(file, question):
80
  file_ext = file.name.split(".")[-1].lower()