eaglelandsonce commited on
Commit
51ad6eb
·
verified ·
1 Parent(s): dc52ee1

Update pages/19_Graphs3.py

Browse files
Files changed (1) hide show
  1. pages/19_Graphs3.py +4 -2
pages/19_Graphs3.py CHANGED
@@ -150,7 +150,7 @@ def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
150
  if node_set_name == "institution":
151
  return tf.keras.layers.Embedding(6_500, 16)(node_set["hashed_id"])
152
  if node_set_name == "paper":
153
- return tf.keras.layers.Dense(node_state_dim, "relu")(node_set["feat"])
154
  if node_set_name == "author":
155
  return node_set["empty_state"]
156
  raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
@@ -171,6 +171,8 @@ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
171
  normalization_type="layer",
172
  next_state_type="residual",
173
  )(graph)
 
 
174
  return tf.keras.Model(inputs, graph)
175
 
176
  # Check for TPU/ GPU and set strategy
@@ -191,7 +193,7 @@ else:
191
  train_padding = None
192
  valid_padding = None
193
 
194
- st.write(f"Found {strategy.num_replicas_in_sync} replicas in sync")
195
 
196
  # Define task
197
  st.write("Defining the task...")
 
150
  if node_set_name == "institution":
151
  return tf.keras.layers.Embedding(6_500, 16)(node_set["hashed_id"])
152
  if node_set_name == "paper":
153
+ return tf.keras.layers.Dense(node_state_dim, activation="relu")(node_set["feat"])
154
  if node_set_name == "author":
155
  return node_set["empty_state"]
156
  raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
 
171
  normalization_type="layer",
172
  next_state_type="residual",
173
  )(graph)
174
+ graph = tfgnn.keras.layers.Readout(
175
+ node_set_name="paper", feature_name="feat")(graph)
176
  return tf.keras.Model(inputs, graph)
177
 
178
  # Check for TPU/ GPU and set strategy
 
193
  train_padding = None
194
  valid_padding = None
195
 
196
+ st.write(f"Found {strategy.num_replicas in_sync} replicas in sync")
197
 
198
  # Define task
199
  st.write("Defining the task...")