File size: 6,161 Bytes
9b19c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""Provides a basic interface for launching experiments. The API is experimental and subject to change!."""

import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from copy import copy
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Literal

from joblib import Parallel, delayed

from tianshou.data import InfoStats
from tianshou.highlevel.experiment import Experiment

log = logging.getLogger(__name__)


@dataclass
class JoblibConfig:
    n_jobs: int = -1
    """The maximum number of concurrently running jobs. If -1, all CPUs are used."""
    backend: Literal["loky", "multiprocessing", "threading"] | None = "loky"
    """Allows to hard-code backend, otherwise inferred based on prefer and require."""
    verbose: int = 10
    """If greater than zero, prints progress messages."""


class ExpLauncher(ABC):
    def __init__(
        self,
        experiment_runner: Callable[
            [Experiment],
            InfoStats | None,
        ] = lambda exp: exp.run().trainer_result,
    ):
        """:param experiment_runner: can be used to override the default way in which an experiment is executed.
        Can be useful e.g., if one wants to use the high-level interfaces to setup an experiment (or an experiment
        collection) and tinker with it prior to execution. This need often arises when prototyping with mechanisms
        that are not yet supported by the high-level interfaces.
        Passing this allows arbitrary things to happen during experiment execution, so use it with caution!
        """
        self.experiment_runner = experiment_runner

    @abstractmethod
    def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
        """Should call `self.experiment_runner` for each experiment in experiments and aggregate the results."""

    def _safe_execute(self, exp: Experiment) -> InfoStats | None | Literal["failed"]:
        try:
            return self.experiment_runner(exp)
        except BaseException as e:
            log.error(f"Failed to run experiment {exp}.", exc_info=e)
            return "failed"

    @staticmethod
    def _return_from_successful_and_failed_exps(
        successful_exp_stats: list[InfoStats | None],
        failed_exps: list[Experiment],
    ) -> list[InfoStats | None]:
        if not successful_exp_stats:
            raise RuntimeError("All experiments failed, see error logs for more details.")
        if failed_exps:
            log.error(
                f"Failed to run the following "
                f"{len(failed_exps)}/{len(successful_exp_stats) + len(failed_exps)} experiments: {failed_exps}. "
                f"See the logs for more details. "
                f"Returning the results of {len(successful_exp_stats)} successful experiments.",
            )
        return successful_exp_stats

    def launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
        """Will return the results of successfully executed experiments.

        If a single experiment is passed, will not use parallelism and run it in the main process.
        Failed experiments will be logged, and a RuntimeError is only raised if all experiments have failed.
        """
        if len(experiments) == 1:
            log.info(
                "A single experiment is being run, will not use parallelism and run it in the main process.",
            )
            return [self.experiment_runner(experiments[0])]
        return self._launch(experiments)


class SequentialExpLauncher(ExpLauncher):
    """Convenience wrapper around a simple for loop to run experiments sequentially."""

    def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
        successful_exp_stats = []
        failed_exps = []
        for exp in experiments:
            for exp in experiments:
                exp_stats = self._safe_execute(exp)
                if exp_stats == "failed":
                    failed_exps.append(exp)
                else:
                    successful_exp_stats.append(exp_stats)
        # noinspection PyTypeChecker
        return self._return_from_successful_and_failed_exps(successful_exp_stats, failed_exps)


class JoblibExpLauncher(ExpLauncher):
    def __init__(
        self,
        joblib_cfg: JoblibConfig | None = None,
        experiment_runner: Callable[
            [Experiment],
            InfoStats | None,
        ] = lambda exp: exp.run().trainer_result,
    ) -> None:
        super().__init__(experiment_runner=experiment_runner)
        self.joblib_cfg = copy(joblib_cfg) if joblib_cfg is not None else JoblibConfig()
        # Joblib's backend is hard-coded to loky since the threading backend produces different results
        if self.joblib_cfg.backend != "loky":
            log.warning(
                f"Ignoring the user provided joblib backend {self.joblib_cfg.backend} and using loky instead. "
                f"The current implementation requires loky to work and will be relaxed soon",
            )
            self.joblib_cfg.backend = "loky"

    def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
        results = Parallel(**asdict(self.joblib_cfg))(
            delayed(self._safe_execute)(exp) for exp in experiments
        )
        successful_exps = []
        failed_exps = []
        for exp, result in zip(experiments, results, strict=True):
            if result == "failed":
                failed_exps.append(exp)
            else:
                successful_exps.append(result)
        return self._return_from_successful_and_failed_exps(successful_exps, failed_exps)


class RegisteredExpLauncher(Enum):
    joblib = "joblib"
    sequential = "sequential"

    def create_launcher(self) -> ExpLauncher:
        match self:
            case RegisteredExpLauncher.joblib:
                return JoblibExpLauncher()
            case RegisteredExpLauncher.sequential:
                return SequentialExpLauncher()
            case _:
                raise NotImplementedError(
                    f"Launcher {self} is not yet implemented.",
                )