import dash from dash import dcc, html from dash.dependencies import Input, Output import plotly.graph_objects as go import networkx as nx # Create a directed graph G = nx.DiGraph() # Add nodes G.add_nodes_from([1, 2, 3, 4, 5, 6, 7, 8, 9]) # Add directed edges G.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 5), (1, 6), (4,7), (4,8),(5,7), (5,8), (7,9), (8,9), (6,9), (6,6)]) # Initialize the Dash app app = dash.Dash(__name__) app.layout = html.Div([ dcc.Dropdown( id='node-dropdown', options=[{'label': f'Node {i}', 'value': i} for i in G.nodes], value=None, placeholder="Select a node to filter" ), dcc.Graph(id='network-graph') ]) @app.callback( Output('network-graph', 'figure'), Input('node-dropdown', 'value') ) def update_graph(selected_node): if selected_node is not None: nodes_to_filter = [selected_node] else: nodes_to_filter = [] filtered_graph = filter_nodes(G, nodes_to_filter) pos = nx.spring_layout(filtered_graph) node_trace = go.Scatter( x=[pos[n][0] for n in filtered_graph.nodes], y=[pos[n][1] for n in filtered_graph.nodes], text=list(filtered_graph.nodes), mode='markers+text', textposition='top center', marker=dict(size=20, color='LightSkyBlue', line=dict(width=2)) ) edge_trace = go.Scatter( x=(), y=(), line=dict(width=1.5, color='Gray'), hoverinfo='none', mode='lines' ) annotations = [] for edge in filtered_graph.edges: x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_trace['x'] += (x0, x1, None) edge_trace['y'] += (y0, y1, None) annotations.append( dict( ax=x0, ay=y0, axref='x', ayref='y', x=x1, y=y1, xref='x', yref='y', showarrow=True, arrowhead=2, arrowsize=1, arrowwidth=2, arrowcolor='Gray' ) ) fig = go.Figure(data=[edge_trace, node_trace], layout=go.Layout( showlegend=False, hovermode='closest', margin=dict(b=0, l=0, r=0, t=0), annotations=annotations, xaxis=dict(showgrid=False, zeroline=False), yaxis=dict(showgrid=False, zeroline=False) )) return fig def filter_nodes(graph, nodes_to_remove): filtered_graph = graph.copy() for node in nodes_to_remove: if node in filtered_graph: filtered_graph.remove_node(node) return filtered_graph if __name__ == '__main__': app.run_server(debug=True, port=8050, host='0.0.0.0')