Kano001 commited on
Commit
e11e4fe
1 Parent(s): 122d3ff

Upload 280 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. MLPY/Lib/site-packages/mlagents/__init__.py +0 -0
  2. MLPY/Lib/site-packages/mlagents/__pycache__/__init__.cpython-39.pyc +0 -0
  3. MLPY/Lib/site-packages/mlagents/plugins/__init__.py +8 -0
  4. MLPY/Lib/site-packages/mlagents/plugins/__pycache__/__init__.cpython-39.pyc +0 -0
  5. MLPY/Lib/site-packages/mlagents/plugins/__pycache__/stats_writer.cpython-39.pyc +0 -0
  6. MLPY/Lib/site-packages/mlagents/plugins/__pycache__/trainer_type.cpython-39.pyc +0 -0
  7. MLPY/Lib/site-packages/mlagents/plugins/stats_writer.py +72 -0
  8. MLPY/Lib/site-packages/mlagents/plugins/trainer_type.py +80 -0
  9. MLPY/Lib/site-packages/mlagents/torch_utils/__init__.py +4 -0
  10. MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/__init__.cpython-39.pyc +0 -0
  11. MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/cpu_utils.cpython-39.pyc +0 -0
  12. MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/globals.cpython-39.pyc +0 -0
  13. MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/torch.cpython-39.pyc +0 -0
  14. MLPY/Lib/site-packages/mlagents/torch_utils/cpu_utils.py +41 -0
  15. MLPY/Lib/site-packages/mlagents/torch_utils/globals.py +13 -0
  16. MLPY/Lib/site-packages/mlagents/torch_utils/torch.py +68 -0
  17. MLPY/Lib/site-packages/mlagents/trainers/__init__.py +5 -0
  18. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/__init__.cpython-39.pyc +0 -0
  19. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/action_info.cpython-39.pyc +0 -0
  20. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/agent_processor.cpython-39.pyc +0 -0
  21. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/behavior_id_utils.cpython-39.pyc +0 -0
  22. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/buffer.cpython-39.pyc +0 -0
  23. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/cli_utils.cpython-39.pyc +0 -0
  24. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/demo_loader.cpython-39.pyc +0 -0
  25. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/directory_utils.cpython-39.pyc +0 -0
  26. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/env_manager.cpython-39.pyc +0 -0
  27. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/environment_parameter_manager.cpython-39.pyc +0 -0
  28. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/exception.cpython-39.pyc +0 -0
  29. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/learn.cpython-39.pyc +0 -0
  30. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/run_experiment.cpython-39.pyc +0 -0
  31. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/settings.cpython-39.pyc +0 -0
  32. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/simple_env_manager.cpython-39.pyc +0 -0
  33. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/stats.cpython-39.pyc +0 -0
  34. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/subprocess_env_manager.cpython-39.pyc +0 -0
  35. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/trainer_controller.cpython-39.pyc +0 -0
  36. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/training_analytics_side_channel.cpython-39.pyc +0 -0
  37. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/training_status.cpython-39.pyc +0 -0
  38. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/trajectory.cpython-39.pyc +0 -0
  39. MLPY/Lib/site-packages/mlagents/trainers/__pycache__/upgrade_config.cpython-39.pyc +0 -0
  40. MLPY/Lib/site-packages/mlagents/trainers/action_info.py +25 -0
  41. MLPY/Lib/site-packages/mlagents/trainers/agent_processor.py +469 -0
  42. MLPY/Lib/site-packages/mlagents/trainers/behavior_id_utils.py +64 -0
  43. MLPY/Lib/site-packages/mlagents/trainers/buffer.py +521 -0
  44. MLPY/Lib/site-packages/mlagents/trainers/cli_utils.py +331 -0
  45. MLPY/Lib/site-packages/mlagents/trainers/demo_loader.py +246 -0
  46. MLPY/Lib/site-packages/mlagents/trainers/directory_utils.py +76 -0
  47. MLPY/Lib/site-packages/mlagents/trainers/env_manager.py +157 -0
  48. MLPY/Lib/site-packages/mlagents/trainers/environment_parameter_manager.py +186 -0
  49. MLPY/Lib/site-packages/mlagents/trainers/exception.py +75 -0
  50. MLPY/Lib/site-packages/mlagents/trainers/ghost/__init__.py +0 -0
MLPY/Lib/site-packages/mlagents/__init__.py ADDED
File without changes
MLPY/Lib/site-packages/mlagents/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (145 Bytes). View file
 
MLPY/Lib/site-packages/mlagents/plugins/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+
3
+ ML_AGENTS_STATS_WRITER = "mlagents.stats_writer"
4
+ ML_AGENTS_TRAINER_TYPE = "mlagents.trainer_type"
5
+
6
+ # TODO: the real type is Dict[str, HyperparamSettings]
7
+ all_trainer_types: Dict[str, Any] = {}
8
+ all_trainer_settings: Dict[str, Any] = {}
MLPY/Lib/site-packages/mlagents/plugins/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (427 Bytes). View file
 
MLPY/Lib/site-packages/mlagents/plugins/__pycache__/stats_writer.cpython-39.pyc ADDED
Binary file (2.2 kB). View file
 
MLPY/Lib/site-packages/mlagents/plugins/__pycache__/trainer_type.cpython-39.pyc ADDED
Binary file (2.42 kB). View file
 
MLPY/Lib/site-packages/mlagents/plugins/stats_writer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from typing import List
3
+
4
+ # importlib.metadata is new in python3.8
5
+ # We use the backport for older python versions.
6
+ if sys.version_info < (3, 8):
7
+ import importlib_metadata
8
+ else:
9
+ import importlib.metadata as importlib_metadata # pylint: disable=E0611
10
+
11
+ from mlagents.trainers.stats import StatsWriter
12
+
13
+ from mlagents_envs import logging_util
14
+ from mlagents.plugins import ML_AGENTS_STATS_WRITER
15
+ from mlagents.trainers.settings import RunOptions
16
+ from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter
17
+
18
+
19
+ logger = logging_util.get_logger(__name__)
20
+
21
+
22
+ def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]:
23
+ """
24
+ The StatsWriters that mlagents-learn always uses:
25
+ * A TensorboardWriter to write information to TensorBoard
26
+ * A GaugeWriter to record our internal stats
27
+ * A ConsoleWriter to output to stdout.
28
+ """
29
+ checkpoint_settings = run_options.checkpoint_settings
30
+ return [
31
+ TensorboardWriter(
32
+ checkpoint_settings.write_path,
33
+ clear_past_data=not checkpoint_settings.resume,
34
+ hidden_keys=["Is Training", "Step"],
35
+ ),
36
+ GaugeWriter(),
37
+ ConsoleWriter(),
38
+ ]
39
+
40
+
41
+ def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]:
42
+ """
43
+ Registers all StatsWriter plugins (including the default one),
44
+ and evaluates them, and returns the list of all the StatsWriter implementations.
45
+ """
46
+ all_stats_writers: List[StatsWriter] = []
47
+ if ML_AGENTS_STATS_WRITER not in importlib_metadata.entry_points():
48
+ logger.warning(
49
+ f"Unable to find any entry points for {ML_AGENTS_STATS_WRITER}, even the default ones. "
50
+ "Uninstalling and reinstalling ml-agents via pip should resolve. "
51
+ "Using default plugins for now."
52
+ )
53
+ return get_default_stats_writers(run_options)
54
+
55
+ entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER]
56
+
57
+ for entry_point in entry_points:
58
+
59
+ try:
60
+ logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}")
61
+ plugin_func = entry_point.load()
62
+ plugin_stats_writers = plugin_func(run_options)
63
+ logger.debug(
64
+ f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}"
65
+ )
66
+ all_stats_writers += plugin_stats_writers
67
+ except BaseException:
68
+ # Catch all exceptions from setting up the plugin, so that bad user code doesn't break things.
69
+ logger.exception(
70
+ f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used."
71
+ )
72
+ return all_stats_writers
MLPY/Lib/site-packages/mlagents/plugins/trainer_type.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from typing import Dict, Tuple, Any
3
+
4
+ # importlib.metadata is new in python3.8
5
+ # We use the backport for older python versions.
6
+ if sys.version_info < (3, 8):
7
+ import importlib_metadata
8
+ else:
9
+ import importlib.metadata as importlib_metadata # pylint: disable=E0611
10
+
11
+
12
+ from mlagents_envs import logging_util
13
+ from mlagents.plugins import ML_AGENTS_TRAINER_TYPE
14
+ from mlagents.trainers.ppo.trainer import PPOTrainer
15
+ from mlagents.trainers.sac.trainer import SACTrainer
16
+ from mlagents.trainers.poca.trainer import POCATrainer
17
+ from mlagents.trainers.ppo.optimizer_torch import PPOSettings
18
+ from mlagents.trainers.sac.optimizer_torch import SACSettings
19
+ from mlagents.trainers.poca.optimizer_torch import POCASettings
20
+ from mlagents import plugins as mla_plugins
21
+
22
+ logger = logging_util.get_logger(__name__)
23
+
24
+
25
+ def get_default_trainer_types() -> Tuple[Dict[str, Any], Dict[str, Any]]:
26
+ """
27
+ The Trainers that mlagents-learn always uses:
28
+ """
29
+
30
+ mla_plugins.all_trainer_types.update(
31
+ {
32
+ PPOTrainer.get_trainer_name(): PPOTrainer,
33
+ SACTrainer.get_trainer_name(): SACTrainer,
34
+ POCATrainer.get_trainer_name(): POCATrainer,
35
+ }
36
+ )
37
+ # global all_trainer_settings
38
+ mla_plugins.all_trainer_settings.update(
39
+ {
40
+ PPOTrainer.get_trainer_name(): PPOSettings,
41
+ SACTrainer.get_trainer_name(): SACSettings,
42
+ POCATrainer.get_trainer_name(): POCASettings,
43
+ }
44
+ )
45
+
46
+ return mla_plugins.all_trainer_types, mla_plugins.all_trainer_settings
47
+
48
+
49
+ def register_trainer_plugins() -> Tuple[Dict[str, Any], Dict[str, Any]]:
50
+ """
51
+ Registers all Trainer plugins (including the default one),
52
+ and evaluates them, and returns the list of all the Trainer implementations.
53
+ """
54
+ if ML_AGENTS_TRAINER_TYPE not in importlib_metadata.entry_points():
55
+ logger.warning(
56
+ f"Unable to find any entry points for {ML_AGENTS_TRAINER_TYPE}, even the default ones. "
57
+ "Uninstalling and reinstalling ml-agents via pip should resolve. "
58
+ "Using default plugins for now."
59
+ )
60
+ return get_default_trainer_types()
61
+
62
+ entry_points = importlib_metadata.entry_points()[ML_AGENTS_TRAINER_TYPE]
63
+
64
+ for entry_point in entry_points:
65
+
66
+ try:
67
+ logger.debug(f"Initializing Trainer plugins: {entry_point.name}")
68
+ plugin_func = entry_point.load()
69
+ plugin_trainer_types, plugin_trainer_settings = plugin_func()
70
+ logger.debug(
71
+ f"Found {len(plugin_trainer_types)} Trainers for plugin {entry_point.name}"
72
+ )
73
+ mla_plugins.all_trainer_types.update(plugin_trainer_types)
74
+ mla_plugins.all_trainer_settings.update(plugin_trainer_settings)
75
+ except BaseException:
76
+ # Catch all exceptions from setting up the plugin, so that bad user code doesn't break things.
77
+ logger.exception(
78
+ f"Error initializing Trainer plugins for {entry_point.name}. This plugin will not be used."
79
+ )
80
+ return mla_plugins.all_trainer_types, mla_plugins.all_trainer_settings
MLPY/Lib/site-packages/mlagents/torch_utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from mlagents.torch_utils.torch import torch as torch # noqa
2
+ from mlagents.torch_utils.torch import nn # noqa
3
+ from mlagents.torch_utils.torch import set_torch_config # noqa
4
+ from mlagents.torch_utils.torch import default_device # noqa
MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (314 Bytes). View file
 
MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/cpu_utils.cpython-39.pyc ADDED
Binary file (1.51 kB). View file
 
MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/globals.cpython-39.pyc ADDED
Binary file (576 Bytes). View file
 
MLPY/Lib/site-packages/mlagents/torch_utils/__pycache__/torch.cpython-39.pyc ADDED
Binary file (1.64 kB). View file
 
MLPY/Lib/site-packages/mlagents/torch_utils/cpu_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import os
4
+
5
+
6
+ def get_num_threads_to_use() -> Optional[int]:
7
+ """
8
+ Gets the number of threads to use. For most problems, 4 is all you
9
+ need, but for smaller machines, we'd like to scale to less than that.
10
+ By default, PyTorch uses 1/2 of the available cores.
11
+ """
12
+ num_cpus = _get_num_available_cpus()
13
+ return max(min(num_cpus // 2, 4), 1) if num_cpus is not None else None
14
+
15
+
16
+ def _get_num_available_cpus() -> Optional[int]:
17
+ """
18
+ Returns number of CPUs using cgroups if possible. This accounts
19
+ for Docker containers that are limited in cores.
20
+ """
21
+ period = _read_in_integer_file("/sys/fs/cgroup/cpu/cpu.cfs_period_us")
22
+ quota = _read_in_integer_file("/sys/fs/cgroup/cpu/cpu.cfs_quota_us")
23
+ share = _read_in_integer_file("/sys/fs/cgroup/cpu/cpu.shares")
24
+ is_kubernetes = os.getenv("KUBERNETES_SERVICE_HOST") is not None
25
+
26
+ if period > 0 and quota > 0:
27
+ return int(quota // period)
28
+ elif period > 0 and share > 0 and is_kubernetes:
29
+ # In kubernetes, each requested CPU is 1024 CPU shares
30
+ # https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/#how-pods-with-resource-limits-are-run
31
+ return int(share // 1024)
32
+ else:
33
+ return os.cpu_count()
34
+
35
+
36
+ def _read_in_integer_file(filename: str) -> int:
37
+ try:
38
+ with open(filename) as f:
39
+ return int(f.read().rstrip())
40
+ except FileNotFoundError:
41
+ return -1
MLPY/Lib/site-packages/mlagents/torch_utils/globals.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ _rank: Optional[int] = None
4
+
5
+
6
+ def get_rank() -> Optional[int]:
7
+ """
8
+ Returns the rank (in the MPI sense) of the current node.
9
+ For local training, this will always be None.
10
+ If this needs to be used, it should be done from outside ml-agents.
11
+ :return:
12
+ """
13
+ return _rank
MLPY/Lib/site-packages/mlagents/torch_utils/torch.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from distutils.version import LooseVersion
4
+ import pkg_resources
5
+ from mlagents.torch_utils import cpu_utils
6
+ from mlagents.trainers.settings import TorchSettings
7
+ from mlagents_envs.logging_util import get_logger
8
+
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ def assert_torch_installed():
14
+ # Check that torch version 1.6.0 or later has been installed. If not, refer
15
+ # user to the PyTorch webpage for install instructions.
16
+ torch_pkg = None
17
+ try:
18
+ torch_pkg = pkg_resources.get_distribution("torch")
19
+ except pkg_resources.DistributionNotFound:
20
+ pass
21
+ assert torch_pkg is not None and LooseVersion(torch_pkg.version) >= LooseVersion(
22
+ "1.6.0"
23
+ ), (
24
+ "A compatible version of PyTorch was not installed. Please visit the PyTorch homepage "
25
+ + "(https://pytorch.org/get-started/locally/) and follow the instructions to install. "
26
+ + "Version 1.6.0 and later are supported."
27
+ )
28
+
29
+
30
+ assert_torch_installed()
31
+
32
+ # This should be the only place that we import torch directly.
33
+ # Everywhere else is caught by the banned-modules setting for flake8
34
+ import torch # noqa I201
35
+
36
+
37
+ torch.set_num_threads(cpu_utils.get_num_threads_to_use())
38
+ os.environ["KMP_BLOCKTIME"] = "0"
39
+
40
+
41
+ _device = torch.device("cpu")
42
+
43
+
44
+ def set_torch_config(torch_settings: TorchSettings) -> None:
45
+ global _device
46
+
47
+ if torch_settings.device is None:
48
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
49
+ else:
50
+ device_str = torch_settings.device
51
+
52
+ _device = torch.device(device_str)
53
+
54
+ if _device.type == "cuda":
55
+ torch.set_default_tensor_type(torch.cuda.FloatTensor)
56
+ else:
57
+ torch.set_default_tensor_type(torch.FloatTensor)
58
+ logger.debug(f"default Torch device: {_device}")
59
+
60
+
61
+ # Initialize to default settings
62
+ set_torch_config(TorchSettings(device=None))
63
+
64
+ nn = torch.nn
65
+
66
+
67
+ def default_device():
68
+ return _device
MLPY/Lib/site-packages/mlagents/trainers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Version of the library that will be used to upload to pypi
2
+ __version__ = "0.30.0"
3
+
4
+ # Git tag that will be checked to determine whether to trigger upload to pypi
5
+ __release_tag__ = "release_20"
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (211 Bytes). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/action_info.cpython-39.pyc ADDED
Binary file (1.27 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/agent_processor.cpython-39.pyc ADDED
Binary file (14.1 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/behavior_id_utils.cpython-39.pyc ADDED
Binary file (2.62 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/buffer.cpython-39.pyc ADDED
Binary file (18.3 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/cli_utils.cpython-39.pyc ADDED
Binary file (9.91 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/demo_loader.cpython-39.pyc ADDED
Binary file (6.44 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/directory_utils.cpython-39.pyc ADDED
Binary file (2.68 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/env_manager.cpython-39.pyc ADDED
Binary file (5.44 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/environment_parameter_manager.cpython-39.pyc ADDED
Binary file (6.54 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/exception.cpython-39.pyc ADDED
Binary file (2.16 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/learn.cpython-39.pyc ADDED
Binary file (8.42 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/run_experiment.cpython-39.pyc ADDED
Binary file (1.31 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/settings.cpython-39.pyc ADDED
Binary file (32.6 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/simple_env_manager.cpython-39.pyc ADDED
Binary file (3.54 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/stats.cpython-39.pyc ADDED
Binary file (14.9 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/subprocess_env_manager.cpython-39.pyc ADDED
Binary file (16.4 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/trainer_controller.cpython-39.pyc ADDED
Binary file (9.53 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/training_analytics_side_channel.cpython-39.pyc ADDED
Binary file (6.14 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/training_status.cpython-39.pyc ADDED
Binary file (5.09 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/trajectory.cpython-39.pyc ADDED
Binary file (8.41 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/__pycache__/upgrade_config.cpython-39.pyc ADDED
Binary file (5.8 kB). View file
 
MLPY/Lib/site-packages/mlagents/trainers/action_info.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple, Any, Dict, List
2
+ import numpy as np
3
+ from mlagents_envs.base_env import AgentId
4
+
5
+ ActionInfoOutputs = Dict[str, np.ndarray]
6
+
7
+
8
+ class ActionInfo(NamedTuple):
9
+ """
10
+ A NamedTuple containing actions and related quantities to the policy forward
11
+ pass. Additionally contains the agent ids in the corresponding DecisionStep
12
+ :param action: The action output of the policy
13
+ :param env_action: The possibly clipped action to be executed in the environment
14
+ :param outputs: Dict of all quantities associated with the policy forward pass
15
+ :param agent_ids: List of int agent ids in DecisionStep
16
+ """
17
+
18
+ action: Any
19
+ env_action: Any
20
+ outputs: ActionInfoOutputs
21
+ agent_ids: List[AgentId]
22
+
23
+ @staticmethod
24
+ def empty() -> "ActionInfo":
25
+ return ActionInfo([], [], {}, [])
MLPY/Lib/site-packages/mlagents/trainers/agent_processor.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ from typing import List, Dict, TypeVar, Generic, Tuple, Any, Union
4
+ from collections import defaultdict, Counter
5
+ import queue
6
+ from mlagents.torch_utils import torch
7
+
8
+ from mlagents_envs.base_env import (
9
+ ActionTuple,
10
+ DecisionSteps,
11
+ DecisionStep,
12
+ TerminalSteps,
13
+ TerminalStep,
14
+ )
15
+ from mlagents_envs.side_channel.stats_side_channel import (
16
+ StatsAggregationMethod,
17
+ EnvironmentStats,
18
+ )
19
+ from mlagents.trainers.exception import UnityTrainerException
20
+ from mlagents.trainers.trajectory import AgentStatus, Trajectory, AgentExperience
21
+ from mlagents.trainers.policy import Policy
22
+ from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs
23
+ from mlagents.trainers.stats import StatsReporter
24
+ from mlagents.trainers.behavior_id_utils import (
25
+ get_global_agent_id,
26
+ get_global_group_id,
27
+ GlobalAgentId,
28
+ GlobalGroupId,
29
+ )
30
+ from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple
31
+ from mlagents.trainers.torch_entities.utils import ModelUtils
32
+
33
+ T = TypeVar("T")
34
+
35
+
36
+ class AgentProcessor:
37
+ """
38
+ AgentProcessor contains a dictionary per-agent trajectory buffers. The buffers are indexed by agent_id.
39
+ Buffer also contains an update_buffer that corresponds to the buffer used when updating the model.
40
+ One AgentProcessor should be created per agent group.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ policy: Policy,
46
+ behavior_id: str,
47
+ stats_reporter: StatsReporter,
48
+ max_trajectory_length: int = sys.maxsize,
49
+ ):
50
+ """
51
+ Create an AgentProcessor.
52
+
53
+ :param trainer: Trainer instance connected to this AgentProcessor. Trainer is given trajectory
54
+ when it is finished.
55
+ :param policy: Policy instance associated with this AgentProcessor.
56
+ :param max_trajectory_length: Maximum length of a trajectory before it is added to the trainer.
57
+ :param stats_category: The category under which to write the stats. Usually, this comes from the Trainer.
58
+ """
59
+ self._experience_buffers: Dict[
60
+ GlobalAgentId, List[AgentExperience]
61
+ ] = defaultdict(list)
62
+ self._last_step_result: Dict[GlobalAgentId, Tuple[DecisionStep, int]] = {}
63
+ # current_group_obs is used to collect the current (i.e. the most recently seen)
64
+ # obs of all the agents in the same group, and assemble the group obs.
65
+ # It is a dictionary of GlobalGroupId to dictionaries of GlobalAgentId to observation.
66
+ self._current_group_obs: Dict[
67
+ GlobalGroupId, Dict[GlobalAgentId, List[np.ndarray]]
68
+ ] = defaultdict(lambda: defaultdict(list))
69
+ # group_status is used to collect the current, most recently seen
70
+ # group status of all the agents in the same group, and assemble the group's status.
71
+ # It is a dictionary of GlobalGroupId to dictionaries of GlobalAgentId to AgentStatus.
72
+ self._group_status: Dict[
73
+ GlobalGroupId, Dict[GlobalAgentId, AgentStatus]
74
+ ] = defaultdict(lambda: defaultdict(None))
75
+ # last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while
76
+ # grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1).
77
+ self._last_take_action_outputs: Dict[GlobalAgentId, ActionInfoOutputs] = {}
78
+
79
+ self._episode_steps: Counter = Counter()
80
+ self._episode_rewards: Dict[GlobalAgentId, float] = defaultdict(float)
81
+ self._stats_reporter = stats_reporter
82
+ self._max_trajectory_length = max_trajectory_length
83
+ self._trajectory_queues: List[AgentManagerQueue[Trajectory]] = []
84
+ self._behavior_id = behavior_id
85
+
86
+ # Note: In the future this policy reference will be the policy of the env_manager and not the trainer.
87
+ # We can in that case just grab the action from the policy rather than having it passed in.
88
+ self.policy = policy
89
+
90
+ def add_experiences(
91
+ self,
92
+ decision_steps: DecisionSteps,
93
+ terminal_steps: TerminalSteps,
94
+ worker_id: int,
95
+ previous_action: ActionInfo,
96
+ ) -> None:
97
+ """
98
+ Adds experiences to each agent's experience history.
99
+ :param decision_steps: current DecisionSteps.
100
+ :param terminal_steps: current TerminalSteps.
101
+ :param previous_action: The outputs of the Policy's get_action method.
102
+ """
103
+ take_action_outputs = previous_action.outputs
104
+ if take_action_outputs:
105
+ try:
106
+ for _entropy in take_action_outputs["entropy"]:
107
+ if isinstance(_entropy, torch.Tensor):
108
+ _entropy = ModelUtils.to_numpy(_entropy)
109
+ self._stats_reporter.add_stat("Policy/Entropy", _entropy)
110
+ except KeyError:
111
+ pass
112
+
113
+ # Make unique agent_ids that are global across workers
114
+ action_global_agent_ids = [
115
+ get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids
116
+ ]
117
+ for global_id in action_global_agent_ids:
118
+ if global_id in self._last_step_result: # Don't store if agent just reset
119
+ self._last_take_action_outputs[global_id] = take_action_outputs
120
+
121
+ # Iterate over all the terminal steps, first gather all the group obs
122
+ # and then create the AgentExperiences/Trajectories. _add_to_group_status
123
+ # stores Group statuses in a common data structure self.group_status
124
+ for terminal_step in terminal_steps.values():
125
+ self._add_group_status_and_obs(terminal_step, worker_id)
126
+ for terminal_step in terminal_steps.values():
127
+ local_id = terminal_step.agent_id
128
+ global_id = get_global_agent_id(worker_id, local_id)
129
+ self._process_step(
130
+ terminal_step, worker_id, terminal_steps.agent_id_to_index[local_id]
131
+ )
132
+
133
+ # Iterate over all the decision steps, first gather all the group obs
134
+ # and then create the trajectories. _add_to_group_status
135
+ # stores Group statuses in a common data structure self.group_status
136
+ for ongoing_step in decision_steps.values():
137
+ self._add_group_status_and_obs(ongoing_step, worker_id)
138
+ for ongoing_step in decision_steps.values():
139
+ local_id = ongoing_step.agent_id
140
+ self._process_step(
141
+ ongoing_step, worker_id, decision_steps.agent_id_to_index[local_id]
142
+ )
143
+ # Clear the last seen group obs when agents die, but only after all of the group
144
+ # statuses were added to the trajectory.
145
+ for terminal_step in terminal_steps.values():
146
+ local_id = terminal_step.agent_id
147
+ global_id = get_global_agent_id(worker_id, local_id)
148
+ self._clear_group_status_and_obs(global_id)
149
+
150
+ for _gid in action_global_agent_ids:
151
+ # If the ID doesn't have a last step result, the agent just reset,
152
+ # don't store the action.
153
+ if _gid in self._last_step_result:
154
+ if "action" in take_action_outputs:
155
+ self.policy.save_previous_action(
156
+ [_gid], take_action_outputs["action"]
157
+ )
158
+
159
+ def _add_group_status_and_obs(
160
+ self, step: Union[TerminalStep, DecisionStep], worker_id: int
161
+ ) -> None:
162
+ """
163
+ Takes a TerminalStep or DecisionStep and adds the information in it
164
+ to self.group_status. This information can then be retrieved
165
+ when constructing trajectories to get the status of group mates. Also stores the current
166
+ observation into current_group_obs, to be used to get the next group observations
167
+ for bootstrapping.
168
+ :param step: TerminalStep or DecisionStep
169
+ :param worker_id: Worker ID of this particular environment. Used to generate a
170
+ global group id.
171
+ """
172
+ global_agent_id = get_global_agent_id(worker_id, step.agent_id)
173
+ stored_decision_step, idx = self._last_step_result.get(
174
+ global_agent_id, (None, None)
175
+ )
176
+ stored_take_action_outputs = self._last_take_action_outputs.get(
177
+ global_agent_id, None
178
+ )
179
+ if stored_decision_step is not None and stored_take_action_outputs is not None:
180
+ # 0, the default group_id, means that the agent doesn't belong to an agent group.
181
+ # If 0, don't add any groupmate information.
182
+ if step.group_id > 0:
183
+ global_group_id = get_global_group_id(worker_id, step.group_id)
184
+ stored_actions = stored_take_action_outputs["action"]
185
+ action_tuple = ActionTuple(
186
+ continuous=stored_actions.continuous[idx],
187
+ discrete=stored_actions.discrete[idx],
188
+ )
189
+ group_status = AgentStatus(
190
+ obs=stored_decision_step.obs,
191
+ reward=step.reward,
192
+ action=action_tuple,
193
+ done=isinstance(step, TerminalStep),
194
+ )
195
+ self._group_status[global_group_id][global_agent_id] = group_status
196
+ self._current_group_obs[global_group_id][global_agent_id] = step.obs
197
+
198
+ def _clear_group_status_and_obs(self, global_id: GlobalAgentId) -> None:
199
+ """
200
+ Clears an agent from self._group_status and self._current_group_obs.
201
+ """
202
+ self._delete_in_nested_dict(self._current_group_obs, global_id)
203
+ self._delete_in_nested_dict(self._group_status, global_id)
204
+
205
+ def _delete_in_nested_dict(self, nested_dict: Dict[str, Any], key: str) -> None:
206
+ for _manager_id in list(nested_dict.keys()):
207
+ _team_group = nested_dict[_manager_id]
208
+ self._safe_delete(_team_group, key)
209
+ if not _team_group: # if dict is empty
210
+ self._safe_delete(nested_dict, _manager_id)
211
+
212
+ def _process_step(
213
+ self, step: Union[TerminalStep, DecisionStep], worker_id: int, index: int
214
+ ) -> None:
215
+ terminated = isinstance(step, TerminalStep)
216
+ global_agent_id = get_global_agent_id(worker_id, step.agent_id)
217
+ global_group_id = get_global_group_id(worker_id, step.group_id)
218
+ stored_decision_step, idx = self._last_step_result.get(
219
+ global_agent_id, (None, None)
220
+ )
221
+ stored_take_action_outputs = self._last_take_action_outputs.get(
222
+ global_agent_id, None
223
+ )
224
+ if not terminated:
225
+ # Index is needed to grab from last_take_action_outputs
226
+ self._last_step_result[global_agent_id] = (step, index)
227
+
228
+ # This state is the consequence of a past action
229
+ if stored_decision_step is not None and stored_take_action_outputs is not None:
230
+ obs = stored_decision_step.obs
231
+ if self.policy.use_recurrent:
232
+ memory = self.policy.retrieve_previous_memories([global_agent_id])[0, :]
233
+ else:
234
+ memory = None
235
+ done = terminated # Since this is an ongoing step
236
+ interrupted = step.interrupted if terminated else False
237
+ # Add the outputs of the last eval
238
+ stored_actions = stored_take_action_outputs["action"]
239
+ action_tuple = ActionTuple(
240
+ continuous=stored_actions.continuous[idx],
241
+ discrete=stored_actions.discrete[idx],
242
+ )
243
+ try:
244
+ stored_action_probs = stored_take_action_outputs["log_probs"]
245
+ if not isinstance(stored_action_probs, LogProbsTuple):
246
+ stored_action_probs = stored_action_probs.to_log_probs_tuple()
247
+ log_probs_tuple = LogProbsTuple(
248
+ continuous=stored_action_probs.continuous[idx],
249
+ discrete=stored_action_probs.discrete[idx],
250
+ )
251
+ except KeyError:
252
+ log_probs_tuple = LogProbsTuple.empty_log_probs()
253
+
254
+ action_mask = stored_decision_step.action_mask
255
+ prev_action = self.policy.retrieve_previous_action([global_agent_id])[0, :]
256
+
257
+ # Assemble teammate_obs. If none saved, then it will be an empty list.
258
+ group_statuses = []
259
+ for _id, _mate_status in self._group_status[global_group_id].items():
260
+ if _id != global_agent_id:
261
+ group_statuses.append(_mate_status)
262
+
263
+ experience = AgentExperience(
264
+ obs=obs,
265
+ reward=step.reward,
266
+ done=done,
267
+ action=action_tuple,
268
+ action_probs=log_probs_tuple,
269
+ action_mask=action_mask,
270
+ prev_action=prev_action,
271
+ interrupted=interrupted,
272
+ memory=memory,
273
+ group_status=group_statuses,
274
+ group_reward=step.group_reward,
275
+ )
276
+ # Add the value outputs if needed
277
+ self._experience_buffers[global_agent_id].append(experience)
278
+ self._episode_rewards[global_agent_id] += step.reward
279
+ if not terminated:
280
+ self._episode_steps[global_agent_id] += 1
281
+
282
+ # Add a trajectory segment to the buffer if terminal or the length has reached the time horizon
283
+ if (
284
+ len(self._experience_buffers[global_agent_id])
285
+ >= self._max_trajectory_length
286
+ or terminated
287
+ ):
288
+ next_obs = step.obs
289
+ next_group_obs = []
290
+ for _id, _obs in self._current_group_obs[global_group_id].items():
291
+ if _id != global_agent_id:
292
+ next_group_obs.append(_obs)
293
+
294
+ trajectory = Trajectory(
295
+ steps=self._experience_buffers[global_agent_id],
296
+ agent_id=global_agent_id,
297
+ next_obs=next_obs,
298
+ next_group_obs=next_group_obs,
299
+ behavior_id=self._behavior_id,
300
+ )
301
+ for traj_queue in self._trajectory_queues:
302
+ traj_queue.put(trajectory)
303
+ self._experience_buffers[global_agent_id] = []
304
+ if terminated:
305
+ # Record episode length.
306
+ self._stats_reporter.add_stat(
307
+ "Environment/Episode Length",
308
+ self._episode_steps.get(global_agent_id, 0),
309
+ )
310
+ self._clean_agent_data(global_agent_id)
311
+
312
+ def _clean_agent_data(self, global_id: GlobalAgentId) -> None:
313
+ """
314
+ Removes the data for an Agent.
315
+ """
316
+ self._safe_delete(self._experience_buffers, global_id)
317
+ self._safe_delete(self._last_take_action_outputs, global_id)
318
+ self._safe_delete(self._last_step_result, global_id)
319
+ self._safe_delete(self._episode_steps, global_id)
320
+ self._safe_delete(self._episode_rewards, global_id)
321
+ self.policy.remove_previous_action([global_id])
322
+ self.policy.remove_memories([global_id])
323
+
324
+ def _safe_delete(self, my_dictionary: Dict[Any, Any], key: Any) -> None:
325
+ """
326
+ Safe removes data from a dictionary. If not found,
327
+ don't delete.
328
+ """
329
+ if key in my_dictionary:
330
+ del my_dictionary[key]
331
+
332
+ def publish_trajectory_queue(
333
+ self, trajectory_queue: "AgentManagerQueue[Trajectory]"
334
+ ) -> None:
335
+ """
336
+ Adds a trajectory queue to the list of queues to publish to when this AgentProcessor
337
+ assembles a Trajectory
338
+ :param trajectory_queue: Trajectory queue to publish to.
339
+ """
340
+ self._trajectory_queues.append(trajectory_queue)
341
+
342
+ def end_episode(self) -> None:
343
+ """
344
+ Ends the episode, terminating the current trajectory and stopping stats collection for that
345
+ episode. Used for forceful reset (e.g. in curriculum or generalization training.)
346
+ """
347
+ all_gids = list(self._experience_buffers.keys()) # Need to make copy
348
+ for _gid in all_gids:
349
+ self._clean_agent_data(_gid)
350
+
351
+
352
+ class AgentManagerQueue(Generic[T]):
353
+ """
354
+ Queue used by the AgentManager. Note that we make our own class here because in most implementations
355
+ deque is sufficient and faster. However, if we want to switch to multiprocessing, we'll need to change
356
+ out this implementation.
357
+ """
358
+
359
+ class Empty(Exception):
360
+ """
361
+ Exception for when the queue is empty.
362
+ """
363
+
364
+ pass
365
+
366
+ def __init__(self, behavior_id: str, maxlen: int = 0):
367
+ """
368
+ Initializes an AgentManagerQueue. Note that we can give it a behavior_id so that it can be identified
369
+ separately from an AgentManager.
370
+ """
371
+ self._maxlen: int = maxlen
372
+ self._queue: queue.Queue = queue.Queue(maxsize=maxlen)
373
+ self._behavior_id = behavior_id
374
+
375
+ @property
376
+ def maxlen(self):
377
+ """
378
+ The maximum length of the queue.
379
+ :return: Maximum length of the queue.
380
+ """
381
+ return self._maxlen
382
+
383
+ @property
384
+ def behavior_id(self):
385
+ """
386
+ The Behavior ID of this queue.
387
+ :return: Behavior ID associated with the queue.
388
+ """
389
+ return self._behavior_id
390
+
391
+ def qsize(self) -> int:
392
+ """
393
+ Returns the approximate size of the queue. Note that values may differ
394
+ depending on the underlying queue implementation.
395
+ """
396
+ return self._queue.qsize()
397
+
398
+ def empty(self) -> bool:
399
+ return self._queue.empty()
400
+
401
+ def get_nowait(self) -> T:
402
+ """
403
+ Gets the next item from the queue, throwing an AgentManagerQueue.Empty exception
404
+ if the queue is empty.
405
+ """
406
+ try:
407
+ return self._queue.get_nowait()
408
+ except queue.Empty:
409
+ raise self.Empty("The AgentManagerQueue is empty.")
410
+
411
+ def put(self, item: T) -> None:
412
+ self._queue.put(item)
413
+
414
+
415
+ class AgentManager(AgentProcessor):
416
+ """
417
+ An AgentManager is an AgentProcessor that also holds a single trajectory and policy queue.
418
+ Note: this leaves room for adding AgentProcessors that publish multiple trajectory queues.
419
+ """
420
+
421
+ def __init__(
422
+ self,
423
+ policy: Policy,
424
+ behavior_id: str,
425
+ stats_reporter: StatsReporter,
426
+ max_trajectory_length: int = sys.maxsize,
427
+ threaded: bool = True,
428
+ ):
429
+ super().__init__(policy, behavior_id, stats_reporter, max_trajectory_length)
430
+ trajectory_queue_len = 20 if threaded else 0
431
+ self.trajectory_queue: AgentManagerQueue[Trajectory] = AgentManagerQueue(
432
+ self._behavior_id, maxlen=trajectory_queue_len
433
+ )
434
+ # NOTE: we make policy queues of infinite length to avoid lockups of the trainers.
435
+ # In the environment manager, we make sure to empty the policy queue before continuing to produce steps.
436
+ self.policy_queue: AgentManagerQueue[Policy] = AgentManagerQueue(
437
+ self._behavior_id, maxlen=0
438
+ )
439
+ self.publish_trajectory_queue(self.trajectory_queue)
440
+
441
+ def record_environment_stats(
442
+ self, env_stats: EnvironmentStats, worker_id: int
443
+ ) -> None:
444
+ """
445
+ Pass stats from the environment to the StatsReporter.
446
+ Depending on the StatsAggregationMethod, either StatsReporter.add_stat or StatsReporter.set_stat is used.
447
+ The worker_id is used to determine whether StatsReporter.set_stat should be used.
448
+
449
+ :param env_stats:
450
+ :param worker_id:
451
+ :return:
452
+ """
453
+ for stat_name, value_list in env_stats.items():
454
+ for val, agg_type in value_list:
455
+ if agg_type == StatsAggregationMethod.AVERAGE:
456
+ self._stats_reporter.add_stat(stat_name, val, agg_type)
457
+ elif agg_type == StatsAggregationMethod.SUM:
458
+ self._stats_reporter.add_stat(stat_name, val, agg_type)
459
+ elif agg_type == StatsAggregationMethod.HISTOGRAM:
460
+ self._stats_reporter.add_stat(stat_name, val, agg_type)
461
+ elif agg_type == StatsAggregationMethod.MOST_RECENT:
462
+ # In order to prevent conflicts between multiple environments,
463
+ # only stats from the first environment are recorded.
464
+ if worker_id == 0:
465
+ self._stats_reporter.set_stat(stat_name, val)
466
+ else:
467
+ raise UnityTrainerException(
468
+ f"Unknown StatsAggregationMethod encountered. {agg_type}"
469
+ )
MLPY/Lib/site-packages/mlagents/trainers/behavior_id_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple
2
+ from urllib.parse import urlparse, parse_qs
3
+ from mlagents_envs.base_env import AgentId, GroupId
4
+
5
+ GlobalGroupId = str
6
+ GlobalAgentId = str
7
+
8
+
9
+ class BehaviorIdentifiers(NamedTuple):
10
+ """
11
+ BehaviorIdentifiers is a named tuple of the identifiers that uniquely distinguish
12
+ an agent encountered in the trainer_controller. The named tuple consists of the
13
+ fully qualified behavior name, the name of the brain name (corresponds to trainer
14
+ in the trainer controller) and the team id. In the future, this can be extended
15
+ to support further identifiers.
16
+ """
17
+
18
+ behavior_id: str
19
+ brain_name: str
20
+ team_id: int
21
+
22
+ @staticmethod
23
+ def from_name_behavior_id(name_behavior_id: str) -> "BehaviorIdentifiers":
24
+ """
25
+ Parses a name_behavior_id of the form name?team=0
26
+ into a BehaviorIdentifiers NamedTuple.
27
+ This allows you to access the brain name and team id of an agent
28
+ :param name_behavior_id: String of behavior params in HTTP format.
29
+ :returns: A BehaviorIdentifiers object.
30
+ """
31
+
32
+ parsed = urlparse(name_behavior_id)
33
+ name = parsed.path
34
+ ids = parse_qs(parsed.query)
35
+ team_id: int = 0
36
+ if "team" in ids:
37
+ team_id = int(ids["team"][0])
38
+ return BehaviorIdentifiers(
39
+ behavior_id=name_behavior_id, brain_name=name, team_id=team_id
40
+ )
41
+
42
+
43
+ def create_name_behavior_id(name: str, team_id: int) -> str:
44
+ """
45
+ Reconstructs fully qualified behavior name from name and team_id
46
+ :param name: brain name
47
+ :param team_id: team ID
48
+ :return: name_behavior_id
49
+ """
50
+ return name + "?team=" + str(team_id)
51
+
52
+
53
+ def get_global_agent_id(worker_id: int, agent_id: AgentId) -> GlobalAgentId:
54
+ """
55
+ Create an agent id that is unique across environment workers using the worker_id.
56
+ """
57
+ return f"agent_{worker_id}-{agent_id}"
58
+
59
+
60
+ def get_global_group_id(worker_id: int, group_id: GroupId) -> GlobalGroupId:
61
+ """
62
+ Create a group id that is unique across environment workers when using the worker_id.
63
+ """
64
+ return f"group_{worker_id}-{group_id}"
MLPY/Lib/site-packages/mlagents/trainers/buffer.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from collections.abc import MutableMapping
3
+ import enum
4
+ import itertools
5
+ from typing import BinaryIO, DefaultDict, List, Tuple, Union, Optional
6
+
7
+ import numpy as np
8
+ import h5py
9
+
10
+ from mlagents_envs.exception import UnityException
11
+
12
+ # Elements in the buffer can be np.ndarray, or in the case of teammate obs, actions, rewards,
13
+ # a List of np.ndarray. This is done so that we don't have duplicated np.ndarrays, only references.
14
+ BufferEntry = Union[np.ndarray, List[np.ndarray]]
15
+
16
+
17
+ class BufferException(UnityException):
18
+ """
19
+ Related to errors with the Buffer.
20
+ """
21
+
22
+ pass
23
+
24
+
25
+ class BufferKey(enum.Enum):
26
+ ACTION_MASK = "action_mask"
27
+ CONTINUOUS_ACTION = "continuous_action"
28
+ NEXT_CONT_ACTION = "next_continuous_action"
29
+ CONTINUOUS_LOG_PROBS = "continuous_log_probs"
30
+ DISCRETE_ACTION = "discrete_action"
31
+ NEXT_DISC_ACTION = "next_discrete_action"
32
+ DISCRETE_LOG_PROBS = "discrete_log_probs"
33
+ DONE = "done"
34
+ ENVIRONMENT_REWARDS = "environment_rewards"
35
+ MASKS = "masks"
36
+ MEMORY = "memory"
37
+ CRITIC_MEMORY = "critic_memory"
38
+ BASELINE_MEMORY = "poca_baseline_memory"
39
+ PREV_ACTION = "prev_action"
40
+
41
+ ADVANTAGES = "advantages"
42
+ DISCOUNTED_RETURNS = "discounted_returns"
43
+
44
+ GROUP_DONES = "group_dones"
45
+ GROUPMATE_REWARDS = "groupmate_reward"
46
+ GROUP_REWARD = "group_reward"
47
+ GROUP_CONTINUOUS_ACTION = "group_continuous_action"
48
+ GROUP_DISCRETE_ACTION = "group_discrete_aaction"
49
+ GROUP_NEXT_CONT_ACTION = "group_next_cont_action"
50
+ GROUP_NEXT_DISC_ACTION = "group_next_disc_action"
51
+
52
+
53
+ class ObservationKeyPrefix(enum.Enum):
54
+ OBSERVATION = "obs"
55
+ NEXT_OBSERVATION = "next_obs"
56
+
57
+ GROUP_OBSERVATION = "group_obs"
58
+ NEXT_GROUP_OBSERVATION = "next_group_obs"
59
+
60
+
61
+ class RewardSignalKeyPrefix(enum.Enum):
62
+ # Reward signals
63
+ REWARDS = "rewards"
64
+ VALUE_ESTIMATES = "value_estimates"
65
+ RETURNS = "returns"
66
+ ADVANTAGE = "advantage"
67
+ BASELINES = "baselines"
68
+
69
+
70
+ AgentBufferKey = Union[
71
+ BufferKey, Tuple[ObservationKeyPrefix, int], Tuple[RewardSignalKeyPrefix, str]
72
+ ]
73
+
74
+
75
+ class RewardSignalUtil:
76
+ @staticmethod
77
+ def rewards_key(name: str) -> AgentBufferKey:
78
+ return RewardSignalKeyPrefix.REWARDS, name
79
+
80
+ @staticmethod
81
+ def value_estimates_key(name: str) -> AgentBufferKey:
82
+ return RewardSignalKeyPrefix.RETURNS, name
83
+
84
+ @staticmethod
85
+ def returns_key(name: str) -> AgentBufferKey:
86
+ return RewardSignalKeyPrefix.RETURNS, name
87
+
88
+ @staticmethod
89
+ def advantage_key(name: str) -> AgentBufferKey:
90
+ return RewardSignalKeyPrefix.ADVANTAGE, name
91
+
92
+ @staticmethod
93
+ def baseline_estimates_key(name: str) -> AgentBufferKey:
94
+ return RewardSignalKeyPrefix.BASELINES, name
95
+
96
+
97
+ class AgentBufferField(list):
98
+ """
99
+ AgentBufferField is a list of numpy arrays, or List[np.ndarray] for group entries.
100
+ When an agent collects a field, you can add it to its AgentBufferField with the append method.
101
+ """
102
+
103
+ def __init__(self, *args, **kwargs):
104
+ self.padding_value = 0
105
+ super().__init__(*args, **kwargs)
106
+
107
+ def __str__(self) -> str:
108
+ return f"AgentBufferField: {super().__str__()}"
109
+
110
+ def __getitem__(self, index):
111
+ return_data = super().__getitem__(index)
112
+ if isinstance(return_data, list):
113
+ return AgentBufferField(return_data)
114
+ else:
115
+ return return_data
116
+
117
+ @property
118
+ def contains_lists(self) -> bool:
119
+ """
120
+ Checks whether this AgentBufferField contains List[np.ndarray].
121
+ """
122
+ return len(self) > 0 and isinstance(self[0], list)
123
+
124
+ def append(self, element: BufferEntry, padding_value: float = 0.0) -> None:
125
+ """
126
+ Adds an element to this list. Also lets you change the padding
127
+ type, so that it can be set on append (e.g. action_masks should
128
+ be padded with 1.)
129
+ :param element: The element to append to the list.
130
+ :param padding_value: The value used to pad when get_batch is called.
131
+ """
132
+ super().append(element)
133
+ self.padding_value = padding_value
134
+
135
+ def set(self, data: List[BufferEntry]) -> None:
136
+ """
137
+ Sets the list of BufferEntry to the input data
138
+ :param data: The BufferEntry list to be set.
139
+ """
140
+ self[:] = data
141
+
142
+ def get_batch(
143
+ self,
144
+ batch_size: int = None,
145
+ training_length: Optional[int] = 1,
146
+ sequential: bool = True,
147
+ ) -> List[BufferEntry]:
148
+ """
149
+ Retrieve the last batch_size elements of length training_length
150
+ from the list of np.array
151
+ :param batch_size: The number of elements to retrieve. If None:
152
+ All elements will be retrieved.
153
+ :param training_length: The length of the sequence to be retrieved. If
154
+ None: only takes one element.
155
+ :param sequential: If true and training_length is not None: the elements
156
+ will not repeat in the sequence. [a,b,c,d,e] with training_length = 2 and
157
+ sequential=True gives [[0,a],[b,c],[d,e]]. If sequential=False gives
158
+ [[a,b],[b,c],[c,d],[d,e]]
159
+ """
160
+ if training_length is None:
161
+ training_length = 1
162
+ if sequential:
163
+ # The sequences will not have overlapping elements (this involves padding)
164
+ leftover = len(self) % training_length
165
+ # leftover is the number of elements in the first sequence (this sequence might need 0 padding)
166
+ if batch_size is None:
167
+ # retrieve the maximum number of elements
168
+ batch_size = len(self) // training_length + 1 * (leftover != 0)
169
+ # The maximum number of sequences taken from a list of length len(self) without overlapping
170
+ # with padding is equal to batch_size
171
+ if batch_size > (len(self) // training_length + 1 * (leftover != 0)):
172
+ raise BufferException(
173
+ "The batch size and training length requested for get_batch where"
174
+ " too large given the current number of data points."
175
+ )
176
+ if batch_size * training_length > len(self):
177
+ if self.contains_lists:
178
+ padding = []
179
+ else:
180
+ # We want to duplicate the last value in the array, multiplied by the padding_value.
181
+ padding = np.array(self[-1], dtype=np.float32) * self.padding_value
182
+ return self[:] + [padding] * (training_length - leftover)
183
+
184
+ else:
185
+ return self[len(self) - batch_size * training_length :]
186
+ else:
187
+ # The sequences will have overlapping elements
188
+ if batch_size is None:
189
+ # retrieve the maximum number of elements
190
+ batch_size = len(self) - training_length + 1
191
+ # The number of sequences of length training_length taken from a list of len(self) elements
192
+ # with overlapping is equal to batch_size
193
+ if (len(self) - training_length + 1) < batch_size:
194
+ raise BufferException(
195
+ "The batch size and training length requested for get_batch where"
196
+ " too large given the current number of data points."
197
+ )
198
+ tmp_list: List[np.ndarray] = []
199
+ for end in range(len(self) - batch_size + 1, len(self) + 1):
200
+ tmp_list += self[end - training_length : end]
201
+ return tmp_list
202
+
203
+ def reset_field(self) -> None:
204
+ """
205
+ Resets the AgentBufferField
206
+ """
207
+ self[:] = []
208
+
209
+ def padded_to_batch(
210
+ self, pad_value: np.float = 0, dtype: np.dtype = np.float32
211
+ ) -> Union[np.ndarray, List[np.ndarray]]:
212
+ """
213
+ Converts this AgentBufferField (which is a List[BufferEntry]) into a numpy array
214
+ with first dimension equal to the length of this AgentBufferField. If this AgentBufferField
215
+ contains a List[List[BufferEntry]] (i.e., in the case of group observations), return a List
216
+ containing numpy arrays or tensors, of length equal to the maximum length of an entry. Missing
217
+ For entries with less than that length, the array will be padded with pad_value.
218
+ :param pad_value: Value to pad List AgentBufferFields, when there are less than the maximum
219
+ number of agents present.
220
+ :param dtype: Dtype of output numpy array.
221
+ :return: Numpy array or List of numpy arrays representing this AgentBufferField, where the first
222
+ dimension is equal to the length of the AgentBufferField.
223
+ """
224
+ if len(self) > 0 and not isinstance(self[0], list):
225
+ return np.asanyarray(self, dtype=dtype)
226
+
227
+ shape = None
228
+ for _entry in self:
229
+ # _entry could be an empty list if there are no group agents in this
230
+ # step. Find the first non-empty list and use that shape.
231
+ if _entry:
232
+ shape = _entry[0].shape
233
+ break
234
+ # If there were no groupmate agents in the entire batch, return an empty List.
235
+ if shape is None:
236
+ return []
237
+
238
+ # Convert to numpy array while padding with 0's
239
+ new_list = list(
240
+ map(
241
+ lambda x: np.asanyarray(x, dtype=dtype),
242
+ itertools.zip_longest(*self, fillvalue=np.full(shape, pad_value)),
243
+ )
244
+ )
245
+ return new_list
246
+
247
+ def to_ndarray(self):
248
+ """
249
+ Returns the AgentBufferField which is a list of numpy ndarrays (or List[np.ndarray]) as an ndarray.
250
+ """
251
+ return np.array(self)
252
+
253
+
254
+ class AgentBuffer(MutableMapping):
255
+ """
256
+ AgentBuffer contains a dictionary of AgentBufferFields. Each agent has his own AgentBuffer.
257
+ The keys correspond to the name of the field. Example: state, action
258
+ """
259
+
260
+ # Whether or not to validate the types of keys at runtime
261
+ # This should be off for training, but enabled for testing
262
+ CHECK_KEY_TYPES_AT_RUNTIME = False
263
+
264
+ def __init__(self):
265
+ self.last_brain_info = None
266
+ self.last_take_action_outputs = None
267
+ self._fields: DefaultDict[AgentBufferKey, AgentBufferField] = defaultdict(
268
+ AgentBufferField
269
+ )
270
+
271
+ def __str__(self):
272
+ return ", ".join([f"'{k}' : {str(self[k])}" for k in self._fields.keys()])
273
+
274
+ def reset_agent(self) -> None:
275
+ """
276
+ Resets the AgentBuffer
277
+ """
278
+ for f in self._fields.values():
279
+ f.reset_field()
280
+ self.last_brain_info = None
281
+ self.last_take_action_outputs = None
282
+
283
+ @staticmethod
284
+ def _check_key(key):
285
+ if isinstance(key, BufferKey):
286
+ return
287
+ if isinstance(key, tuple):
288
+ key0, key1 = key
289
+ if isinstance(key0, ObservationKeyPrefix):
290
+ if isinstance(key1, int):
291
+ return
292
+ raise KeyError(f"{key} has type ({type(key0)}, {type(key1)})")
293
+ if isinstance(key0, RewardSignalKeyPrefix):
294
+ if isinstance(key1, str):
295
+ return
296
+ raise KeyError(f"{key} has type ({type(key0)}, {type(key1)})")
297
+ raise KeyError(f"{key} is a {type(key)}")
298
+
299
+ @staticmethod
300
+ def _encode_key(key: AgentBufferKey) -> str:
301
+ """
302
+ Convert the key to a string representation so that it can be used for serialization.
303
+ """
304
+ if isinstance(key, BufferKey):
305
+ return key.value
306
+ prefix, suffix = key
307
+ return f"{prefix.value}:{suffix}"
308
+
309
+ @staticmethod
310
+ def _decode_key(encoded_key: str) -> AgentBufferKey:
311
+ """
312
+ Convert the string representation back to a key after serialization.
313
+ """
314
+ # Simple case: convert the string directly to a BufferKey
315
+ try:
316
+ return BufferKey(encoded_key)
317
+ except ValueError:
318
+ pass
319
+
320
+ # Not a simple key, so split into two parts
321
+ prefix_str, _, suffix_str = encoded_key.partition(":")
322
+
323
+ # See if it's an ObservationKeyPrefix first
324
+ try:
325
+ return ObservationKeyPrefix(prefix_str), int(suffix_str)
326
+ except ValueError:
327
+ pass
328
+
329
+ # If not, it had better be a RewardSignalKeyPrefix
330
+ try:
331
+ return RewardSignalKeyPrefix(prefix_str), suffix_str
332
+ except ValueError:
333
+ raise ValueError(f"Unable to convert {encoded_key} to an AgentBufferKey")
334
+
335
+ def __getitem__(self, key: AgentBufferKey) -> AgentBufferField:
336
+ if self.CHECK_KEY_TYPES_AT_RUNTIME:
337
+ self._check_key(key)
338
+ return self._fields[key]
339
+
340
+ def __setitem__(self, key: AgentBufferKey, value: AgentBufferField) -> None:
341
+ if self.CHECK_KEY_TYPES_AT_RUNTIME:
342
+ self._check_key(key)
343
+ self._fields[key] = value
344
+
345
+ def __delitem__(self, key: AgentBufferKey) -> None:
346
+ if self.CHECK_KEY_TYPES_AT_RUNTIME:
347
+ self._check_key(key)
348
+ self._fields.__delitem__(key)
349
+
350
+ def __iter__(self):
351
+ return self._fields.__iter__()
352
+
353
+ def __len__(self) -> int:
354
+ return self._fields.__len__()
355
+
356
+ def __contains__(self, key):
357
+ if self.CHECK_KEY_TYPES_AT_RUNTIME:
358
+ self._check_key(key)
359
+ return self._fields.__contains__(key)
360
+
361
+ def check_length(self, key_list: List[AgentBufferKey]) -> bool:
362
+ """
363
+ Some methods will require that some fields have the same length.
364
+ check_length will return true if the fields in key_list
365
+ have the same length.
366
+ :param key_list: The fields which length will be compared
367
+ """
368
+ if self.CHECK_KEY_TYPES_AT_RUNTIME:
369
+ for k in key_list:
370
+ self._check_key(k)
371
+
372
+ if len(key_list) < 2:
373
+ return True
374
+ length = None
375
+ for key in key_list:
376
+ if key not in self._fields:
377
+ return False
378
+ if (length is not None) and (length != len(self[key])):
379
+ return False
380
+ length = len(self[key])
381
+ return True
382
+
383
+ def shuffle(
384
+ self, sequence_length: int, key_list: List[AgentBufferKey] = None
385
+ ) -> None:
386
+ """
387
+ Shuffles the fields in key_list in a consistent way: The reordering will
388
+ be the same across fields.
389
+ :param key_list: The fields that must be shuffled.
390
+ """
391
+ if key_list is None:
392
+ key_list = list(self._fields.keys())
393
+ if not self.check_length(key_list):
394
+ raise BufferException(
395
+ "Unable to shuffle if the fields are not of same length"
396
+ )
397
+ s = np.arange(len(self[key_list[0]]) // sequence_length)
398
+ np.random.shuffle(s)
399
+ for key in key_list:
400
+ buffer_field = self[key]
401
+ tmp: List[np.ndarray] = []
402
+ for i in s:
403
+ tmp += buffer_field[i * sequence_length : (i + 1) * sequence_length]
404
+ buffer_field.set(tmp)
405
+
406
+ def make_mini_batch(self, start: int, end: int) -> "AgentBuffer":
407
+ """
408
+ Creates a mini-batch from buffer.
409
+ :param start: Starting index of buffer.
410
+ :param end: Ending index of buffer.
411
+ :return: Dict of mini batch.
412
+ """
413
+ mini_batch = AgentBuffer()
414
+ for key, field in self._fields.items():
415
+ # slicing AgentBufferField returns a List[Any}
416
+ mini_batch[key] = field[start:end] # type: ignore
417
+ return mini_batch
418
+
419
+ def sample_mini_batch(
420
+ self, batch_size: int, sequence_length: int = 1
421
+ ) -> "AgentBuffer":
422
+ """
423
+ Creates a mini-batch from a random start and end.
424
+ :param batch_size: number of elements to withdraw.
425
+ :param sequence_length: Length of sequences to sample.
426
+ Number of sequences to sample will be batch_size/sequence_length.
427
+ """
428
+ num_seq_to_sample = batch_size // sequence_length
429
+ mini_batch = AgentBuffer()
430
+ buff_len = self.num_experiences
431
+ num_sequences_in_buffer = buff_len // sequence_length
432
+ start_idxes = (
433
+ np.random.randint(num_sequences_in_buffer, size=num_seq_to_sample)
434
+ * sequence_length
435
+ ) # Sample random sequence starts
436
+ for key in self:
437
+ buffer_field = self[key]
438
+ mb_list = (buffer_field[i : i + sequence_length] for i in start_idxes)
439
+ # See comparison of ways to make a list from a list of lists here:
440
+ # https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-list-of-lists
441
+ mini_batch[key].set(list(itertools.chain.from_iterable(mb_list)))
442
+ return mini_batch
443
+
444
+ def save_to_file(self, file_object: BinaryIO) -> None:
445
+ """
446
+ Saves the AgentBuffer to a file-like object.
447
+ """
448
+ with h5py.File(file_object, "w") as write_file:
449
+ for key, data in self.items():
450
+ write_file.create_dataset(
451
+ self._encode_key(key), data=data, dtype="f", compression="gzip"
452
+ )
453
+
454
+ def load_from_file(self, file_object: BinaryIO) -> None:
455
+ """
456
+ Loads the AgentBuffer from a file-like object.
457
+ """
458
+ with h5py.File(file_object, "r") as read_file:
459
+ for key in list(read_file.keys()):
460
+ decoded_key = self._decode_key(key)
461
+ self[decoded_key] = AgentBufferField()
462
+ # extend() will convert the numpy array's first dimension into list
463
+ self[decoded_key].extend(read_file[key][()])
464
+
465
+ def truncate(self, max_length: int, sequence_length: int = 1) -> None:
466
+ """
467
+ Truncates the buffer to a certain length.
468
+
469
+ This can be slow for large buffers. We compensate by cutting further than we need to, so that
470
+ we're not truncating at each update. Note that we must truncate an integer number of sequence_lengths
471
+ param: max_length: The length at which to truncate the buffer.
472
+ """
473
+ current_length = self.num_experiences
474
+ # make max_length an integer number of sequence_lengths
475
+ max_length -= max_length % sequence_length
476
+ if current_length > max_length:
477
+ for _key in self.keys():
478
+ self[_key][:] = self[_key][current_length - max_length :]
479
+
480
+ def resequence_and_append(
481
+ self,
482
+ target_buffer: "AgentBuffer",
483
+ key_list: List[AgentBufferKey] = None,
484
+ batch_size: int = None,
485
+ training_length: int = None,
486
+ ) -> None:
487
+ """
488
+ Takes in a batch size and training length (sequence length), and appends this AgentBuffer to target_buffer
489
+ properly padded for LSTM use. Optionally, use key_list to restrict which fields are inserted into the new
490
+ buffer.
491
+ :param target_buffer: The buffer which to append the samples to.
492
+ :param key_list: The fields that must be added. If None: all fields will be appended.
493
+ :param batch_size: The number of elements that must be appended. If None: All of them will be.
494
+ :param training_length: The length of the samples that must be appended. If None: only takes one element.
495
+ """
496
+ if key_list is None:
497
+ key_list = list(self.keys())
498
+ if not self.check_length(key_list):
499
+ raise BufferException(
500
+ f"The length of the fields {key_list} were not of same length"
501
+ )
502
+ for field_key in key_list:
503
+ target_buffer[field_key].extend(
504
+ self[field_key].get_batch(
505
+ batch_size=batch_size, training_length=training_length
506
+ )
507
+ )
508
+
509
+ @property
510
+ def num_experiences(self) -> int:
511
+ """
512
+ The number of agent experiences in the AgentBuffer, i.e. the length of the buffer.
513
+
514
+ An experience consists of one element across all of the fields of this AgentBuffer.
515
+ Note that these all have to be the same length, otherwise shuffle and append_to_update_buffer
516
+ will fail.
517
+ """
518
+ if self.values():
519
+ return len(next(iter(self.values())))
520
+ else:
521
+ return 0
MLPY/Lib/site-packages/mlagents/trainers/cli_utils.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Dict, Any, TextIO
2
+ import os
3
+ import yaml
4
+ from mlagents.trainers.exception import TrainerConfigError
5
+ from mlagents_envs.environment import UnityEnvironment
6
+ import argparse
7
+ from mlagents_envs import logging_util
8
+
9
+ logger = logging_util.get_logger(__name__)
10
+
11
+
12
+ class RaiseRemovedWarning(argparse.Action):
13
+ """
14
+ Internal custom Action to raise warning when argument is called.
15
+ """
16
+
17
+ def __init__(self, nargs=0, **kwargs):
18
+ super().__init__(nargs=nargs, **kwargs)
19
+
20
+ def __call__(self, arg_parser, namespace, values, option_string=None):
21
+ logger.warning(f"The command line argument {option_string} was removed.")
22
+
23
+
24
+ class DetectDefault(argparse.Action):
25
+ """
26
+ Internal custom Action to help detect arguments that aren't default.
27
+ """
28
+
29
+ non_default_args: Set[str] = set()
30
+
31
+ def __call__(self, arg_parser, namespace, values, option_string=None):
32
+ setattr(namespace, self.dest, values)
33
+ DetectDefault.non_default_args.add(self.dest)
34
+
35
+
36
+ class DetectDefaultStoreTrue(DetectDefault):
37
+ """
38
+ Internal class to help detect arguments that aren't default.
39
+ Used for store_true arguments.
40
+ """
41
+
42
+ def __init__(self, nargs=0, **kwargs):
43
+ super().__init__(nargs=nargs, **kwargs)
44
+
45
+ def __call__(self, arg_parser, namespace, values, option_string=None):
46
+ super().__call__(arg_parser, namespace, True, option_string)
47
+
48
+
49
+ class StoreConfigFile(argparse.Action):
50
+ """
51
+ Custom Action to store the config file location not as part of the CLI args.
52
+ This is because we want to maintain an equivalence between the config file's
53
+ contents and the args themselves.
54
+ """
55
+
56
+ trainer_config_path: str
57
+
58
+ def __call__(self, arg_parser, namespace, values, option_string=None):
59
+ delattr(namespace, self.dest)
60
+ StoreConfigFile.trainer_config_path = values
61
+
62
+
63
+ def _create_parser() -> argparse.ArgumentParser:
64
+ argparser = argparse.ArgumentParser(
65
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
66
+ )
67
+ argparser.add_argument(
68
+ "trainer_config_path", action=StoreConfigFile, nargs="?", default=None
69
+ )
70
+ argparser.add_argument(
71
+ "--env",
72
+ default=None,
73
+ dest="env_path",
74
+ help="Path to the Unity executable to train",
75
+ action=DetectDefault,
76
+ )
77
+ argparser.add_argument(
78
+ "--load",
79
+ default=False,
80
+ dest="load_model",
81
+ action=DetectDefaultStoreTrue,
82
+ help=argparse.SUPPRESS, # Deprecated but still usable for now.
83
+ )
84
+ argparser.add_argument(
85
+ "--resume",
86
+ default=False,
87
+ dest="resume",
88
+ action=DetectDefaultStoreTrue,
89
+ help="Whether to resume training from a checkpoint. Specify a --run-id to use this option. "
90
+ "If set, the training code loads an already trained model to initialize the neural network "
91
+ "before resuming training. This option is only valid when the models exist, and have the same "
92
+ "behavior names as the current agents in your scene.",
93
+ )
94
+ argparser.add_argument(
95
+ "--deterministic",
96
+ default=False,
97
+ dest="deterministic",
98
+ action=DetectDefaultStoreTrue,
99
+ help="Whether to select actions deterministically in policy. `dist.mean` for continuous action "
100
+ "space, and `dist.argmax` for deterministic action space ",
101
+ )
102
+ argparser.add_argument(
103
+ "--force",
104
+ default=False,
105
+ dest="force",
106
+ action=DetectDefaultStoreTrue,
107
+ help="Whether to force-overwrite this run-id's existing summary and model data. (Without "
108
+ "this flag, attempting to train a model with a run-id that has been used before will throw "
109
+ "an error.",
110
+ )
111
+ argparser.add_argument(
112
+ "--run-id",
113
+ default="ppo",
114
+ help="The identifier for the training run. This identifier is used to name the "
115
+ "subdirectories in which the trained model and summary statistics are saved as well "
116
+ "as the saved model itself. If you use TensorBoard to view the training statistics, "
117
+ "always set a unique run-id for each training run. (The statistics for all runs with the "
118
+ "same id are combined as if they were produced by a the same session.)",
119
+ action=DetectDefault,
120
+ )
121
+ argparser.add_argument(
122
+ "--initialize-from",
123
+ metavar="RUN_ID",
124
+ default=None,
125
+ help="Specify a previously saved run ID from which to initialize the model from. "
126
+ "This can be used, for instance, to fine-tune an existing model on a new environment. "
127
+ "Note that the previously saved models must have the same behavior parameters as your "
128
+ "current environment.",
129
+ action=DetectDefault,
130
+ )
131
+ argparser.add_argument(
132
+ "--seed",
133
+ default=-1,
134
+ type=int,
135
+ help="A number to use as a seed for the random number generator used by the training code",
136
+ action=DetectDefault,
137
+ )
138
+ argparser.add_argument(
139
+ "--train",
140
+ default=False,
141
+ dest="train_model",
142
+ action=DetectDefaultStoreTrue,
143
+ help=argparse.SUPPRESS,
144
+ )
145
+ argparser.add_argument(
146
+ "--inference",
147
+ default=False,
148
+ dest="inference",
149
+ action=DetectDefaultStoreTrue,
150
+ help="Whether to run in Python inference mode (i.e. no training). Use with --resume to load "
151
+ "a model trained with an existing run ID.",
152
+ )
153
+ argparser.add_argument(
154
+ "--base-port",
155
+ default=UnityEnvironment.BASE_ENVIRONMENT_PORT,
156
+ type=int,
157
+ help="The starting port for environment communication. Each concurrent Unity environment "
158
+ "instance will get assigned a port sequentially, starting from the base-port. Each instance "
159
+ "will use the port (base_port + worker_id), where the worker_id is sequential IDs given to "
160
+ "each instance from 0 to (num_envs - 1). Note that when training using the Editor rather "
161
+ "than an executable, the base port will be ignored.",
162
+ action=DetectDefault,
163
+ )
164
+ argparser.add_argument(
165
+ "--num-envs",
166
+ default=1,
167
+ type=int,
168
+ help="The number of concurrent Unity environment instances to collect experiences "
169
+ "from when training",
170
+ action=DetectDefault,
171
+ )
172
+
173
+ argparser.add_argument(
174
+ "--num-areas",
175
+ default=1,
176
+ type=int,
177
+ help="The number of parallel training areas in each Unity environment instance.",
178
+ action=DetectDefault,
179
+ )
180
+
181
+ argparser.add_argument(
182
+ "--debug",
183
+ default=False,
184
+ action=DetectDefaultStoreTrue,
185
+ help="Whether to enable debug-level logging for some parts of the code",
186
+ )
187
+ argparser.add_argument(
188
+ "--env-args",
189
+ default=None,
190
+ nargs=argparse.REMAINDER,
191
+ help="Arguments passed to the Unity executable. Be aware that the standalone build will also "
192
+ "process these as Unity Command Line Arguments. You should choose different argument names if "
193
+ "you want to create environment-specific arguments. All arguments after this flag will be "
194
+ "passed to the executable.",
195
+ action=DetectDefault,
196
+ )
197
+ argparser.add_argument(
198
+ "--max-lifetime-restarts",
199
+ default=10,
200
+ help="The max number of times a single Unity executable can crash over its lifetime before ml-agents exits. "
201
+ "Can be set to -1 if no limit is desired.",
202
+ action=DetectDefault,
203
+ )
204
+ argparser.add_argument(
205
+ "--restarts-rate-limit-n",
206
+ default=1,
207
+ help="The maximum number of times a single Unity executable can crash over a period of time (period set in "
208
+ "restarts-rate-limit-period-s). Can be set to -1 to not use rate limiting with restarts.",
209
+ action=DetectDefault,
210
+ )
211
+ argparser.add_argument(
212
+ "--restarts-rate-limit-period-s",
213
+ default=60,
214
+ help="The period of time --restarts-rate-limit-n applies to.",
215
+ action=DetectDefault,
216
+ )
217
+ argparser.add_argument(
218
+ "--torch",
219
+ default=False,
220
+ action=RaiseRemovedWarning,
221
+ help="(Removed) Use the PyTorch framework.",
222
+ )
223
+ argparser.add_argument(
224
+ "--tensorflow",
225
+ default=False,
226
+ action=RaiseRemovedWarning,
227
+ help="(Removed) Use the TensorFlow framework.",
228
+ )
229
+ argparser.add_argument(
230
+ "--results-dir",
231
+ default="results",
232
+ action=DetectDefault,
233
+ help="Results base directory",
234
+ )
235
+
236
+ eng_conf = argparser.add_argument_group(title="Engine Configuration")
237
+ eng_conf.add_argument(
238
+ "--width",
239
+ default=84,
240
+ type=int,
241
+ help="The width of the executable window of the environment(s) in pixels "
242
+ "(ignored for editor training).",
243
+ action=DetectDefault,
244
+ )
245
+ eng_conf.add_argument(
246
+ "--height",
247
+ default=84,
248
+ type=int,
249
+ help="The height of the executable window of the environment(s) in pixels "
250
+ "(ignored for editor training)",
251
+ action=DetectDefault,
252
+ )
253
+ eng_conf.add_argument(
254
+ "--quality-level",
255
+ default=5,
256
+ type=int,
257
+ help="The quality level of the environment(s). Equivalent to calling "
258
+ "QualitySettings.SetQualityLevel in Unity.",
259
+ action=DetectDefault,
260
+ )
261
+ eng_conf.add_argument(
262
+ "--time-scale",
263
+ default=20,
264
+ type=float,
265
+ help="The time scale of the Unity environment(s). Equivalent to setting "
266
+ "Time.timeScale in Unity.",
267
+ action=DetectDefault,
268
+ )
269
+ eng_conf.add_argument(
270
+ "--target-frame-rate",
271
+ default=-1,
272
+ type=int,
273
+ help="The target frame rate of the Unity environment(s). Equivalent to setting "
274
+ "Application.targetFrameRate in Unity.",
275
+ action=DetectDefault,
276
+ )
277
+ eng_conf.add_argument(
278
+ "--capture-frame-rate",
279
+ default=60,
280
+ type=int,
281
+ help="The capture frame rate of the Unity environment(s). Equivalent to setting "
282
+ "Time.captureFramerate in Unity.",
283
+ action=DetectDefault,
284
+ )
285
+ eng_conf.add_argument(
286
+ "--no-graphics",
287
+ default=False,
288
+ action=DetectDefaultStoreTrue,
289
+ help="Whether to run the Unity executable in no-graphics mode (i.e. without initializing "
290
+ "the graphics driver. Use this only if your agents don't use visual observations.",
291
+ )
292
+
293
+ torch_conf = argparser.add_argument_group(title="Torch Configuration")
294
+ torch_conf.add_argument(
295
+ "--torch-device",
296
+ default=None,
297
+ dest="device",
298
+ action=DetectDefault,
299
+ help='Settings for the default torch.device used in training, for example, "cpu", "cuda", or "cuda:0"',
300
+ )
301
+ return argparser
302
+
303
+
304
+ def load_config(config_path: str) -> Dict[str, Any]:
305
+ try:
306
+ with open(config_path) as data_file:
307
+ return _load_config(data_file)
308
+ except OSError:
309
+ abs_path = os.path.abspath(config_path)
310
+ raise TrainerConfigError(f"Config file could not be found at {abs_path}.")
311
+ except UnicodeDecodeError:
312
+ raise TrainerConfigError(
313
+ f"There was an error decoding Config file from {config_path}. "
314
+ f"Make sure your file is save using UTF-8"
315
+ )
316
+
317
+
318
+ def _load_config(fp: TextIO) -> Dict[str, Any]:
319
+ """
320
+ Load the yaml config from the file-like object.
321
+ """
322
+ try:
323
+ return yaml.safe_load(fp)
324
+ except yaml.parser.ParserError as e:
325
+ raise TrainerConfigError(
326
+ "Error parsing yaml file. Please check for formatting errors. "
327
+ "A tool such as http://www.yamllint.com/ can be helpful with this."
328
+ ) from e
329
+
330
+
331
+ parser = _create_parser()
MLPY/Lib/site-packages/mlagents/trainers/demo_loader.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Tuple
3
+ import numpy as np
4
+ from mlagents.trainers.buffer import AgentBuffer, BufferKey
5
+ from mlagents_envs.communicator_objects.agent_info_action_pair_pb2 import (
6
+ AgentInfoActionPairProto,
7
+ )
8
+ from mlagents.trainers.trajectory import ObsUtil
9
+ from mlagents_envs.rpc_utils import behavior_spec_from_proto, steps_from_proto
10
+ from mlagents_envs.base_env import BehaviorSpec
11
+ from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
12
+ from mlagents_envs.communicator_objects.demonstration_meta_pb2 import (
13
+ DemonstrationMetaProto,
14
+ )
15
+ from mlagents_envs.timers import timed, hierarchical_timer
16
+ from google.protobuf.internal.decoder import _DecodeVarint32 # type: ignore
17
+ from google.protobuf.internal.encoder import _EncodeVarint # type: ignore
18
+
19
+
20
+ INITIAL_POS = 33
21
+ SUPPORTED_DEMONSTRATION_VERSIONS = frozenset([0, 1])
22
+
23
+
24
+ @timed
25
+ def make_demo_buffer(
26
+ pair_infos: List[AgentInfoActionPairProto],
27
+ behavior_spec: BehaviorSpec,
28
+ sequence_length: int,
29
+ ) -> AgentBuffer:
30
+ # Create and populate buffer using experiences
31
+ demo_raw_buffer = AgentBuffer()
32
+ demo_processed_buffer = AgentBuffer()
33
+ for idx, current_pair_info in enumerate(pair_infos):
34
+ if idx > len(pair_infos) - 2:
35
+ break
36
+ next_pair_info = pair_infos[idx + 1]
37
+ current_decision_step, current_terminal_step = steps_from_proto(
38
+ [current_pair_info.agent_info], behavior_spec
39
+ )
40
+ next_decision_step, next_terminal_step = steps_from_proto(
41
+ [next_pair_info.agent_info], behavior_spec
42
+ )
43
+ previous_action = (
44
+ np.array(
45
+ pair_infos[idx].action_info.vector_actions_deprecated, dtype=np.float32
46
+ )
47
+ * 0
48
+ )
49
+ if idx > 0:
50
+ previous_action = np.array(
51
+ pair_infos[idx - 1].action_info.vector_actions_deprecated,
52
+ dtype=np.float32,
53
+ )
54
+
55
+ next_done = len(next_terminal_step) == 1
56
+ next_reward = 0
57
+ if len(next_terminal_step) == 1:
58
+ next_reward = next_terminal_step.reward[0]
59
+ else:
60
+ next_reward = next_decision_step.reward[0]
61
+ current_obs = None
62
+ if len(current_terminal_step) == 1:
63
+ current_obs = list(current_terminal_step.values())[0].obs
64
+ else:
65
+ current_obs = list(current_decision_step.values())[0].obs
66
+
67
+ demo_raw_buffer[BufferKey.DONE].append(next_done)
68
+ demo_raw_buffer[BufferKey.ENVIRONMENT_REWARDS].append(next_reward)
69
+ for i, obs in enumerate(current_obs):
70
+ demo_raw_buffer[ObsUtil.get_name_at(i)].append(obs)
71
+ if (
72
+ len(current_pair_info.action_info.continuous_actions) == 0
73
+ and len(current_pair_info.action_info.discrete_actions) == 0
74
+ ):
75
+ if behavior_spec.action_spec.continuous_size > 0:
76
+ demo_raw_buffer[BufferKey.CONTINUOUS_ACTION].append(
77
+ current_pair_info.action_info.vector_actions_deprecated
78
+ )
79
+ else:
80
+ demo_raw_buffer[BufferKey.DISCRETE_ACTION].append(
81
+ current_pair_info.action_info.vector_actions_deprecated
82
+ )
83
+ else:
84
+ if behavior_spec.action_spec.continuous_size > 0:
85
+ demo_raw_buffer[BufferKey.CONTINUOUS_ACTION].append(
86
+ current_pair_info.action_info.continuous_actions
87
+ )
88
+ if behavior_spec.action_spec.discrete_size > 0:
89
+ demo_raw_buffer[BufferKey.DISCRETE_ACTION].append(
90
+ current_pair_info.action_info.discrete_actions
91
+ )
92
+ demo_raw_buffer[BufferKey.PREV_ACTION].append(previous_action)
93
+ if next_done:
94
+ demo_raw_buffer.resequence_and_append(
95
+ demo_processed_buffer, batch_size=None, training_length=sequence_length
96
+ )
97
+ demo_raw_buffer.reset_agent()
98
+ demo_raw_buffer.resequence_and_append(
99
+ demo_processed_buffer, batch_size=None, training_length=sequence_length
100
+ )
101
+ return demo_processed_buffer
102
+
103
+
104
+ @timed
105
+ def demo_to_buffer(
106
+ file_path: str, sequence_length: int, expected_behavior_spec: BehaviorSpec = None
107
+ ) -> Tuple[BehaviorSpec, AgentBuffer]:
108
+ """
109
+ Loads demonstration file and uses it to fill training buffer.
110
+ :param file_path: Location of demonstration file (.demo).
111
+ :param sequence_length: Length of trajectories to fill buffer.
112
+ :return:
113
+ """
114
+ behavior_spec, info_action_pair, _ = load_demonstration(file_path)
115
+ demo_buffer = make_demo_buffer(info_action_pair, behavior_spec, sequence_length)
116
+ if expected_behavior_spec:
117
+ # check action dimensions in demonstration match
118
+ if behavior_spec.action_spec != expected_behavior_spec.action_spec:
119
+ raise RuntimeError(
120
+ "The actions {} in demonstration do not match the policy's {}.".format(
121
+ behavior_spec.action_spec, expected_behavior_spec.action_spec
122
+ )
123
+ )
124
+ # check observations match
125
+ if len(behavior_spec.observation_specs) != len(
126
+ expected_behavior_spec.observation_specs
127
+ ):
128
+ raise RuntimeError(
129
+ "The demonstrations do not have the same number of observations as the policy."
130
+ )
131
+ else:
132
+ for i, (demo_obs, policy_obs) in enumerate(
133
+ zip(
134
+ behavior_spec.observation_specs,
135
+ expected_behavior_spec.observation_specs,
136
+ )
137
+ ):
138
+ if demo_obs.shape != policy_obs.shape:
139
+ raise RuntimeError(
140
+ f"The shape {demo_obs} for observation {i} in demonstration \
141
+ do not match the policy's {policy_obs}."
142
+ )
143
+ return behavior_spec, demo_buffer
144
+
145
+
146
+ def get_demo_files(path: str) -> List[str]:
147
+ """
148
+ Retrieves the demonstration file(s) from a path.
149
+ :param path: Path of demonstration file or directory.
150
+ :return: List of demonstration files
151
+
152
+ Raises errors if |path| is invalid.
153
+ """
154
+ if os.path.isfile(path):
155
+ if not path.endswith(".demo"):
156
+ raise ValueError("The path provided is not a '.demo' file.")
157
+ return [path]
158
+ elif os.path.isdir(path):
159
+ paths = [
160
+ os.path.join(path, name)
161
+ for name in os.listdir(path)
162
+ if name.endswith(".demo")
163
+ ]
164
+ if not paths:
165
+ raise ValueError("There are no '.demo' files in the provided directory.")
166
+ return paths
167
+ else:
168
+ raise FileNotFoundError(
169
+ f"The demonstration file or directory {path} does not exist."
170
+ )
171
+
172
+
173
+ @timed
174
+ def load_demonstration(
175
+ file_path: str,
176
+ ) -> Tuple[BehaviorSpec, List[AgentInfoActionPairProto], int]:
177
+ """
178
+ Loads and parses a demonstration file.
179
+ :param file_path: Location of demonstration file (.demo).
180
+ :return: BrainParameter and list of AgentInfoActionPairProto containing demonstration data.
181
+ """
182
+
183
+ # First 32 bytes of file dedicated to meta-data.
184
+ file_paths = get_demo_files(file_path)
185
+ behavior_spec = None
186
+ brain_param_proto = None
187
+ info_action_pairs = []
188
+ total_expected = 0
189
+ for _file_path in file_paths:
190
+ with open(_file_path, "rb") as fp:
191
+ with hierarchical_timer("read_file"):
192
+ data = fp.read()
193
+ next_pos, pos, obs_decoded = 0, 0, 0
194
+ while pos < len(data):
195
+ next_pos, pos = _DecodeVarint32(data, pos)
196
+ if obs_decoded == 0:
197
+ meta_data_proto = DemonstrationMetaProto()
198
+ meta_data_proto.ParseFromString(data[pos : pos + next_pos])
199
+ if (
200
+ meta_data_proto.api_version
201
+ not in SUPPORTED_DEMONSTRATION_VERSIONS
202
+ ):
203
+ raise RuntimeError(
204
+ f"Can't load Demonstration data from an unsupported version ({meta_data_proto.api_version})"
205
+ )
206
+ total_expected += meta_data_proto.number_steps
207
+ pos = INITIAL_POS
208
+ if obs_decoded == 1:
209
+ brain_param_proto = BrainParametersProto()
210
+ brain_param_proto.ParseFromString(data[pos : pos + next_pos])
211
+ pos += next_pos
212
+ if obs_decoded > 1:
213
+ agent_info_action = AgentInfoActionPairProto()
214
+ agent_info_action.ParseFromString(data[pos : pos + next_pos])
215
+ if behavior_spec is None:
216
+ behavior_spec = behavior_spec_from_proto(
217
+ brain_param_proto, agent_info_action.agent_info
218
+ )
219
+ info_action_pairs.append(agent_info_action)
220
+ if len(info_action_pairs) == total_expected:
221
+ break
222
+ pos += next_pos
223
+ obs_decoded += 1
224
+ if not behavior_spec:
225
+ raise RuntimeError(
226
+ f"No BrainParameters found in demonstration file at {file_path}."
227
+ )
228
+ return behavior_spec, info_action_pairs, total_expected
229
+
230
+
231
+ def write_delimited(f, message):
232
+ msg_string = message.SerializeToString()
233
+ msg_size = len(msg_string)
234
+ _EncodeVarint(f.write, msg_size)
235
+ f.write(msg_string)
236
+
237
+
238
+ def write_demo(demo_path, meta_data_proto, brain_param_proto, agent_info_protos):
239
+ with open(demo_path, "wb") as f:
240
+ # write metadata
241
+ write_delimited(f, meta_data_proto)
242
+ f.seek(INITIAL_POS)
243
+ write_delimited(f, brain_param_proto)
244
+
245
+ for agent in agent_info_protos:
246
+ write_delimited(f, agent)
MLPY/Lib/site-packages/mlagents/trainers/directory_utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from mlagents.trainers.exception import UnityTrainerException
3
+ from mlagents.trainers.settings import TrainerSettings
4
+ from mlagents.trainers.model_saver.torch_model_saver import DEFAULT_CHECKPOINT_NAME
5
+
6
+
7
+ def validate_existing_directories(
8
+ output_path: str, resume: bool, force: bool, init_path: str = None
9
+ ) -> None:
10
+ """
11
+ Validates that if the run_id model exists, we do not overwrite it unless --force is specified.
12
+ Throws an exception if resume isn't specified and run_id exists. Throws an exception
13
+ if --resume is specified and run-id was not found.
14
+ :param model_path: The model path specified.
15
+ :param summary_path: The summary path to be used.
16
+ :param resume: Whether or not the --resume flag was passed.
17
+ :param force: Whether or not the --force flag was passed.
18
+ :param init_path: Path to run-id dir to initialize from
19
+ """
20
+
21
+ output_path_exists = os.path.isdir(output_path)
22
+
23
+ if output_path_exists:
24
+ if not resume and not force:
25
+ raise UnityTrainerException(
26
+ "Previous data from this run ID was found. "
27
+ "Either specify a new run ID, use --resume to resume this run, "
28
+ "or use the --force parameter to overwrite existing data."
29
+ )
30
+ else:
31
+ if resume:
32
+ raise UnityTrainerException(
33
+ "Previous data from this run ID was not found. "
34
+ "Train a new run by removing the --resume flag."
35
+ )
36
+
37
+ # Verify init path if specified.
38
+ if init_path is not None:
39
+ if not os.path.isdir(init_path):
40
+ raise UnityTrainerException(
41
+ "Could not initialize from {}. "
42
+ "Make sure models have already been saved with that run ID.".format(
43
+ init_path
44
+ )
45
+ )
46
+
47
+
48
+ def setup_init_path(
49
+ behaviors: TrainerSettings.DefaultTrainerDict, init_dir: str
50
+ ) -> None:
51
+ """
52
+ For each behavior, setup full init_path to checkpoint file to initialize policy from
53
+ :param behaviors: mapping from behavior_name to TrainerSettings
54
+ :param init_dir: Path to run-id dir to initialize from
55
+ """
56
+ for behavior_name, ts in behaviors.items():
57
+ if ts.init_path is None:
58
+ # set default if None
59
+ ts.init_path = os.path.join(
60
+ init_dir, behavior_name, DEFAULT_CHECKPOINT_NAME
61
+ )
62
+ elif not os.path.dirname(ts.init_path):
63
+ # update to full path if just the file name
64
+ ts.init_path = os.path.join(init_dir, behavior_name, ts.init_path)
65
+ _validate_init_full_path(ts.init_path)
66
+
67
+
68
+ def _validate_init_full_path(init_file: str) -> None:
69
+ """
70
+ Validate initialization path to be a .pt file
71
+ :param init_file: full path to initialization checkpoint file
72
+ """
73
+ if not (os.path.isfile(init_file) and init_file.endswith(".pt")):
74
+ raise UnityTrainerException(
75
+ f"Could not initialize from {init_file}. file does not exists or is not a `.pt` file"
76
+ )
MLPY/Lib/site-packages/mlagents/trainers/env_manager.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ from typing import List, Dict, NamedTuple, Iterable, Tuple
4
+ from mlagents_envs.base_env import (
5
+ DecisionSteps,
6
+ TerminalSteps,
7
+ BehaviorSpec,
8
+ BehaviorName,
9
+ )
10
+ from mlagents_envs.side_channel.stats_side_channel import EnvironmentStats
11
+
12
+ from mlagents.trainers.policy import Policy
13
+ from mlagents.trainers.agent_processor import AgentManager, AgentManagerQueue
14
+ from mlagents.trainers.action_info import ActionInfo
15
+ from mlagents.trainers.settings import TrainerSettings
16
+ from mlagents_envs.logging_util import get_logger
17
+
18
+ AllStepResult = Dict[BehaviorName, Tuple[DecisionSteps, TerminalSteps]]
19
+ AllGroupSpec = Dict[BehaviorName, BehaviorSpec]
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ class EnvironmentStep(NamedTuple):
25
+ current_all_step_result: AllStepResult
26
+ worker_id: int
27
+ brain_name_to_action_info: Dict[BehaviorName, ActionInfo]
28
+ environment_stats: EnvironmentStats
29
+
30
+ @property
31
+ def name_behavior_ids(self) -> Iterable[BehaviorName]:
32
+ return self.current_all_step_result.keys()
33
+
34
+ @staticmethod
35
+ def empty(worker_id: int) -> "EnvironmentStep":
36
+ return EnvironmentStep({}, worker_id, {}, {})
37
+
38
+
39
+ class EnvManager(ABC):
40
+ def __init__(self):
41
+ self.policies: Dict[BehaviorName, Policy] = {}
42
+ self.agent_managers: Dict[BehaviorName, AgentManager] = {}
43
+ self.first_step_infos: List[EnvironmentStep] = []
44
+
45
+ def set_policy(self, brain_name: BehaviorName, policy: Policy) -> None:
46
+ self.policies[brain_name] = policy
47
+ if brain_name in self.agent_managers:
48
+ self.agent_managers[brain_name].policy = policy
49
+
50
+ def set_agent_manager(
51
+ self, brain_name: BehaviorName, manager: AgentManager
52
+ ) -> None:
53
+ self.agent_managers[brain_name] = manager
54
+
55
+ @abstractmethod
56
+ def _step(self) -> List[EnvironmentStep]:
57
+ pass
58
+
59
+ @abstractmethod
60
+ def _reset_env(self, config: Dict = None) -> List[EnvironmentStep]:
61
+ pass
62
+
63
+ def reset(self, config: Dict = None) -> int:
64
+ for manager in self.agent_managers.values():
65
+ manager.end_episode()
66
+ # Save the first step infos, after the reset.
67
+ # They will be processed on the first advance().
68
+ self.first_step_infos = self._reset_env(config)
69
+ return len(self.first_step_infos)
70
+
71
+ @abstractmethod
72
+ def set_env_parameters(self, config: Dict = None) -> None:
73
+ """
74
+ Sends environment parameter settings to C# via the
75
+ EnvironmentParametersSideChannel.
76
+ :param config: Dict of environment parameter keys and values
77
+ """
78
+ pass
79
+
80
+ def on_training_started(
81
+ self, behavior_name: str, trainer_settings: TrainerSettings
82
+ ) -> None:
83
+ """
84
+ Handle traing starting for a new behavior type. Generally nothing is necessary here.
85
+ :param behavior_name:
86
+ :param trainer_settings:
87
+ :return:
88
+ """
89
+ pass
90
+
91
+ @property
92
+ @abstractmethod
93
+ def training_behaviors(self) -> Dict[BehaviorName, BehaviorSpec]:
94
+ pass
95
+
96
+ @abstractmethod
97
+ def close(self):
98
+ pass
99
+
100
+ def get_steps(self) -> List[EnvironmentStep]:
101
+ """
102
+ Updates the policies, steps the environments, and returns the step information from the environments.
103
+ Calling code should pass the returned EnvironmentSteps to process_steps() after calling this.
104
+ :return: The list of EnvironmentSteps
105
+ """
106
+ # If we had just reset, process the first EnvironmentSteps.
107
+ # Note that we do it here instead of in reset() so that on the very first reset(),
108
+ # we can create the needed AgentManagers before calling advance() and processing the EnvironmentSteps.
109
+ if self.first_step_infos:
110
+ self._process_step_infos(self.first_step_infos)
111
+ self.first_step_infos = []
112
+ # Get new policies if found. Always get the latest policy.
113
+ for brain_name in self.agent_managers.keys():
114
+ _policy = None
115
+ try:
116
+ # We make sure to empty the policy queue before continuing to produce steps.
117
+ # This halts the trainers until the policy queue is empty.
118
+ while True:
119
+ _policy = self.agent_managers[brain_name].policy_queue.get_nowait()
120
+ except AgentManagerQueue.Empty:
121
+ if _policy is not None:
122
+ self.set_policy(brain_name, _policy)
123
+ # Step the environments
124
+ new_step_infos = self._step()
125
+ return new_step_infos
126
+
127
+ def process_steps(self, new_step_infos: List[EnvironmentStep]) -> int:
128
+ # Add to AgentProcessor
129
+ num_step_infos = self._process_step_infos(new_step_infos)
130
+ return num_step_infos
131
+
132
+ def _process_step_infos(self, step_infos: List[EnvironmentStep]) -> int:
133
+ for step_info in step_infos:
134
+ for name_behavior_id in step_info.name_behavior_ids:
135
+ if name_behavior_id not in self.agent_managers:
136
+ logger.warning(
137
+ "Agent manager was not created for behavior id {}.".format(
138
+ name_behavior_id
139
+ )
140
+ )
141
+ continue
142
+ decision_steps, terminal_steps = step_info.current_all_step_result[
143
+ name_behavior_id
144
+ ]
145
+ self.agent_managers[name_behavior_id].add_experiences(
146
+ decision_steps,
147
+ terminal_steps,
148
+ step_info.worker_id,
149
+ step_info.brain_name_to_action_info.get(
150
+ name_behavior_id, ActionInfo.empty()
151
+ ),
152
+ )
153
+
154
+ self.agent_managers[name_behavior_id].record_environment_stats(
155
+ step_info.environment_stats, step_info.worker_id
156
+ )
157
+ return len(step_infos)
MLPY/Lib/site-packages/mlagents/trainers/environment_parameter_manager.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Optional
2
+ from mlagents.trainers.settings import (
3
+ EnvironmentParameterSettings,
4
+ ParameterRandomizationSettings,
5
+ )
6
+ from collections import defaultdict
7
+ from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType
8
+
9
+ from mlagents_envs.logging_util import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class EnvironmentParameterManager:
15
+ def __init__(
16
+ self,
17
+ settings: Optional[Dict[str, EnvironmentParameterSettings]] = None,
18
+ run_seed: int = -1,
19
+ restore: bool = False,
20
+ ):
21
+ """
22
+ EnvironmentParameterManager manages all the environment parameters of a training
23
+ session. It determines when parameters should change and gives access to the
24
+ current sampler of each parameter.
25
+ :param settings: A dictionary from environment parameter to
26
+ EnvironmentParameterSettings.
27
+ :param run_seed: When the seed is not provided for an environment parameter,
28
+ this seed will be used instead.
29
+ :param restore: If true, the EnvironmentParameterManager will use the
30
+ GlobalTrainingStatus to try and reload the lesson status of each environment
31
+ parameter.
32
+ """
33
+ if settings is None:
34
+ settings = {}
35
+ self._dict_settings = settings
36
+ for parameter_name in self._dict_settings.keys():
37
+ initial_lesson = GlobalTrainingStatus.get_parameter_state(
38
+ parameter_name, StatusType.LESSON_NUM
39
+ )
40
+ if initial_lesson is None or not restore:
41
+ GlobalTrainingStatus.set_parameter_state(
42
+ parameter_name, StatusType.LESSON_NUM, 0
43
+ )
44
+ self._smoothed_values: Dict[str, float] = defaultdict(float)
45
+ for key in self._dict_settings.keys():
46
+ self._smoothed_values[key] = 0.0
47
+ # Update the seeds of the samplers
48
+ self._set_sampler_seeds(run_seed)
49
+
50
+ def _set_sampler_seeds(self, seed):
51
+ """
52
+ Sets the seeds for the samplers (if no seed was already present). Note that
53
+ using the provided seed.
54
+ """
55
+ offset = 0
56
+ for settings in self._dict_settings.values():
57
+ for lesson in settings.curriculum:
58
+ if lesson.value.seed == -1:
59
+ lesson.value.seed = seed + offset
60
+ offset += 1
61
+
62
+ def get_minimum_reward_buffer_size(self, behavior_name: str) -> int:
63
+ """
64
+ Calculates the minimum size of the reward buffer a behavior must use. This
65
+ method uses the 'min_lesson_length' sampler_parameter to determine this value.
66
+ :param behavior_name: The name of the behavior the minimum reward buffer
67
+ size corresponds to.
68
+ """
69
+ result = 1
70
+ for settings in self._dict_settings.values():
71
+ for lesson in settings.curriculum:
72
+ if lesson.completion_criteria is not None:
73
+ if lesson.completion_criteria.behavior == behavior_name:
74
+ result = max(
75
+ result, lesson.completion_criteria.min_lesson_length
76
+ )
77
+ return result
78
+
79
+ def get_current_samplers(self) -> Dict[str, ParameterRandomizationSettings]:
80
+ """
81
+ Creates a dictionary from environment parameter name to their corresponding
82
+ ParameterRandomizationSettings. If curriculum is used, the
83
+ ParameterRandomizationSettings corresponds to the sampler of the current lesson.
84
+ """
85
+ samplers: Dict[str, ParameterRandomizationSettings] = {}
86
+ for param_name, settings in self._dict_settings.items():
87
+ lesson_num = GlobalTrainingStatus.get_parameter_state(
88
+ param_name, StatusType.LESSON_NUM
89
+ )
90
+ lesson = settings.curriculum[lesson_num]
91
+ samplers[param_name] = lesson.value
92
+ return samplers
93
+
94
+ def get_current_lesson_number(self) -> Dict[str, int]:
95
+ """
96
+ Creates a dictionary from environment parameter to the current lesson number.
97
+ If not using curriculum, this number is always 0 for that environment parameter.
98
+ """
99
+ result: Dict[str, int] = {}
100
+ for parameter_name in self._dict_settings.keys():
101
+ result[parameter_name] = GlobalTrainingStatus.get_parameter_state(
102
+ parameter_name, StatusType.LESSON_NUM
103
+ )
104
+ return result
105
+
106
+ def log_current_lesson(self, parameter_name: Optional[str] = None) -> None:
107
+ """
108
+ Logs the current lesson number and sampler value of the parameter with name
109
+ parameter_name. If no parameter_name is provided, the values and lesson
110
+ numbers of all parameters will be displayed.
111
+ """
112
+ if parameter_name is not None:
113
+ settings = self._dict_settings[parameter_name]
114
+ lesson_number = GlobalTrainingStatus.get_parameter_state(
115
+ parameter_name, StatusType.LESSON_NUM
116
+ )
117
+ lesson_name = settings.curriculum[lesson_number].name
118
+ lesson_value = settings.curriculum[lesson_number].value
119
+ logger.info(
120
+ f"Parameter '{parameter_name}' is in lesson '{lesson_name}' "
121
+ f"and has value '{lesson_value}'."
122
+ )
123
+ else:
124
+ for parameter_name, settings in self._dict_settings.items():
125
+ lesson_number = GlobalTrainingStatus.get_parameter_state(
126
+ parameter_name, StatusType.LESSON_NUM
127
+ )
128
+ lesson_name = settings.curriculum[lesson_number].name
129
+ lesson_value = settings.curriculum[lesson_number].value
130
+ logger.info(
131
+ f"Parameter '{parameter_name}' is in lesson '{lesson_name}' "
132
+ f"and has value '{lesson_value}'."
133
+ )
134
+
135
+ def update_lessons(
136
+ self,
137
+ trainer_steps: Dict[str, int],
138
+ trainer_max_steps: Dict[str, int],
139
+ trainer_reward_buffer: Dict[str, List[float]],
140
+ ) -> Tuple[bool, bool]:
141
+ """
142
+ Given progress metrics, calculates if at least one environment parameter is
143
+ in a new lesson and if at least one environment parameter requires the env
144
+ to reset.
145
+ :param trainer_steps: A dictionary from behavior_name to the number of training
146
+ steps this behavior's trainer has performed.
147
+ :param trainer_max_steps: A dictionary from behavior_name to the maximum number
148
+ of training steps this behavior's trainer has performed.
149
+ :param trainer_reward_buffer: A dictionary from behavior_name to the list of
150
+ the most recent episode returns for this behavior's trainer.
151
+ :returns: A tuple of two booleans : (True if any lesson has changed, True if
152
+ environment needs to reset)
153
+ """
154
+ must_reset = False
155
+ updated = False
156
+ for param_name, settings in self._dict_settings.items():
157
+ lesson_num = GlobalTrainingStatus.get_parameter_state(
158
+ param_name, StatusType.LESSON_NUM
159
+ )
160
+ next_lesson_num = lesson_num + 1
161
+ lesson = settings.curriculum[lesson_num]
162
+ if (
163
+ lesson.completion_criteria is not None
164
+ and len(settings.curriculum) > next_lesson_num
165
+ ):
166
+ behavior_to_consider = lesson.completion_criteria.behavior
167
+ if behavior_to_consider in trainer_steps:
168
+ (
169
+ must_increment,
170
+ new_smoothing,
171
+ ) = lesson.completion_criteria.need_increment(
172
+ float(trainer_steps[behavior_to_consider])
173
+ / float(trainer_max_steps[behavior_to_consider]),
174
+ trainer_reward_buffer[behavior_to_consider],
175
+ self._smoothed_values[param_name],
176
+ )
177
+ self._smoothed_values[param_name] = new_smoothing
178
+ if must_increment:
179
+ GlobalTrainingStatus.set_parameter_state(
180
+ param_name, StatusType.LESSON_NUM, next_lesson_num
181
+ )
182
+ self.log_current_lesson(param_name)
183
+ updated = True
184
+ if lesson.completion_criteria.require_reset:
185
+ must_reset = True
186
+ return updated, must_reset
MLPY/Lib/site-packages/mlagents/trainers/exception.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contains exceptions for the trainers package.
3
+ """
4
+
5
+
6
+ class TrainerError(Exception):
7
+ """
8
+ Any error related to the trainers in the ML-Agents Toolkit.
9
+ """
10
+
11
+ pass
12
+
13
+
14
+ class TrainerConfigError(Exception):
15
+ """
16
+ Any error related to the configuration of trainers in the ML-Agents Toolkit.
17
+ """
18
+
19
+ pass
20
+
21
+
22
+ class TrainerConfigWarning(Warning):
23
+ """
24
+ Any warning related to the configuration of trainers in the ML-Agents Toolkit.
25
+ """
26
+
27
+ pass
28
+
29
+
30
+ class CurriculumError(TrainerError):
31
+ """
32
+ Any error related to training with a curriculum.
33
+ """
34
+
35
+ pass
36
+
37
+
38
+ class CurriculumLoadingError(CurriculumError):
39
+ """
40
+ Any error related to loading the Curriculum config file.
41
+ """
42
+
43
+ pass
44
+
45
+
46
+ class CurriculumConfigError(CurriculumError):
47
+ """
48
+ Any error related to processing the Curriculum config file.
49
+ """
50
+
51
+ pass
52
+
53
+
54
+ class MetaCurriculumError(TrainerError):
55
+ """
56
+ Any error related to the configuration of a metacurriculum.
57
+ """
58
+
59
+ pass
60
+
61
+
62
+ class SamplerException(TrainerError):
63
+ """
64
+ Related to errors with the sampler actions.
65
+ """
66
+
67
+ pass
68
+
69
+
70
+ class UnityTrainerException(TrainerError):
71
+ """
72
+ Related to errors with the Trainer.
73
+ """
74
+
75
+ pass
MLPY/Lib/site-packages/mlagents/trainers/ghost/__init__.py ADDED
File without changes