File size: 7,004 Bytes
f615b93 921780e f615b93 921780e 2f96c18 921780e f615b93 921780e f615b93 3d0f58b 921780e f615b93 921780e f615b93 921780e f615b93 921780e 2f96c18 921780e f615b93 921780e f615b93 921780e f615b93 2f96c18 f615b93 921780e f615b93 921780e f615b93 921780e f615b93 921780e f615b93 921780e f615b93 921780e f615b93 921780e 2f96c18 921780e 2f96c18 f615b93 921780e 2f96c18 f615b93 2f96c18 f615b93 921780e f615b93 921780e f615b93 921780e f615b93 921780e f615b93 2f96c18 921780e f615b93 921780e f615b93 921780e f615b93 2f96c18 f615b93 921780e f615b93 921780e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import os
import faiss
import numpy as np
import requests
import streamlit as st
from pypdf import PdfReader
from sentence_transformers import SentenceTransformer
###############################################################################
# 1. PDF Parsing and Chunking
###############################################################################
def extract_pdf_text(pdf_file) -> str:
"""
Read and extract text from each page of an uploaded PDF file.
"""
reader = PdfReader(pdf_file)
all_text = []
for page in reader.pages:
text = page.extract_text() or ""
all_text.append(text.strip())
return "\n".join(all_text)
def chunk_text(text, chunk_size=300, overlap=50):
"""
Splits text into overlapping chunks, each approx. 'chunk_size' tokens.
'overlap' is how many tokens from the previous chunk to include again.
"""
words = text.split()
chunks = []
start = 0
while start < len(words):
end = start + chunk_size
chunk = words[start:end]
chunks.append(" ".join(chunk))
start += (chunk_size - overlap)
return chunks
###############################################################################
# 2. Embedding Model
###############################################################################
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
###############################################################################
# 3. Build FAISS Index
###############################################################################
def build_faiss_index(chunks):
"""
Creates a FAISS index from embedded chunks.
Returns (index, chunk_embeddings).
"""
chunk_embeddings = embedding_model.encode(chunks, show_progress_bar=False)
chunk_embeddings = np.array(chunk_embeddings, dtype='float32')
dimension = chunk_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(chunk_embeddings)
return index, chunk_embeddings
###############################################################################
# 4. Retrieval Function
###############################################################################
def retrieve_chunks(query, index, chunks, top_k=3):
"""
Embeds 'query' and retrieves the top_k most relevant chunks from 'index'.
"""
query_embedding = embedding_model.encode([query], show_progress_bar=False)
query_embedding = np.array(query_embedding, dtype='float32')
distances, indices = index.search(query_embedding, top_k)
return [chunks[i] for i in indices[0]]
###############################################################################
# 5. Gemini LLM Integration
###############################################################################
def gemini_generate(prompt):
"""
Calls Google's Gemini API with the environment variable GEMINI_API_KEY.
"""
gemini_api_key = os.environ.get("GEMINI_API_KEY", "")
if not gemini_api_key:
return "Error: No GEMINI_API_KEY found in environment variables."
url = (
"https://generativelanguage.googleapis.com/"
"v1beta/models/gemini-1.5-flash:generateContent"
f"?key={gemini_api_key}"
)
payload = {
"contents": [
{
"parts": [
{"text": prompt}
]
}
]
}
headers = {"Content-Type": "application/json"}
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
r_data = response.json()
# Extract the text from the 'candidates' structure:
return r_data["candidates"][0]["content"]["parts"][0]["text"]
except requests.exceptions.RequestException as e:
return f"Error calling Gemini API: {e}"
except KeyError:
return f"Parsing error or unexpected response format: {response.text}"
###############################################################################
# 6. RAG QA Function
###############################################################################
def answer_question_with_RAG(user_question, index, chunks):
"""
Retrieves relevant chunks, builds an augmented prompt, and calls gemini_generate().
"""
relevant_chunks = retrieve_chunks(user_question, index, chunks, top_k=3)
context = "\n\n".join(relevant_chunks)
prompt = f"""
You are an AI assistant that knows the details from the uploaded research paper.
Answer the user's question accurately using the context below.
If something is not in the context, say 'I don't know'.
Context:
{context}
User's question: {user_question}
Answer:
"""
return gemini_generate(prompt)
###############################################################################
# Streamlit Application
###############################################################################
def main():
# Basic page config (optional):
st.set_page_config(
page_title="AI-Powered Personal Research Assistant",
layout="centered"
)
# Title and Subheader
st.title("AI-Powered Personal Research Assistant")
st.write("Welcome! How may I help you?")
# Store the FAISS index + chunks in session_state to persist across reruns
if "faiss_index" not in st.session_state:
st.session_state.faiss_index = None
if "chunks" not in st.session_state:
st.session_state.chunks = None
# Step 1: Upload and Process PDF
uploaded_pdf = st.file_uploader("Upload your research paper (PDF)", type=["pdf"])
if st.button("Process PDF"):
if uploaded_pdf is None:
st.warning("Please upload a PDF file first.")
else:
# Read and chunk
raw_text = extract_pdf_text(uploaded_pdf)
if not raw_text.strip():
st.error("No text found in PDF.")
return
chunks = chunk_text(raw_text, chunk_size=300, overlap=50)
if not chunks:
st.error("No valid text to chunk.")
return
# Build index
faiss_index, _ = build_faiss_index(chunks)
st.session_state.faiss_index = faiss_index
st.session_state.chunks = chunks
st.success("PDF processed successfully!")
# Step 2: Ask a Question
user_question = st.text_input("Ask a question about your research paper:")
if st.button("Get Answer"):
if not st.session_state.faiss_index or not st.session_state.chunks:
st.warning("Please upload and process a PDF first.")
elif not user_question.strip():
st.warning("Please enter a valid question.")
else:
answer = answer_question_with_RAG(
user_question,
st.session_state.faiss_index,
st.session_state.chunks
)
st.write("### Answer:")
st.write(answer)
if __name__ == "__main__":
main()
|