chirbard commited on
Commit
1ec1c03
·
verified ·
1 Parent(s): 1f2ae87

Upload ppo_tetris_v5.ipynb

Browse files
Files changed (1) hide show
  1. ppo_tetris_v5.ipynb +405 -0
ppo_tetris_v5.ipynb ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "### The environment 🎮\n",
7
+ "\n",
8
+ "- https://gymnasium.farama.org/environments/classic_control/mountain_car/\n",
9
+ "\n",
10
+ "### The library used 📚\n",
11
+ "\n",
12
+ "- [Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/)"
13
+ ],
14
+ "metadata": {
15
+ "id": "x7oR6R-ZIbeS"
16
+ }
17
+ },
18
+ {
19
+ "cell_type": "markdown",
20
+ "metadata": {
21
+ "id": "jeDAH0h0EBiG"
22
+ },
23
+ "source": [
24
+ "## Install dependencies and create a virtual screen 🔽\n"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "source": [
30
+ "!apt install swig cmake"
31
+ ],
32
+ "metadata": {
33
+ "id": "yQIGLPDkGhgG"
34
+ },
35
+ "execution_count": null,
36
+ "outputs": []
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {
42
+ "id": "9XaULfDZDvrC"
43
+ },
44
+ "outputs": [],
45
+ "source": [
46
+ "!pip install -r https://raw.githubusercontent.com/huggingface/deep-rl-class/main/notebooks/unit1/requirements-unit1.txt"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "markdown",
51
+ "source": [
52
+ "During the notebook, we'll need to generate a replay video. To do so, with colab, **we need to have a virtual screen to be able to render the environment** (and thus record the frames).\n",
53
+ "\n",
54
+ "Hence the following cell will install virtual screen libraries and create and run a virtual screen 🖥"
55
+ ],
56
+ "metadata": {
57
+ "id": "BEKeXQJsQCYm"
58
+ }
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "source": [
63
+ "!sudo apt-get update\n",
64
+ "!sudo apt-get install -y python3-opengl\n",
65
+ "!apt install ffmpeg\n",
66
+ "!apt install xvfb\n",
67
+ "!pip3 install pyvirtualdisplay"
68
+ ],
69
+ "metadata": {
70
+ "id": "j5f2cGkdP-mb"
71
+ },
72
+ "execution_count": null,
73
+ "outputs": []
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "source": [
78
+ "To make sure the new installed libraries are used, **sometimes it's required to restart the notebook runtime**. The next cell will force the **runtime to crash, so you'll need to connect again and run the code starting from here**. Thanks to this trick, **we will be able to run our virtual screen.**"
79
+ ],
80
+ "metadata": {
81
+ "id": "TCwBTAwAW9JJ"
82
+ }
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "source": [
87
+ "import os\n",
88
+ "os.kill(os.getpid(), 9)"
89
+ ],
90
+ "metadata": {
91
+ "id": "cYvkbef7XEMi"
92
+ },
93
+ "execution_count": null,
94
+ "outputs": []
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "source": [
99
+ "# Virtual display\n",
100
+ "from pyvirtualdisplay import Display\n",
101
+ "\n",
102
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
103
+ "virtual_display.start()"
104
+ ],
105
+ "metadata": {
106
+ "id": "BE5JWP5rQIKf"
107
+ },
108
+ "execution_count": null,
109
+ "outputs": []
110
+ },
111
+ {
112
+ "cell_type": "markdown",
113
+ "metadata": {
114
+ "id": "wrgpVFqyENVf"
115
+ },
116
+ "source": [
117
+ "## Import the packages 📦\n",
118
+ "\n",
119
+ "\n"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": null,
125
+ "metadata": {
126
+ "id": "cygWLPGsEQ0m"
127
+ },
128
+ "outputs": [],
129
+ "source": [
130
+ "import gymnasium\n",
131
+ "\n",
132
+ "from huggingface_sb3 import load_from_hub, package_to_hub\n",
133
+ "from huggingface_hub import notebook_login # To log to our Hugging Face account to be able to upload models to the Hub.\n",
134
+ "\n",
135
+ "from stable_baselines3 import PPO\n",
136
+ "from stable_baselines3.common.env_util import make_vec_env\n",
137
+ "from stable_baselines3.common.evaluation import evaluate_policy\n",
138
+ "from stable_baselines3.common.monitor import Monitor"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "metadata": {
145
+ "id": "w7vOFlpA_ONz"
146
+ },
147
+ "outputs": [],
148
+ "source": [
149
+ "import gymnasium as gym\n",
150
+ "\n",
151
+ "# First, we create our environment\n",
152
+ "env = gym.make(\"ALE/Tetris-v5\")\n",
153
+ "\n",
154
+ "# Then we reset this environment\n",
155
+ "observation, info = env.reset()\n",
156
+ "\n",
157
+ "for _ in range(20):\n",
158
+ " # Take a random action\n",
159
+ " action = env.action_space.sample()\n",
160
+ " print(\"Action taken:\", action)\n",
161
+ "\n",
162
+ " # Do this action in the environment and get\n",
163
+ " # next_state, reward, terminated, truncated and info\n",
164
+ " observation, reward, terminated, truncated, info = env.step(action)\n",
165
+ "\n",
166
+ " # If the game is terminated (in our case we land, crashed) or truncated (timeout)\n",
167
+ " if terminated or truncated:\n",
168
+ " # Reset the environment\n",
169
+ " print(\"Environment is reset\")\n",
170
+ " observation, info = env.reset()\n",
171
+ "\n",
172
+ "env.close()"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "markdown",
177
+ "metadata": {
178
+ "id": "poLBgRocF9aT"
179
+ },
180
+ "source": [
181
+ "Let's see what the Environment looks like:\n"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": null,
187
+ "metadata": {
188
+ "id": "ZNPG0g_UGCfh"
189
+ },
190
+ "outputs": [],
191
+ "source": [
192
+ "# We create our environment with gym.make(\"<name_of_the_environment>\")\n",
193
+ "env = gym.make(\"ALE/Tetris-v5\")\n",
194
+ "env.reset()\n",
195
+ "print(\"_____OBSERVATION SPACE_____ \\n\")\n",
196
+ "print(\"Observation Space Shape\", env.observation_space.shape)\n",
197
+ "print(\"Sample observation\", env.observation_space.sample()) # Get a random observation"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": null,
203
+ "metadata": {
204
+ "id": "We5WqOBGLoSm"
205
+ },
206
+ "outputs": [],
207
+ "source": [
208
+ "print(\"\\n _____ACTION SPACE_____ \\n\")\n",
209
+ "print(\"Action Space Shape\", env.action_space.n)\n",
210
+ "print(\"Action Space Sample\", env.action_space.sample()) # Take a random action"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "markdown",
215
+ "metadata": {
216
+ "id": "dFD9RAFjG8aq"
217
+ },
218
+ "source": [
219
+ "#### Vectorized Environment\n",
220
+ "\n",
221
+ "- We create a vectorized environment (a method for stacking multiple independent environments into a single environment) of 16 environments, this way, **we'll have more diverse experiences during the training.**"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "metadata": {
228
+ "id": "99hqQ_etEy1N"
229
+ },
230
+ "outputs": [],
231
+ "source": [
232
+ "# Create the environment\n",
233
+ "env = make_vec_env('ALE/Tetris-v5', n_envs=16)"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "markdown",
238
+ "metadata": {
239
+ "id": "QAN7B0_HCVZC"
240
+ },
241
+ "source": [
242
+ "#### Model and hyperparameters"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": null,
248
+ "metadata": {
249
+ "id": "543OHYDfcjK4"
250
+ },
251
+ "outputs": [],
252
+ "source": [
253
+ "model = PPO(\n",
254
+ " policy = 'MlpPolicy',\n",
255
+ " env = env,\n",
256
+ " n_steps = 1024,\n",
257
+ " batch_size = 64,\n",
258
+ " n_epochs = 4,\n",
259
+ " gamma = 0.99,\n",
260
+ " gae_lambda = 0.98,\n",
261
+ " ent_coef = 0.01,\n",
262
+ " verbose=1)"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "markdown",
267
+ "metadata": {
268
+ "id": "ClJJk88yoBUi"
269
+ },
270
+ "source": [
271
+ "## Train the PPO agent 🏃\n"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": null,
277
+ "metadata": {
278
+ "id": "poBCy9u_csyR"
279
+ },
280
+ "outputs": [],
281
+ "source": [
282
+ "model.learn(total_timesteps=100000)\n",
283
+ "# Save the model\n",
284
+ "model_name = \"Tetris-v5\"\n",
285
+ "model.save(model_name)"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "markdown",
290
+ "metadata": {
291
+ "id": "BqPKw3jt_pG5"
292
+ },
293
+ "source": [
294
+ "#### Evaluate"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": null,
300
+ "metadata": {
301
+ "id": "zpz8kHlt_a_m"
302
+ },
303
+ "outputs": [],
304
+ "source": [
305
+ "#@title\n",
306
+ "eval_env = Monitor(gym.make(\"ALE/Tetris-v5\"))\n",
307
+ "mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True)\n",
308
+ "print(f\"mean_reward={mean_reward:.2f} +/- {std_reward}\")"
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "markdown",
313
+ "source": [
314
+ "#### Upload to hub"
315
+ ],
316
+ "metadata": {
317
+ "id": "7YFBLHXDPuH5"
318
+ }
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": null,
323
+ "metadata": {
324
+ "id": "GZiFBBlzxzxY"
325
+ },
326
+ "outputs": [],
327
+ "source": [
328
+ "notebook_login()\n",
329
+ "!git config --global credential.helper store"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "source": [
335
+ "import gymnasium as gym\n",
336
+ "\n",
337
+ "from stable_baselines3 import PPO\n",
338
+ "from stable_baselines3.common.vec_env import DummyVecEnv\n",
339
+ "from stable_baselines3.common.env_util import make_vec_env\n",
340
+ "\n",
341
+ "from huggingface_sb3 import package_to_hub\n",
342
+ "\n",
343
+ "# PLACE the variables you've just defined two cells above\n",
344
+ "# Define the name of the environment\n",
345
+ "env_id = \"ALE/Tetris-v5\"\n",
346
+ "\n",
347
+ "# TODO: Define the model architecture we used\n",
348
+ "model_architecture = \"PPO\"\n",
349
+ "\n",
350
+ "## Define a repo_id\n",
351
+ "## repo_id is the id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name}\n",
352
+ "## CHANGE WITH YOUR REPO ID\n",
353
+ "repo_id = \"chirbard/ppo-Tetris-v5\" # Change with your repo id, you can't push with mine 😄\n",
354
+ "\n",
355
+ "## Define the commit message\n",
356
+ "commit_message = \"Upload PPO Tetris-v5 trained agent\"\n",
357
+ "\n",
358
+ "# Create the evaluation env and set the render_mode=\"rgb_array\"\n",
359
+ "eval_env = DummyVecEnv([lambda: gym.make(env_id, render_mode=\"rgb_array\")])\n",
360
+ "\n",
361
+ "# PLACE the package_to_hub function you've just filled here\n",
362
+ "package_to_hub(model=model, # Our trained model\n",
363
+ " model_name=model_name, # The name of our trained model\n",
364
+ " model_architecture=model_architecture, # The model architecture we used: in our case PPO\n",
365
+ " env_id=env_id, # Name of the environment\n",
366
+ " eval_env=eval_env, # Evaluation Environment\n",
367
+ " repo_id=repo_id, # id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name}\n",
368
+ " commit_message=commit_message)\n"
369
+ ],
370
+ "metadata": {
371
+ "id": "I2E--IJu8JYq"
372
+ },
373
+ "execution_count": null,
374
+ "outputs": []
375
+ }
376
+ ],
377
+ "metadata": {
378
+ "accelerator": "GPU",
379
+ "colab": {
380
+ "private_outputs": true,
381
+ "provenance": [],
382
+ "collapsed_sections": [
383
+ "QAN7B0_HCVZC",
384
+ "BqPKw3jt_pG5"
385
+ ]
386
+ },
387
+ "gpuClass": "standard",
388
+ "kernelspec": {
389
+ "display_name": "Python 3.9.7",
390
+ "language": "python",
391
+ "name": "python3"
392
+ },
393
+ "language_info": {
394
+ "name": "python",
395
+ "version": "3.9.7"
396
+ },
397
+ "vscode": {
398
+ "interpreter": {
399
+ "hash": "ed7f8024e43d3b8f5ca3c5e1a8151ab4d136b3ecee1e3fd59e0766ccc55e1b10"
400
+ }
401
+ }
402
+ },
403
+ "nbformat": 4,
404
+ "nbformat_minor": 0
405
+ }