ikraamkb commited on
Commit
a768964
Β·
verified Β·
1 Parent(s): e2fade1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -4
app.py CHANGED
@@ -152,8 +152,111 @@ app = gr.mount_gradio_app(app, demo, path="/")
152
  def home():
153
  return RedirectResponse(url="/")
154
  """
 
 
 
155
  import torch
156
- print("CUDA Available:", torch.cuda.is_available())
157
- print("Torch Device Count:", torch.cuda.device_count())
158
- print("Current Device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
159
- print("CUDA Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def home():
153
  return RedirectResponse(url="/")
154
  """
155
+ import gradio as gr
156
+ import numpy as np
157
+ import fitz # PyMuPDF
158
  import torch
159
+ from fastapi import FastAPI
160
+ from transformers import pipeline
161
+ from PIL import Image
162
+ from starlette.responses import RedirectResponse
163
+ from openpyxl import load_workbook
164
+ from docx import Document
165
+ from pptx import Presentation
166
+
167
+ # βœ… Initialize FastAPI
168
+ app = FastAPI()
169
+
170
+ # βœ… Check if CUDA is Available (For Debugging)
171
+ device = "cpu"
172
+ print(f"βœ… Running on: {device}")
173
+
174
+ # βœ… Lazy Load Model Function (Loads Only When Needed)
175
+ def get_qa_pipeline():
176
+ print("πŸ”„ Loading QA Model on CPU...")
177
+ return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=-1)
178
+
179
+ def get_image_captioning_pipeline():
180
+ print("πŸ”„ Loading Image Captioning Model on CPU...")
181
+ return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning", device=-1)
182
+
183
+ # βœ… File Type Validation
184
+ ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
185
+
186
+ def validate_file_type(file):
187
+ print(f"πŸ“‚ Validating file: {file.name}")
188
+ ext = file.name.split(".")[-1].lower()
189
+ return None if ext in ALLOWED_EXTENSIONS else f"❌ Unsupported file format: {ext}"
190
+
191
+ # βœ… Extract Text Functions (Optimized)
192
+ def extract_text_from_pdf(file):
193
+ print("πŸ“„ Extracting text from PDF...")
194
+ with fitz.open(file.name) as doc:
195
+ return " ".join(page.get_text() for page in doc)
196
+
197
+ def extract_text_from_docx(file):
198
+ print("πŸ“„ Extracting text from DOCX...")
199
+ doc = Document(file.name)
200
+ return " ".join(p.text for p in doc.paragraphs)
201
+
202
+ def extract_text_from_pptx(file):
203
+ print("πŸ“„ Extracting text from PPTX...")
204
+ ppt = Presentation(file.name)
205
+ return " ".join(shape.text for slide in ppt.slides for shape in slide.shapes if hasattr(shape, "text"))
206
+
207
+ def extract_text_from_excel(file):
208
+ print("πŸ“Š Extracting text from Excel...")
209
+ wb = load_workbook(file.name, data_only=True)
210
+ return " ".join(" ".join(str(cell) for cell in row if cell) for sheet in wb.worksheets for row in sheet.iter_rows(values_only=True))
211
+
212
+ # βœ… Question Answering Function (Efficient Processing)
213
+ async def answer_question(file, question: str):
214
+ print("πŸ” Processing file for QA...")
215
+
216
+ validation_error = validate_file_type(file)
217
+ if validation_error:
218
+ return validation_error
219
+
220
+ file_ext = file.name.split(".")[-1].lower()
221
+ text = ""
222
+
223
+ if file_ext == "pdf":
224
+ text = extract_text_from_pdf(file)
225
+ elif file_ext == "docx":
226
+ text = extract_text_from_docx(file)
227
+ elif file_ext == "pptx":
228
+ text = extract_text_from_pptx(file)
229
+ elif file_ext == "xlsx":
230
+ text = extract_text_from_excel(file)
231
+
232
+ if not text.strip():
233
+ return "⚠️ No text extracted from the document."
234
+
235
+ print("βœ‚οΈ Truncating text for faster processing...")
236
+ truncated_text = text[:1024] # Reduce to 1024 characters for better speed
237
+
238
+ qa_pipeline = get_qa_pipeline()
239
+ response = qa_pipeline(f"Question: {question}\nContext: {truncated_text}")
240
+
241
+ return response[0]["generated_text"]
242
+
243
+ # βœ… Gradio UI
244
+ with gr.Blocks() as demo:
245
+ gr.Markdown("## πŸ“„ AI-Powered Document & Image QA")
246
+
247
+ with gr.Row():
248
+ file_input = gr.File(label="Upload Document")
249
+ question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?")
250
+
251
+ answer_output = gr.Textbox(label="Answer")
252
+ submit_btn = gr.Button("Get Answer")
253
+
254
+ submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output)
255
+
256
+ # βœ… Mount Gradio with FastAPI
257
+ app = gr.mount_gradio_app(app, demo, path="/demo")
258
+
259
+ @app.get("/")
260
+ def home():
261
+ return RedirectResponse(url="/demo")
262
+