import datetime from typing import Dict, NamedTuple, List, Any, Optional, Callable, Set import cloudpickle import enum import time from mlagents_envs.environment import UnityEnvironment from mlagents_envs.exception import ( UnityCommunicationException, UnityTimeOutException, UnityEnvironmentException, UnityCommunicatorStoppedException, ) from multiprocessing import Process, Pipe, Queue from multiprocessing.connection import Connection from queue import Empty as EmptyQueueException from mlagents_envs.base_env import BaseEnv, BehaviorName, BehaviorSpec from mlagents_envs import logging_util from mlagents.trainers.env_manager import EnvManager, EnvironmentStep, AllStepResult from mlagents.trainers.settings import TrainerSettings from mlagents_envs.timers import ( TimerNode, timed, hierarchical_timer, reset_timers, get_timer_root, ) from mlagents.trainers.settings import ParameterRandomizationSettings, RunOptions from mlagents.trainers.action_info import ActionInfo from mlagents_envs.side_channel.environment_parameters_channel import ( EnvironmentParametersChannel, ) from mlagents_envs.side_channel.engine_configuration_channel import ( EngineConfigurationChannel, EngineConfig, ) from mlagents_envs.side_channel.stats_side_channel import ( EnvironmentStats, StatsSideChannel, ) from mlagents.trainers.training_analytics_side_channel import ( TrainingAnalyticsSideChannel, ) from mlagents_envs.side_channel.side_channel import SideChannel logger = logging_util.get_logger(__name__) WORKER_SHUTDOWN_TIMEOUT_S = 10 class EnvironmentCommand(enum.Enum): STEP = 1 BEHAVIOR_SPECS = 2 ENVIRONMENT_PARAMETERS = 3 RESET = 4 CLOSE = 5 ENV_EXITED = 6 CLOSED = 7 TRAINING_STARTED = 8 class EnvironmentRequest(NamedTuple): cmd: EnvironmentCommand payload: Any = None class EnvironmentResponse(NamedTuple): cmd: EnvironmentCommand worker_id: int payload: Any class StepResponse(NamedTuple): all_step_result: AllStepResult timer_root: Optional[TimerNode] environment_stats: EnvironmentStats class UnityEnvWorker: def __init__(self, process: Process, worker_id: int, conn: Connection): self.process = process self.worker_id = worker_id self.conn = conn self.previous_step: EnvironmentStep = EnvironmentStep.empty(worker_id) self.previous_all_action_info: Dict[str, ActionInfo] = {} self.waiting = False self.closed = False def send(self, cmd: EnvironmentCommand, payload: Any = None) -> None: try: req = EnvironmentRequest(cmd, payload) self.conn.send(req) except (BrokenPipeError, EOFError): raise UnityCommunicationException("UnityEnvironment worker: send failed.") def recv(self) -> EnvironmentResponse: try: response: EnvironmentResponse = self.conn.recv() if response.cmd == EnvironmentCommand.ENV_EXITED: env_exception: Exception = response.payload raise env_exception return response except (BrokenPipeError, EOFError): raise UnityCommunicationException("UnityEnvironment worker: recv failed.") def request_close(self): try: self.conn.send(EnvironmentRequest(EnvironmentCommand.CLOSE)) except (BrokenPipeError, EOFError): logger.debug( f"UnityEnvWorker {self.worker_id} got exception trying to close." ) pass def worker( parent_conn: Connection, step_queue: Queue, pickled_env_factory: str, worker_id: int, run_options: RunOptions, log_level: int = logging_util.INFO, ) -> None: env_factory: Callable[ [int, List[SideChannel]], UnityEnvironment ] = cloudpickle.loads(pickled_env_factory) env_parameters = EnvironmentParametersChannel() engine_config = EngineConfig( width=run_options.engine_settings.width, height=run_options.engine_settings.height, quality_level=run_options.engine_settings.quality_level, time_scale=run_options.engine_settings.time_scale, target_frame_rate=run_options.engine_settings.target_frame_rate, capture_frame_rate=run_options.engine_settings.capture_frame_rate, ) engine_configuration_channel = EngineConfigurationChannel() engine_configuration_channel.set_configuration(engine_config) stats_channel = StatsSideChannel() training_analytics_channel: Optional[TrainingAnalyticsSideChannel] = None if worker_id == 0: training_analytics_channel = TrainingAnalyticsSideChannel() env: UnityEnvironment = None # Set log level. On some platforms, the logger isn't common with the # main process, so we need to set it again. logging_util.set_log_level(log_level) def _send_response(cmd_name: EnvironmentCommand, payload: Any) -> None: parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload)) def _generate_all_results() -> AllStepResult: all_step_result: AllStepResult = {} for brain_name in env.behavior_specs: all_step_result[brain_name] = env.get_steps(brain_name) return all_step_result try: side_channels = [env_parameters, engine_configuration_channel, stats_channel] if training_analytics_channel is not None: side_channels.append(training_analytics_channel) env = env_factory(worker_id, side_channels) if ( not env.academy_capabilities or not env.academy_capabilities.trainingAnalytics ): # Make sure we don't try to send training analytics if the environment doesn't know how to process # them. This wouldn't be catastrophic, but would result in unknown SideChannel UUIDs being used. training_analytics_channel = None if training_analytics_channel: training_analytics_channel.environment_initialized(run_options) while True: req: EnvironmentRequest = parent_conn.recv() if req.cmd == EnvironmentCommand.STEP: all_action_info = req.payload for brain_name, action_info in all_action_info.items(): if len(action_info.agent_ids) > 0: env.set_actions(brain_name, action_info.env_action) env.step() all_step_result = _generate_all_results() # The timers in this process are independent from all the processes and the "main" process # So after we send back the root timer, we can safely clear them. # Note that we could randomly return timers a fraction of the time if we wanted to reduce # the data transferred. # TODO get gauges from the workers and merge them in the main process too. env_stats = stats_channel.get_and_reset_stats() step_response = StepResponse( all_step_result, get_timer_root(), env_stats ) step_queue.put( EnvironmentResponse( EnvironmentCommand.STEP, worker_id, step_response ) ) reset_timers() elif req.cmd == EnvironmentCommand.BEHAVIOR_SPECS: _send_response(EnvironmentCommand.BEHAVIOR_SPECS, env.behavior_specs) elif req.cmd == EnvironmentCommand.ENVIRONMENT_PARAMETERS: for k, v in req.payload.items(): if isinstance(v, ParameterRandomizationSettings): v.apply(k, env_parameters) elif req.cmd == EnvironmentCommand.TRAINING_STARTED: behavior_name, trainer_config = req.payload if training_analytics_channel: training_analytics_channel.training_started( behavior_name, trainer_config ) elif req.cmd == EnvironmentCommand.RESET: env.reset() all_step_result = _generate_all_results() _send_response(EnvironmentCommand.RESET, all_step_result) elif req.cmd == EnvironmentCommand.CLOSE: break except ( KeyboardInterrupt, UnityCommunicationException, UnityTimeOutException, UnityEnvironmentException, UnityCommunicatorStoppedException, ) as ex: logger.debug(f"UnityEnvironment worker {worker_id}: environment stopping.") step_queue.put( EnvironmentResponse(EnvironmentCommand.ENV_EXITED, worker_id, ex) ) _send_response(EnvironmentCommand.ENV_EXITED, ex) except Exception as ex: logger.exception( f"UnityEnvironment worker {worker_id}: environment raised an unexpected exception." ) step_queue.put( EnvironmentResponse(EnvironmentCommand.ENV_EXITED, worker_id, ex) ) _send_response(EnvironmentCommand.ENV_EXITED, ex) finally: logger.debug(f"UnityEnvironment worker {worker_id} closing.") if env is not None: env.close() logger.debug(f"UnityEnvironment worker {worker_id} done.") parent_conn.close() step_queue.put(EnvironmentResponse(EnvironmentCommand.CLOSED, worker_id, None)) step_queue.close() class SubprocessEnvManager(EnvManager): def __init__( self, env_factory: Callable[[int, List[SideChannel]], BaseEnv], run_options: RunOptions, n_env: int = 1, ): super().__init__() self.env_workers: List[UnityEnvWorker] = [] self.step_queue: Queue = Queue() self.workers_alive = 0 self.env_factory = env_factory self.run_options = run_options self.env_parameters: Optional[Dict] = None # Each worker is correlated with a list of times they restarted within the last time period. self.recent_restart_timestamps: List[List[datetime.datetime]] = [ [] for _ in range(n_env) ] self.restart_counts: List[int] = [0] * n_env for worker_idx in range(n_env): self.env_workers.append( self.create_worker( worker_idx, self.step_queue, env_factory, run_options ) ) self.workers_alive += 1 @staticmethod def create_worker( worker_id: int, step_queue: Queue, env_factory: Callable[[int, List[SideChannel]], BaseEnv], run_options: RunOptions, ) -> UnityEnvWorker: parent_conn, child_conn = Pipe() # Need to use cloudpickle for the env factory function since function objects aren't picklable # on Windows as of Python 3.6. pickled_env_factory = cloudpickle.dumps(env_factory) child_process = Process( target=worker, args=( child_conn, step_queue, pickled_env_factory, worker_id, run_options, logger.level, ), ) child_process.start() return UnityEnvWorker(child_process, worker_id, parent_conn) def _queue_steps(self) -> None: for env_worker in self.env_workers: if not env_worker.waiting: env_action_info = self._take_step(env_worker.previous_step) env_worker.previous_all_action_info = env_action_info env_worker.send(EnvironmentCommand.STEP, env_action_info) env_worker.waiting = True def _restart_failed_workers(self, first_failure: EnvironmentResponse) -> None: if first_failure.cmd != EnvironmentCommand.ENV_EXITED: return # Drain the step queue to make sure all workers are paused and we have found all concurrent errors. # Pausing all training is needed since we need to reset all pending training steps as they could be corrupted. other_failures: Dict[int, Exception] = self._drain_step_queue() # TODO: Once we use python 3.9 switch to using the | operator to combine dicts. failures: Dict[int, Exception] = { **{first_failure.worker_id: first_failure.payload}, **other_failures, } for worker_id, ex in failures.items(): self._assert_worker_can_restart(worker_id, ex) logger.warning(f"Restarting worker[{worker_id}] after '{ex}'") self.recent_restart_timestamps[worker_id].append(datetime.datetime.now()) self.restart_counts[worker_id] += 1 self.env_workers[worker_id] = self.create_worker( worker_id, self.step_queue, self.env_factory, self.run_options ) # The restarts were successful, clear all the existing training trajectories so we don't use corrupted or # outdated data. self.reset(self.env_parameters) def _drain_step_queue(self) -> Dict[int, Exception]: """ Drains all steps out of the step queue and returns all exceptions from crashed workers. This will effectively pause all workers so that they won't do anything until _queue_steps is called. """ all_failures = {} workers_still_pending = {w.worker_id for w in self.env_workers if w.waiting} deadline = datetime.datetime.now() + datetime.timedelta(minutes=1) while workers_still_pending and deadline > datetime.datetime.now(): try: while True: step: EnvironmentResponse = self.step_queue.get_nowait() if step.cmd == EnvironmentCommand.ENV_EXITED: workers_still_pending.add(step.worker_id) all_failures[step.worker_id] = step.payload else: workers_still_pending.remove(step.worker_id) self.env_workers[step.worker_id].waiting = False except EmptyQueueException: pass if deadline < datetime.datetime.now(): still_waiting = {w.worker_id for w in self.env_workers if w.waiting} raise TimeoutError(f"Workers {still_waiting} stuck in waiting state") return all_failures def _assert_worker_can_restart(self, worker_id: int, exception: Exception) -> None: """ Checks if we can recover from an exception from a worker. If the restart limit is exceeded it will raise a UnityCommunicationException. If the exception is not recoverable it re-raises the exception. """ if ( isinstance(exception, UnityCommunicationException) or isinstance(exception, UnityTimeOutException) or isinstance(exception, UnityEnvironmentException) or isinstance(exception, UnityCommunicatorStoppedException) ): if self._worker_has_restart_quota(worker_id): return else: logger.error( f"Worker {worker_id} exceeded the allowed number of restarts." ) raise exception raise exception def _worker_has_restart_quota(self, worker_id: int) -> bool: self._drop_old_restart_timestamps(worker_id) max_lifetime_restarts = self.run_options.env_settings.max_lifetime_restarts max_limit_check = ( max_lifetime_restarts == -1 or self.restart_counts[worker_id] < max_lifetime_restarts ) rate_limit_n = self.run_options.env_settings.restarts_rate_limit_n rate_limit_check = ( rate_limit_n == -1 or len(self.recent_restart_timestamps[worker_id]) < rate_limit_n ) return rate_limit_check and max_limit_check def _drop_old_restart_timestamps(self, worker_id: int) -> None: """ Drops environment restart timestamps that are outside of the current window. """ def _filter(t: datetime.datetime) -> bool: return t > datetime.datetime.now() - datetime.timedelta( seconds=self.run_options.env_settings.restarts_rate_limit_period_s ) self.recent_restart_timestamps[worker_id] = list( filter(_filter, self.recent_restart_timestamps[worker_id]) ) def _step(self) -> List[EnvironmentStep]: # Queue steps for any workers which aren't in the "waiting" state. self._queue_steps() worker_steps: List[EnvironmentResponse] = [] step_workers: Set[int] = set() # Poll the step queue for completed steps from environment workers until we retrieve # 1 or more, which we will then return as StepInfos while len(worker_steps) < 1: try: while True: step: EnvironmentResponse = self.step_queue.get_nowait() if step.cmd == EnvironmentCommand.ENV_EXITED: # If even one env exits try to restart all envs that failed. self._restart_failed_workers(step) # Clear state and restart this function. worker_steps.clear() step_workers.clear() self._queue_steps() elif step.worker_id not in step_workers: self.env_workers[step.worker_id].waiting = False worker_steps.append(step) step_workers.add(step.worker_id) except EmptyQueueException: pass step_infos = self._postprocess_steps(worker_steps) return step_infos def _reset_env(self, config: Optional[Dict] = None) -> List[EnvironmentStep]: while any(ew.waiting for ew in self.env_workers): if not self.step_queue.empty(): step = self.step_queue.get_nowait() self.env_workers[step.worker_id].waiting = False # Send config to environment self.set_env_parameters(config) # First enqueue reset commands for all workers so that they reset in parallel for ew in self.env_workers: ew.send(EnvironmentCommand.RESET, config) # Next (synchronously) collect the reset observations from each worker in sequence for ew in self.env_workers: ew.previous_step = EnvironmentStep(ew.recv().payload, ew.worker_id, {}, {}) return list(map(lambda ew: ew.previous_step, self.env_workers)) def set_env_parameters(self, config: Dict = None) -> None: """ Sends environment parameter settings to C# via the EnvironmentParametersSidehannel for each worker. :param config: Dict of environment parameter keys and values """ self.env_parameters = config for ew in self.env_workers: ew.send(EnvironmentCommand.ENVIRONMENT_PARAMETERS, config) def on_training_started( self, behavior_name: str, trainer_settings: TrainerSettings ) -> None: """ Handle traing starting for a new behavior type. Generally nothing is necessary here. :param behavior_name: :param trainer_settings: :return: """ for ew in self.env_workers: ew.send( EnvironmentCommand.TRAINING_STARTED, (behavior_name, trainer_settings) ) @property def training_behaviors(self) -> Dict[BehaviorName, BehaviorSpec]: result: Dict[BehaviorName, BehaviorSpec] = {} for worker in self.env_workers: worker.send(EnvironmentCommand.BEHAVIOR_SPECS) result.update(worker.recv().payload) return result def close(self) -> None: logger.debug("SubprocessEnvManager closing.") for env_worker in self.env_workers: env_worker.request_close() # Pull messages out of the queue until every worker has CLOSED or we time out. deadline = time.time() + WORKER_SHUTDOWN_TIMEOUT_S while self.workers_alive > 0 and time.time() < deadline: try: step: EnvironmentResponse = self.step_queue.get_nowait() env_worker = self.env_workers[step.worker_id] if step.cmd == EnvironmentCommand.CLOSED and not env_worker.closed: env_worker.closed = True self.workers_alive -= 1 # Discard all other messages. except EmptyQueueException: pass self.step_queue.close() # Sanity check to kill zombie workers and report an issue if they occur. if self.workers_alive > 0: logger.error("SubprocessEnvManager had workers that didn't signal shutdown") for env_worker in self.env_workers: if not env_worker.closed and env_worker.process.is_alive(): env_worker.process.terminate() logger.error( "A SubprocessEnvManager worker did not shut down correctly so it was forcefully terminated." ) self.step_queue.join_thread() def _postprocess_steps( self, env_steps: List[EnvironmentResponse] ) -> List[EnvironmentStep]: step_infos = [] timer_nodes = [] for step in env_steps: payload: StepResponse = step.payload env_worker = self.env_workers[step.worker_id] new_step = EnvironmentStep( payload.all_step_result, step.worker_id, env_worker.previous_all_action_info, payload.environment_stats, ) step_infos.append(new_step) env_worker.previous_step = new_step if payload.timer_root: timer_nodes.append(payload.timer_root) if timer_nodes: with hierarchical_timer("workers") as main_timer_node: for worker_timer_node in timer_nodes: main_timer_node.merge( worker_timer_node, root_name="worker_root", is_parallel=True ) return step_infos @timed def _take_step(self, last_step: EnvironmentStep) -> Dict[BehaviorName, ActionInfo]: all_action_info: Dict[str, ActionInfo] = {} for brain_name, step_tuple in last_step.current_all_step_result.items(): if brain_name in self.policies: all_action_info[brain_name] = self.policies[brain_name].get_action( step_tuple[0], last_step.worker_id ) return all_action_info