eaglelandsonce commited on
Commit
673466f
1 Parent(s): 2478080

Update pages/15_Graphs.py

Browse files
Files changed (1) hide show
  1. pages/15_Graphs.py +3 -3
pages/15_Graphs.py CHANGED
@@ -13,9 +13,9 @@ os.environ['TF_USE_LEGACY_KERAS'] = '1'
13
  def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
14
  graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
15
 
16
- # Encode input features
17
  graph = tfgnn.keras.layers.MapFeatures(
18
- node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(64)(node_set['features'])
19
  )(graph)
20
 
21
  # For each round of message passing...
@@ -39,7 +39,7 @@ def create_sample_graph():
39
  graph = nx.gnm_random_graph(num_nodes, num_edges, directed=True)
40
 
41
  # Create a GraphTensor
42
- node_features = tf.random.normal((num_nodes, 64))
43
  edge_features = tf.random.normal((num_edges, 32))
44
 
45
  graph_tensor = tfgnn.GraphTensor.from_pieces(
 
13
  def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
14
  graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
15
 
16
+ # Encode input features to match the required output shape of 128
17
  graph = tfgnn.keras.layers.MapFeatures(
18
+ node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(128)(node_set['features'])
19
  )(graph)
20
 
21
  # For each round of message passing...
 
39
  graph = nx.gnm_random_graph(num_nodes, num_edges, directed=True)
40
 
41
  # Create a GraphTensor
42
+ node_features = tf.random.normal((num_nodes, 128)) # Match the dense layer output
43
  edge_features = tf.random.normal((num_edges, 32))
44
 
45
  graph_tensor = tfgnn.GraphTensor.from_pieces(