File size: 8,294 Bytes
7ddc93d
8121eff
af11e83
632bf0d
14a4318
 
af11e83
14a4318
632bf0d
14a4318
 
d762ede
 
af11e83
8121eff
af11e83
ff19631
8121eff
632bf0d
 
05d5b78
8121eff
05d5b78
 
 
 
 
 
b117341
af11e83
b117341
669d93a
af11e83
 
632bf0d
 
 
286d467
122cee1
 
8121eff
286d467
 
 
 
 
 
 
 
 
 
 
 
 
ff19631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8121eff
669d93a
 
 
8121eff
669d93a
286d467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5695f84
286d467
 
 
5695f84
 
 
 
 
 
 
 
 
 
 
286d467
5695f84
 
8121eff
122cee1
af11e83
122cee1
af11e83
 
122cee1
af11e83
 
 
 
 
 
122cee1
af11e83
 
05d5b78
af11e83
 
 
 
05d5b78
 
 
5f52091
6a076b8
9a9bac9
286d467
b117341
 
8121eff
 
 
 
660cea6
5695f84
 
286d467
5695f84
8121eff
05d5b78
286d467
 
 
 
ff19631
7b9cfed
 
 
b117341
d762ede
7b9cfed
d762ede
632bf0d
b117341
122cee1
9a9bac9
b117341
286d467
 
7bb0003
5695f84
7bb0003
ff19631
286d467
 
 
 
 
 
 
 
 
 
 
 
 
7bb0003
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# rag/rag_pipeline.py
import json
import logging
import os
import re
from typing import Any, Dict, List, Optional, Tuple

import chromadb
from dotenv import load_dotenv
from llama_index.core import Document, PromptTemplate, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter, SentenceWindowNodeParser
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.vector_stores.chroma import ChromaVectorStore

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

load_dotenv()


class RAGPipeline:
    def __init__(
        self,
        study_json,
        collection_name="study_files_rag_collection",
        use_semantic_splitter=False,
    ):
        self.study_json = study_json
        self.collection_name = collection_name
        self.use_semantic_splitter = use_semantic_splitter
        self.documents = None
        self.client = chromadb.Client()
        self.collection = self.client.get_or_create_collection(self.collection_name)
        self.embedding_model = OpenAIEmbedding(
            model_name="text-embedding-ada-002", api_key=os.getenv("OPENAI_API_KEY")
        )
        self.is_pdf = self._check_if_pdf_collection()
        self.load_documents()
        self.build_index()

    def _check_if_pdf_collection(self) -> bool:
        """Check if this is a PDF collection based on the JSON structure."""
        try:
            with open(self.study_json, "r") as f:
                data = json.load(f)
                # Check first document for PDF-specific fields
                if data and isinstance(data, list) and len(data) > 0:
                    return "pages" in data[0] and "source_file" in data[0]
            return False
        except Exception as e:
            logger.error(f"Error checking collection type: {str(e)}")
            return False

    def extract_page_number_from_query(self, query: str) -> int:
        """Extract page number from query text."""
        # Look for patterns like "page 3", "p3", "p. 3", etc.
        patterns = [
            r"page\s*(\d+)",
            r"p\.\s*(\d+)",
            r"p\s*(\d+)",
            r"pg\.\s*(\d+)",
            r"pg\s*(\d+)",
        ]

        for pattern in patterns:
            match = re.search(pattern, query.lower())
            if match:
                return int(match.group(1))
        return None

    def load_documents(self):
        if self.documents is None:
            with open(self.study_json, "r") as f:
                self.data = json.load(f)

            self.documents = []
            if self.is_pdf:
                # Handle PDF documents
                for index, doc_data in enumerate(self.data):
                    pages = doc_data.get("pages", {})
                    for page_num, page_content in pages.items():
                        if isinstance(page_content, dict):
                            content = page_content.get("text", "")
                        else:
                            content = page_content

                        doc_content = (
                            f"Title: {doc_data['title']}\n"
                            f"Page {page_num} Content:\n{content}\n"
                            f"Authors: {', '.join(doc_data['authors'])}\n"
                        )

                        metadata = {
                            "title": doc_data.get("title"),
                            "authors": ", ".join(doc_data.get("authors", [])),
                            "year": doc_data.get("date"),
                            "source_file": doc_data.get("source_file"),
                            "page_number": int(page_num),
                            "total_pages": doc_data.get("page_count"),
                        }

                        self.documents.append(
                            Document(
                                text=doc_content,
                                id_=f"doc_{index}_page_{page_num}",
                                metadata=metadata,
                            )
                        )
            else:
                # Handle Zotero documents
                for index, doc_data in enumerate(self.data):
                    doc_content = (
                        f"Title: {doc_data.get('title', '')}\n"
                        f"Abstract: {doc_data.get('abstract', '')}\n"
                        f"Authors: {', '.join(doc_data.get('authors', []))}\n"
                    )

                    metadata = {
                        "title": doc_data.get("title"),
                        "authors": ", ".join(doc_data.get("authors", [])),
                        "year": doc_data.get("date"),
                        "doi": doc_data.get("doi"),
                    }

                    self.documents.append(
                        Document(
                            text=doc_content, id_=f"doc_{index}", metadata=metadata
                        )
                    )

    def build_index(self):
        sentence_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=20)

        def _split(text: str) -> List[str]:
            return sentence_splitter.split_text(text)

        node_parser = SentenceWindowNodeParser.from_defaults(
            sentence_splitter=_split,
            window_size=5,
            window_metadata_key="window",
            original_text_metadata_key="original_text",
        )

        # Parse documents into nodes for embedding
        nodes = node_parser.get_nodes_from_documents(self.documents)

        # Initialize ChromaVectorStore with the existing collection
        vector_store = ChromaVectorStore(chroma_collection=self.collection)

        # Create the VectorStoreIndex using the ChromaVectorStore
        self.index = VectorStoreIndex(
            nodes, vector_store=vector_store, embed_model=self.embedding_model
        )

    def query(
        self, context: str, prompt_template: PromptTemplate = None
    ) -> Tuple[str, Optional[Dict[str, Any]]]:
        if prompt_template is None:
            prompt_template = PromptTemplate(
                "Context information is below.\n"
                "---------------------\n"
                "{context_str}\n"
                "---------------------\n"
                "Given this information, please answer the question: {query_str}\n"
                "Provide a detailed answer using the content from the context above. "
                "If the question asks about specific page content, make sure to include that information. "
                "Cite sources using square brackets for EVERY piece of information, e.g. [1], [2], etc. "
                "If you're unsure about something, say so rather than making assumptions."
            )

        # Extract page number for PDF documents
        requested_page = (
            self.extract_page_number_from_query(context) if self.is_pdf else None
        )

        # This is a hack to index all the documents in the store :)
        n_documents = len(self.index.docstore.docs)
        print(f"n_documents: {n_documents}")
        query_engine = self.index.as_query_engine(
            text_qa_template=prompt_template,
            similarity_top_k=n_documents if n_documents <= 17 else 15,
            response_mode="tree_summarize",
            llm=OpenAI(model="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY")),
        )

        response = query_engine.query(context)

        # Handle source information based on document type
        source_info = None
        if hasattr(response, "source_nodes") and response.source_nodes:
            source_node = response.source_nodes[0]
            metadata = source_node.metadata

            if self.is_pdf:
                page_number = (
                    requested_page
                    if requested_page is not None
                    else metadata.get("page_number", 0)
                )
                source_info = {
                    "source_file": metadata.get("source_file"),
                    "page_number": page_number,
                    "title": metadata.get("title"),
                    "authors": metadata.get("authors"),
                    "content": source_node.text,
                }

        return response.response, source_info