File size: 3,649 Bytes
af11e83
14a4318
 
 
 
af11e83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json
from typing import Any, Dict, List

from llama_index.core import Document, PromptTemplate, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter, SentenceWindowNodeParser
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI


class RAGPipeline:
    def __init__(self, study_json, use_semantic_splitter=False):
        self.study_json = study_json
        self.use_semantic_splitter = use_semantic_splitter
        self.documents = None
        self.index = None
        self.load_documents()
        self.build_index()

    def load_documents(self):
        if self.documents is None:
            with open(self.study_json, "r") as f:
                self.data = json.load(f)

            self.documents = []

            for index, doc_data in enumerate(self.data):
                doc_content = (
                    f"Title: {doc_data['title']}\n"
                    f"Abstract: {doc_data['abstract']}\n"
                    f"Authors: {', '.join(doc_data['authors'])}\n"
                    # f"full_text: {doc_data['full_text']}"
                )

                metadata = {
                    "title": doc_data.get("title"),
                    "authors": doc_data.get("authors", []),
                    "year": doc_data.get("date"),
                    "doi": doc_data.get("doi"),
                }

                self.documents.append(
                    Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
                )

    def build_index(self):
        if self.index is None:
            sentence_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=20)

            def _split(text: str) -> List[str]:
                return sentence_splitter.split_text(text)

            node_parser = SentenceWindowNodeParser.from_defaults(
                sentence_splitter=_split,
                window_size=5,
                window_metadata_key="window",
                original_text_metadata_key="original_text",
            )

            nodes = node_parser.get_nodes_from_documents(self.documents)
            self.index = VectorStoreIndex(
                nodes, embed_model=OpenAIEmbedding(model_name="text-embedding-3-large")
            )

    def query(
        self, context: str, prompt_template: PromptTemplate = None
    ) -> Dict[str, Any]:
        if prompt_template is None:
            prompt_template = PromptTemplate(
                "Context information is below.\n"
                "---------------------\n"
                "{context_str}\n"
                "---------------------\n"
                "Given this information, please answer the question: {query_str}\n"
                "Provide an answer to the question using evidence from the context above. "
                "Cite sources using square brackets for EVERY piece of information, e.g. [1], [2], etc. "
                "Even if there's only one source, still include the citation. "
                "If you're unsure about a source, use [?]. "
                "Ensure that EVERY statement from the context is properly cited."
            )

        # This is a hack to index all the documents in the store :)
        n_documents = len(self.index.docstore.docs)
        print(f"n_documents: {n_documents}")
        query_engine = self.index.as_query_engine(
            text_qa_template=prompt_template,
            similarity_top_k=n_documents if n_documents <= 17 else 15,
            response_mode="tree_summarize",
            llm=OpenAI(model="gpt-4o-mini"),
        )

        response = query_engine.query(context)

        return response