|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
"""Default configuration for agent and environment.""" |
|
|
|
from absl import logging |
|
|
|
from common import config_lib |
|
|
|
|
|
def default_config(): |
|
return config_lib.Config( |
|
agent=config_lib.OneOf( |
|
[config_lib.Config( |
|
algorithm='pg', |
|
policy_lstm_sizes=[35,35], |
|
|
|
value_lstm_sizes=[35,35], |
|
obs_embedding_size=10, |
|
grad_clip_threshold=10.0, |
|
param_init_factor=1.0, |
|
lr=5e-5, |
|
pi_loss_hparam=1.0, |
|
vf_loss_hparam=0.5, |
|
entropy_beta=1e-2, |
|
regularizer=0.0, |
|
softmax_tr=1.0, |
|
optimizer='rmsprop', |
|
topk=0, |
|
topk_loss_hparam=0.0, |
|
|
|
|
|
topk_batch_size=1, |
|
|
|
|
|
|
|
ema_baseline_decay=0.99, |
|
|
|
|
|
|
|
|
|
|
|
|
|
eos_token=False, |
|
replay_temperature=1.0, |
|
|
|
alpha=0.0, |
|
|
|
iw_normalize=True), |
|
config_lib.Config( |
|
algorithm='ga', |
|
crossover_rate=0.99, |
|
mutation_rate=0.086), |
|
config_lib.Config( |
|
algorithm='rand')], |
|
algorithm='pg', |
|
), |
|
env=config_lib.Config( |
|
|
|
task='', |
|
task_cycle=[], |
|
task_kwargs='{}', |
|
task_manager_config=config_lib.Config( |
|
|
|
|
|
correct_bonus=2.0, |
|
code_length_bonus=1.0), |
|
correct_syntax=False, |
|
), |
|
batch_size=64, |
|
timestep_limit=32) |
|
|
|
|
|
def default_config_with_updates(config_string, do_logging=True): |
|
if do_logging: |
|
logging.info('Config string: "%s"', config_string) |
|
config = default_config() |
|
config.strict_update(config_lib.Config.parse(config_string)) |
|
if do_logging: |
|
logging.info('Config:\n%s', config.pretty_str()) |
|
return config |
|
|