eaglelandsonce commited on
Commit
a7a0009
1 Parent(s): 92d4dc7

Create pages/15_Graphs.py

Browse files
Files changed (1) hide show
  1. pages/15_Graphs.py +91 -0
pages/15_Graphs.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import tensorflow_gnn as tfgnn
4
+ from tensorflow_gnn.models import mt_albis
5
+ import networkx as nx
6
+ import matplotlib.pyplot as plt
7
+
8
+ # Define the model function
9
+ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
10
+ graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
11
+
12
+ # Encode input features
13
+ graph = tfgnn.keras.layers.MapFeatures(
14
+ node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(64)(node_set['features'])
15
+ )(graph)
16
+
17
+ # Message passing layers
18
+ for _ in range(2):
19
+ graph = mt_albis.MtAlbisGraphUpdate(
20
+ units=128, message_dim=64,
21
+ attention_type="none", simple_conv_reduce_type="mean",
22
+ normalization_type="layer", next_state_type="residual",
23
+ state_dropout_rate=0.2, l2_regularization=1e-5,
24
+ )(graph)
25
+
26
+ return tf.keras.Model(inputs, graph)
27
+
28
+ # Function to create a sample graph
29
+ def create_sample_graph():
30
+ num_nodes = 10
31
+ num_edges = 15
32
+
33
+ graph = nx.gnm_random_graph(num_nodes, num_edges, directed=True)
34
+
35
+ # Create a GraphTensor
36
+ node_features = tf.random.normal((num_nodes, 64))
37
+ edge_features = tf.random.normal((num_edges, 32))
38
+
39
+ graph_tensor = tfgnn.GraphTensor.from_pieces(
40
+ node_sets={
41
+ "papers": tfgnn.NodeSet.from_fields(
42
+ sizes=[num_nodes],
43
+ features={"features": node_features}
44
+ )
45
+ },
46
+ edge_sets={
47
+ "cites": tfgnn.EdgeSet.from_fields(
48
+ sizes=[num_edges],
49
+ adjacency=tfgnn.Adjacency.from_indices(
50
+ source=("papers", tf.cast(list(e[0] for e in graph.edges()), tf.int32)),
51
+ target=("papers", tf.cast(list(e[1] for e in graph.edges()), tf.int32))
52
+ ),
53
+ features={"features": edge_features}
54
+ )
55
+ }
56
+ )
57
+
58
+ return graph, graph_tensor
59
+
60
+ # Streamlit app
61
+ def main():
62
+ st.title("Graph Neural Network Architecture Visualization")
63
+
64
+ # Create sample graph
65
+ nx_graph, graph_tensor = create_sample_graph()
66
+
67
+ # Create and compile the model
68
+ model = model_fn(graph_tensor.spec)
69
+ model.compile(optimizer='adam', loss='binary_crossentropy')
70
+
71
+ # Display model summary
72
+ st.subheader("Model Summary")
73
+ model.summary(print_fn=lambda x: st.text(x))
74
+
75
+ # Visualize the graph
76
+ st.subheader("Sample Graph Visualization")
77
+ fig, ax = plt.subplots(figsize=(10, 8))
78
+ pos = nx.spring_layout(nx_graph)
79
+ nx.draw(nx_graph, pos, with_labels=True, node_color='lightblue',
80
+ node_size=500, arrowsize=20, ax=ax)
81
+ st.pyplot(fig)
82
+
83
+ # Display graph tensor info
84
+ st.subheader("Graph Tensor Information")
85
+ st.text(f"Number of nodes: {graph_tensor.node_sets['papers'].total_size}")
86
+ st.text(f"Number of edges: {graph_tensor.edge_sets['cites'].total_size}")
87
+ st.text(f"Node feature shape: {graph_tensor.node_sets['papers']['features'].shape}")
88
+ st.text(f"Edge feature shape: {graph_tensor.edge_sets['cites']['features'].shape}")
89
+
90
+ if __name__ == "__main__":
91
+ main()