File size: 2,717 Bytes
e11e4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import sys
from typing import List

# importlib.metadata is new in python3.8
# We use the backport for older python versions.
if sys.version_info < (3, 8):
    import importlib_metadata
else:
    import importlib.metadata as importlib_metadata  # pylint: disable=E0611

from mlagents.trainers.stats import StatsWriter

from mlagents_envs import logging_util
from mlagents.plugins import ML_AGENTS_STATS_WRITER
from mlagents.trainers.settings import RunOptions
from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter


logger = logging_util.get_logger(__name__)


def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]:
    """
    The StatsWriters that mlagents-learn always uses:
    * A TensorboardWriter to write information to TensorBoard
    * A GaugeWriter to record our internal stats
    * A ConsoleWriter to output to stdout.
    """
    checkpoint_settings = run_options.checkpoint_settings
    return [
        TensorboardWriter(
            checkpoint_settings.write_path,
            clear_past_data=not checkpoint_settings.resume,
            hidden_keys=["Is Training", "Step"],
        ),
        GaugeWriter(),
        ConsoleWriter(),
    ]


def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]:
    """
    Registers all StatsWriter plugins (including the default one),
    and evaluates them, and returns the list of all the StatsWriter implementations.
    """
    all_stats_writers: List[StatsWriter] = []
    if ML_AGENTS_STATS_WRITER not in importlib_metadata.entry_points():
        logger.warning(
            f"Unable to find any entry points for {ML_AGENTS_STATS_WRITER}, even the default ones. "
            "Uninstalling and reinstalling ml-agents via pip should resolve. "
            "Using default plugins for now."
        )
        return get_default_stats_writers(run_options)

    entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER]

    for entry_point in entry_points:

        try:
            logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}")
            plugin_func = entry_point.load()
            plugin_stats_writers = plugin_func(run_options)
            logger.debug(
                f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}"
            )
            all_stats_writers += plugin_stats_writers
        except BaseException:
            # Catch all exceptions from setting up the plugin, so that bad user code doesn't break things.
            logger.exception(
                f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used."
            )
    return all_stats_writers