File size: 5,425 Bytes
9b19c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar, cast

from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import TLogger
from tianshou.policy import BasePolicy, DQNPolicy
from tianshou.utils.string import ToStringMixin

TPolicy = TypeVar("TPolicy", bound=BasePolicy)
log = logging.getLogger(__name__)


class TrainingContext:
    def __init__(self, policy: TPolicy, envs: Environments, logger: TLogger):
        self.policy = policy
        self.envs = envs
        self.logger = logger


class EpochTrainCallback(ToStringMixin, ABC):
    """Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase
    of each epoch.
    """

    @abstractmethod
    def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
        pass

    def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]:
        def fn(epoch: int, env_step: int) -> None:
            return self.callback(epoch, env_step, context)

        return fn


class EpochTestCallback(ToStringMixin, ABC):
    """Callback which is called at the beginning of the test phase of each epoch."""

    @abstractmethod
    def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
        pass

    def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int | None], None]:
        def fn(epoch: int, env_step: int | None) -> None:
            return self.callback(epoch, env_step, context)

        return fn


class EpochStopCallback(ToStringMixin, ABC):
    """Callback which is called after the test phase of each epoch in order to determine
    whether training should stop early.
    """

    @abstractmethod
    def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
        """Determines whether training should stop.

        :param mean_rewards: the average undiscounted returns of the testing result
        :param context: the training context
        :return: True if the goal has been reached and training should stop, False otherwise
        """

    def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]:
        def fn(mean_rewards: float) -> bool:
            return self.should_stop(mean_rewards, context)

        return fn


@dataclass
class TrainerCallbacks:
    """Container for callbacks used during training."""

    epoch_train_callback: EpochTrainCallback | None = None
    epoch_test_callback: EpochTestCallback | None = None
    epoch_stop_callback: EpochStopCallback | None = None


class EpochTrainCallbackDQNSetEps(EpochTrainCallback):
    """Sets the epsilon value for DQN-based policies at the beginning of the training
    stage in each epoch.
    """

    def __init__(self, eps_test: float):
        self.eps_test = eps_test

    def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
        policy = cast(DQNPolicy, context.policy)
        policy.set_eps(self.eps_test)


class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback):
    """Sets the epsilon value for DQN-based policies at the beginning of the training
    stage in each epoch, using a linear decay in the first `decay_steps` steps.
    """

    def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = 1000000):
        self.eps_train = eps_train
        self.eps_train_final = eps_train_final
        self.decay_steps = decay_steps

    def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
        policy = cast(DQNPolicy, context.policy)
        logger = context.logger
        if env_step <= self.decay_steps:
            eps = self.eps_train - env_step / self.decay_steps * (
                self.eps_train - self.eps_train_final
            )
        else:
            eps = self.eps_train_final
        policy.set_eps(eps)
        logger.write("train/env_step", env_step, {"train/eps": eps})


class EpochTestCallbackDQNSetEps(EpochTestCallback):
    """Sets the epsilon value for DQN-based policies at the beginning of the test
    stage in each epoch.
    """

    def __init__(self, eps_test: float):
        self.eps_test = eps_test

    def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
        policy = cast(DQNPolicy, context.policy)
        policy.set_eps(self.eps_test)


class EpochStopCallbackRewardThreshold(EpochStopCallback):
    """Stops training once the mean rewards exceed the given reward threshold or the threshold that
    is specified in the gymnasium environment (i.e. `env.spec.reward_threshold`).
    """

    def __init__(self, threshold: float | None = None):
        """:param threshold: the reward threshold beyond which to stop training.
        If it is None, use threshold given by the environment, i.e. `env.spec.reward_threshold`.
        """
        self.threshold = threshold

    def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
        threshold = self.threshold
        if threshold is None:
            threshold = context.envs.env.spec.reward_threshold  # type: ignore
            assert threshold is not None
        is_reached = mean_rewards >= threshold
        if is_reached:
            log.info(f"Reward threshold ({threshold}) exceeded")
        return is_reached