Spaces:
Sleeping
Sleeping
File size: 9,398 Bytes
b4be21e 16fad40 b4be21e 16fad40 51ad6eb 16fad40 81e6cca 392c4ff 81e6cca b4be21e 81e6cca b4be21e 81e6cca b4be21e 81e6cca 298bd94 b4be21e f8cb6aa b4be21e 81e6cca b4be21e 298bd94 b4be21e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
import streamlit as st
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn import runner
from tensorflow_gnn.experimental import sampler
from tensorflow_gnn.models import mt_albis
import functools
import os
from typing import Mapping
# Set environment variable for legacy Keras
os.environ['TF_USE_LEGACY_KERAS'] = '1'
# Set Streamlit title
st.title("Solving OGBN-MAG end-to-end with TF-GNN")
st.write("Setting up the environment...")
tf.get_logger().setLevel('ERROR')
st.write(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.")
NUM_TRAINING_SAMPLES = 629571
NUM_VALIDATION_SAMPLES = 64879
GRAPH_TENSOR_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_tensor.example.pb'
SCHEMA_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_schema.pbtxt'
# Load the graph schema and graph tensor
st.write("Loading graph schema and tensor...")
graph_schema = tfgnn.read_schema(SCHEMA_FILE)
serialized_ogbn_mag_graph_tensor_string = tf.io.read_file(GRAPH_TENSOR_FILE)
full_ogbn_mag_graph_tensor = tfgnn.parse_single_example(
tfgnn.create_graph_spec_from_schema_pb(graph_schema, indices_dtype=tf.int64),
serialized_ogbn_mag_graph_tensor_string)
st.write("Graph tensor loaded successfully.")
# Define sampling sizes
train_sampling_sizes = {
"cites": 8,
"rev_writes": 8,
"writes": 8,
"affiliated_with": 8,
"has_topic": 8,
}
validation_sample_sizes = train_sampling_sizes.copy()
# Create sampling model
def create_sampling_model(full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int]) -> tf.keras.Model:
def edge_sampler(sampling_op: tfgnn.sampler.SamplingOp):
edge_set_name = sampling_op.edge_set_name
sample_size = sizes[edge_set_name]
return sampler.InMemUniformEdgesSampler.from_graph_tensor(
full_graph_tensor, edge_set_name, sample_size=sample_size
)
def get_features(node_set_name: tfgnn.NodeSetName):
return sampler.InMemIndexToFeaturesAccessor.from_graph_tensor(
full_graph_tensor, node_set_name
)
# Spell out the sampling procedure in python
sampling_spec_builder = tfgnn.sampler.SamplingSpecBuilder(graph_schema)
seed = sampling_spec_builder.seed("paper")
papers_cited_from_seed = seed.sample(sizes["cites"], "cites")
authors_of_papers = papers_cited_from_seed.join([seed]).sample(sizes["rev_writes"], "rev_writes")
papers_by_authors = authors_of_papers.sample(sizes["writes"], "writes")
institutions = authors_of_papers.sample(sizes["affiliated_with"], "affiliated_with")
fields_of_study = seed.join([papers_cited_from_seed, papers_by_authors]).sample(sizes["has_topic"], "has_topic")
sampling_spec = sampling_spec_builder.build()
model = sampler.create_sampling_model_from_spec(
graph_schema, sampling_spec, edge_sampler, get_features,
seed_node_dtype=tf.int64)
return model
# Create the sampling model
st.write("Creating sampling model...")
sampling_model = create_sampling_model(full_ogbn_mag_graph_tensor, train_sampling_sizes)
st.write("Sampling model created successfully.")
# Define seed dataset function
def seed_dataset(years: tf.Tensor, split_name: str) -> tf.data.Dataset:
"""Seed dataset as indices of papers within split years."""
if split_name == "train":
mask = years <= 2017 # 629,571 examples
elif split_name == "validation":
mask = years == 2018 # 64,879 examples
elif split_name == "test":
mask = years == 2019 # 41,939 examples
else:
raise ValueError(f"Unknown split_name: '{split_name}'")
seed_indices = tf.squeeze(tf.where(mask), axis=-1)
return tf.data.Dataset.from_tensor_slices(seed_indices)
# Define SubgraphDatasetProvider
class SubgraphDatasetProvider(runner.DatasetProvider):
"""Dataset Provider based on Sampler V2."""
def __init__(self, full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int], split_name: str):
super().__init__()
self._years = tf.squeeze(full_graph_tensor.node_sets["paper"]["year"], axis=-1)
self._sampling_model = create_sampling_model(full_graph_tensor, sizes)
self._split_name = split_name
self.input_graph_spec = self._sampling_model.output.spec
def get_dataset(self, context: tf.distribute.InputContext) -> tf.data.Dataset:
"""Creates TF dataset."""
self._seed_dataset = seed_dataset(self._years, self._split_name)
ds = self._seed_dataset.shard(
num_shards=context.num_input_pipelines, index=context.input_pipeline_id)
if self._split_name == "train":
ds = ds.shuffle(NUM_TRAINING_SAMPLES).repeat()
ds = ds.batch(128)
ds = ds.map(
functools.partial(self.sample),
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False,
)
return ds.unbatch().prefetch(tf.data.AUTOTUNE)
def sample(self, seeds: tf.Tensor) -> tfgnn.GraphTensor:
seeds = tf.cast(seeds, tf.int64)
batch_size = tf.size(seeds)
seeds_ragged = tf.RaggedTensor.from_row_lengths(
seeds, tf.ones([batch_size], tf.int64),
)
return self._sampling_model(seeds_ragged)
# Create dataset providers
st.write("Creating dataset providers...")
train_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, train_sampling_sizes, "train")
valid_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, validation_sample_sizes, "validation")
example_input_graph_spec = train_ds_provider.input_graph_spec._unbatch()
st.write("Dataset providers created successfully.")
# Define the model function
node_state_dim = 128
num_graph_updates = 4
message_dim = 128
state_dropout_rate = 0.2
l2_regularization = 1e-5
def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
if node_set_name == "field_of_study":
return tf.keras.layers.Embedding(50_000, 32)(node_set["hashed_id"])
if node_set_name == "institution":
return tf.keras.layers.Embedding(6_500, 16)(node_set["hashed_id"])
if node_set_name == "paper":
return tf.keras.layers.Dense(node_state_dim, activation="relu")(node_set["feat"])
if node_set_name == "author":
return node_set["empty_state"]
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(inputs)
for _ in range(num_graph_updates):
graph = mt_albis.MtAlbisGraphUpdate(
units=node_state_dim,
message_dim=message_dim,
attention_type="none",
simple_conv_reduce_type="mean|sum",
normalization_type="layer",
next_state_type="residual",
state_dropout_rate=state_dropout_rate,
l2_regularization=l2_regularization
)(graph)
paper_state = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="state")(graph)
paper_state = tf.keras.layers.Dense(349, activation="softmax")(paper_state)
return tf.keras.Model(inputs, paper_state)
# Check for TPU/ GPU and set strategy
st.write("Setting up strategy for distributed training...")
if tf.config.list_physical_devices("TPU"):
st.write("Using TPUStrategy")
strategy = runner.TPUStrategy("local")
train_padding = runner.FitOrSkipPadding(example_input_graph_spec, train_ds_provider)
valid_padding = runner.TightPadding(example_input_graph_spec, valid_ds_provider)
elif tf.config.list_physical_devices("GPU"):
st.write("Using MirroredStrategy for GPUs")
strategy = tf.distribute.MirroredStrategy()
train_padding = None
valid_padding = None
else:
st.write("Using default strategy")
strategy = tf.distribute.get_strategy()
train_padding = None
valid_padding = None
st.write(f"Found {strategy.num_replicas_in_sync} replicas in sync")
# Define task
st.write("Defining the task...")
task = runner.NodeMulticlassClassification(
num_classes=349,
label_feature_name="paper_venue")
# Set hyperparameters
st.write("Setting hyperparameters...")
global_batch_size = 128
epochs = 10
initial_learning_rate = 0.001
steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size
validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size
learning_rate = tf.keras.optimizers.schedules.CosineDecay(
initial_learning_rate, steps_per_epoch * epochs)
optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate)
# Define trainer
st.write("Setting up the trainer...")
trainer = runner.KerasTrainer(
strategy=strategy,
model_dir="/tmp/gnn_model/",
callbacks=None,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
restore_best_weights=False,
checkpoint_every_n_steps="never",
summarize_every_n_steps="never",
backup_and_restore=False,
)
# Run training
st.write("Training the model...")
runner.run(
task=task,
model_fn=model_fn,
trainer=trainer,
optimizer_fn=optimizer_fn,
epochs=epochs,
global_batch_size=global_batch_size,
train_ds_provider=train_ds_provider,
valid_ds_provider=valid_ds_provider,
gtspec=example_input_graph_spec,
)
st.write("Training completed successfully.")
|