DIPO / agent /DiPo.py
Wyatt-Huang's picture
Upload 10 files
f761808 verified
import copy
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from agent.model import MLP, Critic
from agent.diffusion import Diffusion
from agent.vae import VAE
from agent.helpers import EMA
class DiPo(object):
def __init__(self,
args,
state_dim,
action_space,
memory,
diffusion_memory,
device,
):
action_dim = np.prod(action_space.shape)
self.policy_type = args.policy_type
if self.policy_type == 'Diffusion':
self.actor = Diffusion(state_dim=state_dim, action_dim=action_dim, noise_ratio=args.noise_ratio,
beta_schedule=args.beta_schedule, n_timesteps=args.n_timesteps).to(device)
elif self.policy_type == 'VAE':
self.actor = VAE(state_dim=state_dim, action_dim=action_dim, device=device).to(device)
else:
self.actor = MLP(state_dim=state_dim, action_dim=action_dim).to(device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=args.diffusion_lr, eps=1e-5)
self.memory = memory
self.diffusion_memory = diffusion_memory
self.action_gradient_steps = args.action_gradient_steps
self.action_grad_norm = action_dim * args.ratio
self.ac_grad_norm = args.ac_grad_norm
self.step = 0
self.tau = args.tau
self.actor_target = copy.deepcopy(self.actor)
self.update_actor_target_every = args.update_actor_target_every
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = copy.deepcopy(self.critic)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=args.critic_lr, eps=1e-5)
self.action_dim = action_dim
self.action_lr = args.action_lr
self.device = device
if action_space is None:
self.action_scale = 1.
self.action_bias = 0.
else:
self.action_scale = (action_space.high - action_space.low) / 2.
self.action_bias = (action_space.high + action_space.low) / 2.
def append_memory(self, state, action, reward, next_state, mask):
action = (action - self.action_bias) / self.action_scale
self.memory.append(state, action, reward, next_state, mask)
self.diffusion_memory.append(state, action)
def sample_action(self, state, eval=False):
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
action = self.actor(state, eval).cpu().data.numpy().flatten()
action = action.clip(-1, 1)
action = action * self.action_scale + self.action_bias
return action
def action_gradient(self, batch_size, log_writer):
states, best_actions, idxs = self.diffusion_memory.sample(batch_size)
actions_optim = torch.optim.Adam([best_actions], lr=self.action_lr, eps=1e-5)
for i in range(self.action_gradient_steps):
best_actions.requires_grad_(True)
q1, q2 = self.critic(states, best_actions)
loss = -torch.min(q1, q2)
actions_optim.zero_grad()
loss.backward(torch.ones_like(loss))
if self.action_grad_norm > 0:
actions_grad_norms = nn.utils.clip_grad_norm_([best_actions], max_norm=self.action_grad_norm, norm_type=2)
actions_optim.step()
best_actions.requires_grad_(False)
best_actions.clamp_(-1., 1.)
# if self.step % 10 == 0:
# log_writer.add_scalar('Action Grad Norm', actions_grad_norms.max().item(), self.step)
best_actions = best_actions.detach()
self.diffusion_memory.replace(idxs, best_actions.cpu().numpy())
return states, best_actions
def train(self, iterations, batch_size=256, log_writer=None):
for _ in range(iterations):
# Sample replay buffer / batch
states, actions, rewards, next_states, masks = self.memory.sample(batch_size)
""" Q Training """
current_q1, current_q2 = self.critic(states, actions)
next_actions = self.actor_target(next_states, eval=False)
target_q1, target_q2 = self.critic_target(next_states, next_actions)
target_q = torch.min(target_q1, target_q2)
target_q = (rewards + masks * target_q).detach()
critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
if self.ac_grad_norm > 0:
critic_grad_norms = nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=self.ac_grad_norm, norm_type=2)
# if self.step % 10 == 0:
# log_writer.add_scalar('Critic Grad Norm', critic_grad_norms.max().item(), self.step)
self.critic_optimizer.step()
""" Policy Training """
states, best_actions = self.action_gradient(batch_size, log_writer)
actor_loss = self.actor.loss(best_actions, states)
self.actor_optimizer.zero_grad()
actor_loss.backward()
if self.ac_grad_norm > 0:
actor_grad_norms = nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.ac_grad_norm, norm_type=2)
# if self.step % 10 == 0:
# log_writer.add_scalar('Actor Grad Norm', actor_grad_norms.max().item(), self.step)
self.actor_optimizer.step()
""" Step Target network """
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
if self.step % self.update_actor_target_every == 0:
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
self.step += 1
def save_model(self, dir, id=None):
if id is not None:
torch.save(self.actor.state_dict(), f'{dir}/actor_{id}.pth')
torch.save(self.critic.state_dict(), f'{dir}/critic_{id}.pth')
else:
torch.save(self.actor.state_dict(), f'{dir}/actor.pth')
torch.save(self.critic.state_dict(), f'{dir}/critic.pth')
def load_model(self, dir, id=None):
if id is not None:
self.actor.load_state_dict(torch.load(f'{dir}/actor_{id}.pth'))
self.critic.load_state_dict(torch.load(f'{dir}/critic_{id}.pth'))
else:
self.actor.load_state_dict(torch.load(f'{dir}/actor.pth'))
self.critic.load_state_dict(torch.load(f'{dir}/critic.pth'))