Chatbot / app.py
NaimaAqeel's picture
Update app.py
2737463 verified
raw
history blame
5.32 kB
import os
import pickle
import numpy as np
import gradio as gr
import fitz # PyMuPDF
from docx import Document
from transformers import AutoModel, AutoTokenizer, pipeline
import faiss
import torch
# ===============================
# EMBEDDING MODEL
# ===============================
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
embedding_model = AutoModel.from_pretrained(model_name)
def get_embeddings(texts):
if isinstance(texts, str):
texts = [texts]
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
with torch.no_grad():
outputs = embedding_model(**inputs)
return outputs.last_hidden_state[:, 0].cpu().numpy()
# ===============================
# TEXT CHUNKING
# ===============================
def chunk_text(text, chunk_size=800, overlap=100):
chunks = []
start = 0
while start < len(text):
end = min(len(text), start + chunk_size)
chunks.append(text[start:end])
start += chunk_size - overlap
return chunks
# ===============================
# FAISS INDEX SETUP
# ===============================
index_path = "faiss_index.pkl"
document_texts_path = "document_texts.pkl"
document_texts = []
embedding_dim = 384
if os.path.exists(index_path) and os.path.exists(document_texts_path):
try:
with open(index_path, "rb") as f:
index = pickle.load(f)
with open(document_texts_path, "rb") as f:
document_texts = pickle.load(f)
except Exception as e:
print(f"Error loading index: {e}")
index = faiss.IndexFlatIP(embedding_dim)
else:
index = faiss.IndexFlatIP(embedding_dim)
# ===============================
# FILE EXTRACTORS
# ===============================
def extract_text_from_pdf(path):
text = ""
try:
doc = fitz.open(path)
for page in doc:
text += page.get_text()
except Exception as e:
print(f"PDF error: {e}")
return text
def extract_text_from_docx(path):
text = ""
try:
doc = Document(path)
text = "\n".join([para.text for para in doc.paragraphs])
except Exception as e:
print(f"DOCX error: {e}")
return text
# ===============================
# UPLOAD HANDLER
# ===============================
def upload_document(file):
ext = os.path.splitext(file.name)[-1].lower()
if ext == ".pdf":
text = extract_text_from_pdf(file.name)
elif ext == ".docx":
text = extract_text_from_docx(file.name)
else:
return "Unsupported file type."
chunks = chunk_text(text)
chunk_embeddings = get_embeddings(chunks)
index.add(np.array(chunk_embeddings).astype('float32'))
document_texts.extend(chunks)
with open(index_path, "wb") as f:
pickle.dump(index, f)
with open(document_texts_path, "wb") as f:
pickle.dump(document_texts, f)
return "Document uploaded and indexed successfully."
# ===============================
# GENERATION PIPELINE (FLAN-T5)
# ===============================
qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")
def generate_answer_from_file(query, top_k=10):
if not document_texts:
return "No documents indexed yet."
query_vector = get_embeddings(query).astype("float32")
scores, indices = index.search(query_vector, k=top_k)
retrieved_chunks = [document_texts[i] for i in indices[0]]
context = "\n\n".join(retrieved_chunks)
print("\n--- Retrieved Context ---\n", context) # Debugging print
# Prompt Engineering
prompt = (
"You are a helpful assistant reading student notes or textbook passages.\n\n"
"Based on the context provided, answer the question accurately and clearly.\n\n"
"### Example\n"
"Context:\nArtificial systems are created by people. These systems are designed to perform specific tasks, improve efficiency, and solve problems. Examples include knowledge systems, engineering systems, and social systems.\n\n"
"Question: What is an Artificial System?\n"
"Answer: Artificial systems are systems created by humans to perform specific tasks, improve efficiency, and solve problems. They include systems like knowledge systems, engineering systems, and social systems.\n\n"
"### Now answer this\n"
f"Context:\n{context}\n\n"
f"Question: {query}\n"
f"Answer:"
)
result = qa_pipeline(prompt, max_length=512, do_sample=False)[0]['generated_text']
return result.strip()
# ===============================
# GRADIO INTERFACES
# ===============================
upload_interface = gr.Interface(
fn=upload_document,
inputs=gr.File(file_types=[".pdf", ".docx"]),
outputs="text",
title="Upload Document",
description="Upload your Word or PDF document for question answering."
)
search_interface = gr.Interface(
fn=generate_answer_from_file,
inputs=gr.Textbox(placeholder="Ask your question about the uploaded document..."),
outputs="text",
title="Ask the Document",
description="Ask questions about the uploaded content. The chatbot will answer based on the document."
)
app = gr.TabbedInterface([upload_interface, search_interface], ["Upload", "Ask"])
app.launch()