File size: 10,463 Bytes
58c81e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9618bfc
58c81e4
 
 
 
 
 
 
f81204d
8db9985
58c81e4
5abd48d
58c81e4
 
 
a2ff208
a26ec13
58c81e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f81204d
58c81e4
 
 
 
 
 
 
 
 
5abd48d
e1fa991
 
58c81e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1fbaf9
58c81e4
 
 
 
 
8af0ee8
5abd48d
a2ff208
 
5abd48d
58c81e4
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import os
from neo4j import GraphDatabase, Result
import pandas as pd
import numpy as np

from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.graphs import Neo4jGraph
from langchain_community.vectorstores import Neo4jVector

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

from langchain_huggingface import HuggingFaceEndpoint

from typing import Dict, Any
from tqdm import tqdm
from transformers import AutoTokenizer

NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
vector_index = os.getenv("VECTOR_INDEX")

chat_llm = HuggingFaceEndpoint(
    # repo_id="HuggingFaceH4/zephyr-7b-beta",
    repo_id="mistralai/Mistral-7B-Instruct-v0.3",
    task="text-generation",
    max_new_tokens=4096,
    do_sample=False,
)

# global_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")


def local_retriever(query: str):
    topChunks = 3
    topCommunities = 3
    topOutsideRels = 10
    topInsideRels = 10
    topEntities = 10

    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
    try:
        lc_retrieval_query = """
        WITH collect(node) as nodes
        // Entity - Text Unit Mapping
        WITH
        collect {
            UNWIND nodes as n
            MATCH (n)<-[:HAS_ENTITY]->(c:__Chunk__)
            WITH c, count(distinct n) as freq
            RETURN c.text AS chunkText
            ORDER BY freq DESC
            LIMIT $topChunks
        } AS text_mapping,
        // Entity - Report Mapping
        collect {
            UNWIND nodes as n
            MATCH (n)-[:IN_COMMUNITY]->(c:__Community__)
            WITH c, c.rank as rank, c.weight AS weight
            RETURN c.summary 
            ORDER BY rank, weight DESC
            LIMIT $topCommunities
        } AS report_mapping,
        // Outside Relationships 
        collect {
            UNWIND nodes as n
            MATCH (n)-[r:RELATED]-(m) 
            WHERE NOT m IN nodes
            RETURN r.description AS descriptionText
            ORDER BY r.rank, r.weight DESC 
            LIMIT $topOutsideRels
        } as outsideRels,
        // Inside Relationships 
        collect {
            UNWIND nodes as n
            MATCH (n)-[r:RELATED]-(m) 
            WHERE m IN nodes
            RETURN r.description AS descriptionText
            ORDER BY r.rank, r.weight DESC 
            LIMIT $topInsideRels
        } as insideRels,
        // Entities description
        collect {
            UNWIND nodes as n
            RETURN n.description AS descriptionText
        } as entities
        // We don't have covariates or claims here
        RETURN {Chunks: text_mapping, Reports: report_mapping, 
            Relationships: outsideRels + insideRels, 
            Entities: entities} AS text, 1.0 AS score, {} AS metadata
        """

        embedding_model_name = "nomic-ai/nomic-embed-text-v1"
        embedding_model_kwargs = {"device": "cpu", "trust_remote_code": True}
        encode_kwargs = {"normalize_embeddings": True}
        embedding_model = HuggingFaceBgeEmbeddings(
            model_name=embedding_model_name,
            model_kwargs=embedding_model_kwargs,
            encode_kwargs=encode_kwargs,
        )

        lc_vector = Neo4jVector.from_existing_index(
            embedding_model,
            url=NEO4J_URI,
            username=NEO4J_USERNAME,
            password=NEO4J_PASSWORD,
            index_name=vector_index,
            retrieval_query=lc_retrieval_query,
        )
        docs = lc_vector.similarity_search(
            query,
            k=topEntities,
            params={
                "topChunks": topChunks,
                "topCommunities": topCommunities,
                "topOutsideRels": topOutsideRels,
                "topInsideRels": topInsideRels,
            },
        )

        return docs[0]
    except Exception as err:
        return f"Error: {err}"
    finally:
        try:
            driver.close()
        except Exception as e:
            print(f"Error closing driver: {e}")


def global_retriever(query: str, level: int, response_type: str):
    MAP_SYSTEM_PROMPT = """
    ---Role---

    You are a helpful assistant responding to questions about data in the tables provided.

    ---Goal---

    Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables.

    You should use the data provided in the data tables below as the primary context for generating the response.
    If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.

    Each key point in the response should have the following element:
    - Description: A comprehensive description of the point.
    - Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0.

    The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".

    Points supported by data should list the relevant reports as references as follows:
    "This is an example sentence supported by data references [Data: Reports (report ids)]"

    **Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.

    For example:
    "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]"

    where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables.

    Do not include information where the supporting evidence for it is not provided. Always start with {{ and end with }}. 

    The response can only be JSON formatted. Do not add any text before or after the JSON-formatted string in the output.

    The response should adhere to the following format:
    {{
        "points": [
            {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}},
            {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}}
        ]
    }}

    ---Data tables---

    """
    map_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                MAP_SYSTEM_PROMPT,
            ),
            ("system", "{context_data}"),
            (
                "human",
                "{question}",
            ),
        ]
    )

    map_chain = map_prompt | chat_llm | StrOutputParser()

    REDUCE_SYSTEM_PROMPT = """
    ---Role---

    You are a helpful assistant responding to questions about a dataset by synthesizing perspectives from multiple analysts.


    ---Goal---

    Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset.

    Note that the analysts' reports provided below are ranked in the **descending order of importance**.

    If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up.

    The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format.

    Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.

    The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".

    The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process.

    **Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.

    For example:

    "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]"

    where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record.

    Do not include information where the supporting evidence for it is not provided. Style the response in markdown.


    ---Target response length and format---
    {response_type}


    ---Analyst Reports---
    {report_data}

    Add sections and commentary to the response as appropriate for the length and format. Do not add references in your answer.

    ---Real Data---
    """

    reduce_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                REDUCE_SYSTEM_PROMPT,
            ),
            (
                "human",
                "{question}",
            ),
        ]
    )

    reduce_chain = reduce_prompt | chat_llm | StrOutputParser()

    graph = Neo4jGraph(
        url=NEO4J_URI,
        username=NEO4J_USERNAME,
        password=NEO4J_PASSWORD,
        refresh_schema=False,
    )

    community_data = graph.query(
        """
    MATCH (c:__Community__)
    WHERE c.level = $level
    RETURN c.full_content AS output
    """,
        params={"level": level},
    )
    # print(community_data)
    intermediate_results = []
    i = 0
    for community in tqdm(community_data[:3], desc="Processing communities"):
        intermediate_response = map_chain.invoke(
            {"question": query, "context_data": community["output"]}
        )
        intermediate_results.append(intermediate_response)
        i += 1
    print(intermediate_results)
    ###Debug####
    # tokens = global_tokenizer(intermediate_results)
    # print(f"Number of input tokens: {len(tokens)}")
    ###Debug###
    final_response = reduce_chain.invoke(
        {
            "report_data": intermediate_results,
            "question": query,
            "response_type": response_type,
        }
    )
    return final_response