eaglelandsonce commited on
Commit
8e03eb3
1 Parent(s): eaa1227

Update pages/15_Graphs.py

Browse files
Files changed (1) hide show
  1. pages/15_Graphs.py +21 -24
pages/15_Graphs.py CHANGED
@@ -12,39 +12,36 @@ os.environ['TF_USE_LEGACY_KERAS'] = '1'
12
  # Define the model function
13
  def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
14
  graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
15
-
16
- # Encode input features
17
  graph = tfgnn.keras.layers.MapFeatures(
18
  node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(64)(node_set['features'])
19
  )(graph)
20
-
21
- # Message passing layers
22
  for _ in range(2):
 
23
  graph = mt_albis.MtAlbisGraphUpdate(
24
- units=128,
25
- message_dim=64,
26
- receiver_tag=tfgnn.NODES, # Ensure it is correctly set to tfgnn.NODES
27
- attention_type="none",
28
- simple_conv_reduce_type="mean",
29
- normalization_type="layer",
30
- next_state_type="residual",
31
- state_dropout_rate=0.2,
32
- l2_regularization=1e-5,
33
  )(graph)
34
-
35
  return tf.keras.Model(inputs, graph)
36
 
37
  # Function to create a sample graph
38
  def create_sample_graph():
39
  num_nodes = 10
40
  num_edges = 15
41
-
42
  graph = nx.gnm_random_graph(num_nodes, num_edges, directed=True)
43
-
44
  # Create a GraphTensor
45
  node_features = tf.random.normal((num_nodes, 64))
46
  edge_features = tf.random.normal((num_edges, 32))
47
-
48
  graph_tensor = tfgnn.GraphTensor.from_pieces(
49
  node_sets={
50
  "papers": tfgnn.NodeSet.from_fields(
@@ -63,32 +60,32 @@ def create_sample_graph():
63
  )
64
  }
65
  )
66
-
67
  return graph, graph_tensor
68
 
69
  # Streamlit app
70
  def main():
71
  st.title("Graph Neural Network Architecture Visualization")
72
-
73
  # Create sample graph
74
  nx_graph, graph_tensor = create_sample_graph()
75
-
76
  # Create and compile the model
77
  model = model_fn(graph_tensor.spec)
78
  model.compile(optimizer='adam', loss='binary_crossentropy')
79
-
80
  # Display model summary
81
  st.subheader("Model Summary")
82
  model.summary(print_fn=lambda x: st.text(x))
83
-
84
  # Visualize the graph
85
  st.subheader("Sample Graph Visualization")
86
  fig, ax = plt.subplots(figsize=(10, 8))
87
  pos = nx.spring_layout(nx_graph)
88
- nx.draw(nx_graph, pos, with_labels=True, node_color='lightblue',
89
  node_size=500, arrowsize=20, ax=ax)
90
  st.pyplot(fig)
91
-
92
  # Display graph tensor info
93
  st.subheader("Graph Tensor Information")
94
  st.text(f"Number of nodes: {graph_tensor.node_sets['papers'].total_size}")
 
12
  # Define the model function
13
  def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
14
  graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
15
+
16
+ # Encode input features (callback omitted for brevity).
17
  graph = tfgnn.keras.layers.MapFeatures(
18
  node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(64)(node_set['features'])
19
  )(graph)
20
+
21
+ # For each round of message passing...
22
  for _ in range(2):
23
+ # ... create and apply a Keras layer.
24
  graph = mt_albis.MtAlbisGraphUpdate(
25
+ units=128, message_dim=64,
26
+ attention_type="none", simple_conv_reduce_type="mean",
27
+ normalization_type="layer", next_state_type="residual",
28
+ state_dropout_rate=0.2, l2_regularization=1e-5,
29
+ receiver_tag=tfgnn.NODES # Correctly use tfgnn.NODES
 
 
 
 
30
  )(graph)
31
+
32
  return tf.keras.Model(inputs, graph)
33
 
34
  # Function to create a sample graph
35
  def create_sample_graph():
36
  num_nodes = 10
37
  num_edges = 15
38
+
39
  graph = nx.gnm_random_graph(num_nodes, num_edges, directed=True)
40
+
41
  # Create a GraphTensor
42
  node_features = tf.random.normal((num_nodes, 64))
43
  edge_features = tf.random.normal((num_edges, 32))
44
+
45
  graph_tensor = tfgnn.GraphTensor.from_pieces(
46
  node_sets={
47
  "papers": tfgnn.NodeSet.from_fields(
 
60
  )
61
  }
62
  )
63
+
64
  return graph, graph_tensor
65
 
66
  # Streamlit app
67
  def main():
68
  st.title("Graph Neural Network Architecture Visualization")
69
+
70
  # Create sample graph
71
  nx_graph, graph_tensor = create_sample_graph()
72
+
73
  # Create and compile the model
74
  model = model_fn(graph_tensor.spec)
75
  model.compile(optimizer='adam', loss='binary_crossentropy')
76
+
77
  # Display model summary
78
  st.subheader("Model Summary")
79
  model.summary(print_fn=lambda x: st.text(x))
80
+
81
  # Visualize the graph
82
  st.subheader("Sample Graph Visualization")
83
  fig, ax = plt.subplots(figsize=(10, 8))
84
  pos = nx.spring_layout(nx_graph)
85
+ nx.draw(nx_graph, pos, with_labels=True, node_color='lightblue',
86
  node_size=500, arrowsize=20, ax=ax)
87
  st.pyplot(fig)
88
+
89
  # Display graph tensor info
90
  st.subheader("Graph Tensor Information")
91
  st.text(f"Number of nodes: {graph_tensor.node_sets['papers'].total_size}")