|
import os |
|
from pathlib import Path |
|
|
|
import pytest |
|
|
|
from mlagents_envs.registry import default_registry, UnityEnvRegistry |
|
from mlagents_envs.registry.remote_registry_entry import RemoteRegistryEntry |
|
|
|
BASIC_ID = "Basic" |
|
|
|
|
|
def create_registry(tmp_dir: str) -> UnityEnvRegistry: |
|
reg = UnityEnvRegistry() |
|
entry = RemoteRegistryEntry( |
|
BASIC_ID, |
|
0.0, |
|
"Description", |
|
"https://storage.googleapis.com/mlagents-test-environments/1.0.0/linux/Basic.zip", |
|
"https://storage.googleapis.com/mlagents-test-environments/1.0.0/darwin/Basic.zip", |
|
"https://storage.googleapis.com/mlagents-test-environments/1.0.0/windows/Basic.zip", |
|
tmp_dir=tmp_dir, |
|
) |
|
reg.register(entry) |
|
return reg |
|
|
|
|
|
@pytest.mark.parametrize("n_ports", [2]) |
|
def test_basic_in_registry(base_port: int, tmp_path: Path) -> None: |
|
assert BASIC_ID in default_registry |
|
os.environ["TERM"] = "xterm" |
|
registry = create_registry(str(tmp_path)) |
|
for worker_id in range(2): |
|
assert BASIC_ID in registry |
|
env = registry[BASIC_ID].make( |
|
base_port=base_port, worker_id=worker_id, no_graphics=True |
|
) |
|
env.reset() |
|
env.step() |
|
assert len(env.behavior_specs) == 1 |
|
env.close() |
|
|