Spaces:
Sleeping
Sleeping
Refactor typing import in app.py and process each page separately in RAGPipeline
Browse files- app.py +1 -1
- rag/rag_pipeline.py +35 -28
app.py
CHANGED
@@ -6,7 +6,7 @@ import io
|
|
6 |
import json
|
7 |
import logging
|
8 |
import os
|
9 |
-
from typing import Tuple, List
|
10 |
|
11 |
import gradio as gr
|
12 |
import openai
|
|
|
6 |
import json
|
7 |
import logging
|
8 |
import os
|
9 |
+
from typing import Tuple, List, Any
|
10 |
|
11 |
import gradio as gr
|
12 |
import openai
|
rag/rag_pipeline.py
CHANGED
@@ -40,26 +40,32 @@ class RAGPipeline:
|
|
40 |
|
41 |
self.documents = []
|
42 |
for index, doc_data in enumerate(self.data):
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
doc_data.get("
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
def build_index(self):
|
65 |
sentence_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=20)
|
@@ -95,17 +101,16 @@ class RAGPipeline:
|
|
95 |
"{context_str}\n"
|
96 |
"---------------------\n"
|
97 |
"Given this information, please answer the question: {query_str}\n"
|
98 |
-
"Provide
|
|
|
99 |
"Cite sources using square brackets for EVERY piece of information, e.g. [1], [2], etc. "
|
100 |
-
"
|
101 |
-
"If you're unsure about
|
102 |
-
"Ensure that EVERY statement from the context is properly cited."
|
103 |
)
|
104 |
|
105 |
-
n_documents = len(self.index.docstore.docs)
|
106 |
query_engine = self.index.as_query_engine(
|
107 |
text_qa_template=prompt_template,
|
108 |
-
similarity_top_k=
|
109 |
response_mode="tree_summarize",
|
110 |
llm=OpenAI(model="gpt-4o-mini"),
|
111 |
)
|
@@ -115,13 +120,15 @@ class RAGPipeline:
|
|
115 |
# Extract source information from the response nodes
|
116 |
source_info = {}
|
117 |
if hasattr(response, "source_nodes") and response.source_nodes:
|
118 |
-
|
|
|
119 |
metadata = source_node.metadata
|
120 |
source_info = {
|
121 |
"source_file": metadata.get("source_file"),
|
122 |
-
"
|
123 |
"title": metadata.get("title"),
|
124 |
"authors": metadata.get("authors"),
|
|
|
125 |
}
|
126 |
|
127 |
return response.response, source_info
|
|
|
40 |
|
41 |
self.documents = []
|
42 |
for index, doc_data in enumerate(self.data):
|
43 |
+
# Process each page's content separately
|
44 |
+
pages = doc_data.get("pages", {})
|
45 |
+
for page_num, page_content in pages.items():
|
46 |
+
doc_content = (
|
47 |
+
f"Title: {doc_data['title']}\n"
|
48 |
+
f"Page {page_num} Content:\n{page_content}\n"
|
49 |
+
f"Authors: {', '.join(doc_data['authors'])}\n"
|
50 |
+
)
|
51 |
+
|
52 |
+
metadata = {
|
53 |
+
"title": doc_data.get("title"),
|
54 |
+
"authors": ", ".join(doc_data.get("authors", [])),
|
55 |
+
"year": doc_data.get("date"),
|
56 |
+
"doi": doc_data.get("doi"),
|
57 |
+
"source_file": doc_data.get("source_file"),
|
58 |
+
"page_number": page_num, # Store single page number
|
59 |
+
"total_pages": len(pages),
|
60 |
+
}
|
61 |
+
|
62 |
+
self.documents.append(
|
63 |
+
Document(
|
64 |
+
text=doc_content,
|
65 |
+
id_=f"doc_{index}_page_{page_num}",
|
66 |
+
metadata=metadata,
|
67 |
+
)
|
68 |
+
)
|
69 |
|
70 |
def build_index(self):
|
71 |
sentence_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=20)
|
|
|
101 |
"{context_str}\n"
|
102 |
"---------------------\n"
|
103 |
"Given this information, please answer the question: {query_str}\n"
|
104 |
+
"Provide a detailed answer using the content from the context above. "
|
105 |
+
"If the question asks about specific page content, make sure to include that information. "
|
106 |
"Cite sources using square brackets for EVERY piece of information, e.g. [1], [2], etc. "
|
107 |
+
"Include page numbers in citations when available, e.g. [1, p.3]. "
|
108 |
+
"If you're unsure about something, say so rather than making assumptions."
|
|
|
109 |
)
|
110 |
|
|
|
111 |
query_engine = self.index.as_query_engine(
|
112 |
text_qa_template=prompt_template,
|
113 |
+
similarity_top_k=5, # Reduced for more focused results
|
114 |
response_mode="tree_summarize",
|
115 |
llm=OpenAI(model="gpt-4o-mini"),
|
116 |
)
|
|
|
120 |
# Extract source information from the response nodes
|
121 |
source_info = {}
|
122 |
if hasattr(response, "source_nodes") and response.source_nodes:
|
123 |
+
# Get the most relevant source
|
124 |
+
source_node = response.source_nodes[0]
|
125 |
metadata = source_node.metadata
|
126 |
source_info = {
|
127 |
"source_file": metadata.get("source_file"),
|
128 |
+
"page_number": metadata.get("page_number"),
|
129 |
"title": metadata.get("title"),
|
130 |
"authors": metadata.get("authors"),
|
131 |
+
"content": source_node.text, # Include the actual content
|
132 |
}
|
133 |
|
134 |
return response.response, source_info
|