Spaces:
Sleeping
Sleeping
eaglelandsonce
commited on
Update pages/15_Graphs.py
Browse files- pages/15_Graphs.py +15 -6
pages/15_Graphs.py
CHANGED
@@ -5,6 +5,10 @@ from tensorflow_gnn.models import mt_albis
|
|
5 |
import networkx as nx
|
6 |
import matplotlib.pyplot as plt
|
7 |
|
|
|
|
|
|
|
|
|
8 |
# Define the model function
|
9 |
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
10 |
graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
|
@@ -17,10 +21,15 @@ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
|
17 |
# Message passing layers
|
18 |
for _ in range(2):
|
19 |
graph = mt_albis.MtAlbisGraphUpdate(
|
20 |
-
units=128,
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
24 |
)(graph)
|
25 |
|
26 |
return tf.keras.Model(inputs, graph)
|
@@ -47,8 +56,8 @@ def create_sample_graph():
|
|
47 |
"cites": tfgnn.EdgeSet.from_fields(
|
48 |
sizes=[num_edges],
|
49 |
adjacency=tfgnn.Adjacency.from_indices(
|
50 |
-
source=("papers", tf.cast(
|
51 |
-
target=("papers", tf.cast(
|
52 |
),
|
53 |
features={"features": edge_features}
|
54 |
)
|
|
|
5 |
import networkx as nx
|
6 |
import matplotlib.pyplot as plt
|
7 |
|
8 |
+
# Set environment variable for legacy Keras
|
9 |
+
import os
|
10 |
+
os.environ['TF_USE_LEGACY_KERAS'] = '1'
|
11 |
+
|
12 |
# Define the model function
|
13 |
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
14 |
graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
|
|
|
21 |
# Message passing layers
|
22 |
for _ in range(2):
|
23 |
graph = mt_albis.MtAlbisGraphUpdate(
|
24 |
+
units=128,
|
25 |
+
message_dim=64,
|
26 |
+
receiver_tag=tfgnn.NODES, # Add this line
|
27 |
+
attention_type="none",
|
28 |
+
simple_conv_reduce_type="mean",
|
29 |
+
normalization_type="layer",
|
30 |
+
next_state_type="residual",
|
31 |
+
state_dropout_rate=0.2,
|
32 |
+
l2_regularization=1e-5,
|
33 |
)(graph)
|
34 |
|
35 |
return tf.keras.Model(inputs, graph)
|
|
|
56 |
"cites": tfgnn.EdgeSet.from_fields(
|
57 |
sizes=[num_edges],
|
58 |
adjacency=tfgnn.Adjacency.from_indices(
|
59 |
+
source=("papers", tf.cast([e[0] for e in graph.edges()], tf.int32)),
|
60 |
+
target=("papers", tf.cast([e[1] for e in graph.edges()], tf.int32))
|
61 |
),
|
62 |
features={"features": edge_features}
|
63 |
)
|