Spaces:
Running
Running
from typing import Any, Optional | |
from gym import error | |
from mlagents_envs.base_env import BaseEnv | |
from pettingzoo import AECEnv | |
from mlagents_envs.envs.unity_pettingzoo_base_env import UnityPettingzooBaseEnv | |
class UnityAECEnv(UnityPettingzooBaseEnv, AECEnv): | |
""" | |
Unity AEC (PettingZoo) environment wrapper. | |
""" | |
def __init__(self, env: BaseEnv, seed: Optional[int] = None): | |
""" | |
Initializes a Unity AEC environment wrapper. | |
:param env: The UnityEnvironment that is being wrapped. | |
:param seed: The seed for the action spaces of the agents. | |
""" | |
super().__init__(env, seed) | |
def step(self, action: Any) -> None: | |
""" | |
Sets the action of the active agent and get the observation, reward, done | |
and info of the next agent. | |
:param action: The action for the active agent | |
""" | |
self._assert_loaded() | |
if len(self._live_agents) <= 0: | |
raise error.Error( | |
"You must reset the environment before you can perform a step" | |
) | |
# Process action | |
current_agent = self._agents[self._agent_index] | |
self._process_action(current_agent, action) | |
self._agent_index += 1 | |
# Reset reward | |
for k in self._rewards.keys(): | |
self._rewards[k] = 0 | |
if self._agent_index >= len(self._agents) and self.num_agents > 0: | |
# The index is too high, time to set the action for the agents we have | |
self._step() | |
self._live_agents.sort() # unnecessary, only for passing API test | |
def observe(self, agent_id): | |
""" | |
Returns the observation an agent currently can make. `last()` calls this function. | |
""" | |
return ( | |
self._observations[agent_id], | |
self._cumm_rewards[agent_id], | |
self._dones[agent_id], | |
self._infos[agent_id], | |
) | |
def last(self, observe=True): | |
""" | |
returns observation, cumulative reward, done, info for the current agent (specified by self.agent_selection) | |
""" | |
obs, reward, done, info = self.observe(self._agents[self._agent_index]) | |
return obs if observe else None, reward, done, info | |
def agent_selection(self): | |
if not self._live_agents: | |
# If we had an agent finish then return that agent even though it isn't alive. | |
return self._agents[0] | |
return self._agents[self._agent_index] | |