qdqn-FrozenLake-v1 / model /wrappers.py
Arnas
refactor
e7650e8
raw
history blame
449 Bytes
import gym
import gym.spaces
import numpy as np
class BinaryWrapper(gym.ObservationWrapper):
def __init__(self, env):
super(BinaryWrapper, self).__init__(env)
self.bits = int(np.ceil(np.log2(env.observation_space.n)))
self.observation_space = gym.spaces.MultiBinary(self.bits)
def observation(self, obs):
binary = map(float, "{0:b}".format(int(obs)).zfill(self.bits))
return np.array(list(binary))