File size: 3,654 Bytes
f7ab812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import networkx as nx
from typing import List, Dict
import community as community_louvain
import matplotlib.pyplot as plt
from openai import OpenAI
from tqdm import tqdm


def get_partition(graph: nx.Graph) -> List[int]:
    """Get the partition of the graph.

    Args:
        graph (nx.Graph): The NetworkX graph.

    Returns:
        List[int]: The partition of the graph.
    """
    partition = community_louvain.best_partition(graph)
    return partition


def plot_graph_with_communities(G: nx.Graph, partition: Dict[str, int]):
    """
    Plot the NetworkX graph with communities.

    Args:
        G (nx.Graph): The NetworkX graph.
        partition (Dict[str, int]): The partition of the graph.
    """
    # Draw graph with communities

    plt.figure(figsize=(20, 20))
    pos = nx.spring_layout(G)  # Use spring layout for better visualization

    # Draw nodes and edges
    nx.draw_networkx_nodes(G, pos, node_color="skyblue", node_size=500, alpha=0.9)
    nx.draw_networkx_edges(G, pos, edge_color="black", width=5.0, alpha=0.8)
    nx.draw_networkx_labels(G, pos, font_size=8, font_color="black")

    # Add edge labels (for the `type` attribute)
    edge_labels = nx.get_edge_attributes(G, "type")
    nx.draw_networkx_edge_labels(
        G, pos, edge_labels=edge_labels, font_size=5, font_color="red"
    )

    # Draw communities
    values = [partition.get(node) for node in G.nodes()]
    nx.draw_networkx_nodes(
        G, pos, node_size=500, cmap=plt.get_cmap("tab20"), node_color=values
    )

    # Show the plot
    plt.title("Knowledge Graph with Communities")
    plt.axis("off")  # Turn off the axes for better visualization
    plt.show()


def get_communities(graph: nx.Graph) -> List[List[str]]:
    """Get the communities in the graph.

    Args:
        graph (nx.Graph): The NetworkX graph.

    Returns:
        List[List[str]]: The list of communities.
    """
    partition = get_partition(graph)
    plot_graph_with_communities(graph, partition)
    c = len(set(partition.values()))
    communities = [[k for k, v in partition.items() if v == j] for j in range(c)]
    return communities


def summarize_communities(
    communities: List[List[str]], graph: nx.Graph, client: OpenAI
) -> List[str]:
    """
    Summarize the communities of entities and relationships.

    Args:
        communities (List[List[str]]): The list of communities.
        graph (nx.Graph): The NetworkX graph.
        client (OpenAI): The OpenAI client.

    Returns:
        List[str]: The list of community summaries.
    """
    community_summaries = []
    for index, community in tqdm(enumerate(communities), total=len(communities)):
        subgraph = graph.subgraph(community)
        nodes = list(subgraph.nodes)
        edges = list(subgraph.edges(data=True))
        description = "Entities: " + ", ".join(nodes) + "\nRelationships: "
        relationships = []
        for edge in edges:
            source, target, data = edge
            relation = data.get("type", "")
            relationships.append(f"{source} -> {data['type']} -> {target}")
        description += ", ".join(relationships)

        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "system",
                    "content": """Summarize the following community of entities and relationships. 
                    """,
                },
                {"role": "user", "content": description},
            ],
        )
        summary = response.choices[0].message.content.strip()
        community_summaries.append(summary)
    return community_summaries