ak3ra commited on
Commit
ff19631
·
1 Parent(s): 5695f84

Refactor RAGPipeline to extract page number from query and log requested page number

Browse files
Files changed (3) hide show
  1. app.py +47 -13
  2. rag/rag_pipeline.py +38 -7
  3. utils/pdf_processor.py +57 -40
app.py CHANGED
@@ -1,7 +1,10 @@
1
  # app.py
2
 
3
  import csv
 
4
  import datetime
 
 
5
  import io
6
  import json
7
  import logging
@@ -377,6 +380,8 @@ def create_gr_interface() -> gr.Blocks:
377
 
378
  # Tab 2: PDF Chat Interface
379
  with gr.Tab("PDF Chat"):
 
 
380
  with gr.Row():
381
  # Left column: Chat and Input
382
  with gr.Column(scale=7):
@@ -412,8 +417,6 @@ def create_gr_interface() -> gr.Blocks:
412
  pdf_status = gr.Markdown()
413
  current_collection = gr.State(value=None)
414
 
415
- pdf_processor = PDFProcessor()
416
-
417
  # Event handlers for Study Analysis tab
418
  process_zotero_btn.click(
419
  process_zotero_library_items,
@@ -438,7 +441,6 @@ def create_gr_interface() -> gr.Blocks:
438
  # Event handlers for PDF Chat tab
439
 
440
  def handle_pdf_upload(files, name):
441
- """Handle PDF upload and processing."""
442
  if not name:
443
  return "Please provide a collection name", None
444
  if not files:
@@ -452,14 +454,20 @@ def create_gr_interface() -> gr.Blocks:
452
  logger.error(f"Error in handle_pdf_upload: {str(e)}")
453
  return f"Error: {str(e)}", None
454
 
 
 
 
 
 
 
455
  def add_message(history, message):
456
  """Add user message to chat history."""
457
  if not message.strip():
458
  raise gr.Error("Please enter a message")
459
  history = history + [(message, None)]
460
- return history, "", None # Return empty preview
461
 
462
- def generate_chat_response(history, collection_id):
463
  """Generate response for the last message in history."""
464
  if not collection_id:
465
  raise gr.Error("Please upload PDFs first")
@@ -468,13 +476,39 @@ def create_gr_interface() -> gr.Blocks:
468
 
469
  last_message = history[-1][0]
470
  try:
471
- updated_history, preview_image = chat_response(
472
- last_message,
473
- history[:-1],
474
- collection_id,
475
- pdf_processor,
476
- )
477
- return updated_history, preview_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  except Exception as e:
479
  logger.error(f"Error in generate_chat_response: {str(e)}")
480
  history[-1] = (last_message, f"Error: {str(e)}")
@@ -493,7 +527,7 @@ def create_gr_interface() -> gr.Blocks:
493
  inputs=[chat_history, query_input],
494
  outputs=[chat_history, query_input, pdf_preview],
495
  ).success(
496
- generate_chat_response,
497
  inputs=[chat_history, current_collection],
498
  outputs=[chat_history, pdf_preview],
499
  )
 
1
  # app.py
2
 
3
  import csv
4
+
5
  import datetime
6
+
7
+ # from datetime import datetime
8
  import io
9
  import json
10
  import logging
 
380
 
381
  # Tab 2: PDF Chat Interface
382
  with gr.Tab("PDF Chat"):
383
+ pdf_processor = PDFProcessor()
384
+
385
  with gr.Row():
386
  # Left column: Chat and Input
387
  with gr.Column(scale=7):
 
417
  pdf_status = gr.Markdown()
418
  current_collection = gr.State(value=None)
419
 
 
 
420
  # Event handlers for Study Analysis tab
421
  process_zotero_btn.click(
422
  process_zotero_library_items,
 
441
  # Event handlers for PDF Chat tab
442
 
443
  def handle_pdf_upload(files, name):
 
444
  if not name:
445
  return "Please provide a collection name", None
446
  if not files:
 
454
  logger.error(f"Error in handle_pdf_upload: {str(e)}")
455
  return f"Error: {str(e)}", None
456
 
457
+ upload_btn.click(
458
+ handle_pdf_upload,
459
+ inputs=[pdf_files, collection_name],
460
+ outputs=[pdf_status, current_collection],
461
+ )
462
+
463
  def add_message(history, message):
464
  """Add user message to chat history."""
465
  if not message.strip():
466
  raise gr.Error("Please enter a message")
467
  history = history + [(message, None)]
468
+ return history, "", None
469
 
470
+ def generate_chat_response(history, collection_id, pdf_processor):
471
  """Generate response for the last message in history."""
472
  if not collection_id:
473
  raise gr.Error("Please upload PDFs first")
 
476
 
477
  last_message = history[-1][0]
478
  try:
479
+ # Get response and source info
480
+ rag = get_rag_pipeline(collection_id)
481
+ response, source_info = rag.query(last_message)
482
+
483
+ # Generate preview if source information is available
484
+ preview_image = None
485
+ if (
486
+ source_info
487
+ and source_info.get("source_file")
488
+ and source_info.get("page_number") is not None
489
+ ):
490
+ try:
491
+ page_num = source_info["page_number"]
492
+ logger.info(f"Attempting to render page {page_num}")
493
+ preview_image = pdf_processor.render_page(
494
+ source_info["source_file"], page_num
495
+ )
496
+ if preview_image:
497
+ logger.info(
498
+ f"Successfully generated preview for page {page_num}"
499
+ )
500
+ else:
501
+ logger.warning(
502
+ f"Failed to generate preview for page {page_num}"
503
+ )
504
+ except Exception as e:
505
+ logger.error(f"Error generating PDF preview: {str(e)}")
506
+ preview_image = None
507
+
508
+ # Update history with response
509
+ history[-1] = (last_message, response)
510
+ return history, preview_image
511
+
512
  except Exception as e:
513
  logger.error(f"Error in generate_chat_response: {str(e)}")
514
  history[-1] = (last_message, f"Error: {str(e)}")
 
527
  inputs=[chat_history, query_input],
528
  outputs=[chat_history, query_input, pdf_preview],
529
  ).success(
530
+ lambda h, c: generate_chat_response(h, c, pdf_processor),
531
  inputs=[chat_history, current_collection],
532
  outputs=[chat_history, pdf_preview],
533
  )
rag/rag_pipeline.py CHANGED
@@ -11,9 +11,12 @@ from llama_index.llms.openai import OpenAI
11
  from llama_index.vector_stores.chroma import ChromaVectorStore
12
  import chromadb
13
  from typing import Dict, Any, List, Tuple
 
 
14
 
15
 
16
  logging.basicConfig(level=logging.INFO)
 
17
 
18
 
19
  class RAGPipeline:
@@ -33,6 +36,23 @@ class RAGPipeline:
33
  self.load_documents()
34
  self.build_index()
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def load_documents(self):
37
  if self.documents is None:
38
  with open(self.study_json, "r") as f:
@@ -55,7 +75,7 @@ class RAGPipeline:
55
  "year": doc_data.get("date"),
56
  "doi": doc_data.get("doi"),
57
  "source_file": doc_data.get("source_file"),
58
- "page_number": page_num, # Store single page number
59
  "total_pages": len(pages),
60
  }
61
 
@@ -103,14 +123,17 @@ class RAGPipeline:
103
  "Given this information, please answer the question: {query_str}\n"
104
  "Provide a detailed answer using the content from the context above. "
105
  "If the question asks about specific page content, make sure to include that information. "
106
- "Cite sources using square brackets for EVERY piece of information, e.g. [1], [2], etc. "
107
- "Include page numbers in citations when available, e.g. [1, p.3]. "
108
  "If you're unsure about something, say so rather than making assumptions."
109
  )
110
 
 
 
 
 
111
  query_engine = self.index.as_query_engine(
112
  text_qa_template=prompt_template,
113
- similarity_top_k=5, # Reduced for more focused results
114
  response_mode="tree_summarize",
115
  llm=OpenAI(model="gpt-4o-mini"),
116
  )
@@ -120,15 +143,23 @@ class RAGPipeline:
120
  # Extract source information from the response nodes
121
  source_info = {}
122
  if hasattr(response, "source_nodes") and response.source_nodes:
123
- # Get the most relevant source
124
  source_node = response.source_nodes[0]
125
  metadata = source_node.metadata
 
 
 
 
 
 
 
 
126
  source_info = {
127
  "source_file": metadata.get("source_file"),
128
- "page_number": metadata.get("page_number"),
129
  "title": metadata.get("title"),
130
  "authors": metadata.get("authors"),
131
- "content": source_node.text, # Include the actual content
132
  }
 
133
 
134
  return response.response, source_info
 
11
  from llama_index.vector_stores.chroma import ChromaVectorStore
12
  import chromadb
13
  from typing import Dict, Any, List, Tuple
14
+ import re
15
+ import logging
16
 
17
 
18
  logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
 
21
 
22
  class RAGPipeline:
 
36
  self.load_documents()
37
  self.build_index()
38
 
39
+ def extract_page_number_from_query(self, query: str) -> int:
40
+ """Extract page number from query text."""
41
+ # Look for patterns like "page 3", "p3", "p. 3", etc.
42
+ patterns = [
43
+ r"page\s*(\d+)",
44
+ r"p\.\s*(\d+)",
45
+ r"p\s*(\d+)",
46
+ r"pg\.\s*(\d+)",
47
+ r"pg\s*(\d+)",
48
+ ]
49
+
50
+ for pattern in patterns:
51
+ match = re.search(pattern, query.lower())
52
+ if match:
53
+ return int(match.group(1))
54
+ return None
55
+
56
  def load_documents(self):
57
  if self.documents is None:
58
  with open(self.study_json, "r") as f:
 
75
  "year": doc_data.get("date"),
76
  "doi": doc_data.get("doi"),
77
  "source_file": doc_data.get("source_file"),
78
+ "page_number": int(page_num), # Store as integer
79
  "total_pages": len(pages),
80
  }
81
 
 
123
  "Given this information, please answer the question: {query_str}\n"
124
  "Provide a detailed answer using the content from the context above. "
125
  "If the question asks about specific page content, make sure to include that information. "
126
+ "Cite sources using square brackets for EVERY piece of information, e.g. [1, p.3], [2, p.5], etc. "
 
127
  "If you're unsure about something, say so rather than making assumptions."
128
  )
129
 
130
+ # Extract page number from query if present
131
+ requested_page = self.extract_page_number_from_query(context)
132
+ logger.info(f"Requested page number: {requested_page}")
133
+
134
  query_engine = self.index.as_query_engine(
135
  text_qa_template=prompt_template,
136
+ similarity_top_k=5,
137
  response_mode="tree_summarize",
138
  llm=OpenAI(model="gpt-4o-mini"),
139
  )
 
143
  # Extract source information from the response nodes
144
  source_info = {}
145
  if hasattr(response, "source_nodes") and response.source_nodes:
 
146
  source_node = response.source_nodes[0]
147
  metadata = source_node.metadata
148
+
149
+ # Use requested page number if available, otherwise use the page from metadata
150
+ page_number = (
151
+ requested_page
152
+ if requested_page is not None
153
+ else metadata.get("page_number", 0)
154
+ )
155
+
156
  source_info = {
157
  "source_file": metadata.get("source_file"),
158
+ "page_number": page_number,
159
  "title": metadata.get("title"),
160
  "authors": metadata.get("authors"),
161
+ "content": source_node.text,
162
  }
163
+ logger.info(f"Source info page number: {page_number}")
164
 
165
  return response.response, source_info
utils/pdf_processor.py CHANGED
@@ -8,11 +8,12 @@ import os
8
  import fitz
9
  import logging
10
  from typing import Dict, List, Optional
11
- from datetime import datetime
12
  from slugify import slugify
13
  import json
14
  from PIL import Image
15
 
 
16
  logger = logging.getLogger(__name__)
17
 
18
 
@@ -23,6 +24,60 @@ class PDFProcessor:
23
  os.makedirs(upload_dir, exist_ok=True)
24
  self.current_page = 0
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def extract_text_from_pdf(self, file_path: str) -> Dict:
27
  """Extract text and metadata from a PDF file."""
28
  try:
@@ -33,7 +88,7 @@ class PDFProcessor:
33
  pages = {}
34
  for page_num in range(len(doc)):
35
  page_text = doc[page_num].get_text()
36
- pages[page_num] = page_text
37
  text += page_text + "\n"
38
 
39
  # Extract metadata
@@ -62,41 +117,3 @@ class PDFProcessor:
62
  except Exception as e:
63
  logger.error(f"Error processing PDF {file_path}: {str(e)}")
64
  raise
65
-
66
- def process_pdfs(self, file_paths: List[str], collection_name: str) -> str:
67
- """Process multiple PDF files and store their content."""
68
- processed_docs = []
69
-
70
- for file_path in file_paths:
71
- try:
72
- doc_data = self.extract_text_from_pdf(file_path)
73
- processed_docs.append(doc_data)
74
- except Exception as e:
75
- logger.error(f"Error processing {file_path}: {str(e)}")
76
- continue
77
-
78
- if not processed_docs:
79
- raise ValueError("No documents were successfully processed")
80
-
81
- # Save to JSON file
82
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
83
- output_filename = f"{slugify(collection_name)}_{timestamp}_documents.json"
84
- output_path = f"data/{output_filename}"
85
-
86
- with open(output_path, "w", encoding="utf-8") as f:
87
- json.dump(processed_docs, f, indent=2, ensure_ascii=False)
88
-
89
- return output_path
90
-
91
- def render_page(self, file_path: str, page_num: int) -> Optional[Image.Image]:
92
- """Render a specific page from a PDF as an image."""
93
- try:
94
- doc = fitz.open(file_path)
95
- page = doc[page_num]
96
- pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72))
97
- image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
98
- doc.close()
99
- return image
100
- except Exception as e:
101
- logger.error(f"Error rendering page {page_num} from {file_path}: {str(e)}")
102
- return None
 
8
  import fitz
9
  import logging
10
  from typing import Dict, List, Optional
11
+ import datetime
12
  from slugify import slugify
13
  import json
14
  from PIL import Image
15
 
16
+
17
  logger = logging.getLogger(__name__)
18
 
19
 
 
24
  os.makedirs(upload_dir, exist_ok=True)
25
  self.current_page = 0
26
 
27
+ def render_page(self, file_path: str, page_num: int) -> Optional[Image.Image]:
28
+ """Render a specific page from a PDF as an image."""
29
+ try:
30
+ logger.info(f"Attempting to render page {page_num} from {file_path}")
31
+ doc = fitz.open(file_path)
32
+
33
+ # Ensure page number is valid
34
+ if page_num < 0 or page_num >= len(doc):
35
+ logger.error(
36
+ f"Invalid page number {page_num} for document with {len(doc)} pages"
37
+ )
38
+ return None
39
+
40
+ page = doc[page_num]
41
+ # Increase resolution for better quality
42
+ pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72))
43
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
44
+ doc.close()
45
+ logger.info(f"Successfully rendered page {page_num}")
46
+ return image
47
+ except Exception as e:
48
+ logger.error(f"Error rendering page {page_num} from {file_path}: {str(e)}")
49
+ return None
50
+
51
+ def process_pdfs(self, file_paths: List[str], collection_name: str) -> str:
52
+ """Process multiple PDF files and store their content."""
53
+ processed_docs = []
54
+
55
+ for file_path in file_paths:
56
+ try:
57
+ doc_data = self.extract_text_from_pdf(file_path)
58
+ processed_docs.append(doc_data)
59
+ logger.info(f"Successfully processed {file_path}")
60
+ except Exception as e:
61
+ logger.error(f"Error processing {file_path}: {str(e)}")
62
+ continue
63
+
64
+ if not processed_docs:
65
+ raise ValueError("No documents were successfully processed")
66
+
67
+ # Save to JSON file
68
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
69
+ output_filename = f"{slugify(collection_name)}_{timestamp}_documents.json"
70
+ output_path = os.path.join("data", output_filename)
71
+
72
+ # Ensure the data directory exists
73
+ os.makedirs("data", exist_ok=True)
74
+
75
+ with open(output_path, "w", encoding="utf-8") as f:
76
+ json.dump(processed_docs, f, indent=2, ensure_ascii=False)
77
+
78
+ logger.info(f"Saved processed documents to {output_path}")
79
+ return output_path
80
+
81
  def extract_text_from_pdf(self, file_path: str) -> Dict:
82
  """Extract text and metadata from a PDF file."""
83
  try:
 
88
  pages = {}
89
  for page_num in range(len(doc)):
90
  page_text = doc[page_num].get_text()
91
+ pages[str(page_num)] = page_text # Convert page_num to string for JSON
92
  text += page_text + "\n"
93
 
94
  # Extract metadata
 
117
  except Exception as e:
118
  logger.error(f"Error processing PDF {file_path}: {str(e)}")
119
  raise