Spaces:
Sleeping
Sleeping
Add possibility to train from a pretrained model
Browse files- a3c/discrete_A3C.py +4 -2
- a3c/worker.py +3 -1
- main.py +6 -1
a3c/discrete_A3C.py
CHANGED
@@ -13,7 +13,7 @@ from .utils import v_wrap
|
|
13 |
from .worker import Worker
|
14 |
|
15 |
|
16 |
-
def train(env, max_ep, model_checkpoint_dir):
|
17 |
os.environ["OMP_NUM_THREADS"] = "1"
|
18 |
if not os.path.exists(model_checkpoint_dir):
|
19 |
os.makedirs(model_checkpoint_dir)
|
@@ -22,13 +22,15 @@ def train(env, max_ep, model_checkpoint_dir):
|
|
22 |
words_list = env.words
|
23 |
word_width = len(env.words[0])
|
24 |
gnet = Net(n_s, n_a, words_list, word_width) # global network
|
|
|
|
|
25 |
gnet.share_memory() # share the global parameters in multiprocessing
|
26 |
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
|
27 |
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
28 |
|
29 |
# parallel training
|
30 |
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
|
31 |
-
words_list, word_width, win_ep, model_checkpoint_dir) for i in range(mp.cpu_count())]
|
32 |
[w.start() for w in workers]
|
33 |
res = [] # record episode reward to plot
|
34 |
while True:
|
|
|
13 |
from .worker import Worker
|
14 |
|
15 |
|
16 |
+
def train(env, max_ep, model_checkpoint_dir, pretrained_model_path=None):
|
17 |
os.environ["OMP_NUM_THREADS"] = "1"
|
18 |
if not os.path.exists(model_checkpoint_dir):
|
19 |
os.makedirs(model_checkpoint_dir)
|
|
|
22 |
words_list = env.words
|
23 |
word_width = len(env.words[0])
|
24 |
gnet = Net(n_s, n_a, words_list, word_width) # global network
|
25 |
+
if pretrained_model_path:
|
26 |
+
gnet.load_state_dict(torch.load(pretrained_model_path))
|
27 |
gnet.share_memory() # share the global parameters in multiprocessing
|
28 |
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
|
29 |
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
30 |
|
31 |
# parallel training
|
32 |
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
|
33 |
+
words_list, word_width, win_ep, model_checkpoint_dir, pretrained_model_path) for i in range(mp.cpu_count())]
|
34 |
[w.start() for w in workers]
|
35 |
res = [] # record episode reward to plot
|
36 |
while True:
|
a3c/worker.py
CHANGED
@@ -14,7 +14,7 @@ GAMMA = 0.65
|
|
14 |
|
15 |
|
16 |
class Worker(mp.Process):
|
17 |
-
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir):
|
18 |
super(Worker, self).__init__()
|
19 |
self.max_ep = max_ep
|
20 |
self.name = 'w%02i' % name
|
@@ -23,6 +23,8 @@ class Worker(mp.Process):
|
|
23 |
self.word_list = words_list
|
24 |
# local network
|
25 |
self.lnet = Net(N_S, N_A, words_list, word_width)
|
|
|
|
|
26 |
self.env = env.unwrapped
|
27 |
self.model_checkpoint_dir = model_checkpoint_dir
|
28 |
|
|
|
14 |
|
15 |
|
16 |
class Worker(mp.Process):
|
17 |
+
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir, pretrained_model_path=None):
|
18 |
super(Worker, self).__init__()
|
19 |
self.max_ep = max_ep
|
20 |
self.name = 'w%02i' % name
|
|
|
23 |
self.word_list = words_list
|
24 |
# local network
|
25 |
self.lnet = Net(N_S, N_A, words_list, word_width)
|
26 |
+
if pretrained_model_path:
|
27 |
+
self.lnet.load_state_dict(torch.load(pretrained_model_path))
|
28 |
self.env = env.unwrapped
|
29 |
self.model_checkpoint_dir = model_checkpoint_dir
|
30 |
|
main.py
CHANGED
@@ -20,11 +20,16 @@ if __name__ == "__main__":
|
|
20 |
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
21 |
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
22 |
evaluation = True if len(sys.argv) > 3 and sys.argv[3] == 'evaluation' else False
|
|
|
23 |
env = gym.make(env_id)
|
24 |
model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
|
25 |
if not evaluation:
|
26 |
start_time = time.time()
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
print("--- %.0f seconds ---" % (time.time() - start_time))
|
29 |
print_results(global_ep, win_ep, res)
|
30 |
evaluate(gnet, env)
|
|
|
20 |
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
21 |
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
22 |
evaluation = True if len(sys.argv) > 3 and sys.argv[3] == 'evaluation' else False
|
23 |
+
pretrained = True if len(sys.argv) > 3 and sys.argv[3] == 'pretrained' else False
|
24 |
env = gym.make(env_id)
|
25 |
model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
|
26 |
if not evaluation:
|
27 |
start_time = time.time()
|
28 |
+
if pretrained:
|
29 |
+
pretrained_model_path = os.path.join(model_checkpoint_dir, sys.argv[4]) if len(sys.argv) > 4 else ''
|
30 |
+
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, pretrained_model_path)
|
31 |
+
else:
|
32 |
+
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
|
33 |
print("--- %.0f seconds ---" % (time.time() - start_time))
|
34 |
print_results(global_ep, win_ep, res)
|
35 |
evaluate(gnet, env)
|