File size: 449 Bytes
52dd602 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import gym
import gym.spaces
import numpy as np
class BinaryWrapper(gym.ObservationWrapper):
def __init__(self, env):
super(BinaryWrapper, self).__init__(env)
self.bits = int(np.ceil(np.log2(env.observation_space.n)))
self.observation_space = gym.spaces.MultiBinary(self.bits)
def observation(self, obs):
binary = map(float, "{0:b}".format(int(obs)).zfill(self.bits))
return np.array(list(binary))
|