santit96 commited on
Commit
676caef
·
1 Parent(s): 1bd428f

Add posibility to save and load models

Browse files

Also add an evaluation task to evaluate saved models

Files changed (4) hide show
  1. .gitignore +4 -1
  2. a3c/discrete_A3C.py +9 -6
  3. a3c/utils.py +6 -0
  4. main.py +29 -4
.gitignore CHANGED
@@ -113,4 +113,7 @@ GitHub.sublime-settings
113
  !.vscode/tasks.json
114
  !.vscode/launch.json
115
  !.vscode/extensions.json
116
- .history
 
 
 
 
113
  !.vscode/tasks.json
114
  !.vscode/launch.json
115
  !.vscode/extensions.json
116
+ .history
117
+
118
+ # PyTorch model files
119
+ *.pth
a3c/discrete_A3C.py CHANGED
@@ -6,14 +6,14 @@ View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.i
6
  """
7
  import os
8
  import torch.multiprocessing as mp
9
- from .utils import v_wrap, push_and_pull, record
10
  from .shared_adam import SharedAdam
11
  from .net import Net
12
 
13
  GAMMA = 0.65
14
 
15
  class Worker(mp.Process):
16
- def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
17
  super(Worker, self).__init__()
18
  self.max_ep = max_ep
19
  self.name = 'w%02i' % name
@@ -22,6 +22,7 @@ class Worker(mp.Process):
22
  self.word_list = words_list
23
  self.lnet = Net(N_S, N_A, words_list, word_width) # local network
24
  self.env = env.unwrapped
 
25
 
26
  def run(self):
27
  while self.g_ep.value < self.max_ep:
@@ -40,16 +41,18 @@ class Worker(mp.Process):
40
  # sync
41
  push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
42
  goal_word = self.word_list[self.env.goal_word]
43
- record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
 
44
  buffer_s, buffer_a, buffer_r = [], [], []
45
  break
46
  s = s_
47
  self.res_queue.put(None)
48
 
49
 
50
- def train(env, max_ep):
51
  os.environ["OMP_NUM_THREADS"] = "1"
52
-
 
53
  n_s = env.observation_space.shape[0]
54
  n_a = env.action_space.n
55
  words_list = env.words
@@ -60,7 +63,7 @@ def train(env, max_ep):
60
  global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
61
 
62
  # parallel training
63
- workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep) for i in range(mp.cpu_count())]
64
  [w.start() for w in workers]
65
  res = [] # record episode reward to plot
66
  while True:
 
6
  """
7
  import os
8
  import torch.multiprocessing as mp
9
+ from .utils import v_wrap, push_and_pull, record, save_model
10
  from .shared_adam import SharedAdam
11
  from .net import Net
12
 
13
  GAMMA = 0.65
14
 
15
  class Worker(mp.Process):
16
+ def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir):
17
  super(Worker, self).__init__()
18
  self.max_ep = max_ep
19
  self.name = 'w%02i' % name
 
22
  self.word_list = words_list
23
  self.lnet = Net(N_S, N_A, words_list, word_width) # local network
24
  self.env = env.unwrapped
25
+ self.model_checkpoint_dir = model_checkpoint_dir
26
 
27
  def run(self):
28
  while self.g_ep.value < self.max_ep:
 
41
  # sync
42
  push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
43
  goal_word = self.word_list[self.env.goal_word]
44
+ record( self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
45
+ save_model(self.gnet, self.model_checkpoint_dir, self.g_ep.value, self.g_ep_r.value)
46
  buffer_s, buffer_a, buffer_r = [], [], []
47
  break
48
  s = s_
49
  self.res_queue.put(None)
50
 
51
 
52
+ def train(env, max_ep, model_checkpoint_dir):
53
  os.environ["OMP_NUM_THREADS"] = "1"
54
+ if not os.path.exists(model_checkpoint_dir):
55
+ os.makedirs(model_checkpoint_dir)
56
  n_s = env.observation_space.shape[0]
57
  n_a = env.action_space.n
58
  words_list = env.words
 
63
  global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
64
 
65
  # parallel training
66
+ workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep, model_checkpoint_dir) for i in range(mp.cpu_count())]
67
  [w.start() for w in workers]
68
  res = [] # record episode reward to plot
69
  while True:
a3c/utils.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  Functions that use multiple times
3
  """
 
4
  from torch import nn
5
  import torch
6
  import numpy as np
@@ -46,6 +47,11 @@ def push_and_pull(opt, lnet, gnet, done, s_, bs, ba, br, gamma):
46
  lnet.load_state_dict(gnet.state_dict())
47
 
48
 
 
 
 
 
 
49
  def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, action_number, winning_ep):
50
  with global_ep.get_lock():
51
  global_ep.value += 1
 
1
  """
2
  Functions that use multiple times
3
  """
4
+ import os
5
  from torch import nn
6
  import torch
7
  import numpy as np
 
47
  lnet.load_state_dict(gnet.state_dict())
48
 
49
 
50
+ def save_model(gnet, dir, episode, reward):
51
+ if reward >= 9 and episode % 100 == 0:
52
+ torch.save(gnet.state_dict(), os.path.join(dir, f'model_{episode}.pth'))
53
+
54
+
55
  def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, action_number, winning_ep):
56
  with global_ep.get_lock():
57
  global_ep.value += 1
main.py CHANGED
@@ -1,10 +1,29 @@
1
  import sys
 
2
  import gym
 
3
  import matplotlib.pyplot as plt
4
  from a3c.discrete_A3C import train
5
  from a3c.utils import v_wrap
 
6
  from wordle_env.wordle import WordleEnvBase
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def evaluate(net, env):
10
  print("Evaluation mode")
@@ -21,9 +40,9 @@ def evaluate(net, env):
21
  # else:
22
  # print("Lost!", goal_word, outcomes)
23
  n_guesses += len(outcomes)
24
-
25
  print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
26
  f"{n_guesses / N} including losses.")
 
27
 
28
  def play(net, env):
29
  state = env.reset()
@@ -51,7 +70,13 @@ def print_results(global_ep, win_ep, res):
51
  if __name__ == "__main__":
52
  max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
53
  env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
 
54
  env = gym.make(env_id)
55
- global_ep, win_ep, gnet, res = train(env, max_ep)
56
- print_results(global_ep, win_ep, res)
57
- evaluate(gnet, env)
 
 
 
 
 
 
1
  import sys
2
+ import os
3
  import gym
4
+ import torch
5
  import matplotlib.pyplot as plt
6
  from a3c.discrete_A3C import train
7
  from a3c.utils import v_wrap
8
+ from a3c.net import Net
9
  from wordle_env.wordle import WordleEnvBase
10
 
11
+ def evaluate_checkpoints(dir, env):
12
+ n_s = env.observation_space.shape[0]
13
+ n_a = env.action_space.n
14
+ words_list = env.words
15
+ word_width = len(env.words[0])
16
+ net = Net(n_s, n_a, words_list, word_width)
17
+ results = {}
18
+ print(dir)
19
+ for checkpoint in os.listdir(dir):
20
+ checkpoint_path = os.path.join(dir, checkpoint)
21
+ if os.path.isfile(checkpoint_path):
22
+ net.load_state_dict(torch.load(checkpoint_path))
23
+ wins, guesses = evaluate(net, env)
24
+ results[checkpoint] = wins, guesses
25
+ return dict(sorted(results.items(), key=lambda x: (x[1][0], -x[1][1]), reverse=True))
26
+
27
 
28
  def evaluate(net, env):
29
  print("Evaluation mode")
 
40
  # else:
41
  # print("Lost!", goal_word, outcomes)
42
  n_guesses += len(outcomes)
 
43
  print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
44
  f"{n_guesses / N} including losses.")
45
+ return n_wins/N*100, n_win_guesses/n_wins
46
 
47
  def play(net, env):
48
  state = env.reset()
 
70
  if __name__ == "__main__":
71
  max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
72
  env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
73
+ evaluation = True if len(sys.argv) > 3 and sys.argv[3] == 'evaluation' else False
74
  env = gym.make(env_id)
75
+ model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
76
+ if not evaluation:
77
+ global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
78
+ print_results(global_ep, win_ep, res)
79
+ evaluate(gnet, env)
80
+ else:
81
+ results = evaluate_checkpoints(model_checkpoint_dir, env)
82
+ print(results)