lvwerra HF staff commited on
Commit
289642b
1 Parent(s): 1304e76

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. rl_reliability.py +28 -7
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
2
  git+https://github.com/google-research/rl-reliability-metrics
3
  scipy
4
  tensorflow
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  git+https://github.com/google-research/rl-reliability-metrics
3
  scipy
4
  tensorflow
rl_reliability.py CHANGED
@@ -13,6 +13,9 @@
13
  # limitations under the License.
14
  """Computes the RL Reliability Metrics."""
15
 
 
 
 
16
  import datasets
17
  import numpy as np
18
  from rl_reliability_metrics.evaluation import eval_metrics
@@ -81,11 +84,27 @@ Examples:
81
  """
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
85
  class RLReliability(evaluate.Metric):
86
  """Computes the RL Reliability Metrics."""
87
 
88
- def _info(self):
 
 
 
89
  if self.config_name not in ["online", "offline"]:
90
  raise KeyError("""You should supply a configuration name selected in '["online", "offline"]'""")
91
 
@@ -94,6 +113,7 @@ class RLReliability(evaluate.Metric):
94
  description=_DESCRIPTION,
95
  citation=_CITATION,
96
  inputs_description=_KWARGS_DESCRIPTION,
 
97
  features=datasets.Features(
98
  {
99
  "timesteps": datasets.Sequence(datasets.Value("int64")),
@@ -107,18 +127,19 @@ class RLReliability(evaluate.Metric):
107
  self,
108
  timesteps,
109
  rewards,
110
- baseline="default",
111
- freq_thresh=0.01,
112
- window_size=100000,
113
- window_size_trimmed=99000,
114
- alpha=0.05,
115
- eval_points=None,
116
  ):
117
  if len(timesteps) < N_RUNS_RECOMMENDED:
118
  logger.warning(
119
  f"For robust statistics it is recommended to use at least {N_RUNS_RECOMMENDED} runs whereas you provided {len(timesteps)}."
120
  )
121
 
 
 
 
 
 
 
 
122
  curves = []
123
  for timestep, reward in zip(timesteps, rewards):
124
  curves.append(np.stack([timestep, reward]))
 
13
  # limitations under the License.
14
  """Computes the RL Reliability Metrics."""
15
 
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional
18
+
19
  import datasets
20
  import numpy as np
21
  from rl_reliability_metrics.evaluation import eval_metrics
 
84
  """
85
 
86
 
87
+ @dataclass
88
+ class RLReliabilityConfig(evaluate.info.Config):
89
+
90
+ name: str = "default"
91
+
92
+ baseline: str = "default"
93
+ freq_thresh: float = 0.01
94
+ window_size: int = 100000
95
+ window_size_trimmed: int = 99000
96
+ alpha: float = 0.05
97
+ eval_points: Optional[List] = None
98
+
99
+
100
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
101
  class RLReliability(evaluate.Metric):
102
  """Computes the RL Reliability Metrics."""
103
 
104
+ CONFIG_CLASS = RLReliabilityConfig
105
+ ALLOWED_CONFIG_NAMES = ["online", "offline"]
106
+
107
+ def _info(self, config):
108
  if self.config_name not in ["online", "offline"]:
109
  raise KeyError("""You should supply a configuration name selected in '["online", "offline"]'""")
110
 
 
113
  description=_DESCRIPTION,
114
  citation=_CITATION,
115
  inputs_description=_KWARGS_DESCRIPTION,
116
+ config=config,
117
  features=datasets.Features(
118
  {
119
  "timesteps": datasets.Sequence(datasets.Value("int64")),
 
127
  self,
128
  timesteps,
129
  rewards,
 
 
 
 
 
 
130
  ):
131
  if len(timesteps) < N_RUNS_RECOMMENDED:
132
  logger.warning(
133
  f"For robust statistics it is recommended to use at least {N_RUNS_RECOMMENDED} runs whereas you provided {len(timesteps)}."
134
  )
135
 
136
+ baseline = self.config.baseline
137
+ freq_thresh = self.config.freq_thresh
138
+ window_size = self.config.window_size
139
+ window_size_trimmed = self.config.window_size_trimmed
140
+ alpha = self.config.alpha
141
+ eval_points = self.config.eval_points
142
+
143
  curves = []
144
  for timestep, reward in zip(timesteps, rewards):
145
  curves.append(np.stack([timestep, reward]))