Spaces:
Sleeping
Sleeping
eaglelandsonce
commited on
Update pages/19_Graphs3.py
Browse files- pages/19_Graphs3.py +13 -43
pages/19_Graphs3.py
CHANGED
@@ -155,10 +155,10 @@ def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
|
|
155 |
return node_set["empty_state"]
|
156 |
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
|
157 |
|
158 |
-
|
|
|
159 |
graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
|
160 |
-
graph = tfgnn.keras.layers.MapFeatures(
|
161 |
-
node_sets_fn=set_initial_node_states)(graph)
|
162 |
for i in range(num_graph_updates):
|
163 |
graph = mt_albis.MtAlbisGraphUpdate(
|
164 |
units=node_state_dim,
|
@@ -171,10 +171,17 @@ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
|
171 |
normalization_type="layer",
|
172 |
next_state_type="residual",
|
173 |
)(graph)
|
174 |
-
graph = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="feat")(graph)
|
175 |
-
graph = tf.keras.layers.Dense(349, activation="softmax")(graph)
|
176 |
return tf.keras.Model(inputs, graph)
|
177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
# Check for TPU/ GPU and set strategy
|
179 |
st.write("Setting up strategy for distributed training...")
|
180 |
if tf.config.list_physical_devices("TPU"):
|
@@ -227,48 +234,11 @@ trainer = runner.KerasTrainer(
|
|
227 |
backup_and_restore=False,
|
228 |
)
|
229 |
|
230 |
-
# Define feature processors
|
231 |
-
st.write("Defining feature processors...")
|
232 |
-
def process_node_features(node_set: tfgnn.NodeSet, node_set_name: str):
|
233 |
-
if node_set_name == "field_of_study":
|
234 |
-
return {"hashed_id": tf.keras.layers.Hashing(50_000)(node_set["#id"])}
|
235 |
-
if node_set_name == "institution":
|
236 |
-
return {"hashed_id": tf.keras.layers.Hashing(6_500)(node_set["#id"])}
|
237 |
-
if node_set_name == "paper":
|
238 |
-
return {"feat": node_set["feat"], "label": node_set["label"]}
|
239 |
-
if node_set_name == "author":
|
240 |
-
return {"empty_state": tfgnn.keras.layers.MakeEmptyFeature()(node_set)}
|
241 |
-
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
|
242 |
-
|
243 |
-
def drop_all_features(_, **unused_kwargs):
|
244 |
-
return {}
|
245 |
-
|
246 |
-
process_features = tfgnn.keras.layers.MapFeatures(
|
247 |
-
context_fn=drop_all_features,
|
248 |
-
node_sets_fn=process_node_features,
|
249 |
-
edge_sets_fn=drop_all_features)
|
250 |
-
|
251 |
-
add_readout = tfgnn.keras.layers.AddReadoutFromFirstNode("seed", node_set_name="paper")
|
252 |
-
move_label_to_readout = tfgnn.keras.layers.StructuredReadoutIntoFeature(
|
253 |
-
"seed", feature_name="label", new_feature_name="paper_venue", remove_input_feature=True)
|
254 |
-
|
255 |
-
feature_processors = [process_features, add_readout, move_label_to_readout]
|
256 |
-
|
257 |
-
# Function to create the full model with feature processing
|
258 |
-
def create_full_model(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
259 |
-
input_graph = tf.keras.layers.Input(type_spec=graph_tensor_spec)
|
260 |
-
graph = input_graph
|
261 |
-
for processor in feature_processors:
|
262 |
-
graph = processor(graph)
|
263 |
-
graph = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="feat")(graph)
|
264 |
-
output_graph = model_fn(graph_tensor_spec)(graph)
|
265 |
-
return tf.keras.Model(input_graph, output_graph)
|
266 |
-
|
267 |
# Run training
|
268 |
st.write("Training the model...")
|
269 |
runner.run(
|
270 |
task=task,
|
271 |
-
model_fn=
|
272 |
trainer=trainer,
|
273 |
optimizer_fn=optimizer_fn,
|
274 |
epochs=epochs,
|
|
|
155 |
return node_set["empty_state"]
|
156 |
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
|
157 |
|
158 |
+
# Define the GNN model function
|
159 |
+
def gnn_model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
160 |
graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
|
161 |
+
graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(graph)
|
|
|
162 |
for i in range(num_graph_updates):
|
163 |
graph = mt_albis.MtAlbisGraphUpdate(
|
164 |
units=node_state_dim,
|
|
|
171 |
normalization_type="layer",
|
172 |
next_state_type="residual",
|
173 |
)(graph)
|
|
|
|
|
174 |
return tf.keras.Model(inputs, graph)
|
175 |
|
176 |
+
# Define the complete model function
|
177 |
+
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
178 |
+
gnn_model = gnn_model_fn(graph_tensor_spec)
|
179 |
+
inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
|
180 |
+
graph = gnn_model(inputs)
|
181 |
+
paper_state = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="hidden_state")(graph)
|
182 |
+
paper_state = tf.keras.layers.Dense(349, activation="softmax")(paper_state)
|
183 |
+
return tf.keras.Model(inputs, paper_state)
|
184 |
+
|
185 |
# Check for TPU/ GPU and set strategy
|
186 |
st.write("Setting up strategy for distributed training...")
|
187 |
if tf.config.list_physical_devices("TPU"):
|
|
|
234 |
backup_and_restore=False,
|
235 |
)
|
236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
# Run training
|
238 |
st.write("Training the model...")
|
239 |
runner.run(
|
240 |
task=task,
|
241 |
+
model_fn=model_fn,
|
242 |
trainer=trainer,
|
243 |
optimizer_fn=optimizer_fn,
|
244 |
epochs=epochs,
|