File size: 1,835 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import sys
import uuid
import mlagents_envs
from mlagents_envs.exception import UnityCommunicationException
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage
from mlagents_envs.communicator_objects.training_analytics_pb2 import (
TrainingEnvironmentInitialized,
)
from google.protobuf.any_pb2 import Any
class DefaultTrainingAnalyticsSideChannel(SideChannel):
"""
Side channel that sends information about the training to the Unity environment so it can be logged.
"""
CHANNEL_ID = uuid.UUID("b664a4a9-d86f-5a5f-95cb-e8353a7e8356")
def __init__(self) -> None:
# >>> uuid.uuid5(uuid.NAMESPACE_URL, "com.unity.ml-agents/TrainingAnalyticsSideChannel")
# UUID('b664a4a9-d86f-5a5f-95cb-e8353a7e8356')
# We purposefully use the SAME side channel as the TrainingAnalyticsSideChannel
super().__init__(DefaultTrainingAnalyticsSideChannel.CHANNEL_ID)
def on_message_received(self, msg: IncomingMessage) -> None:
raise UnityCommunicationException(
"The DefaultTrainingAnalyticsSideChannel received a message from Unity, "
+ "this should not have happened."
)
def environment_initialized(self) -> None:
# Tuple of (major, minor, patch)
vi = sys.version_info
msg = TrainingEnvironmentInitialized(
python_version=f"{vi[0]}.{vi[1]}.{vi[2]}",
mlagents_version="Custom",
mlagents_envs_version=mlagents_envs.__version__,
torch_version="Unknown",
torch_device_type="Unknown",
)
any_message = Any()
any_message.Pack(msg)
env_init_msg = OutgoingMessage()
env_init_msg.set_raw_bytes(any_message.SerializeToString()) # type: ignore
super().queue_message_to_send(env_init_msg)
|