Spaces:
Paused
Paused
import os | |
import gc | |
import time | |
import asyncio | |
import torch | |
import uuid | |
from contextlib import contextmanager | |
from neo4j import GraphDatabase | |
from pyvis.network import Network | |
from src.query_processing.late_chunking.late_chunker import LateChunker | |
from src.query_processing.query_processor import QueryProcessor | |
from src.reasoning.reasoner import Reasoner | |
from src.utils.api_key_manager import APIKeyManager | |
from src.search.search_engine import SearchEngine | |
from src.crawl.crawler import CustomCrawler #, Crawler | |
from sentence_transformers import SentenceTransformer | |
from bert_score.scorer import BERTScorer | |
import numpy as np | |
from concurrent.futures import ThreadPoolExecutor | |
from typing import List, Dict, Any | |
class Neo4jGraphRAG: | |
def __init__(self, num_workers: int = 1): | |
"""Initialize Neo4j connection and required components.""" | |
# Neo4j connection setup | |
self.neo4j_uri = os.getenv("NEO4J_URI") | |
self.neo4j_user = os.getenv("NEO4J_USER") | |
self.neo4j_password = os.getenv("NEO4J_PASSWORD") | |
self.driver = GraphDatabase.driver( | |
self.neo4j_uri, | |
auth=(self.neo4j_user, self.neo4j_password) | |
) | |
# Component initialization | |
self.num_workers = num_workers | |
self.search_engine = SearchEngine() | |
self.query_processor = QueryProcessor() | |
self.reasoner = Reasoner() | |
# self.crawler = Crawler(verbose=True) | |
self.custom_crawler = CustomCrawler(max_concurrent_requests=1000) | |
self.chunking = LateChunker() | |
self.llm = APIKeyManager().get_llm() | |
# Model initialization | |
self.model = SentenceTransformer( | |
"dunzhang/stella_en_400M_v5", | |
trust_remote_code=True, | |
device="cuda" if torch.cuda.is_available() else "cpu" | |
) | |
self.scorer = BERTScorer( | |
model_type="roberta-base", | |
lang="en", | |
rescale_with_baseline=True, | |
device= "cpu" # "cuda" if torch.cuda.is_available() else "cpu" | |
) | |
# Counters and tracking | |
self.root_node_id = "QR" | |
self.node_counter = 0 | |
self.sub_node_counter = 0 | |
self.cross_connections = set() | |
# Add graph tracking | |
self.current_graph_id = None | |
# Thread pool | |
self.executor = ThreadPoolExecutor(max_workers=self.num_workers) | |
# Create a callback to emit an event | |
self.on_event_callback = None | |
def set_on_event_callback(self, callback): | |
"""Register a single callback to be triggered for various event types.""" | |
self.on_event_callback = callback | |
async def emit_event(self, event_type: str, data: dict): | |
"""Helper method to safely emit an event if a callback is registered.""" | |
if self.on_event_callback: | |
# Check if the callback is asynchronous or synchronous | |
if asyncio.iscoroutinefunction(self.on_event_callback): | |
# The callback signature: callback(event_type, data) | |
return await self.on_event_callback(event_type, data) | |
else: | |
return self.on_event_callback(event_type, data) | |
def transaction(self, max_retries: int = 1): | |
"""Synchronous context manager for Neo4j transactions.""" | |
session = self.driver.session() | |
retry_count = 0 | |
while True: | |
try: | |
tx = session.begin_transaction() | |
try: | |
yield tx | |
tx.commit() | |
break | |
except Exception as e: | |
tx.rollback() | |
raise e | |
except Exception as e: | |
retry_count += 1 | |
if retry_count >= max_retries: | |
print(f"Transaction failed after {max_retries} attempts: {str(e)}") | |
raise e | |
print(f"Transaction failed, retrying ({retry_count}/{max_retries}): {str(e)}") | |
time.sleep(1) # Use regular sleep for sync context | |
finally: | |
session.close() | |
def initialize_schema(self): | |
"""Check and initialize database schema.""" | |
constraint_node_id_per_graph = None | |
index_node_query = None | |
index_node_role = None | |
constraint_graph_id = None | |
index_graph_created = None | |
constraint_node_graph = None | |
try: | |
with self.transaction() as tx: | |
# Check if schema already exists by looking for our composite constraint | |
constraint_node_id_per_graph = tx.run(""" | |
SHOW CONSTRAINTS | |
WHERE name = 'constraint_node_id_per_graph' | |
""").data() | |
index_node_role = tx.run(""" | |
SHOW INDEXES | |
WHERE name = 'index_node_role' | |
""").data() | |
index_node_graph_id = tx.run(""" | |
SHOW INDEXES | |
WHERE name = 'index_node_graph_id' | |
""").data() | |
constraint_graph_id = tx.run(""" | |
SHOW CONSTRAINTS | |
WHERE name = 'constraint_graph_id' | |
""").data() | |
index_graph_created = tx.run(""" | |
SHOW INDEXES | |
WHERE name = 'index_graph_created' | |
""").data() | |
constraint_node_graph = tx.run(""" | |
SHOW CONSTRAINTS | |
WHERE name = 'constraint_node_graph' | |
""").data() | |
if constraint_node_id_per_graph and index_node_role and \ | |
index_node_graph_id and constraint_graph_id and index_graph_created and constraint_node_graph: | |
print("Database schema already initialized") | |
return | |
print("Initializing database schema...") | |
# Create composite constraint for node ID uniqueness within each graph | |
if not constraint_node_id_per_graph: | |
tx.run(""" | |
CREATE CONSTRAINT constraint_node_id_per_graph IF NOT EXISTS | |
FOR (n:Node) | |
REQUIRE (n.id, n.graph_id) IS UNIQUE | |
""") | |
if not index_node_role: | |
tx.run(""" | |
CREATE INDEX index_node_role IF NOT EXISTS FOR (n:Node) | |
ON (n.role) | |
""") | |
if not index_node_graph_id: | |
tx.run(""" | |
CREATE INDEX index_node_graph_id IF NOT EXISTS FOR (n:Node) | |
ON (n.graph_id) | |
""") | |
# Graph management constraints | |
if not constraint_graph_id: | |
tx.run(""" | |
CREATE CONSTRAINT constraint_graph_id IF NOT EXISTS | |
FOR (g:Graph) | |
REQUIRE g.id IS UNIQUE | |
""") | |
if not index_graph_created: | |
tx.run(""" | |
CREATE INDEX index_graph_created IF NOT EXISTS FOR (g:Graph) | |
ON (g.created) | |
""") | |
if not constraint_node_graph: | |
tx.run(""" | |
CREATE CONSTRAINT constraint_node_graph IF NOT EXISTS | |
FOR (n:Node) | |
REQUIRE n.graph_id IS NOT NULL | |
""") | |
print("Database schema initialization complete") | |
except Exception as e: | |
print(f"Error ensuring schema exists: {str(e)}") | |
raise | |
def add_node(self, node_id: str, query: str, data: str = "", role: str = None): | |
"""Add a node to the current graph.""" | |
if self.current_graph_id is None: | |
raise Exception("Error: No current graph selected") | |
try: | |
with self.transaction() as tx: | |
# Generate embedding | |
embedding = self.model.encode(query).tolist() | |
# Create node with properties including embedding and graph ID | |
result = tx.run( | |
""" | |
MERGE (n:Node {id: $node_id, graph_id: $graph_id}) | |
SET n.query = $node_query, | |
n.embedding = $embedding, | |
n.data = $data, | |
n.role = $role | |
""", | |
node_id=node_id, | |
graph_id=self.current_graph_id, | |
node_query=query, | |
embedding=embedding, | |
data=data, | |
role=role | |
) | |
print(f"Added node '{node_id}' to graph '{self.current_graph_id}' with role '{role}' and query: '{query}'") | |
except Exception as e: | |
print(f"Error adding node '{node_id}' to graph '{self.current_graph_id}' with role '{role}' and query: '{query}': {str(e)}") | |
raise | |
def add_edge(self, node1: str, node2: str, weight: float = 1.0, relationship_type: str = None): | |
"""Add an edge between two nodes in a way that preserves a DAG structure in the graph""" | |
if self.current_graph_id is None: | |
raise Exception("Error: No current graph selected") | |
# 1) Prevent self loops | |
if node1 == node2: | |
print(f"Cannot add edge to the same node {node1}!") | |
return | |
try: | |
with self.transaction() as tx: | |
# 2) Check if there is already a path from node2 back to node1 | |
check_path = tx.run( | |
""" | |
MATCH (start:Node {id: $node2, graph_id: $graph_id}) | |
MATCH (end:Node {id: $node1, graph_id: $graph_id}) | |
// If there's any path of length >= 0 from 'start' to 'end', | |
// then creating (end)->(start) would introduce a cycle. | |
WHERE (start)-[:RELATION*0..]->(end) | |
RETURN COUNT(start) AS pathExists | |
""", | |
node1=node1, | |
node2=node2, | |
graph_id=self.current_graph_id | |
) | |
path_count = check_path.single()["pathExists"] | |
if path_count > 0: | |
print(f"An edge between {node1} -> {node2} already exists!") | |
return | |
# 3) Otherwise, safe to create a new directed edge | |
tx.run( | |
""" | |
MATCH (a:Node {id: $node1, graph_id: $graph_id}) | |
MATCH (b:Node {id: $node2, graph_id: $graph_id}) | |
MERGE (a)-[r:RELATION {type: $rel_type}]->(b) | |
SET r.weight = $weight | |
""", | |
node1=node1, | |
node2=node2, | |
graph_id=self.current_graph_id, | |
rel_type=relationship_type, | |
weight=weight | |
) | |
print( | |
f"Added edge between '{node1}' and '{node2}' in graph " | |
f"'{self.current_graph_id}' (type='{relationship_type}', weight={weight})" | |
) | |
except Exception as e: | |
print(f"Error adding edge between '{node1}' and '{node2}': {str(e)}") | |
raise | |
def edge_exists(self, node1: str, node2: str) -> bool: | |
"""Check if an edge exists between two nodes.""" | |
try: | |
with self.transaction() as tx: | |
result = tx.run( | |
""" | |
MATCH (a:Node {id: $node1})-[r:RELATION]-(b:Node {id: $node2}) | |
RETURN COUNT(r) as count | |
""", | |
node1=node1, | |
node2=node2 | |
) | |
return result.single()["count"] > 0 | |
except Exception as e: | |
print(f"Error checking edge existence between {node1} and {node2}: {str(e)}") | |
raise | |
def graph_exists(self) -> bool: | |
"""Check if a graph exists in Neo4j.""" | |
try: | |
with self.transaction() as tx: | |
result = tx.run(""" | |
MATCH (n:Node) | |
RETURN count(n) > 0 as has_nodes | |
""") | |
return result.single()["has_nodes"] | |
except Exception as e: | |
print(f"Error checking graph existence: {str(e)}") | |
raise | |
def get_graphs(self) -> list: | |
"""Get detailed information about all existing graphs and their nodes.""" | |
try: | |
with self.transaction() as tx: | |
result = tx.run( | |
""" | |
MATCH (g:Graph) | |
OPTIONAL MATCH (n:Node {graph_id: g.id})-[r:RELATION]->(:Node) | |
WITH g, collect(DISTINCT n) AS nodes, collect(DISTINCT r) AS rels | |
RETURN { | |
graph_id: g.id, | |
created: g.created, | |
updated: g.updated, | |
node_count: size(nodes), | |
edge_count: size(rels), | |
nodes: [node IN nodes | { | |
id: node.id, | |
query: node.query, | |
data: node.data, | |
role: node.role, | |
pagerank: node.pagerank | |
}] | |
} as graph_info | |
ORDER BY g.created DESC | |
""" | |
) | |
return list(result) | |
except Exception as e: | |
print(f"Error getting graphs: {str(e)}") | |
raise | |
def select_graph(self, graph_id: str) -> bool: | |
"""Select a specific graph as the current working graph.""" | |
try: | |
with self.transaction() as tx: | |
result = tx.run(""" | |
MATCH (g:Graph {id: $graph_id}) | |
RETURN g | |
""", graph_id=graph_id) | |
if result.single(): | |
self.current_graph_id = graph_id | |
return True | |
return False | |
except Exception as e: | |
print(f"Error selecting graph: {str(e)}") | |
raise | |
def create_new_graph(self) -> str: | |
"""Create a new graph instance and its ID.""" | |
try: | |
with self.transaction() as tx: | |
graph_id = str(uuid.uuid4()) | |
tx.run(""" | |
CREATE (g:Graph { | |
id: $graph_id, | |
created: datetime(), | |
updated: datetime() | |
}) | |
""", graph_id=graph_id) | |
self.current_graph_id = graph_id | |
except Exception as e: | |
print(f"Error creating new graph: {str(e)}") | |
raise | |
def load_graph(self, node_id: str) -> bool: | |
"""Load an existing graph structure from Neo4j based on node ID.""" | |
# Helper function to safely extract number from node ID | |
def extract_number(node_id: str) -> int: | |
try: | |
# Extract all digits from the string | |
num_str = ''.join(filter(str.isdigit, node_id)) | |
return int(num_str) if num_str else 0 | |
except ValueError: | |
print(f"Warning: Could not extract number from node ID: {node_id}") | |
return 0 | |
try: | |
with self.driver.session() as session: | |
# Start transaction | |
tx = session.begin_transaction() | |
try: | |
# Get all related nodes and relationships | |
result = tx.run(""" | |
MATCH path = (n:Node)-[r:RELATION*0..]->(m:Node) | |
WHERE n.id = $node_id | |
RETURN DISTINCT n, r, m, | |
length(path) as depth, | |
[rel in r | type(rel)] as rel_types, | |
[rel in r | rel.weight] as weights | |
""", node_id=node_id) | |
# Reset internal state | |
self.node_counter = 0 | |
self.sub_node_counter = 0 | |
self.cross_connections.clear() | |
# Track processed nodes to avoid duplicates | |
processed_nodes = set() | |
# Process results | |
for record in result: | |
# Update counters based on node patterns | |
if record["n"]["id"] not in processed_nodes: | |
node_id = record["n"]["id"] | |
if "SQ" in node_id: | |
current_num = extract_number(node_id) | |
self.node_counter = max(self.node_counter, current_num) | |
elif "SSQ" in node_id: | |
current_num = extract_number(node_id) | |
self.sub_node_counter = max(self.sub_node_counter, current_num) | |
processed_nodes.add(node_id) | |
if record["m"]["id"] not in processed_nodes: | |
node_id = record["m"]["id"] | |
if "SQ" in node_id: | |
current_num = extract_number(node_id) | |
self.node_counter = max(self.node_counter, current_num) | |
elif "SSQ" in node_id: | |
current_num = extract_number(node_id) | |
self.sub_node_counter = max(self.sub_node_counter, current_num) | |
processed_nodes.add(node_id) | |
# Increment counters for next use | |
self.node_counter += 1 | |
self.sub_node_counter += 1 | |
# Track cross-connections | |
result = tx.run(""" | |
MATCH (n:Node)-[r:RELATION]->(m:Node) | |
WHERE r.type = 'logical' | |
RETURN n.id as source, m.id as target | |
""") | |
for record in result: | |
connection = tuple(sorted([record["source"], record["target"]])) | |
self.cross_connections.add(connection) | |
tx.commit() | |
print(f"Successfully loaded graph. Current counters - Node: {self.node_counter}, Sub: {self.sub_node_counter}") | |
return True | |
except Exception as e: | |
tx.rollback() | |
print(f"Transaction error while loading graph: {str(e)}") | |
return False | |
except Exception as e: | |
print(f"Error loading graph: {str(e)}") | |
return False | |
async def modify_graph(self, new_query: str, similar_node_id: str, session_id: str = None): | |
"""Modify an existing graph structure by integrating a new query.""" | |
# Inner function to add a new node as a sibling | |
async def add_as_sibling(node_id: str, query: str): | |
with self.transaction() as tx: | |
result = tx.run(""" | |
MATCH (n:Node)<-[r:RELATION]-(parent:Node) | |
WHERE n.id = $node_id | |
RETURN parent.id as parent_id, | |
parent.query as parent_query, | |
r.type as rel_type | |
""", node_id=node_id) | |
parent_data = result.single() | |
if not parent_data: | |
raise ValueError(f"No parent found for node {node_id}") | |
if "SQ" in node_id: | |
self.node_counter += 1 | |
new_node_id = f"SQ{self.node_counter}" | |
else: | |
self.sub_node_counter += 1 | |
new_node_id = f"SSQ{self.sub_node_counter}" | |
self.add_node( | |
node_id=new_node_id, | |
query=query, | |
role="independent" | |
) | |
self.add_edge( | |
parent_data["parent_id"], | |
new_node_id, | |
relationship_type=parent_data["rel_type"] | |
) | |
return new_node_id | |
# Inner function to add a new node as a child | |
async def add_as_child(node_id: str, query: str): | |
if "SQ" in node_id: | |
self.sub_node_counter += 1 | |
new_node_id = f"SSQ{self.sub_node_counter}" | |
else: | |
self.node_counter += 1 | |
new_node_id = f"SQ{self.node_counter}" | |
self.add_node( | |
node_id=new_node_id, | |
query=query, | |
role="dependent" | |
) | |
self.add_edge( | |
node_id, | |
new_node_id, | |
relationship_type="logical" | |
) | |
return new_node_id | |
# Inner function to collect context from existing graph nodes | |
def collect_graph_context() -> list: | |
try: | |
with self.transaction() as tx: | |
# Get all nodes except root, ordered by depth and ID to maintain hierarchy | |
result = tx.run(""" | |
MATCH (n:Node) | |
WHERE n.id <> $root_id AND n.graph_id = $graph_id | |
WITH n | |
ORDER BY | |
CASE | |
WHEN n.id STARTS WITH 'SQ' THEN 1 | |
WHEN n.id STARTS WITH 'SSQ' THEN 2 | |
ELSE 3 | |
END, | |
n.id | |
RETURN COLLECT({ | |
id: n.id, | |
query: n.query, | |
role: n.role | |
}) as nodes | |
""", root_id=self.root_node_id, graph_id=self.current_graph_id) | |
nodes = result.single()["nodes"] | |
if not nodes: | |
return [] | |
# Group nodes by hierarchy level | |
level_queries = {} | |
current_sq = None | |
for node in nodes: | |
node_id = node["id"] | |
if node_id.startswith("SQ"): | |
current_sq = node_id | |
if current_sq not in level_queries: | |
level_queries[current_sq] = { | |
"originalquery": node["query"], | |
"subqueries": [] | |
} | |
# Add the SQ node itself as a sub-query | |
level_queries[current_sq]["subqueries"].append({ | |
"subquery": node["query"], | |
"role": node["role"], | |
"dependson": [] # Dependencies will be added below | |
}) | |
elif node_id.startswith("SSQ") and current_sq: | |
level_queries[current_sq]["subqueries"].append({ | |
"subquery": node["query"], | |
"role": node["role"], | |
"dependson": [] # Dependencies will be added below | |
}) | |
# Add dependency information | |
for sq_id, query_data in level_queries.items(): | |
for i, sub_query in enumerate(query_data["subqueries"]): | |
# Get dependencies for this sub_query | |
deps = tx.run(""" | |
MATCH (n:Node {query: $node_query})-[r:RELATION {type: 'logical'}]->(m:Node) | |
WHERE n.graph_id = $graph_id | |
RETURN COLLECT(m.query) as dependencies | |
""", node_query=sub_query["subquery"], graph_id=self.current_graph_id) | |
dep_queries = deps.single()["dependencies"] | |
if dep_queries: | |
# Find indices of dependent queries | |
curr_deps = [] | |
prev_deps = [] | |
for dep_query in dep_queries: | |
# Check current level dependencies | |
curr_idx = next( | |
(idx for idx, sq in enumerate(query_data["subqueries"]) | |
if sq["subquery"] == dep_query), | |
None | |
) | |
if curr_idx is not None: | |
curr_deps.append(curr_idx) | |
else: | |
# Check previous level dependencies | |
for prev_idx, prev_data in enumerate(level_queries.values()): | |
if dep_query in [sq["subquery"] for sq in prev_data["subqueries"]]: | |
prev_deps.append(prev_idx) | |
break | |
query_data["subqueries"][i]["dependson"] = [prev_deps, curr_deps] | |
# Convert to list maintaining order | |
return list(level_queries.values()) | |
except Exception as e: | |
print(f"Error collecting graph context: {str(e)}") | |
raise | |
try: | |
# Get the role and other metadata of the similar node | |
with self.transaction() as tx: | |
result = tx.run(""" | |
MATCH (n:Node {id: $node_id}) | |
RETURN n.role as role, | |
n.query as query, | |
EXISTS((n)<-[:RELATION]-()) as has_parent | |
""", node_id=similar_node_id) | |
node_data = result.single() | |
if not node_data: | |
raise Exception(f"Node {similar_node_id} not found") | |
# Collect context from existing graph | |
context = collect_graph_context() | |
# Determine modification strategy | |
if node_data["role"] == "independent": | |
# Add as sibling if has parent, else as child | |
if node_data["has_parent"]: | |
new_node_id = await add_as_sibling(similar_node_id, new_query) | |
else: | |
new_node_id = await add_as_child(similar_node_id, new_query) | |
else: | |
# Add as child for dependent or pre-requisite nodes | |
new_node_id = await add_as_child(similar_node_id, new_query) | |
# Recursively build subgraph for new node if needed | |
await self.build_graph( | |
query=new_query, | |
parent_node_id=new_node_id, | |
depth=1 if "SQ" in new_node_id else 2, | |
context=context, # Pass the collected context | |
session_id=session_id | |
) | |
except Exception as e: | |
print(f"Error modifying graph: {str(e)}") | |
raise | |
async def build_graph(self, query: str, data: str = None, parent_node_id: str = None, | |
depth: int = 0, threshold: float = 0.8, recurse: bool = True, | |
context: list = None, session_id: str = None, max_tokens_allowed: int = 128000): | |
"""Build a new graph structure in Neo4j.""" | |
async def process_node(self, node_id: str, sub_query: str, | |
session_id: str, future: asyncio.Future, | |
depth=depth, max_tokens_allowed=max_tokens_allowed): | |
"""Process a node asynchronously.""" | |
try: | |
# Generate an optimized search query | |
optimized_query = await self.search_engine.generate_optimized_query(sub_query) | |
# Search for the sub-query | |
results = await self.search_engine.search( | |
query=optimized_query, | |
num_results=10, | |
exclude_filetypes=["pdf"] | |
) | |
# Emit event with the raw results | |
await self.emit_event("search_results_fetched", { | |
"node_id": node_id, | |
"sub_query": sub_query, | |
"optimized_query": optimized_query, | |
"search_results": results | |
}) | |
# Filter the URLs based on the query | |
filtered_urls = await self.search_engine.filter_urls( | |
sub_query, | |
"extensive research dynamic structure", | |
results | |
) | |
# Emit an event with the filtered URLs | |
await self.emit_event("search_results_filtered", { | |
"node_id": node_id, | |
"sub_query": sub_query, | |
"filtered_urls": filtered_urls | |
}) | |
# Get the URLs | |
urls = [result.get('link', 'No URL') for result in filtered_urls] | |
# Fetch URL contents | |
search_contents = await self.custom_crawler.fetch_page_contents( | |
urls, | |
sub_query, | |
session_id=session_id, | |
max_attempts=1, | |
timeout=30 | |
) | |
# Emit an event with the fetched contents | |
await self.emit_event("search_contents_fetched", { | |
"node_id": node_id, | |
"sub_query": sub_query, | |
"contents": search_contents | |
}) | |
# Format the contents | |
contents = "" | |
for k, content in enumerate(search_contents, 1): | |
if isinstance(content, Exception): | |
print(f"Error fetching content: {content}") | |
elif content: | |
contents += f"Document {k}:\n{content}\n\n" | |
if len(contents.strip()) > 0: | |
if depth == 0: | |
# Emit an event to indicate the completion of sub-query processing | |
await self.emit_event("sub_query_processed", { | |
"node_id": node_id, | |
"sub_query": sub_query, | |
"contents": contents | |
}) | |
# Chunk the contents if it exceeds the token limit | |
token_count = self.llm.get_num_tokens(contents) | |
if token_count > max_tokens_allowed: | |
contents = await self.chunking.chunker( | |
text=contents, | |
query=sub_query, | |
max_tokens=max_tokens_allowed | |
) | |
print(f"Number of tokens in the answer: {token_count}") | |
print(f"Number of tokens in the content: {self.llm.get_num_tokens(contents)}") | |
else: | |
if depth == 0: | |
# Emit an event to indicate the failure of sub-query processing | |
await self.emit_event("sub_query_failed", { | |
"node_id": node_id, | |
"sub_query": sub_query, | |
"contents": contents | |
}) | |
# Update node with data atomically | |
with self.transaction() as tx: | |
tx.run( | |
""" | |
MATCH (n:Node {id: $node_id}) | |
SET n.data = $data | |
""", | |
node_id=node_id, | |
data=contents | |
) | |
# Set the result in the future | |
future.set_result(contents) | |
except Exception as e: | |
print(f"Error processing node {node_id}: {str(e)}") | |
future.set_exception(e) | |
raise | |
async def process_dependent_node(self, node_id: str, sub_query: str, depth, dep_futures: list, future): | |
"""Process a dependent node asynchronously.""" | |
try: | |
loop = asyncio.get_running_loop() | |
# Wait for dependencies | |
dep_data = [await f for f in dep_futures] | |
# Modify query based on dependencies | |
modified_query = await self.query_processor.modify_query( | |
sub_query, | |
dep_data | |
) | |
# Generate new embedding for modified query | |
embedding = await loop.run_in_executor( | |
self.executor, | |
self.model.encode, | |
modified_query | |
) | |
# Update node query and embedding atomically | |
with self.transaction() as tx: | |
tx.run( | |
""" | |
MATCH (n:Node {id: $node_id}) | |
SET n.query = $modified_query, | |
n.embedding = $embedding | |
""", | |
node_id=node_id, | |
modified_query=modified_query, | |
embedding=embedding.tolist() | |
) | |
# Process the modified node | |
try: | |
if not future.done(): | |
await process_node( | |
self, node_id, modified_query, session_id, future, depth, max_tokens_allowed | |
) | |
except Exception as e: | |
if not future.done(): | |
future.set_exception(e) | |
raise | |
except Exception as e: | |
print(f"Error processing dependent node {node_id}: {str(e)}") | |
if not future.done(): | |
future.set_exception(e) | |
raise | |
def create_cross_connections(self, node_id=None, depth=None, role=None): | |
"""Create cross connections based on dependencies.""" | |
try: | |
# Get all logical relationships | |
relationships = self.get_node_relationships( | |
node_id=node_id, | |
depth=depth, | |
role=role, | |
relationship_type='logical' | |
) | |
for current_node_id, edges in relationships.items(): | |
# Get node role | |
with self.transaction() as tx: | |
result = tx.run( | |
"MATCH (n:Node {id: $node_id}) RETURN n.role as role", | |
node_id=current_node_id | |
) | |
node_data = result.single() | |
if not node_data or not node_data["role"]: | |
continue | |
node_role = node_data["role"].lower() | |
# Only process dependent nodes | |
if node_role == 'dependent': | |
# Process incoming edges (dependencies) | |
for source_id, target_id, edge_data in edges['in_edges']: | |
if not source_id or source_id == self.root_node_id: | |
continue | |
# Create connection key | |
connection = tuple(sorted([current_node_id, source_id])) | |
# Add cross-connection if not exists | |
if connection not in self.cross_connections: | |
if not self.edge_exists(source_id, current_node_id): | |
print(f"Adding cross-connection edge between {source_id} and {current_node_id}") | |
self.add_edge( | |
source_id, | |
current_node_id, | |
weight=edge_data.get('weight', 1.0), | |
relationship_type='logical' | |
) | |
self.cross_connections.add(connection) | |
# Process outgoing edges (children) | |
for source_id, target_id, edge_data in edges['out_edges']: | |
if not target_id or target_id == self.root_node_id: | |
continue | |
# Create connection key | |
connection = tuple(sorted([current_node_id, target_id])) | |
# Add cross-connection if not exists | |
if connection not in self.cross_connections: | |
if not self.edge_exists(current_node_id, target_id): | |
print(f"Adding cross-connection edge between {current_node_id} and {target_id}") | |
self.add_edge( | |
current_node_id, | |
target_id, | |
weight=edge_data.get('weight', 1.0), | |
relationship_type='logical' | |
) | |
self.cross_connections.add(connection) | |
except Exception as e: | |
print(f"Error creating cross connections: {str(e)}") | |
raise | |
# Main build_graph implementation | |
# Limit recursion depth | |
if depth > 1: | |
return | |
# Initialize context if not provided | |
if context is None: | |
context = [] | |
# Dictionary to keep track of node data and their futures | |
node_data_futures = {} | |
if parent_node_id is None: | |
# If no parent node, this is the root (original query) | |
self.add_node(self.root_node_id, query, data) | |
parent_node_id = self.root_node_id | |
# Get the query intent | |
intent = await self.query_processor.get_query_intent(query) | |
if depth == 0: | |
# Decompose the query into sub-queries | |
response_data, sub_queries, roles, dependencies = \ | |
await self.query_processor.decompose_query_with_dependencies(query, intent) | |
else: | |
# Decompose the sub-query into sub-sub-queries with past context | |
response_data, sub_queries, roles, dependencies = \ | |
await self.query_processor.decompose_query_with_dependencies( | |
query, | |
intent, | |
context | |
) | |
# Add current query data to context for next iteration | |
if response_data: | |
context.append(response_data) | |
# If no further decomposition is possible, sub_queries will contain only the original query | |
if len(sub_queries) > 1 and sub_queries[0] != query: | |
sub_query_ids = [] | |
pre_req_nodes = {} | |
# Create the structure (nodes and edges) of the graph at the current level | |
for idx, (sub_query, role, dependency) in enumerate(zip(sub_queries, roles, dependencies)): | |
# If this is the sub-queries level, | |
# fire the event, letting the callback know about the sub-query | |
if depth == 0: | |
await self.emit_event( | |
"sub_query_created", | |
{ | |
"depth": depth, | |
"sub_query": sub_query, | |
"role": role, | |
"dependency": dependency, | |
"parent_node_id": parent_node_id, | |
} | |
) | |
# Generate a unique ID for the sub-query | |
if depth == 0: | |
self.node_counter += 1 | |
sub_node_id = f"SQ{self.node_counter}" | |
else: | |
self.sub_node_counter += 1 | |
sub_node_id = f"SSQ{self.sub_node_counter}" | |
# Add the node ID to the list of sub-query IDs | |
sub_query_ids.append(sub_node_id) | |
# Add the node to the graph but without a data | |
self.add_node(node_id=sub_node_id, query=sub_query, role=role) | |
# Create future for the node | |
future = asyncio.Future() | |
node_data_futures[sub_node_id] = future | |
if role.lower() in ('pre-requisite', 'prerequisite'): | |
pre_req_nodes[idx] = sub_node_id | |
# Determine how to add edges based on the role | |
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'): | |
# Pre-requisite and Independent nodes connect directly to the parent | |
self.add_edge(parent_node_id, sub_node_id, relationship_type='hierarchical') | |
elif role.lower() == 'dependent': | |
if isinstance(dependency, list) and ( | |
(len(dependency) == 2 and all(isinstance(d, list) for d in dependency)) | |
): | |
print(f"Dependency: {dependency}") | |
# Handle previous query dependencies | |
prev_deps, current_deps = dependency | |
# Handle previous query dependencies | |
if context and prev_deps not in [None, []]: | |
for dep_idx in prev_deps: | |
if dep_idx is not None: | |
# Find the corresponding context data | |
for context_data in context: | |
if context_data and 'subqueries' in context_data: | |
if dep_idx < len(context_data['subqueries']): | |
# Get the query from context | |
sub_query_data = context_data['subqueries'][dep_idx] | |
if isinstance(sub_query_data, dict) and 'subquery' in sub_query_data: | |
dep_query = sub_query_data['subquery'] | |
# Find matching nodes | |
matching_nodes = self.find_nodes_by_properties(query=dep_query) | |
# Get the best matching node ID and score | |
if matching_nodes not in [None, []]: | |
dep_node_id = matching_nodes[0].get('node_id') | |
score = matching_nodes[0].get('score', 0) | |
if score >= 0.9: | |
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') | |
# Add edges from current query dependencies | |
if current_deps not in [None, []]: | |
for dep_idx in current_deps: | |
if dep_idx < len(sub_queries): | |
dep_node_id = sub_query_ids[dep_idx] | |
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') | |
else: | |
# Dependency is incorrect | |
raise ValueError(f"Invalid dependency index: {dep_idx}") | |
elif len(dependency) > 0: | |
for dep_idx in dependency: | |
if dep_idx < len(sub_queries): | |
# Get the node ID of the dependency | |
dep_node_id = sub_query_ids[dep_idx] | |
# Add an edge from the dependency to the current sub-query | |
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') | |
else: | |
raise ValueError(f"Invalid dependency index: {dep_idx}") | |
else: | |
# Dependency is incorrect or empty | |
raise ValueError(f"Invalid dependency: {dependency}") | |
else: | |
# Handle any unexpected roles | |
raise ValueError(f"Unexpected role: {role}") | |
# Proceed to process the nodes | |
tasks = [] | |
# Process pre-requisite and independent nodes concurrently | |
for idx in range(len(sub_queries)): | |
node_id = sub_query_ids[idx] | |
future = node_data_futures[node_id] | |
if roles[idx].lower() in ('pre-requisite', 'prerequisite', 'independent'): | |
tasks.append(process_node( | |
self, node_id, sub_queries[idx], session_id, future, depth, max_tokens_allowed | |
)) | |
# Process dependent nodes as soon as their dependencies are ready | |
for idx in range(len(sub_queries)): | |
node_id = sub_query_ids[idx] | |
future = node_data_futures[node_id] | |
if roles[idx].lower() == 'dependent': | |
dep_futures = [] | |
if isinstance(dependencies[idx], list) and len(dependencies[idx]) == 2: | |
prev_deps, current_deps = dependencies[idx] | |
# Get futures from previous context dependencies | |
if context and prev_deps not in [None, []]: | |
for context_idx, context_data in enumerate(context): | |
# If prev_deps is a list, process the corresponding dependency | |
if isinstance(prev_deps, list) and context_idx < len(prev_deps): | |
context_dep = prev_deps[context_idx] | |
if context_dep is not None: | |
if context_data and 'subqueries' in context_data: | |
if context_dep < len(context_data['subqueries']): | |
sub_query_data = context_data['subqueries'][context_dep] | |
if isinstance(sub_query_data, dict) and 'subquery' in sub_query_data: | |
dep_query = sub_query_data['subquery'] | |
# Find matching nodes | |
matching_nodes = self.find_nodes_by_properties(query=dep_query) | |
if matching_nodes not in [None, []]: | |
# Get the exact matching node ID and score | |
dep_node_id = matching_nodes[0].get('node_id', None) | |
score = float(matching_nodes[0].get('score', 0)) | |
if score == 1.0 and dep_node_id in node_data_futures: | |
dep_futures.append(node_data_futures[dep_node_id]) | |
# If prev_deps is an integer, process it for the current context | |
elif isinstance(prev_deps, int): | |
if prev_deps < len(context_data['subqueries']): | |
sub_query_data = context_data['subqueries'][prev_deps] | |
if isinstance(sub_query_data, dict) and 'subquery' in sub_query_data: | |
dep_query = sub_query_data['subquery'] | |
# Find matching nodes | |
matching_nodes = self.find_nodes_by_properties(query=dep_query) | |
if matching_nodes not in [None, []]: | |
# Get the exact matching node ID and score | |
dep_node_id = matching_nodes[0].get('node_id', None) | |
score = matching_nodes[0].get('score', 0) | |
if score == 1.0 and dep_node_id in node_data_futures: | |
dep_futures.append(node_data_futures[dep_node_id]) | |
# Get futures from current dependencies | |
if current_deps not in [None, []]: | |
current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps | |
for dep_idx in current_deps_list: | |
if dep_idx < len(sub_queries): | |
dep_node_id = sub_query_ids[dep_idx] | |
if dep_node_id in node_data_futures: | |
dep_futures.append(node_data_futures[dep_node_id]) | |
# Start coroutine to wait for dependencies and then process node | |
tasks.append(process_dependent_node( | |
self, node_id, sub_queries[idx], depth, dep_futures, future | |
)) | |
# Emit an event to indicate the start of the search process | |
if depth == 0: | |
await self.emit_event("search_process_started", { | |
"depth": depth, | |
"sub_queries": sub_queries, | |
"roles": roles | |
}) | |
# Wait for all tasks to complete | |
await asyncio.gather(*tasks) | |
# Recurse into sub-queries if needed | |
if recurse: | |
recursion_tasks = [] | |
for idx, sub_query in enumerate(sub_queries): | |
try: | |
sub_node_id = sub_query_ids[idx] | |
recursion_tasks.append( | |
self.build_graph( | |
query=sub_query, | |
parent_node_id=sub_node_id, | |
depth=depth + 1, | |
threshold=threshold, | |
recurse=recurse, | |
context=context, # Pass the context | |
session_id=session_id | |
)) | |
except Exception as e: | |
print(f"Failed to create recursion task for sub-query {sub_query}: {e}") | |
continue | |
# Only proceed if there are any recursion tasks | |
if recursion_tasks: | |
try: | |
await asyncio.gather(*recursion_tasks) | |
except Exception as e: | |
raise Exception(f"Error during recursive processing: {e}") | |
# Process completion tasks | |
if depth == 0: | |
print("Graph building complete, processing final tasks...") | |
# Create cross-connections | |
create_cross_connections(self) | |
print("All cross-connections have been created!") | |
# Add similarity-based edges | |
print(f"Adding similarity edges with threshold {threshold}") | |
all_nodes = [] | |
with self.driver.session() as session: | |
result = session.run( | |
"MATCH (n:Node) WHERE n.id <> $root_id RETURN n.id as id", | |
root_id=self.root_node_id | |
) | |
all_nodes = [record["id"] for record in result] | |
for i, node1 in enumerate(all_nodes): | |
for node2 in all_nodes[i+1:]: | |
if not self.edge_exists(node1, node2): | |
self.add_edge_based_on_similarity_and_relevance( | |
node1, node2, query, threshold | |
) | |
async def process_graph( | |
self, | |
query: str, | |
data: str = None, | |
similarity_threshold: float = 0.8, | |
relevance_threshold: float = 0.7, | |
sub_sub_queries: bool = True, | |
session_id: str = None, | |
max_tokens_allowed: int = 128000 | |
): | |
"""Process a query and manage graph creation/modification.""" | |
# Inner function to check similarity between new query and existing queries in the graph | |
def check_query_similarity(new_query: str, similarity_threshold: float = 0.8) -> Dict[str, Any]: | |
if self.current_graph_id is None: | |
raise Exception("Error: No current graph ID. Cannot check query similarity.") | |
try: | |
# Get all existing queries of the current graph and their metadata from Neo4j | |
print(f"Retrieving existing queries and their metadata for graph {self.current_graph_id}") | |
with self.transaction() as tx: | |
result = tx.run(""" | |
MATCH (n:Node) | |
WHERE n.graph_id IS NOT NULL | |
AND n.graph_id = $graph_id | |
RETURN n.id as id, | |
n.query as query, | |
n.role as role | |
""", | |
graph_id=self.current_graph_id | |
) | |
# Process results and calculate similarities | |
similarities = [] | |
records = list(result) # Materialize results to avoid session timeout | |
if records == []: # No existing queries | |
return {"should_create_new": True} | |
for record in records: | |
# Skip if missing required data | |
if not all([record["query"]]): | |
continue | |
# Calculate query similarity | |
similarity = self.calculate_query_similarity( | |
new_query, | |
record["query"] | |
) | |
if similarity >= similarity_threshold: | |
similarities.append({ | |
"node_id": record["id"], | |
"query": record["query"], | |
"score": similarity, | |
"role": record["role"] | |
}) | |
# If no similar queries found | |
if similarities == []: | |
print(f"No similar queries found above threshold {similarity_threshold}") | |
return {"should_create_new": True} | |
# Find best match | |
best_match = max(similarities, key=lambda x: x["score"]) | |
# Determine relationship type based on node ID pattern | |
rel_type = "root" | |
if "SSQ" in best_match["node_id"]: | |
rel_type = "sub-sub" | |
elif "SQ" in best_match["node_id"]: | |
rel_type = "sub" | |
return { | |
"most_similar_query": best_match["query"], | |
"similarity_score": best_match["score"], | |
"relationship_type": rel_type, | |
"node_id": best_match["node_id"], | |
"should_create_new": best_match["score"] < similarity_threshold | |
} | |
except Exception as e: | |
print(f"Error checking query similarity: {str(e)}") | |
raise | |
try: | |
# Check if a graph already exists | |
print("Checking for existing graphs...") | |
result = self.get_graphs() | |
graphs = list(result) | |
if graphs == []: # No existing graphs | |
print("No existing graphs found. Creating new graph.") | |
self.create_new_graph() | |
# Emit event for creating a new graph | |
await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"}) | |
await self.build_graph( | |
query=query, | |
data=data, | |
threshold=relevance_threshold, | |
recurse=sub_sub_queries, | |
session_id=session_id, | |
max_tokens_allowed=max_tokens_allowed | |
) | |
# Memory cleanup | |
gc.collect() | |
# Prune edges and update pagerank | |
self.prune_edges() | |
self.update_pagerank() | |
# Verify graph integrity and consistency | |
self.verify_graph_integrity() | |
self.verify_graph_consistency() | |
return | |
# Check similarity with existing root queries | |
max_similarity = 0 | |
most_similar_graph = None | |
# First, consolidate nodes from graphs with same ID | |
consolidated_graphs = {} | |
for graph in graphs: | |
graph_info = graph.get("graph_info") | |
if not graph_info: | |
continue | |
graph_id = graph_info.get("graph_id") | |
if not graph_id: | |
continue | |
# Initialize or append nodes for this graph_id | |
if graph_id not in consolidated_graphs: | |
consolidated_graphs[graph_id] = { | |
"graph_id": graph_id, | |
"nodes": [] | |
} | |
# Add nodes if they exist | |
if graph_info.get("nodes"): | |
consolidated_graphs[graph_id]["nodes"].extend(graph_info["nodes"]) | |
# Now process the consolidated graphs | |
for graph_id, graph_data in consolidated_graphs.items(): | |
nodes = graph_data["nodes"] | |
# Calculate similarity with each node's query | |
for node in nodes: | |
if node.get("query"): # Skip nodes without queries | |
similarity = self.calculate_query_similarity( | |
query, | |
node["query"] | |
) | |
if node.get("id").startswith("SQ"): | |
await self.emit_event("retrieved_sub_query", { | |
"sub_query": node["query"] | |
}) | |
if similarity > max_similarity: | |
max_similarity = similarity | |
most_similar_graph = graph_id | |
if max_similarity >= similarity_threshold: | |
# Use existing graph | |
print(f"Found similar query with score {round(max_similarity, 2)}") | |
self.current_graph_id = most_similar_graph | |
if round(max_similarity, 2) == 1.0: | |
print("Loading and using existing graph") | |
# Emit event for loading an existing graph | |
await self.emit_event("graph_operation", {"operation_type": "loading_existing_graph"}) | |
success = self.load_graph(self.root_node_id) | |
if not success: | |
raise Exception("Failed to load existing graph") | |
else: | |
# Check for node-level similarity | |
print("Checking for node-level similarity...") | |
similarity_info = check_query_similarity( | |
query, | |
similarity_threshold | |
) | |
if similarity_info["relationship_type"] in ["sub", "sub-sub"]: | |
print(f"Most Similar Query: {similarity_info['most_similar_query']}") | |
print("Modifying existing graph structure") | |
# Emit event for modifying the graph | |
await self.emit_event("graph_operation", {"operation_type": "modifying_existing_graph"}) | |
await self.modify_graph( | |
query, | |
similarity_info["node_id"], | |
session_id=session_id | |
) | |
# Memory cleanup | |
gc.collect() | |
# Prune edges and update pagerank | |
self.prune_edges() | |
self.update_pagerank() | |
# Verify graph integrity and consistency | |
self.verify_graph_integrity() | |
self.verify_graph_consistency() | |
else: | |
# Create new graph | |
print(f"Creating new graph for query: {query}") | |
self.create_new_graph() | |
# Emit event for creating a new graph | |
await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"}) | |
await self.build_graph( | |
query=query, | |
data=data, | |
threshold=relevance_threshold, | |
recurse=sub_sub_queries, | |
session_id=session_id, | |
max_tokens_allowed=max_tokens_allowed | |
) | |
# Memory cleanup | |
gc.collect() | |
# Prune edges and update pagerank | |
self.prune_edges() | |
self.update_pagerank() | |
# Verify graph integrity and consistency | |
self.verify_graph_integrity() | |
self.verify_graph_consistency() | |
except Exception as e: | |
print(f"Error in process_graph: {str(e)}") | |
raise | |
def add_edge_based_on_similarity_and_relevance(self, node1_id: str, node2_id: str, query: str, threshold: float = 0.8): | |
"""Add edges based on node similarity and relevance.""" | |
try: | |
with self.transaction() as tx: | |
# Get node data atomically | |
result = tx.run( | |
""" | |
MATCH (n1:Node {id: $node1_id}) | |
WITH n1 | |
MATCH (n2:Node {id: $node2_id}) | |
RETURN n1.embedding as emb1, n1.data as data1, | |
n2.embedding as emb2, n2.data as data2 | |
""", | |
node1_id=node1_id, | |
node2_id=node2_id | |
) | |
data = result.single() | |
if not data or not all([data["emb1"], data["emb2"], data["data1"], data["data2"]]): | |
return | |
# Calculate similarities and relevance | |
similarity = self.cosine_similarity(data["emb1"], data["emb2"]) | |
query_relevance1 = self.calculate_relevance(query, data["data1"]) | |
query_relevance2 = self.calculate_relevance(query, data["data2"]) | |
node_relevance = self.calculate_relevance(data["data1"], data["data2"]) | |
# Calculate weight | |
weight = (similarity + query_relevance1 + query_relevance2 + node_relevance) / 4 | |
# Add edge if weight exceeds threshold | |
if weight >= threshold: | |
tx.run( | |
""" | |
MATCH (a:Node {id: $node1_id}), (b:Node {id: $node2_id}) | |
MERGE (a)-[r:RELATION {type: 'similarity_and_relevance'}]->(b) | |
ON CREATE SET r.weight = $weight | |
ON MATCH SET r.weight = $weight | |
""", | |
node1_id=node1_id, | |
node2_id=node2_id, | |
weight=weight | |
) | |
print(f"Added edge between {node1_id} and {node2_id} with type similarity_and_relevance and weight {weight}") | |
except Exception as e: | |
print(f"Error in similarity edge creation between {node1_id} and {node2_id}: {str(e)}") | |
raise | |
def calculate_relevance(self, data1: str, data2: str) -> float: | |
"""Calculate relevance between two data.""" | |
try: | |
if not data1 or not data2: | |
return 0.0 | |
P, R, F1 = self.scorer.score([data1], [data2]) | |
return F1.mean().item() | |
except Exception as e: | |
print(f"Error calculating relevance: {str(e)}") | |
return 0.0 | |
def calculate_query_similarity(self, query1: str, query2: str) -> float: | |
"""Calculate similarity between two queries.""" | |
try: | |
# Generate embeddings | |
embedding1 = self.model.encode(query1).tolist() | |
embedding2 = self.model.encode(query2).tolist() | |
# Calculate cosine similarity | |
return self.cosine_similarity(embedding1, embedding2) | |
except Exception as e: | |
print(f"Error calculating query similarity: {str(e)}") | |
return 0.0 | |
def get_similarities_and_relevance(self, threshold: float = 0.8) -> list: | |
"""Get similarities and relevance between nodes.""" | |
try: | |
with self.transaction() as tx: | |
# Get all nodes except root | |
result = tx.run( | |
""" | |
MATCH (n:Node) | |
WHERE n.id <> $root_id | |
RETURN n.id as id, n.embedding as embedding, n.data as data | |
""", | |
root_id=self.root_node_id | |
) | |
nodes = list(result) | |
similarities = [] | |
# Calculate similarities between each pair | |
for i, node1 in enumerate(nodes): | |
for node2 in nodes[i + 1:]: | |
similarity = self.cosine_similarity(node1["embedding"], node2["embedding"]) | |
relevance = self.calculate_relevance(node1["data"], node2["data"]) | |
# Calculate weight | |
weight = (similarity + relevance) / 2 | |
# Add to results if meets threshold | |
if weight >= threshold: | |
similarities.append({ | |
'node1': node1["id"], | |
'node2': node2["id"], | |
'similarity': similarity, | |
'relevance': relevance, | |
'weight': weight | |
}) | |
return similarities | |
except Exception as e: | |
print(f"Error getting similarities and relevance: {str(e)}") | |
return [] | |
def get_node_relationships(self, node_id=None, depth=None, role=None, relationship_type=None): | |
"""Get relationships between nodes with filtering options.""" | |
try: | |
with self.transaction() as tx: | |
# Build base query | |
cypher_query = """ | |
MATCH (n:Node) | |
WHERE n.id <> $root_id | |
AND n.graph_id = $current_graph_id | |
""" | |
params = { | |
"root_id": self.root_node_id, | |
"current_graph_id": self.current_graph_id | |
} | |
# Add filters | |
if node_id: | |
cypher_query += " AND n.id = $node_id" | |
params["node_id"] = node_id | |
if role: | |
cypher_query += " AND n.role = $role" | |
params["role"] = role | |
if depth is not None: | |
cypher_query += " AND n.depth = $depth" | |
params["depth"] = depth | |
# First get outgoing relationships | |
cypher_query += """ | |
WITH n | |
OPTIONAL MATCH (n)-[r1:RELATION]->(m1:Node) | |
WHERE m1.id <> $root_id | |
AND m1.graph_id = $current_graph_id | |
""" | |
# Add relationship type filter if specified | |
if relationship_type: | |
cypher_query += " AND r1.type = $rel_type" | |
params["rel_type"] = relationship_type | |
# Then get incoming relationships in a separate match | |
cypher_query += """ | |
WITH n, collect({source: n.id, target: m1.id, weight: r1.weight, type: r1.type}) as out_edges | |
OPTIONAL MATCH (n)<-[r2:RELATION]-(m2:Node) | |
WHERE m2.id <> $root_id | |
AND m2.graph_id = $current_graph_id | |
""" | |
# Add same relationship type filter for incoming edges | |
if relationship_type: | |
cypher_query += " AND r2.type = $rel_type" | |
# Return both collections | |
cypher_query += """ | |
RETURN n.id as node_id, | |
collect({source: m2.id, target: n.id, weight: r2.weight, type: r2.type}) as in_edges, | |
out_edges | |
""" | |
result = tx.run(cypher_query, params) | |
relationships = {} | |
for record in result: | |
node_id = record["node_id"] | |
relationships[node_id] = { | |
'in_edges': [(edge['source'], edge['target'], { | |
'weight': edge['weight'], | |
'type': edge['type'] | |
}) for edge in record["in_edges"] if edge['source'] is not None], | |
'out_edges': [(edge['source'], edge['target'], { | |
'weight': edge['weight'], | |
'type': edge['type'] | |
}) for edge in record["out_edges"] if edge['target'] is not None] | |
} | |
return relationships | |
except Exception as e: | |
print(f"Error getting node relationships: {str(e)}") | |
raise | |
def find_nodes_by_properties(self, query: str = None, embedding: list = None, | |
node_data: dict = None, similarity_threshold: float = 0.8) -> list: | |
"""Find nodes based on properties.""" | |
try: | |
with self.transaction() as tx: | |
match_conditions = [] | |
where_conditions = [] | |
params = {} | |
# Build query conditions | |
if query: | |
where_conditions.append("n.query CONTAINS $node_query") | |
params["node_query"] = query | |
if node_data: | |
for key, value in node_data.items(): | |
where_conditions.append(f"n.{key} = ${key}") | |
params[key] = value | |
# Construct the base query | |
cypher_query = "MATCH (n:Node)" | |
if where_conditions: | |
cypher_query += " WHERE " + " AND ".join(where_conditions) | |
cypher_query += " RETURN n" | |
result = tx.run(cypher_query, params) | |
matching_nodes = [] | |
# Process results and calculate similarities | |
for record in result: | |
node = record["n"] | |
match_score = 0 | |
matches = 0 | |
# Score based on property matches | |
if query and query.lower() in node["query"].lower(): | |
match_score += 1 | |
matches += 1 | |
# Score based on embedding similarity | |
if embedding and "embedding" in node: | |
similarity = self.cosine_similarity(embedding, node["embedding"]) | |
if similarity >= similarity_threshold: | |
match_score += similarity | |
matches += 1 | |
# Score based on node_data matches | |
if node_data: | |
data_matches = sum(1 for k, v in node_data.items() | |
if k in node and node[k] == v) | |
if data_matches > 0: | |
match_score += data_matches / len(node_data) | |
matches += 1 | |
# Add to results if any match found | |
if matches > 0: | |
matching_nodes.append({ | |
"node_id": node["id"], | |
"score": match_score / matches, | |
"data": dict(node) | |
}) | |
# Sort by score | |
matching_nodes.sort(key=lambda x: x["score"], reverse=True) | |
return matching_nodes | |
except Exception as e: | |
print(f"Error finding nodes by properties: {str(e)}") | |
raise | |
def query_graph(self, query: str) -> str: | |
"""Query the graph in Neo4j for a specific query, collecting data from the entire relevant subgraph.""" | |
try: | |
with self.transaction() as tx: | |
# Find the query node | |
query_node = tx.run(""" | |
MATCH (n:Node {query: $node_query}) | |
WHERE n.graph_id = $graph_id | |
RETURN n | |
""", node_query=query, graph_id=self.current_graph_id).single() | |
if not query_node: | |
raise ValueError(f"Query node not found for: {query}") | |
query_node_id = query_node['n']['id'] | |
datas = [] | |
# Get entire subgraph including all relationship types and independent nodes | |
subgraph_paths = tx.run(""" | |
// First get the query node and all its connected paths | |
MATCH path = (n:Node {id: $node_id})-[r:RELATION*0..]->(m:Node) | |
WHERE n.graph_id = $graph_id | |
// Collect all nodes and relationships in these paths | |
WITH COLLECT(path) as paths | |
UNWIND paths as path | |
WITH DISTINCT path | |
// Get all nodes and relationships from the paths | |
WITH nodes(path) as nodes, relationships(path) as rels | |
// Calculate path weight considering all relationship types | |
WITH nodes, rels, | |
reduce(weight = 1.0, rel in rels | | |
CASE rel.type | |
WHEN 'logical' THEN weight * rel.weight * 1.2 | |
WHEN 'hierarchical' THEN weight * rel.weight * 1.1 | |
WHEN 'similarity_and_relevance' THEN weight * rel.weight * 0.9 | |
ELSE weight * rel.weight | |
END | |
) as path_weight | |
// Unwind nodes to get individual records | |
UNWIND nodes as node | |
WITH DISTINCT node, path_weight | |
WHERE node.data IS NOT NULL | |
AND node.data <> '' // Ensure data is not empty | |
// Return ordered by weight and pagerank for better context flow | |
RETURN node.data as data, | |
path_weight, | |
node.role as role, | |
node.pagerank as pagerank | |
ORDER BY | |
CASE node.role | |
WHEN 'pre-requisite' THEN 3 | |
WHEN 'independent' THEN 2 | |
ELSE 1 | |
END DESC, | |
path_weight DESC, | |
pagerank DESC | |
""", node_id=query_node_id, graph_id=self.current_graph_id) | |
# Collect data in the order they were returned (already optimally sorted) | |
for record in subgraph_paths: | |
data = record["data"] | |
if data and isinstance(data, str): | |
datas.append(data.strip()) | |
# If no data are found, return an empty string | |
if datas == []: | |
print(f"No data found for: {query}") | |
return "" | |
# Return combined data | |
return "\n\n".join([f"Data {i+1}:\n{data}" for i, data in enumerate(datas)]) | |
except Exception as e: | |
print(f"Error querying graph for specific query: {str(e)}") | |
raise | |
def prune_edges(self, max_edges: int = 1000): | |
"""Prune excess edges while preserving node data.""" | |
try: | |
with self.transaction() as tx: | |
try: | |
# Count current edges | |
result = tx.run( | |
""" | |
MATCH (a:Node {graph_id: $graphID})-[r:RELATION]->(b:Node {graph_id: $graphID}) | |
RETURN count(r) AS count | |
""", | |
graphID=self.current_graph_id | |
) | |
current_edges = result.single()["count"] | |
if current_edges > max_edges: | |
# Mark edges to keep | |
tx.run( | |
""" | |
MATCH (a:Node {graph_id: $graphID})-[r:RELATION]->(b:Node {graph_id: $graphID}) | |
WITH r | |
ORDER BY r.weight DESC | |
LIMIT $max_edges | |
SET r:KEEP | |
""", | |
graphID=self.current_graph_id, | |
max_edges=max_edges | |
) | |
# Remove excess edges | |
tx.run( | |
""" | |
MATCH (a:Node {graph_id: $graphID})-[r:RELATION]->(b:Node {graph_id: $graphID}) | |
WHERE NOT r:KEEP | |
DELETE r | |
""", | |
graphID=self.current_graph_id | |
) | |
# Remove temporary label | |
tx.run( | |
""" | |
MATCH (a:Node {graph_id: $graphID})-[r:KEEP]->(b:Node {graph_id: $graphID}) | |
REMOVE r:KEEP | |
""", | |
graphID=self.current_graph_id | |
) | |
tx.commit() | |
print(f"Pruned edges. Kept top {max_edges} edges by weight.") | |
except Exception as e: | |
tx.rollback() | |
raise e | |
except Exception as e: | |
print(f"Error pruning edges: {str(e)}") | |
raise | |
def update_pagerank(self): | |
"""Update PageRank values using Neo4j's graph algorithms.""" | |
if not self.current_graph_id: | |
print("No current graph selected. Cannot compute PageRank.") | |
return | |
try: | |
with self.transaction() as tx: | |
# Create graph projection with weighted relationships | |
tx.run( | |
""" | |
CALL gds.graph.project.cypher( | |
'graphProjection', | |
'MATCH (n:Node) WHERE n.graph_id = $myParam RETURN id(n) AS id', | |
'MATCH (n:Node)-[r:RELATION]->(m:Node) | |
WHERE n.graph_id = $myParam AND m.graph_id = $myParam | |
RETURN id(n) AS source, | |
id(m) AS target, | |
CASE r.type | |
WHEN "logical" THEN r.weight * 2 | |
ELSE r.weight | |
END AS weight', | |
{ parameters: { myParam: $graphId } } | |
) | |
""", | |
graphId=self.current_graph_id | |
) | |
# Run PageRank with relationship weights | |
tx.run( | |
""" | |
CALL gds.pageRank.write( | |
'graphProjection', | |
{ | |
relationshipWeightProperty: 'weight', | |
writeProperty: 'pagerank', | |
maxIterations: 20, | |
dampingFactor: 0.85, | |
concurrency: 4 | |
} | |
) | |
""" | |
) | |
# Clean up projection | |
tx.run( | |
""" | |
CALL gds.graph.drop('graphProjection') | |
""" | |
) | |
print("PageRank updated successfully") | |
except Exception as e: | |
print(f"Error updating PageRank: {str(e)}") | |
raise | |
def display_graph(self, query: str): | |
"""Display the graph""" | |
try: | |
with self.transaction() as tx: | |
# 1. Find the graph_id(s) of the node using the provided query | |
cypher_query = """ | |
MATCH (n:Node) | |
WHERE n.query = $node_query | |
RETURN COLLECT(DISTINCT n.graph_id) AS graph_ids | |
""" | |
result = tx.run(cypher_query, node_query=query) | |
graph_ids = result.single().get("graph_ids", []) | |
if not graph_ids: | |
print("No graph found for the given query.") | |
return | |
# Create the PyVis network once, so we can add all data to it: | |
net = Network( | |
height="600px", | |
width="100%", | |
directed=True, | |
bgcolor="#222222", | |
font_color="white" | |
) | |
# Disable physics initially | |
net.options = {"physics": {"enabled": False}} | |
all_nodes = set() | |
all_edges = [] | |
for graph_id in graph_ids: | |
# 2. Fetch Graph Data for this graph_id | |
result = tx.run(f"MATCH (n)-[r]->(m) WHERE n.graph_id = '{graph_id}' RETURN n, r, m") | |
for record in result: | |
source_node = record["n"] | |
target_node = record["m"] | |
relationship = record["r"] | |
source_id = source_node.get("id") | |
target_id = target_node.get("id") | |
# Build a descriptive tooltip for each node | |
source_tooltip = ( | |
f"Query: {source_node.get('query', 'N/A')}" | |
) | |
target_tooltip = ( | |
f"Query: {target_node.get('query', 'N/A')}" | |
) | |
# Add source node if not already in the set | |
if source_id not in all_nodes: | |
net.add_node( | |
source_id, | |
label=source_id, | |
title=source_tooltip, | |
size=20, | |
color="#00cc66" | |
) | |
all_nodes.add(source_id) | |
# Add target node if not already in the set | |
if target_id not in all_nodes: | |
net.add_node( | |
target_id, | |
label=target_id, | |
title=target_tooltip, | |
size=20, | |
color="#00cc66" | |
) | |
all_nodes.add(target_id) | |
# Add edge | |
all_edges.append({ | |
"from": source_id, | |
"to": target_id, | |
"label": relationship.type, | |
}) | |
# Add all edges | |
for edge in all_edges: | |
net.add_edge( | |
edge["from"], | |
edge["to"], | |
title=edge["label"], | |
color="#cccccc" | |
) | |
# 4. Enable improved layout and dragNodes | |
net.options["layout"] = {"improvedLayout": True} | |
net.options["interaction"] = {"dragNodes": True} | |
# 5. Save to a temporary file, read it, then remove that file | |
net.save_graph("temp_graph.html") | |
with open("temp_graph.html", "r", encoding="utf-8") as f: | |
html_str = f.read() | |
os.remove("temp_graph.html") # Clean up the temp file | |
return html_str | |
except Exception as e: | |
print(f"Error displaying graph: {str(e)}") | |
raise | |
def verify_graph_integrity(self): | |
"""Verify and fix graph integrity issues.""" | |
try: | |
with self.transaction() as tx: | |
# Check for orphaned nodes | |
orphaned = tx.run( | |
""" | |
MATCH (n:Node {graph_id: $graph_id}) | |
WHERE NOT (n)-[:RELATION]-() | |
RETURN n.id as node_id | |
""", | |
graph_id=self.current_graph_id | |
).values() | |
if orphaned: | |
print(f"Found orphaned nodes: {orphaned}") | |
# Check for invalid edges | |
invalid_edges = tx.run( | |
""" | |
MATCH (a:Node)-[r:RELATION]->(b:Node) | |
WHERE a.graph_id = $graph_id | |
AND (b.graph_id <> $graph_id OR b.graph_id IS NULL) | |
RETURN a.id as from_id, b.id as to_id | |
""", | |
graph_id=self.current_graph_id | |
).values() | |
if invalid_edges: | |
print(f"Found invalid edges: {invalid_edges}") | |
# Optionally fix issues | |
tx.run( | |
""" | |
MATCH (a:Node)-[r:RELATION]->(b:Node) | |
WHERE a.graph_id = $graph_id | |
AND (b.graph_id <> $graph_id OR b.graph_id IS NULL) | |
DELETE r | |
""", | |
graph_id=self.current_graph_id | |
) | |
print("Graph integrity verified successfully") | |
return True | |
except Exception as e: | |
print(f"Error verifying graph integrity: {str(e)}") | |
raise | |
def verify_graph_consistency(self): | |
"""Verify consistency of the Neo4j graph.""" | |
try: | |
with self.driver.session() as session: | |
# Check for nodes without required properties | |
missing_props = session.run(""" | |
MATCH (n:Node) | |
WHERE n.id IS NULL OR n.query IS NULL | |
RETURN count(n) as count | |
""") | |
if missing_props.single()["count"] > 0: | |
raise ValueError("Found nodes with missing required properties") | |
# Check for relationship consistency | |
invalid_rels = session.run(""" | |
MATCH ()-[r:RELATION]->() | |
WHERE r.type IS NULL OR r.weight IS NULL | |
RETURN count(r) as count | |
""") | |
if invalid_rels.single()["count"] > 0: | |
raise ValueError("Found relationships with missing required properties") | |
print("Graph consistency verified successfully") | |
return True | |
except Exception as e: | |
print(f"Error verifying graph consistency: {str(e)}") | |
raise | |
async def close(self): | |
"""Properly cleanup all resources.""" | |
try: | |
# Shutdown executor | |
if hasattr(self, 'executor'): | |
self.executor.shutdown(wait=True) | |
# Close Neo4j driver | |
if hasattr(self, 'driver'): | |
self.driver.close() | |
# Cleanup crawler resources and browser contexts | |
if hasattr(self, 'crawler'): | |
await asyncio.shield(self.crawler.cleanup_expired_sessions()) | |
await asyncio.shield(self.crawler.cleanup_browser_context(self.session_id)) | |
except Exception as e: | |
print(f"Error during cleanup: {e}") | |
def cosine_similarity(v1: List[float], v2: List[float]) -> float: | |
"""Calculate cosine similarity between two vectors.""" | |
try: | |
v1_array = np.array(v1) | |
v2_array = np.array(v2) | |
return np.dot(v1_array, v2_array) / (np.linalg.norm(v1_array) * np.linalg.norm(v2_array)) | |
except Exception as e: | |
print(f"Error calculating cosine similarity: {str(e)}") | |
return 0.0 | |
if __name__ == "__main__": | |
import os | |
from dotenv import load_dotenv | |
from src.reasoning.reasoner import Reasoner | |
from src.evaluation.evaluator import Evaluator | |
load_dotenv() | |
graph_search = Neo4jGraphRAG(num_workers=24) | |
evaluator = Evaluator() | |
reasoner = Reasoner() | |
async def test_graph_search(): | |
# Sample data for testing | |
queries = [ | |
"""In the context of global economic recovery and energy security concerns, provide an in-depth comparative assessment of the renewable energy policies among G20 countries. | |
Specifically, examine how short-term economic stimulus measures intersect with long-term decarbonization commitments, including: | |
1. Carbon pricing mechanisms | |
2. Subsidies for emerging technologies (such as green hydrogen and battery storage) | |
3. Cross-border climate finance initiatives | |
Highlight the unique challenges faced by both advanced and emerging economies in addressing: | |
1. Energy poverty | |
2. Supply chain disruptions | |
3. Geopolitical tensions (e.g., the Russia-Ukraine conflict) | |
Discuss how these factors influence policy effectiveness, and evaluate the degree to which each country is on track to meet—or exceed—its Paris Agreement targets. | |
Note any significant policy gaps, regional collaborations, or innovative best practices. | |
Lastly, provide a forward-looking perspective on how these renewable energy strategies may evolve over the next decade, considering: | |
1. Technological breakthroughs | |
2. Global market trends | |
3. Potential climate-related disasters | |
Present your analysis as a detailed, well-formatted report.""", | |
"""Analyse the impact of 'hot-money' on the value of Indian Rupee and answer the following questions:- | |
1. How does it affect the exchange rate? | |
2. How can it be mitigated/eliminated? | |
3. Why is it a problem? | |
4. What are the consequences? | |
5. What are the alternatives? | |
- Evaluate the alternatives for pros and cons. | |
- Evaluate the impact of alternatives on the exchange rate. | |
- How can they be implemented? | |
- What are the consequences of each alternative? | |
- Evaluate the feasibility of the alternatives. | |
- Pick top 5 alternatives and justify your choices in detail. | |
6. What are the implications for the Indian economy? Furthermore:- | |
- Evaluate the impact of the chosen alternatives on the Indian economy.""", | |
"""Inflation has been an intrinsic past of human civilization since the very beginning. Answer the following questions:- | |
1. How true is the above statement? | |
2. What are the causes of inflation? | |
3. What are the consequences of inflation? | |
4. Can we completely eliminate inflation?""", | |
"""Perform a detailed comparison between the ancient Greece and Roman civilizations. | |
1. What were the key differences between the two civilizations? | |
- Evaluate the differences in governance, society, and culture | |
- Evaluate the differences in economy, trade, and military | |
- Evaluate the differences in technology and infrastructure | |
2. What were the similarities between the two civilizations? | |
- Evaluate the similarities in governance, society, and culture | |
- Evaluate the similarities in economy, trade, and military | |
- Evaluate the similarities in technology and infrastructure | |
3. How did these two civilizations influence each other? | |
- Evaluate the influence of one civilization on the other | |
4. How did these two civilizations influence the modern world? | |
5. Was there another civilization that influenced these two? If yes, how?""", | |
"""Evaluate the long-term effects of colonialism on economic development in Asia:- | |
1. Include case studies of at least five different countries | |
2. Analyze how these effects differ based on colonial power, time of independence, and resource distribution | |
- Evaluate the impact of colonialism on the economy of the country | |
- Evaluate the impact of colonialism on the economy of the region | |
- Evaluate the impact of colonialism on the economy of the world | |
3. How do these effects compare to Africa?""" | |
] | |
follow_on_queries = [ | |
"How is 'hot-money' related to the current economic situation in India?", | |
"What is inflation?", | |
"Did ancient Greece and Rome have any impact on modern democracy? If yes, how?", | |
"Did colonialism have any impact on the trade between Africa and Asia, both in colonial and post-colonial times? If yes, how?" | |
] | |
query = queries[2] | |
# Initialize the database schema | |
graph_search.initialize_schema() | |
# Build the graph in Neo4j | |
await graph_search.process_graph(query, similarity_threshold=0.8, relevance_threshold=0.8) | |
# Query the graph and generate a response | |
answer = graph_search.query_graph(query) | |
response = "" | |
async for chunk in reasoner.reason(query, answer): | |
response += chunk | |
print(response, end="", flush=True) | |
# Display the graph | |
graph_search.display_graph(query) | |
# Evaluate the response | |
evaluation = await evaluator.evaluate_response(query, response, [answer]) | |
print(f"Faithfulness: {evaluation['faithfulness']}") | |
print(f"Answer Relevancy: {evaluation['answer relevancy']}") | |
print(f"Context Utilization: {evaluation['contextual recall']}") | |
# Shutdown the executor after all tasks are complete | |
await graph_search.close() | |
# Run the test function | |
asyncio.run(test_graph_search()) |