Spaces:
Sleeping
Sleeping
Upload 280 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- MLPY/Lib/site-packages/mlagents/__init__.py +0 -0
- MLPY/Lib/site-packages/mlagents/__pycache__/__init__.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/plugins/__init__.py +8 -0
- MLPY/Lib/site-packages/mlagents/plugins/__pycache__/__init__.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/plugins/__pycache__/stats_writer.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/plugins/__pycache__/trainer_type.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/plugins/stats_writer.py +72 -0
- MLPY/Lib/site-packages/mlagents/plugins/trainer_type.py +80 -0
- MLPY/Lib/site-packages/mlagents/torch_utils/__init__.py +4 -0
- MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/__init__.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/cpu_utils.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/globals.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/torch.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/torch_utils/cpu_utils.py +41 -0
- MLPY/Lib/site-packages/mlagents/torch_utils/globals.py +13 -0
- MLPY/Lib/site-packages/mlagents/torch_utils/torch.py +68 -0
- MLPY/Lib/site-packages/mlagents/trainers/__init__.py +5 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/__init__.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/action_info.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/agent_processor.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/behavior_id_utils.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/buffer.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/cli_utils.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/demo_loader.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/directory_utils.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/env_manager.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/environment_parameter_manager.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/exception.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/learn.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/run_experiment.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/settings.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/simple_env_manager.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/stats.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/subprocess_env_manager.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/trainer_controller.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/training_analytics_side_channel.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/training_status.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/trajectory.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/__pycache__/upgrade_config.cpython-39.pyc +0 -0
- MLPY/Lib/site-packages/mlagents/trainers/action_info.py +25 -0
- MLPY/Lib/site-packages/mlagents/trainers/agent_processor.py +469 -0
- MLPY/Lib/site-packages/mlagents/trainers/behavior_id_utils.py +64 -0
- MLPY/Lib/site-packages/mlagents/trainers/buffer.py +521 -0
- MLPY/Lib/site-packages/mlagents/trainers/cli_utils.py +331 -0
- MLPY/Lib/site-packages/mlagents/trainers/demo_loader.py +246 -0
- MLPY/Lib/site-packages/mlagents/trainers/directory_utils.py +76 -0
- MLPY/Lib/site-packages/mlagents/trainers/env_manager.py +157 -0
- MLPY/Lib/site-packages/mlagents/trainers/environment_parameter_manager.py +186 -0
- MLPY/Lib/site-packages/mlagents/trainers/exception.py +75 -0
- MLPY/Lib/site-packages/mlagents/trainers/ghost/__init__.py +0 -0
MLPY/Lib/site-packages/mlagents/__init__.py
ADDED
File without changes
|
MLPY/Lib/site-packages/mlagents/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (145 Bytes). View file
|
|
MLPY/Lib/site-packages/mlagents/plugins/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
|
3 |
+
ML_AGENTS_STATS_WRITER = "mlagents.stats_writer"
|
4 |
+
ML_AGENTS_TRAINER_TYPE = "mlagents.trainer_type"
|
5 |
+
|
6 |
+
# TODO: the real type is Dict[str, HyperparamSettings]
|
7 |
+
all_trainer_types: Dict[str, Any] = {}
|
8 |
+
all_trainer_settings: Dict[str, Any] = {}
|
MLPY/Lib/site-packages/mlagents/plugins/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (427 Bytes). View file
|
|
MLPY/Lib/site-packages/mlagents/plugins/__pycache__/stats_writer.cpython-39.pyc
ADDED
Binary file (2.2 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/plugins/__pycache__/trainer_type.cpython-39.pyc
ADDED
Binary file (2.42 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/plugins/stats_writer.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
# importlib.metadata is new in python3.8
|
5 |
+
# We use the backport for older python versions.
|
6 |
+
if sys.version_info < (3, 8):
|
7 |
+
import importlib_metadata
|
8 |
+
else:
|
9 |
+
import importlib.metadata as importlib_metadata # pylint: disable=E0611
|
10 |
+
|
11 |
+
from mlagents.trainers.stats import StatsWriter
|
12 |
+
|
13 |
+
from mlagents_envs import logging_util
|
14 |
+
from mlagents.plugins import ML_AGENTS_STATS_WRITER
|
15 |
+
from mlagents.trainers.settings import RunOptions
|
16 |
+
from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging_util.get_logger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]:
|
23 |
+
"""
|
24 |
+
The StatsWriters that mlagents-learn always uses:
|
25 |
+
* A TensorboardWriter to write information to TensorBoard
|
26 |
+
* A GaugeWriter to record our internal stats
|
27 |
+
* A ConsoleWriter to output to stdout.
|
28 |
+
"""
|
29 |
+
checkpoint_settings = run_options.checkpoint_settings
|
30 |
+
return [
|
31 |
+
TensorboardWriter(
|
32 |
+
checkpoint_settings.write_path,
|
33 |
+
clear_past_data=not checkpoint_settings.resume,
|
34 |
+
hidden_keys=["Is Training", "Step"],
|
35 |
+
),
|
36 |
+
GaugeWriter(),
|
37 |
+
ConsoleWriter(),
|
38 |
+
]
|
39 |
+
|
40 |
+
|
41 |
+
def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]:
|
42 |
+
"""
|
43 |
+
Registers all StatsWriter plugins (including the default one),
|
44 |
+
and evaluates them, and returns the list of all the StatsWriter implementations.
|
45 |
+
"""
|
46 |
+
all_stats_writers: List[StatsWriter] = []
|
47 |
+
if ML_AGENTS_STATS_WRITER not in importlib_metadata.entry_points():
|
48 |
+
logger.warning(
|
49 |
+
f"Unable to find any entry points for {ML_AGENTS_STATS_WRITER}, even the default ones. "
|
50 |
+
"Uninstalling and reinstalling ml-agents via pip should resolve. "
|
51 |
+
"Using default plugins for now."
|
52 |
+
)
|
53 |
+
return get_default_stats_writers(run_options)
|
54 |
+
|
55 |
+
entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER]
|
56 |
+
|
57 |
+
for entry_point in entry_points:
|
58 |
+
|
59 |
+
try:
|
60 |
+
logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}")
|
61 |
+
plugin_func = entry_point.load()
|
62 |
+
plugin_stats_writers = plugin_func(run_options)
|
63 |
+
logger.debug(
|
64 |
+
f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}"
|
65 |
+
)
|
66 |
+
all_stats_writers += plugin_stats_writers
|
67 |
+
except BaseException:
|
68 |
+
# Catch all exceptions from setting up the plugin, so that bad user code doesn't break things.
|
69 |
+
logger.exception(
|
70 |
+
f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used."
|
71 |
+
)
|
72 |
+
return all_stats_writers
|
MLPY/Lib/site-packages/mlagents/plugins/trainer_type.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from typing import Dict, Tuple, Any
|
3 |
+
|
4 |
+
# importlib.metadata is new in python3.8
|
5 |
+
# We use the backport for older python versions.
|
6 |
+
if sys.version_info < (3, 8):
|
7 |
+
import importlib_metadata
|
8 |
+
else:
|
9 |
+
import importlib.metadata as importlib_metadata # pylint: disable=E0611
|
10 |
+
|
11 |
+
|
12 |
+
from mlagents_envs import logging_util
|
13 |
+
from mlagents.plugins import ML_AGENTS_TRAINER_TYPE
|
14 |
+
from mlagents.trainers.ppo.trainer import PPOTrainer
|
15 |
+
from mlagents.trainers.sac.trainer import SACTrainer
|
16 |
+
from mlagents.trainers.poca.trainer import POCATrainer
|
17 |
+
from mlagents.trainers.ppo.optimizer_torch import PPOSettings
|
18 |
+
from mlagents.trainers.sac.optimizer_torch import SACSettings
|
19 |
+
from mlagents.trainers.poca.optimizer_torch import POCASettings
|
20 |
+
from mlagents import plugins as mla_plugins
|
21 |
+
|
22 |
+
logger = logging_util.get_logger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def get_default_trainer_types() -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
26 |
+
"""
|
27 |
+
The Trainers that mlagents-learn always uses:
|
28 |
+
"""
|
29 |
+
|
30 |
+
mla_plugins.all_trainer_types.update(
|
31 |
+
{
|
32 |
+
PPOTrainer.get_trainer_name(): PPOTrainer,
|
33 |
+
SACTrainer.get_trainer_name(): SACTrainer,
|
34 |
+
POCATrainer.get_trainer_name(): POCATrainer,
|
35 |
+
}
|
36 |
+
)
|
37 |
+
# global all_trainer_settings
|
38 |
+
mla_plugins.all_trainer_settings.update(
|
39 |
+
{
|
40 |
+
PPOTrainer.get_trainer_name(): PPOSettings,
|
41 |
+
SACTrainer.get_trainer_name(): SACSettings,
|
42 |
+
POCATrainer.get_trainer_name(): POCASettings,
|
43 |
+
}
|
44 |
+
)
|
45 |
+
|
46 |
+
return mla_plugins.all_trainer_types, mla_plugins.all_trainer_settings
|
47 |
+
|
48 |
+
|
49 |
+
def register_trainer_plugins() -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
50 |
+
"""
|
51 |
+
Registers all Trainer plugins (including the default one),
|
52 |
+
and evaluates them, and returns the list of all the Trainer implementations.
|
53 |
+
"""
|
54 |
+
if ML_AGENTS_TRAINER_TYPE not in importlib_metadata.entry_points():
|
55 |
+
logger.warning(
|
56 |
+
f"Unable to find any entry points for {ML_AGENTS_TRAINER_TYPE}, even the default ones. "
|
57 |
+
"Uninstalling and reinstalling ml-agents via pip should resolve. "
|
58 |
+
"Using default plugins for now."
|
59 |
+
)
|
60 |
+
return get_default_trainer_types()
|
61 |
+
|
62 |
+
entry_points = importlib_metadata.entry_points()[ML_AGENTS_TRAINER_TYPE]
|
63 |
+
|
64 |
+
for entry_point in entry_points:
|
65 |
+
|
66 |
+
try:
|
67 |
+
logger.debug(f"Initializing Trainer plugins: {entry_point.name}")
|
68 |
+
plugin_func = entry_point.load()
|
69 |
+
plugin_trainer_types, plugin_trainer_settings = plugin_func()
|
70 |
+
logger.debug(
|
71 |
+
f"Found {len(plugin_trainer_types)} Trainers for plugin {entry_point.name}"
|
72 |
+
)
|
73 |
+
mla_plugins.all_trainer_types.update(plugin_trainer_types)
|
74 |
+
mla_plugins.all_trainer_settings.update(plugin_trainer_settings)
|
75 |
+
except BaseException:
|
76 |
+
# Catch all exceptions from setting up the plugin, so that bad user code doesn't break things.
|
77 |
+
logger.exception(
|
78 |
+
f"Error initializing Trainer plugins for {entry_point.name}. This plugin will not be used."
|
79 |
+
)
|
80 |
+
return mla_plugins.all_trainer_types, mla_plugins.all_trainer_settings
|
MLPY/Lib/site-packages/mlagents/torch_utils/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mlagents.torch_utils.torch import torch as torch # noqa
|
2 |
+
from mlagents.torch_utils.torch import nn # noqa
|
3 |
+
from mlagents.torch_utils.torch import set_torch_config # noqa
|
4 |
+
from mlagents.torch_utils.torch import default_device # noqa
|
MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (314 Bytes). View file
|
|
MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/cpu_utils.cpython-39.pyc
ADDED
Binary file (1.51 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/globals.cpython-39.pyc
ADDED
Binary file (576 Bytes). View file
|
|
MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/torch.cpython-39.pyc
ADDED
Binary file (1.64 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/torch_utils/cpu_utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
def get_num_threads_to_use() -> Optional[int]:
|
7 |
+
"""
|
8 |
+
Gets the number of threads to use. For most problems, 4 is all you
|
9 |
+
need, but for smaller machines, we'd like to scale to less than that.
|
10 |
+
By default, PyTorch uses 1/2 of the available cores.
|
11 |
+
"""
|
12 |
+
num_cpus = _get_num_available_cpus()
|
13 |
+
return max(min(num_cpus // 2, 4), 1) if num_cpus is not None else None
|
14 |
+
|
15 |
+
|
16 |
+
def _get_num_available_cpus() -> Optional[int]:
|
17 |
+
"""
|
18 |
+
Returns number of CPUs using cgroups if possible. This accounts
|
19 |
+
for Docker containers that are limited in cores.
|
20 |
+
"""
|
21 |
+
period = _read_in_integer_file("/sys/fs/cgroup/cpu/cpu.cfs_period_us")
|
22 |
+
quota = _read_in_integer_file("/sys/fs/cgroup/cpu/cpu.cfs_quota_us")
|
23 |
+
share = _read_in_integer_file("/sys/fs/cgroup/cpu/cpu.shares")
|
24 |
+
is_kubernetes = os.getenv("KUBERNETES_SERVICE_HOST") is not None
|
25 |
+
|
26 |
+
if period > 0 and quota > 0:
|
27 |
+
return int(quota // period)
|
28 |
+
elif period > 0 and share > 0 and is_kubernetes:
|
29 |
+
# In kubernetes, each requested CPU is 1024 CPU shares
|
30 |
+
# https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/#how-pods-with-resource-limits-are-run
|
31 |
+
return int(share // 1024)
|
32 |
+
else:
|
33 |
+
return os.cpu_count()
|
34 |
+
|
35 |
+
|
36 |
+
def _read_in_integer_file(filename: str) -> int:
|
37 |
+
try:
|
38 |
+
with open(filename) as f:
|
39 |
+
return int(f.read().rstrip())
|
40 |
+
except FileNotFoundError:
|
41 |
+
return -1
|
MLPY/Lib/site-packages/mlagents/torch_utils/globals.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
_rank: Optional[int] = None
|
4 |
+
|
5 |
+
|
6 |
+
def get_rank() -> Optional[int]:
|
7 |
+
"""
|
8 |
+
Returns the rank (in the MPI sense) of the current node.
|
9 |
+
For local training, this will always be None.
|
10 |
+
If this needs to be used, it should be done from outside ml-agents.
|
11 |
+
:return:
|
12 |
+
"""
|
13 |
+
return _rank
|
MLPY/Lib/site-packages/mlagents/torch_utils/torch.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from distutils.version import LooseVersion
|
4 |
+
import pkg_resources
|
5 |
+
from mlagents.torch_utils import cpu_utils
|
6 |
+
from mlagents.trainers.settings import TorchSettings
|
7 |
+
from mlagents_envs.logging_util import get_logger
|
8 |
+
|
9 |
+
|
10 |
+
logger = get_logger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
def assert_torch_installed():
|
14 |
+
# Check that torch version 1.6.0 or later has been installed. If not, refer
|
15 |
+
# user to the PyTorch webpage for install instructions.
|
16 |
+
torch_pkg = None
|
17 |
+
try:
|
18 |
+
torch_pkg = pkg_resources.get_distribution("torch")
|
19 |
+
except pkg_resources.DistributionNotFound:
|
20 |
+
pass
|
21 |
+
assert torch_pkg is not None and LooseVersion(torch_pkg.version) >= LooseVersion(
|
22 |
+
"1.6.0"
|
23 |
+
), (
|
24 |
+
"A compatible version of PyTorch was not installed. Please visit the PyTorch homepage "
|
25 |
+
+ "(https://pytorch.org/get-started/locally/) and follow the instructions to install. "
|
26 |
+
+ "Version 1.6.0 and later are supported."
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
assert_torch_installed()
|
31 |
+
|
32 |
+
# This should be the only place that we import torch directly.
|
33 |
+
# Everywhere else is caught by the banned-modules setting for flake8
|
34 |
+
import torch # noqa I201
|
35 |
+
|
36 |
+
|
37 |
+
torch.set_num_threads(cpu_utils.get_num_threads_to_use())
|
38 |
+
os.environ["KMP_BLOCKTIME"] = "0"
|
39 |
+
|
40 |
+
|
41 |
+
_device = torch.device("cpu")
|
42 |
+
|
43 |
+
|
44 |
+
def set_torch_config(torch_settings: TorchSettings) -> None:
|
45 |
+
global _device
|
46 |
+
|
47 |
+
if torch_settings.device is None:
|
48 |
+
device_str = "cuda" if torch.cuda.is_available() else "cpu"
|
49 |
+
else:
|
50 |
+
device_str = torch_settings.device
|
51 |
+
|
52 |
+
_device = torch.device(device_str)
|
53 |
+
|
54 |
+
if _device.type == "cuda":
|
55 |
+
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
56 |
+
else:
|
57 |
+
torch.set_default_tensor_type(torch.FloatTensor)
|
58 |
+
logger.debug(f"default Torch device: {_device}")
|
59 |
+
|
60 |
+
|
61 |
+
# Initialize to default settings
|
62 |
+
set_torch_config(TorchSettings(device=None))
|
63 |
+
|
64 |
+
nn = torch.nn
|
65 |
+
|
66 |
+
|
67 |
+
def default_device():
|
68 |
+
return _device
|
MLPY/Lib/site-packages/mlagents/trainers/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Version of the library that will be used to upload to pypi
|
2 |
+
__version__ = "0.30.0"
|
3 |
+
|
4 |
+
# Git tag that will be checked to determine whether to trigger upload to pypi
|
5 |
+
__release_tag__ = "release_20"
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (211 Bytes). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/action_info.cpython-39.pyc
ADDED
Binary file (1.27 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/agent_processor.cpython-39.pyc
ADDED
Binary file (14.1 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/behavior_id_utils.cpython-39.pyc
ADDED
Binary file (2.62 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/buffer.cpython-39.pyc
ADDED
Binary file (18.3 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/cli_utils.cpython-39.pyc
ADDED
Binary file (9.91 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/demo_loader.cpython-39.pyc
ADDED
Binary file (6.44 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/directory_utils.cpython-39.pyc
ADDED
Binary file (2.68 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/env_manager.cpython-39.pyc
ADDED
Binary file (5.44 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/environment_parameter_manager.cpython-39.pyc
ADDED
Binary file (6.54 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/exception.cpython-39.pyc
ADDED
Binary file (2.16 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/learn.cpython-39.pyc
ADDED
Binary file (8.42 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/run_experiment.cpython-39.pyc
ADDED
Binary file (1.31 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/settings.cpython-39.pyc
ADDED
Binary file (32.6 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/simple_env_manager.cpython-39.pyc
ADDED
Binary file (3.54 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/stats.cpython-39.pyc
ADDED
Binary file (14.9 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/subprocess_env_manager.cpython-39.pyc
ADDED
Binary file (16.4 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/trainer_controller.cpython-39.pyc
ADDED
Binary file (9.53 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/training_analytics_side_channel.cpython-39.pyc
ADDED
Binary file (6.14 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/training_status.cpython-39.pyc
ADDED
Binary file (5.09 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/trajectory.cpython-39.pyc
ADDED
Binary file (8.41 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/upgrade_config.cpython-39.pyc
ADDED
Binary file (5.8 kB). View file
|
|
MLPY/Lib/site-packages/mlagents/trainers/action_info.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import NamedTuple, Any, Dict, List
|
2 |
+
import numpy as np
|
3 |
+
from mlagents_envs.base_env import AgentId
|
4 |
+
|
5 |
+
ActionInfoOutputs = Dict[str, np.ndarray]
|
6 |
+
|
7 |
+
|
8 |
+
class ActionInfo(NamedTuple):
|
9 |
+
"""
|
10 |
+
A NamedTuple containing actions and related quantities to the policy forward
|
11 |
+
pass. Additionally contains the agent ids in the corresponding DecisionStep
|
12 |
+
:param action: The action output of the policy
|
13 |
+
:param env_action: The possibly clipped action to be executed in the environment
|
14 |
+
:param outputs: Dict of all quantities associated with the policy forward pass
|
15 |
+
:param agent_ids: List of int agent ids in DecisionStep
|
16 |
+
"""
|
17 |
+
|
18 |
+
action: Any
|
19 |
+
env_action: Any
|
20 |
+
outputs: ActionInfoOutputs
|
21 |
+
agent_ids: List[AgentId]
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def empty() -> "ActionInfo":
|
25 |
+
return ActionInfo([], [], {}, [])
|
MLPY/Lib/site-packages/mlagents/trainers/agent_processor.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
from typing import List, Dict, TypeVar, Generic, Tuple, Any, Union
|
4 |
+
from collections import defaultdict, Counter
|
5 |
+
import queue
|
6 |
+
from mlagents.torch_utils import torch
|
7 |
+
|
8 |
+
from mlagents_envs.base_env import (
|
9 |
+
ActionTuple,
|
10 |
+
DecisionSteps,
|
11 |
+
DecisionStep,
|
12 |
+
TerminalSteps,
|
13 |
+
TerminalStep,
|
14 |
+
)
|
15 |
+
from mlagents_envs.side_channel.stats_side_channel import (
|
16 |
+
StatsAggregationMethod,
|
17 |
+
EnvironmentStats,
|
18 |
+
)
|
19 |
+
from mlagents.trainers.exception import UnityTrainerException
|
20 |
+
from mlagents.trainers.trajectory import AgentStatus, Trajectory, AgentExperience
|
21 |
+
from mlagents.trainers.policy import Policy
|
22 |
+
from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs
|
23 |
+
from mlagents.trainers.stats import StatsReporter
|
24 |
+
from mlagents.trainers.behavior_id_utils import (
|
25 |
+
get_global_agent_id,
|
26 |
+
get_global_group_id,
|
27 |
+
GlobalAgentId,
|
28 |
+
GlobalGroupId,
|
29 |
+
)
|
30 |
+
from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple
|
31 |
+
from mlagents.trainers.torch_entities.utils import ModelUtils
|
32 |
+
|
33 |
+
T = TypeVar("T")
|
34 |
+
|
35 |
+
|
36 |
+
class AgentProcessor:
|
37 |
+
"""
|
38 |
+
AgentProcessor contains a dictionary per-agent trajectory buffers. The buffers are indexed by agent_id.
|
39 |
+
Buffer also contains an update_buffer that corresponds to the buffer used when updating the model.
|
40 |
+
One AgentProcessor should be created per agent group.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
policy: Policy,
|
46 |
+
behavior_id: str,
|
47 |
+
stats_reporter: StatsReporter,
|
48 |
+
max_trajectory_length: int = sys.maxsize,
|
49 |
+
):
|
50 |
+
"""
|
51 |
+
Create an AgentProcessor.
|
52 |
+
|
53 |
+
:param trainer: Trainer instance connected to this AgentProcessor. Trainer is given trajectory
|
54 |
+
when it is finished.
|
55 |
+
:param policy: Policy instance associated with this AgentProcessor.
|
56 |
+
:param max_trajectory_length: Maximum length of a trajectory before it is added to the trainer.
|
57 |
+
:param stats_category: The category under which to write the stats. Usually, this comes from the Trainer.
|
58 |
+
"""
|
59 |
+
self._experience_buffers: Dict[
|
60 |
+
GlobalAgentId, List[AgentExperience]
|
61 |
+
] = defaultdict(list)
|
62 |
+
self._last_step_result: Dict[GlobalAgentId, Tuple[DecisionStep, int]] = {}
|
63 |
+
# current_group_obs is used to collect the current (i.e. the most recently seen)
|
64 |
+
# obs of all the agents in the same group, and assemble the group obs.
|
65 |
+
# It is a dictionary of GlobalGroupId to dictionaries of GlobalAgentId to observation.
|
66 |
+
self._current_group_obs: Dict[
|
67 |
+
GlobalGroupId, Dict[GlobalAgentId, List[np.ndarray]]
|
68 |
+
] = defaultdict(lambda: defaultdict(list))
|
69 |
+
# group_status is used to collect the current, most recently seen
|
70 |
+
# group status of all the agents in the same group, and assemble the group's status.
|
71 |
+
# It is a dictionary of GlobalGroupId to dictionaries of GlobalAgentId to AgentStatus.
|
72 |
+
self._group_status: Dict[
|
73 |
+
GlobalGroupId, Dict[GlobalAgentId, AgentStatus]
|
74 |
+
] = defaultdict(lambda: defaultdict(None))
|
75 |
+
# last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while
|
76 |
+
# grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1).
|
77 |
+
self._last_take_action_outputs: Dict[GlobalAgentId, ActionInfoOutputs] = {}
|
78 |
+
|
79 |
+
self._episode_steps: Counter = Counter()
|
80 |
+
self._episode_rewards: Dict[GlobalAgentId, float] = defaultdict(float)
|
81 |
+
self._stats_reporter = stats_reporter
|
82 |
+
self._max_trajectory_length = max_trajectory_length
|
83 |
+
self._trajectory_queues: List[AgentManagerQueue[Trajectory]] = []
|
84 |
+
self._behavior_id = behavior_id
|
85 |
+
|
86 |
+
# Note: In the future this policy reference will be the policy of the env_manager and not the trainer.
|
87 |
+
# We can in that case just grab the action from the policy rather than having it passed in.
|
88 |
+
self.policy = policy
|
89 |
+
|
90 |
+
def add_experiences(
|
91 |
+
self,
|
92 |
+
decision_steps: DecisionSteps,
|
93 |
+
terminal_steps: TerminalSteps,
|
94 |
+
worker_id: int,
|
95 |
+
previous_action: ActionInfo,
|
96 |
+
) -> None:
|
97 |
+
"""
|
98 |
+
Adds experiences to each agent's experience history.
|
99 |
+
:param decision_steps: current DecisionSteps.
|
100 |
+
:param terminal_steps: current TerminalSteps.
|
101 |
+
:param previous_action: The outputs of the Policy's get_action method.
|
102 |
+
"""
|
103 |
+
take_action_outputs = previous_action.outputs
|
104 |
+
if take_action_outputs:
|
105 |
+
try:
|
106 |
+
for _entropy in take_action_outputs["entropy"]:
|
107 |
+
if isinstance(_entropy, torch.Tensor):
|
108 |
+
_entropy = ModelUtils.to_numpy(_entropy)
|
109 |
+
self._stats_reporter.add_stat("Policy/Entropy", _entropy)
|
110 |
+
except KeyError:
|
111 |
+
pass
|
112 |
+
|
113 |
+
# Make unique agent_ids that are global across workers
|
114 |
+
action_global_agent_ids = [
|
115 |
+
get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids
|
116 |
+
]
|
117 |
+
for global_id in action_global_agent_ids:
|
118 |
+
if global_id in self._last_step_result: # Don't store if agent just reset
|
119 |
+
self._last_take_action_outputs[global_id] = take_action_outputs
|
120 |
+
|
121 |
+
# Iterate over all the terminal steps, first gather all the group obs
|
122 |
+
# and then create the AgentExperiences/Trajectories. _add_to_group_status
|
123 |
+
# stores Group statuses in a common data structure self.group_status
|
124 |
+
for terminal_step in terminal_steps.values():
|
125 |
+
self._add_group_status_and_obs(terminal_step, worker_id)
|
126 |
+
for terminal_step in terminal_steps.values():
|
127 |
+
local_id = terminal_step.agent_id
|
128 |
+
global_id = get_global_agent_id(worker_id, local_id)
|
129 |
+
self._process_step(
|
130 |
+
terminal_step, worker_id, terminal_steps.agent_id_to_index[local_id]
|
131 |
+
)
|
132 |
+
|
133 |
+
# Iterate over all the decision steps, first gather all the group obs
|
134 |
+
# and then create the trajectories. _add_to_group_status
|
135 |
+
# stores Group statuses in a common data structure self.group_status
|
136 |
+
for ongoing_step in decision_steps.values():
|
137 |
+
self._add_group_status_and_obs(ongoing_step, worker_id)
|
138 |
+
for ongoing_step in decision_steps.values():
|
139 |
+
local_id = ongoing_step.agent_id
|
140 |
+
self._process_step(
|
141 |
+
ongoing_step, worker_id, decision_steps.agent_id_to_index[local_id]
|
142 |
+
)
|
143 |
+
# Clear the last seen group obs when agents die, but only after all of the group
|
144 |
+
# statuses were added to the trajectory.
|
145 |
+
for terminal_step in terminal_steps.values():
|
146 |
+
local_id = terminal_step.agent_id
|
147 |
+
global_id = get_global_agent_id(worker_id, local_id)
|
148 |
+
self._clear_group_status_and_obs(global_id)
|
149 |
+
|
150 |
+
for _gid in action_global_agent_ids:
|
151 |
+
# If the ID doesn't have a last step result, the agent just reset,
|
152 |
+
# don't store the action.
|
153 |
+
if _gid in self._last_step_result:
|
154 |
+
if "action" in take_action_outputs:
|
155 |
+
self.policy.save_previous_action(
|
156 |
+
[_gid], take_action_outputs["action"]
|
157 |
+
)
|
158 |
+
|
159 |
+
def _add_group_status_and_obs(
|
160 |
+
self, step: Union[TerminalStep, DecisionStep], worker_id: int
|
161 |
+
) -> None:
|
162 |
+
"""
|
163 |
+
Takes a TerminalStep or DecisionStep and adds the information in it
|
164 |
+
to self.group_status. This information can then be retrieved
|
165 |
+
when constructing trajectories to get the status of group mates. Also stores the current
|
166 |
+
observation into current_group_obs, to be used to get the next group observations
|
167 |
+
for bootstrapping.
|
168 |
+
:param step: TerminalStep or DecisionStep
|
169 |
+
:param worker_id: Worker ID of this particular environment. Used to generate a
|
170 |
+
global group id.
|
171 |
+
"""
|
172 |
+
global_agent_id = get_global_agent_id(worker_id, step.agent_id)
|
173 |
+
stored_decision_step, idx = self._last_step_result.get(
|
174 |
+
global_agent_id, (None, None)
|
175 |
+
)
|
176 |
+
stored_take_action_outputs = self._last_take_action_outputs.get(
|
177 |
+
global_agent_id, None
|
178 |
+
)
|
179 |
+
if stored_decision_step is not None and stored_take_action_outputs is not None:
|
180 |
+
# 0, the default group_id, means that the agent doesn't belong to an agent group.
|
181 |
+
# If 0, don't add any groupmate information.
|
182 |
+
if step.group_id > 0:
|
183 |
+
global_group_id = get_global_group_id(worker_id, step.group_id)
|
184 |
+
stored_actions = stored_take_action_outputs["action"]
|
185 |
+
action_tuple = ActionTuple(
|
186 |
+
continuous=stored_actions.continuous[idx],
|
187 |
+
discrete=stored_actions.discrete[idx],
|
188 |
+
)
|
189 |
+
group_status = AgentStatus(
|
190 |
+
obs=stored_decision_step.obs,
|
191 |
+
reward=step.reward,
|
192 |
+
action=action_tuple,
|
193 |
+
done=isinstance(step, TerminalStep),
|
194 |
+
)
|
195 |
+
self._group_status[global_group_id][global_agent_id] = group_status
|
196 |
+
self._current_group_obs[global_group_id][global_agent_id] = step.obs
|
197 |
+
|
198 |
+
def _clear_group_status_and_obs(self, global_id: GlobalAgentId) -> None:
|
199 |
+
"""
|
200 |
+
Clears an agent from self._group_status and self._current_group_obs.
|
201 |
+
"""
|
202 |
+
self._delete_in_nested_dict(self._current_group_obs, global_id)
|
203 |
+
self._delete_in_nested_dict(self._group_status, global_id)
|
204 |
+
|
205 |
+
def _delete_in_nested_dict(self, nested_dict: Dict[str, Any], key: str) -> None:
|
206 |
+
for _manager_id in list(nested_dict.keys()):
|
207 |
+
_team_group = nested_dict[_manager_id]
|
208 |
+
self._safe_delete(_team_group, key)
|
209 |
+
if not _team_group: # if dict is empty
|
210 |
+
self._safe_delete(nested_dict, _manager_id)
|
211 |
+
|
212 |
+
def _process_step(
|
213 |
+
self, step: Union[TerminalStep, DecisionStep], worker_id: int, index: int
|
214 |
+
) -> None:
|
215 |
+
terminated = isinstance(step, TerminalStep)
|
216 |
+
global_agent_id = get_global_agent_id(worker_id, step.agent_id)
|
217 |
+
global_group_id = get_global_group_id(worker_id, step.group_id)
|
218 |
+
stored_decision_step, idx = self._last_step_result.get(
|
219 |
+
global_agent_id, (None, None)
|
220 |
+
)
|
221 |
+
stored_take_action_outputs = self._last_take_action_outputs.get(
|
222 |
+
global_agent_id, None
|
223 |
+
)
|
224 |
+
if not terminated:
|
225 |
+
# Index is needed to grab from last_take_action_outputs
|
226 |
+
self._last_step_result[global_agent_id] = (step, index)
|
227 |
+
|
228 |
+
# This state is the consequence of a past action
|
229 |
+
if stored_decision_step is not None and stored_take_action_outputs is not None:
|
230 |
+
obs = stored_decision_step.obs
|
231 |
+
if self.policy.use_recurrent:
|
232 |
+
memory = self.policy.retrieve_previous_memories([global_agent_id])[0, :]
|
233 |
+
else:
|
234 |
+
memory = None
|
235 |
+
done = terminated # Since this is an ongoing step
|
236 |
+
interrupted = step.interrupted if terminated else False
|
237 |
+
# Add the outputs of the last eval
|
238 |
+
stored_actions = stored_take_action_outputs["action"]
|
239 |
+
action_tuple = ActionTuple(
|
240 |
+
continuous=stored_actions.continuous[idx],
|
241 |
+
discrete=stored_actions.discrete[idx],
|
242 |
+
)
|
243 |
+
try:
|
244 |
+
stored_action_probs = stored_take_action_outputs["log_probs"]
|
245 |
+
if not isinstance(stored_action_probs, LogProbsTuple):
|
246 |
+
stored_action_probs = stored_action_probs.to_log_probs_tuple()
|
247 |
+
log_probs_tuple = LogProbsTuple(
|
248 |
+
continuous=stored_action_probs.continuous[idx],
|
249 |
+
discrete=stored_action_probs.discrete[idx],
|
250 |
+
)
|
251 |
+
except KeyError:
|
252 |
+
log_probs_tuple = LogProbsTuple.empty_log_probs()
|
253 |
+
|
254 |
+
action_mask = stored_decision_step.action_mask
|
255 |
+
prev_action = self.policy.retrieve_previous_action([global_agent_id])[0, :]
|
256 |
+
|
257 |
+
# Assemble teammate_obs. If none saved, then it will be an empty list.
|
258 |
+
group_statuses = []
|
259 |
+
for _id, _mate_status in self._group_status[global_group_id].items():
|
260 |
+
if _id != global_agent_id:
|
261 |
+
group_statuses.append(_mate_status)
|
262 |
+
|
263 |
+
experience = AgentExperience(
|
264 |
+
obs=obs,
|
265 |
+
reward=step.reward,
|
266 |
+
done=done,
|
267 |
+
action=action_tuple,
|
268 |
+
action_probs=log_probs_tuple,
|
269 |
+
action_mask=action_mask,
|
270 |
+
prev_action=prev_action,
|
271 |
+
interrupted=interrupted,
|
272 |
+
memory=memory,
|
273 |
+
group_status=group_statuses,
|
274 |
+
group_reward=step.group_reward,
|
275 |
+
)
|
276 |
+
# Add the value outputs if needed
|
277 |
+
self._experience_buffers[global_agent_id].append(experience)
|
278 |
+
self._episode_rewards[global_agent_id] += step.reward
|
279 |
+
if not terminated:
|
280 |
+
self._episode_steps[global_agent_id] += 1
|
281 |
+
|
282 |
+
# Add a trajectory segment to the buffer if terminal or the length has reached the time horizon
|
283 |
+
if (
|
284 |
+
len(self._experience_buffers[global_agent_id])
|
285 |
+
>= self._max_trajectory_length
|
286 |
+
or terminated
|
287 |
+
):
|
288 |
+
next_obs = step.obs
|
289 |
+
next_group_obs = []
|
290 |
+
for _id, _obs in self._current_group_obs[global_group_id].items():
|
291 |
+
if _id != global_agent_id:
|
292 |
+
next_group_obs.append(_obs)
|
293 |
+
|
294 |
+
trajectory = Trajectory(
|
295 |
+
steps=self._experience_buffers[global_agent_id],
|
296 |
+
agent_id=global_agent_id,
|
297 |
+
next_obs=next_obs,
|
298 |
+
next_group_obs=next_group_obs,
|
299 |
+
behavior_id=self._behavior_id,
|
300 |
+
)
|
301 |
+
for traj_queue in self._trajectory_queues:
|
302 |
+
traj_queue.put(trajectory)
|
303 |
+
self._experience_buffers[global_agent_id] = []
|
304 |
+
if terminated:
|
305 |
+
# Record episode length.
|
306 |
+
self._stats_reporter.add_stat(
|
307 |
+
"Environment/Episode Length",
|
308 |
+
self._episode_steps.get(global_agent_id, 0),
|
309 |
+
)
|
310 |
+
self._clean_agent_data(global_agent_id)
|
311 |
+
|
312 |
+
def _clean_agent_data(self, global_id: GlobalAgentId) -> None:
|
313 |
+
"""
|
314 |
+
Removes the data for an Agent.
|
315 |
+
"""
|
316 |
+
self._safe_delete(self._experience_buffers, global_id)
|
317 |
+
self._safe_delete(self._last_take_action_outputs, global_id)
|
318 |
+
self._safe_delete(self._last_step_result, global_id)
|
319 |
+
self._safe_delete(self._episode_steps, global_id)
|
320 |
+
self._safe_delete(self._episode_rewards, global_id)
|
321 |
+
self.policy.remove_previous_action([global_id])
|
322 |
+
self.policy.remove_memories([global_id])
|
323 |
+
|
324 |
+
def _safe_delete(self, my_dictionary: Dict[Any, Any], key: Any) -> None:
|
325 |
+
"""
|
326 |
+
Safe removes data from a dictionary. If not found,
|
327 |
+
don't delete.
|
328 |
+
"""
|
329 |
+
if key in my_dictionary:
|
330 |
+
del my_dictionary[key]
|
331 |
+
|
332 |
+
def publish_trajectory_queue(
|
333 |
+
self, trajectory_queue: "AgentManagerQueue[Trajectory]"
|
334 |
+
) -> None:
|
335 |
+
"""
|
336 |
+
Adds a trajectory queue to the list of queues to publish to when this AgentProcessor
|
337 |
+
assembles a Trajectory
|
338 |
+
:param trajectory_queue: Trajectory queue to publish to.
|
339 |
+
"""
|
340 |
+
self._trajectory_queues.append(trajectory_queue)
|
341 |
+
|
342 |
+
def end_episode(self) -> None:
|
343 |
+
"""
|
344 |
+
Ends the episode, terminating the current trajectory and stopping stats collection for that
|
345 |
+
episode. Used for forceful reset (e.g. in curriculum or generalization training.)
|
346 |
+
"""
|
347 |
+
all_gids = list(self._experience_buffers.keys()) # Need to make copy
|
348 |
+
for _gid in all_gids:
|
349 |
+
self._clean_agent_data(_gid)
|
350 |
+
|
351 |
+
|
352 |
+
class AgentManagerQueue(Generic[T]):
|
353 |
+
"""
|
354 |
+
Queue used by the AgentManager. Note that we make our own class here because in most implementations
|
355 |
+
deque is sufficient and faster. However, if we want to switch to multiprocessing, we'll need to change
|
356 |
+
out this implementation.
|
357 |
+
"""
|
358 |
+
|
359 |
+
class Empty(Exception):
|
360 |
+
"""
|
361 |
+
Exception for when the queue is empty.
|
362 |
+
"""
|
363 |
+
|
364 |
+
pass
|
365 |
+
|
366 |
+
def __init__(self, behavior_id: str, maxlen: int = 0):
|
367 |
+
"""
|
368 |
+
Initializes an AgentManagerQueue. Note that we can give it a behavior_id so that it can be identified
|
369 |
+
separately from an AgentManager.
|
370 |
+
"""
|
371 |
+
self._maxlen: int = maxlen
|
372 |
+
self._queue: queue.Queue = queue.Queue(maxsize=maxlen)
|
373 |
+
self._behavior_id = behavior_id
|
374 |
+
|
375 |
+
@property
|
376 |
+
def maxlen(self):
|
377 |
+
"""
|
378 |
+
The maximum length of the queue.
|
379 |
+
:return: Maximum length of the queue.
|
380 |
+
"""
|
381 |
+
return self._maxlen
|
382 |
+
|
383 |
+
@property
|
384 |
+
def behavior_id(self):
|
385 |
+
"""
|
386 |
+
The Behavior ID of this queue.
|
387 |
+
:return: Behavior ID associated with the queue.
|
388 |
+
"""
|
389 |
+
return self._behavior_id
|
390 |
+
|
391 |
+
def qsize(self) -> int:
|
392 |
+
"""
|
393 |
+
Returns the approximate size of the queue. Note that values may differ
|
394 |
+
depending on the underlying queue implementation.
|
395 |
+
"""
|
396 |
+
return self._queue.qsize()
|
397 |
+
|
398 |
+
def empty(self) -> bool:
|
399 |
+
return self._queue.empty()
|
400 |
+
|
401 |
+
def get_nowait(self) -> T:
|
402 |
+
"""
|
403 |
+
Gets the next item from the queue, throwing an AgentManagerQueue.Empty exception
|
404 |
+
if the queue is empty.
|
405 |
+
"""
|
406 |
+
try:
|
407 |
+
return self._queue.get_nowait()
|
408 |
+
except queue.Empty:
|
409 |
+
raise self.Empty("The AgentManagerQueue is empty.")
|
410 |
+
|
411 |
+
def put(self, item: T) -> None:
|
412 |
+
self._queue.put(item)
|
413 |
+
|
414 |
+
|
415 |
+
class AgentManager(AgentProcessor):
|
416 |
+
"""
|
417 |
+
An AgentManager is an AgentProcessor that also holds a single trajectory and policy queue.
|
418 |
+
Note: this leaves room for adding AgentProcessors that publish multiple trajectory queues.
|
419 |
+
"""
|
420 |
+
|
421 |
+
def __init__(
|
422 |
+
self,
|
423 |
+
policy: Policy,
|
424 |
+
behavior_id: str,
|
425 |
+
stats_reporter: StatsReporter,
|
426 |
+
max_trajectory_length: int = sys.maxsize,
|
427 |
+
threaded: bool = True,
|
428 |
+
):
|
429 |
+
super().__init__(policy, behavior_id, stats_reporter, max_trajectory_length)
|
430 |
+
trajectory_queue_len = 20 if threaded else 0
|
431 |
+
self.trajectory_queue: AgentManagerQueue[Trajectory] = AgentManagerQueue(
|
432 |
+
self._behavior_id, maxlen=trajectory_queue_len
|
433 |
+
)
|
434 |
+
# NOTE: we make policy queues of infinite length to avoid lockups of the trainers.
|
435 |
+
# In the environment manager, we make sure to empty the policy queue before continuing to produce steps.
|
436 |
+
self.policy_queue: AgentManagerQueue[Policy] = AgentManagerQueue(
|
437 |
+
self._behavior_id, maxlen=0
|
438 |
+
)
|
439 |
+
self.publish_trajectory_queue(self.trajectory_queue)
|
440 |
+
|
441 |
+
def record_environment_stats(
|
442 |
+
self, env_stats: EnvironmentStats, worker_id: int
|
443 |
+
) -> None:
|
444 |
+
"""
|
445 |
+
Pass stats from the environment to the StatsReporter.
|
446 |
+
Depending on the StatsAggregationMethod, either StatsReporter.add_stat or StatsReporter.set_stat is used.
|
447 |
+
The worker_id is used to determine whether StatsReporter.set_stat should be used.
|
448 |
+
|
449 |
+
:param env_stats:
|
450 |
+
:param worker_id:
|
451 |
+
:return:
|
452 |
+
"""
|
453 |
+
for stat_name, value_list in env_stats.items():
|
454 |
+
for val, agg_type in value_list:
|
455 |
+
if agg_type == StatsAggregationMethod.AVERAGE:
|
456 |
+
self._stats_reporter.add_stat(stat_name, val, agg_type)
|
457 |
+
elif agg_type == StatsAggregationMethod.SUM:
|
458 |
+
self._stats_reporter.add_stat(stat_name, val, agg_type)
|
459 |
+
elif agg_type == StatsAggregationMethod.HISTOGRAM:
|
460 |
+
self._stats_reporter.add_stat(stat_name, val, agg_type)
|
461 |
+
elif agg_type == StatsAggregationMethod.MOST_RECENT:
|
462 |
+
# In order to prevent conflicts between multiple environments,
|
463 |
+
# only stats from the first environment are recorded.
|
464 |
+
if worker_id == 0:
|
465 |
+
self._stats_reporter.set_stat(stat_name, val)
|
466 |
+
else:
|
467 |
+
raise UnityTrainerException(
|
468 |
+
f"Unknown StatsAggregationMethod encountered. {agg_type}"
|
469 |
+
)
|
MLPY/Lib/site-packages/mlagents/trainers/behavior_id_utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import NamedTuple
|
2 |
+
from urllib.parse import urlparse, parse_qs
|
3 |
+
from mlagents_envs.base_env import AgentId, GroupId
|
4 |
+
|
5 |
+
GlobalGroupId = str
|
6 |
+
GlobalAgentId = str
|
7 |
+
|
8 |
+
|
9 |
+
class BehaviorIdentifiers(NamedTuple):
|
10 |
+
"""
|
11 |
+
BehaviorIdentifiers is a named tuple of the identifiers that uniquely distinguish
|
12 |
+
an agent encountered in the trainer_controller. The named tuple consists of the
|
13 |
+
fully qualified behavior name, the name of the brain name (corresponds to trainer
|
14 |
+
in the trainer controller) and the team id. In the future, this can be extended
|
15 |
+
to support further identifiers.
|
16 |
+
"""
|
17 |
+
|
18 |
+
behavior_id: str
|
19 |
+
brain_name: str
|
20 |
+
team_id: int
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def from_name_behavior_id(name_behavior_id: str) -> "BehaviorIdentifiers":
|
24 |
+
"""
|
25 |
+
Parses a name_behavior_id of the form name?team=0
|
26 |
+
into a BehaviorIdentifiers NamedTuple.
|
27 |
+
This allows you to access the brain name and team id of an agent
|
28 |
+
:param name_behavior_id: String of behavior params in HTTP format.
|
29 |
+
:returns: A BehaviorIdentifiers object.
|
30 |
+
"""
|
31 |
+
|
32 |
+
parsed = urlparse(name_behavior_id)
|
33 |
+
name = parsed.path
|
34 |
+
ids = parse_qs(parsed.query)
|
35 |
+
team_id: int = 0
|
36 |
+
if "team" in ids:
|
37 |
+
team_id = int(ids["team"][0])
|
38 |
+
return BehaviorIdentifiers(
|
39 |
+
behavior_id=name_behavior_id, brain_name=name, team_id=team_id
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def create_name_behavior_id(name: str, team_id: int) -> str:
|
44 |
+
"""
|
45 |
+
Reconstructs fully qualified behavior name from name and team_id
|
46 |
+
:param name: brain name
|
47 |
+
:param team_id: team ID
|
48 |
+
:return: name_behavior_id
|
49 |
+
"""
|
50 |
+
return name + "?team=" + str(team_id)
|
51 |
+
|
52 |
+
|
53 |
+
def get_global_agent_id(worker_id: int, agent_id: AgentId) -> GlobalAgentId:
|
54 |
+
"""
|
55 |
+
Create an agent id that is unique across environment workers using the worker_id.
|
56 |
+
"""
|
57 |
+
return f"agent_{worker_id}-{agent_id}"
|
58 |
+
|
59 |
+
|
60 |
+
def get_global_group_id(worker_id: int, group_id: GroupId) -> GlobalGroupId:
|
61 |
+
"""
|
62 |
+
Create a group id that is unique across environment workers when using the worker_id.
|
63 |
+
"""
|
64 |
+
return f"group_{worker_id}-{group_id}"
|
MLPY/Lib/site-packages/mlagents/trainers/buffer.py
ADDED
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from collections.abc import MutableMapping
|
3 |
+
import enum
|
4 |
+
import itertools
|
5 |
+
from typing import BinaryIO, DefaultDict, List, Tuple, Union, Optional
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import h5py
|
9 |
+
|
10 |
+
from mlagents_envs.exception import UnityException
|
11 |
+
|
12 |
+
# Elements in the buffer can be np.ndarray, or in the case of teammate obs, actions, rewards,
|
13 |
+
# a List of np.ndarray. This is done so that we don't have duplicated np.ndarrays, only references.
|
14 |
+
BufferEntry = Union[np.ndarray, List[np.ndarray]]
|
15 |
+
|
16 |
+
|
17 |
+
class BufferException(UnityException):
|
18 |
+
"""
|
19 |
+
Related to errors with the Buffer.
|
20 |
+
"""
|
21 |
+
|
22 |
+
pass
|
23 |
+
|
24 |
+
|
25 |
+
class BufferKey(enum.Enum):
|
26 |
+
ACTION_MASK = "action_mask"
|
27 |
+
CONTINUOUS_ACTION = "continuous_action"
|
28 |
+
NEXT_CONT_ACTION = "next_continuous_action"
|
29 |
+
CONTINUOUS_LOG_PROBS = "continuous_log_probs"
|
30 |
+
DISCRETE_ACTION = "discrete_action"
|
31 |
+
NEXT_DISC_ACTION = "next_discrete_action"
|
32 |
+
DISCRETE_LOG_PROBS = "discrete_log_probs"
|
33 |
+
DONE = "done"
|
34 |
+
ENVIRONMENT_REWARDS = "environment_rewards"
|
35 |
+
MASKS = "masks"
|
36 |
+
MEMORY = "memory"
|
37 |
+
CRITIC_MEMORY = "critic_memory"
|
38 |
+
BASELINE_MEMORY = "poca_baseline_memory"
|
39 |
+
PREV_ACTION = "prev_action"
|
40 |
+
|
41 |
+
ADVANTAGES = "advantages"
|
42 |
+
DISCOUNTED_RETURNS = "discounted_returns"
|
43 |
+
|
44 |
+
GROUP_DONES = "group_dones"
|
45 |
+
GROUPMATE_REWARDS = "groupmate_reward"
|
46 |
+
GROUP_REWARD = "group_reward"
|
47 |
+
GROUP_CONTINUOUS_ACTION = "group_continuous_action"
|
48 |
+
GROUP_DISCRETE_ACTION = "group_discrete_aaction"
|
49 |
+
GROUP_NEXT_CONT_ACTION = "group_next_cont_action"
|
50 |
+
GROUP_NEXT_DISC_ACTION = "group_next_disc_action"
|
51 |
+
|
52 |
+
|
53 |
+
class ObservationKeyPrefix(enum.Enum):
|
54 |
+
OBSERVATION = "obs"
|
55 |
+
NEXT_OBSERVATION = "next_obs"
|
56 |
+
|
57 |
+
GROUP_OBSERVATION = "group_obs"
|
58 |
+
NEXT_GROUP_OBSERVATION = "next_group_obs"
|
59 |
+
|
60 |
+
|
61 |
+
class RewardSignalKeyPrefix(enum.Enum):
|
62 |
+
# Reward signals
|
63 |
+
REWARDS = "rewards"
|
64 |
+
VALUE_ESTIMATES = "value_estimates"
|
65 |
+
RETURNS = "returns"
|
66 |
+
ADVANTAGE = "advantage"
|
67 |
+
BASELINES = "baselines"
|
68 |
+
|
69 |
+
|
70 |
+
AgentBufferKey = Union[
|
71 |
+
BufferKey, Tuple[ObservationKeyPrefix, int], Tuple[RewardSignalKeyPrefix, str]
|
72 |
+
]
|
73 |
+
|
74 |
+
|
75 |
+
class RewardSignalUtil:
|
76 |
+
@staticmethod
|
77 |
+
def rewards_key(name: str) -> AgentBufferKey:
|
78 |
+
return RewardSignalKeyPrefix.REWARDS, name
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def value_estimates_key(name: str) -> AgentBufferKey:
|
82 |
+
return RewardSignalKeyPrefix.RETURNS, name
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def returns_key(name: str) -> AgentBufferKey:
|
86 |
+
return RewardSignalKeyPrefix.RETURNS, name
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
def advantage_key(name: str) -> AgentBufferKey:
|
90 |
+
return RewardSignalKeyPrefix.ADVANTAGE, name
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def baseline_estimates_key(name: str) -> AgentBufferKey:
|
94 |
+
return RewardSignalKeyPrefix.BASELINES, name
|
95 |
+
|
96 |
+
|
97 |
+
class AgentBufferField(list):
|
98 |
+
"""
|
99 |
+
AgentBufferField is a list of numpy arrays, or List[np.ndarray] for group entries.
|
100 |
+
When an agent collects a field, you can add it to its AgentBufferField with the append method.
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, *args, **kwargs):
|
104 |
+
self.padding_value = 0
|
105 |
+
super().__init__(*args, **kwargs)
|
106 |
+
|
107 |
+
def __str__(self) -> str:
|
108 |
+
return f"AgentBufferField: {super().__str__()}"
|
109 |
+
|
110 |
+
def __getitem__(self, index):
|
111 |
+
return_data = super().__getitem__(index)
|
112 |
+
if isinstance(return_data, list):
|
113 |
+
return AgentBufferField(return_data)
|
114 |
+
else:
|
115 |
+
return return_data
|
116 |
+
|
117 |
+
@property
|
118 |
+
def contains_lists(self) -> bool:
|
119 |
+
"""
|
120 |
+
Checks whether this AgentBufferField contains List[np.ndarray].
|
121 |
+
"""
|
122 |
+
return len(self) > 0 and isinstance(self[0], list)
|
123 |
+
|
124 |
+
def append(self, element: BufferEntry, padding_value: float = 0.0) -> None:
|
125 |
+
"""
|
126 |
+
Adds an element to this list. Also lets you change the padding
|
127 |
+
type, so that it can be set on append (e.g. action_masks should
|
128 |
+
be padded with 1.)
|
129 |
+
:param element: The element to append to the list.
|
130 |
+
:param padding_value: The value used to pad when get_batch is called.
|
131 |
+
"""
|
132 |
+
super().append(element)
|
133 |
+
self.padding_value = padding_value
|
134 |
+
|
135 |
+
def set(self, data: List[BufferEntry]) -> None:
|
136 |
+
"""
|
137 |
+
Sets the list of BufferEntry to the input data
|
138 |
+
:param data: The BufferEntry list to be set.
|
139 |
+
"""
|
140 |
+
self[:] = data
|
141 |
+
|
142 |
+
def get_batch(
|
143 |
+
self,
|
144 |
+
batch_size: int = None,
|
145 |
+
training_length: Optional[int] = 1,
|
146 |
+
sequential: bool = True,
|
147 |
+
) -> List[BufferEntry]:
|
148 |
+
"""
|
149 |
+
Retrieve the last batch_size elements of length training_length
|
150 |
+
from the list of np.array
|
151 |
+
:param batch_size: The number of elements to retrieve. If None:
|
152 |
+
All elements will be retrieved.
|
153 |
+
:param training_length: The length of the sequence to be retrieved. If
|
154 |
+
None: only takes one element.
|
155 |
+
:param sequential: If true and training_length is not None: the elements
|
156 |
+
will not repeat in the sequence. [a,b,c,d,e] with training_length = 2 and
|
157 |
+
sequential=True gives [[0,a],[b,c],[d,e]]. If sequential=False gives
|
158 |
+
[[a,b],[b,c],[c,d],[d,e]]
|
159 |
+
"""
|
160 |
+
if training_length is None:
|
161 |
+
training_length = 1
|
162 |
+
if sequential:
|
163 |
+
# The sequences will not have overlapping elements (this involves padding)
|
164 |
+
leftover = len(self) % training_length
|
165 |
+
# leftover is the number of elements in the first sequence (this sequence might need 0 padding)
|
166 |
+
if batch_size is None:
|
167 |
+
# retrieve the maximum number of elements
|
168 |
+
batch_size = len(self) // training_length + 1 * (leftover != 0)
|
169 |
+
# The maximum number of sequences taken from a list of length len(self) without overlapping
|
170 |
+
# with padding is equal to batch_size
|
171 |
+
if batch_size > (len(self) // training_length + 1 * (leftover != 0)):
|
172 |
+
raise BufferException(
|
173 |
+
"The batch size and training length requested for get_batch where"
|
174 |
+
" too large given the current number of data points."
|
175 |
+
)
|
176 |
+
if batch_size * training_length > len(self):
|
177 |
+
if self.contains_lists:
|
178 |
+
padding = []
|
179 |
+
else:
|
180 |
+
# We want to duplicate the last value in the array, multiplied by the padding_value.
|
181 |
+
padding = np.array(self[-1], dtype=np.float32) * self.padding_value
|
182 |
+
return self[:] + [padding] * (training_length - leftover)
|
183 |
+
|
184 |
+
else:
|
185 |
+
return self[len(self) - batch_size * training_length :]
|
186 |
+
else:
|
187 |
+
# The sequences will have overlapping elements
|
188 |
+
if batch_size is None:
|
189 |
+
# retrieve the maximum number of elements
|
190 |
+
batch_size = len(self) - training_length + 1
|
191 |
+
# The number of sequences of length training_length taken from a list of len(self) elements
|
192 |
+
# with overlapping is equal to batch_size
|
193 |
+
if (len(self) - training_length + 1) < batch_size:
|
194 |
+
raise BufferException(
|
195 |
+
"The batch size and training length requested for get_batch where"
|
196 |
+
" too large given the current number of data points."
|
197 |
+
)
|
198 |
+
tmp_list: List[np.ndarray] = []
|
199 |
+
for end in range(len(self) - batch_size + 1, len(self) + 1):
|
200 |
+
tmp_list += self[end - training_length : end]
|
201 |
+
return tmp_list
|
202 |
+
|
203 |
+
def reset_field(self) -> None:
|
204 |
+
"""
|
205 |
+
Resets the AgentBufferField
|
206 |
+
"""
|
207 |
+
self[:] = []
|
208 |
+
|
209 |
+
def padded_to_batch(
|
210 |
+
self, pad_value: np.float = 0, dtype: np.dtype = np.float32
|
211 |
+
) -> Union[np.ndarray, List[np.ndarray]]:
|
212 |
+
"""
|
213 |
+
Converts this AgentBufferField (which is a List[BufferEntry]) into a numpy array
|
214 |
+
with first dimension equal to the length of this AgentBufferField. If this AgentBufferField
|
215 |
+
contains a List[List[BufferEntry]] (i.e., in the case of group observations), return a List
|
216 |
+
containing numpy arrays or tensors, of length equal to the maximum length of an entry. Missing
|
217 |
+
For entries with less than that length, the array will be padded with pad_value.
|
218 |
+
:param pad_value: Value to pad List AgentBufferFields, when there are less than the maximum
|
219 |
+
number of agents present.
|
220 |
+
:param dtype: Dtype of output numpy array.
|
221 |
+
:return: Numpy array or List of numpy arrays representing this AgentBufferField, where the first
|
222 |
+
dimension is equal to the length of the AgentBufferField.
|
223 |
+
"""
|
224 |
+
if len(self) > 0 and not isinstance(self[0], list):
|
225 |
+
return np.asanyarray(self, dtype=dtype)
|
226 |
+
|
227 |
+
shape = None
|
228 |
+
for _entry in self:
|
229 |
+
# _entry could be an empty list if there are no group agents in this
|
230 |
+
# step. Find the first non-empty list and use that shape.
|
231 |
+
if _entry:
|
232 |
+
shape = _entry[0].shape
|
233 |
+
break
|
234 |
+
# If there were no groupmate agents in the entire batch, return an empty List.
|
235 |
+
if shape is None:
|
236 |
+
return []
|
237 |
+
|
238 |
+
# Convert to numpy array while padding with 0's
|
239 |
+
new_list = list(
|
240 |
+
map(
|
241 |
+
lambda x: np.asanyarray(x, dtype=dtype),
|
242 |
+
itertools.zip_longest(*self, fillvalue=np.full(shape, pad_value)),
|
243 |
+
)
|
244 |
+
)
|
245 |
+
return new_list
|
246 |
+
|
247 |
+
def to_ndarray(self):
|
248 |
+
"""
|
249 |
+
Returns the AgentBufferField which is a list of numpy ndarrays (or List[np.ndarray]) as an ndarray.
|
250 |
+
"""
|
251 |
+
return np.array(self)
|
252 |
+
|
253 |
+
|
254 |
+
class AgentBuffer(MutableMapping):
|
255 |
+
"""
|
256 |
+
AgentBuffer contains a dictionary of AgentBufferFields. Each agent has his own AgentBuffer.
|
257 |
+
The keys correspond to the name of the field. Example: state, action
|
258 |
+
"""
|
259 |
+
|
260 |
+
# Whether or not to validate the types of keys at runtime
|
261 |
+
# This should be off for training, but enabled for testing
|
262 |
+
CHECK_KEY_TYPES_AT_RUNTIME = False
|
263 |
+
|
264 |
+
def __init__(self):
|
265 |
+
self.last_brain_info = None
|
266 |
+
self.last_take_action_outputs = None
|
267 |
+
self._fields: DefaultDict[AgentBufferKey, AgentBufferField] = defaultdict(
|
268 |
+
AgentBufferField
|
269 |
+
)
|
270 |
+
|
271 |
+
def __str__(self):
|
272 |
+
return ", ".join([f"'{k}' : {str(self[k])}" for k in self._fields.keys()])
|
273 |
+
|
274 |
+
def reset_agent(self) -> None:
|
275 |
+
"""
|
276 |
+
Resets the AgentBuffer
|
277 |
+
"""
|
278 |
+
for f in self._fields.values():
|
279 |
+
f.reset_field()
|
280 |
+
self.last_brain_info = None
|
281 |
+
self.last_take_action_outputs = None
|
282 |
+
|
283 |
+
@staticmethod
|
284 |
+
def _check_key(key):
|
285 |
+
if isinstance(key, BufferKey):
|
286 |
+
return
|
287 |
+
if isinstance(key, tuple):
|
288 |
+
key0, key1 = key
|
289 |
+
if isinstance(key0, ObservationKeyPrefix):
|
290 |
+
if isinstance(key1, int):
|
291 |
+
return
|
292 |
+
raise KeyError(f"{key} has type ({type(key0)}, {type(key1)})")
|
293 |
+
if isinstance(key0, RewardSignalKeyPrefix):
|
294 |
+
if isinstance(key1, str):
|
295 |
+
return
|
296 |
+
raise KeyError(f"{key} has type ({type(key0)}, {type(key1)})")
|
297 |
+
raise KeyError(f"{key} is a {type(key)}")
|
298 |
+
|
299 |
+
@staticmethod
|
300 |
+
def _encode_key(key: AgentBufferKey) -> str:
|
301 |
+
"""
|
302 |
+
Convert the key to a string representation so that it can be used for serialization.
|
303 |
+
"""
|
304 |
+
if isinstance(key, BufferKey):
|
305 |
+
return key.value
|
306 |
+
prefix, suffix = key
|
307 |
+
return f"{prefix.value}:{suffix}"
|
308 |
+
|
309 |
+
@staticmethod
|
310 |
+
def _decode_key(encoded_key: str) -> AgentBufferKey:
|
311 |
+
"""
|
312 |
+
Convert the string representation back to a key after serialization.
|
313 |
+
"""
|
314 |
+
# Simple case: convert the string directly to a BufferKey
|
315 |
+
try:
|
316 |
+
return BufferKey(encoded_key)
|
317 |
+
except ValueError:
|
318 |
+
pass
|
319 |
+
|
320 |
+
# Not a simple key, so split into two parts
|
321 |
+
prefix_str, _, suffix_str = encoded_key.partition(":")
|
322 |
+
|
323 |
+
# See if it's an ObservationKeyPrefix first
|
324 |
+
try:
|
325 |
+
return ObservationKeyPrefix(prefix_str), int(suffix_str)
|
326 |
+
except ValueError:
|
327 |
+
pass
|
328 |
+
|
329 |
+
# If not, it had better be a RewardSignalKeyPrefix
|
330 |
+
try:
|
331 |
+
return RewardSignalKeyPrefix(prefix_str), suffix_str
|
332 |
+
except ValueError:
|
333 |
+
raise ValueError(f"Unable to convert {encoded_key} to an AgentBufferKey")
|
334 |
+
|
335 |
+
def __getitem__(self, key: AgentBufferKey) -> AgentBufferField:
|
336 |
+
if self.CHECK_KEY_TYPES_AT_RUNTIME:
|
337 |
+
self._check_key(key)
|
338 |
+
return self._fields[key]
|
339 |
+
|
340 |
+
def __setitem__(self, key: AgentBufferKey, value: AgentBufferField) -> None:
|
341 |
+
if self.CHECK_KEY_TYPES_AT_RUNTIME:
|
342 |
+
self._check_key(key)
|
343 |
+
self._fields[key] = value
|
344 |
+
|
345 |
+
def __delitem__(self, key: AgentBufferKey) -> None:
|
346 |
+
if self.CHECK_KEY_TYPES_AT_RUNTIME:
|
347 |
+
self._check_key(key)
|
348 |
+
self._fields.__delitem__(key)
|
349 |
+
|
350 |
+
def __iter__(self):
|
351 |
+
return self._fields.__iter__()
|
352 |
+
|
353 |
+
def __len__(self) -> int:
|
354 |
+
return self._fields.__len__()
|
355 |
+
|
356 |
+
def __contains__(self, key):
|
357 |
+
if self.CHECK_KEY_TYPES_AT_RUNTIME:
|
358 |
+
self._check_key(key)
|
359 |
+
return self._fields.__contains__(key)
|
360 |
+
|
361 |
+
def check_length(self, key_list: List[AgentBufferKey]) -> bool:
|
362 |
+
"""
|
363 |
+
Some methods will require that some fields have the same length.
|
364 |
+
check_length will return true if the fields in key_list
|
365 |
+
have the same length.
|
366 |
+
:param key_list: The fields which length will be compared
|
367 |
+
"""
|
368 |
+
if self.CHECK_KEY_TYPES_AT_RUNTIME:
|
369 |
+
for k in key_list:
|
370 |
+
self._check_key(k)
|
371 |
+
|
372 |
+
if len(key_list) < 2:
|
373 |
+
return True
|
374 |
+
length = None
|
375 |
+
for key in key_list:
|
376 |
+
if key not in self._fields:
|
377 |
+
return False
|
378 |
+
if (length is not None) and (length != len(self[key])):
|
379 |
+
return False
|
380 |
+
length = len(self[key])
|
381 |
+
return True
|
382 |
+
|
383 |
+
def shuffle(
|
384 |
+
self, sequence_length: int, key_list: List[AgentBufferKey] = None
|
385 |
+
) -> None:
|
386 |
+
"""
|
387 |
+
Shuffles the fields in key_list in a consistent way: The reordering will
|
388 |
+
be the same across fields.
|
389 |
+
:param key_list: The fields that must be shuffled.
|
390 |
+
"""
|
391 |
+
if key_list is None:
|
392 |
+
key_list = list(self._fields.keys())
|
393 |
+
if not self.check_length(key_list):
|
394 |
+
raise BufferException(
|
395 |
+
"Unable to shuffle if the fields are not of same length"
|
396 |
+
)
|
397 |
+
s = np.arange(len(self[key_list[0]]) // sequence_length)
|
398 |
+
np.random.shuffle(s)
|
399 |
+
for key in key_list:
|
400 |
+
buffer_field = self[key]
|
401 |
+
tmp: List[np.ndarray] = []
|
402 |
+
for i in s:
|
403 |
+
tmp += buffer_field[i * sequence_length : (i + 1) * sequence_length]
|
404 |
+
buffer_field.set(tmp)
|
405 |
+
|
406 |
+
def make_mini_batch(self, start: int, end: int) -> "AgentBuffer":
|
407 |
+
"""
|
408 |
+
Creates a mini-batch from buffer.
|
409 |
+
:param start: Starting index of buffer.
|
410 |
+
:param end: Ending index of buffer.
|
411 |
+
:return: Dict of mini batch.
|
412 |
+
"""
|
413 |
+
mini_batch = AgentBuffer()
|
414 |
+
for key, field in self._fields.items():
|
415 |
+
# slicing AgentBufferField returns a List[Any}
|
416 |
+
mini_batch[key] = field[start:end] # type: ignore
|
417 |
+
return mini_batch
|
418 |
+
|
419 |
+
def sample_mini_batch(
|
420 |
+
self, batch_size: int, sequence_length: int = 1
|
421 |
+
) -> "AgentBuffer":
|
422 |
+
"""
|
423 |
+
Creates a mini-batch from a random start and end.
|
424 |
+
:param batch_size: number of elements to withdraw.
|
425 |
+
:param sequence_length: Length of sequences to sample.
|
426 |
+
Number of sequences to sample will be batch_size/sequence_length.
|
427 |
+
"""
|
428 |
+
num_seq_to_sample = batch_size // sequence_length
|
429 |
+
mini_batch = AgentBuffer()
|
430 |
+
buff_len = self.num_experiences
|
431 |
+
num_sequences_in_buffer = buff_len // sequence_length
|
432 |
+
start_idxes = (
|
433 |
+
np.random.randint(num_sequences_in_buffer, size=num_seq_to_sample)
|
434 |
+
* sequence_length
|
435 |
+
) # Sample random sequence starts
|
436 |
+
for key in self:
|
437 |
+
buffer_field = self[key]
|
438 |
+
mb_list = (buffer_field[i : i + sequence_length] for i in start_idxes)
|
439 |
+
# See comparison of ways to make a list from a list of lists here:
|
440 |
+
# https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-list-of-lists
|
441 |
+
mini_batch[key].set(list(itertools.chain.from_iterable(mb_list)))
|
442 |
+
return mini_batch
|
443 |
+
|
444 |
+
def save_to_file(self, file_object: BinaryIO) -> None:
|
445 |
+
"""
|
446 |
+
Saves the AgentBuffer to a file-like object.
|
447 |
+
"""
|
448 |
+
with h5py.File(file_object, "w") as write_file:
|
449 |
+
for key, data in self.items():
|
450 |
+
write_file.create_dataset(
|
451 |
+
self._encode_key(key), data=data, dtype="f", compression="gzip"
|
452 |
+
)
|
453 |
+
|
454 |
+
def load_from_file(self, file_object: BinaryIO) -> None:
|
455 |
+
"""
|
456 |
+
Loads the AgentBuffer from a file-like object.
|
457 |
+
"""
|
458 |
+
with h5py.File(file_object, "r") as read_file:
|
459 |
+
for key in list(read_file.keys()):
|
460 |
+
decoded_key = self._decode_key(key)
|
461 |
+
self[decoded_key] = AgentBufferField()
|
462 |
+
# extend() will convert the numpy array's first dimension into list
|
463 |
+
self[decoded_key].extend(read_file[key][()])
|
464 |
+
|
465 |
+
def truncate(self, max_length: int, sequence_length: int = 1) -> None:
|
466 |
+
"""
|
467 |
+
Truncates the buffer to a certain length.
|
468 |
+
|
469 |
+
This can be slow for large buffers. We compensate by cutting further than we need to, so that
|
470 |
+
we're not truncating at each update. Note that we must truncate an integer number of sequence_lengths
|
471 |
+
param: max_length: The length at which to truncate the buffer.
|
472 |
+
"""
|
473 |
+
current_length = self.num_experiences
|
474 |
+
# make max_length an integer number of sequence_lengths
|
475 |
+
max_length -= max_length % sequence_length
|
476 |
+
if current_length > max_length:
|
477 |
+
for _key in self.keys():
|
478 |
+
self[_key][:] = self[_key][current_length - max_length :]
|
479 |
+
|
480 |
+
def resequence_and_append(
|
481 |
+
self,
|
482 |
+
target_buffer: "AgentBuffer",
|
483 |
+
key_list: List[AgentBufferKey] = None,
|
484 |
+
batch_size: int = None,
|
485 |
+
training_length: int = None,
|
486 |
+
) -> None:
|
487 |
+
"""
|
488 |
+
Takes in a batch size and training length (sequence length), and appends this AgentBuffer to target_buffer
|
489 |
+
properly padded for LSTM use. Optionally, use key_list to restrict which fields are inserted into the new
|
490 |
+
buffer.
|
491 |
+
:param target_buffer: The buffer which to append the samples to.
|
492 |
+
:param key_list: The fields that must be added. If None: all fields will be appended.
|
493 |
+
:param batch_size: The number of elements that must be appended. If None: All of them will be.
|
494 |
+
:param training_length: The length of the samples that must be appended. If None: only takes one element.
|
495 |
+
"""
|
496 |
+
if key_list is None:
|
497 |
+
key_list = list(self.keys())
|
498 |
+
if not self.check_length(key_list):
|
499 |
+
raise BufferException(
|
500 |
+
f"The length of the fields {key_list} were not of same length"
|
501 |
+
)
|
502 |
+
for field_key in key_list:
|
503 |
+
target_buffer[field_key].extend(
|
504 |
+
self[field_key].get_batch(
|
505 |
+
batch_size=batch_size, training_length=training_length
|
506 |
+
)
|
507 |
+
)
|
508 |
+
|
509 |
+
@property
|
510 |
+
def num_experiences(self) -> int:
|
511 |
+
"""
|
512 |
+
The number of agent experiences in the AgentBuffer, i.e. the length of the buffer.
|
513 |
+
|
514 |
+
An experience consists of one element across all of the fields of this AgentBuffer.
|
515 |
+
Note that these all have to be the same length, otherwise shuffle and append_to_update_buffer
|
516 |
+
will fail.
|
517 |
+
"""
|
518 |
+
if self.values():
|
519 |
+
return len(next(iter(self.values())))
|
520 |
+
else:
|
521 |
+
return 0
|
MLPY/Lib/site-packages/mlagents/trainers/cli_utils.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Set, Dict, Any, TextIO
|
2 |
+
import os
|
3 |
+
import yaml
|
4 |
+
from mlagents.trainers.exception import TrainerConfigError
|
5 |
+
from mlagents_envs.environment import UnityEnvironment
|
6 |
+
import argparse
|
7 |
+
from mlagents_envs import logging_util
|
8 |
+
|
9 |
+
logger = logging_util.get_logger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class RaiseRemovedWarning(argparse.Action):
|
13 |
+
"""
|
14 |
+
Internal custom Action to raise warning when argument is called.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, nargs=0, **kwargs):
|
18 |
+
super().__init__(nargs=nargs, **kwargs)
|
19 |
+
|
20 |
+
def __call__(self, arg_parser, namespace, values, option_string=None):
|
21 |
+
logger.warning(f"The command line argument {option_string} was removed.")
|
22 |
+
|
23 |
+
|
24 |
+
class DetectDefault(argparse.Action):
|
25 |
+
"""
|
26 |
+
Internal custom Action to help detect arguments that aren't default.
|
27 |
+
"""
|
28 |
+
|
29 |
+
non_default_args: Set[str] = set()
|
30 |
+
|
31 |
+
def __call__(self, arg_parser, namespace, values, option_string=None):
|
32 |
+
setattr(namespace, self.dest, values)
|
33 |
+
DetectDefault.non_default_args.add(self.dest)
|
34 |
+
|
35 |
+
|
36 |
+
class DetectDefaultStoreTrue(DetectDefault):
|
37 |
+
"""
|
38 |
+
Internal class to help detect arguments that aren't default.
|
39 |
+
Used for store_true arguments.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, nargs=0, **kwargs):
|
43 |
+
super().__init__(nargs=nargs, **kwargs)
|
44 |
+
|
45 |
+
def __call__(self, arg_parser, namespace, values, option_string=None):
|
46 |
+
super().__call__(arg_parser, namespace, True, option_string)
|
47 |
+
|
48 |
+
|
49 |
+
class StoreConfigFile(argparse.Action):
|
50 |
+
"""
|
51 |
+
Custom Action to store the config file location not as part of the CLI args.
|
52 |
+
This is because we want to maintain an equivalence between the config file's
|
53 |
+
contents and the args themselves.
|
54 |
+
"""
|
55 |
+
|
56 |
+
trainer_config_path: str
|
57 |
+
|
58 |
+
def __call__(self, arg_parser, namespace, values, option_string=None):
|
59 |
+
delattr(namespace, self.dest)
|
60 |
+
StoreConfigFile.trainer_config_path = values
|
61 |
+
|
62 |
+
|
63 |
+
def _create_parser() -> argparse.ArgumentParser:
|
64 |
+
argparser = argparse.ArgumentParser(
|
65 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
66 |
+
)
|
67 |
+
argparser.add_argument(
|
68 |
+
"trainer_config_path", action=StoreConfigFile, nargs="?", default=None
|
69 |
+
)
|
70 |
+
argparser.add_argument(
|
71 |
+
"--env",
|
72 |
+
default=None,
|
73 |
+
dest="env_path",
|
74 |
+
help="Path to the Unity executable to train",
|
75 |
+
action=DetectDefault,
|
76 |
+
)
|
77 |
+
argparser.add_argument(
|
78 |
+
"--load",
|
79 |
+
default=False,
|
80 |
+
dest="load_model",
|
81 |
+
action=DetectDefaultStoreTrue,
|
82 |
+
help=argparse.SUPPRESS, # Deprecated but still usable for now.
|
83 |
+
)
|
84 |
+
argparser.add_argument(
|
85 |
+
"--resume",
|
86 |
+
default=False,
|
87 |
+
dest="resume",
|
88 |
+
action=DetectDefaultStoreTrue,
|
89 |
+
help="Whether to resume training from a checkpoint. Specify a --run-id to use this option. "
|
90 |
+
"If set, the training code loads an already trained model to initialize the neural network "
|
91 |
+
"before resuming training. This option is only valid when the models exist, and have the same "
|
92 |
+
"behavior names as the current agents in your scene.",
|
93 |
+
)
|
94 |
+
argparser.add_argument(
|
95 |
+
"--deterministic",
|
96 |
+
default=False,
|
97 |
+
dest="deterministic",
|
98 |
+
action=DetectDefaultStoreTrue,
|
99 |
+
help="Whether to select actions deterministically in policy. `dist.mean` for continuous action "
|
100 |
+
"space, and `dist.argmax` for deterministic action space ",
|
101 |
+
)
|
102 |
+
argparser.add_argument(
|
103 |
+
"--force",
|
104 |
+
default=False,
|
105 |
+
dest="force",
|
106 |
+
action=DetectDefaultStoreTrue,
|
107 |
+
help="Whether to force-overwrite this run-id's existing summary and model data. (Without "
|
108 |
+
"this flag, attempting to train a model with a run-id that has been used before will throw "
|
109 |
+
"an error.",
|
110 |
+
)
|
111 |
+
argparser.add_argument(
|
112 |
+
"--run-id",
|
113 |
+
default="ppo",
|
114 |
+
help="The identifier for the training run. This identifier is used to name the "
|
115 |
+
"subdirectories in which the trained model and summary statistics are saved as well "
|
116 |
+
"as the saved model itself. If you use TensorBoard to view the training statistics, "
|
117 |
+
"always set a unique run-id for each training run. (The statistics for all runs with the "
|
118 |
+
"same id are combined as if they were produced by a the same session.)",
|
119 |
+
action=DetectDefault,
|
120 |
+
)
|
121 |
+
argparser.add_argument(
|
122 |
+
"--initialize-from",
|
123 |
+
metavar="RUN_ID",
|
124 |
+
default=None,
|
125 |
+
help="Specify a previously saved run ID from which to initialize the model from. "
|
126 |
+
"This can be used, for instance, to fine-tune an existing model on a new environment. "
|
127 |
+
"Note that the previously saved models must have the same behavior parameters as your "
|
128 |
+
"current environment.",
|
129 |
+
action=DetectDefault,
|
130 |
+
)
|
131 |
+
argparser.add_argument(
|
132 |
+
"--seed",
|
133 |
+
default=-1,
|
134 |
+
type=int,
|
135 |
+
help="A number to use as a seed for the random number generator used by the training code",
|
136 |
+
action=DetectDefault,
|
137 |
+
)
|
138 |
+
argparser.add_argument(
|
139 |
+
"--train",
|
140 |
+
default=False,
|
141 |
+
dest="train_model",
|
142 |
+
action=DetectDefaultStoreTrue,
|
143 |
+
help=argparse.SUPPRESS,
|
144 |
+
)
|
145 |
+
argparser.add_argument(
|
146 |
+
"--inference",
|
147 |
+
default=False,
|
148 |
+
dest="inference",
|
149 |
+
action=DetectDefaultStoreTrue,
|
150 |
+
help="Whether to run in Python inference mode (i.e. no training). Use with --resume to load "
|
151 |
+
"a model trained with an existing run ID.",
|
152 |
+
)
|
153 |
+
argparser.add_argument(
|
154 |
+
"--base-port",
|
155 |
+
default=UnityEnvironment.BASE_ENVIRONMENT_PORT,
|
156 |
+
type=int,
|
157 |
+
help="The starting port for environment communication. Each concurrent Unity environment "
|
158 |
+
"instance will get assigned a port sequentially, starting from the base-port. Each instance "
|
159 |
+
"will use the port (base_port + worker_id), where the worker_id is sequential IDs given to "
|
160 |
+
"each instance from 0 to (num_envs - 1). Note that when training using the Editor rather "
|
161 |
+
"than an executable, the base port will be ignored.",
|
162 |
+
action=DetectDefault,
|
163 |
+
)
|
164 |
+
argparser.add_argument(
|
165 |
+
"--num-envs",
|
166 |
+
default=1,
|
167 |
+
type=int,
|
168 |
+
help="The number of concurrent Unity environment instances to collect experiences "
|
169 |
+
"from when training",
|
170 |
+
action=DetectDefault,
|
171 |
+
)
|
172 |
+
|
173 |
+
argparser.add_argument(
|
174 |
+
"--num-areas",
|
175 |
+
default=1,
|
176 |
+
type=int,
|
177 |
+
help="The number of parallel training areas in each Unity environment instance.",
|
178 |
+
action=DetectDefault,
|
179 |
+
)
|
180 |
+
|
181 |
+
argparser.add_argument(
|
182 |
+
"--debug",
|
183 |
+
default=False,
|
184 |
+
action=DetectDefaultStoreTrue,
|
185 |
+
help="Whether to enable debug-level logging for some parts of the code",
|
186 |
+
)
|
187 |
+
argparser.add_argument(
|
188 |
+
"--env-args",
|
189 |
+
default=None,
|
190 |
+
nargs=argparse.REMAINDER,
|
191 |
+
help="Arguments passed to the Unity executable. Be aware that the standalone build will also "
|
192 |
+
"process these as Unity Command Line Arguments. You should choose different argument names if "
|
193 |
+
"you want to create environment-specific arguments. All arguments after this flag will be "
|
194 |
+
"passed to the executable.",
|
195 |
+
action=DetectDefault,
|
196 |
+
)
|
197 |
+
argparser.add_argument(
|
198 |
+
"--max-lifetime-restarts",
|
199 |
+
default=10,
|
200 |
+
help="The max number of times a single Unity executable can crash over its lifetime before ml-agents exits. "
|
201 |
+
"Can be set to -1 if no limit is desired.",
|
202 |
+
action=DetectDefault,
|
203 |
+
)
|
204 |
+
argparser.add_argument(
|
205 |
+
"--restarts-rate-limit-n",
|
206 |
+
default=1,
|
207 |
+
help="The maximum number of times a single Unity executable can crash over a period of time (period set in "
|
208 |
+
"restarts-rate-limit-period-s). Can be set to -1 to not use rate limiting with restarts.",
|
209 |
+
action=DetectDefault,
|
210 |
+
)
|
211 |
+
argparser.add_argument(
|
212 |
+
"--restarts-rate-limit-period-s",
|
213 |
+
default=60,
|
214 |
+
help="The period of time --restarts-rate-limit-n applies to.",
|
215 |
+
action=DetectDefault,
|
216 |
+
)
|
217 |
+
argparser.add_argument(
|
218 |
+
"--torch",
|
219 |
+
default=False,
|
220 |
+
action=RaiseRemovedWarning,
|
221 |
+
help="(Removed) Use the PyTorch framework.",
|
222 |
+
)
|
223 |
+
argparser.add_argument(
|
224 |
+
"--tensorflow",
|
225 |
+
default=False,
|
226 |
+
action=RaiseRemovedWarning,
|
227 |
+
help="(Removed) Use the TensorFlow framework.",
|
228 |
+
)
|
229 |
+
argparser.add_argument(
|
230 |
+
"--results-dir",
|
231 |
+
default="results",
|
232 |
+
action=DetectDefault,
|
233 |
+
help="Results base directory",
|
234 |
+
)
|
235 |
+
|
236 |
+
eng_conf = argparser.add_argument_group(title="Engine Configuration")
|
237 |
+
eng_conf.add_argument(
|
238 |
+
"--width",
|
239 |
+
default=84,
|
240 |
+
type=int,
|
241 |
+
help="The width of the executable window of the environment(s) in pixels "
|
242 |
+
"(ignored for editor training).",
|
243 |
+
action=DetectDefault,
|
244 |
+
)
|
245 |
+
eng_conf.add_argument(
|
246 |
+
"--height",
|
247 |
+
default=84,
|
248 |
+
type=int,
|
249 |
+
help="The height of the executable window of the environment(s) in pixels "
|
250 |
+
"(ignored for editor training)",
|
251 |
+
action=DetectDefault,
|
252 |
+
)
|
253 |
+
eng_conf.add_argument(
|
254 |
+
"--quality-level",
|
255 |
+
default=5,
|
256 |
+
type=int,
|
257 |
+
help="The quality level of the environment(s). Equivalent to calling "
|
258 |
+
"QualitySettings.SetQualityLevel in Unity.",
|
259 |
+
action=DetectDefault,
|
260 |
+
)
|
261 |
+
eng_conf.add_argument(
|
262 |
+
"--time-scale",
|
263 |
+
default=20,
|
264 |
+
type=float,
|
265 |
+
help="The time scale of the Unity environment(s). Equivalent to setting "
|
266 |
+
"Time.timeScale in Unity.",
|
267 |
+
action=DetectDefault,
|
268 |
+
)
|
269 |
+
eng_conf.add_argument(
|
270 |
+
"--target-frame-rate",
|
271 |
+
default=-1,
|
272 |
+
type=int,
|
273 |
+
help="The target frame rate of the Unity environment(s). Equivalent to setting "
|
274 |
+
"Application.targetFrameRate in Unity.",
|
275 |
+
action=DetectDefault,
|
276 |
+
)
|
277 |
+
eng_conf.add_argument(
|
278 |
+
"--capture-frame-rate",
|
279 |
+
default=60,
|
280 |
+
type=int,
|
281 |
+
help="The capture frame rate of the Unity environment(s). Equivalent to setting "
|
282 |
+
"Time.captureFramerate in Unity.",
|
283 |
+
action=DetectDefault,
|
284 |
+
)
|
285 |
+
eng_conf.add_argument(
|
286 |
+
"--no-graphics",
|
287 |
+
default=False,
|
288 |
+
action=DetectDefaultStoreTrue,
|
289 |
+
help="Whether to run the Unity executable in no-graphics mode (i.e. without initializing "
|
290 |
+
"the graphics driver. Use this only if your agents don't use visual observations.",
|
291 |
+
)
|
292 |
+
|
293 |
+
torch_conf = argparser.add_argument_group(title="Torch Configuration")
|
294 |
+
torch_conf.add_argument(
|
295 |
+
"--torch-device",
|
296 |
+
default=None,
|
297 |
+
dest="device",
|
298 |
+
action=DetectDefault,
|
299 |
+
help='Settings for the default torch.device used in training, for example, "cpu", "cuda", or "cuda:0"',
|
300 |
+
)
|
301 |
+
return argparser
|
302 |
+
|
303 |
+
|
304 |
+
def load_config(config_path: str) -> Dict[str, Any]:
|
305 |
+
try:
|
306 |
+
with open(config_path) as data_file:
|
307 |
+
return _load_config(data_file)
|
308 |
+
except OSError:
|
309 |
+
abs_path = os.path.abspath(config_path)
|
310 |
+
raise TrainerConfigError(f"Config file could not be found at {abs_path}.")
|
311 |
+
except UnicodeDecodeError:
|
312 |
+
raise TrainerConfigError(
|
313 |
+
f"There was an error decoding Config file from {config_path}. "
|
314 |
+
f"Make sure your file is save using UTF-8"
|
315 |
+
)
|
316 |
+
|
317 |
+
|
318 |
+
def _load_config(fp: TextIO) -> Dict[str, Any]:
|
319 |
+
"""
|
320 |
+
Load the yaml config from the file-like object.
|
321 |
+
"""
|
322 |
+
try:
|
323 |
+
return yaml.safe_load(fp)
|
324 |
+
except yaml.parser.ParserError as e:
|
325 |
+
raise TrainerConfigError(
|
326 |
+
"Error parsing yaml file. Please check for formatting errors. "
|
327 |
+
"A tool such as http://www.yamllint.com/ can be helpful with this."
|
328 |
+
) from e
|
329 |
+
|
330 |
+
|
331 |
+
parser = _create_parser()
|
MLPY/Lib/site-packages/mlagents/trainers/demo_loader.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Tuple
|
3 |
+
import numpy as np
|
4 |
+
from mlagents.trainers.buffer import AgentBuffer, BufferKey
|
5 |
+
from mlagents_envs.communicator_objects.agent_info_action_pair_pb2 import (
|
6 |
+
AgentInfoActionPairProto,
|
7 |
+
)
|
8 |
+
from mlagents.trainers.trajectory import ObsUtil
|
9 |
+
from mlagents_envs.rpc_utils import behavior_spec_from_proto, steps_from_proto
|
10 |
+
from mlagents_envs.base_env import BehaviorSpec
|
11 |
+
from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
|
12 |
+
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import (
|
13 |
+
DemonstrationMetaProto,
|
14 |
+
)
|
15 |
+
from mlagents_envs.timers import timed, hierarchical_timer
|
16 |
+
from google.protobuf.internal.decoder import _DecodeVarint32 # type: ignore
|
17 |
+
from google.protobuf.internal.encoder import _EncodeVarint # type: ignore
|
18 |
+
|
19 |
+
|
20 |
+
INITIAL_POS = 33
|
21 |
+
SUPPORTED_DEMONSTRATION_VERSIONS = frozenset([0, 1])
|
22 |
+
|
23 |
+
|
24 |
+
@timed
|
25 |
+
def make_demo_buffer(
|
26 |
+
pair_infos: List[AgentInfoActionPairProto],
|
27 |
+
behavior_spec: BehaviorSpec,
|
28 |
+
sequence_length: int,
|
29 |
+
) -> AgentBuffer:
|
30 |
+
# Create and populate buffer using experiences
|
31 |
+
demo_raw_buffer = AgentBuffer()
|
32 |
+
demo_processed_buffer = AgentBuffer()
|
33 |
+
for idx, current_pair_info in enumerate(pair_infos):
|
34 |
+
if idx > len(pair_infos) - 2:
|
35 |
+
break
|
36 |
+
next_pair_info = pair_infos[idx + 1]
|
37 |
+
current_decision_step, current_terminal_step = steps_from_proto(
|
38 |
+
[current_pair_info.agent_info], behavior_spec
|
39 |
+
)
|
40 |
+
next_decision_step, next_terminal_step = steps_from_proto(
|
41 |
+
[next_pair_info.agent_info], behavior_spec
|
42 |
+
)
|
43 |
+
previous_action = (
|
44 |
+
np.array(
|
45 |
+
pair_infos[idx].action_info.vector_actions_deprecated, dtype=np.float32
|
46 |
+
)
|
47 |
+
* 0
|
48 |
+
)
|
49 |
+
if idx > 0:
|
50 |
+
previous_action = np.array(
|
51 |
+
pair_infos[idx - 1].action_info.vector_actions_deprecated,
|
52 |
+
dtype=np.float32,
|
53 |
+
)
|
54 |
+
|
55 |
+
next_done = len(next_terminal_step) == 1
|
56 |
+
next_reward = 0
|
57 |
+
if len(next_terminal_step) == 1:
|
58 |
+
next_reward = next_terminal_step.reward[0]
|
59 |
+
else:
|
60 |
+
next_reward = next_decision_step.reward[0]
|
61 |
+
current_obs = None
|
62 |
+
if len(current_terminal_step) == 1:
|
63 |
+
current_obs = list(current_terminal_step.values())[0].obs
|
64 |
+
else:
|
65 |
+
current_obs = list(current_decision_step.values())[0].obs
|
66 |
+
|
67 |
+
demo_raw_buffer[BufferKey.DONE].append(next_done)
|
68 |
+
demo_raw_buffer[BufferKey.ENVIRONMENT_REWARDS].append(next_reward)
|
69 |
+
for i, obs in enumerate(current_obs):
|
70 |
+
demo_raw_buffer[ObsUtil.get_name_at(i)].append(obs)
|
71 |
+
if (
|
72 |
+
len(current_pair_info.action_info.continuous_actions) == 0
|
73 |
+
and len(current_pair_info.action_info.discrete_actions) == 0
|
74 |
+
):
|
75 |
+
if behavior_spec.action_spec.continuous_size > 0:
|
76 |
+
demo_raw_buffer[BufferKey.CONTINUOUS_ACTION].append(
|
77 |
+
current_pair_info.action_info.vector_actions_deprecated
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
demo_raw_buffer[BufferKey.DISCRETE_ACTION].append(
|
81 |
+
current_pair_info.action_info.vector_actions_deprecated
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
if behavior_spec.action_spec.continuous_size > 0:
|
85 |
+
demo_raw_buffer[BufferKey.CONTINUOUS_ACTION].append(
|
86 |
+
current_pair_info.action_info.continuous_actions
|
87 |
+
)
|
88 |
+
if behavior_spec.action_spec.discrete_size > 0:
|
89 |
+
demo_raw_buffer[BufferKey.DISCRETE_ACTION].append(
|
90 |
+
current_pair_info.action_info.discrete_actions
|
91 |
+
)
|
92 |
+
demo_raw_buffer[BufferKey.PREV_ACTION].append(previous_action)
|
93 |
+
if next_done:
|
94 |
+
demo_raw_buffer.resequence_and_append(
|
95 |
+
demo_processed_buffer, batch_size=None, training_length=sequence_length
|
96 |
+
)
|
97 |
+
demo_raw_buffer.reset_agent()
|
98 |
+
demo_raw_buffer.resequence_and_append(
|
99 |
+
demo_processed_buffer, batch_size=None, training_length=sequence_length
|
100 |
+
)
|
101 |
+
return demo_processed_buffer
|
102 |
+
|
103 |
+
|
104 |
+
@timed
|
105 |
+
def demo_to_buffer(
|
106 |
+
file_path: str, sequence_length: int, expected_behavior_spec: BehaviorSpec = None
|
107 |
+
) -> Tuple[BehaviorSpec, AgentBuffer]:
|
108 |
+
"""
|
109 |
+
Loads demonstration file and uses it to fill training buffer.
|
110 |
+
:param file_path: Location of demonstration file (.demo).
|
111 |
+
:param sequence_length: Length of trajectories to fill buffer.
|
112 |
+
:return:
|
113 |
+
"""
|
114 |
+
behavior_spec, info_action_pair, _ = load_demonstration(file_path)
|
115 |
+
demo_buffer = make_demo_buffer(info_action_pair, behavior_spec, sequence_length)
|
116 |
+
if expected_behavior_spec:
|
117 |
+
# check action dimensions in demonstration match
|
118 |
+
if behavior_spec.action_spec != expected_behavior_spec.action_spec:
|
119 |
+
raise RuntimeError(
|
120 |
+
"The actions {} in demonstration do not match the policy's {}.".format(
|
121 |
+
behavior_spec.action_spec, expected_behavior_spec.action_spec
|
122 |
+
)
|
123 |
+
)
|
124 |
+
# check observations match
|
125 |
+
if len(behavior_spec.observation_specs) != len(
|
126 |
+
expected_behavior_spec.observation_specs
|
127 |
+
):
|
128 |
+
raise RuntimeError(
|
129 |
+
"The demonstrations do not have the same number of observations as the policy."
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
for i, (demo_obs, policy_obs) in enumerate(
|
133 |
+
zip(
|
134 |
+
behavior_spec.observation_specs,
|
135 |
+
expected_behavior_spec.observation_specs,
|
136 |
+
)
|
137 |
+
):
|
138 |
+
if demo_obs.shape != policy_obs.shape:
|
139 |
+
raise RuntimeError(
|
140 |
+
f"The shape {demo_obs} for observation {i} in demonstration \
|
141 |
+
do not match the policy's {policy_obs}."
|
142 |
+
)
|
143 |
+
return behavior_spec, demo_buffer
|
144 |
+
|
145 |
+
|
146 |
+
def get_demo_files(path: str) -> List[str]:
|
147 |
+
"""
|
148 |
+
Retrieves the demonstration file(s) from a path.
|
149 |
+
:param path: Path of demonstration file or directory.
|
150 |
+
:return: List of demonstration files
|
151 |
+
|
152 |
+
Raises errors if |path| is invalid.
|
153 |
+
"""
|
154 |
+
if os.path.isfile(path):
|
155 |
+
if not path.endswith(".demo"):
|
156 |
+
raise ValueError("The path provided is not a '.demo' file.")
|
157 |
+
return [path]
|
158 |
+
elif os.path.isdir(path):
|
159 |
+
paths = [
|
160 |
+
os.path.join(path, name)
|
161 |
+
for name in os.listdir(path)
|
162 |
+
if name.endswith(".demo")
|
163 |
+
]
|
164 |
+
if not paths:
|
165 |
+
raise ValueError("There are no '.demo' files in the provided directory.")
|
166 |
+
return paths
|
167 |
+
else:
|
168 |
+
raise FileNotFoundError(
|
169 |
+
f"The demonstration file or directory {path} does not exist."
|
170 |
+
)
|
171 |
+
|
172 |
+
|
173 |
+
@timed
|
174 |
+
def load_demonstration(
|
175 |
+
file_path: str,
|
176 |
+
) -> Tuple[BehaviorSpec, List[AgentInfoActionPairProto], int]:
|
177 |
+
"""
|
178 |
+
Loads and parses a demonstration file.
|
179 |
+
:param file_path: Location of demonstration file (.demo).
|
180 |
+
:return: BrainParameter and list of AgentInfoActionPairProto containing demonstration data.
|
181 |
+
"""
|
182 |
+
|
183 |
+
# First 32 bytes of file dedicated to meta-data.
|
184 |
+
file_paths = get_demo_files(file_path)
|
185 |
+
behavior_spec = None
|
186 |
+
brain_param_proto = None
|
187 |
+
info_action_pairs = []
|
188 |
+
total_expected = 0
|
189 |
+
for _file_path in file_paths:
|
190 |
+
with open(_file_path, "rb") as fp:
|
191 |
+
with hierarchical_timer("read_file"):
|
192 |
+
data = fp.read()
|
193 |
+
next_pos, pos, obs_decoded = 0, 0, 0
|
194 |
+
while pos < len(data):
|
195 |
+
next_pos, pos = _DecodeVarint32(data, pos)
|
196 |
+
if obs_decoded == 0:
|
197 |
+
meta_data_proto = DemonstrationMetaProto()
|
198 |
+
meta_data_proto.ParseFromString(data[pos : pos + next_pos])
|
199 |
+
if (
|
200 |
+
meta_data_proto.api_version
|
201 |
+
not in SUPPORTED_DEMONSTRATION_VERSIONS
|
202 |
+
):
|
203 |
+
raise RuntimeError(
|
204 |
+
f"Can't load Demonstration data from an unsupported version ({meta_data_proto.api_version})"
|
205 |
+
)
|
206 |
+
total_expected += meta_data_proto.number_steps
|
207 |
+
pos = INITIAL_POS
|
208 |
+
if obs_decoded == 1:
|
209 |
+
brain_param_proto = BrainParametersProto()
|
210 |
+
brain_param_proto.ParseFromString(data[pos : pos + next_pos])
|
211 |
+
pos += next_pos
|
212 |
+
if obs_decoded > 1:
|
213 |
+
agent_info_action = AgentInfoActionPairProto()
|
214 |
+
agent_info_action.ParseFromString(data[pos : pos + next_pos])
|
215 |
+
if behavior_spec is None:
|
216 |
+
behavior_spec = behavior_spec_from_proto(
|
217 |
+
brain_param_proto, agent_info_action.agent_info
|
218 |
+
)
|
219 |
+
info_action_pairs.append(agent_info_action)
|
220 |
+
if len(info_action_pairs) == total_expected:
|
221 |
+
break
|
222 |
+
pos += next_pos
|
223 |
+
obs_decoded += 1
|
224 |
+
if not behavior_spec:
|
225 |
+
raise RuntimeError(
|
226 |
+
f"No BrainParameters found in demonstration file at {file_path}."
|
227 |
+
)
|
228 |
+
return behavior_spec, info_action_pairs, total_expected
|
229 |
+
|
230 |
+
|
231 |
+
def write_delimited(f, message):
|
232 |
+
msg_string = message.SerializeToString()
|
233 |
+
msg_size = len(msg_string)
|
234 |
+
_EncodeVarint(f.write, msg_size)
|
235 |
+
f.write(msg_string)
|
236 |
+
|
237 |
+
|
238 |
+
def write_demo(demo_path, meta_data_proto, brain_param_proto, agent_info_protos):
|
239 |
+
with open(demo_path, "wb") as f:
|
240 |
+
# write metadata
|
241 |
+
write_delimited(f, meta_data_proto)
|
242 |
+
f.seek(INITIAL_POS)
|
243 |
+
write_delimited(f, brain_param_proto)
|
244 |
+
|
245 |
+
for agent in agent_info_protos:
|
246 |
+
write_delimited(f, agent)
|
MLPY/Lib/site-packages/mlagents/trainers/directory_utils.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from mlagents.trainers.exception import UnityTrainerException
|
3 |
+
from mlagents.trainers.settings import TrainerSettings
|
4 |
+
from mlagents.trainers.model_saver.torch_model_saver import DEFAULT_CHECKPOINT_NAME
|
5 |
+
|
6 |
+
|
7 |
+
def validate_existing_directories(
|
8 |
+
output_path: str, resume: bool, force: bool, init_path: str = None
|
9 |
+
) -> None:
|
10 |
+
"""
|
11 |
+
Validates that if the run_id model exists, we do not overwrite it unless --force is specified.
|
12 |
+
Throws an exception if resume isn't specified and run_id exists. Throws an exception
|
13 |
+
if --resume is specified and run-id was not found.
|
14 |
+
:param model_path: The model path specified.
|
15 |
+
:param summary_path: The summary path to be used.
|
16 |
+
:param resume: Whether or not the --resume flag was passed.
|
17 |
+
:param force: Whether or not the --force flag was passed.
|
18 |
+
:param init_path: Path to run-id dir to initialize from
|
19 |
+
"""
|
20 |
+
|
21 |
+
output_path_exists = os.path.isdir(output_path)
|
22 |
+
|
23 |
+
if output_path_exists:
|
24 |
+
if not resume and not force:
|
25 |
+
raise UnityTrainerException(
|
26 |
+
"Previous data from this run ID was found. "
|
27 |
+
"Either specify a new run ID, use --resume to resume this run, "
|
28 |
+
"or use the --force parameter to overwrite existing data."
|
29 |
+
)
|
30 |
+
else:
|
31 |
+
if resume:
|
32 |
+
raise UnityTrainerException(
|
33 |
+
"Previous data from this run ID was not found. "
|
34 |
+
"Train a new run by removing the --resume flag."
|
35 |
+
)
|
36 |
+
|
37 |
+
# Verify init path if specified.
|
38 |
+
if init_path is not None:
|
39 |
+
if not os.path.isdir(init_path):
|
40 |
+
raise UnityTrainerException(
|
41 |
+
"Could not initialize from {}. "
|
42 |
+
"Make sure models have already been saved with that run ID.".format(
|
43 |
+
init_path
|
44 |
+
)
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
def setup_init_path(
|
49 |
+
behaviors: TrainerSettings.DefaultTrainerDict, init_dir: str
|
50 |
+
) -> None:
|
51 |
+
"""
|
52 |
+
For each behavior, setup full init_path to checkpoint file to initialize policy from
|
53 |
+
:param behaviors: mapping from behavior_name to TrainerSettings
|
54 |
+
:param init_dir: Path to run-id dir to initialize from
|
55 |
+
"""
|
56 |
+
for behavior_name, ts in behaviors.items():
|
57 |
+
if ts.init_path is None:
|
58 |
+
# set default if None
|
59 |
+
ts.init_path = os.path.join(
|
60 |
+
init_dir, behavior_name, DEFAULT_CHECKPOINT_NAME
|
61 |
+
)
|
62 |
+
elif not os.path.dirname(ts.init_path):
|
63 |
+
# update to full path if just the file name
|
64 |
+
ts.init_path = os.path.join(init_dir, behavior_name, ts.init_path)
|
65 |
+
_validate_init_full_path(ts.init_path)
|
66 |
+
|
67 |
+
|
68 |
+
def _validate_init_full_path(init_file: str) -> None:
|
69 |
+
"""
|
70 |
+
Validate initialization path to be a .pt file
|
71 |
+
:param init_file: full path to initialization checkpoint file
|
72 |
+
"""
|
73 |
+
if not (os.path.isfile(init_file) and init_file.endswith(".pt")):
|
74 |
+
raise UnityTrainerException(
|
75 |
+
f"Could not initialize from {init_file}. file does not exists or is not a `.pt` file"
|
76 |
+
)
|
MLPY/Lib/site-packages/mlagents/trainers/env_manager.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
from typing import List, Dict, NamedTuple, Iterable, Tuple
|
4 |
+
from mlagents_envs.base_env import (
|
5 |
+
DecisionSteps,
|
6 |
+
TerminalSteps,
|
7 |
+
BehaviorSpec,
|
8 |
+
BehaviorName,
|
9 |
+
)
|
10 |
+
from mlagents_envs.side_channel.stats_side_channel import EnvironmentStats
|
11 |
+
|
12 |
+
from mlagents.trainers.policy import Policy
|
13 |
+
from mlagents.trainers.agent_processor import AgentManager, AgentManagerQueue
|
14 |
+
from mlagents.trainers.action_info import ActionInfo
|
15 |
+
from mlagents.trainers.settings import TrainerSettings
|
16 |
+
from mlagents_envs.logging_util import get_logger
|
17 |
+
|
18 |
+
AllStepResult = Dict[BehaviorName, Tuple[DecisionSteps, TerminalSteps]]
|
19 |
+
AllGroupSpec = Dict[BehaviorName, BehaviorSpec]
|
20 |
+
|
21 |
+
logger = get_logger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
class EnvironmentStep(NamedTuple):
|
25 |
+
current_all_step_result: AllStepResult
|
26 |
+
worker_id: int
|
27 |
+
brain_name_to_action_info: Dict[BehaviorName, ActionInfo]
|
28 |
+
environment_stats: EnvironmentStats
|
29 |
+
|
30 |
+
@property
|
31 |
+
def name_behavior_ids(self) -> Iterable[BehaviorName]:
|
32 |
+
return self.current_all_step_result.keys()
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def empty(worker_id: int) -> "EnvironmentStep":
|
36 |
+
return EnvironmentStep({}, worker_id, {}, {})
|
37 |
+
|
38 |
+
|
39 |
+
class EnvManager(ABC):
|
40 |
+
def __init__(self):
|
41 |
+
self.policies: Dict[BehaviorName, Policy] = {}
|
42 |
+
self.agent_managers: Dict[BehaviorName, AgentManager] = {}
|
43 |
+
self.first_step_infos: List[EnvironmentStep] = []
|
44 |
+
|
45 |
+
def set_policy(self, brain_name: BehaviorName, policy: Policy) -> None:
|
46 |
+
self.policies[brain_name] = policy
|
47 |
+
if brain_name in self.agent_managers:
|
48 |
+
self.agent_managers[brain_name].policy = policy
|
49 |
+
|
50 |
+
def set_agent_manager(
|
51 |
+
self, brain_name: BehaviorName, manager: AgentManager
|
52 |
+
) -> None:
|
53 |
+
self.agent_managers[brain_name] = manager
|
54 |
+
|
55 |
+
@abstractmethod
|
56 |
+
def _step(self) -> List[EnvironmentStep]:
|
57 |
+
pass
|
58 |
+
|
59 |
+
@abstractmethod
|
60 |
+
def _reset_env(self, config: Dict = None) -> List[EnvironmentStep]:
|
61 |
+
pass
|
62 |
+
|
63 |
+
def reset(self, config: Dict = None) -> int:
|
64 |
+
for manager in self.agent_managers.values():
|
65 |
+
manager.end_episode()
|
66 |
+
# Save the first step infos, after the reset.
|
67 |
+
# They will be processed on the first advance().
|
68 |
+
self.first_step_infos = self._reset_env(config)
|
69 |
+
return len(self.first_step_infos)
|
70 |
+
|
71 |
+
@abstractmethod
|
72 |
+
def set_env_parameters(self, config: Dict = None) -> None:
|
73 |
+
"""
|
74 |
+
Sends environment parameter settings to C# via the
|
75 |
+
EnvironmentParametersSideChannel.
|
76 |
+
:param config: Dict of environment parameter keys and values
|
77 |
+
"""
|
78 |
+
pass
|
79 |
+
|
80 |
+
def on_training_started(
|
81 |
+
self, behavior_name: str, trainer_settings: TrainerSettings
|
82 |
+
) -> None:
|
83 |
+
"""
|
84 |
+
Handle traing starting for a new behavior type. Generally nothing is necessary here.
|
85 |
+
:param behavior_name:
|
86 |
+
:param trainer_settings:
|
87 |
+
:return:
|
88 |
+
"""
|
89 |
+
pass
|
90 |
+
|
91 |
+
@property
|
92 |
+
@abstractmethod
|
93 |
+
def training_behaviors(self) -> Dict[BehaviorName, BehaviorSpec]:
|
94 |
+
pass
|
95 |
+
|
96 |
+
@abstractmethod
|
97 |
+
def close(self):
|
98 |
+
pass
|
99 |
+
|
100 |
+
def get_steps(self) -> List[EnvironmentStep]:
|
101 |
+
"""
|
102 |
+
Updates the policies, steps the environments, and returns the step information from the environments.
|
103 |
+
Calling code should pass the returned EnvironmentSteps to process_steps() after calling this.
|
104 |
+
:return: The list of EnvironmentSteps
|
105 |
+
"""
|
106 |
+
# If we had just reset, process the first EnvironmentSteps.
|
107 |
+
# Note that we do it here instead of in reset() so that on the very first reset(),
|
108 |
+
# we can create the needed AgentManagers before calling advance() and processing the EnvironmentSteps.
|
109 |
+
if self.first_step_infos:
|
110 |
+
self._process_step_infos(self.first_step_infos)
|
111 |
+
self.first_step_infos = []
|
112 |
+
# Get new policies if found. Always get the latest policy.
|
113 |
+
for brain_name in self.agent_managers.keys():
|
114 |
+
_policy = None
|
115 |
+
try:
|
116 |
+
# We make sure to empty the policy queue before continuing to produce steps.
|
117 |
+
# This halts the trainers until the policy queue is empty.
|
118 |
+
while True:
|
119 |
+
_policy = self.agent_managers[brain_name].policy_queue.get_nowait()
|
120 |
+
except AgentManagerQueue.Empty:
|
121 |
+
if _policy is not None:
|
122 |
+
self.set_policy(brain_name, _policy)
|
123 |
+
# Step the environments
|
124 |
+
new_step_infos = self._step()
|
125 |
+
return new_step_infos
|
126 |
+
|
127 |
+
def process_steps(self, new_step_infos: List[EnvironmentStep]) -> int:
|
128 |
+
# Add to AgentProcessor
|
129 |
+
num_step_infos = self._process_step_infos(new_step_infos)
|
130 |
+
return num_step_infos
|
131 |
+
|
132 |
+
def _process_step_infos(self, step_infos: List[EnvironmentStep]) -> int:
|
133 |
+
for step_info in step_infos:
|
134 |
+
for name_behavior_id in step_info.name_behavior_ids:
|
135 |
+
if name_behavior_id not in self.agent_managers:
|
136 |
+
logger.warning(
|
137 |
+
"Agent manager was not created for behavior id {}.".format(
|
138 |
+
name_behavior_id
|
139 |
+
)
|
140 |
+
)
|
141 |
+
continue
|
142 |
+
decision_steps, terminal_steps = step_info.current_all_step_result[
|
143 |
+
name_behavior_id
|
144 |
+
]
|
145 |
+
self.agent_managers[name_behavior_id].add_experiences(
|
146 |
+
decision_steps,
|
147 |
+
terminal_steps,
|
148 |
+
step_info.worker_id,
|
149 |
+
step_info.brain_name_to_action_info.get(
|
150 |
+
name_behavior_id, ActionInfo.empty()
|
151 |
+
),
|
152 |
+
)
|
153 |
+
|
154 |
+
self.agent_managers[name_behavior_id].record_environment_stats(
|
155 |
+
step_info.environment_stats, step_info.worker_id
|
156 |
+
)
|
157 |
+
return len(step_infos)
|
MLPY/Lib/site-packages/mlagents/trainers/environment_parameter_manager.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Tuple, Optional
|
2 |
+
from mlagents.trainers.settings import (
|
3 |
+
EnvironmentParameterSettings,
|
4 |
+
ParameterRandomizationSettings,
|
5 |
+
)
|
6 |
+
from collections import defaultdict
|
7 |
+
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType
|
8 |
+
|
9 |
+
from mlagents_envs.logging_util import get_logger
|
10 |
+
|
11 |
+
logger = get_logger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class EnvironmentParameterManager:
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
settings: Optional[Dict[str, EnvironmentParameterSettings]] = None,
|
18 |
+
run_seed: int = -1,
|
19 |
+
restore: bool = False,
|
20 |
+
):
|
21 |
+
"""
|
22 |
+
EnvironmentParameterManager manages all the environment parameters of a training
|
23 |
+
session. It determines when parameters should change and gives access to the
|
24 |
+
current sampler of each parameter.
|
25 |
+
:param settings: A dictionary from environment parameter to
|
26 |
+
EnvironmentParameterSettings.
|
27 |
+
:param run_seed: When the seed is not provided for an environment parameter,
|
28 |
+
this seed will be used instead.
|
29 |
+
:param restore: If true, the EnvironmentParameterManager will use the
|
30 |
+
GlobalTrainingStatus to try and reload the lesson status of each environment
|
31 |
+
parameter.
|
32 |
+
"""
|
33 |
+
if settings is None:
|
34 |
+
settings = {}
|
35 |
+
self._dict_settings = settings
|
36 |
+
for parameter_name in self._dict_settings.keys():
|
37 |
+
initial_lesson = GlobalTrainingStatus.get_parameter_state(
|
38 |
+
parameter_name, StatusType.LESSON_NUM
|
39 |
+
)
|
40 |
+
if initial_lesson is None or not restore:
|
41 |
+
GlobalTrainingStatus.set_parameter_state(
|
42 |
+
parameter_name, StatusType.LESSON_NUM, 0
|
43 |
+
)
|
44 |
+
self._smoothed_values: Dict[str, float] = defaultdict(float)
|
45 |
+
for key in self._dict_settings.keys():
|
46 |
+
self._smoothed_values[key] = 0.0
|
47 |
+
# Update the seeds of the samplers
|
48 |
+
self._set_sampler_seeds(run_seed)
|
49 |
+
|
50 |
+
def _set_sampler_seeds(self, seed):
|
51 |
+
"""
|
52 |
+
Sets the seeds for the samplers (if no seed was already present). Note that
|
53 |
+
using the provided seed.
|
54 |
+
"""
|
55 |
+
offset = 0
|
56 |
+
for settings in self._dict_settings.values():
|
57 |
+
for lesson in settings.curriculum:
|
58 |
+
if lesson.value.seed == -1:
|
59 |
+
lesson.value.seed = seed + offset
|
60 |
+
offset += 1
|
61 |
+
|
62 |
+
def get_minimum_reward_buffer_size(self, behavior_name: str) -> int:
|
63 |
+
"""
|
64 |
+
Calculates the minimum size of the reward buffer a behavior must use. This
|
65 |
+
method uses the 'min_lesson_length' sampler_parameter to determine this value.
|
66 |
+
:param behavior_name: The name of the behavior the minimum reward buffer
|
67 |
+
size corresponds to.
|
68 |
+
"""
|
69 |
+
result = 1
|
70 |
+
for settings in self._dict_settings.values():
|
71 |
+
for lesson in settings.curriculum:
|
72 |
+
if lesson.completion_criteria is not None:
|
73 |
+
if lesson.completion_criteria.behavior == behavior_name:
|
74 |
+
result = max(
|
75 |
+
result, lesson.completion_criteria.min_lesson_length
|
76 |
+
)
|
77 |
+
return result
|
78 |
+
|
79 |
+
def get_current_samplers(self) -> Dict[str, ParameterRandomizationSettings]:
|
80 |
+
"""
|
81 |
+
Creates a dictionary from environment parameter name to their corresponding
|
82 |
+
ParameterRandomizationSettings. If curriculum is used, the
|
83 |
+
ParameterRandomizationSettings corresponds to the sampler of the current lesson.
|
84 |
+
"""
|
85 |
+
samplers: Dict[str, ParameterRandomizationSettings] = {}
|
86 |
+
for param_name, settings in self._dict_settings.items():
|
87 |
+
lesson_num = GlobalTrainingStatus.get_parameter_state(
|
88 |
+
param_name, StatusType.LESSON_NUM
|
89 |
+
)
|
90 |
+
lesson = settings.curriculum[lesson_num]
|
91 |
+
samplers[param_name] = lesson.value
|
92 |
+
return samplers
|
93 |
+
|
94 |
+
def get_current_lesson_number(self) -> Dict[str, int]:
|
95 |
+
"""
|
96 |
+
Creates a dictionary from environment parameter to the current lesson number.
|
97 |
+
If not using curriculum, this number is always 0 for that environment parameter.
|
98 |
+
"""
|
99 |
+
result: Dict[str, int] = {}
|
100 |
+
for parameter_name in self._dict_settings.keys():
|
101 |
+
result[parameter_name] = GlobalTrainingStatus.get_parameter_state(
|
102 |
+
parameter_name, StatusType.LESSON_NUM
|
103 |
+
)
|
104 |
+
return result
|
105 |
+
|
106 |
+
def log_current_lesson(self, parameter_name: Optional[str] = None) -> None:
|
107 |
+
"""
|
108 |
+
Logs the current lesson number and sampler value of the parameter with name
|
109 |
+
parameter_name. If no parameter_name is provided, the values and lesson
|
110 |
+
numbers of all parameters will be displayed.
|
111 |
+
"""
|
112 |
+
if parameter_name is not None:
|
113 |
+
settings = self._dict_settings[parameter_name]
|
114 |
+
lesson_number = GlobalTrainingStatus.get_parameter_state(
|
115 |
+
parameter_name, StatusType.LESSON_NUM
|
116 |
+
)
|
117 |
+
lesson_name = settings.curriculum[lesson_number].name
|
118 |
+
lesson_value = settings.curriculum[lesson_number].value
|
119 |
+
logger.info(
|
120 |
+
f"Parameter '{parameter_name}' is in lesson '{lesson_name}' "
|
121 |
+
f"and has value '{lesson_value}'."
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
for parameter_name, settings in self._dict_settings.items():
|
125 |
+
lesson_number = GlobalTrainingStatus.get_parameter_state(
|
126 |
+
parameter_name, StatusType.LESSON_NUM
|
127 |
+
)
|
128 |
+
lesson_name = settings.curriculum[lesson_number].name
|
129 |
+
lesson_value = settings.curriculum[lesson_number].value
|
130 |
+
logger.info(
|
131 |
+
f"Parameter '{parameter_name}' is in lesson '{lesson_name}' "
|
132 |
+
f"and has value '{lesson_value}'."
|
133 |
+
)
|
134 |
+
|
135 |
+
def update_lessons(
|
136 |
+
self,
|
137 |
+
trainer_steps: Dict[str, int],
|
138 |
+
trainer_max_steps: Dict[str, int],
|
139 |
+
trainer_reward_buffer: Dict[str, List[float]],
|
140 |
+
) -> Tuple[bool, bool]:
|
141 |
+
"""
|
142 |
+
Given progress metrics, calculates if at least one environment parameter is
|
143 |
+
in a new lesson and if at least one environment parameter requires the env
|
144 |
+
to reset.
|
145 |
+
:param trainer_steps: A dictionary from behavior_name to the number of training
|
146 |
+
steps this behavior's trainer has performed.
|
147 |
+
:param trainer_max_steps: A dictionary from behavior_name to the maximum number
|
148 |
+
of training steps this behavior's trainer has performed.
|
149 |
+
:param trainer_reward_buffer: A dictionary from behavior_name to the list of
|
150 |
+
the most recent episode returns for this behavior's trainer.
|
151 |
+
:returns: A tuple of two booleans : (True if any lesson has changed, True if
|
152 |
+
environment needs to reset)
|
153 |
+
"""
|
154 |
+
must_reset = False
|
155 |
+
updated = False
|
156 |
+
for param_name, settings in self._dict_settings.items():
|
157 |
+
lesson_num = GlobalTrainingStatus.get_parameter_state(
|
158 |
+
param_name, StatusType.LESSON_NUM
|
159 |
+
)
|
160 |
+
next_lesson_num = lesson_num + 1
|
161 |
+
lesson = settings.curriculum[lesson_num]
|
162 |
+
if (
|
163 |
+
lesson.completion_criteria is not None
|
164 |
+
and len(settings.curriculum) > next_lesson_num
|
165 |
+
):
|
166 |
+
behavior_to_consider = lesson.completion_criteria.behavior
|
167 |
+
if behavior_to_consider in trainer_steps:
|
168 |
+
(
|
169 |
+
must_increment,
|
170 |
+
new_smoothing,
|
171 |
+
) = lesson.completion_criteria.need_increment(
|
172 |
+
float(trainer_steps[behavior_to_consider])
|
173 |
+
/ float(trainer_max_steps[behavior_to_consider]),
|
174 |
+
trainer_reward_buffer[behavior_to_consider],
|
175 |
+
self._smoothed_values[param_name],
|
176 |
+
)
|
177 |
+
self._smoothed_values[param_name] = new_smoothing
|
178 |
+
if must_increment:
|
179 |
+
GlobalTrainingStatus.set_parameter_state(
|
180 |
+
param_name, StatusType.LESSON_NUM, next_lesson_num
|
181 |
+
)
|
182 |
+
self.log_current_lesson(param_name)
|
183 |
+
updated = True
|
184 |
+
if lesson.completion_criteria.require_reset:
|
185 |
+
must_reset = True
|
186 |
+
return updated, must_reset
|
MLPY/Lib/site-packages/mlagents/trainers/exception.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Contains exceptions for the trainers package.
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
class TrainerError(Exception):
|
7 |
+
"""
|
8 |
+
Any error related to the trainers in the ML-Agents Toolkit.
|
9 |
+
"""
|
10 |
+
|
11 |
+
pass
|
12 |
+
|
13 |
+
|
14 |
+
class TrainerConfigError(Exception):
|
15 |
+
"""
|
16 |
+
Any error related to the configuration of trainers in the ML-Agents Toolkit.
|
17 |
+
"""
|
18 |
+
|
19 |
+
pass
|
20 |
+
|
21 |
+
|
22 |
+
class TrainerConfigWarning(Warning):
|
23 |
+
"""
|
24 |
+
Any warning related to the configuration of trainers in the ML-Agents Toolkit.
|
25 |
+
"""
|
26 |
+
|
27 |
+
pass
|
28 |
+
|
29 |
+
|
30 |
+
class CurriculumError(TrainerError):
|
31 |
+
"""
|
32 |
+
Any error related to training with a curriculum.
|
33 |
+
"""
|
34 |
+
|
35 |
+
pass
|
36 |
+
|
37 |
+
|
38 |
+
class CurriculumLoadingError(CurriculumError):
|
39 |
+
"""
|
40 |
+
Any error related to loading the Curriculum config file.
|
41 |
+
"""
|
42 |
+
|
43 |
+
pass
|
44 |
+
|
45 |
+
|
46 |
+
class CurriculumConfigError(CurriculumError):
|
47 |
+
"""
|
48 |
+
Any error related to processing the Curriculum config file.
|
49 |
+
"""
|
50 |
+
|
51 |
+
pass
|
52 |
+
|
53 |
+
|
54 |
+
class MetaCurriculumError(TrainerError):
|
55 |
+
"""
|
56 |
+
Any error related to the configuration of a metacurriculum.
|
57 |
+
"""
|
58 |
+
|
59 |
+
pass
|
60 |
+
|
61 |
+
|
62 |
+
class SamplerException(TrainerError):
|
63 |
+
"""
|
64 |
+
Related to errors with the sampler actions.
|
65 |
+
"""
|
66 |
+
|
67 |
+
pass
|
68 |
+
|
69 |
+
|
70 |
+
class UnityTrainerException(TrainerError):
|
71 |
+
"""
|
72 |
+
Related to errors with the Trainer.
|
73 |
+
"""
|
74 |
+
|
75 |
+
pass
|
MLPY/Lib/site-packages/mlagents/trainers/ghost/__init__.py
ADDED
File without changes
|