"""Wrapper that converts a color observation to grayscale.""" import numpy as np import gym from gym.spaces import Box class GrayScaleObservation(gym.ObservationWrapper): """Convert the image observation from RGB to gray scale. Example: >>> env = gym.make('CarRacing-v1') >>> env.observation_space Box(0, 255, (96, 96, 3), uint8) >>> env = GrayScaleObservation(gym.make('CarRacing-v1')) >>> env.observation_space Box(0, 255, (96, 96), uint8) >>> env = GrayScaleObservation(gym.make('CarRacing-v1'), keep_dim=True) >>> env.observation_space Box(0, 255, (96, 96, 1), uint8) """ def __init__(self, env: gym.Env, keep_dim: bool = False): """Convert the image observation from RGB to gray scale. Args: env (Env): The environment to apply the wrapper keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1. Otherwise, they are of shape AxB. """ super().__init__(env) self.keep_dim = keep_dim assert ( isinstance(self.observation_space, Box) and len(self.observation_space.shape) == 3 and self.observation_space.shape[-1] == 3 ) obs_shape = self.observation_space.shape[:2] if self.keep_dim: self.observation_space = Box( low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8 ) else: self.observation_space = Box( low=0, high=255, shape=obs_shape, dtype=np.uint8 ) def observation(self, observation): """Converts the colour observation to greyscale. Args: observation: Color observations Returns: Grayscale observations """ import cv2 observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY) if self.keep_dim: observation = np.expand_dims(observation, -1) return observation