"""Wrapper that stacks frames.""" from collections import deque from typing import Union import numpy as np import gym from gym.error import DependencyNotInstalled from gym.spaces import Box class LazyFrames: """Ensures common frames are only stored once to optimize memory use. To further reduce the memory use, it is optionally to turn on lz4 to compress the observations. Note: This object should only be converted to numpy array just before forward pass. """ __slots__ = ("frame_shape", "dtype", "shape", "lz4_compress", "_frames") def __init__(self, frames: list, lz4_compress: bool = False): """Lazyframe for a set of frames and if to apply lz4. Args: frames (list): The frames to convert to lazy frames lz4_compress (bool): Use lz4 to compress the frames internally Raises: DependencyNotInstalled: lz4 is not installed """ self.frame_shape = tuple(frames[0].shape) self.shape = (len(frames),) + self.frame_shape self.dtype = frames[0].dtype if lz4_compress: try: from lz4.block import compress except ImportError: raise DependencyNotInstalled( "lz4 is not installed, run `pip install gym[other]`" ) frames = [compress(frame) for frame in frames] self._frames = frames self.lz4_compress = lz4_compress def __array__(self, dtype=None): """Gets a numpy array of stacked frames with specific dtype. Args: dtype: The dtype of the stacked frames Returns: The array of stacked frames with dtype """ arr = self[:] if dtype is not None: return arr.astype(dtype) return arr def __len__(self): """Returns the number of frame stacks. Returns: The number of frame stacks """ return self.shape[0] def __getitem__(self, int_or_slice: Union[int, slice]): """Gets the stacked frames for a particular index or slice. Args: int_or_slice: Index or slice to get items for Returns: np.stacked frames for the int or slice """ if isinstance(int_or_slice, int): return self._check_decompress(self._frames[int_or_slice]) # single frame return np.stack( [self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0 ) def __eq__(self, other): """Checks that the current frames are equal to the other object.""" return self.__array__() == other def _check_decompress(self, frame): if self.lz4_compress: from lz4.block import decompress return np.frombuffer(decompress(frame), dtype=self.dtype).reshape( self.frame_shape ) return frame class FrameStack(gym.ObservationWrapper): """Observation wrapper that stacks the observations in a rolling manner. For example, if the number of stacks is 4, then the returned observation contains the most recent 4 observations. For environment 'Pendulum-v1', the original observation is an array with shape [3], so if we stack 4 observations, the processed observation has shape [4, 3]. Note: - To be memory efficient, the stacked observations are wrapped by :class:`LazyFrame`. - The observation space must be :class:`Box` type. If one uses :class:`Dict` as observation space, it should apply :class:`FlattenObservation` wrapper first. - After :meth:`reset` is called, the frame buffer will be filled with the initial observation. I.e. the observation returned by :meth:`reset` will consist of ``num_stack`-many identical frames, Example: >>> import gym >>> env = gym.make('CarRacing-v1') >>> env = FrameStack(env, 4) >>> env.observation_space Box(4, 96, 96, 3) >>> obs = env.reset() >>> obs.shape (4, 96, 96, 3) """ def __init__( self, env: gym.Env, num_stack: int, lz4_compress: bool = False, ): """Observation wrapper that stacks the observations in a rolling manner. Args: env (Env): The environment to apply the wrapper num_stack (int): The number of frames to stack lz4_compress (bool): Use lz4 to compress the frames internally """ super().__init__(env) self.num_stack = num_stack self.lz4_compress = lz4_compress self.frames = deque(maxlen=num_stack) low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0) high = np.repeat( self.observation_space.high[np.newaxis, ...], num_stack, axis=0 ) self.observation_space = Box( low=low, high=high, dtype=self.observation_space.dtype ) def observation(self, observation): """Converts the wrappers current frames to lazy frames. Args: observation: Ignored Returns: :class:`LazyFrames` object for the wrapper's frame buffer, :attr:`self.frames` """ assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack) return LazyFrames(list(self.frames), self.lz4_compress) def step(self, action): """Steps through the environment, appending the observation to the frame buffer. Args: action: The action to step through the environment with Returns: Stacked observations, reward, terminated, truncated, and information from the environment """ observation, reward, terminated, truncated, info = self.env.step(action) self.frames.append(observation) return self.observation(None), reward, terminated, truncated, info def reset(self, **kwargs): """Reset the environment with kwargs. Args: **kwargs: The kwargs for the environment reset Returns: The stacked observations """ obs, info = self.env.reset(**kwargs) [self.frames.append(obs) for _ in range(self.num_stack)] return self.observation(None), info