File size: 2,494 Bytes
e11e4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import Any, Optional
from gym import error
from mlagents_envs.base_env import BaseEnv
from pettingzoo import AECEnv

from mlagents_envs.envs.unity_pettingzoo_base_env import UnityPettingzooBaseEnv


class UnityAECEnv(UnityPettingzooBaseEnv, AECEnv):
    """
    Unity AEC (PettingZoo) environment wrapper.
    """

    def __init__(self, env: BaseEnv, seed: Optional[int] = None):
        """
        Initializes a Unity AEC environment wrapper.

        :param env: The UnityEnvironment that is being wrapped.
        :param seed: The seed for the action spaces of the agents.
        """
        super().__init__(env, seed)

    def step(self, action: Any) -> None:
        """
        Sets the action of the active agent and get the observation, reward, done
        and info of the next agent.
        :param action: The action for the active agent
        """
        self._assert_loaded()
        if len(self._live_agents) <= 0:
            raise error.Error(
                "You must reset the environment before you can perform a step"
            )

        # Process action
        current_agent = self._agents[self._agent_index]
        self._process_action(current_agent, action)

        self._agent_index += 1
        # Reset reward
        for k in self._rewards.keys():
            self._rewards[k] = 0

        if self._agent_index >= len(self._agents) and self.num_agents > 0:
            # The index is too high, time to set the action for the agents we have
            self._step()
            self._live_agents.sort()  # unnecessary, only for passing API test

    def observe(self, agent_id):
        """
        Returns the observation an agent currently can make. `last()` calls this function.
        """
        return (
            self._observations[agent_id],
            self._cumm_rewards[agent_id],
            self._dones[agent_id],
            self._infos[agent_id],
        )

    def last(self, observe=True):
        """
        returns observation, cumulative reward, done, info for the current agent (specified by self.agent_selection)
        """
        obs, reward, done, info = self.observe(self._agents[self._agent_index])
        return obs if observe else None, reward, done, info

    @property
    def agent_selection(self):
        if not self._live_agents:
            # If we had an agent finish then return that agent even though it isn't alive.
            return self._agents[0]
        return self._agents[self._agent_index]