eaglelandsonce commited on
Commit
298bd94
·
verified ·
1 Parent(s): 469ca53

Update pages/19_Graphs3.py

Browse files
Files changed (1) hide show
  1. pages/19_Graphs3.py +13 -43
pages/19_Graphs3.py CHANGED
@@ -155,10 +155,10 @@ def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
155
  return node_set["empty_state"]
156
  raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
157
 
158
- def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
 
159
  graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
160
- graph = tfgnn.keras.layers.MapFeatures(
161
- node_sets_fn=set_initial_node_states)(graph)
162
  for i in range(num_graph_updates):
163
  graph = mt_albis.MtAlbisGraphUpdate(
164
  units=node_state_dim,
@@ -171,10 +171,17 @@ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
171
  normalization_type="layer",
172
  next_state_type="residual",
173
  )(graph)
174
- graph = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="feat")(graph)
175
- graph = tf.keras.layers.Dense(349, activation="softmax")(graph)
176
  return tf.keras.Model(inputs, graph)
177
 
 
 
 
 
 
 
 
 
 
178
  # Check for TPU/ GPU and set strategy
179
  st.write("Setting up strategy for distributed training...")
180
  if tf.config.list_physical_devices("TPU"):
@@ -227,48 +234,11 @@ trainer = runner.KerasTrainer(
227
  backup_and_restore=False,
228
  )
229
 
230
- # Define feature processors
231
- st.write("Defining feature processors...")
232
- def process_node_features(node_set: tfgnn.NodeSet, node_set_name: str):
233
- if node_set_name == "field_of_study":
234
- return {"hashed_id": tf.keras.layers.Hashing(50_000)(node_set["#id"])}
235
- if node_set_name == "institution":
236
- return {"hashed_id": tf.keras.layers.Hashing(6_500)(node_set["#id"])}
237
- if node_set_name == "paper":
238
- return {"feat": node_set["feat"], "label": node_set["label"]}
239
- if node_set_name == "author":
240
- return {"empty_state": tfgnn.keras.layers.MakeEmptyFeature()(node_set)}
241
- raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
242
-
243
- def drop_all_features(_, **unused_kwargs):
244
- return {}
245
-
246
- process_features = tfgnn.keras.layers.MapFeatures(
247
- context_fn=drop_all_features,
248
- node_sets_fn=process_node_features,
249
- edge_sets_fn=drop_all_features)
250
-
251
- add_readout = tfgnn.keras.layers.AddReadoutFromFirstNode("seed", node_set_name="paper")
252
- move_label_to_readout = tfgnn.keras.layers.StructuredReadoutIntoFeature(
253
- "seed", feature_name="label", new_feature_name="paper_venue", remove_input_feature=True)
254
-
255
- feature_processors = [process_features, add_readout, move_label_to_readout]
256
-
257
- # Function to create the full model with feature processing
258
- def create_full_model(graph_tensor_spec: tfgnn.GraphTensorSpec):
259
- input_graph = tf.keras.layers.Input(type_spec=graph_tensor_spec)
260
- graph = input_graph
261
- for processor in feature_processors:
262
- graph = processor(graph)
263
- graph = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="feat")(graph)
264
- output_graph = model_fn(graph_tensor_spec)(graph)
265
- return tf.keras.Model(input_graph, output_graph)
266
-
267
  # Run training
268
  st.write("Training the model...")
269
  runner.run(
270
  task=task,
271
- model_fn=create_full_model,
272
  trainer=trainer,
273
  optimizer_fn=optimizer_fn,
274
  epochs=epochs,
 
155
  return node_set["empty_state"]
156
  raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
157
 
158
+ # Define the GNN model function
159
+ def gnn_model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
160
  graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
161
+ graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(graph)
 
162
  for i in range(num_graph_updates):
163
  graph = mt_albis.MtAlbisGraphUpdate(
164
  units=node_state_dim,
 
171
  normalization_type="layer",
172
  next_state_type="residual",
173
  )(graph)
 
 
174
  return tf.keras.Model(inputs, graph)
175
 
176
+ # Define the complete model function
177
+ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
178
+ gnn_model = gnn_model_fn(graph_tensor_spec)
179
+ inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
180
+ graph = gnn_model(inputs)
181
+ paper_state = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="hidden_state")(graph)
182
+ paper_state = tf.keras.layers.Dense(349, activation="softmax")(paper_state)
183
+ return tf.keras.Model(inputs, paper_state)
184
+
185
  # Check for TPU/ GPU and set strategy
186
  st.write("Setting up strategy for distributed training...")
187
  if tf.config.list_physical_devices("TPU"):
 
234
  backup_and_restore=False,
235
  )
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  # Run training
238
  st.write("Training the model...")
239
  runner.run(
240
  task=task,
241
+ model_fn=model_fn,
242
  trainer=trainer,
243
  optimizer_fn=optimizer_fn,
244
  epochs=epochs,