eaglelandsonce commited on
Commit
81e6cca
1 Parent(s): 392c4ff

Update pages/19_Graphs3.py

Browse files
Files changed (1) hide show
  1. pages/19_Graphs3.py +9 -14
pages/19_Graphs3.py CHANGED
@@ -155,28 +155,23 @@ def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
155
  return node_set["empty_state"]
156
  raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
157
 
158
- def gnn_model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
159
  inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
160
  graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(inputs)
161
- for i in range(num_graph_updates):
 
162
  graph = mt_albis.MtAlbisGraphUpdate(
163
  units=node_state_dim,
164
  message_dim=message_dim,
165
- receiver_tag=tfgnn.SOURCE,
166
- node_set_names=None if i < num_graph_updates-1 else ["paper"],
167
  simple_conv_reduce_type="mean|sum",
168
- state_dropout_rate=state_dropout_rate,
169
- l2_regularization=l2_regularization,
170
  normalization_type="layer",
171
  next_state_type="residual",
 
 
172
  )(graph)
173
- return tf.keras.Model(inputs, graph)
174
-
175
- def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
176
- gnn_model = gnn_model_fn(graph_tensor_spec)
177
- inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
178
- graph = gnn_model(inputs)
179
- paper_state = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="hidden_state")(graph)
180
  paper_state = tf.keras.layers.Dense(349, activation="softmax")(paper_state)
181
  return tf.keras.Model(inputs, paper_state)
182
 
@@ -215,7 +210,7 @@ initial_learning_rate = 0.001
215
  steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size
216
  validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size
217
  learning_rate = tf.keras.optimizers.schedules.CosineDecay(
218
- initial_learning_rate, steps_per_epoch*epochs)
219
  optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate)
220
 
221
  # Define trainer
 
155
  return node_set["empty_state"]
156
  raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
157
 
158
+ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
159
  inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
160
  graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(inputs)
161
+
162
+ for _ in range(num_graph_updates):
163
  graph = mt_albis.MtAlbisGraphUpdate(
164
  units=node_state_dim,
165
  message_dim=message_dim,
166
+ attention_type="none",
 
167
  simple_conv_reduce_type="mean|sum",
 
 
168
  normalization_type="layer",
169
  next_state_type="residual",
170
+ state_dropout_rate=state_dropout_rate,
171
+ l2_regularization=l2_regularization
172
  )(graph)
173
+
174
+ paper_state = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="state")(graph)
 
 
 
 
 
175
  paper_state = tf.keras.layers.Dense(349, activation="softmax")(paper_state)
176
  return tf.keras.Model(inputs, paper_state)
177
 
 
210
  steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size
211
  validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size
212
  learning_rate = tf.keras.optimizers.schedules.CosineDecay(
213
+ initial_learning_rate, steps_per_epoch * epochs)
214
  optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate)
215
 
216
  # Define trainer