|
import os |
|
import faiss |
|
import numpy as np |
|
import requests |
|
import streamlit as st |
|
|
|
from pypdf import PdfReader |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
|
def extract_pdf_text(pdf_file) -> str: |
|
""" |
|
Read and extract text from each page of an uploaded PDF file. |
|
""" |
|
reader = PdfReader(pdf_file) |
|
all_text = [] |
|
for page in reader.pages: |
|
text = page.extract_text() or "" |
|
all_text.append(text.strip()) |
|
return "\n".join(all_text) |
|
|
|
def chunk_text(text, chunk_size=300, overlap=50): |
|
""" |
|
Splits text into overlapping chunks, each approx. 'chunk_size' tokens. |
|
'overlap' is how many tokens from the previous chunk to include again. |
|
""" |
|
words = text.split() |
|
chunks = [] |
|
start = 0 |
|
while start < len(words): |
|
end = start + chunk_size |
|
chunk = words[start:end] |
|
chunks.append(" ".join(chunk)) |
|
start += (chunk_size - overlap) |
|
return chunks |
|
|
|
|
|
|
|
|
|
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
|
|
|
|
|
|
def build_faiss_index(chunks): |
|
""" |
|
Creates a FAISS index from embedded chunks. |
|
Returns (index, chunk_embeddings). |
|
""" |
|
chunk_embeddings = embedding_model.encode(chunks, show_progress_bar=False) |
|
chunk_embeddings = np.array(chunk_embeddings, dtype='float32') |
|
|
|
dimension = chunk_embeddings.shape[1] |
|
index = faiss.IndexFlatL2(dimension) |
|
index.add(chunk_embeddings) |
|
|
|
return index, chunk_embeddings |
|
|
|
|
|
|
|
|
|
def retrieve_chunks(query, index, chunks, top_k=3): |
|
""" |
|
Embeds 'query' and retrieves the top_k most relevant chunks from 'index'. |
|
""" |
|
query_embedding = embedding_model.encode([query], show_progress_bar=False) |
|
query_embedding = np.array(query_embedding, dtype='float32') |
|
|
|
distances, indices = index.search(query_embedding, top_k) |
|
return [chunks[i] for i in indices[0]] |
|
|
|
|
|
|
|
|
|
def gemini_generate(prompt): |
|
""" |
|
Calls Google's Gemini API with the environment variable GEMINI_API_KEY. |
|
""" |
|
gemini_api_key = os.environ.get("GEMINI_API_KEY", "") |
|
if not gemini_api_key: |
|
return "Error: No GEMINI_API_KEY found in environment variables." |
|
|
|
url = ( |
|
"https://generativelanguage.googleapis.com/" |
|
"v1beta/models/gemini-1.5-flash:generateContent" |
|
f"?key={gemini_api_key}" |
|
) |
|
payload = { |
|
"contents": [ |
|
{ |
|
"parts": [ |
|
{"text": prompt} |
|
] |
|
} |
|
] |
|
} |
|
headers = {"Content-Type": "application/json"} |
|
|
|
try: |
|
response = requests.post(url, headers=headers, json=payload) |
|
response.raise_for_status() |
|
r_data = response.json() |
|
|
|
return r_data["candidates"][0]["content"]["parts"][0]["text"] |
|
except requests.exceptions.RequestException as e: |
|
return f"Error calling Gemini API: {e}" |
|
except KeyError: |
|
return f"Parsing error or unexpected response format: {response.text}" |
|
|
|
|
|
|
|
|
|
def answer_question_with_RAG(user_question, index, chunks): |
|
""" |
|
Retrieves relevant chunks, builds an augmented prompt, and calls gemini_generate(). |
|
""" |
|
relevant_chunks = retrieve_chunks(user_question, index, chunks, top_k=3) |
|
context = "\n\n".join(relevant_chunks) |
|
|
|
prompt = f""" |
|
You are an AI assistant that knows the details from the uploaded research paper. |
|
Answer the user's question accurately using the context below. |
|
If something is not in the context, say 'I don't know'. |
|
|
|
Context: |
|
{context} |
|
|
|
User's question: {user_question} |
|
|
|
Answer: |
|
""" |
|
return gemini_generate(prompt) |
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
st.set_page_config( |
|
page_title="AI-Powered Personal Research Assistant", |
|
layout="centered" |
|
) |
|
|
|
|
|
st.title("AI-Powered Personal Research Assistant") |
|
st.write("Welcome! How may I help you?") |
|
|
|
|
|
if "faiss_index" not in st.session_state: |
|
st.session_state.faiss_index = None |
|
if "chunks" not in st.session_state: |
|
st.session_state.chunks = None |
|
|
|
|
|
uploaded_pdf = st.file_uploader("Upload your research paper (PDF)", type=["pdf"]) |
|
if st.button("Process PDF"): |
|
if uploaded_pdf is None: |
|
st.warning("Please upload a PDF file first.") |
|
else: |
|
|
|
raw_text = extract_pdf_text(uploaded_pdf) |
|
if not raw_text.strip(): |
|
st.error("No text found in PDF.") |
|
return |
|
chunks = chunk_text(raw_text, chunk_size=300, overlap=50) |
|
if not chunks: |
|
st.error("No valid text to chunk.") |
|
return |
|
|
|
faiss_index, _ = build_faiss_index(chunks) |
|
st.session_state.faiss_index = faiss_index |
|
st.session_state.chunks = chunks |
|
st.success("PDF processed successfully!") |
|
|
|
|
|
user_question = st.text_input("Ask a question about your research paper:") |
|
if st.button("Get Answer"): |
|
if not st.session_state.faiss_index or not st.session_state.chunks: |
|
st.warning("Please upload and process a PDF first.") |
|
elif not user_question.strip(): |
|
st.warning("Please enter a valid question.") |
|
else: |
|
answer = answer_question_with_RAG( |
|
user_question, |
|
st.session_state.faiss_index, |
|
st.session_state.chunks |
|
) |
|
st.write("### Answer:") |
|
st.write(answer) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|