import streamlit as st import tensorflow as tf import tensorflow_gnn as tfgnn from tensorflow_gnn import runner from tensorflow_gnn.experimental import sampler from tensorflow_gnn.models import mt_albis import functools import os from typing import Mapping # Set environment variable for legacy Keras os.environ['TF_USE_LEGACY_KERAS'] = '1' # Set Streamlit title st.title("Solving OGBN-MAG end-to-end with TF-GNN") st.write("Setting up the environment...") tf.get_logger().setLevel('ERROR') st.write(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.") NUM_TRAINING_SAMPLES = 629571 NUM_VALIDATION_SAMPLES = 64879 GRAPH_TENSOR_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_tensor.example.pb' SCHEMA_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_schema.pbtxt' # Load the graph schema and graph tensor st.write("Loading graph schema and tensor...") graph_schema = tfgnn.read_schema(SCHEMA_FILE) serialized_ogbn_mag_graph_tensor_string = tf.io.read_file(GRAPH_TENSOR_FILE) full_ogbn_mag_graph_tensor = tfgnn.parse_single_example( tfgnn.create_graph_spec_from_schema_pb(graph_schema, indices_dtype=tf.int64), serialized_ogbn_mag_graph_tensor_string) st.write("Graph tensor loaded successfully.") # Define sampling sizes train_sampling_sizes = { "cites": 8, "rev_writes": 8, "writes": 8, "affiliated_with": 8, "has_topic": 8, } validation_sample_sizes = train_sampling_sizes.copy() # Create sampling model def create_sampling_model(full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int]) -> tf.keras.Model: def edge_sampler(sampling_op: tfgnn.sampler.SamplingOp): edge_set_name = sampling_op.edge_set_name sample_size = sizes[edge_set_name] return sampler.InMemUniformEdgesSampler.from_graph_tensor( full_graph_tensor, edge_set_name, sample_size=sample_size ) def get_features(node_set_name: tfgnn.NodeSetName): return sampler.InMemIndexToFeaturesAccessor.from_graph_tensor( full_graph_tensor, node_set_name ) # Spell out the sampling procedure in python sampling_spec_builder = tfgnn.sampler.SamplingSpecBuilder(graph_schema) seed = sampling_spec_builder.seed("paper") papers_cited_from_seed = seed.sample(sizes["cites"], "cites") authors_of_papers = papers_cited_from_seed.join([seed]).sample(sizes["rev_writes"], "rev_writes") papers_by_authors = authors_of_papers.sample(sizes["writes"], "writes") institutions = authors_of_papers.sample(sizes["affiliated_with"], "affiliated_with") fields_of_study = seed.join([papers_cited_from_seed, papers_by_authors]).sample(sizes["has_topic"], "has_topic") sampling_spec = sampling_spec_builder.build() model = sampler.create_sampling_model_from_spec( graph_schema, sampling_spec, edge_sampler, get_features, seed_node_dtype=tf.int64) return model # Create the sampling model st.write("Creating sampling model...") sampling_model = create_sampling_model(full_ogbn_mag_graph_tensor, train_sampling_sizes) st.write("Sampling model created successfully.") # Define seed dataset function def seed_dataset(years: tf.Tensor, split_name: str) -> tf.data.Dataset: """Seed dataset as indices of papers within split years.""" if split_name == "train": mask = years <= 2017 # 629,571 examples elif split_name == "validation": mask = years == 2018 # 64,879 examples elif split_name == "test": mask = years == 2019 # 41,939 examples else: raise ValueError(f"Unknown split_name: '{split_name}'") seed_indices = tf.squeeze(tf.where(mask), axis=-1) return tf.data.Dataset.from_tensor_slices(seed_indices) # Define SubgraphDatasetProvider class SubgraphDatasetProvider(runner.DatasetProvider): """Dataset Provider based on Sampler V2.""" def __init__(self, full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int], split_name: str): super().__init__() self._years = tf.squeeze(full_graph_tensor.node_sets["paper"]["year"], axis=-1) self._sampling_model = create_sampling_model(full_graph_tensor, sizes) self._split_name = split_name self.input_graph_spec = self._sampling_model.output.spec def get_dataset(self, context: tf.distribute.InputContext) -> tf.data.Dataset: """Creates TF dataset.""" self._seed_dataset = seed_dataset(self._years, self._split_name) ds = self._seed_dataset.shard( num_shards=context.num_input_pipelines, index=context.input_pipeline_id) if self._split_name == "train": ds = ds.shuffle(NUM_TRAINING_SAMPLES).repeat() ds = ds.batch(128) ds = ds.map( functools.partial(self.sample), num_parallel_calls=tf.data.AUTOTUNE, deterministic=False, ) return ds.unbatch().prefetch(tf.data.AUTOTUNE) def sample(self, seeds: tf.Tensor) -> tfgnn.GraphTensor: seeds = tf.cast(seeds, tf.int64) batch_size = tf.size(seeds) seeds_ragged = tf.RaggedTensor.from_row_lengths( seeds, tf.ones([batch_size], tf.int64), ) return self._sampling_model(seeds_ragged) # Create dataset providers st.write("Creating dataset providers...") train_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, train_sampling_sizes, "train") valid_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, validation_sample_sizes, "validation") example_input_graph_spec = train_ds_provider.input_graph_spec._unbatch() st.write("Dataset providers created successfully.") # Define the model function node_state_dim = 128 num_graph_updates = 4 message_dim = 128 state_dropout_rate = 0.2 l2_regularization = 1e-5 def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str): if node_set_name == "field_of_study": return tf.keras.layers.Embedding(50_000, 32)(node_set["hashed_id"]) if node_set_name == "institution": return tf.keras.layers.Embedding(6_500, 16)(node_set["hashed_id"]) if node_set_name == "paper": return tf.keras.layers.Dense(node_state_dim, activation="relu")(node_set["feat"]) if node_set_name == "author": return node_set["empty_state"] raise KeyError(f"Unexpected node_set_name='{node_set_name}'") def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec): inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec) graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(inputs) for _ in range(num_graph_updates): graph = mt_albis.MtAlbisGraphUpdate( units=node_state_dim, message_dim=message_dim, attention_type="none", simple_conv_reduce_type="mean|sum", normalization_type="layer", next_state_type="residual", state_dropout_rate=state_dropout_rate, l2_regularization=l2_regularization )(graph) paper_state = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="state")(graph) paper_state = tf.keras.layers.Dense(349, activation="softmax")(paper_state) return tf.keras.Model(inputs, paper_state) # Check for TPU/ GPU and set strategy st.write("Setting up strategy for distributed training...") if tf.config.list_physical_devices("TPU"): st.write("Using TPUStrategy") strategy = runner.TPUStrategy("local") train_padding = runner.FitOrSkipPadding(example_input_graph_spec, train_ds_provider) valid_padding = runner.TightPadding(example_input_graph_spec, valid_ds_provider) elif tf.config.list_physical_devices("GPU"): st.write("Using MirroredStrategy for GPUs") strategy = tf.distribute.MirroredStrategy() train_padding = None valid_padding = None else: st.write("Using default strategy") strategy = tf.distribute.get_strategy() train_padding = None valid_padding = None st.write(f"Found {strategy.num_replicas_in_sync} replicas in sync") # Define task st.write("Defining the task...") task = runner.NodeMulticlassClassification( num_classes=349, label_feature_name="paper_venue") # Set hyperparameters st.write("Setting hyperparameters...") global_batch_size = 128 epochs = 10 initial_learning_rate = 0.001 steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size learning_rate = tf.keras.optimizers.schedules.CosineDecay( initial_learning_rate, steps_per_epoch * epochs) optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate) # Define trainer st.write("Setting up the trainer...") trainer = runner.KerasTrainer( strategy=strategy, model_dir="/tmp/gnn_model/", callbacks=None, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, restore_best_weights=False, checkpoint_every_n_steps="never", summarize_every_n_steps="never", backup_and_restore=False, ) # Run training st.write("Training the model...") runner.run( task=task, model_fn=model_fn, trainer=trainer, optimizer_fn=optimizer_fn, epochs=epochs, global_batch_size=global_batch_size, train_ds_provider=train_ds_provider, valid_ds_provider=valid_ds_provider, gtspec=example_input_graph_spec, ) st.write("Training completed successfully.")