|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Run training of one or more algorithmic tasks from CLRS.""" |
|
|
|
import os |
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
|
|
|
import functools |
|
import os |
|
import shutil |
|
from typing import Any, Dict, List |
|
|
|
from absl import app |
|
from absl import flags |
|
from absl import logging |
|
|
|
logging.set_verbosity(logging.ERROR) |
|
|
|
import clrs |
|
import jax |
|
import numpy as np |
|
import requests |
|
import tensorflow as tf |
|
import sys |
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../env"))) |
|
from baselines import BaselineModel, BaselineModelChunked |
|
import pickle |
|
|
|
flags.DEFINE_list('algorithms', ['floyd_warshall'], 'Which algorithms to run.') |
|
flags.DEFINE_list('train_lengths', ['4', '7', '11', '13', '16'], |
|
'Which training sizes to use. A size of -1 means ' |
|
'use the benchmark dataset.') |
|
flags.DEFINE_integer('length_needle', -8, |
|
'Length of needle for training and validation ' |
|
'(not testing) in string matching algorithms. ' |
|
'A negative value randomizes the length for each sample ' |
|
'between 1 and the opposite of the value. ' |
|
'A value of 0 means use always 1/4 of the length of ' |
|
'the haystack (the default sampler behavior).') |
|
flags.DEFINE_integer('seed', 42, 'Random seed to set') |
|
|
|
flags.DEFINE_boolean('random_pos', True, |
|
'Randomize the pos input common to all algos.') |
|
flags.DEFINE_boolean('enforce_permutations', True, |
|
'Whether to enforce permutation-type node pointers.') |
|
flags.DEFINE_boolean('enforce_pred_as_input', True, |
|
'Whether to change pred_h hints into pred inputs.') |
|
flags.DEFINE_integer('batch_size', 32, 'Batch size used for training.') |
|
flags.DEFINE_boolean('chunked_training', False, |
|
'Whether to use chunking for training.') |
|
flags.DEFINE_integer('chunk_length', 16, |
|
'Time chunk length used for training (if ' |
|
'`chunked_training` is True.') |
|
flags.DEFINE_integer('train_steps', 1000, 'Number of training iterations.') |
|
flags.DEFINE_integer('eval_every', 50, 'Evaluation frequency (in steps).') |
|
flags.DEFINE_integer('test_every', 500, 'Evaluation frequency (in steps).') |
|
flags.DEFINE_integer('log_every', 50, 'Logging frequency (in steps).') |
|
|
|
flags.DEFINE_integer('hidden_size', 128, |
|
'Number of hidden units of the model.') |
|
flags.DEFINE_integer('nb_heads', 1, 'Number of heads for GAT processors') |
|
flags.DEFINE_integer('nb_msg_passing_steps', 1, |
|
'Number of message passing steps to run per hint.') |
|
flags.DEFINE_float('learning_rate', 0.001, 'Learning rate to use.') |
|
flags.DEFINE_float('grad_clip_max_norm', 1.0, |
|
'Gradient clipping by norm. 0.0 disables grad clipping') |
|
flags.DEFINE_float('dropout_prob', 0.0, 'Dropout rate to use.') |
|
flags.DEFINE_float('hint_teacher_forcing', 0.0, |
|
'Probability that ground-truth teacher hints are encoded ' |
|
'during training instead of predicted hints. Only ' |
|
'pertinent in encoded_decoded modes.') |
|
flags.DEFINE_enum('hint_mode', 'encoded_decoded', |
|
['encoded_decoded', 'decoded_only', 'none'], |
|
'How should hints be used? Note, each mode defines a ' |
|
'separate task, with various difficulties. `encoded_decoded` ' |
|
'requires the model to explicitly materialise hint sequences ' |
|
'and therefore is hardest, but also most aligned to the ' |
|
'underlying algorithmic rule. Hence, `encoded_decoded` ' |
|
'should be treated as the default mode for our benchmark. ' |
|
'In `decoded_only`, hints are only used for defining ' |
|
'reconstruction losses. Often, this will perform well, but ' |
|
'note that we currently do not make any efforts to ' |
|
'counterbalance the various hint losses. Hence, for certain ' |
|
'tasks, the best performance will now be achievable with no ' |
|
'hint usage at all (`none`).') |
|
flags.DEFINE_enum('hint_repred_mode', 'soft', ['soft', 'hard', 'hard_on_eval'], |
|
'How to process predicted hints when fed back as inputs.' |
|
'In soft mode, we use softmaxes for categoricals, pointers ' |
|
'and mask_one, and sigmoids for masks. ' |
|
'In hard mode, we use argmax instead of softmax, and hard ' |
|
'thresholding of masks. ' |
|
'In hard_on_eval mode, soft mode is ' |
|
'used for training and hard mode is used for evaluation.') |
|
flags.DEFINE_boolean('use_ln', True, |
|
'Whether to use layer normalisation in the processor.') |
|
flags.DEFINE_boolean('use_lstm', False, |
|
'Whether to insert an LSTM after message passing.') |
|
flags.DEFINE_integer('nb_triplet_fts', 8, |
|
'How many triplet features to compute?') |
|
|
|
flags.DEFINE_enum('encoder_init', 'xavier_on_scalars', |
|
['default', 'xavier_on_scalars'], |
|
'Initialiser to use for the encoders.') |
|
flags.DEFINE_enum('processor_type', 'triplet_gmpnn', |
|
['deepsets', 'mpnn', 'pgn', 'pgn_mask', |
|
'triplet_mpnn', 'triplet_pgn', 'triplet_pgn_mask', |
|
'gat', 'gatv2', 'gat_full', 'gatv2_full', |
|
'gpgn', 'gpgn_mask', 'gmpnn', |
|
'triplet_gpgn', 'triplet_gpgn_mask', 'triplet_gmpnn'], |
|
'Processor type to use as the network P.') |
|
|
|
flags.DEFINE_string('checkpoint_path', '../env/checkpoints', |
|
'Path in which checkpoints are saved.') |
|
flags.DEFINE_string('dataset_path', '/tmp/CLRS30', |
|
'Path in which dataset is stored.') |
|
flags.DEFINE_boolean('freeze_processor', False, |
|
'Whether to freeze the processor of the model.') |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
|
|
PRED_AS_INPUT_ALGOS = [ |
|
'binary_search', |
|
'minimum', |
|
'find_maximum_subarray', |
|
'find_maximum_subarray_kadane', |
|
'matrix_chain_order', |
|
'lcs_length', |
|
'optimal_bst', |
|
'activity_selector', |
|
'task_scheduling', |
|
'naive_string_matcher', |
|
'kmp_matcher', |
|
'jarvis_march'] |
|
|
|
|
|
def unpack(v): |
|
try: |
|
return v.item() |
|
except (AttributeError, ValueError): |
|
return v |
|
|
|
|
|
def _iterate_sampler(sampler, batch_size): |
|
while True: |
|
yield sampler.next(batch_size) |
|
|
|
|
|
def _maybe_download_dataset(dataset_path): |
|
"""Download CLRS30 dataset if needed.""" |
|
dataset_folder = os.path.join(dataset_path, clrs.get_clrs_folder()) |
|
if os.path.isdir(dataset_folder): |
|
logging.info('Dataset found at %s. Skipping download.', dataset_folder) |
|
return dataset_folder |
|
logging.info('Dataset not found in %s. Downloading...', dataset_folder) |
|
|
|
clrs_url = clrs.get_dataset_gcp_url() |
|
request = requests.get(clrs_url, allow_redirects=True) |
|
clrs_file = os.path.join(dataset_path, os.path.basename(clrs_url)) |
|
os.makedirs(dataset_folder) |
|
open(clrs_file, 'wb').write(request.content) |
|
shutil.unpack_archive(clrs_file, extract_dir=dataset_folder) |
|
os.remove(clrs_file) |
|
return dataset_folder |
|
|
|
|
|
def make_sampler(length: int, |
|
rng: Any, |
|
algorithm: str, |
|
split: str, |
|
batch_size: int, |
|
multiplier: int, |
|
randomize_pos: bool, |
|
enforce_pred_as_input: bool, |
|
enforce_permutations: bool, |
|
chunked: bool, |
|
chunk_length: int, |
|
sampler_kwargs: Dict[str, Any]): |
|
"""Create a sampler with given options. |
|
|
|
Args: |
|
length: Size of samples (i.e., number of nodes in the graph). |
|
A length of -1 will mean that the benchmark |
|
dataset (for the given split) is used. Positive sizes will instantiate |
|
samplers of the corresponding size. |
|
rng: Numpy random state. |
|
algorithm: The name of the algorithm to sample from. |
|
split: 'train', 'val' or 'test'. |
|
batch_size: Samples per batch. |
|
multiplier: Integer multiplier for the number of samples in the dataset, |
|
only used for positive sizes. Negative multiplier means infinite samples. |
|
randomize_pos: Whether to randomize the `pos` input. |
|
enforce_pred_as_input: Whether to convert fixed pred_h hints to inputs. |
|
enforce_permutations: Whether to enforce permutation pointers. |
|
chunked: Whether to chunk the dataset. |
|
chunk_length: Unroll length of chunks, if `chunked` is True. |
|
sampler_kwargs: Extra args passed to the sampler. |
|
Returns: |
|
A sampler (iterator), the number of samples in the iterator (negative |
|
if infinite samples), and the spec. |
|
""" |
|
if length < 0: |
|
dataset_folder = _maybe_download_dataset(FLAGS.dataset_path) |
|
sampler, num_samples, spec = clrs.create_dataset(folder=dataset_folder, |
|
algorithm=algorithm, |
|
batch_size=batch_size, |
|
split=split) |
|
sampler = sampler.as_numpy_iterator() |
|
else: |
|
num_samples = clrs.CLRS30[split]['num_samples'] * multiplier |
|
sampler, spec = clrs.build_sampler( |
|
algorithm, |
|
seed=rng.randint(2**32), |
|
num_samples=num_samples, |
|
length=length, |
|
**sampler_kwargs, |
|
) |
|
sampler = _iterate_sampler(sampler, batch_size) |
|
|
|
if randomize_pos: |
|
sampler = clrs.process_random_pos(sampler, rng) |
|
if enforce_pred_as_input and algorithm in PRED_AS_INPUT_ALGOS: |
|
spec, sampler = clrs.process_pred_as_input(spec, sampler) |
|
spec, sampler = clrs.process_permutations(spec, sampler, enforce_permutations) |
|
if chunked: |
|
sampler = clrs.chunkify(sampler, chunk_length) |
|
return sampler, num_samples, spec |
|
|
|
|
|
def make_multi_sampler(sizes, rng, **kwargs): |
|
"""Create a sampler with cycling sample sizes.""" |
|
ss = [] |
|
tot_samples = 0 |
|
for length in sizes: |
|
sampler, num_samples, spec = make_sampler(length, rng, **kwargs) |
|
ss.append(sampler) |
|
tot_samples += num_samples |
|
|
|
def cycle_samplers(): |
|
while True: |
|
for s in ss: |
|
yield next(s) |
|
return cycle_samplers(), tot_samples, spec |
|
|
|
|
|
def _concat(dps, axis): |
|
return jax.tree_util.tree_map(lambda *x: np.concatenate(x, axis), *dps) |
|
|
|
|
|
def collect_and_eval(sampler, predict_fn, sample_count, rng_key, extras): |
|
"""Collect batches of output and hint preds and evaluate them.""" |
|
processed_samples = 0 |
|
preds = [] |
|
outputs = [] |
|
while processed_samples < sample_count: |
|
feedback = next(sampler) |
|
batch_size = feedback.outputs[0].data.shape[0] |
|
outputs.append(feedback.outputs) |
|
new_rng_key, rng_key = jax.random.split(rng_key) |
|
cur_preds, _ = predict_fn(new_rng_key, feedback.features) |
|
preds.append(cur_preds) |
|
processed_samples += batch_size |
|
outputs = _concat(outputs, axis=0) |
|
preds = _concat(preds, axis=0) |
|
out = clrs.evaluate(outputs, preds) |
|
if extras: |
|
out.update(extras) |
|
return {k: unpack(v) for k, v in out.items()} |
|
|
|
|
|
def create_samplers(rng, train_lengths: List[int]): |
|
"""Create all the samplers.""" |
|
train_samplers = [] |
|
val_samplers = [] |
|
val_sample_counts = [] |
|
test_samplers = [] |
|
test_sample_counts = [] |
|
spec_list = [] |
|
|
|
for algo_idx, algorithm in enumerate(FLAGS.algorithms): |
|
|
|
with tf.device('/cpu:0'): |
|
|
|
if algorithm in ['naive_string_matcher', 'kmp_matcher']: |
|
|
|
|
|
|
|
|
|
max_length = max(train_lengths) |
|
if max_length > 0: |
|
max_length = (max_length * 5) // 4 |
|
train_lengths = [max_length] |
|
if FLAGS.chunked_training: |
|
train_lengths = train_lengths * len(train_lengths) |
|
|
|
logging.info('Creating samplers for algo %s', algorithm) |
|
|
|
p = tuple([0.1 + 0.1 * i for i in range(9)]) |
|
if p and algorithm in ['articulation_points', 'bridges', |
|
'mst_kruskal', 'bipartite_matching']: |
|
|
|
|
|
p = tuple(np.array(p) / 2) |
|
length_needle = FLAGS.length_needle |
|
sampler_kwargs = dict(p=p, length_needle=length_needle) |
|
if length_needle == 0: |
|
sampler_kwargs.pop('length_needle') |
|
|
|
common_sampler_args = dict( |
|
algorithm=FLAGS.algorithms[algo_idx], |
|
rng=rng, |
|
enforce_pred_as_input=FLAGS.enforce_pred_as_input, |
|
enforce_permutations=FLAGS.enforce_permutations, |
|
chunk_length=FLAGS.chunk_length, |
|
) |
|
|
|
train_args = dict(sizes=train_lengths, |
|
split='train', |
|
batch_size=FLAGS.batch_size, |
|
multiplier=-1, |
|
randomize_pos=FLAGS.random_pos, |
|
chunked=FLAGS.chunked_training, |
|
sampler_kwargs=sampler_kwargs, |
|
**common_sampler_args) |
|
train_sampler, _, spec = make_multi_sampler(**train_args) |
|
|
|
mult = clrs.CLRS_30_ALGS_SETTINGS[algorithm]['num_samples_multiplier'] |
|
val_args = dict(sizes=[np.amax(train_lengths)], |
|
split='val', |
|
batch_size=32, |
|
multiplier=2 * mult, |
|
randomize_pos=FLAGS.random_pos, |
|
chunked=False, |
|
sampler_kwargs=sampler_kwargs, |
|
**common_sampler_args) |
|
val_sampler, val_samples, spec = make_multi_sampler(**val_args) |
|
|
|
test_args = dict(sizes=[-1], |
|
split='test', |
|
batch_size=32, |
|
multiplier=2 * mult, |
|
randomize_pos=False, |
|
chunked=False, |
|
sampler_kwargs={}, |
|
**common_sampler_args) |
|
test_sampler, test_samples, spec = make_multi_sampler(**test_args) |
|
|
|
spec_list.append(spec) |
|
train_samplers.append(train_sampler) |
|
val_samplers.append(val_sampler) |
|
val_sample_counts.append(val_samples) |
|
test_samplers.append(test_sampler) |
|
test_sample_counts.append(test_samples) |
|
|
|
return (train_samplers, |
|
val_samplers, val_sample_counts, |
|
test_samplers, test_sample_counts, |
|
spec_list) |
|
|
|
|
|
def get_score(submission_folder): |
|
FLAGS(["eval.py"]) |
|
if FLAGS.hint_mode == 'encoded_decoded': |
|
encode_hints = True |
|
decode_hints = True |
|
elif FLAGS.hint_mode == 'decoded_only': |
|
encode_hints = False |
|
decode_hints = True |
|
elif FLAGS.hint_mode == 'none': |
|
encode_hints = False |
|
decode_hints = False |
|
else: |
|
raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none}.') |
|
|
|
train_lengths = [int(x) for x in FLAGS.train_lengths] |
|
|
|
rng = np.random.RandomState(FLAGS.seed) |
|
rng_key = jax.random.PRNGKey(rng.randint(2**32)) |
|
|
|
checkpoint_path = os.path.join(submission_folder, 'checkpoints') |
|
|
|
spec_list = pickle.load(open(os.path.join(checkpoint_path, 'spec_list.pkl'), 'rb')) |
|
|
|
|
|
(train_samplers, |
|
val_samplers, val_sample_counts, |
|
test_samplers, test_sample_counts, |
|
spec_list) = create_samplers(rng, train_lengths) |
|
|
|
|
|
model_params = pickle.load(open(os.path.join(checkpoint_path, 'model_params.pkl'), 'rb')) |
|
processor_type, use_ln, nb_triplet_fts, nb_heads = model_params["processor_factory"] |
|
model_params["processor_factory"] = clrs.get_processor_factory( |
|
processor_type, |
|
use_ln=use_ln, |
|
nb_triplet_fts=nb_triplet_fts, |
|
nb_heads=nb_heads |
|
) |
|
model_params["checkpoint_path"]=checkpoint_path |
|
|
|
eval_model = BaselineModel( |
|
spec=spec_list, |
|
dummy_trajectory=[next(t) for t in val_samplers], |
|
**model_params |
|
) |
|
|
|
feedback_list = [next(t) for t in train_samplers] |
|
|
|
|
|
all_features = [f.features for f in feedback_list] |
|
eval_model.init(all_features, FLAGS.seed + 1) |
|
|
|
|
|
logging.set_verbosity(logging.INFO) |
|
|
|
logging.info('Restoring best model from checkpoint...') |
|
eval_model.restore_model('best.pkl', only_load_processor=False) |
|
|
|
for algo_idx in range(len(train_samplers)): |
|
new_rng_key, rng_key = jax.random.split(rng_key) |
|
val_stats = collect_and_eval( |
|
val_samplers[algo_idx], |
|
functools.partial(eval_model.predict, algorithm_index=algo_idx), |
|
val_sample_counts[algo_idx], |
|
new_rng_key, |
|
extras = {}) |
|
|
|
|
|
new_rng_key, rng_key = jax.random.split(rng_key) |
|
test_stats = collect_and_eval( |
|
test_samplers[algo_idx], |
|
functools.partial(eval_model.predict, algorithm_index=algo_idx), |
|
test_sample_counts[algo_idx], |
|
new_rng_key, |
|
extras = {}) |
|
|
|
return test_stats['score'] |
|
|
|
if __name__ == '__main__': |
|
app.run(get_score) |