File size: 1,483 Bytes
9b19c29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING
from torch import nn
if TYPE_CHECKING:
from tianshou.policy import BasePolicy
@contextmanager
def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]:
"""Temporarily switch to `module.training=enabled`, affecting things like `BatchNormalization`."""
original_mode = module.training
try:
module.train(enabled)
yield
finally:
module.train(original_mode)
@contextmanager
def policy_within_training_step(policy: "BasePolicy", enabled: bool = True) -> Iterator[None]:
"""Temporarily switch to `policy.is_within_training_step=enabled`.
Enabling this ensures that the policy is able to adapt its behavior,
allowing it to differentiate between training and inference/evaluation,
e.g., to sample actions instead of using the most probable action (where applicable)
Note that for rollout, which also happens within a training step, one would usually want
the wrapped torch module to be in evaluation mode, which can be achieved using
`with torch_train_mode(policy, False)`. For subsequent gradient updates, the policy should be both
within training step and in torch train mode.
"""
original_mode = policy.is_within_training_step
try:
policy.is_within_training_step = enabled
yield
finally:
policy.is_within_training_step = original_mode
|