File size: 5,998 Bytes
9b19c29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
from typing import Any, Literal, Self, TypeVar
import gymnasium as gym
import numpy as np
import torch
import torch.nn.functional as F
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
from tianshou.policy import BasePolicy
from tianshou.policy.base import (
TLearningRateScheduler,
TrainingStats,
TrainingStatsWrapper,
TTrainingStats,
)
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
class ICMTrainingStats(TrainingStatsWrapper):
def __init__(
self,
wrapped_stats: TrainingStats,
*,
icm_loss: float,
icm_forward_loss: float,
icm_inverse_loss: float,
) -> None:
self.icm_loss = icm_loss
self.icm_forward_loss = icm_forward_loss
self.icm_inverse_loss = icm_inverse_loss
super().__init__(wrapped_stats)
class ICMPolicy(BasePolicy[ICMTrainingStats]):
"""Implementation of Intrinsic Curiosity Module. arXiv:1705.05363.
:param policy: a base policy to add ICM to.
:param model: the ICM model.
:param optim: a torch.optim for optimizing the model.
:param lr_scale: the scaling factor for ICM learning.
:param forward_loss_weight: the weight for forward model loss.
:param observation_space: Env's observation space.
:param action_scaling: if True, scale the action from [-1, 1] to the range
of action_space. Only used if the action_space is continuous.
:param action_bound_method: method to bound action to range [-1, 1].
Only used if the action_space is continuous.
:param lr_scheduler: if not None, will be called in `policy.update()`.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
*,
policy: BasePolicy[TTrainingStats],
model: IntrinsicCuriosityModule,
optim: torch.optim.Optimizer,
lr_scale: float,
reward_scale: float,
forward_loss_weight: float,
action_space: gym.Space,
observation_space: gym.Space | None = None,
action_scaling: bool = False,
action_bound_method: Literal["clip", "tanh"] | None = "clip",
lr_scheduler: TLearningRateScheduler | None = None,
) -> None:
super().__init__(
action_space=action_space,
observation_space=observation_space,
action_scaling=action_scaling,
action_bound_method=action_bound_method,
lr_scheduler=lr_scheduler,
)
self.policy = policy
self.model = model
self.optim = optim
self.lr_scale = lr_scale
self.reward_scale = reward_scale
self.forward_loss_weight = forward_loss_weight
def train(self, mode: bool = True) -> Self:
"""Set the module in training mode."""
self.policy.train(mode)
self.training = mode
self.model.train(mode)
return self
def forward(
self,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> ActBatchProtocol:
"""Compute action over the given batch data by inner policy.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
return self.policy.forward(batch, state, **kwargs)
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise(
self,
act: _TArrOrActBatch,
batch: ObsBatchProtocol,
) -> _TArrOrActBatch:
return self.policy.exploration_noise(act, batch)
def set_eps(self, eps: float) -> None:
"""Set the eps for epsilon-greedy exploration."""
if hasattr(self.policy, "set_eps"):
self.policy.set_eps(eps)
else:
raise NotImplementedError
def process_fn(
self,
batch: RolloutBatchProtocol,
buffer: ReplayBuffer,
indices: np.ndarray,
) -> RolloutBatchProtocol:
"""Pre-process the data from the provided replay buffer.
Used in :meth:`update`. Check out :ref:`process_fn` for more information.
"""
mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next)
batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss)
batch.rew += to_numpy(mse_loss * self.reward_scale)
return self.policy.process_fn(batch, buffer, indices)
def post_process_fn(
self,
batch: BatchProtocol,
buffer: ReplayBuffer,
indices: np.ndarray,
) -> None:
"""Post-process the data from the provided replay buffer.
Typical usage is to update the sampling weight in prioritized
experience replay. Used in :meth:`update`.
"""
self.policy.post_process_fn(batch, buffer, indices)
batch.rew = batch.policy.orig_rew # restore original reward
def learn(
self,
batch: RolloutBatchProtocol,
*args: Any,
**kwargs: Any,
) -> ICMTrainingStats:
training_stat = self.policy.learn(batch, **kwargs)
self.optim.zero_grad()
act_hat = batch.policy.act_hat
act = to_torch(batch.act, dtype=torch.long, device=act_hat.device)
inverse_loss = F.cross_entropy(act_hat, act).mean()
forward_loss = batch.policy.mse_loss.mean()
loss = (
(1 - self.forward_loss_weight) * inverse_loss + self.forward_loss_weight * forward_loss
) * self.lr_scale
loss.backward()
self.optim.step()
return ICMTrainingStats(
training_stat,
icm_loss=loss.item(),
icm_forward_loss=forward_loss.item(),
icm_inverse_loss=inverse_loss.item(),
)
|