vpg-AntBulletEnv-v0 / rl_algo_impls /wrappers /action_mask_wrapper.py
sgoodfriend's picture
VPG playing AntBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c
a7758a3
raw
history blame
1.29 kB
from typing import Optional, Union
import numpy as np
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecEnv,
VecotarableWrapper,
find_wrapper,
)
class IncompleteArrayError(Exception):
pass
class SingleActionMaskWrapper(VecotarableWrapper):
def action_masks(self) -> Optional[np.ndarray]:
envs = getattr(self.env.unwrapped, "envs")
assert (
envs
), f"{self.__class__.__name__} expects to wrap synchronous vectorized env"
masks = [getattr(e.unwrapped, "action_mask") for e in envs]
assert all(m is not None for m in masks)
return np.array(masks, dtype=np.bool8)
class MicrortsMaskWrapper(VecotarableWrapper):
def action_masks(self) -> np.ndarray:
microrts_env = self.env.unwrapped # type: ignore
vec_client = getattr(microrts_env, "vec_client")
assert (
vec_client
), f"{microrts_env.__class__.__name__} must have vec_client property (as MicroRTSVecEnv does)"
return np.array(vec_client.getMasks(0), dtype=np.bool8)
def find_action_masker(
env: VecEnv,
) -> Optional[Union[SingleActionMaskWrapper, MicrortsMaskWrapper]]:
return find_wrapper(env, SingleActionMaskWrapper) or find_wrapper(
env, MicrortsMaskWrapper
)