sabretoothedhugs's picture
v2
9b19c29
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import numpy as np
from tianshou.utils.print import DataclassPPrintMixin
if TYPE_CHECKING:
from tianshou.data import CollectStats, CollectStatsBase
from tianshou.policy.base import TrainingStats
@dataclass(kw_only=True)
class SequenceSummaryStats(DataclassPPrintMixin):
"""A data structure for storing the statistics of a sequence."""
mean: float
std: float
max: float
min: float
@classmethod
def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "SequenceSummaryStats":
return cls(
mean=float(np.mean(sequence)),
std=float(np.std(sequence)),
max=float(np.max(sequence)),
min=float(np.min(sequence)),
)
@dataclass(kw_only=True)
class TimingStats(DataclassPPrintMixin):
"""A data structure for storing timing statistics."""
total_time: float = 0.0
"""The total time elapsed."""
train_time: float = 0.0
"""The total time elapsed for training (collecting samples plus model update)."""
train_time_collect: float = 0.0
"""The total time elapsed for collecting training transitions."""
train_time_update: float = 0.0
"""The total time elapsed for updating models."""
test_time: float = 0.0
"""The total time elapsed for testing models."""
update_speed: float = 0.0
"""The speed of updating (env_step per second)."""
@dataclass(kw_only=True)
class InfoStats(DataclassPPrintMixin):
"""A data structure for storing information about the learning process."""
gradient_step: int
"""The total gradient step."""
best_reward: float
"""The best reward over the test results."""
best_reward_std: float
"""Standard deviation of the best reward over the test results."""
train_step: int
"""The total collected step of training collector."""
train_episode: int
"""The total collected episode of training collector."""
test_step: int
"""The total collected step of test collector."""
test_episode: int
"""The total collected episode of test collector."""
timing: TimingStats
"""The timing statistics."""
@dataclass(kw_only=True)
class EpochStats(DataclassPPrintMixin):
"""A data structure for storing epoch statistics."""
epoch: int
"""The current epoch."""
train_collect_stat: "CollectStatsBase"
"""The statistics of the last call to the training collector."""
test_collect_stat: Optional["CollectStats"]
"""The statistics of the last call to the test collector."""
training_stat: Optional["TrainingStats"]
"""The statistics of the last model update step.
Can be None if no model update is performed, typically in the last training iteration."""
info_stat: InfoStats
"""The information of the collector."""