import requests |
import networkx as nx |
import matplotlib.pyplot as plt |
base_url = "http://localhost:5000" |
def fetch_relationships(node_id, direction="down"): |
"""Fetch relationships for the specified node in the given direction (up or down).""" |
response = requests.get(f"{base_url}/traverse_node?node_id={node_id}&direction={direction}") |
return response.json().get("traversal_path", {}) |
def build_graph_from_relationships(node_id): |
"""Builds a NetworkX graph based on recursive relationship traversal.""" |
G = nx.DiGraph() |
descendants_data = fetch_relationships(node_id, direction="down") |
ancestors_data = fetch_relationships(node_id, direction="up") |
add_nodes_and_edges(G, descendants_data) |
add_nodes_and_edges(G, ancestors_data) |
return G |
def add_nodes_and_edges(G, node, visited=None): |
"""Recursive function to add nodes and edges from a traversal hierarchy to a NetworkX graph.""" |
if visited is None: |
visited = set() |
node_id = node.get("node_id") |
if not node_id or node_id in visited: |
return |
visited.add(node_id) |
G.add_node(node_id, label=node_id) |
for child in node.get("descendants", []): |
child_id = child.get("node_id") |
relationship = child.get("relationship", "related_to") |
G.add_edge(node_id, child_id, label=relationship) |
add_nodes_and_edges(G, child, visited) |
for ancestor in node.get("ancestors", []): |
ancestor_id = ancestor.get("node_id") |
relationship = ancestor.get("relationship", "related_to") |
G.add_edge(ancestor_id, node_id, label=relationship) |
add_nodes_and_edges(G, ancestor, visited) |
def visualize_graph(G, title="Graph Structure and Relationships"): |
"""Visualize the graph using matplotlib and networkx.""" |
plt.figure(figsize=(12, 8)) |
pos = nx.spring_layout(G) |
nx.draw_networkx_nodes(G, pos, node_size=3000, node_color="skyblue", alpha=0.8) |
nx.draw_networkx_labels(G, pos, font_size=10, font_color="black") |
nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True) |
edge_labels = {(u, v): d["label"] for u, v, d in G.edges(data=True)} |
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color="red") |
plt.title(title) |
plt.axis("off") |
plt.show() |
print("\n--- Loading Graph ---") |
graph_data = {"graph_file": "graphs/PHSA/phsa_sec_340b.json"} |
response = requests.post(f"{base_url}/load_graph", json=graph_data) |
print("Load Graph Response:", response.json()) |
print("\n--- Building Graph for Visualization ---") |
G = build_graph_from_relationships("340B Program") |
visualize_graph(G, title="340B Program - Inferred Contextual Relationships") |