sgoodfriend's picture
PPO playing Walker2DBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c
c6e31cd
raw
history blame
1.07 kB
from abc import ABC, abstractmethod
from typing import NamedTuple, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Distribution
class PiForward(NamedTuple):
pi: Distribution
logp_a: Optional[torch.Tensor]
entropy: Optional[torch.Tensor]
class Actor(nn.Module, ABC):
@abstractmethod
def forward(
self,
obs: torch.Tensor,
actions: Optional[torch.Tensor] = None,
action_masks: Optional[torch.Tensor] = None,
) -> PiForward:
...
def sample_weights(self, batch_size: int = 1) -> None:
pass
@property
@abstractmethod
def action_shape(self) -> Tuple[int, ...]:
...
def pi_forward(
self, distribution: Distribution, actions: Optional[torch.Tensor] = None
) -> PiForward:
logp_a = None
entropy = None
if actions is not None:
logp_a = distribution.log_prob(actions)
entropy = distribution.entropy()
return PiForward(distribution, logp_a, entropy)