Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, UploadFile | |
import pdfplumber | |
import docx | |
import openpyxl | |
from pptx import Presentation | |
import easyocr | |
from transformers import pipeline | |
import gradio as gr | |
from fastapi.responses import RedirectResponse | |
# Initialize FastAPI | |
app = FastAPI() | |
# Load AI Model for Question Answering | |
qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-large", tokenizer="google/flan-t5-large", use_fast=True) | |
# Function to truncate text to 450 tokens | |
def truncate_text(text, max_tokens=450): | |
words = text.split() | |
return " ".join(words[:max_tokens]) | |
# Functions to extract text from different file formats | |
def extract_text_from_pdf(pdf_file): | |
text = "" | |
with pdfplumber.open(pdf_file) as pdf: | |
for page in pdf.pages: | |
text += page.extract_text() + "\n" | |
return text.strip() | |
def extract_text_from_docx(docx_file): | |
doc = docx.Document(docx_file) | |
return "\n".join([para.text for para in doc.paragraphs]) | |
def extract_text_from_pptx(pptx_file): | |
ppt = Presentation(pptx_file) | |
text = [] | |
for slide in ppt.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text"): | |
text.append(shape.text) | |
return "\n".join(text) | |
def extract_text_from_excel(excel_file): | |
wb = openpyxl.load_workbook(excel_file) | |
text = [] | |
for sheet in wb.worksheets: | |
for row in sheet.iter_rows(values_only=True): | |
text.append(" ".join(map(str, row))) | |
return "\n".join(text) | |
def extract_text_from_image(image_file): | |
reader = easyocr.Reader(["en"]) | |
result = reader.readtext(image_file) | |
return " ".join([res[1] for res in result]) | |
# Function to answer questions based on document content | |
def answer_question_from_document(file, question): | |
file_ext = file.name.split(".")[-1].lower() | |
if file_ext == "pdf": | |
text = extract_text_from_pdf(file) | |
elif file_ext == "docx": | |
text = extract_text_from_docx(file) | |
elif file_ext == "pptx": | |
text = extract_text_from_pptx(file) | |
elif file_ext == "xlsx": | |
text = extract_text_from_excel(file) | |
else: | |
return "Unsupported file format!" | |
if not text: | |
return "No text extracted from the document." | |
truncated_text = truncate_text(text) | |
input_text = f"Question: {question} Context: {truncated_text}" | |
response = qa_pipeline(input_text) | |
return response[0]["generated_text"] | |
# Function to answer questions based on image content | |
def answer_question_from_image(image, question): | |
image_text = extract_text_from_image(image) | |
if not image_text: | |
return "No text detected in the image." | |
truncated_text = truncate_text(image_text) | |
input_text = f"Question: {question} Context: {truncated_text}" | |
response = qa_pipeline(input_text) | |
return response[0]["generated_text"] | |
# Gradio UI for Document & Image QA | |
doc_interface = gr.Interface( | |
fn=answer_question_from_document, | |
inputs=[gr.File(label="Upload Document"), gr.Textbox(label="Ask a Question")], | |
outputs="text", | |
title="AI Document Question Answering" | |
) | |
img_interface = gr.Interface( | |
fn=answer_question_from_image, | |
inputs=[gr.Image(label="Upload Image"), gr.Textbox(label="Ask a Question")], | |
outputs="text", | |
title="AI Image Question Answering" | |
) | |
# Mount Gradio Interfaces | |
demo = gr.TabbedInterface([doc_interface, img_interface], ["Document QA", "Image QA"]) | |
app = gr.mount_gradio_app(app, demo, path="/") | |
def home(): | |
return RedirectResponse(url="/") | |