Spaces:
Sleeping
Sleeping
File size: 5,631 Bytes
3dfe8fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from typing import Optional, Union, Tuple
import time
import pickle
from ditk import logging
from multiprocessing import Process, Event
import threading
from easydict import EasyDict
from ding.worker import create_comm_learner, create_comm_collector, Coordinator
from ding.config import read_config_with_system, compile_config_parallel
from ding.utils import set_pkg_seed
def parallel_pipeline(
input_cfg: Union[str, Tuple[dict, dict, dict]],
seed: int,
enable_total_log: Optional[bool] = False,
disable_flask_log: Optional[bool] = True,
) -> None:
r"""
Overview:
Parallel pipeline entry.
Arguments:
- config (:obj:`Union[str, dict]`): Config file path.
- seed (:obj:`int`): Random seed.
- enable_total_log (:obj:`Optional[bool]`): whether enable total DI-engine system log
- disable_flask_log (:obj:`Optional[bool]`): whether disable flask log
"""
# Disable some part of DI-engine log
if not enable_total_log:
coordinator_log = logging.getLogger('coordinator_logger')
coordinator_log.disabled = True
# Disable flask logger.
if disable_flask_log:
log = logging.getLogger('werkzeug')
log.disabled = True
# Parallel job launch.
if isinstance(input_cfg, str):
main_cfg, create_cfg, system_cfg = read_config_with_system(input_cfg)
elif isinstance(input_cfg, tuple) or isinstance(input_cfg, list):
main_cfg, create_cfg, system_cfg = input_cfg
else:
raise TypeError("invalid config type: {}".format(input_cfg))
config = compile_config_parallel(main_cfg, create_cfg=create_cfg, system_cfg=system_cfg, seed=seed)
learner_handle = []
collector_handle = []
for k, v in config.system.items():
if 'learner' in k:
learner_handle.append(launch_learner(config.seed, v))
elif 'collector' in k:
collector_handle.append(launch_collector(config.seed, v))
launch_coordinator(config.seed, config, learner_handle=learner_handle, collector_handle=collector_handle)
# Following functions are used to launch different components(learner, learner aggregator, collector, coordinator).
# Argument ``config`` is the dict type config. If it is None, then ``filename`` and ``name`` must be passed,
# for they can be used to read corresponding config from file.
def run_learner(config, seed, start_learner_event, close_learner_event):
set_pkg_seed(seed)
log = logging.getLogger('werkzeug')
log.disabled = True
learner = create_comm_learner(config)
learner.start()
start_learner_event.set()
close_learner_event.wait()
learner.close()
def launch_learner(
seed: int, config: Optional[dict] = None, filename: Optional[str] = None, name: Optional[str] = None
) -> list:
if config is None:
with open(filename, 'rb') as f:
config = pickle.load(f)[name]
start_learner_event = Event()
close_learner_event = Event()
learner_thread = Process(
target=run_learner, args=(config, seed, start_learner_event, close_learner_event), name='learner_entry_process'
)
learner_thread.start()
return learner_thread, start_learner_event, close_learner_event
def run_collector(config, seed, start_collector_event, close_collector_event):
set_pkg_seed(seed)
log = logging.getLogger('werkzeug')
log.disabled = True
collector = create_comm_collector(config)
collector.start()
start_collector_event.set()
close_collector_event.wait()
collector.close()
def launch_collector(
seed: int, config: Optional[dict] = None, filename: Optional[str] = None, name: Optional[str] = None
) -> list:
if config is None:
with open(filename, 'rb') as f:
config = pickle.load(f)[name]
start_collector_event = Event()
close_collector_event = Event()
collector_thread = Process(
target=run_collector,
args=(config, seed, start_collector_event, close_collector_event),
name='collector_entry_process'
)
collector_thread.start()
return collector_thread, start_collector_event, close_collector_event
def launch_coordinator(
seed: int,
config: Optional[EasyDict] = None,
filename: Optional[str] = None,
learner_handle: Optional[list] = None,
collector_handle: Optional[list] = None
) -> None:
set_pkg_seed(seed)
if config is None:
with open(filename, 'rb') as f:
config = pickle.load(f)
coordinator = Coordinator(config)
for _, start_event, _ in learner_handle:
start_event.wait()
for _, start_event, _ in collector_handle:
start_event.wait()
coordinator.start()
system_shutdown_event = threading.Event()
# Monitor thread: Coordinator will remain running until its ``system_shutdown_flag`` is set to False.
def shutdown_monitor():
while True:
time.sleep(3)
if coordinator.system_shutdown_flag:
coordinator.close()
for _, _, close_event in learner_handle:
close_event.set()
for _, _, close_event in collector_handle:
close_event.set()
system_shutdown_event.set()
break
shutdown_monitor_thread = threading.Thread(target=shutdown_monitor, args=(), daemon=True, name='shutdown_monitor')
shutdown_monitor_thread.start()
system_shutdown_event.wait()
print(
"[DI-engine parallel pipeline]Your RL agent is converged, you can refer to 'log' and 'tensorboard' for details"
)
|