|
import copy |
|
import numpy as np |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
|
from agent.model import MLP, Critic |
|
from agent.diffusion import Diffusion |
|
from agent.vae import VAE |
|
from agent.helpers import EMA |
|
|
|
|
|
class DiPo(object): |
|
def __init__(self, |
|
args, |
|
state_dim, |
|
action_space, |
|
memory, |
|
diffusion_memory, |
|
device, |
|
): |
|
action_dim = np.prod(action_space.shape) |
|
|
|
self.policy_type = args.policy_type |
|
if self.policy_type == 'Diffusion': |
|
self.actor = Diffusion(state_dim=state_dim, action_dim=action_dim, noise_ratio=args.noise_ratio, |
|
beta_schedule=args.beta_schedule, n_timesteps=args.n_timesteps).to(device) |
|
elif self.policy_type == 'VAE': |
|
self.actor = VAE(state_dim=state_dim, action_dim=action_dim, device=device).to(device) |
|
else: |
|
self.actor = MLP(state_dim=state_dim, action_dim=action_dim).to(device) |
|
|
|
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=args.diffusion_lr, eps=1e-5) |
|
|
|
self.memory = memory |
|
self.diffusion_memory = diffusion_memory |
|
self.action_gradient_steps = args.action_gradient_steps |
|
|
|
self.action_grad_norm = action_dim * args.ratio |
|
self.ac_grad_norm = args.ac_grad_norm |
|
|
|
self.step = 0 |
|
self.tau = args.tau |
|
self.actor_target = copy.deepcopy(self.actor) |
|
self.update_actor_target_every = args.update_actor_target_every |
|
|
|
self.critic = Critic(state_dim, action_dim).to(device) |
|
self.critic_target = copy.deepcopy(self.critic) |
|
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=args.critic_lr, eps=1e-5) |
|
|
|
self.action_dim = action_dim |
|
|
|
self.action_lr = args.action_lr |
|
|
|
self.device = device |
|
|
|
if action_space is None: |
|
self.action_scale = 1. |
|
self.action_bias = 0. |
|
else: |
|
self.action_scale = (action_space.high - action_space.low) / 2. |
|
self.action_bias = (action_space.high + action_space.low) / 2. |
|
|
|
def append_memory(self, state, action, reward, next_state, mask): |
|
action = (action - self.action_bias) / self.action_scale |
|
|
|
self.memory.append(state, action, reward, next_state, mask) |
|
self.diffusion_memory.append(state, action) |
|
|
|
def sample_action(self, state, eval=False): |
|
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) |
|
|
|
action = self.actor(state, eval).cpu().data.numpy().flatten() |
|
action = action.clip(-1, 1) |
|
action = action * self.action_scale + self.action_bias |
|
return action |
|
|
|
def action_gradient(self, batch_size, log_writer): |
|
states, best_actions, idxs = self.diffusion_memory.sample(batch_size) |
|
|
|
actions_optim = torch.optim.Adam([best_actions], lr=self.action_lr, eps=1e-5) |
|
|
|
|
|
for i in range(self.action_gradient_steps): |
|
best_actions.requires_grad_(True) |
|
q1, q2 = self.critic(states, best_actions) |
|
loss = -torch.min(q1, q2) |
|
|
|
actions_optim.zero_grad() |
|
|
|
loss.backward(torch.ones_like(loss)) |
|
if self.action_grad_norm > 0: |
|
actions_grad_norms = nn.utils.clip_grad_norm_([best_actions], max_norm=self.action_grad_norm, norm_type=2) |
|
|
|
actions_optim.step() |
|
|
|
best_actions.requires_grad_(False) |
|
best_actions.clamp_(-1., 1.) |
|
|
|
|
|
|
|
|
|
best_actions = best_actions.detach() |
|
|
|
self.diffusion_memory.replace(idxs, best_actions.cpu().numpy()) |
|
|
|
return states, best_actions |
|
|
|
def train(self, iterations, batch_size=256, log_writer=None): |
|
for _ in range(iterations): |
|
|
|
states, actions, rewards, next_states, masks = self.memory.sample(batch_size) |
|
|
|
""" Q Training """ |
|
current_q1, current_q2 = self.critic(states, actions) |
|
|
|
next_actions = self.actor_target(next_states, eval=False) |
|
target_q1, target_q2 = self.critic_target(next_states, next_actions) |
|
target_q = torch.min(target_q1, target_q2) |
|
|
|
target_q = (rewards + masks * target_q).detach() |
|
|
|
critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) |
|
|
|
self.critic_optimizer.zero_grad() |
|
critic_loss.backward() |
|
if self.ac_grad_norm > 0: |
|
critic_grad_norms = nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=self.ac_grad_norm, norm_type=2) |
|
|
|
|
|
self.critic_optimizer.step() |
|
|
|
""" Policy Training """ |
|
states, best_actions = self.action_gradient(batch_size, log_writer) |
|
|
|
actor_loss = self.actor.loss(best_actions, states) |
|
|
|
self.actor_optimizer.zero_grad() |
|
actor_loss.backward() |
|
if self.ac_grad_norm > 0: |
|
actor_grad_norms = nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.ac_grad_norm, norm_type=2) |
|
|
|
|
|
self.actor_optimizer.step() |
|
|
|
""" Step Target network """ |
|
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): |
|
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) |
|
|
|
if self.step % self.update_actor_target_every == 0: |
|
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): |
|
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) |
|
|
|
self.step += 1 |
|
|
|
def save_model(self, dir, id=None): |
|
if id is not None: |
|
torch.save(self.actor.state_dict(), f'{dir}/actor_{id}.pth') |
|
torch.save(self.critic.state_dict(), f'{dir}/critic_{id}.pth') |
|
else: |
|
torch.save(self.actor.state_dict(), f'{dir}/actor.pth') |
|
torch.save(self.critic.state_dict(), f'{dir}/critic.pth') |
|
|
|
def load_model(self, dir, id=None): |
|
if id is not None: |
|
self.actor.load_state_dict(torch.load(f'{dir}/actor_{id}.pth')) |
|
self.critic.load_state_dict(torch.load(f'{dir}/critic_{id}.pth')) |
|
else: |
|
self.actor.load_state_dict(torch.load(f'{dir}/actor.pth')) |
|
self.critic.load_state_dict(torch.load(f'{dir}/critic.pth')) |
|
|
|
|