|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
r"""Run training. |
|
|
|
Choose training algorithm and task(s) and follow these examples. |
|
|
|
Run synchronous policy gradient training locally: |
|
|
|
CONFIG="agent=c(algorithm='pg'),env=c(task='reverse')" |
|
OUT_DIR="/tmp/bf_pg_local" |
|
rm -rf $OUT_DIR |
|
bazel run -c opt single_task:run -- \ |
|
--alsologtostderr \ |
|
--config="$CONFIG" \ |
|
--max_npe=0 \ |
|
--logdir="$OUT_DIR" \ |
|
--summary_interval=1 \ |
|
--model_v=0 |
|
learning/brain/tensorboard/tensorboard.sh --port 12345 --logdir "$OUT_DIR" |
|
|
|
|
|
Run genetic algorithm locally: |
|
|
|
CONFIG="agent=c(algorithm='ga'),env=c(task='reverse')" |
|
OUT_DIR="/tmp/bf_ga_local" |
|
rm -rf $OUT_DIR |
|
bazel run -c opt single_task:run -- \ |
|
--alsologtostderr \ |
|
--config="$CONFIG" \ |
|
--max_npe=0 \ |
|
--logdir="$OUT_DIR" |
|
|
|
|
|
Run uniform random search locally: |
|
|
|
CONFIG="agent=c(algorithm='rand'),env=c(task='reverse')" |
|
OUT_DIR="/tmp/bf_rand_local" |
|
rm -rf $OUT_DIR |
|
bazel run -c opt single_task:run -- \ |
|
--alsologtostderr \ |
|
--config="$CONFIG" \ |
|
--max_npe=0 \ |
|
--logdir="$OUT_DIR" |
|
""" |
|
|
|
from absl import app |
|
from absl import flags |
|
from absl import logging |
|
|
|
from single_task import defaults |
|
from single_task import ga_train |
|
from single_task import pg_train |
|
|
|
FLAGS = flags.FLAGS |
|
flags.DEFINE_string('config', '', 'Configuration.') |
|
flags.DEFINE_string( |
|
'logdir', None, 'Absolute path where to write results.') |
|
flags.DEFINE_integer('task_id', 0, 'ID for this worker.') |
|
flags.DEFINE_integer('num_workers', 1, 'How many workers there are.') |
|
flags.DEFINE_integer( |
|
'max_npe', 0, |
|
'NPE = number of programs executed. Maximum number of programs to execute ' |
|
'in each run. Training will complete when this threshold is reached. Set ' |
|
'to 0 for unlimited training.') |
|
flags.DEFINE_integer( |
|
'num_repetitions', 1, |
|
'Number of times the same experiment will be run (globally across all ' |
|
'workers). Each run is independent.') |
|
flags.DEFINE_string( |
|
'log_level', 'INFO', |
|
'The threshold for what messages will be logged. One of DEBUG, INFO, WARN, ' |
|
'ERROR, or FATAL.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ALGORITHM_REGISTRATION = { |
|
'pg': pg_train, |
|
'ga': ga_train, |
|
'rand': ga_train, |
|
} |
|
|
|
|
|
def get_namespace(config_string): |
|
"""Get namespace for the selected algorithm. |
|
|
|
Users who want to add additional algorithm types should modify this function. |
|
The algorithm's namespace should contain the following functions: |
|
run_training: Run the main training loop. |
|
define_tuner_hparam_space: Return the hparam tuning space for the algo. |
|
write_hparams_to_config: Helper for tuning. Write hparams chosen for tuning |
|
to the Config object. |
|
Look at pg_train.py and ga_train.py for function signatures and |
|
implementations. |
|
|
|
Args: |
|
config_string: String representation of a Config object. This will get |
|
parsed into a Config in order to determine what algorithm to use. |
|
|
|
Returns: |
|
algorithm_namespace: The module corresponding to the algorithm given in the |
|
config. |
|
config: The Config object resulting from parsing `config_string`. |
|
|
|
Raises: |
|
ValueError: If config.agent.algorithm is not one of the registered |
|
algorithms. |
|
""" |
|
config = defaults.default_config_with_updates(config_string) |
|
if config.agent.algorithm not in ALGORITHM_REGISTRATION: |
|
raise ValueError('Unknown algorithm type "%s"' % (config.agent.algorithm,)) |
|
else: |
|
return ALGORITHM_REGISTRATION[config.agent.algorithm], config |
|
|
|
|
|
def main(argv): |
|
del argv |
|
|
|
logging.set_verbosity(FLAGS.log_level) |
|
|
|
flags.mark_flag_as_required('logdir') |
|
if FLAGS.num_workers <= 0: |
|
raise ValueError('num_workers flag must be greater than 0.') |
|
if FLAGS.task_id < 0: |
|
raise ValueError('task_id flag must be greater than or equal to 0.') |
|
if FLAGS.task_id >= FLAGS.num_workers: |
|
raise ValueError( |
|
'task_id flag must be strictly less than num_workers flag.') |
|
|
|
ns, _ = get_namespace(FLAGS.config) |
|
ns.run_training(is_chief=FLAGS.task_id == 0) |
|
|
|
|
|
if __name__ == '__main__': |
|
app.run(main) |
|
|