Spaces:
Sleeping
Sleeping
eaglelandsonce
commited on
Commit
•
81e6cca
1
Parent(s):
392c4ff
Update pages/19_Graphs3.py
Browse files- pages/19_Graphs3.py +9 -14
pages/19_Graphs3.py
CHANGED
@@ -155,28 +155,23 @@ def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
|
|
155 |
return node_set["empty_state"]
|
156 |
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
|
157 |
|
158 |
-
def
|
159 |
inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
|
160 |
graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(inputs)
|
161 |
-
|
|
|
162 |
graph = mt_albis.MtAlbisGraphUpdate(
|
163 |
units=node_state_dim,
|
164 |
message_dim=message_dim,
|
165 |
-
|
166 |
-
node_set_names=None if i < num_graph_updates-1 else ["paper"],
|
167 |
simple_conv_reduce_type="mean|sum",
|
168 |
-
state_dropout_rate=state_dropout_rate,
|
169 |
-
l2_regularization=l2_regularization,
|
170 |
normalization_type="layer",
|
171 |
next_state_type="residual",
|
|
|
|
|
172 |
)(graph)
|
173 |
-
|
174 |
-
|
175 |
-
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
176 |
-
gnn_model = gnn_model_fn(graph_tensor_spec)
|
177 |
-
inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
|
178 |
-
graph = gnn_model(inputs)
|
179 |
-
paper_state = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="hidden_state")(graph)
|
180 |
paper_state = tf.keras.layers.Dense(349, activation="softmax")(paper_state)
|
181 |
return tf.keras.Model(inputs, paper_state)
|
182 |
|
@@ -215,7 +210,7 @@ initial_learning_rate = 0.001
|
|
215 |
steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size
|
216 |
validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size
|
217 |
learning_rate = tf.keras.optimizers.schedules.CosineDecay(
|
218 |
-
initial_learning_rate, steps_per_epoch*epochs)
|
219 |
optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate)
|
220 |
|
221 |
# Define trainer
|
|
|
155 |
return node_set["empty_state"]
|
156 |
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
|
157 |
|
158 |
+
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
159 |
inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
|
160 |
graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(inputs)
|
161 |
+
|
162 |
+
for _ in range(num_graph_updates):
|
163 |
graph = mt_albis.MtAlbisGraphUpdate(
|
164 |
units=node_state_dim,
|
165 |
message_dim=message_dim,
|
166 |
+
attention_type="none",
|
|
|
167 |
simple_conv_reduce_type="mean|sum",
|
|
|
|
|
168 |
normalization_type="layer",
|
169 |
next_state_type="residual",
|
170 |
+
state_dropout_rate=state_dropout_rate,
|
171 |
+
l2_regularization=l2_regularization
|
172 |
)(graph)
|
173 |
+
|
174 |
+
paper_state = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="state")(graph)
|
|
|
|
|
|
|
|
|
|
|
175 |
paper_state = tf.keras.layers.Dense(349, activation="softmax")(paper_state)
|
176 |
return tf.keras.Model(inputs, paper_state)
|
177 |
|
|
|
210 |
steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size
|
211 |
validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size
|
212 |
learning_rate = tf.keras.optimizers.schedules.CosineDecay(
|
213 |
+
initial_learning_rate, steps_per_epoch * epochs)
|
214 |
optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate)
|
215 |
|
216 |
# Define trainer
|