eaglelandsonce commited on
Commit
a786ddb
·
verified ·
1 Parent(s): b6966f8

Update pages/15_Graphs.py

Browse files
Files changed (1) hide show
  1. pages/15_Graphs.py +4 -4
pages/15_Graphs.py CHANGED
@@ -23,7 +23,7 @@ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
23
  graph = mt_albis.MtAlbisGraphUpdate(
24
  units=128,
25
  message_dim=64,
26
- receiver_tag=tfgnn.NODES, # Add this line
27
  attention_type="none",
28
  simple_conv_reduce_type="mean",
29
  normalization_type="layer",
@@ -56,8 +56,8 @@ def create_sample_graph():
56
  "cites": tfgnn.EdgeSet.from_fields(
57
  sizes=[num_edges],
58
  adjacency=tfgnn.Adjacency.from_indices(
59
- source=("papers", tf.cast([e[0] for e in graph.edges()], tf.int32)),
60
- target=("papers", tf.cast([e[1] for e in graph.edges()], tf.int32))
61
  ),
62
  features={"features": edge_features}
63
  )
@@ -97,4 +97,4 @@ def main():
97
  st.text(f"Edge feature shape: {graph_tensor.edge_sets['cites']['features'].shape}")
98
 
99
  if __name__ == "__main__":
100
- main()
 
23
  graph = mt_albis.MtAlbisGraphUpdate(
24
  units=128,
25
  message_dim=64,
26
+ receiver_tag=tfgnn.NODES, # Correctly use tfgnn.NODES
27
  attention_type="none",
28
  simple_conv_reduce_type="mean",
29
  normalization_type="layer",
 
56
  "cites": tfgnn.EdgeSet.from_fields(
57
  sizes=[num_edges],
58
  adjacency=tfgnn.Adjacency.from_indices(
59
+ source=("papers", tf.constant([e[0] for e in graph.edges()], dtype=tf.int32)),
60
+ target=("papers", tf.constant([e[1] for e in graph.edges()], dtype=tf.int32))
61
  ),
62
  features={"features": edge_features}
63
  )
 
97
  st.text(f"Edge feature shape: {graph_tensor.edge_sets['cites']['features'].shape}")
98
 
99
  if __name__ == "__main__":
100
+ main()