wordle-solver / a3c /utils.py
santit96's picture
Add posibility to save and load models
676caef
raw
history blame
2.21 kB
"""
Functions that use multiple times
"""
import os
from torch import nn
import torch
import numpy as np
def v_wrap(np_array, dtype=np.float32):
if np_array.dtype != dtype:
np_array = np_array.astype(dtype)
return torch.from_numpy(np_array)
def set_init(layers):
for layer in layers:
nn.init.normal_(layer.weight, mean=0., std=0.1)
nn.init.constant_(layer.bias, 0.)
def push_and_pull(opt, lnet, gnet, done, s_, bs, ba, br, gamma):
if done:
v_s_ = 0. # terminal
else:
v_s_ = lnet.forward(v_wrap(s_[None, :]))[-1].data.numpy()[0, 0]
buffer_v_target = []
for r in br[::-1]: # reverse buffer r
v_s_ = r + gamma * v_s_
buffer_v_target.append(v_s_)
buffer_v_target.reverse()
loss = lnet.loss_func(
v_wrap(np.vstack(bs)),
v_wrap(np.array(ba), dtype=np.int64) if ba[0].dtype == np.int64 else v_wrap(np.vstack(ba)),
v_wrap(np.array(buffer_v_target)[:, None]))
# calculate local gradients and push local parameters to global
opt.zero_grad()
loss.backward()
for lp, gp in zip(lnet.parameters(), gnet.parameters()):
gp._grad = lp.grad
opt.step()
# pull global parameters
lnet.load_state_dict(gnet.state_dict())
def save_model(gnet, dir, episode, reward):
if reward >= 9 and episode % 100 == 0:
torch.save(gnet.state_dict(), os.path.join(dir, f'model_{episode}.pth'))
def record(global_ep, global_ep_r, ep_r, res_queue, name, goal_word, action, action_number, winning_ep):
with global_ep.get_lock():
global_ep.value += 1
with global_ep_r.get_lock():
if global_ep_r.value == 0.:
global_ep_r.value = ep_r
else:
global_ep_r.value = global_ep_r.value * 0.99 + ep_r * 0.01
res_queue.put(global_ep_r.value)
if goal_word == action:
winning_ep.value += 1
if global_ep.value % 100 == 0:
print(
name,
"Ep:", global_ep.value,
"| Ep_r: %.0f" % global_ep_r.value,
"| Goal :", goal_word,
"| Action: ", action,
"| Actions: ", action_number
)