|
from sys import platform |
|
from typing import Optional, Any, List |
|
from mlagents_envs.environment import UnityEnvironment |
|
from mlagents_envs.base_env import BaseEnv |
|
from mlagents_envs.registry.binary_utils import get_local_binary_path |
|
from mlagents_envs.registry.base_registry_entry import BaseRegistryEntry |
|
|
|
|
|
class RemoteRegistryEntry(BaseRegistryEntry): |
|
def __init__( |
|
self, |
|
identifier: str, |
|
expected_reward: Optional[float], |
|
description: Optional[str], |
|
linux_url: Optional[str], |
|
darwin_url: Optional[str], |
|
win_url: Optional[str], |
|
additional_args: Optional[List[str]] = None, |
|
tmp_dir: Optional[str] = None, |
|
): |
|
""" |
|
A RemoteRegistryEntry is an implementation of BaseRegistryEntry that uses a |
|
Unity executable downloaded from the internet to launch a UnityEnvironment. |
|
__Note__: The url provided must be a link to a `.zip` file containing a single |
|
compressed folder with the executable inside. There can only be one executable |
|
in the folder and it must be at the root of the folder. |
|
:param identifier: The name of the Unity Environment. |
|
:param expected_reward: The cumulative reward that an Agent must receive |
|
for the task to be considered solved. |
|
:param description: A description of the Unity Environment. Contains human |
|
readable information about potential special arguments that the make method can |
|
take as well as information regarding the observation, reward, actions, |
|
behaviors and number of agents in the Environment. |
|
:param linux_url: The url of the Unity executable for the Linux platform |
|
:param darwin_url: The url of the Unity executable for the OSX platform |
|
:param win_url: The url of the Unity executable for the Windows platform |
|
""" |
|
super().__init__(identifier, expected_reward, description) |
|
self._linux_url = linux_url |
|
self._darwin_url = darwin_url |
|
self._win_url = win_url |
|
self._add_args = additional_args |
|
self._tmp_dir_override = tmp_dir |
|
|
|
def make(self, **kwargs: Any) -> BaseEnv: |
|
""" |
|
Returns the UnityEnvironment that corresponds to the Unity executable found at |
|
the provided url. The arguments passed to this method will be passed to the |
|
constructor of the UnityEnvironment (except for the file_name argument) |
|
""" |
|
url = None |
|
if platform == "linux" or platform == "linux2": |
|
url = self._linux_url |
|
if platform == "darwin": |
|
url = self._darwin_url |
|
if platform == "win32": |
|
url = self._win_url |
|
if url is None: |
|
raise FileNotFoundError( |
|
f"The entry {self.identifier} does not contain a valid url for this " |
|
"platform" |
|
) |
|
path = get_local_binary_path( |
|
self.identifier, url, tmp_dir=self._tmp_dir_override |
|
) |
|
if "file_name" in kwargs: |
|
kwargs.pop("file_name") |
|
args: List[str] = [] |
|
if "additional_args" in kwargs: |
|
if kwargs["additional_args"] is not None: |
|
args += kwargs["additional_args"] |
|
if self._add_args is not None: |
|
args += self._add_args |
|
kwargs["additional_args"] = args |
|
return UnityEnvironment(file_name=path, **kwargs) |
|
|