Spaces:
Sleeping
Sleeping
import streamlit as st | |
import tensorflow as tf | |
import tensorflow_gnn as tfgnn | |
from tensorflow_gnn import runner | |
from tensorflow_gnn.experimental import sampler | |
from tensorflow_gnn.models import mt_albis | |
import functools | |
import os | |
from typing import Mapping | |
# Set environment variable for legacy Keras | |
os.environ['TF_USE_LEGACY_KERAS'] = '1' | |
# Set Streamlit title | |
st.title("Solving OGBN-MAG end-to-end with TF-GNN") | |
st.write("Setting up the environment...") | |
tf.get_logger().setLevel('ERROR') | |
st.write(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.") | |
NUM_TRAINING_SAMPLES = 629571 | |
NUM_VALIDATION_SAMPLES = 64879 | |
GRAPH_TENSOR_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_tensor.example.pb' | |
SCHEMA_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_schema.pbtxt' | |
# Load the graph schema and graph tensor | |
st.write("Loading graph schema and tensor...") | |
graph_schema = tfgnn.read_schema(SCHEMA_FILE) | |
serialized_ogbn_mag_graph_tensor_string = tf.io.read_file(GRAPH_TENSOR_FILE) | |
full_ogbn_mag_graph_tensor = tfgnn.parse_single_example( | |
tfgnn.create_graph_spec_from_schema_pb(graph_schema, indices_dtype=tf.int64), | |
serialized_ogbn_mag_graph_tensor_string) | |
st.write("Graph tensor loaded successfully.") | |
# Define sampling sizes | |
train_sampling_sizes = { | |
"cites": 8, | |
"rev_writes": 8, | |
"writes": 8, | |
"affiliated_with": 8, | |
"has_topic": 8, | |
} | |
validation_sample_sizes = train_sampling_sizes.copy() | |
# Create sampling model | |
def create_sampling_model(full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int]) -> tf.keras.Model: | |
def edge_sampler(sampling_op: tfgnn.sampler.SamplingOp): | |
edge_set_name = sampling_op.edge_set_name | |
sample_size = sizes[edge_set_name] | |
return sampler.InMemUniformEdgesSampler.from_graph_tensor( | |
full_graph_tensor, edge_set_name, sample_size=sample_size | |
) | |
def get_features(node_set_name: tfgnn.NodeSetName): | |
return sampler.InMemIndexToFeaturesAccessor.from_graph_tensor( | |
full_graph_tensor, node_set_name | |
) | |
# Spell out the sampling procedure in python | |
sampling_spec_builder = tfgnn.sampler.SamplingSpecBuilder(graph_schema) | |
seed = sampling_spec_builder.seed("paper") | |
papers_cited_from_seed = seed.sample(sizes["cites"], "cites") | |
authors_of_papers = papers_cited_from_seed.join([seed]).sample(sizes["rev_writes"], "rev_writes") | |
papers_by_authors = authors_of_papers.sample(sizes["writes"], "writes") | |
institutions = authors_of_papers.sample(sizes["affiliated_with"], "affiliated_with") | |
fields_of_study = seed.join([papers_cited_from_seed, papers_by_authors]).sample(sizes["has_topic"], "has_topic") | |
sampling_spec = sampling_spec_builder.build() | |
model = sampler.create_sampling_model_from_spec( | |
graph_schema, sampling_spec, edge_sampler, get_features, | |
seed_node_dtype=tf.int64) | |
return model | |
# Create the sampling model | |
st.write("Creating sampling model...") | |
sampling_model = create_sampling_model(full_ogbn_mag_graph_tensor, train_sampling_sizes) | |
st.write("Sampling model created successfully.") | |
# Define seed dataset function | |
def seed_dataset(years: tf.Tensor, split_name: str) -> tf.data.Dataset: | |
"""Seed dataset as indices of papers within split years.""" | |
if split_name == "train": | |
mask = years <= 2017 # 629,571 examples | |
elif split_name == "validation": | |
mask = years == 2018 # 64,879 examples | |
elif split_name == "test": | |
mask = years == 2019 # 41,939 examples | |
else: | |
raise ValueError(f"Unknown split_name: '{split_name}'") | |
seed_indices = tf.squeeze(tf.where(mask), axis=-1) | |
return tf.data.Dataset.from_tensor_slices(seed_indices) | |
# Define SubgraphDatasetProvider | |
class SubgraphDatasetProvider(runner.DatasetProvider): | |
"""Dataset Provider based on Sampler V2.""" | |
def __init__(self, full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int], split_name: str): | |
super().__init__() | |
self._years = tf.squeeze(full_graph_tensor.node_sets["paper"]["year"], axis=-1) | |
self._sampling_model = create_sampling_model(full_graph_tensor, sizes) | |
self._split_name = split_name | |
self.input_graph_spec = self._sampling_model.output.spec | |
def get_dataset(self, context: tf.distribute.InputContext) -> tf.data.Dataset: | |
"""Creates TF dataset.""" | |
self._seed_dataset = seed_dataset(self._years, self._split_name) | |
ds = self._seed_dataset.shard( | |
num_shards=context.num_input_pipelines, index=context.input_pipeline_id) | |
if self._split_name == "train": | |
ds = ds.shuffle(NUM_TRAINING_SAMPLES).repeat() | |
ds = ds.batch(128) | |
ds = ds.map( | |
functools.partial(self.sample), | |
num_parallel_calls=tf.data.AUTOTUNE, | |
deterministic=False, | |
) | |
return ds.unbatch().prefetch(tf.data.AUTOTUNE) | |
def sample(self, seeds: tf.Tensor) -> tfgnn.GraphTensor: | |
seeds = tf.cast(seeds, tf.int64) | |
batch_size = tf.size(seeds) | |
seeds_ragged = tf.RaggedTensor.from_row_lengths( | |
seeds, tf.ones([batch_size], tf.int64), | |
) | |
return self._sampling_model(seeds_ragged) | |
# Create dataset providers | |
st.write("Creating dataset providers...") | |
train_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, train_sampling_sizes, "train") | |
valid_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, validation_sample_sizes, "validation") | |
example_input_graph_spec = train_ds_provider.input_graph_spec._unbatch() | |
st.write("Dataset providers created successfully.") | |
# Define the model function | |
node_state_dim = 128 | |
num_graph_updates = 4 | |
message_dim = 128 | |
state_dropout_rate = 0.2 | |
l2_regularization = 1e-5 | |
def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str): | |
if node_set_name == "field_of_study": | |
return tf.keras.layers.Embedding(50_000, 32)(node_set["hashed_id"]) | |
if node_set_name == "institution": | |
return tf.keras.layers.Embedding(6_500, 16)(node_set["hashed_id"]) | |
if node_set_name == "paper": | |
return tf.keras.layers.Dense(node_state_dim, activation="relu")(node_set["feat"]) | |
if node_set_name == "author": | |
return node_set["empty_state"] | |
raise KeyError(f"Unexpected node_set_name='{node_set_name}'") | |
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec): | |
inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec) | |
graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(inputs) | |
for _ in range(num_graph_updates): | |
graph = mt_albis.MtAlbisGraphUpdate( | |
units=node_state_dim, | |
message_dim=message_dim, | |
attention_type="none", | |
simple_conv_reduce_type="mean|sum", | |
normalization_type="layer", | |
next_state_type="residual", | |
state_dropout_rate=state_dropout_rate, | |
l2_regularization=l2_regularization | |
)(graph) | |
paper_state = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="state")(graph) | |
paper_state = tf.keras.layers.Dense(349, activation="softmax")(paper_state) | |
return tf.keras.Model(inputs, paper_state) | |
# Check for TPU/ GPU and set strategy | |
st.write("Setting up strategy for distributed training...") | |
if tf.config.list_physical_devices("TPU"): | |
st.write("Using TPUStrategy") | |
strategy = runner.TPUStrategy("local") | |
train_padding = runner.FitOrSkipPadding(example_input_graph_spec, train_ds_provider) | |
valid_padding = runner.TightPadding(example_input_graph_spec, valid_ds_provider) | |
elif tf.config.list_physical_devices("GPU"): | |
st.write("Using MirroredStrategy for GPUs") | |
strategy = tf.distribute.MirroredStrategy() | |
train_padding = None | |
valid_padding = None | |
else: | |
st.write("Using default strategy") | |
strategy = tf.distribute.get_strategy() | |
train_padding = None | |
valid_padding = None | |
st.write(f"Found {strategy.num_replicas_in_sync} replicas in sync") | |
# Define task | |
st.write("Defining the task...") | |
task = runner.NodeMulticlassClassification( | |
num_classes=349, | |
label_feature_name="paper_venue") | |
# Set hyperparameters | |
st.write("Setting hyperparameters...") | |
global_batch_size = 128 | |
epochs = 10 | |
initial_learning_rate = 0.001 | |
steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size | |
validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size | |
learning_rate = tf.keras.optimizers.schedules.CosineDecay( | |
initial_learning_rate, steps_per_epoch * epochs) | |
optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate) | |
# Define trainer | |
st.write("Setting up the trainer...") | |
trainer = runner.KerasTrainer( | |
strategy=strategy, | |
model_dir="/tmp/gnn_model/", | |
callbacks=None, | |
steps_per_epoch=steps_per_epoch, | |
validation_steps=validation_steps, | |
restore_best_weights=False, | |
checkpoint_every_n_steps="never", | |
summarize_every_n_steps="never", | |
backup_and_restore=False, | |
) | |
# Run training | |
st.write("Training the model...") | |
runner.run( | |
task=task, | |
model_fn=model_fn, | |
trainer=trainer, | |
optimizer_fn=optimizer_fn, | |
epochs=epochs, | |
global_batch_size=global_batch_size, | |
train_ds_provider=train_ds_provider, | |
valid_ds_provider=valid_ds_provider, | |
gtspec=example_input_graph_spec, | |
) | |
st.write("Training completed successfully.") | |