Spaces:
Runtime error
Runtime error
Update Space (evaluate main: 1c421923)
Browse files- README.md +109 -6
- app.py +6 -0
- requirements.txt +7 -0
- rl_reliability.py +186 -0
README.md
CHANGED
@@ -1,12 +1,115 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
|
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.0.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: RL Reliability
|
3 |
+
datasets:
|
4 |
+
-
|
5 |
+
tags:
|
6 |
+
- evaluate
|
7 |
+
- metric
|
8 |
sdk: gradio
|
9 |
+
sdk_version: 3.0.2
|
10 |
app_file: app.py
|
11 |
pinned: false
|
12 |
---
|
13 |
|
14 |
+
# Metric Card for RL Reliability
|
15 |
+
|
16 |
+
## Metric Description
|
17 |
+
The RL Reliability Metrics library provides a set of metrics for measuring the reliability of reinforcement learning (RL) algorithms.
|
18 |
+
|
19 |
+
## How to Use
|
20 |
+
|
21 |
+
```python
|
22 |
+
import evaluate
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
rl_reliability = evaluate.load("rl_reliability", "online")
|
26 |
+
results = rl_reliability.compute(
|
27 |
+
timesteps=[np.linspace(0, 2000000, 1000)],
|
28 |
+
rewards=[np.linspace(0, 100, 1000)]
|
29 |
+
)
|
30 |
+
|
31 |
+
rl_reliability = evaluate.load("rl_reliability", "offline")
|
32 |
+
results = rl_reliability.compute(
|
33 |
+
timesteps=[np.linspace(0, 2000000, 1000)],
|
34 |
+
rewards=[np.linspace(0, 100, 1000)]
|
35 |
+
)
|
36 |
+
```
|
37 |
+
|
38 |
+
|
39 |
+
### Inputs
|
40 |
+
- **timesteps** *(List[int]): For each run a an list/array with its timesteps.*
|
41 |
+
- **rewards** *(List[float]): For each run a an list/array with its rewards.*
|
42 |
+
|
43 |
+
KWARGS:
|
44 |
+
- **baseline="default"** *(Union[str, float]) Normalization used for curves. When `"default"` is passed the curves are normalized by their range in the online setting and by the median performance across runs in the offline case. When a float is passed the curves are divided by that value.*
|
45 |
+
- **eval_points=[50000, 150000, ..., 2000000]** *(List[int]) Statistics will be computed at these points*
|
46 |
+
- **freq_thresh=0.01** *(float) Frequency threshold for low-pass filtering.*
|
47 |
+
- **window_size=100000** *(int) Defines a window centered at each eval point.*
|
48 |
+
- **window_size_trimmed=99000** *(int) To handle shortened curves due to differencing*
|
49 |
+
- **alpha=0.05** *(float)The "value at risk" (VaR) cutoff point, a float in the range [0,1].*
|
50 |
+
|
51 |
+
### Output Values
|
52 |
+
|
53 |
+
In `"online"` mode:
|
54 |
+
- HighFreqEnergyWithinRuns: High Frequency across Time (DT)
|
55 |
+
- IqrWithinRuns: IQR across Time (DT)
|
56 |
+
- MadWithinRuns: 'MAD across Time (DT)
|
57 |
+
- StddevWithinRuns: Stddev across Time (DT)
|
58 |
+
- LowerCVaROnDiffs: Lower CVaR on Differences (SRT)
|
59 |
+
- UpperCVaROnDiffs: Upper CVaR on Differences (SRT)
|
60 |
+
- MaxDrawdown: Max Drawdown (LRT)
|
61 |
+
- LowerCVaROnDrawdown: Lower CVaR on Drawdown (LRT)
|
62 |
+
- UpperCVaROnDrawdown: Upper CVaR on Drawdown (LRT)
|
63 |
+
- LowerCVaROnRaw: Lower CVaR on Raw
|
64 |
+
- UpperCVaROnRaw: Upper CVaR on Raw
|
65 |
+
- IqrAcrossRuns: IQR across Runs (DR)
|
66 |
+
- MadAcrossRuns: MAD across Runs (DR)
|
67 |
+
- StddevAcrossRuns: Stddev across Runs (DR)
|
68 |
+
- LowerCVaROnAcross: Lower CVaR across Runs (RR)
|
69 |
+
- UpperCVaROnAcross: Upper CVaR across Runs (RR)
|
70 |
+
- MedianPerfDuringTraining: Median Performance across Runs
|
71 |
+
|
72 |
+
In `"offline"` mode:
|
73 |
+
- MadAcrossRollouts: MAD across rollouts (DF)
|
74 |
+
- IqrAcrossRollouts: IQR across rollouts (DF)
|
75 |
+
- LowerCVaRAcrossRollouts: Lower CVaR across rollouts (RF)
|
76 |
+
- UpperCVaRAcrossRollouts: Upper CVaR across rollouts (RF)
|
77 |
+
- MedianPerfAcrossRollouts: Median Performance across rollouts
|
78 |
+
|
79 |
+
|
80 |
+
### Examples
|
81 |
+
First get the sample data from the repository:
|
82 |
+
|
83 |
+
```bash
|
84 |
+
wget https://storage.googleapis.com/rl-reliability-metrics/data/tf_agents_example_csv_dataset.tgz
|
85 |
+
tar -xvzf tf_agents_example_csv_dataset.tgz
|
86 |
+
```
|
87 |
+
|
88 |
+
Load the sample data:
|
89 |
+
```python
|
90 |
+
dfs = [pd.read_csv(f"./csv_data/sac_humanoid_{i}_train.csv") for i in range(1, 4)]
|
91 |
+
```
|
92 |
+
|
93 |
+
Compute the metrics:
|
94 |
+
```python
|
95 |
+
rl_reliability = evaluate.load("rl_reliability", "online")
|
96 |
+
rl_reliability.compute(timesteps=[df["Metrics/EnvironmentSteps"] for df in dfs],
|
97 |
+
rewards=[df["Metrics/AverageReturn"] for df in dfs])
|
98 |
+
```
|
99 |
+
|
100 |
+
## Limitations and Bias
|
101 |
+
This implementation of RL reliability metrics does not compute permutation tests to determine whether algorithms are statistically different in their metric values and also does not compute bootstrap confidence intervals on the rankings of the algorithms. See the [original library](https://github.com/google-research/rl-reliability-metrics/) for more resources.
|
102 |
+
|
103 |
+
## Citation
|
104 |
+
|
105 |
+
```bibtex
|
106 |
+
@conference{rl_reliability_metrics,
|
107 |
+
title = {Measuring the Reliability of Reinforcement Learning Algorithms},
|
108 |
+
author = {Stephanie CY Chan, Sam Fishman, John Canny, Anoop Korattikara, and Sergio Guadarrama},
|
109 |
+
booktitle = {International Conference on Learning Representations, Addis Ababa, Ethiopia},
|
110 |
+
year = 2020,
|
111 |
+
}
|
112 |
+
```
|
113 |
+
|
114 |
+
## Further References
|
115 |
+
- Homepage: https://github.com/google-research/rl-reliability-metrics
|
app.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import evaluate
|
2 |
+
from evaluate.utils import launch_gradio_widget
|
3 |
+
|
4 |
+
|
5 |
+
module = evaluate.load("rl_reliability", "online")
|
6 |
+
launch_gradio_widget(module)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO: fix github to release
|
2 |
+
git+https://github.com/huggingface/evaluate.git@main
|
3 |
+
datasets~=2.0
|
4 |
+
git+https://github.com/google-research/rl-reliability-metrics
|
5 |
+
scipy
|
6 |
+
tensorflow
|
7 |
+
gin-config
|
rl_reliability.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Computes the RL Reliability Metrics."""
|
15 |
+
|
16 |
+
import datasets
|
17 |
+
import numpy as np
|
18 |
+
from rl_reliability_metrics.evaluation import eval_metrics
|
19 |
+
from rl_reliability_metrics.metrics import metrics_offline, metrics_online
|
20 |
+
|
21 |
+
import evaluate
|
22 |
+
|
23 |
+
|
24 |
+
logger = evaluate.logging.get_logger(__name__)
|
25 |
+
|
26 |
+
DEFAULT_EVAL_POINTS = [
|
27 |
+
50000,
|
28 |
+
150000,
|
29 |
+
250000,
|
30 |
+
350000,
|
31 |
+
450000,
|
32 |
+
550000,
|
33 |
+
650000,
|
34 |
+
750000,
|
35 |
+
850000,
|
36 |
+
950000,
|
37 |
+
1050000,
|
38 |
+
1150000,
|
39 |
+
1250000,
|
40 |
+
1350000,
|
41 |
+
1450000,
|
42 |
+
1550000,
|
43 |
+
1650000,
|
44 |
+
1750000,
|
45 |
+
1850000,
|
46 |
+
1950000,
|
47 |
+
]
|
48 |
+
|
49 |
+
N_RUNS_RECOMMENDED = 10
|
50 |
+
|
51 |
+
_CITATION = """\
|
52 |
+
@conference{rl_reliability_metrics,
|
53 |
+
title = {Measuring the Reliability of Reinforcement Learning Algorithms},
|
54 |
+
author = {Stephanie CY Chan, Sam Fishman, John Canny, Anoop Korattikara, and Sergio Guadarrama},
|
55 |
+
booktitle = {International Conference on Learning Representations, Addis Ababa, Ethiopia},
|
56 |
+
year = 2020,
|
57 |
+
}
|
58 |
+
"""
|
59 |
+
|
60 |
+
_DESCRIPTION = """\
|
61 |
+
This new module is designed to solve this great NLP task and is crafted with a lot of care.
|
62 |
+
"""
|
63 |
+
|
64 |
+
|
65 |
+
_KWARGS_DESCRIPTION = """
|
66 |
+
Computes the RL reliability metrics from a set of experiments. There is an `"online"` and `"offline"` configuration for evaluation.
|
67 |
+
Args:
|
68 |
+
timestamps: list of timestep lists/arrays that serve as index.
|
69 |
+
rewards: list of reward lists/arrays of each experiment.
|
70 |
+
Returns:
|
71 |
+
dictionary: a set of reliability metrics
|
72 |
+
Examples:
|
73 |
+
>>> import numpy as np
|
74 |
+
>>> rl_reliability = evaluate.load("rl_reliability", "online")
|
75 |
+
>>> results = rl_reliability.compute(
|
76 |
+
... timesteps=[np.linspace(0, 2000000, 1000)],
|
77 |
+
... rewards=[np.linspace(0, 100, 1000)]
|
78 |
+
... )
|
79 |
+
>>> print(results["LowerCVaROnRaw"].round(4))
|
80 |
+
[0.0258]
|
81 |
+
"""
|
82 |
+
|
83 |
+
|
84 |
+
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
85 |
+
class RLReliability(evaluate.EvaluationModule):
|
86 |
+
"""Computes the RL Reliability Metrics."""
|
87 |
+
|
88 |
+
def _info(self):
|
89 |
+
if self.config_name not in ["online", "offline"]:
|
90 |
+
raise KeyError("""You should supply a configuration name selected in '["online", "offline"]'""")
|
91 |
+
|
92 |
+
return evaluate.EvaluationModuleInfo(
|
93 |
+
module_type="metric",
|
94 |
+
description=_DESCRIPTION,
|
95 |
+
citation=_CITATION,
|
96 |
+
inputs_description=_KWARGS_DESCRIPTION,
|
97 |
+
features=datasets.Features(
|
98 |
+
{
|
99 |
+
"timesteps": datasets.Sequence(datasets.Value("int64")),
|
100 |
+
"rewards": datasets.Sequence(datasets.Value("float")),
|
101 |
+
}
|
102 |
+
),
|
103 |
+
homepage="https://github.com/google-research/rl-reliability-metrics",
|
104 |
+
)
|
105 |
+
|
106 |
+
def _compute(
|
107 |
+
self,
|
108 |
+
timesteps,
|
109 |
+
rewards,
|
110 |
+
baseline="default",
|
111 |
+
freq_thresh=0.01,
|
112 |
+
window_size=100000,
|
113 |
+
window_size_trimmed=99000,
|
114 |
+
alpha=0.05,
|
115 |
+
eval_points=None,
|
116 |
+
):
|
117 |
+
if len(timesteps) < N_RUNS_RECOMMENDED:
|
118 |
+
logger.warning(
|
119 |
+
f"For robust statistics it is recommended to use at least {N_RUNS_RECOMMENDED} runs whereas you provided {len(timesteps)}."
|
120 |
+
)
|
121 |
+
|
122 |
+
curves = []
|
123 |
+
for timestep, reward in zip(timesteps, rewards):
|
124 |
+
curves.append(np.stack([timestep, reward]))
|
125 |
+
|
126 |
+
if self.config_name == "online":
|
127 |
+
if baseline == "default":
|
128 |
+
baseline = "curve_range"
|
129 |
+
if eval_points is None:
|
130 |
+
eval_points = DEFAULT_EVAL_POINTS
|
131 |
+
|
132 |
+
metrics = [
|
133 |
+
metrics_online.HighFreqEnergyWithinRuns(thresh=freq_thresh),
|
134 |
+
metrics_online.IqrWithinRuns(
|
135 |
+
window_size=window_size_trimmed, eval_points=eval_points, baseline=baseline
|
136 |
+
),
|
137 |
+
metrics_online.IqrAcrossRuns(
|
138 |
+
lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline
|
139 |
+
),
|
140 |
+
metrics_online.LowerCVaROnDiffs(baseline=baseline),
|
141 |
+
metrics_online.LowerCVaROnDrawdown(baseline=baseline),
|
142 |
+
metrics_online.LowerCVaROnAcross(
|
143 |
+
lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline
|
144 |
+
),
|
145 |
+
metrics_online.LowerCVaROnRaw(alpha=alpha, baseline=baseline),
|
146 |
+
metrics_online.MadAcrossRuns(
|
147 |
+
lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline
|
148 |
+
),
|
149 |
+
metrics_online.MadWithinRuns(
|
150 |
+
eval_points=eval_points, window_size=window_size_trimmed, baseline=baseline
|
151 |
+
),
|
152 |
+
metrics_online.MaxDrawdown(),
|
153 |
+
metrics_online.StddevAcrossRuns(
|
154 |
+
lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline
|
155 |
+
),
|
156 |
+
metrics_online.StddevWithinRuns(
|
157 |
+
eval_points=eval_points, window_size=window_size_trimmed, baseline=baseline
|
158 |
+
),
|
159 |
+
metrics_online.UpperCVaROnAcross(
|
160 |
+
alpha=alpha,
|
161 |
+
lowpass_thresh=freq_thresh,
|
162 |
+
eval_points=eval_points,
|
163 |
+
window_size=window_size,
|
164 |
+
baseline=baseline,
|
165 |
+
),
|
166 |
+
metrics_online.UpperCVaROnDiffs(alpha=alpha, baseline=baseline),
|
167 |
+
metrics_online.UpperCVaROnDrawdown(alpha=alpha, baseline=baseline),
|
168 |
+
metrics_online.UpperCVaROnRaw(alpha=alpha, baseline=baseline),
|
169 |
+
metrics_online.MedianPerfDuringTraining(window_size=window_size, eval_points=eval_points),
|
170 |
+
]
|
171 |
+
else:
|
172 |
+
if baseline == "default":
|
173 |
+
baseline = "median_perf"
|
174 |
+
|
175 |
+
metrics = [
|
176 |
+
metrics_offline.MadAcrossRollouts(baseline=baseline),
|
177 |
+
metrics_offline.IqrAcrossRollouts(baseline=baseline),
|
178 |
+
metrics_offline.StddevAcrossRollouts(baseline=baseline),
|
179 |
+
metrics_offline.LowerCVaRAcrossRollouts(alpha=alpha, baseline=baseline),
|
180 |
+
metrics_offline.UpperCVaRAcrossRollouts(alpha=alpha, baseline=baseline),
|
181 |
+
metrics_offline.MedianPerfAcrossRollouts(baseline=None),
|
182 |
+
]
|
183 |
+
|
184 |
+
evaluator = eval_metrics.Evaluator(metrics=metrics)
|
185 |
+
result = evaluator.compute_metrics(curves)
|
186 |
+
return result
|