Spaces:
Sleeping
Sleeping
Code refactor
Browse filesSent Worker class to own file and unify with some utils functions
Move main methods to a3c file
- a3c/discrete_A3C.py +55 -40
- a3c/utils.py +0 -66
- a3c/worker.py +107 -0
- main.py +2 -53
a3c/discrete_A3C.py
CHANGED
@@ -5,48 +5,12 @@ The most simple implementation for continuous action.
|
|
5 |
View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
|
6 |
"""
|
7 |
import os
|
|
|
8 |
import torch.multiprocessing as mp
|
9 |
-
from .utils import v_wrap, push_and_pull, record, save_model
|
10 |
from .shared_adam import SharedAdam
|
11 |
from .net import Net
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
class Worker(mp.Process):
|
16 |
-
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir):
|
17 |
-
super(Worker, self).__init__()
|
18 |
-
self.max_ep = max_ep
|
19 |
-
self.name = 'w%02i' % name
|
20 |
-
self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
|
21 |
-
self.gnet, self.opt = gnet, opt
|
22 |
-
self.word_list = words_list
|
23 |
-
self.lnet = Net(N_S, N_A, words_list, word_width) # local network
|
24 |
-
self.env = env.unwrapped
|
25 |
-
self.model_checkpoint_dir = model_checkpoint_dir
|
26 |
-
|
27 |
-
def run(self):
|
28 |
-
while self.g_ep.value < self.max_ep:
|
29 |
-
s = self.env.reset()
|
30 |
-
buffer_s, buffer_a, buffer_r = [], [], []
|
31 |
-
ep_r = 0.
|
32 |
-
while True:
|
33 |
-
a = self.lnet.choose_action(v_wrap(s[None, :]))
|
34 |
-
s_, r, done, _ = self.env.step(a)
|
35 |
-
ep_r += r
|
36 |
-
buffer_a.append(a)
|
37 |
-
buffer_s.append(s)
|
38 |
-
buffer_r.append(r)
|
39 |
-
|
40 |
-
if done: # update global and assign to local net
|
41 |
-
# sync
|
42 |
-
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
|
43 |
-
goal_word = self.word_list[self.env.goal_word]
|
44 |
-
record( self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
|
45 |
-
save_model(self.gnet, self.model_checkpoint_dir, self.g_ep.value, self.g_ep_r.value)
|
46 |
-
buffer_s, buffer_a, buffer_r = [], [], []
|
47 |
-
break
|
48 |
-
s = s_
|
49 |
-
self.res_queue.put(None)
|
50 |
|
51 |
|
52 |
def train(env, max_ep, model_checkpoint_dir):
|
@@ -63,7 +27,8 @@ def train(env, max_ep, model_checkpoint_dir):
|
|
63 |
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
64 |
|
65 |
# parallel training
|
66 |
-
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
|
|
|
67 |
[w.start() for w in workers]
|
68 |
res = [] # record episode reward to plot
|
69 |
while True:
|
@@ -74,3 +39,53 @@ def train(env, max_ep, model_checkpoint_dir):
|
|
74 |
break
|
75 |
[w.join() for w in workers]
|
76 |
return global_ep, win_ep, gnet, res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
|
6 |
"""
|
7 |
import os
|
8 |
+
import torch
|
9 |
import torch.multiprocessing as mp
|
|
|
10 |
from .shared_adam import SharedAdam
|
11 |
from .net import Net
|
12 |
+
from .utils import v_wrap
|
13 |
+
from .worker import Worker
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
def train(env, max_ep, model_checkpoint_dir):
|
|
|
27 |
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
28 |
|
29 |
# parallel training
|
30 |
+
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
|
31 |
+
words_list, word_width, win_ep, model_checkpoint_dir) for i in range(mp.cpu_count())]
|
32 |
[w.start() for w in workers]
|
33 |
res = [] # record episode reward to plot
|
34 |
while True:
|
|
|
39 |
break
|
40 |
[w.join() for w in workers]
|
41 |
return global_ep, win_ep, gnet, res
|
42 |
+
|
43 |
+
|
44 |
+
def evaluate_checkpoints(dir, env):
|
45 |
+
n_s = env.observation_space.shape[0]
|
46 |
+
n_a = env.action_space.n
|
47 |
+
words_list = env.words
|
48 |
+
word_width = len(env.words[0])
|
49 |
+
net = Net(n_s, n_a, words_list, word_width)
|
50 |
+
results = {}
|
51 |
+
for checkpoint in os.listdir(dir):
|
52 |
+
checkpoint_path = os.path.join(dir, checkpoint)
|
53 |
+
if os.path.isfile(checkpoint_path):
|
54 |
+
net.load_state_dict(torch.load(checkpoint_path))
|
55 |
+
wins, guesses = evaluate(net, env)
|
56 |
+
results[checkpoint] = wins, guesses
|
57 |
+
return dict(sorted(results.items(), key=lambda x: (x[1][0], -x[1][1]), reverse=True))
|
58 |
+
|
59 |
+
|
60 |
+
def evaluate(net, env):
|
61 |
+
n_wins = 0
|
62 |
+
n_guesses = 0
|
63 |
+
n_win_guesses = 0
|
64 |
+
env = env.unwrapped
|
65 |
+
N = env.allowable_words
|
66 |
+
for goal_word in env.words[:N]:
|
67 |
+
win, outcomes = play(net, env)
|
68 |
+
if win:
|
69 |
+
n_wins += 1
|
70 |
+
n_win_guesses += len(outcomes)
|
71 |
+
# else:
|
72 |
+
# print("Lost!", goal_word, outcomes)
|
73 |
+
n_guesses += len(outcomes)
|
74 |
+
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
75 |
+
f"{n_guesses / N} including losses.")
|
76 |
+
return n_wins/N*100, n_win_guesses/n_wins
|
77 |
+
|
78 |
+
|
79 |
+
def play(net, env):
|
80 |
+
state = env.reset()
|
81 |
+
outcomes = []
|
82 |
+
win = False
|
83 |
+
for i in range(env.max_turns):
|
84 |
+
action = net.choose_action(v_wrap(state[None, :]))
|
85 |
+
state, reward, done, _ = env.step(action)
|
86 |
+
outcomes.append((env.words[action], reward))
|
87 |
+
if done:
|
88 |
+
if reward >= 0:
|
89 |
+
win = True
|
90 |
+
break
|
91 |
+
return win, outcomes
|
a3c/utils.py
CHANGED
@@ -1,8 +1,3 @@
|
|
1 |
-
"""
|
2 |
-
Functions that use multiple times
|
3 |
-
"""
|
4 |
-
import os
|
5 |
-
from torch import nn
|
6 |
import torch
|
7 |
import numpy as np
|
8 |
|
@@ -11,64 +6,3 @@ def v_wrap(np_array, dtype=np.float32):
|
|
11 |
if np_array.dtype != dtype:
|
12 |
np_array = np_array.astype(dtype)
|
13 |
return torch.from_numpy(np_array)
|
14 |
-
|
15 |
-
|
16 |
-
def set_init(layers):
|
17 |
-
for layer in layers:
|
18 |
-
nn.init.normal_(layer.weight, mean=0., std=0.1)
|
19 |
-
nn.init.constant_(layer.bias, 0.)
|
20 |
-
|
21 |
-
|
22 |
-
def push_and_pull(opt, lnet, gnet, done, s_, bs, ba, br, gamma):
|
23 |
-
if done:
|
24 |
-
v_s_ = 0. # terminal
|
25 |
-
else:
|
26 |
-
v_s_ = lnet.forward(v_wrap(s_[None, :]))[-1].data.numpy()[0, 0]
|
27 |
-
|
28 |
-
buffer_v_target = []
|
29 |
-
for r in br[::-1]: # reverse buffer r
|
30 |
-
v_s_ = r + gamma * v_s_
|
31 |
-
buffer_v_target.append(v_s_)
|
32 |
-
buffer_v_target.reverse()
|
33 |
-
|
34 |
-
loss = lnet.loss_func(
|
35 |
-
v_wrap(np.vstack(bs)),
|
36 |
-
v_wrap(np.array(ba), dtype=np.int64) if ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)),
|
37 |
-
v_wrap(np.array(buffer_v_target)[:, None]))
|
38 |
-
|
39 |
-
# calculate local gradients and push local parameters to global
|
40 |
-
opt.zero_grad()
|
41 |
-
loss.backward()
|
42 |
-
for lp, gp in zip(lnet.parameters(), gnet.parameters()):
|
43 |
-
gp._grad = lp.grad
|
44 |
-
opt.step()
|
45 |
-
|
46 |
-
# pull global parameters
|
47 |
-
lnet.load_state_dict(gnet.state_dict())
|
48 |
-
|
49 |
-
|
50 |
-
def save_model(gnet, dir, episode, reward):
|
51 |
-
if reward >= 9 and episode % 100 == 0:
|
52 |
-
torch.save(gnet.state_dict(), os.path.join(dir, f'model_{episode}.pth'))
|
53 |
-
|
54 |
-
|
55 |
-
def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, action_number, winning_ep):
|
56 |
-
with global_ep.get_lock():
|
57 |
-
global_ep.value += 1
|
58 |
-
with global_ep_r.get_lock():
|
59 |
-
if global_ep_r.value == 0.:
|
60 |
-
global_ep_r.value = ep_r
|
61 |
-
else:
|
62 |
-
global_ep_r.value = global_ep_r.value * 0.99 + ep_r * 0.01
|
63 |
-
res_queue.put(global_ep_r.value)
|
64 |
-
if goal_word == action:
|
65 |
-
winning_ep.value += 1
|
66 |
-
if global_ep.value % 100 == 0:
|
67 |
-
print(
|
68 |
-
name,
|
69 |
-
"Ep:", global_ep.value,
|
70 |
-
"| Ep_r: %.0f" % global_ep_r.value,
|
71 |
-
"| Goal :", goal_word,
|
72 |
-
"| Action: ", action,
|
73 |
-
"| Actions: ", action_number
|
74 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
|
|
|
6 |
if np_array.dtype != dtype:
|
7 |
np_array = np_array.astype(dtype)
|
8 |
return torch.from_numpy(np_array)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
a3c/worker.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Worker class implementation of the a3c discrete algorithm
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import torch.multiprocessing as mp
|
8 |
+
from torch import nn
|
9 |
+
from .net import Net
|
10 |
+
from .utils import v_wrap
|
11 |
+
|
12 |
+
|
13 |
+
GAMMA = 0.65
|
14 |
+
|
15 |
+
|
16 |
+
class Worker(mp.Process):
|
17 |
+
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir):
|
18 |
+
super(Worker, self).__init__()
|
19 |
+
self.max_ep = max_ep
|
20 |
+
self.name = 'w%02i' % name
|
21 |
+
self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
|
22 |
+
self.gnet, self.opt = gnet, opt
|
23 |
+
self.word_list = words_list
|
24 |
+
# local network
|
25 |
+
self.lnet = Net(N_S, N_A, words_list, word_width)
|
26 |
+
self.env = env.unwrapped
|
27 |
+
self.model_checkpoint_dir = model_checkpoint_dir
|
28 |
+
|
29 |
+
def run(self):
|
30 |
+
while self.g_ep.value < self.max_ep:
|
31 |
+
s = self.env.reset()
|
32 |
+
buffer_s, buffer_a, buffer_r = [], [], []
|
33 |
+
ep_r = 0.
|
34 |
+
while True:
|
35 |
+
a = self.lnet.choose_action(v_wrap(s[None, :]))
|
36 |
+
s_, r, done, _ = self.env.step(a)
|
37 |
+
ep_r += r
|
38 |
+
buffer_a.append(a)
|
39 |
+
buffer_s.append(s)
|
40 |
+
buffer_r.append(r)
|
41 |
+
|
42 |
+
if done: # update global and assign to local net
|
43 |
+
# sync
|
44 |
+
self.push_and_pull(done, s_, buffer_s,
|
45 |
+
buffer_a, buffer_r, GAMMA)
|
46 |
+
goal_word = self.word_list[self.env.goal_word]
|
47 |
+
self.record(ep_r, goal_word,
|
48 |
+
self.word_list[a], len(buffer_a))
|
49 |
+
self.save_model()
|
50 |
+
buffer_s, buffer_a, buffer_r = [], [], []
|
51 |
+
break
|
52 |
+
s = s_
|
53 |
+
self.res_queue.put(None)
|
54 |
+
|
55 |
+
def push_and_pull(self, done, s_, bs, ba, br, gamma):
|
56 |
+
if done:
|
57 |
+
v_s_ = 0. # terminal
|
58 |
+
else:
|
59 |
+
v_s_ = self.lnet.forward(v_wrap(
|
60 |
+
s_[None, :]))[-1].data.numpy()[0, 0]
|
61 |
+
|
62 |
+
buffer_v_target = []
|
63 |
+
for r in br[::-1]: # reverse buffer r
|
64 |
+
v_s_ = r + gamma * v_s_
|
65 |
+
buffer_v_target.append(v_s_)
|
66 |
+
buffer_v_target.reverse()
|
67 |
+
|
68 |
+
loss = self.lnet.loss_func(
|
69 |
+
v_wrap(np.vstack(bs)),
|
70 |
+
v_wrap(np.array(ba), dtype=np.int64) if ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)),
|
71 |
+
v_wrap(np.array(buffer_v_target)[:, None]))
|
72 |
+
|
73 |
+
# calculate local gradients and push local parameters to global
|
74 |
+
self.opt.zero_grad()
|
75 |
+
loss.backward()
|
76 |
+
for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()):
|
77 |
+
gp._grad = lp.grad
|
78 |
+
self.opt.step()
|
79 |
+
|
80 |
+
# pull global parameters
|
81 |
+
self.lnet.load_state_dict(self.gnet.state_dict())
|
82 |
+
|
83 |
+
def save_model(self):
|
84 |
+
if self.g_ep_r.value >= 9 and self.g_ep.value % 100 == 0:
|
85 |
+
torch.save(self.gnet.state_dict(), os.path.join(
|
86 |
+
self.model_checkpoint_dir, f'model_{ self.g_ep.value}.pth'))
|
87 |
+
|
88 |
+
def record(self, ep_r, goal_word, action, action_number):
|
89 |
+
with self.g_ep.get_lock():
|
90 |
+
self.g_ep.value += 1
|
91 |
+
with self.g_ep_r.get_lock():
|
92 |
+
if self.g_ep_r.value == 0.:
|
93 |
+
self.g_ep_r.value = ep_r
|
94 |
+
else:
|
95 |
+
self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
|
96 |
+
self.res_queue.put(self.g_ep_r.value)
|
97 |
+
if goal_word == action:
|
98 |
+
self.winning_ep.value += 1
|
99 |
+
if self.g_ep.value % 100 == 0:
|
100 |
+
print(
|
101 |
+
self.name,
|
102 |
+
"Ep:", self.g_ep.value,
|
103 |
+
"| Ep_r: %.0f" % self.g_ep_r.value,
|
104 |
+
"| Goal :", goal_word,
|
105 |
+
"| Action: ", action,
|
106 |
+
"| Actions: ", action_number
|
107 |
+
)
|
main.py
CHANGED
@@ -1,62 +1,10 @@
|
|
1 |
import sys
|
2 |
import os
|
3 |
import gym
|
4 |
-
import torch
|
5 |
import matplotlib.pyplot as plt
|
6 |
-
from a3c.discrete_A3C import train
|
7 |
-
from a3c.utils import v_wrap
|
8 |
-
from a3c.net import Net
|
9 |
from wordle_env.wordle import WordleEnvBase
|
10 |
|
11 |
-
def evaluate_checkpoints(dir, env):
|
12 |
-
n_s = env.observation_space.shape[0]
|
13 |
-
n_a = env.action_space.n
|
14 |
-
words_list = env.words
|
15 |
-
word_width = len(env.words[0])
|
16 |
-
net = Net(n_s, n_a, words_list, word_width)
|
17 |
-
results = {}
|
18 |
-
print(dir)
|
19 |
-
for checkpoint in os.listdir(dir):
|
20 |
-
checkpoint_path = os.path.join(dir, checkpoint)
|
21 |
-
if os.path.isfile(checkpoint_path):
|
22 |
-
net.load_state_dict(torch.load(checkpoint_path))
|
23 |
-
wins, guesses = evaluate(net, env)
|
24 |
-
results[checkpoint] = wins, guesses
|
25 |
-
return dict(sorted(results.items(), key=lambda x: (x[1][0], -x[1][1]), reverse=True))
|
26 |
-
|
27 |
-
|
28 |
-
def evaluate(net, env):
|
29 |
-
print("Evaluation mode")
|
30 |
-
n_wins = 0
|
31 |
-
n_guesses = 0
|
32 |
-
n_win_guesses = 0
|
33 |
-
env = env.unwrapped
|
34 |
-
N = env.allowable_words
|
35 |
-
for goal_word in env.words[:N]:
|
36 |
-
win, outcomes = play(net, env)
|
37 |
-
if win:
|
38 |
-
n_wins += 1
|
39 |
-
n_win_guesses += len(outcomes)
|
40 |
-
# else:
|
41 |
-
# print("Lost!", goal_word, outcomes)
|
42 |
-
n_guesses += len(outcomes)
|
43 |
-
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
44 |
-
f"{n_guesses / N} including losses.")
|
45 |
-
return n_wins/N*100, n_win_guesses/n_wins
|
46 |
-
|
47 |
-
def play(net, env):
|
48 |
-
state = env.reset()
|
49 |
-
outcomes = []
|
50 |
-
win = False
|
51 |
-
for i in range(env.max_turns):
|
52 |
-
action = net.choose_action(v_wrap(state[None, :]))
|
53 |
-
state, reward, done, _ = env.step(action)
|
54 |
-
outcomes.append((env.words[action], reward))
|
55 |
-
if done:
|
56 |
-
if reward >= 0:
|
57 |
-
win = True
|
58 |
-
break
|
59 |
-
return win, outcomes
|
60 |
|
61 |
def print_results(global_ep, win_ep, res):
|
62 |
print("Jugadas:", global_ep.value)
|
@@ -78,5 +26,6 @@ if __name__ == "__main__":
|
|
78 |
print_results(global_ep, win_ep, res)
|
79 |
evaluate(gnet, env)
|
80 |
else:
|
|
|
81 |
results = evaluate_checkpoints(model_checkpoint_dir, env)
|
82 |
print(results)
|
|
|
1 |
import sys
|
2 |
import os
|
3 |
import gym
|
|
|
4 |
import matplotlib.pyplot as plt
|
5 |
+
from a3c.discrete_A3C import train, evaluate, evaluate_checkpoints
|
|
|
|
|
6 |
from wordle_env.wordle import WordleEnvBase
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def print_results(global_ep, win_ep, res):
|
10 |
print("Jugadas:", global_ep.value)
|
|
|
26 |
print_results(global_ep, win_ep, res)
|
27 |
evaluate(gnet, env)
|
28 |
else:
|
29 |
+
print("Evaluation mode")
|
30 |
results = evaluate_checkpoints(model_checkpoint_dir, env)
|
31 |
print(results)
|