File size: 1,957 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import Optional, Union, List

from mlagents_envs import logging_util
from mlagents_envs.exception import UnityWorkerInUseException
from mlagents_envs.registry import default_registry
from mlagents_envs.side_channel.engine_configuration_channel import (
    EngineConfigurationChannel,
)
from mlagents_envs.side_channel.environment_parameters_channel import (
    EnvironmentParametersChannel,
)
from mlagents_envs.side_channel.stats_side_channel import StatsSideChannel
from mlagents_envs.envs.unity_aec_env import UnityAECEnv

logger = logging_util.get_logger(__name__)


class PettingZooEnvFactory:
    def __init__(self, env_id: str) -> None:
        self.env_id = env_id

    def env(
        self, seed: Optional[int] = None, **kwargs: Union[List, int, bool, None]
    ) -> UnityAECEnv:
        """
        Creates the environment with env_id from unity's default_registry and wraps it in a UnityToPettingZooWrapper
        :param seed: The seed for the action spaces of the agents.
        :param kwargs: Any argument accepted by `UnityEnvironment`class except file_name
        """
        # If not side_channels specified, add the followings
        if "side_channels" not in kwargs:
            kwargs["side_channels"] = [
                EngineConfigurationChannel(),
                EnvironmentParametersChannel(),
                StatsSideChannel(),
            ]
        _env = None
        # If no base port argument is provided, try ports starting at 6000 until one is free
        if "base_port" not in kwargs:
            port = 6000
            while _env is None:
                try:
                    kwargs["base_port"] = port
                    _env = default_registry[self.env_id].make(**kwargs)
                except UnityWorkerInUseException:
                    port += 1
                    pass
        else:
            _env = default_registry[self.env_id].make(**kwargs)
        return UnityAECEnv(_env, seed)