Spaces:
Running
Running
"""Wrapper for flattening observations of an environment.""" | |
import gym | |
import gym.spaces as spaces | |
class FlattenObservation(gym.ObservationWrapper): | |
"""Observation wrapper that flattens the observation. | |
Example: | |
>>> import gym | |
>>> env = gym.make('CarRacing-v1') | |
>>> env.observation_space.shape | |
(96, 96, 3) | |
>>> env = FlattenObservation(env) | |
>>> env.observation_space.shape | |
(27648,) | |
>>> obs = env.reset() | |
>>> obs.shape | |
(27648,) | |
""" | |
def __init__(self, env: gym.Env): | |
"""Flattens the observations of an environment. | |
Args: | |
env: The environment to apply the wrapper | |
""" | |
super().__init__(env) | |
self.observation_space = spaces.flatten_space(env.observation_space) | |
def observation(self, observation): | |
"""Flattens an observation. | |
Args: | |
observation: The observation to flatten | |
Returns: | |
The flattened observation | |
""" | |
return spaces.flatten(self.env.observation_space, observation) | |