|
|
|
from typing import Dict, List, Optional |
|
from collections import defaultdict |
|
import abc |
|
import time |
|
import attr |
|
import numpy as np |
|
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod |
|
|
|
from mlagents.trainers.policy.checkpoint_manager import ( |
|
ModelCheckpoint, |
|
ModelCheckpointManager, |
|
) |
|
from mlagents_envs.logging_util import get_logger |
|
from mlagents_envs.timers import timed |
|
from mlagents.trainers.optimizer import Optimizer |
|
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
|
from mlagents.trainers.buffer import AgentBuffer, BufferKey |
|
from mlagents.trainers.trainer import Trainer |
|
from mlagents.trainers.torch_entities.components.reward_providers.base_reward_provider import ( |
|
BaseRewardProvider, |
|
) |
|
from mlagents_envs.timers import hierarchical_timer |
|
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver |
|
from mlagents.trainers.agent_processor import AgentManagerQueue |
|
from mlagents.trainers.trajectory import Trajectory |
|
from mlagents.trainers.settings import TrainerSettings |
|
from mlagents.trainers.stats import StatsPropertyType |
|
from mlagents.trainers.model_saver.model_saver import BaseModelSaver |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class RLTrainer(Trainer): |
|
""" |
|
This class is the base class for trainers that use Reward Signals. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
|
|
self.cumulative_returns_since_policy_update: List[float] = [] |
|
self.collected_rewards: Dict[str, Dict[str, int]] = { |
|
"environment": defaultdict(lambda: 0) |
|
} |
|
self.update_buffer: AgentBuffer = AgentBuffer() |
|
self._stats_reporter.add_property( |
|
StatsPropertyType.HYPERPARAMETERS, self.trainer_settings.as_dict() |
|
) |
|
|
|
self._next_save_step = 0 |
|
self._next_summary_step = 0 |
|
self.model_saver = self.create_model_saver( |
|
self.trainer_settings, self.artifact_path, self.load |
|
) |
|
self._has_warned_group_rewards = False |
|
|
|
def end_episode(self) -> None: |
|
""" |
|
A signal that the Episode has ended. The buffer must be reset. |
|
Get only called when the academy resets. |
|
""" |
|
for rewards in self.collected_rewards.values(): |
|
for agent_id in rewards: |
|
rewards[agent_id] = 0 |
|
|
|
def _update_end_episode_stats(self, agent_id: str, optimizer: Optimizer) -> None: |
|
for name, rewards in self.collected_rewards.items(): |
|
if name == "environment": |
|
self.stats_reporter.add_stat( |
|
"Environment/Cumulative Reward", |
|
rewards.get(agent_id, 0), |
|
aggregation=StatsAggregationMethod.HISTOGRAM, |
|
) |
|
self.cumulative_returns_since_policy_update.append( |
|
rewards.get(agent_id, 0) |
|
) |
|
self.reward_buffer.appendleft(rewards.get(agent_id, 0)) |
|
rewards[agent_id] = 0 |
|
else: |
|
if isinstance(optimizer.reward_signals[name], BaseRewardProvider): |
|
self.stats_reporter.add_stat( |
|
f"Policy/{optimizer.reward_signals[name].name.capitalize()} Reward", |
|
rewards.get(agent_id, 0), |
|
) |
|
else: |
|
self.stats_reporter.add_stat( |
|
optimizer.reward_signals[name].stat_name, |
|
rewards.get(agent_id, 0), |
|
) |
|
rewards[agent_id] = 0 |
|
|
|
def _clear_update_buffer(self) -> None: |
|
""" |
|
Clear the buffers that have been built up during inference. |
|
""" |
|
self.update_buffer.reset_agent() |
|
|
|
@abc.abstractmethod |
|
def _is_ready_update(self): |
|
""" |
|
Returns whether or not the trainer has enough elements to run update model |
|
:return: A boolean corresponding to wether or not update_model() can be run |
|
""" |
|
return False |
|
|
|
@abc.abstractmethod |
|
def create_optimizer(self) -> TorchOptimizer: |
|
""" |
|
Creates an Optimizer object |
|
""" |
|
pass |
|
|
|
@staticmethod |
|
def create_model_saver( |
|
trainer_settings: TrainerSettings, model_path: str, load: bool |
|
) -> BaseModelSaver: |
|
model_saver = TorchModelSaver( |
|
trainer_settings, model_path, load |
|
) |
|
return model_saver |
|
|
|
def _policy_mean_reward(self) -> Optional[float]: |
|
"""Returns the mean episode reward for the current policy.""" |
|
rewards = self.cumulative_returns_since_policy_update |
|
if len(rewards) == 0: |
|
return None |
|
else: |
|
return sum(rewards) / len(rewards) |
|
|
|
@timed |
|
def _checkpoint(self) -> ModelCheckpoint: |
|
""" |
|
Checkpoints the policy associated with this trainer. |
|
""" |
|
n_policies = len(self.policies.keys()) |
|
if n_policies > 1: |
|
logger.warning( |
|
"Trainer has multiple policies, but default behavior only saves the first." |
|
) |
|
export_path, auxillary_paths = self.model_saver.save_checkpoint( |
|
self.brain_name, self._step |
|
) |
|
new_checkpoint = ModelCheckpoint( |
|
int(self._step), |
|
export_path, |
|
self._policy_mean_reward(), |
|
time.time(), |
|
auxillary_file_paths=auxillary_paths, |
|
) |
|
ModelCheckpointManager.add_checkpoint( |
|
self.brain_name, new_checkpoint, self.trainer_settings.keep_checkpoints |
|
) |
|
return new_checkpoint |
|
|
|
def save_model(self) -> None: |
|
""" |
|
Saves the policy associated with this trainer. |
|
""" |
|
n_policies = len(self.policies.keys()) |
|
if n_policies > 1: |
|
logger.warning( |
|
"Trainer has multiple policies, but default behavior only saves the first." |
|
) |
|
elif n_policies == 0: |
|
logger.warning("Trainer has no policies, not saving anything.") |
|
return |
|
|
|
model_checkpoint = self._checkpoint() |
|
self.model_saver.copy_final_model(model_checkpoint.file_path) |
|
export_ext = "onnx" |
|
final_checkpoint = attr.evolve( |
|
model_checkpoint, file_path=f"{self.model_saver.model_path}.{export_ext}" |
|
) |
|
ModelCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint) |
|
|
|
@abc.abstractmethod |
|
def _update_policy(self) -> bool: |
|
""" |
|
Uses demonstration_buffer to update model. |
|
:return: Whether or not the policy was updated. |
|
""" |
|
pass |
|
|
|
def _increment_step(self, n_steps: int, name_behavior_id: str) -> None: |
|
""" |
|
Increment the step count of the trainer |
|
:param n_steps: number of steps to increment the step count by |
|
""" |
|
self._step += n_steps |
|
self._next_summary_step = self._get_next_interval_step(self.summary_freq) |
|
self._next_save_step = self._get_next_interval_step( |
|
self.trainer_settings.checkpoint_interval |
|
) |
|
p = self.get_policy(name_behavior_id) |
|
if p: |
|
p.increment_step(n_steps) |
|
self.stats_reporter.set_stat("Step", float(self.get_step)) |
|
|
|
def _get_next_interval_step(self, interval: int) -> int: |
|
""" |
|
Get the next step count that should result in an action. |
|
:param interval: The interval between actions. |
|
""" |
|
return self._step + (interval - self._step % interval) |
|
|
|
def _write_summary(self, step: int) -> None: |
|
""" |
|
Saves training statistics to Tensorboard. |
|
""" |
|
self.stats_reporter.add_stat("Is Training", float(self.should_still_train)) |
|
self.stats_reporter.write_stats(int(step)) |
|
|
|
@abc.abstractmethod |
|
def _process_trajectory(self, trajectory: Trajectory) -> None: |
|
""" |
|
Takes a trajectory and processes it, putting it into the update buffer. |
|
:param trajectory: The Trajectory tuple containing the steps to be processed. |
|
""" |
|
self._maybe_write_summary(self.get_step + len(trajectory.steps)) |
|
self._maybe_save_model(self.get_step + len(trajectory.steps)) |
|
self._increment_step(len(trajectory.steps), trajectory.behavior_id) |
|
|
|
def _maybe_write_summary(self, step_after_process: int) -> None: |
|
""" |
|
If processing the trajectory will make the step exceed the next summary write, |
|
write the summary. This logic ensures summaries are written on the update step and not in between. |
|
:param step_after_process: the step count after processing the next trajectory. |
|
""" |
|
if self._next_summary_step == 0: |
|
self._next_summary_step = self._get_next_interval_step(self.summary_freq) |
|
if step_after_process >= self._next_summary_step and self.get_step != 0: |
|
self._write_summary(self._next_summary_step) |
|
|
|
def _append_to_update_buffer(self, agentbuffer_trajectory: AgentBuffer) -> None: |
|
""" |
|
Append an AgentBuffer to the update buffer. If the trainer isn't training, |
|
don't update to avoid a memory leak. |
|
""" |
|
if self.should_still_train: |
|
seq_len = ( |
|
self.trainer_settings.network_settings.memory.sequence_length |
|
if self.trainer_settings.network_settings.memory is not None |
|
else 1 |
|
) |
|
agentbuffer_trajectory.resequence_and_append( |
|
self.update_buffer, training_length=seq_len |
|
) |
|
|
|
def _maybe_save_model(self, step_after_process: int) -> None: |
|
""" |
|
If processing the trajectory will make the step exceed the next model write, |
|
save the model. This logic ensures models are written on the update step and not in between. |
|
:param step_after_process: the step count after processing the next trajectory. |
|
""" |
|
if self._next_save_step == 0: |
|
self._next_save_step = self._get_next_interval_step( |
|
self.trainer_settings.checkpoint_interval |
|
) |
|
if step_after_process >= self._next_save_step and self.get_step != 0: |
|
self._checkpoint() |
|
|
|
def _warn_if_group_reward(self, buffer: AgentBuffer) -> None: |
|
""" |
|
Warn if the trainer receives a Group Reward but isn't a multiagent trainer (e.g. POCA). |
|
""" |
|
if not self._has_warned_group_rewards: |
|
if np.any(buffer[BufferKey.GROUP_REWARD]): |
|
logger.warning( |
|
"An agent recieved a Group Reward, but you are not using a multi-agent trainer. " |
|
"Please use the POCA trainer for best results." |
|
) |
|
self._has_warned_group_rewards = True |
|
|
|
def advance(self) -> None: |
|
""" |
|
Steps the trainer, taking in trajectories and updates if ready. |
|
Will block and wait briefly if there are no trajectories. |
|
""" |
|
with hierarchical_timer("process_trajectory"): |
|
for traj_queue in self.trajectory_queues: |
|
|
|
|
|
|
|
_queried = False |
|
for _ in range(traj_queue.qsize()): |
|
_queried = True |
|
try: |
|
t = traj_queue.get_nowait() |
|
self._process_trajectory(t) |
|
except AgentManagerQueue.Empty: |
|
break |
|
if self.threaded and not _queried: |
|
|
|
time.sleep(0.0001) |
|
if self.should_still_train: |
|
if self._is_ready_update(): |
|
with hierarchical_timer("_update_policy"): |
|
if self._update_policy(): |
|
for q in self.policy_queues: |
|
|
|
q.put(self.get_policy(q.behavior_id)) |
|
|