|
import requests |
|
import networkx as nx |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
base_url = "http://localhost:5000" |
|
|
|
def fetch_relationships(node_id, direction="down"): |
|
"""Fetch relationships for the specified node in the given direction (up or down).""" |
|
response = requests.get(f"{base_url}/traverse_node?node_id={node_id}&direction={direction}") |
|
return response.json().get("traversal_path", {}) |
|
|
|
def build_graph_from_relationships(node_id): |
|
"""Builds a NetworkX graph based on recursive relationship traversal.""" |
|
|
|
G = nx.DiGraph() |
|
|
|
|
|
descendants_data = fetch_relationships(node_id, direction="down") |
|
ancestors_data = fetch_relationships(node_id, direction="up") |
|
|
|
|
|
add_nodes_and_edges(G, descendants_data) |
|
add_nodes_and_edges(G, ancestors_data) |
|
|
|
return G |
|
|
|
def add_nodes_and_edges(G, node, visited=None): |
|
"""Recursive function to add nodes and edges from a traversal hierarchy to a NetworkX graph.""" |
|
if visited is None: |
|
visited = set() |
|
|
|
node_id = node.get("node_id") |
|
if not node_id or node_id in visited: |
|
return |
|
visited.add(node_id) |
|
|
|
|
|
G.add_node(node_id, label=node_id) |
|
|
|
|
|
for child in node.get("descendants", []): |
|
child_id = child.get("node_id") |
|
relationship = child.get("relationship", "related_to") |
|
G.add_edge(node_id, child_id, label=relationship) |
|
add_nodes_and_edges(G, child, visited) |
|
|
|
|
|
for ancestor in node.get("ancestors", []): |
|
ancestor_id = ancestor.get("node_id") |
|
relationship = ancestor.get("relationship", "related_to") |
|
G.add_edge(ancestor_id, node_id, label=relationship) |
|
add_nodes_and_edges(G, ancestor, visited) |
|
|
|
def visualize_graph(G, title="Graph Structure and Relationships"): |
|
"""Visualize the graph using matplotlib and networkx.""" |
|
plt.figure(figsize=(12, 8)) |
|
pos = nx.spring_layout(G) |
|
|
|
|
|
nx.draw_networkx_nodes(G, pos, node_size=3000, node_color="skyblue", alpha=0.8) |
|
nx.draw_networkx_labels(G, pos, font_size=10, font_color="black") |
|
|
|
|
|
nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True) |
|
edge_labels = {(u, v): d["label"] for u, v, d in G.edges(data=True)} |
|
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color="red") |
|
|
|
|
|
plt.title(title) |
|
plt.axis("off") |
|
plt.show() |
|
|
|
|
|
print("\n--- Loading Graph ---") |
|
graph_data = {"graph_file": "graphs/PHSA/phsa_sec_340b.json"} |
|
response = requests.post(f"{base_url}/load_graph", json=graph_data) |
|
print("Load Graph Response:", response.json()) |
|
|
|
|
|
print("\n--- Building Graph for Visualization ---") |
|
G = build_graph_from_relationships("340B Program") |
|
visualize_graph(G, title="340B Program - Inferred Contextual Relationships") |