File size: 3,155 Bytes
d28af7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import omegaconf
from omegaconf import OmegaConf
def load_config(args=None, config_file=None, overwrite_fairseq=False):
"""TODO (huxu): move fairseq overwrite to another function."""
if args is not None:
config_file = args.taskconfig
config = recursive_config(config_file)
if config.dataset.subsampling is not None:
batch_size = config.fairseq.dataset.batch_size // config.dataset.subsampling
print(
"adjusting batch_size to {} due to subsampling {}.".format(
batch_size, config.dataset.subsampling
)
)
config.fairseq.dataset.batch_size = batch_size
is_test = config.dataset.split is not None and config.dataset.split == "test"
if not is_test:
if (
config.fairseq.checkpoint is None
or config.fairseq.checkpoint.save_dir is None
):
raise ValueError("fairseq save_dir or save_path must be specified.")
save_dir = config.fairseq.checkpoint.save_dir
os.makedirs(save_dir, exist_ok=True)
if config.fairseq.common.tensorboard_logdir is not None:
tb_run_dir = suffix_rundir(
save_dir, config.fairseq.common.tensorboard_logdir
)
config.fairseq.common.tensorboard_logdir = tb_run_dir
print(
"update tensorboard_logdir as", config.fairseq.common.tensorboard_logdir
)
os.makedirs(save_dir, exist_ok=True)
OmegaConf.save(config=config, f=os.path.join(save_dir, "config.yaml"))
if overwrite_fairseq and config.fairseq is not None and args is not None:
# flatten fields.
for group in config.fairseq:
for field in config.fairseq[group]:
print("overwrite args." + field, "as", config.fairseq[group][field])
setattr(args, field, config.fairseq[group][field])
return config
def recursive_config(config_path):
"""allows for stacking of configs in any depth."""
config = OmegaConf.load(config_path)
if config.includes is not None:
includes = config.includes
config.pop("includes")
base_config = recursive_config(includes)
config = OmegaConf.merge(base_config, config)
return config
def suffix_rundir(save_dir, run_dir):
max_id = -1
for search_dir in os.listdir(save_dir):
if search_dir.startswith(run_dir):
splits = search_dir.split("_")
cur_id = int(splits[1]) if len(splits) > 1 else 0
max_id = max(max_id, cur_id)
return os.path.join(save_dir, run_dir + "_" + str(max_id + 1))
def overwrite_dir(config, replace, basedir):
for key in config:
if isinstance(config[key], str) and config[key].startswith(basedir):
config[key] = config[key].replace(basedir, replace)
if isinstance(config[key], omegaconf.dictconfig.DictConfig):
overwrite_dir(config[key], replace, basedir)
|