Spaces:
Sleeping
Sleeping
File size: 2,073 Bytes
e0c3c75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import multiprocessing
import time
import gymnasium as gym
import numpy as np
from gymnasium.envs.toy_text.frozen_lake import generate_random_map
from DPAgent import DPAgent
from MCAgent import MCAgent
env_ver = "FrozenLake-v1"
def test_mc(i, seed):
env = gym.make(
env_ver, desc=generate_random_map(size=i, p=0.4, seed=seed), is_slippery=False
)
agent = MCAgent(env=env_ver, gamma=1.0, epsilon=0.4)
agent.env = env
agent.n_states, agent.n_actions = (
agent.env.observation_space.n,
agent.env.action_space.n,
)
agent.initialize()
tic = time.perf_counter()
trained = agent.train(
max_steps=int((i**2) * 3),
n_train_episodes=10_000,
save_best=False,
early_stopping=True,
update_type="every_visit",
)
toc = time.perf_counter()
return trained, toc - tic
def test_dp(i, seed, gamma=0.99):
env = gym.make(env_ver, desc=generate_random_map(i, seed=seed), is_slippery=False)
agent = DPAgent(env=env_ver, gamma=gamma)
agent.env = env
agent.V = np.zeros(agent.env.observation_space.n)
agent.Pi = np.zeros(agent.env.observation_space.n, agent.env.action_space.n)
agent.n_states, agent.n_actions = (
agent.env.observation_space.n,
agent.env.action_space.n,
)
return agent.train()
def run_test(i):
mc_trained = False
seed = 0
mc_time = 0
dp_time = 0
while not mc_trained:
seed = np.random.randint(0, 100000)
mc_trained, train_time = test_mc(i, seed)
mc_time = train_time
dp_time = test_dp(i, seed)
return mc_time, dp_time
def run_exp(gamma):
times = []
for i in range(8, 512, 8):
# mc_time, dp_time = run_test(i)
dp_time = test_dp(i, 0, gamma=gamma)
times.append((i, dp_time))
times = np.array(times)
np.save(f"times_{gamma}.npy", times)
return
def main():
with multiprocessing.Pool(8) as p:
gamma = [0.95, 0.9, 0.8, 0.7, 0.6, 0.5, 0.1]
p.map(run_exp, gamma)
if __name__ == "__main__":
main()
|