Spaces:
Sleeping
Sleeping
File size: 8,133 Bytes
7ddc93d 8121eff af11e83 632bf0d 14a4318 af11e83 14a4318 632bf0d 14a4318 d762ede af11e83 8121eff af11e83 ff19631 8121eff 632bf0d 05d5b78 8121eff 05d5b78 b117341 af11e83 b117341 669d93a af11e83 632bf0d 286d467 122cee1 8121eff 286d467 ff19631 8121eff 669d93a 8121eff 669d93a 286d467 5695f84 286d467 5695f84 286d467 5695f84 8121eff 122cee1 af11e83 122cee1 af11e83 122cee1 af11e83 122cee1 af11e83 05d5b78 af11e83 05d5b78 d0a03de 5f52091 6a076b8 9a9bac9 d0a03de b117341 d0a03de 05d5b78 286d467 ff19631 7b9cfed b117341 d762ede 7b9cfed d762ede 632bf0d b117341 122cee1 9a9bac9 d0a03de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
# rag/rag_pipeline.py
import json
import logging
import os
import re
from typing import Any, Dict, List, Optional, Tuple
import chromadb
from dotenv import load_dotenv
from llama_index.core import Document, PromptTemplate, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter, SentenceWindowNodeParser
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.vector_stores.chroma import ChromaVectorStore
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
load_dotenv()
class RAGPipeline:
def __init__(
self,
study_json,
collection_name="study_files_rag_collection",
use_semantic_splitter=False,
):
self.study_json = study_json
self.collection_name = collection_name
self.use_semantic_splitter = use_semantic_splitter
self.documents = None
self.client = chromadb.Client()
self.collection = self.client.get_or_create_collection(self.collection_name)
self.embedding_model = OpenAIEmbedding(
model_name="text-embedding-ada-002", api_key=os.getenv("OPENAI_API_KEY")
)
self.is_pdf = self._check_if_pdf_collection()
self.load_documents()
self.build_index()
def _check_if_pdf_collection(self) -> bool:
"""Check if this is a PDF collection based on the JSON structure."""
try:
with open(self.study_json, "r") as f:
data = json.load(f)
# Check first document for PDF-specific fields
if data and isinstance(data, list) and len(data) > 0:
return "pages" in data[0] and "source_file" in data[0]
return False
except Exception as e:
logger.error(f"Error checking collection type: {str(e)}")
return False
def extract_page_number_from_query(self, query: str) -> int:
"""Extract page number from query text."""
# Look for patterns like "page 3", "p3", "p. 3", etc.
patterns = [
r"page\s*(\d+)",
r"p\.\s*(\d+)",
r"p\s*(\d+)",
r"pg\.\s*(\d+)",
r"pg\s*(\d+)",
]
for pattern in patterns:
match = re.search(pattern, query.lower())
if match:
return int(match.group(1))
return None
def load_documents(self):
if self.documents is None:
with open(self.study_json, "r") as f:
self.data = json.load(f)
self.documents = []
if self.is_pdf:
# Handle PDF documents
for index, doc_data in enumerate(self.data):
pages = doc_data.get("pages", {})
for page_num, page_content in pages.items():
if isinstance(page_content, dict):
content = page_content.get("text", "")
else:
content = page_content
doc_content = (
f"Title: {doc_data['title']}\n"
f"Page {page_num} Content:\n{content}\n"
f"Authors: {', '.join(doc_data['authors'])}\n"
)
metadata = {
"title": doc_data.get("title"),
"authors": ", ".join(doc_data.get("authors", [])),
"year": doc_data.get("date"),
"source_file": doc_data.get("source_file"),
"page_number": int(page_num),
"total_pages": doc_data.get("page_count"),
}
self.documents.append(
Document(
text=doc_content,
id_=f"doc_{index}_page_{page_num}",
metadata=metadata,
)
)
else:
# Handle Zotero documents
for index, doc_data in enumerate(self.data):
doc_content = (
f"Title: {doc_data.get('title', '')}\n"
f"Abstract: {doc_data.get('abstract', '')}\n"
f"Authors: {', '.join(doc_data.get('authors', []))}\n"
)
metadata = {
"title": doc_data.get("title"),
"authors": ", ".join(doc_data.get("authors", [])),
"year": doc_data.get("date"),
"doi": doc_data.get("doi"),
}
self.documents.append(
Document(
text=doc_content, id_=f"doc_{index}", metadata=metadata
)
)
def build_index(self):
sentence_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=20)
def _split(text: str) -> List[str]:
return sentence_splitter.split_text(text)
node_parser = SentenceWindowNodeParser.from_defaults(
sentence_splitter=_split,
window_size=5,
window_metadata_key="window",
original_text_metadata_key="original_text",
)
# Parse documents into nodes for embedding
nodes = node_parser.get_nodes_from_documents(self.documents)
# Initialize ChromaVectorStore with the existing collection
vector_store = ChromaVectorStore(chroma_collection=self.collection)
# Create the VectorStoreIndex using the ChromaVectorStore
self.index = VectorStoreIndex(
nodes, vector_store=vector_store, embed_model=self.embedding_model
)
def query(
self, context: str, prompt_template: PromptTemplate = None
) -> Tuple[str, List[Any]]:
if prompt_template is None:
prompt_template = PromptTemplate(
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given this information, please answer the question: {query_str}\n"
"Follow these guidelines for your response:\n"
"1. If the answer contains multiple pieces of information (e.g., author names, dates, statistics), "
"present it in a markdown table format.\n"
"2. For single piece information or simple answers, respond in a clear sentence.\n"
"3. Always cite sources using square brackets for EVERY piece of information, e.g. [1], [2], etc.\n"
"4. If the information spans multiple documents or pages, organize it by source.\n"
"5. If you're unsure about something, say so rather than making assumptions.\n"
"\nFormat tables like this:\n"
"| Field | Information | Source |\n"
"|-------|-------------|--------|\n"
"| Title | Example Title | [1] |\n"
)
# Extract page number for PDF documents
requested_page = (
self.extract_page_number_from_query(context) if self.is_pdf else None
)
n_documents = len(self.index.docstore.docs)
print(f"n_documents: {n_documents}")
query_engine = self.index.as_query_engine(
text_qa_template=prompt_template,
similarity_top_k=n_documents if n_documents <= 17 else 15,
response_mode="tree_summarize",
llm=OpenAI(model="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY")),
)
response = query_engine.query(context)
# Debug logging
print(f"Response type: {type(response)}")
print(f"Has source_nodes: {hasattr(response, 'source_nodes')}")
if hasattr(response, 'source_nodes'):
print(f"Number of source nodes: {len(response.source_nodes)}")
return response.response, getattr(response, 'source_nodes', []) |