skar0 commited on
Commit
820bb68
·
1 Parent(s): 52c188e

Initial commit

Browse files
Files changed (5) hide show
  1. .gitignore +136 -0
  2. README.md +3 -12
  3. app.py +54 -0
  4. cartpole.py +635 -0
  5. requirements.txt +438 -0
.gitignore ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generate/
2
+ videos/
3
+ token.txt
4
+ pat.txt
5
+ *.ipynb
6
+ runs/
7
+ wandb/
8
+ # Byte-compiled / optimized / DLL files
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ pip-wheel-metadata/
31
+ share/python-wheels/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+ MANIFEST
36
+
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Unit test / coverage reports
48
+ htmlcov/
49
+ .tox/
50
+ .nox/
51
+ .coverage
52
+ .coverage.*
53
+ .cache
54
+ nosetests.xml
55
+ coverage.xml
56
+ *.cover
57
+ *.py,cover
58
+ .hypothesis/
59
+ .pytest_cache/
60
+
61
+ # Translations
62
+ *.mo
63
+ *.pot
64
+
65
+ # Django stuff:
66
+ *.log
67
+ local_settings.py
68
+ db.sqlite3
69
+ db.sqlite3-journal
70
+
71
+ # Flask stuff:
72
+ instance/
73
+ .webassets-cache
74
+
75
+ # Scrapy stuff:
76
+ .scrapy
77
+
78
+ # Sphinx documentation
79
+ docs/_build/
80
+
81
+ # PyBuilder
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102
+ __pypackages__/
103
+
104
+ # Celery stuff
105
+ celerybeat-schedule
106
+ celerybeat.pid
107
+
108
+ # SageMath parsed files
109
+ *.sage.py
110
+
111
+ # Environments
112
+ .env
113
+ .venv
114
+ env/
115
+ venv/
116
+ ENV/
117
+ env.bak/
118
+ venv.bak/
119
+
120
+ # Spyder project settings
121
+ .spyderproject
122
+ .spyproject
123
+
124
+ # Rope project settings
125
+ .ropeproject
126
+
127
+ # mkdocs documentation
128
+ /site
129
+
130
+ # mypy
131
+ .mypy_cache/
132
+ .dmypy.json
133
+ dmypy.json
134
+
135
+ # Pyre type checker
136
+ .pyre/
README.md CHANGED
@@ -1,13 +1,4 @@
1
- ---
2
- title: Cartpole Demo
3
- emoji: 🔥
4
- colorFrom: blue
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.19.1
8
- app_file: app.py
9
- pinned: false
10
- license: wtfpl
11
- ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # cartpole-demo
2
+ This project is intended to publish my solution to the Cartpole environment from OpenAI's Gym.
 
 
 
 
 
 
 
 
 
3
 
4
+ I want to deploy to HuggingFace
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import gradio as gr
3
+ import gym
4
+ import sys
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ import yaml
7
+ import torch
8
+ from cartpole import (
9
+ make_env, reset_env, Agent, rollout_phase, get_action_shape
10
+ )
11
+
12
+ MAIN = __name__ == "__main__"
13
+ examples = [0, 1, 31415, 'Hello, World!', 'This is a seed...']
14
+
15
+ def generate_video(
16
+ string: str, wandb_path='wandb/run-20230303_211416-ox4d1p0u/files'
17
+ ):
18
+ with open(f'{wandb_path}/config.yaml') as f_cfg:
19
+ config = yaml.safe_load(f_cfg)
20
+ seed = hash(string) % ((sys.maxsize + 1) * 2)
21
+ num_envs = config['num_envs']['value']
22
+ num_steps = config['num_steps']['value']
23
+ assert seed >= 0
24
+ assert isinstance(seed, int)
25
+ run_name = f'seed{seed}'
26
+ log_dir = f'generate/{run_name}'
27
+ writer = SummaryWriter(log_dir)
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ envs = gym.vector.SyncVectorEnv([
30
+ make_env("CartPole-v1", seed, i, True, run_name)
31
+ for i in range(num_envs)
32
+ ])
33
+ action_shape = get_action_shape(envs)
34
+ next_obs, next_done = reset_env(envs, device)
35
+ global_step = 0
36
+ agent = Agent(envs).to(device)
37
+ agent.load_state_dict(torch.load(f'{wandb_path}/model_state_dict.pt'))
38
+ rollout_phase(
39
+ next_obs, next_done, agent, envs, writer, device,
40
+ global_step, action_shape, num_envs, num_steps,
41
+ )
42
+ video_path = glob.glob(f'videos/{run_name}/*.mp4')[0]
43
+ return video_path
44
+
45
+ if MAIN:
46
+ demo = gr.Interface(
47
+ fn=generate_video,
48
+ inputs=[
49
+ gr.components.Textbox(lines=1, label="Seed"),
50
+ ],
51
+ outputs=gr.components.Video(label="Generated Video"),
52
+ examples=examples,
53
+ )
54
+ demo.launch(share=True)
cartpole.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---
2
+ # jupyter:
3
+ # jupytext:
4
+ # text_representation:
5
+ # extension: .py
6
+ # format_name: light
7
+ # format_version: '1.5'
8
+ # jupytext_version: 1.14.5
9
+ # kernelspec:
10
+ # display_name: Python 3
11
+ # name: python3
12
+ # ---
13
+
14
+ # + id="QAY_RQOLcRtA" executionInfo={"status": "ok", "timestamp": 1677942285188, "user_tz": 0, "elapsed": 1942, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} colab={"base_uri": "https://localhost:8080/"} outputId="ee4de327-947e-4f4e-9d34-514460da288a"
15
+ MAIN = __name__ == "__main__"
16
+ if MAIN:
17
+ print('Mounting drive...')
18
+ from google.colab import drive
19
+ drive.mount('/content/drive')
20
+ # %cd /content/drive/MyDrive/Colab Notebooks/cartpole-demo
21
+
22
+ # + colab={"base_uri": "https://localhost:8080/"} id="GgSNZRJh4EjV" executionInfo={"status": "ok", "timestamp": 1677942324397, "user_tz": 0, "elapsed": 39212, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} outputId="8fd1eecc-12d1-4bae-cd15-dd541f1d84c7"
23
+ # !pip install einops
24
+ # !pip install wandb
25
+ # !pip install jupytext
26
+ # !pip install pygame
27
+ # !pip install torchtyping
28
+ # !pip install gradio
29
+
30
+ # + colab={"base_uri": "https://localhost:8080/"} id="1g58HZUb8Ltl" executionInfo={"status": "ok", "timestamp": 1677942492332, "user_tz": 0, "elapsed": 2440, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} outputId="d2f2ab57-c2c0-49aa-fdef-323556a2e4b6"
31
+ # !git config --global user.email "[email protected]"
32
+ # !git config --global user.name "ojh31"
33
+ # !cat pat.txt | xargs git remote set-url origin
34
+ # !jupytext --to py cartpole.ipynb
35
+ # !git fetch
36
+ # !git status
37
+
38
+ # + id="vEczQ48wC40O" executionInfo={"status": "ok", "timestamp": 1677942330521, "user_tz": 0, "elapsed": 4062, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
39
+ import os
40
+ import glob
41
+ import sys
42
+ import argparse
43
+ import random
44
+ import time
45
+ from distutils.util import strtobool
46
+ from dataclasses import dataclass
47
+ from typing import Optional
48
+ import numpy as np
49
+ import torch
50
+ import torch as t
51
+ from torchtyping import TensorType as TT
52
+ from typeguard import typechecked
53
+ import gym
54
+ import torch.nn as nn
55
+ import torch.optim as optim
56
+ from torch.distributions.categorical import Categorical
57
+ from torch.utils.tensorboard import SummaryWriter
58
+ from gym.spaces import Discrete
59
+ from typing import Any, List, Optional, Union, Tuple, Iterable
60
+ from einops import rearrange
61
+ import importlib
62
+ import wandb
63
+ from typeguard import typechecked
64
+
65
+
66
+ # + id="K7T8bs1Y76ZK" executionInfo={"status": "ok", "timestamp": 1677942330521, "user_tz": 0, "elapsed": 8, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} colab={"base_uri": "https://localhost:8080/"} outputId="f59ffef0-7156-4f27-d992-a392d59a1c73"
67
+ # %env "WANDB_NOTEBOOK_NAME" "cartpole.py"
68
+
69
+ # + id="Q5E93-BGRjuy" executionInfo={"status": "ok", "timestamp": 1677942330522, "user_tz": 0, "elapsed": 8, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
70
+ def make_env(
71
+ env_id: str, seed: int, idx: int, capture_video: bool, run_name: str
72
+ ):
73
+ """
74
+ Return a function that returns an environment after setting up boilerplate.
75
+ """
76
+
77
+ def thunk():
78
+ env = gym.make(env_id, new_step_api=True)
79
+ env = gym.wrappers.RecordEpisodeStatistics(env)
80
+ if capture_video:
81
+ if idx == 0:
82
+ # Video every 50 runs for env #1
83
+ env = gym.wrappers.RecordVideo(
84
+ env,
85
+ f"videos/{run_name}",
86
+ episode_trigger=lambda x : x % 50 == 0
87
+ )
88
+ obs = env.reset(seed=seed)
89
+ env.action_space.seed(seed)
90
+ env.observation_space.seed(seed)
91
+ return env
92
+
93
+ return thunk
94
+
95
+
96
+ # + id="Kf152ROwHjM_" executionInfo={"status": "ok", "timestamp": 1677942330522, "user_tz": 0, "elapsed": 7, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
97
+ def test_minibatch_indexes(minibatch_indexes):
98
+ for n in range(5):
99
+ frac, minibatch_size = np.random.randint(1, 8, size=(2,))
100
+ batch_size = frac * minibatch_size
101
+ indices = minibatch_indexes(batch_size, minibatch_size)
102
+ assert any([isinstance(indices, list), isinstance(indices, np.ndarray)])
103
+ assert isinstance(indices[0], np.ndarray)
104
+ assert len(indices) == frac
105
+ np.testing.assert_equal(np.sort(np.stack(indices).flatten()), np.arange(batch_size))
106
+
107
+
108
+ # + id="mhvduVeOHkln" executionInfo={"status": "ok", "timestamp": 1677942330522, "user_tz": 0, "elapsed": 7, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
109
+ def test_calc_entropy_bonus(calc_entropy_bonus):
110
+ probs = Categorical(logits=t.randn((3, 4)))
111
+ ent_coef = 0.5
112
+ expected = ent_coef * probs.entropy().mean()
113
+ actual = calc_entropy_bonus(probs, ent_coef)
114
+ t.testing.assert_close(expected, actual)
115
+
116
+
117
+ # + id="Aya60GeCGA5X" executionInfo={"status": "ok", "timestamp": 1677942330875, "user_tz": 0, "elapsed": 360, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
118
+ def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
119
+ t.nn.init.orthogonal_(layer.weight, std)
120
+ t.nn.init.constant_(layer.bias, bias_const)
121
+ return layer
122
+
123
+ class Agent(nn.Module):
124
+ critic: nn.Sequential
125
+ actor: nn.Sequential
126
+
127
+ def __init__(self, envs: gym.vector.SyncVectorEnv):
128
+ super().__init__()
129
+ obs_shape = np.array(
130
+ (envs.num_envs, ) + envs.single_action_space.shape
131
+ ).prod().astype(int)
132
+ self.actor = nn.Sequential(
133
+ layer_init(nn.Linear(obs_shape, 64)),
134
+ nn.Tanh(),
135
+ layer_init(nn.Linear(64, 64)),
136
+ nn.Tanh(),
137
+ layer_init(nn.Linear(64, envs.single_action_space.n), std=.01),
138
+ )
139
+ self.critic = nn.Sequential(
140
+ layer_init(nn.Linear(obs_shape, 64)),
141
+ nn.Tanh(),
142
+ layer_init(nn.Linear(64, 64)),
143
+ nn.Tanh(),
144
+ layer_init(nn.Linear(64, 1), std=1),
145
+ )
146
+
147
+
148
+
149
+ # + id="6PwPZHlLGDYu" executionInfo={"status": "ok", "timestamp": 1677942330875, "user_tz": 0, "elapsed": 4, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
150
+ # %%
151
+ @t.inference_mode()
152
+ def compute_advantages(
153
+ next_value: t.Tensor,
154
+ next_done: t.Tensor,
155
+ rewards: t.Tensor,
156
+ values: t.Tensor,
157
+ dones: t.Tensor,
158
+ device: t.device,
159
+ gamma: float,
160
+ gae_lambda: float,
161
+ ) -> t.Tensor:
162
+ '''Compute advantages using Generalized Advantage Estimation.
163
+
164
+ next_value: shape (1, env) -
165
+ represents V(s_{t+1}) which is needed for the last advantage term
166
+ next_done: shape (env,)
167
+ rewards: shape (t, env)
168
+ values: shape (t, env)
169
+ dones: shape (t, env)
170
+
171
+ Return: shape (t, env)
172
+ '''
173
+ assert isinstance(next_value, t.Tensor)
174
+ assert isinstance(next_done, t.Tensor)
175
+ assert isinstance(rewards, t.Tensor)
176
+ assert isinstance(values, t.Tensor)
177
+ assert isinstance(dones, t.Tensor)
178
+ t_max, n_env = values.shape
179
+ next_values = t.concat((values[1:, ], next_value))
180
+ next_dones = t.concat((dones[1:, ], next_done.unsqueeze(0)))
181
+ deltas = rewards + gamma * next_values * (1.0 - next_dones) - values
182
+ adv = deltas.clone().to(device)
183
+ for to_go in range(1, t_max):
184
+ t_idx = t_max - to_go - 1
185
+ t.testing.assert_close(adv[t_idx], deltas[t_idx])
186
+ adv[t_idx] += (
187
+ gamma * gae_lambda * adv[t_idx + 1] * (1.0 - next_dones[t_idx])
188
+ )
189
+ return adv
190
+
191
+
192
+
193
+ # + id="uYSSMnF-GPvm" executionInfo={"status": "ok", "timestamp": 1677942330875, "user_tz": 0, "elapsed": 3, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
194
+ # %%
195
+ @dataclass
196
+ class Minibatch:
197
+ obs: t.Tensor
198
+ logprobs: t.Tensor
199
+ actions: t.Tensor
200
+ advantages: t.Tensor
201
+ returns: t.Tensor
202
+ values: t.Tensor
203
+
204
+ def minibatch_indexes(
205
+ batch_size: int, minibatch_size: int
206
+ ) -> List[np.ndarray]:
207
+ '''
208
+ Return a list of length (batch_size // minibatch_size) where
209
+ each element is an array of indexes into the batch.
210
+
211
+ Each index should appear exactly once.
212
+ '''
213
+ assert batch_size % minibatch_size == 0
214
+ n = batch_size // minibatch_size
215
+ indices = np.arange(batch_size)
216
+ np.random.shuffle(indices)
217
+ return [indices[i::n] for i in range(n)]
218
+
219
+ if MAIN:
220
+ test_minibatch_indexes(minibatch_indexes)
221
+
222
+ def make_minibatches(
223
+ obs: t.Tensor,
224
+ logprobs: t.Tensor,
225
+ actions: t.Tensor,
226
+ advantages: t.Tensor,
227
+ values: t.Tensor,
228
+ obs_shape: tuple,
229
+ action_shape: tuple,
230
+ batch_size: int,
231
+ minibatch_size: int,
232
+ ) -> List[Minibatch]:
233
+ '''
234
+ Flatten the environment and steps dimension into one batch dimension,
235
+ then shuffle and split into minibatches.
236
+ '''
237
+ n_steps, n_env = values.shape
238
+ n_dim = n_steps * n_env
239
+ indexes = minibatch_indexes(batch_size=batch_size, minibatch_size=minibatch_size)
240
+ obs_flat = obs.reshape((batch_size,) + obs_shape)
241
+ act_flat = actions.reshape((batch_size,) + action_shape)
242
+ probs_flat = logprobs.reshape((batch_size,) + action_shape)
243
+ adv_flat = advantages.reshape(n_dim)
244
+ val_flat = values.reshape(n_dim)
245
+ return [
246
+ Minibatch(
247
+ obs_flat[idx], probs_flat[idx], act_flat[idx], adv_flat[idx],
248
+ adv_flat[idx] + val_flat[idx], val_flat[idx]
249
+ )
250
+ for idx in indexes
251
+ ]
252
+
253
+
254
+
255
+ # + id="K7wXDJ9MGOWu" executionInfo={"status": "ok", "timestamp": 1677942330876, "user_tz": 0, "elapsed": 4, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
256
+ # %%
257
+ def calc_policy_loss(
258
+ probs: Categorical, mb_action: t.Tensor, mb_advantages: t.Tensor,
259
+ mb_logprobs: t.Tensor, clip_coef: float
260
+ ) -> t.Tensor:
261
+ '''
262
+ Return the policy loss, suitable for maximisation with gradient ascent.
263
+
264
+ probs:
265
+ a distribution containing the actor's unnormalized logits of
266
+ shape (minibatch, num_actions)
267
+
268
+ clip_coef: amount of clipping, denoted by epsilon in Eq 7.
269
+
270
+ normalize: if true, normalize mb_advantages to have mean 0, variance 1
271
+ '''
272
+ adv_norm = (mb_advantages - mb_advantages.mean()) / mb_advantages.std()
273
+ ratio = t.exp(probs.log_prob(mb_action)) / t.exp(mb_logprobs)
274
+ min_left = ratio * adv_norm
275
+ min_right = t.clip(ratio, 1 - clip_coef, 1 + clip_coef) * adv_norm
276
+ return t.minimum(min_left, min_right).mean()
277
+
278
+
279
+
280
+ # + id="CmyxU6JWGMsG" executionInfo={"status": "ok", "timestamp": 1677942330876, "user_tz": 0, "elapsed": 4, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
281
+ # %%
282
+ def calc_value_function_loss(
283
+ critic: nn.Sequential, mb_obs: t.Tensor, mb_returns: t.Tensor, v_coef: float
284
+ ) -> t.Tensor:
285
+ '''Compute the value function portion of the loss function.
286
+ Need to minimise this
287
+
288
+ v_coef:
289
+ the coefficient for the value loss, which weights its contribution to
290
+ the overall loss. Denoted by c_1 in the paper.
291
+ '''
292
+ output = critic(mb_obs)
293
+ return v_coef * (output - mb_returns).pow(2).mean() / 2
294
+
295
+
296
+
297
+ # + id="npyWs6xjGLkP" executionInfo={"status": "ok", "timestamp": 1677942331469, "user_tz": 0, "elapsed": 597, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
298
+ # %%
299
+ def calc_entropy_loss(probs: Categorical, ent_coef: float):
300
+ '''Return the entropy loss term.
301
+ Need to maximise this
302
+
303
+ ent_coef:
304
+ The coefficient for the entropy loss, which weights its contribution to the overall loss.
305
+ Denoted by c_2 in the paper.
306
+ '''
307
+ return probs.entropy().mean() * ent_coef
308
+
309
+ if MAIN:
310
+ test_calc_entropy_bonus(calc_entropy_loss)
311
+
312
+
313
+ # + id="nqJeg1kZGKSG" executionInfo={"status": "ok", "timestamp": 1677942331470, "user_tz": 0, "elapsed": 5, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
314
+ # %%
315
+ class PPOScheduler:
316
+ def __init__(self, optimizer: optim.Adam, initial_lr: float, end_lr: float, num_updates: int):
317
+ self.optimizer = optimizer
318
+ self.initial_lr = initial_lr
319
+ self.end_lr = end_lr
320
+ self.num_updates = num_updates
321
+ self.n_step_calls = 0
322
+
323
+ def step(self):
324
+ '''
325
+ Implement linear learning rate decay so that after num_updates calls to step,
326
+ the learning rate is end_lr.
327
+ '''
328
+ lr = (
329
+ self.initial_lr +
330
+ (self.end_lr - self.initial_lr) * self.n_step_calls / self.num_updates
331
+ )
332
+ for param in self.optimizer.param_groups:
333
+ param['lr'] = lr
334
+ self.n_step_calls += 1
335
+
336
+ def make_optimizer(
337
+ agent: Agent, num_updates: int, initial_lr: float, end_lr: float
338
+ ) -> Tuple[optim.Adam, PPOScheduler]:
339
+ '''Return an appropriately configured Adam with its attached scheduler.'''
340
+ optimizer = optim.Adam(agent.parameters(), lr=initial_lr, maximize=True)
341
+ scheduler = PPOScheduler(
342
+ optimizer=optimizer, initial_lr=initial_lr, end_lr=end_lr, num_updates=num_updates
343
+ )
344
+ return optimizer, scheduler
345
+
346
+
347
+
348
+ # + id="mgZ7-wsRCxJW" executionInfo={"status": "ok", "timestamp": 1677942331470, "user_tz": 0, "elapsed": 5, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
349
+ @dataclass
350
+ class PPOArgs:
351
+ exp_name: str = 'cartpole.py'
352
+ seed: int = 1
353
+ torch_deterministic: bool = True
354
+ cuda: bool = True
355
+ track: bool = True
356
+ wandb_project_name: str = "PPOCart"
357
+ wandb_entity: str = None
358
+ capture_video: bool = True
359
+ env_id: str = "CartPole-v1"
360
+ total_timesteps: int = 40_000
361
+ learning_rate: float = 0.00025
362
+ num_envs: int = 4
363
+ num_steps: int = 128
364
+ gamma: float = 0.99
365
+ gae_lambda: float = 0.95
366
+ num_minibatches: int = 4
367
+ update_epochs: int = 4
368
+ clip_coef: float = 0.2
369
+ ent_coef: float = 0.01
370
+ vf_coef: float = 0.5
371
+ max_grad_norm: float = 0.5
372
+ batch_size: int = 512
373
+ minibatch_size: int = 128
374
+
375
+
376
+ # + id="xeIu-J3ZwGyq" executionInfo={"status": "ok", "timestamp": 1677942356492, "user_tz": 0, "elapsed": 218, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
377
+ def wandb_init(name: str, args: PPOArgs):
378
+ wandb.init(
379
+ project=args.wandb_project_name,
380
+ entity=args.wandb_entity,
381
+ sync_tensorboard=True,
382
+ config=vars(args),
383
+ name=name,
384
+ monitor_gym=True,
385
+ save_code=True,
386
+ settings=wandb.Settings(symlink=False)
387
+ )
388
+
389
+
390
+ # + id="gMYWqhsryYHy" executionInfo={"status": "ok", "timestamp": 1677942331470, "user_tz": 0, "elapsed": 4, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
391
+ def set_seed(seed: int):
392
+ random.seed(seed)
393
+ np.random.seed(seed)
394
+ torch.manual_seed(seed)
395
+
396
+
397
+ # + id="T9j_L0Wpyrgz" executionInfo={"status": "ok", "timestamp": 1677942331471, "user_tz": 0, "elapsed": 5, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
398
+ @typechecked
399
+ def rollout_phase(
400
+ next_obs: t.Tensor, next_done: t.Tensor,
401
+ agent: Agent, envs: gym.vector.SyncVectorEnv,
402
+ writer: SummaryWriter, device: torch.device,
403
+ global_step: int, action_shape: Tuple,
404
+ num_envs: int, num_steps: int,
405
+ ) -> Tuple[
406
+ TT['envs'],
407
+ TT['envs'],
408
+ TT['steps', 'envs'],
409
+ TT['steps', 'envs'],
410
+ TT['steps', 'envs'],
411
+ TT['steps', 'envs'],
412
+ TT['steps', 'envs'],
413
+ TT['steps', 'envs'],
414
+ ]:
415
+ '''
416
+ Output:
417
+
418
+ next_obs, next_done, actions, dones, logprobs, obs, rewards, values
419
+ '''
420
+ obs = torch.zeros(
421
+ (num_steps, num_envs) +
422
+ envs.single_observation_space.shape
423
+ ).to(device)
424
+ actions = torch.zeros(
425
+ (num_steps, num_envs) +
426
+ action_shape
427
+ ).to(device)
428
+ logprobs = torch.zeros((num_steps, num_envs)).to(device)
429
+ rewards = torch.zeros((num_steps, num_envs)).to(device)
430
+ dones = torch.zeros((num_steps, num_envs)).to(device)
431
+ values = torch.zeros((num_steps, num_envs)).to(device)
432
+ for i in range(0, num_steps):
433
+ # Rollout phase
434
+ global_step += 1
435
+ curr_obs = next_obs
436
+ done = next_done
437
+ with t.inference_mode():
438
+ logits = agent.actor(curr_obs).detach()
439
+ q_values = agent.critic(curr_obs).detach().squeeze(-1)
440
+ prob = Categorical(logits=logits)
441
+ action = prob.sample()
442
+ logprob = prob.log_prob(action)
443
+ next_obs, reward, next_done, info = envs.step(action.numpy())
444
+ next_obs = t.tensor(next_obs, device=device)
445
+ next_done = t.tensor(next_done, device=device)
446
+ actions[i] = action
447
+ dones[i] = done.detach().clone()
448
+ logprobs[i] = logprob
449
+ obs[i] = curr_obs
450
+ rewards[i] = t.tensor(reward, device=device)
451
+ values[i] = q_values
452
+
453
+ if writer is not None and "episode" in info.keys():
454
+ for item in info['episode']:
455
+ if item is None or 'r' not in item.keys():
456
+ continue
457
+ writer.add_scalar(
458
+ "charts/episodic_return", item["r"], global_step
459
+ )
460
+ writer.add_scalar(
461
+ "charts/episodic_length", item["l"], global_step
462
+ )
463
+ if global_step % 10 != 0:
464
+ continue
465
+ print(
466
+ f"global_step={global_step}, episodic_return={item['r']}"
467
+ )
468
+ print("charts/episodic_return", item["r"], global_step)
469
+ print("charts/episodic_length", item["l"], global_step)
470
+ return (
471
+ next_obs, next_done, actions, dones, logprobs, obs, rewards, values
472
+ )
473
+
474
+
475
+ # + id="xdDhABIk5jyb" executionInfo={"status": "ok", "timestamp": 1677942331471, "user_tz": 0, "elapsed": 5, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
476
+ def reset_env(envs, device):
477
+ next_obs = torch.Tensor(envs.reset()).to(device)
478
+ next_done = torch.zeros(envs.num_envs).to(device)
479
+ return next_obs, next_done
480
+
481
+
482
+ # + id="5CoMpUVU7rFT" executionInfo={"status": "ok", "timestamp": 1677942331471, "user_tz": 0, "elapsed": 5, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
483
+ def get_action_shape(envs: gym.vector.SyncVectorEnv):
484
+ action_shape = envs.single_action_space.shape
485
+ assert action_shape is not None
486
+ assert isinstance(
487
+ envs.single_action_space, Discrete
488
+ ), "only discrete action space is supported"
489
+ return action_shape
490
+
491
+
492
+ # + id="FHmn5kSUGFFu" executionInfo={"status": "ok", "timestamp": 1677942366007, "user_tz": 0, "elapsed": 251, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
493
+ # %%
494
+ def train_ppo(args: PPOArgs):
495
+ t0 = int(time.time())
496
+ run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{t0}"
497
+ if args.track:
498
+ wandb_init(run_name, args)
499
+ log_dir = wandb.run.dir
500
+ writer = SummaryWriter(log_dir)
501
+ writer.add_text(
502
+ "hyperparameters",
503
+ "|param|value|\n|-|-|\n%s" % "\n".join([f"|{key}|{value}|"
504
+ for (key, value) in vars(args).items()]),
505
+ )
506
+ set_seed(args.seed)
507
+ torch.backends.cudnn.deterministic = args.torch_deterministic
508
+ device = torch.device(
509
+ "cuda" if torch.cuda.is_available() and args.cuda else "cpu"
510
+ )
511
+ envs = gym.vector.SyncVectorEnv([
512
+ make_env(args.env_id, args.seed + i, i, args.capture_video, run_name)
513
+ for i in range(args.num_envs)
514
+ ])
515
+ agent = Agent(envs).to(device)
516
+ num_updates = args.total_timesteps // args.batch_size
517
+ (optimizer, scheduler) = make_optimizer(
518
+ agent, num_updates, args.learning_rate, 0.0
519
+ )
520
+ global_step = 0
521
+ old_approx_kl = 0.0
522
+ approx_kl = 0.0
523
+ value_loss = t.tensor(0.0)
524
+ policy_loss = t.tensor(0.0)
525
+ entropy_loss = t.tensor(0.0)
526
+ clipfracs = []
527
+ info = []
528
+ action_shape = get_action_shape(envs)
529
+ next_obs, next_done = reset_env(envs, device)
530
+ start_time = time.time()
531
+ for _ in range(num_updates):
532
+ rp = rollout_phase(
533
+ next_obs, next_done, agent, envs, writer, device, global_step,
534
+ action_shape, args.num_envs, args.num_steps,
535
+ )
536
+ next_obs, next_done, actions, dones, logprobs, obs, rewards, values = rp
537
+ with t.inference_mode():
538
+ next_value = rearrange(agent.critic(next_obs), "env 1 -> 1 env")
539
+ advantages = compute_advantages(
540
+ next_value, next_done, rewards, values, dones, device, args.gamma,
541
+ args.gae_lambda
542
+ )
543
+ clipfracs.clear()
544
+ mb: Minibatch
545
+ for _ in range(args.update_epochs):
546
+ minibatches = make_minibatches(
547
+ obs,
548
+ logprobs,
549
+ actions,
550
+ advantages,
551
+ values,
552
+ envs.single_observation_space.shape,
553
+ action_shape,
554
+ args.batch_size,
555
+ args.minibatch_size,
556
+ )
557
+ for mb in minibatches:
558
+ probs = Categorical(logits=agent.actor(mb.obs))
559
+ value_loss = calc_value_function_loss(
560
+ agent.critic, mb.obs, mb.returns, args.vf_coef
561
+ )
562
+ policy_loss = calc_policy_loss(
563
+ probs, mb.actions, mb.advantages, mb.logprobs,
564
+ args.clip_coef
565
+ )
566
+ entropy_loss = calc_entropy_loss(probs, args.ent_coef)
567
+ loss = policy_loss + entropy_loss - value_loss
568
+ loss.backward()
569
+ nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
570
+ optimizer.step()
571
+ optimizer.zero_grad()
572
+
573
+ scheduler.step()
574
+ (y_pred, y_true) = (mb.values.cpu().numpy(), mb.returns.cpu().numpy())
575
+ var_y = np.var(y_true)
576
+ explained_var = (
577
+ np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
578
+ )
579
+ with torch.no_grad():
580
+ newlogprob: t.Tensor = probs.log_prob(mb.actions)
581
+ logratio = newlogprob - mb.logprobs
582
+ ratio = logratio.exp()
583
+ old_approx_kl = (-logratio).mean().item()
584
+ approx_kl = (ratio - 1 - logratio).mean().item()
585
+ clipfracs += [
586
+ ((ratio - 1.0).abs() > args.clip_coef).float().mean().item()
587
+ ]
588
+ writer.add_scalar(
589
+ "charts/learning_rate", optimizer.param_groups[0]["lr"],
590
+ global_step
591
+ )
592
+ writer.add_scalar("losses/value_loss", value_loss.item(), global_step)
593
+ writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step)
594
+ writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
595
+ writer.add_scalar("losses/old_approx_kl", old_approx_kl, global_step)
596
+ writer.add_scalar("losses/approx_kl", approx_kl, global_step)
597
+ writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
598
+ writer.add_scalar(
599
+ "losses/explained_variance", explained_var, global_step
600
+ )
601
+ writer.add_scalar(
602
+ "charts/SPS",
603
+ int(global_step / (time.time() - start_time)),
604
+ global_step
605
+ )
606
+ if global_step % 1000 == 0:
607
+ print(
608
+ "steps per second (SPS):",
609
+ int(global_step / (time.time() - start_time))
610
+ )
611
+ print("losses/value_loss", value_loss.item())
612
+ print("losses/policy_loss", policy_loss.item())
613
+ print("losses/entropy", entropy_loss.item())
614
+ print(f'... training complete after {global_step} steps')
615
+ envs.close()
616
+ writer.close()
617
+ if args.track:
618
+ model_path = f'{wandb.run.dir}/model_state_dict.pt'
619
+ print(f'Saving model to {model_path}')
620
+ t.save(agent.state_dict(), model_path)
621
+ wandb.finish()
622
+ print('...wandb finished.')
623
+
624
+
625
+ # + id="-oZHTffJZP17" executionInfo={"status": "ok", "timestamp": 1677942433344, "user_tz": 0, "elapsed": 66678, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} colab={"base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": ["c966d31ee30d43e0a8cc269a8a22b717", "294a378e56c44e4c9a3c58e8bf5b5f62", "473cc94ea22746f3a51e2186d973f741", "e3bb8c5a2c3841c2b33a7b8afb66a88f", "6133d8cbba964b7e8755e1c0691caf27", "1bf18f5fae9c4f58b2e360bc35251a94", "e820d38826494e248ca8974cccc1f338", "05eebe964b4b4c93b4aa0eac9ff865cb"]} outputId="0cfbb11c-831a-4622-8c01-afebae209d04"
626
+ # #%%wandb
627
+ if MAIN:
628
+ args = PPOArgs()
629
+ train_ppo(args)
630
+
631
+ # + colab={"base_uri": "https://localhost:8080/"} id="xJW6KL7QIj4s" outputId="7c529849-6d46-4a6a-def5-e1c0ef652c64"
632
+ # !python demo.py
633
+
634
+ # + id="P7ZfUlAqImIr" executionInfo={"status": "aborted", "timestamp": 1677942332655, "user_tz": 0, "elapsed": 4, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
635
+
requirements.txt ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ aeppl==0.0.33
3
+ aesara==2.7.9
4
+ aiofiles==23.1.0
5
+ aiohttp==3.8.4
6
+ aiosignal==1.3.1
7
+ alabaster==0.7.13
8
+ albumentations==1.2.1
9
+ altair==4.2.2
10
+ anyio==3.6.2
11
+ appdirs==1.4.4
12
+ argon2-cffi==21.3.0
13
+ argon2-cffi-bindings==21.2.0
14
+ arviz==0.12.1
15
+ astor==0.8.1
16
+ astropy==4.3.1
17
+ astunparse==1.6.3
18
+ async-timeout==4.0.2
19
+ atomicwrites==1.4.1
20
+ attrs==22.2.0
21
+ audioread==3.0.0
22
+ autograd==1.5
23
+ Babel==2.12.1
24
+ backcall==0.2.0
25
+ backports.zoneinfo==0.2.1
26
+ beautifulsoup4==4.6.3
27
+ bleach==6.0.0
28
+ blis==0.7.9
29
+ bokeh==2.4.3
30
+ branca==0.6.0
31
+ bs4==0.0.1
32
+ CacheControl==0.12.11
33
+ cachetools==5.3.0
34
+ catalogue==2.0.8
35
+ certifi==2022.12.7
36
+ cffi==1.15.1
37
+ cftime==1.6.2
38
+ chardet==4.0.0
39
+ charset-normalizer==3.0.1
40
+ click==8.1.3
41
+ clikit==0.6.2
42
+ cloudpickle==2.2.1
43
+ cmake==3.22.6
44
+ cmdstanpy==1.1.0
45
+ colorcet==3.0.1
46
+ colorlover==0.3.0
47
+ community==1.0.0b1
48
+ confection==0.0.4
49
+ cons==0.4.5
50
+ contextlib2==0.5.5
51
+ convertdate==2.4.0
52
+ crashtest==0.3.1
53
+ crcmod==1.7
54
+ cufflinks==0.17.3
55
+ cvxopt==1.3.0
56
+ cvxpy==1.2.3
57
+ cycler==0.11.0
58
+ cymem==2.0.7
59
+ Cython==0.29.33
60
+ dask==2022.2.1
61
+ datascience==0.17.6
62
+ db-dtypes==1.0.5
63
+ dbus-python==1.2.16
64
+ debugpy==1.6.4
65
+ decorator==4.4.2
66
+ defusedxml==0.7.1
67
+ distributed==2022.2.1
68
+ dlib==19.24.0
69
+ dm-tree==0.1.8
70
+ dnspython==2.3.0
71
+ docker-pycreds==0.4.0
72
+ docutils==0.16
73
+ dopamine-rl==1.0.5
74
+ earthengine-api==0.1.342
75
+ easydict==1.10
76
+ ecos==2.0.12
77
+ editdistance==0.5.3
78
+ einops==0.6.0
79
+ en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.1/en_core_web_sm-3.4.1-py3-none-any.whl
80
+ entrypoints==0.4
81
+ ephem==4.1.4
82
+ et-xmlfile==1.1.0
83
+ etils==1.0.0
84
+ etuples==0.3.8
85
+ fa2==0.3.5
86
+ fastai==2.7.11
87
+ fastapi==0.92.0
88
+ fastcore==1.5.28
89
+ fastdownload==0.0.7
90
+ fastdtw==0.3.4
91
+ fastjsonschema==2.16.3
92
+ fastprogress==1.0.3
93
+ fastrlock==0.8.1
94
+ feather-format==0.4.1
95
+ ffmpy==0.3.0
96
+ filelock==3.9.0
97
+ firebase-admin==5.3.0
98
+ fix-yahoo-finance==0.0.22
99
+ Flask==2.2.3
100
+ flatbuffers==23.1.21
101
+ folium==0.12.1.post1
102
+ fonttools==4.38.0
103
+ frozenlist==1.3.3
104
+ fsspec==2023.1.0
105
+ future==0.16.0
106
+ gast==0.4.0
107
+ GDAL==3.3.2
108
+ gdown==4.4.0
109
+ gensim==3.6.0
110
+ geographiclib==1.52
111
+ geopy==1.17.0
112
+ gin-config==0.5.0
113
+ gitdb==4.0.10
114
+ GitPython==3.1.31
115
+ glob2==0.7
116
+ google==2.0.3
117
+ google-api-core==2.11.0
118
+ google-api-python-client==2.70.0
119
+ google-auth==2.16.1
120
+ google-auth-httplib2==0.1.0
121
+ google-auth-oauthlib==0.4.6
122
+ google-cloud-bigquery==3.4.2
123
+ google-cloud-bigquery-storage==2.18.1
124
+ google-cloud-core==2.3.2
125
+ google-cloud-datastore==2.11.1
126
+ google-cloud-firestore==2.7.3
127
+ google-cloud-language==2.6.1
128
+ google-cloud-storage==2.7.0
129
+ google-cloud-translate==3.8.4
130
+ google-colab @ file:///colabtools/dist/google-colab-1.0.0.tar.gz
131
+ google-crc32c==1.5.0
132
+ google-pasta==0.2.0
133
+ google-resumable-media==2.4.1
134
+ googleapis-common-protos==1.58.0
135
+ googledrivedownloader==0.4
136
+ gradio==3.20.0
137
+ graphviz==0.10.1
138
+ greenlet==2.0.2
139
+ grpcio==1.51.3
140
+ grpcio-status==1.48.2
141
+ gspread==3.4.2
142
+ gspread-dataframe==3.0.8
143
+ gym==0.25.2
144
+ gym-notices==0.0.8
145
+ h11==0.14.0
146
+ h5py==3.1.0
147
+ HeapDict==1.0.1
148
+ hijri-converter==2.2.4
149
+ holidays==0.20
150
+ holoviews==1.14.9
151
+ html5lib==1.0.1
152
+ httpcore==0.16.3
153
+ httpimport==0.5.18
154
+ httplib2==0.17.4
155
+ httpstan==4.6.1
156
+ httpx==0.23.3
157
+ humanize==0.5.1
158
+ hyperopt==0.1.2
159
+ idna==2.10
160
+ imageio==2.9.0
161
+ imagesize==1.4.1
162
+ imbalanced-learn==0.8.1
163
+ imblearn==0.0
164
+ imgaug==0.4.0
165
+ importlib-metadata==6.0.0
166
+ importlib-resources==5.12.0
167
+ imutils==0.5.4
168
+ inflect==2.1.0
169
+ intel-openmp==2023.0.0
170
+ ipykernel==5.3.4
171
+ ipython==7.9.0
172
+ ipython-genutils==0.2.0
173
+ ipython-sql==0.3.9
174
+ ipywidgets==7.7.1
175
+ itsdangerous==2.1.2
176
+ jax==0.4.4
177
+ jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.4+cuda11.cudnn82-cp38-cp38-manylinux2014_x86_64.whl
178
+ jieba==0.42.1
179
+ Jinja2==3.1.2
180
+ joblib==1.2.0
181
+ jsonschema==4.3.3
182
+ jupyter-client==6.1.12
183
+ jupyter-console==6.1.0
184
+ jupyter_core==5.2.0
185
+ jupyterlab-pygments==0.2.2
186
+ jupyterlab-widgets==3.0.5
187
+ jupytext==1.14.5
188
+ kaggle==1.5.12
189
+ keras==2.11.0
190
+ keras-vis==0.4.1
191
+ kiwisolver==1.4.4
192
+ korean-lunar-calendar==0.3.1
193
+ langcodes==3.3.0
194
+ libclang==15.0.6.1
195
+ librosa==0.8.1
196
+ lightgbm==2.2.3
197
+ linkify-it-py==2.0.0
198
+ llvmlite==0.39.1
199
+ lmdb==0.99
200
+ locket==1.0.0
201
+ logical-unification==0.4.5
202
+ LunarCalendar==0.0.9
203
+ lxml==4.9.2
204
+ Markdown==3.4.1
205
+ markdown-it-py==2.2.0
206
+ MarkupSafe==2.1.2
207
+ marshmallow==3.19.0
208
+ matplotlib==3.5.3
209
+ matplotlib-venn==0.11.9
210
+ mdit-py-plugins==0.3.3
211
+ mdurl==0.1.2
212
+ miniKanren==1.0.3
213
+ missingno==0.5.2
214
+ mistune==0.8.4
215
+ mizani==0.8.1
216
+ mkl==2019.0
217
+ mlxtend==0.14.0
218
+ more-itertools==9.1.0
219
+ moviepy==0.2.3.5
220
+ mpmath==1.2.1
221
+ msgpack==1.0.4
222
+ multidict==6.0.4
223
+ multipledispatch==0.6.0
224
+ multitasking==0.0.11
225
+ murmurhash==1.0.9
226
+ music21==5.5.0
227
+ natsort==5.5.0
228
+ nbclient==0.7.2
229
+ nbconvert==6.5.4
230
+ nbformat==5.7.3
231
+ netCDF4==1.6.2
232
+ networkx==3.0
233
+ nibabel==3.0.2
234
+ nltk==3.7
235
+ notebook==6.3.0
236
+ numba==0.56.4
237
+ numexpr==2.8.4
238
+ numpy==1.22.4
239
+ oauth2client==4.1.3
240
+ oauthlib==3.2.2
241
+ opencv-contrib-python==4.6.0.66
242
+ opencv-python==4.6.0.66
243
+ opencv-python-headless==4.7.0.72
244
+ openpyxl==3.0.10
245
+ opt-einsum==3.3.0
246
+ orjson==3.8.7
247
+ osqp==0.6.2.post0
248
+ packaging==23.0
249
+ palettable==3.3.0
250
+ pandas==1.3.5
251
+ pandas-datareader==0.9.0
252
+ pandas-gbq==0.17.9
253
+ pandas-profiling==1.4.1
254
+ pandocfilters==1.5.0
255
+ panel==0.14.3
256
+ param==1.12.3
257
+ parso==0.8.3
258
+ partd==1.3.0
259
+ pastel==0.2.1
260
+ pathlib==1.0.1
261
+ pathtools==0.1.2
262
+ pathy==0.10.1
263
+ patsy==0.5.3
264
+ pep517==0.13.0
265
+ pexpect==4.8.0
266
+ pickleshare==0.7.5
267
+ Pillow==8.4.0
268
+ pip-tools==6.6.2
269
+ platformdirs==3.0.0
270
+ plotly==5.5.0
271
+ plotnine==0.10.1
272
+ pluggy==0.7.1
273
+ pooch==1.7.0
274
+ portpicker==1.3.9
275
+ prefetch-generator==1.0.3
276
+ preshed==3.0.8
277
+ prettytable==3.6.0
278
+ progressbar2==3.38.0
279
+ prometheus-client==0.16.0
280
+ promise==2.3
281
+ prompt-toolkit==2.0.10
282
+ prophet==1.1.2
283
+ proto-plus==1.22.2
284
+ protobuf==3.19.6
285
+ psutil==5.4.8
286
+ psycopg2==2.9.5
287
+ ptyprocess==0.7.0
288
+ py==1.11.0
289
+ pyarrow==9.0.0
290
+ pyasn1==0.4.8
291
+ pyasn1-modules==0.2.8
292
+ pycocotools==2.0.6
293
+ pycparser==2.21
294
+ pycryptodome==3.17
295
+ pyct==0.5.0
296
+ pydantic==1.10.5
297
+ pydata-google-auth==1.7.0
298
+ pydot==1.3.0
299
+ pydot-ng==2.0.0
300
+ pydotplus==2.0.2
301
+ PyDrive==1.3.1
302
+ pydub==0.25.1
303
+ pyerfa==2.0.0.1
304
+ pygame==2.2.0
305
+ Pygments==2.6.1
306
+ PyGObject==3.36.0
307
+ pylev==1.4.0
308
+ pymc==4.1.4
309
+ PyMeeus==0.5.12
310
+ pymongo==4.3.3
311
+ pymystem3==0.2.0
312
+ PyOpenGL==3.1.6
313
+ pyparsing==3.0.9
314
+ pyrsistent==0.19.3
315
+ pysimdjson==3.2.0
316
+ PySocks==1.7.1
317
+ pystan==3.3.0
318
+ pytest==3.6.4
319
+ python-apt==2.0.1
320
+ python-dateutil==2.8.2
321
+ python-louvain==0.16
322
+ python-multipart==0.0.6
323
+ python-slugify==8.0.1
324
+ python-utils==3.5.2
325
+ pytz==2022.7.1
326
+ pyviz-comms==2.2.1
327
+ PyWavelets==1.4.1
328
+ PyYAML==6.0
329
+ pyzmq==23.2.1
330
+ qdldl==0.1.5.post3
331
+ qudida==0.0.4
332
+ regex==2022.6.2
333
+ requests==2.25.1
334
+ requests-oauthlib==1.3.1
335
+ requests-unixsocket==0.2.0
336
+ resampy==0.4.2
337
+ rfc3986==1.5.0
338
+ rpy2==3.5.5
339
+ rsa==4.9
340
+ scikit-image==0.19.3
341
+ scikit-learn==1.2.1
342
+ scipy==1.10.1
343
+ screen-resolution-extra==0.0.0
344
+ scs==3.2.2
345
+ seaborn==0.11.2
346
+ Send2Trash==1.8.0
347
+ sentry-sdk==1.16.0
348
+ setproctitle==1.3.2
349
+ shapely==2.0.1
350
+ six==1.15.0
351
+ sklearn-pandas==2.2.0
352
+ smart-open==6.3.0
353
+ smmap==5.0.0
354
+ sniffio==1.3.0
355
+ snowballstemmer==2.2.0
356
+ sortedcontainers==2.4.0
357
+ soundfile==0.12.1
358
+ spacy==3.4.4
359
+ spacy-legacy==3.0.12
360
+ spacy-loggers==1.0.4
361
+ Sphinx==3.5.4
362
+ sphinxcontrib-applehelp==1.0.4
363
+ sphinxcontrib-devhelp==1.0.2
364
+ sphinxcontrib-htmlhelp==2.0.1
365
+ sphinxcontrib-jsmath==1.0.1
366
+ sphinxcontrib-qthelp==1.0.3
367
+ sphinxcontrib-serializinghtml==1.1.5
368
+ SQLAlchemy==1.4.46
369
+ sqlparse==0.4.3
370
+ srsly==2.4.6
371
+ starlette==0.25.0
372
+ statsmodels==0.13.5
373
+ sympy==1.7.1
374
+ tables==3.7.0
375
+ tabulate==0.8.10
376
+ tblib==1.7.0
377
+ tenacity==8.2.2
378
+ tensorboard==2.11.2
379
+ tensorboard-data-server==0.6.1
380
+ tensorboard-plugin-wit==1.8.1
381
+ tensorflow==2.11.0
382
+ tensorflow-datasets==4.8.3
383
+ tensorflow-estimator==2.11.0
384
+ tensorflow-gcs-config==2.11.0
385
+ tensorflow-hub==0.12.0
386
+ tensorflow-io-gcs-filesystem==0.31.0
387
+ tensorflow-metadata==1.12.0
388
+ tensorflow-probability==0.19.0
389
+ termcolor==2.2.0
390
+ terminado==0.13.3
391
+ text-unidecode==1.3
392
+ textblob==0.15.3
393
+ thinc==8.1.7
394
+ threadpoolctl==3.1.0
395
+ tifffile==2023.2.27
396
+ tinycss2==1.2.1
397
+ toml==0.10.2
398
+ tomli==2.0.1
399
+ toolz==0.12.0
400
+ torch @ https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
401
+ torchaudio @ https://download.pytorch.org/whl/cu116/torchaudio-0.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
402
+ torchsummary==1.5.1
403
+ torchtext==0.14.1
404
+ torchtyping==0.1.4
405
+ torchvision @ https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp38-cp38-linux_x86_64.whl
406
+ tornado==6.2
407
+ tqdm==4.64.1
408
+ traitlets==5.7.1
409
+ tweepy==3.10.0
410
+ typeguard==2.13.3
411
+ typer==0.7.0
412
+ typing_extensions==4.5.0
413
+ tzlocal==1.5.1
414
+ uc-micro-py==1.0.1
415
+ uritemplate==4.1.1
416
+ urllib3==1.26.14
417
+ uvicorn==0.20.0
418
+ vega-datasets==0.9.0
419
+ wandb==0.13.10
420
+ wasabi==0.10.1
421
+ wcwidth==0.2.6
422
+ webargs==8.2.0
423
+ webencodings==0.5.1
424
+ websockets==10.4
425
+ Werkzeug==2.2.3
426
+ widgetsnbextension==3.6.2
427
+ wordcloud==1.8.2.2
428
+ wrapt==1.15.0
429
+ xarray==2022.12.0
430
+ xarray-einstats==0.5.1
431
+ xgboost==1.7.4
432
+ xkit==0.0.0
433
+ xlrd==1.2.0
434
+ xlwt==1.3.0
435
+ yarl==1.8.2
436
+ yellowbrick==1.5
437
+ zict==2.2.0
438
+ zipp==3.15.0