import json import networkx as nx import matplotlib.pyplot as plt import os # Define the path to your index.json file index_file_path = "graphs/index.json" # Load the data from index.json def load_index_data(file_path): """Load the index.json file and parse its contents.""" with open(file_path, "r") as file: data = json.load(file) return data def load_entity_file(entity_info): """Load the entity-specific JSON file if file_path is provided.""" file_path = entity_info.get("file_path") if file_path and os.path.exists(file_path): try: with open(file_path, "r") as file: data = json.load(file) return data except json.JSONDecodeError as e: print(f"Error loading JSON file at {file_path}: {e}") return None elif file_path: print(f"File not found: {file_path}") return None def build_graph(data): """Builds a directed graph based on entities and relationships.""" G = nx.DiGraph() # Add nodes for each entity, excluding file placeholders excluded_nodes = {"patient_protection._tmp", "phsa_sec_340b", "medicade_tmp"} for entity_id, entity_info in data["entities"].items(): if entity_id in excluded_nodes: continue label = entity_info.get("label", entity_id) G.add_node(entity_id, label=label, domain=entity_info.get("inherits_from", "Default")) # Load entity file if specified entity_data = load_entity_file(entity_info) if isinstance(entity_data, dict): # Check if the loaded data is a dictionary for relationship in entity_data.get("relationships", []): source = relationship["source"] target = relationship["target"] relationship_label = relationship["attributes"].get("relationship", "related_to") G.add_edge(source, target, label=relationship_label) else: print(f"Skipping entity {entity_id} due to invalid data format.") # Add edges for each relationship in index.json for relationship in data["relationships"]: source = relationship["source"] target = relationship["target"] relationship_label = relationship["attributes"].get("relationship", "related_to") G.add_edge(source, target, label=relationship_label) return G # Visualize the graph using Matplotlib def visualize_graph(G, title="Inferred Contextual Relationships"): """Visualizes the graph with nodes and relationships, using domain colors and improved layout.""" # Use different colors for each domain color_map = { "Legislation": "lightcoral", "Healthcare Systems": "lightgreen", "Healthcare Policies": "lightblue", "Default": "lightgrey" } # Set node colors based on their domains node_colors = [color_map.get(G.nodes[node].get("domain", "Default"), "lightgrey") for node in G.nodes] # Use Kamada-Kawai layout for better spacing of nodes pos = nx.kamada_kawai_layout(G) # Draw nodes with domain-specific colors plt.figure(figsize=(15, 10)) nx.draw_networkx_nodes(G, pos, node_size=3000, node_color=node_colors, alpha=0.8) nx.draw_networkx_labels(G, pos, font_size=9, font_color="black", font_weight="bold") # Draw edges with labels nx.draw_networkx_edges(G, pos, arrowstyle="->", arrowsize=15, edge_color="gray", connectionstyle="arc3,rad=0.1") edge_labels = {(u, v): d["label"] for u, v, d in G.edges(data=True)} nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color="red", font_size=8) # Set plot title and display plt.title(title, fontsize=14) plt.axis("off") plt.show() # Main execution if __name__ == "__main__": # Load data from the index.json data = load_index_data(index_file_path) # Build the graph with entities and relationships G = build_graph(data) # Visualize the graph visualize_graph(G)