Spaces:
Paused
Paused
import os | |
import gc | |
import time | |
import asyncio | |
import torch | |
import uuid | |
import rustworkx as rx | |
import numpy as np | |
from concurrent.futures import ThreadPoolExecutor | |
from typing import List, Dict, Any | |
from pyvis.network import Network | |
from src.query_processing.late_chunking.late_chunker import LateChunker | |
from src.query_processing.query_processor import QueryProcessor | |
from src.reasoning.reasoner import Reasoner | |
from src.utils.api_key_manager import APIKeyManager | |
from src.search.search_engine import SearchEngine | |
from src.crawl.crawler import CustomCrawler #, Crawler | |
from sentence_transformers import SentenceTransformer | |
from bert_score.scorer import BERTScorer | |
from tenacity import RetryError | |
from openai import RateLimitError | |
from anthropic import RateLimitError as AnthropicRateLimitError | |
from google.api_core.exceptions import ResourceExhausted | |
class GraphRAG: | |
def __init__(self, num_workers: int = 1): | |
"""Initialize graph and required components.""" | |
# Dictionary to store multiple graphs | |
self.graphs = {} | |
self.current_graph_id = None | |
# Component initialization | |
self.num_workers = num_workers | |
self.search_engine = SearchEngine() | |
self.query_processor = QueryProcessor() | |
self.reasoner = Reasoner() | |
# self.crawler = Crawler(verbose=True) | |
self.custom_crawler = CustomCrawler(max_concurrent_requests=1000) | |
self.chunking = LateChunker() | |
self.llm = APIKeyManager().get_llm() | |
# Model initialization | |
self.model = SentenceTransformer( | |
"dunzhang/stella_en_400M_v5", | |
trust_remote_code=True, | |
device="cuda" if torch.cuda.is_available() else "cpu" | |
) | |
self.scorer = BERTScorer( | |
model_type="roberta-base", | |
lang="en", | |
rescale_with_baseline=True, | |
device="cuda" if torch.cuda.is_available() else "cpu" | |
) | |
# Counters and tracking | |
self.root_node_id = "QR" | |
self.node_counter = 0 | |
self.sub_node_counter = 0 | |
self.cross_connections = set() | |
# Semaphore protection | |
self.semaphore = asyncio.Semaphore(min(num_workers * 2, 12)) | |
# Thread pool | |
self.executor = ThreadPoolExecutor(max_workers=self.num_workers) | |
# Event callback | |
self.on_event_callback = None | |
def set_on_event_callback(self, callback): | |
"""Register a single callback to be triggered for various event types.""" | |
self.on_event_callback = callback | |
async def emit_event(self, event_type: str, data: dict): | |
"""Helper method to safely emit an event if a callback is registered.""" | |
if self.on_event_callback: | |
if asyncio.iscoroutinefunction(self.on_event_callback): | |
return await self.on_event_callback(event_type, data) | |
else: | |
return self.on_event_callback(event_type, data) | |
def _get_current_graph_data(self): | |
if self.current_graph_id is None or self.current_graph_id not in self.graphs: | |
raise Exception("Error: No current graph selected") | |
return self.graphs[self.current_graph_id] | |
def add_node(self, node_id: str, query: str, data: str = "", role: str = None): | |
"""Add a node to the current graph.""" | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
# Generate embedding | |
embedding = self.model.encode(query).tolist() | |
node_data = { | |
"id": node_id, | |
"query": query, | |
"data": data, | |
"role": role, | |
"embedding": embedding, | |
"pagerank": 0, | |
"graph_id": self.current_graph_id | |
} | |
node_index = graph.add_node(node_data) | |
node_map[node_id] = node_index | |
print(f"Added node '{node_id}' to graph '{self.current_graph_id}' with role '{role}' and query: '{query}'") | |
def _has_path(self, source_idx: int, target_idx: int) -> bool: | |
"""Helper method to check if there is a path from source to target in the current graph.""" | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
visited = set() | |
stack = [source_idx] | |
while stack: | |
current = stack.pop() | |
if current == target_idx: | |
return True | |
if current in visited: | |
continue | |
visited.add(current) | |
for neighbor in graph.neighbors(current): | |
stack.append(neighbor) | |
return False | |
def add_edge(self, node1: str, node2: str, weight: float = 1.0, relationship_type: str = None): | |
"""Add an edge between two nodes in a way that preserves a DAG structure.""" | |
if self.current_graph_id is None: | |
raise Exception("Error: No current graph selected") | |
if node1 == node2: | |
print(f"Cannot add edge to the same node {node1}!") | |
return | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
if node1 not in node_map or node2 not in node_map: | |
print(f"One or both nodes {node1}, {node2} do not exist in the current graph.") | |
return | |
idx1 = node_map[node1] | |
idx2 = node_map[node2] | |
# Check if adding this edge would create a cycle (i.e. if there is a path from node2 to node1) | |
if self._has_path(idx2, idx1): | |
print(f"An edge between {node1} -> {node2} already exists or would create a cycle!") | |
return | |
if relationship_type and weight: | |
edge_data = {"type": relationship_type, "weight": weight} | |
graph.add_edge(idx1, idx2, edge_data) | |
else: | |
raise ValueError("Error: Relationship type and weight must be provided") | |
print(f"Added edge between '{node1}' and '{node2}' in graph '{self.current_graph_id}' (type='{relationship_type}', weight={weight})") | |
def edge_exists(self, node1: str, node2: str) -> bool: | |
"""Check if an edge exists between two nodes.""" | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
if node1 not in node_map or node2 not in node_map: | |
return False | |
idx1 = node_map[node1] | |
idx2 = node_map[node2] | |
for edge in graph.out_edges(idx1): | |
if edge[1] == idx2: | |
return True | |
return False | |
def graph_exists(self) -> bool: | |
"""Check if a graph exists.""" | |
return self.current_graph_id is not None and self.current_graph_id in self.graphs and len(self.graphs[self.current_graph_id]["node_map"]) > 0 | |
def get_graphs(self) -> list: | |
"""Get detailed information about all existing graphs and their nodes.""" | |
result = [] | |
for graph_id, data in self.graphs.items(): | |
metadata = data["metadata"] | |
node_map = data["node_map"] | |
graph = data["graph"] | |
nodes_info = [] | |
for node_id, idx in node_map.items(): | |
node_data = graph.get_node_data(idx) | |
nodes_info.append({ | |
"id": node_data.get("id"), | |
"query": node_data.get("query"), | |
"data": node_data.get("data"), | |
"role": node_data.get("role"), | |
"pagerank": node_data.get("pagerank") | |
}) | |
edge_count = len(graph.edge_list()) | |
result.append({ | |
"graph_info": { | |
"graph_id": graph_id, | |
"created": metadata.get("created"), | |
"updated": metadata.get("updated"), | |
"node_count": len(node_map), | |
"edge_count": edge_count, | |
"nodes": nodes_info | |
} | |
}) | |
result.sort(key=lambda x: x["graph_info"]["created"], reverse=True) | |
return result | |
def select_graph(self, graph_id: str) -> bool: | |
"""Select a specific graph as the current working graph.""" | |
if graph_id in self.graphs: | |
self.current_graph_id = graph_id | |
return True | |
return False | |
def create_new_graph(self) -> str: | |
"""Create a new graph instance and its ID.""" | |
graph_id = str(uuid.uuid4()) | |
graph = rx.PyDiGraph() | |
node_map = {} | |
metadata = { | |
"id": graph_id, | |
"created": time.time(), | |
"updated": time.time() | |
} | |
self.graphs[graph_id] = {"graph": graph, "node_map": node_map, "metadata": metadata} | |
self.current_graph_id = graph_id | |
return graph_id | |
def load_graph(self, node_id: str) -> bool: | |
"""Load an existing graph structure from memory based on a node ID.""" | |
for gid, data in self.graphs.items(): | |
if node_id in data["node_map"]: | |
self.current_graph_id = gid | |
for n_id in data["node_map"].keys(): | |
if "SQ" in n_id: | |
num = int(''.join(filter(str.isdigit, n_id)) or 0) | |
self.node_counter = max(self.node_counter, num) | |
elif "SSQ" in n_id: | |
num = int(''.join(filter(str.isdigit, n_id)) or 0) | |
self.sub_node_counter = max(self.sub_node_counter, num) | |
self.node_counter += 1 | |
self.sub_node_counter += 1 | |
graph = data["graph"] | |
node_map = data["node_map"] | |
for (u, v), edge_data in zip(graph.edge_list(), graph.edges()): | |
if edge_data.get("type") == "logical": | |
source_id = graph.get_node_data(u).get("id") | |
target_id = graph.get_node_data(v).get("id") | |
connection = tuple(sorted([source_id, target_id])) | |
self.cross_connections.add(connection) | |
print(f"Successfully loaded graph. Current counters - Node: {self.node_counter}, Sub: {self.sub_node_counter}") | |
return True | |
print(f"Graph with node_id {node_id} not found.") | |
return False | |
async def modify_graph(self, new_query: str, similar_node_id: str, session_id: str = None): | |
"""Modify an existing graph structure by integrating a new query.""" | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
async def add_as_sibling(node_id: str, query: str): | |
if node_id not in node_map: | |
raise ValueError(f"Node {node_id} not found") | |
idx = node_map[node_id] | |
in_edges = graph.in_edges(idx) | |
if not in_edges: | |
raise ValueError(f"No parent found for node {node_id}") | |
parent_idx = in_edges[0][0] | |
parent_data = graph.get_node_data(parent_idx) | |
parent_id = parent_data.get("id") | |
if "SQ" in node_id: | |
self.node_counter += 1 | |
new_node_id = f"SQ{self.node_counter}" | |
else: | |
self.sub_node_counter += 1 | |
new_node_id = f"SSQ{self.sub_node_counter}" | |
self.add_node(new_node_id, query, role="independent") | |
self.add_edge(parent_id, new_node_id, relationship_type=in_edges[0][2].get("type")) | |
return new_node_id | |
async def add_as_child(node_id: str, query: str): | |
if "SQ" in node_id: | |
self.sub_node_counter += 1 | |
new_node_id = f"SSQ{self.sub_node_counter}" | |
else: | |
self.node_counter += 1 | |
new_node_id = f"SQ{self.node_counter}" | |
self.add_node(new_node_id, query, role="dependent") | |
self.add_edge(node_id, new_node_id, relationship_type="logical") | |
return new_node_id | |
def collect_graph_context() -> list: | |
"""Collect context from existing graph nodes.""" | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
nodes = [] | |
for n_id, idx in node_map.items(): | |
if n_id == self.root_node_id: | |
continue | |
node_data = graph.get_node_data(idx) | |
nodes.append({ | |
"id": node_data.get("id"), | |
"query": node_data.get("query"), | |
"role": node_data.get("role") | |
}) | |
nodes.sort(key=lambda x: (0 if x["id"].startswith("SQ") else (1 if x["id"].startswith("SSQ") else 2), x["id"])) | |
level_queries = {} | |
current_sq = None | |
for node in nodes: | |
node_id = node["id"] | |
if node_id.startswith("SQ"): | |
current_sq = node_id | |
if current_sq not in level_queries: | |
level_queries[current_sq] = { | |
"originalquery": node["query"], | |
"subqueries": [] | |
} | |
level_queries[current_sq]["subqueries"].append({ | |
"subquery": node["query"], | |
"role": node["role"], | |
"dependson": [] | |
}) | |
elif node_id.startswith("SSQ") and current_sq: | |
level_queries[current_sq]["subqueries"].append({ | |
"subquery": node["query"], | |
"role": node["role"], | |
"dependson": [] | |
}) | |
return list(level_queries.values()) | |
if similar_node_id not in node_map: | |
raise Exception(f"Node {similar_node_id} not found") | |
similar_node_data = graph.get_node_data(node_map[similar_node_id]) | |
has_parent = len(graph.in_edges(node_map[similar_node_id])) > 0 | |
context = collect_graph_context() | |
if similar_node_data.get("role") == "independent": | |
if has_parent: | |
new_node_id = await add_as_sibling(similar_node_id, new_query) | |
else: | |
new_node_id = await add_as_child(similar_node_id, new_query) | |
else: | |
new_node_id = await add_as_child(similar_node_id, new_query) | |
await self.build_graph( | |
query=new_query, | |
parent_node_id=new_node_id, | |
depth=1 if "SQ" in new_node_id else 2, | |
context=context, | |
session_id=session_id | |
) | |
async def build_graph(self, query: str, data: str = None, parent_node_id: str = None, | |
depth: int = 0, threshold: float = 0.8, recurse: bool = True, | |
context: list = None, session_id: str = None, max_tokens_allowed: int = 128000, | |
node_data_futures: dict = None, sub_nodes_info: list = None, | |
sub_query_ids: list = None, pre_req_nodes: list = None): | |
"""Build a new graph structure in memory.""" | |
async def process_node(node_id: str, sub_query: str, session_id: str, | |
future: asyncio.Future, max_tokens_allowed: int = max_tokens_allowed): | |
try: | |
optimized_query = await self.search_engine.generate_optimized_query(sub_query) | |
results = await self.search_engine.search( | |
query=optimized_query, | |
num_results=10, | |
exclude_filetypes=["pdf"] | |
) | |
await self.emit_event("search_results_fetched", { | |
"node_id": node_id, | |
"sub_query": sub_query, | |
"optimized_query": optimized_query, | |
"search_results": results | |
}) | |
filtered_urls = await self.search_engine.filter_urls( | |
sub_query, | |
"extensive research dynamic structure", | |
results | |
) | |
await self.emit_event("search_results_filtered", { | |
"node_id": node_id, | |
"sub_query": sub_query, | |
"filtered_urls": filtered_urls | |
}) | |
urls = [result.get('link', 'No URL') for result in filtered_urls] | |
search_contents = await self.custom_crawler.fetch_page_contents( | |
urls, | |
sub_query, | |
session_id=session_id, | |
max_attempts=1, | |
timeout=30 | |
) | |
await self.emit_event("search_contents_fetched", { | |
"node_id": node_id, | |
"sub_query": sub_query, | |
"contents": search_contents | |
}) | |
contents = "" | |
for k, content in enumerate(search_contents, 1): | |
if isinstance(content, Exception): | |
print(f"Error fetching content: {content}") | |
elif content: | |
contents += f"Document {k}:\n{content}\n\n" | |
if contents.strip(): | |
token_count = self.llm.get_num_tokens(contents) | |
if token_count > max_tokens_allowed: | |
contents = await self.chunking.chunker( | |
text=contents, | |
query=sub_query, | |
max_tokens=max_tokens_allowed | |
) | |
print(f"Number of tokens in the answer: {token_count}") | |
print(f"Number of tokens in the content: {self.llm.get_num_tokens(contents)}") | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
if node_id in node_map: | |
idx = node_map[node_id] | |
node_data = graph.get_node_data(idx) | |
node_data["data"] = contents | |
if not future.done(): | |
future.set_result(contents) | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError) as e: | |
print(f"Error processing node {node_id}: {str(e)}") | |
if not future.done(): | |
future.set_exception(e) | |
except Exception as e: | |
print(f"Error processing node {node_id}: {str(e)}") | |
if not future.done(): | |
future.set_exception(e) | |
raise e | |
async def process_dependent_node(node_id: str, sub_query: str, dep_futures: list, future): | |
try: | |
dep_data = [await f for f in dep_futures] | |
modified_query = await self.query_processor.modify_query( | |
sub_query, | |
dep_data | |
) | |
loop = asyncio.get_running_loop() | |
embedding = await loop.run_in_executor( | |
self.executor, | |
self.model.encode, | |
modified_query | |
) | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
if node_id in node_map: | |
idx = node_map[node_id] | |
node_data = graph.get_node_data(idx) | |
node_data["query"] = modified_query | |
node_data["embedding"] = embedding.tolist() if hasattr(embedding, "tolist") else embedding | |
try: | |
if not future.done(): | |
await process_node(node_id, modified_query, session_id, future, max_tokens_allowed) | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError) as e: | |
if not future.done(): | |
future.set_exception(e) | |
except Exception as e: | |
if not future.done(): | |
future.set_exception(e) | |
raise e | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError) as e: | |
print(f"Error processing dependent node {node_id}: {str(e)}") | |
if not future.done(): | |
future.set_exception(e) | |
except Exception as e: | |
print(f"Error processing dependent node {node_id}: {str(e)}") | |
if not future.done(): | |
future.set_exception(e) | |
raise e | |
def create_cross_connections(): | |
try: | |
relationships = self.get_node_relationships(relationship_type='logical') | |
for current_node_id, edges in relationships.items(): | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
if current_node_id not in node_map: | |
continue | |
idx = node_map[current_node_id] | |
node_data = graph.get_node_data(idx) | |
node_role = (node_data.get("role") or "").lower() | |
if node_role == 'dependent': | |
for source_id, target_id, edge_data in edges['in_edges']: | |
if not source_id or source_id == self.root_node_id: | |
continue | |
connection = tuple(sorted([current_node_id, source_id])) | |
if connection not in self.cross_connections: | |
if not self.edge_exists(source_id, current_node_id): | |
print(f"Adding cross-connection edge between {source_id} and {current_node_id}") | |
self.add_edge(source_id, current_node_id, weight=edge_data.get('weight', 1.0), relationship_type='logical') | |
self.cross_connections.add(connection) | |
for source_id, target_id, edge_data in edges['out_edges']: | |
if not target_id or target_id == self.root_node_id: | |
continue | |
connection = tuple(sorted([current_node_id, target_id])) | |
if connection not in self.cross_connections: | |
if not self.edge_exists(current_node_id, target_id): | |
print(f"Adding cross-connection edge between {current_node_id} and {target_id}") | |
self.add_edge(current_node_id, target_id, weight=edge_data.get('weight', 1.0), relationship_type='logical') | |
self.cross_connections.add(connection) | |
except Exception as e: | |
print(f"Error creating cross connections: {str(e)}") | |
raise | |
if depth > 1: | |
return | |
if context is None: | |
context = [] | |
if node_data_futures is None: | |
node_data_futures = {} | |
if sub_nodes_info is None: | |
sub_nodes_info = [] | |
if sub_query_ids is None: | |
sub_query_ids = [] | |
if pre_req_nodes is None: | |
pre_req_nodes = {} | |
if parent_node_id is None: | |
self.add_node(self.root_node_id, query, data) | |
parent_node_id = self.root_node_id | |
intent = await self.query_processor.get_query_intent(query) | |
if depth == 0: | |
response_data, sub_queries, roles, dependencies = await self.query_processor.decompose_query_with_dependencies(query, intent) | |
else: | |
response_data, sub_queries, roles, dependencies = await self.query_processor.decompose_query_with_dependencies(query, intent, context) | |
if response_data: | |
context.append(response_data) | |
if len(sub_queries) > 1 and sub_queries[0] != query: | |
for idx, (sub_query, role, dependency) in enumerate(zip(sub_queries, roles, dependencies)): | |
if depth == 0: | |
await self.emit_event("sub_query_created", { | |
"depth": depth, | |
"sub_query": sub_query, | |
"role": role, | |
"dependency": dependency, | |
"parent_node_id": parent_node_id, | |
}) | |
if depth == 0: | |
self.node_counter += 1 | |
sub_node_id = f"SQ{self.node_counter}" | |
else: | |
self.sub_node_counter += 1 | |
sub_node_id = f"SSQ{self.sub_node_counter}" | |
sub_query_ids.append(sub_node_id) | |
self.add_node(sub_node_id, sub_query, role=role) | |
future = asyncio.Future() | |
node_data_futures[sub_node_id] = future | |
sub_nodes_info.append((sub_node_id, sub_query, role, dependency, future, depth)) | |
if role.lower() in ['pre-requisite', 'prerequisite']: | |
pre_req_nodes[idx] = sub_node_id | |
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'): | |
self.add_edge(parent_node_id, sub_node_id, relationship_type='hierarchical') | |
elif role.lower() == 'dependent': | |
if isinstance(dependency, list) and (len(dependency) == 2 and all(isinstance(d, list) for d in dependency)): | |
print(f"Dependency: {dependency}") | |
prev_deps, current_deps = dependency | |
if context and prev_deps not in [None, []]: | |
for dep_idx in prev_deps: | |
if dep_idx is not None: | |
for context_data in context: | |
if 'subqueries' in context_data and dep_idx < len(context_data['subqueries']): | |
sub_query_data = context_data['subqueries'][dep_idx] | |
if isinstance(sub_query_data, dict) and 'subquery' in sub_query_data: | |
dep_query = sub_query_data['subquery'] | |
matching_nodes = self.find_nodes_by_properties(query=dep_query) | |
if matching_nodes: | |
dep_node_id = matching_nodes[0].get('node_id') | |
score = matching_nodes[0].get('score', 0) | |
if score >= 0.9: | |
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') | |
if current_deps not in [None, []]: | |
for dep_idx in current_deps: | |
if dep_idx < len(sub_query_ids): | |
dep_node_id = sub_query_ids[dep_idx] | |
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') | |
else: | |
raise ValueError(f"Invalid dependency index: {dep_idx}") | |
elif len(dependency) > 0: | |
for dep_idx in dependency: | |
if dep_idx < len(sub_queries): | |
dep_node_id = sub_query_ids[dep_idx] | |
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') | |
else: | |
raise ValueError(f"Invalid dependency index: {dep_idx}") | |
else: | |
raise ValueError(f"Invalid dependency: {dependency}") | |
else: | |
raise ValueError(f"Unexpected role: {role}") | |
if recurse: | |
recursion_tasks = [] | |
for idx, sub_query in enumerate(sub_queries): | |
try: | |
sub_node_id = sub_query_ids[idx] | |
recursion_tasks.append( | |
self.build_graph( | |
query=sub_query, | |
parent_node_id=sub_node_id, | |
depth=depth + 1, | |
threshold=threshold, | |
recurse=recurse, | |
context=context, | |
session_id=session_id, | |
node_data_futures=node_data_futures, | |
sub_nodes_info=sub_nodes_info, | |
sub_query_ids=sub_query_ids, | |
pre_req_nodes=pre_req_nodes | |
) | |
) | |
except Exception as e: | |
print(f"Failed to create recursion task for sub-query {sub_query}: {e}") | |
continue | |
if recursion_tasks: | |
try: | |
await asyncio.gather(*recursion_tasks, return_exceptions=True) | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError) as e: | |
print(f"Error during recursive processing: {e}") | |
except Exception as e: | |
print(f"Error during recursive processing: {e}") | |
raise e | |
futures = {} | |
all_child_futures = {} | |
process_tasks = [] | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
for (sub_node_id, sub_query, role, dependency, future, local_depth) in sub_nodes_info: | |
idx = node_map.get(sub_node_id) | |
has_children = False | |
child_futures = [] | |
if idx is not None: | |
for (_, child_idx, edge_data) in graph.out_edges(idx): | |
if edge_data.get("type") == "hierarchical": | |
has_children = True | |
child_future = node_data_futures.get(graph.get_node_data(child_idx).get("id")) | |
if child_future: | |
child_futures.append(child_future) | |
if local_depth == 0: | |
futures[sub_query] = future | |
all_child_futures[sub_query] = child_futures | |
if has_children: | |
if not future.done(): | |
future.set_result("") | |
else: | |
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'): | |
process_tasks.append(process_node(sub_node_id, sub_query, session_id, future, max_tokens_allowed)) | |
elif role.lower() == 'dependent': | |
dep_futures = [] | |
if isinstance(dependency, list) and len(dependency) == 2: | |
prev_deps, current_deps = dependency | |
if context and prev_deps not in [None, []]: | |
for context_idx, context_data in enumerate(context): | |
if isinstance(prev_deps, list) and context_idx < len(prev_deps): | |
context_dep = prev_deps[context_idx] | |
if (context_dep is not None and isinstance(context_data, dict) | |
and 'subqueries' in context_data): | |
if context_dep < len(context_data['subqueries']): | |
dep_query = context_data['subqueries'][context_dep]['subquery'] | |
matching_nodes = self.find_nodes_by_properties(query=dep_query) | |
if matching_nodes not in [None, []]: | |
dep_node_id = matching_nodes[0].get('node_id', None) | |
score = float(matching_nodes[0].get('score', 0)) | |
if score == 1.0 and dep_node_id in node_data_futures: | |
dep_futures.append(node_data_futures[dep_node_id]) | |
elif isinstance(prev_deps, int): | |
if context_idx < len(context_data['subqueries']): | |
dep_query = context_data['subqueries'][prev_deps]['subquery'] | |
matching_nodes = self.find_nodes_by_properties(query=dep_query) | |
if matching_nodes not in [None, []]: | |
dep_node_id = matching_nodes[0].get('node_id', None) | |
score = matching_nodes[0].get('score', 0) | |
if score == 1.0 and dep_node_id in node_data_futures: | |
dep_futures.append(node_data_futures[dep_node_id]) | |
if current_deps not in [None, []]: | |
current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps | |
for dep_idx in current_deps_list: | |
if dep_idx < len(sub_query_ids): | |
dep_node_id = sub_query_ids[dep_idx] | |
if dep_node_id in node_data_futures: | |
dep_futures.append(node_data_futures[dep_node_id]) | |
process_tasks.append(process_dependent_node(sub_node_id, sub_query, dep_futures, future)) | |
else: | |
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'): | |
process_tasks.append(process_node(sub_node_id, sub_query, session_id, future, max_tokens_allowed)) | |
elif role.lower() == 'dependent': | |
dep_futures = [] | |
if isinstance(dependency, list) and len(dependency) == 2: | |
prev_deps, current_deps = dependency | |
if context and prev_deps not in [None, []]: | |
for context_idx, context_data in enumerate(context): | |
if isinstance(prev_deps, list) and context_idx < len(prev_deps): | |
context_dep = prev_deps[context_idx] | |
if (context_dep is not None and isinstance(context_data, dict) | |
and 'subqueries' in context_data): | |
if context_dep < len(context_data['subqueries']): | |
dep_query = context_data['subqueries'][context_dep]['subquery'] | |
matching_nodes = self.find_nodes_by_properties(query=dep_query) | |
if matching_nodes not in [None, []]: | |
dep_node_id = matching_nodes[0].get('node_id', None) | |
score = float(matching_nodes[0].get('score', 0)) | |
if score == 1.0 and dep_node_id in node_data_futures: | |
dep_futures.append(node_data_futures[dep_node_id]) | |
elif isinstance(prev_deps, int): | |
if context_idx < len(context_data['subqueries']): | |
dep_query = context_data['subqueries'][prev_deps]['subquery'] | |
matching_nodes = self.find_nodes_by_properties(query=dep_query) | |
if matching_nodes not in [None, []]: | |
dep_node_id = matching_nodes[0].get('node_id', None) | |
score = matching_nodes[0].get('score', 0) | |
if score == 1.0 and dep_node_id in node_data_futures: | |
dep_futures.append(node_data_futures[dep_node_id]) | |
if current_deps not in [None, []]: | |
current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps | |
for dep_idx in current_deps_list: | |
if dep_idx < len(sub_query_ids): | |
dep_node_id = sub_query_ids[dep_idx] | |
if dep_node_id in node_data_futures: | |
dep_futures.append(node_data_futures[dep_node_id]) | |
process_tasks.append(process_dependent_node(sub_node_id, sub_query, dep_futures, future)) | |
if process_tasks: | |
await self.emit_event("search_process_started", { | |
"depth": depth, | |
"sub_queries": sub_queries, | |
"roles": roles | |
}) | |
processed_sub_queries = set() | |
for sub_query, future in futures.items(): | |
try: | |
parent_content = future.result().strip() | |
except: | |
parent_content = "" | |
child_futures = all_child_futures.get(sub_query) | |
any_child_done = any(cf.done() and cf.result().strip() for cf in child_futures) | |
if parent_content or any_child_done: | |
await self.emit_event("sub_query_processed", {"sub_query": sub_query}) | |
processed_sub_queries.add(sub_query) | |
await asyncio.gather(*process_tasks) | |
if depth == 0: | |
for sub_query, future in futures.items(): | |
if sub_query not in processed_sub_queries: | |
try: | |
parent_content = future.result().strip() | |
except: | |
parent_content = "" | |
child_futures = all_child_futures.get(sub_query) | |
any_child_done = any(cf.done() and cf.result().strip() for cf in child_futures) | |
if parent_content or any_child_done: | |
await self.emit_event("sub_query_processed", {"sub_query": sub_query}) | |
else: | |
await self.emit_event("sub_query_failed", {"sub_query": sub_query}) | |
print("Graph building complete, processing final tasks...") | |
await self.emit_event("search_process_completed", { | |
"depth": depth, | |
"sub_queries": sub_queries, | |
"roles": roles | |
}) | |
create_cross_connections() | |
print("All cross-connections have been created!") | |
print(f"Adding similarity edges with threshold {threshold}") | |
graph_data = self._get_current_graph_data() | |
node_map = graph_data["node_map"] | |
all_node_ids = list(node_map.keys()) | |
for i, node1 in enumerate(all_node_ids): | |
for node2 in all_node_ids[i+1:]: | |
if not self.edge_exists(node1, node2): | |
self.add_edge_based_on_similarity_and_relevance(node1, node2, query, threshold) | |
print("All similarity edges have been added!") | |
async def process_graph( | |
self, | |
query: str, | |
data: str = None, | |
similarity_threshold: float = 0.8, | |
relevance_threshold: float = 0.7, | |
sub_sub_queries: bool = True, | |
session_id: str = None, | |
max_tokens_allowed: int = 128000 | |
): | |
"""Process a query and manage graph creation/modification.""" | |
def check_query_similarity(new_query: str, similarity_threshold: float = 0.8) -> Dict[str, Any]: | |
if self.current_graph_id is None: | |
raise Exception("Error: No current graph ID. Cannot check query similarity.") | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
similarities = [] | |
if not node_map: | |
return {"should_create_new": True} | |
for node_id, idx in node_map.items(): | |
node_data = graph.get_node_data(idx) | |
if not node_data.get("query"): | |
continue | |
similarity = self.calculate_query_similarity(new_query, node_data.get("query")) | |
if similarity >= similarity_threshold: | |
similarities.append({ | |
"node_id": node_id, | |
"query": node_data.get("query"), | |
"score": similarity, | |
"role": node_data.get("role") | |
}) | |
if not similarities: | |
print(f"No similar queries found above threshold {similarity_threshold}") | |
return {"should_create_new": True} | |
best_match = max(similarities, key=lambda x: x["score"]) | |
rel_type = "root" | |
if "SSQ" in best_match["node_id"]: | |
rel_type = "sub-sub" | |
elif "SQ" in best_match["node_id"]: | |
rel_type = "sub" | |
return { | |
"most_similar_query": best_match["query"], | |
"similarity_score": best_match["score"], | |
"relationship_type": rel_type, | |
"node_id": best_match["node_id"], | |
"should_create_new": best_match["score"] < similarity_threshold | |
} | |
try: | |
graphs = self.get_graphs() | |
if not graphs: | |
print("No existing graphs found. Creating new graph.") | |
self.create_new_graph() | |
await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"}) | |
await self.build_graph( | |
query=query, | |
data=data, | |
threshold=relevance_threshold, | |
recurse=sub_sub_queries, | |
session_id=session_id, | |
max_tokens_allowed=max_tokens_allowed | |
) | |
gc.collect() | |
self.prune_edges() | |
self.update_pagerank() | |
self.verify_graph_integrity() | |
self.verify_graph_consistency() | |
return | |
max_similarity = 0 | |
most_similar_graph = None | |
consolidated_graphs = {} | |
for graph_obj in graphs: | |
graph_info = graph_obj.get("graph_info") | |
if not graph_info: | |
continue | |
graph_id = graph_info.get("graph_id") | |
if not graph_id: | |
continue | |
if graph_id not in consolidated_graphs: | |
consolidated_graphs[graph_id] = { | |
"graph_id": graph_id, | |
"nodes": [] | |
} | |
if graph_info.get("nodes"): | |
consolidated_graphs[graph_id]["nodes"].extend(graph_info["nodes"]) | |
for graph_id, graph_data in consolidated_graphs.items(): | |
nodes = graph_data["nodes"] | |
for node in nodes: | |
if node.get("query"): | |
similarity = self.calculate_query_similarity(query, node["query"]) | |
if node.get("id", "").startswith("SQ"): | |
asyncio.create_task(self.emit_event("retrieved_sub_query", { | |
"sub_query": node["query"] | |
})) | |
if similarity > max_similarity: | |
max_similarity = similarity | |
most_similar_graph = graph_id | |
if max_similarity >= similarity_threshold: | |
print(f"Found similar query with score {round(max_similarity, 2)}") | |
self.current_graph_id = most_similar_graph | |
if round(max_similarity, 2) == 1.0: | |
print("Loading and using existing graph") | |
await self.emit_event("graph_operation", {"operation_type": "loading_existing_graph"}) | |
success = self.load_graph(self.root_node_id) | |
if not success: | |
raise Exception("Failed to load existing graph") | |
else: | |
print("Checking for node-level similarity...") | |
similarity_info = check_query_similarity( | |
query, | |
similarity_threshold | |
) | |
if similarity_info["relationship_type"] in ["sub", "sub-sub"]: | |
print(f"Most Similar Query: {similarity_info['most_similar_query']}") | |
print("Modifying existing graph structure") | |
await self.emit_event("graph_operation", {"operation_type": "modifying_existing_graph"}) | |
await self.modify_graph( | |
query, | |
similarity_info["node_id"], | |
session_id=session_id | |
) | |
gc.collect() | |
self.prune_edges() | |
self.update_pagerank() | |
self.verify_graph_integrity() | |
self.verify_graph_consistency() | |
else: | |
print(f"Creating new graph for query: {query}") | |
self.create_new_graph() | |
await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"}) | |
await self.build_graph( | |
query=query, | |
data=data, | |
threshold=relevance_threshold, | |
recurse=sub_sub_queries, | |
session_id=session_id, | |
max_tokens_allowed=max_tokens_allowed | |
) | |
gc.collect() | |
self.prune_edges() | |
self.update_pagerank() | |
self.verify_graph_integrity() | |
self.verify_graph_consistency() | |
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError): | |
pass | |
except Exception as e: | |
print(f"Error in process_graph: {str(e)}") | |
raise | |
def add_edge_based_on_similarity_and_relevance(self, node1_id: str, node2_id: str, query: str, threshold: float = 0.8): | |
"""Add edges based on node similarity and relevance.""" | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
if node1_id not in node_map or node2_id not in node_map: | |
return | |
idx1 = node_map[node1_id] | |
idx2 = node_map[node2_id] | |
node1_data = graph.get_node_data(idx1) | |
node2_data = graph.get_node_data(idx2) | |
if not all([node1_data.get("embedding"), node2_data.get("embedding"), node1_data.get("data"), node2_data.get("data")]): | |
return | |
similarity = self.cosine_similarity(node1_data["embedding"], node2_data["embedding"]) | |
query_relevance1 = self.calculate_relevance(query, node1_data["data"]) | |
query_relevance2 = self.calculate_relevance(query, node2_data["data"]) | |
node_relevance = self.calculate_relevance(node1_data["data"], node2_data["data"]) | |
weight = (similarity + query_relevance1 + query_relevance2 + node_relevance) / 4 | |
if weight >= threshold: | |
self.add_edge(node1_id, node2_id, weight=weight, relationship_type='similarity_and_relevance') | |
print(f"Added edge between {node1_id} and {node2_id} with type similarity_and_relevance and weight {weight}") | |
def calculate_relevance(self, data1: str, data2: str) -> float: | |
"""Calculate relevance between two data strings.""" | |
try: | |
if not data1 or not data2: | |
return 0.0 | |
P, R, F1 = self.scorer.score([data1], [data2]) | |
return F1.mean().item() | |
except Exception as e: | |
print(f"Error calculating relevance: {str(e)}") | |
return 0.0 | |
def calculate_query_similarity(self, query1: str, query2: str) -> float: | |
"""Calculate similarity between two queries.""" | |
try: | |
embedding1 = self.model.encode(query1).tolist() | |
embedding2 = self.model.encode(query2).tolist() | |
return self.cosine_similarity(embedding1, embedding2) | |
except Exception as e: | |
print(f"Error calculating query similarity: {str(e)}") | |
return 0.0 | |
def get_similarities_and_relevance(self, threshold: float = 0.8) -> list: | |
"""Get similarities and relevance between nodes.""" | |
try: | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
nodes = [] | |
for node_id, idx in node_map.items(): | |
node_data = graph.get_node_data(idx) | |
nodes.append({ | |
"id": node_data.get("id"), | |
"embedding": node_data.get("embedding"), | |
"data": node_data.get("data") | |
}) | |
similarities = [] | |
for i, node1 in enumerate(nodes): | |
for node2 in nodes[i + 1:]: | |
similarity = self.cosine_similarity(node1["embedding"], node2["embedding"]) | |
relevance = self.calculate_relevance(node1["data"], node2["data"]) | |
weight = (similarity + relevance) / 2 | |
if weight >= threshold: | |
similarities.append({ | |
'node1': node1["id"], | |
'node2': node2["id"], | |
'similarity': similarity, | |
'relevance': relevance, | |
'weight': weight | |
}) | |
return similarities | |
except Exception as e: | |
print(f"Error getting similarities and relevance: {str(e)}") | |
return [] | |
def get_node_relationships(self, node_id=None, depth=None, role=None, relationship_type=None): | |
"""Get relationships between nodes with filtering options.""" | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
relationships = {} | |
for n_id, idx in node_map.items(): | |
if n_id == self.root_node_id: | |
continue | |
node_data = graph.get_node_data(idx) | |
if node_id and n_id != node_id: | |
continue | |
if role and node_data.get("role") != role: | |
continue | |
in_edges = [] | |
for u, v, edge_data in graph.in_edges(idx): | |
source_id = graph.get_node_data(u).get("id") | |
in_edges.append((source_id, n_id, {"weight": edge_data.get("weight"), "type": edge_data.get("type")})) | |
out_edges = [] | |
for u, v, edge_data in graph.out_edges(idx): | |
target_id = graph.get_node_data(v).get("id") | |
out_edges.append((n_id, target_id, {"weight": edge_data.get("weight"), "type": edge_data.get("type")})) | |
relationships[n_id] = {"in_edges": in_edges, "out_edges": out_edges} | |
return relationships | |
def find_nodes_by_properties(self, query: str = None, embedding: list = None, | |
node_data: dict = None, similarity_threshold: float = 0.8) -> list: | |
"""Find nodes based on properties.""" | |
try: | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
matching_nodes = [] | |
for n_id, idx in node_map.items(): | |
data = graph.get_node_data(idx) | |
match_score = 0 | |
matches = 0 | |
if query and query.lower() in data.get("query", "").lower(): | |
match_score += 1 | |
matches += 1 | |
if embedding and "embedding" in data: | |
sim = self.cosine_similarity(embedding, data["embedding"]) | |
if sim >= similarity_threshold: | |
match_score += sim | |
matches += 1 | |
if node_data: | |
data_matches = sum(1 for k, v in node_data.items() if k in data and data[k] == v) | |
if data_matches > 0: | |
match_score += data_matches / len(node_data) | |
matches += 1 | |
if matches > 0: | |
matching_nodes.append({ | |
"node_id": n_id, | |
"score": match_score / matches, | |
"data": data | |
}) | |
matching_nodes.sort(key=lambda x: x["score"], reverse=True) | |
return matching_nodes | |
except Exception as e: | |
print(f"Error finding nodes by properties: {str(e)}") | |
raise | |
def query_graph(self, query: str) -> str: | |
"""Query the graph for a specific query, collecting data from the entire relevant subgraph.""" | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
target_node_id = None | |
for n_id, idx in node_map.items(): | |
if graph.get_node_data(idx).get("query") == query: | |
target_node_id = n_id | |
break | |
if not target_node_id: | |
raise ValueError(f"Query node not found for: {query}") | |
datas = [] | |
start_idx = node_map[target_node_id] | |
visited = set() | |
stack = [start_idx] | |
while stack: | |
current = stack.pop() | |
if current in visited: | |
continue | |
visited.add(current) | |
current_data = graph.get_node_data(current) | |
if current_data.get("data") and current_data.get("data").strip(): | |
datas.append(current_data.get("data").strip()) | |
for neighbor in graph.neighbors(current): | |
if neighbor not in visited: | |
stack.append(neighbor) | |
if not datas: | |
print(f"No data found for: {query}") | |
return "" | |
return "\n\n".join([f"Data {i+1}:\n{data}" for i, data in enumerate(datas)]) | |
def prune_edges(self, max_edges: int = 1000): | |
"""Prune excess edges while preserving node data.""" | |
print(f"Pruning edges to maximum {max_edges} edges...") | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
all_edges = list(graph.edge_list()) | |
current_edges = len(all_edges) | |
if current_edges > max_edges: | |
sorted_edges = sorted(all_edges, key=lambda x: x[2].get("weight", 1.0), reverse=True) | |
edges_to_keep = set() | |
for edge in sorted_edges[:max_edges]: | |
edges_to_keep.add((edge[0], edge[1])) | |
edges_to_remove = [] | |
for edge in all_edges: | |
if (edge[0], edge[1]) not in edges_to_keep: | |
edges_to_remove.append((edge[0], edge[1])) | |
for u, v in edges_to_remove: | |
try: | |
graph.remove_edge(u, v) | |
except Exception as e: | |
print(f"Error removing edge from {u} to {v}: {e}") | |
print(f"Pruned edges. Kept top {max_edges} edges by weight.") | |
print("No pruning required. Current edge count is within limits.") | |
def update_pagerank(self): | |
"""Update PageRank values using Rustworkx's pagerank algorithm.""" | |
if not self.current_graph_id: | |
print("No current graph selected. Cannot compute PageRank.") | |
return | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
try: | |
pr = rx.pagerank(graph, weight_fn=lambda e: e.get("weight", 1.0)) | |
node_map = graph_data["node_map"] | |
for n_id, idx in node_map.items(): | |
node_data = graph.get_node_data(idx) | |
node_data["pagerank"] = pr[idx] | |
print("PageRank updated successfully") | |
except Exception as e: | |
print(f"Error updating PageRank: {str(e)}") | |
raise | |
def display_graph(self): | |
"""Display the graph using PyVis.""" | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
net = Network(height="530px", width="100%", directed=True, bgcolor="#222222", font_color="white") | |
net.options = {"physics": {"enabled": False}} | |
all_nodes = set() | |
all_edges = [] | |
for (u, v), edge_data in zip(graph.edge_list(), graph.edges()): | |
source_data = graph.get_node_data(u) | |
target_data = graph.get_node_data(v) | |
source_id = source_data.get("id") | |
target_id = target_data.get("id") | |
source_tooltip = f"Query: {source_data.get('query', 'N/A')}" | |
target_tooltip = f"Query: {target_data.get('query', 'N/A')}" | |
if source_id not in all_nodes: | |
net.add_node(source_id, label=source_id, title=source_tooltip, size=20, color="#00cc66") | |
all_nodes.add(source_id) | |
if target_id not in all_nodes: | |
net.add_node(target_id, label=target_id, title=target_tooltip, size=20, color="#00cc66") | |
all_nodes.add(target_id) | |
edge_type = edge_data.get("type", "N/A") | |
edge_weight = edge_data.get("weight", "N/A") | |
edge_tooltip = f"Weight: {edge_weight}" | |
all_edges.append({ | |
"from": source_id, | |
"to": target_id, | |
"label": edge_type, | |
"title": edge_tooltip | |
}) | |
for edge in all_edges: | |
net.add_edge(edge["from"], edge["to"], title=edge["title"], color="#cccccc") | |
net.options["layout"] = {"improvedLayout": True} | |
net.options["interaction"] = {"dragNodes": True} | |
original_dir = os.getcwd() | |
os.chdir(os.getenv("WRITABLE_DIR", "/tmp")) | |
net.save_graph("temp_graph.html") | |
with open("temp_graph.html", "r", encoding="utf-8") as f: | |
html_str = f.read() | |
os.remove("temp_graph.html") | |
os.chdir(original_dir) | |
return html_str | |
def verify_graph_integrity(self): | |
"""Verify and fix graph integrity issues.""" | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
orphaned = [] | |
for n_id, idx in node_map.items(): | |
if not graph.in_edges(idx) and not graph.out_edges(idx): | |
orphaned.append(n_id) | |
if orphaned: | |
print(f"Found orphaned nodes: {orphaned}") | |
invalid_edges = [] | |
for u, v in graph.edge_list(): | |
target_data = graph.get_node_data(v) | |
if target_data.get("graph_id") != self.current_graph_id: | |
invalid_edges.append((graph.get_node_data(u).get("id"), target_data.get("id"))) | |
if invalid_edges: | |
print(f"Found invalid edges: {invalid_edges}") | |
edges_to_remove = [] | |
for u, v in graph.edge_list(): | |
if graph.get_node_data(v).get("graph_id") != self.current_graph_id: | |
edges_to_remove.append((u, v)) | |
for u, v in edges_to_remove: | |
try: | |
graph.remove_edge(u, v) | |
except Exception as e: | |
Exception(f"Error removing invalid edge from {u} to {v}: {e}") | |
print("Graph integrity verified successfully") | |
return True | |
def verify_graph_consistency(self): | |
"""Verify consistency of the in-memory graph.""" | |
graph_data = self._get_current_graph_data() | |
graph = graph_data["graph"] | |
node_map = graph_data["node_map"] | |
for n_id, idx in node_map.items(): | |
node_data = graph.get_node_data(idx) | |
if node_data.get("id") is None or node_data.get("query") is None: | |
raise ValueError("Found nodes with missing required properties") | |
for edge_data in graph.edges(): | |
if edge_data.get("type") is None or edge_data.get("weight") is None: | |
raise ValueError("Found relationships with missing required properties") | |
print("Graph consistency verified successfully") | |
return True | |
async def close(self): | |
"""Properly cleanup all resources.""" | |
try: | |
if hasattr(self, 'executor'): | |
self.executor.shutdown(wait=True) | |
if hasattr(self, 'crawler'): | |
await asyncio.shield(self.crawler.cleanup_expired_sessions()) | |
await asyncio.shield(self.crawler.cleanup_browser_context(getattr(self, "session_id", None))) | |
except Exception as e: | |
print(f"Error during cleanup: {e}") | |
def cosine_similarity(v1: List[float], v2: List[float]) -> float: | |
"""Calculate cosine similarity between two vectors.""" | |
try: | |
v1_array = np.array(v1) | |
v2_array = np.array(v2) | |
return np.dot(v1_array, v2_array) / (np.linalg.norm(v1_array) * np.linalg.norm(v2_array)) | |
except Exception as e: | |
print(f"Error calculating cosine similarity: {str(e)}") | |
return 0.0 | |
if __name__ == "__main__": | |
import os | |
from dotenv import load_dotenv | |
from src.reasoning.reasoner import Reasoner | |
from src.evaluation.evaluator import Evaluator | |
load_dotenv() | |
graph_search = GraphRAG(num_workers=24) | |
evaluator = Evaluator() | |
reasoner = Reasoner() | |
async def test_graph_search(): | |
# Sample data for testing | |
queries = [ | |
"""In the context of global economic recovery and energy security concerns, provide an in-depth comparative assessment of the renewable energy policies among G20 countries. | |
Specifically, examine how short-term economic stimulus measures intersect with long-term decarbonization commitments, including: | |
1. Carbon pricing mechanisms | |
2. Subsidies for emerging technologies (such as green hydrogen and battery storage) | |
3. Cross-border climate finance initiatives | |
Highlight the unique challenges faced by both advanced and emerging economies in addressing: | |
1. Energy poverty | |
2. Supply chain disruptions | |
3. Geopolitical tensions (e.g., the Russia-Ukraine conflict) | |
Discuss how these factors influence policy effectiveness, and evaluate the degree to which each country is on track to meet—or exceed—its Paris Agreement targets. | |
Note any significant policy gaps, regional collaborations, or innovative best practices. | |
Lastly, provide a forward-looking perspective on how these renewable energy strategies may evolve over the next decade, considering: | |
1. Technological breakthroughs | |
2. Global market trends | |
3. Potential climate-related disasters | |
Present your analysis as a detailed, well-formatted report.""", | |
"""Analyse the impact of 'hot-money' on the value of Indian Rupee and answer the following questions:- | |
1. How does it affect the exchange rate? | |
2. How can it be mitigated/eliminated? | |
3. Why is it a problem? | |
4. What are the consequences? | |
5. What are the alternatives? | |
- Evaluate the alternatives for pros and cons. | |
- Evaluate the impact of alternatives on the exchange rate. | |
- How can they be implemented? | |
- What are the consequences of each alternative? | |
- Evaluate the feasibility of the alternatives. | |
- Pick top 5 alternatives and justify your choices in detail. | |
6. What are the implications for the Indian economy? Furthermore:- | |
- Evaluate the impact of the chosen alternatives on the Indian economy.""", | |
"""Inflation has been an intrinsic past of human civilization since the very beginning. Answer the following questions:- | |
1. How true is the above statement? | |
2. What are the causes of inflation? | |
3. What are the consequences of inflation? | |
4. Can we completely eliminate inflation?""", | |
"""Perform a detailed comparison between the ancient Greece and Roman civilizations. | |
1. What were the key differences between the two civilizations? | |
- Evaluate the differences in governance, society, and culture | |
- Evaluate the differences in economy, trade, and military | |
- Evaluate the differences in technology and infrastructure | |
2. What were the similarities between the two civilizations? | |
- Evaluate the similarities in governance, society, and culture | |
- Evaluate the similarities in economy, trade, and military | |
- Evaluate the similarities in technology and infrastructure | |
3. How did these two civilizations influence each other? | |
- Evaluate the influence of one civilization on the other | |
4. How did these two civilizations influence the modern world? | |
5. Was there another civilization that influenced these two? If yes, how?""", | |
"""Evaluate the long-term effects of colonialism on economic development in Asia:- | |
1. Include case studies of at least five different countries | |
2. Analyze how these effects differ based on colonial power, time of independence, and resource distribution | |
- Evaluate the impact of colonialism on the economy of the country | |
- Evaluate the impact of colonialism on the economy of the region | |
- Evaluate the impact of colonialism on the economy of the world | |
3. How do these effects compare to Africa?""" | |
] | |
follow_on_queries = [ | |
"How is 'hot-money' related to the current economic situation in India?", | |
"What is inflation?", | |
"Did ancient Greece and Rome have any impact on modern democracy? If yes, how?", | |
"Did colonialism have any impact on the trade between Africa and Asia, both in colonial and post-colonial times? If yes, how?" | |
] | |
while True: | |
print("\n\nEnter query (finish input with an empty line):") | |
query_lines = [] | |
while True: | |
line = input() | |
if line.strip() == "": | |
break | |
query_lines.append(line) | |
query = "\n".join(query_lines).strip() | |
if query.strip().lower() == "exit": | |
break | |
print("\n\n" + "="*15 + " Processing Query " + "="*15 + "\n\n") | |
await graph_search.process_graph(query, similarity_threshold=0.8, relevance_threshold=0.8) | |
answer = graph_search.query_graph(query) | |
response = "" | |
async for chunk in reasoner.reason(query, answer): | |
response += chunk | |
print(response, end="", flush=True) | |
graph_search.display_graph() | |
evaluation = await evaluator.evaluate_response(query, response, [answer]) | |
print(f"Faithfulness: {evaluation['faithfulness']}") | |
print(f"Answer Relevancy: {evaluation['answer relevancy']}") | |
print(f"Context Utilization: {evaluation['contextual recall']}") | |
await graph_search.close() | |
asyncio.run(test_graph_search()) |