File size: 1,512 Bytes
f7ab812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import matplotlib.pyplot as plt
import networkx as nx
from langchain_community.graphs.graph_document import GraphDocument


def build_nx_graph(final_graph_document: GraphDocument) -> nx.Graph:
    """
    Build a NetworkX graph from a GraphDocument.

    Args:
        final_graph_document (GraphDocument): The final graph document.

    Returns:
        nx.Graph: The NetworkX graph.
    """
    nodes = [node.id for node in final_graph_document.nodes]
    unique_nodes = set(nodes)
    relationships = final_graph_document.relationships
    G = nx.Graph()
    G.add_nodes_from(nodes)
    for relation in relationships:
        G.add_edge(relation.source.id, relation.target.id, type=relation.type)
    return G


def plot_nx_graph(G: nx.Graph):
    """
    Plot the NetworkX graph.

    Args:
        G (nx.Graph): The NetworkX graph.
    """
    pos = nx.spring_layout(G)
    plt.figure(figsize=(20, 20))
    # Draw nodes and edges
    nx.draw_networkx_nodes(G, pos, node_color="skyblue", node_size=500, alpha=0.9)
    nx.draw_networkx_edges(G, pos, edge_color="black", width=5.0, alpha=0.8)
    nx.draw_networkx_labels(G, pos, font_size=8, font_color="black")

    # Add edge labels (for the `type` attribute)
    edge_labels = nx.get_edge_attributes(G, "type")
    nx.draw_networkx_edge_labels(
        G, pos, edge_labels=edge_labels, font_size=5, font_color="red"
    )

    # Show the plot
    plt.title("Knowledge Graph")
    plt.axis("off")  # Turn off the axes for better visualization
    plt.show()