Spaces:
Sleeping
Sleeping
Refactor chat_response function to include PDF preview generation
Browse files- app.py +41 -13
- interface.py +1 -0
- rag/rag_pipeline.py +20 -7
- study_files.json +0 -1
- utils/pdf_processor.py +1 -0
app.py
CHANGED
@@ -274,15 +274,36 @@ def process_pdf_uploads(files: List[gr.File], collection_name: str) -> str:
|
|
274 |
|
275 |
|
276 |
def chat_response(
|
277 |
-
message: str,
|
278 |
-
|
|
|
|
|
|
|
279 |
"""Generate chat response and update history."""
|
280 |
if not message.strip():
|
281 |
-
return history, None
|
282 |
|
283 |
-
|
|
|
284 |
history.append((message, response))
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
|
288 |
def create_gr_interface() -> gr.Blocks:
|
@@ -390,6 +411,9 @@ def create_gr_interface() -> gr.Blocks:
|
|
390 |
upload_btn = gr.Button("Process PDFs", variant="primary")
|
391 |
pdf_status = gr.Markdown()
|
392 |
current_collection = gr.State(value=None)
|
|
|
|
|
|
|
393 |
# Event handlers for Study Analysis tab
|
394 |
process_zotero_btn.click(
|
395 |
process_zotero_library_items,
|
@@ -433,24 +457,28 @@ def create_gr_interface() -> gr.Blocks:
|
|
433 |
if not message.strip():
|
434 |
raise gr.Error("Please enter a message")
|
435 |
history = history + [(message, None)]
|
436 |
-
return history, ""
|
437 |
|
438 |
def generate_chat_response(history, collection_id):
|
439 |
"""Generate response for the last message in history."""
|
440 |
if not collection_id:
|
441 |
raise gr.Error("Please upload PDFs first")
|
442 |
if len(history) == 0:
|
443 |
-
return history
|
444 |
|
445 |
last_message = history[-1][0]
|
446 |
try:
|
447 |
-
|
448 |
-
|
|
|
|
|
|
|
|
|
|
|
449 |
except Exception as e:
|
450 |
logger.error(f"Error in generate_chat_response: {str(e)}")
|
451 |
history[-1] = (last_message, f"Error: {str(e)}")
|
452 |
-
|
453 |
-
return history
|
454 |
|
455 |
# Update PDF event handlers
|
456 |
upload_btn.click( # Change from pdf_files.upload to upload_btn.click
|
@@ -463,11 +491,11 @@ def create_gr_interface() -> gr.Blocks:
|
|
463 |
chat_submit_btn.click(
|
464 |
add_message,
|
465 |
inputs=[chat_history, query_input],
|
466 |
-
outputs=[chat_history, query_input],
|
467 |
).success(
|
468 |
generate_chat_response,
|
469 |
inputs=[chat_history, current_collection],
|
470 |
-
outputs=[chat_history],
|
471 |
)
|
472 |
|
473 |
return demo
|
|
|
274 |
|
275 |
|
276 |
def chat_response(
|
277 |
+
message: str,
|
278 |
+
history: List[Tuple[str, str]],
|
279 |
+
study_name: str,
|
280 |
+
pdf_processor: PDFProcessor,
|
281 |
+
) -> Tuple[List[Tuple[str, str]], str, Any]:
|
282 |
"""Generate chat response and update history."""
|
283 |
if not message.strip():
|
284 |
+
return history, None, None
|
285 |
|
286 |
+
rag = get_rag_pipeline(study_name)
|
287 |
+
response, source_info = rag.query(message)
|
288 |
history.append((message, response))
|
289 |
+
|
290 |
+
# Generate PDF preview if source information is available
|
291 |
+
preview_image = None
|
292 |
+
if (
|
293 |
+
source_info
|
294 |
+
and source_info.get("source_file")
|
295 |
+
and source_info.get("page_numbers")
|
296 |
+
):
|
297 |
+
try:
|
298 |
+
# Get the first page number from the source
|
299 |
+
page_num = source_info["page_numbers"][0]
|
300 |
+
preview_image = pdf_processor.render_page(
|
301 |
+
source_info["source_file"], int(page_num)
|
302 |
+
)
|
303 |
+
except Exception as e:
|
304 |
+
logger.error(f"Error generating PDF preview: {str(e)}")
|
305 |
+
|
306 |
+
return history, preview_image
|
307 |
|
308 |
|
309 |
def create_gr_interface() -> gr.Blocks:
|
|
|
411 |
upload_btn = gr.Button("Process PDFs", variant="primary")
|
412 |
pdf_status = gr.Markdown()
|
413 |
current_collection = gr.State(value=None)
|
414 |
+
|
415 |
+
pdf_processor = PDFProcessor()
|
416 |
+
|
417 |
# Event handlers for Study Analysis tab
|
418 |
process_zotero_btn.click(
|
419 |
process_zotero_library_items,
|
|
|
457 |
if not message.strip():
|
458 |
raise gr.Error("Please enter a message")
|
459 |
history = history + [(message, None)]
|
460 |
+
return history, "", None # Return empty preview
|
461 |
|
462 |
def generate_chat_response(history, collection_id):
|
463 |
"""Generate response for the last message in history."""
|
464 |
if not collection_id:
|
465 |
raise gr.Error("Please upload PDFs first")
|
466 |
if len(history) == 0:
|
467 |
+
return history, None
|
468 |
|
469 |
last_message = history[-1][0]
|
470 |
try:
|
471 |
+
updated_history, preview_image = chat_response(
|
472 |
+
last_message,
|
473 |
+
history[:-1],
|
474 |
+
collection_id,
|
475 |
+
pdf_processor,
|
476 |
+
)
|
477 |
+
return updated_history, preview_image
|
478 |
except Exception as e:
|
479 |
logger.error(f"Error in generate_chat_response: {str(e)}")
|
480 |
history[-1] = (last_message, f"Error: {str(e)}")
|
481 |
+
return history, None
|
|
|
482 |
|
483 |
# Update PDF event handlers
|
484 |
upload_btn.click( # Change from pdf_files.upload to upload_btn.click
|
|
|
491 |
chat_submit_btn.click(
|
492 |
add_message,
|
493 |
inputs=[chat_history, query_input],
|
494 |
+
outputs=[chat_history, query_input, pdf_preview],
|
495 |
).success(
|
496 |
generate_chat_response,
|
497 |
inputs=[chat_history, current_collection],
|
498 |
+
outputs=[chat_history, pdf_preview],
|
499 |
)
|
500 |
|
501 |
return demo
|
interface.py
CHANGED
@@ -3,6 +3,7 @@ Gradio interface module for ACRES RAG Platform.
|
|
3 |
Defines the UI components and layout.
|
4 |
"""
|
5 |
|
|
|
6 |
import gradio as gr
|
7 |
|
8 |
|
|
|
3 |
Defines the UI components and layout.
|
4 |
"""
|
5 |
|
6 |
+
# interface.py
|
7 |
import gradio as gr
|
8 |
|
9 |
|
rag/rag_pipeline.py
CHANGED
@@ -10,6 +10,8 @@ from llama_index.embeddings.openai import OpenAIEmbedding
|
|
10 |
from llama_index.llms.openai import OpenAI
|
11 |
from llama_index.vector_stores.chroma import ChromaVectorStore
|
12 |
import chromadb
|
|
|
|
|
13 |
|
14 |
logging.basicConfig(level=logging.INFO)
|
15 |
|
@@ -27,7 +29,6 @@ class RAGPipeline:
|
|
27 |
self.documents = None
|
28 |
self.client = chromadb.Client()
|
29 |
self.collection = self.client.get_or_create_collection(self.collection_name)
|
30 |
-
# Embed and store each node in ChromaDB
|
31 |
self.embedding_model = OpenAIEmbedding(model_name="text-embedding-ada-002")
|
32 |
self.load_documents()
|
33 |
self.build_index()
|
@@ -50,9 +51,12 @@ class RAGPipeline:
|
|
50 |
"authors": ", ".join(doc_data.get("authors", [])),
|
51 |
"year": doc_data.get("date"),
|
52 |
"doi": doc_data.get("doi"),
|
|
|
|
|
|
|
|
|
53 |
}
|
54 |
|
55 |
-
# Append document data for use in ChromaDB indexing
|
56 |
self.documents.append(
|
57 |
Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
|
58 |
)
|
@@ -83,7 +87,7 @@ class RAGPipeline:
|
|
83 |
|
84 |
def query(
|
85 |
self, context: str, prompt_template: PromptTemplate = None
|
86 |
-
) -> Dict[str, Any]:
|
87 |
if prompt_template is None:
|
88 |
prompt_template = PromptTemplate(
|
89 |
"Context information is below.\n"
|
@@ -98,9 +102,7 @@ class RAGPipeline:
|
|
98 |
"Ensure that EVERY statement from the context is properly cited."
|
99 |
)
|
100 |
|
101 |
-
# This is a hack to index all the documents in the store :)
|
102 |
n_documents = len(self.index.docstore.docs)
|
103 |
-
print(f"n_documents: {n_documents}")
|
104 |
query_engine = self.index.as_query_engine(
|
105 |
text_qa_template=prompt_template,
|
106 |
similarity_top_k=n_documents if n_documents <= 17 else 15,
|
@@ -108,7 +110,18 @@ class RAGPipeline:
|
|
108 |
llm=OpenAI(model="gpt-4o-mini"),
|
109 |
)
|
110 |
|
111 |
-
# Perform the query
|
112 |
response = query_engine.query(context)
|
113 |
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from llama_index.llms.openai import OpenAI
|
11 |
from llama_index.vector_stores.chroma import ChromaVectorStore
|
12 |
import chromadb
|
13 |
+
from typing import Dict, Any, List, Tuple
|
14 |
+
|
15 |
|
16 |
logging.basicConfig(level=logging.INFO)
|
17 |
|
|
|
29 |
self.documents = None
|
30 |
self.client = chromadb.Client()
|
31 |
self.collection = self.client.get_or_create_collection(self.collection_name)
|
|
|
32 |
self.embedding_model = OpenAIEmbedding(model_name="text-embedding-ada-002")
|
33 |
self.load_documents()
|
34 |
self.build_index()
|
|
|
51 |
"authors": ", ".join(doc_data.get("authors", [])),
|
52 |
"year": doc_data.get("date"),
|
53 |
"doi": doc_data.get("doi"),
|
54 |
+
"source_file": doc_data.get("source_file"), # Add source file path
|
55 |
+
"page_numbers": list(
|
56 |
+
doc_data.get("pages", {}).keys()
|
57 |
+
), # Add page numbers
|
58 |
}
|
59 |
|
|
|
60 |
self.documents.append(
|
61 |
Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
|
62 |
)
|
|
|
87 |
|
88 |
def query(
|
89 |
self, context: str, prompt_template: PromptTemplate = None
|
90 |
+
) -> Tuple[str, Dict[str, Any]]:
|
91 |
if prompt_template is None:
|
92 |
prompt_template = PromptTemplate(
|
93 |
"Context information is below.\n"
|
|
|
102 |
"Ensure that EVERY statement from the context is properly cited."
|
103 |
)
|
104 |
|
|
|
105 |
n_documents = len(self.index.docstore.docs)
|
|
|
106 |
query_engine = self.index.as_query_engine(
|
107 |
text_qa_template=prompt_template,
|
108 |
similarity_top_k=n_documents if n_documents <= 17 else 15,
|
|
|
110 |
llm=OpenAI(model="gpt-4o-mini"),
|
111 |
)
|
112 |
|
|
|
113 |
response = query_engine.query(context)
|
114 |
|
115 |
+
# Extract source information from the response nodes
|
116 |
+
source_info = {}
|
117 |
+
if hasattr(response, "source_nodes") and response.source_nodes:
|
118 |
+
source_node = response.source_nodes[0] # Get the most relevant source
|
119 |
+
metadata = source_node.metadata
|
120 |
+
source_info = {
|
121 |
+
"source_file": metadata.get("source_file"),
|
122 |
+
"page_numbers": metadata.get("page_numbers", []),
|
123 |
+
"title": metadata.get("title"),
|
124 |
+
"authors": metadata.get("authors"),
|
125 |
+
}
|
126 |
+
|
127 |
+
return response.response, source_info
|
study_files.json
CHANGED
@@ -10,5 +10,4 @@
|
|
10 |
"iom": "data/iom_zotero_items.json",
|
11 |
"ExportedRis_file_1_of_1 (1)": "data/exportedris-file-1-of-1-1_zotero_items.json",
|
12 |
"wb_1813-9450-6689": "data/wb-1813-9450-6689_zotero_items.json",
|
13 |
-
"kayongo papers": "data/kayongo-papers_zotero_items.json"
|
14 |
}
|
|
|
10 |
"iom": "data/iom_zotero_items.json",
|
11 |
"ExportedRis_file_1_of_1 (1)": "data/exportedris-file-1-of-1-1_zotero_items.json",
|
12 |
"wb_1813-9450-6689": "data/wb-1813-9450-6689_zotero_items.json",
|
|
|
13 |
}
|
utils/pdf_processor.py
CHANGED
@@ -3,6 +3,7 @@ PDF processing module for ACRES RAG Platform.
|
|
3 |
Handles PDF file processing, text extraction, and page rendering.
|
4 |
"""
|
5 |
|
|
|
6 |
import os
|
7 |
import fitz
|
8 |
import logging
|
|
|
3 |
Handles PDF file processing, text extraction, and page rendering.
|
4 |
"""
|
5 |
|
6 |
+
# utils/pdf_processor.py
|
7 |
import os
|
8 |
import fitz
|
9 |
import logging
|