Kano001's picture
Upload 280 files
e11e4fe verified
raw
history blame
10.2 kB
# # Unity ML-Agents Toolkit
from mlagents import torch_utils
import yaml
import os
import numpy as np
import json
from typing import Callable, Optional, List
import mlagents.trainers
import mlagents_envs
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.trainer import TrainerFactory
from mlagents.trainers.directory_utils import (
validate_existing_directories,
setup_init_path,
)
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.cli_utils import parser
from mlagents_envs.environment import UnityEnvironment
from mlagents.trainers.settings import RunOptions
from mlagents.trainers.training_status import GlobalTrainingStatus
from mlagents_envs.base_env import BaseEnv
from mlagents.trainers.subprocess_env_manager import SubprocessEnvManager
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.timers import (
hierarchical_timer,
get_timer_tree,
add_metadata as add_timer_metadata,
)
from mlagents_envs import logging_util
from mlagents.plugins.stats_writer import register_stats_writer_plugins
from mlagents.plugins.trainer_type import register_trainer_plugins
logger = logging_util.get_logger(__name__)
TRAINING_STATUS_FILE_NAME = "training_status.json"
def get_version_string() -> str:
return f""" Version information:
ml-agents: {mlagents.trainers.__version__},
ml-agents-envs: {mlagents_envs.__version__},
Communicator API: {UnityEnvironment.API_VERSION},
PyTorch: {torch_utils.torch.__version__}"""
def parse_command_line(
argv: Optional[List[str]] = None,
) -> RunOptions:
_, _ = register_trainer_plugins()
args = parser.parse_args(argv)
return RunOptions.from_argparse(args)
def run_training(run_seed: int, options: RunOptions, num_areas: int) -> None:
"""
Launches training session.
:param run_seed: Random seed used for training.
:param num_areas: Number of training areas to instantiate
:param options: parsed command line arguments
"""
with hierarchical_timer("run_training.setup"):
torch_utils.set_torch_config(options.torch_settings)
checkpoint_settings = options.checkpoint_settings
env_settings = options.env_settings
engine_settings = options.engine_settings
run_logs_dir = checkpoint_settings.run_logs_dir
port: Optional[int] = env_settings.base_port
# Check if directory exists
validate_existing_directories(
checkpoint_settings.write_path,
checkpoint_settings.resume,
checkpoint_settings.force,
checkpoint_settings.maybe_init_path,
)
# Make run logs directory
os.makedirs(run_logs_dir, exist_ok=True)
# Load any needed states in case of resume
if checkpoint_settings.resume:
GlobalTrainingStatus.load_state(
os.path.join(run_logs_dir, "training_status.json")
)
# In case of initialization, set full init_path for all behaviors
elif checkpoint_settings.maybe_init_path is not None:
setup_init_path(options.behaviors, checkpoint_settings.maybe_init_path)
# Configure Tensorboard Writers and StatsReporter
stats_writers = register_stats_writer_plugins(options)
for sw in stats_writers:
StatsReporter.add_writer(sw)
if env_settings.env_path is None:
port = None
env_factory = create_environment_factory(
env_settings.env_path,
engine_settings.no_graphics,
run_seed,
num_areas,
port,
env_settings.env_args,
os.path.abspath(run_logs_dir), # Unity environment requires absolute path
)
env_manager = SubprocessEnvManager(env_factory, options, env_settings.num_envs)
env_parameter_manager = EnvironmentParameterManager(
options.environment_parameters, run_seed, restore=checkpoint_settings.resume
)
trainer_factory = TrainerFactory(
trainer_config=options.behaviors,
output_path=checkpoint_settings.write_path,
train_model=not checkpoint_settings.inference,
load_model=checkpoint_settings.resume,
seed=run_seed,
param_manager=env_parameter_manager,
init_path=checkpoint_settings.maybe_init_path,
multi_gpu=False,
)
# Create controller and begin training.
tc = TrainerController(
trainer_factory,
checkpoint_settings.write_path,
checkpoint_settings.run_id,
env_parameter_manager,
not checkpoint_settings.inference,
run_seed,
)
# Begin training
try:
tc.start_learning(env_manager)
finally:
env_manager.close()
write_run_options(checkpoint_settings.write_path, options)
write_timing_tree(run_logs_dir)
write_training_status(run_logs_dir)
def write_run_options(output_dir: str, run_options: RunOptions) -> None:
run_options_path = os.path.join(output_dir, "configuration.yaml")
try:
with open(run_options_path, "w") as f:
try:
yaml.dump(run_options.as_dict(), f, sort_keys=False)
except TypeError: # Older versions of pyyaml don't support sort_keys
yaml.dump(run_options.as_dict(), f)
except FileNotFoundError:
logger.warning(
f"Unable to save configuration to {run_options_path}. Make sure the directory exists"
)
def write_training_status(output_dir: str) -> None:
GlobalTrainingStatus.save_state(os.path.join(output_dir, TRAINING_STATUS_FILE_NAME))
def write_timing_tree(output_dir: str) -> None:
timing_path = os.path.join(output_dir, "timers.json")
try:
with open(timing_path, "w") as f:
json.dump(get_timer_tree(), f, indent=4)
except FileNotFoundError:
logger.warning(
f"Unable to save to {timing_path}. Make sure the directory exists"
)
def create_environment_factory(
env_path: Optional[str],
no_graphics: bool,
seed: int,
num_areas: int,
start_port: Optional[int],
env_args: Optional[List[str]],
log_folder: str,
) -> Callable[[int, List[SideChannel]], BaseEnv]:
def create_unity_environment(
worker_id: int, side_channels: List[SideChannel]
) -> UnityEnvironment:
# Make sure that each environment gets a different seed
env_seed = seed + worker_id
return UnityEnvironment(
file_name=env_path,
worker_id=worker_id,
seed=env_seed,
num_areas=num_areas,
no_graphics=no_graphics,
base_port=start_port,
additional_args=env_args,
side_channels=side_channels,
log_folder=log_folder,
)
return create_unity_environment
def run_cli(options: RunOptions) -> None:
try:
print(
"""
┐ β•–
╓╖╬│║ ││╬╖╖
β•“β•–β•¬β”‚β”‚β”‚β”‚β”‚β”˜ ╬│││││╬╖
β•–β•¬β”‚β”‚β”‚β”‚β”‚β•¬β•œ ╙╬│││││╖╖ β•—β•—β•—
╬╬╬╬╖││╦╖ ╖╬││╗╣╣╣╬ β•Ÿβ•£β•£β•¬ β•Ÿβ•£β•£β•£ β•œβ•œβ•œ β•Ÿβ•£β•£
╬╬╬╬╬╬╬╬╖│╬╖╖╓╬β•ͺ│╓╣╣╣╣╣╣╣╬ β•Ÿβ•£β•£β•¬ β•Ÿβ•£β•£β•£ β•’β•£β•£β•–β•—β•£β•£β•£β•— β•£β•£β•£ β•£β•£β•£β•£β•£β•£ β•Ÿβ•£β•£β•– β•£β•£β•£
╬╬╬╬┐ β•™β•¬β•¬β•¬β•¬β”‚β•“β•£β•£β•£β•β•œ ╫╣╣╣╬ β•Ÿβ•£β•£β•¬ β•Ÿβ•£β•£β•£ β•Ÿβ•£β•£β•£β•™ β•™β•£β•£β•£ β•£β•£β•£ β•™β•Ÿβ•£β•£β•œβ•™ β•«β•£β•£ β•Ÿβ•£β•£
╬╬╬╬┐ ╙╬╬╣╣ ╫╣╣╣╬ β•Ÿβ•£β•£β•¬ β•Ÿβ•£β•£β•£ β•Ÿβ•£β•£β•¬ β•£β•£β•£ β•£β•£β•£ β•Ÿβ•£β•£ β•£β•£β•£β”Œβ•£β•£β•œ
β•¬β•¬β•¬β•œ ╬╬╣╣ ╙╝╣╣╬ β•™β•£β•£β•£β•—β•–β•“β•—β•£β•£β•£β•œ β•Ÿβ•£β•£β•¬ β•£β•£β•£ β•£β•£β•£ β•Ÿβ•£β•£β•¦β•“ β•£β•£β•£β•£β•£
β•™ ╓╦╖ ╬╬╣╣ β•“β•—β•—β•– β•™β•β•£β•£β•£β•£β•β•œ β•˜β•β•β•œ ╝╝╝ ╝╝╝ β•™β•£β•£β•£ β•Ÿβ•£β•£β•£
╩╬╬╬╬╬╬╦╦╬╬╣╣╗╣╣╣╣╣╣╣╝ β•«β•£β•£β•£β•£
β•™β•¬β•¬β•¬β•¬β•¬β•¬β•¬β•£β•£β•£β•£β•£β•£β•β•œ
β•™β•¬β•¬β•¬β•£β•£β•£β•œ
β•™
"""
)
except Exception:
print("\n\n\tUnity Technologies\n")
print(get_version_string())
if options.debug:
log_level = logging_util.DEBUG
else:
log_level = logging_util.INFO
logging_util.set_log_level(log_level)
logger.debug("Configuration for this run:")
logger.debug(json.dumps(options.as_dict(), indent=4))
# Options deprecation warnings
if options.checkpoint_settings.load_model:
logger.warning(
"The --load option has been deprecated. Please use the --resume option instead."
)
if options.checkpoint_settings.train_model:
logger.warning(
"The --train option has been deprecated. Train mode is now the default. Use "
"--inference to run in inference mode."
)
run_seed = options.env_settings.seed
num_areas = options.env_settings.num_areas
# Add some timer metadata
add_timer_metadata("mlagents_version", mlagents.trainers.__version__)
add_timer_metadata("mlagents_envs_version", mlagents_envs.__version__)
add_timer_metadata("communication_protocol_version", UnityEnvironment.API_VERSION)
add_timer_metadata("pytorch_version", torch_utils.torch.__version__)
add_timer_metadata("numpy_version", np.__version__)
if options.env_settings.seed == -1:
run_seed = np.random.randint(0, 10000)
logger.debug(f"run_seed set to {run_seed}")
run_training(run_seed, options, num_areas)
def main():
run_cli(parse_command_line())
# For python debugger to directly run this script
if __name__ == "__main__":
main()