vpg-PongNoFrameskip-v4 / shared /module /feature_extractor.py
sgoodfriend's picture
VPG playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/e8bc541d8b5e67bb4d3f2075282463fb61f5f2c6
10ca8cb
raw
history blame
6.85 kB
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
from gym.spaces import Box, Discrete
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from typing import Dict, Optional, Sequence, Type
from shared.module.module import layer_init
class CnnFeatureExtractor(nn.Module, ABC):
@abstractmethod
def __init__(
self,
in_channels: int,
activation: Type[nn.Module] = nn.ReLU,
init_layers_orthogonal: Optional[bool] = None,
**kwargs,
) -> None:
super().__init__()
class NatureCnn(CnnFeatureExtractor):
"""
CNN from DQN Nature paper: Mnih, Volodymyr, et al.
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529-533.
"""
def __init__(
self,
in_channels: int,
activation: Type[nn.Module] = nn.ReLU,
init_layers_orthogonal: Optional[bool] = None,
**kwargs,
) -> None:
if init_layers_orthogonal is None:
init_layers_orthogonal = True
super().__init__(in_channels, activation, init_layers_orthogonal)
self.cnn = nn.Sequential(
layer_init(
nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
init_layers_orthogonal,
),
activation(),
layer_init(
nn.Conv2d(32, 64, kernel_size=4, stride=2),
init_layers_orthogonal,
),
activation(),
layer_init(
nn.Conv2d(64, 64, kernel_size=3, stride=1),
init_layers_orthogonal,
),
activation(),
nn.Flatten(),
)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
return self.cnn(obs)
class ResidualBlock(nn.Module):
def __init__(
self,
channels: int,
activation: Type[nn.Module] = nn.ReLU,
init_layers_orthogonal: bool = False,
) -> None:
super().__init__()
self.residual = nn.Sequential(
activation(),
layer_init(
nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
),
activation(),
layer_init(
nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.residual(x)
class ConvSequence(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
activation: Type[nn.Module] = nn.ReLU,
init_layers_orthogonal: bool = False,
) -> None:
super().__init__()
self.seq = nn.Sequential(
layer_init(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
init_layers_orthogonal,
),
nn.MaxPool2d(3, stride=2, padding=1),
ResidualBlock(out_channels, activation, init_layers_orthogonal),
ResidualBlock(out_channels, activation, init_layers_orthogonal),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.seq(x)
class ImpalaCnn(CnnFeatureExtractor):
"""
IMPALA-style CNN architecture
"""
def __init__(
self,
in_channels: int,
activation: Type[nn.Module] = nn.ReLU,
init_layers_orthogonal: Optional[bool] = None,
impala_channels: Sequence[int] = (16, 32, 32),
**kwargs,
) -> None:
if init_layers_orthogonal is None:
init_layers_orthogonal = False
super().__init__(in_channels, activation, init_layers_orthogonal)
sequences = []
for out_channels in impala_channels:
sequences.append(
ConvSequence(
in_channels, out_channels, activation, init_layers_orthogonal
)
)
in_channels = out_channels
sequences.extend(
[
activation(),
nn.Flatten(),
]
)
self.seq = nn.Sequential(*sequences)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
return self.seq(obs)
CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnFeatureExtractor]] = {
"nature": NatureCnn,
"impala": ImpalaCnn,
}
class FeatureExtractor(nn.Module):
def __init__(
self,
obs_space: gym.Space,
activation: Type[nn.Module],
init_layers_orthogonal: bool = False,
cnn_feature_dim: int = 512,
cnn_style: str = "nature",
cnn_layers_init_orthogonal: Optional[bool] = None,
impala_channels: Sequence[int] = (16, 32, 32),
) -> None:
super().__init__()
if isinstance(obs_space, Box):
# Conv2D: (channels, height, width)
if len(obs_space.shape) == 3:
cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
obs_space.shape[0],
activation,
init_layers_orthogonal=cnn_layers_init_orthogonal,
impala_channels=impala_channels,
)
def preprocess(obs: torch.Tensor) -> torch.Tensor:
if len(obs.shape) == 3:
obs = obs.unsqueeze(0)
return obs.float() / 255.0
with torch.no_grad():
cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample())))
self.preprocess = preprocess
self.feature_extractor = nn.Sequential(
cnn,
layer_init(
nn.Linear(cnn_out.shape[1], cnn_feature_dim),
init_layers_orthogonal,
),
activation(),
)
self.out_dim = cnn_feature_dim
elif len(obs_space.shape) == 1:
def preprocess(obs: torch.Tensor) -> torch.Tensor:
if len(obs.shape) == 1:
obs = obs.unsqueeze(0)
return obs.float()
self.preprocess = preprocess
self.feature_extractor = nn.Flatten()
self.out_dim = get_flattened_obs_dim(obs_space)
else:
raise ValueError(f"Unsupported observation space: {obs_space}")
elif isinstance(obs_space, Discrete):
self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
self.feature_extractor = nn.Flatten()
self.out_dim = obs_space.n
else:
raise NotImplementedError
def forward(self, obs: torch.Tensor) -> torch.Tensor:
if self.preprocess:
obs = self.preprocess(obs)
return self.feature_extractor(obs)