eaglelandsonce commited on
Commit
469ca53
·
verified ·
1 Parent(s): 1e34469

Update pages/19_Graphs3.py

Browse files
Files changed (1) hide show
  1. pages/19_Graphs3.py +4 -4
pages/19_Graphs3.py CHANGED
@@ -171,6 +171,8 @@ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
171
  normalization_type="layer",
172
  next_state_type="residual",
173
  )(graph)
 
 
174
  return tf.keras.Model(inputs, graph)
175
 
176
  # Check for TPU/ GPU and set strategy
@@ -258,10 +260,8 @@ def create_full_model(graph_tensor_spec: tfgnn.GraphTensorSpec):
258
  graph = input_graph
259
  for processor in feature_processors:
260
  graph = processor(graph)
261
- # Ensure the Readout layer is correctly added after feature processing
262
- readout = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="feat")
263
- graph = readout(graph)
264
- output_graph = model_fn(graph.spec)(graph)
265
  return tf.keras.Model(input_graph, output_graph)
266
 
267
  # Run training
 
171
  normalization_type="layer",
172
  next_state_type="residual",
173
  )(graph)
174
+ graph = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="feat")(graph)
175
+ graph = tf.keras.layers.Dense(349, activation="softmax")(graph)
176
  return tf.keras.Model(inputs, graph)
177
 
178
  # Check for TPU/ GPU and set strategy
 
260
  graph = input_graph
261
  for processor in feature_processors:
262
  graph = processor(graph)
263
+ graph = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="feat")(graph)
264
+ output_graph = model_fn(graph_tensor_spec)(graph)
 
 
265
  return tf.keras.Model(input_graph, output_graph)
266
 
267
  # Run training