ikraamkb commited on
Commit
df3d859
·
verified ·
1 Parent(s): 15de350

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -31
app.py CHANGED
@@ -3,29 +3,16 @@ import pdfplumber
3
  import docx
4
  import openpyxl
5
  from pptx import Presentation
6
- import torch
7
- from torchvision import transforms
8
- from torchvision.models.detection import fasterrcnn_resnet50_fpn
9
- from PIL import Image
10
  from transformers import pipeline
11
  import gradio as gr
12
  from fastapi.responses import RedirectResponse
13
- import numpy as np
14
 
15
  # Initialize FastAPI
16
  app = FastAPI()
17
 
18
- # Load AI Model for Question Answering (Summarization-based approach)
19
- qa_pipeline = pipeline("summarization", model="facebook/bart-large-cnn", tokenizer="facebook/bart-large-cnn")
20
-
21
- # Load Pretrained Object Detection Model (Torchvision)
22
- model = fasterrcnn_resnet50_fpn(pretrained=True)
23
- model.eval()
24
-
25
- # Image Transformations
26
- transform = transforms.Compose([
27
- transforms.ToTensor()
28
- ])
29
 
30
  # Function to truncate text to 450 tokens
31
  def truncate_text(text, max_tokens=450):
@@ -61,18 +48,12 @@ def extract_text_from_excel(excel_file):
61
  text.append(" ".join(map(str, row)))
62
  return "\n".join(text)
63
 
64
- # Function to extract text from image
65
  def extract_text_from_image(image_file):
66
- if isinstance(image_file, np.ndarray): # Check if input is a NumPy array
67
- image = Image.fromarray(image_file) # Convert NumPy array to PIL image
68
- else:
69
- image = Image.open(image_file).convert("RGB") # Handle file input
70
-
71
  reader = easyocr.Reader(["en"])
72
- result = reader.readtext(np.array(image)) # Convert PIL image back to NumPy array
73
  return " ".join([res[1] for res in result])
74
 
75
- # Function to answer questions based on document content using BART summarization
76
  def answer_question_from_document(file, question):
77
  file_ext = file.name.split(".")[-1].lower()
78
 
@@ -91,22 +72,22 @@ def answer_question_from_document(file, question):
91
  return "No text extracted from the document."
92
 
93
  truncated_text = truncate_text(text)
94
- input_text = f"Context: {truncated_text} Question: {question}"
95
- response = qa_pipeline(input_text, max_length=100, min_length=30, do_sample=False)
96
 
97
- return response[0]["summary_text"]
98
 
99
  # Function to answer questions based on image content
100
  def answer_question_from_image(image, question):
101
  image_text = extract_text_from_image(image)
102
  if not image_text:
103
- return "No meaningful content detected in the image."
104
 
105
  truncated_text = truncate_text(image_text)
106
- input_text = f"Context: {truncated_text} Question: {question}"
107
- response = qa_pipeline(input_text, max_length=100, min_length=30, do_sample=False)
108
 
109
- return response[0]["summary_text"]
110
 
111
  # Gradio UI for Document & Image QA
112
  doc_interface = gr.Interface(
 
3
  import docx
4
  import openpyxl
5
  from pptx import Presentation
6
+ import easyocr
 
 
 
7
  from transformers import pipeline
8
  import gradio as gr
9
  from fastapi.responses import RedirectResponse
 
10
 
11
  # Initialize FastAPI
12
  app = FastAPI()
13
 
14
+ # Load AI Model for Question Answering
15
+ qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-large", tokenizer="google/flan-t5-large", use_fast=True)
 
 
 
 
 
 
 
 
 
16
 
17
  # Function to truncate text to 450 tokens
18
  def truncate_text(text, max_tokens=450):
 
48
  text.append(" ".join(map(str, row)))
49
  return "\n".join(text)
50
 
 
51
  def extract_text_from_image(image_file):
 
 
 
 
 
52
  reader = easyocr.Reader(["en"])
53
+ result = reader.readtext(image_file)
54
  return " ".join([res[1] for res in result])
55
 
56
+ # Function to answer questions based on document content
57
  def answer_question_from_document(file, question):
58
  file_ext = file.name.split(".")[-1].lower()
59
 
 
72
  return "No text extracted from the document."
73
 
74
  truncated_text = truncate_text(text)
75
+ input_text = f"Question: {question} Context: {truncated_text}"
76
+ response = qa_pipeline(input_text)
77
 
78
+ return response[0]["generated_text"]
79
 
80
  # Function to answer questions based on image content
81
  def answer_question_from_image(image, question):
82
  image_text = extract_text_from_image(image)
83
  if not image_text:
84
+ return "No text detected in the image."
85
 
86
  truncated_text = truncate_text(image_text)
87
+ input_text = f"Question: {question} Context: {truncated_text}"
88
+ response = qa_pipeline(input_text)
89
 
90
+ return response[0]["generated_text"]
91
 
92
  # Gradio UI for Document & Image QA
93
  doc_interface = gr.Interface(