Spaces:
Sleeping
Sleeping
from typing import List, Dict, Any, Tuple, Union, Callable, Optional | |
from collections import namedtuple | |
from easydict import EasyDict | |
import copy | |
import random | |
import numpy as np | |
import torch | |
import treetensor.torch as ttorch | |
from torch.optim import AdamW | |
from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, gae, gae_data, ppo_error_continuous, \ | |
get_gae, ppo_policy_error_continuous, ArgmaxSampler, MultinomialSampler, ReparameterizationSampler, MuSampler, \ | |
HybridStochasticSampler, HybridDeterminsticSampler, value_transform, value_inv_transform, symlog, inv_symlog | |
from ding.utils import POLICY_REGISTRY, RunningMeanStd | |
class PPOFPolicy: | |
config = dict( | |
type='ppo', | |
on_policy=True, | |
cuda=True, | |
action_space='discrete', | |
discount_factor=0.99, | |
gae_lambda=0.95, | |
# learn | |
epoch_per_collect=10, | |
batch_size=64, | |
learning_rate=3e-4, | |
# learningrate scheduler, which the format is (10000, 0.1) | |
lr_scheduler=None, | |
weight_decay=0, | |
value_weight=0.5, | |
entropy_weight=0.01, | |
clip_ratio=0.2, | |
adv_norm=True, | |
value_norm='baseline', | |
ppo_param_init=True, | |
grad_norm=0.5, | |
# collect | |
n_sample=128, | |
unroll_len=1, | |
# eval | |
deterministic_eval=True, | |
# model | |
model=dict(), | |
) | |
mode = ['learn', 'collect', 'eval'] | |
def default_config(cls: type) -> EasyDict: | |
cfg = EasyDict(copy.deepcopy(cls.config)) | |
cfg.cfg_type = cls.__name__ + 'Dict' | |
return cfg | |
def default_model(cls: type) -> Callable: | |
from .model import PPOFModel | |
return PPOFModel | |
def __init__(self, cfg: "EasyDict", model: torch.nn.Module, enable_mode: List[str] = None) -> None: | |
self._cfg = cfg | |
if model is None: | |
self._model = self.default_model() | |
else: | |
self._model = model | |
if self._cfg.cuda and torch.cuda.is_available(): | |
self._device = 'cuda' | |
self._model.cuda() | |
else: | |
self._device = 'cpu' | |
assert self._cfg.action_space in ["continuous", "discrete", "hybrid", 'multi_discrete'] | |
self._action_space = self._cfg.action_space | |
if self._cfg.ppo_param_init: | |
self._model_param_init() | |
if enable_mode is None: | |
enable_mode = self.mode | |
self.enable_mode = enable_mode | |
if 'learn' in enable_mode: | |
self._optimizer = AdamW( | |
self._model.parameters(), | |
lr=self._cfg.learning_rate, | |
weight_decay=self._cfg.weight_decay, | |
) | |
# define linear lr scheduler | |
if self._cfg.lr_scheduler is not None: | |
epoch_num, min_lr_lambda = self._cfg.lr_scheduler | |
self._lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | |
self._optimizer, | |
lr_lambda=lambda epoch: max(1.0 - epoch * (1.0 - min_lr_lambda) / epoch_num, min_lr_lambda) | |
) | |
if self._cfg.value_norm: | |
self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device) | |
if 'collect' in enable_mode: | |
if self._action_space == 'discrete': | |
self._collect_sampler = MultinomialSampler() | |
elif self._action_space == 'continuous': | |
self._collect_sampler = ReparameterizationSampler() | |
elif self._action_space == 'hybrid': | |
self._collect_sampler = HybridStochasticSampler() | |
if 'eval' in enable_mode: | |
if self._action_space == 'discrete': | |
if self._cfg.deterministic_eval: | |
self._eval_sampler = ArgmaxSampler() | |
else: | |
self._eval_sampler = MultinomialSampler() | |
elif self._action_space == 'continuous': | |
if self._cfg.deterministic_eval: | |
self._eval_sampler = MuSampler() | |
else: | |
self._eval_sampler = ReparameterizationSampler() | |
elif self._action_space == 'hybrid': | |
if self._cfg.deterministic_eval: | |
self._eval_sampler = HybridDeterminsticSampler() | |
else: | |
self._eval_sampler = HybridStochasticSampler() | |
# for compatibility | |
self.learn_mode = self | |
self.collect_mode = self | |
self.eval_mode = self | |
def _model_param_init(self): | |
for n, m in self._model.named_modules(): | |
if isinstance(m, torch.nn.Linear): | |
torch.nn.init.orthogonal_(m.weight) | |
torch.nn.init.zeros_(m.bias) | |
if self._action_space in ['continuous', 'hybrid']: | |
for m in list(self._model.critic.modules()) + list(self._model.actor.modules()): | |
if isinstance(m, torch.nn.Linear): | |
# orthogonal initialization | |
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) | |
torch.nn.init.zeros_(m.bias) | |
# init log sigma | |
if self._action_space == 'continuous': | |
torch.nn.init.constant_(self._model.actor_head.log_sigma_param, -0.5) | |
for m in self._model.actor_head.mu.modules(): | |
if isinstance(m, torch.nn.Linear): | |
torch.nn.init.zeros_(m.bias) | |
m.weight.data.copy_(0.01 * m.weight.data) | |
elif self._action_space == 'hybrid': # actor_head[1]: ReparameterizationHead, for action_args | |
if hasattr(self._model.actor_head[1], 'log_sigma_param'): | |
torch.nn.init.constant_(self._model.actor_head[1].log_sigma_param, -0.5) | |
for m in self._model.actor_head[1].mu.modules(): | |
if isinstance(m, torch.nn.Linear): | |
torch.nn.init.zeros_(m.bias) | |
m.weight.data.copy_(0.01 * m.weight.data) | |
def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: | |
return_infos = [] | |
self._model.train() | |
bs = self._cfg.batch_size | |
data = data[:self._cfg.n_sample // bs * bs] # rounding | |
# outer training loop | |
for epoch in range(self._cfg.epoch_per_collect): | |
# recompute adv | |
with torch.no_grad(): | |
# get the value dictionary | |
# In popart, the dictionary has two keys: 'pred' and 'unnormalized_pred' | |
value = self._model.compute_critic(data.obs) | |
next_value = self._model.compute_critic(data.next_obs) | |
reward = data.reward | |
assert self._cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline'],\ | |
'Not supported value normalization! Value normalization supported: \ | |
popart, value rescale, symlog, baseline' | |
if self._cfg.value_norm == 'popart': | |
unnormalized_value = value['unnormalized_pred'] | |
unnormalized_next_value = value['unnormalized_pred'] | |
mu = self._model.critic_head.popart.mu | |
sigma = self._model.critic_head.popart.sigma | |
reward = (reward - mu) / sigma | |
value = value['pred'] | |
next_value = next_value['pred'] | |
elif self._cfg.value_norm == 'value_rescale': | |
value = value_inv_transform(value['pred']) | |
next_value = value_inv_transform(next_value['pred']) | |
elif self._cfg.value_norm == 'symlog': | |
value = inv_symlog(value['pred']) | |
next_value = inv_symlog(next_value['pred']) | |
elif self._cfg.value_norm == 'baseline': | |
value = value['pred'] * self._running_mean_std.std | |
next_value = next_value['pred'] * self._running_mean_std.std | |
traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory | |
adv_data = gae_data(value, next_value, reward, data.done, traj_flag) | |
data.adv = gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda) | |
unnormalized_returns = value + data.adv # In popart, this return is normalized | |
if self._cfg.value_norm == 'popart': | |
self._model.critic_head.popart.update_parameters((data.reward).unsqueeze(1)) | |
elif self._cfg.value_norm == 'value_rescale': | |
value = value_transform(value) | |
unnormalized_returns = value_transform(unnormalized_returns) | |
elif self._cfg.value_norm == 'symlog': | |
value = symlog(value) | |
unnormalized_returns = symlog(unnormalized_returns) | |
elif self._cfg.value_norm == 'baseline': | |
value /= self._running_mean_std.std | |
unnormalized_returns /= self._running_mean_std.std | |
self._running_mean_std.update(unnormalized_returns.cpu().numpy()) | |
data.value = value | |
data.return_ = unnormalized_returns | |
# inner training loop | |
split_data = ttorch.split(data, self._cfg.batch_size) | |
random.shuffle(list(split_data)) | |
for batch in split_data: | |
output = self._model.compute_actor_critic(batch.obs) | |
adv = batch.adv | |
if self._cfg.adv_norm: | |
# Normalize advantage in a train_batch | |
adv = (adv - adv.mean()) / (adv.std() + 1e-8) | |
# Calculate ppo error | |
if self._action_space == 'continuous': | |
ppo_batch = ppo_data( | |
output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None | |
) | |
ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._cfg.clip_ratio) | |
elif self._action_space == 'discrete': | |
ppo_batch = ppo_data( | |
output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None | |
) | |
ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) | |
elif self._action_space == 'hybrid': | |
# discrete part (discrete policy loss and entropy loss) | |
ppo_discrete_batch = ppo_policy_data( | |
output.logit.action_type, batch.logit.action_type, batch.action.action_type, adv, None | |
) | |
ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(ppo_discrete_batch, self._cfg.clip_ratio) | |
# continuous part (continuous policy loss and entropy loss, value loss) | |
ppo_continuous_batch = ppo_data( | |
output.logit.action_args, batch.logit.action_args, batch.action.action_args, output.value, | |
batch.value, adv, batch.return_, None | |
) | |
ppo_continuous_loss, ppo_continuous_info = ppo_error_continuous( | |
ppo_continuous_batch, self._cfg.clip_ratio | |
) | |
# sum discrete and continuous loss | |
ppo_loss = type(ppo_continuous_loss)( | |
ppo_continuous_loss.policy_loss + ppo_discrete_loss.policy_loss, ppo_continuous_loss.value_loss, | |
ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss | |
) | |
ppo_info = type(ppo_continuous_info)( | |
max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl), | |
max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac) | |
) | |
wv, we = self._cfg.value_weight, self._cfg.entropy_weight | |
total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss | |
self._optimizer.zero_grad() | |
total_loss.backward() | |
torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.grad_norm) | |
self._optimizer.step() | |
return_info = { | |
'cur_lr': self._optimizer.defaults['lr'], | |
'total_loss': total_loss.item(), | |
'policy_loss': ppo_loss.policy_loss.item(), | |
'value_loss': ppo_loss.value_loss.item(), | |
'entropy_loss': ppo_loss.entropy_loss.item(), | |
'adv_max': adv.max().item(), | |
'adv_mean': adv.mean().item(), | |
'value_mean': output.value.mean().item(), | |
'value_max': output.value.max().item(), | |
'approx_kl': ppo_info.approx_kl, | |
'clipfrac': ppo_info.clipfrac, | |
} | |
if self._action_space == 'continuous': | |
return_info.update( | |
{ | |
'action': batch.action.float().mean().item(), | |
'mu_mean': output.logit.mu.mean().item(), | |
'sigma_mean': output.logit.sigma.mean().item(), | |
} | |
) | |
elif self._action_space == 'hybrid': | |
return_info.update( | |
{ | |
'action': batch.action.action_args.float().mean().item(), | |
'mu_mean': output.logit.action_args.mu.mean().item(), | |
'sigma_mean': output.logit.action_args.sigma.mean().item(), | |
} | |
) | |
return_infos.append(return_info) | |
if self._cfg.lr_scheduler is not None: | |
self._lr_scheduler.step() | |
return return_infos | |
def state_dict(self) -> Dict[str, Any]: | |
state_dict = { | |
'model': self._model.state_dict(), | |
} | |
if 'learn' in self.enable_mode: | |
state_dict['optimizer'] = self._optimizer.state_dict() | |
return state_dict | |
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
self._model.load_state_dict(state_dict['model']) | |
if 'learn' in self.enable_mode: | |
self._optimizer.load_state_dict(state_dict['optimizer']) | |
def collect(self, data: ttorch.Tensor) -> ttorch.Tensor: | |
self._model.eval() | |
with torch.no_grad(): | |
output = self._model.compute_actor_critic(data) | |
action = self._collect_sampler(output.logit) | |
output.action = action | |
return output | |
def process_transition(self, obs: ttorch.Tensor, inference_output: dict, timestep: namedtuple) -> ttorch.Tensor: | |
return ttorch.as_tensor( | |
{ | |
'obs': obs, | |
'next_obs': timestep.obs, | |
'action': inference_output.action, | |
'logit': inference_output.logit, | |
'value': inference_output.value, | |
'reward': timestep.reward, | |
'done': timestep.done, | |
} | |
) | |
def eval(self, data: ttorch.Tensor) -> ttorch.Tensor: | |
self._model.eval() | |
with torch.no_grad(): | |
logit = self._model.compute_actor(data) | |
action = self._eval_sampler(logit) | |
return ttorch.as_tensor({'logit': logit, 'action': action}) | |
def monitor_vars(self) -> List[str]: | |
variables = [ | |
'cur_lr', | |
'policy_loss', | |
'value_loss', | |
'entropy_loss', | |
'adv_max', | |
'adv_mean', | |
'approx_kl', | |
'clipfrac', | |
'value_max', | |
'value_mean', | |
] | |
if self._action_space in ['action', 'mu_mean', 'sigma_mean']: | |
variables += ['mu_mean', 'sigma_mean', 'action'] | |
return variables | |
def reset(self, env_id_list: Optional[List[int]] = None) -> None: | |
pass | |