Spaces:
Sleeping
Sleeping
eaglelandsonce
commited on
Rename pages/19_Graphs.py to pages/19_Graphs3.py
Browse files
pages/{19_Graphs.py → 19_Graphs3.py}
RENAMED
@@ -6,6 +6,7 @@ from tensorflow_gnn.experimental import sampler
|
|
6 |
from tensorflow_gnn.models import mt_albis
|
7 |
import functools
|
8 |
import os
|
|
|
9 |
|
10 |
# Set environment variable for legacy Keras
|
11 |
os.environ['TF_USE_LEGACY_KERAS'] = '1'
|
@@ -13,9 +14,6 @@ os.environ['TF_USE_LEGACY_KERAS'] = '1'
|
|
13 |
# Set Streamlit title
|
14 |
st.title("Solving OGBN-MAG end-to-end with TF-GNN")
|
15 |
|
16 |
-
# Install necessary packages
|
17 |
-
st.write("Installing necessary packages...")
|
18 |
-
|
19 |
st.write("Setting up the environment...")
|
20 |
tf.get_logger().setLevel('ERROR')
|
21 |
st.write(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.")
|
@@ -140,6 +138,23 @@ example_input_graph_spec = train_ds_provider.input_graph_spec._unbatch()
|
|
140 |
st.write("Dataset providers created successfully.")
|
141 |
|
142 |
# Define the model function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
144 |
graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
|
145 |
graph = tfgnn.keras.layers.MapFeatures(
|
|
|
6 |
from tensorflow_gnn.models import mt_albis
|
7 |
import functools
|
8 |
import os
|
9 |
+
from typing import Mapping
|
10 |
|
11 |
# Set environment variable for legacy Keras
|
12 |
os.environ['TF_USE_LEGACY_KERAS'] = '1'
|
|
|
14 |
# Set Streamlit title
|
15 |
st.title("Solving OGBN-MAG end-to-end with TF-GNN")
|
16 |
|
|
|
|
|
|
|
17 |
st.write("Setting up the environment...")
|
18 |
tf.get_logger().setLevel('ERROR')
|
19 |
st.write(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.")
|
|
|
138 |
st.write("Dataset providers created successfully.")
|
139 |
|
140 |
# Define the model function
|
141 |
+
node_state_dim = 128
|
142 |
+
num_graph_updates = 4
|
143 |
+
message_dim = 128
|
144 |
+
state_dropout_rate = 0.2
|
145 |
+
l2_regularization = 1e-5
|
146 |
+
|
147 |
+
def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
|
148 |
+
if node_set_name == "field_of_study":
|
149 |
+
return tf.keras.layers.Embedding(50_000, 32)(node_set["hashed_id"])
|
150 |
+
if node_set_name == "institution":
|
151 |
+
return tf.keras.layers.Embedding(6_500, 16)(node_set["hashed_id"])
|
152 |
+
if node_set_name == "paper":
|
153 |
+
return tf.keras.layers.Dense(node_state_dim, "relu")(node_set["feat"])
|
154 |
+
if node_set_name == "author":
|
155 |
+
return node_set["empty_state"]
|
156 |
+
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
|
157 |
+
|
158 |
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
159 |
graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
|
160 |
graph = tfgnn.keras.layers.MapFeatures(
|