santit96 commited on
Commit
18a7031
·
1 Parent(s): 570282c

Add possibility to train from a pretrained model

Browse files
Files changed (3) hide show
  1. a3c/discrete_A3C.py +4 -2
  2. a3c/worker.py +3 -1
  3. main.py +6 -1
a3c/discrete_A3C.py CHANGED
@@ -13,7 +13,7 @@ from .utils import v_wrap
13
  from .worker import Worker
14
 
15
 
16
- def train(env, max_ep, model_checkpoint_dir):
17
  os.environ["OMP_NUM_THREADS"] = "1"
18
  if not os.path.exists(model_checkpoint_dir):
19
  os.makedirs(model_checkpoint_dir)
@@ -22,13 +22,15 @@ def train(env, max_ep, model_checkpoint_dir):
22
  words_list = env.words
23
  word_width = len(env.words[0])
24
  gnet = Net(n_s, n_a, words_list, word_width) # global network
 
 
25
  gnet.share_memory() # share the global parameters in multiprocessing
26
  opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
27
  global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
28
 
29
  # parallel training
30
  workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
31
- words_list, word_width, win_ep, model_checkpoint_dir) for i in range(mp.cpu_count())]
32
  [w.start() for w in workers]
33
  res = [] # record episode reward to plot
34
  while True:
 
13
  from .worker import Worker
14
 
15
 
16
+ def train(env, max_ep, model_checkpoint_dir, pretrained_model_path=None):
17
  os.environ["OMP_NUM_THREADS"] = "1"
18
  if not os.path.exists(model_checkpoint_dir):
19
  os.makedirs(model_checkpoint_dir)
 
22
  words_list = env.words
23
  word_width = len(env.words[0])
24
  gnet = Net(n_s, n_a, words_list, word_width) # global network
25
+ if pretrained_model_path:
26
+ gnet.load_state_dict(torch.load(pretrained_model_path))
27
  gnet.share_memory() # share the global parameters in multiprocessing
28
  opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
29
  global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
30
 
31
  # parallel training
32
  workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
33
+ words_list, word_width, win_ep, model_checkpoint_dir, pretrained_model_path) for i in range(mp.cpu_count())]
34
  [w.start() for w in workers]
35
  res = [] # record episode reward to plot
36
  while True:
a3c/worker.py CHANGED
@@ -14,7 +14,7 @@ GAMMA = 0.65
14
 
15
 
16
  class Worker(mp.Process):
17
- def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir):
18
  super(Worker, self).__init__()
19
  self.max_ep = max_ep
20
  self.name = 'w%02i' % name
@@ -23,6 +23,8 @@ class Worker(mp.Process):
23
  self.word_list = words_list
24
  # local network
25
  self.lnet = Net(N_S, N_A, words_list, word_width)
 
 
26
  self.env = env.unwrapped
27
  self.model_checkpoint_dir = model_checkpoint_dir
28
 
 
14
 
15
 
16
  class Worker(mp.Process):
17
+ def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir, pretrained_model_path=None):
18
  super(Worker, self).__init__()
19
  self.max_ep = max_ep
20
  self.name = 'w%02i' % name
 
23
  self.word_list = words_list
24
  # local network
25
  self.lnet = Net(N_S, N_A, words_list, word_width)
26
+ if pretrained_model_path:
27
+ self.lnet.load_state_dict(torch.load(pretrained_model_path))
28
  self.env = env.unwrapped
29
  self.model_checkpoint_dir = model_checkpoint_dir
30
 
main.py CHANGED
@@ -20,11 +20,16 @@ if __name__ == "__main__":
20
  max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
21
  env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
22
  evaluation = True if len(sys.argv) > 3 and sys.argv[3] == 'evaluation' else False
 
23
  env = gym.make(env_id)
24
  model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
25
  if not evaluation:
26
  start_time = time.time()
27
- global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
 
 
 
 
28
  print("--- %.0f seconds ---" % (time.time() - start_time))
29
  print_results(global_ep, win_ep, res)
30
  evaluate(gnet, env)
 
20
  max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
21
  env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
22
  evaluation = True if len(sys.argv) > 3 and sys.argv[3] == 'evaluation' else False
23
+ pretrained = True if len(sys.argv) > 3 and sys.argv[3] == 'pretrained' else False
24
  env = gym.make(env_id)
25
  model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
26
  if not evaluation:
27
  start_time = time.time()
28
+ if pretrained:
29
+ pretrained_model_path = os.path.join(model_checkpoint_dir, sys.argv[4]) if len(sys.argv) > 4 else ''
30
+ global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, pretrained_model_path)
31
+ else:
32
+ global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
33
  print("--- %.0f seconds ---" % (time.time() - start_time))
34
  print_results(global_ep, win_ep, res)
35
  evaluate(gnet, env)