santit96 commited on
Commit
254d61f
·
1 Parent(s): 676caef

Code refactor

Browse files

Sent Worker class to own file and unify with some utils functions
Move main methods to a3c file

Files changed (4) hide show
  1. a3c/discrete_A3C.py +55 -40
  2. a3c/utils.py +0 -66
  3. a3c/worker.py +107 -0
  4. main.py +2 -53
a3c/discrete_A3C.py CHANGED
@@ -5,48 +5,12 @@ The most simple implementation for continuous action.
5
  View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
6
  """
7
  import os
 
8
  import torch.multiprocessing as mp
9
- from .utils import v_wrap, push_and_pull, record, save_model
10
  from .shared_adam import SharedAdam
11
  from .net import Net
12
-
13
- GAMMA = 0.65
14
-
15
- class Worker(mp.Process):
16
- def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir):
17
- super(Worker, self).__init__()
18
- self.max_ep = max_ep
19
- self.name = 'w%02i' % name
20
- self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
21
- self.gnet, self.opt = gnet, opt
22
- self.word_list = words_list
23
- self.lnet = Net(N_S, N_A, words_list, word_width) # local network
24
- self.env = env.unwrapped
25
- self.model_checkpoint_dir = model_checkpoint_dir
26
-
27
- def run(self):
28
- while self.g_ep.value < self.max_ep:
29
- s = self.env.reset()
30
- buffer_s, buffer_a, buffer_r = [], [], []
31
- ep_r = 0.
32
- while True:
33
- a = self.lnet.choose_action(v_wrap(s[None, :]))
34
- s_, r, done, _ = self.env.step(a)
35
- ep_r += r
36
- buffer_a.append(a)
37
- buffer_s.append(s)
38
- buffer_r.append(r)
39
-
40
- if done: # update global and assign to local net
41
- # sync
42
- push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
43
- goal_word = self.word_list[self.env.goal_word]
44
- record( self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
45
- save_model(self.gnet, self.model_checkpoint_dir, self.g_ep.value, self.g_ep_r.value)
46
- buffer_s, buffer_a, buffer_r = [], [], []
47
- break
48
- s = s_
49
- self.res_queue.put(None)
50
 
51
 
52
  def train(env, max_ep, model_checkpoint_dir):
@@ -63,7 +27,8 @@ def train(env, max_ep, model_checkpoint_dir):
63
  global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
64
 
65
  # parallel training
66
- workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep, model_checkpoint_dir) for i in range(mp.cpu_count())]
 
67
  [w.start() for w in workers]
68
  res = [] # record episode reward to plot
69
  while True:
@@ -74,3 +39,53 @@ def train(env, max_ep, model_checkpoint_dir):
74
  break
75
  [w.join() for w in workers]
76
  return global_ep, win_ep, gnet, res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
6
  """
7
  import os
8
+ import torch
9
  import torch.multiprocessing as mp
 
10
  from .shared_adam import SharedAdam
11
  from .net import Net
12
+ from .utils import v_wrap
13
+ from .worker import Worker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  def train(env, max_ep, model_checkpoint_dir):
 
27
  global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
28
 
29
  # parallel training
30
+ workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
31
+ words_list, word_width, win_ep, model_checkpoint_dir) for i in range(mp.cpu_count())]
32
  [w.start() for w in workers]
33
  res = [] # record episode reward to plot
34
  while True:
 
39
  break
40
  [w.join() for w in workers]
41
  return global_ep, win_ep, gnet, res
42
+
43
+
44
+ def evaluate_checkpoints(dir, env):
45
+ n_s = env.observation_space.shape[0]
46
+ n_a = env.action_space.n
47
+ words_list = env.words
48
+ word_width = len(env.words[0])
49
+ net = Net(n_s, n_a, words_list, word_width)
50
+ results = {}
51
+ for checkpoint in os.listdir(dir):
52
+ checkpoint_path = os.path.join(dir, checkpoint)
53
+ if os.path.isfile(checkpoint_path):
54
+ net.load_state_dict(torch.load(checkpoint_path))
55
+ wins, guesses = evaluate(net, env)
56
+ results[checkpoint] = wins, guesses
57
+ return dict(sorted(results.items(), key=lambda x: (x[1][0], -x[1][1]), reverse=True))
58
+
59
+
60
+ def evaluate(net, env):
61
+ n_wins = 0
62
+ n_guesses = 0
63
+ n_win_guesses = 0
64
+ env = env.unwrapped
65
+ N = env.allowable_words
66
+ for goal_word in env.words[:N]:
67
+ win, outcomes = play(net, env)
68
+ if win:
69
+ n_wins += 1
70
+ n_win_guesses += len(outcomes)
71
+ # else:
72
+ # print("Lost!", goal_word, outcomes)
73
+ n_guesses += len(outcomes)
74
+ print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
75
+ f"{n_guesses / N} including losses.")
76
+ return n_wins/N*100, n_win_guesses/n_wins
77
+
78
+
79
+ def play(net, env):
80
+ state = env.reset()
81
+ outcomes = []
82
+ win = False
83
+ for i in range(env.max_turns):
84
+ action = net.choose_action(v_wrap(state[None, :]))
85
+ state, reward, done, _ = env.step(action)
86
+ outcomes.append((env.words[action], reward))
87
+ if done:
88
+ if reward >= 0:
89
+ win = True
90
+ break
91
+ return win, outcomes
a3c/utils.py CHANGED
@@ -1,8 +1,3 @@
1
- """
2
- Functions that use multiple times
3
- """
4
- import os
5
- from torch import nn
6
  import torch
7
  import numpy as np
8
 
@@ -11,64 +6,3 @@ def v_wrap(np_array, dtype=np.float32):
11
  if np_array.dtype != dtype:
12
  np_array = np_array.astype(dtype)
13
  return torch.from_numpy(np_array)
14
-
15
-
16
- def set_init(layers):
17
- for layer in layers:
18
- nn.init.normal_(layer.weight, mean=0., std=0.1)
19
- nn.init.constant_(layer.bias, 0.)
20
-
21
-
22
- def push_and_pull(opt, lnet, gnet, done, s_, bs, ba, br, gamma):
23
- if done:
24
- v_s_ = 0. # terminal
25
- else:
26
- v_s_ = lnet.forward(v_wrap(s_[None, :]))[-1].data.numpy()[0, 0]
27
-
28
- buffer_v_target = []
29
- for r in br[::-1]: # reverse buffer r
30
- v_s_ = r + gamma * v_s_
31
- buffer_v_target.append(v_s_)
32
- buffer_v_target.reverse()
33
-
34
- loss = lnet.loss_func(
35
- v_wrap(np.vstack(bs)),
36
- v_wrap(np.array(ba), dtype=np.int64) if ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)),
37
- v_wrap(np.array(buffer_v_target)[:, None]))
38
-
39
- # calculate local gradients and push local parameters to global
40
- opt.zero_grad()
41
- loss.backward()
42
- for lp, gp in zip(lnet.parameters(), gnet.parameters()):
43
- gp._grad = lp.grad
44
- opt.step()
45
-
46
- # pull global parameters
47
- lnet.load_state_dict(gnet.state_dict())
48
-
49
-
50
- def save_model(gnet, dir, episode, reward):
51
- if reward >= 9 and episode % 100 == 0:
52
- torch.save(gnet.state_dict(), os.path.join(dir, f'model_{episode}.pth'))
53
-
54
-
55
- def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, action_number, winning_ep):
56
- with global_ep.get_lock():
57
- global_ep.value += 1
58
- with global_ep_r.get_lock():
59
- if global_ep_r.value == 0.:
60
- global_ep_r.value = ep_r
61
- else:
62
- global_ep_r.value = global_ep_r.value * 0.99 + ep_r * 0.01
63
- res_queue.put(global_ep_r.value)
64
- if goal_word == action:
65
- winning_ep.value += 1
66
- if global_ep.value % 100 == 0:
67
- print(
68
- name,
69
- "Ep:", global_ep.value,
70
- "| Ep_r: %.0f" % global_ep_r.value,
71
- "| Goal :", goal_word,
72
- "| Action: ", action,
73
- "| Actions: ", action_number
74
- )
 
 
 
 
 
 
1
  import torch
2
  import numpy as np
3
 
 
6
  if np_array.dtype != dtype:
7
  np_array = np_array.astype(dtype)
8
  return torch.from_numpy(np_array)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3c/worker.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Worker class implementation of the a3c discrete algorithm
3
+ """
4
+ import os
5
+ import torch
6
+ import numpy as np
7
+ import torch.multiprocessing as mp
8
+ from torch import nn
9
+ from .net import Net
10
+ from .utils import v_wrap
11
+
12
+
13
+ GAMMA = 0.65
14
+
15
+
16
+ class Worker(mp.Process):
17
+ def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir):
18
+ super(Worker, self).__init__()
19
+ self.max_ep = max_ep
20
+ self.name = 'w%02i' % name
21
+ self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
22
+ self.gnet, self.opt = gnet, opt
23
+ self.word_list = words_list
24
+ # local network
25
+ self.lnet = Net(N_S, N_A, words_list, word_width)
26
+ self.env = env.unwrapped
27
+ self.model_checkpoint_dir = model_checkpoint_dir
28
+
29
+ def run(self):
30
+ while self.g_ep.value < self.max_ep:
31
+ s = self.env.reset()
32
+ buffer_s, buffer_a, buffer_r = [], [], []
33
+ ep_r = 0.
34
+ while True:
35
+ a = self.lnet.choose_action(v_wrap(s[None, :]))
36
+ s_, r, done, _ = self.env.step(a)
37
+ ep_r += r
38
+ buffer_a.append(a)
39
+ buffer_s.append(s)
40
+ buffer_r.append(r)
41
+
42
+ if done: # update global and assign to local net
43
+ # sync
44
+ self.push_and_pull(done, s_, buffer_s,
45
+ buffer_a, buffer_r, GAMMA)
46
+ goal_word = self.word_list[self.env.goal_word]
47
+ self.record(ep_r, goal_word,
48
+ self.word_list[a], len(buffer_a))
49
+ self.save_model()
50
+ buffer_s, buffer_a, buffer_r = [], [], []
51
+ break
52
+ s = s_
53
+ self.res_queue.put(None)
54
+
55
+ def push_and_pull(self, done, s_, bs, ba, br, gamma):
56
+ if done:
57
+ v_s_ = 0. # terminal
58
+ else:
59
+ v_s_ = self.lnet.forward(v_wrap(
60
+ s_[None, :]))[-1].data.numpy()[0, 0]
61
+
62
+ buffer_v_target = []
63
+ for r in br[::-1]: # reverse buffer r
64
+ v_s_ = r + gamma * v_s_
65
+ buffer_v_target.append(v_s_)
66
+ buffer_v_target.reverse()
67
+
68
+ loss = self.lnet.loss_func(
69
+ v_wrap(np.vstack(bs)),
70
+ v_wrap(np.array(ba), dtype=np.int64) if ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)),
71
+ v_wrap(np.array(buffer_v_target)[:, None]))
72
+
73
+ # calculate local gradients and push local parameters to global
74
+ self.opt.zero_grad()
75
+ loss.backward()
76
+ for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()):
77
+ gp._grad = lp.grad
78
+ self.opt.step()
79
+
80
+ # pull global parameters
81
+ self.lnet.load_state_dict(self.gnet.state_dict())
82
+
83
+ def save_model(self):
84
+ if self.g_ep_r.value >= 9 and self.g_ep.value % 100 == 0:
85
+ torch.save(self.gnet.state_dict(), os.path.join(
86
+ self.model_checkpoint_dir, f'model_{ self.g_ep.value}.pth'))
87
+
88
+ def record(self, ep_r, goal_word, action, action_number):
89
+ with self.g_ep.get_lock():
90
+ self.g_ep.value += 1
91
+ with self.g_ep_r.get_lock():
92
+ if self.g_ep_r.value == 0.:
93
+ self.g_ep_r.value = ep_r
94
+ else:
95
+ self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
96
+ self.res_queue.put(self.g_ep_r.value)
97
+ if goal_word == action:
98
+ self.winning_ep.value += 1
99
+ if self.g_ep.value % 100 == 0:
100
+ print(
101
+ self.name,
102
+ "Ep:", self.g_ep.value,
103
+ "| Ep_r: %.0f" % self.g_ep_r.value,
104
+ "| Goal :", goal_word,
105
+ "| Action: ", action,
106
+ "| Actions: ", action_number
107
+ )
main.py CHANGED
@@ -1,62 +1,10 @@
1
  import sys
2
  import os
3
  import gym
4
- import torch
5
  import matplotlib.pyplot as plt
6
- from a3c.discrete_A3C import train
7
- from a3c.utils import v_wrap
8
- from a3c.net import Net
9
  from wordle_env.wordle import WordleEnvBase
10
 
11
- def evaluate_checkpoints(dir, env):
12
- n_s = env.observation_space.shape[0]
13
- n_a = env.action_space.n
14
- words_list = env.words
15
- word_width = len(env.words[0])
16
- net = Net(n_s, n_a, words_list, word_width)
17
- results = {}
18
- print(dir)
19
- for checkpoint in os.listdir(dir):
20
- checkpoint_path = os.path.join(dir, checkpoint)
21
- if os.path.isfile(checkpoint_path):
22
- net.load_state_dict(torch.load(checkpoint_path))
23
- wins, guesses = evaluate(net, env)
24
- results[checkpoint] = wins, guesses
25
- return dict(sorted(results.items(), key=lambda x: (x[1][0], -x[1][1]), reverse=True))
26
-
27
-
28
- def evaluate(net, env):
29
- print("Evaluation mode")
30
- n_wins = 0
31
- n_guesses = 0
32
- n_win_guesses = 0
33
- env = env.unwrapped
34
- N = env.allowable_words
35
- for goal_word in env.words[:N]:
36
- win, outcomes = play(net, env)
37
- if win:
38
- n_wins += 1
39
- n_win_guesses += len(outcomes)
40
- # else:
41
- # print("Lost!", goal_word, outcomes)
42
- n_guesses += len(outcomes)
43
- print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
44
- f"{n_guesses / N} including losses.")
45
- return n_wins/N*100, n_win_guesses/n_wins
46
-
47
- def play(net, env):
48
- state = env.reset()
49
- outcomes = []
50
- win = False
51
- for i in range(env.max_turns):
52
- action = net.choose_action(v_wrap(state[None, :]))
53
- state, reward, done, _ = env.step(action)
54
- outcomes.append((env.words[action], reward))
55
- if done:
56
- if reward >= 0:
57
- win = True
58
- break
59
- return win, outcomes
60
 
61
  def print_results(global_ep, win_ep, res):
62
  print("Jugadas:", global_ep.value)
@@ -78,5 +26,6 @@ if __name__ == "__main__":
78
  print_results(global_ep, win_ep, res)
79
  evaluate(gnet, env)
80
  else:
 
81
  results = evaluate_checkpoints(model_checkpoint_dir, env)
82
  print(results)
 
1
  import sys
2
  import os
3
  import gym
 
4
  import matplotlib.pyplot as plt
5
+ from a3c.discrete_A3C import train, evaluate, evaluate_checkpoints
 
 
6
  from wordle_env.wordle import WordleEnvBase
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def print_results(global_ep, win_ep, res):
10
  print("Jugadas:", global_ep.value)
 
26
  print_results(global_ep, win_ep, res)
27
  evaluate(gnet, env)
28
  else:
29
+ print("Evaluation mode")
30
  results = evaluate_checkpoints(model_checkpoint_dir, env)
31
  print(results)