File size: 4,641 Bytes
cfb1a62
 
8121eff
6a076b8
cfb1a62
 
bc5a5b2
f4b7267
8121eff
 
 
cfb1a62
b117341
 
669d93a
 
122cee1
 
8121eff
 
669d93a
 
 
8121eff
669d93a
8121eff
669d93a
 
 
 
 
 
8121eff
669d93a
 
 
 
 
 
 
b117341
669d93a
 
 
8121eff
122cee1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f52091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a076b8
9f2191f
6a076b8
b117341
 
8121eff
 
 
 
 
 
 
 
 
 
b117341
 
cfb1a62
b117341
122cee1
5f52091
9f2191f
b117341
6a076b8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# rag/rag_pipeline.py

import json
from typing import Dict, Any
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
from llama_index.core import PromptTemplate
from typing import List


class RAGPipeline:
    def __init__(self, study_json, use_semantic_splitter=False):
        self.study_json = study_json
        self.use_semantic_splitter = use_semantic_splitter
        self.documents = None
        self.index = None
        self.load_documents()
        self.build_index()

    def load_documents(self):
        if self.documents is None:
            with open(self.study_json, "r") as f:
                self.data = json.load(f)

            self.documents = []

            for index, doc_data in enumerate(self.data):
                doc_content = (
                    f"Title: {doc_data['title']}\n"
                    f"Authors: {', '.join(doc_data['authors'])}\n"
                    f"Full Text: {doc_data['full_text']}"
                )

                metadata = {
                    "title": doc_data.get("title"),
                    "abstract": doc_data.get("abstract"),
                    "authors": doc_data.get("authors", []),
                    "year": doc_data.get("year"),
                    "doi": doc_data.get("doi"),
                }

                self.documents.append(
                    Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
                )

    def build_index(self):
        if self.index is None:
            sentence_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=13)

            def _split(text: str) -> List[str]:
                return sentence_splitter.split_text(text)

            node_parser = SentenceWindowNodeParser.from_defaults(
                sentence_splitter=_split,
                window_size=3,
                window_metadata_key="window",
                original_text_metadata_key="original_text",
            )

            nodes = node_parser.get_nodes_from_documents(self.documents)
            self.index = VectorStoreIndex(nodes)

    def extract_study_info(self) -> Dict[str, Any]:
        extraction_prompt = PromptTemplate(
            "Based on the given context, please extract the following information about the study:\n"
            "1. Study ID\n"
            "2. Author(s)\n"
            "3. Year\n"
            "4. Title\n"
            "5. Study design\n"
            "6. Study area/region\n"
            "7. Study population\n"
            "8. Disease under study\n"
            "9. Duration of study\n"
            "If the information is not available, please respond with 'Not found' for that field.\n"
            "Context: {context_str}\n"
            "Extracted information:"
        )

        query_engine = self.index.as_query_engine(
            text_qa_template=extraction_prompt, similarity_top_k=5
        )

        response = query_engine.query("Extract study information")

        # Parse the response to extract key-value pairs
        lines = response.response.split("\n")
        extracted_info = {}
        for line in lines:
            if ":" in line:
                key, value = line.split(":", 1)
                extracted_info[key.strip().lower().replace(" ", "_")] = value.strip()

        return extracted_info

    def query(
        self, question: str, prompt_template: PromptTemplate = None, **kwargs
    ) -> Dict[str, Any]:
        if prompt_template is None:
            prompt_template = PromptTemplate(
                "Context information is below.\n"
                "---------------------\n"
                "{context_str}\n"
                "---------------------\n"
                "Given this information, please answer the question: {query_str}\n"
                "Include all relevant information from the provided context. "
                "If information comes from multiple sources, please mention all of them. "
                "If the information is not available in the context, please state that clearly. "
                "When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
            )

        query_engine = self.index.as_query_engine(
            text_qa_template=prompt_template, similarity_top_k=5
        )

        # Use kwargs to pass additional parameters to the query
        response = query_engine.query(question, **kwargs)

        return {
            "question": question,
            "answer": response.response,
            "sources": [node.metadata for node in response.source_nodes],
        }