""" Worker class implementation of the a3c discrete algorithm """ import os import torch import numpy as np import torch.multiprocessing as mp from torch import nn from .net import Net from .utils import v_wrap GAMMA = 0.65 class Worker(mp.Process): def __init__(self, max_ep, gnet, opt, global_ep, global_ep_r, res_queue, name, env, N_S, N_A, words_list, word_width, winning_ep, model_checkpoint_dir): super(Worker, self).__init__() self.max_ep = max_ep self.name = 'w%02i' % name self.g_ep, self.g_ep_r, self.res_queue, self.winning_ep = global_ep, global_ep_r, res_queue, winning_ep self.gnet, self.opt = gnet, opt self.word_list = words_list # local network self.lnet = Net(N_S, N_A, words_list, word_width) self.env = env.unwrapped self.model_checkpoint_dir = model_checkpoint_dir def run(self): while self.g_ep.value < self.max_ep: s = self.env.reset() buffer_s, buffer_a, buffer_r = [], [], [] ep_r = 0. while True: a = self.lnet.choose_action(v_wrap(s[None, :])) s_, r, done, _ = self.env.step(a) ep_r += r buffer_a.append(a) buffer_s.append(s) buffer_r.append(r) if done: # update global and assign to local net # sync self.push_and_pull(done, s_, buffer_s, buffer_a, buffer_r, GAMMA) goal_word = self.word_list[self.env.goal_word] self.record(ep_r, goal_word, self.word_list[a], len(buffer_a)) self.save_model() buffer_s, buffer_a, buffer_r = [], [], [] break s = s_ self.res_queue.put(None) def push_and_pull(self, done, s_, bs, ba, br, gamma): if done: v_s_ = 0. # terminal else: v_s_ = self.lnet.forward(v_wrap( s_[None, :]))[-1].data.numpy()[0, 0] buffer_v_target = [] for r in br[::-1]: # reverse buffer r v_s_ = r + gamma * v_s_ buffer_v_target.append(v_s_) buffer_v_target.reverse() loss = self.lnet.loss_func( v_wrap(np.vstack(bs)), v_wrap(np.array(ba), dtype=np.int64) if ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)), v_wrap(np.array(buffer_v_target)[:, None])) # calculate local gradients and push local parameters to global self.opt.zero_grad() loss.backward() for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()): gp._grad = lp.grad self.opt.step() # pull global parameters self.lnet.load_state_dict(self.gnet.state_dict()) def save_model(self): if self.g_ep_r.value >= 9 and self.g_ep.value % 100 == 0: torch.save(self.gnet.state_dict(), os.path.join( self.model_checkpoint_dir, f'model_{ self.g_ep.value}.pth')) def record(self, ep_r, goal_word, action, action_number): with self.g_ep.get_lock(): self.g_ep.value += 1 with self.g_ep_r.get_lock(): if self.g_ep_r.value == 0.: self.g_ep_r.value = ep_r else: self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01 self.res_queue.put(self.g_ep_r.value) if goal_word == action: self.winning_ep.value += 1 if self.g_ep.value % 100 == 0: print( self.name, "Ep:", self.g_ep.value, "| Ep_r: %.0f" % self.g_ep_r.value, "| Goal :", goal_word, "| Action: ", action, "| Actions: ", action_number )