Spaces:
Running
Running
File size: 1,092 Bytes
375a1cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
"""Wrapper for flattening observations of an environment."""
import gym
import gym.spaces as spaces
class FlattenObservation(gym.ObservationWrapper):
"""Observation wrapper that flattens the observation.
Example:
>>> import gym
>>> env = gym.make('CarRacing-v1')
>>> env.observation_space.shape
(96, 96, 3)
>>> env = FlattenObservation(env)
>>> env.observation_space.shape
(27648,)
>>> obs = env.reset()
>>> obs.shape
(27648,)
"""
def __init__(self, env: gym.Env):
"""Flattens the observations of an environment.
Args:
env: The environment to apply the wrapper
"""
super().__init__(env)
self.observation_space = spaces.flatten_space(env.observation_space)
def observation(self, observation):
"""Flattens an observation.
Args:
observation: The observation to flatten
Returns:
The flattened observation
"""
return spaces.flatten(self.env.observation_space, observation)
|