eaglelandsonce commited on
Commit
2478080
1 Parent(s): 8e03eb3

Update pages/15_Graphs.py

Browse files
Files changed (1) hide show
  1. pages/15_Graphs.py +2 -2
pages/15_Graphs.py CHANGED
@@ -13,7 +13,7 @@ os.environ['TF_USE_LEGACY_KERAS'] = '1'
13
  def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
14
  graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
15
 
16
- # Encode input features (callback omitted for brevity).
17
  graph = tfgnn.keras.layers.MapFeatures(
18
  node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(64)(node_set['features'])
19
  )(graph)
@@ -26,7 +26,7 @@ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
26
  attention_type="none", simple_conv_reduce_type="mean",
27
  normalization_type="layer", next_state_type="residual",
28
  state_dropout_rate=0.2, l2_regularization=1e-5,
29
- receiver_tag=tfgnn.NODES # Correctly use tfgnn.NODES
30
  )(graph)
31
 
32
  return tf.keras.Model(inputs, graph)
 
13
  def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
14
  graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
15
 
16
+ # Encode input features
17
  graph = tfgnn.keras.layers.MapFeatures(
18
  node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(64)(node_set['features'])
19
  )(graph)
 
26
  attention_type="none", simple_conv_reduce_type="mean",
27
  normalization_type="layer", next_state_type="residual",
28
  state_dropout_rate=0.2, l2_regularization=1e-5,
29
+ receiver_tag=tfgnn.TARGET # Use TARGET instead of NODES
30
  )(graph)
31
 
32
  return tf.keras.Model(inputs, graph)