Spaces:
Running
Running
"""Base class for vectorized environments.""" | |
from typing import Any, List, Optional, Tuple, Union | |
import numpy as np | |
import gym | |
from gym.vector.utils.spaces import batch_space | |
__all__ = ["VectorEnv"] | |
class VectorEnv(gym.Env): | |
"""Base class for vectorized environments. Runs multiple independent copies of the same environment in parallel. | |
This is not the same as 1 environment that has multiple subcomponents, but it is many copies of the same base env. | |
Each observation returned from vectorized environment is a batch of observations for each parallel environment. | |
And :meth:`step` is also expected to receive a batch of actions for each parallel environment. | |
Notes: | |
All parallel environments should share the identical observation and action spaces. | |
In other words, a vector of multiple different environments is not supported. | |
""" | |
def __init__( | |
self, | |
num_envs: int, | |
observation_space: gym.Space, | |
action_space: gym.Space, | |
): | |
"""Base class for vectorized environments. | |
Args: | |
num_envs: Number of environments in the vectorized environment. | |
observation_space: Observation space of a single environment. | |
action_space: Action space of a single environment. | |
""" | |
self.num_envs = num_envs | |
self.is_vector_env = True | |
self.observation_space = batch_space(observation_space, n=num_envs) | |
self.action_space = batch_space(action_space, n=num_envs) | |
self.closed = False | |
self.viewer = None | |
# The observation and action spaces of a single environment are | |
# kept in separate properties | |
self.single_observation_space = observation_space | |
self.single_action_space = action_space | |
def reset_async( | |
self, | |
seed: Optional[Union[int, List[int]]] = None, | |
options: Optional[dict] = None, | |
): | |
"""Reset the sub-environments asynchronously. | |
This method will return ``None``. A call to :meth:`reset_async` should be followed | |
by a call to :meth:`reset_wait` to retrieve the results. | |
Args: | |
seed: The reset seed | |
options: Reset options | |
""" | |
pass | |
def reset_wait( | |
self, | |
seed: Optional[Union[int, List[int]]] = None, | |
options: Optional[dict] = None, | |
): | |
"""Retrieves the results of a :meth:`reset_async` call. | |
A call to this method must always be preceded by a call to :meth:`reset_async`. | |
Args: | |
seed: The reset seed | |
options: Reset options | |
Returns: | |
The results from :meth:`reset_async` | |
Raises: | |
NotImplementedError: VectorEnv does not implement function | |
""" | |
raise NotImplementedError("VectorEnv does not implement function") | |
def reset( | |
self, | |
*, | |
seed: Optional[Union[int, List[int]]] = None, | |
options: Optional[dict] = None, | |
): | |
"""Reset all parallel environments and return a batch of initial observations. | |
Args: | |
seed: The environment reset seeds | |
options: If to return the options | |
Returns: | |
A batch of observations from the vectorized environment. | |
""" | |
self.reset_async(seed=seed, options=options) | |
return self.reset_wait(seed=seed, options=options) | |
def step_async(self, actions): | |
"""Asynchronously performs steps in the sub-environments. | |
The results can be retrieved via a call to :meth:`step_wait`. | |
Args: | |
actions: The actions to take asynchronously | |
""" | |
def step_wait(self, **kwargs): | |
"""Retrieves the results of a :meth:`step_async` call. | |
A call to this method must always be preceded by a call to :meth:`step_async`. | |
Args: | |
**kwargs: Additional keywords for vector implementation | |
Returns: | |
The results from the :meth:`step_async` call | |
""" | |
def step(self, actions): | |
"""Take an action for each parallel environment. | |
Args: | |
actions: element of :attr:`action_space` Batch of actions. | |
Returns: | |
Batch of (observations, rewards, terminated, truncated, infos) or (observations, rewards, dones, infos) | |
""" | |
self.step_async(actions) | |
return self.step_wait() | |
def call_async(self, name, *args, **kwargs): | |
"""Calls a method name for each parallel environment asynchronously.""" | |
def call_wait(self, **kwargs) -> List[Any]: # type: ignore | |
"""After calling a method in :meth:`call_async`, this function collects the results.""" | |
def call(self, name: str, *args, **kwargs) -> List[Any]: | |
"""Call a method, or get a property, from each parallel environment. | |
Args: | |
name (str): Name of the method or property to call. | |
*args: Arguments to apply to the method call. | |
**kwargs: Keyword arguments to apply to the method call. | |
Returns: | |
List of the results of the individual calls to the method or property for each environment. | |
""" | |
self.call_async(name, *args, **kwargs) | |
return self.call_wait() | |
def get_attr(self, name: str): | |
"""Get a property from each parallel environment. | |
Args: | |
name (str): Name of the property to be get from each individual environment. | |
Returns: | |
The property with name | |
""" | |
return self.call(name) | |
def set_attr(self, name: str, values: Union[list, tuple, object]): | |
"""Set a property in each sub-environment. | |
Args: | |
name (str): Name of the property to be set in each individual environment. | |
values (list, tuple, or object): Values of the property to be set to. If `values` is a list or | |
tuple, then it corresponds to the values for each individual environment, otherwise a single value | |
is set for all environments. | |
""" | |
def close_extras(self, **kwargs): | |
"""Clean up the extra resources e.g. beyond what's in this base class.""" | |
pass | |
def close(self, **kwargs): | |
"""Close all parallel environments and release resources. | |
It also closes all the existing image viewers, then calls :meth:`close_extras` and set | |
:attr:`closed` as ``True``. | |
Warnings: | |
This function itself does not close the environments, it should be handled | |
in :meth:`close_extras`. This is generic for both synchronous and asynchronous | |
vectorized environments. | |
Notes: | |
This will be automatically called when garbage collected or program exited. | |
Args: | |
**kwargs: Keyword arguments passed to :meth:`close_extras` | |
""" | |
if self.closed: | |
return | |
if self.viewer is not None: | |
self.viewer.close() | |
self.close_extras(**kwargs) | |
self.closed = True | |
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict: | |
"""Add env info to the info dictionary of the vectorized environment. | |
Given the `info` of a single environment add it to the `infos` dictionary | |
which represents all the infos of the vectorized environment. | |
Every `key` of `info` is paired with a boolean mask `_key` representing | |
whether or not the i-indexed environment has this `info`. | |
Args: | |
infos (dict): the infos of the vectorized environment | |
info (dict): the info coming from the single environment | |
env_num (int): the index of the single environment | |
Returns: | |
infos (dict): the (updated) infos of the vectorized environment | |
""" | |
for k in info.keys(): | |
if k not in infos: | |
info_array, array_mask = self._init_info_arrays(type(info[k])) | |
else: | |
info_array, array_mask = infos[k], infos[f"_{k}"] | |
info_array[env_num], array_mask[env_num] = info[k], True | |
infos[k], infos[f"_{k}"] = info_array, array_mask | |
return infos | |
def _init_info_arrays(self, dtype: type) -> Tuple[np.ndarray, np.ndarray]: | |
"""Initialize the info array. | |
Initialize the info array. If the dtype is numeric | |
the info array will have the same dtype, otherwise | |
will be an array of `None`. Also, a boolean array | |
of the same length is returned. It will be used for | |
assessing which environment has info data. | |
Args: | |
dtype (type): data type of the info coming from the env. | |
Returns: | |
array (np.ndarray): the initialized info array. | |
array_mask (np.ndarray): the initialized boolean array. | |
""" | |
if dtype in [int, float, bool] or issubclass(dtype, np.number): | |
array = np.zeros(self.num_envs, dtype=dtype) | |
else: | |
array = np.zeros(self.num_envs, dtype=object) | |
array[:] = None | |
array_mask = np.zeros(self.num_envs, dtype=bool) | |
return array, array_mask | |
def __del__(self): | |
"""Closes the vector environment.""" | |
if not getattr(self, "closed", True): | |
self.close() | |
def __repr__(self) -> str: | |
"""Returns a string representation of the vector environment. | |
Returns: | |
A string containing the class name, number of environments and environment spec id | |
""" | |
if self.spec is None: | |
return f"{self.__class__.__name__}({self.num_envs})" | |
else: | |
return f"{self.__class__.__name__}({self.spec.id}, {self.num_envs})" | |
class VectorEnvWrapper(VectorEnv): | |
"""Wraps the vectorized environment to allow a modular transformation. | |
This class is the base class for all wrappers for vectorized environments. The subclass | |
could override some methods to change the behavior of the original vectorized environment | |
without touching the original code. | |
Notes: | |
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`. | |
""" | |
def __init__(self, env: VectorEnv): | |
assert isinstance(env, VectorEnv) | |
self.env = env | |
# explicitly forward the methods defined in VectorEnv | |
# to self.env (instead of the base class) | |
def reset_async(self, **kwargs): | |
return self.env.reset_async(**kwargs) | |
def reset_wait(self, **kwargs): | |
return self.env.reset_wait(**kwargs) | |
def step_async(self, actions): | |
return self.env.step_async(actions) | |
def step_wait(self): | |
return self.env.step_wait() | |
def close(self, **kwargs): | |
return self.env.close(**kwargs) | |
def close_extras(self, **kwargs): | |
return self.env.close_extras(**kwargs) | |
def call(self, name, *args, **kwargs): | |
return self.env.call(name, *args, **kwargs) | |
def set_attr(self, name, values): | |
return self.env.set_attr(name, values) | |
# implicitly forward all other methods and attributes to self.env | |
def __getattr__(self, name): | |
if name.startswith("_"): | |
raise AttributeError(f"attempted to get missing private attribute '{name}'") | |
return getattr(self.env, name) | |
def unwrapped(self): | |
return self.env.unwrapped | |
def __repr__(self): | |
return f"<{self.__class__.__name__}, {self.env}>" | |
def __del__(self): | |
self.env.__del__() | |