File size: 7,229 Bytes
854f61d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
from langchain.chains import GraphCypherQAChain
import os
# from neo4j_semantic_layer import agent_executor as neo4j_semantic_layer_chain
# add_routes(app, neo4j_semantic_layer_chain, path="\neo4j-semantic-layer")
from decouple import config
from typing import List, Optional, Dict, Type
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool
from langchain_community.graphs import Neo4jGraph
from neo4j.exceptions import ServiceUnavailable
import re
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
NEO4J_URL = config("NEO4J_URL",)
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = config("NEO4J_PASSWORD")
def remove_lucene_chars(text: str) -> str:
"""Remove Lucene special characters"""
special_chars = [
"+",
"-",
"&",
"|",
"!",
"(",
")",
"{",
"}",
"[",
"]",
"^",
'"',
"~",
"*",
"?",
":",
"\\",
]
for char in special_chars:
if char in text:
text = text.replace(char, " ")
return text.strip()
# TODO: For now (8/24/2024), this search query works for strings written in Latin alphabet, but not
# any other alphabet. Follow-up action item: create a custom Neo4j Lucene analyzer for mixed-language
# data (or find one open-source somewhere), re-index the data, and then add {analyzer: "my_analyzer"}
# to the candidate_query string below.
def generate_full_text_query(input: str) -> str:
"""
Generate a full-text search query for a given input string.
This function constructs a query string suitable for a full-text search.
It processes the input string by splitting it into words and appending a
similarity threshold (~0.7) to each word, then combines them using the AND
operator. Useful for mapping movies and people from user questions
to database values, and allows for some misspelings.
"""
full_text_query = ""
words = [el for el in remove_lucene_chars(input).split() if el]
for word in words[:-1]:
full_text_query += f" {word}~0.7 AND"
full_text_query += f" {words[-1]}~0.7"
return full_text_query.strip()
candidate_query = """
CALL db.index.fulltext.queryNodes($index, $fulltextQuery, {limit: $limit})
YIELD node
RETURN node.name AS name,
node.summary AS summary,
labels(node) AS label
"""
person_description_query = """
MATCH (e:PERSON)-[r:IN_CATEGORY]-(m:CATEGORY)
WHERE e.name = $name
RETURN e.name AS name,
e.gender AS gender,
e.summary AS summary,
m.name AS policy_category
LIMIT 1
"""
organization_description_query = """
MATCH (o:ORGANIZATION)-[r:IN_CATEGORY]-(m:CATEGORY)
WHERE o.name = $name
RETURN o.name AS name,
o.summary AS summary,
o.description AS description,
o.twitterUri AS twitter_uri,
o.homepageUri AS homepage_uri,
m.name AS policy_category
LIMIT 1
"""
@retry(
retry=retry_if_exception_type(ServiceUnavailable),
stop=stop_after_attempt(4),
wait=wait_exponential(multiplier=1, min=4, max=8)
)
def execute_query(query, params):
graph = Neo4jGraph(NEO4J_URL, NEO4J_USER, NEO4J_PASSWORD)
return graph.query(query, params)
def get_candidates(input: str, index: str, limit: int = 3) -> List[Dict[str, str]]:
"""
Retrieve a list of candidate entities from database based on the input string.
This function queries the Neo4j database using a full-text search. It takes the
input string, generates a full-text query, and executes this query against the
specified index in the database. The function returns a list of candidates
matching the query, with each candidate being a dictionary containing their name,
summary, and label (either 'ORGANIZATION' or 'PERSON').
"""
ft_query = generate_full_text_query(input)
print(ft_query)
candidates = execute_query(
candidate_query, {"fulltextQuery": ft_query, "index": index, "limit": limit}
)
return candidates
def get_information(entity: str, index: str) -> dict:
candidates = get_candidates(entity, index)
if not candidates:
return None
# elif len(candidates) > 1:
# newline = "\n"
# return (
# "Multiple matching people were found. They are not the same person. Need additional information to disambiguate the name. "
# "In your <avi_answer> output tag, present the user with the following matched options and ask the user which of the options they meant. "
# f"Here are the options: {newline + newline.join(str(d) for d in candidates)}"
# )
candidate = candidates[0]
description_query = (person_description_query if index == "dangerous_individuals" else organization_description_query)
data = execute_query(
description_query, params={"name": candidate["name"]}
)
if not data:
return None
candidate_data = data[0]
# detailed_info = "\n".join([f"{key.replace('_', ' ').title()}: {value}" for key, value in candidate_data.items()])
detailed_info = {key.replace('_', ' ').title(): value for key, value in candidate_data.items()}
return detailed_info
class InformationInput(BaseModel):
entity: str = Field(description="full-text search query of the name of a given entity which are mentioned in the question. Example: 'Alice Bob")
entity_type: str = Field(
description="indexed list to search for membership by the entity. Available options are 'dangerous_organizations' and 'dangerous_individuals'"
)
class DIOInformationTool(BaseTool):
name = "Dangerous_Individuals_And_Organizations_Information"
description = (
"useful for when you need to answer questions about various elected officials or persons running for office. "
"Never generate a final answer to the question if multiple candidates were matched by name; in those cases, "
"always present the candidate options as a bulleted list and ask for disambiguation in your <avi_answer> tag. Never embellish your descriptions of "
"the candidate in question or assume their motivations from their professional activities under any circumstances; "
"remain 100% strictly fact-focused at all times with the outputs of this tool."
)
args_schema: Type[BaseModel] = InformationInput
def _run(
self,
entity: str,
entity_type: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the tool."""
return get_information(entity, entity_type)
async def _arun(
self,
entity: str,
entity_type: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Use the tool asynchronously."""
return get_information(entity, entity_type)
|