Spaces:
Running
Running
import sys | |
from typing import List | |
# importlib.metadata is new in python3.8 | |
# We use the backport for older python versions. | |
if sys.version_info < (3, 8): | |
import importlib_metadata | |
else: | |
import importlib.metadata as importlib_metadata # pylint: disable=E0611 | |
from mlagents.trainers.stats import StatsWriter | |
from mlagents_envs import logging_util | |
from mlagents.plugins import ML_AGENTS_STATS_WRITER | |
from mlagents.trainers.settings import RunOptions | |
from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter | |
logger = logging_util.get_logger(__name__) | |
def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]: | |
""" | |
The StatsWriters that mlagents-learn always uses: | |
* A TensorboardWriter to write information to TensorBoard | |
* A GaugeWriter to record our internal stats | |
* A ConsoleWriter to output to stdout. | |
""" | |
checkpoint_settings = run_options.checkpoint_settings | |
return [ | |
TensorboardWriter( | |
checkpoint_settings.write_path, | |
clear_past_data=not checkpoint_settings.resume, | |
hidden_keys=["Is Training", "Step"], | |
), | |
GaugeWriter(), | |
ConsoleWriter(), | |
] | |
def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]: | |
""" | |
Registers all StatsWriter plugins (including the default one), | |
and evaluates them, and returns the list of all the StatsWriter implementations. | |
""" | |
all_stats_writers: List[StatsWriter] = [] | |
if ML_AGENTS_STATS_WRITER not in importlib_metadata.entry_points(): | |
logger.warning( | |
f"Unable to find any entry points for {ML_AGENTS_STATS_WRITER}, even the default ones. " | |
"Uninstalling and reinstalling ml-agents via pip should resolve. " | |
"Using default plugins for now." | |
) | |
return get_default_stats_writers(run_options) | |
entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER] | |
for entry_point in entry_points: | |
try: | |
logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}") | |
plugin_func = entry_point.load() | |
plugin_stats_writers = plugin_func(run_options) | |
logger.debug( | |
f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}" | |
) | |
all_stats_writers += plugin_stats_writers | |
except BaseException: | |
# Catch all exceptions from setting up the plugin, so that bad user code doesn't break things. | |
logger.exception( | |
f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used." | |
) | |
return all_stats_writers | |