Spaces:
Sleeping
Sleeping
File size: 5,020 Bytes
e11e4fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from typing import Dict, Iterator, Any, List
from collections.abc import Mapping
from mlagents_envs.registry.base_registry_entry import BaseRegistryEntry
from mlagents_envs.registry.binary_utils import (
load_local_manifest,
load_remote_manifest,
)
from mlagents_envs.registry.remote_registry_entry import RemoteRegistryEntry
class UnityEnvRegistry(Mapping):
"""
### UnityEnvRegistry
Provides a library of Unity environments that can be launched without the need
of downloading the Unity Editor.
The UnityEnvRegistry implements a Map, to access an entry of the Registry, use:
```python
registry = UnityEnvRegistry()
entry = registry[<environment_identifyier>]
```
An entry has the following properties :
* `identifier` : Uniquely identifies this environment
* `expected_reward` : Corresponds to the reward an agent must obtained for the task
to be considered completed.
* `description` : A human readable description of the environment.
To launch a Unity environment from a registry entry, use the `make` method:
```python
registry = UnityEnvRegistry()
env = registry[<environment_identifyier>].make()
```
"""
def __init__(self):
self._REGISTERED_ENVS: Dict[str, BaseRegistryEntry] = {}
self._manifests: List[str] = []
self._sync = True
def register(self, new_entry: BaseRegistryEntry) -> None:
"""
Registers a new BaseRegistryEntry to the registry. The
BaseRegistryEntry.identifier value will be used as indexing key.
If two are more environments are registered under the same key, the most
recentry added will replace the others.
"""
self._REGISTERED_ENVS[new_entry.identifier] = new_entry
def register_from_yaml(self, path_to_yaml: str) -> None:
"""
Registers the environments listed in a yaml file (either local or remote). Note
that the entries are registered lazily: the registration will only happen when
an environment is accessed.
The yaml file must have the following format :
```yaml
environments:
- <identifier of the first environment>:
expected_reward: <expected reward of the environment>
description: | <a multi line description of the environment>
<continued multi line description>
linux_url: <The url for the Linux executable zip file>
darwin_url: <The url for the OSX executable zip file>
win_url: <The url for the Windows executable zip file>
- <identifier of the second environment>:
expected_reward: <expected reward of the environment>
description: | <a multi line description of the environment>
<continued multi line description>
linux_url: <The url for the Linux executable zip file>
darwin_url: <The url for the OSX executable zip file>
win_url: <The url for the Windows executable zip file>
- ...
```
:param path_to_yaml: A local path or url to the yaml file
"""
self._manifests.append(path_to_yaml)
self._sync = False
def _load_all_manifests(self) -> None:
if not self._sync:
for path_to_yaml in self._manifests:
if path_to_yaml[:4] == "http":
manifest = load_remote_manifest(path_to_yaml)
else:
manifest = load_local_manifest(path_to_yaml)
for env in manifest["environments"]:
remote_entry_args = list(env.values())[0]
remote_entry_args["identifier"] = list(env.keys())[0]
self.register(RemoteRegistryEntry(**remote_entry_args))
self._manifests = []
self._sync = True
def clear(self) -> None:
"""
Deletes all entries in the registry.
"""
self._REGISTERED_ENVS.clear()
self._manifests = []
self._sync = True
def __getitem__(self, identifier: str) -> BaseRegistryEntry:
"""
Returns the BaseRegistryEntry with the provided identifier. BaseRegistryEntry
can then be used to make a Unity Environment.
:param identifier: The identifier of the BaseRegistryEntry
:returns: The associated BaseRegistryEntry
"""
self._load_all_manifests()
if identifier not in self._REGISTERED_ENVS:
raise KeyError(f"The entry {identifier} is not present in the registry.")
return self._REGISTERED_ENVS[identifier]
def __len__(self) -> int:
self._load_all_manifests()
return len(self._REGISTERED_ENVS)
def __iter__(self) -> Iterator[Any]:
self._load_all_manifests()
yield from self._REGISTERED_ENVS
default_registry = UnityEnvRegistry()
default_registry.register_from_yaml(
"https://storage.googleapis.com/mlagents-test-environments/1.0.0/manifest.yaml"
) # noqa E501
|