Kano001's picture
Upload 280 files
e11e4fe verified
raw
history blame
3.39 kB
from sys import platform
from typing import Optional, Any, List
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.base_env import BaseEnv
from mlagents_envs.registry.binary_utils import get_local_binary_path
from mlagents_envs.registry.base_registry_entry import BaseRegistryEntry
class RemoteRegistryEntry(BaseRegistryEntry):
def __init__(
self,
identifier: str,
expected_reward: Optional[float],
description: Optional[str],
linux_url: Optional[str],
darwin_url: Optional[str],
win_url: Optional[str],
additional_args: Optional[List[str]] = None,
tmp_dir: Optional[str] = None,
):
"""
A RemoteRegistryEntry is an implementation of BaseRegistryEntry that uses a
Unity executable downloaded from the internet to launch a UnityEnvironment.
__Note__: The url provided must be a link to a `.zip` file containing a single
compressed folder with the executable inside. There can only be one executable
in the folder and it must be at the root of the folder.
:param identifier: The name of the Unity Environment.
:param expected_reward: The cumulative reward that an Agent must receive
for the task to be considered solved.
:param description: A description of the Unity Environment. Contains human
readable information about potential special arguments that the make method can
take as well as information regarding the observation, reward, actions,
behaviors and number of agents in the Environment.
:param linux_url: The url of the Unity executable for the Linux platform
:param darwin_url: The url of the Unity executable for the OSX platform
:param win_url: The url of the Unity executable for the Windows platform
"""
super().__init__(identifier, expected_reward, description)
self._linux_url = linux_url
self._darwin_url = darwin_url
self._win_url = win_url
self._add_args = additional_args
self._tmp_dir_override = tmp_dir
def make(self, **kwargs: Any) -> BaseEnv:
"""
Returns the UnityEnvironment that corresponds to the Unity executable found at
the provided url. The arguments passed to this method will be passed to the
constructor of the UnityEnvironment (except for the file_name argument)
"""
url = None
if platform == "linux" or platform == "linux2":
url = self._linux_url
if platform == "darwin":
url = self._darwin_url
if platform == "win32":
url = self._win_url
if url is None:
raise FileNotFoundError(
f"The entry {self.identifier} does not contain a valid url for this "
"platform"
)
path = get_local_binary_path(
self.identifier, url, tmp_dir=self._tmp_dir_override
)
if "file_name" in kwargs:
kwargs.pop("file_name")
args: List[str] = []
if "additional_args" in kwargs:
if kwargs["additional_args"] is not None:
args += kwargs["additional_args"]
if self._add_args is not None:
args += self._add_args
kwargs["additional_args"] = args
return UnityEnvironment(file_name=path, **kwargs)