ak3ra commited on
Commit
7bb0003
·
1 Parent(s): 5674d87

Refactor chat_response function to include PDF preview generation

Browse files
Files changed (5) hide show
  1. app.py +41 -13
  2. interface.py +1 -0
  3. rag/rag_pipeline.py +20 -7
  4. study_files.json +0 -1
  5. utils/pdf_processor.py +1 -0
app.py CHANGED
@@ -274,15 +274,36 @@ def process_pdf_uploads(files: List[gr.File], collection_name: str) -> str:
274
 
275
 
276
  def chat_response(
277
- message: str, history: List[Tuple[str, str]], study_name: str
278
- ) -> Tuple[List[Tuple[str, str]], str]:
 
 
 
279
  """Generate chat response and update history."""
280
  if not message.strip():
281
- return history, None
282
 
283
- response = chat_function(message, study_name, "Default")
 
284
  history.append((message, response))
285
- return history, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
 
288
  def create_gr_interface() -> gr.Blocks:
@@ -390,6 +411,9 @@ def create_gr_interface() -> gr.Blocks:
390
  upload_btn = gr.Button("Process PDFs", variant="primary")
391
  pdf_status = gr.Markdown()
392
  current_collection = gr.State(value=None)
 
 
 
393
  # Event handlers for Study Analysis tab
394
  process_zotero_btn.click(
395
  process_zotero_library_items,
@@ -433,24 +457,28 @@ def create_gr_interface() -> gr.Blocks:
433
  if not message.strip():
434
  raise gr.Error("Please enter a message")
435
  history = history + [(message, None)]
436
- return history, ""
437
 
438
  def generate_chat_response(history, collection_id):
439
  """Generate response for the last message in history."""
440
  if not collection_id:
441
  raise gr.Error("Please upload PDFs first")
442
  if len(history) == 0:
443
- return history
444
 
445
  last_message = history[-1][0]
446
  try:
447
- response = chat_function(last_message, collection_id, "Default")
448
- history[-1] = (last_message, response)
 
 
 
 
 
449
  except Exception as e:
450
  logger.error(f"Error in generate_chat_response: {str(e)}")
451
  history[-1] = (last_message, f"Error: {str(e)}")
452
-
453
- return history
454
 
455
  # Update PDF event handlers
456
  upload_btn.click( # Change from pdf_files.upload to upload_btn.click
@@ -463,11 +491,11 @@ def create_gr_interface() -> gr.Blocks:
463
  chat_submit_btn.click(
464
  add_message,
465
  inputs=[chat_history, query_input],
466
- outputs=[chat_history, query_input],
467
  ).success(
468
  generate_chat_response,
469
  inputs=[chat_history, current_collection],
470
- outputs=[chat_history],
471
  )
472
 
473
  return demo
 
274
 
275
 
276
  def chat_response(
277
+ message: str,
278
+ history: List[Tuple[str, str]],
279
+ study_name: str,
280
+ pdf_processor: PDFProcessor,
281
+ ) -> Tuple[List[Tuple[str, str]], str, Any]:
282
  """Generate chat response and update history."""
283
  if not message.strip():
284
+ return history, None, None
285
 
286
+ rag = get_rag_pipeline(study_name)
287
+ response, source_info = rag.query(message)
288
  history.append((message, response))
289
+
290
+ # Generate PDF preview if source information is available
291
+ preview_image = None
292
+ if (
293
+ source_info
294
+ and source_info.get("source_file")
295
+ and source_info.get("page_numbers")
296
+ ):
297
+ try:
298
+ # Get the first page number from the source
299
+ page_num = source_info["page_numbers"][0]
300
+ preview_image = pdf_processor.render_page(
301
+ source_info["source_file"], int(page_num)
302
+ )
303
+ except Exception as e:
304
+ logger.error(f"Error generating PDF preview: {str(e)}")
305
+
306
+ return history, preview_image
307
 
308
 
309
  def create_gr_interface() -> gr.Blocks:
 
411
  upload_btn = gr.Button("Process PDFs", variant="primary")
412
  pdf_status = gr.Markdown()
413
  current_collection = gr.State(value=None)
414
+
415
+ pdf_processor = PDFProcessor()
416
+
417
  # Event handlers for Study Analysis tab
418
  process_zotero_btn.click(
419
  process_zotero_library_items,
 
457
  if not message.strip():
458
  raise gr.Error("Please enter a message")
459
  history = history + [(message, None)]
460
+ return history, "", None # Return empty preview
461
 
462
  def generate_chat_response(history, collection_id):
463
  """Generate response for the last message in history."""
464
  if not collection_id:
465
  raise gr.Error("Please upload PDFs first")
466
  if len(history) == 0:
467
+ return history, None
468
 
469
  last_message = history[-1][0]
470
  try:
471
+ updated_history, preview_image = chat_response(
472
+ last_message,
473
+ history[:-1],
474
+ collection_id,
475
+ pdf_processor,
476
+ )
477
+ return updated_history, preview_image
478
  except Exception as e:
479
  logger.error(f"Error in generate_chat_response: {str(e)}")
480
  history[-1] = (last_message, f"Error: {str(e)}")
481
+ return history, None
 
482
 
483
  # Update PDF event handlers
484
  upload_btn.click( # Change from pdf_files.upload to upload_btn.click
 
491
  chat_submit_btn.click(
492
  add_message,
493
  inputs=[chat_history, query_input],
494
+ outputs=[chat_history, query_input, pdf_preview],
495
  ).success(
496
  generate_chat_response,
497
  inputs=[chat_history, current_collection],
498
+ outputs=[chat_history, pdf_preview],
499
  )
500
 
501
  return demo
interface.py CHANGED
@@ -3,6 +3,7 @@ Gradio interface module for ACRES RAG Platform.
3
  Defines the UI components and layout.
4
  """
5
 
 
6
  import gradio as gr
7
 
8
 
 
3
  Defines the UI components and layout.
4
  """
5
 
6
+ # interface.py
7
  import gradio as gr
8
 
9
 
rag/rag_pipeline.py CHANGED
@@ -10,6 +10,8 @@ from llama_index.embeddings.openai import OpenAIEmbedding
10
  from llama_index.llms.openai import OpenAI
11
  from llama_index.vector_stores.chroma import ChromaVectorStore
12
  import chromadb
 
 
13
 
14
  logging.basicConfig(level=logging.INFO)
15
 
@@ -27,7 +29,6 @@ class RAGPipeline:
27
  self.documents = None
28
  self.client = chromadb.Client()
29
  self.collection = self.client.get_or_create_collection(self.collection_name)
30
- # Embed and store each node in ChromaDB
31
  self.embedding_model = OpenAIEmbedding(model_name="text-embedding-ada-002")
32
  self.load_documents()
33
  self.build_index()
@@ -50,9 +51,12 @@ class RAGPipeline:
50
  "authors": ", ".join(doc_data.get("authors", [])),
51
  "year": doc_data.get("date"),
52
  "doi": doc_data.get("doi"),
 
 
 
 
53
  }
54
 
55
- # Append document data for use in ChromaDB indexing
56
  self.documents.append(
57
  Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
58
  )
@@ -83,7 +87,7 @@ class RAGPipeline:
83
 
84
  def query(
85
  self, context: str, prompt_template: PromptTemplate = None
86
- ) -> Dict[str, Any]:
87
  if prompt_template is None:
88
  prompt_template = PromptTemplate(
89
  "Context information is below.\n"
@@ -98,9 +102,7 @@ class RAGPipeline:
98
  "Ensure that EVERY statement from the context is properly cited."
99
  )
100
 
101
- # This is a hack to index all the documents in the store :)
102
  n_documents = len(self.index.docstore.docs)
103
- print(f"n_documents: {n_documents}")
104
  query_engine = self.index.as_query_engine(
105
  text_qa_template=prompt_template,
106
  similarity_top_k=n_documents if n_documents <= 17 else 15,
@@ -108,7 +110,18 @@ class RAGPipeline:
108
  llm=OpenAI(model="gpt-4o-mini"),
109
  )
110
 
111
- # Perform the query
112
  response = query_engine.query(context)
113
 
114
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from llama_index.llms.openai import OpenAI
11
  from llama_index.vector_stores.chroma import ChromaVectorStore
12
  import chromadb
13
+ from typing import Dict, Any, List, Tuple
14
+
15
 
16
  logging.basicConfig(level=logging.INFO)
17
 
 
29
  self.documents = None
30
  self.client = chromadb.Client()
31
  self.collection = self.client.get_or_create_collection(self.collection_name)
 
32
  self.embedding_model = OpenAIEmbedding(model_name="text-embedding-ada-002")
33
  self.load_documents()
34
  self.build_index()
 
51
  "authors": ", ".join(doc_data.get("authors", [])),
52
  "year": doc_data.get("date"),
53
  "doi": doc_data.get("doi"),
54
+ "source_file": doc_data.get("source_file"), # Add source file path
55
+ "page_numbers": list(
56
+ doc_data.get("pages", {}).keys()
57
+ ), # Add page numbers
58
  }
59
 
 
60
  self.documents.append(
61
  Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
62
  )
 
87
 
88
  def query(
89
  self, context: str, prompt_template: PromptTemplate = None
90
+ ) -> Tuple[str, Dict[str, Any]]:
91
  if prompt_template is None:
92
  prompt_template = PromptTemplate(
93
  "Context information is below.\n"
 
102
  "Ensure that EVERY statement from the context is properly cited."
103
  )
104
 
 
105
  n_documents = len(self.index.docstore.docs)
 
106
  query_engine = self.index.as_query_engine(
107
  text_qa_template=prompt_template,
108
  similarity_top_k=n_documents if n_documents <= 17 else 15,
 
110
  llm=OpenAI(model="gpt-4o-mini"),
111
  )
112
 
 
113
  response = query_engine.query(context)
114
 
115
+ # Extract source information from the response nodes
116
+ source_info = {}
117
+ if hasattr(response, "source_nodes") and response.source_nodes:
118
+ source_node = response.source_nodes[0] # Get the most relevant source
119
+ metadata = source_node.metadata
120
+ source_info = {
121
+ "source_file": metadata.get("source_file"),
122
+ "page_numbers": metadata.get("page_numbers", []),
123
+ "title": metadata.get("title"),
124
+ "authors": metadata.get("authors"),
125
+ }
126
+
127
+ return response.response, source_info
study_files.json CHANGED
@@ -10,5 +10,4 @@
10
  "iom": "data/iom_zotero_items.json",
11
  "ExportedRis_file_1_of_1 (1)": "data/exportedris-file-1-of-1-1_zotero_items.json",
12
  "wb_1813-9450-6689": "data/wb-1813-9450-6689_zotero_items.json",
13
- "kayongo papers": "data/kayongo-papers_zotero_items.json"
14
  }
 
10
  "iom": "data/iom_zotero_items.json",
11
  "ExportedRis_file_1_of_1 (1)": "data/exportedris-file-1-of-1-1_zotero_items.json",
12
  "wb_1813-9450-6689": "data/wb-1813-9450-6689_zotero_items.json",
 
13
  }
utils/pdf_processor.py CHANGED
@@ -3,6 +3,7 @@ PDF processing module for ACRES RAG Platform.
3
  Handles PDF file processing, text extraction, and page rendering.
4
  """
5
 
 
6
  import os
7
  import fitz
8
  import logging
 
3
  Handles PDF file processing, text extraction, and page rendering.
4
  """
5
 
6
+ # utils/pdf_processor.py
7
  import os
8
  import fitz
9
  import logging