aaporosh commited on
Commit
afc3005
·
verified ·
1 Parent(s): 7c6674a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -53
app.py CHANGED
@@ -3,12 +3,17 @@ import logging
3
  import os
4
  from io import BytesIO
5
  import pdfplumber
 
 
6
  from langchain.text_splitter import CharacterTextSplitter
7
  from langchain_community.vectorstores import FAISS
8
  from sentence_transformers import SentenceTransformer
9
  from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments
10
  from datasets import load_dataset
 
 
11
  import re
 
12
 
13
  # Setup logging for Spaces
14
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -34,7 +39,7 @@ def load_qa_pipeline():
34
  fine_tuned_pipeline = fine_tune_qa_model(dataset)
35
  if fine_tuned_pipeline:
36
  return fine_tuned_pipeline
37
- return pipeline("text2text-generation", model="google/flan-t5-small", max_length=300)
38
  except Exception as e:
39
  logger.error(f"QA model load error: {str(e)}")
40
  st.error(f"QA model error: {str(e)}")
@@ -51,19 +56,19 @@ def load_summary_pipeline():
51
  return None
52
 
53
  # Load and prepare dataset (e.g., SQuAD)
54
- @st.cache_resource(ttl=3600)
55
  def load_and_prepare_dataset(dataset_name="squad", max_samples=1000):
56
  logger.info(f"Loading dataset: {dataset_name}")
57
  try:
58
- dataset = load_dataset(dataset_name, split="train")
59
- dataset = dataset.shuffle(seed=42).select(range(max_samples))
60
 
61
  def preprocess(examples):
62
  inputs = [f"question: {q} context: {c}" for q, c in zip(examples['question'], examples['context'])]
63
  targets = examples['answers']['text']
64
  return {'input_text': inputs, 'target_text': [t[0] if t else "" for t in targets]}
65
 
66
- dataset = dataset.map(preprocess, batched=True)
67
  return dataset
68
  except Exception as e:
69
  logger.error(f"Dataset load error: {str(e)}")
@@ -74,7 +79,7 @@ def load_and_prepare_dataset(dataset_name="squad", max_samples=1000):
74
  def fine_tune_qa_model(dataset):
75
  logger.info("Starting fine-tuning")
76
  try:
77
- model_name = "google/flan-t5-small"
78
  tokenizer = AutoTokenizer.from_pretrained(model_name)
79
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
80
 
@@ -84,17 +89,17 @@ def fine_tune_qa_model(dataset):
84
  model_inputs["labels"] = labels["input_ids"]
85
  return model_inputs
86
 
87
- tokenized_dataset = dataset.map(tokenize_function, batched=True)
88
 
89
  training_args = TrainingArguments(
90
  output_dir="./fine_tuned_model",
91
- num_train_epochs=1,
92
  per_device_train_batch_size=4,
93
  save_steps=500,
94
  logging_steps=100,
95
  evaluation_strategy="no",
96
- learning_rate=5e-5,
97
- fp16=False,
98
  )
99
 
100
  trainer = Trainer(
@@ -116,18 +121,18 @@ def fine_tune_qa_model(dataset):
116
  def augment_vector_store(vector_store, dataset_name="squad", max_samples=500):
117
  logger.info(f"Augmenting vector store with dataset: {dataset_name}")
118
  try:
119
- dataset = load_dataset(dataset_name, split="train").select(range(max_samples))
120
  chunks = [f"Context: {c}\nAnswer: {a['text'][0]}" for c, a in zip(dataset['context'], dataset['answers'])]
121
  embeddings_model = load_embeddings_model()
122
  if embeddings_model and vector_store:
123
- embeddings = embeddings_model.encode(chunks)
124
  vector_store.add_embeddings(zip(chunks, embeddings))
125
  return vector_store
126
  except Exception as e:
127
  logger.error(f"Vector store augmentation error: {str(e)}")
128
  return vector_store
129
 
130
- # Process PDF with enhanced extraction
131
  def process_pdf(uploaded_file):
132
  logger.info("Processing PDF with enhanced extraction")
133
  try:
@@ -136,6 +141,12 @@ def process_pdf(uploaded_file):
136
  with pdfplumber.open(BytesIO(uploaded_file.getvalue())) as pdf:
137
  for page in pdf.pages[:20]:
138
  extracted = page.extract_text(layout=False)
 
 
 
 
 
 
139
  if extracted:
140
  text += extracted + "\n"
141
  for char in page.chars:
@@ -157,35 +168,34 @@ def process_pdf(uploaded_file):
157
  if not text:
158
  raise ValueError("No text extracted from PDF")
159
 
160
- text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=500, chunk_overlap=100, keep_separator=True)
161
- text_chunks = text_splitter.split_text(text)[:50]
162
- code_chunks = text_splitter.split_text(code_text)[:25] if code_text else []
163
 
164
  embeddings_model = load_embeddings_model()
165
  if not embeddings_model:
166
  return None, None, text, code_text
167
 
168
  text_vector_store = FAISS.from_embeddings(
169
- zip(text_chunks, [embeddings_model.encode(chunk) for chunk in text_chunks]),
170
  embeddings_model.encode
171
  ) if text_chunks else None
172
  code_vector_store = FAISS.from_embeddings(
173
- zip(code_chunks, [embeddings_model.encode(chunk) for chunk in code_chunks]),
174
  embeddings_model.encode
175
  ) if code_chunks else None
176
 
177
- # Augment text vector store with dataset
178
  if text_vector_store:
179
  text_vector_store = augment_vector_store(text_vector_store)
180
 
181
- logger.info("PDF processed successfully with enhanced extraction")
182
  return text_vector_store, code_vector_store, text, code_text
183
  except Exception as e:
184
  logger.error(f"PDF processing error: {str(e)}")
185
  st.error(f"PDF error: {str(e)}")
186
  return None, None, "", ""
187
 
188
- # Summarize PDF
189
  def summarize_pdf(text):
190
  logger.info("Generating summary")
191
  try:
@@ -193,24 +203,28 @@ def summarize_pdf(text):
193
  if not summary_pipeline:
194
  return "Summary model unavailable."
195
 
196
- text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=500, chunk_overlap=50)
197
  chunks = text_splitter.split_text(text)[:2]
198
  summaries = []
199
 
200
  for chunk in chunks:
201
- summary = summary_pipeline(chunk[:500], max_length=100, min_length=30, do_sample=False)[0]['summary_text']
202
  summaries.append(summary.strip())
203
 
204
  combined_summary = " ".join(summaries)
205
  if len(combined_summary.split()) > 150:
206
  combined_summary = " ".join(combined_summary.split()[:150])
207
- logger.info("Summary generated")
208
- return f"Sure, here's a concise summary of the PDF:\n{combined_summary}"
 
 
 
 
209
  except Exception as e:
210
  logger.error(f"Summary error: {str(e)}")
211
  return f"Oops, something went wrong summarizing: {str(e)}"
212
 
213
- # Answer question with improved response
214
  def answer_question(text_vector_store, code_vector_store, query):
215
  logger.info(f"Processing query: {query}")
216
  try:
@@ -223,18 +237,27 @@ def answer_question(text_vector_store, code_vector_store, query):
223
 
224
  is_code_query = any(keyword in query.lower() for keyword in ["code", "script", "function", "programming", "give me code", "show code"])
225
  if is_code_query and code_vector_store:
226
- return f"Here's the code from the PDF:\n```python\n{st.session_state.code_text}\n```"
 
 
 
227
 
228
  vector_store = text_vector_store
229
  if not vector_store:
230
  return "No relevant content found for your query."
231
 
232
- docs = vector_store.similarity_search(query, k=5)
233
- context = "\n".join(doc.page_content for doc in docs)
234
- prompt = f"Context: {context}\nQuestion: {query}\nProvide a detailed, accurate answer based on the context, prioritizing relevant information. Respond as a helpful assistant:"
 
 
 
 
 
 
235
  response = qa_pipeline(prompt)[0]['generated_text']
236
  logger.info("Answer generated")
237
- return f"Got it! Here's a detailed answer:\n{response.strip()}"
238
  except Exception as e:
239
  logger.error(f"Query error: {str(e)}")
240
  return f"Sorry, something went wrong: {str(e)}"
@@ -245,22 +268,24 @@ try:
245
  st.markdown("""
246
  <style>
247
  .main { max-width: 900px; margin: 0 auto; padding: 20px; }
248
- .sidebar { background-color: #f8f9fa; padding: 10px; border-radius: 5px; }
249
- .chat-container { border: 1px solid #ddd; border-radius: 10px; padding: 10px; height: 60vh; overflow-y: auto; margin-top: 20px; }
250
- .stChatMessage { border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; }
251
- .user { background-color: #e6f3ff; align-self: flex-end; }
252
- .assistant { background-color: #f0f0f0; }
253
- .dark .user { background-color: #2a2a72; color: #fff; }
254
- .dark .assistant { background-color: #2e2e2e; color: #fff; }
255
- .stButton>button { background-color: #4CAF50; color: white; border: none; padding: 8px 16px; border-radius: 5px; }
256
- .stButton>button:hover { background-color: #45a049; }
257
- pre { background-color: #f8f8f8; padding: 10px; border-radius: 5px; overflow-x: auto; }
258
- .header { background: linear-gradient(90deg, #4CAF50, #81C784); color: white; padding: 10px; border-radius: 5px; text-align: center; }
 
 
259
  </style>
260
  """, unsafe_allow_html=True)
261
 
262
  st.markdown('<div class="header"><h1>Smart PDF Q&A</h1></div>', unsafe_allow_html=True)
263
- st.markdown("Upload a PDF to ask questions, summarize (~150 words), or extract code with 'give me code'. Fast and friendly responses!")
264
 
265
  # Initialize session state
266
  if "messages" not in st.session_state:
@@ -274,20 +299,26 @@ try:
274
  if "code_text" not in st.session_state:
275
  st.session_state.code_text = ""
276
 
277
- # Sidebar with toggle and dataset options
278
  with st.sidebar:
279
  st.markdown('<div class="sidebar">', unsafe_allow_html=True)
280
  theme = st.radio("Theme", ["Light", "Dark"], index=0)
281
  dataset_name = st.selectbox("Select Dataset for Fine-Tuning", ["squad", "cnn_dailymail", "bigcode/the-stack"], index=0)
282
  if st.button("Fine-Tune Model"):
283
- with st.spinner("Fine-tuning model..."):
284
- dataset = load_and_prepare_dataset(dataset_name=dataset_name)
285
- if dataset:
286
- fine_tuned_pipeline = fine_tune_qa_model(dataset)
287
- if fine_tuned_pipeline:
288
- st.success("Model fine-tuned successfully!")
289
- else:
290
- st.error("Fine-tuning failed.")
 
 
 
 
 
 
291
  st.markdown('</div>', unsafe_allow_html=True)
292
 
293
  # PDF upload and processing
@@ -295,7 +326,11 @@ try:
295
  col1, col2 = st.columns([1, 1])
296
  with col1:
297
  if st.button("Process PDF"):
 
298
  with st.spinner("Processing PDF..."):
 
 
 
299
  st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text = process_pdf(uploaded_file)
300
  if st.session_state.text_vector_store or st.session_state.code_vector_store:
301
  st.success("PDF processed! Ask away or summarize.")
@@ -304,7 +339,11 @@ try:
304
  st.error("Failed to process PDF.")
305
  with col2:
306
  if st.button("Summarize PDF") and st.session_state.pdf_text:
 
307
  with st.spinner("Summarizing..."):
 
 
 
308
  summary = summarize_pdf(st.session_state.pdf_text)
309
  st.session_state.messages.append({"role": "assistant", "content": summary})
310
  st.markdown(summary, unsafe_allow_html=True)
@@ -318,7 +357,11 @@ try:
318
  with st.chat_message("user"):
319
  st.markdown(prompt)
320
  with st.chat_message("assistant"):
321
- with st.spinner('<div class="spinner">⏳</div>'):
 
 
 
 
322
  answer = answer_question(st.session_state.text_vector_store, st.session_state.code_vector_store, prompt)
323
  st.markdown(answer, unsafe_allow_html=True)
324
  st.session_state.messages.append({"role": "assistant", "content": answer})
 
3
  import os
4
  from io import BytesIO
5
  import pdfplumber
6
+ from PIL import Image
7
+ import pytesseract
8
  from langchain.text_splitter import CharacterTextSplitter
9
  from langchain_community.vectorstores import FAISS
10
  from sentence_transformers import SentenceTransformer
11
  from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments
12
  from datasets import load_dataset
13
+ from rank_bm25 import BM25Okapi
14
+ from rouge_score import rouge_scorer
15
  import re
16
+ import time
17
 
18
  # Setup logging for Spaces
19
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
39
  fine_tuned_pipeline = fine_tune_qa_model(dataset)
40
  if fine_tuned_pipeline:
41
  return fine_tuned_pipeline
42
+ return pipeline("text2text-generation", model="google/flan-t5-base", max_length=300)
43
  except Exception as e:
44
  logger.error(f"QA model load error: {str(e)}")
45
  st.error(f"QA model error: {str(e)}")
 
56
  return None
57
 
58
  # Load and prepare dataset (e.g., SQuAD)
59
+ @st.cache_data(ttl=3600)
60
  def load_and_prepare_dataset(dataset_name="squad", max_samples=1000):
61
  logger.info(f"Loading dataset: {dataset_name}")
62
  try:
63
+ dataset = load_dataset(dataset_name, split="train[:80%]")
64
+ dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
65
 
66
  def preprocess(examples):
67
  inputs = [f"question: {q} context: {c}" for q, c in zip(examples['question'], examples['context'])]
68
  targets = examples['answers']['text']
69
  return {'input_text': inputs, 'target_text': [t[0] if t else "" for t in targets]}
70
 
71
+ dataset = dataset.map(preprocess, batched=True, remove_columns=dataset.column_names)
72
  return dataset
73
  except Exception as e:
74
  logger.error(f"Dataset load error: {str(e)}")
 
79
  def fine_tune_qa_model(dataset):
80
  logger.info("Starting fine-tuning")
81
  try:
82
+ model_name = "google/flan-t5-base"
83
  tokenizer = AutoTokenizer.from_pretrained(model_name)
84
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
85
 
 
89
  model_inputs["labels"] = labels["input_ids"]
90
  return model_inputs
91
 
92
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['input_text', 'target_text'])
93
 
94
  training_args = TrainingArguments(
95
  output_dir="./fine_tuned_model",
96
+ num_train_epochs=2,
97
  per_device_train_batch_size=4,
98
  save_steps=500,
99
  logging_steps=100,
100
  evaluation_strategy="no",
101
+ learning_rate=3e-5,
102
+ fp16=False, # Set True if GPU available
103
  )
104
 
105
  trainer = Trainer(
 
121
  def augment_vector_store(vector_store, dataset_name="squad", max_samples=500):
122
  logger.info(f"Augmenting vector store with dataset: {dataset_name}")
123
  try:
124
+ dataset = load_dataset(dataset_name, split="train").select(range(min(max_samples, len(dataset))))
125
  chunks = [f"Context: {c}\nAnswer: {a['text'][0]}" for c, a in zip(dataset['context'], dataset['answers'])]
126
  embeddings_model = load_embeddings_model()
127
  if embeddings_model and vector_store:
128
+ embeddings = embeddings_model.encode(chunks, batch_size=32, show_progress_bar=False)
129
  vector_store.add_embeddings(zip(chunks, embeddings))
130
  return vector_store
131
  except Exception as e:
132
  logger.error(f"Vector store augmentation error: {str(e)}")
133
  return vector_store
134
 
135
+ # Process PDF with enhanced extraction and OCR fallback
136
  def process_pdf(uploaded_file):
137
  logger.info("Processing PDF with enhanced extraction")
138
  try:
 
141
  with pdfplumber.open(BytesIO(uploaded_file.getvalue())) as pdf:
142
  for page in pdf.pages[:20]:
143
  extracted = page.extract_text(layout=False)
144
+ if not extracted: # OCR fallback for scanned PDFs
145
+ try:
146
+ img = page.to_image(resolution=150).original
147
+ extracted = pytesseract.image_to_string(img, config='--psm 6')
148
+ except Exception as ocr_e:
149
+ logger.warning(f"OCR failed: {str(ocr_e)}")
150
  if extracted:
151
  text += extracted + "\n"
152
  for char in page.chars:
 
168
  if not text:
169
  raise ValueError("No text extracted from PDF")
170
 
171
+ text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=400, chunk_overlap=80, keep_separator=True)
172
+ text_chunks = text_splitter.split_text(text)[:80]
173
+ code_chunks = text_splitter.split_text(code_text)[:40] if code_text else []
174
 
175
  embeddings_model = load_embeddings_model()
176
  if not embeddings_model:
177
  return None, None, text, code_text
178
 
179
  text_vector_store = FAISS.from_embeddings(
180
+ zip(text_chunks, [embeddings_model.encode(chunk, show_progress_bar=False) for chunk in text_chunks]),
181
  embeddings_model.encode
182
  ) if text_chunks else None
183
  code_vector_store = FAISS.from_embeddings(
184
+ zip(code_chunks, [embeddings_model.encode(chunk, show_progress_bar=False) for chunk in code_chunks]),
185
  embeddings_model.encode
186
  ) if code_chunks else None
187
 
 
188
  if text_vector_store:
189
  text_vector_store = augment_vector_store(text_vector_store)
190
 
191
+ logger.info("PDF processed successfully")
192
  return text_vector_store, code_vector_store, text, code_text
193
  except Exception as e:
194
  logger.error(f"PDF processing error: {str(e)}")
195
  st.error(f"PDF error: {str(e)}")
196
  return None, None, "", ""
197
 
198
+ # Summarize PDF with ROUGE metrics
199
  def summarize_pdf(text):
200
  logger.info("Generating summary")
201
  try:
 
203
  if not summary_pipeline:
204
  return "Summary model unavailable."
205
 
206
+ text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=400, chunk_overlap=50)
207
  chunks = text_splitter.split_text(text)[:2]
208
  summaries = []
209
 
210
  for chunk in chunks:
211
+ summary = summary_pipeline(chunk[:400], max_length=100, min_length=30, do_sample=False)[0]['summary_text']
212
  summaries.append(summary.strip())
213
 
214
  combined_summary = " ".join(summaries)
215
  if len(combined_summary.split()) > 150:
216
  combined_summary = " ".join(combined_summary.split()[:150])
217
+
218
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
219
+ scores = scorer.score(text[:400], combined_summary)
220
+ logger.info(f"ROUGE scores: {scores}")
221
+
222
+ return f"**Summary**:\n{combined_summary}\n\n**ROUGE-1**: {scores['rouge1'].fmeasure:.2f}"
223
  except Exception as e:
224
  logger.error(f"Summary error: {str(e)}")
225
  return f"Oops, something went wrong summarizing: {str(e)}"
226
 
227
+ # Answer question with hybrid search
228
  def answer_question(text_vector_store, code_vector_store, query):
229
  logger.info(f"Processing query: {query}")
230
  try:
 
237
 
238
  is_code_query = any(keyword in query.lower() for keyword in ["code", "script", "function", "programming", "give me code", "show code"])
239
  if is_code_query and code_vector_store:
240
+ docs = code_vector_store.similarity_search(query, k=3)
241
+ code = "\n".join(doc.page_content for doc in docs)
242
+ explanation = qa_pipeline(f"Explain this code: {code[:500]}")[0]['generated_text']
243
+ return f"**Code**:\n```python\n{code}\n```\n**Explanation**:\n{explanation}"
244
 
245
  vector_store = text_vector_store
246
  if not vector_store:
247
  return "No relevant content found for your query."
248
 
249
+ # Hybrid search: FAISS + BM25
250
+ text_chunks = [doc.page_content for doc in vector_store.similarity_search(query, k=10)]
251
+ bm25 = BM25Okapi([chunk.split() for chunk in text_chunks])
252
+ bm25_docs = bm25.get_top_n(query.split(), text_chunks, n=5)
253
+ faiss_docs = vector_store.similarity_search(query, k=5)
254
+ combined_docs = list(set(bm25_docs + [doc.page_content for doc in faiss_docs]))[:5]
255
+ context = "\n".join(combined_docs)
256
+
257
+ prompt = f"Use the following PDF content to answer the question accurately and concisely. Avoid speculation and focus on the provided context:\n\n{context}\n\nQuestion: {query}\nAnswer:"
258
  response = qa_pipeline(prompt)[0]['generated_text']
259
  logger.info("Answer generated")
260
+ return f"**Answer**:\n{response.strip()}\n\n**Source Context**:\n{context[:500]}..."
261
  except Exception as e:
262
  logger.error(f"Query error: {str(e)}")
263
  return f"Sorry, something went wrong: {str(e)}"
 
268
  st.markdown("""
269
  <style>
270
  .main { max-width: 900px; margin: 0 auto; padding: 20px; }
271
+ .sidebar { background-color: #f8f9fa; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
272
+ .chat-container { border: 1px solid #ddd; border-radius: 12px; padding: 15px; height: 60vh; overflow-y: auto; margin-top: 20px; background-color: #fafafa; }
273
+ .stChatMessage { border-radius: 12px; padding: 12px; margin: 8px; max-width: 75%; transition: all 0.3s ease; }
274
+ .user { background-color: #e6f3ff; align-self: flex-end; border: 1px solid #b3d4fc; }
275
+ .assistant { background-color: #f0f0f0; border: 1px solid #ccc; }
276
+ .dark .user { background-color: #2a2a72; color: #fff; border: 1px solid #4a4ab2; }
277
+ .dark .assistant { background-color: #2e2e2e; color: #fff; border: 1px solid #4a4a4a; }
278
+ .stButton>button { background-color: #4CAF50; color: white; border: none; padding: 10px 20px; border-radius: 8px; font-weight: bold; }
279
+ .stButton>button:hover { background-color: #45a049; transform: scale(1.05); }
280
+ pre { background-color: #f8f8f8; padding: 12px; border-radius: 8px; overflow-x: auto; }
281
+ .header { background: linear-gradient(90deg, #4CAF50, #81C784); color: white; padding: 15px; border-radius: 8px; text-align: center; box-shadow: 0 2px 4px rgba(0,0,0,0.2); }
282
+ .progress-bar { background-color: #e0e0e0; border-radius: 5px; height: 10px; }
283
+ .progress-fill { background-color: #4CAF50; height: 100%; border-radius: 5px; transition: width 0.5s ease; }
284
  </style>
285
  """, unsafe_allow_html=True)
286
 
287
  st.markdown('<div class="header"><h1>Smart PDF Q&A</h1></div>', unsafe_allow_html=True)
288
+ st.markdown("Upload a PDF to ask questions, summarize (~150 words), or extract code with 'give me code'. Fast, accurate, and smooth!")
289
 
290
  # Initialize session state
291
  if "messages" not in st.session_state:
 
299
  if "code_text" not in st.session_state:
300
  st.session_state.code_text = ""
301
 
302
+ # Sidebar with controls
303
  with st.sidebar:
304
  st.markdown('<div class="sidebar">', unsafe_allow_html=True)
305
  theme = st.radio("Theme", ["Light", "Dark"], index=0)
306
  dataset_name = st.selectbox("Select Dataset for Fine-Tuning", ["squad", "cnn_dailymail", "bigcode/the-stack"], index=0)
307
  if st.button("Fine-Tune Model"):
308
+ progress_bar = st.progress(0)
309
+ for i in range(100):
310
+ time.sleep(0.02)
311
+ progress_bar.progress(i + 1)
312
+ dataset = load_and_prepare_dataset(dataset_name=dataset_name)
313
+ if dataset:
314
+ fine_tuned_pipeline = fine_tune_qa_model(dataset)
315
+ if fine_tuned_pipeline:
316
+ st.success("Model fine-tuned successfully!")
317
+ else:
318
+ st.error("Fine-tuning failed.")
319
+ if st.button("Clear Chat"):
320
+ st.session_state.messages = []
321
+ st.experimental_rerun()
322
  st.markdown('</div>', unsafe_allow_html=True)
323
 
324
  # PDF upload and processing
 
326
  col1, col2 = st.columns([1, 1])
327
  with col1:
328
  if st.button("Process PDF"):
329
+ progress_bar = st.progress(0)
330
  with st.spinner("Processing PDF..."):
331
+ for i in range(100):
332
+ time.sleep(0.05)
333
+ progress_bar.progress(i + 1)
334
  st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text = process_pdf(uploaded_file)
335
  if st.session_state.text_vector_store or st.session_state.code_vector_store:
336
  st.success("PDF processed! Ask away or summarize.")
 
339
  st.error("Failed to process PDF.")
340
  with col2:
341
  if st.button("Summarize PDF") and st.session_state.pdf_text:
342
+ progress_bar = st.progress(0)
343
  with st.spinner("Summarizing..."):
344
+ for i in range(100):
345
+ time.sleep(0.02)
346
+ progress_bar.progress(i + 1)
347
  summary = summarize_pdf(st.session_state.pdf_text)
348
  st.session_state.messages.append({"role": "assistant", "content": summary})
349
  st.markdown(summary, unsafe_allow_html=True)
 
357
  with st.chat_message("user"):
358
  st.markdown(prompt)
359
  with st.chat_message("assistant"):
360
+ progress_bar = st.progress(0)
361
+ with st.spinner('<div class="spinner">⏳ Processing...</div>'):
362
+ for i in range(100):
363
+ time.sleep(0.01)
364
+ progress_bar.progress(i + 1)
365
  answer = answer_question(st.session_state.text_vector_store, st.session_state.code_vector_store, prompt)
366
  st.markdown(answer, unsafe_allow_html=True)
367
  st.session_state.messages.append({"role": "assistant", "content": answer})