aaporosh commited on
Commit
7c6674a
·
verified ·
1 Parent(s): 574210c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +275 -235
app.py CHANGED
@@ -3,49 +3,137 @@ import logging
3
  import os
4
  from io import BytesIO
5
  import pdfplumber
6
- from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain_community.vectorstores import FAISS
8
  from sentence_transformers import SentenceTransformer
9
- from transformers import pipeline
 
10
  import re
11
 
12
- # Setup logging
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
  logger = logging.getLogger(__name__)
15
 
16
- # ----------- Load Models -----------
17
-
18
  @st.cache_resource(ttl=1800)
19
  def load_embeddings_model():
 
20
  try:
21
  return SentenceTransformer("all-MiniLM-L12-v2")
22
  except Exception as e:
 
23
  st.error(f"Embedding model error: {str(e)}")
24
  return None
25
 
26
  @st.cache_resource(ttl=1800)
27
  def load_qa_pipeline():
 
28
  try:
 
 
 
 
 
29
  return pipeline("text2text-generation", model="google/flan-t5-small", max_length=300)
30
  except Exception as e:
 
31
  st.error(f"QA model error: {str(e)}")
32
  return None
33
 
34
  @st.cache_resource(ttl=1800)
35
  def load_summary_pipeline():
 
36
  try:
37
  return pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", max_length=150)
38
  except Exception as e:
 
39
  st.error(f"Summary model error: {str(e)}")
40
  return None
41
 
42
- # ----------- PDF Processing -----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
44
  def process_pdf(uploaded_file):
45
- text = ""
46
- code_blocks = []
47
  try:
48
- with pdfplumber.open(BytesIO(uploaded_file.read())) as pdf:
 
 
49
  for page in pdf.pages[:20]:
50
  extracted = page.extract_text(layout=False)
51
  if extracted:
@@ -53,248 +141,200 @@ def process_pdf(uploaded_file):
53
  for char in page.chars:
54
  if 'fontname' in char and 'mono' in char['fontname'].lower():
55
  code_blocks.append(char['text'])
56
- code_text_page = page.extract_text() or ""
57
- code_matches = re.finditer(r'(^\s{2,}.*?(?:\n\s{2,}.*?)*)', code_text_page, re.MULTILINE)
58
  for match in code_matches:
59
  code_blocks.append(match.group().strip())
60
  tables = page.extract_tables()
61
  if tables:
62
  for table in tables:
63
  text += "\n".join([" | ".join(map(str, row)) for row in table if row]) + "\n"
64
- code_text = "\n".join(code_blocks).strip()
 
 
65
 
66
- text_splitter = RecursiveCharacterTextSplitter(
67
- chunk_size=500, chunk_overlap=100, separators=["\n\n", "\n", ".", " "]
68
- )
 
 
69
  text_chunks = text_splitter.split_text(text)[:50]
70
  code_chunks = text_splitter.split_text(code_text)[:25] if code_text else []
71
-
72
  embeddings_model = load_embeddings_model()
73
  if not embeddings_model:
74
  return None, None, text, code_text
75
-
76
- text_vectors = [embeddings_model.encode(chunk) for chunk in text_chunks]
77
- code_vectors = [embeddings_model.encode(chunk) for chunk in code_chunks]
78
-
79
- text_vector_store = FAISS.from_embeddings(zip(text_chunks, text_vectors), embeddings_model.encode) if text_chunks else None
80
- code_vector_store = FAISS.from_embeddings(zip(code_chunks, code_vectors), embeddings_model.encode) if code_chunks else None
81
-
 
 
 
 
 
 
 
 
82
  return text_vector_store, code_vector_store, text, code_text
83
-
84
  except Exception as e:
 
85
  st.error(f"PDF error: {str(e)}")
86
  return None, None, "", ""
87
 
88
- # ----------- Preload Dataset -----------
89
-
90
- def preload_dataset():
91
- dataset_path = "data"
92
- combined_text = ""
93
- combined_code = ""
94
- text_vector_store = None
95
- code_vector_store = None
96
-
97
- if not os.path.exists(dataset_path):
98
- return text_vector_store, code_vector_store, combined_text, combined_code
99
-
100
- embeddings_model = load_embeddings_model()
101
- if not embeddings_model:
102
- return text_vector_store, code_vector_store, combined_text, combined_code
103
-
104
- all_text_chunks = []
105
- all_text_vectors = []
106
- all_code_chunks = []
107
- all_code_vectors = []
108
-
109
- for file_name in os.listdir(dataset_path):
110
- file_path = os.path.join(dataset_path, file_name)
111
- if file_name.lower().endswith(".pdf"):
112
- with open(file_path, "rb") as f:
113
- t_store, c_store, t_text, c_text = process_pdf(f)
114
- combined_text += t_text + "\n"
115
- combined_code += c_text + "\n"
116
- if t_store:
117
- for chunk in t_store.index_to_docstore().values():
118
- all_text_chunks.append(chunk)
119
- all_text_vectors.append(embeddings_model.encode(chunk))
120
- if c_store:
121
- for chunk in c_store.index_to_docstore().values():
122
- all_code_chunks.append(chunk)
123
- all_code_vectors.append(embeddings_model.encode(chunk))
124
- elif file_name.lower().endswith(".txt"):
125
- with open(file_path, "r", encoding="utf-8") as f:
126
- text_content = f.read()
127
- combined_text += text_content + "\n"
128
- chunks = text_content.split("\n\n")
129
- for chunk in chunks:
130
- all_text_chunks.append(chunk)
131
- all_text_vectors.append(embeddings_model.encode(chunk))
132
-
133
- if all_text_chunks:
134
- text_vector_store = FAISS.from_embeddings(zip(all_text_chunks, all_text_vectors), embeddings_model.encode)
135
- if all_code_chunks:
136
- code_vector_store = FAISS.from_embeddings(zip(all_code_chunks, all_code_vectors), embeddings_model.encode)
137
-
138
- return text_vector_store, code_vector_store, combined_text, combined_code
139
-
140
- # ----------- Streamlit UI -----------
141
-
142
- st.set_page_config(page_title="Smart PDF Q&A", page_icon="📄", layout="wide")
143
-
144
- # Fixed CSS for chat colors
145
- st.markdown("""
146
- <style>
147
- /* Chat container */
148
- .chat-container {
149
- border: 1px solid #ddd;
150
- border-radius: 10px;
151
- padding: 10px;
152
- height: 60vh;
153
- overflow-y: auto;
154
- margin-top: 20px;
155
- }
156
-
157
- /* Chat bubbles */
158
- .stChatMessage {
159
- border-radius: 15px;
160
- padding: 10px;
161
- margin: 5px;
162
- max-width: 70%;
163
- word-wrap: break-word;
164
- }
165
-
166
- /* User message */
167
- .user {
168
- background-color: #e6f3ff !important;
169
- color: #000 !important;
170
- align-self: flex-end;
171
- text-align: right;
172
- }
173
-
174
- /* Assistant message */
175
- .assistant {
176
- background-color: #f0f0f0 !important;
177
- color: #000 !important;
178
- text-align: left;
179
- }
180
-
181
- /* Dark mode support */
182
- body[data-theme="dark"] .user {
183
- background-color: #2a2a72 !important;
184
- color: #fff !important;
185
- }
186
- body[data-theme="dark"] .assistant {
187
- background-color: #2e2e2e !important;
188
- color: #fff !important;
189
- }
190
-
191
- /* Buttons */
192
- .stButton>button {
193
- background-color: #4CAF50;
194
- color: white;
195
- border: none;
196
- padding: 8px 16px;
197
- border-radius: 5px;
198
- }
199
- .stButton>button:hover {
200
- background-color: #45a049;
201
- }
202
-
203
- /* Preformatted code */
204
- pre {
205
- background-color: #f8f8f8;
206
- padding: 10px;
207
- border-radius: 5px;
208
- overflow-x: auto;
209
- }
210
-
211
- /* Header */
212
- .header {
213
- background: linear-gradient(90deg, #4CAF50, #81C784);
214
- color: white;
215
- padding: 10px;
216
- border-radius: 5px;
217
- text-align: center;
218
- }
219
- </style>
220
- """, unsafe_allow_html=True)
221
-
222
- st.markdown('<div class="header"><h1>Smart PDF Q&A</h1></div>', unsafe_allow_html=True)
223
- st.markdown("Upload a PDF to ask questions, summarize (~150 words), or extract code with 'give me code'.")
224
-
225
- # Session state
226
- if "messages" not in st.session_state:
227
- st.session_state.messages = []
228
- if "text_vector_store" not in st.session_state:
229
- st.session_state.text_vector_store = None
230
- if "code_vector_store" not in st.session_state:
231
- st.session_state.code_vector_store = None
232
- if "pdf_text" not in st.session_state:
233
- st.session_state.pdf_text = ""
234
- if "code_text" not in st.session_state:
235
- st.session_state.code_text = ""
236
-
237
- # Preload dataset at start
238
- if st.session_state.text_vector_store is None and st.session_state.code_vector_store is None:
239
- st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text = preload_dataset()
240
- if st.session_state.text_vector_store or st.session_state.code_vector_store:
241
- st.info("Preloaded sample dataset loaded for better QA and code retrieval.")
242
-
243
- # PDF upload & buttons
244
- uploaded_file = st.file_uploader("Upload a PDF", type=["pdf"])
245
- col1, col2 = st.columns([1,1])
246
- with col1:
247
- if st.button("Process PDF") and uploaded_file:
248
- with st.spinner("Processing PDF..."):
249
- st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text = process_pdf(uploaded_file)
250
- if st.session_state.text_vector_store or st.session_state.code_vector_store:
251
- st.success("PDF processed! Ask away or summarize.")
252
- st.session_state.messages = []
253
- else:
254
- st.error("Failed to process PDF.")
255
-
256
- with col2:
257
- if st.button("Summarize PDF") and st.session_state.pdf_text:
258
- with st.spinner("Summarizing..."):
259
- summary_pipeline = load_summary_pipeline()
260
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50, separators=["\n\n", "\n", ".", " "])
261
- chunks = text_splitter.split_text(st.session_state.pdf_text)[:2]
262
- summaries = []
263
- for chunk in chunks:
264
- summary = summary_pipeline(chunk[:500], max_length=100, min_length=30, do_sample=False)[0]['summary_text']
265
- summaries.append(summary.strip())
266
- combined_summary = " ".join(summaries)
267
- st.session_state.messages.append({"role":"assistant","content":combined_summary})
268
- st.markdown(combined_summary)
269
 
270
- # Chat interface
271
- st.markdown('<div class="chat-container">', unsafe_allow_html=True)
272
- prompt = st.chat_input("Ask a question (e.g., 'Give me code' or 'What’s the main idea?'):")
273
- if prompt:
274
- st.session_state.messages.append({"role":"user","content":prompt})
275
- with st.chat_message("user"):
276
- st.markdown(f"<div class='user'>{prompt}</div>", unsafe_allow_html=True)
277
- with st.chat_message("assistant"):
278
  qa_pipeline = load_qa_pipeline()
279
- is_code_query = any(k in prompt.lower() for k in ["code","script","function","programming","give me code","show code"])
280
- if is_code_query and st.session_state.code_vector_store:
281
- answer = f"Here's the code from the PDF:\n```python\n{st.session_state.code_text}\n```"
282
- elif st.session_state.text_vector_store:
283
- docs = st.session_state.text_vector_store.similarity_search(prompt, k=5)
284
- context = "\n".join(doc.page_content for doc in docs)
285
- answer = qa_pipeline(f"Context: {context}\nQuestion: {prompt}\nProvide a detailed answer.")[0]['generated_text']
286
- else:
287
- answer = "Please upload a PDF first!"
288
- st.markdown(f"<div class='assistant'>{answer}</div>", unsafe_allow_html=True)
289
- st.session_state.messages.append({"role":"assistant","content":answer})
290
-
291
- # Display chat history
292
- for msg in st.session_state.messages:
293
- cls = "user" if msg["role"]=="user" else "assistant"
294
- st.markdown(f"<div class='{cls}' style='margin:5px;padding:10px;border-radius:15px;'>{msg['content']}</div>", unsafe_allow_html=True)
295
- st.markdown('</div>', unsafe_allow_html=True)
296
-
297
- # Download chat
298
- if st.session_state.messages:
299
- chat_text = "\n".join(f"{m['role'].capitalize()}: {m['content']}" for m in st.session_state.messages)
300
- st.download_button("Download Chat History", chat_text, "chat_history.txt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
  from io import BytesIO
5
  import pdfplumber
6
+ from langchain.text_splitter import CharacterTextSplitter
7
  from langchain_community.vectorstores import FAISS
8
  from sentence_transformers import SentenceTransformer
9
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments
10
+ from datasets import load_dataset
11
  import re
12
 
13
+ # Setup logging for Spaces
14
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
  logger = logging.getLogger(__name__)
16
 
17
+ # Lazy load models
 
18
  @st.cache_resource(ttl=1800)
19
  def load_embeddings_model():
20
+ logger.info("Loading embeddings model")
21
  try:
22
  return SentenceTransformer("all-MiniLM-L12-v2")
23
  except Exception as e:
24
+ logger.error(f"Embeddings load error: {str(e)}")
25
  st.error(f"Embedding model error: {str(e)}")
26
  return None
27
 
28
  @st.cache_resource(ttl=1800)
29
  def load_qa_pipeline():
30
+ logger.info("Loading QA pipeline")
31
  try:
32
+ dataset = load_and_prepare_dataset()
33
+ if dataset:
34
+ fine_tuned_pipeline = fine_tune_qa_model(dataset)
35
+ if fine_tuned_pipeline:
36
+ return fine_tuned_pipeline
37
  return pipeline("text2text-generation", model="google/flan-t5-small", max_length=300)
38
  except Exception as e:
39
+ logger.error(f"QA model load error: {str(e)}")
40
  st.error(f"QA model error: {str(e)}")
41
  return None
42
 
43
  @st.cache_resource(ttl=1800)
44
  def load_summary_pipeline():
45
+ logger.info("Loading summary pipeline")
46
  try:
47
  return pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", max_length=150)
48
  except Exception as e:
49
+ logger.error(f"Summary model load error: {str(e)}")
50
  st.error(f"Summary model error: {str(e)}")
51
  return None
52
 
53
+ # Load and prepare dataset (e.g., SQuAD)
54
+ @st.cache_resource(ttl=3600)
55
+ def load_and_prepare_dataset(dataset_name="squad", max_samples=1000):
56
+ logger.info(f"Loading dataset: {dataset_name}")
57
+ try:
58
+ dataset = load_dataset(dataset_name, split="train")
59
+ dataset = dataset.shuffle(seed=42).select(range(max_samples))
60
+
61
+ def preprocess(examples):
62
+ inputs = [f"question: {q} context: {c}" for q, c in zip(examples['question'], examples['context'])]
63
+ targets = examples['answers']['text']
64
+ return {'input_text': inputs, 'target_text': [t[0] if t else "" for t in targets]}
65
+
66
+ dataset = dataset.map(preprocess, batched=True)
67
+ return dataset
68
+ except Exception as e:
69
+ logger.error(f"Dataset load error: {str(e)}")
70
+ return None
71
+
72
+ # Fine-tune QA model
73
+ @st.cache_resource(ttl=3600)
74
+ def fine_tune_qa_model(dataset):
75
+ logger.info("Starting fine-tuning")
76
+ try:
77
+ model_name = "google/flan-t5-small"
78
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
79
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
80
+
81
+ def tokenize_function(examples):
82
+ model_inputs = tokenizer(examples['input_text'], max_length=512, truncation=True, padding="max_length")
83
+ labels = tokenizer(examples['target_text'], max_length=128, truncation=True, padding="max_length")
84
+ model_inputs["labels"] = labels["input_ids"]
85
+ return model_inputs
86
+
87
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
88
+
89
+ training_args = TrainingArguments(
90
+ output_dir="./fine_tuned_model",
91
+ num_train_epochs=1,
92
+ per_device_train_batch_size=4,
93
+ save_steps=500,
94
+ logging_steps=100,
95
+ evaluation_strategy="no",
96
+ learning_rate=5e-5,
97
+ fp16=False,
98
+ )
99
+
100
+ trainer = Trainer(
101
+ model=model,
102
+ args=training_args,
103
+ train_dataset=tokenized_dataset,
104
+ )
105
+ trainer.train()
106
+
107
+ model.save_pretrained("./fine_tuned_model")
108
+ tokenizer.save_pretrained("./fine_tuned_model")
109
+ logger.info("Fine-tuning complete")
110
+ return pipeline("text2text-generation", model="./fine_tuned_model", tokenizer="./fine_tuned_model", max_length=300)
111
+ except Exception as e:
112
+ logger.error(f"Fine-tuning error: {str(e)}")
113
+ return None
114
+
115
+ # Augment vector store with dataset
116
+ def augment_vector_store(vector_store, dataset_name="squad", max_samples=500):
117
+ logger.info(f"Augmenting vector store with dataset: {dataset_name}")
118
+ try:
119
+ dataset = load_dataset(dataset_name, split="train").select(range(max_samples))
120
+ chunks = [f"Context: {c}\nAnswer: {a['text'][0]}" for c, a in zip(dataset['context'], dataset['answers'])]
121
+ embeddings_model = load_embeddings_model()
122
+ if embeddings_model and vector_store:
123
+ embeddings = embeddings_model.encode(chunks)
124
+ vector_store.add_embeddings(zip(chunks, embeddings))
125
+ return vector_store
126
+ except Exception as e:
127
+ logger.error(f"Vector store augmentation error: {str(e)}")
128
+ return vector_store
129
 
130
+ # Process PDF with enhanced extraction
131
  def process_pdf(uploaded_file):
132
+ logger.info("Processing PDF with enhanced extraction")
 
133
  try:
134
+ text = ""
135
+ code_blocks = []
136
+ with pdfplumber.open(BytesIO(uploaded_file.getvalue())) as pdf:
137
  for page in pdf.pages[:20]:
138
  extracted = page.extract_text(layout=False)
139
  if extracted:
 
141
  for char in page.chars:
142
  if 'fontname' in char and 'mono' in char['fontname'].lower():
143
  code_blocks.append(char['text'])
144
+ code_text = page.extract_text()
145
+ code_matches = re.finditer(r'(^\s{2,}.*?(?:\n\s{2,}.*?)*)', code_text, re.MULTILINE)
146
  for match in code_matches:
147
  code_blocks.append(match.group().strip())
148
  tables = page.extract_tables()
149
  if tables:
150
  for table in tables:
151
  text += "\n".join([" | ".join(map(str, row)) for row in table if row]) + "\n"
152
+ for obj in page.extract_words():
153
+ if obj.get('size', 0) > 12:
154
+ text += f"\n{obj['text']}\n"
155
 
156
+ code_text = "\n".join(code_blocks).strip()
157
+ if not text:
158
+ raise ValueError("No text extracted from PDF")
159
+
160
+ text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=500, chunk_overlap=100, keep_separator=True)
161
  text_chunks = text_splitter.split_text(text)[:50]
162
  code_chunks = text_splitter.split_text(code_text)[:25] if code_text else []
163
+
164
  embeddings_model = load_embeddings_model()
165
  if not embeddings_model:
166
  return None, None, text, code_text
167
+
168
+ text_vector_store = FAISS.from_embeddings(
169
+ zip(text_chunks, [embeddings_model.encode(chunk) for chunk in text_chunks]),
170
+ embeddings_model.encode
171
+ ) if text_chunks else None
172
+ code_vector_store = FAISS.from_embeddings(
173
+ zip(code_chunks, [embeddings_model.encode(chunk) for chunk in code_chunks]),
174
+ embeddings_model.encode
175
+ ) if code_chunks else None
176
+
177
+ # Augment text vector store with dataset
178
+ if text_vector_store:
179
+ text_vector_store = augment_vector_store(text_vector_store)
180
+
181
+ logger.info("PDF processed successfully with enhanced extraction")
182
  return text_vector_store, code_vector_store, text, code_text
 
183
  except Exception as e:
184
+ logger.error(f"PDF processing error: {str(e)}")
185
  st.error(f"PDF error: {str(e)}")
186
  return None, None, "", ""
187
 
188
+ # Summarize PDF
189
+ def summarize_pdf(text):
190
+ logger.info("Generating summary")
191
+ try:
192
+ summary_pipeline = load_summary_pipeline()
193
+ if not summary_pipeline:
194
+ return "Summary model unavailable."
195
+
196
+ text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=500, chunk_overlap=50)
197
+ chunks = text_splitter.split_text(text)[:2]
198
+ summaries = []
199
+
200
+ for chunk in chunks:
201
+ summary = summary_pipeline(chunk[:500], max_length=100, min_length=30, do_sample=False)[0]['summary_text']
202
+ summaries.append(summary.strip())
203
+
204
+ combined_summary = " ".join(summaries)
205
+ if len(combined_summary.split()) > 150:
206
+ combined_summary = " ".join(combined_summary.split()[:150])
207
+ logger.info("Summary generated")
208
+ return f"Sure, here's a concise summary of the PDF:\n{combined_summary}"
209
+ except Exception as e:
210
+ logger.error(f"Summary error: {str(e)}")
211
+ return f"Oops, something went wrong summarizing: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
+ # Answer question with improved response
214
+ def answer_question(text_vector_store, code_vector_store, query):
215
+ logger.info(f"Processing query: {query}")
216
+ try:
217
+ if not text_vector_store and not code_vector_store:
218
+ return "Please upload a PDF first!"
219
+
 
220
  qa_pipeline = load_qa_pipeline()
221
+ if not qa_pipeline:
222
+ return "Sorry, the QA model is unavailable right now."
223
+
224
+ is_code_query = any(keyword in query.lower() for keyword in ["code", "script", "function", "programming", "give me code", "show code"])
225
+ if is_code_query and code_vector_store:
226
+ return f"Here's the code from the PDF:\n```python\n{st.session_state.code_text}\n```"
227
+
228
+ vector_store = text_vector_store
229
+ if not vector_store:
230
+ return "No relevant content found for your query."
231
+
232
+ docs = vector_store.similarity_search(query, k=5)
233
+ context = "\n".join(doc.page_content for doc in docs)
234
+ prompt = f"Context: {context}\nQuestion: {query}\nProvide a detailed, accurate answer based on the context, prioritizing relevant information. Respond as a helpful assistant:"
235
+ response = qa_pipeline(prompt)[0]['generated_text']
236
+ logger.info("Answer generated")
237
+ return f"Got it! Here's a detailed answer:\n{response.strip()}"
238
+ except Exception as e:
239
+ logger.error(f"Query error: {str(e)}")
240
+ return f"Sorry, something went wrong: {str(e)}"
241
+
242
+ # Streamlit UI
243
+ try:
244
+ st.set_page_config(page_title="Smart PDF Q&A", page_icon="📄", layout="wide")
245
+ st.markdown("""
246
+ <style>
247
+ .main { max-width: 900px; margin: 0 auto; padding: 20px; }
248
+ .sidebar { background-color: #f8f9fa; padding: 10px; border-radius: 5px; }
249
+ .chat-container { border: 1px solid #ddd; border-radius: 10px; padding: 10px; height: 60vh; overflow-y: auto; margin-top: 20px; }
250
+ .stChatMessage { border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; }
251
+ .user { background-color: #e6f3ff; align-self: flex-end; }
252
+ .assistant { background-color: #f0f0f0; }
253
+ .dark .user { background-color: #2a2a72; color: #fff; }
254
+ .dark .assistant { background-color: #2e2e2e; color: #fff; }
255
+ .stButton>button { background-color: #4CAF50; color: white; border: none; padding: 8px 16px; border-radius: 5px; }
256
+ .stButton>button:hover { background-color: #45a049; }
257
+ pre { background-color: #f8f8f8; padding: 10px; border-radius: 5px; overflow-x: auto; }
258
+ .header { background: linear-gradient(90deg, #4CAF50, #81C784); color: white; padding: 10px; border-radius: 5px; text-align: center; }
259
+ </style>
260
+ """, unsafe_allow_html=True)
261
+
262
+ st.markdown('<div class="header"><h1>Smart PDF Q&A</h1></div>', unsafe_allow_html=True)
263
+ st.markdown("Upload a PDF to ask questions, summarize (~150 words), or extract code with 'give me code'. Fast and friendly responses!")
264
+
265
+ # Initialize session state
266
+ if "messages" not in st.session_state:
267
+ st.session_state.messages = []
268
+ if "text_vector_store" not in st.session_state:
269
+ st.session_state.text_vector_store = None
270
+ if "code_vector_store" not in st.session_state:
271
+ st.session_state.code_vector_store = None
272
+ if "pdf_text" not in st.session_state:
273
+ st.session_state.pdf_text = ""
274
+ if "code_text" not in st.session_state:
275
+ st.session_state.code_text = ""
276
+
277
+ # Sidebar with toggle and dataset options
278
+ with st.sidebar:
279
+ st.markdown('<div class="sidebar">', unsafe_allow_html=True)
280
+ theme = st.radio("Theme", ["Light", "Dark"], index=0)
281
+ dataset_name = st.selectbox("Select Dataset for Fine-Tuning", ["squad", "cnn_dailymail", "bigcode/the-stack"], index=0)
282
+ if st.button("Fine-Tune Model"):
283
+ with st.spinner("Fine-tuning model..."):
284
+ dataset = load_and_prepare_dataset(dataset_name=dataset_name)
285
+ if dataset:
286
+ fine_tuned_pipeline = fine_tune_qa_model(dataset)
287
+ if fine_tuned_pipeline:
288
+ st.success("Model fine-tuned successfully!")
289
+ else:
290
+ st.error("Fine-tuning failed.")
291
+ st.markdown('</div>', unsafe_allow_html=True)
292
+
293
+ # PDF upload and processing
294
+ uploaded_file = st.file_uploader("Upload a PDF", type=["pdf"])
295
+ col1, col2 = st.columns([1, 1])
296
+ with col1:
297
+ if st.button("Process PDF"):
298
+ with st.spinner("Processing PDF..."):
299
+ st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text = process_pdf(uploaded_file)
300
+ if st.session_state.text_vector_store or st.session_state.code_vector_store:
301
+ st.success("PDF processed! Ask away or summarize.")
302
+ st.session_state.messages = []
303
+ else:
304
+ st.error("Failed to process PDF.")
305
+ with col2:
306
+ if st.button("Summarize PDF") and st.session_state.pdf_text:
307
+ with st.spinner("Summarizing..."):
308
+ summary = summarize_pdf(st.session_state.pdf_text)
309
+ st.session_state.messages.append({"role": "assistant", "content": summary})
310
+ st.markdown(summary, unsafe_allow_html=True)
311
+
312
+ # Chat interface
313
+ st.markdown('<div class="chat-container">', unsafe_allow_html=True)
314
+ if st.session_state.text_vector_store or st.session_state.code_vector_store:
315
+ prompt = st.chat_input("Ask a question (e.g., 'Give me code' or 'What’s the main idea?'):")
316
+ if prompt:
317
+ st.session_state.messages.append({"role": "user", "content": prompt})
318
+ with st.chat_message("user"):
319
+ st.markdown(prompt)
320
+ with st.chat_message("assistant"):
321
+ with st.spinner('<div class="spinner">⏳</div>'):
322
+ answer = answer_question(st.session_state.text_vector_store, st.session_state.code_vector_store, prompt)
323
+ st.markdown(answer, unsafe_allow_html=True)
324
+ st.session_state.messages.append({"role": "assistant", "content": answer})
325
+
326
+ # Display chat history
327
+ for message in st.session_state.messages:
328
+ with st.chat_message(message["role"]):
329
+ st.markdown(message["content"], unsafe_allow_html=True)
330
+
331
+ st.markdown('</div>', unsafe_allow_html=True)
332
+
333
+ # Download chat history
334
+ if st.session_state.messages:
335
+ chat_text = "\n".join(f"{m['role'].capitalize()}: {m['content']}" for m in st.session_state.messages)
336
+ st.download_button("Download Chat History", chat_text, "chat_history.txt")
337
+
338
+ except Exception as e:
339
+ logger.error(f"App initialization failed: {str(e)}")
340
+ st.error(f"App failed to start: {str(e)}. Check Spaces logs or contact support.")