PPO playing procgen-bigfish-easy from https://github.com/sgoodfriend/rl-algo-impls/tree/21ee1ab96a186676e5ed2f8c3185902f7c7bca7a
a678795
import gym | |
import torch | |
import torch.nn as nn | |
from typing import Sequence, Type | |
from shared.module.feature_extractor import FeatureExtractor | |
from shared.module.module import mlp | |
class CriticHead(nn.Module): | |
def __init__( | |
self, | |
hidden_sizes: Sequence[int] = (32,), | |
activation: Type[nn.Module] = nn.Tanh, | |
init_layers_orthogonal: bool = True, | |
) -> None: | |
super().__init__() | |
layer_sizes = tuple(hidden_sizes) + (1,) | |
self._fc = mlp( | |
layer_sizes, | |
activation, | |
init_layers_orthogonal=init_layers_orthogonal, | |
final_layer_gain=1.0, | |
) | |
def forward(self, obs: torch.Tensor) -> torch.Tensor: | |
v = self._fc(obs) | |
return v.squeeze(-1) | |