import requests import networkx as nx import matplotlib.pyplot as plt # API Base URL base_url = "http://localhost:5000" def fetch_relationships(node_id, direction="down"): """Fetch relationships for the specified node in the given direction (up or down).""" response = requests.get(f"{base_url}/traverse_node?node_id={node_id}&direction={direction}") return response.json().get("traversal_path", {}) def build_graph_from_relationships(node_id): """Builds a NetworkX graph based on recursive relationship traversal.""" # Initialize directed graph G = nx.DiGraph() # Collect descendants and ancestors to build the graph structure descendants_data = fetch_relationships(node_id, direction="down") ancestors_data = fetch_relationships(node_id, direction="up") # Recursively add nodes and edges for both descendants and ancestors add_nodes_and_edges(G, descendants_data) add_nodes_and_edges(G, ancestors_data) return G def add_nodes_and_edges(G, node, visited=None): """Recursive function to add nodes and edges from a traversal hierarchy to a NetworkX graph.""" if visited is None: visited = set() node_id = node.get("node_id") if not node_id or node_id in visited: return visited.add(node_id) # Add node to graph G.add_node(node_id, label=node_id) # Process child (descendant) relationships for child in node.get("descendants", []): child_id = child.get("node_id") relationship = child.get("relationship", "related_to") G.add_edge(node_id, child_id, label=relationship) add_nodes_and_edges(G, child, visited) # Recursive call for descendants # Process parent (ancestor) relationships for ancestor in node.get("ancestors", []): ancestor_id = ancestor.get("node_id") relationship = ancestor.get("relationship", "related_to") G.add_edge(ancestor_id, node_id, label=relationship) add_nodes_and_edges(G, ancestor, visited) # Recursive call for ancestors def visualize_graph(G, title="Graph Structure and Relationships"): """Visualize the graph using matplotlib and networkx.""" plt.figure(figsize=(12, 8)) pos = nx.spring_layout(G) # Draw nodes and labels nx.draw_networkx_nodes(G, pos, node_size=3000, node_color="skyblue", alpha=0.8) nx.draw_networkx_labels(G, pos, font_size=10, font_color="black") # Draw edges with labels nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True) edge_labels = {(u, v): d["label"] for u, v, d in G.edges(data=True)} nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color="red") # Title and display options plt.title(title) plt.axis("off") plt.show() # Step 1: Load Graph (Specify the graph to load, e.g., PHSA/340B section) print("\n--- Loading Graph ---") graph_data = {"graph_file": "graphs/PHSA/phsa_sec_340b.json"} response = requests.post(f"{base_url}/load_graph", json=graph_data) print("Load Graph Response:", response.json()) # Step 2: Build and visualize the graph for 340B Program print("\n--- Building Graph for Visualization ---") G = build_graph_from_relationships("340B Program") visualize_graph(G, title="340B Program - Inferred Contextual Relationships")