VPG playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/e8bc541d8b5e67bb4d3f2075282463fb61f5f2c6
10ca8cb
import gym | |
import torch | |
import torch.nn as nn | |
from abc import ABC, abstractmethod | |
from gym.spaces import Box, Discrete | |
from torch.distributions import Categorical, Distribution, Normal | |
from typing import NamedTuple, Optional, Sequence, Type, TypeVar, Union | |
from shared.module.feature_extractor import FeatureExtractor | |
from shared.module.module import mlp | |
class PiForward(NamedTuple): | |
pi: Distribution | |
logp_a: Optional[torch.Tensor] | |
entropy: Optional[torch.Tensor] | |
class Actor(nn.Module, ABC): | |
def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward: | |
... | |
class CategoricalActorHead(Actor): | |
def __init__( | |
self, | |
act_dim: int, | |
hidden_sizes: Sequence[int] = (32,), | |
activation: Type[nn.Module] = nn.Tanh, | |
init_layers_orthogonal: bool = True, | |
) -> None: | |
super().__init__() | |
layer_sizes = tuple(hidden_sizes) + (act_dim,) | |
self._fc = mlp( | |
layer_sizes, | |
activation, | |
init_layers_orthogonal=init_layers_orthogonal, | |
final_layer_gain=0.01, | |
) | |
def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward: | |
logits = self._fc(obs) | |
pi = Categorical(logits=logits) | |
logp_a = None | |
entropy = None | |
if a is not None: | |
logp_a = pi.log_prob(a) | |
entropy = pi.entropy() | |
return PiForward(pi, logp_a, entropy) | |
class GaussianDistribution(Normal): | |
def log_prob(self, a: torch.Tensor) -> torch.Tensor: | |
return super().log_prob(a).sum(axis=-1) | |
def sample(self) -> torch.Tensor: | |
return self.rsample() | |
class GaussianActorHead(Actor): | |
def __init__( | |
self, | |
act_dim: int, | |
hidden_sizes: Sequence[int] = (32,), | |
activation: Type[nn.Module] = nn.Tanh, | |
init_layers_orthogonal: bool = True, | |
log_std_init: float = -0.5, | |
) -> None: | |
super().__init__() | |
layer_sizes = tuple(hidden_sizes) + (act_dim,) | |
self.mu_net = mlp( | |
layer_sizes, | |
activation, | |
init_layers_orthogonal=init_layers_orthogonal, | |
final_layer_gain=0.01, | |
) | |
self.log_std = nn.Parameter( | |
torch.ones(act_dim, dtype=torch.float32) * log_std_init | |
) | |
def _distribution(self, obs: torch.Tensor) -> Distribution: | |
mu = self.mu_net(obs) | |
std = torch.exp(self.log_std) | |
return GaussianDistribution(mu, std) | |
def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward: | |
pi = self._distribution(obs) | |
logp_a = None | |
entropy = None | |
if a is not None: | |
logp_a = pi.log_prob(a) | |
entropy = pi.entropy() | |
return PiForward(pi, logp_a, entropy) | |
class TanhBijector: | |
def __init__(self, epsilon: float = 1e-6) -> None: | |
self.epsilon = epsilon | |
def forward(x: torch.Tensor) -> torch.Tensor: | |
return torch.tanh(x) | |
def inverse(y: torch.Tensor) -> torch.Tensor: | |
eps = torch.finfo(y.dtype).eps | |
clamped_y = y.clamp(min=-1.0 + eps, max=1.0 - eps) | |
return torch.atanh(clamped_y) | |
def log_prob_correction(self, x: torch.Tensor) -> torch.Tensor: | |
return torch.log(1.0 - torch.tanh(x) ** 2 + self.epsilon) | |
class StateDependentNoiseDistribution(Normal): | |
def __init__( | |
self, | |
loc, | |
scale, | |
latent_sde: torch.Tensor, | |
exploration_mat: torch.Tensor, | |
exploration_matrices: torch.Tensor, | |
bijector: Optional[TanhBijector] = None, | |
validate_args=None, | |
): | |
super().__init__(loc, scale, validate_args) | |
self.latent_sde = latent_sde | |
self.exploration_mat = exploration_mat | |
self.exploration_matrices = exploration_matrices | |
self.bijector = bijector | |
def log_prob(self, a: torch.Tensor) -> torch.Tensor: | |
gaussian_a = self.bijector.inverse(a) if self.bijector else a | |
log_prob = super().log_prob(gaussian_a).sum(axis=-1) | |
if self.bijector: | |
log_prob -= torch.sum(self.bijector.log_prob_correction(gaussian_a), dim=1) | |
return log_prob | |
def sample(self) -> torch.Tensor: | |
noise = self._get_noise() | |
actions = self.mean + noise | |
return self.bijector.forward(actions) if self.bijector else actions | |
def _get_noise(self) -> torch.Tensor: | |
if len(self.latent_sde) == 1 or len(self.latent_sde) != len( | |
self.exploration_matrices | |
): | |
return torch.mm(self.latent_sde, self.exploration_mat) | |
# (batch_size, n_features) -> (batch_size, 1, n_features) | |
latent_sde = self.latent_sde.unsqueeze(dim=1) | |
# (batch_size, 1, n_actions) | |
noise = torch.bmm(latent_sde, self.exploration_matrices) | |
return noise.squeeze(dim=1) | |
def mode(self) -> torch.Tensor: | |
mean = super().mode | |
return self.bijector.forward(mean) if self.bijector else mean | |
StateDependentNoiseActorHeadSelf = TypeVar( | |
"StateDependentNoiseActorHeadSelf", bound="StateDependentNoiseActorHead" | |
) | |
class StateDependentNoiseActorHead(Actor): | |
def __init__( | |
self, | |
act_dim: int, | |
hidden_sizes: Sequence[int] = (32,), | |
activation: Type[nn.Module] = nn.Tanh, | |
init_layers_orthogonal: bool = True, | |
log_std_init: float = -0.5, | |
full_std: bool = True, | |
squash_output: bool = False, | |
learn_std: bool = False, | |
) -> None: | |
super().__init__() | |
self.act_dim = act_dim | |
layer_sizes = tuple(hidden_sizes) + (self.act_dim,) | |
if len(layer_sizes) == 2: | |
self.latent_net = nn.Identity() | |
elif len(layer_sizes) > 2: | |
self.latent_net = mlp( | |
layer_sizes[:-1], | |
activation, | |
output_activation=activation, | |
init_layers_orthogonal=init_layers_orthogonal, | |
) | |
else: | |
raise ValueError("hidden_sizes must be of at least length 1") | |
self.mu_net = mlp( | |
layer_sizes[-2:], | |
activation, | |
init_layers_orthogonal=init_layers_orthogonal, | |
final_layer_gain=0.01, | |
) | |
self.full_std = full_std | |
std_dim = (hidden_sizes[-1], act_dim if self.full_std else 1) | |
self.log_std = nn.Parameter( | |
torch.ones(std_dim, dtype=torch.float32) * log_std_init | |
) | |
self.bijector = TanhBijector() if squash_output else None | |
self.learn_std = learn_std | |
self.device = None | |
self.exploration_mat = None | |
self.exploration_matrices = None | |
self.sample_weights() | |
def to( | |
self: StateDependentNoiseActorHeadSelf, | |
device: Optional[torch.device] = None, | |
dtype: Optional[Union[torch.dtype, str]] = None, | |
non_blocking: bool = False, | |
) -> StateDependentNoiseActorHeadSelf: | |
super().to(device, dtype, non_blocking) | |
self.device = device | |
return self | |
def _distribution(self, obs: torch.Tensor) -> Distribution: | |
latent = self.latent_net(obs) | |
mu = self.mu_net(latent) | |
latent_sde = latent if self.learn_std else latent.detach() | |
variance = torch.mm(latent_sde**2, self._get_std() ** 2) | |
assert self.exploration_mat is not None | |
assert self.exploration_matrices is not None | |
return StateDependentNoiseDistribution( | |
mu, | |
torch.sqrt(variance + 1e-6), | |
latent_sde, | |
self.exploration_mat, | |
self.exploration_matrices, | |
self.bijector, | |
) | |
def _get_std(self) -> torch.Tensor: | |
std = torch.exp(self.log_std) | |
if self.full_std: | |
return std | |
ones = torch.ones(self.log_std.shape[0], self.act_dim) | |
if self.device: | |
ones = ones.to(self.device) | |
return ones * std | |
def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward: | |
pi = self._distribution(obs) | |
logp_a = None | |
entropy = None | |
if a is not None: | |
logp_a = pi.log_prob(a) | |
entropy = -logp_a | |
return PiForward(pi, logp_a, entropy) | |
def sample_weights(self, batch_size: int = 1) -> None: | |
std = self._get_std() | |
weights_dist = Normal(torch.zeros_like(std), std) | |
# Reparametrization trick to pass gradients | |
self.exploration_mat = weights_dist.rsample() | |
self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,))) | |
def actor_head( | |
action_space: gym.Space, | |
hidden_sizes: Sequence[int], | |
init_layers_orthogonal: bool, | |
activation: Type[nn.Module], | |
log_std_init: float = -0.5, | |
use_sde: bool = False, | |
full_std: bool = True, | |
squash_output: bool = False, | |
) -> Actor: | |
assert not use_sde or isinstance( | |
action_space, Box | |
), "use_sde only valid if Box action_space" | |
assert not squash_output or use_sde, "squash_output only valid if use_sde" | |
if isinstance(action_space, Discrete): | |
return CategoricalActorHead( | |
action_space.n, | |
hidden_sizes=hidden_sizes, | |
activation=activation, | |
init_layers_orthogonal=init_layers_orthogonal, | |
) | |
elif isinstance(action_space, Box): | |
if use_sde: | |
return StateDependentNoiseActorHead( | |
action_space.shape[0], | |
hidden_sizes=hidden_sizes, | |
activation=activation, | |
init_layers_orthogonal=init_layers_orthogonal, | |
log_std_init=log_std_init, | |
full_std=full_std, | |
squash_output=squash_output, | |
) | |
else: | |
return GaussianActorHead( | |
action_space.shape[0], | |
hidden_sizes=hidden_sizes, | |
activation=activation, | |
init_layers_orthogonal=init_layers_orthogonal, | |
log_std_init=log_std_init, | |
) | |
else: | |
raise ValueError(f"Unsupported action space: {action_space}") | |