File size: 3,035 Bytes
4289090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import plotly.graph_objects as go
import networkx as nx

import plotly.graph_objects as go
import networkx as nx

def create_cytoscape_plot(entities, relationships):
    G = nx.DiGraph()  # Use DiGraph for directed edges

    for entity_id, entity_data in entities.items():
        G.add_node(entity_id, **entity_data)

    for source, relation, target in relationships:
        G.add_edge(source, target, relation=relation)

    pos = nx.spring_layout(G, k=0.5, iterations=50)  # Adjust layout parameters

    edge_trace = go.Scatter(
        x=[],
        y=[],
        line=dict(width=1, color="#888"),
        hoverinfo="text",
        mode="lines",
        text=[],
    )

    node_trace = go.Scatter(
        x=[],
        y=[],
        mode="markers+text",
        hoverinfo="text",
        marker=dict(
            showscale=True,
            colorscale="Viridis",
            reversescale=True,
            color=[],
            size=15,
            colorbar=dict(
                thickness=15,
                title="Node Connections",
                xanchor="left",
                titleside="right",
            ),
            line_width=2,
        ),
        text=[],
        textposition="top center",
    )

    edge_labels = []

    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_trace["x"] += (x0, x1, None)
        edge_trace["y"] += (y0, y1, None)
        
        # Calculate midpoint for edge label
        mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
        edge_labels.append(
            go.Scatter(
                x=[mid_x],
                y=[mid_y],
                mode="text",
                text=[G.edges[edge]["relation"]],
                textposition="middle center",
                hoverinfo="none",
                showlegend=False,
                textfont=dict(size=8),
            )
        )

    for node in G.nodes():
        x, y = pos[node]
        node_trace["x"] += (x,)
        node_trace["y"] += (y,)
        node_info = f"{entities[node]['value']} ({entities[node]['type']})"
        node_trace["text"] += (node_info,)
        node_trace["marker"]["color"] += (len(list(G.neighbors(node))),)

    fig = go.Figure(
        data=[edge_trace, node_trace] + edge_labels,
        layout=go.Layout(
            title="Knowledge Graph",
            titlefont_size=16,
            showlegend=False,
            hovermode="closest",
            margin=dict(b=20, l=5, r=5, t=40),
            annotations=[],
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            width=800,
            height=600,
        ),
    )

    # Enable dragging of nodes
    fig.update_layout(
        newshape=dict(line_color="#009900"),
        # Enable zoom
        xaxis=dict(
            scaleanchor="y",
            scaleratio=1,
        ),
        yaxis=dict(
            scaleanchor="x",
            scaleratio=1,
        ),
    )

    return fig