fhmsf commited on
Commit
f47a6a2
·
verified ·
1 Parent(s): b0ac1ef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import faiss
3
+ import numpy as np
4
+ import requests
5
+ import streamlit as st
6
+
7
+ from pypdf import PdfReader
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ ###############################################################################
11
+ # 1. PDF Parsing and Chunking
12
+ ###############################################################################
13
+ def extract_pdf_text(pdf_file) -> str:
14
+ """
15
+ Read and extract text from each page of an uploaded PDF file.
16
+ """
17
+ reader = PdfReader(pdf_file)
18
+ all_text = []
19
+ for page in reader.pages:
20
+ text = page.extract_text() or ""
21
+ all_text.append(text.strip())
22
+ return "\n".join(all_text)
23
+
24
+ def chunk_text(text, chunk_size=300, overlap=50):
25
+ """
26
+ Splits text into overlapping chunks, each approx. 'chunk_size' tokens.
27
+ 'overlap' is how many tokens from the previous chunk to include again.
28
+ """
29
+ words = text.split()
30
+ chunks = []
31
+ start = 0
32
+ while start < len(words):
33
+ end = start + chunk_size
34
+ chunk = words[start:end]
35
+ chunks.append(" ".join(chunk))
36
+ start += (chunk_size - overlap)
37
+ return chunks
38
+
39
+ ###############################################################################
40
+ # 2. Embedding Model
41
+ ###############################################################################
42
+ embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
43
+
44
+ ###############################################################################
45
+ # 3. Build FAISS Index
46
+ ###############################################################################
47
+ def build_faiss_index(chunks):
48
+ """
49
+ Creates a FAISS index from embedded chunks.
50
+ Returns (index, chunk_embeddings).
51
+ """
52
+ chunk_embeddings = embedding_model.encode(chunks, show_progress_bar=False)
53
+ chunk_embeddings = np.array(chunk_embeddings, dtype='float32')
54
+
55
+ dimension = chunk_embeddings.shape[1]
56
+ index = faiss.IndexFlatL2(dimension)
57
+ index.add(chunk_embeddings)
58
+
59
+ return index, chunk_embeddings
60
+
61
+ ###############################################################################
62
+ # 4. Retrieval Function
63
+ ###############################################################################
64
+ def retrieve_chunks(query, index, chunks, top_k=3):
65
+ """
66
+ Embeds 'query' and retrieves the top_k most relevant chunks from 'index'.
67
+ """
68
+ query_embedding = embedding_model.encode([query], show_progress_bar=False)
69
+ query_embedding = np.array(query_embedding, dtype='float32')
70
+
71
+ distances, indices = index.search(query_embedding, top_k)
72
+ return [chunks[i] for i in indices[0]]
73
+
74
+ ###############################################################################
75
+ # 5. Gemini LLM Integration
76
+ ###############################################################################
77
+ def gemini_generate(prompt):
78
+ """
79
+ Calls Google's Gemini API with the environment variable GEMINI_API_KEY.
80
+ """
81
+ gemini_api_key = os.environ.get("GEMINI_API_KEY", "")
82
+ if not gemini_api_key:
83
+ return "Error: No GEMINI_API_KEY found in environment variables."
84
+
85
+ url = (
86
+ "https://generativelanguage.googleapis.com/"
87
+ "v1beta/models/gemini-1.5-flash:generateContent"
88
+ f"?key={gemini_api_key}"
89
+ )
90
+ payload = {
91
+ "contents": [
92
+ {
93
+ "parts": [
94
+ {"text": prompt}
95
+ ]
96
+ }
97
+ ]
98
+ }
99
+ headers = {"Content-Type": "application/json"}
100
+
101
+ try:
102
+ response = requests.post(url, headers=headers, json=payload)
103
+ response.raise_for_status()
104
+ r_data = response.json()
105
+ # Extract the text from the 'candidates' structure:
106
+ return r_data["candidates"][0]["content"]["parts"][0]["text"]
107
+ except requests.exceptions.RequestException as e:
108
+ return f"Error calling Gemini API: {e}"
109
+ except KeyError:
110
+ return f"Parsing error or unexpected response format: {response.text}"
111
+
112
+ ###############################################################################
113
+ # 6. RAG QA Function
114
+ ###############################################################################
115
+ def answer_question_with_RAG(user_question, index, chunks):
116
+ """
117
+ Retrieves relevant chunks, builds an augmented prompt, and calls gemini_generate().
118
+ """
119
+ relevant_chunks = retrieve_chunks(user_question, index, chunks, top_k=3)
120
+ context = "\n\n".join(relevant_chunks)
121
+
122
+ prompt = f"""
123
+ You are an AI assistant that knows the details from the uploaded research paper.
124
+ Answer the user's question accurately using the context below.
125
+ If something is not in the context, say 'I don't know'.
126
+
127
+ Context:
128
+ {context}
129
+
130
+ User's question: {user_question}
131
+
132
+ Answer:
133
+ """
134
+ return gemini_generate(prompt)
135
+
136
+ ###############################################################################
137
+ # Streamlit Application
138
+ ###############################################################################
139
+ def main():
140
+ # Basic page config (optional):
141
+ st.set_page_config(
142
+ page_title="AI-Powered Personal Research Assistant",
143
+ layout="centered"
144
+ )
145
+
146
+ # Title and Subheader
147
+ st.title("AI-Powered Personal Research Assistant")
148
+ st.write("Welcome! How may I help you?")
149
+
150
+ # Store the FAISS index + chunks in session_state to persist across reruns
151
+ if "faiss_index" not in st.session_state:
152
+ st.session_state.faiss_index = None
153
+ if "chunks" not in st.session_state:
154
+ st.session_state.chunks = None
155
+
156
+ # Step 1: Upload and Process PDF
157
+ uploaded_pdf = st.file_uploader("Upload your research paper (PDF)", type=["pdf"])
158
+ if st.button("Process PDF"):
159
+ if uploaded_pdf is None:
160
+ st.warning("Please upload a PDF file first.")
161
+ else:
162
+ # Read and chunk
163
+ raw_text = extract_pdf_text(uploaded_pdf)
164
+ if not raw_text.strip():
165
+ st.error("No text found in PDF.")
166
+ return
167
+ chunks = chunk_text(raw_text, chunk_size=300, overlap=50)
168
+ if not chunks:
169
+ st.error("No valid text to chunk.")
170
+ return
171
+ # Build index
172
+ faiss_index, _ = build_faiss_index(chunks)
173
+ st.session_state.faiss_index = faiss_index
174
+ st.session_state.chunks = chunks
175
+ st.success("PDF processed successfully!")
176
+
177
+ # Step 2: Ask a Question
178
+ user_question = st.text_input("Ask a question about your research paper:")
179
+ if st.button("Get Answer"):
180
+ if not st.session_state.faiss_index or not st.session_state.chunks:
181
+ st.warning("Please upload and process a PDF first.")
182
+ elif not user_question.strip():
183
+ st.warning("Please enter a valid question.")
184
+ else:
185
+ answer = answer_question_with_RAG(
186
+ user_question,
187
+ st.session_state.faiss_index,
188
+ st.session_state.chunks
189
+ )
190
+ st.write("### Answer:")
191
+ st.write(answer)
192
+
193
+ if __name__ == "__main__":
194
+ main()