Remsky's picture
Super-squash branch 'main' using huggingface_hub
4289090 verified
raw
history blame
3.04 kB
import plotly.graph_objects as go
import networkx as nx
import plotly.graph_objects as go
import networkx as nx
def create_cytoscape_plot(entities, relationships):
G = nx.DiGraph() # Use DiGraph for directed edges
for entity_id, entity_data in entities.items():
G.add_node(entity_id, **entity_data)
for source, relation, target in relationships:
G.add_edge(source, target, relation=relation)
pos = nx.spring_layout(G, k=0.5, iterations=50) # Adjust layout parameters
edge_trace = go.Scatter(
x=[],
y=[],
line=dict(width=1, color="#888"),
hoverinfo="text",
mode="lines",
text=[],
)
node_trace = go.Scatter(
x=[],
y=[],
mode="markers+text",
hoverinfo="text",
marker=dict(
showscale=True,
colorscale="Viridis",
reversescale=True,
color=[],
size=15,
colorbar=dict(
thickness=15,
title="Node Connections",
xanchor="left",
titleside="right",
),
line_width=2,
),
text=[],
textposition="top center",
)
edge_labels = []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_trace["x"] += (x0, x1, None)
edge_trace["y"] += (y0, y1, None)
# Calculate midpoint for edge label
mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
edge_labels.append(
go.Scatter(
x=[mid_x],
y=[mid_y],
mode="text",
text=[G.edges[edge]["relation"]],
textposition="middle center",
hoverinfo="none",
showlegend=False,
textfont=dict(size=8),
)
)
for node in G.nodes():
x, y = pos[node]
node_trace["x"] += (x,)
node_trace["y"] += (y,)
node_info = f"{entities[node]['value']} ({entities[node]['type']})"
node_trace["text"] += (node_info,)
node_trace["marker"]["color"] += (len(list(G.neighbors(node))),)
fig = go.Figure(
data=[edge_trace, node_trace] + edge_labels,
layout=go.Layout(
title="Knowledge Graph",
titlefont_size=16,
showlegend=False,
hovermode="closest",
margin=dict(b=20, l=5, r=5, t=40),
annotations=[],
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
width=800,
height=600,
),
)
# Enable dragging of nodes
fig.update_layout(
newshape=dict(line_color="#009900"),
# Enable zoom
xaxis=dict(
scaleanchor="y",
scaleratio=1,
),
yaxis=dict(
scaleanchor="x",
scaleratio=1,
),
)
return fig