eaglelandsonce commited on
Commit
b4be21e
·
verified ·
1 Parent(s): b1de838

Create 19_Graphs.py

Browse files
Files changed (1) hide show
  1. pages/19_Graphs.py +255 -0
pages/19_Graphs.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import tensorflow_gnn as tfgnn
4
+ from tensorflow_gnn import runner
5
+ from tensorflow_gnn.experimental import sampler
6
+ from tensorflow_gnn.models import mt_albis
7
+ import functools
8
+ import os
9
+
10
+ # Set environment variable for legacy Keras
11
+ os.environ['TF_USE_LEGACY_KERAS'] = '1'
12
+
13
+ # Set Streamlit title
14
+ st.title("Solving OGBN-MAG end-to-end with TF-GNN")
15
+
16
+ # Install necessary packages
17
+ st.write("Installing necessary packages...")
18
+ !pip install -q tensorflow-gnn || echo "Ignoring package errors..."
19
+
20
+ st.write("Setting up the environment...")
21
+ tf.get_logger().setLevel('ERROR')
22
+ st.write(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.")
23
+
24
+ NUM_TRAINING_SAMPLES = 629571
25
+ NUM_VALIDATION_SAMPLES = 64879
26
+
27
+ GRAPH_TENSOR_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_tensor.example.pb'
28
+ SCHEMA_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_schema.pbtxt'
29
+
30
+ # Load the graph schema and graph tensor
31
+ st.write("Loading graph schema and tensor...")
32
+ graph_schema = tfgnn.read_schema(SCHEMA_FILE)
33
+ serialized_ogbn_mag_graph_tensor_string = tf.io.read_file(GRAPH_TENSOR_FILE)
34
+
35
+ full_ogbn_mag_graph_tensor = tfgnn.parse_single_example(
36
+ tfgnn.create_graph_spec_from_schema_pb(graph_schema, indices_dtype=tf.int64),
37
+ serialized_ogbn_mag_graph_tensor_string)
38
+
39
+ st.write("Graph tensor loaded successfully.")
40
+
41
+ # Define sampling sizes
42
+ train_sampling_sizes = {
43
+ "cites": 8,
44
+ "rev_writes": 8,
45
+ "writes": 8,
46
+ "affiliated_with": 8,
47
+ "has_topic": 8,
48
+ }
49
+ validation_sample_sizes = train_sampling_sizes.copy()
50
+
51
+ # Create sampling model
52
+ def create_sampling_model(full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int]) -> tf.keras.Model:
53
+ def edge_sampler(sampling_op: tfgnn.sampler.SamplingOp):
54
+ edge_set_name = sampling_op.edge_set_name
55
+ sample_size = sizes[edge_set_name]
56
+ return sampler.InMemUniformEdgesSampler.from_graph_tensor(
57
+ full_graph_tensor, edge_set_name, sample_size=sample_size
58
+ )
59
+
60
+ def get_features(node_set_name: tfgnn.NodeSetName):
61
+ return sampler.InMemIndexToFeaturesAccessor.from_graph_tensor(
62
+ full_graph_tensor, node_set_name
63
+ )
64
+
65
+ # Spell out the sampling procedure in python
66
+ sampling_spec_builder = tfgnn.sampler.SamplingSpecBuilder(graph_schema)
67
+ seed = sampling_spec_builder.seed("paper")
68
+ papers_cited_from_seed = seed.sample(sizes["cites"], "cites")
69
+ authors_of_papers = papers_cited_from_seed.join([seed]).sample(sizes["rev_writes"], "rev_writes")
70
+ papers_by_authors = authors_of_papers.sample(sizes["writes"], "writes")
71
+ institutions = authors_of_papers.sample(sizes["affiliated_with"], "affiliated_with")
72
+ fields_of_study = seed.join([papers_cited_from_seed, papers_by_authors]).sample(sizes["has_topic"], "has_topic")
73
+ sampling_spec = sampling_spec_builder.build()
74
+
75
+ model = sampler.create_sampling_model_from_spec(
76
+ graph_schema, sampling_spec, edge_sampler, get_features,
77
+ seed_node_dtype=tf.int64)
78
+
79
+ return model
80
+
81
+ # Create the sampling model
82
+ st.write("Creating sampling model...")
83
+ sampling_model = create_sampling_model(full_ogbn_mag_graph_tensor, train_sampling_sizes)
84
+
85
+ st.write("Sampling model created successfully.")
86
+
87
+ # Define seed dataset function
88
+ def seed_dataset(years: tf.Tensor, split_name: str) -> tf.data.Dataset:
89
+ """Seed dataset as indices of papers within split years."""
90
+ if split_name == "train":
91
+ mask = years <= 2017 # 629,571 examples
92
+ elif split_name == "validation":
93
+ mask = years == 2018 # 64,879 examples
94
+ elif split_name == "test":
95
+ mask = years == 2019 # 41,939 examples
96
+ else:
97
+ raise ValueError(f"Unknown split_name: '{split_name}'")
98
+ seed_indices = tf.squeeze(tf.where(mask), axis=-1)
99
+ return tf.data.Dataset.from_tensor_slices(seed_indices)
100
+
101
+ # Define SubgraphDatasetProvider
102
+ class SubgraphDatasetProvider(runner.DatasetProvider):
103
+ """Dataset Provider based on Sampler V2."""
104
+
105
+ def __init__(self, full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int], split_name: str):
106
+ super().__init__()
107
+ self._years = tf.squeeze(full_graph_tensor.node_sets["paper"]["year"], axis=-1)
108
+ self._sampling_model = create_sampling_model(full_graph_tensor, sizes)
109
+ self._split_name = split_name
110
+ self.input_graph_spec = self._sampling_model.output.spec
111
+
112
+ def get_dataset(self, context: tf.distribute.InputContext) -> tf.data.Dataset:
113
+ """Creates TF dataset."""
114
+ self._seed_dataset = seed_dataset(self._years, self._split_name)
115
+ ds = self._seed_dataset.shard(
116
+ num_shards=context.num_input_pipelines, index=context.input_pipeline_id)
117
+ if self._split_name == "train":
118
+ ds = ds.shuffle(NUM_TRAINING_SAMPLES).repeat()
119
+ ds = ds.batch(128)
120
+ ds = ds.map(
121
+ functools.partial(self.sample),
122
+ num_parallel_calls=tf.data.AUTOTUNE,
123
+ deterministic=False,
124
+ )
125
+ return ds.unbatch().prefetch(tf.data.AUTOTUNE)
126
+
127
+ def sample(self, seeds: tf.Tensor) -> tfgnn.GraphTensor:
128
+ seeds = tf.cast(seeds, tf.int64)
129
+ batch_size = tf.size(seeds)
130
+ seeds_ragged = tf.RaggedTensor.from_row_lengths(
131
+ seeds, tf.ones([batch_size], tf.int64),
132
+ )
133
+ return self._sampling_model(seeds_ragged)
134
+
135
+ # Create dataset providers
136
+ st.write("Creating dataset providers...")
137
+ train_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, train_sampling_sizes, "train")
138
+ valid_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, validation_sample_sizes, "validation")
139
+ example_input_graph_spec = train_ds_provider.input_graph_spec._unbatch()
140
+
141
+ st.write("Dataset providers created successfully.")
142
+
143
+ # Define the model function
144
+ def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
145
+ graph = inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
146
+ graph = tfgnn.keras.layers.MapFeatures(
147
+ node_sets_fn=set_initial_node_states)(graph)
148
+ for i in range(num_graph_updates):
149
+ graph = mt_albis.MtAlbisGraphUpdate(
150
+ units=node_state_dim,
151
+ message_dim=message_dim,
152
+ receiver_tag=tfgnn.SOURCE,
153
+ node_set_names=None if i < num_graph_updates-1 else ["paper"],
154
+ simple_conv_reduce_type="mean|sum",
155
+ state_dropout_rate=state_dropout_rate,
156
+ l2_regularization=l2_regularization,
157
+ normalization_type="layer",
158
+ next_state_type="residual",
159
+ )(graph)
160
+ return tf.keras.Model(inputs, graph)
161
+
162
+ # Check for TPU/ GPU and set strategy
163
+ st.write("Setting up strategy for distributed training...")
164
+ if tf.config.list_physical_devices("TPU"):
165
+ st.write("Using TPUStrategy")
166
+ strategy = runner.TPUStrategy("local")
167
+ train_padding = runner.FitOrSkipPadding(example_input_graph_spec, train_ds_provider)
168
+ valid_padding = runner.TightPadding(example_input_graph_spec, valid_ds_provider)
169
+ elif tf.config.list_physical_devices("GPU"):
170
+ st.write("Using MirroredStrategy for GPUs")
171
+ strategy = tf.distribute.MirroredStrategy()
172
+ train_padding = None
173
+ valid_padding = None
174
+ else:
175
+ st.write("Using default strategy")
176
+ strategy = tf.distribute.get_strategy()
177
+ train_padding = None
178
+ valid_padding = None
179
+
180
+ st.write(f"Found {strategy.num_replicas_in_sync} replicas in sync")
181
+
182
+ # Define task
183
+ st.write("Defining the task...")
184
+ task = runner.NodeMulticlassClassification(
185
+ num_classes=349,
186
+ label_feature_name="paper_venue")
187
+
188
+ # Set hyperparameters
189
+ st.write("Setting hyperparameters...")
190
+ global_batch_size = 128
191
+ epochs = 10
192
+ initial_learning_rate = 0.001
193
+
194
+ steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size
195
+ validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size
196
+ learning_rate = tf.keras.optimizers.schedules.CosineDecay(
197
+ initial_learning_rate, steps_per_epoch*epochs)
198
+ optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate)
199
+
200
+ # Define trainer
201
+ st.write("Setting up the trainer...")
202
+ trainer = runner.KerasTrainer(
203
+ strategy=strategy,
204
+ model_dir="/tmp/gnn_model/",
205
+ callbacks=None,
206
+ steps_per_epoch=steps_per_epoch,
207
+ validation_steps=validation_steps,
208
+ restore_best_weights=False,
209
+ checkpoint_every_n_steps="never",
210
+ summarize_every_n_steps="never",
211
+ backup_and_restore=False,
212
+ )
213
+
214
+ # Define feature processors
215
+ st.write("Defining feature processors...")
216
+ def process_node_features(node_set: tfgnn.NodeSet, node_set_name: str):
217
+ if node_set_name == "field_of_study":
218
+ return {"hashed_id": tf.keras.layers.Hashing(50_000)(node_set["#id"])}
219
+ if node_set_name == "institution":
220
+ return {"hashed_id": tf.keras.layers.Hashing(6_500)(node_set["#id"])}
221
+ if node_set_name == "paper":
222
+ return {"feat": node_set["feat"], "label": node_set["label"]}
223
+ if node_set_name == "author":
224
+ return {"empty_state": tfgnn.keras.layers.MakeEmptyFeature()(node_set)}
225
+ raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
226
+
227
+ def drop_all_features(_, **unused_kwargs):
228
+ return {}
229
+
230
+ process_features = tfgnn.keras.layers.MapFeatures(
231
+ context_fn=drop_all_features,
232
+ node_sets_fn=process_node_features,
233
+ edge_sets_fn=drop_all_features)
234
+
235
+ add_readout = tfgnn.keras.layers.AddReadoutFromFirstNode("seed", node_set_name="paper")
236
+ move_label_to_readout = tfgnn.keras.layers.StructuredReadoutIntoFeature(
237
+ "seed", feature_name="label", new_feature_name="paper_venue", remove_input_feature=True)
238
+
239
+ feature_processors = [process_features, add_readout, move_label_to_readout]
240
+
241
+ # Run training
242
+ st.write("Training the model...")
243
+ runner.run(
244
+ task=task,
245
+ model_fn=model_fn,
246
+ trainer=trainer,
247
+ optimizer_fn=optimizer_fn,
248
+ epochs=epochs,
249
+ global_batch_size=global_batch_size,
250
+ train_ds_provider=train_ds_provider,
251
+ valid_ds_provider=valid_ds_provider,
252
+ gtspec=example_input_graph_spec,
253
+ )
254
+
255
+ st.write("Training completed successfully.")