sgoodfriend's picture
VPG playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/e8bc541d8b5e67bb4d3f2075282463fb61f5f2c6
10ca8cb
raw
history blame
1.11 kB
import numpy as np
import torch.nn as nn
from typing import Sequence, Type
def mlp(
layer_sizes: Sequence[int],
activation: Type[nn.Module],
output_activation: Type[nn.Module] = nn.Identity,
init_layers_orthogonal: bool = False,
final_layer_gain: float = np.sqrt(2),
) -> nn.Module:
layers = []
for i in range(len(layer_sizes) - 2):
layers.append(
layer_init(
nn.Linear(layer_sizes[i], layer_sizes[i + 1]), init_layers_orthogonal
)
)
layers.append(activation())
layers.append(
layer_init(
nn.Linear(layer_sizes[-2], layer_sizes[-1]),
init_layers_orthogonal,
std=final_layer_gain,
)
)
layers.append(output_activation())
return nn.Sequential(*layers)
def layer_init(
layer: nn.Module, init_layers_orthogonal: bool, std: float = np.sqrt(2)
) -> nn.Module:
if not init_layers_orthogonal:
return layer
nn.init.orthogonal_(layer.weight, std) # type: ignore
nn.init.constant_(layer.bias, 0.0) # type: ignore
return layer