eaglelandsonce commited on
Commit
1e34469
·
verified ·
1 Parent(s): f8cb6aa

Update pages/19_Graphs3.py

Browse files
Files changed (1) hide show
  1. pages/19_Graphs3.py +3 -2
pages/19_Graphs3.py CHANGED
@@ -171,8 +171,6 @@ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
171
  normalization_type="layer",
172
  next_state_type="residual",
173
  )(graph)
174
- graph = tfgnn.keras.layers.Readout(
175
- node_set_name="paper", feature_name="feat")(graph)
176
  return tf.keras.Model(inputs, graph)
177
 
178
  # Check for TPU/ GPU and set strategy
@@ -260,6 +258,9 @@ def create_full_model(graph_tensor_spec: tfgnn.GraphTensorSpec):
260
  graph = input_graph
261
  for processor in feature_processors:
262
  graph = processor(graph)
 
 
 
263
  output_graph = model_fn(graph.spec)(graph)
264
  return tf.keras.Model(input_graph, output_graph)
265
 
 
171
  normalization_type="layer",
172
  next_state_type="residual",
173
  )(graph)
 
 
174
  return tf.keras.Model(inputs, graph)
175
 
176
  # Check for TPU/ GPU and set strategy
 
258
  graph = input_graph
259
  for processor in feature_processors:
260
  graph = processor(graph)
261
+ # Ensure the Readout layer is correctly added after feature processing
262
+ readout = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="feat")
263
+ graph = readout(graph)
264
  output_graph = model_fn(graph.spec)(graph)
265
  return tf.keras.Model(input_graph, output_graph)
266