File size: 3,998 Bytes
64ed965
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import json
import networkx as nx
import matplotlib.pyplot as plt
import os

# Define the path to your index.json file
index_file_path = "graphs/index.json"

# Load the data from index.json
def load_index_data(file_path):
    """Load the index.json file and parse its contents."""
    with open(file_path, "r") as file:
        data = json.load(file)
    return data

def load_entity_file(entity_info):
    """Load the entity-specific JSON file if file_path is provided."""
    file_path = entity_info.get("file_path")
    if file_path and os.path.exists(file_path):
        try:
            with open(file_path, "r") as file:
                data = json.load(file)
                return data
        except json.JSONDecodeError as e:
            print(f"Error loading JSON file at {file_path}: {e}")
            return None
    elif file_path:
        print(f"File not found: {file_path}")
    return None

def build_graph(data):
    """Builds a directed graph based on entities and relationships."""
    G = nx.DiGraph()

    # Add nodes for each entity, excluding file placeholders
    excluded_nodes = {"patient_protection._tmp", "phsa_sec_340b", "medicade_tmp"}
    for entity_id, entity_info in data["entities"].items():
        if entity_id in excluded_nodes:
            continue
        label = entity_info.get("label", entity_id)
        G.add_node(entity_id, label=label, domain=entity_info.get("inherits_from", "Default"))

        # Load entity file if specified
        entity_data = load_entity_file(entity_info)
        if isinstance(entity_data, dict):  # Check if the loaded data is a dictionary
            for relationship in entity_data.get("relationships", []):
                source = relationship["source"]
                target = relationship["target"]
                relationship_label = relationship["attributes"].get("relationship", "related_to")
                G.add_edge(source, target, label=relationship_label)
        else:
            print(f"Skipping entity {entity_id} due to invalid data format.")

    # Add edges for each relationship in index.json
    for relationship in data["relationships"]:
        source = relationship["source"]
        target = relationship["target"]
        relationship_label = relationship["attributes"].get("relationship", "related_to")
        G.add_edge(source, target, label=relationship_label)

    return G

# Visualize the graph using Matplotlib
def visualize_graph(G, title="Inferred Contextual Relationships"):
    """Visualizes the graph with nodes and relationships, using domain colors and improved layout."""
    # Use different colors for each domain
    color_map = {
        "Legislation": "lightcoral",
        "Healthcare Systems": "lightgreen",
        "Healthcare Policies": "lightblue",
        "Default": "lightgrey"
    }
    
    # Set node colors based on their domains
    node_colors = [color_map.get(G.nodes[node].get("domain", "Default"), "lightgrey") for node in G.nodes]

    # Use Kamada-Kawai layout for better spacing of nodes
    pos = nx.kamada_kawai_layout(G)

    # Draw nodes with domain-specific colors
    plt.figure(figsize=(15, 10))
    nx.draw_networkx_nodes(G, pos, node_size=3000, node_color=node_colors, alpha=0.8)
    nx.draw_networkx_labels(G, pos, font_size=9, font_color="black", font_weight="bold")

    # Draw edges with labels
    nx.draw_networkx_edges(G, pos, arrowstyle="->", arrowsize=15, edge_color="gray", connectionstyle="arc3,rad=0.1")
    edge_labels = {(u, v): d["label"] for u, v, d in G.edges(data=True)}
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color="red", font_size=8)

    # Set plot title and display
    plt.title(title, fontsize=14)
    plt.axis("off")
    plt.show()

# Main execution
if __name__ == "__main__":
    # Load data from the index.json
    data = load_index_data(index_file_path)

    # Build the graph with entities and relationships
    G = build_graph(data)

    # Visualize the graph
    visualize_graph(G)