Spaces:
Sleeping
Sleeping
eaglelandsonce
commited on
Commit
•
392c4ff
1
Parent(s):
298bd94
Update pages/19_Graphs3.py
Browse files- pages/19_Graphs3.py +2 -4
pages/19_Graphs3.py
CHANGED
@@ -155,10 +155,9 @@ def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
|
|
155 |
return node_set["empty_state"]
|
156 |
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
|
157 |
|
158 |
-
# Define the GNN model function
|
159 |
def gnn_model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
160 |
-
|
161 |
-
graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(
|
162 |
for i in range(num_graph_updates):
|
163 |
graph = mt_albis.MtAlbisGraphUpdate(
|
164 |
units=node_state_dim,
|
@@ -173,7 +172,6 @@ def gnn_model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
|
173 |
)(graph)
|
174 |
return tf.keras.Model(inputs, graph)
|
175 |
|
176 |
-
# Define the complete model function
|
177 |
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
178 |
gnn_model = gnn_model_fn(graph_tensor_spec)
|
179 |
inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
|
|
|
155 |
return node_set["empty_state"]
|
156 |
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
|
157 |
|
|
|
158 |
def gnn_model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
159 |
+
inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
|
160 |
+
graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(inputs)
|
161 |
for i in range(num_graph_updates):
|
162 |
graph = mt_albis.MtAlbisGraphUpdate(
|
163 |
units=node_state_dim,
|
|
|
172 |
)(graph)
|
173 |
return tf.keras.Model(inputs, graph)
|
174 |
|
|
|
175 |
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
176 |
gnn_model = gnn_model_fn(graph_tensor_spec)
|
177 |
inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
|