import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
import networkx as nx
# Create a directed graph
G = nx.DiGraph()
# Add nodes
G.add_nodes_from([1, 2, 3, 4, 5, 6, 7, 8, 9])
# Add directed edges
G.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 5), (1, 6), (4,7), (4,8),(5,7), (5,8), (7,9), (8,9), (6,9), (6,6)])
# Initialize the Dash app
app = dash.Dash(__name__)
app.layout = html.Div([
dcc.Dropdown(
id='node-dropdown',
options=[{'label': f'Node {i}', 'value': i} for i in G.nodes],
value=None,
placeholder="Select a node to filter"
),
dcc.Graph(id='network-graph')
])
@app.callback(
Output('network-graph', 'figure'),
Input('node-dropdown', 'value')
)
def update_graph(selected_node):
if selected_node is not None:
nodes_to_filter = [selected_node]
else:
nodes_to_filter = []
filtered_graph = filter_nodes(G, nodes_to_filter)
pos = nx.spring_layout(filtered_graph)
node_trace = go.Scatter(
x=[pos[n][0] for n in filtered_graph.nodes],
y=[pos[n][1] for n in filtered_graph.nodes],
text=list(filtered_graph.nodes),
mode='markers+text',
textposition='top center',
marker=dict(size=20, color='LightSkyBlue', line=dict(width=2))
)
edge_trace = go.Scatter(
x=(),
y=(),
line=dict(width=1.5, color='Gray'),
hoverinfo='none',
mode='lines'
)
annotations = []
for edge in filtered_graph.edges:
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_trace['x'] += (x0, x1, None)
edge_trace['y'] += (y0, y1, None)
annotations.append(
dict(
ax=x0,
ay=y0,
axref='x',
ayref='y',
x=x1,
y=y1,
xref='x',
yref='y',
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=2,
arrowcolor='Gray'
)
)
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
showlegend=False,
hovermode='closest',
margin=dict(b=0, l=0, r=0, t=0),
annotations=annotations,
xaxis=dict(showgrid=False, zeroline=False),
yaxis=dict(showgrid=False, zeroline=False)
))
return fig
def filter_nodes(graph, nodes_to_remove):
filtered_graph = graph.copy()
for node in nodes_to_remove:
if node in filtered_graph:
filtered_graph.remove_node(node)
return filtered_graph
if __name__ == '__main__':
app.run_server(debug=True, port=8050, host='0.0.0.0')