Spaces:
Sleeping
Sleeping
from typing import Optional, List | |
import copy | |
from easydict import EasyDict | |
from ding.utils import find_free_port, find_free_port_slurm, node_to_partition, node_to_host, pretty_print, \ | |
DEFAULT_K8S_COLLECTOR_PORT, DEFAULT_K8S_LEARNER_PORT, DEFAULT_K8S_COORDINATOR_PORT | |
from dizoo.classic_control.cartpole.config.parallel import cartpole_dqn_config | |
default_host = '0.0.0.0' | |
default_port = 22270 | |
def set_host_port(cfg: EasyDict, coordinator_host: str, learner_host: str, collector_host: str) -> EasyDict: | |
cfg.coordinator.host = coordinator_host | |
if cfg.coordinator.port == 'auto': | |
cfg.coordinator.port = find_free_port(coordinator_host) | |
learner_count = 0 | |
collector_count = 0 | |
for k in cfg.keys(): | |
if k == 'learner_aggregator': | |
raise NotImplementedError | |
if k.startswith('learner'): | |
if cfg[k].host == 'auto': | |
if isinstance(learner_host, list): | |
cfg[k].host = learner_host[learner_count] | |
learner_count += 1 | |
elif isinstance(learner_host, str): | |
cfg[k].host = learner_host | |
else: | |
raise TypeError("not support learner_host type: {}".format(learner_host)) | |
if cfg[k].port == 'auto': | |
cfg[k].port = find_free_port(cfg[k].host) | |
cfg[k].aggregator = False | |
if k.startswith('collector'): | |
if cfg[k].host == 'auto': | |
if isinstance(collector_host, list): | |
cfg[k].host = collector_host[collector_count] | |
collector_count += 1 | |
elif isinstance(collector_host, str): | |
cfg[k].host = collector_host | |
else: | |
raise TypeError("not support collector_host type: {}".format(collector_host)) | |
if cfg[k].port == 'auto': | |
cfg[k].port = find_free_port(cfg[k].host) | |
return cfg | |
def set_host_port_slurm(cfg: EasyDict, coordinator_host: str, learner_node: list, collector_node: list) -> EasyDict: | |
cfg.coordinator.host = coordinator_host | |
if cfg.coordinator.port == 'auto': | |
cfg.coordinator.port = find_free_port(coordinator_host) | |
if isinstance(learner_node, str): | |
learner_node = [learner_node] | |
if isinstance(collector_node, str): | |
collector_node = [collector_node] | |
learner_count, collector_count = 0, 0 | |
learner_multi = {} | |
for k in cfg.keys(): | |
if learner_node is not None and k.startswith('learner'): | |
node = learner_node[learner_count % len(learner_node)] | |
cfg[k].node = node | |
cfg[k].partition = node_to_partition(node) | |
gpu_num = cfg[k].gpu_num | |
if cfg[k].host == 'auto': | |
cfg[k].host = node_to_host(node) | |
if cfg[k].port == 'auto': | |
if gpu_num == 1: | |
cfg[k].port = find_free_port_slurm(node) | |
learner_multi[k] = False | |
else: | |
cfg[k].port = [find_free_port_slurm(node) for _ in range(gpu_num)] | |
learner_multi[k] = True | |
learner_count += 1 | |
if collector_node is not None and k.startswith('collector'): | |
node = collector_node[collector_count % len(collector_node)] | |
cfg[k].node = node | |
cfg[k].partition = node_to_partition(node) | |
if cfg[k].host == 'auto': | |
cfg[k].host = node_to_host(node) | |
if cfg[k].port == 'auto': | |
cfg[k].port = find_free_port_slurm(node) | |
collector_count += 1 | |
for k, flag in learner_multi.items(): | |
if flag: | |
host = cfg[k].host | |
learner_interaction_cfg = {str(i): [str(i), host, p] for i, p in enumerate(cfg[k].port)} | |
aggregator_cfg = dict( | |
master=dict( | |
host=host, | |
port=find_free_port_slurm(cfg[k].node), | |
), | |
slave=dict( | |
host=host, | |
port=find_free_port_slurm(cfg[k].node), | |
), | |
learner=learner_interaction_cfg, | |
node=cfg[k].node, | |
partition=cfg[k].partition, | |
) | |
cfg[k].aggregator = True | |
cfg['learner_aggregator' + k[7:]] = aggregator_cfg | |
else: | |
cfg[k].aggregator = False | |
return cfg | |
def set_host_port_k8s(cfg: EasyDict, coordinator_port: int, learner_port: int, collector_port: int) -> EasyDict: | |
cfg.coordinator.host = default_host | |
cfg.coordinator.port = coordinator_port if coordinator_port is not None else DEFAULT_K8S_COORDINATOR_PORT | |
base_learner_cfg = None | |
base_collector_cfg = None | |
if learner_port is None: | |
learner_port = DEFAULT_K8S_LEARNER_PORT | |
if collector_port is None: | |
collector_port = DEFAULT_K8S_COLLECTOR_PORT | |
for k in cfg.keys(): | |
if k.startswith('learner'): | |
# create the base learner config | |
if base_learner_cfg is None: | |
base_learner_cfg = copy.deepcopy(cfg[k]) | |
base_learner_cfg.host = default_host | |
base_learner_cfg.port = learner_port | |
cfg[k].port = learner_port | |
elif k.startswith('collector'): | |
# create the base collector config | |
if base_collector_cfg is None: | |
base_collector_cfg = copy.deepcopy(cfg[k]) | |
base_collector_cfg.host = default_host | |
base_collector_cfg.port = collector_port | |
cfg[k].port = collector_port | |
cfg['learner'] = base_learner_cfg | |
cfg['collector'] = base_collector_cfg | |
return cfg | |
def set_learner_interaction_for_coordinator(cfg: EasyDict) -> EasyDict: | |
cfg.coordinator.learner = {} | |
for k in cfg.keys(): | |
if k.startswith('learner') and not k.startswith('learner_aggregator'): | |
if cfg[k].aggregator: | |
dst_k = 'learner_aggregator' + k[7:] | |
cfg.coordinator.learner[k] = [k, cfg[dst_k].slave.host, cfg[dst_k].slave.port] | |
else: | |
dst_k = k | |
cfg.coordinator.learner[k] = [k, cfg[dst_k].host, cfg[dst_k].port] | |
return cfg | |
def set_collector_interaction_for_coordinator(cfg: EasyDict) -> EasyDict: | |
cfg.coordinator.collector = {} | |
for k in cfg.keys(): | |
if k.startswith('collector'): | |
cfg.coordinator.collector[k] = [k, cfg[k].host, cfg[k].port] | |
return cfg | |
def set_system_cfg(cfg: EasyDict) -> EasyDict: | |
learner_num = cfg.main.policy.learn.learner.learner_num | |
collector_num = cfg.main.policy.collect.collector.collector_num | |
path_data = cfg.system.path_data | |
path_policy = cfg.system.path_policy | |
coordinator_cfg = cfg.system.coordinator | |
communication_mode = cfg.system.communication_mode | |
assert communication_mode in ['auto'], communication_mode | |
learner_gpu_num = cfg.system.learner_gpu_num | |
learner_multi_gpu = learner_gpu_num > 1 | |
new_cfg = dict(coordinator=dict( | |
host='auto', | |
port='auto', | |
)) | |
new_cfg['coordinator'].update(coordinator_cfg) | |
for i in range(learner_num): | |
new_cfg[f'learner{i}'] = dict( | |
type=cfg.system.comm_learner.type, | |
import_names=cfg.system.comm_learner.import_names, | |
host='auto', | |
port='auto', | |
path_data=path_data, | |
path_policy=path_policy, | |
multi_gpu=learner_multi_gpu, | |
gpu_num=learner_gpu_num, | |
) | |
for i in range(collector_num): | |
new_cfg[f'collector{i}'] = dict( | |
type=cfg.system.comm_collector.type, | |
import_names=cfg.system.comm_collector.import_names, | |
host='auto', | |
port='auto', | |
path_data=path_data, | |
path_policy=path_policy, | |
) | |
return EasyDict(new_cfg) | |
def parallel_transform( | |
cfg: dict, | |
coordinator_host: Optional[str] = None, | |
learner_host: Optional[List[str]] = None, | |
collector_host: Optional[List[str]] = None | |
) -> None: | |
coordinator_host = default_host if coordinator_host is None else coordinator_host | |
collector_host = default_host if collector_host is None else collector_host | |
learner_host = default_host if learner_host is None else learner_host | |
cfg = EasyDict(cfg) | |
cfg.system = set_system_cfg(cfg) | |
cfg.system = set_host_port(cfg.system, coordinator_host, learner_host, collector_host) | |
cfg.system = set_learner_interaction_for_coordinator(cfg.system) | |
cfg.system = set_collector_interaction_for_coordinator(cfg.system) | |
return cfg | |
def parallel_transform_slurm( | |
cfg: dict, | |
coordinator_host: Optional[str] = None, | |
learner_node: Optional[List[str]] = None, | |
collector_node: Optional[List[str]] = None | |
) -> None: | |
cfg = EasyDict(cfg) | |
cfg.system = set_system_cfg(cfg) | |
cfg.system = set_host_port_slurm(cfg.system, coordinator_host, learner_node, collector_node) | |
cfg.system = set_learner_interaction_for_coordinator(cfg.system) | |
cfg.system = set_collector_interaction_for_coordinator(cfg.system) | |
pretty_print(cfg) | |
return cfg | |
def parallel_transform_k8s( | |
cfg: dict, | |
coordinator_port: Optional[int] = None, | |
learner_port: Optional[int] = None, | |
collector_port: Optional[int] = None | |
) -> None: | |
cfg = EasyDict(cfg) | |
cfg.system = set_system_cfg(cfg) | |
cfg.system = set_host_port_k8s(cfg.system, coordinator_port, learner_port, collector_port) | |
# learner/collector is created by opereator, so the following field is placeholder | |
cfg.system.coordinator.collector = {} | |
cfg.system.coordinator.learner = {} | |
pretty_print(cfg) | |
return cfg | |
def save_config_formatted(config_: dict, path: str = 'formatted_total_config.py') -> None: | |
""" | |
Overview: | |
save formatted configuration to python file that can be read by serial_pipeline directly. | |
Arguments: | |
- config (:obj:`dict`): Config dict | |
- path (:obj:`str`): Path of python file | |
""" | |
with open(path, "w") as f: | |
f.write('from easydict import EasyDict\n\n') | |
f.write('main_config = dict(\n') | |
f.write(" exp_name='{}',\n".format(config_.exp_name)) | |
for k, v in config_.items(): | |
if (k == 'env'): | |
f.write(' env=dict(\n') | |
for k2, v2 in v.items(): | |
if (k2 != 'type' and k2 != 'import_names' and k2 != 'manager'): | |
if (isinstance(v2, str)): | |
f.write(" {}='{}',\n".format(k2, v2)) | |
else: | |
f.write(" {}={},\n".format(k2, v2)) | |
if (k2 == 'manager'): | |
f.write(" manager=dict(\n") | |
for k3, v3 in v2.items(): | |
if (v3 != 'cfg_type' and v3 != 'type'): | |
if (isinstance(v3, str)): | |
f.write(" {}='{}',\n".format(k3, v3)) | |
elif v3 == float('inf'): | |
f.write(" {}=float('{}'),\n".format(k3, v3)) | |
else: | |
f.write(" {}={},\n".format(k3, v3)) | |
f.write(" ),\n") | |
f.write(" ),\n") | |
if (k == 'policy'): | |
f.write(' policy=dict(\n') | |
for k2, v2 in v.items(): | |
if (k2 != 'type' and k2 != 'learn' and k2 != 'collect' and k2 != 'eval' and k2 != 'other' | |
and k2 != 'model'): | |
if (isinstance(v2, str)): | |
f.write(" {}='{}',\n".format(k2, v2)) | |
else: | |
f.write(" {}={},\n".format(k2, v2)) | |
elif (k2 == 'learn'): | |
f.write(" learn=dict(\n") | |
for k3, v3 in v2.items(): | |
if (k3 != 'learner'): | |
if (isinstance(v3, str)): | |
f.write(" {}='{}',\n".format(k3, v3)) | |
else: | |
f.write(" {}={},\n".format(k3, v3)) | |
if (k3 == 'learner'): | |
f.write(" learner=dict(\n") | |
for k4, v4 in v3.items(): | |
if (k4 != 'dataloader' and k4 != 'hook'): | |
if (isinstance(v4, str)): | |
f.write(" {}='{}',\n".format(k4, v4)) | |
else: | |
f.write(" {}={},\n".format(k4, v4)) | |
else: | |
if (k4 == 'dataloader'): | |
f.write(" dataloader=dict(\n") | |
for k5, v5 in v4.items(): | |
if (isinstance(v5, str)): | |
f.write(" {}='{}',\n".format(k5, v5)) | |
else: | |
f.write(" {}={},\n".format(k5, v5)) | |
f.write(" ),\n") | |
if (k4 == 'hook'): | |
f.write(" hook=dict(\n") | |
for k5, v5 in v4.items(): | |
if (isinstance(v5, str)): | |
f.write(" {}='{}',\n".format(k5, v5)) | |
else: | |
f.write(" {}={},\n".format(k5, v5)) | |
f.write(" ),\n") | |
f.write(" ),\n") | |
f.write(" ),\n") | |
elif (k2 == 'collect'): | |
f.write(" collect=dict(\n") | |
for k3, v3 in v2.items(): | |
if (k3 != 'collector'): | |
if (isinstance(v3, str)): | |
f.write(" {}='{}',\n".format(k3, v3)) | |
else: | |
f.write(" {}={},\n".format(k3, v3)) | |
if (k3 == 'collector'): | |
f.write(" collector=dict(\n") | |
for k4, v4 in v3.items(): | |
if (isinstance(v4, str)): | |
f.write(" {}='{}',\n".format(k4, v4)) | |
else: | |
f.write(" {}={},\n".format(k4, v4)) | |
f.write(" ),\n") | |
f.write(" ),\n") | |
elif (k2 == 'eval'): | |
f.write(" eval=dict(\n") | |
for k3, v3 in v2.items(): | |
if (k3 != 'evaluator'): | |
if (isinstance(v3, str)): | |
f.write(" {}='{}',\n".format(k3, v3)) | |
else: | |
f.write(" {}={},\n".format(k3, v3)) | |
if (k3 == 'evaluator'): | |
f.write(" evaluator=dict(\n") | |
for k4, v4 in v3.items(): | |
if (isinstance(v4, str)): | |
f.write(" {}='{}',\n".format(k4, v4)) | |
else: | |
f.write(" {}={},\n".format(k4, v4)) | |
f.write(" ),\n") | |
f.write(" ),\n") | |
elif (k2 == 'model'): | |
f.write(" model=dict(\n") | |
for k3, v3 in v2.items(): | |
if (isinstance(v3, str)): | |
f.write(" {}='{}',\n".format(k3, v3)) | |
else: | |
f.write(" {}={},\n".format(k3, v3)) | |
f.write(" ),\n") | |
elif (k2 == 'other'): | |
f.write(" other=dict(\n") | |
for k3, v3 in v2.items(): | |
if (k3 == 'replay_buffer'): | |
f.write(" replay_buffer=dict(\n") | |
for k4, v4 in v3.items(): | |
if (k4 != 'monitor' and k4 != 'thruput_controller'): | |
if (isinstance(v4, dict)): | |
f.write(" {}=dict(\n".format(k4)) | |
for k5, v5 in v4.items(): | |
if (isinstance(v5, str)): | |
f.write(" {}='{}',\n".format(k5, v5)) | |
elif v5 == float('inf'): | |
f.write(" {}=float('{}'),\n".format(k5, v5)) | |
elif (isinstance(v5, dict)): | |
f.write(" {}=dict(\n".format(k5)) | |
for k6, v6 in v5.items(): | |
if (isinstance(v6, str)): | |
f.write(" {}='{}',\n".format(k6, v6)) | |
elif v6 == float('inf'): | |
f.write( | |
" {}=float('{}'),\n".format( | |
k6, v6 | |
) | |
) | |
elif (isinstance(v6, dict)): | |
f.write(" {}=dict(\n".format(k6)) | |
for k7, v7 in v6.items(): | |
if (isinstance(v7, str)): | |
f.write( | |
" {}='{}',\n".format( | |
k7, v7 | |
) | |
) | |
elif v7 == float('inf'): | |
f.write( | |
" {}=float('{}'),\n". | |
format(k7, v7) | |
) | |
else: | |
f.write( | |
" {}={},\n".format( | |
k7, v7 | |
) | |
) | |
f.write(" ),\n") | |
else: | |
f.write(" {}={},\n".format(k6, v6)) | |
f.write(" ),\n") | |
else: | |
f.write(" {}={},\n".format(k5, v5)) | |
f.write(" ),\n") | |
else: | |
if (isinstance(v4, str)): | |
f.write(" {}='{}',\n".format(k4, v4)) | |
elif v4 == float('inf'): | |
f.write(" {}=float('{}'),\n".format(k4, v4)) | |
else: | |
f.write(" {}={},\n".format(k4, v4)) | |
else: | |
if (k4 == 'monitor'): | |
f.write(" monitor=dict(\n") | |
for k5, v5 in v4.items(): | |
if (k5 == 'log_path'): | |
if (isinstance(v5, str)): | |
f.write(" {}='{}',\n".format(k5, v5)) | |
else: | |
f.write(" {}={},\n".format(k5, v5)) | |
else: | |
f.write(" {}=dict(\n".format(k5)) | |
for k6, v6 in v5.items(): | |
if (isinstance(v6, str)): | |
f.write(" {}='{}',\n".format(k6, v6)) | |
else: | |
f.write(" {}={},\n".format(k6, v6)) | |
f.write(" ),\n") | |
f.write(" ),\n") | |
if (k4 == 'thruput_controller'): | |
f.write(" thruput_controller=dict(\n") | |
for k5, v5 in v4.items(): | |
if (isinstance(v5, dict)): | |
f.write(" {}=dict(\n".format(k5)) | |
for k6, v6 in v5.items(): | |
if (isinstance(v6, str)): | |
f.write(" {}='{}',\n".format(k6, v6)) | |
elif v6 == float('inf'): | |
f.write( | |
" {}=float('{}'),\n".format( | |
k6, v6 | |
) | |
) | |
else: | |
f.write(" {}={},\n".format(k6, v6)) | |
f.write(" ),\n") | |
else: | |
if (isinstance(v5, str)): | |
f.write(" {}='{}',\n".format(k5, v5)) | |
else: | |
f.write(" {}={},\n".format(k5, v5)) | |
f.write(" ),\n") | |
f.write(" ),\n") | |
f.write(" ),\n") | |
f.write(" ),\n)\n") | |
f.write('main_config = EasyDict(main_config)\n') | |
f.write('main_config = main_config\n') | |
f.write('create_config = dict(\n') | |
for k, v in config_.items(): | |
if (k == 'env'): | |
f.write(' env=dict(\n') | |
for k2, v2 in v.items(): | |
if (k2 == 'type' or k2 == 'import_names'): | |
if isinstance(v2, str): | |
f.write(" {}='{}',\n".format(k2, v2)) | |
else: | |
f.write(" {}={},\n".format(k2, v2)) | |
f.write(" ),\n") | |
for k2, v2 in v.items(): | |
if (k2 == 'manager'): | |
f.write(' env_manager=dict(\n') | |
for k3, v3 in v2.items(): | |
if (k3 == 'cfg_type' or k3 == 'type'): | |
if (isinstance(v3, str)): | |
f.write(" {}='{}',\n".format(k3, v3)) | |
else: | |
f.write(" {}={},\n".format(k3, v3)) | |
f.write(" ),\n") | |
policy_type = config_.policy.type | |
if '_command' in policy_type: | |
f.write(" policy=dict(type='{}'),\n".format(policy_type[0:len(policy_type) - 8])) | |
else: | |
f.write(" policy=dict(type='{}'),\n".format(policy_type)) | |
f.write(")\n") | |
f.write('create_config = EasyDict(create_config)\n') | |
f.write('create_config = create_config\n') | |
parallel_test_main_config = cartpole_dqn_config | |
parallel_test_create_config = dict( | |
env=dict( | |
type='cartpole', | |
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], | |
), | |
env_manager=dict(type='subprocess'), | |
policy=dict(type='dqn_command'), | |
comm_learner=dict( | |
type='flask_fs', | |
import_names=['ding.worker.learner.comm.flask_fs_learner'], | |
), | |
comm_collector=dict( | |
type='flask_fs', | |
import_names=['ding.worker.collector.comm.flask_fs_collector'], | |
), | |
learner=dict( | |
type='base', | |
import_names=['ding.worker.learner.base_learner'], | |
), | |
collector=dict( | |
type='zergling', | |
import_names=['ding.worker.collector.zergling_parallel_collector'], | |
), | |
commander=dict( | |
type='naive', | |
import_names=['ding.worker.coordinator.base_parallel_commander'], | |
), | |
) | |
parallel_test_create_config = EasyDict(parallel_test_create_config) | |
parallel_test_system_config = dict( | |
coordinator=dict(), | |
path_data='.', | |
path_policy='.', | |
communication_mode='auto', | |
learner_gpu_num=1, | |
) | |
parallel_test_system_config = EasyDict(parallel_test_system_config) | |