import gym import gym.spaces import numpy as np class BinaryWrapper(gym.ObservationWrapper): def __init__(self, env): super(BinaryWrapper, self).__init__(env) self.bits = int(np.ceil(np.log2(env.observation_space.n))) self.observation_space = gym.spaces.MultiBinary(self.bits) def observation(self, obs): binary = map(float, "{0:b}".format(int(obs)).zfill(self.bits)) return np.array(list(binary))