"""Wrapper for resizing observations.""" from typing import Union import numpy as np import gym from gym.error import DependencyNotInstalled from gym.spaces import Box class ResizeObservation(gym.ObservationWrapper): """Resize the image observation. This wrapper works on environments with image observations (or more generally observations of shape AxBxC) and resizes the observation to the shape given by the 2-tuple :attr:`shape`. The argument :attr:`shape` may also be an integer. In that case, the observation is scaled to a square of side-length :attr:`shape`. Example: >>> import gym >>> env = gym.make('CarRacing-v1') >>> env.observation_space.shape (96, 96, 3) >>> env = ResizeObservation(env, 64) >>> env.observation_space.shape (64, 64, 3) """ def __init__(self, env: gym.Env, shape: Union[tuple, int]): """Resizes image observations to shape given by :attr:`shape`. Args: env: The environment to apply the wrapper shape: The shape of the resized observations """ super().__init__(env) if isinstance(shape, int): shape = (shape, shape) assert all(x > 0 for x in shape), shape self.shape = tuple(shape) assert isinstance( env.observation_space, Box ), f"Expected the observation space to be Box, actual type: {type(env.observation_space)}" obs_shape = self.shape + env.observation_space.shape[2:] self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8) def observation(self, observation): """Updates the observations by resizing the observation to shape given by :attr:`shape`. Args: observation: The observation to reshape Returns: The reshaped observations Raises: DependencyNotInstalled: opencv-python is not installed """ try: import cv2 except ImportError: raise DependencyNotInstalled( "opencv is not install, run `pip install gym[other]`" ) observation = cv2.resize( observation, self.shape[::-1], interpolation=cv2.INTER_AREA ) if observation.ndim == 2: observation = np.expand_dims(observation, -1) return observation