Spaces:
Sleeping
Sleeping
from typing import List, Dict, Any, Tuple, Union | |
import copy | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.distributions import Normal, Independent | |
from ding.torch_utils import Adam, to_device | |
from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ | |
qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data | |
from ding.model import model_wrap | |
from ding.utils import POLICY_REGISTRY | |
from ding.utils.data import default_collate, default_decollate | |
from .sac import SACPolicy | |
from .dqn import DQNPolicy | |
from .common_utils import default_preprocess_learn | |
class EDACPolicy(SACPolicy): | |
""" | |
Overview: | |
Policy class of EDAC algorithm. https://arxiv.org/pdf/2110.01548.pdf | |
Config: | |
== ==================== ======== ============= ================================= ======================= | |
ID Symbol Type Default Value Description Other(Shape) | |
== ==================== ======== ============= ================================= ======================= | |
1 ``type`` str td3 | RL policy register name, refer | this arg is optional, | |
| to registry ``POLICY_REGISTRY`` | a placeholder | |
2 ``cuda`` bool True | Whether to use cuda for network | | |
3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for | |
| ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ | |
| | buffer when training starts. | TD3. | |
4 | ``model.policy_`` int 256 | Linear layer size for policy | | |
| ``embedding_size`` | network. | | |
5 | ``model.soft_q_`` int 256 | Linear layer size for soft q | | |
| ``embedding_size`` | network. | | |
6 | ``model.emsemble`` int 10 | Number of Q-ensemble network | | |
| ``_num`` | | | |
| | | is False. | |
7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when | |
| ``_rate_q`` | network. | model.value_network | |
| | | is True. | |
8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when | |
| ``_rate_policy`` | network. | model.value_network | |
| | | is True. | |
9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when | |
| ``_rate_value`` | network. | model.value_network | |
| | | is False. | |
10 | ``learn.alpha`` float 1.0 | Entropy regularization | alpha is initiali- | |
| | coefficient. | zation for auto | |
| | | `alpha`, when | |
| | | auto_alpha is True | |
11 | ``learn.eta`` bool True | Parameter of EDAC algorithm | Defalut to 1.0 | |
12 | ``learn.`` bool True | Determine whether to use | Temperature parameter | |
| ``auto_alpha`` | auto temperature parameter | determines the | |
| | `alpha`. | relative importance | |
| | | of the entropy term | |
| | | against the reward. | |
13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only | |
| ``ignore_done`` | done flag. | in halfcheetah env. | |
14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation | |
| ``target_theta`` | target network. | factor in polyak aver | |
| | | aging for target | |
| | | networks. | |
== ==================== ======== ============= ================================= ======================= | |
""" | |
config = dict( | |
# (str) RL policy register name | |
type='edac', | |
cuda=False, | |
on_policy=False, | |
multi_agent=False, | |
priority=False, | |
priority_IS_weight=False, | |
random_collect_size=10000, | |
model=dict( | |
# (bool type) ensemble_num:num of Q-network. | |
ensemble_num=10, | |
# (bool type) value_network: Determine whether to use value network as the | |
# original EDAC paper (arXiv 2110.01548). | |
# using value_network needs to set learning_rate_value, learning_rate_q, | |
# and learning_rate_policy in `cfg.policy.learn`. | |
# Default to False. | |
# value_network=False, | |
# (int) Hidden size for actor network head. | |
actor_head_hidden_size=256, | |
# (int) Hidden size for critic network head. | |
critic_head_hidden_size=256, | |
), | |
learn=dict( | |
multi_gpu=False, | |
update_per_collect=1, | |
batch_size=256, | |
learning_rate_q=3e-4, | |
learning_rate_policy=3e-4, | |
learning_rate_value=3e-4, | |
learning_rate_alpha=3e-4, | |
target_theta=0.005, | |
discount_factor=0.99, | |
alpha=1, | |
auto_alpha=True, | |
# (bool type) log_space: Determine whether to use auto `\alpha` in log space. | |
log_space=True, | |
# (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) | |
# Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. | |
# These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. | |
# However, interaction with HalfCheetah always gets done with done is False, | |
# Since we inplace done==True with done==False to keep | |
# TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), | |
# when the episode step is greater than max episode step. | |
ignore_done=False, | |
# (float) Weight uniform initialization range in the last output layer | |
init_w=3e-3, | |
# (float) Loss weight for conservative item. | |
min_q_weight=1.0, | |
# (bool) Whether to use entropy in target q. | |
with_q_entropy=False, | |
eta=0.1, | |
), | |
collect=dict( | |
# (int) Cut trajectories into pieces with length "unroll_len". | |
unroll_len=1, | |
), | |
eval=dict(), | |
other=dict( | |
replay_buffer=dict( | |
# (int type) replay_buffer_size: Max size of replay buffer. | |
replay_buffer_size=1000000, | |
# (int type) max_use: Max use times of one data in the buffer. | |
# Data will be removed once used for too many times. | |
# Default to infinite. | |
# max_use=256, | |
), | |
), | |
) | |
def default_model(self) -> Tuple[str, List[str]]: | |
""" | |
Overview: | |
Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ | |
automatically call this method to get the default model setting and create model. | |
Returns: | |
- model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. | |
""" | |
return 'edac', ['ding.model.template.edac'] | |
def _init_learn(self) -> None: | |
r""" | |
Overview: | |
Learn mode init method. Called by ``self.__init__``. | |
Init q, value and policy's optimizers, algorithm config, main and target models. | |
""" | |
super()._init_learn() | |
# EDAC special implementation | |
self._eta = self._cfg.learn.eta | |
self._with_q_entropy = self._cfg.learn.with_q_entropy | |
self._forward_learn_cnt = 0 | |
def _forward_learn(self, data: dict) -> Dict[str, Any]: | |
loss_dict = {} | |
data = default_preprocess_learn( | |
data, | |
use_priority=self._priority, | |
use_priority_IS_weight=self._cfg.priority_IS_weight, | |
ignore_done=self._cfg.learn.ignore_done, | |
use_nstep=False | |
) | |
if len(data.get('action').shape) == 1: | |
data['action'] = data['action'].reshape(-1, 1) | |
if self._cuda: | |
data = to_device(data, self._device) | |
self._learn_model.train() | |
self._target_model.train() | |
obs = data['obs'] | |
next_obs = data['next_obs'] | |
reward = data['reward'] | |
done = data['done'] | |
acs = data['action'] | |
# 1. predict q value | |
q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] | |
with torch.no_grad(): | |
(mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit'] | |
dist = Independent(Normal(mu, sigma), 1) | |
pred = dist.rsample() | |
next_action = torch.tanh(pred) | |
y = 1 - next_action.pow(2) + 1e-6 | |
next_log_prob = dist.log_prob(pred).unsqueeze(-1) | |
next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True) | |
next_data = {'obs': next_obs, 'action': next_action} | |
target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value'] | |
# the value of a policy according to the maximum entropy objective | |
target_q_value, _ = torch.min(target_q_value, dim=0) | |
if self._with_q_entropy: | |
target_q_value -= self._alpha * next_log_prob.squeeze(-1) | |
target_q_value = self._gamma * (1 - done) * target_q_value + reward | |
weight = data['weight'] | |
if weight is None: | |
weight = torch.ones_like(q_value) | |
td_error_per_sample = nn.MSELoss(reduction='none')(q_value, target_q_value).mean(dim=1).sum() | |
loss_dict['critic_loss'] = (td_error_per_sample * weight).mean() | |
# penalty term of EDAC | |
if self._eta > 0: | |
# [batch_size,dim] -> [Ensemble_num,batch_size,dim] | |
pre_obs = obs.unsqueeze(0).repeat_interleave(self._cfg.model.ensemble_num, dim=0) | |
pre_acs = acs.unsqueeze(0).repeat_interleave(self._cfg.model.ensemble_num, dim=0).requires_grad_(True) | |
# [Ensemble_num,batch_size] | |
q_pred_tile = self._learn_model.forward({ | |
'obs': pre_obs, | |
'action': pre_acs | |
}, mode='compute_critic')['q_value'].requires_grad_(True) | |
q_pred_grads = torch.autograd.grad(q_pred_tile.sum(), pre_acs, retain_graph=True, create_graph=True)[0] | |
q_pred_grads = q_pred_grads / (torch.norm(q_pred_grads, p=2, dim=2).unsqueeze(-1) + 1e-10) | |
# [Ensemble_num,batch_size,act_dim] -> [batch_size,Ensemble_num,act_dim] | |
q_pred_grads = q_pred_grads.transpose(0, 1) | |
q_pred_grads = q_pred_grads @ q_pred_grads.permute(0, 2, 1) | |
masks = torch.eye( | |
self._cfg.model.ensemble_num, device=obs.device | |
).unsqueeze(dim=0).repeat(q_pred_grads.size(0), 1, 1) | |
q_pred_grads = (1 - masks) * q_pred_grads | |
grad_loss = torch.mean(torch.sum(q_pred_grads, dim=(1, 2))) / (self._cfg.model.ensemble_num - 1) | |
loss_dict['critic_loss'] += grad_loss * self._eta | |
self._optimizer_q.zero_grad() | |
loss_dict['critic_loss'].backward() | |
self._optimizer_q.step() | |
(mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit'] | |
dist = Independent(Normal(mu, sigma), 1) | |
pred = dist.rsample() | |
action = torch.tanh(pred) | |
y = 1 - action.pow(2) + 1e-6 | |
log_prob = dist.log_prob(pred).unsqueeze(-1) | |
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) | |
eval_data = {'obs': obs, 'action': action} | |
new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value'] | |
new_q_value, _ = torch.min(new_q_value, dim=0) | |
# 8. compute policy loss | |
policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean() | |
loss_dict['policy_loss'] = policy_loss | |
# 9. update policy network | |
self._optimizer_policy.zero_grad() | |
loss_dict['policy_loss'].backward() | |
self._optimizer_policy.step() | |
# 10. compute alpha loss | |
if self._auto_alpha: | |
if self._log_space: | |
log_prob = log_prob + self._target_entropy | |
loss_dict['alpha_loss'] = -(self._log_alpha * log_prob.detach()).mean() | |
self._alpha_optim.zero_grad() | |
loss_dict['alpha_loss'].backward() | |
self._alpha_optim.step() | |
self._alpha = self._log_alpha.detach().exp() | |
else: | |
log_prob = log_prob + self._target_entropy | |
loss_dict['alpha_loss'] = -(self._alpha * log_prob.detach()).mean() | |
self._alpha_optim.zero_grad() | |
loss_dict['alpha_loss'].backward() | |
self._alpha_optim.step() | |
self._alpha = max(0, self._alpha) | |
loss_dict['total_loss'] = sum(loss_dict.values()) | |
# ============= | |
# after update | |
# ============= | |
self._forward_learn_cnt += 1 | |
# target update | |
self._target_model.update(self._learn_model.state_dict()) | |
return { | |
'cur_lr_q': self._optimizer_q.defaults['lr'], | |
'cur_lr_p': self._optimizer_policy.defaults['lr'], | |
'priority': td_error_per_sample.abs().tolist(), | |
'td_error': td_error_per_sample.detach().mean().item(), | |
'alpha': self._alpha.item(), | |
'target_q_value': target_q_value.detach().mean().item(), | |
**loss_dict | |
} | |