|
import uuid |
|
from typing import Tuple, List, Mapping |
|
from enum import Enum |
|
from collections import defaultdict |
|
|
|
from mlagents_envs.side_channel import SideChannel, IncomingMessage |
|
|
|
|
|
|
|
class StatsAggregationMethod(Enum): |
|
|
|
AVERAGE = 0 |
|
|
|
|
|
MOST_RECENT = 1 |
|
|
|
|
|
SUM = 2 |
|
|
|
|
|
HISTOGRAM = 3 |
|
|
|
|
|
StatList = List[Tuple[float, StatsAggregationMethod]] |
|
EnvironmentStats = Mapping[str, StatList] |
|
|
|
|
|
class StatsSideChannel(SideChannel): |
|
""" |
|
Side channel that receives (string, float) pairs from the environment, so that they can eventually |
|
be passed to a StatsReporter. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
|
|
|
|
super().__init__(uuid.UUID("a1d8f7b7-cec8-50f9-b78b-d3e165a78520")) |
|
|
|
self.stats: EnvironmentStats = defaultdict(list) |
|
|
|
def on_message_received(self, msg: IncomingMessage) -> None: |
|
""" |
|
Receive the message from the environment, and save it for later retrieval. |
|
|
|
:param msg: |
|
:return: |
|
""" |
|
key = msg.read_string() |
|
val = msg.read_float32() |
|
agg_type = StatsAggregationMethod(msg.read_int32()) |
|
|
|
self.stats[key].append((val, agg_type)) |
|
|
|
def get_and_reset_stats(self) -> EnvironmentStats: |
|
""" |
|
Returns the current stats, and resets the internal storage of the stats. |
|
|
|
:return: |
|
""" |
|
s = self.stats |
|
self.stats = defaultdict(list) |
|
return s |
|
|