axa / graphRAG /get_communities.py
Mayara Ayat
Upload folder using huggingface_hub
f7ab812 verified
import networkx as nx
from typing import List, Dict
import community as community_louvain
import matplotlib.pyplot as plt
from openai import OpenAI
from tqdm import tqdm
def get_partition(graph: nx.Graph) -> List[int]:
"""Get the partition of the graph.
Args:
graph (nx.Graph): The NetworkX graph.
Returns:
List[int]: The partition of the graph.
"""
partition = community_louvain.best_partition(graph)
return partition
def plot_graph_with_communities(G: nx.Graph, partition: Dict[str, int]):
"""
Plot the NetworkX graph with communities.
Args:
G (nx.Graph): The NetworkX graph.
partition (Dict[str, int]): The partition of the graph.
"""
# Draw graph with communities
plt.figure(figsize=(20, 20))
pos = nx.spring_layout(G) # Use spring layout for better visualization
# Draw nodes and edges
nx.draw_networkx_nodes(G, pos, node_color="skyblue", node_size=500, alpha=0.9)
nx.draw_networkx_edges(G, pos, edge_color="black", width=5.0, alpha=0.8)
nx.draw_networkx_labels(G, pos, font_size=8, font_color="black")
# Add edge labels (for the `type` attribute)
edge_labels = nx.get_edge_attributes(G, "type")
nx.draw_networkx_edge_labels(
G, pos, edge_labels=edge_labels, font_size=5, font_color="red"
)
# Draw communities
values = [partition.get(node) for node in G.nodes()]
nx.draw_networkx_nodes(
G, pos, node_size=500, cmap=plt.get_cmap("tab20"), node_color=values
)
# Show the plot
plt.title("Knowledge Graph with Communities")
plt.axis("off") # Turn off the axes for better visualization
plt.show()
def get_communities(graph: nx.Graph) -> List[List[str]]:
"""Get the communities in the graph.
Args:
graph (nx.Graph): The NetworkX graph.
Returns:
List[List[str]]: The list of communities.
"""
partition = get_partition(graph)
plot_graph_with_communities(graph, partition)
c = len(set(partition.values()))
communities = [[k for k, v in partition.items() if v == j] for j in range(c)]
return communities
def summarize_communities(
communities: List[List[str]], graph: nx.Graph, client: OpenAI
) -> List[str]:
"""
Summarize the communities of entities and relationships.
Args:
communities (List[List[str]]): The list of communities.
graph (nx.Graph): The NetworkX graph.
client (OpenAI): The OpenAI client.
Returns:
List[str]: The list of community summaries.
"""
community_summaries = []
for index, community in tqdm(enumerate(communities), total=len(communities)):
subgraph = graph.subgraph(community)
nodes = list(subgraph.nodes)
edges = list(subgraph.edges(data=True))
description = "Entities: " + ", ".join(nodes) + "\nRelationships: "
relationships = []
for edge in edges:
source, target, data = edge
relation = data.get("type", "")
relationships.append(f"{source} -> {data['type']} -> {target}")
description += ", ".join(relationships)
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": """Summarize the following community of entities and relationships.
""",
},
{"role": "user", "content": description},
],
)
summary = response.choices[0].message.content.strip()
community_summaries.append(summary)
return community_summaries