ThomasSimonini HF staff commited on
Commit
91cb2a2
β€’
1 Parent(s): 1f35814

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +262 -0
app.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, hf_hub_download
3
+ from huggingface_hub.repocard import metadata_load
4
+
5
+ import pandas as pd
6
+
7
+ from utils import *
8
+
9
+ api = HfApi()
10
+
11
+ def get_user_models(hf_username, env_tag, lib_tag):
12
+ """
13
+ List the Reinforcement Learning models
14
+ from user given environment and lib
15
+ :param hf_username: User HF username
16
+ :param env_tag: Environment tag
17
+ :param lib_tag: Library tag
18
+ """
19
+ api = HfApi()
20
+ models = api.list_models(author=hf_username, filter=["reinforcement-learning", env_tag, lib_tag])
21
+
22
+ user_model_ids = [x.modelId for x in models]
23
+ return user_model_ids
24
+
25
+
26
+ def get_metadata(model_id):
27
+ """
28
+ Get model metadata (contains evaluation data)
29
+ :param model_id
30
+ """
31
+ try:
32
+ readme_path = hf_hub_download(model_id, filename="README.md")
33
+ return metadata_load(readme_path)
34
+ except requests.exceptions.HTTPError:
35
+ # 404 README.md not found
36
+ return None
37
+
38
+
39
+ def parse_metrics_accuracy(meta):
40
+ """
41
+ Get model results and parse it
42
+ :param meta: model metadata
43
+ """
44
+ if "model-index" not in meta:
45
+ return None
46
+ result = meta["model-index"][0]["results"]
47
+ metrics = result[0]["metrics"]
48
+ accuracy = metrics[0]["value"]
49
+
50
+ return accuracy
51
+
52
+
53
+ def parse_rewards(accuracy):
54
+ """
55
+ Parse mean_reward and std_reward
56
+ :param accuracy: model results
57
+ """
58
+ default_std = -1000
59
+ default_reward= -1000
60
+ if accuracy != None:
61
+ accuracy = str(accuracy)
62
+ parsed = accuracy.split(' +/- ')
63
+ if len(parsed)>1:
64
+ mean_reward = float(parsed[0])
65
+ std_reward = float(parsed[1])
66
+ elif len(parsed)==1: #only mean reward
67
+ mean_reward = float(parsed[0])
68
+ std_reward = float(0)
69
+ else:
70
+ mean_reward = float(default_std)
71
+ std_reward = float(default_reward)
72
+ else:
73
+ mean_reward = float(default_std)
74
+ std_reward = float(default_reward)
75
+
76
+ return mean_reward, std_reward
77
+
78
+ def calculate_best_result(user_model_ids):
79
+ """
80
+ Calculate the best results of a unit
81
+ best_result = mean_reward - std_reward
82
+ :param user_model_ids: RL models of a user
83
+ """
84
+ best_result = -100
85
+ best_model_id = ""
86
+ for model in user_model_ids:
87
+ meta = get_metadata(model)
88
+ if meta is None:
89
+ continue
90
+ accuracy = parse_metrics_accuracy(meta)
91
+ mean_reward, std_reward = parse_rewards(accuracy)
92
+ result = mean_reward - std_reward
93
+ if result > best_result:
94
+ best_result = result
95
+ best_model_id = model
96
+
97
+ return best_result, best_model_id
98
+
99
+ def check_if_passed(model):
100
+ """
101
+ Check if result >= baseline
102
+ to know if you pass
103
+ :param model: user model
104
+ """
105
+ if model["best_result"] >= model["min_result"]:
106
+ model["passed"] = True
107
+
108
+ def test_(hf_username):
109
+ results_certification = [
110
+ {
111
+ "unit": "Unit 1",
112
+ "env": "LunarLander-v2",
113
+ "library": "stable-baselines3",
114
+ "min_result": 200,
115
+ "best_result": 0,
116
+ "best_model_id": "",
117
+ "passed": False
118
+ },
119
+ {
120
+ "unit": "Unit 2",
121
+ "env": "Taxi-v3",
122
+ "library": "q-learning",
123
+ "min_result": 4,
124
+ "best_result": 0,
125
+ "best_model_id": "",
126
+ "passed": False
127
+ },
128
+ {
129
+ "unit": "Unit 3",
130
+ "env": "SpaceInvadersNoFrameskip-v4",
131
+ "library": "stable-baselines3",
132
+ "min_result": 200,
133
+ "best_result": 0,
134
+ "best_model_id": "",
135
+ "passed": False
136
+ },
137
+ {
138
+ "unit": "Unit 4",
139
+ "env": "CartPole-v1",
140
+ "library": "reinforce",
141
+ "min_result": 350,
142
+ "best_result": 0,
143
+ "best_model_id": "",
144
+ "passed": False
145
+ },
146
+ {
147
+ "unit": "Unit 4",
148
+ "env": "Pixelcopter-PLE-v0",
149
+ "library": "reinforce",
150
+ "min_result": 5,
151
+ "best_result": 0,
152
+ "best_model_id": "",
153
+ "passed": False
154
+ },
155
+ {
156
+ "unit": "Unit 5",
157
+ "env": "ML-Agents-SnowballTarget",
158
+ "library": "ml-agents",
159
+ "min_result": -100,
160
+ "best_result": 0,
161
+ "best_model_id": "",
162
+ "passed": False
163
+ },
164
+ {
165
+ "unit": "Unit 5",
166
+ "env": "ML-Agents-Pyramids",
167
+ "library": "ml-agents",
168
+ "min_result": -100,
169
+ "best_result": 0,
170
+ "best_model_id": "",
171
+ "passed": False
172
+ },
173
+ {
174
+ "unit": "Unit 6",
175
+ "env": "AntBulletEnv-v0",
176
+ "library": "stable-baselines3",
177
+ "min_result": 650,
178
+ "best_result": 0,
179
+ "best_model_id": "",
180
+ "passed": False
181
+ },
182
+ {
183
+ "unit": "Unit 6",
184
+ "env": "PandaReachDense-v2",
185
+ "library": "stable-baselines3",
186
+ "min_result": -3.5,
187
+ "best_result": 0,
188
+ "best_model_id": "",
189
+ "passed": False
190
+ },
191
+ {
192
+ "unit": "Unit 7",
193
+ "env": "ML-Agents-SoccerTwos",
194
+ "library": "ml-agents",
195
+ "min_result": -100,
196
+ "best_result": 0,
197
+ "best_model_id": "",
198
+ "passed": False
199
+ },
200
+ {
201
+ "unit": "Unit 8 Part 1",
202
+ "env": "GodotRL-JumperHard",
203
+ "library": "cleanrl",
204
+ "min_result": -100,
205
+ "best_result": 0,
206
+ "best_model_id": "",
207
+ "passed": False
208
+ },
209
+ {
210
+ "unit": "Unit 8 Part 2",
211
+ "env": "Vizdoom-Battle",
212
+ "library": "cleanrl",
213
+ "min_result": -100,
214
+ "best_result": 0,
215
+ "best_model_id": "",
216
+ "passed": False
217
+ },
218
+ ]
219
+ for unit in results_certification:
220
+ # Get user model
221
+ user_models = get_user_models(hf_username, unit['env'], unit['library'])
222
+ print(user_models)
223
+ # Calculate the best result and get the best_model_id
224
+ best_result, best_model_id = calculate_best_result(user_models)
225
+
226
+ # Save best_result and best_model_id
227
+ unit["best_result"] = best_result
228
+ unit["best_model_id"] = make_clickable_model(best_model_id)
229
+
230
+ # Based on best_result do we pass the unit?
231
+ check_if_passed(unit)
232
+ #pass_emoji(unit["passed"])
233
+
234
+ print(results_certification)
235
+
236
+ df = pd.DataFrame (results_certification)
237
+
238
+ return df
239
+
240
+
241
+ with gr.Blocks() as demo:
242
+ gr.Markdown(f"""
243
+ # πŸ† Check your progress in the Deep Reinforcement Learning Course πŸ†
244
+ You can check your progress here.
245
+
246
+ - To get a certificate of completion, you must **pass 80% of the assignments before the end of April 2023**.
247
+ - To get an honors certificate, you must **pass 100% of the assignments before the end of April 2023**.
248
+
249
+ To pass an assignment your model result (mean_reward - std_reward) must be >= min_result
250
+
251
+ **When min_result = -100 it means that you just need to push a model to pass this hands-on. No need to reach a certain result.**
252
+
253
+ Just type your Hugging Face Username πŸ€— (in my case ThomasSimonini)
254
+ """)
255
+
256
+ hf_username = gr.Textbox(placeholder="ThomasSimonini", label="Your Hugging Face Username")
257
+ #email = gr.Textbox(placeholder="[email protected]", label="Your Email (to receive your certificate)")
258
+ check_progress_button = gr.Button(value="Check my progress")
259
+ output = gr.components.Dataframe(value= test_(hf_username), headers=["Unit", "Environment", "Library", "Baseline", "Your best result", "Your best model id", "Pass?"], datatype=["markdown", "markdown", "markdown", "number", "number", "markdown", "bool"])
260
+ check_progress_button.click(fn=test_, inputs=hf_username, outputs=output)
261
+
262
+ demo.launch()