File size: 8,269 Bytes
9b19c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""The rliable-evaluation module provides a high-level interface to evaluate the results of an experiment with multiple runs
on different seeds using the rliable library. The API is experimental and subject to change!.
"""

import os
from dataclasses import asdict, dataclass, fields

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as sst
from rliable import library as rly
from rliable import plot_utils

from tianshou.highlevel.experiment import Experiment
from tianshou.utils import logging
from tianshou.utils.logger.base import DataScope

log = logging.getLogger(__name__)


@dataclass
class LoggedSummaryData:
    mean: np.ndarray
    std: np.ndarray
    max: np.ndarray
    min: np.ndarray


@dataclass
class LoggedCollectStats:
    env_step: np.ndarray | None = None
    n_collected_episodes: np.ndarray | None = None
    n_collected_steps: np.ndarray | None = None
    collect_time: np.ndarray | None = None
    collect_speed: np.ndarray | None = None
    returns_stat: LoggedSummaryData | None = None
    lens_stat: LoggedSummaryData | None = None

    @classmethod
    def from_data_dict(cls, data: dict) -> "LoggedCollectStats":
        """Create a LoggedCollectStats object from a dictionary.

        Converts SequenceSummaryStats from dict format to dataclass format and ignores fields that are not present.
        """
        field_names = [f.name for f in fields(cls)]
        for k, v in data.items():
            if k not in field_names:
                data.pop(k)
            if isinstance(v, dict):
                data[k] = LoggedSummaryData(**v)
        return cls(**data)


@dataclass
class RLiableExperimentResult:
    """The result of an experiment that can be used with the rliable library."""

    exp_dir: str
    """The base directory where each sub-directory contains the results of one experiment run."""

    test_episode_returns_RE: np.ndarray
    """The test episodes for each run of the experiment where each row corresponds to one run."""

    env_steps_E: np.ndarray
    """The number of environment steps at which the test episodes were evaluated."""

    @classmethod
    def load_from_disk(cls, exp_dir: str) -> "RLiableExperimentResult":
        """Load the experiment result from disk.

        :param exp_dir: The directory from where the experiment results are restored.
        """
        test_episode_returns = []
        env_step_at_test = None

        # TODO: env_step_at_test should not be defined in a loop and overwritten at each iteration
        #  just for retrieving them. We might need a cleaner directory structure.
        for entry in os.scandir(exp_dir):
            if entry.name.startswith(".") or not entry.is_dir():
                continue

            exp = Experiment.from_directory(entry.path)
            logger = exp.logger_factory.create_logger(
                entry.path,
                entry.name,
                None,
                asdict(exp.config),
            )
            data = logger.restore_logged_data(entry.path)

            if DataScope.TEST.value not in data or not data[DataScope.TEST.value]:
                continue
            restored_test_data = data[DataScope.TEST.value]
            if not isinstance(restored_test_data, dict):
                raise RuntimeError(
                    f"Expected entry with key {DataScope.TEST.value} data to be a dictionary, "
                    f"but got {restored_test_data=}.",
                )
            test_data = LoggedCollectStats.from_data_dict(restored_test_data)

            if test_data.returns_stat is None:
                continue
            test_episode_returns.append(test_data.returns_stat.mean)
            env_step_at_test = test_data.env_step

        if not test_episode_returns or env_step_at_test is None:
            raise ValueError(f"No experiment data found in {exp_dir}.")

        return cls(
            test_episode_returns_RE=np.array(test_episode_returns),
            env_steps_E=np.array(env_step_at_test),
            exp_dir=exp_dir,
        )

    def _get_rliable_data(
        self,
        algo_name: str | None = None,
        score_thresholds: np.ndarray | None = None,
    ) -> tuple[dict, np.ndarray, np.ndarray]:
        """Return the data in the format expected by the rliable library.

        :param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
            is set to the experiment dir.
        :param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
            from the minimum and maximum test episode returns.

        :return: A tuple score_dict, env_steps, and score_thresholds.
        """
        if score_thresholds is None:
            score_thresholds = np.linspace(
                np.min(self.test_episode_returns_RE),
                np.max(self.test_episode_returns_RE),
                101,
            )

        if algo_name is None:
            algo_name = os.path.basename(self.exp_dir)

        score_dict = {algo_name: self.test_episode_returns_RE}

        return score_dict, self.env_steps_E, score_thresholds

    def eval_results(
        self,
        algo_name: str | None = None,
        score_thresholds: np.ndarray | None = None,
        save_plots: bool = False,
        show_plots: bool = True,
    ) -> tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]:
        """Evaluate the results of an experiment and create a sample efficiency curve and a performance profile.

        :param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm
            is set to the experiment dir.
        :param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred
            from the minimum and maximum test episode returns.
        :param save_plots: If True, the figures are saved to the experiment directory.
        :param show_plots: If True, the figures are shown.

        :return: The created figures and axes.
        """
        score_dict, env_steps, score_thresholds = self._get_rliable_data(
            algo_name,
            score_thresholds,
        )

        iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0)
        iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm)

        # Plot IQM sample efficiency curve
        fig_iqm, ax_iqm = plt.subplots(ncols=1, figsize=(7, 5), constrained_layout=True)
        plot_utils.plot_sample_efficiency_curve(
            env_steps,
            iqm_scores,
            iqm_cis,
            algorithms=None,
            xlabel="env step",
            ylabel="IQM episode return",
            ax=ax_iqm,
        )
        if show_plots:
            plt.show(block=False)

        if save_plots:
            iqm_sample_efficiency_curve_path = os.path.abspath(
                os.path.join(
                    self.exp_dir,
                    "iqm_sample_efficiency_curve.png",
                ),
            )
            log.info(f"Saving iqm sample efficiency curve to {iqm_sample_efficiency_curve_path}.")
            fig_iqm.savefig(iqm_sample_efficiency_curve_path)

        final_score_dict = {algo: returns[:, [-1]] for algo, returns in score_dict.items()}
        score_distributions, score_distributions_cis = rly.create_performance_profile(
            final_score_dict,
            score_thresholds,
        )

        # Plot score distributions
        fig_profile, ax_profile = plt.subplots(ncols=1, figsize=(7, 5), constrained_layout=True)
        plot_utils.plot_performance_profiles(
            score_distributions,
            score_thresholds,
            performance_profile_cis=score_distributions_cis,
            xlabel=r"Episode return $(\tau)$",
            ax=ax_profile,
        )

        if save_plots:
            profile_curve_path = os.path.abspath(
                os.path.join(self.exp_dir, "performance_profile.png"),
            )
            log.info(f"Saving performance profile curve to {profile_curve_path}.")
            fig_profile.savefig(profile_curve_path)
        if show_plots:
            plt.show(block=False)

        return fig_iqm, ax_iqm, fig_profile, ax_profile