acres / rag /rag_pipeline.py
ak3ra's picture
added source citations
d0a03de
# rag/rag_pipeline.py
import json
import logging
import os
import re
from typing import Any, Dict, List, Optional, Tuple
import chromadb
from dotenv import load_dotenv
from llama_index.core import Document, PromptTemplate, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter, SentenceWindowNodeParser
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.vector_stores.chroma import ChromaVectorStore
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
load_dotenv()
class RAGPipeline:
def __init__(
self,
study_json,
collection_name="study_files_rag_collection",
use_semantic_splitter=False,
):
self.study_json = study_json
self.collection_name = collection_name
self.use_semantic_splitter = use_semantic_splitter
self.documents = None
self.client = chromadb.Client()
self.collection = self.client.get_or_create_collection(self.collection_name)
self.embedding_model = OpenAIEmbedding(
model_name="text-embedding-ada-002", api_key=os.getenv("OPENAI_API_KEY")
)
self.is_pdf = self._check_if_pdf_collection()
self.load_documents()
self.build_index()
def _check_if_pdf_collection(self) -> bool:
"""Check if this is a PDF collection based on the JSON structure."""
try:
with open(self.study_json, "r") as f:
data = json.load(f)
# Check first document for PDF-specific fields
if data and isinstance(data, list) and len(data) > 0:
return "pages" in data[0] and "source_file" in data[0]
return False
except Exception as e:
logger.error(f"Error checking collection type: {str(e)}")
return False
def extract_page_number_from_query(self, query: str) -> int:
"""Extract page number from query text."""
# Look for patterns like "page 3", "p3", "p. 3", etc.
patterns = [
r"page\s*(\d+)",
r"p\.\s*(\d+)",
r"p\s*(\d+)",
r"pg\.\s*(\d+)",
r"pg\s*(\d+)",
]
for pattern in patterns:
match = re.search(pattern, query.lower())
if match:
return int(match.group(1))
return None
def load_documents(self):
if self.documents is None:
with open(self.study_json, "r") as f:
self.data = json.load(f)
self.documents = []
if self.is_pdf:
# Handle PDF documents
for index, doc_data in enumerate(self.data):
pages = doc_data.get("pages", {})
for page_num, page_content in pages.items():
if isinstance(page_content, dict):
content = page_content.get("text", "")
else:
content = page_content
doc_content = (
f"Title: {doc_data['title']}\n"
f"Page {page_num} Content:\n{content}\n"
f"Authors: {', '.join(doc_data['authors'])}\n"
)
metadata = {
"title": doc_data.get("title"),
"authors": ", ".join(doc_data.get("authors", [])),
"year": doc_data.get("date"),
"source_file": doc_data.get("source_file"),
"page_number": int(page_num),
"total_pages": doc_data.get("page_count"),
}
self.documents.append(
Document(
text=doc_content,
id_=f"doc_{index}_page_{page_num}",
metadata=metadata,
)
)
else:
# Handle Zotero documents
for index, doc_data in enumerate(self.data):
doc_content = (
f"Title: {doc_data.get('title', '')}\n"
f"Abstract: {doc_data.get('abstract', '')}\n"
f"Authors: {', '.join(doc_data.get('authors', []))}\n"
)
metadata = {
"title": doc_data.get("title"),
"authors": ", ".join(doc_data.get("authors", [])),
"year": doc_data.get("date"),
"doi": doc_data.get("doi"),
}
self.documents.append(
Document(
text=doc_content, id_=f"doc_{index}", metadata=metadata
)
)
def build_index(self):
sentence_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=20)
def _split(text: str) -> List[str]:
return sentence_splitter.split_text(text)
node_parser = SentenceWindowNodeParser.from_defaults(
sentence_splitter=_split,
window_size=5,
window_metadata_key="window",
original_text_metadata_key="original_text",
)
# Parse documents into nodes for embedding
nodes = node_parser.get_nodes_from_documents(self.documents)
# Initialize ChromaVectorStore with the existing collection
vector_store = ChromaVectorStore(chroma_collection=self.collection)
# Create the VectorStoreIndex using the ChromaVectorStore
self.index = VectorStoreIndex(
nodes, vector_store=vector_store, embed_model=self.embedding_model
)
def query(
self, context: str, prompt_template: PromptTemplate = None
) -> Tuple[str, List[Any]]:
if prompt_template is None:
prompt_template = PromptTemplate(
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given this information, please answer the question: {query_str}\n"
"Follow these guidelines for your response:\n"
"1. If the answer contains multiple pieces of information (e.g., author names, dates, statistics), "
"present it in a markdown table format.\n"
"2. For single piece information or simple answers, respond in a clear sentence.\n"
"3. Always cite sources using square brackets for EVERY piece of information, e.g. [1], [2], etc.\n"
"4. If the information spans multiple documents or pages, organize it by source.\n"
"5. If you're unsure about something, say so rather than making assumptions.\n"
"\nFormat tables like this:\n"
"| Field | Information | Source |\n"
"|-------|-------------|--------|\n"
"| Title | Example Title | [1] |\n"
)
# Extract page number for PDF documents
requested_page = (
self.extract_page_number_from_query(context) if self.is_pdf else None
)
n_documents = len(self.index.docstore.docs)
print(f"n_documents: {n_documents}")
query_engine = self.index.as_query_engine(
text_qa_template=prompt_template,
similarity_top_k=n_documents if n_documents <= 17 else 15,
response_mode="tree_summarize",
llm=OpenAI(model="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY")),
)
response = query_engine.query(context)
# Debug logging
print(f"Response type: {type(response)}")
print(f"Has source_nodes: {hasattr(response, 'source_nodes')}")
if hasattr(response, 'source_nodes'):
print(f"Number of source nodes: {len(response.source_nodes)}")
return response.response, getattr(response, 'source_nodes', [])