from typing import Dict, cast import attr from mlagents.torch_utils import torch, default_device from mlagents.trainers.buffer import AgentBuffer, BufferKey, RewardSignalUtil from mlagents_envs.timers import timed from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer from mlagents.trainers.settings import ( TrainerSettings, OnPolicyHyperparamSettings, ScheduleType, ) from mlagents.trainers.torch_entities.networks import ValueNetwork from mlagents.trainers.torch_entities.agent_action import AgentAction from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs from mlagents.trainers.torch_entities.utils import ModelUtils from mlagents.trainers.trajectory import ObsUtil @attr.s(auto_attribs=True) class PPOSettings(OnPolicyHyperparamSettings): beta: float = 5.0e-3 epsilon: float = 0.2 lambd: float = 0.95 num_epoch: int = 3 shared_critic: bool = False learning_rate_schedule: ScheduleType = ScheduleType.LINEAR beta_schedule: ScheduleType = ScheduleType.LINEAR epsilon_schedule: ScheduleType = ScheduleType.LINEAR class TorchPPOOptimizer(TorchOptimizer): def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): """ Takes a Policy and a Dict of trainer parameters and creates an Optimizer around the policy. The PPO optimizer has a value estimator and a loss function. :param policy: A TorchPolicy object that will be updated by this PPO Optimizer. :param trainer_params: Trainer parameters dictionary that specifies the properties of the trainer. """ # Create the graph here to give more granular control of the TF graph to the Optimizer. super().__init__(policy, trainer_settings) reward_signal_configs = trainer_settings.reward_signals reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] self.hyperparameters: PPOSettings = cast( PPOSettings, trainer_settings.hyperparameters ) params = list(self.policy.actor.parameters()) if self.hyperparameters.shared_critic: self._critic = policy.actor else: self._critic = ValueNetwork( reward_signal_names, policy.behavior_spec.observation_specs, network_settings=trainer_settings.network_settings, ) self._critic.to(default_device()) params += list(self._critic.parameters()) self.decay_learning_rate = ModelUtils.DecayedValue( self.hyperparameters.learning_rate_schedule, self.hyperparameters.learning_rate, 1e-10, self.trainer_settings.max_steps, ) self.decay_epsilon = ModelUtils.DecayedValue( self.hyperparameters.epsilon_schedule, self.hyperparameters.epsilon, 0.1, self.trainer_settings.max_steps, ) self.decay_beta = ModelUtils.DecayedValue( self.hyperparameters.beta_schedule, self.hyperparameters.beta, 1e-5, self.trainer_settings.max_steps, ) self.optimizer = torch.optim.Adam( params, lr=self.trainer_settings.hyperparameters.learning_rate ) self.stats_name_to_update_name = { "Losses/Value Loss": "value_loss", "Losses/Policy Loss": "policy_loss", } self.stream_names = list(self.reward_signals.keys()) @property def critic(self): return self._critic @timed def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: """ Performs update on model. :param batch: Batch of experiences. :param num_sequences: Number of sequences to process. :return: Results of update. """ # Get decayed parameters decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) decay_eps = self.decay_epsilon.get_value(self.policy.get_current_step()) decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) returns = {} old_values = {} for name in self.reward_signals: old_values[name] = ModelUtils.list_to_tensor( batch[RewardSignalUtil.value_estimates_key(name)] ) returns[name] = ModelUtils.list_to_tensor( batch[RewardSignalUtil.returns_key(name)] ) n_obs = len(self.policy.behavior_spec.observation_specs) current_obs = ObsUtil.from_buffer(batch, n_obs) # Convert to tensors current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK]) actions = AgentAction.from_buffer(batch) memories = [ ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length) ] if len(memories) > 0: memories = torch.stack(memories).unsqueeze(0) # Get value memories value_memories = [ ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i]) for i in range( 0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length ) ] if len(value_memories) > 0: value_memories = torch.stack(value_memories).unsqueeze(0) run_out = self.policy.actor.get_stats( current_obs, actions, masks=act_masks, memories=memories, sequence_length=self.policy.sequence_length, ) log_probs = run_out["log_probs"] entropy = run_out["entropy"] values, _ = self.critic.critic_pass( current_obs, memories=value_memories, sequence_length=self.policy.sequence_length, ) old_log_probs = ActionLogProbs.from_buffer(batch).flatten() log_probs = log_probs.flatten() loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool) value_loss = ModelUtils.trust_region_value_loss( values, old_values, returns, decay_eps, loss_masks ) policy_loss = ModelUtils.trust_region_policy_loss( ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES]), log_probs, old_log_probs, loss_masks, decay_eps, ) loss = ( policy_loss + 0.5 * value_loss - decay_bet * ModelUtils.masked_mean(entropy, loss_masks) ) # Set optimizer learning rate ModelUtils.update_learning_rate(self.optimizer, decay_lr) self.optimizer.zero_grad() loss.backward() self.optimizer.step() update_stats = { # NOTE: abs() is not technically correct, but matches the behavior in TensorFlow. # TODO: After PyTorch is default, change to something more correct. "Losses/Policy Loss": torch.abs(policy_loss).item(), "Losses/Value Loss": value_loss.item(), "Policy/Learning Rate": decay_lr, "Policy/Epsilon": decay_eps, "Policy/Beta": decay_bet, } return update_stats # TODO move module update into TorchOptimizer for reward_provider def get_modules(self): modules = { "Optimizer:value_optimizer": self.optimizer, "Optimizer:critic": self._critic, } for reward_provider in self.reward_signals.values(): modules.update(reward_provider.get_modules()) return modules