Spaces:
Sleeping
Sleeping
Add posibility to save and load models
Browse filesAlso add an evaluation task to evaluate saved models
- .gitignore +4 -1
- a3c/discrete_A3C.py +9 -6
- a3c/utils.py +6 -0
- main.py +29 -4
.gitignore
CHANGED
@@ -113,4 +113,7 @@ GitHub.sublime-settings
|
|
113 |
!.vscode/tasks.json
|
114 |
!.vscode/launch.json
|
115 |
!.vscode/extensions.json
|
116 |
-
.history
|
|
|
|
|
|
|
|
113 |
!.vscode/tasks.json
|
114 |
!.vscode/launch.json
|
115 |
!.vscode/extensions.json
|
116 |
+
.history
|
117 |
+
|
118 |
+
# PyTorch model files
|
119 |
+
*.pth
|
a3c/discrete_A3C.py
CHANGED
@@ -6,14 +6,14 @@ View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.i
|
|
6 |
"""
|
7 |
import os
|
8 |
import torch.multiprocessing as mp
|
9 |
-
from .utils import v_wrap, push_and_pull, record
|
10 |
from .shared_adam import SharedAdam
|
11 |
from .net import Net
|
12 |
|
13 |
GAMMA = 0.65
|
14 |
|
15 |
class Worker(mp.Process):
|
16 |
-
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep):
|
17 |
super(Worker, self).__init__()
|
18 |
self.max_ep = max_ep
|
19 |
self.name = 'w%02i' % name
|
@@ -22,6 +22,7 @@ class Worker(mp.Process):
|
|
22 |
self.word_list = words_list
|
23 |
self.lnet = Net(N_S, N_A, words_list, word_width) # local network
|
24 |
self.env = env.unwrapped
|
|
|
25 |
|
26 |
def run(self):
|
27 |
while self.g_ep.value < self.max_ep:
|
@@ -40,16 +41,18 @@ class Worker(mp.Process):
|
|
40 |
# sync
|
41 |
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
|
42 |
goal_word = self.word_list[self.env.goal_word]
|
43 |
-
record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
|
|
|
44 |
buffer_s, buffer_a, buffer_r = [], [], []
|
45 |
break
|
46 |
s = s_
|
47 |
self.res_queue.put(None)
|
48 |
|
49 |
|
50 |
-
def train(env, max_ep):
|
51 |
os.environ["OMP_NUM_THREADS"] = "1"
|
52 |
-
|
|
|
53 |
n_s = env.observation_space.shape[0]
|
54 |
n_a = env.action_space.n
|
55 |
words_list = env.words
|
@@ -60,7 +63,7 @@ def train(env, max_ep):
|
|
60 |
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
61 |
|
62 |
# parallel training
|
63 |
-
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep) for i in range(mp.cpu_count())]
|
64 |
[w.start() for w in workers]
|
65 |
res = [] # record episode reward to plot
|
66 |
while True:
|
|
|
6 |
"""
|
7 |
import os
|
8 |
import torch.multiprocessing as mp
|
9 |
+
from .utils import v_wrap, push_and_pull, record, save_model
|
10 |
from .shared_adam import SharedAdam
|
11 |
from .net import Net
|
12 |
|
13 |
GAMMA = 0.65
|
14 |
|
15 |
class Worker(mp.Process):
|
16 |
+
def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir):
|
17 |
super(Worker, self).__init__()
|
18 |
self.max_ep = max_ep
|
19 |
self.name = 'w%02i' % name
|
|
|
22 |
self.word_list = words_list
|
23 |
self.lnet = Net(N_S, N_A, words_list, word_width) # local network
|
24 |
self.env = env.unwrapped
|
25 |
+
self.model_checkpoint_dir = model_checkpoint_dir
|
26 |
|
27 |
def run(self):
|
28 |
while self.g_ep.value < self.max_ep:
|
|
|
41 |
# sync
|
42 |
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
|
43 |
goal_word = self.word_list[self.env.goal_word]
|
44 |
+
record( self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
|
45 |
+
save_model(self.gnet, self.model_checkpoint_dir, self.g_ep.value, self.g_ep_r.value)
|
46 |
buffer_s, buffer_a, buffer_r = [], [], []
|
47 |
break
|
48 |
s = s_
|
49 |
self.res_queue.put(None)
|
50 |
|
51 |
|
52 |
+
def train(env, max_ep, model_checkpoint_dir):
|
53 |
os.environ["OMP_NUM_THREADS"] = "1"
|
54 |
+
if not os.path.exists(model_checkpoint_dir):
|
55 |
+
os.makedirs(model_checkpoint_dir)
|
56 |
n_s = env.observation_space.shape[0]
|
57 |
n_a = env.action_space.n
|
58 |
words_list = env.words
|
|
|
63 |
global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)
|
64 |
|
65 |
# parallel training
|
66 |
+
workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep, model_checkpoint_dir) for i in range(mp.cpu_count())]
|
67 |
[w.start() for w in workers]
|
68 |
res = [] # record episode reward to plot
|
69 |
while True:
|
a3c/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
"""
|
2 |
Functions that use multiple times
|
3 |
"""
|
|
|
4 |
from torch import nn
|
5 |
import torch
|
6 |
import numpy as np
|
@@ -46,6 +47,11 @@ def push_and_pull(opt, lnet, gnet, done, s_, bs, ba, br, gamma):
|
|
46 |
lnet.load_state_dict(gnet.state_dict())
|
47 |
|
48 |
|
|
|
|
|
|
|
|
|
|
|
49 |
def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, action_number, winning_ep):
|
50 |
with global_ep.get_lock():
|
51 |
global_ep.value += 1
|
|
|
1 |
"""
|
2 |
Functions that use multiple times
|
3 |
"""
|
4 |
+
import os
|
5 |
from torch import nn
|
6 |
import torch
|
7 |
import numpy as np
|
|
|
47 |
lnet.load_state_dict(gnet.state_dict())
|
48 |
|
49 |
|
50 |
+
def save_model(gnet, dir, episode, reward):
|
51 |
+
if reward >= 9 and episode % 100 == 0:
|
52 |
+
torch.save(gnet.state_dict(), os.path.join(dir, f'model_{episode}.pth'))
|
53 |
+
|
54 |
+
|
55 |
def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, action_number, winning_ep):
|
56 |
with global_ep.get_lock():
|
57 |
global_ep.value += 1
|
main.py
CHANGED
@@ -1,10 +1,29 @@
|
|
1 |
import sys
|
|
|
2 |
import gym
|
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
from a3c.discrete_A3C import train
|
5 |
from a3c.utils import v_wrap
|
|
|
6 |
from wordle_env.wordle import WordleEnvBase
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def evaluate(net, env):
|
10 |
print("Evaluation mode")
|
@@ -21,9 +40,9 @@ def evaluate(net, env):
|
|
21 |
# else:
|
22 |
# print("Lost!", goal_word, outcomes)
|
23 |
n_guesses += len(outcomes)
|
24 |
-
|
25 |
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
26 |
f"{n_guesses / N} including losses.")
|
|
|
27 |
|
28 |
def play(net, env):
|
29 |
state = env.reset()
|
@@ -51,7 +70,13 @@ def print_results(global_ep, win_ep, res):
|
|
51 |
if __name__ == "__main__":
|
52 |
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
53 |
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
|
|
54 |
env = gym.make(env_id)
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import sys
|
2 |
+
import os
|
3 |
import gym
|
4 |
+
import torch
|
5 |
import matplotlib.pyplot as plt
|
6 |
from a3c.discrete_A3C import train
|
7 |
from a3c.utils import v_wrap
|
8 |
+
from a3c.net import Net
|
9 |
from wordle_env.wordle import WordleEnvBase
|
10 |
|
11 |
+
def evaluate_checkpoints(dir, env):
|
12 |
+
n_s = env.observation_space.shape[0]
|
13 |
+
n_a = env.action_space.n
|
14 |
+
words_list = env.words
|
15 |
+
word_width = len(env.words[0])
|
16 |
+
net = Net(n_s, n_a, words_list, word_width)
|
17 |
+
results = {}
|
18 |
+
print(dir)
|
19 |
+
for checkpoint in os.listdir(dir):
|
20 |
+
checkpoint_path = os.path.join(dir, checkpoint)
|
21 |
+
if os.path.isfile(checkpoint_path):
|
22 |
+
net.load_state_dict(torch.load(checkpoint_path))
|
23 |
+
wins, guesses = evaluate(net, env)
|
24 |
+
results[checkpoint] = wins, guesses
|
25 |
+
return dict(sorted(results.items(), key=lambda x: (x[1][0], -x[1][1]), reverse=True))
|
26 |
+
|
27 |
|
28 |
def evaluate(net, env):
|
29 |
print("Evaluation mode")
|
|
|
40 |
# else:
|
41 |
# print("Lost!", goal_word, outcomes)
|
42 |
n_guesses += len(outcomes)
|
|
|
43 |
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
44 |
f"{n_guesses / N} including losses.")
|
45 |
+
return n_wins/N*100, n_win_guesses/n_wins
|
46 |
|
47 |
def play(net, env):
|
48 |
state = env.reset()
|
|
|
70 |
if __name__ == "__main__":
|
71 |
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
72 |
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
73 |
+
evaluation = True if len(sys.argv) > 3 and sys.argv[3] == 'evaluation' else False
|
74 |
env = gym.make(env_id)
|
75 |
+
model_checkpoint_dir = os.path.join('checkpoints', env.unwrapped.spec.id)
|
76 |
+
if not evaluation:
|
77 |
+
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir)
|
78 |
+
print_results(global_ep, win_ep, res)
|
79 |
+
evaluate(gnet, env)
|
80 |
+
else:
|
81 |
+
results = evaluate_checkpoints(model_checkpoint_dir, env)
|
82 |
+
print(results)
|