Spaces:
Build error
Build error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import faiss
|
3 |
+
import numpy as np
|
4 |
+
import requests
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
from pypdf import PdfReader
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
|
10 |
+
###############################################################################
|
11 |
+
# 1. PDF Parsing and Chunking
|
12 |
+
###############################################################################
|
13 |
+
def extract_pdf_text(pdf_file) -> str:
|
14 |
+
"""
|
15 |
+
Read and extract text from each page of an uploaded PDF file.
|
16 |
+
"""
|
17 |
+
reader = PdfReader(pdf_file)
|
18 |
+
all_text = []
|
19 |
+
for page in reader.pages:
|
20 |
+
text = page.extract_text() or ""
|
21 |
+
all_text.append(text.strip())
|
22 |
+
return "\n".join(all_text)
|
23 |
+
|
24 |
+
def chunk_text(text, chunk_size=300, overlap=50):
|
25 |
+
"""
|
26 |
+
Splits text into overlapping chunks, each approx. 'chunk_size' tokens.
|
27 |
+
'overlap' is how many tokens from the previous chunk to include again.
|
28 |
+
"""
|
29 |
+
words = text.split()
|
30 |
+
chunks = []
|
31 |
+
start = 0
|
32 |
+
while start < len(words):
|
33 |
+
end = start + chunk_size
|
34 |
+
chunk = words[start:end]
|
35 |
+
chunks.append(" ".join(chunk))
|
36 |
+
start += (chunk_size - overlap)
|
37 |
+
return chunks
|
38 |
+
|
39 |
+
###############################################################################
|
40 |
+
# 2. Embedding Model
|
41 |
+
###############################################################################
|
42 |
+
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
43 |
+
|
44 |
+
###############################################################################
|
45 |
+
# 3. Build FAISS Index
|
46 |
+
###############################################################################
|
47 |
+
def build_faiss_index(chunks):
|
48 |
+
"""
|
49 |
+
Creates a FAISS index from embedded chunks.
|
50 |
+
Returns (index, chunk_embeddings).
|
51 |
+
"""
|
52 |
+
chunk_embeddings = embedding_model.encode(chunks, show_progress_bar=False)
|
53 |
+
chunk_embeddings = np.array(chunk_embeddings, dtype='float32')
|
54 |
+
|
55 |
+
dimension = chunk_embeddings.shape[1]
|
56 |
+
index = faiss.IndexFlatL2(dimension)
|
57 |
+
index.add(chunk_embeddings)
|
58 |
+
|
59 |
+
return index, chunk_embeddings
|
60 |
+
|
61 |
+
###############################################################################
|
62 |
+
# 4. Retrieval Function
|
63 |
+
###############################################################################
|
64 |
+
def retrieve_chunks(query, index, chunks, top_k=3):
|
65 |
+
"""
|
66 |
+
Embeds 'query' and retrieves the top_k most relevant chunks from 'index'.
|
67 |
+
"""
|
68 |
+
query_embedding = embedding_model.encode([query], show_progress_bar=False)
|
69 |
+
query_embedding = np.array(query_embedding, dtype='float32')
|
70 |
+
|
71 |
+
distances, indices = index.search(query_embedding, top_k)
|
72 |
+
return [chunks[i] for i in indices[0]]
|
73 |
+
|
74 |
+
###############################################################################
|
75 |
+
# 5. Gemini LLM Integration
|
76 |
+
###############################################################################
|
77 |
+
def gemini_generate(prompt):
|
78 |
+
"""
|
79 |
+
Calls Google's Gemini API with the environment variable GEMINI_API_KEY.
|
80 |
+
"""
|
81 |
+
gemini_api_key = os.environ.get("GEMINI_API_KEY", "")
|
82 |
+
if not gemini_api_key:
|
83 |
+
return "Error: No GEMINI_API_KEY found in environment variables."
|
84 |
+
|
85 |
+
url = (
|
86 |
+
"https://generativelanguage.googleapis.com/"
|
87 |
+
"v1beta/models/gemini-1.5-flash:generateContent"
|
88 |
+
f"?key={gemini_api_key}"
|
89 |
+
)
|
90 |
+
payload = {
|
91 |
+
"contents": [
|
92 |
+
{
|
93 |
+
"parts": [
|
94 |
+
{"text": prompt}
|
95 |
+
]
|
96 |
+
}
|
97 |
+
]
|
98 |
+
}
|
99 |
+
headers = {"Content-Type": "application/json"}
|
100 |
+
|
101 |
+
try:
|
102 |
+
response = requests.post(url, headers=headers, json=payload)
|
103 |
+
response.raise_for_status()
|
104 |
+
r_data = response.json()
|
105 |
+
# Extract the text from the 'candidates' structure:
|
106 |
+
return r_data["candidates"][0]["content"]["parts"][0]["text"]
|
107 |
+
except requests.exceptions.RequestException as e:
|
108 |
+
return f"Error calling Gemini API: {e}"
|
109 |
+
except KeyError:
|
110 |
+
return f"Parsing error or unexpected response format: {response.text}"
|
111 |
+
|
112 |
+
###############################################################################
|
113 |
+
# 6. RAG QA Function
|
114 |
+
###############################################################################
|
115 |
+
def answer_question_with_RAG(user_question, index, chunks):
|
116 |
+
"""
|
117 |
+
Retrieves relevant chunks, builds an augmented prompt, and calls gemini_generate().
|
118 |
+
"""
|
119 |
+
relevant_chunks = retrieve_chunks(user_question, index, chunks, top_k=3)
|
120 |
+
context = "\n\n".join(relevant_chunks)
|
121 |
+
|
122 |
+
prompt = f"""
|
123 |
+
You are an AI assistant that knows the details from the uploaded research paper.
|
124 |
+
Answer the user's question accurately using the context below.
|
125 |
+
If something is not in the context, say 'I don't know'.
|
126 |
+
|
127 |
+
Context:
|
128 |
+
{context}
|
129 |
+
|
130 |
+
User's question: {user_question}
|
131 |
+
|
132 |
+
Answer:
|
133 |
+
"""
|
134 |
+
return gemini_generate(prompt)
|
135 |
+
|
136 |
+
###############################################################################
|
137 |
+
# Streamlit Application
|
138 |
+
###############################################################################
|
139 |
+
def main():
|
140 |
+
# Basic page config (optional):
|
141 |
+
st.set_page_config(
|
142 |
+
page_title="AI-Powered Personal Research Assistant",
|
143 |
+
layout="centered"
|
144 |
+
)
|
145 |
+
|
146 |
+
# Title and Subheader
|
147 |
+
st.title("AI-Powered Personal Research Assistant")
|
148 |
+
st.write("Welcome! How may I help you?")
|
149 |
+
|
150 |
+
# Store the FAISS index + chunks in session_state to persist across reruns
|
151 |
+
if "faiss_index" not in st.session_state:
|
152 |
+
st.session_state.faiss_index = None
|
153 |
+
if "chunks" not in st.session_state:
|
154 |
+
st.session_state.chunks = None
|
155 |
+
|
156 |
+
# Step 1: Upload and Process PDF
|
157 |
+
uploaded_pdf = st.file_uploader("Upload your research paper (PDF)", type=["pdf"])
|
158 |
+
if st.button("Process PDF"):
|
159 |
+
if uploaded_pdf is None:
|
160 |
+
st.warning("Please upload a PDF file first.")
|
161 |
+
else:
|
162 |
+
# Read and chunk
|
163 |
+
raw_text = extract_pdf_text(uploaded_pdf)
|
164 |
+
if not raw_text.strip():
|
165 |
+
st.error("No text found in PDF.")
|
166 |
+
return
|
167 |
+
chunks = chunk_text(raw_text, chunk_size=300, overlap=50)
|
168 |
+
if not chunks:
|
169 |
+
st.error("No valid text to chunk.")
|
170 |
+
return
|
171 |
+
# Build index
|
172 |
+
faiss_index, _ = build_faiss_index(chunks)
|
173 |
+
st.session_state.faiss_index = faiss_index
|
174 |
+
st.session_state.chunks = chunks
|
175 |
+
st.success("PDF processed successfully!")
|
176 |
+
|
177 |
+
# Step 2: Ask a Question
|
178 |
+
user_question = st.text_input("Ask a question about your research paper:")
|
179 |
+
if st.button("Get Answer"):
|
180 |
+
if not st.session_state.faiss_index or not st.session_state.chunks:
|
181 |
+
st.warning("Please upload and process a PDF first.")
|
182 |
+
elif not user_question.strip():
|
183 |
+
st.warning("Please enter a valid question.")
|
184 |
+
else:
|
185 |
+
answer = answer_question_with_RAG(
|
186 |
+
user_question,
|
187 |
+
st.session_state.faiss_index,
|
188 |
+
st.session_state.chunks
|
189 |
+
)
|
190 |
+
st.write("### Answer:")
|
191 |
+
st.write(answer)
|
192 |
+
|
193 |
+
if __name__ == "__main__":
|
194 |
+
main()
|