File size: 3,394 Bytes
e11e4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from sys import platform
from typing import Optional, Any, List
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.base_env import BaseEnv
from mlagents_envs.registry.binary_utils import get_local_binary_path
from mlagents_envs.registry.base_registry_entry import BaseRegistryEntry


class RemoteRegistryEntry(BaseRegistryEntry):
    def __init__(
        self,
        identifier: str,
        expected_reward: Optional[float],
        description: Optional[str],
        linux_url: Optional[str],
        darwin_url: Optional[str],
        win_url: Optional[str],
        additional_args: Optional[List[str]] = None,
        tmp_dir: Optional[str] = None,
    ):
        """
        A RemoteRegistryEntry is an implementation of BaseRegistryEntry that uses a
        Unity executable downloaded from the internet to launch a UnityEnvironment.
        __Note__: The url provided must be a link to a `.zip` file containing a single
        compressed folder with the executable inside. There can only be one executable
        in the folder and it must be at the root of the folder.
        :param identifier: The name of the Unity Environment.
        :param expected_reward: The cumulative reward that an Agent must receive
        for the task to be considered solved.
        :param description: A description of the Unity Environment. Contains human
        readable information about potential special arguments that the make method can
        take as well as information regarding the observation, reward, actions,
        behaviors and number of agents in the Environment.
        :param linux_url: The url of the Unity executable for the Linux platform
        :param darwin_url: The url of the Unity executable for the OSX platform
        :param win_url: The url of the Unity executable for the Windows platform
        """
        super().__init__(identifier, expected_reward, description)
        self._linux_url = linux_url
        self._darwin_url = darwin_url
        self._win_url = win_url
        self._add_args = additional_args
        self._tmp_dir_override = tmp_dir

    def make(self, **kwargs: Any) -> BaseEnv:
        """
        Returns the UnityEnvironment that corresponds to the Unity executable found at
        the provided url. The arguments passed to this method will be passed to the
        constructor of the UnityEnvironment (except for the file_name argument)
        """
        url = None
        if platform == "linux" or platform == "linux2":
            url = self._linux_url
        if platform == "darwin":
            url = self._darwin_url
        if platform == "win32":
            url = self._win_url
        if url is None:
            raise FileNotFoundError(
                f"The entry {self.identifier} does not contain a valid url for this "
                "platform"
            )
        path = get_local_binary_path(
            self.identifier, url, tmp_dir=self._tmp_dir_override
        )
        if "file_name" in kwargs:
            kwargs.pop("file_name")
        args: List[str] = []
        if "additional_args" in kwargs:
            if kwargs["additional_args"] is not None:
                args += kwargs["additional_args"]
        if self._add_args is not None:
            args += self._add_args
        kwargs["additional_args"] = args
        return UnityEnvironment(file_name=path, **kwargs)