File size: 1,250 Bytes
9b19c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from collections.abc import Sequence

from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import ModuleFactory, TDevice
from tianshou.highlevel.module.intermediate import IntermediateModuleFactory
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
from tianshou.utils.string import ToStringMixin


class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin):
    def __init__(
        self,
        preprocess_net_factory: IntermediateModuleFactory,
        hidden_sizes: Sequence[int] = (),
        num_cosines: int = 64,
    ):
        self.preprocess_net_factory = preprocess_net_factory
        self.hidden_sizes = hidden_sizes
        self.num_cosines = num_cosines

    def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork:
        preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device)
        return ImplicitQuantileNetwork(
            preprocess_net=preprocess_net.module,
            action_shape=envs.get_action_shape(),
            hidden_sizes=self.hidden_sizes,
            num_cosines=self.num_cosines,
            preprocess_net_output_dim=preprocess_net.output_dim,
            device=device,
        ).to(device)