Spaces:
Sleeping
Sleeping
from typing import List, Dict, Any, Tuple, Union | |
import torch | |
from ding.policy import PPOPolicy, PPOOffPolicy | |
from ding.rl_utils import ppo_data, ppo_error, gae, gae_data | |
from ding.utils import POLICY_REGISTRY, split_data_generator | |
from ding.torch_utils import to_device | |
from ding.policy.common_utils import default_preprocess_learn | |
class MultiDiscretePPOPolicy(PPOPolicy): | |
r""" | |
Overview: | |
Policy class of Multi-discrete action space PPO algorithm. | |
""" | |
def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
r""" | |
Overview: | |
Forward and backward function of learn mode. | |
Arguments: | |
- data (:obj:`dict`): Dict type data | |
Returns: | |
- info_dict (:obj:`Dict[str, Any]`): | |
Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \ | |
adv_max, adv_mean, value_max, value_mean, approx_kl, clipfrac | |
""" | |
data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False) | |
if self._cuda: | |
data = to_device(data, self._device) | |
# ==================== | |
# PPO forward | |
# ==================== | |
return_infos = [] | |
self._learn_model.train() | |
for epoch in range(self._cfg.learn.epoch_per_collect): | |
if self._recompute_adv: | |
with torch.no_grad(): | |
value = self._learn_model.forward(data['obs'], mode='compute_critic')['value'] | |
next_value = self._learn_model.forward(data['next_obs'], mode='compute_critic')['value'] | |
if self._value_norm: | |
value *= self._running_mean_std.std | |
next_value *= self._running_mean_std.std | |
compute_adv_data = gae_data(value, next_value, data['reward'], data['done'], data['traj_flag']) | |
# GAE need (T, B) shape input and return (T, B) output | |
data['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda) | |
# value = value[:-1] | |
unnormalized_returns = value + data['adv'] | |
if self._value_norm: | |
data['value'] = value / self._running_mean_std.std | |
data['return'] = unnormalized_returns / self._running_mean_std.std | |
self._running_mean_std.update(unnormalized_returns.cpu().numpy()) | |
else: | |
data['value'] = value | |
data['return'] = unnormalized_returns | |
else: # don't recompute adv | |
if self._value_norm: | |
unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std | |
data['return'] = unnormalized_return / self._running_mean_std.std | |
self._running_mean_std.update(unnormalized_return.cpu().numpy()) | |
else: | |
data['return'] = data['adv'] + data['value'] | |
for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): | |
output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') | |
adv = batch['adv'] | |
if self._adv_norm: | |
# Normalize advantage in a train_batch | |
adv = (adv - adv.mean()) / (adv.std() + 1e-8) | |
# Calculate ppo error | |
loss_list = [] | |
info_list = [] | |
action_num = len(batch['action']) | |
for i in range(action_num): | |
ppo_batch = ppo_data( | |
output['logit'][i], batch['logit'][i], batch['action'][i], output['value'], batch['value'], adv, | |
batch['return'], batch['weight'] | |
) | |
ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio) | |
loss_list.append(ppo_loss) | |
info_list.append(ppo_info) | |
avg_policy_loss = sum([item.policy_loss for item in loss_list]) / action_num | |
avg_value_loss = sum([item.value_loss for item in loss_list]) / action_num | |
avg_entropy_loss = sum([item.entropy_loss for item in loss_list]) / action_num | |
avg_approx_kl = sum([item.approx_kl for item in info_list]) / action_num | |
avg_clipfrac = sum([item.clipfrac for item in info_list]) / action_num | |
wv, we = self._value_weight, self._entropy_weight | |
total_loss = avg_policy_loss + wv * avg_value_loss - we * avg_entropy_loss | |
self._optimizer.zero_grad() | |
total_loss.backward() | |
self._optimizer.step() | |
return_info = { | |
'cur_lr': self._optimizer.defaults['lr'], | |
'total_loss': total_loss.item(), | |
'policy_loss': avg_policy_loss.item(), | |
'value_loss': avg_value_loss.item(), | |
'entropy_loss': avg_entropy_loss.item(), | |
'adv_max': adv.max().item(), | |
'adv_mean': adv.mean().item(), | |
'value_mean': output['value'].mean().item(), | |
'value_max': output['value'].max().item(), | |
'approx_kl': avg_approx_kl, | |
'clipfrac': avg_clipfrac, | |
} | |
return_infos.append(return_info) | |
return return_infos | |
class MultiDiscretePPOOffPolicy(PPOOffPolicy): | |
r""" | |
Overview: | |
Policy class of Multi-discrete action space off-policy PPO algorithm. | |
""" | |
def _forward_learn(self, data: dict) -> Dict[str, Any]: | |
r""" | |
Overview: | |
Forward and backward function of learn mode. | |
Arguments: | |
- data (:obj:`dict`): Dict type data | |
Returns: | |
- info_dict (:obj:`Dict[str, Any]`): | |
Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \ | |
adv_abs_max, approx_kl, clipfrac | |
""" | |
assert not self._nstep_return | |
data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=self._nstep_return) | |
if self._cuda: | |
data = to_device(data, self._device) | |
# ==================== | |
# PPO forward | |
# ==================== | |
self._learn_model.train() | |
# normal ppo | |
output = self._learn_model.forward(data['obs'], mode='compute_actor_critic') | |
adv = data['adv'] | |
return_ = data['value'] + adv | |
if self._adv_norm: | |
# Normalize advantage in a total train_batch | |
adv = (adv - adv.mean()) / (adv.std() + 1e-8) | |
# Calculate ppo error | |
loss_list = [] | |
info_list = [] | |
action_num = len(data['action']) | |
for i in range(action_num): | |
ppodata = ppo_data( | |
output['logit'][i], data['logit'][i], data['action'][i], output['value'], data['value'], adv, return_, | |
data['weight'] | |
) | |
ppo_loss, ppo_info = ppo_error(ppodata, self._clip_ratio) | |
loss_list.append(ppo_loss) | |
info_list.append(ppo_info) | |
avg_policy_loss = sum([item.policy_loss for item in loss_list]) / action_num | |
avg_value_loss = sum([item.value_loss for item in loss_list]) / action_num | |
avg_entropy_loss = sum([item.entropy_loss for item in loss_list]) / action_num | |
avg_approx_kl = sum([item.approx_kl for item in info_list]) / action_num | |
avg_clipfrac = sum([item.clipfrac for item in info_list]) / action_num | |
wv, we = self._value_weight, self._entropy_weight | |
total_loss = avg_policy_loss + wv * avg_value_loss - we * avg_entropy_loss | |
# ==================== | |
# PPO update | |
# ==================== | |
self._optimizer.zero_grad() | |
total_loss.backward() | |
self._optimizer.step() | |
return { | |
'cur_lr': self._optimizer.defaults['lr'], | |
'total_loss': total_loss.item(), | |
'policy_loss': avg_policy_loss, | |
'value_loss': avg_value_loss, | |
'entropy_loss': avg_entropy_loss, | |
'adv_abs_max': adv.abs().max().item(), | |
'approx_kl': avg_approx_kl, | |
'clipfrac': avg_clipfrac, | |
} | |