sgoodfriend commited on
Commit
c6e31cd
·
1 Parent(s): 43e364e

PPO playing Walker2DBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +17 -14
  2. pyproject.toml +23 -2
  3. replay.meta.json +1 -1
  4. replay.mp4 +2 -2
  5. rl_algo_impls/a2c/a2c.py +13 -19
  6. rl_algo_impls/a2c/optimize.py +1 -1
  7. rl_algo_impls/benchmark_publish.py +2 -2
  8. rl_algo_impls/compare_runs.py +2 -1
  9. rl_algo_impls/dqn/policy.py +14 -7
  10. rl_algo_impls/dqn/q_net.py +6 -6
  11. rl_algo_impls/huggingface_publish.py +1 -1
  12. rl_algo_impls/hyperparams/a2c.yml +17 -13
  13. rl_algo_impls/hyperparams/dqn.yml +1 -1
  14. rl_algo_impls/hyperparams/ppo.yml +125 -5
  15. rl_algo_impls/hyperparams/vpg.yml +4 -4
  16. rl_algo_impls/optimize.py +5 -4
  17. rl_algo_impls/ppo/ppo.py +248 -227
  18. rl_algo_impls/runner/config.py +9 -3
  19. rl_algo_impls/runner/evaluate.py +2 -2
  20. rl_algo_impls/runner/running_utils.py +33 -18
  21. rl_algo_impls/runner/train.py +11 -10
  22. rl_algo_impls/shared/actor/__init__.py +2 -0
  23. rl_algo_impls/shared/actor/actor.py +42 -0
  24. rl_algo_impls/shared/actor/categorical.py +64 -0
  25. rl_algo_impls/shared/actor/gaussian.py +61 -0
  26. rl_algo_impls/shared/actor/gridnet.py +108 -0
  27. rl_algo_impls/shared/actor/gridnet_decoder.py +80 -0
  28. rl_algo_impls/shared/actor/make_actor.py +95 -0
  29. rl_algo_impls/shared/actor/multi_discrete.py +101 -0
  30. rl_algo_impls/shared/{policy/actor.py → actor/state_dependent_noise.py} +33 -143
  31. rl_algo_impls/shared/callbacks/eval_callback.py +26 -9
  32. rl_algo_impls/shared/encoder/__init__.py +2 -0
  33. rl_algo_impls/shared/encoder/cnn.py +72 -0
  34. rl_algo_impls/shared/encoder/encoder.py +73 -0
  35. rl_algo_impls/shared/encoder/gridnet_encoder.py +64 -0
  36. rl_algo_impls/shared/encoder/impala_cnn.py +92 -0
  37. rl_algo_impls/shared/encoder/microrts_cnn.py +45 -0
  38. rl_algo_impls/shared/encoder/nature_cnn.py +53 -0
  39. rl_algo_impls/shared/gae.py +29 -2
  40. rl_algo_impls/shared/module/feature_extractor.py +0 -215
  41. rl_algo_impls/shared/module/module.py +6 -3
  42. rl_algo_impls/shared/policy/critic.py +22 -10
  43. rl_algo_impls/shared/policy/on_policy.py +57 -34
  44. rl_algo_impls/shared/policy/policy.py +6 -1
  45. rl_algo_impls/shared/schedule.py +29 -1
  46. rl_algo_impls/shared/stats.py +24 -6
  47. rl_algo_impls/shared/vec_env/__init__.py +1 -0
  48. rl_algo_impls/shared/vec_env/make_env.py +66 -0
  49. rl_algo_impls/shared/vec_env/microrts.py +94 -0
  50. rl_algo_impls/shared/vec_env/microrts_compat.py +49 -0
README.md CHANGED
@@ -10,7 +10,7 @@ model-index:
10
  results:
11
  - metrics:
12
  - type: mean_reward
13
- value: 2166.35 +/- 25.49
14
  name: mean_reward
15
  task:
16
  type: reinforcement-learning
@@ -23,17 +23,17 @@ model-index:
23
 
24
  This is a trained model of a **PPO** agent playing **Walker2DBulletEnv-v0** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
25
 
26
- All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/09frjfcs.
27
 
28
  ## Training Results
29
 
30
- This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [2067e21](https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
31
 
32
  | algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
33
  |:-------|:---------------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
34
- | ppo | Walker2DBulletEnv-v0 | 1 | 1002.75 | 12.2103 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/t1nah1wv) |
35
- | ppo | Walker2DBulletEnv-v0 | 2 | 2270.15 | 375.098 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/k5liyiuj) |
36
- | ppo | Walker2DBulletEnv-v0 | 3 | 2166.35 | 25.4924 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/t0x7t3no) |
37
 
38
 
39
  ### Prerequisites: Weights & Biases (WandB)
@@ -53,10 +53,10 @@ login`.
53
  Note: While the model state dictionary and hyperaparameters are saved, the latest
54
  implementation could be sufficiently different to not be able to reproduce similar
55
  results. You might need to checkout the commit the agent was trained on:
56
- [2067e21](https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0).
57
  ```
58
  # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
59
- python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/t0x7t3no
60
  ```
61
 
62
  Setup hasn't been completely worked out yet, so you might be best served by using Google
@@ -68,11 +68,11 @@ notebook.
68
 
69
  ## Training
70
  If you want the highest chance to reproduce these results, you'll want to checkout the
71
- commit the agent was trained on: [2067e21](https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0). While
72
  training is deterministic, different hardware will give different results.
73
 
74
  ```
75
- python train.py --algo ppo --env Walker2DBulletEnv-v0 --seed 3
76
  ```
77
 
78
  Setup hasn't been completely worked out yet, so you might be best served by using Google
@@ -83,7 +83,7 @@ notebook.
83
 
84
 
85
  ## Benchmarking (with Lambda Labs instance)
86
- This and other models from https://api.wandb.ai/links/sgoodfriend/09frjfcs were generated by running a script on a Lambda
87
  Labs instance. In a Lambda Labs instance terminal:
88
  ```
89
  git clone [email protected]:sgoodfriend/rl-algo-impls.git
@@ -105,6 +105,7 @@ can be used. However, this requires a Google Colab Pro+ subscription and running
105
  This isn't exactly the format of hyperparams in hyperparams/ppo.yml, but instead the Wandb Run Config. However, it's very
106
  close and has some additional data:
107
  ```
 
108
  algo: ppo
109
  algo_hyperparams:
110
  batch_size: 128
@@ -134,13 +135,15 @@ policy_hyperparams:
134
  v_hidden_sizes:
135
  - 256
136
  - 256
137
- seed: 3
138
  use_deterministic_algorithms: true
139
  wandb_entity: null
140
  wandb_group: null
141
  wandb_project_name: rl-algo-impls-benchmarks
142
  wandb_tags:
143
- - benchmark_2067e21
144
- - host_155-248-199-228
 
 
145
 
146
  ```
 
10
  results:
11
  - metrics:
12
  - type: mean_reward
13
+ value: 1943.44 +/- 6.03
14
  name: mean_reward
15
  task:
16
  type: reinforcement-learning
 
23
 
24
  This is a trained model of a **PPO** agent playing **Walker2DBulletEnv-v0** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
25
 
26
+ All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/7lx79bf0.
27
 
28
  ## Training Results
29
 
30
+ This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
31
 
32
  | algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
33
  |:-------|:---------------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
34
+ | ppo | Walker2DBulletEnv-v0 | 1 | 1943.44 | 6.02595 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/rkpqqxbp) |
35
+ | ppo | Walker2DBulletEnv-v0 | 2 | 1821.93 | 13.1212 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/2dxhttk3) |
36
+ | ppo | Walker2DBulletEnv-v0 | 3 | 2109.58 | 509.27 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/ormofluw) |
37
 
38
 
39
  ### Prerequisites: Weights & Biases (WandB)
 
53
  Note: While the model state dictionary and hyperaparameters are saved, the latest
54
  implementation could be sufficiently different to not be able to reproduce similar
55
  results. You might need to checkout the commit the agent was trained on:
56
+ [0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c).
57
  ```
58
  # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
59
+ python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/rkpqqxbp
60
  ```
61
 
62
  Setup hasn't been completely worked out yet, so you might be best served by using Google
 
68
 
69
  ## Training
70
  If you want the highest chance to reproduce these results, you'll want to checkout the
71
+ commit the agent was trained on: [0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c). While
72
  training is deterministic, different hardware will give different results.
73
 
74
  ```
75
+ python train.py --algo ppo --env Walker2DBulletEnv-v0 --seed 1
76
  ```
77
 
78
  Setup hasn't been completely worked out yet, so you might be best served by using Google
 
83
 
84
 
85
  ## Benchmarking (with Lambda Labs instance)
86
+ This and other models from https://api.wandb.ai/links/sgoodfriend/7lx79bf0 were generated by running a script on a Lambda
87
  Labs instance. In a Lambda Labs instance terminal:
88
  ```
89
  git clone [email protected]:sgoodfriend/rl-algo-impls.git
 
105
  This isn't exactly the format of hyperparams in hyperparams/ppo.yml, but instead the Wandb Run Config. However, it's very
106
  close and has some additional data:
107
  ```
108
+ additional_keys_to_log: []
109
  algo: ppo
110
  algo_hyperparams:
111
  batch_size: 128
 
135
  v_hidden_sizes:
136
  - 256
137
  - 256
138
+ seed: 1
139
  use_deterministic_algorithms: true
140
  wandb_entity: null
141
  wandb_group: null
142
  wandb_project_name: rl-algo-impls-benchmarks
143
  wandb_tags:
144
+ - benchmark_0511de3
145
+ - host_152-67-249-42
146
+ - branch_main
147
+ - v0.0.8
148
 
149
  ```
pyproject.toml CHANGED
@@ -1,6 +1,6 @@
1
  [project]
2
  name = "rl_algo_impls"
3
- version = "0.0.4"
4
  description = "Implementations of reinforcement learning algorithms"
5
  authors = [
6
  {name = "Scott Goodfriend", email = "[email protected]"},
@@ -35,6 +35,7 @@ dependencies = [
35
  "dash",
36
  "kaleido",
37
  "PyYAML",
 
38
  ]
39
 
40
  [tool.setuptools]
@@ -55,10 +56,30 @@ procgen = [
55
  "glfw >= 1.12.0, < 1.13",
56
  "procgen; platform_machine=='x86_64'",
57
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  [project.urls]
60
  "Homepage" = "https://github.com/sgoodfriend/rl-algo-impls"
61
 
62
  [build-system]
63
  requires = ["setuptools==65.5.0", "setuptools-scm"]
64
- build-backend = "setuptools.build_meta"
 
 
 
 
1
  [project]
2
  name = "rl_algo_impls"
3
+ version = "0.0.8"
4
  description = "Implementations of reinforcement learning algorithms"
5
  authors = [
6
  {name = "Scott Goodfriend", email = "[email protected]"},
 
35
  "dash",
36
  "kaleido",
37
  "PyYAML",
38
+ "scikit-learn",
39
  ]
40
 
41
  [tool.setuptools]
 
56
  "glfw >= 1.12.0, < 1.13",
57
  "procgen; platform_machine=='x86_64'",
58
  ]
59
+ microrts-old = [
60
+ "numpy < 1.24.0", # Support for gym-microrts < 0.6.0
61
+ "gym-microrts == 0.2.0", # Match ppo-implementation-details
62
+ ]
63
+ microrts = [
64
+ "numpy < 1.24.0", # Support for gym-microrts < 0.6.0
65
+ "gym-microrts == 0.3.2",
66
+ ]
67
+ jupyter = [
68
+ "jupyter",
69
+ "notebook"
70
+ ]
71
+ all = [
72
+ "rl-algo-impls[test]",
73
+ "rl-algo-impls[procgen]",
74
+ "rl-algo-impls[microrts]",
75
+ ]
76
 
77
  [project.urls]
78
  "Homepage" = "https://github.com/sgoodfriend/rl-algo-impls"
79
 
80
  [build-system]
81
  requires = ["setuptools==65.5.0", "setuptools-scm"]
82
+ build-backend = "setuptools.build_meta"
83
+
84
+ [tool.isort]
85
+ profile = "black"
replay.meta.json CHANGED
@@ -1 +1 @@
1
- {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "320x240", "-pix_fmt", "rgb24", "-framerate", "60", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "60", "/tmp/tmpvlzcdxnv/ppo-Walker2DBulletEnv-v0/replay.mp4"]}, "episode": {"r": 2173.694091796875, "l": 1000, "t": 28.513911}}
 
1
+ {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "320x240", "-pix_fmt", "rgb24", "-framerate", "60", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "60", "/tmp/tmpitcja6vi/ppo-Walker2DBulletEnv-v0/replay.mp4"]}, "episode": {"r": 1936.2642822265625, "l": 1000, "t": 28.126744}}
replay.mp4 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d26e365601a4eb5029fd5c8c7df0ed78f9c9ad33ca421de56d7f0fe3bd3cc019
3
- size 1148791
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a4a32646a8d7dee1c9f24f52e21be8dcd4b1f79a86ba04667c6054d729f162c
3
+ size 1180592
rl_algo_impls/a2c/a2c.py CHANGED
@@ -10,6 +10,7 @@ from typing import Optional, TypeVar
10
 
11
  from rl_algo_impls.shared.algorithm import Algorithm
12
  from rl_algo_impls.shared.callbacks.callback import Callback
 
13
  from rl_algo_impls.shared.policy.on_policy import ActorCritic
14
  from rl_algo_impls.shared.schedule import schedule, update_learning_rate
15
  from rl_algo_impls.shared.stats import log_scalars
@@ -84,12 +85,12 @@ class A2C(Algorithm):
84
  obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
85
  actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
86
  rewards = np.zeros(epoch_dim, dtype=np.float32)
87
- episode_starts = np.zeros(epoch_dim, dtype=np.byte)
88
  values = np.zeros(epoch_dim, dtype=np.float32)
89
  logprobs = np.zeros(epoch_dim, dtype=np.float32)
90
 
91
  next_obs = self.env.reset()
92
- next_episode_starts = np.ones(step_dim, dtype=np.byte)
93
 
94
  timesteps_elapsed = start_timesteps
95
  while timesteps_elapsed < start_timesteps + train_timesteps:
@@ -126,23 +127,16 @@ class A2C(Algorithm):
126
  clamped_action
127
  )
128
 
129
- advantages = np.zeros(epoch_dim, dtype=np.float32)
130
- last_gae_lam = 0
131
- for t in reversed(range(self.n_steps)):
132
- if t == self.n_steps - 1:
133
- next_nonterminal = 1.0 - next_episode_starts
134
- next_value = self.policy.value(next_obs)
135
- else:
136
- next_nonterminal = 1.0 - episode_starts[t + 1]
137
- next_value = values[t + 1]
138
- delta = (
139
- rewards[t] + self.gamma * next_value * next_nonterminal - values[t]
140
- )
141
- last_gae_lam = (
142
- delta
143
- + self.gamma * self.gae_lambda * next_nonterminal * last_gae_lam
144
- )
145
- advantages[t] = last_gae_lam
146
  returns = advantages + values
147
 
148
  b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
 
10
 
11
  from rl_algo_impls.shared.algorithm import Algorithm
12
  from rl_algo_impls.shared.callbacks.callback import Callback
13
+ from rl_algo_impls.shared.gae import compute_advantages
14
  from rl_algo_impls.shared.policy.on_policy import ActorCritic
15
  from rl_algo_impls.shared.schedule import schedule, update_learning_rate
16
  from rl_algo_impls.shared.stats import log_scalars
 
85
  obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
86
  actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
87
  rewards = np.zeros(epoch_dim, dtype=np.float32)
88
+ episode_starts = np.zeros(epoch_dim, dtype=np.bool8)
89
  values = np.zeros(epoch_dim, dtype=np.float32)
90
  logprobs = np.zeros(epoch_dim, dtype=np.float32)
91
 
92
  next_obs = self.env.reset()
93
+ next_episode_starts = np.full(step_dim, True, dtype=np.bool8)
94
 
95
  timesteps_elapsed = start_timesteps
96
  while timesteps_elapsed < start_timesteps + train_timesteps:
 
127
  clamped_action
128
  )
129
 
130
+ advantages = compute_advantages(
131
+ rewards,
132
+ values,
133
+ episode_starts,
134
+ next_episode_starts,
135
+ next_obs,
136
+ self.policy,
137
+ self.gamma,
138
+ self.gae_lambda,
139
+ )
 
 
 
 
 
 
 
140
  returns = advantages + values
141
 
142
  b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
rl_algo_impls/a2c/optimize.py CHANGED
@@ -3,7 +3,7 @@ import optuna
3
  from copy import deepcopy
4
 
5
  from rl_algo_impls.runner.config import Config, Hyperparams, EnvHyperparams
6
- from rl_algo_impls.runner.env import make_eval_env
7
  from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
8
  from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
9
 
 
3
  from copy import deepcopy
4
 
5
  from rl_algo_impls.runner.config import Config, Hyperparams, EnvHyperparams
6
+ from rl_algo_impls.shared.vec_env import make_eval_env
7
  from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
8
  from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
9
 
rl_algo_impls/benchmark_publish.py CHANGED
@@ -54,8 +54,8 @@ def benchmark_publish() -> None:
54
  "--virtual-display", action="store_true", help="Use headless virtual display"
55
  )
56
  # parser.set_defaults(
57
- # wandb_tags=["benchmark_2067e21", "host_155-248-199-228"],
58
- # wandb_report_url="https://api.wandb.ai/links/sgoodfriend/09frjfcs",
59
  # envs=[],
60
  # exclude_envs=[],
61
  # )
 
54
  "--virtual-display", action="store_true", help="Use headless virtual display"
55
  )
56
  # parser.set_defaults(
57
+ # wandb_tags=["benchmark_e47a44c", "host_129-146-2-230"],
58
+ # wandb_report_url="https://api.wandb.ai/links/sgoodfriend/v4wd7cp5",
59
  # envs=[],
60
  # exclude_envs=[],
61
  # )
rl_algo_impls/compare_runs.py CHANGED
@@ -194,5 +194,6 @@ def compare_runs() -> None:
194
  df.loc["mean"] = df.mean(numeric_only=True)
195
  print(df.to_markdown())
196
 
 
197
  if __name__ == "__main__":
198
- compare_runs()
 
194
  df.loc["mean"] = df.mean(numeric_only=True)
195
  print(df.to_markdown())
196
 
197
+
198
  if __name__ == "__main__":
199
+ compare_runs()
rl_algo_impls/dqn/policy.py CHANGED
@@ -1,16 +1,16 @@
1
- import numpy as np
2
  import os
3
- import torch
4
-
5
  from typing import Optional, Sequence, TypeVar
6
 
 
 
 
7
  from rl_algo_impls.dqn.q_net import QNetwork
8
  from rl_algo_impls.shared.policy.policy import Policy
9
  from rl_algo_impls.wrappers.vectorable_wrapper import (
10
  VecEnv,
11
  VecEnvObs,
12
- single_observation_space,
13
  single_action_space,
 
14
  )
15
 
16
  DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
@@ -21,7 +21,7 @@ class DQNPolicy(Policy):
21
  self,
22
  env: VecEnv,
23
  hidden_sizes: Sequence[int] = [],
24
- cnn_feature_dim: int = 512,
25
  cnn_style: str = "nature",
26
  cnn_layers_init_orthogonal: Optional[bool] = None,
27
  impala_channels: Sequence[int] = (16, 32, 32),
@@ -32,16 +32,23 @@ class DQNPolicy(Policy):
32
  single_observation_space(env),
33
  single_action_space(env),
34
  hidden_sizes,
35
- cnn_feature_dim=cnn_feature_dim,
36
  cnn_style=cnn_style,
37
  cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
38
  impala_channels=impala_channels,
39
  )
40
 
41
  def act(
42
- self, obs: VecEnvObs, eps: float = 0, deterministic: bool = True
 
 
 
 
43
  ) -> np.ndarray:
44
  assert eps == 0 if deterministic else eps >= 0
 
 
 
45
  if not deterministic and np.random.random() < eps:
46
  return np.array(
47
  [
 
 
1
  import os
 
 
2
  from typing import Optional, Sequence, TypeVar
3
 
4
+ import numpy as np
5
+ import torch
6
+
7
  from rl_algo_impls.dqn.q_net import QNetwork
8
  from rl_algo_impls.shared.policy.policy import Policy
9
  from rl_algo_impls.wrappers.vectorable_wrapper import (
10
  VecEnv,
11
  VecEnvObs,
 
12
  single_action_space,
13
+ single_observation_space,
14
  )
15
 
16
  DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
 
21
  self,
22
  env: VecEnv,
23
  hidden_sizes: Sequence[int] = [],
24
+ cnn_flatten_dim: int = 512,
25
  cnn_style: str = "nature",
26
  cnn_layers_init_orthogonal: Optional[bool] = None,
27
  impala_channels: Sequence[int] = (16, 32, 32),
 
32
  single_observation_space(env),
33
  single_action_space(env),
34
  hidden_sizes,
35
+ cnn_flatten_dim=cnn_flatten_dim,
36
  cnn_style=cnn_style,
37
  cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
38
  impala_channels=impala_channels,
39
  )
40
 
41
  def act(
42
+ self,
43
+ obs: VecEnvObs,
44
+ eps: float = 0,
45
+ deterministic: bool = True,
46
+ action_masks: Optional[np.ndarray] = None,
47
  ) -> np.ndarray:
48
  assert eps == 0 if deterministic else eps >= 0
49
+ assert (
50
+ action_masks is None
51
+ ), f"action_masks not currently supported in {self.__class__.__name__}"
52
  if not deterministic and np.random.random() < eps:
53
  return np.array(
54
  [
rl_algo_impls/dqn/q_net.py CHANGED
@@ -1,11 +1,11 @@
 
 
1
  import gym
2
  import torch as th
3
  import torch.nn as nn
4
-
5
  from gym.spaces import Discrete
6
- from typing import Optional, Sequence, Type
7
 
8
- from rl_algo_impls.shared.module.feature_extractor import FeatureExtractor
9
  from rl_algo_impls.shared.module.module import mlp
10
 
11
 
@@ -16,17 +16,17 @@ class QNetwork(nn.Module):
16
  action_space: gym.Space,
17
  hidden_sizes: Sequence[int] = [],
18
  activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
19
- cnn_feature_dim: int = 512,
20
  cnn_style: str = "nature",
21
  cnn_layers_init_orthogonal: Optional[bool] = None,
22
  impala_channels: Sequence[int] = (16, 32, 32),
23
  ) -> None:
24
  super().__init__()
25
  assert isinstance(action_space, Discrete)
26
- self._feature_extractor = FeatureExtractor(
27
  observation_space,
28
  activation,
29
- cnn_feature_dim=cnn_feature_dim,
30
  cnn_style=cnn_style,
31
  cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
32
  impala_channels=impala_channels,
 
1
+ from typing import Optional, Sequence, Type
2
+
3
  import gym
4
  import torch as th
5
  import torch.nn as nn
 
6
  from gym.spaces import Discrete
 
7
 
8
+ from rl_algo_impls.shared.encoder import Encoder
9
  from rl_algo_impls.shared.module.module import mlp
10
 
11
 
 
16
  action_space: gym.Space,
17
  hidden_sizes: Sequence[int] = [],
18
  activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
19
+ cnn_flatten_dim: int = 512,
20
  cnn_style: str = "nature",
21
  cnn_layers_init_orthogonal: Optional[bool] = None,
22
  impala_channels: Sequence[int] = (16, 32, 32),
23
  ) -> None:
24
  super().__init__()
25
  assert isinstance(action_space, Discrete)
26
+ self._feature_extractor = Encoder(
27
  observation_space,
28
  activation,
29
+ cnn_flatten_dim=cnn_flatten_dim,
30
  cnn_style=cnn_style,
31
  cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
32
  impala_channels=impala_channels,
rl_algo_impls/huggingface_publish.py CHANGED
@@ -19,7 +19,7 @@ from pyvirtualdisplay.display import Display
19
  from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
20
  from rl_algo_impls.runner.config import EnvHyperparams
21
  from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
22
- from rl_algo_impls.runner.env import make_eval_env
23
  from rl_algo_impls.shared.callbacks.eval_callback import evaluate
24
  from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
25
 
 
19
  from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
20
  from rl_algo_impls.runner.config import EnvHyperparams
21
  from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
22
+ from rl_algo_impls.shared.vec_env import make_eval_env
23
  from rl_algo_impls.shared.callbacks.eval_callback import evaluate
24
  from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
25
 
rl_algo_impls/hyperparams/a2c.yml CHANGED
@@ -97,31 +97,35 @@ Walker2DBulletEnv-v0:
97
  HopperBulletEnv-v0:
98
  <<: *pybullet-defaults
99
 
 
100
  CarRacing-v0:
101
  n_timesteps: !!float 4e6
102
  env_hyperparams:
103
- n_envs: 8
104
  frame_stack: 4
105
  normalize: true
106
  normalize_kwargs:
107
  norm_obs: false
108
  norm_reward: true
109
  policy_hyperparams:
110
- use_sde: true
111
- log_std_init: -2
112
- init_layers_orthogonal: false
113
- activation_fn: relu
114
  share_features_extractor: false
115
- cnn_feature_dim: 256
116
  hidden_sizes: [256]
117
  algo_hyperparams:
118
- n_steps: 512
119
- learning_rate: !!float 1.62e-5
120
- gamma: 0.997
121
- gae_lambda: 0.975
122
- ent_coef: 0
123
- sde_sample_freq: 128
124
- vf_coef: 0.64
 
 
 
125
 
126
  _atari: &atari-defaults
127
  n_timesteps: !!float 1e7
 
97
  HopperBulletEnv-v0:
98
  <<: *pybullet-defaults
99
 
100
+ # Tuned
101
  CarRacing-v0:
102
  n_timesteps: !!float 4e6
103
  env_hyperparams:
104
+ n_envs: 16
105
  frame_stack: 4
106
  normalize: true
107
  normalize_kwargs:
108
  norm_obs: false
109
  norm_reward: true
110
  policy_hyperparams:
111
+ use_sde: false
112
+ log_std_init: -1.3502584927786276
113
+ init_layers_orthogonal: true
114
+ activation_fn: tanh
115
  share_features_extractor: false
116
+ cnn_flatten_dim: 256
117
  hidden_sizes: [256]
118
  algo_hyperparams:
119
+ n_steps: 16
120
+ learning_rate: 0.000025630993245026736
121
+ learning_rate_decay: linear
122
+ gamma: 0.99957617037542
123
+ gae_lambda: 0.949455676599436
124
+ ent_coef: !!float 1.707983205298309e-7
125
+ vf_coef: 0.10428178193833336
126
+ max_grad_norm: 0.5406643389792273
127
+ normalize_advantage: true
128
+ use_rms_prop: false
129
 
130
  _atari: &atari-defaults
131
  n_timesteps: !!float 1e7
rl_algo_impls/hyperparams/dqn.yml CHANGED
@@ -108,7 +108,7 @@ _impala-atari: &impala-atari-defaults
108
  <<: *atari-defaults
109
  policy_hyperparams:
110
  cnn_style: impala
111
- cnn_feature_dim: 256
112
  init_layers_orthogonal: true
113
  cnn_layers_init_orthogonal: false
114
 
 
108
  <<: *atari-defaults
109
  policy_hyperparams:
110
  cnn_style: impala
111
+ cnn_flatten_dim: 256
112
  init_layers_orthogonal: true
113
  cnn_layers_init_orthogonal: false
114
 
rl_algo_impls/hyperparams/ppo.yml CHANGED
@@ -112,7 +112,7 @@ CarRacing-v0: &carracing-defaults
112
  init_layers_orthogonal: false
113
  activation_fn: relu
114
  share_features_extractor: false
115
- cnn_feature_dim: 256
116
  hidden_sizes: [256]
117
  algo_hyperparams:
118
  n_steps: 512
@@ -152,7 +152,7 @@ _atari: &atari-defaults
152
  vec_env_class: async
153
  policy_hyperparams: &atari-policy-defaults
154
  activation_fn: relu
155
- algo_hyperparams:
156
  n_steps: 128
157
  batch_size: 256
158
  n_epochs: 4
@@ -192,7 +192,7 @@ _impala-atari: &impala-atari-defaults
192
  policy_hyperparams:
193
  <<: *atari-policy-defaults
194
  cnn_style: impala
195
- cnn_feature_dim: 256
196
  init_layers_orthogonal: true
197
  cnn_layers_init_orthogonal: false
198
 
@@ -212,6 +212,126 @@ impala-QbertNoFrameskip-v4:
212
  <<: *impala-atari-defaults
213
  env_id: QbertNoFrameskip-v4
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  HalfCheetahBulletEnv-v0: &pybullet-defaults
216
  n_timesteps: !!float 2e6
217
  env_hyperparams: &pybullet-env-defaults
@@ -282,7 +402,7 @@ _procgen: &procgen-defaults
282
  policy_hyperparams: &procgen-policy-defaults
283
  activation_fn: relu
284
  cnn_style: impala
285
- cnn_feature_dim: 256
286
  init_layers_orthogonal: true
287
  cnn_layers_init_orthogonal: false
288
  algo_hyperparams: &procgen-algo-defaults
@@ -368,7 +488,7 @@ procgen-starpilot-hard-2xIMPALA-fat:
368
  policy_hyperparams:
369
  <<: *procgen-policy-defaults
370
  impala_channels: [32, 64, 64]
371
- cnn_feature_dim: 512
372
  algo_hyperparams:
373
  <<: *procgen-hard-algo-defaults
374
  learning_rate: !!float 2.5e-4
 
112
  init_layers_orthogonal: false
113
  activation_fn: relu
114
  share_features_extractor: false
115
+ cnn_flatten_dim: 256
116
  hidden_sizes: [256]
117
  algo_hyperparams:
118
  n_steps: 512
 
152
  vec_env_class: async
153
  policy_hyperparams: &atari-policy-defaults
154
  activation_fn: relu
155
+ algo_hyperparams: &atari-algo-defaults
156
  n_steps: 128
157
  batch_size: 256
158
  n_epochs: 4
 
192
  policy_hyperparams:
193
  <<: *atari-policy-defaults
194
  cnn_style: impala
195
+ cnn_flatten_dim: 256
196
  init_layers_orthogonal: true
197
  cnn_layers_init_orthogonal: false
198
 
 
212
  <<: *impala-atari-defaults
213
  env_id: QbertNoFrameskip-v4
214
 
215
+ _microrts: &microrts-defaults
216
+ <<: *atari-defaults
217
+ n_timesteps: !!float 2e6
218
+ env_hyperparams: &microrts-env-defaults
219
+ n_envs: 8
220
+ vec_env_class: sync
221
+ mask_actions: true
222
+ policy_hyperparams: &microrts-policy-defaults
223
+ <<: *atari-policy-defaults
224
+ cnn_style: microrts
225
+ cnn_flatten_dim: 128
226
+ algo_hyperparams: &microrts-algo-defaults
227
+ <<: *atari-algo-defaults
228
+ clip_range_decay: none
229
+ clip_range_vf: 0.1
230
+ ppo2_vf_coef_halving: true
231
+ eval_params:
232
+ deterministic: false # Good idea because MultiCategorical mode isn't great
233
+
234
+ _no-mask-microrts: &no-mask-microrts-defaults
235
+ <<: *microrts-defaults
236
+ env_hyperparams:
237
+ <<: *microrts-env-defaults
238
+ mask_actions: false
239
+
240
+ MicrortsMining-v1-NoMask:
241
+ <<: *no-mask-microrts-defaults
242
+ env_id: MicrortsMining-v1
243
+
244
+ MicrortsAttackShapedReward-v1-NoMask:
245
+ <<: *no-mask-microrts-defaults
246
+ env_id: MicrortsAttackShapedReward-v1
247
+
248
+ MicrortsRandomEnemyShapedReward3-v1-NoMask:
249
+ <<: *no-mask-microrts-defaults
250
+ env_id: MicrortsRandomEnemyShapedReward3-v1
251
+
252
+ _microrts_ai: &microrts-ai-defaults
253
+ <<: *microrts-defaults
254
+ n_timesteps: !!float 100e6
255
+ additional_keys_to_log: ["microrts_stats"]
256
+ env_hyperparams: &microrts-ai-env-defaults
257
+ n_envs: 24
258
+ env_type: microrts
259
+ make_kwargs:
260
+ num_selfplay_envs: 0
261
+ max_steps: 2000
262
+ render_theme: 2
263
+ map_path: maps/16x16/basesWorkers16x16.xml
264
+ reward_weight: [10.0, 1.0, 1.0, 0.2, 1.0, 4.0]
265
+ policy_hyperparams: &microrts-ai-policy-defaults
266
+ <<: *microrts-policy-defaults
267
+ cnn_flatten_dim: 256
268
+ actor_head_style: gridnet
269
+ algo_hyperparams: &microrts-ai-algo-defaults
270
+ <<: *microrts-algo-defaults
271
+ learning_rate: !!float 2.5e-4
272
+ learning_rate_decay: linear
273
+ n_steps: 512
274
+ batch_size: 3072
275
+ n_epochs: 4
276
+ ent_coef: 0.01
277
+ vf_coef: 0.5
278
+ max_grad_norm: 0.5
279
+ clip_range: 0.1
280
+ clip_range_vf: 0.1
281
+
282
+ MicrortsAttackPassiveEnemySparseReward-v3:
283
+ <<: *microrts-ai-defaults
284
+ n_timesteps: !!float 2e6
285
+ env_id: MicrortsAttackPassiveEnemySparseReward-v3 # Workaround to keep model name simple
286
+ env_hyperparams:
287
+ <<: *microrts-ai-env-defaults
288
+ bots:
289
+ passiveAI: 24
290
+
291
+ MicrortsDefeatRandomEnemySparseReward-v3: &microrts-random-ai-defaults
292
+ <<: *microrts-ai-defaults
293
+ n_timesteps: !!float 2e6
294
+ env_id: MicrortsDefeatRandomEnemySparseReward-v3 # Workaround to keep model name simple
295
+ env_hyperparams:
296
+ <<: *microrts-ai-env-defaults
297
+ bots:
298
+ randomBiasedAI: 24
299
+
300
+ enc-dec-MicrortsDefeatRandomEnemySparseReward-v3:
301
+ <<: *microrts-random-ai-defaults
302
+ policy_hyperparams:
303
+ <<: *microrts-ai-policy-defaults
304
+ cnn_style: gridnet_encoder
305
+ actor_head_style: gridnet_decoder
306
+ v_hidden_sizes: [128]
307
+
308
+ MicrortsDefeatCoacAIShaped-v3: &microrts-coacai-defaults
309
+ <<: *microrts-ai-defaults
310
+ env_id: MicrortsDefeatCoacAIShaped-v3 # Workaround to keep model name simple
311
+ n_timesteps: !!float 300e6
312
+ env_hyperparams: &microrts-coacai-env-defaults
313
+ <<: *microrts-ai-env-defaults
314
+ bots:
315
+ coacAI: 24
316
+
317
+ MicrortsDefeatCoacAIShaped-v3-diverseBots: &microrts-diverse-defaults
318
+ <<: *microrts-coacai-defaults
319
+ env_hyperparams:
320
+ <<: *microrts-coacai-env-defaults
321
+ bots:
322
+ coacAI: 18
323
+ randomBiasedAI: 2
324
+ lightRushAI: 2
325
+ workerRushAI: 2
326
+
327
+ enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
328
+ <<: *microrts-diverse-defaults
329
+ policy_hyperparams:
330
+ <<: *microrts-ai-policy-defaults
331
+ cnn_style: gridnet_encoder
332
+ actor_head_style: gridnet_decoder
333
+ v_hidden_sizes: [128]
334
+
335
  HalfCheetahBulletEnv-v0: &pybullet-defaults
336
  n_timesteps: !!float 2e6
337
  env_hyperparams: &pybullet-env-defaults
 
402
  policy_hyperparams: &procgen-policy-defaults
403
  activation_fn: relu
404
  cnn_style: impala
405
+ cnn_flatten_dim: 256
406
  init_layers_orthogonal: true
407
  cnn_layers_init_orthogonal: false
408
  algo_hyperparams: &procgen-algo-defaults
 
488
  policy_hyperparams:
489
  <<: *procgen-policy-defaults
490
  impala_channels: [32, 64, 64]
491
+ cnn_flatten_dim: 512
492
  algo_hyperparams:
493
  <<: *procgen-hard-algo-defaults
494
  learning_rate: !!float 2.5e-4
rl_algo_impls/hyperparams/vpg.yml CHANGED
@@ -110,7 +110,7 @@ CarRacing-v0:
110
  log_std_init: -2
111
  init_layers_orthogonal: false
112
  activation_fn: relu
113
- cnn_feature_dim: 256
114
  hidden_sizes: [256]
115
  algo_hyperparams:
116
  n_steps: 1000
@@ -175,9 +175,9 @@ FrozenLake-v1:
175
  save_best: true
176
 
177
  _atari: &atari-defaults
178
- n_timesteps: !!float 25e6
179
  env_hyperparams:
180
- n_envs: 4
181
  frame_stack: 4
182
  no_reward_timeout_steps: 1000
183
  no_reward_fire_steps: 500
@@ -185,7 +185,7 @@ _atari: &atari-defaults
185
  policy_hyperparams:
186
  activation_fn: relu
187
  algo_hyperparams:
188
- n_steps: 2048
189
  pi_lr: !!float 5e-5
190
  gamma: 0.99
191
  gae_lambda: 0.95
 
110
  log_std_init: -2
111
  init_layers_orthogonal: false
112
  activation_fn: relu
113
+ cnn_flatten_dim: 256
114
  hidden_sizes: [256]
115
  algo_hyperparams:
116
  n_steps: 1000
 
175
  save_best: true
176
 
177
  _atari: &atari-defaults
178
+ n_timesteps: !!float 10e6
179
  env_hyperparams:
180
+ n_envs: 2
181
  frame_stack: 4
182
  no_reward_timeout_steps: 1000
183
  no_reward_fire_steps: 500
 
185
  policy_hyperparams:
186
  activation_fn: relu
187
  algo_hyperparams:
188
+ n_steps: 3072
189
  pi_lr: !!float 5e-5
190
  gamma: 0.99
191
  gae_lambda: 0.95
rl_algo_impls/optimize.py CHANGED
@@ -17,7 +17,7 @@ from typing import Callable, List, NamedTuple, Optional, Sequence, Union
17
 
18
  from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
19
  from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
20
- from rl_algo_impls.runner.env import make_env, make_eval_env
21
  from rl_algo_impls.runner.running_utils import (
22
  base_parser,
23
  load_hyperparams,
@@ -194,7 +194,7 @@ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -
194
  env = make_env(
195
  config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
196
  )
197
- device = get_device(config.device, env)
198
  policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
199
  algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
200
 
@@ -274,7 +274,7 @@ def stepwise_optimize(
274
  project=study_args.wandb_project_name,
275
  entity=study_args.wandb_entity,
276
  config=asdict(hyperparams),
277
- name=f"{study_args.study_name}-{str(trial.number)}",
278
  tags=study_args.wandb_tags,
279
  group=study_args.wandb_group,
280
  save_code=True,
@@ -298,7 +298,7 @@ def stepwise_optimize(
298
  normalize_load_path=config.model_dir_path() if i > 0 else None,
299
  tb_writer=tb_writer,
300
  )
301
- device = get_device(config.device, env)
302
  policy = make_policy(arg.algo, env, device, **config.policy_hyperparams)
303
  if i > 0:
304
  policy.load(config.model_dir_path())
@@ -433,6 +433,7 @@ def optimize() -> None:
433
 
434
  fig1 = plot_optimization_history(study)
435
  fig1.write_image("opt_history.png")
 
436
  fig2 = plot_param_importances(study)
437
  fig2.write_image("param_importances.png")
438
 
 
17
 
18
  from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
19
  from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
20
+ from rl_algo_impls.shared.vec_env import make_env, make_eval_env
21
  from rl_algo_impls.runner.running_utils import (
22
  base_parser,
23
  load_hyperparams,
 
194
  env = make_env(
195
  config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
196
  )
197
+ device = get_device(config, env)
198
  policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
199
  algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
200
 
 
274
  project=study_args.wandb_project_name,
275
  entity=study_args.wandb_entity,
276
  config=asdict(hyperparams),
277
+ name=f"{str(trial.number)}-S{base_config.seed()}",
278
  tags=study_args.wandb_tags,
279
  group=study_args.wandb_group,
280
  save_code=True,
 
298
  normalize_load_path=config.model_dir_path() if i > 0 else None,
299
  tb_writer=tb_writer,
300
  )
301
+ device = get_device(config, env)
302
  policy = make_policy(arg.algo, env, device, **config.policy_hyperparams)
303
  if i > 0:
304
  policy.load(config.model_dir_path())
 
433
 
434
  fig1 = plot_optimization_history(study)
435
  fig1.write_image("opt_history.png")
436
+
437
  fig2 = plot_param_importances(study)
438
  fig2.write_image("param_importances.png")
439
 
rl_algo_impls/ppo/ppo.py CHANGED
@@ -1,59 +1,26 @@
 
 
 
 
 
1
  import numpy as np
2
  import torch
3
  import torch.nn as nn
4
-
5
- from dataclasses import asdict, dataclass, field
6
- from time import perf_counter
7
  from torch.optim import Adam
8
  from torch.utils.tensorboard.writer import SummaryWriter
9
- from typing import List, Optional, NamedTuple, TypeVar
10
 
11
  from rl_algo_impls.shared.algorithm import Algorithm
12
  from rl_algo_impls.shared.callbacks.callback import Callback
13
- from rl_algo_impls.shared.gae import compute_advantage, compute_rtg_and_advantage
14
  from rl_algo_impls.shared.policy.on_policy import ActorCritic
15
- from rl_algo_impls.shared.schedule import (
16
- constant_schedule,
17
- linear_schedule,
18
- update_learning_rate,
 
 
 
19
  )
20
- from rl_algo_impls.shared.trajectory import Trajectory, TrajectoryAccumulator
21
- from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs
22
-
23
-
24
- @dataclass
25
- class PPOTrajectory(Trajectory):
26
- logp_a: List[float] = field(default_factory=list)
27
-
28
- def add(
29
- self,
30
- obs: np.ndarray,
31
- act: np.ndarray,
32
- next_obs: np.ndarray,
33
- rew: float,
34
- terminated: bool,
35
- v: float,
36
- logp_a: float,
37
- ):
38
- super().add(obs, act, next_obs, rew, terminated, v)
39
- self.logp_a.append(logp_a)
40
-
41
-
42
- class PPOTrajectoryAccumulator(TrajectoryAccumulator):
43
- def __init__(self, num_envs: int) -> None:
44
- super().__init__(num_envs, PPOTrajectory)
45
-
46
- def step(
47
- self,
48
- obs: VecEnvObs,
49
- action: np.ndarray,
50
- next_obs: VecEnvObs,
51
- reward: np.ndarray,
52
- done: np.ndarray,
53
- val: np.ndarray,
54
- logp_a: np.ndarray,
55
- ) -> None:
56
- super().step(obs, action, next_obs, reward, done, val, logp_a)
57
 
58
 
59
  class TrainStepStats(NamedTuple):
@@ -132,39 +99,31 @@ class PPO(Algorithm):
132
  vf_coef: float = 0.5,
133
  ppo2_vf_coef_halving: bool = False,
134
  max_grad_norm: float = 0.5,
135
- update_rtg_between_epochs: bool = False,
136
  sde_sample_freq: int = -1,
 
 
137
  ) -> None:
138
  super().__init__(policy, env, device, tb_writer)
139
  self.policy = policy
 
140
 
141
  self.gamma = gamma
142
  self.gae_lambda = gae_lambda
143
  self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
144
- self.lr_schedule = (
145
- linear_schedule(learning_rate, 0)
146
- if learning_rate_decay == "linear"
147
- else constant_schedule(learning_rate)
148
- )
149
  self.max_grad_norm = max_grad_norm
150
- self.clip_range_schedule = (
151
- linear_schedule(clip_range, 0)
152
- if clip_range_decay == "linear"
153
- else constant_schedule(clip_range)
154
- )
155
  self.clip_range_vf_schedule = None
156
  if clip_range_vf:
157
- self.clip_range_vf_schedule = (
158
- linear_schedule(clip_range_vf, 0)
159
- if clip_range_vf_decay == "linear"
160
- else constant_schedule(clip_range_vf)
161
- )
 
162
  self.normalize_advantage = normalize_advantage
163
- self.ent_coef_schedule = (
164
- linear_schedule(ent_coef, 0)
165
- if ent_coef_decay == "linear"
166
- else constant_schedule(ent_coef)
167
- )
168
  self.vf_coef = vf_coef
169
  self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
170
 
@@ -173,181 +132,243 @@ class PPO(Algorithm):
173
  self.n_epochs = n_epochs
174
  self.sde_sample_freq = sde_sample_freq
175
 
176
- self.update_rtg_between_epochs = update_rtg_between_epochs
 
177
 
178
  def learn(
179
  self: PPOSelf,
180
- total_timesteps: int,
181
  callback: Optional[Callback] = None,
 
 
182
  ) -> PPOSelf:
183
- obs = self.env.reset()
184
- ts_elapsed = 0
185
- while ts_elapsed < total_timesteps:
186
- start_time = perf_counter()
187
- accumulator = self._collect_trajectories(obs)
188
- rollout_steps = self.n_steps * self.env.num_envs
189
- ts_elapsed += rollout_steps
190
- progress = ts_elapsed / total_timesteps
191
- train_stats = self.train(accumulator.all_trajectories, progress, ts_elapsed)
192
- train_stats.write_to_tensorboard(self.tb_writer, ts_elapsed)
193
- end_time = perf_counter()
194
- self.tb_writer.add_scalar(
195
- "train/steps_per_second",
196
- rollout_steps / (end_time - start_time),
197
- ts_elapsed,
 
 
 
 
 
 
 
 
 
 
198
  )
199
- if callback:
200
- callback.on_step(timesteps_elapsed=rollout_steps)
201
-
202
- return self
203
-
204
- def _collect_trajectories(self, obs: VecEnvObs) -> PPOTrajectoryAccumulator:
205
- self.policy.eval()
206
- accumulator = PPOTrajectoryAccumulator(self.env.num_envs)
207
- self.policy.reset_noise()
208
- for i in range(self.n_steps):
209
- if self.sde_sample_freq > 0 and i > 0 and i % self.sde_sample_freq == 0:
210
- self.policy.reset_noise()
211
- action, value, logp_a, clamped_action = self.policy.step(obs)
212
- next_obs, reward, done, _ = self.env.step(clamped_action)
213
- accumulator.step(obs, action, next_obs, reward, done, value, logp_a)
214
- obs = next_obs
215
- return accumulator
216
-
217
- def train(
218
- self, trajectories: List[PPOTrajectory], progress: float, timesteps_elapsed: int
219
- ) -> TrainStats:
220
- self.policy.train()
221
- learning_rate = self.lr_schedule(progress)
222
- update_learning_rate(self.optimizer, learning_rate)
223
- self.tb_writer.add_scalar(
224
- "charts/learning_rate",
225
- self.optimizer.param_groups[0]["lr"],
226
- timesteps_elapsed,
227
  )
228
 
229
- pi_clip = self.clip_range_schedule(progress)
230
- self.tb_writer.add_scalar("charts/pi_clip", pi_clip, timesteps_elapsed)
231
- if self.clip_range_vf_schedule:
232
- v_clip = self.clip_range_vf_schedule(progress)
233
- self.tb_writer.add_scalar("charts/v_clip", v_clip, timesteps_elapsed)
234
- else:
235
- v_clip = None
236
- ent_coef = self.ent_coef_schedule(progress)
237
- self.tb_writer.add_scalar("charts/ent_coef", ent_coef, timesteps_elapsed)
238
-
239
- obs = torch.as_tensor(
240
- np.concatenate([np.array(t.obs) for t in trajectories]), device=self.device
241
- )
242
- act = torch.as_tensor(
243
- np.concatenate([np.array(t.act) for t in trajectories]), device=self.device
244
- )
245
- rtg, adv = compute_rtg_and_advantage(
246
- trajectories, self.policy, self.gamma, self.gae_lambda, self.device
247
- )
248
- orig_v = torch.as_tensor(
249
- np.concatenate([np.array(t.v) for t in trajectories]), device=self.device
250
- )
251
- orig_logp_a = torch.as_tensor(
252
- np.concatenate([np.array(t.logp_a) for t in trajectories]),
253
- device=self.device,
254
- )
255
 
256
- step_stats = []
257
- for _ in range(self.n_epochs):
258
- step_stats.clear()
259
- if self.update_rtg_between_epochs:
260
- rtg, adv = compute_rtg_and_advantage(
261
- trajectories, self.policy, self.gamma, self.gae_lambda, self.device
262
- )
 
 
 
 
 
 
263
  else:
264
- adv = compute_advantage(
265
- trajectories, self.policy, self.gamma, self.gae_lambda, self.device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  )
267
- idxs = torch.randperm(len(obs))
268
- for i in range(0, len(obs), self.batch_size):
269
- mb_idxs = idxs[i : i + self.batch_size]
270
- mb_adv = adv[mb_idxs]
271
- if self.normalize_advantage:
272
- mb_adv = (mb_adv - mb_adv.mean(-1)) / (mb_adv.std(-1) + 1e-8)
273
- self.policy.reset_noise(self.batch_size)
274
- step_stats.append(
275
- self._train_step(
276
- pi_clip,
277
- v_clip,
278
- ent_coef,
279
- obs[mb_idxs],
280
- act[mb_idxs],
281
- rtg[mb_idxs],
282
- mb_adv,
283
- orig_v[mb_idxs],
284
- orig_logp_a[mb_idxs],
285
- )
286
  )
287
 
288
- y_pred, y_true = orig_v.cpu().numpy(), rtg.cpu().numpy()
289
- var_y = np.var(y_true).item()
290
- explained_var = (
291
- np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
292
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
- return TrainStats(step_stats, explained_var)
 
 
 
295
 
296
- def _train_step(
297
- self,
298
- pi_clip: float,
299
- v_clip: Optional[float],
300
- ent_coef: float,
301
- obs: torch.Tensor,
302
- act: torch.Tensor,
303
- rtg: torch.Tensor,
304
- adv: torch.Tensor,
305
- orig_v: torch.Tensor,
306
- orig_logp_a: torch.Tensor,
307
- ) -> TrainStepStats:
308
- logp_a, entropy, v = self.policy(obs, act)
309
- logratio = logp_a - orig_logp_a
310
- ratio = torch.exp(logratio)
311
- clip_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
312
- pi_loss = torch.maximum(-ratio * adv, -clip_ratio * adv).mean()
313
-
314
- v_loss_unclipped = (v - rtg) ** 2
315
- if v_clip:
316
- v_loss_clipped = (
317
- orig_v + torch.clamp(v - orig_v, -v_clip, v_clip) - rtg
318
- ) ** 2
319
- v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
320
- else:
321
- v_loss = v_loss_unclipped.mean()
322
- if self.ppo2_vf_coef_halving:
323
- v_loss *= 0.5
324
-
325
- entropy_loss = -entropy.mean()
326
-
327
- loss = pi_loss + ent_coef * entropy_loss + self.vf_coef * v_loss
328
-
329
- self.optimizer.zero_grad()
330
- loss.backward()
331
- nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
332
- self.optimizer.step()
333
-
334
- with torch.no_grad():
335
- approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
336
- clipped_frac = (
337
- ((ratio - 1).abs() > pi_clip).float().mean().cpu().numpy().item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  )
339
- val_clipped_frac = (
340
- (((v - orig_v).abs() > v_clip).float().mean().cpu().numpy().item())
341
- if v_clip
342
- else 0
343
  )
344
 
345
- return TrainStepStats(
346
- loss.item(),
347
- pi_loss.item(),
348
- v_loss.item(),
349
- entropy_loss.item(),
350
- approx_kl,
351
- clipped_frac,
352
- val_clipped_frac,
353
- )
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import asdict, dataclass
3
+ from time import perf_counter
4
+ from typing import List, NamedTuple, Optional, TypeVar
5
+
6
  import numpy as np
7
  import torch
8
  import torch.nn as nn
 
 
 
9
  from torch.optim import Adam
10
  from torch.utils.tensorboard.writer import SummaryWriter
 
11
 
12
  from rl_algo_impls.shared.algorithm import Algorithm
13
  from rl_algo_impls.shared.callbacks.callback import Callback
14
+ from rl_algo_impls.shared.gae import compute_advantages
15
  from rl_algo_impls.shared.policy.on_policy import ActorCritic
16
+ from rl_algo_impls.shared.schedule import schedule, update_learning_rate
17
+ from rl_algo_impls.shared.stats import log_scalars
18
+ from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
19
+ from rl_algo_impls.wrappers.vectorable_wrapper import (
20
+ VecEnv,
21
+ single_action_space,
22
+ single_observation_space,
23
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  class TrainStepStats(NamedTuple):
 
99
  vf_coef: float = 0.5,
100
  ppo2_vf_coef_halving: bool = False,
101
  max_grad_norm: float = 0.5,
 
102
  sde_sample_freq: int = -1,
103
+ update_advantage_between_epochs: bool = True,
104
+ update_returns_between_epochs: bool = False,
105
  ) -> None:
106
  super().__init__(policy, env, device, tb_writer)
107
  self.policy = policy
108
+ self.action_masker = find_action_masker(env)
109
 
110
  self.gamma = gamma
111
  self.gae_lambda = gae_lambda
112
  self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
113
+ self.lr_schedule = schedule(learning_rate_decay, learning_rate)
 
 
 
 
114
  self.max_grad_norm = max_grad_norm
115
+ self.clip_range_schedule = schedule(clip_range_decay, clip_range)
 
 
 
 
116
  self.clip_range_vf_schedule = None
117
  if clip_range_vf:
118
+ self.clip_range_vf_schedule = schedule(clip_range_vf_decay, clip_range_vf)
119
+
120
+ if normalize_advantage:
121
+ assert (
122
+ env.num_envs * n_steps > 1 and batch_size > 1
123
+ ), f"Each minibatch must be larger than 1 to support normalization"
124
  self.normalize_advantage = normalize_advantage
125
+
126
+ self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
 
 
 
127
  self.vf_coef = vf_coef
128
  self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
129
 
 
132
  self.n_epochs = n_epochs
133
  self.sde_sample_freq = sde_sample_freq
134
 
135
+ self.update_advantage_between_epochs = update_advantage_between_epochs
136
+ self.update_returns_between_epochs = update_returns_between_epochs
137
 
138
  def learn(
139
  self: PPOSelf,
140
+ train_timesteps: int,
141
  callback: Optional[Callback] = None,
142
+ total_timesteps: Optional[int] = None,
143
+ start_timesteps: int = 0,
144
  ) -> PPOSelf:
145
+ if total_timesteps is None:
146
+ total_timesteps = train_timesteps
147
+ assert start_timesteps + train_timesteps <= total_timesteps
148
+
149
+ epoch_dim = (self.n_steps, self.env.num_envs)
150
+ step_dim = (self.env.num_envs,)
151
+ obs_space = single_observation_space(self.env)
152
+ act_space = single_action_space(self.env)
153
+ act_shape = self.policy.action_shape
154
+
155
+ next_obs = self.env.reset()
156
+ next_action_masks = (
157
+ self.action_masker.action_masks() if self.action_masker else None
158
+ )
159
+ next_episode_starts = np.full(step_dim, True, dtype=np.bool8)
160
+
161
+ obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype) # type: ignore
162
+ actions = np.zeros(epoch_dim + act_shape, dtype=act_space.dtype) # type: ignore
163
+ rewards = np.zeros(epoch_dim, dtype=np.float32)
164
+ episode_starts = np.zeros(epoch_dim, dtype=np.bool8)
165
+ values = np.zeros(epoch_dim, dtype=np.float32)
166
+ logprobs = np.zeros(epoch_dim, dtype=np.float32)
167
+ action_masks = (
168
+ np.zeros(
169
+ (self.n_steps,) + next_action_masks.shape, dtype=next_action_masks.dtype
170
  )
171
+ if next_action_masks is not None
172
+ else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  )
174
 
175
+ timesteps_elapsed = start_timesteps
176
+ while timesteps_elapsed < start_timesteps + train_timesteps:
177
+ start_time = perf_counter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ progress = timesteps_elapsed / total_timesteps
180
+ ent_coef = self.ent_coef_schedule(progress)
181
+ learning_rate = self.lr_schedule(progress)
182
+ update_learning_rate(self.optimizer, learning_rate)
183
+ pi_clip = self.clip_range_schedule(progress)
184
+ chart_scalars = {
185
+ "learning_rate": self.optimizer.param_groups[0]["lr"],
186
+ "ent_coef": ent_coef,
187
+ "pi_clip": pi_clip,
188
+ }
189
+ if self.clip_range_vf_schedule:
190
+ v_clip = self.clip_range_vf_schedule(progress)
191
+ chart_scalars["v_clip"] = v_clip
192
  else:
193
+ v_clip = None
194
+ log_scalars(self.tb_writer, "charts", chart_scalars, timesteps_elapsed)
195
+
196
+ self.policy.eval()
197
+ self.policy.reset_noise()
198
+ for s in range(self.n_steps):
199
+ timesteps_elapsed += self.env.num_envs
200
+ if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0:
201
+ self.policy.reset_noise()
202
+
203
+ obs[s] = next_obs
204
+ episode_starts[s] = next_episode_starts
205
+ if action_masks is not None:
206
+ action_masks[s] = next_action_masks
207
+
208
+ (
209
+ actions[s],
210
+ values[s],
211
+ logprobs[s],
212
+ clamped_action,
213
+ ) = self.policy.step(next_obs, action_masks=next_action_masks)
214
+ next_obs, rewards[s], next_episode_starts, _ = self.env.step(
215
+ clamped_action
216
  )
217
+ next_action_masks = (
218
+ self.action_masker.action_masks() if self.action_masker else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  )
220
 
221
+ self.policy.train()
222
+
223
+ b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device) # type: ignore
224
+ b_actions = torch.tensor(actions.reshape((-1,) + act_shape)).to( # type: ignore
225
+ self.device
226
+ )
227
+ b_logprobs = torch.tensor(logprobs.reshape(-1)).to(self.device)
228
+ b_action_masks = (
229
+ torch.tensor(action_masks.reshape((-1,) + next_action_masks.shape[1:])).to( # type: ignore
230
+ self.device
231
+ )
232
+ if action_masks is not None
233
+ else None
234
+ )
235
+
236
+ y_pred = values.reshape(-1)
237
+ b_values = torch.tensor(y_pred).to(self.device)
238
+
239
+ step_stats = []
240
+ # Define variables that will definitely be set through the first epoch
241
+ advantages: np.ndarray = None # type: ignore
242
+ b_advantages: torch.Tensor = None # type: ignore
243
+ y_true: np.ndarray = None # type: ignore
244
+ b_returns: torch.Tensor = None # type: ignore
245
+ for e in range(self.n_epochs):
246
+ if e == 0 or self.update_advantage_between_epochs:
247
+ advantages = compute_advantages(
248
+ rewards,
249
+ values,
250
+ episode_starts,
251
+ next_episode_starts,
252
+ next_obs,
253
+ self.policy,
254
+ self.gamma,
255
+ self.gae_lambda,
256
+ )
257
+ b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
258
+ if e == 0 or self.update_returns_between_epochs:
259
+ returns = advantages + values
260
+ y_true = returns.reshape(-1)
261
+ b_returns = torch.tensor(y_true).to(self.device)
262
+
263
+ b_idxs = torch.randperm(len(b_obs))
264
+ # Only record last epoch's stats
265
+ step_stats.clear()
266
+ for i in range(0, len(b_obs), self.batch_size):
267
+ self.policy.reset_noise(self.batch_size)
268
+
269
+ mb_idxs = b_idxs[i : i + self.batch_size]
270
+
271
+ mb_obs = b_obs[mb_idxs]
272
+ mb_actions = b_actions[mb_idxs]
273
+ mb_values = b_values[mb_idxs]
274
+ mb_logprobs = b_logprobs[mb_idxs]
275
+ mb_action_masks = (
276
+ b_action_masks[mb_idxs] if b_action_masks is not None else None
277
+ )
278
 
279
+ mb_adv = b_advantages[mb_idxs]
280
+ if self.normalize_advantage:
281
+ mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8)
282
+ mb_returns = b_returns[mb_idxs]
283
 
284
+ new_logprobs, entropy, new_values = self.policy(
285
+ mb_obs, mb_actions, action_masks=mb_action_masks
286
+ )
287
+
288
+ logratio = new_logprobs - mb_logprobs
289
+ ratio = torch.exp(logratio)
290
+ clipped_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
291
+ pi_loss = torch.max(-ratio * mb_adv, -clipped_ratio * mb_adv).mean()
292
+
293
+ v_loss_unclipped = (new_values - mb_returns) ** 2
294
+ if v_clip:
295
+ v_loss_clipped = (
296
+ mb_values
297
+ + torch.clamp(new_values - mb_values, -v_clip, v_clip)
298
+ - mb_returns
299
+ ) ** 2
300
+ v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
301
+ else:
302
+ v_loss = v_loss_unclipped.mean()
303
+
304
+ if self.ppo2_vf_coef_halving:
305
+ v_loss *= 0.5
306
+
307
+ entropy_loss = -entropy.mean()
308
+
309
+ loss = pi_loss + ent_coef * entropy_loss + self.vf_coef * v_loss
310
+
311
+ self.optimizer.zero_grad()
312
+ loss.backward()
313
+ nn.utils.clip_grad_norm_(
314
+ self.policy.parameters(), self.max_grad_norm
315
+ )
316
+ self.optimizer.step()
317
+
318
+ with torch.no_grad():
319
+ approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
320
+ clipped_frac = (
321
+ ((ratio - 1).abs() > pi_clip)
322
+ .float()
323
+ .mean()
324
+ .cpu()
325
+ .numpy()
326
+ .item()
327
+ )
328
+ val_clipped_frac = (
329
+ ((new_values - mb_values).abs() > v_clip)
330
+ .float()
331
+ .mean()
332
+ .cpu()
333
+ .numpy()
334
+ .item()
335
+ if v_clip
336
+ else 0
337
+ )
338
+
339
+ step_stats.append(
340
+ TrainStepStats(
341
+ loss.item(),
342
+ pi_loss.item(),
343
+ v_loss.item(),
344
+ entropy_loss.item(),
345
+ approx_kl,
346
+ clipped_frac,
347
+ val_clipped_frac,
348
+ )
349
+ )
350
+
351
+ var_y = np.var(y_true).item()
352
+ explained_var = (
353
+ np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
354
  )
355
+ TrainStats(step_stats, explained_var).write_to_tensorboard(
356
+ self.tb_writer, timesteps_elapsed
 
 
357
  )
358
 
359
+ end_time = perf_counter()
360
+ rollout_steps = self.n_steps * self.env.num_envs
361
+ self.tb_writer.add_scalar(
362
+ "train/steps_per_second",
363
+ rollout_steps / (end_time - start_time),
364
+ timesteps_elapsed,
365
+ )
366
+
367
+ if callback:
368
+ if not callback.on_step(timesteps_elapsed=rollout_steps):
369
+ logging.info(
370
+ f"Callback terminated training at {timesteps_elapsed} timesteps"
371
+ )
372
+ break
373
+
374
+ return self
rl_algo_impls/runner/config.py CHANGED
@@ -2,12 +2,10 @@ import dataclasses
2
  import inspect
3
  import itertools
4
  import os
5
-
6
- from datetime import datetime
7
  from dataclasses import dataclass
 
8
  from typing import Any, Dict, List, Optional, Type, TypeVar, Union
9
 
10
-
11
  RunArgsSelf = TypeVar("RunArgsSelf", bound="RunArgs")
12
 
13
 
@@ -50,6 +48,9 @@ class EnvHyperparams:
50
  video_step_interval: Union[int, float] = 1_000_000
51
  initial_steps_to_truncate: Optional[int] = None
52
  clip_atari_rewards: bool = True
 
 
 
53
 
54
 
55
  HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
@@ -64,6 +65,7 @@ class Hyperparams:
64
  algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
65
  eval_params: Dict[str, Any] = dataclasses.field(default_factory=dict)
66
  env_id: Optional[str] = None
 
67
 
68
  @classmethod
69
  def from_dict_with_extra_fields(
@@ -119,6 +121,10 @@ class Config:
119
  def env_id(self) -> str:
120
  return self.hyperparams.env_id or self.args.env
121
 
 
 
 
 
122
  def model_name(self, include_seed: bool = True) -> str:
123
  # Use arg env name instead of environment name
124
  parts = [self.algo, self.args.env]
 
2
  import inspect
3
  import itertools
4
  import os
 
 
5
  from dataclasses import dataclass
6
+ from datetime import datetime
7
  from typing import Any, Dict, List, Optional, Type, TypeVar, Union
8
 
 
9
  RunArgsSelf = TypeVar("RunArgsSelf", bound="RunArgs")
10
 
11
 
 
48
  video_step_interval: Union[int, float] = 1_000_000
49
  initial_steps_to_truncate: Optional[int] = None
50
  clip_atari_rewards: bool = True
51
+ normalize_type: Optional[str] = None
52
+ mask_actions: bool = False
53
+ bots: Optional[Dict[str, int]] = None
54
 
55
 
56
  HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
 
65
  algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
66
  eval_params: Dict[str, Any] = dataclasses.field(default_factory=dict)
67
  env_id: Optional[str] = None
68
+ additional_keys_to_log: List[str] = dataclasses.field(default_factory=list)
69
 
70
  @classmethod
71
  def from_dict_with_extra_fields(
 
121
  def env_id(self) -> str:
122
  return self.hyperparams.env_id or self.args.env
123
 
124
+ @property
125
+ def additional_keys_to_log(self) -> List[str]:
126
+ return self.hyperparams.additional_keys_to_log
127
+
128
  def model_name(self, include_seed: bool = True) -> str:
129
  # Use arg env name instead of environment name
130
  parts = [self.algo, self.args.env]
rl_algo_impls/runner/evaluate.py CHANGED
@@ -4,7 +4,7 @@ import shutil
4
  from dataclasses import dataclass
5
  from typing import NamedTuple, Optional
6
 
7
- from rl_algo_impls.runner.env import make_eval_env
8
  from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
9
  from rl_algo_impls.runner.running_utils import (
10
  load_hyperparams,
@@ -75,7 +75,7 @@ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
75
  render=args.render,
76
  normalize_load_path=model_path,
77
  )
78
- device = get_device(config.device, env)
79
  policy = make_policy(
80
  args.algo,
81
  env,
 
4
  from dataclasses import dataclass
5
  from typing import NamedTuple, Optional
6
 
7
+ from rl_algo_impls.shared.vec_env import make_eval_env
8
  from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
9
  from rl_algo_impls.runner.running_utils import (
10
  load_hyperparams,
 
75
  render=args.render,
76
  normalize_load_path=model_path,
77
  )
78
+ device = get_device(config, env)
79
  policy = make_policy(
80
  args.algo,
81
  env,
rl_algo_impls/runner/running_utils.py CHANGED
@@ -1,32 +1,32 @@
1
  import argparse
2
- import gym
3
  import json
4
- import matplotlib.pyplot as plt
5
- import numpy as np
6
  import os
7
  import random
 
 
 
 
 
 
 
8
  import torch
9
  import torch.backends.cudnn
10
  import yaml
11
-
12
- from dataclasses import asdict
13
  from gym.spaces import Box, Discrete
14
- from pathlib import Path
15
  from torch.utils.tensorboard.writer import SummaryWriter
16
- from typing import Dict, Optional, Type, Union
17
-
18
- from rl_algo_impls.runner.config import Hyperparams
19
- from rl_algo_impls.shared.algorithm import Algorithm
20
- from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
21
- from rl_algo_impls.shared.policy.on_policy import ActorCritic
22
- from rl_algo_impls.shared.policy.policy import Policy
23
 
24
  from rl_algo_impls.a2c.a2c import A2C
25
  from rl_algo_impls.dqn.dqn import DQN
26
  from rl_algo_impls.dqn.policy import DQNPolicy
27
  from rl_algo_impls.ppo.ppo import PPO
28
- from rl_algo_impls.vpg.vpg import VanillaPolicyGradient
 
 
 
 
 
29
  from rl_algo_impls.vpg.policy import VPGActorCritic
 
30
  from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, single_observation_space
31
 
32
  ALGOS: Dict[str, Type[Algorithm]] = {
@@ -81,16 +81,19 @@ def load_hyperparams(algo: str, env_id: str) -> Hyperparams:
81
  if env_id in hyperparams_dict:
82
  return Hyperparams(**hyperparams_dict[env_id])
83
 
84
- if "BulletEnv" in env_id:
85
- import pybullet_envs
86
  spec = gym.spec(env_id)
87
- if "AtariEnv" in str(spec.entry_point) and "_atari" in hyperparams_dict:
 
88
  return Hyperparams(**hyperparams_dict["_atari"])
 
 
89
  else:
90
  raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
91
 
92
 
93
- def get_device(device: str, env: VecEnv) -> torch.device:
 
94
  # cuda by default
95
  if device == "auto":
96
  device = "cuda"
@@ -108,6 +111,16 @@ def get_device(device: str, env: VecEnv) -> torch.device:
108
  device = "cpu"
109
  elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
110
  device = "cpu"
 
 
 
 
 
 
 
 
 
 
111
  print(f"Device: {device}")
112
  return torch.device(device)
113
 
@@ -187,6 +200,8 @@ def hparam_dict(
187
  flattened[key] = str(sv)
188
  else:
189
  flattened[key] = sv
 
 
190
  else:
191
  flattened[k] = v # type: ignore
192
  return flattened # type: ignore
 
1
  import argparse
 
2
  import json
 
 
3
  import os
4
  import random
5
+ from dataclasses import asdict
6
+ from pathlib import Path
7
+ from typing import Dict, Optional, Type, Union
8
+
9
+ import gym
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
  import torch
13
  import torch.backends.cudnn
14
  import yaml
 
 
15
  from gym.spaces import Box, Discrete
 
16
  from torch.utils.tensorboard.writer import SummaryWriter
 
 
 
 
 
 
 
17
 
18
  from rl_algo_impls.a2c.a2c import A2C
19
  from rl_algo_impls.dqn.dqn import DQN
20
  from rl_algo_impls.dqn.policy import DQNPolicy
21
  from rl_algo_impls.ppo.ppo import PPO
22
+ from rl_algo_impls.runner.config import Config, Hyperparams
23
+ from rl_algo_impls.shared.algorithm import Algorithm
24
+ from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
25
+ from rl_algo_impls.shared.policy.on_policy import ActorCritic
26
+ from rl_algo_impls.shared.policy.policy import Policy
27
+ from rl_algo_impls.shared.vec_env.utils import import_for_env_id, is_microrts
28
  from rl_algo_impls.vpg.policy import VPGActorCritic
29
+ from rl_algo_impls.vpg.vpg import VanillaPolicyGradient
30
  from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, single_observation_space
31
 
32
  ALGOS: Dict[str, Type[Algorithm]] = {
 
81
  if env_id in hyperparams_dict:
82
  return Hyperparams(**hyperparams_dict[env_id])
83
 
84
+ import_for_env_id(env_id)
 
85
  spec = gym.spec(env_id)
86
+ entry_point_name = str(spec.entry_point) # type: ignore
87
+ if "AtariEnv" in entry_point_name and "_atari" in hyperparams_dict:
88
  return Hyperparams(**hyperparams_dict["_atari"])
89
+ elif "gym_microrts" in entry_point_name and "_microrts" in hyperparams_dict:
90
+ return Hyperparams(**hyperparams_dict["_microrts"])
91
  else:
92
  raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
93
 
94
 
95
+ def get_device(config: Config, env: VecEnv) -> torch.device:
96
+ device = config.device
97
  # cuda by default
98
  if device == "auto":
99
  device = "cuda"
 
111
  device = "cpu"
112
  elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
113
  device = "cpu"
114
+ if is_microrts(config):
115
+ try:
116
+ from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv
117
+
118
+ # Models that move more than one unit at a time should use mps
119
+ if not isinstance(env.unwrapped, MicroRTSGridModeVecEnv):
120
+ device = "cpu"
121
+ except ModuleNotFoundError:
122
+ # Likely on gym_microrts v0.0.2 to match ppo-implementation-details
123
+ device = "cpu"
124
  print(f"Device: {device}")
125
  return torch.device(device)
126
 
 
200
  flattened[key] = str(sv)
201
  else:
202
  flattened[key] = sv
203
+ elif isinstance(v, list):
204
+ flattened[k] = json.dumps(v)
205
  else:
206
  flattened[k] = v # type: ignore
207
  return flattened # type: ignore
rl_algo_impls/runner/train.py CHANGED
@@ -5,26 +5,26 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
 
6
  import dataclasses
7
  import shutil
8
- import wandb
9
- import yaml
10
-
11
  from dataclasses import asdict, dataclass
12
- from torch.utils.tensorboard.writer import SummaryWriter
13
  from typing import Any, Dict, Optional, Sequence
14
 
15
- from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
 
 
 
16
  from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
17
- from rl_algo_impls.runner.env import make_env, make_eval_env
18
  from rl_algo_impls.runner.running_utils import (
19
  ALGOS,
20
- load_hyperparams,
21
- set_seeds,
22
  get_device,
 
 
23
  make_policy,
24
  plot_eval_callback,
25
- hparam_dict,
26
  )
 
27
  from rl_algo_impls.shared.stats import EpisodesStats
 
28
 
29
 
30
  @dataclass
@@ -65,7 +65,7 @@ def train(args: TrainArgs):
65
  env = make_env(
66
  config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
67
  )
68
- device = get_device(config.device, env)
69
  policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
70
  algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
71
 
@@ -94,6 +94,7 @@ def train(args: TrainArgs):
94
  if record_best_videos
95
  else None,
96
  best_video_dir=config.best_videos_dir,
 
97
  )
98
  algo.learn(config.n_timesteps, callback=callback)
99
 
 
5
 
6
  import dataclasses
7
  import shutil
 
 
 
8
  from dataclasses import asdict, dataclass
 
9
  from typing import Any, Dict, Optional, Sequence
10
 
11
+ import yaml
12
+ from torch.utils.tensorboard.writer import SummaryWriter
13
+
14
+ import wandb
15
  from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
 
16
  from rl_algo_impls.runner.running_utils import (
17
  ALGOS,
 
 
18
  get_device,
19
+ hparam_dict,
20
+ load_hyperparams,
21
  make_policy,
22
  plot_eval_callback,
23
+ set_seeds,
24
  )
25
+ from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
26
  from rl_algo_impls.shared.stats import EpisodesStats
27
+ from rl_algo_impls.shared.vec_env import make_env, make_eval_env
28
 
29
 
30
  @dataclass
 
65
  env = make_env(
66
  config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
67
  )
68
+ device = get_device(config, env)
69
  policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
70
  algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
71
 
 
94
  if record_best_videos
95
  else None,
96
  best_video_dir=config.best_videos_dir,
97
+ additional_keys_to_log=config.additional_keys_to_log,
98
  )
99
  algo.learn(config.n_timesteps, callback=callback)
100
 
rl_algo_impls/shared/actor/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from rl_algo_impls.shared.actor.actor import Actor, PiForward
2
+ from rl_algo_impls.shared.actor.make_actor import actor_head
rl_algo_impls/shared/actor/actor.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import NamedTuple, Optional, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.distributions import Distribution
8
+
9
+
10
+ class PiForward(NamedTuple):
11
+ pi: Distribution
12
+ logp_a: Optional[torch.Tensor]
13
+ entropy: Optional[torch.Tensor]
14
+
15
+
16
+ class Actor(nn.Module, ABC):
17
+ @abstractmethod
18
+ def forward(
19
+ self,
20
+ obs: torch.Tensor,
21
+ actions: Optional[torch.Tensor] = None,
22
+ action_masks: Optional[torch.Tensor] = None,
23
+ ) -> PiForward:
24
+ ...
25
+
26
+ def sample_weights(self, batch_size: int = 1) -> None:
27
+ pass
28
+
29
+ @property
30
+ @abstractmethod
31
+ def action_shape(self) -> Tuple[int, ...]:
32
+ ...
33
+
34
+ def pi_forward(
35
+ self, distribution: Distribution, actions: Optional[torch.Tensor] = None
36
+ ) -> PiForward:
37
+ logp_a = None
38
+ entropy = None
39
+ if actions is not None:
40
+ logp_a = distribution.log_prob(actions)
41
+ entropy = distribution.entropy()
42
+ return PiForward(distribution, logp_a, entropy)
rl_algo_impls/shared/actor/categorical.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Type
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.distributions import Categorical
6
+
7
+ from rl_algo_impls.shared.actor import Actor, PiForward
8
+ from rl_algo_impls.shared.module.module import mlp
9
+
10
+
11
+ class MaskedCategorical(Categorical):
12
+ def __init__(
13
+ self,
14
+ probs=None,
15
+ logits=None,
16
+ validate_args=None,
17
+ mask: Optional[torch.Tensor] = None,
18
+ ):
19
+ if mask is not None:
20
+ assert logits is not None, "mask requires logits and not probs"
21
+ logits = torch.where(mask, logits, -1e8)
22
+ self.mask = mask
23
+ super().__init__(probs, logits, validate_args)
24
+
25
+ def entropy(self) -> torch.Tensor:
26
+ if self.mask is None:
27
+ return super().entropy()
28
+ # If mask set, then use approximation for entropy
29
+ p_log_p = self.logits * self.probs # type: ignore
30
+ masked = torch.where(self.mask, p_log_p, 0)
31
+ return -masked.sum(-1)
32
+
33
+
34
+ class CategoricalActorHead(Actor):
35
+ def __init__(
36
+ self,
37
+ act_dim: int,
38
+ in_dim: int,
39
+ hidden_sizes: Tuple[int, ...] = (32,),
40
+ activation: Type[nn.Module] = nn.Tanh,
41
+ init_layers_orthogonal: bool = True,
42
+ ) -> None:
43
+ super().__init__()
44
+ layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
45
+ self._fc = mlp(
46
+ layer_sizes,
47
+ activation,
48
+ init_layers_orthogonal=init_layers_orthogonal,
49
+ final_layer_gain=0.01,
50
+ )
51
+
52
+ def forward(
53
+ self,
54
+ obs: torch.Tensor,
55
+ actions: Optional[torch.Tensor] = None,
56
+ action_masks: Optional[torch.Tensor] = None,
57
+ ) -> PiForward:
58
+ logits = self._fc(obs)
59
+ pi = MaskedCategorical(logits=logits, mask=action_masks)
60
+ return self.pi_forward(pi, actions)
61
+
62
+ @property
63
+ def action_shape(self) -> Tuple[int, ...]:
64
+ return ()
rl_algo_impls/shared/actor/gaussian.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Type
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.distributions import Distribution, Normal
6
+
7
+ from rl_algo_impls.shared.actor.actor import Actor, PiForward
8
+ from rl_algo_impls.shared.module.module import mlp
9
+
10
+
11
+ class GaussianDistribution(Normal):
12
+ def log_prob(self, a: torch.Tensor) -> torch.Tensor:
13
+ return super().log_prob(a).sum(axis=-1)
14
+
15
+ def sample(self) -> torch.Tensor:
16
+ return self.rsample()
17
+
18
+
19
+ class GaussianActorHead(Actor):
20
+ def __init__(
21
+ self,
22
+ act_dim: int,
23
+ in_dim: int,
24
+ hidden_sizes: Tuple[int, ...] = (32,),
25
+ activation: Type[nn.Module] = nn.Tanh,
26
+ init_layers_orthogonal: bool = True,
27
+ log_std_init: float = -0.5,
28
+ ) -> None:
29
+ super().__init__()
30
+ self.act_dim = act_dim
31
+ layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
32
+ self.mu_net = mlp(
33
+ layer_sizes,
34
+ activation,
35
+ init_layers_orthogonal=init_layers_orthogonal,
36
+ final_layer_gain=0.01,
37
+ )
38
+ self.log_std = nn.Parameter(
39
+ torch.ones(act_dim, dtype=torch.float32) * log_std_init
40
+ )
41
+
42
+ def _distribution(self, obs: torch.Tensor) -> Distribution:
43
+ mu = self.mu_net(obs)
44
+ std = torch.exp(self.log_std)
45
+ return GaussianDistribution(mu, std)
46
+
47
+ def forward(
48
+ self,
49
+ obs: torch.Tensor,
50
+ actions: Optional[torch.Tensor] = None,
51
+ action_masks: Optional[torch.Tensor] = None,
52
+ ) -> PiForward:
53
+ assert (
54
+ not action_masks
55
+ ), f"{self.__class__.__name__} does not support action_masks"
56
+ pi = self._distribution(obs)
57
+ return self.pi_forward(pi, actions)
58
+
59
+ @property
60
+ def action_shape(self) -> Tuple[int, ...]:
61
+ return (self.act_dim,)
rl_algo_impls/shared/actor/gridnet.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, Type
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from numpy.typing import NDArray
7
+ from torch.distributions import Distribution, constraints
8
+
9
+ from rl_algo_impls.shared.actor import Actor, PiForward
10
+ from rl_algo_impls.shared.actor.categorical import MaskedCategorical
11
+ from rl_algo_impls.shared.encoder import EncoderOutDim
12
+ from rl_algo_impls.shared.module.module import mlp
13
+
14
+
15
+ class GridnetDistribution(Distribution):
16
+ def __init__(
17
+ self,
18
+ map_size: int,
19
+ action_vec: NDArray[np.int64],
20
+ logits: torch.Tensor,
21
+ masks: torch.Tensor,
22
+ validate_args: Optional[bool] = None,
23
+ ) -> None:
24
+ self.map_size = map_size
25
+ self.action_vec = action_vec
26
+
27
+ masks = masks.view(-1, masks.shape[-1])
28
+ split_masks = torch.split(masks[:, 1:], action_vec.tolist(), dim=1)
29
+
30
+ grid_logits = logits.reshape(-1, action_vec.sum())
31
+ split_logits = torch.split(grid_logits, action_vec.tolist(), dim=1)
32
+ self.categoricals = [
33
+ MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
34
+ for lg, m in zip(split_logits, split_masks)
35
+ ]
36
+
37
+ batch_shape = logits.size()[:-1] if logits.ndimension() > 1 else torch.Size()
38
+ super().__init__(batch_shape=batch_shape, validate_args=validate_args)
39
+
40
+ def log_prob(self, action: torch.Tensor) -> torch.Tensor:
41
+ prob_stack = torch.stack(
42
+ [
43
+ c.log_prob(a)
44
+ for a, c in zip(action.view(-1, action.shape[-1]).T, self.categoricals)
45
+ ],
46
+ dim=-1,
47
+ )
48
+ logprob = prob_stack.view(-1, self.map_size, len(self.action_vec))
49
+ return logprob.sum(dim=(1, 2))
50
+
51
+ def entropy(self) -> torch.Tensor:
52
+ ent = torch.stack([c.entropy() for c in self.categoricals], dim=-1)
53
+ ent = ent.view(-1, self.map_size, len(self.action_vec))
54
+ return ent.sum(dim=(1, 2))
55
+
56
+ def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
57
+ s = torch.stack([c.sample(sample_shape) for c in self.categoricals], dim=-1)
58
+ return s.view(-1, self.map_size, len(self.action_vec))
59
+
60
+ @property
61
+ def mode(self) -> torch.Tensor:
62
+ m = torch.stack([c.mode for c in self.categoricals], dim=-1)
63
+ return m.view(-1, self.map_size, len(self.action_vec))
64
+
65
+ @property
66
+ def arg_constraints(self) -> Dict[str, constraints.Constraint]:
67
+ # Constraints handled by child distributions in dist
68
+ return {}
69
+
70
+
71
+ class GridnetActorHead(Actor):
72
+ def __init__(
73
+ self,
74
+ map_size: int,
75
+ action_vec: NDArray[np.int64],
76
+ in_dim: EncoderOutDim,
77
+ hidden_sizes: Tuple[int, ...] = (32,),
78
+ activation: Type[nn.Module] = nn.ReLU,
79
+ init_layers_orthogonal: bool = True,
80
+ ) -> None:
81
+ super().__init__()
82
+ self.map_size = map_size
83
+ self.action_vec = action_vec
84
+ assert isinstance(in_dim, int)
85
+ layer_sizes = (in_dim,) + hidden_sizes + (map_size * action_vec.sum(),)
86
+ self._fc = mlp(
87
+ layer_sizes,
88
+ activation,
89
+ init_layers_orthogonal=init_layers_orthogonal,
90
+ final_layer_gain=0.01,
91
+ )
92
+
93
+ def forward(
94
+ self,
95
+ obs: torch.Tensor,
96
+ actions: Optional[torch.Tensor] = None,
97
+ action_masks: Optional[torch.Tensor] = None,
98
+ ) -> PiForward:
99
+ assert (
100
+ action_masks is not None
101
+ ), f"No mask case unhandled in {self.__class__.__name__}"
102
+ logits = self._fc(obs)
103
+ pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
104
+ return self.pi_forward(pi, actions)
105
+
106
+ @property
107
+ def action_shape(self) -> Tuple[int, ...]:
108
+ return (self.map_size, len(self.action_vec))
rl_algo_impls/shared/actor/gridnet_decoder.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Type
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from numpy.typing import NDArray
7
+
8
+ from rl_algo_impls.shared.actor import Actor, PiForward
9
+ from rl_algo_impls.shared.actor.categorical import MaskedCategorical
10
+ from rl_algo_impls.shared.actor.gridnet import GridnetDistribution
11
+ from rl_algo_impls.shared.encoder import EncoderOutDim
12
+ from rl_algo_impls.shared.module.module import layer_init
13
+
14
+
15
+ class Transpose(nn.Module):
16
+ def __init__(self, permutation: Tuple[int, ...]) -> None:
17
+ super().__init__()
18
+ self.permutation = permutation
19
+
20
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
21
+ return x.permute(self.permutation)
22
+
23
+
24
+ class GridnetDecoder(Actor):
25
+ def __init__(
26
+ self,
27
+ map_size: int,
28
+ action_vec: NDArray[np.int64],
29
+ in_dim: EncoderOutDim,
30
+ activation: Type[nn.Module] = nn.ReLU,
31
+ init_layers_orthogonal: bool = True,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.map_size = map_size
35
+ self.action_vec = action_vec
36
+ assert isinstance(in_dim, tuple)
37
+ self.deconv = nn.Sequential(
38
+ layer_init(
39
+ nn.ConvTranspose2d(
40
+ in_dim[0], 128, 3, stride=2, padding=1, output_padding=1
41
+ ),
42
+ init_layers_orthogonal=init_layers_orthogonal,
43
+ ),
44
+ activation(),
45
+ layer_init(
46
+ nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
47
+ init_layers_orthogonal=init_layers_orthogonal,
48
+ ),
49
+ activation(),
50
+ layer_init(
51
+ nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
52
+ init_layers_orthogonal=init_layers_orthogonal,
53
+ ),
54
+ activation(),
55
+ layer_init(
56
+ nn.ConvTranspose2d(
57
+ 32, action_vec.sum(), 3, stride=2, padding=1, output_padding=1
58
+ ),
59
+ init_layers_orthogonal=init_layers_orthogonal,
60
+ std=0.01,
61
+ ),
62
+ Transpose((0, 2, 3, 1)),
63
+ )
64
+
65
+ def forward(
66
+ self,
67
+ obs: torch.Tensor,
68
+ actions: Optional[torch.Tensor] = None,
69
+ action_masks: Optional[torch.Tensor] = None,
70
+ ) -> PiForward:
71
+ assert (
72
+ action_masks is not None
73
+ ), f"No mask case unhandled in {self.__class__.__name__}"
74
+ logits = self.deconv(obs)
75
+ pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
76
+ return self.pi_forward(pi, actions)
77
+
78
+ @property
79
+ def action_shape(self) -> Tuple[int, ...]:
80
+ return (self.map_size, len(self.action_vec))
rl_algo_impls/shared/actor/make_actor.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Type
2
+
3
+ import gym
4
+ import torch.nn as nn
5
+ from gym.spaces import Box, Discrete, MultiDiscrete
6
+
7
+ from rl_algo_impls.shared.actor.actor import Actor
8
+ from rl_algo_impls.shared.actor.categorical import CategoricalActorHead
9
+ from rl_algo_impls.shared.actor.gaussian import GaussianActorHead
10
+ from rl_algo_impls.shared.actor.gridnet import GridnetActorHead
11
+ from rl_algo_impls.shared.actor.gridnet_decoder import GridnetDecoder
12
+ from rl_algo_impls.shared.actor.multi_discrete import MultiDiscreteActorHead
13
+ from rl_algo_impls.shared.actor.state_dependent_noise import (
14
+ StateDependentNoiseActorHead,
15
+ )
16
+ from rl_algo_impls.shared.encoder import EncoderOutDim
17
+
18
+
19
+ def actor_head(
20
+ action_space: gym.Space,
21
+ in_dim: EncoderOutDim,
22
+ hidden_sizes: Tuple[int, ...],
23
+ init_layers_orthogonal: bool,
24
+ activation: Type[nn.Module],
25
+ log_std_init: float = -0.5,
26
+ use_sde: bool = False,
27
+ full_std: bool = True,
28
+ squash_output: bool = False,
29
+ actor_head_style: str = "single",
30
+ ) -> Actor:
31
+ assert not use_sde or isinstance(
32
+ action_space, Box
33
+ ), "use_sde only valid if Box action_space"
34
+ assert not squash_output or use_sde, "squash_output only valid if use_sde"
35
+ if isinstance(action_space, Discrete):
36
+ assert isinstance(in_dim, int)
37
+ return CategoricalActorHead(
38
+ action_space.n, # type: ignore
39
+ in_dim=in_dim,
40
+ hidden_sizes=hidden_sizes,
41
+ activation=activation,
42
+ init_layers_orthogonal=init_layers_orthogonal,
43
+ )
44
+ elif isinstance(action_space, Box):
45
+ assert isinstance(in_dim, int)
46
+ if use_sde:
47
+ return StateDependentNoiseActorHead(
48
+ action_space.shape[0], # type: ignore
49
+ in_dim=in_dim,
50
+ hidden_sizes=hidden_sizes,
51
+ activation=activation,
52
+ init_layers_orthogonal=init_layers_orthogonal,
53
+ log_std_init=log_std_init,
54
+ full_std=full_std,
55
+ squash_output=squash_output,
56
+ )
57
+ else:
58
+ return GaussianActorHead(
59
+ action_space.shape[0], # type: ignore
60
+ in_dim=in_dim,
61
+ hidden_sizes=hidden_sizes,
62
+ activation=activation,
63
+ init_layers_orthogonal=init_layers_orthogonal,
64
+ log_std_init=log_std_init,
65
+ )
66
+ elif isinstance(action_space, MultiDiscrete):
67
+ if actor_head_style == "single":
68
+ return MultiDiscreteActorHead(
69
+ action_space.nvec, # type: ignore
70
+ in_dim=in_dim,
71
+ hidden_sizes=hidden_sizes,
72
+ activation=activation,
73
+ init_layers_orthogonal=init_layers_orthogonal,
74
+ )
75
+ elif actor_head_style == "gridnet":
76
+ return GridnetActorHead(
77
+ action_space.nvec[0], # type: ignore
78
+ action_space.nvec[1:], # type: ignore
79
+ in_dim=in_dim,
80
+ hidden_sizes=hidden_sizes,
81
+ activation=activation,
82
+ init_layers_orthogonal=init_layers_orthogonal,
83
+ )
84
+ elif actor_head_style == "gridnet_decoder":
85
+ return GridnetDecoder(
86
+ action_space.nvec[0], # type: ignore
87
+ action_space.nvec[1:], # type: ignore
88
+ in_dim=in_dim,
89
+ activation=activation,
90
+ init_layers_orthogonal=init_layers_orthogonal,
91
+ )
92
+ else:
93
+ raise ValueError(f"Doesn't support actor_head_style {actor_head_style}")
94
+ else:
95
+ raise ValueError(f"Unsupported action space: {action_space}")
rl_algo_impls/shared/actor/multi_discrete.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, Type
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from numpy.typing import NDArray
7
+ from torch.distributions import Distribution, constraints
8
+
9
+ from rl_algo_impls.shared.actor.actor import Actor, PiForward
10
+ from rl_algo_impls.shared.actor.categorical import MaskedCategorical
11
+ from rl_algo_impls.shared.encoder import EncoderOutDim
12
+ from rl_algo_impls.shared.module.module import mlp
13
+
14
+
15
+ class MultiCategorical(Distribution):
16
+ def __init__(
17
+ self,
18
+ nvec: NDArray[np.int64],
19
+ probs=None,
20
+ logits=None,
21
+ validate_args=None,
22
+ masks: Optional[torch.Tensor] = None,
23
+ ):
24
+ # Either probs or logits should be set
25
+ assert (probs is None) != (logits is None)
26
+ masks_split = (
27
+ torch.split(masks, nvec.tolist(), dim=1)
28
+ if masks is not None
29
+ else [None] * len(nvec)
30
+ )
31
+ if probs:
32
+ self.dists = [
33
+ MaskedCategorical(probs=p, validate_args=validate_args, mask=m)
34
+ for p, m in zip(torch.split(probs, nvec.tolist(), dim=1), masks_split)
35
+ ]
36
+ param = probs
37
+ else:
38
+ assert logits is not None
39
+ self.dists = [
40
+ MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
41
+ for lg, m in zip(torch.split(logits, nvec.tolist(), dim=1), masks_split)
42
+ ]
43
+ param = logits
44
+ batch_shape = param.size()[:-1] if param.ndimension() > 1 else torch.Size()
45
+ super().__init__(batch_shape=batch_shape, validate_args=validate_args)
46
+
47
+ def log_prob(self, action: torch.Tensor) -> torch.Tensor:
48
+ prob_stack = torch.stack(
49
+ [c.log_prob(a) for a, c in zip(action.T, self.dists)], dim=-1
50
+ )
51
+ return prob_stack.sum(dim=-1)
52
+
53
+ def entropy(self) -> torch.Tensor:
54
+ return torch.stack([c.entropy() for c in self.dists], dim=-1).sum(dim=-1)
55
+
56
+ def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
57
+ return torch.stack([c.sample(sample_shape) for c in self.dists], dim=-1)
58
+
59
+ @property
60
+ def mode(self) -> torch.Tensor:
61
+ return torch.stack([c.mode for c in self.dists], dim=-1)
62
+
63
+ @property
64
+ def arg_constraints(self) -> Dict[str, constraints.Constraint]:
65
+ # Constraints handled by child distributions in dist
66
+ return {}
67
+
68
+
69
+ class MultiDiscreteActorHead(Actor):
70
+ def __init__(
71
+ self,
72
+ nvec: NDArray[np.int64],
73
+ in_dim: EncoderOutDim,
74
+ hidden_sizes: Tuple[int, ...] = (32,),
75
+ activation: Type[nn.Module] = nn.ReLU,
76
+ init_layers_orthogonal: bool = True,
77
+ ) -> None:
78
+ super().__init__()
79
+ self.nvec = nvec
80
+ assert isinstance(in_dim, int)
81
+ layer_sizes = (in_dim,) + hidden_sizes + (nvec.sum(),)
82
+ self._fc = mlp(
83
+ layer_sizes,
84
+ activation,
85
+ init_layers_orthogonal=init_layers_orthogonal,
86
+ final_layer_gain=0.01,
87
+ )
88
+
89
+ def forward(
90
+ self,
91
+ obs: torch.Tensor,
92
+ actions: Optional[torch.Tensor] = None,
93
+ action_masks: Optional[torch.Tensor] = None,
94
+ ) -> PiForward:
95
+ logits = self._fc(obs)
96
+ pi = MultiCategorical(self.nvec, logits=logits, masks=action_masks)
97
+ return self.pi_forward(pi, actions)
98
+
99
+ @property
100
+ def action_shape(self) -> Tuple[int, ...]:
101
+ return (len(self.nvec),)
rl_algo_impls/shared/{policy/actor.py → actor/state_dependent_noise.py} RENAMED
@@ -1,99 +1,13 @@
1
- import gym
 
2
  import torch
3
  import torch.nn as nn
 
4
 
5
- from abc import ABC, abstractmethod
6
- from gym.spaces import Box, Discrete
7
- from torch.distributions import Categorical, Distribution, Normal
8
- from typing import NamedTuple, Optional, Sequence, Type, TypeVar, Union
9
-
10
  from rl_algo_impls.shared.module.module import mlp
11
 
12
 
13
- class PiForward(NamedTuple):
14
- pi: Distribution
15
- logp_a: Optional[torch.Tensor]
16
- entropy: Optional[torch.Tensor]
17
-
18
-
19
- class Actor(nn.Module, ABC):
20
- @abstractmethod
21
- def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
22
- ...
23
-
24
-
25
- class CategoricalActorHead(Actor):
26
- def __init__(
27
- self,
28
- act_dim: int,
29
- hidden_sizes: Sequence[int] = (32,),
30
- activation: Type[nn.Module] = nn.Tanh,
31
- init_layers_orthogonal: bool = True,
32
- ) -> None:
33
- super().__init__()
34
- layer_sizes = tuple(hidden_sizes) + (act_dim,)
35
- self._fc = mlp(
36
- layer_sizes,
37
- activation,
38
- init_layers_orthogonal=init_layers_orthogonal,
39
- final_layer_gain=0.01,
40
- )
41
-
42
- def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
43
- logits = self._fc(obs)
44
- pi = Categorical(logits=logits)
45
- logp_a = None
46
- entropy = None
47
- if a is not None:
48
- logp_a = pi.log_prob(a)
49
- entropy = pi.entropy()
50
- return PiForward(pi, logp_a, entropy)
51
-
52
-
53
- class GaussianDistribution(Normal):
54
- def log_prob(self, a: torch.Tensor) -> torch.Tensor:
55
- return super().log_prob(a).sum(axis=-1)
56
-
57
- def sample(self) -> torch.Tensor:
58
- return self.rsample()
59
-
60
-
61
- class GaussianActorHead(Actor):
62
- def __init__(
63
- self,
64
- act_dim: int,
65
- hidden_sizes: Sequence[int] = (32,),
66
- activation: Type[nn.Module] = nn.Tanh,
67
- init_layers_orthogonal: bool = True,
68
- log_std_init: float = -0.5,
69
- ) -> None:
70
- super().__init__()
71
- layer_sizes = tuple(hidden_sizes) + (act_dim,)
72
- self.mu_net = mlp(
73
- layer_sizes,
74
- activation,
75
- init_layers_orthogonal=init_layers_orthogonal,
76
- final_layer_gain=0.01,
77
- )
78
- self.log_std = nn.Parameter(
79
- torch.ones(act_dim, dtype=torch.float32) * log_std_init
80
- )
81
-
82
- def _distribution(self, obs: torch.Tensor) -> Distribution:
83
- mu = self.mu_net(obs)
84
- std = torch.exp(self.log_std)
85
- return GaussianDistribution(mu, std)
86
-
87
- def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
88
- pi = self._distribution(obs)
89
- logp_a = None
90
- entropy = None
91
- if a is not None:
92
- logp_a = pi.log_prob(a)
93
- entropy = pi.entropy()
94
- return PiForward(pi, logp_a, entropy)
95
-
96
-
97
  class TanhBijector:
98
  def __init__(self, epsilon: float = 1e-6) -> None:
99
  self.epsilon = epsilon
@@ -173,7 +87,8 @@ class StateDependentNoiseActorHead(Actor):
173
  def __init__(
174
  self,
175
  act_dim: int,
176
- hidden_sizes: Sequence[int] = (32,),
 
177
  activation: Type[nn.Module] = nn.Tanh,
178
  init_layers_orthogonal: bool = True,
179
  log_std_init: float = -0.5,
@@ -183,7 +98,7 @@ class StateDependentNoiseActorHead(Actor):
183
  ) -> None:
184
  super().__init__()
185
  self.act_dim = act_dim
186
- layer_sizes = tuple(hidden_sizes) + (self.act_dim,)
187
  if len(layer_sizes) == 2:
188
  self.latent_net = nn.Identity()
189
  elif len(layer_sizes) > 2:
@@ -193,8 +108,6 @@ class StateDependentNoiseActorHead(Actor):
193
  output_activation=activation,
194
  init_layers_orthogonal=init_layers_orthogonal,
195
  )
196
- else:
197
- raise ValueError("hidden_sizes must be of at least length 1")
198
  self.mu_net = mlp(
199
  layer_sizes[-2:],
200
  activation,
@@ -202,7 +115,7 @@ class StateDependentNoiseActorHead(Actor):
202
  final_layer_gain=0.01,
203
  )
204
  self.full_std = full_std
205
- std_dim = (hidden_sizes[-1], act_dim if self.full_std else 1)
206
  self.log_std = nn.Parameter(
207
  torch.ones(std_dim, dtype=torch.float32) * log_std_init
208
  )
@@ -249,14 +162,17 @@ class StateDependentNoiseActorHead(Actor):
249
  ones = ones.to(self.device)
250
  return ones * std
251
 
252
- def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
 
 
 
 
 
 
 
 
253
  pi = self._distribution(obs)
254
- logp_a = None
255
- entropy = None
256
- if a is not None:
257
- logp_a = pi.log_prob(a)
258
- entropy = -logp_a if self.bijector else sum_independent_dims(pi.entropy())
259
- return PiForward(pi, logp_a, entropy)
260
 
261
  def sample_weights(self, batch_size: int = 1) -> None:
262
  std = self._get_std()
@@ -265,46 +181,20 @@ class StateDependentNoiseActorHead(Actor):
265
  self.exploration_mat = weights_dist.rsample()
266
  self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,)))
267
 
 
 
 
268
 
269
- def actor_head(
270
- action_space: gym.Space,
271
- hidden_sizes: Sequence[int],
272
- init_layers_orthogonal: bool,
273
- activation: Type[nn.Module],
274
- log_std_init: float = -0.5,
275
- use_sde: bool = False,
276
- full_std: bool = True,
277
- squash_output: bool = False,
278
- ) -> Actor:
279
- assert not use_sde or isinstance(
280
- action_space, Box
281
- ), "use_sde only valid if Box action_space"
282
- assert not squash_output or use_sde, "squash_output only valid if use_sde"
283
- if isinstance(action_space, Discrete):
284
- return CategoricalActorHead(
285
- action_space.n,
286
- hidden_sizes=hidden_sizes,
287
- activation=activation,
288
- init_layers_orthogonal=init_layers_orthogonal,
289
- )
290
- elif isinstance(action_space, Box):
291
- if use_sde:
292
- return StateDependentNoiseActorHead(
293
- action_space.shape[0],
294
- hidden_sizes=hidden_sizes,
295
- activation=activation,
296
- init_layers_orthogonal=init_layers_orthogonal,
297
- log_std_init=log_std_init,
298
- full_std=full_std,
299
- squash_output=squash_output,
300
- )
301
- else:
302
- return GaussianActorHead(
303
- action_space.shape[0],
304
- hidden_sizes=hidden_sizes,
305
- activation=activation,
306
- init_layers_orthogonal=init_layers_orthogonal,
307
- log_std_init=log_std_init,
308
  )
309
- else:
310
- raise ValueError(f"Unsupported action space: {action_space}")
 
1
+ from typing import Optional, Tuple, Type, TypeVar, Union
2
+
3
  import torch
4
  import torch.nn as nn
5
+ from torch.distributions import Distribution, Normal
6
 
7
+ from rl_algo_impls.shared.actor.actor import Actor, PiForward
 
 
 
 
8
  from rl_algo_impls.shared.module.module import mlp
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class TanhBijector:
12
  def __init__(self, epsilon: float = 1e-6) -> None:
13
  self.epsilon = epsilon
 
87
  def __init__(
88
  self,
89
  act_dim: int,
90
+ in_dim: int,
91
+ hidden_sizes: Tuple[int, ...] = (32,),
92
  activation: Type[nn.Module] = nn.Tanh,
93
  init_layers_orthogonal: bool = True,
94
  log_std_init: float = -0.5,
 
98
  ) -> None:
99
  super().__init__()
100
  self.act_dim = act_dim
101
+ layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
102
  if len(layer_sizes) == 2:
103
  self.latent_net = nn.Identity()
104
  elif len(layer_sizes) > 2:
 
108
  output_activation=activation,
109
  init_layers_orthogonal=init_layers_orthogonal,
110
  )
 
 
111
  self.mu_net = mlp(
112
  layer_sizes[-2:],
113
  activation,
 
115
  final_layer_gain=0.01,
116
  )
117
  self.full_std = full_std
118
+ std_dim = (layer_sizes[-2], act_dim if self.full_std else 1)
119
  self.log_std = nn.Parameter(
120
  torch.ones(std_dim, dtype=torch.float32) * log_std_init
121
  )
 
162
  ones = ones.to(self.device)
163
  return ones * std
164
 
165
+ def forward(
166
+ self,
167
+ obs: torch.Tensor,
168
+ actions: Optional[torch.Tensor] = None,
169
+ action_masks: Optional[torch.Tensor] = None,
170
+ ) -> PiForward:
171
+ assert (
172
+ not action_masks
173
+ ), f"{self.__class__.__name__} does not support action_masks"
174
  pi = self._distribution(obs)
175
+ return self.pi_forward(pi, actions)
 
 
 
 
 
176
 
177
  def sample_weights(self, batch_size: int = 1) -> None:
178
  std = self._get_std()
 
181
  self.exploration_mat = weights_dist.rsample()
182
  self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,)))
183
 
184
+ @property
185
+ def action_shape(self) -> Tuple[int, ...]:
186
+ return (self.act_dim,)
187
 
188
+ def pi_forward(
189
+ self, distribution: Distribution, actions: Optional[torch.Tensor] = None
190
+ ) -> PiForward:
191
+ logp_a = None
192
+ entropy = None
193
+ if actions is not None:
194
+ logp_a = distribution.log_prob(actions)
195
+ entropy = (
196
+ -logp_a
197
+ if self.bijector
198
+ else sum_independent_dims(distribution.entropy())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  )
200
+ return PiForward(distribution, logp_a, entropy)
 
rl_algo_impls/shared/callbacks/eval_callback.py CHANGED
@@ -1,14 +1,15 @@
1
  import itertools
2
- import numpy as np
3
  import os
4
-
5
  from time import perf_counter
 
 
 
6
  from torch.utils.tensorboard.writer import SummaryWriter
7
- from typing import List, Optional, Union
8
 
9
  from rl_algo_impls.shared.callbacks.callback import Callback
10
  from rl_algo_impls.shared.policy.policy import Policy
11
  from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
 
12
  from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
13
  from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
14
 
@@ -20,6 +21,7 @@ class EvaluateAccumulator(EpisodeAccumulator):
20
  goal_episodes: int,
21
  print_returns: bool = True,
22
  ignore_first_episode: bool = False,
 
23
  ):
24
  super().__init__(num_envs)
25
  self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)]
@@ -36,8 +38,11 @@ class EvaluateAccumulator(EpisodeAccumulator):
36
  self.should_record_done = should_record_done
37
  else:
38
  self.should_record_done = lambda idx: True
 
39
 
40
- def on_done(self, ep_idx: int, episode: Episode) -> None:
 
 
41
  if (
42
  self.should_record_done(ep_idx)
43
  and len(self.completed_episodes_by_env_idx[ep_idx])
@@ -74,19 +79,29 @@ def evaluate(
74
  deterministic: bool = True,
75
  print_returns: bool = True,
76
  ignore_first_episode: bool = False,
 
77
  ) -> EpisodesStats:
78
  policy.sync_normalization(env)
79
  policy.eval()
80
 
81
  episodes = EvaluateAccumulator(
82
- env.num_envs, n_episodes, print_returns, ignore_first_episode
 
 
 
 
83
  )
84
 
85
  obs = env.reset()
 
86
  while not episodes.is_done():
87
- act = policy.act(obs, deterministic=deterministic)
88
- obs, rew, done, _ = env.step(act)
89
- episodes.step(rew, done)
 
 
 
 
90
  if render:
91
  env.render()
92
  stats = EpisodesStats(episodes.episodes)
@@ -111,6 +126,7 @@ class EvalCallback(Callback):
111
  best_video_dir: Optional[str] = None,
112
  max_video_length: int = 3600,
113
  ignore_first_episode: bool = False,
 
114
  ) -> None:
115
  super().__init__()
116
  self.policy = policy
@@ -133,8 +149,8 @@ class EvalCallback(Callback):
133
  os.makedirs(best_video_dir, exist_ok=True)
134
  self.max_video_length = max_video_length
135
  self.best_video_base_path = None
136
-
137
  self.ignore_first_episode = ignore_first_episode
 
138
 
139
  def on_step(self, timesteps_elapsed: int = 1) -> bool:
140
  super().on_step(timesteps_elapsed)
@@ -153,6 +169,7 @@ class EvalCallback(Callback):
153
  deterministic=self.deterministic,
154
  print_returns=print_returns or False,
155
  ignore_first_episode=self.ignore_first_episode,
 
156
  )
157
  end_time = perf_counter()
158
  self.tb_writer.add_scalar(
 
1
  import itertools
 
2
  import os
 
3
  from time import perf_counter
4
+ from typing import Dict, List, Optional, Union
5
+
6
+ import numpy as np
7
  from torch.utils.tensorboard.writer import SummaryWriter
 
8
 
9
  from rl_algo_impls.shared.callbacks.callback import Callback
10
  from rl_algo_impls.shared.policy.policy import Policy
11
  from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
12
+ from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
13
  from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
14
  from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
15
 
 
21
  goal_episodes: int,
22
  print_returns: bool = True,
23
  ignore_first_episode: bool = False,
24
+ additional_keys_to_log: Optional[List[str]] = None,
25
  ):
26
  super().__init__(num_envs)
27
  self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)]
 
38
  self.should_record_done = should_record_done
39
  else:
40
  self.should_record_done = lambda idx: True
41
+ self.additional_keys_to_log = additional_keys_to_log
42
 
43
+ def on_done(self, ep_idx: int, episode: Episode, info: Dict) -> None:
44
+ if self.additional_keys_to_log:
45
+ episode.info = {k: info[k] for k in self.additional_keys_to_log}
46
  if (
47
  self.should_record_done(ep_idx)
48
  and len(self.completed_episodes_by_env_idx[ep_idx])
 
79
  deterministic: bool = True,
80
  print_returns: bool = True,
81
  ignore_first_episode: bool = False,
82
+ additional_keys_to_log: Optional[List[str]] = None,
83
  ) -> EpisodesStats:
84
  policy.sync_normalization(env)
85
  policy.eval()
86
 
87
  episodes = EvaluateAccumulator(
88
+ env.num_envs,
89
+ n_episodes,
90
+ print_returns,
91
+ ignore_first_episode,
92
+ additional_keys_to_log=additional_keys_to_log,
93
  )
94
 
95
  obs = env.reset()
96
+ action_masker = find_action_masker(env)
97
  while not episodes.is_done():
98
+ act = policy.act(
99
+ obs,
100
+ deterministic=deterministic,
101
+ action_masks=action_masker.action_masks() if action_masker else None,
102
+ )
103
+ obs, rew, done, info = env.step(act)
104
+ episodes.step(rew, done, info)
105
  if render:
106
  env.render()
107
  stats = EpisodesStats(episodes.episodes)
 
126
  best_video_dir: Optional[str] = None,
127
  max_video_length: int = 3600,
128
  ignore_first_episode: bool = False,
129
+ additional_keys_to_log: Optional[List[str]] = None,
130
  ) -> None:
131
  super().__init__()
132
  self.policy = policy
 
149
  os.makedirs(best_video_dir, exist_ok=True)
150
  self.max_video_length = max_video_length
151
  self.best_video_base_path = None
 
152
  self.ignore_first_episode = ignore_first_episode
153
+ self.additional_keys_to_log = additional_keys_to_log
154
 
155
  def on_step(self, timesteps_elapsed: int = 1) -> bool:
156
  super().on_step(timesteps_elapsed)
 
169
  deterministic=self.deterministic,
170
  print_returns=print_returns or False,
171
  ignore_first_episode=self.ignore_first_episode,
172
+ additional_keys_to_log=self.additional_keys_to_log,
173
  )
174
  end_time = perf_counter()
175
  self.tb_writer.add_scalar(
rl_algo_impls/shared/encoder/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from rl_algo_impls.shared.encoder.cnn import EncoderOutDim
2
+ from rl_algo_impls.shared.encoder.encoder import Encoder
rl_algo_impls/shared/encoder/cnn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional, Tuple, Type, Union
3
+
4
+ import gym
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from rl_algo_impls.shared.module.module import layer_init
10
+
11
+ EncoderOutDim = Union[int, Tuple[int, ...]]
12
+
13
+
14
+ class CnnEncoder(nn.Module, ABC):
15
+ @abstractmethod
16
+ def __init__(
17
+ self,
18
+ obs_space: gym.Space,
19
+ **kwargs,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.range_size = np.max(obs_space.high) - np.min(obs_space.low) # type: ignore
23
+
24
+ def preprocess(self, obs: torch.Tensor) -> torch.Tensor:
25
+ if len(obs.shape) == 3:
26
+ obs = obs.unsqueeze(0)
27
+ return obs.float() / self.range_size
28
+
29
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
30
+ return self.preprocess(obs)
31
+
32
+ @property
33
+ @abstractmethod
34
+ def out_dim(self) -> EncoderOutDim:
35
+ ...
36
+
37
+
38
+ class FlattenedCnnEncoder(CnnEncoder):
39
+ def __init__(
40
+ self,
41
+ obs_space: gym.Space,
42
+ activation: Type[nn.Module],
43
+ linear_init_layers_orthogonal: bool,
44
+ cnn_flatten_dim: int,
45
+ cnn: nn.Module,
46
+ **kwargs,
47
+ ) -> None:
48
+ super().__init__(obs_space, **kwargs)
49
+ self.cnn = cnn
50
+ self.flattened_dim = cnn_flatten_dim
51
+ with torch.no_grad():
52
+ cnn_out = torch.flatten(
53
+ cnn(self.preprocess(torch.as_tensor(obs_space.sample()))), start_dim=1
54
+ )
55
+ self.fc = nn.Sequential(
56
+ nn.Flatten(),
57
+ layer_init(
58
+ nn.Linear(cnn_out.shape[1], cnn_flatten_dim),
59
+ linear_init_layers_orthogonal,
60
+ ),
61
+ activation(),
62
+ )
63
+
64
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
65
+ x = super().forward(obs)
66
+ x = self.cnn(x)
67
+ x = self.fc(x)
68
+ return x
69
+
70
+ @property
71
+ def out_dim(self) -> EncoderOutDim:
72
+ return self.flattened_dim
rl_algo_impls/shared/encoder/encoder.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Sequence, Type
2
+
3
+ import gym
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from gym.spaces import Box, Discrete
8
+ from stable_baselines3.common.preprocessing import get_flattened_obs_dim
9
+
10
+ from rl_algo_impls.shared.encoder.cnn import CnnEncoder
11
+ from rl_algo_impls.shared.encoder.gridnet_encoder import GridnetEncoder
12
+ from rl_algo_impls.shared.encoder.impala_cnn import ImpalaCnn
13
+ from rl_algo_impls.shared.encoder.microrts_cnn import MicrortsCnn
14
+ from rl_algo_impls.shared.encoder.nature_cnn import NatureCnn
15
+ from rl_algo_impls.shared.module.module import layer_init
16
+
17
+ CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnEncoder]] = {
18
+ "nature": NatureCnn,
19
+ "impala": ImpalaCnn,
20
+ "microrts": MicrortsCnn,
21
+ "gridnet_encoder": GridnetEncoder,
22
+ }
23
+
24
+
25
+ class Encoder(nn.Module):
26
+ def __init__(
27
+ self,
28
+ obs_space: gym.Space,
29
+ activation: Type[nn.Module],
30
+ init_layers_orthogonal: bool = False,
31
+ cnn_flatten_dim: int = 512,
32
+ cnn_style: str = "nature",
33
+ cnn_layers_init_orthogonal: Optional[bool] = None,
34
+ impala_channels: Sequence[int] = (16, 32, 32),
35
+ ) -> None:
36
+ super().__init__()
37
+ if isinstance(obs_space, Box):
38
+ # Conv2D: (channels, height, width)
39
+ if len(obs_space.shape) == 3: # type: ignore
40
+ self.preprocess = None
41
+ cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
42
+ obs_space,
43
+ activation=activation,
44
+ cnn_init_layers_orthogonal=cnn_layers_init_orthogonal,
45
+ linear_init_layers_orthogonal=init_layers_orthogonal,
46
+ cnn_flatten_dim=cnn_flatten_dim,
47
+ impala_channels=impala_channels,
48
+ )
49
+ self.feature_extractor = cnn
50
+ self.out_dim = cnn.out_dim
51
+ elif len(obs_space.shape) == 1: # type: ignore
52
+
53
+ def preprocess(obs: torch.Tensor) -> torch.Tensor:
54
+ if len(obs.shape) == 1:
55
+ obs = obs.unsqueeze(0)
56
+ return obs.float()
57
+
58
+ self.preprocess = preprocess
59
+ self.feature_extractor = nn.Flatten()
60
+ self.out_dim = get_flattened_obs_dim(obs_space)
61
+ else:
62
+ raise ValueError(f"Unsupported observation space: {obs_space}")
63
+ elif isinstance(obs_space, Discrete):
64
+ self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
65
+ self.feature_extractor = nn.Flatten()
66
+ self.out_dim = obs_space.n # type: ignore
67
+ else:
68
+ raise NotImplementedError
69
+
70
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
71
+ if self.preprocess:
72
+ obs = self.preprocess(obs)
73
+ return self.feature_extractor(obs)
rl_algo_impls/shared/encoder/gridnet_encoder.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Type, Union
2
+
3
+ import gym
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from rl_algo_impls.shared.encoder.cnn import CnnEncoder, EncoderOutDim
8
+ from rl_algo_impls.shared.module.module import layer_init
9
+
10
+
11
+ class GridnetEncoder(CnnEncoder):
12
+ """
13
+ Encoder for encoder-decoder for Gym-MicroRTS
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ obs_space: gym.Space,
19
+ activation: Type[nn.Module] = nn.ReLU,
20
+ cnn_init_layers_orthogonal: Optional[bool] = None,
21
+ **kwargs
22
+ ) -> None:
23
+ if cnn_init_layers_orthogonal is None:
24
+ cnn_init_layers_orthogonal = True
25
+ super().__init__(obs_space, **kwargs)
26
+ in_channels = obs_space.shape[0] # type: ignore
27
+ self.encoder = nn.Sequential(
28
+ layer_init(
29
+ nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
30
+ cnn_init_layers_orthogonal,
31
+ ),
32
+ nn.MaxPool2d(3, stride=2, padding=1),
33
+ activation(),
34
+ layer_init(
35
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
36
+ cnn_init_layers_orthogonal,
37
+ ),
38
+ nn.MaxPool2d(3, stride=2, padding=1),
39
+ activation(),
40
+ layer_init(
41
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
42
+ cnn_init_layers_orthogonal,
43
+ ),
44
+ nn.MaxPool2d(3, stride=2, padding=1),
45
+ activation(),
46
+ layer_init(
47
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
48
+ cnn_init_layers_orthogonal,
49
+ ),
50
+ nn.MaxPool2d(3, stride=2, padding=1),
51
+ activation(),
52
+ )
53
+ with torch.no_grad():
54
+ encoder_out = self.encoder(
55
+ self.preprocess(torch.as_tensor(obs_space.sample())) # type: ignore
56
+ )
57
+ self._out_dim = encoder_out.shape[1:]
58
+
59
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
60
+ return self.encoder(super().forward(obs))
61
+
62
+ @property
63
+ def out_dim(self) -> EncoderOutDim:
64
+ return self._out_dim
rl_algo_impls/shared/encoder/impala_cnn.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Type
2
+
3
+ import gym
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
8
+ from rl_algo_impls.shared.module.module import layer_init
9
+
10
+
11
+ class ResidualBlock(nn.Module):
12
+ def __init__(
13
+ self,
14
+ channels: int,
15
+ activation: Type[nn.Module] = nn.ReLU,
16
+ init_layers_orthogonal: bool = False,
17
+ ) -> None:
18
+ super().__init__()
19
+ self.residual = nn.Sequential(
20
+ activation(),
21
+ layer_init(
22
+ nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
23
+ ),
24
+ activation(),
25
+ layer_init(
26
+ nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
27
+ ),
28
+ )
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ return x + self.residual(x)
32
+
33
+
34
+ class ConvSequence(nn.Module):
35
+ def __init__(
36
+ self,
37
+ in_channels: int,
38
+ out_channels: int,
39
+ activation: Type[nn.Module] = nn.ReLU,
40
+ init_layers_orthogonal: bool = False,
41
+ ) -> None:
42
+ super().__init__()
43
+ self.seq = nn.Sequential(
44
+ layer_init(
45
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
46
+ init_layers_orthogonal,
47
+ ),
48
+ nn.MaxPool2d(3, stride=2, padding=1),
49
+ ResidualBlock(out_channels, activation, init_layers_orthogonal),
50
+ ResidualBlock(out_channels, activation, init_layers_orthogonal),
51
+ )
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ return self.seq(x)
55
+
56
+
57
+ class ImpalaCnn(FlattenedCnnEncoder):
58
+ """
59
+ IMPALA-style CNN architecture
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ obs_space: gym.Space,
65
+ activation: Type[nn.Module],
66
+ cnn_init_layers_orthogonal: Optional[bool],
67
+ linear_init_layers_orthogonal: bool,
68
+ cnn_flatten_dim: int,
69
+ impala_channels: Sequence[int] = (16, 32, 32),
70
+ **kwargs,
71
+ ) -> None:
72
+ if cnn_init_layers_orthogonal is None:
73
+ cnn_init_layers_orthogonal = False
74
+ in_channels = obs_space.shape[0] # type: ignore
75
+ sequences = []
76
+ for out_channels in impala_channels:
77
+ sequences.append(
78
+ ConvSequence(
79
+ in_channels, out_channels, activation, cnn_init_layers_orthogonal
80
+ )
81
+ )
82
+ in_channels = out_channels
83
+ sequences.append(activation())
84
+ cnn = nn.Sequential(*sequences)
85
+ super().__init__(
86
+ obs_space,
87
+ activation,
88
+ linear_init_layers_orthogonal,
89
+ cnn_flatten_dim,
90
+ cnn,
91
+ **kwargs,
92
+ )
rl_algo_impls/shared/encoder/microrts_cnn.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Type
2
+
3
+ import gym
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
8
+ from rl_algo_impls.shared.module.module import layer_init
9
+
10
+
11
+ class MicrortsCnn(FlattenedCnnEncoder):
12
+ """
13
+ Base CNN architecture for Gym-MicroRTS
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ obs_space: gym.Space,
19
+ activation: Type[nn.Module],
20
+ cnn_init_layers_orthogonal: Optional[bool],
21
+ linear_init_layers_orthogonal: bool,
22
+ cnn_flatten_dim: int,
23
+ **kwargs,
24
+ ) -> None:
25
+ if cnn_init_layers_orthogonal is None:
26
+ cnn_init_layers_orthogonal = True
27
+ in_channels = obs_space.shape[0] # type: ignore
28
+ cnn = nn.Sequential(
29
+ layer_init(
30
+ nn.Conv2d(in_channels, 16, kernel_size=3, stride=2),
31
+ cnn_init_layers_orthogonal,
32
+ ),
33
+ activation(),
34
+ layer_init(nn.Conv2d(16, 32, kernel_size=2), cnn_init_layers_orthogonal),
35
+ activation(),
36
+ nn.Flatten(),
37
+ )
38
+ super().__init__(
39
+ obs_space,
40
+ activation,
41
+ linear_init_layers_orthogonal,
42
+ cnn_flatten_dim,
43
+ cnn,
44
+ **kwargs,
45
+ )
rl_algo_impls/shared/encoder/nature_cnn.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Type
2
+
3
+ import gym
4
+ import torch.nn as nn
5
+
6
+ from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
7
+ from rl_algo_impls.shared.module.module import layer_init
8
+
9
+
10
+ class NatureCnn(FlattenedCnnEncoder):
11
+ """
12
+ CNN from DQN Nature paper: Mnih, Volodymyr, et al.
13
+ "Human-level control through deep reinforcement learning."
14
+ Nature 518.7540 (2015): 529-533.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ obs_space: gym.Space,
20
+ activation: Type[nn.Module],
21
+ cnn_init_layers_orthogonal: Optional[bool],
22
+ linear_init_layers_orthogonal: bool,
23
+ cnn_flatten_dim: int,
24
+ **kwargs,
25
+ ) -> None:
26
+ if cnn_init_layers_orthogonal is None:
27
+ cnn_init_layers_orthogonal = True
28
+ in_channels = obs_space.shape[0] # type: ignore
29
+ cnn = nn.Sequential(
30
+ layer_init(
31
+ nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
32
+ cnn_init_layers_orthogonal,
33
+ ),
34
+ activation(),
35
+ layer_init(
36
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
37
+ cnn_init_layers_orthogonal,
38
+ ),
39
+ activation(),
40
+ layer_init(
41
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
42
+ cnn_init_layers_orthogonal,
43
+ ),
44
+ activation(),
45
+ )
46
+ super().__init__(
47
+ obs_space,
48
+ activation,
49
+ linear_init_layers_orthogonal,
50
+ cnn_flatten_dim,
51
+ cnn,
52
+ **kwargs,
53
+ )
rl_algo_impls/shared/gae.py CHANGED
@@ -5,6 +5,7 @@ from typing import NamedTuple, Sequence
5
 
6
  from rl_algo_impls.shared.policy.on_policy import OnPolicy
7
  from rl_algo_impls.shared.trajectory import Trajectory
 
8
 
9
 
10
  class RtgAdvantage(NamedTuple):
@@ -19,7 +20,7 @@ def discounted_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
19
  return dc
20
 
21
 
22
- def compute_advantage(
23
  trajectories: Sequence[Trajectory],
24
  policy: OnPolicy,
25
  gamma: float,
@@ -40,7 +41,7 @@ def compute_advantage(
40
  )
41
 
42
 
43
- def compute_rtg_and_advantage(
44
  trajectories: Sequence[Trajectory],
45
  policy: OnPolicy,
46
  gamma: float,
@@ -65,3 +66,29 @@ def compute_rtg_and_advantage(
65
  ),
66
  torch.as_tensor(np.concatenate(advantages), dtype=torch.float32, device=device),
67
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  from rl_algo_impls.shared.policy.on_policy import OnPolicy
7
  from rl_algo_impls.shared.trajectory import Trajectory
8
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvObs
9
 
10
 
11
  class RtgAdvantage(NamedTuple):
 
20
  return dc
21
 
22
 
23
+ def compute_advantage_from_trajectories(
24
  trajectories: Sequence[Trajectory],
25
  policy: OnPolicy,
26
  gamma: float,
 
41
  )
42
 
43
 
44
+ def compute_rtg_and_advantage_from_trajectories(
45
  trajectories: Sequence[Trajectory],
46
  policy: OnPolicy,
47
  gamma: float,
 
66
  ),
67
  torch.as_tensor(np.concatenate(advantages), dtype=torch.float32, device=device),
68
  )
69
+
70
+
71
+ def compute_advantages(
72
+ rewards: np.ndarray,
73
+ values: np.ndarray,
74
+ episode_starts: np.ndarray,
75
+ next_episode_starts: np.ndarray,
76
+ next_obs: VecEnvObs,
77
+ policy: OnPolicy,
78
+ gamma: float,
79
+ gae_lambda: float,
80
+ ) -> np.ndarray:
81
+ advantages = np.zeros_like(rewards)
82
+ last_gae_lam = 0
83
+ n_steps = advantages.shape[0]
84
+ for t in reversed(range(n_steps)):
85
+ if t == n_steps - 1:
86
+ next_nonterminal = 1.0 - next_episode_starts
87
+ next_value = policy.value(next_obs)
88
+ else:
89
+ next_nonterminal = 1.0 - episode_starts[t + 1]
90
+ next_value = values[t + 1]
91
+ delta = rewards[t] + gamma * next_value * next_nonterminal - values[t]
92
+ last_gae_lam = delta + gamma * gae_lambda * next_nonterminal * last_gae_lam
93
+ advantages[t] = last_gae_lam
94
+ return advantages
rl_algo_impls/shared/module/feature_extractor.py DELETED
@@ -1,215 +0,0 @@
1
- import gym
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
- from abc import ABC, abstractmethod
7
- from gym.spaces import Box, Discrete
8
- from stable_baselines3.common.preprocessing import get_flattened_obs_dim
9
- from typing import Dict, Optional, Sequence, Type
10
-
11
- from rl_algo_impls.shared.module.module import layer_init
12
-
13
-
14
- class CnnFeatureExtractor(nn.Module, ABC):
15
- @abstractmethod
16
- def __init__(
17
- self,
18
- in_channels: int,
19
- activation: Type[nn.Module] = nn.ReLU,
20
- init_layers_orthogonal: Optional[bool] = None,
21
- **kwargs,
22
- ) -> None:
23
- super().__init__()
24
-
25
-
26
- class NatureCnn(CnnFeatureExtractor):
27
- """
28
- CNN from DQN Nature paper: Mnih, Volodymyr, et al.
29
- "Human-level control through deep reinforcement learning."
30
- Nature 518.7540 (2015): 529-533.
31
- """
32
-
33
- def __init__(
34
- self,
35
- in_channels: int,
36
- activation: Type[nn.Module] = nn.ReLU,
37
- init_layers_orthogonal: Optional[bool] = None,
38
- **kwargs,
39
- ) -> None:
40
- if init_layers_orthogonal is None:
41
- init_layers_orthogonal = True
42
- super().__init__(in_channels, activation, init_layers_orthogonal)
43
- self.cnn = nn.Sequential(
44
- layer_init(
45
- nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
46
- init_layers_orthogonal,
47
- ),
48
- activation(),
49
- layer_init(
50
- nn.Conv2d(32, 64, kernel_size=4, stride=2),
51
- init_layers_orthogonal,
52
- ),
53
- activation(),
54
- layer_init(
55
- nn.Conv2d(64, 64, kernel_size=3, stride=1),
56
- init_layers_orthogonal,
57
- ),
58
- activation(),
59
- nn.Flatten(),
60
- )
61
-
62
- def forward(self, obs: torch.Tensor) -> torch.Tensor:
63
- return self.cnn(obs)
64
-
65
-
66
- class ResidualBlock(nn.Module):
67
- def __init__(
68
- self,
69
- channels: int,
70
- activation: Type[nn.Module] = nn.ReLU,
71
- init_layers_orthogonal: bool = False,
72
- ) -> None:
73
- super().__init__()
74
- self.residual = nn.Sequential(
75
- activation(),
76
- layer_init(
77
- nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
78
- ),
79
- activation(),
80
- layer_init(
81
- nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
82
- ),
83
- )
84
-
85
- def forward(self, x: torch.Tensor) -> torch.Tensor:
86
- return x + self.residual(x)
87
-
88
-
89
- class ConvSequence(nn.Module):
90
- def __init__(
91
- self,
92
- in_channels: int,
93
- out_channels: int,
94
- activation: Type[nn.Module] = nn.ReLU,
95
- init_layers_orthogonal: bool = False,
96
- ) -> None:
97
- super().__init__()
98
- self.seq = nn.Sequential(
99
- layer_init(
100
- nn.Conv2d(in_channels, out_channels, 3, padding=1),
101
- init_layers_orthogonal,
102
- ),
103
- nn.MaxPool2d(3, stride=2, padding=1),
104
- ResidualBlock(out_channels, activation, init_layers_orthogonal),
105
- ResidualBlock(out_channels, activation, init_layers_orthogonal),
106
- )
107
-
108
- def forward(self, x: torch.Tensor) -> torch.Tensor:
109
- return self.seq(x)
110
-
111
-
112
- class ImpalaCnn(CnnFeatureExtractor):
113
- """
114
- IMPALA-style CNN architecture
115
- """
116
-
117
- def __init__(
118
- self,
119
- in_channels: int,
120
- activation: Type[nn.Module] = nn.ReLU,
121
- init_layers_orthogonal: Optional[bool] = None,
122
- impala_channels: Sequence[int] = (16, 32, 32),
123
- **kwargs,
124
- ) -> None:
125
- if init_layers_orthogonal is None:
126
- init_layers_orthogonal = False
127
- super().__init__(in_channels, activation, init_layers_orthogonal)
128
- sequences = []
129
- for out_channels in impala_channels:
130
- sequences.append(
131
- ConvSequence(
132
- in_channels, out_channels, activation, init_layers_orthogonal
133
- )
134
- )
135
- in_channels = out_channels
136
- sequences.extend(
137
- [
138
- activation(),
139
- nn.Flatten(),
140
- ]
141
- )
142
- self.seq = nn.Sequential(*sequences)
143
-
144
- def forward(self, obs: torch.Tensor) -> torch.Tensor:
145
- return self.seq(obs)
146
-
147
-
148
- CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnFeatureExtractor]] = {
149
- "nature": NatureCnn,
150
- "impala": ImpalaCnn,
151
- }
152
-
153
-
154
- class FeatureExtractor(nn.Module):
155
- def __init__(
156
- self,
157
- obs_space: gym.Space,
158
- activation: Type[nn.Module],
159
- init_layers_orthogonal: bool = False,
160
- cnn_feature_dim: int = 512,
161
- cnn_style: str = "nature",
162
- cnn_layers_init_orthogonal: Optional[bool] = None,
163
- impala_channels: Sequence[int] = (16, 32, 32),
164
- ) -> None:
165
- super().__init__()
166
- if isinstance(obs_space, Box):
167
- # Conv2D: (channels, height, width)
168
- if len(obs_space.shape) == 3:
169
- cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
170
- obs_space.shape[0],
171
- activation,
172
- init_layers_orthogonal=cnn_layers_init_orthogonal,
173
- impala_channels=impala_channels,
174
- )
175
-
176
- def preprocess(obs: torch.Tensor) -> torch.Tensor:
177
- if len(obs.shape) == 3:
178
- obs = obs.unsqueeze(0)
179
- return obs.float() / 255.0
180
-
181
- with torch.no_grad():
182
- cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample())))
183
- self.preprocess = preprocess
184
- self.feature_extractor = nn.Sequential(
185
- cnn,
186
- layer_init(
187
- nn.Linear(cnn_out.shape[1], cnn_feature_dim),
188
- init_layers_orthogonal,
189
- ),
190
- activation(),
191
- )
192
- self.out_dim = cnn_feature_dim
193
- elif len(obs_space.shape) == 1:
194
-
195
- def preprocess(obs: torch.Tensor) -> torch.Tensor:
196
- if len(obs.shape) == 1:
197
- obs = obs.unsqueeze(0)
198
- return obs.float()
199
-
200
- self.preprocess = preprocess
201
- self.feature_extractor = nn.Flatten()
202
- self.out_dim = get_flattened_obs_dim(obs_space)
203
- else:
204
- raise ValueError(f"Unsupported observation space: {obs_space}")
205
- elif isinstance(obs_space, Discrete):
206
- self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
207
- self.feature_extractor = nn.Flatten()
208
- self.out_dim = obs_space.n
209
- else:
210
- raise NotImplementedError
211
-
212
- def forward(self, obs: torch.Tensor) -> torch.Tensor:
213
- if self.preprocess:
214
- obs = self.preprocess(obs)
215
- return self.feature_extractor(obs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rl_algo_impls/shared/module/module.py CHANGED
@@ -1,8 +1,8 @@
 
 
1
  import numpy as np
2
  import torch.nn as nn
3
 
4
- from typing import Sequence, Type
5
-
6
 
7
  def mlp(
8
  layer_sizes: Sequence[int],
@@ -10,12 +10,15 @@ def mlp(
10
  output_activation: Type[nn.Module] = nn.Identity,
11
  init_layers_orthogonal: bool = False,
12
  final_layer_gain: float = np.sqrt(2),
 
13
  ) -> nn.Module:
14
  layers = []
15
  for i in range(len(layer_sizes) - 2):
16
  layers.append(
17
  layer_init(
18
- nn.Linear(layer_sizes[i], layer_sizes[i + 1]), init_layers_orthogonal
 
 
19
  )
20
  )
21
  layers.append(activation())
 
1
+ from typing import Sequence, Type
2
+
3
  import numpy as np
4
  import torch.nn as nn
5
 
 
 
6
 
7
  def mlp(
8
  layer_sizes: Sequence[int],
 
10
  output_activation: Type[nn.Module] = nn.Identity,
11
  init_layers_orthogonal: bool = False,
12
  final_layer_gain: float = np.sqrt(2),
13
+ hidden_layer_gain: float = np.sqrt(2),
14
  ) -> nn.Module:
15
  layers = []
16
  for i in range(len(layer_sizes) - 2):
17
  layers.append(
18
  layer_init(
19
+ nn.Linear(layer_sizes[i], layer_sizes[i + 1]),
20
+ init_layers_orthogonal,
21
+ std=hidden_layer_gain,
22
  )
23
  )
24
  layers.append(activation())
rl_algo_impls/shared/policy/critic.py CHANGED
@@ -1,27 +1,39 @@
1
- import gym
 
 
2
  import torch
3
  import torch.nn as nn
4
 
5
- from typing import Sequence, Type
6
-
7
  from rl_algo_impls.shared.module.module import mlp
8
 
9
 
10
  class CriticHead(nn.Module):
11
  def __init__(
12
  self,
13
- hidden_sizes: Sequence[int] = (32,),
 
14
  activation: Type[nn.Module] = nn.Tanh,
15
  init_layers_orthogonal: bool = True,
16
  ) -> None:
17
  super().__init__()
18
- layer_sizes = tuple(hidden_sizes) + (1,)
19
- self._fc = mlp(
20
- layer_sizes,
21
- activation,
22
- init_layers_orthogonal=init_layers_orthogonal,
23
- final_layer_gain=1.0,
 
 
 
 
 
 
 
 
 
24
  )
 
25
 
26
  def forward(self, obs: torch.Tensor) -> torch.Tensor:
27
  v = self._fc(obs)
 
1
+ from typing import Sequence, Type
2
+
3
+ import numpy as np
4
  import torch
5
  import torch.nn as nn
6
 
7
+ from rl_algo_impls.shared.encoder import EncoderOutDim
 
8
  from rl_algo_impls.shared.module.module import mlp
9
 
10
 
11
  class CriticHead(nn.Module):
12
  def __init__(
13
  self,
14
+ in_dim: EncoderOutDim,
15
+ hidden_sizes: Sequence[int] = (),
16
  activation: Type[nn.Module] = nn.Tanh,
17
  init_layers_orthogonal: bool = True,
18
  ) -> None:
19
  super().__init__()
20
+ seq = []
21
+ if isinstance(in_dim, tuple):
22
+ seq.append(nn.Flatten())
23
+ in_channels = int(np.prod(in_dim))
24
+ else:
25
+ in_channels = in_dim
26
+ layer_sizes = (in_channels,) + tuple(hidden_sizes) + (1,)
27
+ seq.append(
28
+ mlp(
29
+ layer_sizes,
30
+ activation,
31
+ init_layers_orthogonal=init_layers_orthogonal,
32
+ final_layer_gain=1.0,
33
+ hidden_layer_gain=1.0,
34
+ )
35
  )
36
+ self._fc = nn.Sequential(*seq)
37
 
38
  def forward(self, obs: torch.Tensor) -> torch.Tensor:
39
  v = self._fc(obs)
rl_algo_impls/shared/policy/on_policy.py CHANGED
@@ -1,24 +1,20 @@
 
 
 
1
  import gym
2
  import numpy as np
3
  import torch
4
-
5
- from abc import abstractmethod
6
  from gym.spaces import Box, Discrete, Space
7
- from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
8
 
9
- from rl_algo_impls.shared.module.feature_extractor import FeatureExtractor
10
- from rl_algo_impls.shared.policy.actor import (
11
- PiForward,
12
- StateDependentNoiseActorHead,
13
- actor_head,
14
- )
15
  from rl_algo_impls.shared.policy.critic import CriticHead
16
  from rl_algo_impls.shared.policy.policy import ACTIVATION, Policy
17
  from rl_algo_impls.wrappers.vectorable_wrapper import (
18
  VecEnv,
19
  VecEnvObs,
20
- single_observation_space,
21
  single_action_space,
 
22
  )
23
 
24
 
@@ -77,7 +73,12 @@ class OnPolicy(Policy):
77
  ...
78
 
79
  @abstractmethod
80
- def step(self, obs: VecEnvObs) -> Step:
 
 
 
 
 
81
  ...
82
 
83
 
@@ -94,10 +95,11 @@ class ActorCritic(OnPolicy):
94
  full_std: bool = True,
95
  squash_output: bool = False,
96
  share_features_extractor: bool = True,
97
- cnn_feature_dim: int = 512,
98
  cnn_style: str = "nature",
99
  cnn_layers_init_orthogonal: Optional[bool] = None,
100
  impala_channels: Sequence[int] = (16, 32, 32),
 
101
  **kwargs,
102
  ) -> None:
103
  super().__init__(env, **kwargs)
@@ -120,52 +122,56 @@ class ActorCritic(OnPolicy):
120
  self.action_space = action_space
121
  self.squash_output = squash_output
122
  self.share_features_extractor = share_features_extractor
123
- self._feature_extractor = FeatureExtractor(
124
  observation_space,
125
  activation,
126
  init_layers_orthogonal=init_layers_orthogonal,
127
- cnn_feature_dim=cnn_feature_dim,
128
  cnn_style=cnn_style,
129
  cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
130
  impala_channels=impala_channels,
131
  )
132
  self._pi = actor_head(
133
  self.action_space,
134
- (self._feature_extractor.out_dim,) + tuple(pi_hidden_sizes),
 
135
  init_layers_orthogonal,
136
  activation,
137
  log_std_init=log_std_init,
138
  use_sde=use_sde,
139
  full_std=full_std,
140
  squash_output=squash_output,
 
141
  )
142
 
143
  if not share_features_extractor:
144
- self._v_feature_extractor = FeatureExtractor(
145
  observation_space,
146
  activation,
147
  init_layers_orthogonal=init_layers_orthogonal,
148
- cnn_feature_dim=cnn_feature_dim,
149
  cnn_style=cnn_style,
150
  cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
151
  )
152
- v_hidden_sizes = (self._v_feature_extractor.out_dim,) + tuple(
153
- v_hidden_sizes
154
- )
155
  else:
156
  self._v_feature_extractor = None
157
- v_hidden_sizes = (self._feature_extractor.out_dim,) + tuple(v_hidden_sizes)
158
  self._v = CriticHead(
 
159
  hidden_sizes=v_hidden_sizes,
160
  activation=activation,
161
  init_layers_orthogonal=init_layers_orthogonal,
162
  )
163
 
164
  def _pi_forward(
165
- self, obs: torch.Tensor, action: Optional[torch.Tensor] = None
 
 
 
166
  ) -> Tuple[PiForward, torch.Tensor]:
167
  p_fe = self._feature_extractor(obs)
168
- pi_forward = self._pi(p_fe, action)
169
 
170
  return pi_forward, p_fe
171
 
@@ -173,8 +179,13 @@ class ActorCritic(OnPolicy):
173
  v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
174
  return self._v(v_fe)
175
 
176
- def forward(self, obs: torch.Tensor, action: torch.Tensor) -> ACForward:
177
- (_, logp_a, entropy), p_fc = self._pi_forward(obs, action)
 
 
 
 
 
178
  v = self._v_forward(obs, p_fc)
179
 
180
  assert logp_a is not None
@@ -192,10 +203,11 @@ class ActorCritic(OnPolicy):
192
  v = self._v(fe)
193
  return v.cpu().numpy()
194
 
195
- def step(self, obs: VecEnvObs) -> Step:
196
  o = self._as_tensor(obs)
 
197
  with torch.no_grad():
198
- (pi, _, _), p_fc = self._pi_forward(o)
199
  a = pi.sample()
200
  logp_a = pi.log_prob(a)
201
 
@@ -205,13 +217,21 @@ class ActorCritic(OnPolicy):
205
  clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
206
  return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
207
 
208
- def act(self, obs: np.ndarray, deterministic: bool = True) -> np.ndarray:
 
 
 
 
 
209
  if not deterministic:
210
- return self.step(obs).clamped_a
211
  else:
212
  o = self._as_tensor(obs)
 
 
 
213
  with torch.no_grad():
214
- (pi, _, _), _ = self._pi_forward(o)
215
  a = pi.mode
216
  return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
217
 
@@ -220,7 +240,10 @@ class ActorCritic(OnPolicy):
220
  self.reset_noise()
221
 
222
  def reset_noise(self, batch_size: Optional[int] = None) -> None:
223
- if isinstance(self._pi, StateDependentNoiseActorHead):
224
- self._pi.sample_weights(
225
- batch_size=batch_size if batch_size else self.env.num_envs
226
- )
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
3
+
4
  import gym
5
  import numpy as np
6
  import torch
 
 
7
  from gym.spaces import Box, Discrete, Space
 
8
 
9
+ from rl_algo_impls.shared.actor import PiForward, actor_head
10
+ from rl_algo_impls.shared.encoder import Encoder
 
 
 
 
11
  from rl_algo_impls.shared.policy.critic import CriticHead
12
  from rl_algo_impls.shared.policy.policy import ACTIVATION, Policy
13
  from rl_algo_impls.wrappers.vectorable_wrapper import (
14
  VecEnv,
15
  VecEnvObs,
 
16
  single_action_space,
17
+ single_observation_space,
18
  )
19
 
20
 
 
73
  ...
74
 
75
  @abstractmethod
76
+ def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
77
+ ...
78
+
79
+ @property
80
+ @abstractmethod
81
+ def action_shape(self) -> Tuple[int, ...]:
82
  ...
83
 
84
 
 
95
  full_std: bool = True,
96
  squash_output: bool = False,
97
  share_features_extractor: bool = True,
98
+ cnn_flatten_dim: int = 512,
99
  cnn_style: str = "nature",
100
  cnn_layers_init_orthogonal: Optional[bool] = None,
101
  impala_channels: Sequence[int] = (16, 32, 32),
102
+ actor_head_style: str = "single",
103
  **kwargs,
104
  ) -> None:
105
  super().__init__(env, **kwargs)
 
122
  self.action_space = action_space
123
  self.squash_output = squash_output
124
  self.share_features_extractor = share_features_extractor
125
+ self._feature_extractor = Encoder(
126
  observation_space,
127
  activation,
128
  init_layers_orthogonal=init_layers_orthogonal,
129
+ cnn_flatten_dim=cnn_flatten_dim,
130
  cnn_style=cnn_style,
131
  cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
132
  impala_channels=impala_channels,
133
  )
134
  self._pi = actor_head(
135
  self.action_space,
136
+ self._feature_extractor.out_dim,
137
+ tuple(pi_hidden_sizes),
138
  init_layers_orthogonal,
139
  activation,
140
  log_std_init=log_std_init,
141
  use_sde=use_sde,
142
  full_std=full_std,
143
  squash_output=squash_output,
144
+ actor_head_style=actor_head_style,
145
  )
146
 
147
  if not share_features_extractor:
148
+ self._v_feature_extractor = Encoder(
149
  observation_space,
150
  activation,
151
  init_layers_orthogonal=init_layers_orthogonal,
152
+ cnn_flatten_dim=cnn_flatten_dim,
153
  cnn_style=cnn_style,
154
  cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
155
  )
156
+ critic_in_dim = self._v_feature_extractor.out_dim
 
 
157
  else:
158
  self._v_feature_extractor = None
159
+ critic_in_dim = self._feature_extractor.out_dim
160
  self._v = CriticHead(
161
+ in_dim=critic_in_dim,
162
  hidden_sizes=v_hidden_sizes,
163
  activation=activation,
164
  init_layers_orthogonal=init_layers_orthogonal,
165
  )
166
 
167
  def _pi_forward(
168
+ self,
169
+ obs: torch.Tensor,
170
+ action_masks: Optional[torch.Tensor],
171
+ action: Optional[torch.Tensor] = None,
172
  ) -> Tuple[PiForward, torch.Tensor]:
173
  p_fe = self._feature_extractor(obs)
174
+ pi_forward = self._pi(p_fe, actions=action, action_masks=action_masks)
175
 
176
  return pi_forward, p_fe
177
 
 
179
  v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
180
  return self._v(v_fe)
181
 
182
+ def forward(
183
+ self,
184
+ obs: torch.Tensor,
185
+ action: torch.Tensor,
186
+ action_masks: Optional[torch.Tensor] = None,
187
+ ) -> ACForward:
188
+ (_, logp_a, entropy), p_fc = self._pi_forward(obs, action_masks, action=action)
189
  v = self._v_forward(obs, p_fc)
190
 
191
  assert logp_a is not None
 
203
  v = self._v(fe)
204
  return v.cpu().numpy()
205
 
206
+ def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
207
  o = self._as_tensor(obs)
208
+ a_masks = self._as_tensor(action_masks) if action_masks is not None else None
209
  with torch.no_grad():
210
+ (pi, _, _), p_fc = self._pi_forward(o, action_masks=a_masks)
211
  a = pi.sample()
212
  logp_a = pi.log_prob(a)
213
 
 
217
  clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
218
  return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
219
 
220
+ def act(
221
+ self,
222
+ obs: np.ndarray,
223
+ deterministic: bool = True,
224
+ action_masks: Optional[np.ndarray] = None,
225
+ ) -> np.ndarray:
226
  if not deterministic:
227
+ return self.step(obs, action_masks=action_masks).clamped_a
228
  else:
229
  o = self._as_tensor(obs)
230
+ a_masks = (
231
+ self._as_tensor(action_masks) if action_masks is not None else None
232
+ )
233
  with torch.no_grad():
234
+ (pi, _, _), _ = self._pi_forward(o, action_masks=a_masks)
235
  a = pi.mode
236
  return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
237
 
 
240
  self.reset_noise()
241
 
242
  def reset_noise(self, batch_size: Optional[int] = None) -> None:
243
+ self._pi.sample_weights(
244
+ batch_size=batch_size if batch_size else self.env.num_envs
245
+ )
246
+
247
+ @property
248
+ def action_shape(self) -> Tuple[int, ...]:
249
+ return self._pi.action_shape
rl_algo_impls/shared/policy/policy.py CHANGED
@@ -46,7 +46,12 @@ class Policy(nn.Module, ABC):
46
  return self
47
 
48
  @abstractmethod
49
- def act(self, obs: VecEnvObs, deterministic: bool = True) -> np.ndarray:
 
 
 
 
 
50
  ...
51
 
52
  def save(self, path: str) -> None:
 
46
  return self
47
 
48
  @abstractmethod
49
+ def act(
50
+ self,
51
+ obs: VecEnvObs,
52
+ deterministic: bool = True,
53
+ action_masks: Optional[np.ndarray] = None,
54
+ ) -> np.ndarray:
55
  ...
56
 
57
  def save(self, path: str) -> None:
rl_algo_impls/shared/schedule.py CHANGED
@@ -20,10 +20,38 @@ def constant_schedule(val: float) -> Schedule:
20
  return lambda f: val
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def schedule(name: str, start_val: float) -> Schedule:
24
  if name == "linear":
25
  return linear_schedule(start_val, 0)
26
- return constant_schedule(start_val)
 
 
 
 
 
27
 
28
 
29
  def update_learning_rate(optimizer: Optimizer, learning_rate: float) -> None:
 
20
  return lambda f: val
21
 
22
 
23
+ def spike_schedule(
24
+ max_value: float,
25
+ start_fraction: float = 1e-2,
26
+ end_fraction: float = 1e-4,
27
+ peak_progress: float = 0.1,
28
+ ) -> Schedule:
29
+ assert 0 < peak_progress < 1
30
+
31
+ def func(progress_fraction: float) -> float:
32
+ if progress_fraction < peak_progress:
33
+ fraction = (
34
+ start_fraction
35
+ + (1 - start_fraction) * progress_fraction / peak_progress
36
+ )
37
+ else:
38
+ fraction = 1 + (end_fraction - 1) * (progress_fraction - peak_progress) / (
39
+ 1 - peak_progress
40
+ )
41
+ return max_value * fraction
42
+
43
+ return func
44
+
45
+
46
  def schedule(name: str, start_val: float) -> Schedule:
47
  if name == "linear":
48
  return linear_schedule(start_val, 0)
49
+ elif name == "none":
50
+ return constant_schedule(start_val)
51
+ elif name == "spike":
52
+ return spike_schedule(start_val)
53
+ else:
54
+ raise ValueError(f"Schedule {name} not supported")
55
 
56
 
57
  def update_learning_rate(optimizer: Optimizer, learning_rate: float) -> None:
rl_algo_impls/shared/stats.py CHANGED
@@ -1,14 +1,17 @@
1
- import numpy as np
2
-
3
  from dataclasses import dataclass
 
 
 
4
  from torch.utils.tensorboard.writer import SummaryWriter
5
- from typing import Dict, List, Optional, Sequence, Union, TypeVar
6
 
7
 
8
  @dataclass
9
  class Episode:
10
  score: float = 0
11
  length: int = 0
 
12
 
13
 
14
  StatisticSelf = TypeVar("StatisticSelf", bound="Statistic")
@@ -75,12 +78,25 @@ class EpisodesStats:
75
  simple: bool
76
  score: Statistic
77
  length: Statistic
 
78
 
79
  def __init__(self, episodes: Sequence[Episode], simple: bool = False) -> None:
80
  self.episodes = episodes
81
  self.simple = simple
82
  self.score = Statistic(np.array([e.score for e in episodes]))
83
  self.length = Statistic(np.array([e.length for e in episodes]), round_digits=0)
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def __gt__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
86
  return self.score > o.score
@@ -118,6 +134,8 @@ class EpisodesStats:
118
  "length": self.length.mean,
119
  }
120
  )
 
 
121
  for name, value in stats.items():
122
  tb_writer.add_scalar(f"{main_tag}/{name}", value, global_step=global_step)
123
 
@@ -131,19 +149,19 @@ class EpisodeAccumulator:
131
  def episodes(self) -> List[Episode]:
132
  return self._episodes
133
 
134
- def step(self, reward: np.ndarray, done: np.ndarray) -> None:
135
  for idx, current in enumerate(self.current_episodes):
136
  current.score += reward[idx]
137
  current.length += 1
138
  if done[idx]:
139
  self._episodes.append(current)
140
  self.current_episodes[idx] = Episode()
141
- self.on_done(idx, current)
142
 
143
  def __len__(self) -> int:
144
  return len(self.episodes)
145
 
146
- def on_done(self, ep_idx: int, episode: Episode) -> None:
147
  pass
148
 
149
  def stats(self) -> EpisodesStats:
 
1
+ import dataclasses
2
+ from collections import defaultdict
3
  from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
5
+
6
+ import numpy as np
7
  from torch.utils.tensorboard.writer import SummaryWriter
 
8
 
9
 
10
  @dataclass
11
  class Episode:
12
  score: float = 0
13
  length: int = 0
14
+ info: Dict[str, Dict[str, Any]] = dataclasses.field(default_factory=dict)
15
 
16
 
17
  StatisticSelf = TypeVar("StatisticSelf", bound="Statistic")
 
78
  simple: bool
79
  score: Statistic
80
  length: Statistic
81
+ additional_stats: Dict[str, Statistic]
82
 
83
  def __init__(self, episodes: Sequence[Episode], simple: bool = False) -> None:
84
  self.episodes = episodes
85
  self.simple = simple
86
  self.score = Statistic(np.array([e.score for e in episodes]))
87
  self.length = Statistic(np.array([e.length for e in episodes]), round_digits=0)
88
+ additional_values = defaultdict(list)
89
+ for e in self.episodes:
90
+ if e.info:
91
+ for k, v in e.info.items():
92
+ if isinstance(v, dict):
93
+ for k2, v2 in v.items():
94
+ additional_values[f"{k}_{k2}"].append(v2)
95
+ else:
96
+ additional_values[k].append(v)
97
+ self.additional_stats = {
98
+ k: Statistic(np.array(values)) for k, values in additional_values.items()
99
+ }
100
 
101
  def __gt__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
102
  return self.score > o.score
 
134
  "length": self.length.mean,
135
  }
136
  )
137
+ for k, addl_stats in self.additional_stats.items():
138
+ stats[k] = addl_stats.mean
139
  for name, value in stats.items():
140
  tb_writer.add_scalar(f"{main_tag}/{name}", value, global_step=global_step)
141
 
 
149
  def episodes(self) -> List[Episode]:
150
  return self._episodes
151
 
152
+ def step(self, reward: np.ndarray, done: np.ndarray, info: List[Dict]) -> None:
153
  for idx, current in enumerate(self.current_episodes):
154
  current.score += reward[idx]
155
  current.length += 1
156
  if done[idx]:
157
  self._episodes.append(current)
158
  self.current_episodes[idx] = Episode()
159
+ self.on_done(idx, current, info[idx])
160
 
161
  def __len__(self) -> int:
162
  return len(self.episodes)
163
 
164
+ def on_done(self, ep_idx: int, episode: Episode, info: Dict) -> None:
165
  pass
166
 
167
  def stats(self) -> EpisodesStats:
rl_algo_impls/shared/vec_env/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from rl_algo_impls.shared.vec_env.make_env import make_env, make_eval_env
rl_algo_impls/shared/vec_env/make_env.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import asdict
2
+ from typing import Optional
3
+
4
+ from torch.utils.tensorboard.writer import SummaryWriter
5
+
6
+ from rl_algo_impls.runner.config import Config, EnvHyperparams
7
+ from rl_algo_impls.shared.vec_env.microrts import make_microrts_env
8
+ from rl_algo_impls.shared.vec_env.procgen import make_procgen_env
9
+ from rl_algo_impls.shared.vec_env.vec_env import make_vec_env
10
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
11
+
12
+
13
+ def make_env(
14
+ config: Config,
15
+ hparams: EnvHyperparams,
16
+ training: bool = True,
17
+ render: bool = False,
18
+ normalize_load_path: Optional[str] = None,
19
+ tb_writer: Optional[SummaryWriter] = None,
20
+ ) -> VecEnv:
21
+ if hparams.env_type == "procgen":
22
+ return make_procgen_env(
23
+ config,
24
+ hparams,
25
+ training=training,
26
+ render=render,
27
+ normalize_load_path=normalize_load_path,
28
+ tb_writer=tb_writer,
29
+ )
30
+ elif hparams.env_type in {"sb3vec", "gymvec"}:
31
+ return make_vec_env(
32
+ config,
33
+ hparams,
34
+ training=training,
35
+ render=render,
36
+ normalize_load_path=normalize_load_path,
37
+ tb_writer=tb_writer,
38
+ )
39
+ elif hparams.env_type == "microrts":
40
+ return make_microrts_env(
41
+ config,
42
+ hparams,
43
+ training=training,
44
+ render=render,
45
+ normalize_load_path=normalize_load_path,
46
+ tb_writer=tb_writer,
47
+ )
48
+ else:
49
+ raise ValueError(f"env_type {hparams.env_type} not supported")
50
+
51
+
52
+ def make_eval_env(
53
+ config: Config,
54
+ hparams: EnvHyperparams,
55
+ override_n_envs: Optional[int] = None,
56
+ **kwargs,
57
+ ) -> VecEnv:
58
+ kwargs = kwargs.copy()
59
+ kwargs["training"] = False
60
+ if override_n_envs is not None:
61
+ hparams_kwargs = asdict(hparams)
62
+ hparams_kwargs["n_envs"] = override_n_envs
63
+ if override_n_envs == 1:
64
+ hparams_kwargs["vec_env_class"] = "sync"
65
+ hparams = EnvHyperparams(**hparams_kwargs)
66
+ return make_env(config, hparams, **kwargs)
rl_algo_impls/shared/vec_env/microrts.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import astuple
2
+ from typing import Optional
3
+
4
+ import gym
5
+ import numpy as np
6
+ from torch.utils.tensorboard.writer import SummaryWriter
7
+
8
+ from rl_algo_impls.runner.config import Config, EnvHyperparams
9
+ from rl_algo_impls.wrappers.action_mask_wrapper import MicrortsMaskWrapper
10
+ from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
11
+ from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation
12
+ from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
13
+ from rl_algo_impls.wrappers.microrts_stats_recorder import MicrortsStatsRecorder
14
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
15
+
16
+
17
+ def make_microrts_env(
18
+ config: Config,
19
+ hparams: EnvHyperparams,
20
+ training: bool = True,
21
+ render: bool = False,
22
+ normalize_load_path: Optional[str] = None,
23
+ tb_writer: Optional[SummaryWriter] = None,
24
+ ) -> VecEnv:
25
+ import gym_microrts
26
+ from gym_microrts import microrts_ai
27
+
28
+ from rl_algo_impls.shared.vec_env.microrts_compat import (
29
+ MicroRTSGridModeVecEnvCompat,
30
+ )
31
+
32
+ (
33
+ _, # env_type
34
+ n_envs,
35
+ _, # frame_stack
36
+ make_kwargs,
37
+ _, # no_reward_timeout_steps
38
+ _, # no_reward_fire_steps
39
+ _, # vec_env_class
40
+ _, # normalize
41
+ _, # normalize_kwargs,
42
+ rolling_length,
43
+ _, # train_record_video
44
+ _, # video_step_interval
45
+ _, # initial_steps_to_truncate
46
+ _, # clip_atari_rewards
47
+ _, # normalize_type
48
+ _, # mask_actions
49
+ bots,
50
+ ) = astuple(hparams)
51
+
52
+ seed = config.seed(training=training)
53
+
54
+ make_kwargs = make_kwargs or {}
55
+ if "num_selfplay_envs" not in make_kwargs:
56
+ make_kwargs["num_selfplay_envs"] = 0
57
+ if "num_bot_envs" not in make_kwargs:
58
+ make_kwargs["num_bot_envs"] = n_envs - make_kwargs["num_selfplay_envs"]
59
+ if "reward_weight" in make_kwargs:
60
+ make_kwargs["reward_weight"] = np.array(make_kwargs["reward_weight"])
61
+ if bots:
62
+ ai2s = []
63
+ for ai_name, n in bots.items():
64
+ for _ in range(n):
65
+ if len(ai2s) >= make_kwargs["num_bot_envs"]:
66
+ break
67
+ ai = getattr(microrts_ai, ai_name)
68
+ assert ai, f"{ai_name} not in microrts_ai"
69
+ ai2s.append(ai)
70
+ else:
71
+ ai2s = [microrts_ai.randomAI for _ in make_kwargs["num_bot_envs"]]
72
+ make_kwargs["ai2s"] = ai2s
73
+ envs = MicroRTSGridModeVecEnvCompat(**make_kwargs)
74
+ envs = HwcToChwObservation(envs)
75
+ envs = IsVectorEnv(envs)
76
+ envs = MicrortsMaskWrapper(envs)
77
+
78
+ if seed is not None:
79
+ envs.action_space.seed(seed)
80
+ envs.observation_space.seed(seed)
81
+
82
+ envs = gym.wrappers.RecordEpisodeStatistics(envs)
83
+ envs = MicrortsStatsRecorder(envs, config.algo_hyperparams.get("gamma", 0.99))
84
+ if training:
85
+ assert tb_writer
86
+ envs = EpisodeStatsWriter(
87
+ envs,
88
+ tb_writer,
89
+ training=training,
90
+ rolling_length=rolling_length,
91
+ additional_keys_to_log=config.additional_keys_to_log,
92
+ )
93
+
94
+ return envs
rl_algo_impls/shared/vec_env/microrts_compat.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypeVar
2
+
3
+ import numpy as np
4
+ from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv
5
+ from jpype.types import JArray, JInt
6
+
7
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvStepReturn
8
+
9
+ MicroRTSGridModeVecEnvCompatSelf = TypeVar(
10
+ "MicroRTSGridModeVecEnvCompatSelf", bound="MicroRTSGridModeVecEnvCompat"
11
+ )
12
+
13
+
14
+ class MicroRTSGridModeVecEnvCompat(MicroRTSGridModeVecEnv):
15
+ def step(self, action: np.ndarray) -> VecEnvStepReturn:
16
+ indexed_actions = np.concatenate(
17
+ [
18
+ np.expand_dims(
19
+ np.stack(
20
+ [np.arange(0, action.shape[1]) for i in range(self.num_envs)]
21
+ ),
22
+ axis=2,
23
+ ),
24
+ action,
25
+ ],
26
+ axis=2,
27
+ )
28
+ action_mask = np.array(self.vec_client.getMasks(0), dtype=np.bool8).reshape(
29
+ indexed_actions.shape[:-1] + (-1,)
30
+ )
31
+ valid_action_mask = action_mask[:, :, 0]
32
+ valid_actions_counts = valid_action_mask.sum(1)
33
+ valid_actions = indexed_actions[valid_action_mask]
34
+ valid_actions_idx = 0
35
+
36
+ all_valid_actions = []
37
+ for env_act_cnt in valid_actions_counts:
38
+ env_valid_actions = []
39
+ for _ in range(env_act_cnt):
40
+ env_valid_actions.append(JArray(JInt)(valid_actions[valid_actions_idx]))
41
+ valid_actions_idx += 1
42
+ all_valid_actions.append(JArray(JArray(JInt))(env_valid_actions))
43
+ return super().step(JArray(JArray(JArray(JInt)))(all_valid_actions)) # type: ignore
44
+
45
+ @property
46
+ def unwrapped(
47
+ self: MicroRTSGridModeVecEnvCompatSelf,
48
+ ) -> MicroRTSGridModeVecEnvCompatSelf:
49
+ return self