wordle-solver / a3c /discrete_A3C.py
santit96's picture
Commiting wordle solver project, state completed, a3c not completed yet
44db2f9
raw
history blame
3.23 kB
"""
Reinforcement Learning (A3C) using Pytroch + multiprocessing.
The most simple implementation for continuous action.
View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import torch.multiprocessing as mp
from .utils import v_wrap, set_init, push_and_pull, record
UPDATE_GLOBAL_ITER = 5
GAMMA = 0.9
MAX_EP = 3000
class Net(nn.Module):
def __init__(self, s_dim, a_dim):
super(Net, self).__init__()
self.s_dim = s_dim
self.a_dim = a_dim
self.pi1 = nn.Linear(s_dim, 128)
self.pi2 = nn.Linear(128, a_dim)
self.v1 = nn.Linear(s_dim, 128)
self.v2 = nn.Linear(128, 1)
set_init([self.pi1, self.pi2, self.v1, self.v2])
self.distribution = torch.distributions.Categorical
def forward(self, x):
pi1 = torch.tanh(self.pi1(x))
logits = self.pi2(pi1)
v1 = torch.tanh(self.v1(x))
values = self.v2(v1)
return logits, values
def choose_action(self, s):
self.eval()
logits, _ = self.forward(s)
prob = F.softmax(logits, dim=1).data
m = self.distribution(prob)
return m.sample().numpy()[0]
def loss_func(self, s, a, v_t):
self.train()
logits, values = self.forward(s)
td = v_t - values
c_loss = td.pow(2)
probs = F.softmax(logits, dim=1)
m = self.distribution(probs)
exp_v = m.log_prob(a) * td.detach().squeeze()
a_loss = -exp_v
total_loss = (c_loss + a_loss).mean()
return total_loss
class Worker(mp.Process):
def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name, N_S, N_A):
super(Worker, self).__init__()
self.name = 'w%02i' % name
self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
self.gnet, self.opt = gnet, opt
self.lnet = Net(N_S, N_A) # local network
self.env = gym.make('WordleEnv100OneAction-v0').unwrapped
def run(self):
total_step = 1
while self.g_ep.value < MAX_EP:
s = self.env.reset()
buffer_s, buffer_a, buffer_r = [], [], []
ep_r = 0.
while True:
if self.name == 'w00':
self.env.render()
a = self.lnet.choose_action(v_wrap(s[None, :]))
s_, r, done, _ = self.env.step(a)
if done: r = -1
ep_r += r
buffer_a.append(a)
buffer_s.append(s)
buffer_r.append(r)
if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net
# sync
push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
buffer_s, buffer_a, buffer_r = [], [], []
if done: # done and print information
record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.name)
break
s = s_
total_step += 1
self.res_queue.put(None)