eaglelandsonce commited on
Commit
3c86a58
·
verified ·
1 Parent(s): 40eeee3

Delete pages/19_Graphs3.py

Browse files
Files changed (1) hide show
  1. pages/19_Graphs3.py +0 -244
pages/19_Graphs3.py DELETED
@@ -1,244 +0,0 @@
1
- import streamlit as st
2
- import tensorflow as tf
3
- import tensorflow_gnn as tfgnn
4
- from tensorflow_gnn import runner
5
- from tensorflow_gnn.experimental import sampler
6
- from tensorflow_gnn.models import mt_albis
7
- import functools
8
- import os
9
- from typing import Mapping
10
-
11
- # Set environment variable for legacy Keras
12
- os.environ['TF_USE_LEGACY_KERAS'] = '1'
13
-
14
- # Set Streamlit title
15
- st.title("Solving OGBN-MAG end-to-end with TF-GNN")
16
-
17
- st.write("Setting up the environment...")
18
- tf.get_logger().setLevel('ERROR')
19
- st.write(f"Running TF-GNN {tfgnn.__version__} under TensorFlow {tf.__version__}.")
20
-
21
- NUM_TRAINING_SAMPLES = 629571
22
- NUM_VALIDATION_SAMPLES = 64879
23
-
24
- GRAPH_TENSOR_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_tensor.example.pb'
25
- SCHEMA_FILE = 'gs://download.tensorflow.org/data/ogbn-mag/sampled/v2/graph_schema.pbtxt'
26
-
27
- # Load the graph schema and graph tensor
28
- st.write("Loading graph schema and tensor...")
29
- graph_schema = tfgnn.read_schema(SCHEMA_FILE)
30
- serialized_ogbn_mag_graph_tensor_string = tf.io.read_file(GRAPH_TENSOR_FILE)
31
-
32
- full_ogbn_mag_graph_tensor = tfgnn.parse_single_example(
33
- tfgnn.create_graph_spec_from_schema_pb(graph_schema, indices_dtype=tf.int64),
34
- serialized_ogbn_mag_graph_tensor_string)
35
-
36
- st.write("Graph tensor loaded successfully.")
37
-
38
- # Define sampling sizes
39
- train_sampling_sizes = {
40
- "cites": 8,
41
- "rev_writes": 8,
42
- "writes": 8,
43
- "affiliated_with": 8,
44
- "has_topic": 8,
45
- }
46
- validation_sample_sizes = train_sampling_sizes.copy()
47
-
48
- # Create sampling model
49
- def create_sampling_model(full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int]) -> tf.keras.Model:
50
- def edge_sampler(sampling_op: tfgnn.sampler.SamplingOp):
51
- edge_set_name = sampling_op.edge_set_name
52
- sample_size = sizes[edge_set_name]
53
- return sampler.InMemUniformEdgesSampler.from_graph_tensor(
54
- full_graph_tensor, edge_set_name, sample_size=sample_size
55
- )
56
-
57
- def get_features(node_set_name: tfgnn.NodeSetName):
58
- return sampler.InMemIndexToFeaturesAccessor.from_graph_tensor(
59
- full_graph_tensor, node_set_name
60
- )
61
-
62
- # Spell out the sampling procedure in python
63
- sampling_spec_builder = tfgnn.sampler.SamplingSpecBuilder(graph_schema)
64
- seed = sampling_spec_builder.seed("paper")
65
- papers_cited_from_seed = seed.sample(sizes["cites"], "cites")
66
- authors_of_papers = papers_cited_from_seed.join([seed]).sample(sizes["rev_writes"], "rev_writes")
67
- papers_by_authors = authors_of_papers.sample(sizes["writes"], "writes")
68
- institutions = authors_of_papers.sample(sizes["affiliated_with"], "affiliated_with")
69
- fields_of_study = seed.join([papers_cited_from_seed, papers_by_authors]).sample(sizes["has_topic"], "has_topic")
70
- sampling_spec = sampling_spec_builder.build()
71
-
72
- model = sampler.create_sampling_model_from_spec(
73
- graph_schema, sampling_spec, edge_sampler, get_features,
74
- seed_node_dtype=tf.int64)
75
-
76
- return model
77
-
78
- # Create the sampling model
79
- st.write("Creating sampling model...")
80
- sampling_model = create_sampling_model(full_ogbn_mag_graph_tensor, train_sampling_sizes)
81
-
82
- st.write("Sampling model created successfully.")
83
-
84
- # Define seed dataset function
85
- def seed_dataset(years: tf.Tensor, split_name: str) -> tf.data.Dataset:
86
- """Seed dataset as indices of papers within split years."""
87
- if split_name == "train":
88
- mask = years <= 2017 # 629,571 examples
89
- elif split_name == "validation":
90
- mask = years == 2018 # 64,879 examples
91
- elif split_name == "test":
92
- mask = years == 2019 # 41,939 examples
93
- else:
94
- raise ValueError(f"Unknown split_name: '{split_name}'")
95
- seed_indices = tf.squeeze(tf.where(mask), axis=-1)
96
- return tf.data.Dataset.from_tensor_slices(seed_indices)
97
-
98
- # Define SubgraphDatasetProvider
99
- class SubgraphDatasetProvider(runner.DatasetProvider):
100
- """Dataset Provider based on Sampler V2."""
101
-
102
- def __init__(self, full_graph_tensor: tfgnn.GraphTensor, sizes: Mapping[str, int], split_name: str):
103
- super().__init__()
104
- self._years = tf.squeeze(full_graph_tensor.node_sets["paper"]["year"], axis=-1)
105
- self._sampling_model = create_sampling_model(full_graph_tensor, sizes)
106
- self._split_name = split_name
107
- self.input_graph_spec = self._sampling_model.output.spec
108
-
109
- def get_dataset(self, context: tf.distribute.InputContext) -> tf.data.Dataset:
110
- """Creates TF dataset."""
111
- self._seed_dataset = seed_dataset(self._years, self._split_name)
112
- ds = self._seed_dataset.shard(
113
- num_shards=context.num_input_pipelines, index=context.input_pipeline_id)
114
- if self._split_name == "train":
115
- ds = ds.shuffle(NUM_TRAINING_SAMPLES).repeat()
116
- ds = ds.batch(128)
117
- ds = ds.map(
118
- functools.partial(self.sample),
119
- num_parallel_calls=tf.data.AUTOTUNE,
120
- deterministic=False,
121
- )
122
- return ds.unbatch().prefetch(tf.data.AUTOTUNE)
123
-
124
- def sample(self, seeds: tf.Tensor) -> tfgnn.GraphTensor:
125
- seeds = tf.cast(seeds, tf.int64)
126
- batch_size = tf.size(seeds)
127
- seeds_ragged = tf.RaggedTensor.from_row_lengths(
128
- seeds, tf.ones([batch_size], tf.int64),
129
- )
130
- return self._sampling_model(seeds_ragged)
131
-
132
- # Create dataset providers
133
- st.write("Creating dataset providers...")
134
- train_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, train_sampling_sizes, "train")
135
- valid_ds_provider = SubgraphDatasetProvider(full_ogbn_mag_graph_tensor, validation_sample_sizes, "validation")
136
- example_input_graph_spec = train_ds_provider.input_graph_spec._unbatch()
137
-
138
- st.write("Dataset providers created successfully.")
139
-
140
- # Define the model function
141
- node_state_dim = 128
142
- num_graph_updates = 4
143
- message_dim = 128
144
- state_dropout_rate = 0.2
145
- l2_regularization = 1e-5
146
-
147
- def set_initial_node_states(node_set: tfgnn.NodeSet, node_set_name: str):
148
- if node_set_name == "field_of_study":
149
- return tf.keras.layers.Embedding(50_000, 32)(node_set["hashed_id"])
150
- if node_set_name == "institution":
151
- return tf.keras.layers.Embedding(6_500, 16)(node_set["hashed_id"])
152
- if node_set_name == "paper":
153
- return tf.keras.layers.Dense(node_state_dim, activation="relu")(node_set["feat"])
154
- if node_set_name == "author":
155
- return node_set["empty_state"]
156
- raise KeyError(f"Unexpected node_set_name='{node_set_name}'")
157
-
158
- def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
159
- inputs = tf.keras.layers.Input(type_spec=graph_tensor_spec)
160
- graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=set_initial_node_states)(inputs)
161
-
162
- for _ in range(num_graph_updates):
163
- graph = mt_albis.MtAlbisGraphUpdate(
164
- units=node_state_dim,
165
- message_dim=message_dim,
166
- attention_type="none",
167
- simple_conv_reduce_type="mean|sum",
168
- normalization_type="layer",
169
- next_state_type="residual",
170
- state_dropout_rate=state_dropout_rate,
171
- l2_regularization=l2_regularization
172
- )(graph)
173
-
174
- paper_state = tfgnn.keras.layers.Readout(node_set_name="paper", feature_name="state")(graph)
175
- paper_state = tf.keras.layers.Dense(349, activation="softmax")(paper_state)
176
- return tf.keras.Model(inputs, paper_state)
177
-
178
- # Check for TPU/ GPU and set strategy
179
- st.write("Setting up strategy for distributed training...")
180
- if tf.config.list_physical_devices("TPU"):
181
- st.write("Using TPUStrategy")
182
- strategy = runner.TPUStrategy("local")
183
- train_padding = runner.FitOrSkipPadding(example_input_graph_spec, train_ds_provider)
184
- valid_padding = runner.TightPadding(example_input_graph_spec, valid_ds_provider)
185
- elif tf.config.list_physical_devices("GPU"):
186
- st.write("Using MirroredStrategy for GPUs")
187
- strategy = tf.distribute.MirroredStrategy()
188
- train_padding = None
189
- valid_padding = None
190
- else:
191
- st.write("Using default strategy")
192
- strategy = tf.distribute.get_strategy()
193
- train_padding = None
194
- valid_padding = None
195
-
196
- st.write(f"Found {strategy.num_replicas_in_sync} replicas in sync")
197
-
198
- # Define task
199
- st.write("Defining the task...")
200
- task = runner.NodeMulticlassClassification(
201
- num_classes=349,
202
- label_feature_name="paper_venue")
203
-
204
- # Set hyperparameters
205
- st.write("Setting hyperparameters...")
206
- global_batch_size = 128
207
- epochs = 10
208
- initial_learning_rate = 0.001
209
-
210
- steps_per_epoch = NUM_TRAINING_SAMPLES // global_batch_size
211
- validation_steps = NUM_VALIDATION_SAMPLES // global_batch_size
212
- learning_rate = tf.keras.optimizers.schedules.CosineDecay(
213
- initial_learning_rate, steps_per_epoch * epochs)
214
- optimizer_fn = functools.partial(tf.keras.optimizers.Adam, learning_rate=learning_rate)
215
-
216
- # Define trainer
217
- st.write("Setting up the trainer...")
218
- trainer = runner.KerasTrainer(
219
- strategy=strategy,
220
- model_dir="/tmp/gnn_model/",
221
- callbacks=None,
222
- steps_per_epoch=steps_per_epoch,
223
- validation_steps=validation_steps,
224
- restore_best_weights=False,
225
- checkpoint_every_n_steps="never",
226
- summarize_every_n_steps="never",
227
- backup_and_restore=False,
228
- )
229
-
230
- # Run training
231
- st.write("Training the model...")
232
- runner.run(
233
- task=task,
234
- model_fn=model_fn,
235
- trainer=trainer,
236
- optimizer_fn=optimizer_fn,
237
- epochs=epochs,
238
- global_batch_size=global_batch_size,
239
- train_ds_provider=train_ds_provider,
240
- valid_ds_provider=valid_ds_provider,
241
- gtspec=example_input_graph_spec,
242
- )
243
-
244
- st.write("Training completed successfully.")