File size: 2,881 Bytes
1089f07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
import networkx as nx

# Create a directed graph
G = nx.DiGraph()

# Add nodes
G.add_nodes_from([1, 2, 3, 4, 5, 6, 7, 8, 9])

# Add directed edges
G.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 5), (1, 6), (4,7), (4,8),(5,7), (5,8), (7,9), (8,9), (6,9), (6,6)])

# Initialize the Dash app
app = dash.Dash(__name__)

app.layout = html.Div([
    dcc.Dropdown(
        id='node-dropdown',
        options=[{'label': f'Node {i}', 'value': i} for i in G.nodes],
        value=None,
        placeholder="Select a node to filter"
    ),
    dcc.Graph(id='network-graph')
])

@app.callback(
    Output('network-graph', 'figure'),
    Input('node-dropdown', 'value')
)
def update_graph(selected_node):
    if selected_node is not None:
        nodes_to_filter = [selected_node]
    else:
        nodes_to_filter = []

    filtered_graph = filter_nodes(G, nodes_to_filter)

    pos = nx.spring_layout(filtered_graph)

    node_trace = go.Scatter(
        x=[pos[n][0] for n in filtered_graph.nodes],
        y=[pos[n][1] for n in filtered_graph.nodes],
        text=list(filtered_graph.nodes),
        mode='markers+text',
        textposition='top center',
        marker=dict(size=20, color='LightSkyBlue', line=dict(width=2))
    )

    edge_trace = go.Scatter(
        x=(),
        y=(),
        line=dict(width=1.5, color='Gray'),
        hoverinfo='none',
        mode='lines'
    )

    annotations = []
    for edge in filtered_graph.edges:
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_trace['x'] += (x0, x1, None)
        edge_trace['y'] += (y0, y1, None)

        annotations.append(
            dict(
                ax=x0,
                ay=y0,
                axref='x',
                ayref='y',
                x=x1,
                y=y1,
                xref='x',
                yref='y',
                showarrow=True,
                arrowhead=2,
                arrowsize=1,
                arrowwidth=2,
                arrowcolor='Gray'
            )
        )

    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=0, l=0, r=0, t=0),
                        annotations=annotations,
                        xaxis=dict(showgrid=False, zeroline=False),
                        yaxis=dict(showgrid=False, zeroline=False)
                    ))

    return fig

def filter_nodes(graph, nodes_to_remove):
    filtered_graph = graph.copy()
    for node in nodes_to_remove:
        if node in filtered_graph:
            filtered_graph.remove_node(node)
    return filtered_graph

if __name__ == '__main__':
    app.run_server(debug=True, port=8050, host='0.0.0.0')