Spaces:
Running
Running
File size: 2,717 Bytes
e11e4fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import sys
from typing import List
# importlib.metadata is new in python3.8
# We use the backport for older python versions.
if sys.version_info < (3, 8):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata # pylint: disable=E0611
from mlagents.trainers.stats import StatsWriter
from mlagents_envs import logging_util
from mlagents.plugins import ML_AGENTS_STATS_WRITER
from mlagents.trainers.settings import RunOptions
from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter
logger = logging_util.get_logger(__name__)
def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]:
"""
The StatsWriters that mlagents-learn always uses:
* A TensorboardWriter to write information to TensorBoard
* A GaugeWriter to record our internal stats
* A ConsoleWriter to output to stdout.
"""
checkpoint_settings = run_options.checkpoint_settings
return [
TensorboardWriter(
checkpoint_settings.write_path,
clear_past_data=not checkpoint_settings.resume,
hidden_keys=["Is Training", "Step"],
),
GaugeWriter(),
ConsoleWriter(),
]
def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]:
"""
Registers all StatsWriter plugins (including the default one),
and evaluates them, and returns the list of all the StatsWriter implementations.
"""
all_stats_writers: List[StatsWriter] = []
if ML_AGENTS_STATS_WRITER not in importlib_metadata.entry_points():
logger.warning(
f"Unable to find any entry points for {ML_AGENTS_STATS_WRITER}, even the default ones. "
"Uninstalling and reinstalling ml-agents via pip should resolve. "
"Using default plugins for now."
)
return get_default_stats_writers(run_options)
entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER]
for entry_point in entry_points:
try:
logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}")
plugin_func = entry_point.load()
plugin_stats_writers = plugin_func(run_options)
logger.debug(
f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}"
)
all_stats_writers += plugin_stats_writers
except BaseException:
# Catch all exceptions from setting up the plugin, so that bad user code doesn't break things.
logger.exception(
f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used."
)
return all_stats_writers
|