|
import networkx as nx |
|
from typing import List, Dict |
|
import community as community_louvain |
|
import matplotlib.pyplot as plt |
|
from openai import OpenAI |
|
from tqdm import tqdm |
|
|
|
|
|
def get_partition(graph: nx.Graph) -> List[int]: |
|
"""Get the partition of the graph. |
|
|
|
Args: |
|
graph (nx.Graph): The NetworkX graph. |
|
|
|
Returns: |
|
List[int]: The partition of the graph. |
|
""" |
|
partition = community_louvain.best_partition(graph) |
|
return partition |
|
|
|
|
|
def plot_graph_with_communities(G: nx.Graph, partition: Dict[str, int]): |
|
""" |
|
Plot the NetworkX graph with communities. |
|
|
|
Args: |
|
G (nx.Graph): The NetworkX graph. |
|
partition (Dict[str, int]): The partition of the graph. |
|
""" |
|
|
|
|
|
plt.figure(figsize=(20, 20)) |
|
pos = nx.spring_layout(G) |
|
|
|
|
|
nx.draw_networkx_nodes(G, pos, node_color="skyblue", node_size=500, alpha=0.9) |
|
nx.draw_networkx_edges(G, pos, edge_color="black", width=5.0, alpha=0.8) |
|
nx.draw_networkx_labels(G, pos, font_size=8, font_color="black") |
|
|
|
|
|
edge_labels = nx.get_edge_attributes(G, "type") |
|
nx.draw_networkx_edge_labels( |
|
G, pos, edge_labels=edge_labels, font_size=5, font_color="red" |
|
) |
|
|
|
|
|
values = [partition.get(node) for node in G.nodes()] |
|
nx.draw_networkx_nodes( |
|
G, pos, node_size=500, cmap=plt.get_cmap("tab20"), node_color=values |
|
) |
|
|
|
|
|
plt.title("Knowledge Graph with Communities") |
|
plt.axis("off") |
|
plt.show() |
|
|
|
|
|
def get_communities(graph: nx.Graph) -> List[List[str]]: |
|
"""Get the communities in the graph. |
|
|
|
Args: |
|
graph (nx.Graph): The NetworkX graph. |
|
|
|
Returns: |
|
List[List[str]]: The list of communities. |
|
""" |
|
partition = get_partition(graph) |
|
plot_graph_with_communities(graph, partition) |
|
c = len(set(partition.values())) |
|
communities = [[k for k, v in partition.items() if v == j] for j in range(c)] |
|
return communities |
|
|
|
|
|
def summarize_communities( |
|
communities: List[List[str]], graph: nx.Graph, client: OpenAI |
|
) -> List[str]: |
|
""" |
|
Summarize the communities of entities and relationships. |
|
|
|
Args: |
|
communities (List[List[str]]): The list of communities. |
|
graph (nx.Graph): The NetworkX graph. |
|
client (OpenAI): The OpenAI client. |
|
|
|
Returns: |
|
List[str]: The list of community summaries. |
|
""" |
|
community_summaries = [] |
|
for index, community in tqdm(enumerate(communities), total=len(communities)): |
|
subgraph = graph.subgraph(community) |
|
nodes = list(subgraph.nodes) |
|
edges = list(subgraph.edges(data=True)) |
|
description = "Entities: " + ", ".join(nodes) + "\nRelationships: " |
|
relationships = [] |
|
for edge in edges: |
|
source, target, data = edge |
|
relation = data.get("type", "") |
|
relationships.append(f"{source} -> {data['type']} -> {target}") |
|
description += ", ".join(relationships) |
|
|
|
response = client.chat.completions.create( |
|
model="gpt-4o", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": """Summarize the following community of entities and relationships. |
|
""", |
|
}, |
|
{"role": "user", "content": description}, |
|
], |
|
) |
|
summary = response.choices[0].message.content.strip() |
|
community_summaries.append(summary) |
|
return community_summaries |
|
|