File size: 5,631 Bytes
3dfe8fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from typing import Optional, Union, Tuple
import time
import pickle
from ditk import logging
from multiprocessing import Process, Event
import threading
from easydict import EasyDict

from ding.worker import create_comm_learner, create_comm_collector, Coordinator
from ding.config import read_config_with_system, compile_config_parallel
from ding.utils import set_pkg_seed


def parallel_pipeline(
        input_cfg: Union[str, Tuple[dict, dict, dict]],
        seed: int,
        enable_total_log: Optional[bool] = False,
        disable_flask_log: Optional[bool] = True,
) -> None:
    r"""
    Overview:
        Parallel pipeline entry.
    Arguments:
        - config (:obj:`Union[str, dict]`): Config file path.
        - seed (:obj:`int`): Random seed.
        - enable_total_log (:obj:`Optional[bool]`): whether enable total DI-engine system log
        - disable_flask_log (:obj:`Optional[bool]`): whether disable flask log
    """
    # Disable some part of DI-engine log
    if not enable_total_log:
        coordinator_log = logging.getLogger('coordinator_logger')
        coordinator_log.disabled = True
    # Disable flask logger.
    if disable_flask_log:
        log = logging.getLogger('werkzeug')
        log.disabled = True
    # Parallel job launch.
    if isinstance(input_cfg, str):
        main_cfg, create_cfg, system_cfg = read_config_with_system(input_cfg)
    elif isinstance(input_cfg, tuple) or isinstance(input_cfg, list):
        main_cfg, create_cfg, system_cfg = input_cfg
    else:
        raise TypeError("invalid config type: {}".format(input_cfg))
    config = compile_config_parallel(main_cfg, create_cfg=create_cfg, system_cfg=system_cfg, seed=seed)
    learner_handle = []
    collector_handle = []
    for k, v in config.system.items():
        if 'learner' in k:
            learner_handle.append(launch_learner(config.seed, v))
        elif 'collector' in k:
            collector_handle.append(launch_collector(config.seed, v))
    launch_coordinator(config.seed, config, learner_handle=learner_handle, collector_handle=collector_handle)


# Following functions are used to launch different components(learner, learner aggregator, collector, coordinator).
# Argument ``config`` is the dict type config. If it is None, then ``filename`` and ``name`` must be passed,
# for they can be used to read corresponding config from file.
def run_learner(config, seed, start_learner_event, close_learner_event):
    set_pkg_seed(seed)
    log = logging.getLogger('werkzeug')
    log.disabled = True
    learner = create_comm_learner(config)
    learner.start()
    start_learner_event.set()
    close_learner_event.wait()
    learner.close()


def launch_learner(
        seed: int, config: Optional[dict] = None, filename: Optional[str] = None, name: Optional[str] = None
) -> list:
    if config is None:
        with open(filename, 'rb') as f:
            config = pickle.load(f)[name]
    start_learner_event = Event()
    close_learner_event = Event()

    learner_thread = Process(
        target=run_learner, args=(config, seed, start_learner_event, close_learner_event), name='learner_entry_process'
    )
    learner_thread.start()
    return learner_thread, start_learner_event, close_learner_event


def run_collector(config, seed, start_collector_event, close_collector_event):
    set_pkg_seed(seed)
    log = logging.getLogger('werkzeug')
    log.disabled = True
    collector = create_comm_collector(config)
    collector.start()
    start_collector_event.set()
    close_collector_event.wait()
    collector.close()


def launch_collector(
        seed: int, config: Optional[dict] = None, filename: Optional[str] = None, name: Optional[str] = None
) -> list:
    if config is None:
        with open(filename, 'rb') as f:
            config = pickle.load(f)[name]
    start_collector_event = Event()
    close_collector_event = Event()

    collector_thread = Process(
        target=run_collector,
        args=(config, seed, start_collector_event, close_collector_event),
        name='collector_entry_process'
    )
    collector_thread.start()
    return collector_thread, start_collector_event, close_collector_event


def launch_coordinator(
        seed: int,
        config: Optional[EasyDict] = None,
        filename: Optional[str] = None,
        learner_handle: Optional[list] = None,
        collector_handle: Optional[list] = None
) -> None:
    set_pkg_seed(seed)
    if config is None:
        with open(filename, 'rb') as f:
            config = pickle.load(f)
    coordinator = Coordinator(config)
    for _, start_event, _ in learner_handle:
        start_event.wait()
    for _, start_event, _ in collector_handle:
        start_event.wait()
    coordinator.start()
    system_shutdown_event = threading.Event()

    # Monitor thread: Coordinator will remain running until its ``system_shutdown_flag`` is set to False.
    def shutdown_monitor():
        while True:
            time.sleep(3)
            if coordinator.system_shutdown_flag:
                coordinator.close()
                for _, _, close_event in learner_handle:
                    close_event.set()
                for _, _, close_event in collector_handle:
                    close_event.set()
                system_shutdown_event.set()
                break

    shutdown_monitor_thread = threading.Thread(target=shutdown_monitor, args=(), daemon=True, name='shutdown_monitor')
    shutdown_monitor_thread.start()
    system_shutdown_event.wait()
    print(
        "[DI-engine parallel pipeline]Your RL agent is converged, you can refer to 'log' and 'tensorboard' for details"
    )