Spaces:
Sleeping
Sleeping
import os | |
import os.path as osp | |
import yaml | |
import json | |
import shutil | |
import sys | |
import time | |
import tempfile | |
import subprocess | |
import datetime | |
from importlib import import_module | |
from typing import Optional, Tuple | |
from easydict import EasyDict | |
from copy import deepcopy | |
from ding.utils import deep_merge_dicts, get_rank | |
from ding.envs import get_env_cls, get_env_manager_cls, BaseEnvManager | |
from ding.policy import get_policy_cls | |
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, Coordinator, \ | |
AdvancedReplayBuffer, get_parallel_commander_cls, get_parallel_collector_cls, get_buffer_cls, \ | |
get_serial_collector_cls, MetricSerialEvaluator, BattleInteractionSerialEvaluator | |
from ding.reward_model import get_reward_model_cls | |
from ding.world_model import get_world_model_cls | |
from .utils import parallel_transform, parallel_transform_slurm, parallel_transform_k8s, save_config_formatted | |
class Config(object): | |
r""" | |
Overview: | |
Base class for config. | |
Interface: | |
__init__, file_to_dict | |
Property: | |
cfg_dict | |
""" | |
def __init__( | |
self, | |
cfg_dict: Optional[dict] = None, | |
cfg_text: Optional[str] = None, | |
filename: Optional[str] = None | |
) -> None: | |
""" | |
Overview: | |
Init method. Create config including dict type config and text type config. | |
Arguments: | |
- cfg_dict (:obj:`Optional[dict]`): dict type config | |
- cfg_text (:obj:`Optional[str]`): text type config | |
- filename (:obj:`Optional[str]`): config file name | |
""" | |
if cfg_dict is None: | |
cfg_dict = {} | |
if not isinstance(cfg_dict, dict): | |
raise TypeError("invalid type for cfg_dict: {}".format(type(cfg_dict))) | |
self._cfg_dict = cfg_dict | |
if cfg_text: | |
text = cfg_text | |
elif filename: | |
with open(filename, 'r') as f: | |
text = f.read() | |
else: | |
text = '.' | |
self._text = text | |
self._filename = filename | |
def file_to_dict(filename: str) -> 'Config': # noqa | |
""" | |
Overview: | |
Read config file and create config. | |
Arguments: | |
- filename (:obj:`Optional[str]`): config file name. | |
Returns: | |
- cfg_dict (:obj:`Config`): config class | |
""" | |
cfg_dict, cfg_text = Config._file_to_dict(filename) | |
return Config(cfg_dict, cfg_text, filename=filename) | |
def _file_to_dict(filename: str) -> Tuple[dict, str]: | |
""" | |
Overview: | |
Read config file and convert the config file to dict type config and text type config. | |
Arguments: | |
- filename (:obj:`Optional[str]`): config file name. | |
Returns: | |
- cfg_dict (:obj:`Optional[dict]`): dict type config | |
- cfg_text (:obj:`Optional[str]`): text type config | |
""" | |
filename = osp.abspath(osp.expanduser(filename)) | |
# TODO check exist | |
# TODO check suffix | |
ext_name = osp.splitext(filename)[-1] | |
with tempfile.TemporaryDirectory() as temp_config_dir: | |
temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=ext_name) | |
temp_config_name = osp.basename(temp_config_file.name) | |
temp_config_file.close() | |
shutil.copyfile(filename, temp_config_file.name) | |
temp_module_name = osp.splitext(temp_config_name)[0] | |
sys.path.insert(0, temp_config_dir) | |
# TODO validate py syntax | |
module = import_module(temp_module_name) | |
cfg_dict = {k: v for k, v in module.__dict__.items() if not k.startswith('_')} | |
del sys.modules[temp_module_name] | |
sys.path.pop(0) | |
cfg_text = filename + '\n' | |
with open(filename, 'r') as f: | |
cfg_text += f.read() | |
return cfg_dict, cfg_text | |
def cfg_dict(self) -> dict: | |
return self._cfg_dict | |
def read_config_yaml(path: str) -> EasyDict: | |
""" | |
Overview: | |
read configuration from path | |
Arguments: | |
- path (:obj:`str`): Path of source yaml | |
Returns: | |
- (:obj:`EasyDict`): Config data from this file with dict type | |
""" | |
with open(path, "r") as f: | |
config_ = yaml.safe_load(f) | |
return EasyDict(config_) | |
def save_config_yaml(config_: dict, path: str) -> None: | |
""" | |
Overview: | |
save configuration to path | |
Arguments: | |
- config (:obj:`dict`): Config dict | |
- path (:obj:`str`): Path of target yaml | |
""" | |
config_string = json.dumps(config_) | |
with open(path, "w") as f: | |
yaml.safe_dump(json.loads(config_string), f) | |
def save_config_py(config_: dict, path: str) -> None: | |
""" | |
Overview: | |
save configuration to python file | |
Arguments: | |
- config (:obj:`dict`): Config dict | |
- path (:obj:`str`): Path of target yaml | |
""" | |
# config_string = json.dumps(config_, indent=4) | |
config_string = str(config_) | |
from yapf.yapflib.yapf_api import FormatCode | |
config_string, _ = FormatCode(config_string) | |
config_string = config_string.replace('inf,', 'float("inf"),') | |
with open(path, "w") as f: | |
f.write('exp_config = ' + config_string) | |
def read_config_directly(path: str) -> dict: | |
""" | |
Overview: | |
Read configuration from a file path(now only support python file) and directly return results. | |
Arguments: | |
- path (:obj:`str`): Path of configuration file | |
Returns: | |
- cfg (:obj:`Tuple[dict, dict]`): Configuration dict. | |
""" | |
suffix = path.split('.')[-1] | |
if suffix == 'py': | |
return Config.file_to_dict(path).cfg_dict | |
else: | |
raise KeyError("invalid config file suffix: {}".format(suffix)) | |
def read_config(path: str) -> Tuple[dict, dict]: | |
""" | |
Overview: | |
Read configuration from a file path(now only suport python file). And select some proper parts. | |
Arguments: | |
- path (:obj:`str`): Path of configuration file | |
Returns: | |
- cfg (:obj:`Tuple[dict, dict]`): A collection(tuple) of configuration dict, divided into `main_config` and \ | |
`create_cfg` two parts. | |
""" | |
suffix = path.split('.')[-1] | |
if suffix == 'py': | |
cfg = Config.file_to_dict(path).cfg_dict | |
assert "main_config" in cfg, "Please make sure a 'main_config' variable is declared in config python file!" | |
assert "create_config" in cfg, "Please make sure a 'create_config' variable is declared in config python file!" | |
return cfg['main_config'], cfg['create_config'] | |
else: | |
raise KeyError("invalid config file suffix: {}".format(suffix)) | |
def read_config_with_system(path: str) -> Tuple[dict, dict, dict]: | |
""" | |
Overview: | |
Read configuration from a file path(now only suport python file). And select some proper parts | |
Arguments: | |
- path (:obj:`str`): Path of configuration file | |
Returns: | |
- cfg (:obj:`Tuple[dict, dict]`): A collection(tuple) of configuration dict, divided into `main_config`, \ | |
`create_cfg` and `system_config` three parts. | |
""" | |
suffix = path.split('.')[-1] | |
if suffix == 'py': | |
cfg = Config.file_to_dict(path).cfg_dict | |
assert "main_config" in cfg, "Please make sure a 'main_config' variable is declared in config python file!" | |
assert "create_config" in cfg, "Please make sure a 'create_config' variable is declared in config python file!" | |
assert "system_config" in cfg, "Please make sure a 'system_config' variable is declared in config python file!" | |
return cfg['main_config'], cfg['create_config'], cfg['system_config'] | |
else: | |
raise KeyError("invalid config file suffix: {}".format(suffix)) | |
def save_config(config_: dict, path: str, type_: str = 'py', save_formatted: bool = False) -> None: | |
""" | |
Overview: | |
save configuration to python file or yaml file | |
Arguments: | |
- config (:obj:`dict`): Config dict | |
- path (:obj:`str`): Path of target yaml or target python file | |
- type (:obj:`str`): If type is ``yaml`` , save configuration to yaml file. If type is ``py`` , save\ | |
configuration to python file. | |
- save_formatted (:obj:`bool`): If save_formatted is true, save formatted config to path.\ | |
Formatted config can be read by serial_pipeline directly. | |
""" | |
assert type_ in ['yaml', 'py'], type_ | |
if type_ == 'yaml': | |
save_config_yaml(config_, path) | |
elif type_ == 'py': | |
save_config_py(config_, path) | |
if save_formatted: | |
formated_path = osp.join(osp.dirname(path), 'formatted_' + osp.basename(path)) | |
save_config_formatted(config_, formated_path) | |
def compile_buffer_config(policy_cfg: EasyDict, user_cfg: EasyDict, buffer_cls: 'IBuffer') -> EasyDict: # noqa | |
def _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, buffer_cls): | |
if buffer_cls is None: | |
assert 'type' in policy_buffer_cfg, "please indicate buffer type in create_cfg" | |
buffer_cls = get_buffer_cls(policy_buffer_cfg) | |
buffer_cfg = deep_merge_dicts(buffer_cls.default_config(), policy_buffer_cfg) | |
buffer_cfg = deep_merge_dicts(buffer_cfg, user_buffer_cfg) | |
return buffer_cfg | |
policy_multi_buffer = policy_cfg.other.replay_buffer.get('multi_buffer', False) | |
user_multi_buffer = user_cfg.policy.get('other', {}).get('replay_buffer', {}).get('multi_buffer', False) | |
assert not user_multi_buffer or user_multi_buffer == policy_multi_buffer, "For multi_buffer, \ | |
user_cfg({}) and policy_cfg({}) must be in accordance".format(user_multi_buffer, policy_multi_buffer) | |
multi_buffer = policy_multi_buffer | |
if not multi_buffer: | |
policy_buffer_cfg = policy_cfg.other.replay_buffer | |
user_buffer_cfg = user_cfg.policy.get('other', {}).get('replay_buffer', {}) | |
return _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, buffer_cls) | |
else: | |
return_cfg = EasyDict() | |
for buffer_name in policy_cfg.other.replay_buffer: # Only traverse keys in policy_cfg | |
if buffer_name == 'multi_buffer': | |
continue | |
policy_buffer_cfg = policy_cfg.other.replay_buffer[buffer_name] | |
user_buffer_cfg = user_cfg.policy.get('other', {}).get('replay_buffer', {}).get('buffer_name', {}) | |
if buffer_cls is None: | |
return_cfg[buffer_name] = _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, None) | |
else: | |
return_cfg[buffer_name] = _compile_buffer_config( | |
policy_buffer_cfg, user_buffer_cfg, buffer_cls[buffer_name] | |
) | |
return_cfg[buffer_name].name = buffer_name | |
return return_cfg | |
def compile_collector_config( | |
policy_cfg: EasyDict, | |
user_cfg: EasyDict, | |
collector_cls: 'ISerialCollector' # noqa | |
) -> EasyDict: | |
policy_collector_cfg = policy_cfg.collect.collector | |
user_collector_cfg = user_cfg.policy.get('collect', {}).get('collector', {}) | |
# step1: get collector class | |
# two cases: create cfg merged in policy_cfg, collector class, and class has higher priority | |
if collector_cls is None: | |
assert 'type' in policy_collector_cfg, "please indicate collector type in create_cfg" | |
# use type to get collector_cls | |
collector_cls = get_serial_collector_cls(policy_collector_cfg) | |
# step2: policy collector cfg merge to collector cfg | |
collector_cfg = deep_merge_dicts(collector_cls.default_config(), policy_collector_cfg) | |
# step3: user collector cfg merge to the step2 config | |
collector_cfg = deep_merge_dicts(collector_cfg, user_collector_cfg) | |
return collector_cfg | |
policy_config_template = dict( | |
model=dict(), | |
learn=dict(learner=dict()), | |
collect=dict(collector=dict()), | |
eval=dict(evaluator=dict()), | |
other=dict(replay_buffer=dict()), | |
) | |
policy_config_template = EasyDict(policy_config_template) | |
env_config_template = dict(manager=dict(), stop_value=int(1e10), n_evaluator_episode=4) | |
env_config_template = EasyDict(env_config_template) | |
def save_project_state(exp_name: str) -> None: | |
def _fn(cmd: str): | |
return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.strip().decode("utf-8") | |
if subprocess.run("git status", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).returncode == 0: | |
short_sha = _fn("git describe --always") | |
log = _fn("git log --stat -n 5") | |
diff = _fn("git diff") | |
with open(os.path.join(exp_name, "git_log.txt"), "w", encoding='utf-8') as f: | |
f.write(short_sha + '\n\n' + log) | |
with open(os.path.join(exp_name, "git_diff.txt"), "w", encoding='utf-8') as f: | |
f.write(diff) | |
def compile_config( | |
cfg: EasyDict, | |
env_manager: type = None, | |
policy: type = None, | |
learner: type = BaseLearner, | |
collector: type = None, | |
evaluator: type = InteractionSerialEvaluator, | |
buffer: type = None, | |
env: type = None, | |
reward_model: type = None, | |
world_model: type = None, | |
seed: int = 0, | |
auto: bool = False, | |
create_cfg: dict = None, | |
save_cfg: bool = True, | |
save_path: str = 'total_config.py', | |
renew_dir: bool = True, | |
) -> EasyDict: | |
""" | |
Overview: | |
Combine the input config information with other input information. | |
Compile config to make it easy to be called by other programs | |
Arguments: | |
- cfg (:obj:`EasyDict`): Input config dict which is to be used in the following pipeline | |
- env_manager (:obj:`type`): Env_manager class which is to be used in the following pipeline | |
- policy (:obj:`type`): Policy class which is to be used in the following pipeline | |
- learner (:obj:`type`): Input learner class, defaults to BaseLearner | |
- collector (:obj:`type`): Input collector class, defaults to BaseSerialCollector | |
- evaluator (:obj:`type`): Input evaluator class, defaults to InteractionSerialEvaluator | |
- buffer (:obj:`type`): Input buffer class, defaults to IBuffer | |
- env (:obj:`type`): Environment class which is to be used in the following pipeline | |
- reward_model (:obj:`type`): Reward model class which aims to offer various and valuable reward | |
- seed (:obj:`int`): Random number seed | |
- auto (:obj:`bool`): Compile create_config dict or not | |
- create_cfg (:obj:`dict`): Input create config dict | |
- save_cfg (:obj:`bool`): Save config or not | |
- save_path (:obj:`str`): Path of saving file | |
- renew_dir (:obj:`bool`): Whether to new a directory for saving config. | |
Returns: | |
- cfg (:obj:`EasyDict`): Config after compiling | |
""" | |
cfg, create_cfg = deepcopy(cfg), deepcopy(create_cfg) | |
if auto: | |
assert create_cfg is not None | |
# for compatibility | |
if 'collector' not in create_cfg: | |
create_cfg.collector = EasyDict(dict(type='sample')) | |
if 'replay_buffer' not in create_cfg: | |
create_cfg.replay_buffer = EasyDict(dict(type='advanced')) | |
buffer = AdvancedReplayBuffer | |
if env is None: | |
if 'env' in create_cfg: | |
env = get_env_cls(create_cfg.env) | |
else: | |
env = None | |
create_cfg.env = {'type': 'ding_env_wrapper_generated'} | |
if env_manager is None: | |
env_manager = get_env_manager_cls(create_cfg.env_manager) | |
if policy is None: | |
policy = get_policy_cls(create_cfg.policy) | |
if 'default_config' in dir(env): | |
env_config = env.default_config() | |
else: | |
env_config = EasyDict() # env does not have default_config | |
env_config = deep_merge_dicts(env_config_template, env_config) | |
env_config.update(create_cfg.env) | |
env_config.manager = deep_merge_dicts(env_manager.default_config(), env_config.manager) | |
env_config.manager.update(create_cfg.env_manager) | |
policy_config = policy.default_config() | |
policy_config = deep_merge_dicts(policy_config_template, policy_config) | |
policy_config.update(create_cfg.policy) | |
policy_config.collect.collector.update(create_cfg.collector) | |
if 'evaluator' in create_cfg: | |
policy_config.eval.evaluator.update(create_cfg.evaluator) | |
policy_config.other.replay_buffer.update(create_cfg.replay_buffer) | |
policy_config.other.commander = BaseSerialCommander.default_config() | |
if 'reward_model' in create_cfg: | |
reward_model = get_reward_model_cls(create_cfg.reward_model) | |
reward_model_config = reward_model.default_config() | |
else: | |
reward_model_config = EasyDict() | |
if 'world_model' in create_cfg: | |
world_model = get_world_model_cls(create_cfg.world_model) | |
world_model_config = world_model.default_config() | |
world_model_config.update(create_cfg.world_model) | |
else: | |
world_model_config = EasyDict() | |
else: | |
if 'default_config' in dir(env): | |
env_config = env.default_config() | |
else: | |
env_config = EasyDict() # env does not have default_config | |
env_config = deep_merge_dicts(env_config_template, env_config) | |
if env_manager is None: | |
env_manager = BaseEnvManager # for compatibility | |
env_config.manager = deep_merge_dicts(env_manager.default_config(), env_config.manager) | |
policy_config = policy.default_config() | |
policy_config = deep_merge_dicts(policy_config_template, policy_config) | |
if reward_model is None: | |
reward_model_config = EasyDict() | |
else: | |
reward_model_config = reward_model.default_config() | |
if world_model is None: | |
world_model_config = EasyDict() | |
else: | |
world_model_config = world_model.default_config() | |
world_model_config.update(create_cfg.world_model) | |
policy_config.learn.learner = deep_merge_dicts( | |
learner.default_config(), | |
policy_config.learn.learner, | |
) | |
if create_cfg is not None or collector is not None: | |
policy_config.collect.collector = compile_collector_config(policy_config, cfg, collector) | |
if evaluator: | |
policy_config.eval.evaluator = deep_merge_dicts( | |
evaluator.default_config(), | |
policy_config.eval.evaluator, | |
) | |
if create_cfg is not None or buffer is not None: | |
policy_config.other.replay_buffer = compile_buffer_config(policy_config, cfg, buffer) | |
default_config = EasyDict({'env': env_config, 'policy': policy_config}) | |
if len(reward_model_config) > 0: | |
default_config['reward_model'] = reward_model_config | |
if len(world_model_config) > 0: | |
default_config['world_model'] = world_model_config | |
cfg = deep_merge_dicts(default_config, cfg) | |
if 'unroll_len' in cfg.policy: | |
cfg.policy.collect.unroll_len = cfg.policy.unroll_len | |
cfg.seed = seed | |
# check important key in config | |
if evaluator in [InteractionSerialEvaluator, BattleInteractionSerialEvaluator]: # env interaction evaluation | |
cfg.policy.eval.evaluator.stop_value = cfg.env.stop_value | |
cfg.policy.eval.evaluator.n_episode = cfg.env.n_evaluator_episode | |
if 'exp_name' not in cfg: | |
cfg.exp_name = 'default_experiment' | |
if save_cfg and get_rank() == 0: | |
if os.path.exists(cfg.exp_name) and renew_dir: | |
cfg.exp_name += datetime.datetime.now().strftime("_%y%m%d_%H%M%S") | |
try: | |
os.makedirs(cfg.exp_name) | |
except FileExistsError: | |
pass | |
save_project_state(cfg.exp_name) | |
save_path = os.path.join(cfg.exp_name, save_path) | |
save_config(cfg, save_path, save_formatted=True) | |
return cfg | |
def compile_config_parallel( | |
cfg: EasyDict, | |
create_cfg: EasyDict, | |
system_cfg: EasyDict, | |
seed: int = 0, | |
save_cfg: bool = True, | |
save_path: str = 'total_config.py', | |
platform: str = 'local', | |
coordinator_host: Optional[str] = None, | |
learner_host: Optional[str] = None, | |
collector_host: Optional[str] = None, | |
coordinator_port: Optional[int] = None, | |
learner_port: Optional[int] = None, | |
collector_port: Optional[int] = None, | |
) -> EasyDict: | |
""" | |
Overview: | |
Combine the input parallel mode configuration information with other input information. Compile config\ | |
to make it easy to be called by other programs | |
Arguments: | |
- cfg (:obj:`EasyDict`): Input main config dict | |
- create_cfg (:obj:`dict`): Input create config dict, including type parameters, such as environment type | |
- system_cfg (:obj:`dict`): Input system config dict, including system parameters, such as file path,\ | |
communication mode, use multiple GPUs or not | |
- seed (:obj:`int`): Random number seed | |
- save_cfg (:obj:`bool`): Save config or not | |
- save_path (:obj:`str`): Path of saving file | |
- platform (:obj:`str`): Where to run the program, 'local' or 'slurm' | |
- coordinator_host (:obj:`Optional[str]`): Input coordinator's host when platform is slurm | |
- learner_host (:obj:`Optional[str]`): Input learner's host when platform is slurm | |
- collector_host (:obj:`Optional[str]`): Input collector's host when platform is slurm | |
Returns: | |
- cfg (:obj:`EasyDict`): Config after compiling | |
""" | |
# for compatibility | |
if 'replay_buffer' not in create_cfg: | |
create_cfg.replay_buffer = EasyDict(dict(type='advanced')) | |
# env | |
env = get_env_cls(create_cfg.env) | |
if 'default_config' in dir(env): | |
env_config = env.default_config() | |
else: | |
env_config = EasyDict() # env does not have default_config | |
env_config = deep_merge_dicts(env_config_template, env_config) | |
env_config.update(create_cfg.env) | |
env_manager = get_env_manager_cls(create_cfg.env_manager) | |
env_config.manager = env_manager.default_config() | |
env_config.manager.update(create_cfg.env_manager) | |
# policy | |
policy = get_policy_cls(create_cfg.policy) | |
policy_config = policy.default_config() | |
policy_config = deep_merge_dicts(policy_config_template, policy_config) | |
cfg.policy.update(create_cfg.policy) | |
collector = get_parallel_collector_cls(create_cfg.collector) | |
policy_config.collect.collector = collector.default_config() | |
policy_config.collect.collector.update(create_cfg.collector) | |
policy_config.learn.learner = BaseLearner.default_config() | |
policy_config.learn.learner.update(create_cfg.learner) | |
commander = get_parallel_commander_cls(create_cfg.commander) | |
policy_config.other.commander = commander.default_config() | |
policy_config.other.commander.update(create_cfg.commander) | |
policy_config.other.replay_buffer.update(create_cfg.replay_buffer) | |
policy_config.other.replay_buffer = compile_buffer_config(policy_config, cfg, None) | |
default_config = EasyDict({'env': env_config, 'policy': policy_config}) | |
cfg = deep_merge_dicts(default_config, cfg) | |
cfg.policy.other.commander.path_policy = system_cfg.path_policy # league may use 'path_policy' | |
# system | |
for k in ['comm_learner', 'comm_collector']: | |
system_cfg[k] = create_cfg[k] | |
if platform == 'local': | |
cfg = parallel_transform(EasyDict({'main': cfg, 'system': system_cfg})) | |
elif platform == 'slurm': | |
cfg = parallel_transform_slurm( | |
EasyDict({ | |
'main': cfg, | |
'system': system_cfg | |
}), coordinator_host, learner_host, collector_host | |
) | |
elif platform == 'k8s': | |
cfg = parallel_transform_k8s( | |
EasyDict({ | |
'main': cfg, | |
'system': system_cfg | |
}), | |
coordinator_port=coordinator_port, | |
learner_port=learner_port, | |
collector_port=collector_port | |
) | |
else: | |
raise KeyError("not support platform type: {}".format(platform)) | |
cfg.system.coordinator = deep_merge_dicts(Coordinator.default_config(), cfg.system.coordinator) | |
# seed | |
cfg.seed = seed | |
if save_cfg: | |
save_config(cfg, save_path) | |
return cfg | |