Spaces:
Sleeping
Sleeping
import dash | |
from dash import dcc, html | |
from dash.dependencies import Input, Output | |
import plotly.graph_objects as go | |
import networkx as nx | |
# Create a directed graph | |
G = nx.DiGraph() | |
# Add nodes | |
G.add_nodes_from([1, 2, 3, 4, 5, 6, 7, 8, 9]) | |
# Add directed edges | |
G.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 5), (1, 6), (4,7), (4,8),(5,7), (5,8), (7,9), (8,9), (6,9), (6,6)]) | |
# Initialize the Dash app | |
app = dash.Dash(__name__) | |
app.layout = html.Div([ | |
dcc.Dropdown( | |
id='node-dropdown', | |
options=[{'label': f'Node {i}', 'value': i} for i in G.nodes], | |
value=None, | |
placeholder="Select a node to filter" | |
), | |
dcc.Graph(id='network-graph') | |
]) | |
def update_graph(selected_node): | |
if selected_node is not None: | |
nodes_to_filter = [selected_node] | |
else: | |
nodes_to_filter = [] | |
filtered_graph = filter_nodes(G, nodes_to_filter) | |
pos = nx.spring_layout(filtered_graph) | |
node_trace = go.Scatter( | |
x=[pos[n][0] for n in filtered_graph.nodes], | |
y=[pos[n][1] for n in filtered_graph.nodes], | |
text=list(filtered_graph.nodes), | |
mode='markers+text', | |
textposition='top center', | |
marker=dict(size=20, color='LightSkyBlue', line=dict(width=2)) | |
) | |
edge_trace = go.Scatter( | |
x=(), | |
y=(), | |
line=dict(width=1.5, color='Gray'), | |
hoverinfo='none', | |
mode='lines' | |
) | |
annotations = [] | |
for edge in filtered_graph.edges: | |
x0, y0 = pos[edge[0]] | |
x1, y1 = pos[edge[1]] | |
edge_trace['x'] += (x0, x1, None) | |
edge_trace['y'] += (y0, y1, None) | |
annotations.append( | |
dict( | |
ax=x0, | |
ay=y0, | |
axref='x', | |
ayref='y', | |
x=x1, | |
y=y1, | |
xref='x', | |
yref='y', | |
showarrow=True, | |
arrowhead=2, | |
arrowsize=1, | |
arrowwidth=2, | |
arrowcolor='Gray' | |
) | |
) | |
fig = go.Figure(data=[edge_trace, node_trace], | |
layout=go.Layout( | |
showlegend=False, | |
hovermode='closest', | |
margin=dict(b=0, l=0, r=0, t=0), | |
annotations=annotations, | |
xaxis=dict(showgrid=False, zeroline=False), | |
yaxis=dict(showgrid=False, zeroline=False) | |
)) | |
return fig | |
def filter_nodes(graph, nodes_to_remove): | |
filtered_graph = graph.copy() | |
for node in nodes_to_remove: | |
if node in filtered_graph: | |
filtered_graph.remove_node(node) | |
return filtered_graph | |
if __name__ == '__main__': | |
app.run_server(debug=True, port=8050, host='0.0.0.0') | |