santit96 commited on
Commit
fa34b1d
·
1 Parent(s): f899dd3

Add the posiblity to save checkpoints of the model and the condition on which the model is saved as arguments

Browse files
Files changed (3) hide show
  1. a3c/train.py +4 -2
  2. a3c/worker.py +27 -3
  3. main.py +9 -8
a3c/train.py CHANGED
@@ -6,7 +6,7 @@ from .net import Net
6
  from .worker import Worker
7
 
8
 
9
- def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=None):
10
  os.environ["OMP_NUM_THREADS"] = "1"
11
  if not os.path.exists(model_checkpoint_dir):
12
  os.makedirs(model_checkpoint_dir)
@@ -23,7 +23,7 @@ def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=Non
23
 
24
  # parallel training
25
  workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
26
- words_list, word_width, win_ep, model_checkpoint_dir, gamma, pretrained_model_path) for i in range(mp.cpu_count())]
27
  [w.start() for w in workers]
28
  res = [] # record episode reward to plot
29
  while True:
@@ -33,4 +33,6 @@ def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=Non
33
  else:
34
  break
35
  [w.join() for w in workers]
 
 
36
  return global_ep, win_ep, gnet, res
 
6
  from .worker import Worker
7
 
8
 
9
+ def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=None, save=False, min_reward=9.9, every_n_save=100):
10
  os.environ["OMP_NUM_THREADS"] = "1"
11
  if not os.path.exists(model_checkpoint_dir):
12
  os.makedirs(model_checkpoint_dir)
 
23
 
24
  # parallel training
25
  workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
26
+ words_list, word_width, win_ep, model_checkpoint_dir, gamma, pretrained_model_path, save, min_reward, every_n_save) for i in range(mp.cpu_count())]
27
  [w.start() for w in workers]
28
  res = [] # record episode reward to plot
29
  while True:
 
33
  else:
34
  break
35
  [w.join() for w in workers]
36
+ if save:
37
+ torch.save(gnet.state_dict(), os.path.join(model_checkpoint_dir, f'model_{env.unwrapped.spec.id}.pth'))
38
  return global_ep, win_ep, gnet, res
a3c/worker.py CHANGED
@@ -11,7 +11,28 @@ from .utils import v_wrap
11
 
12
 
13
  class Worker(mp.Process):
14
- def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir, gamma, pretrained_model_path=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  super(Worker, self).__init__()
16
  self.max_ep = max_ep
17
  self.name = 'w%02i' % name
@@ -25,6 +46,9 @@ class Worker(mp.Process):
25
  self.env = env.unwrapped
26
  self.gamma = gamma
27
  self.model_checkpoint_dir = model_checkpoint_dir
 
 
 
28
 
29
  def run(self):
30
  while self.g_ep.value < self.max_ep:
@@ -81,9 +105,9 @@ class Worker(mp.Process):
81
  self.lnet.load_state_dict(self.gnet.state_dict())
82
 
83
  def save_model(self):
84
- if self.g_ep_r.value >= 9.9 and self.g_ep.value % 100 == 0:
85
  torch.save(self.gnet.state_dict(), os.path.join(
86
- self.model_checkpoint_dir, f'model_{ self.g_ep.value}.pth'))
87
 
88
  def record(self, ep_r, goal_word, action, action_number):
89
  with self.g_ep.get_lock():
 
11
 
12
 
13
  class Worker(mp.Process):
14
+ def __init__(
15
+ self,
16
+ max_ep,
17
+ gnet,
18
+ opt,
19
+ global_ep,
20
+ global_ep_r,
21
+ res_queue,
22
+ name,
23
+ env,
24
+ N_S,
25
+ N_A,
26
+ words_list,
27
+ word_width,
28
+ winning_ep,
29
+ model_checkpoint_dir,
30
+ gamma=0.,
31
+ pretrained_model_path=None,
32
+ save=False,
33
+ min_reward=9.9,
34
+ every_n_save=100
35
+ ):
36
  super(Worker, self).__init__()
37
  self.max_ep = max_ep
38
  self.name = 'w%02i' % name
 
46
  self.env = env.unwrapped
47
  self.gamma = gamma
48
  self.model_checkpoint_dir = model_checkpoint_dir
49
+ self.save = save
50
+ self.min_reward = min_reward
51
+ self.every_n_save = every_n_save
52
 
53
  def run(self):
54
  while self.g_ep.value < self.max_ep:
 
105
  self.lnet.load_state_dict(self.gnet.state_dict())
106
 
107
  def save_model(self):
108
+ if self.save and self.g_ep_r.value >= self.min_reward and self.g_ep.value % self.every_n_save == 0:
109
  torch.save(self.gnet.state_dict(), os.path.join(
110
+ self.model_checkpoint_dir, f'model_{self.g_ep.value}.pth'))
111
 
112
  def record(self, ep_r, goal_word, action, action_number):
113
  with self.g_ep.get_lock():
main.py CHANGED
@@ -14,13 +14,8 @@ from wordle_env.wordle import WordleEnvBase
14
  def training_mode(args, env, model_checkpoint_dir):
15
  max_ep = args.games
16
  start_time = time.time()
17
- if args.model_name:
18
- pretrained_model_path = os.path.join(
19
- model_checkpoint_dir, args.model_name)
20
- global_ep, win_ep, gnet, res = train(
21
- env, max_ep, model_checkpoint_dir, args.gamma, pretrained_model_path)
22
- else:
23
- global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, args.gamma)
24
  print("--- %.0f seconds ---" % (time.time() - start_time))
25
  print_results(global_ep, win_ep, res)
26
  evaluate(gnet, env)
@@ -56,7 +51,13 @@ if __name__ == "__main__":
56
  parser_train.add_argument(
57
  "--model_name", "-n", help="If want to train from a pretrained model, the name of the pretrained model file")
58
  parser_train.add_argument(
59
- "--gamma", help="Gamma hyperparameter value", type=float, default=0.)
 
 
 
 
 
 
60
  parser_train.set_defaults(func=training_mode)
61
 
62
  parser_eval = subparsers.add_parser(
 
14
  def training_mode(args, env, model_checkpoint_dir):
15
  max_ep = args.games
16
  start_time = time.time()
17
+ pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name) if args.model_name else args.model_name
18
+ global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, args.gamma, pretrained_model_path, args.save, args.min_reward, args.every_n_save)
 
 
 
 
 
19
  print("--- %.0f seconds ---" % (time.time() - start_time))
20
  print_results(global_ep, win_ep, res)
21
  evaluate(gnet, env)
 
51
  parser_train.add_argument(
52
  "--model_name", "-n", help="If want to train from a pretrained model, the name of the pretrained model file")
53
  parser_train.add_argument(
54
+ "--gamma", help="Gamma hyperparameter (discount factor) value", type=float, default=0.)
55
+ parser_train.add_argument(
56
+ "--save", '-s', help="Save instances of the model while training", action='store_true')
57
+ parser_train.add_argument(
58
+ "--min_reward", help="The minimun global reward value achieved for saving the model", type=float, default=9.9)
59
+ parser_train.add_argument(
60
+ "--every_n_save", help="Check every n training steps to save the model", type=int, default=100)
61
  parser_train.set_defaults(func=training_mode)
62
 
63
  parser_eval = subparsers.add_parser(