File size: 7,229 Bytes
854f61d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from langchain.chains import GraphCypherQAChain
import os
# from neo4j_semantic_layer import agent_executor as neo4j_semantic_layer_chain

# add_routes(app, neo4j_semantic_layer_chain, path="\neo4j-semantic-layer")
    

from decouple import config 
from typing import List, Optional, Dict, Type

from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool
from langchain_community.graphs import Neo4jGraph
from neo4j.exceptions import ServiceUnavailable
import re
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

NEO4J_URL = config("NEO4J_URL",)
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = config("NEO4J_PASSWORD")


def remove_lucene_chars(text: str) -> str:
    """Remove Lucene special characters"""
    special_chars = [
        "+",
        "-",
        "&",
        "|",
        "!",
        "(",
        ")",
        "{",
        "}",
        "[",
        "]",
        "^",
        '"',
        "~",
        "*",
        "?",
        ":",
        "\\",
    ]
    for char in special_chars:
        if char in text:
            text = text.replace(char, " ")
    return text.strip()


# TODO: For now (8/24/2024), this search query works for strings written in Latin alphabet, but not
# any other alphabet. Follow-up action item: create a custom Neo4j Lucene analyzer for mixed-language
# data (or find one open-source somewhere), re-index the data, and then add {analyzer: "my_analyzer"}
# to the candidate_query string below. 
def generate_full_text_query(input: str) -> str:
    """

    Generate a full-text search query for a given input string.



    This function constructs a query string suitable for a full-text search.

    It processes the input string by splitting it into words and appending a

    similarity threshold (~0.7) to each word, then combines them using the AND

    operator. Useful for mapping movies and people from user questions

    to database values, and allows for some misspelings.

    """
    full_text_query = ""
    words = [el for el in remove_lucene_chars(input).split() if el]
    for word in words[:-1]:
        full_text_query += f" {word}~0.7 AND"
    full_text_query += f" {words[-1]}~0.7"
    return full_text_query.strip()


candidate_query = """

CALL db.index.fulltext.queryNodes($index, $fulltextQuery, {limit: $limit})

YIELD node

RETURN node.name AS name, 

        node.summary AS summary,

        labels(node) AS label

        

"""

person_description_query = """

MATCH (e:PERSON)-[r:IN_CATEGORY]-(m:CATEGORY)

WHERE e.name = $name

RETURN e.name AS name,

       e.gender AS gender, 

       e.summary AS summary, 

       m.name AS policy_category 

LIMIT 1

"""

organization_description_query = """

MATCH (o:ORGANIZATION)-[r:IN_CATEGORY]-(m:CATEGORY)

WHERE o.name = $name

RETURN o.name AS name, 

       o.summary AS summary, 

       o.description AS description, 

       o.twitterUri AS twitter_uri,

       o.homepageUri AS homepage_uri,

       m.name AS policy_category

LIMIT 1

"""

@retry(

    retry=retry_if_exception_type(ServiceUnavailable),

    stop=stop_after_attempt(4),

    wait=wait_exponential(multiplier=1, min=4, max=8)

)
def execute_query(query, params):
    graph = Neo4jGraph(NEO4J_URL, NEO4J_USER, NEO4J_PASSWORD)
    return graph.query(query, params)


def get_candidates(input: str, index: str, limit: int = 3) -> List[Dict[str, str]]:
    """

    Retrieve a list of candidate entities from database based on the input string.



    This function queries the Neo4j database using a full-text search. It takes the

    input string, generates a full-text query, and executes this query against the

    specified index in the database. The function returns a list of candidates

    matching the query, with each candidate being a dictionary containing their name,

    summary, and label (either 'ORGANIZATION' or 'PERSON').

    """
    ft_query = generate_full_text_query(input)
    print(ft_query)
    candidates = execute_query(
        candidate_query, {"fulltextQuery": ft_query, "index": index, "limit": limit}
    )
    return candidates


def get_information(entity: str, index: str) -> dict:
    candidates = get_candidates(entity, index)
    if not candidates:
        return None
    # elif len(candidates) > 1:
    #     newline = "\n"
    #     return (
    #         "Multiple matching people were found. They are not the same person. Need additional information to disambiguate the name. "
    #         "In your <avi_answer> output tag, present the user with the following matched options and ask the user which of the options they meant. "
    #         f"Here are the options: {newline + newline.join(str(d) for d in candidates)}"
    #     )
    candidate = candidates[0]
    description_query = (person_description_query if index == "dangerous_individuals" else organization_description_query)
    
    data = execute_query(
        description_query, params={"name": candidate["name"]}
    )

    if not data:
        return None
    candidate_data = data[0]
    # detailed_info = "\n".join([f"{key.replace('_', ' ').title()}: {value}" for key, value in candidate_data.items()])
    detailed_info = {key.replace('_', ' ').title(): value for key, value in candidate_data.items()}
    return detailed_info
 

class InformationInput(BaseModel):
    entity: str = Field(description="full-text search query of the name of a given entity which are mentioned in the question. Example: 'Alice Bob")
    entity_type: str = Field(
        description="indexed list to search for membership by the entity. Available options are 'dangerous_organizations' and 'dangerous_individuals'"
    )


class DIOInformationTool(BaseTool):
    name = "Dangerous_Individuals_And_Organizations_Information"
    description = (
        "useful for when you need to answer questions about various elected officials or persons running for office. " 
        "Never generate a final answer to the question if multiple candidates were matched by name; in those cases, " 
        "always present the candidate options as a bulleted list and ask for disambiguation in your <avi_answer> tag. Never embellish your descriptions of "
        "the candidate in question or assume their motivations from their professional activities under any circumstances; "
        "remain 100% strictly fact-focused at all times with the outputs of this tool."
    )
    args_schema: Type[BaseModel] = InformationInput

    def _run(

        self,

        entity: str,

        entity_type: str,

        run_manager: Optional[CallbackManagerForToolRun] = None,

    ) -> str:
        """Use the tool."""
        return get_information(entity, entity_type)

    async def _arun(

        self,

        entity: str,

        entity_type: str,

        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,

    ) -> str:
        """Use the tool asynchronously."""
        return get_information(entity, entity_type)