from collections.abc import Iterator from contextlib import contextmanager from typing import TYPE_CHECKING from torch import nn if TYPE_CHECKING: from tianshou.policy import BasePolicy @contextmanager def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]: """Temporarily switch to `module.training=enabled`, affecting things like `BatchNormalization`.""" original_mode = module.training try: module.train(enabled) yield finally: module.train(original_mode) @contextmanager def policy_within_training_step(policy: "BasePolicy", enabled: bool = True) -> Iterator[None]: """Temporarily switch to `policy.is_within_training_step=enabled`. Enabling this ensures that the policy is able to adapt its behavior, allowing it to differentiate between training and inference/evaluation, e.g., to sample actions instead of using the most probable action (where applicable) Note that for rollout, which also happens within a training step, one would usually want the wrapped torch module to be in evaluation mode, which can be achieved using `with torch_train_mode(policy, False)`. For subsequent gradient updates, the policy should be both within training step and in torch train mode. """ original_mode = policy.is_within_training_step try: policy.is_within_training_step = enabled yield finally: policy.is_within_training_step = original_mode