fhmsf's picture
Update app.py
921780e verified
raw
history blame
7 kB
import os
import faiss
import numpy as np
import requests
import streamlit as st
from pypdf import PdfReader
from sentence_transformers import SentenceTransformer
###############################################################################
# 1. PDF Parsing and Chunking
###############################################################################
def extract_pdf_text(pdf_file) -> str:
"""
Read and extract text from each page of an uploaded PDF file.
"""
reader = PdfReader(pdf_file)
all_text = []
for page in reader.pages:
text = page.extract_text() or ""
all_text.append(text.strip())
return "\n".join(all_text)
def chunk_text(text, chunk_size=300, overlap=50):
"""
Splits text into overlapping chunks, each approx. 'chunk_size' tokens.
'overlap' is how many tokens from the previous chunk to include again.
"""
words = text.split()
chunks = []
start = 0
while start < len(words):
end = start + chunk_size
chunk = words[start:end]
chunks.append(" ".join(chunk))
start += (chunk_size - overlap)
return chunks
###############################################################################
# 2. Embedding Model
###############################################################################
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
###############################################################################
# 3. Build FAISS Index
###############################################################################
def build_faiss_index(chunks):
"""
Creates a FAISS index from embedded chunks.
Returns (index, chunk_embeddings).
"""
chunk_embeddings = embedding_model.encode(chunks, show_progress_bar=False)
chunk_embeddings = np.array(chunk_embeddings, dtype='float32')
dimension = chunk_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(chunk_embeddings)
return index, chunk_embeddings
###############################################################################
# 4. Retrieval Function
###############################################################################
def retrieve_chunks(query, index, chunks, top_k=3):
"""
Embeds 'query' and retrieves the top_k most relevant chunks from 'index'.
"""
query_embedding = embedding_model.encode([query], show_progress_bar=False)
query_embedding = np.array(query_embedding, dtype='float32')
distances, indices = index.search(query_embedding, top_k)
return [chunks[i] for i in indices[0]]
###############################################################################
# 5. Gemini LLM Integration
###############################################################################
def gemini_generate(prompt):
"""
Calls Google's Gemini API with the environment variable GEMINI_API_KEY.
"""
gemini_api_key = os.environ.get("GEMINI_API_KEY", "")
if not gemini_api_key:
return "Error: No GEMINI_API_KEY found in environment variables."
url = (
"https://generativelanguage.googleapis.com/"
"v1beta/models/gemini-1.5-flash:generateContent"
f"?key={gemini_api_key}"
)
payload = {
"contents": [
{
"parts": [
{"text": prompt}
]
}
]
}
headers = {"Content-Type": "application/json"}
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
r_data = response.json()
# Extract the text from the 'candidates' structure:
return r_data["candidates"][0]["content"]["parts"][0]["text"]
except requests.exceptions.RequestException as e:
return f"Error calling Gemini API: {e}"
except KeyError:
return f"Parsing error or unexpected response format: {response.text}"
###############################################################################
# 6. RAG QA Function
###############################################################################
def answer_question_with_RAG(user_question, index, chunks):
"""
Retrieves relevant chunks, builds an augmented prompt, and calls gemini_generate().
"""
relevant_chunks = retrieve_chunks(user_question, index, chunks, top_k=3)
context = "\n\n".join(relevant_chunks)
prompt = f"""
You are an AI assistant that knows the details from the uploaded research paper.
Answer the user's question accurately using the context below.
If something is not in the context, say 'I don't know'.
Context:
{context}
User's question: {user_question}
Answer:
"""
return gemini_generate(prompt)
###############################################################################
# Streamlit Application
###############################################################################
def main():
# Basic page config (optional):
st.set_page_config(
page_title="AI-Powered Personal Research Assistant",
layout="centered"
)
# Title and Subheader
st.title("AI-Powered Personal Research Assistant")
st.write("Welcome! How may I help you?")
# Store the FAISS index + chunks in session_state to persist across reruns
if "faiss_index" not in st.session_state:
st.session_state.faiss_index = None
if "chunks" not in st.session_state:
st.session_state.chunks = None
# Step 1: Upload and Process PDF
uploaded_pdf = st.file_uploader("Upload your research paper (PDF)", type=["pdf"])
if st.button("Process PDF"):
if uploaded_pdf is None:
st.warning("Please upload a PDF file first.")
else:
# Read and chunk
raw_text = extract_pdf_text(uploaded_pdf)
if not raw_text.strip():
st.error("No text found in PDF.")
return
chunks = chunk_text(raw_text, chunk_size=300, overlap=50)
if not chunks:
st.error("No valid text to chunk.")
return
# Build index
faiss_index, _ = build_faiss_index(chunks)
st.session_state.faiss_index = faiss_index
st.session_state.chunks = chunks
st.success("PDF processed successfully!")
# Step 2: Ask a Question
user_question = st.text_input("Ask a question about your research paper:")
if st.button("Get Answer"):
if not st.session_state.faiss_index or not st.session_state.chunks:
st.warning("Please upload and process a PDF first.")
elif not user_question.strip():
st.warning("Please enter a valid question.")
else:
answer = answer_question_with_RAG(
user_question,
st.session_state.faiss_index,
st.session_state.chunks
)
st.write("### Answer:")
st.write(answer)
if __name__ == "__main__":
main()