Spaces:
Running
Running
"""Wrapper that converts a color observation to grayscale.""" | |
import numpy as np | |
import gym | |
from gym.spaces import Box | |
class GrayScaleObservation(gym.ObservationWrapper): | |
"""Convert the image observation from RGB to gray scale. | |
Example: | |
>>> env = gym.make('CarRacing-v1') | |
>>> env.observation_space | |
Box(0, 255, (96, 96, 3), uint8) | |
>>> env = GrayScaleObservation(gym.make('CarRacing-v1')) | |
>>> env.observation_space | |
Box(0, 255, (96, 96), uint8) | |
>>> env = GrayScaleObservation(gym.make('CarRacing-v1'), keep_dim=True) | |
>>> env.observation_space | |
Box(0, 255, (96, 96, 1), uint8) | |
""" | |
def __init__(self, env: gym.Env, keep_dim: bool = False): | |
"""Convert the image observation from RGB to gray scale. | |
Args: | |
env (Env): The environment to apply the wrapper | |
keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1. | |
Otherwise, they are of shape AxB. | |
""" | |
super().__init__(env) | |
self.keep_dim = keep_dim | |
assert ( | |
isinstance(self.observation_space, Box) | |
and len(self.observation_space.shape) == 3 | |
and self.observation_space.shape[-1] == 3 | |
) | |
obs_shape = self.observation_space.shape[:2] | |
if self.keep_dim: | |
self.observation_space = Box( | |
low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8 | |
) | |
else: | |
self.observation_space = Box( | |
low=0, high=255, shape=obs_shape, dtype=np.uint8 | |
) | |
def observation(self, observation): | |
"""Converts the colour observation to greyscale. | |
Args: | |
observation: Color observations | |
Returns: | |
Grayscale observations | |
""" | |
import cv2 | |
observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY) | |
if self.keep_dim: | |
observation = np.expand_dims(observation, -1) | |
return observation | |