Hemang Thakur commited on
Commit
c3d9a20
·
1 Parent(s): c8abe84

made changes to graph rag file and some ui elements

Browse files
frontend/src/Components/AiComponents/ChatWindow.css CHANGED
@@ -127,6 +127,7 @@
127
  .markdown {
128
  margin: -1rem 0 -0.8rem 0;
129
  line-height: 2rem;
 
130
  }
131
 
132
  /* Post-answer icons container: placed below the bot bubble */
@@ -275,4 +276,4 @@
275
  100% {
276
  transform: rotate(360deg);
277
  }
278
- }
 
127
  .markdown {
128
  margin: -1rem 0 -0.8rem 0;
129
  line-height: 2rem;
130
+ white-space: pre-wrap;
131
  }
132
 
133
  /* Post-answer icons container: placed below the bot bubble */
 
276
  100% {
277
  transform: rotate(360deg);
278
  }
279
+ }
frontend/src/Components/AiComponents/ChatWindow.js CHANGED
@@ -168,10 +168,7 @@ function ChatWindow({
168
  )}
169
  {renderSourcesRead() !== null && (
170
  <div className="sources-read-container">
171
- <p
172
- className="sources-read"
173
- onClick={() => openRightSidebar("sources", blockId)}
174
- >
175
  Sources Read: {renderSourcesRead()}
176
  </p>
177
  </div>
 
168
  )}
169
  {renderSourcesRead() !== null && (
170
  <div className="sources-read-container">
171
+ <p className="sources-read">
 
 
 
172
  Sources Read: {renderSourcesRead()}
173
  </p>
174
  </div>
frontend/src/Components/AiComponents/Graph.css CHANGED
@@ -21,7 +21,7 @@
21
  width: 45% !important;
22
  max-width: 100% !important;
23
  background-color: #1e1e1e !important;
24
- max-height: 80vh !important;
25
  overflow: hidden !important; /* Prevent scrolling */
26
  }
27
 
@@ -62,7 +62,7 @@
62
  .graph-dialog-content {
63
  padding: 0 !important;
64
  background-color: #1e1e1e !important;
65
- height: 750px !important;
66
  overflow: hidden !important;
67
  }
68
 
 
21
  width: 45% !important;
22
  max-width: 100% !important;
23
  background-color: #1e1e1e !important;
24
+ min-height: 80vh !important;
25
  overflow: hidden !important; /* Prevent scrolling */
26
  }
27
 
 
62
  .graph-dialog-content {
63
  padding: 0 !important;
64
  background-color: #1e1e1e !important;
65
+ height: 550px !important;
66
  overflow: hidden !important;
67
  }
68
 
frontend/src/Components/AiComponents/Sources.css CHANGED
@@ -6,6 +6,10 @@
6
  padding: 0 !important;
7
  }
8
 
 
 
 
 
9
  /* Styling for each Card component */
10
  .source-card {
11
  background-color: #3e3e3eec !important;
 
6
  padding: 0 !important;
7
  }
8
 
9
+ .loading-sources {
10
+ padding: 1rem !important;
11
+ }
12
+
13
  /* Styling for each Card component */
14
  .source-card {
15
  background-color: #3e3e3eec !important;
frontend/src/Components/AiComponents/Sources.js CHANGED
@@ -65,7 +65,7 @@ export default function Sources({ sources, handleSourceClick }) {
65
  if (loading) {
66
  return (
67
  <Box className="sources-container">
68
- <Typography variant="body2">Loading Sources...</Typography>
69
  </Box>
70
  );
71
  }
 
65
  if (loading) {
66
  return (
67
  <Box className="sources-container">
68
+ <Typography className="loading-sources" variant="body2">Loading Sources...</Typography>
69
  </Box>
70
  );
71
  }
main.py CHANGED
@@ -445,6 +445,10 @@ async def process_query(user_query: str, sse_queue: asyncio.Queue):
445
 
446
  contents = data["contents"]
447
  current_search_contents.extend(contents)
 
 
 
 
448
 
449
  state['graph_rag'].set_on_event_callback(on_event_callback)
450
 
 
445
 
446
  contents = data["contents"]
447
  current_search_contents.extend(contents)
448
+
449
+ elif event_type == "search_process_completed":
450
+ await sse_queue.put(("step", "Processing final graph tasks..."))
451
+ await asyncio.sleep(0.01) # Sleep for a short time to allow the message to be sent
452
 
453
  state['graph_rag'].set_on_event_callback(on_event_callback)
454
 
requirements.txt CHANGED
@@ -18,7 +18,8 @@ langchain_xai==0.1.1
18
  langgraph==0.2.62
19
  model2vec==0.3.3
20
  neo4j==5.26.0
21
- rustworkx==0.16.0
 
22
  openai==1.59.3
23
  protobuf==4.23.4
24
  PyPDF2==3.0.1
@@ -32,4 +33,4 @@ transformers==4.46.2
32
  xformers==0.0.29.post1
33
  # Intall the following seperately:-
34
  # latest torch cuda version compatible with xformers' version
35
- # python -m spacy download en_core_web_sm
 
18
  langgraph==0.2.62
19
  model2vec==0.3.3
20
  neo4j==5.26.0
21
+ rustworkx===0.16.0
22
+ rank_bm25==0.2.2
23
  openai==1.59.3
24
  protobuf==4.23.4
25
  PyPDF2==3.0.1
 
33
  xformers==0.0.29.post1
34
  # Intall the following seperately:-
35
  # latest torch cuda version compatible with xformers' version
36
+ # python -m spacy download en_core_web_sm
src/rag/graph_rag.py CHANGED
@@ -17,6 +17,9 @@ from src.search.search_engine import SearchEngine
17
  from src.crawl.crawler import CustomCrawler #, Crawler
18
  from sentence_transformers import SentenceTransformer
19
  from bert_score.scorer import BERTScorer
 
 
 
20
 
21
  class GraphRAG:
22
  def __init__(self, num_workers: int = 1):
@@ -54,6 +57,9 @@ class GraphRAG:
54
  self.sub_node_counter = 0
55
  self.cross_connections = set()
56
 
 
 
 
57
  # Thread pool
58
  self.executor = ThreadPoolExecutor(max_workers=self.num_workers)
59
 
@@ -379,10 +385,13 @@ class GraphRAG:
379
  )
380
 
381
  async def build_graph(self, query: str, data: str = None, parent_node_id: str = None,
382
- depth: int = 0, threshold: float = 0.8, recurse: bool = True,
383
- context: list = None, session_id: str = None, max_tokens_allowed: int = 128000):
 
 
384
  """Build a new graph structure in memory."""
385
- async def process_node(node_id: str, sub_query: str, session_id: str, future: asyncio.Future, depth=depth, max_tokens_allowed=max_tokens_allowed):
 
386
  try:
387
  optimized_query = await self.search_engine.generate_optimized_query(sub_query)
388
  results = await self.search_engine.search(
@@ -427,14 +436,7 @@ class GraphRAG:
427
  elif content:
428
  contents += f"Document {k}:\n{content}\n\n"
429
 
430
- if contents.strip():
431
- if depth == 0:
432
- await self.emit_event("sub_query_processed", {
433
- "node_id": node_id,
434
- "sub_query": sub_query,
435
- "contents": contents
436
- })
437
-
438
  token_count = self.llm.get_num_tokens(contents)
439
  if token_count > max_tokens_allowed:
440
  contents = await self.chunking.chunker(
@@ -444,13 +446,6 @@ class GraphRAG:
444
  )
445
  print(f"Number of tokens in the answer: {token_count}")
446
  print(f"Number of tokens in the content: {self.llm.get_num_tokens(contents)}")
447
- else:
448
- if depth == 0:
449
- await self.emit_event("sub_query_failed", {
450
- "node_id": node_id,
451
- "sub_query": sub_query,
452
- "contents": contents
453
- })
454
 
455
  graph_data = self._get_current_graph_data()
456
  graph = graph_data["graph"]
@@ -460,13 +455,15 @@ class GraphRAG:
460
  idx = node_map[node_id]
461
  node_data = graph.get_node_data(idx)
462
  node_data["data"] = contents
463
- future.set_result(contents)
 
464
  except Exception as e:
465
  print(f"Error processing node {node_id}: {str(e)}")
466
- future.set_exception(e)
467
- raise
 
468
 
469
- async def process_dependent_node(node_id: str, sub_query: str, depth, dep_futures: list, future):
470
  try:
471
  dep_data = [await f for f in dep_futures]
472
  modified_query = await self.query_processor.modify_query(
@@ -490,16 +487,16 @@ class GraphRAG:
490
  node_data["embedding"] = embedding.tolist() if hasattr(embedding, "tolist") else embedding
491
  try:
492
  if not future.done():
493
- await process_node(node_id, modified_query, session_id, future, depth, max_tokens_allowed)
494
  except Exception as e:
495
  if not future.done():
496
  future.set_exception(e)
497
- raise
498
  except Exception as e:
499
  print(f"Error processing dependent node {node_id}: {str(e)}")
500
  if not future.done():
501
  future.set_exception(e)
502
- raise
503
 
504
  def create_cross_connections():
505
  try:
@@ -548,7 +545,15 @@ class GraphRAG:
548
 
549
  if context is None:
550
  context = []
551
- node_data_futures = {}
 
 
 
 
 
 
 
 
552
 
553
  if parent_node_id is None:
554
  self.add_node(self.root_node_id, query, data)
@@ -565,9 +570,6 @@ class GraphRAG:
565
  context.append(response_data)
566
 
567
  if len(sub_queries) > 1 and sub_queries[0] != query:
568
- sub_query_ids = []
569
- pre_req_nodes = {}
570
-
571
  for idx, (sub_query, role, dependency) in enumerate(zip(sub_queries, roles, dependencies)):
572
  if depth == 0:
573
  await self.emit_event("sub_query_created", {
@@ -589,13 +591,13 @@ class GraphRAG:
589
  self.add_node(sub_node_id, sub_query, role=role)
590
  future = asyncio.Future()
591
  node_data_futures[sub_node_id] = future
 
592
 
593
  if role.lower() in ['pre-requisite', 'prerequisite']:
594
  pre_req_nodes[idx] = sub_node_id
595
 
596
  if role.lower() in ('pre-requisite', 'prerequisite', 'independent'):
597
  self.add_edge(parent_node_id, sub_node_id, relationship_type='hierarchical')
598
-
599
  elif role.lower() == 'dependent':
600
  if isinstance(dependency, list) and (len(dependency) == 2 and all(isinstance(d, list) for d in dependency)):
601
  print(f"Dependency: {dependency}")
@@ -618,12 +620,11 @@ class GraphRAG:
618
 
619
  if current_deps not in [None, []]:
620
  for dep_idx in current_deps:
621
- if dep_idx < len(sub_queries):
622
  dep_node_id = sub_query_ids[dep_idx]
623
  self.add_edge(dep_node_id, sub_node_id, relationship_type='logical')
624
  else:
625
  raise ValueError(f"Invalid dependency index: {dep_idx}")
626
-
627
  elif len(dependency) > 0:
628
  for dep_idx in dependency:
629
  if dep_idx < len(sub_queries):
@@ -635,66 +636,6 @@ class GraphRAG:
635
  raise ValueError(f"Invalid dependency: {dependency}")
636
  else:
637
  raise ValueError(f"Unexpected role: {role}")
638
-
639
- tasks = []
640
- for idx in range(len(sub_queries)):
641
- node_id = sub_query_ids[idx]
642
- future = node_data_futures[node_id]
643
-
644
- if roles[idx].lower() in ('pre-requisite', 'prerequisite', 'independent'):
645
- tasks.append(process_node(node_id, sub_queries[idx], session_id, future, depth, max_tokens_allowed))
646
-
647
- for idx in range(len(sub_queries)):
648
- node_id = sub_query_ids[idx]
649
- future = node_data_futures[node_id]
650
-
651
- if roles[idx].lower() == 'dependent':
652
- dep_futures = []
653
-
654
- if isinstance(dependencies[idx], list) and len(dependencies[idx]) == 2:
655
- prev_deps, current_deps = dependencies[idx]
656
- if context and prev_deps not in [None, []]:
657
- for context_idx, context_data in enumerate(context):
658
- if isinstance(prev_deps, list) and context_idx < len(prev_deps):
659
- context_dep = prev_deps[context_idx]
660
- if context_dep is not None and isinstance(context_data, dict) and 'subqueries' in context_data:
661
- if context_dep < len(context_data['subqueries']):
662
- dep_query = context_data['subqueries'][context_dep]['subquery']
663
- matching_nodes = self.find_nodes_by_properties(query=dep_query)
664
- if matching_nodes not in [None, []]:
665
- dep_node_id = matching_nodes[0].get('node_id', None)
666
- score = float(matching_nodes[0].get('score', 0))
667
- if score == 1.0 and dep_node_id in node_data_futures:
668
- dep_futures.append(node_data_futures[dep_node_id])
669
-
670
- elif isinstance(prev_deps, int):
671
- if prev_deps < len(context_data['subqueries']):
672
- dep_query = context_data['subqueries'][prev_deps]['subquery']
673
- matching_nodes = self.find_nodes_by_properties(query=dep_query)
674
- if matching_nodes not in [None, []]:
675
- dep_node_id = matching_nodes[0].get('node_id', None)
676
- score = matching_nodes[0].get('score', 0)
677
- if score == 1.0 and dep_node_id in node_data_futures:
678
- dep_futures.append(node_data_futures[dep_node_id])
679
-
680
- if current_deps not in [None, []]:
681
- current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps
682
- for dep_idx in current_deps_list:
683
- if dep_idx < len(sub_queries):
684
- dep_node_id = sub_query_ids[dep_idx]
685
- if dep_node_id in node_data_futures:
686
- dep_futures.append(node_data_futures[dep_node_id])
687
-
688
- tasks.append(process_dependent_node(node_id, sub_queries[idx], depth, dep_futures, future))
689
-
690
- if depth == 0:
691
- await self.emit_event("search_process_started", {
692
- "depth": depth,
693
- "sub_queries": sub_queries,
694
- "roles": roles
695
- })
696
-
697
- await asyncio.gather(*tasks)
698
 
699
  if recurse:
700
  recursion_tasks = []
@@ -710,7 +651,11 @@ class GraphRAG:
710
  threshold=threshold,
711
  recurse=recurse,
712
  context=context,
713
- session_id=session_id
 
 
 
 
714
  )
715
  )
716
  except Exception as e:
@@ -721,9 +666,154 @@ class GraphRAG:
721
  try:
722
  await asyncio.gather(*recursion_tasks)
723
  except Exception as e:
724
- raise Exception(f"Error during recursive processing: {e}")
725
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
  if depth == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
  print("Graph building complete, processing final tasks...")
728
  await self.emit_event("search_process_completed", {
729
  "depth": depth,
@@ -916,6 +1006,8 @@ class GraphRAG:
916
  self.update_pagerank()
917
  self.verify_graph_integrity()
918
  self.verify_graph_consistency()
 
 
919
  except Exception as e:
920
  print(f"Error in process_graph: {str(e)}")
921
  raise
@@ -1183,7 +1275,7 @@ class GraphRAG:
1183
  graph_data = self._get_current_graph_data()
1184
  graph = graph_data["graph"]
1185
  node_map = graph_data["node_map"]
1186
- net = Network(height="465px", width="100%", directed=True, bgcolor="#222222", font_color="white")
1187
  net.options = {"physics": {"enabled": False}}
1188
  all_nodes = set()
1189
  all_edges = []
@@ -1220,17 +1312,11 @@ class GraphRAG:
1220
  net.options["layout"] = {"improvedLayout": True}
1221
  net.options["interaction"] = {"dragNodes": True}
1222
 
1223
- original_dir = os.getcwd()
1224
- os.chdir(os.getenv("WRITABLE_DIR", "/tmp"))
1225
-
1226
  net.save_graph("temp_graph.html")
1227
 
1228
  with open("temp_graph.html", "r", encoding="utf-8") as f:
1229
  html_str = f.read()
1230
-
1231
  os.remove("temp_graph.html")
1232
- os.chdir(original_dir)
1233
-
1234
  return html_str
1235
 
1236
  def verify_graph_integrity(self):
 
17
  from src.crawl.crawler import CustomCrawler #, Crawler
18
  from sentence_transformers import SentenceTransformer
19
  from bert_score.scorer import BERTScorer
20
+ from openai import RateLimitError
21
+ from anthropic import RateLimitError as AnthropicRateLimitError
22
+ from google.api_core.exceptions import ResourceExhausted
23
 
24
  class GraphRAG:
25
  def __init__(self, num_workers: int = 1):
 
57
  self.sub_node_counter = 0
58
  self.cross_connections = set()
59
 
60
+ # Semaphore protection
61
+ self.semaphore = asyncio.Semaphore(min(num_workers * 2, 12))
62
+
63
  # Thread pool
64
  self.executor = ThreadPoolExecutor(max_workers=self.num_workers)
65
 
 
385
  )
386
 
387
  async def build_graph(self, query: str, data: str = None, parent_node_id: str = None,
388
+ depth: int = 0, threshold: float = 0.8, recurse: bool = True,
389
+ context: list = None, session_id: str = None, max_tokens_allowed: int = 128000,
390
+ node_data_futures: dict = None, sub_nodes_info: list = None,
391
+ sub_query_ids: list = None, pre_req_nodes: list = None):
392
  """Build a new graph structure in memory."""
393
+ async def process_node(node_id: str, sub_query: str, session_id: str,
394
+ future: asyncio.Future, max_tokens_allowed: int = max_tokens_allowed):
395
  try:
396
  optimized_query = await self.search_engine.generate_optimized_query(sub_query)
397
  results = await self.search_engine.search(
 
436
  elif content:
437
  contents += f"Document {k}:\n{content}\n\n"
438
 
439
+ if contents.strip():
 
 
 
 
 
 
 
440
  token_count = self.llm.get_num_tokens(contents)
441
  if token_count > max_tokens_allowed:
442
  contents = await self.chunking.chunker(
 
446
  )
447
  print(f"Number of tokens in the answer: {token_count}")
448
  print(f"Number of tokens in the content: {self.llm.get_num_tokens(contents)}")
 
 
 
 
 
 
 
449
 
450
  graph_data = self._get_current_graph_data()
451
  graph = graph_data["graph"]
 
455
  idx = node_map[node_id]
456
  node_data = graph.get_node_data(idx)
457
  node_data["data"] = contents
458
+ if not future.done():
459
+ future.set_result(contents)
460
  except Exception as e:
461
  print(f"Error processing node {node_id}: {str(e)}")
462
+ if not future.done():
463
+ future.set_exception(e)
464
+ raise e
465
 
466
+ async def process_dependent_node(node_id: str, sub_query: str, dep_futures: list, future):
467
  try:
468
  dep_data = [await f for f in dep_futures]
469
  modified_query = await self.query_processor.modify_query(
 
487
  node_data["embedding"] = embedding.tolist() if hasattr(embedding, "tolist") else embedding
488
  try:
489
  if not future.done():
490
+ await process_node(node_id, modified_query, session_id, future, max_tokens_allowed)
491
  except Exception as e:
492
  if not future.done():
493
  future.set_exception(e)
494
+ raise e
495
  except Exception as e:
496
  print(f"Error processing dependent node {node_id}: {str(e)}")
497
  if not future.done():
498
  future.set_exception(e)
499
+ raise e
500
 
501
  def create_cross_connections():
502
  try:
 
545
 
546
  if context is None:
547
  context = []
548
+
549
+ if node_data_futures is None:
550
+ node_data_futures = {}
551
+ if sub_nodes_info is None:
552
+ sub_nodes_info = []
553
+ if sub_query_ids is None:
554
+ sub_query_ids = []
555
+ if pre_req_nodes is None:
556
+ pre_req_nodes = {}
557
 
558
  if parent_node_id is None:
559
  self.add_node(self.root_node_id, query, data)
 
570
  context.append(response_data)
571
 
572
  if len(sub_queries) > 1 and sub_queries[0] != query:
 
 
 
573
  for idx, (sub_query, role, dependency) in enumerate(zip(sub_queries, roles, dependencies)):
574
  if depth == 0:
575
  await self.emit_event("sub_query_created", {
 
591
  self.add_node(sub_node_id, sub_query, role=role)
592
  future = asyncio.Future()
593
  node_data_futures[sub_node_id] = future
594
+ sub_nodes_info.append((sub_node_id, sub_query, role, dependency, future, depth))
595
 
596
  if role.lower() in ['pre-requisite', 'prerequisite']:
597
  pre_req_nodes[idx] = sub_node_id
598
 
599
  if role.lower() in ('pre-requisite', 'prerequisite', 'independent'):
600
  self.add_edge(parent_node_id, sub_node_id, relationship_type='hierarchical')
 
601
  elif role.lower() == 'dependent':
602
  if isinstance(dependency, list) and (len(dependency) == 2 and all(isinstance(d, list) for d in dependency)):
603
  print(f"Dependency: {dependency}")
 
620
 
621
  if current_deps not in [None, []]:
622
  for dep_idx in current_deps:
623
+ if dep_idx < len(sub_query_ids):
624
  dep_node_id = sub_query_ids[dep_idx]
625
  self.add_edge(dep_node_id, sub_node_id, relationship_type='logical')
626
  else:
627
  raise ValueError(f"Invalid dependency index: {dep_idx}")
 
628
  elif len(dependency) > 0:
629
  for dep_idx in dependency:
630
  if dep_idx < len(sub_queries):
 
636
  raise ValueError(f"Invalid dependency: {dependency}")
637
  else:
638
  raise ValueError(f"Unexpected role: {role}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
 
640
  if recurse:
641
  recursion_tasks = []
 
651
  threshold=threshold,
652
  recurse=recurse,
653
  context=context,
654
+ session_id=session_id,
655
+ node_data_futures=node_data_futures,
656
+ sub_nodes_info=sub_nodes_info,
657
+ sub_query_ids=sub_query_ids,
658
+ pre_req_nodes=pre_req_nodes
659
  )
660
  )
661
  except Exception as e:
 
666
  try:
667
  await asyncio.gather(*recursion_tasks)
668
  except Exception as e:
669
+ print(f"Error during recursive processing: {e}")
670
+ raise e
671
+
672
+ futures = {}
673
+ all_child_futures = {}
674
+ process_tasks = []
675
+ graph_data = self._get_current_graph_data()
676
+ graph = graph_data["graph"]
677
+ node_map = graph_data["node_map"]
678
+
679
+ for (sub_node_id, sub_query, role, dependency, future, local_depth) in sub_nodes_info:
680
+ idx = node_map.get(sub_node_id)
681
+ has_children = False
682
+ child_futures = []
683
+ if idx is not None:
684
+ for (_, child_idx, edge_data) in graph.out_edges(idx):
685
+ if edge_data.get("type") == "hierarchical":
686
+ has_children = True
687
+ child_future = node_data_futures.get(graph.get_node_data(child_idx).get("id"))
688
+ if child_future:
689
+ child_futures.append(child_future)
690
+ if local_depth == 0:
691
+ futures[sub_query] = future
692
+ all_child_futures[sub_query] = child_futures
693
+ if has_children:
694
+ if not future.done():
695
+ future.set_result("")
696
+ else:
697
+ if role.lower() in ('pre-requisite', 'prerequisite', 'independent'):
698
+ process_tasks.append(process_node(sub_node_id, sub_query, session_id, future, max_tokens_allowed))
699
+ elif role.lower() == 'dependent':
700
+ dep_futures = []
701
+ if isinstance(dependency, list) and len(dependency) == 2:
702
+ prev_deps, current_deps = dependency
703
+ if context and prev_deps not in [None, []]:
704
+ for context_idx, context_data in enumerate(context):
705
+ if isinstance(prev_deps, list) and context_idx < len(prev_deps):
706
+ context_dep = prev_deps[context_idx]
707
+ if (context_dep is not None and isinstance(context_data, dict)
708
+ and 'subqueries' in context_data):
709
+ if context_dep < len(context_data['subqueries']):
710
+ dep_query = context_data['subqueries'][context_dep]['subquery']
711
+ matching_nodes = self.find_nodes_by_properties(query=dep_query)
712
+ if matching_nodes not in [None, []]:
713
+ dep_node_id = matching_nodes[0].get('node_id', None)
714
+ score = float(matching_nodes[0].get('score', 0))
715
+ if score == 1.0 and dep_node_id in node_data_futures:
716
+ dep_futures.append(node_data_futures[dep_node_id])
717
+ elif isinstance(prev_deps, int):
718
+ if context_idx < len(context_data['subqueries']):
719
+ dep_query = context_data['subqueries'][prev_deps]['subquery']
720
+ matching_nodes = self.find_nodes_by_properties(query=dep_query)
721
+ if matching_nodes not in [None, []]:
722
+ dep_node_id = matching_nodes[0].get('node_id', None)
723
+ score = matching_nodes[0].get('score', 0)
724
+ if score == 1.0 and dep_node_id in node_data_futures:
725
+ dep_futures.append(node_data_futures[dep_node_id])
726
+ if current_deps not in [None, []]:
727
+ current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps
728
+ for dep_idx in current_deps_list:
729
+ if dep_idx < len(sub_query_ids):
730
+ dep_node_id = sub_query_ids[dep_idx]
731
+ if dep_node_id in node_data_futures:
732
+ dep_futures.append(node_data_futures[dep_node_id])
733
+ process_tasks.append(process_dependent_node(sub_node_id, sub_query, dep_futures, future))
734
+ else:
735
+ if role.lower() in ('pre-requisite', 'prerequisite', 'independent'):
736
+ process_tasks.append(process_node(sub_node_id, sub_query, session_id, future, max_tokens_allowed))
737
+ elif role.lower() == 'dependent':
738
+ dep_futures = []
739
+ if isinstance(dependency, list) and len(dependency) == 2:
740
+ prev_deps, current_deps = dependency
741
+ if context and prev_deps not in [None, []]:
742
+ for context_idx, context_data in enumerate(context):
743
+ if isinstance(prev_deps, list) and context_idx < len(prev_deps):
744
+ context_dep = prev_deps[context_idx]
745
+ if (context_dep is not None and isinstance(context_data, dict)
746
+ and 'subqueries' in context_data):
747
+ if context_dep < len(context_data['subqueries']):
748
+ dep_query = context_data['subqueries'][context_dep]['subquery']
749
+ matching_nodes = self.find_nodes_by_properties(query=dep_query)
750
+ if matching_nodes not in [None, []]:
751
+ dep_node_id = matching_nodes[0].get('node_id', None)
752
+ score = float(matching_nodes[0].get('score', 0))
753
+ if score == 1.0 and dep_node_id in node_data_futures:
754
+ dep_futures.append(node_data_futures[dep_node_id])
755
+ elif isinstance(prev_deps, int):
756
+ if context_idx < len(context_data['subqueries']):
757
+ dep_query = context_data['subqueries'][prev_deps]['subquery']
758
+ matching_nodes = self.find_nodes_by_properties(query=dep_query)
759
+ if matching_nodes not in [None, []]:
760
+ dep_node_id = matching_nodes[0].get('node_id', None)
761
+ score = matching_nodes[0].get('score', 0)
762
+ if score == 1.0 and dep_node_id in node_data_futures:
763
+ dep_futures.append(node_data_futures[dep_node_id])
764
+ if current_deps not in [None, []]:
765
+ current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps
766
+ for dep_idx in current_deps_list:
767
+ if dep_idx < len(sub_query_ids):
768
+ dep_node_id = sub_query_ids[dep_idx]
769
+ if dep_node_id in node_data_futures:
770
+ dep_futures.append(node_data_futures[dep_node_id])
771
+ process_tasks.append(process_dependent_node(sub_node_id, sub_query, dep_futures, future))
772
+
773
+ if process_tasks:
774
+ await self.emit_event("search_process_started", {
775
+ "depth": depth,
776
+ "sub_queries": sub_queries,
777
+ "roles": roles
778
+ })
779
+
780
+ for sub_query, future in futures.items():
781
+ try:
782
+ parent_content = future.result().strip()
783
+ except:
784
+ parent_content = ""
785
+
786
+ child_futures = all_child_futures.get(sub_query)
787
+ any_child_done = any(cf.done() and cf.result().strip() for cf in child_futures)
788
+
789
+ if parent_content or any_child_done:
790
+ await self.emit_event("sub_query_processed", {"sub_query": sub_query})
791
+
792
+ await asyncio.gather(*process_tasks)
793
+
794
  if depth == 0:
795
+ for sub_query, future in futures.items():
796
+ try:
797
+ parent_content = future.result().strip()
798
+ except:
799
+ parent_content = ""
800
+
801
+ child_futures = all_child_futures.get(sub_query)
802
+ no_child_done = not any(cf.done() and cf.result().strip() for cf in child_futures)
803
+
804
+ if no_child_done:
805
+ await self.emit_event("sub_query_failed", {"sub_query": sub_query})
806
+
807
+ for idx, (sub_query, future) in enumerate(futures.items(), 1):
808
+ if future.done() and future.result().strip():
809
+ print(f"Sub-query {idx} processed successfully")
810
+ else:
811
+ child_futures = all_child_futures.get(sub_query)
812
+ if any(cf.done() and cf.result().strip() for cf in child_futures):
813
+ print(f"Sub-query {idx} processed successfully because of child nodes")
814
+ else:
815
+ print(f"Sub-query {idx} failed to process because of child nodes")
816
+
817
  print("Graph building complete, processing final tasks...")
818
  await self.emit_event("search_process_completed", {
819
  "depth": depth,
 
1006
  self.update_pagerank()
1007
  self.verify_graph_integrity()
1008
  self.verify_graph_consistency()
1009
+ except (RateLimitError, ResourceExhausted, AnthropicRateLimitError):
1010
+ pass
1011
  except Exception as e:
1012
  print(f"Error in process_graph: {str(e)}")
1013
  raise
 
1275
  graph_data = self._get_current_graph_data()
1276
  graph = graph_data["graph"]
1277
  node_map = graph_data["node_map"]
1278
+ net = Network(height="530px", width="100%", directed=True, bgcolor="#222222", font_color="white")
1279
  net.options = {"physics": {"enabled": False}}
1280
  all_nodes = set()
1281
  all_edges = []
 
1312
  net.options["layout"] = {"improvedLayout": True}
1313
  net.options["interaction"] = {"dragNodes": True}
1314
 
 
 
 
1315
  net.save_graph("temp_graph.html")
1316
 
1317
  with open("temp_graph.html", "r", encoding="utf-8") as f:
1318
  html_str = f.read()
 
1319
  os.remove("temp_graph.html")
 
 
1320
  return html_str
1321
 
1322
  def verify_graph_integrity(self):