File size: 2,907 Bytes
9b19c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional

import numpy as np

from tianshou.utils.print import DataclassPPrintMixin

if TYPE_CHECKING:
    from tianshou.data import CollectStats, CollectStatsBase
    from tianshou.policy.base import TrainingStats


@dataclass(kw_only=True)
class SequenceSummaryStats(DataclassPPrintMixin):
    """A data structure for storing the statistics of a sequence."""

    mean: float
    std: float
    max: float
    min: float

    @classmethod
    def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "SequenceSummaryStats":
        return cls(
            mean=float(np.mean(sequence)),
            std=float(np.std(sequence)),
            max=float(np.max(sequence)),
            min=float(np.min(sequence)),
        )


@dataclass(kw_only=True)
class TimingStats(DataclassPPrintMixin):
    """A data structure for storing timing statistics."""

    total_time: float = 0.0
    """The total time elapsed."""
    train_time: float = 0.0
    """The total time elapsed for training (collecting samples plus model update)."""
    train_time_collect: float = 0.0
    """The total time elapsed for collecting training transitions."""
    train_time_update: float = 0.0
    """The total time elapsed for updating models."""
    test_time: float = 0.0
    """The total time elapsed for testing models."""
    update_speed: float = 0.0
    """The speed of updating (env_step per second)."""


@dataclass(kw_only=True)
class InfoStats(DataclassPPrintMixin):
    """A data structure for storing information about the learning process."""

    gradient_step: int
    """The total gradient step."""
    best_reward: float
    """The best reward over the test results."""
    best_reward_std: float
    """Standard deviation of the best reward over the test results."""
    train_step: int
    """The total collected step of training collector."""
    train_episode: int
    """The total collected episode of training collector."""
    test_step: int
    """The total collected step of test collector."""
    test_episode: int
    """The total collected episode of test collector."""

    timing: TimingStats
    """The timing statistics."""


@dataclass(kw_only=True)
class EpochStats(DataclassPPrintMixin):
    """A data structure for storing epoch statistics."""

    epoch: int
    """The current epoch."""

    train_collect_stat: "CollectStatsBase"
    """The statistics of the last call to the training collector."""
    test_collect_stat: Optional["CollectStats"]
    """The statistics of the last call to the test collector."""
    training_stat: Optional["TrainingStats"]
    """The statistics of the last model update step.
    Can be None if no model update is performed, typically in the last training iteration."""
    info_stat: InfoStats
    """The information of the collector."""