Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -152,8 +152,111 @@ app = gr.mount_gradio_app(app, demo, path="/")
|
|
152 |
def home():
|
153 |
return RedirectResponse(url="/")
|
154 |
"""
|
|
|
|
|
|
|
155 |
import torch
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
def home():
|
153 |
return RedirectResponse(url="/")
|
154 |
"""
|
155 |
+
import gradio as gr
|
156 |
+
import numpy as np
|
157 |
+
import fitz # PyMuPDF
|
158 |
import torch
|
159 |
+
from fastapi import FastAPI
|
160 |
+
from transformers import pipeline
|
161 |
+
from PIL import Image
|
162 |
+
from starlette.responses import RedirectResponse
|
163 |
+
from openpyxl import load_workbook
|
164 |
+
from docx import Document
|
165 |
+
from pptx import Presentation
|
166 |
+
|
167 |
+
# β
Initialize FastAPI
|
168 |
+
app = FastAPI()
|
169 |
+
|
170 |
+
# β
Check if CUDA is Available (For Debugging)
|
171 |
+
device = "cpu"
|
172 |
+
print(f"β
Running on: {device}")
|
173 |
+
|
174 |
+
# β
Lazy Load Model Function (Loads Only When Needed)
|
175 |
+
def get_qa_pipeline():
|
176 |
+
print("π Loading QA Model on CPU...")
|
177 |
+
return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=-1)
|
178 |
+
|
179 |
+
def get_image_captioning_pipeline():
|
180 |
+
print("π Loading Image Captioning Model on CPU...")
|
181 |
+
return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning", device=-1)
|
182 |
+
|
183 |
+
# β
File Type Validation
|
184 |
+
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
|
185 |
+
|
186 |
+
def validate_file_type(file):
|
187 |
+
print(f"π Validating file: {file.name}")
|
188 |
+
ext = file.name.split(".")[-1].lower()
|
189 |
+
return None if ext in ALLOWED_EXTENSIONS else f"β Unsupported file format: {ext}"
|
190 |
+
|
191 |
+
# β
Extract Text Functions (Optimized)
|
192 |
+
def extract_text_from_pdf(file):
|
193 |
+
print("π Extracting text from PDF...")
|
194 |
+
with fitz.open(file.name) as doc:
|
195 |
+
return " ".join(page.get_text() for page in doc)
|
196 |
+
|
197 |
+
def extract_text_from_docx(file):
|
198 |
+
print("π Extracting text from DOCX...")
|
199 |
+
doc = Document(file.name)
|
200 |
+
return " ".join(p.text for p in doc.paragraphs)
|
201 |
+
|
202 |
+
def extract_text_from_pptx(file):
|
203 |
+
print("π Extracting text from PPTX...")
|
204 |
+
ppt = Presentation(file.name)
|
205 |
+
return " ".join(shape.text for slide in ppt.slides for shape in slide.shapes if hasattr(shape, "text"))
|
206 |
+
|
207 |
+
def extract_text_from_excel(file):
|
208 |
+
print("π Extracting text from Excel...")
|
209 |
+
wb = load_workbook(file.name, data_only=True)
|
210 |
+
return " ".join(" ".join(str(cell) for cell in row if cell) for sheet in wb.worksheets for row in sheet.iter_rows(values_only=True))
|
211 |
+
|
212 |
+
# β
Question Answering Function (Efficient Processing)
|
213 |
+
async def answer_question(file, question: str):
|
214 |
+
print("π Processing file for QA...")
|
215 |
+
|
216 |
+
validation_error = validate_file_type(file)
|
217 |
+
if validation_error:
|
218 |
+
return validation_error
|
219 |
+
|
220 |
+
file_ext = file.name.split(".")[-1].lower()
|
221 |
+
text = ""
|
222 |
+
|
223 |
+
if file_ext == "pdf":
|
224 |
+
text = extract_text_from_pdf(file)
|
225 |
+
elif file_ext == "docx":
|
226 |
+
text = extract_text_from_docx(file)
|
227 |
+
elif file_ext == "pptx":
|
228 |
+
text = extract_text_from_pptx(file)
|
229 |
+
elif file_ext == "xlsx":
|
230 |
+
text = extract_text_from_excel(file)
|
231 |
+
|
232 |
+
if not text.strip():
|
233 |
+
return "β οΈ No text extracted from the document."
|
234 |
+
|
235 |
+
print("βοΈ Truncating text for faster processing...")
|
236 |
+
truncated_text = text[:1024] # Reduce to 1024 characters for better speed
|
237 |
+
|
238 |
+
qa_pipeline = get_qa_pipeline()
|
239 |
+
response = qa_pipeline(f"Question: {question}\nContext: {truncated_text}")
|
240 |
+
|
241 |
+
return response[0]["generated_text"]
|
242 |
+
|
243 |
+
# β
Gradio UI
|
244 |
+
with gr.Blocks() as demo:
|
245 |
+
gr.Markdown("## π AI-Powered Document & Image QA")
|
246 |
+
|
247 |
+
with gr.Row():
|
248 |
+
file_input = gr.File(label="Upload Document")
|
249 |
+
question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?")
|
250 |
+
|
251 |
+
answer_output = gr.Textbox(label="Answer")
|
252 |
+
submit_btn = gr.Button("Get Answer")
|
253 |
+
|
254 |
+
submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output)
|
255 |
+
|
256 |
+
# β
Mount Gradio with FastAPI
|
257 |
+
app = gr.mount_gradio_app(app, demo, path="/demo")
|
258 |
+
|
259 |
+
@app.get("/")
|
260 |
+
def home():
|
261 |
+
return RedirectResponse(url="/demo")
|
262 |
+
|