|
from abc import ABC, abstractmethod |
|
from collections.abc import Sequence |
|
from typing import Generic, TypeVar |
|
|
|
from tianshou.highlevel.env import Environments |
|
from tianshou.highlevel.module.core import TDevice |
|
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory |
|
from tianshou.highlevel.optim import OptimizerFactory |
|
from tianshou.policy import BasePolicy, ICMPolicy |
|
from tianshou.utils.net.discrete import IntrinsicCuriosityModule |
|
from tianshou.utils.string import ToStringMixin |
|
|
|
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy) |
|
|
|
|
|
class PolicyWrapperFactory(Generic[TPolicyOut], ToStringMixin, ABC): |
|
@abstractmethod |
|
def create_wrapped_policy( |
|
self, |
|
policy: BasePolicy, |
|
envs: Environments, |
|
optim_factory: OptimizerFactory, |
|
device: TDevice, |
|
) -> TPolicyOut: |
|
pass |
|
|
|
|
|
class PolicyWrapperFactoryIntrinsicCuriosity( |
|
PolicyWrapperFactory[ICMPolicy], |
|
): |
|
def __init__( |
|
self, |
|
*, |
|
feature_net_factory: IntermediateModuleFactory, |
|
hidden_sizes: Sequence[int], |
|
lr: float, |
|
lr_scale: float, |
|
reward_scale: float, |
|
forward_loss_weight: float, |
|
): |
|
self.feature_net_factory = feature_net_factory |
|
self.hidden_sizes = hidden_sizes |
|
self.lr = lr |
|
self.lr_scale = lr_scale |
|
self.reward_scale = reward_scale |
|
self.forward_loss_weight = forward_loss_weight |
|
|
|
def create_wrapped_policy( |
|
self, |
|
policy: BasePolicy, |
|
envs: Environments, |
|
optim_factory: OptimizerFactory, |
|
device: TDevice, |
|
) -> ICMPolicy: |
|
feature_net = self.feature_net_factory.create_intermediate_module(envs, device) |
|
action_dim = envs.get_action_shape() |
|
if not isinstance(action_dim, int): |
|
raise ValueError(f"Environment action shape must be an integer, got {action_dim}") |
|
feature_dim = feature_net.output_dim |
|
icm_net = IntrinsicCuriosityModule( |
|
feature_net.module, |
|
feature_dim, |
|
action_dim, |
|
hidden_sizes=self.hidden_sizes, |
|
device=device, |
|
) |
|
icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr) |
|
return ICMPolicy( |
|
policy=policy, |
|
model=icm_net, |
|
optim=icm_optim, |
|
action_space=envs.get_action_space(), |
|
lr_scale=self.lr_scale, |
|
reward_scale=self.reward_scale, |
|
forward_loss_weight=self.forward_loss_weight, |
|
).to(device) |
|
|