zjowowen's picture
init space
3dfe8fb
from typing import List, Dict, Any, Tuple, Union, Optional
from collections import namedtuple, deque
import torch
import copy
from ding.torch_utils import Adam, to_device
from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, get_gae_with_default_last_value, \
v_nstep_td_data, v_nstep_td_error, get_nstep_return_data, get_train_sample
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY, deep_merge_dicts
from ding.utils.data import default_collate, default_decollate
from ding.policy.base_policy import Policy
from ding.policy.common_utils import default_preprocess_learn
from ding.policy.command_mode_policy_instance import DummyCommandModePolicy
@POLICY_REGISTRY.register('ppo_lstm')
class PPOPolicy(Policy):
r"""
Overview:
Policy class of PPO algorithm.
"""
config = dict(
# (str) RL policy register name (refer to function "POLICY_REGISTRY").
type='ppo_lstm',
# (bool) Whether to use cuda for network.
cuda=False,
# (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used)
on_policy=True,
# (bool) Whether to use priority(priority sample, IS weight, update priority)
priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False,
# (bool) Whether to use nstep_return for value loss
nstep_return=False,
nstep=3,
learn=dict(
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=5,
batch_size=64,
learning_rate=0.001,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) The loss weight of value network, policy network weight is set to 1
value_weight=0.5,
# (float) The loss weight of entropy regularization, policy network weight is set to 1
entropy_weight=0.01,
# (float) PPO clip ratio, defaults to 0.2
clip_ratio=0.2,
# (bool) Whether to use advantage norm in a whole training batch
adv_norm=False,
ignore_done=False,
),
collect=dict(
# (int) Only one of [n_sample, n_episode] shoule be set
# n_sample=64,
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) Reward's future discount factor, aka. gamma.
discount_factor=0.99,
# (float) GAE lambda factor for the balance of bias and variance(1-step td and mc)
gae_lambda=0.95,
),
eval=dict(),
# Although ppo is an on-policy algorithm, ding reuses the buffer mechanism, and clear buffer after update.
# Note replay_buffer_size must be greater than n_sample.
other=dict(replay_buffer=dict(replay_buffer_size=1000, ), ),
)
def _init_learn(self) -> None:
r"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init the optimizer, algorithm config and the main model.
"""
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPO"
# Orthogonal init
for m in self._model.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.orthogonal_(m.weight)
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
# Optimizer
self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
self._learn_model = model_wrap(self._model, wrapper_name='base')
# self._learn_model = model_wrap(self._learn_model, wrapper_name='hidden_state', state_num=self._cfg.learn.batch_size)
# Algorithm config
self._value_weight = self._cfg.learn.value_weight
self._entropy_weight = self._cfg.learn.entropy_weight
self._clip_ratio = self._cfg.learn.clip_ratio
self._adv_norm = self._cfg.learn.adv_norm
self._nstep = self._cfg.nstep
self._nstep_return = self._cfg.nstep_return
# Main model
self._learn_model.reset()
def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
Overview:
Forward and backward function of learn mode.
Arguments:
- data (:obj:`dict`): Dict type data
Returns:
- info_dict (:obj:`Dict[str, Any]`):
Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \
adv_abs_max, approx_kl, clipfrac
"""
data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=self._nstep_return)
if self._cuda:
data = to_device(data, self._device)
# ====================
# PPO forward
# ====================
self._learn_model.train()
# normal ppo
if not self._nstep_return:
output = self._learn_model.forward(data['obs'])
adv = data['adv']
if self._adv_norm:
# Normalize advantage in a total train_batch
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
return_ = data['value'] + adv
# Calculate ppo error
ppodata = ppo_data(
output['logit'], data['logit'], data['action'], output['value'], data['value'], adv, return_,
data['weight']
)
ppo_loss, ppo_info = ppo_error(ppodata, self._clip_ratio)
wv, we = self._value_weight, self._entropy_weight
total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
else:
output = self._learn_model.forward(data['obs'])
adv = data['adv']
if self._adv_norm:
# Normalize advantage in a total train_batch
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
# Calculate ppo error
ppodata = ppo_policy_data(output['logit'], data['logit'], data['action'], adv, data['weight'])
ppo_policy_loss, ppo_info = ppo_policy_error(ppodata, self._clip_ratio)
wv, we = self._value_weight, self._entropy_weight
next_obs = data.get('next_obs')
value_gamma = data.get('value_gamma')
reward = data.get('reward')
# current value
value = self._learn_model.forward(data['obs'])
# target value
next_data = {'obs': next_obs}
target_value = self._learn_model.forward(next_data['obs'])
# TODO what should we do here to keep shape
assert self._nstep > 1
td_data = v_nstep_td_data(
value['value'], target_value['value'], reward.t(), data['done'], data['weight'], value_gamma
)
#calculate v_nstep_td critic_loss
critic_loss, td_error_per_sample = v_nstep_td_error(td_data, self._gamma, self._nstep)
ppo_loss_data = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss'])
ppo_loss = ppo_loss_data(ppo_policy_loss.policy_loss, critic_loss, ppo_policy_loss.entropy_loss)
total_loss = ppo_policy_loss.policy_loss + wv * critic_loss - we * ppo_policy_loss.entropy_loss
# ====================
# PPO update
# ====================
self._optimizer.zero_grad()
total_loss.backward()
self._optimizer.step()
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': total_loss.item(),
'policy_loss': ppo_loss.policy_loss.item(),
'value_loss': ppo_loss.value_loss.item(),
'entropy_loss': ppo_loss.entropy_loss.item(),
'adv_abs_max': adv.abs().max().item(),
'approx_kl': ppo_info.approx_kl,
'clipfrac': ppo_info.clipfrac,
}
def _state_dict_learn(self) -> Dict[str, Any]:
return {
'model': self._learn_model.state_dict(),
'optimizer': self._optimizer.state_dict(),
}
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
self._learn_model.load_state_dict(state_dict['model'])
self._optimizer.load_state_dict(state_dict['optimizer'])
def _init_collect(self) -> None:
r"""
Overview:
Collect mode init method. Called by ``self.__init__``.
Init traj and unroll length, collect model.
"""
self._unroll_len = self._cfg.collect.unroll_len
self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
# self._collect_model = model_wrap(
# self._collect_model, wrapper_name='hidden_state', state_num=self._cfg.collect.env_num, save_prev_state=True
# )
self._collect_model.reset()
self._gamma = self._cfg.collect.discount_factor
self._gae_lambda = self._cfg.collect.gae_lambda
self._nstep = self._cfg.nstep
self._nstep_return = self._cfg.nstep_return
def _forward_collect(self, data: dict) -> dict:
r"""
Overview:
Forward function of collect mode.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
Returns:
- output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
ReturnsKeys
- necessary: ``action``
"""
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._collect_model.eval()
with torch.no_grad():
output = self._collect_model.forward(data)
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
"""
Overview:
Generate dict type transition data from inputs.
Arguments:
- obs (:obj:`Any`): Env observation
- model_output (:obj:`dict`): Output of collect model, including at least ['action']
- timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\
(here 'obs' indicates obs after env step).
Returns:
- transition (:obj:`dict`): Dict type transition data.
"""
if not self._nstep_return:
transition = {
'obs': obs,
'logit': model_output['logit'],
'action': model_output['action'],
'value': model_output['value'],
'prev_state': model_output['prev_state'],
'reward': timestep.reward,
'done': timestep.done,
}
else:
transition = {
'obs': obs,
'next_obs': timestep.obs,
'logit': model_output['logit'],
'action': model_output['action'],
'prev_state': model_output['prev_state'],
'value': model_output['value'],
'reward': timestep.reward,
'done': timestep.done,
}
return transition
def _get_train_sample(self, data: deque) -> Union[None, List[Any]]:
r"""
Overview:
Get the trajectory and calculate GAE, return one data to cache for next time calculation
Arguments:
- data (:obj:`deque`): The trajectory's cache
Returns:
- samples (:obj:`dict`): The training samples generated
"""
data = get_gae_with_default_last_value(
data,
data[-1]['done'],
gamma=self._gamma,
gae_lambda=self._gae_lambda,
cuda=self._cuda,
)
if not self._nstep_return:
return get_train_sample(data, self._unroll_len)
else:
return get_nstep_return_data(data, self._nstep)
def _init_eval(self) -> None:
r"""
Overview:
Evaluate mode init method. Called by ``self.__init__``.
Init eval model with argmax strategy.
"""
self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
# self._eval_model = model_wrap(self._model, wrapper_name='hidden_state', state_num=self._cfg.eval.env_num)
self._eval_model.reset()
def _forward_eval(self, data: dict) -> dict:
r"""
Overview:
Forward function of eval mode, similar to ``self._forward_collect``.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
Returns:
- output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
ReturnsKeys
- necessary: ``action``
"""
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
# data = {'obs': data}
self._eval_model.eval()
with torch.no_grad():
output = self._eval_model.forward(data[0])
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
self._eval_model.reset(data_id=data_id)
def default_model(self) -> Tuple[str, List[str]]:
return 'vac', ['ding.model.template.vac']
def _monitor_vars_learn(self) -> List[str]:
return super()._monitor_vars_learn() + [
'policy_loss', 'value_loss', 'entropy_loss', 'adv_abs_max', 'approx_kl', 'clipfrac'
]
@POLICY_REGISTRY.register('ppo_lstm_command')
class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy):
pass