sgoodfriend's picture
VPG playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/e8bc541d8b5e67bb4d3f2075282463fb61f5f2c6
10ca8cb
raw
history blame
6.85 kB
import gym
import numpy as np
import torch
from abc import abstractmethod
from gym.spaces import Box, Discrete, Space
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
from shared.module.feature_extractor import FeatureExtractor
from shared.policy.actor import PiForward, StateDependentNoiseActorHead, actor_head
from shared.policy.critic import CriticHead
from shared.policy.policy import ACTIVATION, Policy
class Step(NamedTuple):
a: np.ndarray
v: np.ndarray
logp_a: np.ndarray
clamped_a: np.ndarray
class ACForward(NamedTuple):
logp_a: torch.Tensor
entropy: torch.Tensor
v: torch.Tensor
FEAT_EXT_FILE_NAME = "feat_ext.pt"
V_FEAT_EXT_FILE_NAME = "v_feat_ext.pt"
PI_FILE_NAME = "pi.pt"
V_FILE_NAME = "v.pt"
ActorCriticSelf = TypeVar("ActorCriticSelf", bound="ActorCritic")
def clamp_actions(
actions: np.ndarray, action_space: gym.Space, squash_output: bool
) -> np.ndarray:
if isinstance(action_space, Box):
low, high = action_space.low, action_space.high # type: ignore
if squash_output:
# Squashed output is already between -1 and 1. Rescale if the actual
# output needs to something other than -1 and 1
return low + 0.5 * (actions + 1) * (high - low)
else:
return np.clip(actions, low, high)
return actions
def default_hidden_sizes(obs_space: Space) -> Sequence[int]:
if isinstance(obs_space, Box):
if len(obs_space.shape) == 3:
# By default feature extractor to output has no hidden layers
return []
elif len(obs_space.shape) == 1:
return [64, 64]
else:
raise ValueError(f"Unsupported observation space: {obs_space}")
elif isinstance(obs_space, Discrete):
return [64]
else:
raise ValueError(f"Unsupported observation space: {obs_space}")
class OnPolicy(Policy):
@abstractmethod
def value(self, obs: VecEnvObs) -> np.ndarray:
...
@abstractmethod
def step(self, obs: VecEnvObs) -> Step:
...
class ActorCritic(OnPolicy):
def __init__(
self,
env: VecEnv,
pi_hidden_sizes: Sequence[int],
v_hidden_sizes: Sequence[int],
init_layers_orthogonal: bool = True,
activation_fn: str = "tanh",
log_std_init: float = -0.5,
use_sde: bool = False,
full_std: bool = True,
squash_output: bool = False,
share_features_extractor: bool = True,
cnn_feature_dim: int = 512,
cnn_style: str = "nature",
cnn_layers_init_orthogonal: Optional[bool] = None,
**kwargs,
) -> None:
super().__init__(env, **kwargs)
activation = ACTIVATION[activation_fn]
observation_space = env.observation_space
self.action_space = env.action_space
self.squash_output = squash_output
self.share_features_extractor = share_features_extractor
self._feature_extractor = FeatureExtractor(
observation_space,
activation,
init_layers_orthogonal=init_layers_orthogonal,
cnn_feature_dim=cnn_feature_dim,
cnn_style=cnn_style,
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
)
self._pi = actor_head(
self.action_space,
(self._feature_extractor.out_dim,) + tuple(pi_hidden_sizes),
init_layers_orthogonal,
activation,
log_std_init=log_std_init,
use_sde=use_sde,
full_std=full_std,
squash_output=squash_output,
)
if not share_features_extractor:
self._v_feature_extractor = FeatureExtractor(
observation_space,
activation,
init_layers_orthogonal=init_layers_orthogonal,
cnn_feature_dim=cnn_feature_dim,
cnn_style=cnn_style,
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
)
v_hidden_sizes = (self._v_feature_extractor.out_dim,) + tuple(
v_hidden_sizes
)
else:
self._v_feature_extractor = None
v_hidden_sizes = (self._feature_extractor.out_dim,) + tuple(v_hidden_sizes)
self._v = CriticHead(
hidden_sizes=v_hidden_sizes,
activation=activation,
init_layers_orthogonal=init_layers_orthogonal,
)
def _pi_forward(
self, obs: torch.Tensor, action: Optional[torch.Tensor] = None
) -> Tuple[PiForward, torch.Tensor]:
p_fe = self._feature_extractor(obs)
pi_forward = self._pi(p_fe, action)
return pi_forward, p_fe
def _v_forward(self, obs: torch.Tensor, p_fc: torch.Tensor) -> torch.Tensor:
v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
return self._v(v_fe)
def forward(self, obs: torch.Tensor, action: torch.Tensor) -> ACForward:
(_, logp_a, entropy), p_fc = self._pi_forward(obs, action)
v = self._v_forward(obs, p_fc)
assert logp_a is not None
assert entropy is not None
return ACForward(logp_a, entropy, v)
def value(self, obs: VecEnvObs) -> np.ndarray:
o = self._as_tensor(obs)
with torch.no_grad():
fe = (
self._v_feature_extractor(o)
if self._v_feature_extractor
else self._feature_extractor(o)
)
v = self._v(fe)
return v.cpu().numpy()
def step(self, obs: VecEnvObs) -> Step:
o = self._as_tensor(obs)
with torch.no_grad():
(pi, _, _), p_fc = self._pi_forward(o)
a = pi.sample()
logp_a = pi.log_prob(a)
v = self._v_forward(o, p_fc)
a_np = a.cpu().numpy()
clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
def act(self, obs: np.ndarray, deterministic: bool = True) -> np.ndarray:
if not deterministic:
return self.step(obs).clamped_a
else:
o = self._as_tensor(obs)
with torch.no_grad():
(pi, _, _), _ = self._pi_forward(o)
a = pi.mode
return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
def load(self, path: str) -> None:
super().load(path)
self.reset_noise()
def reset_noise(self, batch_size: Optional[int] = None) -> None:
if isinstance(self._pi, StateDependentNoiseActorHead):
self._pi.sample_weights(
batch_size=batch_size if batch_size else self.env.num_envs
)