File size: 449 Bytes
52dd602
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import gym
import gym.spaces
import numpy as np


class BinaryWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super(BinaryWrapper, self).__init__(env)
        self.bits = int(np.ceil(np.log2(env.observation_space.n)))
        self.observation_space = gym.spaces.MultiBinary(self.bits)

    def observation(self, obs):
        binary = map(float, "{0:b}".format(int(obs)).zfill(self.bits))
        return np.array(list(binary))