axa / graphRAG /graph_nx.py
Mayara Ayat
Upload folder using huggingface_hub
f7ab812 verified
import matplotlib.pyplot as plt
import networkx as nx
from langchain_community.graphs.graph_document import GraphDocument
def build_nx_graph(final_graph_document: GraphDocument) -> nx.Graph:
"""
Build a NetworkX graph from a GraphDocument.
Args:
final_graph_document (GraphDocument): The final graph document.
Returns:
nx.Graph: The NetworkX graph.
"""
nodes = [node.id for node in final_graph_document.nodes]
unique_nodes = set(nodes)
relationships = final_graph_document.relationships
G = nx.Graph()
G.add_nodes_from(nodes)
for relation in relationships:
G.add_edge(relation.source.id, relation.target.id, type=relation.type)
return G
def plot_nx_graph(G: nx.Graph):
"""
Plot the NetworkX graph.
Args:
G (nx.Graph): The NetworkX graph.
"""
pos = nx.spring_layout(G)
plt.figure(figsize=(20, 20))
# Draw nodes and edges
nx.draw_networkx_nodes(G, pos, node_color="skyblue", node_size=500, alpha=0.9)
nx.draw_networkx_edges(G, pos, edge_color="black", width=5.0, alpha=0.8)
nx.draw_networkx_labels(G, pos, font_size=8, font_color="black")
# Add edge labels (for the `type` attribute)
edge_labels = nx.get_edge_attributes(G, "type")
nx.draw_networkx_edge_labels(
G, pos, edge_labels=edge_labels, font_size=5, font_color="red"
)
# Show the plot
plt.title("Knowledge Graph")
plt.axis("off") # Turn off the axes for better visualization
plt.show()