TensorFlowClass / pages /19_Graphs3.py
eaglelandsonce's picture
Update pages/19_Graphs3.py
81e6cca verified
raw
history blame
9.4 kB
import streamlit as st
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn import runner
from tensorflow_gnn.experimental import sampler
from tensorflow_gnn.models import mt_albis
import functools
import os
from typing import Mapping
# Set environment variable for legacy Keras
os.environ['TF_USE_LEGACY_KERAS'] = '1'
# Set Streamlit title
st.title("Solving OGBN-MAG end-to-end with TF-GNN")
st.write("Setting up the environment...")
tf.get_logger().setLevel('ERROR')
st.write(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.")
NUM_TRAINING_SAMPLES = 629571
NUM_VALIDATION_SAMPLES = 64879
GRAPH_TENSOR_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_tensor.example.pb'
SCHEMA_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_schema.pbtxt'
# Load the graph schema and graph tensor
st.write("Loading graph schema and tensor...")
graph_schema = tfgnn.read_schema(SCHEMA_FILE)
serialized_ogbn_mag_graph_tensor_string = tf.io.read_file(GRAPH_TENSOR_FILE)
full_ogbn_mag_graph_tensor = tfgnn.parse_single_example(
tfgnn.create_graph_spec_from_schema_pb(graph_schema, indices_dtype=tf.int64),
serialized_ogbn_mag_graph_tensor_string)
st.write("Graph tensor loaded successfully.")
# Define sampling sizes
train_sampling_sizes = {
"cites": 8,
"rev_writes": 8,
"writes": 8,
"affiliated_with": 8,
"has_topic": 8,
}
validation_sample_sizes = train_sampling_sizes.copy()
# Create sampling model
def create_sampling_model(full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int]) -> tf.keras.Model:
def edge_sampler(sampling_op: tfgnn.sampler.SamplingOp):
edge_set_name = sampling_op.edge_set_name
sample_size = sizes[edge_set_name]
return sampler.InMemUniformEdgesSampler.from_graph_tensor(
full_graph_tensor, edge_set_name, sample_size=sample_size
)
def get_features(node_set_name: tfgnn.NodeSetName):
return sampler.InMemIndexToFeaturesAccessor.from_graph_tensor(
full_graph_tensor, node_set_name
)
# Spell out the sampling procedure in python
sampling_spec_builder = tfgnn.sampler.SamplingSpecBuilder(graph_schema)
seed = sampling_spec_builder.seed("paper")
papers_cited_from_seed = seed.sample(sizes["cites"], "cites")
authors_of_papers = papers_cited_from_seed.join([seed]).sample(sizes["rev_writes"], "rev_writes")
papers_by_authors = authors_of_papers.sample(sizes["writes"], "writes")
institutions = authors_of_papers.sample(sizes["affiliated_with"], "affiliated_with")
fields_of_study = seed.join([papers_cited_from_seed, papers_by_authors]).sample(sizes["has_topic"], "has_topic")
sampling_spec = sampling_spec_builder.build()
model = sampler.create_sampling_model_from_spec(
graph_schema, sampling_spec, edge_sampler, get_features,
seed_node_dtype=tf.int64)
return model
# Create the sampling model
st.write("Creating sampling model...")
sampling_model = create_sampling_model(full_ogbn_mag_graph_tensor, train_sampling_sizes)
st.write("Sampling model created successfully.")
# Define seed dataset function
def seed_dataset(years: tf.Tensor, split_name: str) -> tf.data.Dataset:
"""Seed dataset as indices of papers within split years."""
if split_name == "train":
mask = years <= 2017 # 629,571 examples
elif split_name == "validation":
mask = years == 2018 # 64,879 examples
elif split_name == "test":
mask = years == 2019 # 41,939 examples
else:
raise ValueError(f"Unknown split_name: '{split_name}'")
seed_indices = tf.squeeze(tf.where(mask), axis=-1)
return tf.data.Dataset.from_tensor_slices(seed_indices)
# Define SubgraphDatasetProvider
class SubgraphDatasetProvider(runner.DatasetProvider):
"""Dataset Provider based on Sampler V2."""
def __init__(self, full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int], split_name: str):
super().__init__()
self._years = tf.squeeze(full_graph_tensor.node_sets["paper"]["year"], axis=-1)
self._sampling_model = create_sampling_model(full_graph_tensor, sizes)
self._split_name = split_name
self.input_graph_spec = self._sampling_model.output.spec
def get_dataset(self, context: tf.distribute.InputContext) -> tf.data.Dataset:
"""Creates TF dataset."""
self._seed_dataset = seed_dataset(self._years, self._split_name)
ds = self._seed_dataset.shard(
num_shards=context.num_input_pipelines, index=context.input_pipeline_id)
if self._split_name == "train":
ds = ds.shuffle(NUM_TRAINING_SAMPLES).repeat()
ds = ds.batch(128)
ds = ds.map(
functools.partial(self.sample),
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False,
)
return ds.unbatch().prefetch(tf.data.AUTOTUNE)
def sample(self, seeds: tf.Tensor) -> tfgnn.GraphTensor:
seeds = tf.cast(seeds, tf.int64)
batch_size = tf.size(seeds)
seeds_ragged = tf.RaggedTensor.from_row_lengths(
seeds, tf.ones([batch_size], tf.int64),
)
return self._sampling_model(seeds_ragged)
# Create dataset providers
st.write("Creating dataset providers...")
train_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, train_sampling_sizes, "train")
valid_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, validation_sample_sizes, "validation")
example_input_graph_spec = train_ds_provider.input_graph_spec._unbatch()
st.write("Dataset providers created successfully.")
# Define the model function
node_state_dim = 128
num_graph_updates = 4
message_dim = 128
state_dropout_rate = 0.2
l2_regularization = 1e-5
def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
if node_set_name == "field_of_study":
return tf.keras.layers.Embedding(50_000, 32)(node_set["hashed_id"])
if node_set_name == "institution":
return tf.keras.layers.Embedding(6_500, 16)(node_set["hashed_id"])
if node_set_name == "paper":
return tf.keras.layers.Dense(node_state_dim, activation="relu")(node_set["feat"])
if node_set_name == "author":
return node_set["empty_state"]
raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(inputs)
for _ in range(num_graph_updates):
graph = mt_albis.MtAlbisGraphUpdate(
units=node_state_dim,
message_dim=message_dim,
attention_type="none",
simple_conv_reduce_type="mean|sum",
normalization_type="layer",
next_state_type="residual",
state_dropout_rate=state_dropout_rate,
l2_regularization=l2_regularization
)(graph)
paper_state = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="state")(graph)
paper_state = tf.keras.layers.Dense(349, activation="softmax")(paper_state)
return tf.keras.Model(inputs, paper_state)
# Check for TPU/ GPU and set strategy
st.write("Setting up strategy for distributed training...")
if tf.config.list_physical_devices("TPU"):
st.write("Using TPUStrategy")
strategy = runner.TPUStrategy("local")
train_padding = runner.FitOrSkipPadding(example_input_graph_spec, train_ds_provider)
valid_padding = runner.TightPadding(example_input_graph_spec, valid_ds_provider)
elif tf.config.list_physical_devices("GPU"):
st.write("Using MirroredStrategy for GPUs")
strategy = tf.distribute.MirroredStrategy()
train_padding = None
valid_padding = None
else:
st.write("Using default strategy")
strategy = tf.distribute.get_strategy()
train_padding = None
valid_padding = None
st.write(f"Found {strategy.num_replicas_in_sync} replicas in sync")
# Define task
st.write("Defining the task...")
task = runner.NodeMulticlassClassification(
num_classes=349,
label_feature_name="paper_venue")
# Set hyperparameters
st.write("Setting hyperparameters...")
global_batch_size = 128
epochs = 10
initial_learning_rate = 0.001
steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size
validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size
learning_rate = tf.keras.optimizers.schedules.CosineDecay(
initial_learning_rate, steps_per_epoch * epochs)
optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate)
# Define trainer
st.write("Setting up the trainer...")
trainer = runner.KerasTrainer(
strategy=strategy,
model_dir="/tmp/gnn_model/",
callbacks=None,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
restore_best_weights=False,
checkpoint_every_n_steps="never",
summarize_every_n_steps="never",
backup_and_restore=False,
)
# Run training
st.write("Training the model...")
runner.run(
task=task,
model_fn=model_fn,
trainer=trainer,
optimizer_fn=optimizer_fn,
epochs=epochs,
global_batch_size=global_batch_size,
train_ds_provider=train_ds_provider,
valid_ds_provider=valid_ds_provider,
gtspec=example_input_graph_spec,
)
st.write("Training completed successfully.")