Spaces:
Paused
Paused
Hemang Thakur
commited on
Commit
·
c3d9a20
1
Parent(s):
c8abe84
made changes to graph rag file and some ui elements
Browse files- frontend/src/Components/AiComponents/ChatWindow.css +2 -1
- frontend/src/Components/AiComponents/ChatWindow.js +1 -4
- frontend/src/Components/AiComponents/Graph.css +2 -2
- frontend/src/Components/AiComponents/Sources.css +4 -0
- frontend/src/Components/AiComponents/Sources.js +1 -1
- main.py +4 -0
- requirements.txt +3 -2
- src/rag/graph_rag.py +188 -102
frontend/src/Components/AiComponents/ChatWindow.css
CHANGED
@@ -127,6 +127,7 @@
|
|
127 |
.markdown {
|
128 |
margin: -1rem 0 -0.8rem 0;
|
129 |
line-height: 2rem;
|
|
|
130 |
}
|
131 |
|
132 |
/* Post-answer icons container: placed below the bot bubble */
|
@@ -275,4 +276,4 @@
|
|
275 |
100% {
|
276 |
transform: rotate(360deg);
|
277 |
}
|
278 |
-
}
|
|
|
127 |
.markdown {
|
128 |
margin: -1rem 0 -0.8rem 0;
|
129 |
line-height: 2rem;
|
130 |
+
white-space: pre-wrap;
|
131 |
}
|
132 |
|
133 |
/* Post-answer icons container: placed below the bot bubble */
|
|
|
276 |
100% {
|
277 |
transform: rotate(360deg);
|
278 |
}
|
279 |
+
}
|
frontend/src/Components/AiComponents/ChatWindow.js
CHANGED
@@ -168,10 +168,7 @@ function ChatWindow({
|
|
168 |
)}
|
169 |
{renderSourcesRead() !== null && (
|
170 |
<div className="sources-read-container">
|
171 |
-
<p
|
172 |
-
className="sources-read"
|
173 |
-
onClick={() => openRightSidebar("sources", blockId)}
|
174 |
-
>
|
175 |
Sources Read: {renderSourcesRead()}
|
176 |
</p>
|
177 |
</div>
|
|
|
168 |
)}
|
169 |
{renderSourcesRead() !== null && (
|
170 |
<div className="sources-read-container">
|
171 |
+
<p className="sources-read">
|
|
|
|
|
|
|
172 |
Sources Read: {renderSourcesRead()}
|
173 |
</p>
|
174 |
</div>
|
frontend/src/Components/AiComponents/Graph.css
CHANGED
@@ -21,7 +21,7 @@
|
|
21 |
width: 45% !important;
|
22 |
max-width: 100% !important;
|
23 |
background-color: #1e1e1e !important;
|
24 |
-
|
25 |
overflow: hidden !important; /* Prevent scrolling */
|
26 |
}
|
27 |
|
@@ -62,7 +62,7 @@
|
|
62 |
.graph-dialog-content {
|
63 |
padding: 0 !important;
|
64 |
background-color: #1e1e1e !important;
|
65 |
-
height:
|
66 |
overflow: hidden !important;
|
67 |
}
|
68 |
|
|
|
21 |
width: 45% !important;
|
22 |
max-width: 100% !important;
|
23 |
background-color: #1e1e1e !important;
|
24 |
+
min-height: 80vh !important;
|
25 |
overflow: hidden !important; /* Prevent scrolling */
|
26 |
}
|
27 |
|
|
|
62 |
.graph-dialog-content {
|
63 |
padding: 0 !important;
|
64 |
background-color: #1e1e1e !important;
|
65 |
+
height: 550px !important;
|
66 |
overflow: hidden !important;
|
67 |
}
|
68 |
|
frontend/src/Components/AiComponents/Sources.css
CHANGED
@@ -6,6 +6,10 @@
|
|
6 |
padding: 0 !important;
|
7 |
}
|
8 |
|
|
|
|
|
|
|
|
|
9 |
/* Styling for each Card component */
|
10 |
.source-card {
|
11 |
background-color: #3e3e3eec !important;
|
|
|
6 |
padding: 0 !important;
|
7 |
}
|
8 |
|
9 |
+
.loading-sources {
|
10 |
+
padding: 1rem !important;
|
11 |
+
}
|
12 |
+
|
13 |
/* Styling for each Card component */
|
14 |
.source-card {
|
15 |
background-color: #3e3e3eec !important;
|
frontend/src/Components/AiComponents/Sources.js
CHANGED
@@ -65,7 +65,7 @@ export default function Sources({ sources, handleSourceClick }) {
|
|
65 |
if (loading) {
|
66 |
return (
|
67 |
<Box className="sources-container">
|
68 |
-
<Typography variant="body2">Loading Sources...</Typography>
|
69 |
</Box>
|
70 |
);
|
71 |
}
|
|
|
65 |
if (loading) {
|
66 |
return (
|
67 |
<Box className="sources-container">
|
68 |
+
<Typography className="loading-sources" variant="body2">Loading Sources...</Typography>
|
69 |
</Box>
|
70 |
);
|
71 |
}
|
main.py
CHANGED
@@ -445,6 +445,10 @@ async def process_query(user_query: str, sse_queue: asyncio.Queue):
|
|
445 |
|
446 |
contents = data["contents"]
|
447 |
current_search_contents.extend(contents)
|
|
|
|
|
|
|
|
|
448 |
|
449 |
state['graph_rag'].set_on_event_callback(on_event_callback)
|
450 |
|
|
|
445 |
|
446 |
contents = data["contents"]
|
447 |
current_search_contents.extend(contents)
|
448 |
+
|
449 |
+
elif event_type == "search_process_completed":
|
450 |
+
await sse_queue.put(("step", "Processing final graph tasks..."))
|
451 |
+
await asyncio.sleep(0.01) # Sleep for a short time to allow the message to be sent
|
452 |
|
453 |
state['graph_rag'].set_on_event_callback(on_event_callback)
|
454 |
|
requirements.txt
CHANGED
@@ -18,7 +18,8 @@ langchain_xai==0.1.1
|
|
18 |
langgraph==0.2.62
|
19 |
model2vec==0.3.3
|
20 |
neo4j==5.26.0
|
21 |
-
rustworkx
|
|
|
22 |
openai==1.59.3
|
23 |
protobuf==4.23.4
|
24 |
PyPDF2==3.0.1
|
@@ -32,4 +33,4 @@ transformers==4.46.2
|
|
32 |
xformers==0.0.29.post1
|
33 |
# Intall the following seperately:-
|
34 |
# latest torch cuda version compatible with xformers' version
|
35 |
-
# python -m spacy download en_core_web_sm
|
|
|
18 |
langgraph==0.2.62
|
19 |
model2vec==0.3.3
|
20 |
neo4j==5.26.0
|
21 |
+
rustworkx===0.16.0
|
22 |
+
rank_bm25==0.2.2
|
23 |
openai==1.59.3
|
24 |
protobuf==4.23.4
|
25 |
PyPDF2==3.0.1
|
|
|
33 |
xformers==0.0.29.post1
|
34 |
# Intall the following seperately:-
|
35 |
# latest torch cuda version compatible with xformers' version
|
36 |
+
# python -m spacy download en_core_web_sm
|
src/rag/graph_rag.py
CHANGED
@@ -17,6 +17,9 @@ from src.search.search_engine import SearchEngine
|
|
17 |
from src.crawl.crawler import CustomCrawler #, Crawler
|
18 |
from sentence_transformers import SentenceTransformer
|
19 |
from bert_score.scorer import BERTScorer
|
|
|
|
|
|
|
20 |
|
21 |
class GraphRAG:
|
22 |
def __init__(self, num_workers: int = 1):
|
@@ -54,6 +57,9 @@ class GraphRAG:
|
|
54 |
self.sub_node_counter = 0
|
55 |
self.cross_connections = set()
|
56 |
|
|
|
|
|
|
|
57 |
# Thread pool
|
58 |
self.executor = ThreadPoolExecutor(max_workers=self.num_workers)
|
59 |
|
@@ -379,10 +385,13 @@ class GraphRAG:
|
|
379 |
)
|
380 |
|
381 |
async def build_graph(self, query: str, data: str = None, parent_node_id: str = None,
|
382 |
-
|
383 |
-
|
|
|
|
|
384 |
"""Build a new graph structure in memory."""
|
385 |
-
async def process_node(node_id: str, sub_query: str, session_id: str,
|
|
|
386 |
try:
|
387 |
optimized_query = await self.search_engine.generate_optimized_query(sub_query)
|
388 |
results = await self.search_engine.search(
|
@@ -427,14 +436,7 @@ class GraphRAG:
|
|
427 |
elif content:
|
428 |
contents += f"Document {k}:\n{content}\n\n"
|
429 |
|
430 |
-
if contents.strip():
|
431 |
-
if depth == 0:
|
432 |
-
await self.emit_event("sub_query_processed", {
|
433 |
-
"node_id": node_id,
|
434 |
-
"sub_query": sub_query,
|
435 |
-
"contents": contents
|
436 |
-
})
|
437 |
-
|
438 |
token_count = self.llm.get_num_tokens(contents)
|
439 |
if token_count > max_tokens_allowed:
|
440 |
contents = await self.chunking.chunker(
|
@@ -444,13 +446,6 @@ class GraphRAG:
|
|
444 |
)
|
445 |
print(f"Number of tokens in the answer: {token_count}")
|
446 |
print(f"Number of tokens in the content: {self.llm.get_num_tokens(contents)}")
|
447 |
-
else:
|
448 |
-
if depth == 0:
|
449 |
-
await self.emit_event("sub_query_failed", {
|
450 |
-
"node_id": node_id,
|
451 |
-
"sub_query": sub_query,
|
452 |
-
"contents": contents
|
453 |
-
})
|
454 |
|
455 |
graph_data = self._get_current_graph_data()
|
456 |
graph = graph_data["graph"]
|
@@ -460,13 +455,15 @@ class GraphRAG:
|
|
460 |
idx = node_map[node_id]
|
461 |
node_data = graph.get_node_data(idx)
|
462 |
node_data["data"] = contents
|
463 |
-
future.
|
|
|
464 |
except Exception as e:
|
465 |
print(f"Error processing node {node_id}: {str(e)}")
|
466 |
-
future.
|
467 |
-
|
|
|
468 |
|
469 |
-
async def process_dependent_node(node_id: str, sub_query: str,
|
470 |
try:
|
471 |
dep_data = [await f for f in dep_futures]
|
472 |
modified_query = await self.query_processor.modify_query(
|
@@ -490,16 +487,16 @@ class GraphRAG:
|
|
490 |
node_data["embedding"] = embedding.tolist() if hasattr(embedding, "tolist") else embedding
|
491 |
try:
|
492 |
if not future.done():
|
493 |
-
await process_node(node_id, modified_query, session_id, future,
|
494 |
except Exception as e:
|
495 |
if not future.done():
|
496 |
future.set_exception(e)
|
497 |
-
raise
|
498 |
except Exception as e:
|
499 |
print(f"Error processing dependent node {node_id}: {str(e)}")
|
500 |
if not future.done():
|
501 |
future.set_exception(e)
|
502 |
-
raise
|
503 |
|
504 |
def create_cross_connections():
|
505 |
try:
|
@@ -548,7 +545,15 @@ class GraphRAG:
|
|
548 |
|
549 |
if context is None:
|
550 |
context = []
|
551 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
552 |
|
553 |
if parent_node_id is None:
|
554 |
self.add_node(self.root_node_id, query, data)
|
@@ -565,9 +570,6 @@ class GraphRAG:
|
|
565 |
context.append(response_data)
|
566 |
|
567 |
if len(sub_queries) > 1 and sub_queries[0] != query:
|
568 |
-
sub_query_ids = []
|
569 |
-
pre_req_nodes = {}
|
570 |
-
|
571 |
for idx, (sub_query, role, dependency) in enumerate(zip(sub_queries, roles, dependencies)):
|
572 |
if depth == 0:
|
573 |
await self.emit_event("sub_query_created", {
|
@@ -589,13 +591,13 @@ class GraphRAG:
|
|
589 |
self.add_node(sub_node_id, sub_query, role=role)
|
590 |
future = asyncio.Future()
|
591 |
node_data_futures[sub_node_id] = future
|
|
|
592 |
|
593 |
if role.lower() in ['pre-requisite', 'prerequisite']:
|
594 |
pre_req_nodes[idx] = sub_node_id
|
595 |
|
596 |
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'):
|
597 |
self.add_edge(parent_node_id, sub_node_id, relationship_type='hierarchical')
|
598 |
-
|
599 |
elif role.lower() == 'dependent':
|
600 |
if isinstance(dependency, list) and (len(dependency) == 2 and all(isinstance(d, list) for d in dependency)):
|
601 |
print(f"Dependency: {dependency}")
|
@@ -618,12 +620,11 @@ class GraphRAG:
|
|
618 |
|
619 |
if current_deps not in [None, []]:
|
620 |
for dep_idx in current_deps:
|
621 |
-
if dep_idx < len(
|
622 |
dep_node_id = sub_query_ids[dep_idx]
|
623 |
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical')
|
624 |
else:
|
625 |
raise ValueError(f"Invalid dependency index: {dep_idx}")
|
626 |
-
|
627 |
elif len(dependency) > 0:
|
628 |
for dep_idx in dependency:
|
629 |
if dep_idx < len(sub_queries):
|
@@ -635,66 +636,6 @@ class GraphRAG:
|
|
635 |
raise ValueError(f"Invalid dependency: {dependency}")
|
636 |
else:
|
637 |
raise ValueError(f"Unexpected role: {role}")
|
638 |
-
|
639 |
-
tasks = []
|
640 |
-
for idx in range(len(sub_queries)):
|
641 |
-
node_id = sub_query_ids[idx]
|
642 |
-
future = node_data_futures[node_id]
|
643 |
-
|
644 |
-
if roles[idx].lower() in ('pre-requisite', 'prerequisite', 'independent'):
|
645 |
-
tasks.append(process_node(node_id, sub_queries[idx], session_id, future, depth, max_tokens_allowed))
|
646 |
-
|
647 |
-
for idx in range(len(sub_queries)):
|
648 |
-
node_id = sub_query_ids[idx]
|
649 |
-
future = node_data_futures[node_id]
|
650 |
-
|
651 |
-
if roles[idx].lower() == 'dependent':
|
652 |
-
dep_futures = []
|
653 |
-
|
654 |
-
if isinstance(dependencies[idx], list) and len(dependencies[idx]) == 2:
|
655 |
-
prev_deps, current_deps = dependencies[idx]
|
656 |
-
if context and prev_deps not in [None, []]:
|
657 |
-
for context_idx, context_data in enumerate(context):
|
658 |
-
if isinstance(prev_deps, list) and context_idx < len(prev_deps):
|
659 |
-
context_dep = prev_deps[context_idx]
|
660 |
-
if context_dep is not None and isinstance(context_data, dict) and 'subqueries' in context_data:
|
661 |
-
if context_dep < len(context_data['subqueries']):
|
662 |
-
dep_query = context_data['subqueries'][context_dep]['subquery']
|
663 |
-
matching_nodes = self.find_nodes_by_properties(query=dep_query)
|
664 |
-
if matching_nodes not in [None, []]:
|
665 |
-
dep_node_id = matching_nodes[0].get('node_id', None)
|
666 |
-
score = float(matching_nodes[0].get('score', 0))
|
667 |
-
if score == 1.0 and dep_node_id in node_data_futures:
|
668 |
-
dep_futures.append(node_data_futures[dep_node_id])
|
669 |
-
|
670 |
-
elif isinstance(prev_deps, int):
|
671 |
-
if prev_deps < len(context_data['subqueries']):
|
672 |
-
dep_query = context_data['subqueries'][prev_deps]['subquery']
|
673 |
-
matching_nodes = self.find_nodes_by_properties(query=dep_query)
|
674 |
-
if matching_nodes not in [None, []]:
|
675 |
-
dep_node_id = matching_nodes[0].get('node_id', None)
|
676 |
-
score = matching_nodes[0].get('score', 0)
|
677 |
-
if score == 1.0 and dep_node_id in node_data_futures:
|
678 |
-
dep_futures.append(node_data_futures[dep_node_id])
|
679 |
-
|
680 |
-
if current_deps not in [None, []]:
|
681 |
-
current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps
|
682 |
-
for dep_idx in current_deps_list:
|
683 |
-
if dep_idx < len(sub_queries):
|
684 |
-
dep_node_id = sub_query_ids[dep_idx]
|
685 |
-
if dep_node_id in node_data_futures:
|
686 |
-
dep_futures.append(node_data_futures[dep_node_id])
|
687 |
-
|
688 |
-
tasks.append(process_dependent_node(node_id, sub_queries[idx], depth, dep_futures, future))
|
689 |
-
|
690 |
-
if depth == 0:
|
691 |
-
await self.emit_event("search_process_started", {
|
692 |
-
"depth": depth,
|
693 |
-
"sub_queries": sub_queries,
|
694 |
-
"roles": roles
|
695 |
-
})
|
696 |
-
|
697 |
-
await asyncio.gather(*tasks)
|
698 |
|
699 |
if recurse:
|
700 |
recursion_tasks = []
|
@@ -710,7 +651,11 @@ class GraphRAG:
|
|
710 |
threshold=threshold,
|
711 |
recurse=recurse,
|
712 |
context=context,
|
713 |
-
session_id=session_id
|
|
|
|
|
|
|
|
|
714 |
)
|
715 |
)
|
716 |
except Exception as e:
|
@@ -721,9 +666,154 @@ class GraphRAG:
|
|
721 |
try:
|
722 |
await asyncio.gather(*recursion_tasks)
|
723 |
except Exception as e:
|
724 |
-
|
725 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
726 |
if depth == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
727 |
print("Graph building complete, processing final tasks...")
|
728 |
await self.emit_event("search_process_completed", {
|
729 |
"depth": depth,
|
@@ -916,6 +1006,8 @@ class GraphRAG:
|
|
916 |
self.update_pagerank()
|
917 |
self.verify_graph_integrity()
|
918 |
self.verify_graph_consistency()
|
|
|
|
|
919 |
except Exception as e:
|
920 |
print(f"Error in process_graph: {str(e)}")
|
921 |
raise
|
@@ -1183,7 +1275,7 @@ class GraphRAG:
|
|
1183 |
graph_data = self._get_current_graph_data()
|
1184 |
graph = graph_data["graph"]
|
1185 |
node_map = graph_data["node_map"]
|
1186 |
-
net = Network(height="
|
1187 |
net.options = {"physics": {"enabled": False}}
|
1188 |
all_nodes = set()
|
1189 |
all_edges = []
|
@@ -1220,17 +1312,11 @@ class GraphRAG:
|
|
1220 |
net.options["layout"] = {"improvedLayout": True}
|
1221 |
net.options["interaction"] = {"dragNodes": True}
|
1222 |
|
1223 |
-
original_dir = os.getcwd()
|
1224 |
-
os.chdir(os.getenv("WRITABLE_DIR", "/tmp"))
|
1225 |
-
|
1226 |
net.save_graph("temp_graph.html")
|
1227 |
|
1228 |
with open("temp_graph.html", "r", encoding="utf-8") as f:
|
1229 |
html_str = f.read()
|
1230 |
-
|
1231 |
os.remove("temp_graph.html")
|
1232 |
-
os.chdir(original_dir)
|
1233 |
-
|
1234 |
return html_str
|
1235 |
|
1236 |
def verify_graph_integrity(self):
|
|
|
17 |
from src.crawl.crawler import CustomCrawler #, Crawler
|
18 |
from sentence_transformers import SentenceTransformer
|
19 |
from bert_score.scorer import BERTScorer
|
20 |
+
from openai import RateLimitError
|
21 |
+
from anthropic import RateLimitError as AnthropicRateLimitError
|
22 |
+
from google.api_core.exceptions import ResourceExhausted
|
23 |
|
24 |
class GraphRAG:
|
25 |
def __init__(self, num_workers: int = 1):
|
|
|
57 |
self.sub_node_counter = 0
|
58 |
self.cross_connections = set()
|
59 |
|
60 |
+
# Semaphore protection
|
61 |
+
self.semaphore = asyncio.Semaphore(min(num_workers * 2, 12))
|
62 |
+
|
63 |
# Thread pool
|
64 |
self.executor = ThreadPoolExecutor(max_workers=self.num_workers)
|
65 |
|
|
|
385 |
)
|
386 |
|
387 |
async def build_graph(self, query: str, data: str = None, parent_node_id: str = None,
|
388 |
+
depth: int = 0, threshold: float = 0.8, recurse: bool = True,
|
389 |
+
context: list = None, session_id: str = None, max_tokens_allowed: int = 128000,
|
390 |
+
node_data_futures: dict = None, sub_nodes_info: list = None,
|
391 |
+
sub_query_ids: list = None, pre_req_nodes: list = None):
|
392 |
"""Build a new graph structure in memory."""
|
393 |
+
async def process_node(node_id: str, sub_query: str, session_id: str,
|
394 |
+
future: asyncio.Future, max_tokens_allowed: int = max_tokens_allowed):
|
395 |
try:
|
396 |
optimized_query = await self.search_engine.generate_optimized_query(sub_query)
|
397 |
results = await self.search_engine.search(
|
|
|
436 |
elif content:
|
437 |
contents += f"Document {k}:\n{content}\n\n"
|
438 |
|
439 |
+
if contents.strip():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
token_count = self.llm.get_num_tokens(contents)
|
441 |
if token_count > max_tokens_allowed:
|
442 |
contents = await self.chunking.chunker(
|
|
|
446 |
)
|
447 |
print(f"Number of tokens in the answer: {token_count}")
|
448 |
print(f"Number of tokens in the content: {self.llm.get_num_tokens(contents)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
|
450 |
graph_data = self._get_current_graph_data()
|
451 |
graph = graph_data["graph"]
|
|
|
455 |
idx = node_map[node_id]
|
456 |
node_data = graph.get_node_data(idx)
|
457 |
node_data["data"] = contents
|
458 |
+
if not future.done():
|
459 |
+
future.set_result(contents)
|
460 |
except Exception as e:
|
461 |
print(f"Error processing node {node_id}: {str(e)}")
|
462 |
+
if not future.done():
|
463 |
+
future.set_exception(e)
|
464 |
+
raise e
|
465 |
|
466 |
+
async def process_dependent_node(node_id: str, sub_query: str, dep_futures: list, future):
|
467 |
try:
|
468 |
dep_data = [await f for f in dep_futures]
|
469 |
modified_query = await self.query_processor.modify_query(
|
|
|
487 |
node_data["embedding"] = embedding.tolist() if hasattr(embedding, "tolist") else embedding
|
488 |
try:
|
489 |
if not future.done():
|
490 |
+
await process_node(node_id, modified_query, session_id, future, max_tokens_allowed)
|
491 |
except Exception as e:
|
492 |
if not future.done():
|
493 |
future.set_exception(e)
|
494 |
+
raise e
|
495 |
except Exception as e:
|
496 |
print(f"Error processing dependent node {node_id}: {str(e)}")
|
497 |
if not future.done():
|
498 |
future.set_exception(e)
|
499 |
+
raise e
|
500 |
|
501 |
def create_cross_connections():
|
502 |
try:
|
|
|
545 |
|
546 |
if context is None:
|
547 |
context = []
|
548 |
+
|
549 |
+
if node_data_futures is None:
|
550 |
+
node_data_futures = {}
|
551 |
+
if sub_nodes_info is None:
|
552 |
+
sub_nodes_info = []
|
553 |
+
if sub_query_ids is None:
|
554 |
+
sub_query_ids = []
|
555 |
+
if pre_req_nodes is None:
|
556 |
+
pre_req_nodes = {}
|
557 |
|
558 |
if parent_node_id is None:
|
559 |
self.add_node(self.root_node_id, query, data)
|
|
|
570 |
context.append(response_data)
|
571 |
|
572 |
if len(sub_queries) > 1 and sub_queries[0] != query:
|
|
|
|
|
|
|
573 |
for idx, (sub_query, role, dependency) in enumerate(zip(sub_queries, roles, dependencies)):
|
574 |
if depth == 0:
|
575 |
await self.emit_event("sub_query_created", {
|
|
|
591 |
self.add_node(sub_node_id, sub_query, role=role)
|
592 |
future = asyncio.Future()
|
593 |
node_data_futures[sub_node_id] = future
|
594 |
+
sub_nodes_info.append((sub_node_id, sub_query, role, dependency, future, depth))
|
595 |
|
596 |
if role.lower() in ['pre-requisite', 'prerequisite']:
|
597 |
pre_req_nodes[idx] = sub_node_id
|
598 |
|
599 |
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'):
|
600 |
self.add_edge(parent_node_id, sub_node_id, relationship_type='hierarchical')
|
|
|
601 |
elif role.lower() == 'dependent':
|
602 |
if isinstance(dependency, list) and (len(dependency) == 2 and all(isinstance(d, list) for d in dependency)):
|
603 |
print(f"Dependency: {dependency}")
|
|
|
620 |
|
621 |
if current_deps not in [None, []]:
|
622 |
for dep_idx in current_deps:
|
623 |
+
if dep_idx < len(sub_query_ids):
|
624 |
dep_node_id = sub_query_ids[dep_idx]
|
625 |
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical')
|
626 |
else:
|
627 |
raise ValueError(f"Invalid dependency index: {dep_idx}")
|
|
|
628 |
elif len(dependency) > 0:
|
629 |
for dep_idx in dependency:
|
630 |
if dep_idx < len(sub_queries):
|
|
|
636 |
raise ValueError(f"Invalid dependency: {dependency}")
|
637 |
else:
|
638 |
raise ValueError(f"Unexpected role: {role}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
639 |
|
640 |
if recurse:
|
641 |
recursion_tasks = []
|
|
|
651 |
threshold=threshold,
|
652 |
recurse=recurse,
|
653 |
context=context,
|
654 |
+
session_id=session_id,
|
655 |
+
node_data_futures=node_data_futures,
|
656 |
+
sub_nodes_info=sub_nodes_info,
|
657 |
+
sub_query_ids=sub_query_ids,
|
658 |
+
pre_req_nodes=pre_req_nodes
|
659 |
)
|
660 |
)
|
661 |
except Exception as e:
|
|
|
666 |
try:
|
667 |
await asyncio.gather(*recursion_tasks)
|
668 |
except Exception as e:
|
669 |
+
print(f"Error during recursive processing: {e}")
|
670 |
+
raise e
|
671 |
+
|
672 |
+
futures = {}
|
673 |
+
all_child_futures = {}
|
674 |
+
process_tasks = []
|
675 |
+
graph_data = self._get_current_graph_data()
|
676 |
+
graph = graph_data["graph"]
|
677 |
+
node_map = graph_data["node_map"]
|
678 |
+
|
679 |
+
for (sub_node_id, sub_query, role, dependency, future, local_depth) in sub_nodes_info:
|
680 |
+
idx = node_map.get(sub_node_id)
|
681 |
+
has_children = False
|
682 |
+
child_futures = []
|
683 |
+
if idx is not None:
|
684 |
+
for (_, child_idx, edge_data) in graph.out_edges(idx):
|
685 |
+
if edge_data.get("type") == "hierarchical":
|
686 |
+
has_children = True
|
687 |
+
child_future = node_data_futures.get(graph.get_node_data(child_idx).get("id"))
|
688 |
+
if child_future:
|
689 |
+
child_futures.append(child_future)
|
690 |
+
if local_depth == 0:
|
691 |
+
futures[sub_query] = future
|
692 |
+
all_child_futures[sub_query] = child_futures
|
693 |
+
if has_children:
|
694 |
+
if not future.done():
|
695 |
+
future.set_result("")
|
696 |
+
else:
|
697 |
+
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'):
|
698 |
+
process_tasks.append(process_node(sub_node_id, sub_query, session_id, future, max_tokens_allowed))
|
699 |
+
elif role.lower() == 'dependent':
|
700 |
+
dep_futures = []
|
701 |
+
if isinstance(dependency, list) and len(dependency) == 2:
|
702 |
+
prev_deps, current_deps = dependency
|
703 |
+
if context and prev_deps not in [None, []]:
|
704 |
+
for context_idx, context_data in enumerate(context):
|
705 |
+
if isinstance(prev_deps, list) and context_idx < len(prev_deps):
|
706 |
+
context_dep = prev_deps[context_idx]
|
707 |
+
if (context_dep is not None and isinstance(context_data, dict)
|
708 |
+
and 'subqueries' in context_data):
|
709 |
+
if context_dep < len(context_data['subqueries']):
|
710 |
+
dep_query = context_data['subqueries'][context_dep]['subquery']
|
711 |
+
matching_nodes = self.find_nodes_by_properties(query=dep_query)
|
712 |
+
if matching_nodes not in [None, []]:
|
713 |
+
dep_node_id = matching_nodes[0].get('node_id', None)
|
714 |
+
score = float(matching_nodes[0].get('score', 0))
|
715 |
+
if score == 1.0 and dep_node_id in node_data_futures:
|
716 |
+
dep_futures.append(node_data_futures[dep_node_id])
|
717 |
+
elif isinstance(prev_deps, int):
|
718 |
+
if context_idx < len(context_data['subqueries']):
|
719 |
+
dep_query = context_data['subqueries'][prev_deps]['subquery']
|
720 |
+
matching_nodes = self.find_nodes_by_properties(query=dep_query)
|
721 |
+
if matching_nodes not in [None, []]:
|
722 |
+
dep_node_id = matching_nodes[0].get('node_id', None)
|
723 |
+
score = matching_nodes[0].get('score', 0)
|
724 |
+
if score == 1.0 and dep_node_id in node_data_futures:
|
725 |
+
dep_futures.append(node_data_futures[dep_node_id])
|
726 |
+
if current_deps not in [None, []]:
|
727 |
+
current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps
|
728 |
+
for dep_idx in current_deps_list:
|
729 |
+
if dep_idx < len(sub_query_ids):
|
730 |
+
dep_node_id = sub_query_ids[dep_idx]
|
731 |
+
if dep_node_id in node_data_futures:
|
732 |
+
dep_futures.append(node_data_futures[dep_node_id])
|
733 |
+
process_tasks.append(process_dependent_node(sub_node_id, sub_query, dep_futures, future))
|
734 |
+
else:
|
735 |
+
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'):
|
736 |
+
process_tasks.append(process_node(sub_node_id, sub_query, session_id, future, max_tokens_allowed))
|
737 |
+
elif role.lower() == 'dependent':
|
738 |
+
dep_futures = []
|
739 |
+
if isinstance(dependency, list) and len(dependency) == 2:
|
740 |
+
prev_deps, current_deps = dependency
|
741 |
+
if context and prev_deps not in [None, []]:
|
742 |
+
for context_idx, context_data in enumerate(context):
|
743 |
+
if isinstance(prev_deps, list) and context_idx < len(prev_deps):
|
744 |
+
context_dep = prev_deps[context_idx]
|
745 |
+
if (context_dep is not None and isinstance(context_data, dict)
|
746 |
+
and 'subqueries' in context_data):
|
747 |
+
if context_dep < len(context_data['subqueries']):
|
748 |
+
dep_query = context_data['subqueries'][context_dep]['subquery']
|
749 |
+
matching_nodes = self.find_nodes_by_properties(query=dep_query)
|
750 |
+
if matching_nodes not in [None, []]:
|
751 |
+
dep_node_id = matching_nodes[0].get('node_id', None)
|
752 |
+
score = float(matching_nodes[0].get('score', 0))
|
753 |
+
if score == 1.0 and dep_node_id in node_data_futures:
|
754 |
+
dep_futures.append(node_data_futures[dep_node_id])
|
755 |
+
elif isinstance(prev_deps, int):
|
756 |
+
if context_idx < len(context_data['subqueries']):
|
757 |
+
dep_query = context_data['subqueries'][prev_deps]['subquery']
|
758 |
+
matching_nodes = self.find_nodes_by_properties(query=dep_query)
|
759 |
+
if matching_nodes not in [None, []]:
|
760 |
+
dep_node_id = matching_nodes[0].get('node_id', None)
|
761 |
+
score = matching_nodes[0].get('score', 0)
|
762 |
+
if score == 1.0 and dep_node_id in node_data_futures:
|
763 |
+
dep_futures.append(node_data_futures[dep_node_id])
|
764 |
+
if current_deps not in [None, []]:
|
765 |
+
current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps
|
766 |
+
for dep_idx in current_deps_list:
|
767 |
+
if dep_idx < len(sub_query_ids):
|
768 |
+
dep_node_id = sub_query_ids[dep_idx]
|
769 |
+
if dep_node_id in node_data_futures:
|
770 |
+
dep_futures.append(node_data_futures[dep_node_id])
|
771 |
+
process_tasks.append(process_dependent_node(sub_node_id, sub_query, dep_futures, future))
|
772 |
+
|
773 |
+
if process_tasks:
|
774 |
+
await self.emit_event("search_process_started", {
|
775 |
+
"depth": depth,
|
776 |
+
"sub_queries": sub_queries,
|
777 |
+
"roles": roles
|
778 |
+
})
|
779 |
+
|
780 |
+
for sub_query, future in futures.items():
|
781 |
+
try:
|
782 |
+
parent_content = future.result().strip()
|
783 |
+
except:
|
784 |
+
parent_content = ""
|
785 |
+
|
786 |
+
child_futures = all_child_futures.get(sub_query)
|
787 |
+
any_child_done = any(cf.done() and cf.result().strip() for cf in child_futures)
|
788 |
+
|
789 |
+
if parent_content or any_child_done:
|
790 |
+
await self.emit_event("sub_query_processed", {"sub_query": sub_query})
|
791 |
+
|
792 |
+
await asyncio.gather(*process_tasks)
|
793 |
+
|
794 |
if depth == 0:
|
795 |
+
for sub_query, future in futures.items():
|
796 |
+
try:
|
797 |
+
parent_content = future.result().strip()
|
798 |
+
except:
|
799 |
+
parent_content = ""
|
800 |
+
|
801 |
+
child_futures = all_child_futures.get(sub_query)
|
802 |
+
no_child_done = not any(cf.done() and cf.result().strip() for cf in child_futures)
|
803 |
+
|
804 |
+
if no_child_done:
|
805 |
+
await self.emit_event("sub_query_failed", {"sub_query": sub_query})
|
806 |
+
|
807 |
+
for idx, (sub_query, future) in enumerate(futures.items(), 1):
|
808 |
+
if future.done() and future.result().strip():
|
809 |
+
print(f"Sub-query {idx} processed successfully")
|
810 |
+
else:
|
811 |
+
child_futures = all_child_futures.get(sub_query)
|
812 |
+
if any(cf.done() and cf.result().strip() for cf in child_futures):
|
813 |
+
print(f"Sub-query {idx} processed successfully because of child nodes")
|
814 |
+
else:
|
815 |
+
print(f"Sub-query {idx} failed to process because of child nodes")
|
816 |
+
|
817 |
print("Graph building complete, processing final tasks...")
|
818 |
await self.emit_event("search_process_completed", {
|
819 |
"depth": depth,
|
|
|
1006 |
self.update_pagerank()
|
1007 |
self.verify_graph_integrity()
|
1008 |
self.verify_graph_consistency()
|
1009 |
+
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError):
|
1010 |
+
pass
|
1011 |
except Exception as e:
|
1012 |
print(f"Error in process_graph: {str(e)}")
|
1013 |
raise
|
|
|
1275 |
graph_data = self._get_current_graph_data()
|
1276 |
graph = graph_data["graph"]
|
1277 |
node_map = graph_data["node_map"]
|
1278 |
+
net = Network(height="530px", width="100%", directed=True, bgcolor="#222222", font_color="white")
|
1279 |
net.options = {"physics": {"enabled": False}}
|
1280 |
all_nodes = set()
|
1281 |
all_edges = []
|
|
|
1312 |
net.options["layout"] = {"improvedLayout": True}
|
1313 |
net.options["interaction"] = {"dragNodes": True}
|
1314 |
|
|
|
|
|
|
|
1315 |
net.save_graph("temp_graph.html")
|
1316 |
|
1317 |
with open("temp_graph.html", "r", encoding="utf-8") as f:
|
1318 |
html_str = f.read()
|
|
|
1319 |
os.remove("temp_graph.html")
|
|
|
|
|
1320 |
return html_str
|
1321 |
|
1322 |
def verify_graph_integrity(self):
|