Spaces:
Running
Running
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Run training of one or more algorithmic tasks from CLRS.""" | |
import os | |
# disable logging until training starts | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
import functools | |
import os | |
import shutil | |
from typing import Any, Dict, List | |
from absl import app | |
from absl import flags | |
from absl import logging | |
# disable logging until training starts | |
logging.set_verbosity(logging.ERROR) | |
import clrs | |
import jax | |
import numpy as np | |
import requests | |
import tensorflow as tf | |
from baselines import BaselineModel, BaselineModelChunked | |
import pickle | |
import copy | |
flags.DEFINE_list('algorithms', ['floyd_warshall'], 'Which algorithms to run.') | |
flags.DEFINE_list('train_lengths', ['4', '7', '11', '13', '16'], | |
'Which training sizes to use. A size of -1 means ' | |
'use the benchmark dataset.') | |
flags.DEFINE_integer('length_needle', -8, | |
'Length of needle for training and validation ' | |
'(not testing) in string matching algorithms. ' | |
'A negative value randomizes the length for each sample ' | |
'between 1 and the opposite of the value. ' | |
'A value of 0 means use always 1/4 of the length of ' | |
'the haystack (the default sampler behavior).') | |
flags.DEFINE_integer('seed', 42, 'Random seed to set') | |
flags.DEFINE_boolean('random_pos', True, | |
'Randomize the pos input common to all algos.') | |
flags.DEFINE_boolean('enforce_permutations', True, | |
'Whether to enforce permutation-type node pointers.') | |
flags.DEFINE_boolean('enforce_pred_as_input', True, | |
'Whether to change pred_h hints into pred inputs.') | |
flags.DEFINE_integer('batch_size', 32, 'Batch size used for training.') | |
flags.DEFINE_boolean('chunked_training', False, | |
'Whether to use chunking for training.') | |
flags.DEFINE_integer('chunk_length', 16, | |
'Time chunk length used for training (if ' | |
'`chunked_training` is True.') | |
flags.DEFINE_integer('train_steps', 500, 'Number of training iterations.') | |
flags.DEFINE_integer('eval_every', 50, 'Evaluation frequency (in steps).') | |
flags.DEFINE_integer('test_every', 500, 'Evaluation frequency (in steps).') | |
flags.DEFINE_integer('log_every', 50, 'Logging frequency (in steps).') | |
flags.DEFINE_integer('hidden_size', 128, | |
'Number of hidden units of the model.') | |
flags.DEFINE_integer('nb_heads', 1, 'Number of heads for GAT processors') | |
flags.DEFINE_integer('nb_msg_passing_steps', 1, | |
'Number of message passing steps to run per hint.') | |
flags.DEFINE_float('learning_rate', 0.001, 'Learning rate to use.') | |
flags.DEFINE_float('grad_clip_max_norm', 1.0, | |
'Gradient clipping by norm. 0.0 disables grad clipping') | |
flags.DEFINE_float('dropout_prob', 0.0, 'Dropout rate to use.') | |
flags.DEFINE_float('hint_teacher_forcing', 0.0, | |
'Probability that ground-truth teacher hints are encoded ' | |
'during training instead of predicted hints. Only ' | |
'pertinent in encoded_decoded modes.') | |
flags.DEFINE_enum('hint_mode', 'encoded_decoded', | |
['encoded_decoded', 'decoded_only', 'none'], | |
'How should hints be used? Note, each mode defines a ' | |
'separate task, with various difficulties. `encoded_decoded` ' | |
'requires the model to explicitly materialise hint sequences ' | |
'and therefore is hardest, but also most aligned to the ' | |
'underlying algorithmic rule. Hence, `encoded_decoded` ' | |
'should be treated as the default mode for our benchmark. ' | |
'In `decoded_only`, hints are only used for defining ' | |
'reconstruction losses. Often, this will perform well, but ' | |
'note that we currently do not make any efforts to ' | |
'counterbalance the various hint losses. Hence, for certain ' | |
'tasks, the best performance will now be achievable with no ' | |
'hint usage at all (`none`).') | |
flags.DEFINE_enum('hint_repred_mode', 'soft', ['soft', 'hard', 'hard_on_eval'], | |
'How to process predicted hints when fed back as inputs.' | |
'In soft mode, we use softmaxes for categoricals, pointers ' | |
'and mask_one, and sigmoids for masks. ' | |
'In hard mode, we use argmax instead of softmax, and hard ' | |
'thresholding of masks. ' | |
'In hard_on_eval mode, soft mode is ' | |
'used for training and hard mode is used for evaluation.') | |
flags.DEFINE_boolean('use_ln', True, | |
'Whether to use layer normalisation in the processor.') | |
flags.DEFINE_boolean('use_lstm', False, | |
'Whether to insert an LSTM after message passing.') | |
flags.DEFINE_integer('nb_triplet_fts', 8, | |
'How many triplet features to compute?') | |
flags.DEFINE_enum('encoder_init', 'xavier_on_scalars', | |
['default', 'xavier_on_scalars'], | |
'Initialiser to use for the encoders.') | |
flags.DEFINE_enum('processor_type', 'triplet_gmpnn', | |
['deepsets', 'mpnn', 'pgn', 'pgn_mask', | |
'triplet_mpnn', 'triplet_pgn', 'triplet_pgn_mask', | |
'gat', 'gatv2', 'gat_full', 'gatv2_full', | |
'gpgn', 'gpgn_mask', 'gmpnn', | |
'triplet_gpgn', 'triplet_gpgn_mask', 'triplet_gmpnn'], | |
'Processor type to use as the network P.') | |
flags.DEFINE_string('checkpoint_path', './checkpoints', | |
'Path in which checkpoints are saved.') | |
flags.DEFINE_string('dataset_path', '/tmp/CLRS30', | |
'Path in which dataset is stored.') | |
flags.DEFINE_boolean('freeze_processor', False, | |
'Whether to freeze the processor of the model.') | |
FLAGS = flags.FLAGS | |
PRED_AS_INPUT_ALGOS = [ | |
'binary_search', | |
'minimum', | |
'find_maximum_subarray', | |
'find_maximum_subarray_kadane', | |
'matrix_chain_order', | |
'lcs_length', | |
'optimal_bst', | |
'activity_selector', | |
'task_scheduling', | |
'naive_string_matcher', | |
'kmp_matcher', | |
'jarvis_march'] | |
def unpack(v): | |
try: | |
return v.item() # DeviceArray | |
except (AttributeError, ValueError): | |
return v | |
def _iterate_sampler(sampler, batch_size): | |
while True: | |
yield sampler.next(batch_size) | |
def _maybe_download_dataset(dataset_path): | |
"""Download CLRS30 dataset if needed.""" | |
dataset_folder = os.path.join(dataset_path, clrs.get_clrs_folder()) | |
if os.path.isdir(dataset_folder): | |
logging.info('Dataset found at %s. Skipping download.', dataset_folder) | |
return dataset_folder | |
logging.info('Dataset not found in %s. Downloading...', dataset_folder) | |
clrs_url = clrs.get_dataset_gcp_url() | |
request = requests.get(clrs_url, allow_redirects=True) | |
clrs_file = os.path.join(dataset_path, os.path.basename(clrs_url)) | |
os.makedirs(dataset_folder) | |
open(clrs_file, 'wb').write(request.content) | |
shutil.unpack_archive(clrs_file, extract_dir=dataset_folder) | |
os.remove(clrs_file) | |
return dataset_folder | |
def make_sampler(length: int, | |
rng: Any, | |
algorithm: str, | |
split: str, | |
batch_size: int, | |
multiplier: int, | |
randomize_pos: bool, | |
enforce_pred_as_input: bool, | |
enforce_permutations: bool, | |
chunked: bool, | |
chunk_length: int, | |
sampler_kwargs: Dict[str, Any]): | |
"""Create a sampler with given options. | |
Args: | |
length: Size of samples (i.e., number of nodes in the graph). | |
A length of -1 will mean that the benchmark | |
dataset (for the given split) is used. Positive sizes will instantiate | |
samplers of the corresponding size. | |
rng: Numpy random state. | |
algorithm: The name of the algorithm to sample from. | |
split: 'train', 'val' or 'test'. | |
batch_size: Samples per batch. | |
multiplier: Integer multiplier for the number of samples in the dataset, | |
only used for positive sizes. Negative multiplier means infinite samples. | |
randomize_pos: Whether to randomize the `pos` input. | |
enforce_pred_as_input: Whether to convert fixed pred_h hints to inputs. | |
enforce_permutations: Whether to enforce permutation pointers. | |
chunked: Whether to chunk the dataset. | |
chunk_length: Unroll length of chunks, if `chunked` is True. | |
sampler_kwargs: Extra args passed to the sampler. | |
Returns: | |
A sampler (iterator), the number of samples in the iterator (negative | |
if infinite samples), and the spec. | |
""" | |
if length < 0: # load from file | |
dataset_folder = _maybe_download_dataset(FLAGS.dataset_path) | |
sampler, num_samples, spec = clrs.create_dataset(folder=dataset_folder, | |
algorithm=algorithm, | |
batch_size=batch_size, | |
split=split) | |
sampler = sampler.as_numpy_iterator() | |
else: | |
num_samples = clrs.CLRS30[split]['num_samples'] * multiplier | |
sampler, spec = clrs.build_sampler( | |
algorithm, | |
seed=rng.randint(2**32), | |
num_samples=num_samples, | |
length=length, | |
**sampler_kwargs, | |
) | |
sampler = _iterate_sampler(sampler, batch_size) | |
if randomize_pos: | |
sampler = clrs.process_random_pos(sampler, rng) | |
if enforce_pred_as_input and algorithm in PRED_AS_INPUT_ALGOS: | |
spec, sampler = clrs.process_pred_as_input(spec, sampler) | |
spec, sampler = clrs.process_permutations(spec, sampler, enforce_permutations) | |
if chunked: | |
sampler = clrs.chunkify(sampler, chunk_length) | |
return sampler, num_samples, spec | |
def make_multi_sampler(sizes, rng, **kwargs): | |
"""Create a sampler with cycling sample sizes.""" | |
ss = [] | |
tot_samples = 0 | |
for length in sizes: | |
sampler, num_samples, spec = make_sampler(length, rng, **kwargs) | |
ss.append(sampler) | |
tot_samples += num_samples | |
def cycle_samplers(): | |
while True: | |
for s in ss: | |
yield next(s) | |
return cycle_samplers(), tot_samples, spec | |
def _concat(dps, axis): | |
return jax.tree_util.tree_map(lambda *x: np.concatenate(x, axis), *dps) | |
def collect_and_eval(sampler, predict_fn, sample_count, rng_key, extras): | |
"""Collect batches of output and hint preds and evaluate them.""" | |
processed_samples = 0 | |
preds = [] | |
outputs = [] | |
while processed_samples < sample_count: | |
feedback = next(sampler) | |
batch_size = feedback.outputs[0].data.shape[0] | |
outputs.append(feedback.outputs) | |
new_rng_key, rng_key = jax.random.split(rng_key) | |
cur_preds, _ = predict_fn(new_rng_key, feedback.features) | |
preds.append(cur_preds) | |
processed_samples += batch_size | |
outputs = _concat(outputs, axis=0) | |
preds = _concat(preds, axis=0) | |
out = clrs.evaluate(outputs, preds) | |
if extras: | |
out.update(extras) | |
return {k: unpack(v) for k, v in out.items()} | |
def create_samplers(rng, train_lengths: List[int]): | |
"""Create all the samplers.""" | |
train_samplers = [] | |
val_samplers = [] | |
val_sample_counts = [] | |
test_samplers = [] | |
test_sample_counts = [] | |
spec_list = [] | |
for algo_idx, algorithm in enumerate(FLAGS.algorithms): | |
# Make full dataset pipeline run on CPU (including prefetching). | |
with tf.device('/cpu:0'): | |
if algorithm in ['naive_string_matcher', 'kmp_matcher']: | |
# Fixed haystack + needle; variability will be in needle | |
# Still, for chunked training, we maintain as many samplers | |
# as train lengths, since, for each length there is a separate state, | |
# and we must keep the 1:1 relationship between states and samplers. | |
max_length = max(train_lengths) | |
if max_length > 0: # if < 0, we are using the benchmark data | |
max_length = (max_length * 5) // 4 | |
train_lengths = [max_length] | |
if FLAGS.chunked_training: | |
train_lengths = train_lengths * len(train_lengths) | |
logging.info('Creating samplers for algo %s', algorithm) | |
p = tuple([0.1 + 0.1 * i for i in range(9)]) | |
if p and algorithm in ['articulation_points', 'bridges', | |
'mst_kruskal', 'bipartite_matching']: | |
# Choose a lower connection probability for the above algorithms, | |
# otherwise trajectories are very long | |
p = tuple(np.array(p) / 2) | |
length_needle = FLAGS.length_needle | |
sampler_kwargs = dict(p=p, length_needle=length_needle) | |
if length_needle == 0: | |
sampler_kwargs.pop('length_needle') | |
common_sampler_args = dict( | |
algorithm=FLAGS.algorithms[algo_idx], | |
rng=rng, | |
enforce_pred_as_input=FLAGS.enforce_pred_as_input, | |
enforce_permutations=FLAGS.enforce_permutations, | |
chunk_length=FLAGS.chunk_length, | |
) | |
train_args = dict(sizes=train_lengths, | |
split='train', | |
batch_size=FLAGS.batch_size, | |
multiplier=-1, | |
randomize_pos=FLAGS.random_pos, | |
chunked=FLAGS.chunked_training, | |
sampler_kwargs=sampler_kwargs, | |
**common_sampler_args) | |
train_sampler, _, spec = make_multi_sampler(**train_args) | |
mult = clrs.CLRS_30_ALGS_SETTINGS[algorithm]['num_samples_multiplier'] | |
val_args = dict(sizes=[np.amax(train_lengths)], | |
split='val', | |
batch_size=32, | |
multiplier=2 * mult, | |
randomize_pos=FLAGS.random_pos, | |
chunked=False, | |
sampler_kwargs=sampler_kwargs, | |
**common_sampler_args) | |
val_sampler, val_samples, spec = make_multi_sampler(**val_args) | |
test_args = dict(sizes=[-1], | |
split='test', | |
batch_size=32, | |
multiplier=2 * mult, | |
randomize_pos=False, | |
chunked=False, | |
sampler_kwargs={}, | |
**common_sampler_args) | |
test_sampler, test_samples, spec = make_multi_sampler(**test_args) | |
spec_list.append(spec) | |
train_samplers.append(train_sampler) | |
val_samplers.append(val_sampler) | |
val_sample_counts.append(val_samples) | |
test_samplers.append(test_sampler) | |
test_sample_counts.append(test_samples) | |
return (train_samplers, | |
val_samplers, val_sample_counts, | |
test_samplers, test_sample_counts, | |
spec_list) | |
def main(unused_argv): | |
if FLAGS.hint_mode == 'encoded_decoded': | |
encode_hints = True | |
decode_hints = True | |
elif FLAGS.hint_mode == 'decoded_only': | |
encode_hints = False | |
decode_hints = True | |
elif FLAGS.hint_mode == 'none': | |
encode_hints = False | |
decode_hints = False | |
else: | |
raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none}.') | |
train_lengths = [int(x) for x in FLAGS.train_lengths] | |
rng = np.random.RandomState(FLAGS.seed) | |
rng_key = jax.random.PRNGKey(rng.randint(2**32)) | |
# Create samplers | |
(train_samplers, | |
val_samplers, val_sample_counts, | |
test_samplers, test_sample_counts, | |
spec_list) = create_samplers(rng, train_lengths) | |
processor_factory = clrs.get_processor_factory( | |
FLAGS.processor_type, | |
use_ln=FLAGS.use_ln, | |
nb_triplet_fts=FLAGS.nb_triplet_fts, | |
nb_heads=FLAGS.nb_heads | |
) | |
model_params = dict( | |
processor_factory=processor_factory, | |
hidden_dim=FLAGS.hidden_size, | |
encode_hints=encode_hints, | |
decode_hints=decode_hints, | |
encoder_init=FLAGS.encoder_init, | |
use_lstm=FLAGS.use_lstm, | |
learning_rate=FLAGS.learning_rate, | |
grad_clip_max_norm=FLAGS.grad_clip_max_norm, | |
checkpoint_path=FLAGS.checkpoint_path, | |
freeze_processor=FLAGS.freeze_processor, | |
dropout_prob=FLAGS.dropout_prob, | |
hint_teacher_forcing=FLAGS.hint_teacher_forcing, | |
hint_repred_mode=FLAGS.hint_repred_mode, | |
nb_msg_passing_steps=FLAGS.nb_msg_passing_steps, | |
) | |
# save spec_list and model_params; do not change or delete!! | |
if not os.path.exists(FLAGS.checkpoint_path): | |
os.makedirs(FLAGS.checkpoint_path) | |
with open(os.path.join(FLAGS.checkpoint_path, 'spec_list.pkl'), 'wb') as f: | |
pickle.dump(spec_list, f) | |
model_params_save = copy.deepcopy(model_params) | |
model_params_save["processor_factory"] = (FLAGS.processor_type, FLAGS.use_ln, FLAGS.nb_triplet_fts, FLAGS.nb_heads) | |
with open(os.path.join(FLAGS.checkpoint_path, 'model_params.pkl'), 'wb') as f: | |
pickle.dump(model_params_save, f) | |
eval_model = BaselineModel( | |
spec=spec_list, | |
dummy_trajectory=[next(t) for t in val_samplers], | |
**model_params | |
) | |
if FLAGS.chunked_training: | |
train_model = BaselineModelChunked( | |
spec=spec_list, | |
dummy_trajectory=[next(t) for t in train_samplers], | |
**model_params | |
) | |
else: | |
train_model = eval_model | |
# Training loop. | |
best_score = -1.0 | |
current_train_items = [0] * len(FLAGS.algorithms) | |
step = 0 | |
next_eval = 0 | |
# Make sure scores improve on first step, but not overcome best score | |
# until all algos have had at least one evaluation. | |
val_scores = [-99999.9] * len(FLAGS.algorithms) | |
length_idx = 0 | |
while step < FLAGS.train_steps: | |
feedback_list = [next(t) for t in train_samplers] | |
# Initialize model. | |
if step == 0: | |
all_features = [f.features for f in feedback_list] | |
if FLAGS.chunked_training: | |
# We need to initialize the model with samples of all lengths for | |
# all algorithms. Also, we need to make sure that the order of these | |
# sample sizes is the same as the order of the actual training sizes. | |
all_length_features = [all_features] + [ | |
[next(t).features for t in train_samplers] | |
for _ in range(len(train_lengths))] | |
train_model.init(all_length_features[:-1], FLAGS.seed + 1) | |
else: | |
train_model.init(all_features, FLAGS.seed + 1) | |
# Training step. | |
# enable logging now that we have initialized the model | |
logging.set_verbosity(logging.INFO) | |
for algo_idx in range(len(train_samplers)): | |
feedback = feedback_list[algo_idx] | |
rng_key, new_rng_key = jax.random.split(rng_key) | |
if FLAGS.chunked_training: | |
# In chunked training, we must indicate which training length we are | |
# using, so the model uses the correct state. | |
length_and_algo_idx = (length_idx, algo_idx) | |
else: | |
# In non-chunked training, all training lengths can be treated equally, | |
# since there is no state to maintain between batches. | |
length_and_algo_idx = algo_idx | |
cur_loss = train_model.feedback(rng_key, feedback, length_and_algo_idx) | |
rng_key = new_rng_key | |
if FLAGS.chunked_training: | |
examples_in_chunk = np.sum(feedback.features.is_last).item() | |
else: | |
examples_in_chunk = len(feedback.features.lengths) | |
current_train_items[algo_idx] += examples_in_chunk | |
if step % FLAGS.log_every == 0: | |
logging.info('Algo %s step %i current loss %f, current_train_items %i.', | |
FLAGS.algorithms[algo_idx], step, | |
cur_loss, current_train_items[algo_idx]) | |
# Periodically evaluate model | |
if step >= next_eval: | |
eval_model.params = train_model.params | |
for algo_idx in range(len(train_samplers)): | |
common_extras = {'examples_seen': current_train_items[algo_idx], | |
'step': step, | |
'algorithm': FLAGS.algorithms[algo_idx]} | |
# Validation info. | |
new_rng_key, rng_key = jax.random.split(rng_key) | |
val_stats = collect_and_eval( | |
val_samplers[algo_idx], | |
functools.partial(eval_model.predict, algorithm_index=algo_idx), | |
val_sample_counts[algo_idx], | |
new_rng_key, | |
extras=common_extras) | |
logging.info('(val) algo %s step %d: %s', | |
FLAGS.algorithms[algo_idx], step, val_stats) | |
val_scores[algo_idx] = val_stats['score'] | |
next_eval += FLAGS.eval_every | |
# If best total score, update best checkpoint. | |
# Also save a best checkpoint on the first step. | |
msg = (f'best avg val score was ' | |
f'{best_score/len(FLAGS.algorithms):.3f}, ' | |
f'current avg val score is {np.mean(val_scores):.3f}, ' | |
f'val scores are: ') | |
msg += ', '.join( | |
['%s: %.3f' % (x, y) for (x, y) in zip(FLAGS.algorithms, val_scores)]) | |
if (sum(val_scores) > best_score) or step == 0: | |
best_score = sum(val_scores) | |
logging.info('Checkpointing best model, %s', msg) | |
train_model.save_model('best.pkl') | |
else: | |
logging.info('Not saving new best model, %s', msg) | |
step += 1 | |
length_idx = (length_idx + 1) % len(train_lengths) | |
logging.info('Restoring best model from checkpoint...') | |
eval_model.restore_model('best.pkl', only_load_processor=False) | |
for algo_idx in range(len(train_samplers)): | |
common_extras = {'examples_seen': current_train_items[algo_idx], | |
'step': step, | |
'algorithm': FLAGS.algorithms[algo_idx]} | |
new_rng_key, rng_key = jax.random.split(rng_key) | |
test_stats = collect_and_eval( | |
test_samplers[algo_idx], | |
functools.partial(eval_model.predict, algorithm_index=algo_idx), | |
test_sample_counts[algo_idx], | |
new_rng_key, | |
extras=common_extras) | |
logging.info('(test) algo %s : %s', FLAGS.algorithms[algo_idx], test_stats) | |
logging.info('Done!') | |
if __name__ == '__main__': | |
app.run(main) | |