File size: 649 Bytes
9b19c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from dataclasses import dataclass

import torch

from tianshou.utils.net.common import ActorCritic


@dataclass
class ModuleOpt:
    """Container for a torch module along with its optimizer."""

    module: torch.nn.Module
    optim: torch.optim.Optimizer


@dataclass
class ActorCriticOpt:
    """Container for an :class:`ActorCritic` instance along with its optimizer."""

    actor_critic_module: ActorCritic
    optim: torch.optim.Optimizer

    @property
    def actor(self) -> torch.nn.Module:
        return self.actor_critic_module.actor

    @property
    def critic(self) -> torch.nn.Module:
        return self.actor_critic_module.critic