File size: 3,324 Bytes
44db2f9
 
 
 
 
 
1bd428f
44db2f9
676caef
1bd428f
 
44db2f9
1bd428f
44db2f9
 
676caef
44db2f9
f05ece6
44db2f9
62c6c3b
44db2f9
abff1ef
 
 
676caef
44db2f9
 
f05ece6
44db2f9
 
 
 
 
f05ece6
44db2f9
 
 
 
 
f05ece6
44db2f9
 
f05ece6
676caef
 
62c6c3b
f05ece6
44db2f9
 
 
 
676caef
1bd428f
676caef
 
1bd428f
 
 
 
 
 
 
 
 
 
676caef
1bd428f
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Reinforcement Learning (A3C) using Pytroch + multiprocessing.
The most simple implementation for continuous action.

View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
"""
import os
import torch.multiprocessing as mp
from .utils import v_wrap, push_and_pull, record, save_model
from .shared_adam import SharedAdam
from .net import Net

GAMMA = 0.65

class Worker(mp.Process):
    def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir):
        super(Worker, self).__init__()
        self.max_ep = max_ep
        self.name = 'w%02i' % name
        self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep
        self.gnet, self.opt = gnet, opt
        self.word_list = words_list
        self.lnet = Net(N_S, N_A, words_list, word_width)           # local network
        self.env = env.unwrapped
        self.model_checkpoint_dir = model_checkpoint_dir

    def run(self):
        while self.g_ep.value < self.max_ep:
            s = self.env.reset()
            buffer_s, buffer_a, buffer_r = [], [], []
            ep_r = 0.
            while True:
                a = self.lnet.choose_action(v_wrap(s[None, :]))
                s_, r, done, _ = self.env.step(a)
                ep_r += r
                buffer_a.append(a)
                buffer_s.append(s)
                buffer_r.append(r)

                if done:  # update global and assign to local net
                    # sync
                    push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
                    goal_word = self.word_list[self.env.goal_word]
                    record( self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name, goal_word, self.word_list[a], len(buffer_a), self.winning_ep)
                    save_model(self.gnet, self.model_checkpoint_dir, self.g_ep.value, self.g_ep_r.value)
                    buffer_s, buffer_a, buffer_r = [], [], []
                    break
                s = s_
        self.res_queue.put(None)


def train(env, max_ep, model_checkpoint_dir):
    os.environ["OMP_NUM_THREADS"] = "1"
    if not os.path.exists(model_checkpoint_dir):
        os.makedirs(model_checkpoint_dir)
    n_s = env.observation_space.shape[0]
    n_a = env.action_space.n
    words_list = env.words
    word_width = len(env.words[0])
    gnet = Net(n_s, n_a, words_list, word_width)        # global network
    gnet.share_memory()         # share the global parameters in multiprocessing
    opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999))      # global optimizer
    global_ep, global_ep_r, res_queue, win_ep = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue(), mp.Value('i', 0)

    # parallel training
    workers = [Worker(max_ep, gnet, opt, global_ep, global_ep_r, res_queue, i, env, n_s, n_a, words_list, word_width, win_ep, model_checkpoint_dir) for i in range(mp.cpu_count())]
    [w.start() for w in workers]
    res = []                    # record episode reward to plot
    while True:
        r = res_queue.get()
        if r is not None:
            res.append(r)
        else:
            break
    [w.join() for w in workers]
    return global_ep, win_ep, gnet, res