Chatbot / app.py
NaimaAqeel's picture
Update app.py
58fc57a verified
raw
history blame
4.92 kB
import os
import numpy as np
import faiss
import pickle
from sentence_transformers import SentenceTransformer
from transformers import pipeline
import gradio as gr
import fitz # PyMuPDF for PDFs
import docx # python-docx for Word files
# Initialize global variables
index_path = "faiss_index.pkl"
document_texts_path = "document_texts.pkl"
# Load or initialize FAISS index and document chunks
if os.path.exists(index_path) and os.path.exists(document_texts_path):
with open(index_path, "rb") as f:
index = pickle.load(f)
with open(document_texts_path, "rb") as f:
document_texts = pickle.load(f)
else:
# Use 384 dim for all-MiniLM-L6-v2 model
dim = 384
index = faiss.IndexFlatL2(dim)
document_texts = []
# Load SentenceTransformer for embeddings
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# Initialize QA pipeline with a text generation model
qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-small")
def extract_text_from_pdf(file_path):
doc = fitz.open(file_path)
text = ""
for page in doc:
text += page.get_text()
doc.close()
return text
def extract_text_from_docx(file_path):
doc = docx.Document(file_path)
fullText = []
for para in doc.paragraphs:
fullText.append(para.text)
return "\n".join(fullText)
def chunk_text(text, max_len=500):
"""Split text into chunks of max_len characters, trying to split at sentence boundaries."""
import re
sentences = re.split(r'(?<=[.!?]) +', text)
chunks = []
current_chunk = ""
for sent in sentences:
if len(current_chunk) + len(sent) + 1 <= max_len:
current_chunk += sent + " "
else:
chunks.append(current_chunk.strip())
current_chunk = sent + " "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def get_embeddings(texts, is_query=False):
if isinstance(texts, str):
texts = [texts]
embeddings = embedder.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
return embeddings
def upload_document(file):
global index, document_texts
ext = os.path.splitext(file.name)[-1].lower()
try:
if ext == ".pdf":
text = extract_text_from_pdf(file.file.name)
elif ext == ".docx":
text = extract_text_from_docx(file.file.name)
else:
return "Unsupported file type. Please upload a PDF or DOCX file."
except Exception as e:
return f"Failed to extract text: {str(e)}"
if not text.strip():
return "Failed to extract any text from the document."
chunks = chunk_text(text)
embeddings = get_embeddings(chunks)
# Convert FAISS index to IDMap to allow adding new vectors incrementally
if not isinstance(index, faiss.IndexIDMap):
id_map = faiss.IndexIDMap(index)
index = id_map
start_id = len(document_texts)
ids = np.arange(start_id, start_id + len(chunks))
index.add_with_ids(embeddings.astype('float32'), ids)
document_texts.extend(chunks)
# Save index and texts
with open(index_path, "wb") as f:
pickle.dump(index, f)
with open(document_texts_path, "wb") as f:
pickle.dump(document_texts, f)
return f"Document uploaded and indexed successfully with {len(chunks)} chunks."
def generate_answer_from_file(query, top_k=5):
global index, document_texts
if len(document_texts) == 0:
return "No document uploaded yet. Please upload a PDF or DOCX file first."
query_vec = get_embeddings(query, is_query=True).astype("float32")
scores, indices = index.search(query_vec, top_k)
retrieved_chunks = [document_texts[i] for i in indices[0] if i < len(document_texts)]
context = "\n\n".join(retrieved_chunks)
prompt = (
"You are a helpful assistant reading a document.\n\n"
"Context:\n"
f"{context}\n\n"
f"Question: {query}\n"
"Answer:"
)
# Generate answer with max length 256 tokens
result = qa_pipeline(prompt, max_length=256, do_sample=False)[0]['generated_text']
return result.strip()
with gr.Blocks() as demo:
gr.Markdown("## Document Question Answering App\nUpload a PDF or DOCX file, then ask questions based on it.")
with gr.Row():
file_input = gr.File(label="Upload PDF or DOCX file", file_types=['.pdf', '.docx'])
upload_btn = gr.Button("Upload & Index Document")
upload_output = gr.Textbox(label="Upload Status", interactive=False)
question = gr.Textbox(label="Enter your question here")
answer = gr.Textbox(label="Answer", interactive=False)
ask_btn = gr.Button("Ask")
upload_btn.click(upload_document, inputs=file_input, outputs=upload_output)
ask_btn.click(generate_answer_from_file, inputs=question, outputs=answer)
demo.launch()