# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Computes the RL Reliability Metrics.""" from dataclasses import dataclass from typing import List, Optional import datasets import numpy as np from rl_reliability_metrics.evaluation import eval_metrics from rl_reliability_metrics.metrics import metrics_offline, metrics_online import evaluate logger = evaluate.logging.get_logger(__name__) DEFAULT_EVAL_POINTS = [ 50000, 150000, 250000, 350000, 450000, 550000, 650000, 750000, 850000, 950000, 1050000, 1150000, 1250000, 1350000, 1450000, 1550000, 1650000, 1750000, 1850000, 1950000, ] N_RUNS_RECOMMENDED = 10 _CITATION = """\ @conference{rl_reliability_metrics, title = {Measuring the Reliability of Reinforcement Learning Algorithms}, author = {Stephanie CY Chan, Sam Fishman, John Canny, Anoop Korattikara, and Sergio Guadarrama}, booktitle = {International Conference on Learning Representations, Addis Ababa, Ethiopia}, year = 2020, } """ _DESCRIPTION = """\ Computes the RL reliability metrics from a set of experiments. There is an `"online"` and `"offline"` configuration for evaluation. """ _KWARGS_DESCRIPTION = """ Computes the RL reliability metrics from a set of experiments. There is an `"online"` and `"offline"` configuration for evaluation. Args: timestamps: list of timestep lists/arrays that serve as index. rewards: list of reward lists/arrays of each experiment. Returns: dictionary: a set of reliability metrics Examples: >>> import numpy as np >>> rl_reliability = evaluate.load("rl_reliability", "online") >>> results = rl_reliability.compute( ... timesteps=[np.linspace(0, 2000000, 1000)], ... rewards=[np.linspace(0, 100, 1000)] ... ) >>> print(results["LowerCVaROnRaw"].round(4)) [0.0258] """ @dataclass class RLReliabilityConfig(evaluate.info.Config): name: str = "default" baseline: str = "default" freq_thresh: float = 0.01 window_size: int = 100000 window_size_trimmed: int = 99000 alpha: float = 0.05 eval_points: Optional[List] = None @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class RLReliability(evaluate.Metric): """Computes the RL Reliability Metrics.""" CONFIG_CLASS = RLReliabilityConfig ALLOWED_CONFIG_NAMES = ["online", "offline"] def _info(self, config): if self.config_name not in ["online", "offline"]: raise KeyError("""You should supply a configuration name selected in '["online", "offline"]'""") return evaluate.MetricInfo( module_type="metric", description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, config=config, features=datasets.Features( { "timesteps": datasets.Sequence(datasets.Value("int64")), "rewards": datasets.Sequence(datasets.Value("float")), } ), homepage="https://github.com/google-research/rl-reliability-metrics", ) def _compute( self, timesteps, rewards, ): if len(timesteps) < N_RUNS_RECOMMENDED: logger.warning( f"For robust statistics it is recommended to use at least {N_RUNS_RECOMMENDED} runs whereas you provided {len(timesteps)}." ) baseline = self.config.baseline freq_thresh = self.config.freq_thresh window_size = self.config.window_size window_size_trimmed = self.config.window_size_trimmed alpha = self.config.alpha eval_points = self.config.eval_points curves = [] for timestep, reward in zip(timesteps, rewards): curves.append(np.stack([timestep, reward])) if self.config_name == "online": if baseline == "default": baseline = "curve_range" if eval_points is None: eval_points = DEFAULT_EVAL_POINTS metrics = [ metrics_online.HighFreqEnergyWithinRuns(thresh=freq_thresh), metrics_online.IqrWithinRuns( window_size=window_size_trimmed, eval_points=eval_points, baseline=baseline ), metrics_online.IqrAcrossRuns( lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline ), metrics_online.LowerCVaROnDiffs(baseline=baseline), metrics_online.LowerCVaROnDrawdown(baseline=baseline), metrics_online.LowerCVaROnAcross( lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline ), metrics_online.LowerCVaROnRaw(alpha=alpha, baseline=baseline), metrics_online.MadAcrossRuns( lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline ), metrics_online.MadWithinRuns( eval_points=eval_points, window_size=window_size_trimmed, baseline=baseline ), metrics_online.MaxDrawdown(), metrics_online.StddevAcrossRuns( lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline ), metrics_online.StddevWithinRuns( eval_points=eval_points, window_size=window_size_trimmed, baseline=baseline ), metrics_online.UpperCVaROnAcross( alpha=alpha, lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline, ), metrics_online.UpperCVaROnDiffs(alpha=alpha, baseline=baseline), metrics_online.UpperCVaROnDrawdown(alpha=alpha, baseline=baseline), metrics_online.UpperCVaROnRaw(alpha=alpha, baseline=baseline), metrics_online.MedianPerfDuringTraining(window_size=window_size, eval_points=eval_points), ] else: if baseline == "default": baseline = "median_perf" metrics = [ metrics_offline.MadAcrossRollouts(baseline=baseline), metrics_offline.IqrAcrossRollouts(baseline=baseline), metrics_offline.StddevAcrossRollouts(baseline=baseline), metrics_offline.LowerCVaRAcrossRollouts(alpha=alpha, baseline=baseline), metrics_offline.UpperCVaRAcrossRollouts(alpha=alpha, baseline=baseline), metrics_offline.MedianPerfAcrossRollouts(baseline=None), ] evaluator = eval_metrics.Evaluator(metrics=metrics) result = evaluator.compute_metrics(curves) return result