Spaces:
Sleeping
Sleeping
Add the posiblity to save checkpoints of the model and the condition on which the model is saved as arguments
Browse files- a3c/train.py +4 -2
- a3c/worker.py +27 -3
- main.py +9 -8
a3c/train.py
CHANGED
@@ -6,7 +6,7 @@ from .net import Net
|
|
6 |
from .worker import Worker
|
7 |
|
8 |
|
9 |
-
def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=None):
|
10 |
os.environ["OMP_NUM_THREADS"] = "1"
|
11 |
if not os.path.exists(model_checkpoint_dir):
|
12 |
os.makedirs(model_checkpoint_dir)
|
@@ -23,7 +23,7 @@ def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=Non
|
|
23 |
|
24 |
# parallel training
|
25 |
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
|
26 |
-
words_list, word_width, win_ep, model_checkpoint_dir, gamma, pretrained_model_path) for i in range(mp.cpu_count())]
|
27 |
[w.start() for w in workers]
|
28 |
res = [] # record episode reward to plot
|
29 |
while True:
|
@@ -33,4 +33,6 @@ def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=Non
|
|
33 |
else:
|
34 |
break
|
35 |
[w.join() for w in workers]
|
|
|
|
|
36 |
return global_ep, win_ep, gnet, res
|
|
|
6 |
from .worker import Worker
|
7 |
|
8 |
|
9 |
+
def train(env, max_ep, model_checkpoint_dir, gamma=0., pretrained_model_path=None, save=False, min_reward=9.9, every_n_save=100):
|
10 |
os.environ["OMP_NUM_THREADS"] = "1"
|
11 |
if not os.path.exists(model_checkpoint_dir):
|
12 |
os.makedirs(model_checkpoint_dir)
|
|
|
23 |
|
24 |
# parallel training
|
25 |
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a,
|
26 |
+
words_list, word_width, win_ep, model_checkpoint_dir, gamma, pretrained_model_path, save, min_reward, every_n_save) for i in range(mp.cpu_count())]
|
27 |
[w.start() for w in workers]
|
28 |
res = [] # record episode reward to plot
|
29 |
while True:
|
|
|
33 |
else:
|
34 |
break
|
35 |
[w.join() for w in workers]
|
36 |
+
if save:
|
37 |
+
torch.save(gnet.state_dict(), os.path.join(model_checkpoint_dir, f'model_{env.unwrapped.spec.id}.pth'))
|
38 |
return global_ep, win_ep, gnet, res
|
a3c/worker.py
CHANGED
@@ -11,7 +11,28 @@ from .utils import v_wrap
|
|
11 |
|
12 |
|
13 |
class Worker(mp.Process):
|
14 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
super(Worker, self).__init__()
|
16 |
self.max_ep = max_ep
|
17 |
self.name = 'w%02i' % name
|
@@ -25,6 +46,9 @@ class Worker(mp.Process):
|
|
25 |
self.env = env.unwrapped
|
26 |
self.gamma = gamma
|
27 |
self.model_checkpoint_dir = model_checkpoint_dir
|
|
|
|
|
|
|
28 |
|
29 |
def run(self):
|
30 |
while self.g_ep.value < self.max_ep:
|
@@ -81,9 +105,9 @@ class Worker(mp.Process):
|
|
81 |
self.lnet.load_state_dict(self.gnet.state_dict())
|
82 |
|
83 |
def save_model(self):
|
84 |
-
if self.g_ep_r.value >=
|
85 |
torch.save(self.gnet.state_dict(), os.path.join(
|
86 |
-
self.model_checkpoint_dir, f'model_{
|
87 |
|
88 |
def record(self, ep_r, goal_word, action, action_number):
|
89 |
with self.g_ep.get_lock():
|
|
|
11 |
|
12 |
|
13 |
class Worker(mp.Process):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
max_ep,
|
17 |
+
gnet,
|
18 |
+
opt,
|
19 |
+
global_ep,
|
20 |
+
global_ep_r,
|
21 |
+
res_queue,
|
22 |
+
name,
|
23 |
+
env,
|
24 |
+
N_S,
|
25 |
+
N_A,
|
26 |
+
words_list,
|
27 |
+
word_width,
|
28 |
+
winning_ep,
|
29 |
+
model_checkpoint_dir,
|
30 |
+
gamma=0.,
|
31 |
+
pretrained_model_path=None,
|
32 |
+
save=False,
|
33 |
+
min_reward=9.9,
|
34 |
+
every_n_save=100
|
35 |
+
):
|
36 |
super(Worker, self).__init__()
|
37 |
self.max_ep = max_ep
|
38 |
self.name = 'w%02i' % name
|
|
|
46 |
self.env = env.unwrapped
|
47 |
self.gamma = gamma
|
48 |
self.model_checkpoint_dir = model_checkpoint_dir
|
49 |
+
self.save = save
|
50 |
+
self.min_reward = min_reward
|
51 |
+
self.every_n_save = every_n_save
|
52 |
|
53 |
def run(self):
|
54 |
while self.g_ep.value < self.max_ep:
|
|
|
105 |
self.lnet.load_state_dict(self.gnet.state_dict())
|
106 |
|
107 |
def save_model(self):
|
108 |
+
if self.save and self.g_ep_r.value >= self.min_reward and self.g_ep.value % self.every_n_save == 0:
|
109 |
torch.save(self.gnet.state_dict(), os.path.join(
|
110 |
+
self.model_checkpoint_dir, f'model_{self.g_ep.value}.pth'))
|
111 |
|
112 |
def record(self, ep_r, goal_word, action, action_number):
|
113 |
with self.g_ep.get_lock():
|
main.py
CHANGED
@@ -14,13 +14,8 @@ from wordle_env.wordle import WordleEnvBase
|
|
14 |
def training_mode(args, env, model_checkpoint_dir):
|
15 |
max_ep = args.games
|
16 |
start_time = time.time()
|
17 |
-
if args.model_name
|
18 |
-
|
19 |
-
model_checkpoint_dir, args.model_name)
|
20 |
-
global_ep, win_ep, gnet, res = train(
|
21 |
-
env, max_ep, model_checkpoint_dir, args.gamma, pretrained_model_path)
|
22 |
-
else:
|
23 |
-
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, args.gamma)
|
24 |
print("--- %.0f seconds ---" % (time.time() - start_time))
|
25 |
print_results(global_ep, win_ep, res)
|
26 |
evaluate(gnet, env)
|
@@ -56,7 +51,13 @@ if __name__ == "__main__":
|
|
56 |
parser_train.add_argument(
|
57 |
"--model_name", "-n", help="If want to train from a pretrained model, the name of the pretrained model file")
|
58 |
parser_train.add_argument(
|
59 |
-
"--gamma", help="Gamma hyperparameter value", type=float, default=0.)
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
parser_train.set_defaults(func=training_mode)
|
61 |
|
62 |
parser_eval = subparsers.add_parser(
|
|
|
14 |
def training_mode(args, env, model_checkpoint_dir):
|
15 |
max_ep = args.games
|
16 |
start_time = time.time()
|
17 |
+
pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name) if args.model_name else args.model_name
|
18 |
+
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, args.gamma, pretrained_model_path, args.save, args.min_reward, args.every_n_save)
|
|
|
|
|
|
|
|
|
|
|
19 |
print("--- %.0f seconds ---" % (time.time() - start_time))
|
20 |
print_results(global_ep, win_ep, res)
|
21 |
evaluate(gnet, env)
|
|
|
51 |
parser_train.add_argument(
|
52 |
"--model_name", "-n", help="If want to train from a pretrained model, the name of the pretrained model file")
|
53 |
parser_train.add_argument(
|
54 |
+
"--gamma", help="Gamma hyperparameter (discount factor) value", type=float, default=0.)
|
55 |
+
parser_train.add_argument(
|
56 |
+
"--save", '-s', help="Save instances of the model while training", action='store_true')
|
57 |
+
parser_train.add_argument(
|
58 |
+
"--min_reward", help="The minimun global reward value achieved for saving the model", type=float, default=9.9)
|
59 |
+
parser_train.add_argument(
|
60 |
+
"--every_n_save", help="Check every n training steps to save the model", type=int, default=100)
|
61 |
parser_train.set_defaults(func=training_mode)
|
62 |
|
63 |
parser_eval = subparsers.add_parser(
|