|
import matplotlib.pyplot as plt |
|
import networkx as nx |
|
from langchain_community.graphs.graph_document import GraphDocument |
|
|
|
|
|
def build_nx_graph(final_graph_document: GraphDocument) -> nx.Graph: |
|
""" |
|
Build a NetworkX graph from a GraphDocument. |
|
|
|
Args: |
|
final_graph_document (GraphDocument): The final graph document. |
|
|
|
Returns: |
|
nx.Graph: The NetworkX graph. |
|
""" |
|
nodes = [node.id for node in final_graph_document.nodes] |
|
unique_nodes = set(nodes) |
|
relationships = final_graph_document.relationships |
|
G = nx.Graph() |
|
G.add_nodes_from(nodes) |
|
for relation in relationships: |
|
G.add_edge(relation.source.id, relation.target.id, type=relation.type) |
|
return G |
|
|
|
|
|
def plot_nx_graph(G: nx.Graph): |
|
""" |
|
Plot the NetworkX graph. |
|
|
|
Args: |
|
G (nx.Graph): The NetworkX graph. |
|
""" |
|
pos = nx.spring_layout(G) |
|
plt.figure(figsize=(20, 20)) |
|
|
|
nx.draw_networkx_nodes(G, pos, node_color="skyblue", node_size=500, alpha=0.9) |
|
nx.draw_networkx_edges(G, pos, edge_color="black", width=5.0, alpha=0.8) |
|
nx.draw_networkx_labels(G, pos, font_size=8, font_color="black") |
|
|
|
|
|
edge_labels = nx.get_edge_attributes(G, "type") |
|
nx.draw_networkx_edge_labels( |
|
G, pos, edge_labels=edge_labels, font_size=5, font_color="red" |
|
) |
|
|
|
|
|
plt.title("Knowledge Graph") |
|
plt.axis("off") |
|
plt.show() |
|
|