aaporosh commited on
Commit
058a20c
·
verified ·
1 Parent(s): b3ca527

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -31
app.py CHANGED
@@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
24
  def load_embeddings_model():
25
  logger.info("Loading embeddings model")
26
  try:
27
- return SentenceTransformer("all-MiniLM-L12-v2")
28
  except Exception as e:
29
  logger.error(f"Embeddings load error: {str(e)}")
30
  st.error(f"Embedding model error: {str(e)}")
@@ -39,7 +39,7 @@ def load_qa_pipeline():
39
  fine_tuned_pipeline = fine_tune_qa_model(dataset)
40
  if fine_tuned_pipeline:
41
  return fine_tuned_pipeline
42
- return pipeline("text2text-generation", model="google/flan-t5-base", max_length=300)
43
  except Exception as e:
44
  logger.error(f"QA model load error: {str(e)}")
45
  st.error(f"QA model error: {str(e)}")
@@ -49,7 +49,7 @@ def load_qa_pipeline():
49
  def load_summary_pipeline():
50
  logger.info("Loading summary pipeline")
51
  try:
52
- return pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", max_length=150)
53
  except Exception as e:
54
  logger.error(f"Summary model load error: {str(e)}")
55
  st.error(f"Summary model error: {str(e)}")
@@ -79,7 +79,7 @@ def load_and_prepare_dataset(dataset_name="squad", max_samples=1000):
79
  def fine_tune_qa_model(dataset):
80
  logger.info("Starting fine-tuning")
81
  try:
82
- model_name = "google/flan-t5-base"
83
  tokenizer = AutoTokenizer.from_pretrained(model_name)
84
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
85
 
@@ -93,13 +93,13 @@ def fine_tune_qa_model(dataset):
93
 
94
  training_args = TrainingArguments(
95
  output_dir="./fine_tuned_model",
96
- num_train_epochs=2,
97
  per_device_train_batch_size=4,
98
  save_steps=500,
99
  logging_steps=100,
100
  evaluation_strategy="no",
101
  learning_rate=3e-5,
102
- fp16=False, # Set True if GPU available
103
  )
104
 
105
  trainer = Trainer(
@@ -118,14 +118,14 @@ def fine_tune_qa_model(dataset):
118
  return None
119
 
120
  # Augment vector store with dataset
121
- def augment_vector_store(vector_store, dataset_name="squad", max_samples=500):
122
  logger.info(f"Augmenting vector store with dataset: {dataset_name}")
123
  try:
124
  dataset = load_dataset(dataset_name, split="train").select(range(min(max_samples, len(dataset))))
125
  chunks = [f"Context: {c}\nAnswer: {a['text'][0]}" for c, a in zip(dataset['context'], dataset['answers'])]
126
  embeddings_model = load_embeddings_model()
127
  if embeddings_model and vector_store:
128
- embeddings = embeddings_model.encode(chunks, batch_size=32, show_progress_bar=False)
129
  vector_store.add_embeddings(zip(chunks, embeddings))
130
  return vector_store
131
  except Exception as e:
@@ -139,16 +139,19 @@ def process_pdf(uploaded_file):
139
  text = ""
140
  code_blocks = []
141
  with pdfplumber.open(BytesIO(uploaded_file.getvalue())) as pdf:
142
- for page in pdf.pages[:20]:
143
  extracted = page.extract_text(layout=False)
144
- if not extracted: # OCR fallback for scanned PDFs
145
  try:
146
  img = page.to_image(resolution=150).original
147
  extracted = pytesseract.image_to_string(img, config='--psm 6')
148
  except Exception as ocr_e:
149
  logger.warning(f"OCR failed: {str(ocr_e)}")
150
  if extracted:
151
- text += extracted + "\n"
 
 
 
152
  for char in page.chars:
153
  if 'fontname' in char and 'mono' in char['fontname'].lower():
154
  code_blocks.append(char['text'])
@@ -168,20 +171,20 @@ def process_pdf(uploaded_file):
168
  if not text:
169
  raise ValueError("No text extracted from PDF")
170
 
171
- text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=400, chunk_overlap=80, keep_separator=True)
172
- text_chunks = text_splitter.split_text(text)[:80]
173
- code_chunks = text_splitter.split_text(code_text)[:40] if code_text else []
174
 
175
  embeddings_model = load_embeddings_model()
176
  if not embeddings_model:
177
  return None, None, text, code_text
178
 
179
  text_vector_store = FAISS.from_embeddings(
180
- zip(text_chunks, [embeddings_model.encode(chunk, show_progress_bar=False) for chunk in text_chunks]),
181
  embeddings_model.encode
182
  ) if text_chunks else None
183
  code_vector_store = FAISS.from_embeddings(
184
- zip(code_chunks, [embeddings_model.encode(chunk, show_progress_bar=False) for chunk in code_chunks]),
185
  embeddings_model.encode
186
  ) if code_chunks else None
187
 
@@ -195,7 +198,7 @@ def process_pdf(uploaded_file):
195
  st.error(f"PDF error: {str(e)}")
196
  return None, None, "", ""
197
 
198
- # Summarize PDF with ROUGE metrics
199
  def summarize_pdf(text):
200
  logger.info("Generating summary")
201
  try:
@@ -203,23 +206,39 @@ def summarize_pdf(text):
203
  if not summary_pipeline:
204
  return "Summary model unavailable."
205
 
206
- text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=400, chunk_overlap=50)
207
- chunks = text_splitter.split_text(text)[:2]
208
- summaries = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- for chunk in chunks:
211
- summary = summary_pipeline(chunk[:400], max_length=100, min_length=30, do_sample=False)[0]['summary_text']
 
212
  summaries.append(summary.strip())
213
 
214
  combined_summary = " ".join(summaries)
215
- if len(combined_summary.split()) > 150:
216
- combined_summary = " ".join(combined_summary.split()[:150])
217
 
 
218
  scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
219
- scores = scorer.score(text[:400], combined_summary)
220
  logger.info(f"ROUGE scores: {scores}")
221
 
222
- return f"**Summary**:\n{combined_summary}\n\n**ROUGE-1**: {scores['rouge1'].fmeasure:.2f}"
223
  except Exception as e:
224
  logger.error(f"Summary error: {str(e)}")
225
  return f"Oops, something went wrong summarizing: {str(e)}"
@@ -285,7 +304,7 @@ try:
285
  """, unsafe_allow_html=True)
286
 
287
  st.markdown('<div class="header"><h1>Smart PDF Q&A</h1></div>', unsafe_allow_html=True)
288
- st.markdown("Upload a PDF to ask questions, summarize (~150 words), or extract code with 'give me code'. Fast, accurate, and smooth!")
289
 
290
  # Initialize session state
291
  if "messages" not in st.session_state:
@@ -307,7 +326,7 @@ try:
307
  if st.button("Fine-Tune Model"):
308
  progress_bar = st.progress(0)
309
  for i in range(100):
310
- time.sleep(0.02)
311
  progress_bar.progress(i + 1)
312
  dataset = load_and_prepare_dataset(dataset_name=dataset_name)
313
  if dataset:
@@ -319,6 +338,15 @@ try:
319
  if st.button("Clear Chat"):
320
  st.session_state.messages = []
321
  st.experimental_rerun()
 
 
 
 
 
 
 
 
 
322
  st.markdown('</div>', unsafe_allow_html=True)
323
 
324
  # PDF upload and processing
@@ -329,7 +357,7 @@ try:
329
  progress_bar = st.progress(0)
330
  with st.spinner("Processing PDF..."):
331
  for i in range(100):
332
- time.sleep(0.05)
333
  progress_bar.progress(i + 1)
334
  st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text = process_pdf(uploaded_file)
335
  if st.session_state.text_vector_store or st.session_state.code_vector_store:
@@ -342,7 +370,7 @@ try:
342
  progress_bar = st.progress(0)
343
  with st.spinner("Summarizing..."):
344
  for i in range(100):
345
- time.sleep(0.02)
346
  progress_bar.progress(i + 1)
347
  summary = summarize_pdf(st.session_state.pdf_text)
348
  st.session_state.messages.append({"role": "assistant", "content": summary})
@@ -360,7 +388,7 @@ try:
360
  progress_bar = st.progress(0)
361
  with st.spinner('<div class="spinner">⏳ Processing...</div>'):
362
  for i in range(100):
363
- time.sleep(0.01)
364
  progress_bar.progress(i + 1)
365
  answer = answer_question(st.session_state.text_vector_store, st.session_state.code_vector_store, prompt)
366
  st.markdown(answer, unsafe_allow_html=True)
 
24
  def load_embeddings_model():
25
  logger.info("Loading embeddings model")
26
  try:
27
+ return SentenceTransformer("all-MiniLM-L6-v2")
28
  except Exception as e:
29
  logger.error(f"Embeddings load error: {str(e)}")
30
  st.error(f"Embedding model error: {str(e)}")
 
39
  fine_tuned_pipeline = fine_tune_qa_model(dataset)
40
  if fine_tuned_pipeline:
41
  return fine_tuned_pipeline
42
+ return pipeline("text2text-generation", model="google/flan-t5-small", max_length=300)
43
  except Exception as e:
44
  logger.error(f"QA model load error: {str(e)}")
45
  st.error(f"QA model error: {str(e)}")
 
49
  def load_summary_pipeline():
50
  logger.info("Loading summary pipeline")
51
  try:
52
+ return pipeline("summarization", model="facebook/bart-large-cnn", max_length=250)
53
  except Exception as e:
54
  logger.error(f"Summary model load error: {str(e)}")
55
  st.error(f"Summary model error: {str(e)}")
 
79
  def fine_tune_qa_model(dataset):
80
  logger.info("Starting fine-tuning")
81
  try:
82
+ model_name = "google/flan-t5-small"
83
  tokenizer = AutoTokenizer.from_pretrained(model_name)
84
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
85
 
 
93
 
94
  training_args = TrainingArguments(
95
  output_dir="./fine_tuned_model",
96
+ num_train_epochs = 2,
97
  per_device_train_batch_size=4,
98
  save_steps=500,
99
  logging_steps=100,
100
  evaluation_strategy="no",
101
  learning_rate=3e-5,
102
+ fp16=False,
103
  )
104
 
105
  trainer = Trainer(
 
118
  return None
119
 
120
  # Augment vector store with dataset
121
+ def augment_vector_store(vector_store, dataset_name="squad", max_samples=300):
122
  logger.info(f"Augmenting vector store with dataset: {dataset_name}")
123
  try:
124
  dataset = load_dataset(dataset_name, split="train").select(range(min(max_samples, len(dataset))))
125
  chunks = [f"Context: {c}\nAnswer: {a['text'][0]}" for c, a in zip(dataset['context'], dataset['answers'])]
126
  embeddings_model = load_embeddings_model()
127
  if embeddings_model and vector_store:
128
+ embeddings = embeddings_model.encode(chunks, batch_size=128, show_progress_bar=False)
129
  vector_store.add_embeddings(zip(chunks, embeddings))
130
  return vector_store
131
  except Exception as e:
 
139
  text = ""
140
  code_blocks = []
141
  with pdfplumber.open(BytesIO(uploaded_file.getvalue())) as pdf:
142
+ for page in pdf.pages[:8]:
143
  extracted = page.extract_text(layout=False)
144
+ if not extracted:
145
  try:
146
  img = page.to_image(resolution=150).original
147
  extracted = pytesseract.image_to_string(img, config='--psm 6')
148
  except Exception as ocr_e:
149
  logger.warning(f"OCR failed: {str(ocr_e)}")
150
  if extracted:
151
+ # Clean text: remove headers/footers (simple heuristic)
152
+ lines = extracted.split("\n")
153
+ cleaned_lines = [line for line in lines if not re.match(r'^\s*(Page \d+|.*\d{4}-\d{4}|Copyright.*)\s*$', line, re.I)]
154
+ text += "\n".join(cleaned_lines) + "\n"
155
  for char in page.chars:
156
  if 'fontname' in char and 'mono' in char['fontname'].lower():
157
  code_blocks.append(char['text'])
 
171
  if not text:
172
  raise ValueError("No text extracted from PDF")
173
 
174
+ text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=250, chunk_overlap=40, keep_separator=True)
175
+ text_chunks = text_splitter.split_text(text)[:25]
176
+ code_chunks = text_splitter.split_text(code_text)[:10] if code_text else []
177
 
178
  embeddings_model = load_embeddings_model()
179
  if not embeddings_model:
180
  return None, None, text, code_text
181
 
182
  text_vector_store = FAISS.from_embeddings(
183
+ zip(text_chunks, [embeddings_model.encode(chunk, show_progress_bar=False, batch_size=128) for chunk in text_chunks]),
184
  embeddings_model.encode
185
  ) if text_chunks else None
186
  code_vector_store = FAISS.from_embeddings(
187
+ zip(code_chunks, [embeddings_model.encode(chunk, show_progress_bar=False, batch_size=128) for chunk in code_chunks]),
188
  embeddings_model.encode
189
  ) if code_chunks else None
190
 
 
198
  st.error(f"PDF error: {str(e)}")
199
  return None, None, "", ""
200
 
201
+ # Summarize PDF with ROUGE metrics and improved topic focus
202
  def summarize_pdf(text):
203
  logger.info("Generating summary")
204
  try:
 
206
  if not summary_pipeline:
207
  return "Summary model unavailable."
208
 
209
+ text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=250, chunk_overlap=40)
210
+ chunks = text_splitter.split_text(text)
211
+
212
+ # Hybrid search for relevant chunks
213
+ embeddings_model = load_embeddings_model()
214
+ if embeddings_model and chunks:
215
+ temp_vector_store = FAISS.from_embeddings(
216
+ zip(chunks, [embeddings_model.encode(chunk, show_progress_bar=False) for chunk in chunks]),
217
+ embeddings_model.encode
218
+ )
219
+ bm25 = BM25Okapi([chunk.split() for chunk in chunks])
220
+ query = "main topic and key points"
221
+ bm25_docs = bm25.get_top_n(query.split(), chunks, n=4)
222
+ faiss_docs = temp_vector_store.similarity_search(query, k=4)
223
+ selected_chunks = list(set(bm25_docs + [doc.page_content for doc in faiss_docs]))[:4]
224
+ else:
225
+ selected_chunks = chunks[:4]
226
 
227
+ summaries = []
228
+ for chunk in selected_chunks:
229
+ summary = summary_pipeline(f"Summarize the main topic and key points in detail: {chunk[:250]}", max_length=100, min_length=50, do_sample=False)[0]['summary_text']
230
  summaries.append(summary.strip())
231
 
232
  combined_summary = " ".join(summaries)
233
+ if len(combined_summary.split()) > 250:
234
+ combined_summary = " ".join(combined_summary.split()[:250])
235
 
236
+ word_count = len(combined_summary.split())
237
  scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
238
+ scores = scorer.score(text[:500], combined_summary)
239
  logger.info(f"ROUGE scores: {scores}")
240
 
241
+ return f"**Main Topic Summary** ({word_count} words):\n{combined_summary}\n\n**ROUGE-1**: {scores['rouge1'].fmeasure:.2f}"
242
  except Exception as e:
243
  logger.error(f"Summary error: {str(e)}")
244
  return f"Oops, something went wrong summarizing: {str(e)}"
 
304
  """, unsafe_allow_html=True)
305
 
306
  st.markdown('<div class="header"><h1>Smart PDF Q&A</h1></div>', unsafe_allow_html=True)
307
+ st.markdown("Upload a PDF to ask questions, get a ~200-word summary, or extract code with 'give me code'. Optimized for speed and accuracy!")
308
 
309
  # Initialize session state
310
  if "messages" not in st.session_state:
 
326
  if st.button("Fine-Tune Model"):
327
  progress_bar = st.progress(0)
328
  for i in range(100):
329
+ time.sleep(0.008)
330
  progress_bar.progress(i + 1)
331
  dataset = load_and_prepare_dataset(dataset_name=dataset_name)
332
  if dataset:
 
338
  if st.button("Clear Chat"):
339
  st.session_state.messages = []
340
  st.experimental_rerun()
341
+ if st.button("Retry Summarization") and st.session_state.pdf_text:
342
+ progress_bar = st.progress(0)
343
+ with st.spinner("Retrying summarization..."):
344
+ for i in range(100):
345
+ time.sleep(0.008)
346
+ progress_bar.progress(i + 1)
347
+ summary = summarize_pdf(st.session_state.pdf_text)
348
+ st.session_state.messages.append({"role": "assistant", "content": summary})
349
+ st.markdown(summary, unsafe_allow_html=True)
350
  st.markdown('</div>', unsafe_allow_html=True)
351
 
352
  # PDF upload and processing
 
357
  progress_bar = st.progress(0)
358
  with st.spinner("Processing PDF..."):
359
  for i in range(100):
360
+ time.sleep(0.02)
361
  progress_bar.progress(i + 1)
362
  st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text = process_pdf(uploaded_file)
363
  if st.session_state.text_vector_store or st.session_state.code_vector_store:
 
370
  progress_bar = st.progress(0)
371
  with st.spinner("Summarizing..."):
372
  for i in range(100):
373
+ time.sleep(0.008)
374
  progress_bar.progress(i + 1)
375
  summary = summarize_pdf(st.session_state.pdf_text)
376
  st.session_state.messages.append({"role": "assistant", "content": summary})
 
388
  progress_bar = st.progress(0)
389
  with st.spinner('<div class="spinner">⏳ Processing...</div>'):
390
  for i in range(100):
391
+ time.sleep(0.004)
392
  progress_bar.progress(i + 1)
393
  answer = answer_question(st.session_state.text_vector_store, st.session_state.code_vector_store, prompt)
394
  st.markdown(answer, unsafe_allow_html=True)