import plotly.graph_objects as go import networkx as nx import plotly.graph_objects as go import networkx as nx def create_cytoscape_plot(entities, relationships): G = nx.DiGraph() # Use DiGraph for directed edges for entity_id, entity_data in entities.items(): G.add_node(entity_id, **entity_data) for source, relation, target in relationships: G.add_edge(source, target, relation=relation) pos = nx.spring_layout(G, k=0.5, iterations=50) # Adjust layout parameters edge_trace = go.Scatter( x=[], y=[], line=dict(width=1, color="#888"), hoverinfo="text", mode="lines", text=[], ) node_trace = go.Scatter( x=[], y=[], mode="markers+text", hoverinfo="text", marker=dict( showscale=True, colorscale="Viridis", reversescale=True, color=[], size=15, colorbar=dict( thickness=15, title="Node Connections", xanchor="left", titleside="right", ), line_width=2, ), text=[], textposition="top center", ) edge_labels = [] for edge in G.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_trace["x"] += (x0, x1, None) edge_trace["y"] += (y0, y1, None) # Calculate midpoint for edge label mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2 edge_labels.append( go.Scatter( x=[mid_x], y=[mid_y], mode="text", text=[G.edges[edge]["relation"]], textposition="middle center", hoverinfo="none", showlegend=False, textfont=dict(size=8), ) ) for node in G.nodes(): x, y = pos[node] node_trace["x"] += (x,) node_trace["y"] += (y,) node_info = f"{entities[node]['value']} ({entities[node]['type']})" node_trace["text"] += (node_info,) node_trace["marker"]["color"] += (len(list(G.neighbors(node))),) fig = go.Figure( data=[edge_trace, node_trace] + edge_labels, layout=go.Layout( title="Knowledge Graph", titlefont_size=16, showlegend=False, hovermode="closest", margin=dict(b=20, l=5, r=5, t=40), annotations=[], xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), width=800, height=600, ), ) # Enable dragging of nodes fig.update_layout( newshape=dict(line_color="#009900"), # Enable zoom xaxis=dict( scaleanchor="y", scaleratio=1, ), yaxis=dict( scaleanchor="x", scaleratio=1, ), ) return fig