seekr / src /rag /graph_rag.py
Hemang Thakur
made changes to graph rag and main files
85f093d
import os
import gc
import time
import asyncio
import torch
import uuid
import rustworkx as rx
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Any
from pyvis.network import Network
from src.query_processing.late_chunking.late_chunker import LateChunker
from src.query_processing.query_processor import QueryProcessor
from src.reasoning.reasoner import Reasoner
from src.utils.api_key_manager import APIKeyManager
from src.search.search_engine import SearchEngine
from src.crawl.crawler import CustomCrawler #, Crawler
from sentence_transformers import SentenceTransformer
from bert_score.scorer import BERTScorer
from tenacity import RetryError
from openai import RateLimitError
from anthropic import RateLimitError as AnthropicRateLimitError
from google.api_core.exceptions import ResourceExhausted
class GraphRAG:
def __init__(self, num_workers: int = 1):
"""Initialize graph and required components."""
# Dictionary to store multiple graphs
self.graphs = {}
self.current_graph_id = None
# Component initialization
self.num_workers = num_workers
self.search_engine = SearchEngine()
self.query_processor = QueryProcessor()
self.reasoner = Reasoner()
# self.crawler = Crawler(verbose=True)
self.custom_crawler = CustomCrawler(max_concurrent_requests=1000)
self.chunking = LateChunker()
self.llm = APIKeyManager().get_llm()
# Model initialization
self.model = SentenceTransformer(
"dunzhang/stella_en_400M_v5",
trust_remote_code=True,
device="cuda" if torch.cuda.is_available() else "cpu"
)
self.scorer = BERTScorer(
model_type="roberta-base",
lang="en",
rescale_with_baseline=True,
device="cuda" if torch.cuda.is_available() else "cpu"
)
# Counters and tracking
self.root_node_id = "QR"
self.node_counter = 0
self.sub_node_counter = 0
self.cross_connections = set()
# Semaphore protection
self.semaphore = asyncio.Semaphore(min(num_workers * 2, 12))
# Thread pool
self.executor = ThreadPoolExecutor(max_workers=self.num_workers)
# Event callback
self.on_event_callback = None
def set_on_event_callback(self, callback):
"""Register a single callback to be triggered for various event types."""
self.on_event_callback = callback
async def emit_event(self, event_type: str, data: dict):
"""Helper method to safely emit an event if a callback is registered."""
if self.on_event_callback:
if asyncio.iscoroutinefunction(self.on_event_callback):
return await self.on_event_callback(event_type, data)
else:
return self.on_event_callback(event_type, data)
def _get_current_graph_data(self):
if self.current_graph_id is None or self.current_graph_id not in self.graphs:
raise Exception("Error: No current graph selected")
return self.graphs[self.current_graph_id]
def add_node(self, node_id: str, query: str, data: str = "", role: str = None):
"""Add a node to the current graph."""
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
# Generate embedding
embedding = self.model.encode(query).tolist()
node_data = {
"id": node_id,
"query": query,
"data": data,
"role": role,
"embedding": embedding,
"pagerank": 0,
"graph_id": self.current_graph_id
}
node_index = graph.add_node(node_data)
node_map[node_id] = node_index
print(f"Added node '{node_id}' to graph '{self.current_graph_id}' with role '{role}' and query: '{query}'")
def _has_path(self, source_idx: int, target_idx: int) -> bool:
"""Helper method to check if there is a path from source to target in the current graph."""
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
visited = set()
stack = [source_idx]
while stack:
current = stack.pop()
if current == target_idx:
return True
if current in visited:
continue
visited.add(current)
for neighbor in graph.neighbors(current):
stack.append(neighbor)
return False
def add_edge(self, node1: str, node2: str, weight: float = 1.0, relationship_type: str = None):
"""Add an edge between two nodes in a way that preserves a DAG structure."""
if self.current_graph_id is None:
raise Exception("Error: No current graph selected")
if node1 == node2:
print(f"Cannot add edge to the same node {node1}!")
return
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
if node1 not in node_map or node2 not in node_map:
print(f"One or both nodes {node1}, {node2} do not exist in the current graph.")
return
idx1 = node_map[node1]
idx2 = node_map[node2]
# Check if adding this edge would create a cycle (i.e. if there is a path from node2 to node1)
if self._has_path(idx2, idx1):
print(f"An edge between {node1} -> {node2} already exists or would create a cycle!")
return
if relationship_type and weight:
edge_data = {"type": relationship_type, "weight": weight}
graph.add_edge(idx1, idx2, edge_data)
else:
raise ValueError("Error: Relationship type and weight must be provided")
print(f"Added edge between '{node1}' and '{node2}' in graph '{self.current_graph_id}' (type='{relationship_type}', weight={weight})")
def edge_exists(self, node1: str, node2: str) -> bool:
"""Check if an edge exists between two nodes."""
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
if node1 not in node_map or node2 not in node_map:
return False
idx1 = node_map[node1]
idx2 = node_map[node2]
for edge in graph.out_edges(idx1):
if edge[1] == idx2:
return True
return False
def graph_exists(self) -> bool:
"""Check if a graph exists."""
return self.current_graph_id is not None and self.current_graph_id in self.graphs and len(self.graphs[self.current_graph_id]["node_map"]) > 0
def get_graphs(self) -> list:
"""Get detailed information about all existing graphs and their nodes."""
result = []
for graph_id, data in self.graphs.items():
metadata = data["metadata"]
node_map = data["node_map"]
graph = data["graph"]
nodes_info = []
for node_id, idx in node_map.items():
node_data = graph.get_node_data(idx)
nodes_info.append({
"id": node_data.get("id"),
"query": node_data.get("query"),
"data": node_data.get("data"),
"role": node_data.get("role"),
"pagerank": node_data.get("pagerank")
})
edge_count = len(graph.edge_list())
result.append({
"graph_info": {
"graph_id": graph_id,
"created": metadata.get("created"),
"updated": metadata.get("updated"),
"node_count": len(node_map),
"edge_count": edge_count,
"nodes": nodes_info
}
})
result.sort(key=lambda x: x["graph_info"]["created"], reverse=True)
return result
def select_graph(self, graph_id: str) -> bool:
"""Select a specific graph as the current working graph."""
if graph_id in self.graphs:
self.current_graph_id = graph_id
return True
return False
def create_new_graph(self) -> str:
"""Create a new graph instance and its ID."""
graph_id = str(uuid.uuid4())
graph = rx.PyDiGraph()
node_map = {}
metadata = {
"id": graph_id,
"created": time.time(),
"updated": time.time()
}
self.graphs[graph_id] = {"graph": graph, "node_map": node_map, "metadata": metadata}
self.current_graph_id = graph_id
return graph_id
def load_graph(self, node_id: str) -> bool:
"""Load an existing graph structure from memory based on a node ID."""
for gid, data in self.graphs.items():
if node_id in data["node_map"]:
self.current_graph_id = gid
for n_id in data["node_map"].keys():
if "SQ" in n_id:
num = int(''.join(filter(str.isdigit, n_id)) or 0)
self.node_counter = max(self.node_counter, num)
elif "SSQ" in n_id:
num = int(''.join(filter(str.isdigit, n_id)) or 0)
self.sub_node_counter = max(self.sub_node_counter, num)
self.node_counter += 1
self.sub_node_counter += 1
graph = data["graph"]
node_map = data["node_map"]
for (u, v), edge_data in zip(graph.edge_list(), graph.edges()):
if edge_data.get("type") == "logical":
source_id = graph.get_node_data(u).get("id")
target_id = graph.get_node_data(v).get("id")
connection = tuple(sorted([source_id, target_id]))
self.cross_connections.add(connection)
print(f"Successfully loaded graph. Current counters - Node: {self.node_counter}, Sub: {self.sub_node_counter}")
return True
print(f"Graph with node_id {node_id} not found.")
return False
async def modify_graph(self, new_query: str, similar_node_id: str, session_id: str = None):
"""Modify an existing graph structure by integrating a new query."""
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
async def add_as_sibling(node_id: str, query: str):
if node_id not in node_map:
raise ValueError(f"Node {node_id} not found")
idx = node_map[node_id]
in_edges = graph.in_edges(idx)
if not in_edges:
raise ValueError(f"No parent found for node {node_id}")
parent_idx = in_edges[0][0]
parent_data = graph.get_node_data(parent_idx)
parent_id = parent_data.get("id")
if "SQ" in node_id:
self.node_counter += 1
new_node_id = f"SQ{self.node_counter}"
else:
self.sub_node_counter += 1
new_node_id = f"SSQ{self.sub_node_counter}"
self.add_node(new_node_id, query, role="independent")
self.add_edge(parent_id, new_node_id, relationship_type=in_edges[0][2].get("type"))
return new_node_id
async def add_as_child(node_id: str, query: str):
if "SQ" in node_id:
self.sub_node_counter += 1
new_node_id = f"SSQ{self.sub_node_counter}"
else:
self.node_counter += 1
new_node_id = f"SQ{self.node_counter}"
self.add_node(new_node_id, query, role="dependent")
self.add_edge(node_id, new_node_id, relationship_type="logical")
return new_node_id
def collect_graph_context() -> list:
"""Collect context from existing graph nodes."""
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
nodes = []
for n_id, idx in node_map.items():
if n_id == self.root_node_id:
continue
node_data = graph.get_node_data(idx)
nodes.append({
"id": node_data.get("id"),
"query": node_data.get("query"),
"role": node_data.get("role")
})
nodes.sort(key=lambda x: (0 if x["id"].startswith("SQ") else (1 if x["id"].startswith("SSQ") else 2), x["id"]))
level_queries = {}
current_sq = None
for node in nodes:
node_id = node["id"]
if node_id.startswith("SQ"):
current_sq = node_id
if current_sq not in level_queries:
level_queries[current_sq] = {
"originalquery": node["query"],
"subqueries": []
}
level_queries[current_sq]["subqueries"].append({
"subquery": node["query"],
"role": node["role"],
"dependson": []
})
elif node_id.startswith("SSQ") and current_sq:
level_queries[current_sq]["subqueries"].append({
"subquery": node["query"],
"role": node["role"],
"dependson": []
})
return list(level_queries.values())
if similar_node_id not in node_map:
raise Exception(f"Node {similar_node_id} not found")
similar_node_data = graph.get_node_data(node_map[similar_node_id])
has_parent = len(graph.in_edges(node_map[similar_node_id])) > 0
context = collect_graph_context()
if similar_node_data.get("role") == "independent":
if has_parent:
new_node_id = await add_as_sibling(similar_node_id, new_query)
else:
new_node_id = await add_as_child(similar_node_id, new_query)
else:
new_node_id = await add_as_child(similar_node_id, new_query)
await self.build_graph(
query=new_query,
parent_node_id=new_node_id,
depth=1 if "SQ" in new_node_id else 2,
context=context,
session_id=session_id
)
async def build_graph(self, query: str, data: str = None, parent_node_id: str = None,
depth: int = 0, threshold: float = 0.8, recurse: bool = True,
context: list = None, session_id: str = None, max_tokens_allowed: int = 128000,
node_data_futures: dict = None, sub_nodes_info: list = None,
sub_query_ids: list = None, pre_req_nodes: list = None):
"""Build a new graph structure in memory."""
async def process_node(node_id: str, sub_query: str, session_id: str,
future: asyncio.Future, max_tokens_allowed: int = max_tokens_allowed):
try:
optimized_query = await self.search_engine.generate_optimized_query(sub_query)
results = await self.search_engine.search(
query=optimized_query,
num_results=10,
exclude_filetypes=["pdf"]
)
await self.emit_event("search_results_fetched", {
"node_id": node_id,
"sub_query": sub_query,
"optimized_query": optimized_query,
"search_results": results
})
filtered_urls = await self.search_engine.filter_urls(
sub_query,
"extensive research dynamic structure",
results
)
await self.emit_event("search_results_filtered", {
"node_id": node_id,
"sub_query": sub_query,
"filtered_urls": filtered_urls
})
urls = [result.get('link', 'No URL') for result in filtered_urls]
search_contents = await self.custom_crawler.fetch_page_contents(
urls,
sub_query,
session_id=session_id,
max_attempts=1,
timeout=30
)
await self.emit_event("search_contents_fetched", {
"node_id": node_id,
"sub_query": sub_query,
"contents": search_contents
})
contents = ""
for k, content in enumerate(search_contents, 1):
if isinstance(content, Exception):
print(f"Error fetching content: {content}")
elif content:
contents += f"Document {k}:\n{content}\n\n"
if contents.strip():
token_count = self.llm.get_num_tokens(contents)
if token_count > max_tokens_allowed:
contents = await self.chunking.chunker(
text=contents,
query=sub_query,
max_tokens=max_tokens_allowed
)
print(f"Number of tokens in the answer: {token_count}")
print(f"Number of tokens in the content: {self.llm.get_num_tokens(contents)}")
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
if node_id in node_map:
idx = node_map[node_id]
node_data = graph.get_node_data(idx)
node_data["data"] = contents
if not future.done():
future.set_result(contents)
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError) as e:
print(f"Error processing node {node_id}: {str(e)}")
if not future.done():
future.set_exception(e)
except Exception as e:
print(f"Error processing node {node_id}: {str(e)}")
if not future.done():
future.set_exception(e)
raise e
async def process_dependent_node(node_id: str, sub_query: str, dep_futures: list, future):
try:
dep_data = [await f for f in dep_futures]
modified_query = await self.query_processor.modify_query(
sub_query,
dep_data
)
loop = asyncio.get_running_loop()
embedding = await loop.run_in_executor(
self.executor,
self.model.encode,
modified_query
)
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
if node_id in node_map:
idx = node_map[node_id]
node_data = graph.get_node_data(idx)
node_data["query"] = modified_query
node_data["embedding"] = embedding.tolist() if hasattr(embedding, "tolist") else embedding
try:
if not future.done():
await process_node(node_id, modified_query, session_id, future, max_tokens_allowed)
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError) as e:
if not future.done():
future.set_exception(e)
except Exception as e:
if not future.done():
future.set_exception(e)
raise e
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError) as e:
print(f"Error processing dependent node {node_id}: {str(e)}")
if not future.done():
future.set_exception(e)
except Exception as e:
print(f"Error processing dependent node {node_id}: {str(e)}")
if not future.done():
future.set_exception(e)
raise e
def create_cross_connections():
try:
relationships = self.get_node_relationships(relationship_type='logical')
for current_node_id, edges in relationships.items():
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
if current_node_id not in node_map:
continue
idx = node_map[current_node_id]
node_data = graph.get_node_data(idx)
node_role = (node_data.get("role") or "").lower()
if node_role == 'dependent':
for source_id, target_id, edge_data in edges['in_edges']:
if not source_id or source_id == self.root_node_id:
continue
connection = tuple(sorted([current_node_id, source_id]))
if connection not in self.cross_connections:
if not self.edge_exists(source_id, current_node_id):
print(f"Adding cross-connection edge between {source_id} and {current_node_id}")
self.add_edge(source_id, current_node_id, weight=edge_data.get('weight', 1.0), relationship_type='logical')
self.cross_connections.add(connection)
for source_id, target_id, edge_data in edges['out_edges']:
if not target_id or target_id == self.root_node_id:
continue
connection = tuple(sorted([current_node_id, target_id]))
if connection not in self.cross_connections:
if not self.edge_exists(current_node_id, target_id):
print(f"Adding cross-connection edge between {current_node_id} and {target_id}")
self.add_edge(current_node_id, target_id, weight=edge_data.get('weight', 1.0), relationship_type='logical')
self.cross_connections.add(connection)
except Exception as e:
print(f"Error creating cross connections: {str(e)}")
raise
if depth > 1:
return
if context is None:
context = []
if node_data_futures is None:
node_data_futures = {}
if sub_nodes_info is None:
sub_nodes_info = []
if sub_query_ids is None:
sub_query_ids = []
if pre_req_nodes is None:
pre_req_nodes = {}
if parent_node_id is None:
self.add_node(self.root_node_id, query, data)
parent_node_id = self.root_node_id
intent = await self.query_processor.get_query_intent(query)
if depth == 0:
response_data, sub_queries, roles, dependencies = await self.query_processor.decompose_query_with_dependencies(query, intent)
else:
response_data, sub_queries, roles, dependencies = await self.query_processor.decompose_query_with_dependencies(query, intent, context)
if response_data:
context.append(response_data)
if len(sub_queries) > 1 and sub_queries[0] != query:
for idx, (sub_query, role, dependency) in enumerate(zip(sub_queries, roles, dependencies)):
if depth == 0:
await self.emit_event("sub_query_created", {
"depth": depth,
"sub_query": sub_query,
"role": role,
"dependency": dependency,
"parent_node_id": parent_node_id,
})
if depth == 0:
self.node_counter += 1
sub_node_id = f"SQ{self.node_counter}"
else:
self.sub_node_counter += 1
sub_node_id = f"SSQ{self.sub_node_counter}"
sub_query_ids.append(sub_node_id)
self.add_node(sub_node_id, sub_query, role=role)
future = asyncio.Future()
node_data_futures[sub_node_id] = future
sub_nodes_info.append((sub_node_id, sub_query, role, dependency, future, depth))
if role.lower() in ['pre-requisite', 'prerequisite']:
pre_req_nodes[idx] = sub_node_id
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'):
self.add_edge(parent_node_id, sub_node_id, relationship_type='hierarchical')
elif role.lower() == 'dependent':
if isinstance(dependency, list) and (len(dependency) == 2 and all(isinstance(d, list) for d in dependency)):
print(f"Dependency: {dependency}")
prev_deps, current_deps = dependency
if context and prev_deps not in [None, []]:
for dep_idx in prev_deps:
if dep_idx is not None:
for context_data in context:
if 'subqueries' in context_data and dep_idx < len(context_data['subqueries']):
sub_query_data = context_data['subqueries'][dep_idx]
if isinstance(sub_query_data, dict) and 'subquery' in sub_query_data:
dep_query = sub_query_data['subquery']
matching_nodes = self.find_nodes_by_properties(query=dep_query)
if matching_nodes:
dep_node_id = matching_nodes[0].get('node_id')
score = matching_nodes[0].get('score', 0)
if score >= 0.9:
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical')
if current_deps not in [None, []]:
for dep_idx in current_deps:
if dep_idx < len(sub_query_ids):
dep_node_id = sub_query_ids[dep_idx]
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical')
else:
raise ValueError(f"Invalid dependency index: {dep_idx}")
elif len(dependency) > 0:
for dep_idx in dependency:
if dep_idx < len(sub_queries):
dep_node_id = sub_query_ids[dep_idx]
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical')
else:
raise ValueError(f"Invalid dependency index: {dep_idx}")
else:
raise ValueError(f"Invalid dependency: {dependency}")
else:
raise ValueError(f"Unexpected role: {role}")
if recurse:
recursion_tasks = []
for idx, sub_query in enumerate(sub_queries):
try:
sub_node_id = sub_query_ids[idx]
recursion_tasks.append(
self.build_graph(
query=sub_query,
parent_node_id=sub_node_id,
depth=depth + 1,
threshold=threshold,
recurse=recurse,
context=context,
session_id=session_id,
node_data_futures=node_data_futures,
sub_nodes_info=sub_nodes_info,
sub_query_ids=sub_query_ids,
pre_req_nodes=pre_req_nodes
)
)
except Exception as e:
print(f"Failed to create recursion task for sub-query {sub_query}: {e}")
continue
if recursion_tasks:
try:
await asyncio.gather(*recursion_tasks, return_exceptions=True)
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError) as e:
print(f"Error during recursive processing: {e}")
except Exception as e:
print(f"Error during recursive processing: {e}")
raise e
futures = {}
all_child_futures = {}
process_tasks = []
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
for (sub_node_id, sub_query, role, dependency, future, local_depth) in sub_nodes_info:
idx = node_map.get(sub_node_id)
has_children = False
child_futures = []
if idx is not None:
for (_, child_idx, edge_data) in graph.out_edges(idx):
if edge_data.get("type") == "hierarchical":
has_children = True
child_future = node_data_futures.get(graph.get_node_data(child_idx).get("id"))
if child_future:
child_futures.append(child_future)
if local_depth == 0:
futures[sub_query] = future
all_child_futures[sub_query] = child_futures
if has_children:
if not future.done():
future.set_result("")
else:
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'):
process_tasks.append(process_node(sub_node_id, sub_query, session_id, future, max_tokens_allowed))
elif role.lower() == 'dependent':
dep_futures = []
if isinstance(dependency, list) and len(dependency) == 2:
prev_deps, current_deps = dependency
if context and prev_deps not in [None, []]:
for context_idx, context_data in enumerate(context):
if isinstance(prev_deps, list) and context_idx < len(prev_deps):
context_dep = prev_deps[context_idx]
if (context_dep is not None and isinstance(context_data, dict)
and 'subqueries' in context_data):
if context_dep < len(context_data['subqueries']):
dep_query = context_data['subqueries'][context_dep]['subquery']
matching_nodes = self.find_nodes_by_properties(query=dep_query)
if matching_nodes not in [None, []]:
dep_node_id = matching_nodes[0].get('node_id', None)
score = float(matching_nodes[0].get('score', 0))
if score == 1.0 and dep_node_id in node_data_futures:
dep_futures.append(node_data_futures[dep_node_id])
elif isinstance(prev_deps, int):
if context_idx < len(context_data['subqueries']):
dep_query = context_data['subqueries'][prev_deps]['subquery']
matching_nodes = self.find_nodes_by_properties(query=dep_query)
if matching_nodes not in [None, []]:
dep_node_id = matching_nodes[0].get('node_id', None)
score = matching_nodes[0].get('score', 0)
if score == 1.0 and dep_node_id in node_data_futures:
dep_futures.append(node_data_futures[dep_node_id])
if current_deps not in [None, []]:
current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps
for dep_idx in current_deps_list:
if dep_idx < len(sub_query_ids):
dep_node_id = sub_query_ids[dep_idx]
if dep_node_id in node_data_futures:
dep_futures.append(node_data_futures[dep_node_id])
process_tasks.append(process_dependent_node(sub_node_id, sub_query, dep_futures, future))
else:
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'):
process_tasks.append(process_node(sub_node_id, sub_query, session_id, future, max_tokens_allowed))
elif role.lower() == 'dependent':
dep_futures = []
if isinstance(dependency, list) and len(dependency) == 2:
prev_deps, current_deps = dependency
if context and prev_deps not in [None, []]:
for context_idx, context_data in enumerate(context):
if isinstance(prev_deps, list) and context_idx < len(prev_deps):
context_dep = prev_deps[context_idx]
if (context_dep is not None and isinstance(context_data, dict)
and 'subqueries' in context_data):
if context_dep < len(context_data['subqueries']):
dep_query = context_data['subqueries'][context_dep]['subquery']
matching_nodes = self.find_nodes_by_properties(query=dep_query)
if matching_nodes not in [None, []]:
dep_node_id = matching_nodes[0].get('node_id', None)
score = float(matching_nodes[0].get('score', 0))
if score == 1.0 and dep_node_id in node_data_futures:
dep_futures.append(node_data_futures[dep_node_id])
elif isinstance(prev_deps, int):
if context_idx < len(context_data['subqueries']):
dep_query = context_data['subqueries'][prev_deps]['subquery']
matching_nodes = self.find_nodes_by_properties(query=dep_query)
if matching_nodes not in [None, []]:
dep_node_id = matching_nodes[0].get('node_id', None)
score = matching_nodes[0].get('score', 0)
if score == 1.0 and dep_node_id in node_data_futures:
dep_futures.append(node_data_futures[dep_node_id])
if current_deps not in [None, []]:
current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps
for dep_idx in current_deps_list:
if dep_idx < len(sub_query_ids):
dep_node_id = sub_query_ids[dep_idx]
if dep_node_id in node_data_futures:
dep_futures.append(node_data_futures[dep_node_id])
process_tasks.append(process_dependent_node(sub_node_id, sub_query, dep_futures, future))
if process_tasks:
await self.emit_event("search_process_started", {
"depth": depth,
"sub_queries": sub_queries,
"roles": roles
})
processed_sub_queries = set()
for sub_query, future in futures.items():
try:
parent_content = future.result().strip()
except:
parent_content = ""
child_futures = all_child_futures.get(sub_query)
any_child_done = any(cf.done() and cf.result().strip() for cf in child_futures)
if parent_content or any_child_done:
await self.emit_event("sub_query_processed", {"sub_query": sub_query})
processed_sub_queries.add(sub_query)
await asyncio.gather(*process_tasks)
if depth == 0:
for sub_query, future in futures.items():
if sub_query not in processed_sub_queries:
try:
parent_content = future.result().strip()
except:
parent_content = ""
child_futures = all_child_futures.get(sub_query)
any_child_done = any(cf.done() and cf.result().strip() for cf in child_futures)
if parent_content or any_child_done:
await self.emit_event("sub_query_processed", {"sub_query": sub_query})
else:
await self.emit_event("sub_query_failed", {"sub_query": sub_query})
print("Graph building complete, processing final tasks...")
await self.emit_event("search_process_completed", {
"depth": depth,
"sub_queries": sub_queries,
"roles": roles
})
create_cross_connections()
print("All cross-connections have been created!")
print(f"Adding similarity edges with threshold {threshold}")
graph_data = self._get_current_graph_data()
node_map = graph_data["node_map"]
all_node_ids = list(node_map.keys())
for i, node1 in enumerate(all_node_ids):
for node2 in all_node_ids[i+1:]:
if not self.edge_exists(node1, node2):
self.add_edge_based_on_similarity_and_relevance(node1, node2, query, threshold)
print("All similarity edges have been added!")
async def process_graph(
self,
query: str,
data: str = None,
similarity_threshold: float = 0.8,
relevance_threshold: float = 0.7,
sub_sub_queries: bool = True,
session_id: str = None,
max_tokens_allowed: int = 128000
):
"""Process a query and manage graph creation/modification."""
def check_query_similarity(new_query: str, similarity_threshold: float = 0.8) -> Dict[str, Any]:
if self.current_graph_id is None:
raise Exception("Error: No current graph ID. Cannot check query similarity.")
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
similarities = []
if not node_map:
return {"should_create_new": True}
for node_id, idx in node_map.items():
node_data = graph.get_node_data(idx)
if not node_data.get("query"):
continue
similarity = self.calculate_query_similarity(new_query, node_data.get("query"))
if similarity >= similarity_threshold:
similarities.append({
"node_id": node_id,
"query": node_data.get("query"),
"score": similarity,
"role": node_data.get("role")
})
if not similarities:
print(f"No similar queries found above threshold {similarity_threshold}")
return {"should_create_new": True}
best_match = max(similarities, key=lambda x: x["score"])
rel_type = "root"
if "SSQ" in best_match["node_id"]:
rel_type = "sub-sub"
elif "SQ" in best_match["node_id"]:
rel_type = "sub"
return {
"most_similar_query": best_match["query"],
"similarity_score": best_match["score"],
"relationship_type": rel_type,
"node_id": best_match["node_id"],
"should_create_new": best_match["score"] < similarity_threshold
}
try:
graphs = self.get_graphs()
if not graphs:
print("No existing graphs found. Creating new graph.")
self.create_new_graph()
await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"})
await self.build_graph(
query=query,
data=data,
threshold=relevance_threshold,
recurse=sub_sub_queries,
session_id=session_id,
max_tokens_allowed=max_tokens_allowed
)
gc.collect()
self.prune_edges()
self.update_pagerank()
self.verify_graph_integrity()
self.verify_graph_consistency()
return
max_similarity = 0
most_similar_graph = None
consolidated_graphs = {}
for graph_obj in graphs:
graph_info = graph_obj.get("graph_info")
if not graph_info:
continue
graph_id = graph_info.get("graph_id")
if not graph_id:
continue
if graph_id not in consolidated_graphs:
consolidated_graphs[graph_id] = {
"graph_id": graph_id,
"nodes": []
}
if graph_info.get("nodes"):
consolidated_graphs[graph_id]["nodes"].extend(graph_info["nodes"])
for graph_id, graph_data in consolidated_graphs.items():
nodes = graph_data["nodes"]
for node in nodes:
if node.get("query"):
similarity = self.calculate_query_similarity(query, node["query"])
if node.get("id", "").startswith("SQ"):
asyncio.create_task(self.emit_event("retrieved_sub_query", {
"sub_query": node["query"]
}))
if similarity > max_similarity:
max_similarity = similarity
most_similar_graph = graph_id
if max_similarity >= similarity_threshold:
print(f"Found similar query with score {round(max_similarity, 2)}")
self.current_graph_id = most_similar_graph
if round(max_similarity, 2) == 1.0:
print("Loading and using existing graph")
await self.emit_event("graph_operation", {"operation_type": "loading_existing_graph"})
success = self.load_graph(self.root_node_id)
if not success:
raise Exception("Failed to load existing graph")
else:
print("Checking for node-level similarity...")
similarity_info = check_query_similarity(
query,
similarity_threshold
)
if similarity_info["relationship_type"] in ["sub", "sub-sub"]:
print(f"Most Similar Query: {similarity_info['most_similar_query']}")
print("Modifying existing graph structure")
await self.emit_event("graph_operation", {"operation_type": "modifying_existing_graph"})
await self.modify_graph(
query,
similarity_info["node_id"],
session_id=session_id
)
gc.collect()
self.prune_edges()
self.update_pagerank()
self.verify_graph_integrity()
self.verify_graph_consistency()
else:
print(f"Creating new graph for query: {query}")
self.create_new_graph()
await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"})
await self.build_graph(
query=query,
data=data,
threshold=relevance_threshold,
recurse=sub_sub_queries,
session_id=session_id,
max_tokens_allowed=max_tokens_allowed
)
gc.collect()
self.prune_edges()
self.update_pagerank()
self.verify_graph_integrity()
self.verify_graph_consistency()
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError):
pass
except Exception as e:
print(f"Error in process_graph: {str(e)}")
raise
def add_edge_based_on_similarity_and_relevance(self, node1_id: str, node2_id: str, query: str, threshold: float = 0.8):
"""Add edges based on node similarity and relevance."""
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
if node1_id not in node_map or node2_id not in node_map:
return
idx1 = node_map[node1_id]
idx2 = node_map[node2_id]
node1_data = graph.get_node_data(idx1)
node2_data = graph.get_node_data(idx2)
if not all([node1_data.get("embedding"), node2_data.get("embedding"), node1_data.get("data"), node2_data.get("data")]):
return
similarity = self.cosine_similarity(node1_data["embedding"], node2_data["embedding"])
query_relevance1 = self.calculate_relevance(query, node1_data["data"])
query_relevance2 = self.calculate_relevance(query, node2_data["data"])
node_relevance = self.calculate_relevance(node1_data["data"], node2_data["data"])
weight = (similarity + query_relevance1 + query_relevance2 + node_relevance) / 4
if weight >= threshold:
self.add_edge(node1_id, node2_id, weight=weight, relationship_type='similarity_and_relevance')
print(f"Added edge between {node1_id} and {node2_id} with type similarity_and_relevance and weight {weight}")
def calculate_relevance(self, data1: str, data2: str) -> float:
"""Calculate relevance between two data strings."""
try:
if not data1 or not data2:
return 0.0
P, R, F1 = self.scorer.score([data1], [data2])
return F1.mean().item()
except Exception as e:
print(f"Error calculating relevance: {str(e)}")
return 0.0
def calculate_query_similarity(self, query1: str, query2: str) -> float:
"""Calculate similarity between two queries."""
try:
embedding1 = self.model.encode(query1).tolist()
embedding2 = self.model.encode(query2).tolist()
return self.cosine_similarity(embedding1, embedding2)
except Exception as e:
print(f"Error calculating query similarity: {str(e)}")
return 0.0
def get_similarities_and_relevance(self, threshold: float = 0.8) -> list:
"""Get similarities and relevance between nodes."""
try:
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
nodes = []
for node_id, idx in node_map.items():
node_data = graph.get_node_data(idx)
nodes.append({
"id": node_data.get("id"),
"embedding": node_data.get("embedding"),
"data": node_data.get("data")
})
similarities = []
for i, node1 in enumerate(nodes):
for node2 in nodes[i + 1:]:
similarity = self.cosine_similarity(node1["embedding"], node2["embedding"])
relevance = self.calculate_relevance(node1["data"], node2["data"])
weight = (similarity + relevance) / 2
if weight >= threshold:
similarities.append({
'node1': node1["id"],
'node2': node2["id"],
'similarity': similarity,
'relevance': relevance,
'weight': weight
})
return similarities
except Exception as e:
print(f"Error getting similarities and relevance: {str(e)}")
return []
def get_node_relationships(self, node_id=None, depth=None, role=None, relationship_type=None):
"""Get relationships between nodes with filtering options."""
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
relationships = {}
for n_id, idx in node_map.items():
if n_id == self.root_node_id:
continue
node_data = graph.get_node_data(idx)
if node_id and n_id != node_id:
continue
if role and node_data.get("role") != role:
continue
in_edges = []
for u, v, edge_data in graph.in_edges(idx):
source_id = graph.get_node_data(u).get("id")
in_edges.append((source_id, n_id, {"weight": edge_data.get("weight"), "type": edge_data.get("type")}))
out_edges = []
for u, v, edge_data in graph.out_edges(idx):
target_id = graph.get_node_data(v).get("id")
out_edges.append((n_id, target_id, {"weight": edge_data.get("weight"), "type": edge_data.get("type")}))
relationships[n_id] = {"in_edges": in_edges, "out_edges": out_edges}
return relationships
def find_nodes_by_properties(self, query: str = None, embedding: list = None,
node_data: dict = None, similarity_threshold: float = 0.8) -> list:
"""Find nodes based on properties."""
try:
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
matching_nodes = []
for n_id, idx in node_map.items():
data = graph.get_node_data(idx)
match_score = 0
matches = 0
if query and query.lower() in data.get("query", "").lower():
match_score += 1
matches += 1
if embedding and "embedding" in data:
sim = self.cosine_similarity(embedding, data["embedding"])
if sim >= similarity_threshold:
match_score += sim
matches += 1
if node_data:
data_matches = sum(1 for k, v in node_data.items() if k in data and data[k] == v)
if data_matches > 0:
match_score += data_matches / len(node_data)
matches += 1
if matches > 0:
matching_nodes.append({
"node_id": n_id,
"score": match_score / matches,
"data": data
})
matching_nodes.sort(key=lambda x: x["score"], reverse=True)
return matching_nodes
except Exception as e:
print(f"Error finding nodes by properties: {str(e)}")
raise
def query_graph(self, query: str) -> str:
"""Query the graph for a specific query, collecting data from the entire relevant subgraph."""
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
target_node_id = None
for n_id, idx in node_map.items():
if graph.get_node_data(idx).get("query") == query:
target_node_id = n_id
break
if not target_node_id:
raise ValueError(f"Query node not found for: {query}")
datas = []
start_idx = node_map[target_node_id]
visited = set()
stack = [start_idx]
while stack:
current = stack.pop()
if current in visited:
continue
visited.add(current)
current_data = graph.get_node_data(current)
if current_data.get("data") and current_data.get("data").strip():
datas.append(current_data.get("data").strip())
for neighbor in graph.neighbors(current):
if neighbor not in visited:
stack.append(neighbor)
if not datas:
print(f"No data found for: {query}")
return ""
return "\n\n".join([f"Data {i+1}:\n{data}" for i, data in enumerate(datas)])
def prune_edges(self, max_edges: int = 1000):
"""Prune excess edges while preserving node data."""
print(f"Pruning edges to maximum {max_edges} edges...")
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
all_edges = list(graph.edge_list())
current_edges = len(all_edges)
if current_edges > max_edges:
sorted_edges = sorted(all_edges, key=lambda x: x[2].get("weight", 1.0), reverse=True)
edges_to_keep = set()
for edge in sorted_edges[:max_edges]:
edges_to_keep.add((edge[0], edge[1]))
edges_to_remove = []
for edge in all_edges:
if (edge[0], edge[1]) not in edges_to_keep:
edges_to_remove.append((edge[0], edge[1]))
for u, v in edges_to_remove:
try:
graph.remove_edge(u, v)
except Exception as e:
print(f"Error removing edge from {u} to {v}: {e}")
print(f"Pruned edges. Kept top {max_edges} edges by weight.")
print("No pruning required. Current edge count is within limits.")
def update_pagerank(self):
"""Update PageRank values using Rustworkx's pagerank algorithm."""
if not self.current_graph_id:
print("No current graph selected. Cannot compute PageRank.")
return
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
try:
pr = rx.pagerank(graph, weight_fn=lambda e: e.get("weight", 1.0))
node_map = graph_data["node_map"]
for n_id, idx in node_map.items():
node_data = graph.get_node_data(idx)
node_data["pagerank"] = pr[idx]
print("PageRank updated successfully")
except Exception as e:
print(f"Error updating PageRank: {str(e)}")
raise
def display_graph(self):
"""Display the graph using PyVis."""
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
net = Network(height="530px", width="100%", directed=True, bgcolor="#222222", font_color="white")
net.options = {"physics": {"enabled": False}}
all_nodes = set()
all_edges = []
for (u, v), edge_data in zip(graph.edge_list(), graph.edges()):
source_data = graph.get_node_data(u)
target_data = graph.get_node_data(v)
source_id = source_data.get("id")
target_id = target_data.get("id")
source_tooltip = f"Query: {source_data.get('query', 'N/A')}"
target_tooltip = f"Query: {target_data.get('query', 'N/A')}"
if source_id not in all_nodes:
net.add_node(source_id, label=source_id, title=source_tooltip, size=20, color="#00cc66")
all_nodes.add(source_id)
if target_id not in all_nodes:
net.add_node(target_id, label=target_id, title=target_tooltip, size=20, color="#00cc66")
all_nodes.add(target_id)
edge_type = edge_data.get("type", "N/A")
edge_weight = edge_data.get("weight", "N/A")
edge_tooltip = f"Weight: {edge_weight}"
all_edges.append({
"from": source_id,
"to": target_id,
"label": edge_type,
"title": edge_tooltip
})
for edge in all_edges:
net.add_edge(edge["from"], edge["to"], title=edge["title"], color="#cccccc")
net.options["layout"] = {"improvedLayout": True}
net.options["interaction"] = {"dragNodes": True}
original_dir = os.getcwd()
os.chdir(os.getenv("WRITABLE_DIR", "/tmp"))
net.save_graph("temp_graph.html")
with open("temp_graph.html", "r", encoding="utf-8") as f:
html_str = f.read()
os.remove("temp_graph.html")
os.chdir(original_dir)
return html_str
def verify_graph_integrity(self):
"""Verify and fix graph integrity issues."""
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
orphaned = []
for n_id, idx in node_map.items():
if not graph.in_edges(idx) and not graph.out_edges(idx):
orphaned.append(n_id)
if orphaned:
print(f"Found orphaned nodes: {orphaned}")
invalid_edges = []
for u, v in graph.edge_list():
target_data = graph.get_node_data(v)
if target_data.get("graph_id") != self.current_graph_id:
invalid_edges.append((graph.get_node_data(u).get("id"), target_data.get("id")))
if invalid_edges:
print(f"Found invalid edges: {invalid_edges}")
edges_to_remove = []
for u, v in graph.edge_list():
if graph.get_node_data(v).get("graph_id") != self.current_graph_id:
edges_to_remove.append((u, v))
for u, v in edges_to_remove:
try:
graph.remove_edge(u, v)
except Exception as e:
Exception(f"Error removing invalid edge from {u} to {v}: {e}")
print("Graph integrity verified successfully")
return True
def verify_graph_consistency(self):
"""Verify consistency of the in-memory graph."""
graph_data = self._get_current_graph_data()
graph = graph_data["graph"]
node_map = graph_data["node_map"]
for n_id, idx in node_map.items():
node_data = graph.get_node_data(idx)
if node_data.get("id") is None or node_data.get("query") is None:
raise ValueError("Found nodes with missing required properties")
for edge_data in graph.edges():
if edge_data.get("type") is None or edge_data.get("weight") is None:
raise ValueError("Found relationships with missing required properties")
print("Graph consistency verified successfully")
return True
async def close(self):
"""Properly cleanup all resources."""
try:
if hasattr(self, 'executor'):
self.executor.shutdown(wait=True)
if hasattr(self, 'crawler'):
await asyncio.shield(self.crawler.cleanup_expired_sessions())
await asyncio.shield(self.crawler.cleanup_browser_context(getattr(self, "session_id", None)))
except Exception as e:
print(f"Error during cleanup: {e}")
@staticmethod
def cosine_similarity(v1: List[float], v2: List[float]) -> float:
"""Calculate cosine similarity between two vectors."""
try:
v1_array = np.array(v1)
v2_array = np.array(v2)
return np.dot(v1_array, v2_array) / (np.linalg.norm(v1_array) * np.linalg.norm(v2_array))
except Exception as e:
print(f"Error calculating cosine similarity: {str(e)}")
return 0.0
if __name__ == "__main__":
import os
from dotenv import load_dotenv
from src.reasoning.reasoner import Reasoner
from src.evaluation.evaluator import Evaluator
load_dotenv()
graph_search = GraphRAG(num_workers=24)
evaluator = Evaluator()
reasoner = Reasoner()
async def test_graph_search():
# Sample data for testing
queries = [
"""In the context of global economic recovery and energy security concerns, provide an in-depth comparative assessment of the renewable energy policies among G20 countries.
Specifically, examine how short-term economic stimulus measures intersect with long-term decarbonization commitments, including:
1. Carbon pricing mechanisms
2. Subsidies for emerging technologies (such as green hydrogen and battery storage)
3. Cross-border climate finance initiatives
Highlight the unique challenges faced by both advanced and emerging economies in addressing:
1. Energy poverty
2. Supply chain disruptions
3. Geopolitical tensions (e.g., the Russia-Ukraine conflict)
Discuss how these factors influence policy effectiveness, and evaluate the degree to which each country is on track to meet—or exceed—its Paris Agreement targets.
Note any significant policy gaps, regional collaborations, or innovative best practices.
Lastly, provide a forward-looking perspective on how these renewable energy strategies may evolve over the next decade, considering:
1. Technological breakthroughs
2. Global market trends
3. Potential climate-related disasters
Present your analysis as a detailed, well-formatted report.""",
"""Analyse the impact of 'hot-money' on the value of Indian Rupee and answer the following questions:-
1. How does it affect the exchange rate?
2. How can it be mitigated/eliminated?
3. Why is it a problem?
4. What are the consequences?
5. What are the alternatives?
- Evaluate the alternatives for pros and cons.
- Evaluate the impact of alternatives on the exchange rate.
- How can they be implemented?
- What are the consequences of each alternative?
- Evaluate the feasibility of the alternatives.
- Pick top 5 alternatives and justify your choices in detail.
6. What are the implications for the Indian economy? Furthermore:-
- Evaluate the impact of the chosen alternatives on the Indian economy.""",
"""Inflation has been an intrinsic past of human civilization since the very beginning. Answer the following questions:-
1. How true is the above statement?
2. What are the causes of inflation?
3. What are the consequences of inflation?
4. Can we completely eliminate inflation?""",
"""Perform a detailed comparison between the ancient Greece and Roman civilizations.
1. What were the key differences between the two civilizations?
- Evaluate the differences in governance, society, and culture
- Evaluate the differences in economy, trade, and military
- Evaluate the differences in technology and infrastructure
2. What were the similarities between the two civilizations?
- Evaluate the similarities in governance, society, and culture
- Evaluate the similarities in economy, trade, and military
- Evaluate the similarities in technology and infrastructure
3. How did these two civilizations influence each other?
- Evaluate the influence of one civilization on the other
4. How did these two civilizations influence the modern world?
5. Was there another civilization that influenced these two? If yes, how?""",
"""Evaluate the long-term effects of colonialism on economic development in Asia:-
1. Include case studies of at least five different countries
2. Analyze how these effects differ based on colonial power, time of independence, and resource distribution
- Evaluate the impact of colonialism on the economy of the country
- Evaluate the impact of colonialism on the economy of the region
- Evaluate the impact of colonialism on the economy of the world
3. How do these effects compare to Africa?"""
]
follow_on_queries = [
"How is 'hot-money' related to the current economic situation in India?",
"What is inflation?",
"Did ancient Greece and Rome have any impact on modern democracy? If yes, how?",
"Did colonialism have any impact on the trade between Africa and Asia, both in colonial and post-colonial times? If yes, how?"
]
while True:
print("\n\nEnter query (finish input with an empty line):")
query_lines = []
while True:
line = input()
if line.strip() == "":
break
query_lines.append(line)
query = "\n".join(query_lines).strip()
if query.strip().lower() == "exit":
break
print("\n\n" + "="*15 + " Processing Query " + "="*15 + "\n\n")
await graph_search.process_graph(query, similarity_threshold=0.8, relevance_threshold=0.8)
answer = graph_search.query_graph(query)
response = ""
async for chunk in reasoner.reason(query, answer):
response += chunk
print(response, end="", flush=True)
graph_search.display_graph()
evaluation = await evaluator.evaluate_response(query, response, [answer])
print(f"Faithfulness: {evaluation['faithfulness']}")
print(f"Answer Relevancy: {evaluation['answer relevancy']}")
print(f"Context Utilization: {evaluation['contextual recall']}")
await graph_search.close()
asyncio.run(test_graph_search())