ikraamkb commited on
Commit
39b3aed
·
verified ·
1 Parent(s): 0c9548a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -11,11 +11,12 @@ from transformers import pipeline
11
  import gradio as gr
12
  from fastapi.responses import RedirectResponse
13
  import numpy as np
 
14
  # Initialize FastAPI
15
  app = FastAPI()
16
 
17
- # Load AI Model for Question Answering
18
- qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-large", tokenizer="google/flan-t5-large", use_fast=True)
19
 
20
  # Load Pretrained Object Detection Model (Torchvision)
21
  model = fasterrcnn_resnet50_fpn(pretrained=True)
@@ -60,7 +61,7 @@ def extract_text_from_excel(excel_file):
60
  text.append(" ".join(map(str, row)))
61
  return "\n".join(text)
62
 
63
- # Function to perform object detection using Torchvision
64
  def extract_text_from_image(image_file):
65
  if isinstance(image_file, np.ndarray): # Check if input is a NumPy array
66
  image = Image.fromarray(image_file) # Convert NumPy array to PIL image
@@ -70,7 +71,8 @@ def extract_text_from_image(image_file):
70
  reader = easyocr.Reader(["en"])
71
  result = reader.readtext(np.array(image)) # Convert PIL image back to NumPy array
72
  return " ".join([res[1] for res in result])
73
- # Function to answer questions based on document content
 
74
  def answer_question_from_document(file, question):
75
  file_ext = file.name.split(".")[-1].lower()
76
 
@@ -89,10 +91,10 @@ def answer_question_from_document(file, question):
89
  return "No text extracted from the document."
90
 
91
  truncated_text = truncate_text(text)
92
- input_text = f"Question: {question} Context: {truncated_text}"
93
- response = qa_pipeline(input_text)
94
 
95
- return response[0]["generated_text"]
96
 
97
  # Function to answer questions based on image content
98
  def answer_question_from_image(image, question):
@@ -101,10 +103,10 @@ def answer_question_from_image(image, question):
101
  return "No meaningful content detected in the image."
102
 
103
  truncated_text = truncate_text(image_text)
104
- input_text = f"Question: {question} Context: {truncated_text}"
105
- response = qa_pipeline(input_text)
106
 
107
- return response[0]["generated_text"]
108
 
109
  # Gradio UI for Document & Image QA
110
  doc_interface = gr.Interface(
 
11
  import gradio as gr
12
  from fastapi.responses import RedirectResponse
13
  import numpy as np
14
+
15
  # Initialize FastAPI
16
  app = FastAPI()
17
 
18
+ # Load AI Model for Question Answering (Summarization-based approach)
19
+ qa_pipeline = pipeline("summarization", model="facebook/bart-large-cnn", tokenizer="facebook/bart-large-cnn")
20
 
21
  # Load Pretrained Object Detection Model (Torchvision)
22
  model = fasterrcnn_resnet50_fpn(pretrained=True)
 
61
  text.append(" ".join(map(str, row)))
62
  return "\n".join(text)
63
 
64
+ # Function to extract text from image
65
  def extract_text_from_image(image_file):
66
  if isinstance(image_file, np.ndarray): # Check if input is a NumPy array
67
  image = Image.fromarray(image_file) # Convert NumPy array to PIL image
 
71
  reader = easyocr.Reader(["en"])
72
  result = reader.readtext(np.array(image)) # Convert PIL image back to NumPy array
73
  return " ".join([res[1] for res in result])
74
+
75
+ # Function to answer questions based on document content using BART summarization
76
  def answer_question_from_document(file, question):
77
  file_ext = file.name.split(".")[-1].lower()
78
 
 
91
  return "No text extracted from the document."
92
 
93
  truncated_text = truncate_text(text)
94
+ input_text = f"Context: {truncated_text} Question: {question}"
95
+ response = qa_pipeline(input_text, max_length=100, min_length=30, do_sample=False)
96
 
97
+ return response[0]["summary_text"]
98
 
99
  # Function to answer questions based on image content
100
  def answer_question_from_image(image, question):
 
103
  return "No meaningful content detected in the image."
104
 
105
  truncated_text = truncate_text(image_text)
106
+ input_text = f"Context: {truncated_text} Question: {question}"
107
+ response = qa_pipeline(input_text, max_length=100, min_length=30, do_sample=False)
108
 
109
+ return response[0]["summary_text"]
110
 
111
  # Gradio UI for Document & Image QA
112
  doc_interface = gr.Interface(