Spaces:
Paused
Paused
File size: 11,487 Bytes
375a1cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 |
"""Base class for vectorized environments."""
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import gym
from gym.vector.utils.spaces import batch_space
__all__ = ["VectorEnv"]
class VectorEnv(gym.Env):
"""Base class for vectorized environments. Runs multiple independent copies of the same environment in parallel.
This is not the same as 1 environment that has multiple subcomponents, but it is many copies of the same base env.
Each observation returned from vectorized environment is a batch of observations for each parallel environment.
And :meth:`step` is also expected to receive a batch of actions for each parallel environment.
Notes:
All parallel environments should share the identical observation and action spaces.
In other words, a vector of multiple different environments is not supported.
"""
def __init__(
self,
num_envs: int,
observation_space: gym.Space,
action_space: gym.Space,
):
"""Base class for vectorized environments.
Args:
num_envs: Number of environments in the vectorized environment.
observation_space: Observation space of a single environment.
action_space: Action space of a single environment.
"""
self.num_envs = num_envs
self.is_vector_env = True
self.observation_space = batch_space(observation_space, n=num_envs)
self.action_space = batch_space(action_space, n=num_envs)
self.closed = False
self.viewer = None
# The observation and action spaces of a single environment are
# kept in separate properties
self.single_observation_space = observation_space
self.single_action_space = action_space
def reset_async(
self,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
):
"""Reset the sub-environments asynchronously.
This method will return ``None``. A call to :meth:`reset_async` should be followed
by a call to :meth:`reset_wait` to retrieve the results.
Args:
seed: The reset seed
options: Reset options
"""
pass
def reset_wait(
self,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
):
"""Retrieves the results of a :meth:`reset_async` call.
A call to this method must always be preceded by a call to :meth:`reset_async`.
Args:
seed: The reset seed
options: Reset options
Returns:
The results from :meth:`reset_async`
Raises:
NotImplementedError: VectorEnv does not implement function
"""
raise NotImplementedError("VectorEnv does not implement function")
def reset(
self,
*,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
):
"""Reset all parallel environments and return a batch of initial observations.
Args:
seed: The environment reset seeds
options: If to return the options
Returns:
A batch of observations from the vectorized environment.
"""
self.reset_async(seed=seed, options=options)
return self.reset_wait(seed=seed, options=options)
def step_async(self, actions):
"""Asynchronously performs steps in the sub-environments.
The results can be retrieved via a call to :meth:`step_wait`.
Args:
actions: The actions to take asynchronously
"""
def step_wait(self, **kwargs):
"""Retrieves the results of a :meth:`step_async` call.
A call to this method must always be preceded by a call to :meth:`step_async`.
Args:
**kwargs: Additional keywords for vector implementation
Returns:
The results from the :meth:`step_async` call
"""
def step(self, actions):
"""Take an action for each parallel environment.
Args:
actions: element of :attr:`action_space` Batch of actions.
Returns:
Batch of (observations, rewards, terminated, truncated, infos) or (observations, rewards, dones, infos)
"""
self.step_async(actions)
return self.step_wait()
def call_async(self, name, *args, **kwargs):
"""Calls a method name for each parallel environment asynchronously."""
def call_wait(self, **kwargs) -> List[Any]: # type: ignore
"""After calling a method in :meth:`call_async`, this function collects the results."""
def call(self, name: str, *args, **kwargs) -> List[Any]:
"""Call a method, or get a property, from each parallel environment.
Args:
name (str): Name of the method or property to call.
*args: Arguments to apply to the method call.
**kwargs: Keyword arguments to apply to the method call.
Returns:
List of the results of the individual calls to the method or property for each environment.
"""
self.call_async(name, *args, **kwargs)
return self.call_wait()
def get_attr(self, name: str):
"""Get a property from each parallel environment.
Args:
name (str): Name of the property to be get from each individual environment.
Returns:
The property with name
"""
return self.call(name)
def set_attr(self, name: str, values: Union[list, tuple, object]):
"""Set a property in each sub-environment.
Args:
name (str): Name of the property to be set in each individual environment.
values (list, tuple, or object): Values of the property to be set to. If `values` is a list or
tuple, then it corresponds to the values for each individual environment, otherwise a single value
is set for all environments.
"""
def close_extras(self, **kwargs):
"""Clean up the extra resources e.g. beyond what's in this base class."""
pass
def close(self, **kwargs):
"""Close all parallel environments and release resources.
It also closes all the existing image viewers, then calls :meth:`close_extras` and set
:attr:`closed` as ``True``.
Warnings:
This function itself does not close the environments, it should be handled
in :meth:`close_extras`. This is generic for both synchronous and asynchronous
vectorized environments.
Notes:
This will be automatically called when garbage collected or program exited.
Args:
**kwargs: Keyword arguments passed to :meth:`close_extras`
"""
if self.closed:
return
if self.viewer is not None:
self.viewer.close()
self.close_extras(**kwargs)
self.closed = True
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
"""Add env info to the info dictionary of the vectorized environment.
Given the `info` of a single environment add it to the `infos` dictionary
which represents all the infos of the vectorized environment.
Every `key` of `info` is paired with a boolean mask `_key` representing
whether or not the i-indexed environment has this `info`.
Args:
infos (dict): the infos of the vectorized environment
info (dict): the info coming from the single environment
env_num (int): the index of the single environment
Returns:
infos (dict): the (updated) infos of the vectorized environment
"""
for k in info.keys():
if k not in infos:
info_array, array_mask = self._init_info_arrays(type(info[k]))
else:
info_array, array_mask = infos[k], infos[f"_{k}"]
info_array[env_num], array_mask[env_num] = info[k], True
infos[k], infos[f"_{k}"] = info_array, array_mask
return infos
def _init_info_arrays(self, dtype: type) -> Tuple[np.ndarray, np.ndarray]:
"""Initialize the info array.
Initialize the info array. If the dtype is numeric
the info array will have the same dtype, otherwise
will be an array of `None`. Also, a boolean array
of the same length is returned. It will be used for
assessing which environment has info data.
Args:
dtype (type): data type of the info coming from the env.
Returns:
array (np.ndarray): the initialized info array.
array_mask (np.ndarray): the initialized boolean array.
"""
if dtype in [int, float, bool] or issubclass(dtype, np.number):
array = np.zeros(self.num_envs, dtype=dtype)
else:
array = np.zeros(self.num_envs, dtype=object)
array[:] = None
array_mask = np.zeros(self.num_envs, dtype=bool)
return array, array_mask
def __del__(self):
"""Closes the vector environment."""
if not getattr(self, "closed", True):
self.close()
def __repr__(self) -> str:
"""Returns a string representation of the vector environment.
Returns:
A string containing the class name, number of environments and environment spec id
"""
if self.spec is None:
return f"{self.__class__.__name__}({self.num_envs})"
else:
return f"{self.__class__.__name__}({self.spec.id}, {self.num_envs})"
class VectorEnvWrapper(VectorEnv):
"""Wraps the vectorized environment to allow a modular transformation.
This class is the base class for all wrappers for vectorized environments. The subclass
could override some methods to change the behavior of the original vectorized environment
without touching the original code.
Notes:
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
"""
def __init__(self, env: VectorEnv):
assert isinstance(env, VectorEnv)
self.env = env
# explicitly forward the methods defined in VectorEnv
# to self.env (instead of the base class)
def reset_async(self, **kwargs):
return self.env.reset_async(**kwargs)
def reset_wait(self, **kwargs):
return self.env.reset_wait(**kwargs)
def step_async(self, actions):
return self.env.step_async(actions)
def step_wait(self):
return self.env.step_wait()
def close(self, **kwargs):
return self.env.close(**kwargs)
def close_extras(self, **kwargs):
return self.env.close_extras(**kwargs)
def call(self, name, *args, **kwargs):
return self.env.call(name, *args, **kwargs)
def set_attr(self, name, values):
return self.env.set_attr(name, values)
# implicitly forward all other methods and attributes to self.env
def __getattr__(self, name):
if name.startswith("_"):
raise AttributeError(f"attempted to get missing private attribute '{name}'")
return getattr(self.env, name)
@property
def unwrapped(self):
return self.env.unwrapped
def __repr__(self):
return f"<{self.__class__.__name__}, {self.env}>"
def __del__(self):
self.env.__del__()
|