eaglelandsonce commited on
Commit
16fad40
·
verified ·
1 Parent(s): d34da7a

Rename pages/19_Graphs.py to pages/19_Graphs3.py

Browse files
pages/{19_Graphs.py → 19_Graphs3.py} RENAMED
@@ -6,6 +6,7 @@ from tensorflow_gnn.experimental import sampler
6
  from tensorflow_gnn.models import mt_albis
7
  import functools
8
  import os
 
9
 
10
  # Set environment variable for legacy Keras
11
  os.environ['TF_USE_LEGACY_KERAS'] = '1'
@@ -13,9 +14,6 @@ os.environ['TF_USE_LEGACY_KERAS'] = '1'
13
  # Set Streamlit title
14
  st.title("Solving OGBN-MAG end-to-end with TF-GNN")
15
 
16
- # Install necessary packages
17
- st.write("Installing necessary packages...")
18
-
19
  st.write("Setting up the environment...")
20
  tf.get_logger().setLevel('ERROR')
21
  st.write(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.")
@@ -140,6 +138,23 @@ example_input_graph_spec = train_ds_provider.input_graph_spec._unbatch()
140
  st.write("Dataset providers created successfully.")
141
 
142
  # Define the model function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
144
  graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
145
  graph = tfgnn.keras.layers.MapFeatures(
 
6
  from tensorflow_gnn.models import mt_albis
7
  import functools
8
  import os
9
+ from typing import Mapping
10
 
11
  # Set environment variable for legacy Keras
12
  os.environ['TF_USE_LEGACY_KERAS'] = '1'
 
14
  # Set Streamlit title
15
  st.title("Solving OGBN-MAG end-to-end with TF-GNN")
16
 
 
 
 
17
  st.write("Setting up the environment...")
18
  tf.get_logger().setLevel('ERROR')
19
  st.write(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.")
 
138
  st.write("Dataset providers created successfully.")
139
 
140
  # Define the model function
141
+ node_state_dim = 128
142
+ num_graph_updates = 4
143
+ message_dim = 128
144
+ state_dropout_rate = 0.2
145
+ l2_regularization = 1e-5
146
+
147
+ def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
148
+ if node_set_name == "field_of_study":
149
+ return tf.keras.layers.Embedding(50_000, 32)(node_set["hashed_id"])
150
+ if node_set_name == "institution":
151
+ return tf.keras.layers.Embedding(6_500, 16)(node_set["hashed_id"])
152
+ if node_set_name == "paper":
153
+ return tf.keras.layers.Dense(node_state_dim, "relu")(node_set["feat"])
154
+ if node_set_name == "author":
155
+ return node_set["empty_state"]
156
+ raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
157
+
158
  def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
159
  graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
160
  graph = tfgnn.keras.layers.MapFeatures(