File size: 6,955 Bytes
f761808 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import copy
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from agent.model import MLP, Critic
from agent.diffusion import Diffusion
from agent.vae import VAE
from agent.helpers import EMA
class DiPo(object):
def __init__(self,
args,
state_dim,
action_space,
memory,
diffusion_memory,
device,
):
action_dim = np.prod(action_space.shape)
self.policy_type = args.policy_type
if self.policy_type == 'Diffusion':
self.actor = Diffusion(state_dim=state_dim, action_dim=action_dim, noise_ratio=args.noise_ratio,
beta_schedule=args.beta_schedule, n_timesteps=args.n_timesteps).to(device)
elif self.policy_type == 'VAE':
self.actor = VAE(state_dim=state_dim, action_dim=action_dim, device=device).to(device)
else:
self.actor = MLP(state_dim=state_dim, action_dim=action_dim).to(device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=args.diffusion_lr, eps=1e-5)
self.memory = memory
self.diffusion_memory = diffusion_memory
self.action_gradient_steps = args.action_gradient_steps
self.action_grad_norm = action_dim * args.ratio
self.ac_grad_norm = args.ac_grad_norm
self.step = 0
self.tau = args.tau
self.actor_target = copy.deepcopy(self.actor)
self.update_actor_target_every = args.update_actor_target_every
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = copy.deepcopy(self.critic)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=args.critic_lr, eps=1e-5)
self.action_dim = action_dim
self.action_lr = args.action_lr
self.device = device
if action_space is None:
self.action_scale = 1.
self.action_bias = 0.
else:
self.action_scale = (action_space.high - action_space.low) / 2.
self.action_bias = (action_space.high + action_space.low) / 2.
def append_memory(self, state, action, reward, next_state, mask):
action = (action - self.action_bias) / self.action_scale
self.memory.append(state, action, reward, next_state, mask)
self.diffusion_memory.append(state, action)
def sample_action(self, state, eval=False):
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
action = self.actor(state, eval).cpu().data.numpy().flatten()
action = action.clip(-1, 1)
action = action * self.action_scale + self.action_bias
return action
def action_gradient(self, batch_size, log_writer):
states, best_actions, idxs = self.diffusion_memory.sample(batch_size)
actions_optim = torch.optim.Adam([best_actions], lr=self.action_lr, eps=1e-5)
for i in range(self.action_gradient_steps):
best_actions.requires_grad_(True)
q1, q2 = self.critic(states, best_actions)
loss = -torch.min(q1, q2)
actions_optim.zero_grad()
loss.backward(torch.ones_like(loss))
if self.action_grad_norm > 0:
actions_grad_norms = nn.utils.clip_grad_norm_([best_actions], max_norm=self.action_grad_norm, norm_type=2)
actions_optim.step()
best_actions.requires_grad_(False)
best_actions.clamp_(-1., 1.)
# if self.step % 10 == 0:
# log_writer.add_scalar('Action Grad Norm', actions_grad_norms.max().item(), self.step)
best_actions = best_actions.detach()
self.diffusion_memory.replace(idxs, best_actions.cpu().numpy())
return states, best_actions
def train(self, iterations, batch_size=256, log_writer=None):
for _ in range(iterations):
# Sample replay buffer / batch
states, actions, rewards, next_states, masks = self.memory.sample(batch_size)
""" Q Training """
current_q1, current_q2 = self.critic(states, actions)
next_actions = self.actor_target(next_states, eval=False)
target_q1, target_q2 = self.critic_target(next_states, next_actions)
target_q = torch.min(target_q1, target_q2)
target_q = (rewards + masks * target_q).detach()
critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
if self.ac_grad_norm > 0:
critic_grad_norms = nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=self.ac_grad_norm, norm_type=2)
# if self.step % 10 == 0:
# log_writer.add_scalar('Critic Grad Norm', critic_grad_norms.max().item(), self.step)
self.critic_optimizer.step()
""" Policy Training """
states, best_actions = self.action_gradient(batch_size, log_writer)
actor_loss = self.actor.loss(best_actions, states)
self.actor_optimizer.zero_grad()
actor_loss.backward()
if self.ac_grad_norm > 0:
actor_grad_norms = nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.ac_grad_norm, norm_type=2)
# if self.step % 10 == 0:
# log_writer.add_scalar('Actor Grad Norm', actor_grad_norms.max().item(), self.step)
self.actor_optimizer.step()
""" Step Target network """
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
if self.step % self.update_actor_target_every == 0:
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
self.step += 1
def save_model(self, dir, id=None):
if id is not None:
torch.save(self.actor.state_dict(), f'{dir}/actor_{id}.pth')
torch.save(self.critic.state_dict(), f'{dir}/critic_{id}.pth')
else:
torch.save(self.actor.state_dict(), f'{dir}/actor.pth')
torch.save(self.critic.state_dict(), f'{dir}/critic.pth')
def load_model(self, dir, id=None):
if id is not None:
self.actor.load_state_dict(torch.load(f'{dir}/actor_{id}.pth'))
self.critic.load_state_dict(torch.load(f'{dir}/critic_{id}.pth'))
else:
self.actor.load_state_dict(torch.load(f'{dir}/actor.pth'))
self.critic.load_state_dict(torch.load(f'{dir}/critic.pth'))
|