eaglelandsonce commited on
Commit
b6966f8
·
verified ·
1 Parent(s): f43e499

Update pages/15_Graphs.py

Browse files
Files changed (1) hide show
  1. pages/15_Graphs.py +15 -6
pages/15_Graphs.py CHANGED
@@ -5,6 +5,10 @@ from tensorflow_gnn.models import mt_albis
5
  import networkx as nx
6
  import matplotlib.pyplot as plt
7
 
 
 
 
 
8
  # Define the model function
9
  def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
10
  graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
@@ -17,10 +21,15 @@ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
17
  # Message passing layers
18
  for _ in range(2):
19
  graph = mt_albis.MtAlbisGraphUpdate(
20
- units=128, message_dim=64,
21
- attention_type="none", simple_conv_reduce_type="mean",
22
- normalization_type="layer", next_state_type="residual",
23
- state_dropout_rate=0.2, l2_regularization=1e-5,
 
 
 
 
 
24
  )(graph)
25
 
26
  return tf.keras.Model(inputs, graph)
@@ -47,8 +56,8 @@ def create_sample_graph():
47
  "cites": tfgnn.EdgeSet.from_fields(
48
  sizes=[num_edges],
49
  adjacency=tfgnn.Adjacency.from_indices(
50
- source=("papers", tf.cast(list(e[0] for e in graph.edges()), tf.int32)),
51
- target=("papers", tf.cast(list(e[1] for e in graph.edges()), tf.int32))
52
  ),
53
  features={"features": edge_features}
54
  )
 
5
  import networkx as nx
6
  import matplotlib.pyplot as plt
7
 
8
+ # Set environment variable for legacy Keras
9
+ import os
10
+ os.environ['TF_USE_LEGACY_KERAS'] = '1'
11
+
12
  # Define the model function
13
  def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
14
  graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
 
21
  # Message passing layers
22
  for _ in range(2):
23
  graph = mt_albis.MtAlbisGraphUpdate(
24
+ units=128,
25
+ message_dim=64,
26
+ receiver_tag=tfgnn.NODES, # Add this line
27
+ attention_type="none",
28
+ simple_conv_reduce_type="mean",
29
+ normalization_type="layer",
30
+ next_state_type="residual",
31
+ state_dropout_rate=0.2,
32
+ l2_regularization=1e-5,
33
  )(graph)
34
 
35
  return tf.keras.Model(inputs, graph)
 
56
  "cites": tfgnn.EdgeSet.from_fields(
57
  sizes=[num_edges],
58
  adjacency=tfgnn.Adjacency.from_indices(
59
+ source=("papers", tf.cast([e[0] for e in graph.edges()], tf.int32)),
60
+ target=("papers", tf.cast([e[1] for e in graph.edges()], tf.int32))
61
  ),
62
  features={"features": edge_features}
63
  )