VPG playing AntBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c
a7758a3
from typing import Optional, Union | |
import numpy as np | |
from rl_algo_impls.wrappers.vectorable_wrapper import ( | |
VecEnv, | |
VecotarableWrapper, | |
find_wrapper, | |
) | |
class IncompleteArrayError(Exception): | |
pass | |
class SingleActionMaskWrapper(VecotarableWrapper): | |
def action_masks(self) -> Optional[np.ndarray]: | |
envs = getattr(self.env.unwrapped, "envs") | |
assert ( | |
envs | |
), f"{self.__class__.__name__} expects to wrap synchronous vectorized env" | |
masks = [getattr(e.unwrapped, "action_mask") for e in envs] | |
assert all(m is not None for m in masks) | |
return np.array(masks, dtype=np.bool8) | |
class MicrortsMaskWrapper(VecotarableWrapper): | |
def action_masks(self) -> np.ndarray: | |
microrts_env = self.env.unwrapped # type: ignore | |
vec_client = getattr(microrts_env, "vec_client") | |
assert ( | |
vec_client | |
), f"{microrts_env.__class__.__name__} must have vec_client property (as MicroRTSVecEnv does)" | |
return np.array(vec_client.getMasks(0), dtype=np.bool8) | |
def find_action_masker( | |
env: VecEnv, | |
) -> Optional[Union[SingleActionMaskWrapper, MicrortsMaskWrapper]]: | |
return find_wrapper(env, SingleActionMaskWrapper) or find_wrapper( | |
env, MicrortsMaskWrapper | |
) | |