eaglelandsonce commited on
Commit
392c4ff
1 Parent(s): 298bd94

Update pages/19_Graphs3.py

Browse files
Files changed (1) hide show
  1. pages/19_Graphs3.py +2 -4
pages/19_Graphs3.py CHANGED
@@ -155,10 +155,9 @@ def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
155
  return node_set["empty_state"]
156
  raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
157
 
158
- # Define the GNN model function
159
  def gnn_model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
160
- graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
161
- graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(graph)
162
  for i in range(num_graph_updates):
163
  graph = mt_albis.MtAlbisGraphUpdate(
164
  units=node_state_dim,
@@ -173,7 +172,6 @@ def gnn_model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
173
  )(graph)
174
  return tf.keras.Model(inputs, graph)
175
 
176
- # Define the complete model function
177
  def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
178
  gnn_model = gnn_model_fn(graph_tensor_spec)
179
  inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
 
155
  return node_set["empty_state"]
156
  raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
157
 
 
158
  def gnn_model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
159
+ inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
160
+ graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(inputs)
161
  for i in range(num_graph_updates):
162
  graph = mt_albis.MtAlbisGraphUpdate(
163
  units=node_state_dim,
 
172
  )(graph)
173
  return tf.keras.Model(inputs, graph)
174
 
 
175
  def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
176
  gnn_model = gnn_model_fn(graph_tensor_spec)
177
  inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)