Spaces:
Sleeping
Sleeping
eaglelandsonce
commited on
Update pages/19_Graphs3.py
Browse files- pages/19_Graphs3.py +4 -2
pages/19_Graphs3.py
CHANGED
@@ -150,7 +150,7 @@ def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
|
|
150 |
if node_set_name == "institution":
|
151 |
return tf.keras.layers.Embedding(6_500, 16)(node_set["hashed_id"])
|
152 |
if node_set_name == "paper":
|
153 |
-
return tf.keras.layers.Dense(node_state_dim, "relu")(node_set["feat"])
|
154 |
if node_set_name == "author":
|
155 |
return node_set["empty_state"]
|
156 |
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
|
@@ -171,6 +171,8 @@ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
|
171 |
normalization_type="layer",
|
172 |
next_state_type="residual",
|
173 |
)(graph)
|
|
|
|
|
174 |
return tf.keras.Model(inputs, graph)
|
175 |
|
176 |
# Check for TPU/ GPU and set strategy
|
@@ -191,7 +193,7 @@ else:
|
|
191 |
train_padding = None
|
192 |
valid_padding = None
|
193 |
|
194 |
-
st.write(f"Found {strategy.
|
195 |
|
196 |
# Define task
|
197 |
st.write("Defining the task...")
|
|
|
150 |
if node_set_name == "institution":
|
151 |
return tf.keras.layers.Embedding(6_500, 16)(node_set["hashed_id"])
|
152 |
if node_set_name == "paper":
|
153 |
+
return tf.keras.layers.Dense(node_state_dim, activation="relu")(node_set["feat"])
|
154 |
if node_set_name == "author":
|
155 |
return node_set["empty_state"]
|
156 |
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
|
|
|
171 |
normalization_type="layer",
|
172 |
next_state_type="residual",
|
173 |
)(graph)
|
174 |
+
graph = tfgnn.keras.layers.Readout(
|
175 |
+
node_set_name="paper", feature_name="feat")(graph)
|
176 |
return tf.keras.Model(inputs, graph)
|
177 |
|
178 |
# Check for TPU/ GPU and set strategy
|
|
|
193 |
train_padding = None
|
194 |
valid_padding = None
|
195 |
|
196 |
+
st.write(f"Found {strategy.num_replicas in_sync} replicas in sync")
|
197 |
|
198 |
# Define task
|
199 |
st.write("Defining the task...")
|