Spaces:
Running
Running
File size: 6,685 Bytes
e11e4fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# ## ML-Agent Learning (SAC)
# Contains an implementation of SAC as described in https://arxiv.org/abs/1801.01290
# and implemented in https://github.com/hill-a/stable-baselines
from typing import cast
import numpy as np
from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.buffer import BufferKey
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.trainer.off_policy_trainer import OffPolicyTrainer
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.policy.policy import Policy
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer, SACSettings
from mlagents.trainers.trajectory import Trajectory, ObsUtil
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.torch_entities.networks import SimpleActor
logger = get_logger(__name__)
BUFFER_TRUNCATE_PERCENT = 0.8
TRAINER_NAME = "sac"
class SACTrainer(OffPolicyTrainer):
"""
The SACTrainer is an implementation of the SAC algorithm, with support
for discrete actions and recurrent networks.
"""
def __init__(
self,
behavior_name: str,
reward_buff_cap: int,
trainer_settings: TrainerSettings,
training: bool,
load: bool,
seed: int,
artifact_path: str,
):
"""
Responsible for collecting experiences and training SAC model.
:param behavior_name: The name of the behavior associated with trainer config
:param reward_buff_cap: Max reward history to track in the reward buffer
:param trainer_settings: The parameters for the trainer.
:param training: Whether the trainer is set for training.
:param load: Whether the model should be loaded.
:param seed: The seed the model will be initialized with
:param artifact_path: The directory within which to store artifacts from this trainer.
"""
super().__init__(
behavior_name,
reward_buff_cap,
trainer_settings,
training,
load,
seed,
artifact_path,
)
self.seed = seed
self.policy: TorchPolicy = None # type: ignore
self.optimizer: TorchSACOptimizer = None # type: ignore
self.hyperparameters: SACSettings = cast(
SACSettings, trainer_settings.hyperparameters
)
self._step = 0
# Don't divide by zero
self.update_steps = 1
self.reward_signal_update_steps = 1
self.steps_per_update = self.hyperparameters.steps_per_update
self.reward_signal_steps_per_update = (
self.hyperparameters.reward_signal_steps_per_update
)
self.checkpoint_replay_buffer = self.hyperparameters.save_replay_buffer
def _process_trajectory(self, trajectory: Trajectory) -> None:
"""
Takes a trajectory and processes it, putting it into the replay buffer.
"""
super()._process_trajectory(trajectory)
last_step = trajectory.steps[-1]
agent_id = trajectory.agent_id # All the agents should have the same ID
agent_buffer_trajectory = trajectory.to_agentbuffer()
# Check if we used group rewards, warn if so.
self._warn_if_group_reward(agent_buffer_trajectory)
# Update the normalization
if self.is_training:
self.policy.actor.update_normalization(agent_buffer_trajectory)
self.optimizer.critic.update_normalization(agent_buffer_trajectory)
# Evaluate all reward functions for reporting purposes
self.collected_rewards["environment"][agent_id] += np.sum(
agent_buffer_trajectory[BufferKey.ENVIRONMENT_REWARDS]
)
for name, reward_signal in self.optimizer.reward_signals.items():
evaluate_result = (
reward_signal.evaluate(agent_buffer_trajectory) * reward_signal.strength
)
# Report the reward signals
self.collected_rewards[name][agent_id] += np.sum(evaluate_result)
# Get all value estimates for reporting purposes
(
value_estimates,
_,
value_memories,
) = self.optimizer.get_trajectory_value_estimates(
agent_buffer_trajectory, trajectory.next_obs, trajectory.done_reached
)
if value_memories is not None:
agent_buffer_trajectory[BufferKey.CRITIC_MEMORY].set(value_memories)
for name, v in value_estimates.items():
self._stats_reporter.add_stat(
f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} Value",
np.mean(v),
)
# Bootstrap using the last step rather than the bootstrap step if max step is reached.
# Set last element to duplicate obs and remove dones.
if last_step.interrupted:
last_step_obs = last_step.obs
for i, obs in enumerate(last_step_obs):
agent_buffer_trajectory[ObsUtil.get_name_at_next(i)][-1] = obs
agent_buffer_trajectory[BufferKey.DONE][-1] = False
self._append_to_update_buffer(agent_buffer_trajectory)
if trajectory.done_reached:
self._update_end_episode_stats(agent_id, self.optimizer)
def create_optimizer(self) -> TorchOptimizer:
return TorchSACOptimizer( # type: ignore
cast(TorchPolicy, self.policy), self.trainer_settings # type: ignore
) # type: ignore
def create_policy(
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
) -> TorchPolicy:
"""
Creates a policy with a PyTorch backend and SAC hyperparameters
:param parsed_behavior_id:
:param behavior_spec: specifications for policy construction
:return policy
"""
actor_cls = SimpleActor
actor_kwargs = {"conditional_sigma": True, "tanh_squash": True}
policy = TorchPolicy(
self.seed,
behavior_spec,
self.trainer_settings.network_settings,
actor_cls,
actor_kwargs,
)
self.maybe_load_replay_buffer()
return policy
def get_policy(self, name_behavior_id: str) -> Policy:
"""
Gets policy from trainer associated with name_behavior_id
:param name_behavior_id: full identifier of policy
"""
return self.policy
@staticmethod
def get_trainer_name() -> str:
return TRAINER_NAME
|