|
import plotly.graph_objects as go |
|
import networkx as nx |
|
|
|
import plotly.graph_objects as go |
|
import networkx as nx |
|
|
|
def create_cytoscape_plot(entities, relationships): |
|
G = nx.DiGraph() |
|
|
|
for entity_id, entity_data in entities.items(): |
|
G.add_node(entity_id, **entity_data) |
|
|
|
for source, relation, target in relationships: |
|
G.add_edge(source, target, relation=relation) |
|
|
|
pos = nx.spring_layout(G, k=0.5, iterations=50) |
|
|
|
edge_trace = go.Scatter( |
|
x=[], |
|
y=[], |
|
line=dict(width=1, color="#888"), |
|
hoverinfo="text", |
|
mode="lines", |
|
text=[], |
|
) |
|
|
|
node_trace = go.Scatter( |
|
x=[], |
|
y=[], |
|
mode="markers+text", |
|
hoverinfo="text", |
|
marker=dict( |
|
showscale=True, |
|
colorscale="Viridis", |
|
reversescale=True, |
|
color=[], |
|
size=15, |
|
colorbar=dict( |
|
thickness=15, |
|
title="Node Connections", |
|
xanchor="left", |
|
titleside="right", |
|
), |
|
line_width=2, |
|
), |
|
text=[], |
|
textposition="top center", |
|
) |
|
|
|
edge_labels = [] |
|
|
|
for edge in G.edges(): |
|
x0, y0 = pos[edge[0]] |
|
x1, y1 = pos[edge[1]] |
|
edge_trace["x"] += (x0, x1, None) |
|
edge_trace["y"] += (y0, y1, None) |
|
|
|
|
|
mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2 |
|
edge_labels.append( |
|
go.Scatter( |
|
x=[mid_x], |
|
y=[mid_y], |
|
mode="text", |
|
text=[G.edges[edge]["relation"]], |
|
textposition="middle center", |
|
hoverinfo="none", |
|
showlegend=False, |
|
textfont=dict(size=8), |
|
) |
|
) |
|
|
|
for node in G.nodes(): |
|
x, y = pos[node] |
|
node_trace["x"] += (x,) |
|
node_trace["y"] += (y,) |
|
node_info = f"{entities[node]['value']} ({entities[node]['type']})" |
|
node_trace["text"] += (node_info,) |
|
node_trace["marker"]["color"] += (len(list(G.neighbors(node))),) |
|
|
|
fig = go.Figure( |
|
data=[edge_trace, node_trace] + edge_labels, |
|
layout=go.Layout( |
|
title="Knowledge Graph", |
|
titlefont_size=16, |
|
showlegend=False, |
|
hovermode="closest", |
|
margin=dict(b=20, l=5, r=5, t=40), |
|
annotations=[], |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
width=800, |
|
height=600, |
|
), |
|
) |
|
|
|
|
|
fig.update_layout( |
|
newshape=dict(line_color="#009900"), |
|
|
|
xaxis=dict( |
|
scaleanchor="y", |
|
scaleratio=1, |
|
), |
|
yaxis=dict( |
|
scaleanchor="x", |
|
scaleratio=1, |
|
), |
|
) |
|
|
|
return fig |