sgoodfriend commited on
Commit
85ca419
1 Parent(s): a1df559

PPO playing MicrortsAttackShapedReward-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/fb34ab86707f5e2db85e821ff7dbdc624072d640

Browse files
README.md CHANGED
@@ -10,7 +10,7 @@ model-index:
10
  results:
11
  - metrics:
12
  - type: mean_reward
13
- value: 8.11 +/- 0.27
14
  name: mean_reward
15
  task:
16
  type: reinforcement-learning
@@ -23,17 +23,17 @@ model-index:
23
 
24
  This is a trained model of a **PPO** agent playing **MicrortsAttackShapedReward-v1** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
25
 
26
- All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/hz9h6o30.
27
 
28
  ## Training Results
29
 
30
- This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [587a52b](https://github.com/sgoodfriend/rl-algo-impls/tree/587a52bc38901314c7c1b5c6892acf9315796cf3). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
31
 
32
  | algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
33
  |:-------|:------------------------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
34
- | ppo | MicrortsAttackShapedReward-v1 | 1 | 7.826 | 0.0610015 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/67g3x4i9) |
35
- | ppo | MicrortsAttackShapedReward-v1 | 2 | 8.10527 | 0.266247 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/usawt7bt) |
36
- | ppo | MicrortsAttackShapedReward-v1 | 3 | 7.7645 | 0.318334 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/g3rmjrcf) |
37
 
38
 
39
  ### Prerequisites: Weights & Biases (WandB)
@@ -53,10 +53,10 @@ login`.
53
  Note: While the model state dictionary and hyperaparameters are saved, the latest
54
  implementation could be sufficiently different to not be able to reproduce similar
55
  results. You might need to checkout the commit the agent was trained on:
56
- [587a52b](https://github.com/sgoodfriend/rl-algo-impls/tree/587a52bc38901314c7c1b5c6892acf9315796cf3).
57
  ```
58
  # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
59
- python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/usawt7bt
60
  ```
61
 
62
  Setup hasn't been completely worked out yet, so you might be best served by using Google
@@ -68,7 +68,7 @@ notebook.
68
 
69
  ## Training
70
  If you want the highest chance to reproduce these results, you'll want to checkout the
71
- commit the agent was trained on: [587a52b](https://github.com/sgoodfriend/rl-algo-impls/tree/587a52bc38901314c7c1b5c6892acf9315796cf3). While
72
  training is deterministic, different hardware will give different results.
73
 
74
  ```
@@ -83,7 +83,7 @@ notebook.
83
 
84
 
85
  ## Benchmarking (with Lambda Labs instance)
86
- This and other models from https://api.wandb.ai/links/sgoodfriend/hz9h6o30 were generated by running a script on a Lambda
87
  Labs instance. In a Lambda Labs instance terminal:
88
  ```
89
  git clone [email protected]:sgoodfriend/rl-algo-impls.git
@@ -120,6 +120,7 @@ algo_hyperparams:
120
  device: auto
121
  env: MicrortsAttackShapedReward-v1
122
  env_hyperparams:
 
123
  n_envs: 8
124
  vec_env_class: sync
125
  env_id: null
@@ -136,7 +137,7 @@ wandb_entity: null
136
  wandb_group: null
137
  wandb_project_name: rl-algo-impls-benchmarks
138
  wandb_tags:
139
- - benchmark_587a52b
140
- - host_192-9-151-119
141
 
142
  ```
 
10
  results:
11
  - metrics:
12
  - type: mean_reward
13
+ value: 19.22 +/- 0.0
14
  name: mean_reward
15
  task:
16
  type: reinforcement-learning
 
23
 
24
  This is a trained model of a **PPO** agent playing **MicrortsAttackShapedReward-v1** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
25
 
26
+ All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/z3kioih3.
27
 
28
  ## Training Results
29
 
30
+ This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [fb34ab8](https://github.com/sgoodfriend/rl-algo-impls/tree/fb34ab86707f5e2db85e821ff7dbdc624072d640). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
31
 
32
  | algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
33
  |:-------|:------------------------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
34
+ | ppo | MicrortsAttackShapedReward-v1 | 1 | 19.2195 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/7x622tan) |
35
+ | ppo | MicrortsAttackShapedReward-v1 | 2 | 19.2195 | 0 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/i8e9nqxz) |
36
+ | ppo | MicrortsAttackShapedReward-v1 | 3 | 19.2195 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/285khfoz) |
37
 
38
 
39
  ### Prerequisites: Weights & Biases (WandB)
 
53
  Note: While the model state dictionary and hyperaparameters are saved, the latest
54
  implementation could be sufficiently different to not be able to reproduce similar
55
  results. You might need to checkout the commit the agent was trained on:
56
+ [fb34ab8](https://github.com/sgoodfriend/rl-algo-impls/tree/fb34ab86707f5e2db85e821ff7dbdc624072d640).
57
  ```
58
  # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
59
+ python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/i8e9nqxz
60
  ```
61
 
62
  Setup hasn't been completely worked out yet, so you might be best served by using Google
 
68
 
69
  ## Training
70
  If you want the highest chance to reproduce these results, you'll want to checkout the
71
+ commit the agent was trained on: [fb34ab8](https://github.com/sgoodfriend/rl-algo-impls/tree/fb34ab86707f5e2db85e821ff7dbdc624072d640). While
72
  training is deterministic, different hardware will give different results.
73
 
74
  ```
 
83
 
84
 
85
  ## Benchmarking (with Lambda Labs instance)
86
+ This and other models from https://api.wandb.ai/links/sgoodfriend/z3kioih3 were generated by running a script on a Lambda
87
  Labs instance. In a Lambda Labs instance terminal:
88
  ```
89
  git clone [email protected]:sgoodfriend/rl-algo-impls.git
 
120
  device: auto
121
  env: MicrortsAttackShapedReward-v1
122
  env_hyperparams:
123
+ mask_actions: true
124
  n_envs: 8
125
  vec_env_class: sync
126
  env_id: null
 
137
  wandb_group: null
138
  wandb_project_name: rl-algo-impls-benchmarks
139
  wandb_tags:
140
+ - benchmark_fb34ab8
141
+ - host_155-248-210-13
142
 
143
  ```
pyproject.toml CHANGED
@@ -1,6 +1,6 @@
1
  [project]
2
  name = "rl_algo_impls"
3
- version = "0.0.6"
4
  description = "Implementations of reinforcement learning algorithms"
5
  authors = [
6
  {name = "Scott Goodfriend", email = "[email protected]"},
 
1
  [project]
2
  name = "rl_algo_impls"
3
+ version = "0.0.7"
4
  description = "Implementations of reinforcement learning algorithms"
5
  authors = [
6
  {name = "Scott Goodfriend", email = "[email protected]"},
replay.meta.json CHANGED
@@ -1 +1 @@
1
- {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "640x640", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/tmp/tmpe690pgiu/ppo-MicrortsAttackShapedReward-v1/replay.mp4"]}, "episode": {"r": 8.219544410705566, "l": 400, "t": 8.021762}}
 
1
+ {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "640x640", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/tmp/tmpff1d453r/ppo-MicrortsAttackShapedReward-v1/replay.mp4"]}, "episode": {"r": 19.21954345703125, "l": 24, "t": 1.094416}}
replay.mp4 CHANGED
Binary files a/replay.mp4 and b/replay.mp4 differ
 
rl_algo_impls/a2c/a2c.py CHANGED
@@ -10,6 +10,7 @@ from typing import Optional, TypeVar
10
 
11
  from rl_algo_impls.shared.algorithm import Algorithm
12
  from rl_algo_impls.shared.callbacks.callback import Callback
 
13
  from rl_algo_impls.shared.policy.on_policy import ActorCritic
14
  from rl_algo_impls.shared.schedule import schedule, update_learning_rate
15
  from rl_algo_impls.shared.stats import log_scalars
@@ -84,12 +85,12 @@ class A2C(Algorithm):
84
  obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
85
  actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
86
  rewards = np.zeros(epoch_dim, dtype=np.float32)
87
- episode_starts = np.zeros(epoch_dim, dtype=np.byte)
88
  values = np.zeros(epoch_dim, dtype=np.float32)
89
  logprobs = np.zeros(epoch_dim, dtype=np.float32)
90
 
91
  next_obs = self.env.reset()
92
- next_episode_starts = np.ones(step_dim, dtype=np.byte)
93
 
94
  timesteps_elapsed = start_timesteps
95
  while timesteps_elapsed < start_timesteps + train_timesteps:
@@ -126,23 +127,16 @@ class A2C(Algorithm):
126
  clamped_action
127
  )
128
 
129
- advantages = np.zeros(epoch_dim, dtype=np.float32)
130
- last_gae_lam = 0
131
- for t in reversed(range(self.n_steps)):
132
- if t == self.n_steps - 1:
133
- next_nonterminal = 1.0 - next_episode_starts
134
- next_value = self.policy.value(next_obs)
135
- else:
136
- next_nonterminal = 1.0 - episode_starts[t + 1]
137
- next_value = values[t + 1]
138
- delta = (
139
- rewards[t] + self.gamma * next_value * next_nonterminal - values[t]
140
- )
141
- last_gae_lam = (
142
- delta
143
- + self.gamma * self.gae_lambda * next_nonterminal * last_gae_lam
144
- )
145
- advantages[t] = last_gae_lam
146
  returns = advantages + values
147
 
148
  b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
 
10
 
11
  from rl_algo_impls.shared.algorithm import Algorithm
12
  from rl_algo_impls.shared.callbacks.callback import Callback
13
+ from rl_algo_impls.shared.gae import compute_advantages
14
  from rl_algo_impls.shared.policy.on_policy import ActorCritic
15
  from rl_algo_impls.shared.schedule import schedule, update_learning_rate
16
  from rl_algo_impls.shared.stats import log_scalars
 
85
  obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
86
  actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
87
  rewards = np.zeros(epoch_dim, dtype=np.float32)
88
+ episode_starts = np.zeros(epoch_dim, dtype=np.bool8)
89
  values = np.zeros(epoch_dim, dtype=np.float32)
90
  logprobs = np.zeros(epoch_dim, dtype=np.float32)
91
 
92
  next_obs = self.env.reset()
93
+ next_episode_starts = np.full(step_dim, True, dtype=np.bool8)
94
 
95
  timesteps_elapsed = start_timesteps
96
  while timesteps_elapsed < start_timesteps + train_timesteps:
 
127
  clamped_action
128
  )
129
 
130
+ advantages = compute_advantages(
131
+ rewards,
132
+ values,
133
+ episode_starts,
134
+ next_episode_starts,
135
+ next_obs,
136
+ self.policy,
137
+ self.gamma,
138
+ self.gae_lambda,
139
+ )
 
 
 
 
 
 
 
140
  returns = advantages + values
141
 
142
  b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
rl_algo_impls/compare_runs.py CHANGED
@@ -194,5 +194,6 @@ def compare_runs() -> None:
194
  df.loc["mean"] = df.mean(numeric_only=True)
195
  print(df.to_markdown())
196
 
 
197
  if __name__ == "__main__":
198
- compare_runs()
 
194
  df.loc["mean"] = df.mean(numeric_only=True)
195
  print(df.to_markdown())
196
 
197
+
198
  if __name__ == "__main__":
199
+ compare_runs()
rl_algo_impls/huggingface_publish.py CHANGED
@@ -162,6 +162,7 @@ def publish(
162
  path_in_repo="",
163
  commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}",
164
  token=huggingface_token,
 
165
  )
166
  print(f"Pushed model to the hub: {repo_url}")
167
 
 
162
  path_in_repo="",
163
  commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}",
164
  token=huggingface_token,
165
+ delete_patterns="*",
166
  )
167
  print(f"Pushed model to the hub: {repo_url}")
168
 
rl_algo_impls/hyperparams/ppo.yml CHANGED
@@ -218,6 +218,7 @@ _microrts: &microrts-defaults
218
  env_hyperparams: &microrts-env-defaults
219
  n_envs: 8
220
  vec_env_class: sync
 
221
  policy_hyperparams:
222
  <<: *atari-policy-defaults
223
  cnn_style: microrts
@@ -227,10 +228,23 @@ _microrts: &microrts-defaults
227
  clip_range_decay: none
228
  clip_range_vf: 0.1
229
 
230
- debug-MicrortsMining-v1:
231
  <<: *microrts-defaults
 
 
 
 
 
 
232
  env_id: MicrortsMining-v1
233
- device: cpu
 
 
 
 
 
 
 
234
 
235
  HalfCheetahBulletEnv-v0: &pybullet-defaults
236
  n_timesteps: !!float 2e6
 
218
  env_hyperparams: &microrts-env-defaults
219
  n_envs: 8
220
  vec_env_class: sync
221
+ mask_actions: true
222
  policy_hyperparams:
223
  <<: *atari-policy-defaults
224
  cnn_style: microrts
 
228
  clip_range_decay: none
229
  clip_range_vf: 0.1
230
 
231
+ _no-mask-microrts: &no-mask-microrts-defaults
232
  <<: *microrts-defaults
233
+ env_hyperparams:
234
+ <<: *microrts-env-defaults
235
+ mask_actions: false
236
+
237
+ MicrortsMining-v1-NoMask:
238
+ <<: *no-mask-microrts-defaults
239
  env_id: MicrortsMining-v1
240
+
241
+ MicrortsAttackShapedReward-v1-NoMask:
242
+ <<: *no-mask-microrts-defaults
243
+ env_id: MicrortsAttackShapedReward-v1
244
+
245
+ MicrortsRandomEnemyShapedReward3-v1-NoMask:
246
+ <<: *no-mask-microrts-defaults
247
+ env_id: MicrortsRandomEnemyShapedReward3-v1
248
 
249
  HalfCheetahBulletEnv-v0: &pybullet-defaults
250
  n_timesteps: !!float 2e6
rl_algo_impls/optimize.py CHANGED
@@ -194,7 +194,7 @@ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -
194
  env = make_env(
195
  config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
196
  )
197
- device = get_device(config.device, env)
198
  policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
199
  algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
200
 
@@ -298,7 +298,7 @@ def stepwise_optimize(
298
  normalize_load_path=config.model_dir_path() if i > 0 else None,
299
  tb_writer=tb_writer,
300
  )
301
- device = get_device(config.device, env)
302
  policy = make_policy(arg.algo, env, device, **config.policy_hyperparams)
303
  if i > 0:
304
  policy.load(config.model_dir_path())
 
194
  env = make_env(
195
  config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
196
  )
197
+ device = get_device(config, env)
198
  policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
199
  algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
200
 
 
298
  normalize_load_path=config.model_dir_path() if i > 0 else None,
299
  tb_writer=tb_writer,
300
  )
301
+ device = get_device(config, env)
302
  policy = make_policy(arg.algo, env, device, **config.policy_hyperparams)
303
  if i > 0:
304
  policy.load(config.model_dir_path())
rl_algo_impls/ppo/ppo.py CHANGED
@@ -1,8 +1,9 @@
 
1
  import numpy as np
2
  import torch
3
  import torch.nn as nn
4
 
5
- from dataclasses import asdict, dataclass, field
6
  from time import perf_counter
7
  from torch.optim import Adam
8
  from torch.utils.tensorboard.writer import SummaryWriter
@@ -10,49 +11,22 @@ from typing import List, Optional, NamedTuple, TypeVar
10
 
11
  from rl_algo_impls.shared.algorithm import Algorithm
12
  from rl_algo_impls.shared.callbacks.callback import Callback
13
- from rl_algo_impls.shared.gae import compute_advantage, compute_rtg_and_advantage
 
 
14
  from rl_algo_impls.shared.policy.on_policy import ActorCritic
15
  from rl_algo_impls.shared.schedule import (
16
  schedule,
17
  update_learning_rate,
18
  )
19
- from rl_algo_impls.shared.trajectory import Trajectory, TrajectoryAccumulator
20
- from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs
21
-
22
-
23
- @dataclass
24
- class PPOTrajectory(Trajectory):
25
- logp_a: List[float] = field(default_factory=list)
26
-
27
- def add(
28
- self,
29
- obs: np.ndarray,
30
- act: np.ndarray,
31
- next_obs: np.ndarray,
32
- rew: float,
33
- terminated: bool,
34
- v: float,
35
- logp_a: float,
36
- ):
37
- super().add(obs, act, next_obs, rew, terminated, v)
38
- self.logp_a.append(logp_a)
39
-
40
-
41
- class PPOTrajectoryAccumulator(TrajectoryAccumulator):
42
- def __init__(self, num_envs: int) -> None:
43
- super().__init__(num_envs, PPOTrajectory)
44
-
45
- def step(
46
- self,
47
- obs: VecEnvObs,
48
- action: np.ndarray,
49
- next_obs: VecEnvObs,
50
- reward: np.ndarray,
51
- done: np.ndarray,
52
- val: np.ndarray,
53
- logp_a: np.ndarray,
54
- ) -> None:
55
- super().step(obs, action, next_obs, reward, done, val, logp_a)
56
 
57
 
58
  class TrainStepStats(NamedTuple):
@@ -131,11 +105,11 @@ class PPO(Algorithm):
131
  vf_coef: float = 0.5,
132
  ppo2_vf_coef_halving: bool = True,
133
  max_grad_norm: float = 0.5,
134
- update_rtg_between_epochs: bool = False,
135
  sde_sample_freq: int = -1,
136
  ) -> None:
137
  super().__init__(policy, env, device, tb_writer)
138
  self.policy = policy
 
139
 
140
  self.gamma = gamma
141
  self.gae_lambda = gae_lambda
@@ -146,7 +120,13 @@ class PPO(Algorithm):
146
  self.clip_range_vf_schedule = None
147
  if clip_range_vf:
148
  self.clip_range_vf_schedule = schedule(clip_range_vf_decay, clip_range_vf)
 
 
 
 
 
149
  self.normalize_advantage = normalize_advantage
 
150
  self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
151
  self.vf_coef = vf_coef
152
  self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
@@ -156,181 +136,235 @@ class PPO(Algorithm):
156
  self.n_epochs = n_epochs
157
  self.sde_sample_freq = sde_sample_freq
158
 
159
- self.update_rtg_between_epochs = update_rtg_between_epochs
160
-
161
  def learn(
162
  self: PPOSelf,
163
- total_timesteps: int,
164
  callback: Optional[Callback] = None,
 
 
165
  ) -> PPOSelf:
166
- obs = self.env.reset()
167
- ts_elapsed = 0
168
- while ts_elapsed < total_timesteps:
169
- start_time = perf_counter()
170
- accumulator = self._collect_trajectories(obs)
171
- rollout_steps = self.n_steps * self.env.num_envs
172
- ts_elapsed += rollout_steps
173
- progress = ts_elapsed / total_timesteps
174
- train_stats = self.train(accumulator.all_trajectories, progress, ts_elapsed)
175
- train_stats.write_to_tensorboard(self.tb_writer, ts_elapsed)
176
- end_time = perf_counter()
177
- self.tb_writer.add_scalar(
178
- "train/steps_per_second",
179
- rollout_steps / (end_time - start_time),
180
- ts_elapsed,
 
 
 
 
 
 
 
 
 
181
  )
182
- if callback:
183
- callback.on_step(timesteps_elapsed=rollout_steps)
184
-
185
- return self
186
-
187
- def _collect_trajectories(self, obs: VecEnvObs) -> PPOTrajectoryAccumulator:
188
- self.policy.eval()
189
- accumulator = PPOTrajectoryAccumulator(self.env.num_envs)
190
- self.policy.reset_noise()
191
- for i in range(self.n_steps):
192
- if self.sde_sample_freq > 0 and i > 0 and i % self.sde_sample_freq == 0:
193
- self.policy.reset_noise()
194
- action, value, logp_a, clamped_action = self.policy.step(obs)
195
- next_obs, reward, done, _ = self.env.step(clamped_action)
196
- accumulator.step(obs, action, next_obs, reward, done, value, logp_a)
197
- obs = next_obs
198
- return accumulator
199
-
200
- def train(
201
- self, trajectories: List[PPOTrajectory], progress: float, timesteps_elapsed: int
202
- ) -> TrainStats:
203
- self.policy.train()
204
- learning_rate = self.lr_schedule(progress)
205
- update_learning_rate(self.optimizer, learning_rate)
206
- self.tb_writer.add_scalar(
207
- "charts/learning_rate",
208
- self.optimizer.param_groups[0]["lr"],
209
- timesteps_elapsed,
210
  )
211
 
212
- pi_clip = self.clip_range_schedule(progress)
213
- self.tb_writer.add_scalar("charts/pi_clip", pi_clip, timesteps_elapsed)
214
- if self.clip_range_vf_schedule:
215
- v_clip = self.clip_range_vf_schedule(progress)
216
- self.tb_writer.add_scalar("charts/v_clip", v_clip, timesteps_elapsed)
217
- else:
218
- v_clip = None
219
- ent_coef = self.ent_coef_schedule(progress)
220
- self.tb_writer.add_scalar("charts/ent_coef", ent_coef, timesteps_elapsed)
221
-
222
- obs = torch.as_tensor(
223
- np.concatenate([np.array(t.obs) for t in trajectories]), device=self.device
224
- )
225
- act = torch.as_tensor(
226
- np.concatenate([np.array(t.act) for t in trajectories]), device=self.device
227
- )
228
- rtg, adv = compute_rtg_and_advantage(
229
- trajectories, self.policy, self.gamma, self.gae_lambda, self.device
230
- )
231
- orig_v = torch.as_tensor(
232
- np.concatenate([np.array(t.v) for t in trajectories]), device=self.device
233
- )
234
- orig_logp_a = torch.as_tensor(
235
- np.concatenate([np.array(t.logp_a) for t in trajectories]),
236
- device=self.device,
237
- )
238
 
239
- step_stats = []
240
- for _ in range(self.n_epochs):
241
- step_stats.clear()
242
- if self.update_rtg_between_epochs:
243
- rtg, adv = compute_rtg_and_advantage(
244
- trajectories, self.policy, self.gamma, self.gae_lambda, self.device
245
- )
 
 
 
 
 
 
246
  else:
247
- adv = compute_advantage(
248
- trajectories, self.policy, self.gamma, self.gae_lambda, self.device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  )
250
- idxs = torch.randperm(len(obs))
251
- for i in range(0, len(obs), self.batch_size):
252
- mb_idxs = idxs[i : i + self.batch_size]
253
- mb_adv = adv[mb_idxs]
254
- if self.normalize_advantage:
255
- mb_adv = (mb_adv - mb_adv.mean(-1)) / (mb_adv.std(-1) + 1e-8)
256
- self.policy.reset_noise(self.batch_size)
257
- step_stats.append(
258
- self._train_step(
259
- pi_clip,
260
- v_clip,
261
- ent_coef,
262
- obs[mb_idxs],
263
- act[mb_idxs],
264
- rtg[mb_idxs],
265
- mb_adv,
266
- orig_v[mb_idxs],
267
- orig_logp_a[mb_idxs],
268
- )
269
  )
270
 
271
- y_pred, y_true = orig_v.cpu().numpy(), rtg.cpu().numpy()
272
- var_y = np.var(y_true).item()
273
- explained_var = (
274
- np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
275
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
- return TrainStats(step_stats, explained_var)
 
 
 
278
 
279
- def _train_step(
280
- self,
281
- pi_clip: float,
282
- v_clip: Optional[float],
283
- ent_coef: float,
284
- obs: torch.Tensor,
285
- act: torch.Tensor,
286
- rtg: torch.Tensor,
287
- adv: torch.Tensor,
288
- orig_v: torch.Tensor,
289
- orig_logp_a: torch.Tensor,
290
- ) -> TrainStepStats:
291
- logp_a, entropy, v = self.policy(obs, act)
292
- logratio = logp_a - orig_logp_a
293
- ratio = torch.exp(logratio)
294
- clip_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
295
- pi_loss = torch.maximum(-ratio * adv, -clip_ratio * adv).mean()
296
-
297
- v_loss_unclipped = (v - rtg) ** 2
298
- if v_clip:
299
- v_loss_clipped = (
300
- orig_v + torch.clamp(v - orig_v, -v_clip, v_clip) - rtg
301
- ) ** 2
302
- v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
303
- else:
304
- v_loss = v_loss_unclipped.mean()
305
- if self.ppo2_vf_coef_halving:
306
- v_loss *= 0.5
307
-
308
- entropy_loss = -entropy.mean()
309
-
310
- loss = pi_loss + ent_coef * entropy_loss + self.vf_coef * v_loss
311
-
312
- self.optimizer.zero_grad()
313
- loss.backward()
314
- nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
315
- self.optimizer.step()
316
-
317
- with torch.no_grad():
318
- approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
319
- clipped_frac = (
320
- ((ratio - 1).abs() > pi_clip).float().mean().cpu().numpy().item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  )
322
- val_clipped_frac = (
323
- (((v - orig_v).abs() > v_clip).float().mean().cpu().numpy().item())
324
- if v_clip
325
- else 0
326
  )
327
 
328
- return TrainStepStats(
329
- loss.item(),
330
- pi_loss.item(),
331
- v_loss.item(),
332
- entropy_loss.item(),
333
- approx_kl,
334
- clipped_frac,
335
- val_clipped_frac,
336
- )
 
 
 
 
 
 
 
 
1
+ import logging
2
  import numpy as np
3
  import torch
4
  import torch.nn as nn
5
 
6
+ from dataclasses import asdict, dataclass
7
  from time import perf_counter
8
  from torch.optim import Adam
9
  from torch.utils.tensorboard.writer import SummaryWriter
 
11
 
12
  from rl_algo_impls.shared.algorithm import Algorithm
13
  from rl_algo_impls.shared.callbacks.callback import Callback
14
+ from rl_algo_impls.shared.gae import (
15
+ compute_advantages,
16
+ )
17
  from rl_algo_impls.shared.policy.on_policy import ActorCritic
18
  from rl_algo_impls.shared.schedule import (
19
  schedule,
20
  update_learning_rate,
21
  )
22
+ from rl_algo_impls.shared.stats import log_scalars
23
+ from rl_algo_impls.wrappers.action_mask_wrapper import ActionMaskWrapper
24
+ from rl_algo_impls.wrappers.vectorable_wrapper import (
25
+ VecEnv,
26
+ find_wrapper,
27
+ single_observation_space,
28
+ single_action_space,
29
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  class TrainStepStats(NamedTuple):
 
105
  vf_coef: float = 0.5,
106
  ppo2_vf_coef_halving: bool = True,
107
  max_grad_norm: float = 0.5,
 
108
  sde_sample_freq: int = -1,
109
  ) -> None:
110
  super().__init__(policy, env, device, tb_writer)
111
  self.policy = policy
112
+ self.action_masker = find_wrapper(env, ActionMaskWrapper)
113
 
114
  self.gamma = gamma
115
  self.gae_lambda = gae_lambda
 
120
  self.clip_range_vf_schedule = None
121
  if clip_range_vf:
122
  self.clip_range_vf_schedule = schedule(clip_range_vf_decay, clip_range_vf)
123
+
124
+ if normalize_advantage:
125
+ assert (
126
+ env.num_envs * n_steps > 1 and batch_size > 1
127
+ ), f"Each minibatch must be larger than 1 to support normalization"
128
  self.normalize_advantage = normalize_advantage
129
+
130
  self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
131
  self.vf_coef = vf_coef
132
  self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
 
136
  self.n_epochs = n_epochs
137
  self.sde_sample_freq = sde_sample_freq
138
 
 
 
139
  def learn(
140
  self: PPOSelf,
141
+ train_timesteps: int,
142
  callback: Optional[Callback] = None,
143
+ total_timesteps: Optional[int] = None,
144
+ start_timesteps: int = 0,
145
  ) -> PPOSelf:
146
+ if total_timesteps is None:
147
+ total_timesteps = train_timesteps
148
+ assert start_timesteps + train_timesteps <= total_timesteps
149
+
150
+ epoch_dim = (self.n_steps, self.env.num_envs)
151
+ step_dim = (self.env.num_envs,)
152
+ obs_space = single_observation_space(self.env)
153
+ act_space = single_action_space(self.env)
154
+
155
+ next_obs = self.env.reset()
156
+ next_action_masks = (
157
+ self.action_masker.action_masks() if self.action_masker else None
158
+ )
159
+ next_episode_starts = np.full(step_dim, True, dtype=np.bool8)
160
+
161
+ obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype) # type: ignore
162
+ actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype) # type: ignore
163
+ rewards = np.zeros(epoch_dim, dtype=np.float32)
164
+ episode_starts = np.zeros(epoch_dim, dtype=np.bool8)
165
+ values = np.zeros(epoch_dim, dtype=np.float32)
166
+ logprobs = np.zeros(epoch_dim, dtype=np.float32)
167
+ action_masks = (
168
+ np.zeros(
169
+ (self.n_steps,) + next_action_masks.shape, dtype=next_action_masks.dtype
170
  )
171
+ if next_action_masks is not None
172
+ else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  )
174
 
175
+ timesteps_elapsed = start_timesteps
176
+ while timesteps_elapsed < start_timesteps + train_timesteps:
177
+ start_time = perf_counter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ progress = timesteps_elapsed / total_timesteps
180
+ ent_coef = self.ent_coef_schedule(progress)
181
+ learning_rate = self.lr_schedule(progress)
182
+ update_learning_rate(self.optimizer, learning_rate)
183
+ pi_clip = self.clip_range_schedule(progress)
184
+ chart_scalars = {
185
+ "learning_rate": self.optimizer.param_groups[0]["lr"],
186
+ "ent_coef": ent_coef,
187
+ "pi_clip": pi_clip,
188
+ }
189
+ if self.clip_range_vf_schedule:
190
+ v_clip = self.clip_range_vf_schedule(progress)
191
+ chart_scalars["v_clip"] = v_clip
192
  else:
193
+ v_clip = None
194
+ log_scalars(self.tb_writer, "charts", chart_scalars, timesteps_elapsed)
195
+
196
+ self.policy.eval()
197
+ self.policy.reset_noise()
198
+ for s in range(self.n_steps):
199
+ timesteps_elapsed += self.env.num_envs
200
+ if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0:
201
+ self.policy.reset_noise()
202
+
203
+ obs[s] = next_obs
204
+ episode_starts[s] = next_episode_starts
205
+ if action_masks is not None:
206
+ action_masks[s] = next_action_masks
207
+
208
+ (
209
+ actions[s],
210
+ values[s],
211
+ logprobs[s],
212
+ clamped_action,
213
+ ) = self.policy.step(next_obs, action_masks=next_action_masks)
214
+ next_obs, rewards[s], next_episode_starts, _ = self.env.step(
215
+ clamped_action
216
  )
217
+ next_action_masks = (
218
+ self.action_masker.action_masks() if self.action_masker else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  )
220
 
221
+ self.policy.train()
222
+
223
+ advantages = compute_advantages(
224
+ rewards,
225
+ values,
226
+ episode_starts,
227
+ next_episode_starts,
228
+ next_obs,
229
+ self.policy,
230
+ self.gamma,
231
+ self.gae_lambda,
232
+ )
233
+ returns = advantages + values
234
+
235
+ b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device) # type: ignore
236
+ b_actions = torch.tensor(actions.reshape((-1,) + act_space.shape)).to( # type: ignore
237
+ self.device
238
+ )
239
+ b_logprobs = torch.tensor(logprobs.reshape(-1)).to(self.device)
240
+ b_action_masks = (
241
+ torch.tensor(action_masks.reshape((-1,) + next_action_masks.shape[1:])).to( # type: ignore
242
+ self.device
243
+ )
244
+ if action_masks is not None
245
+ else None
246
+ )
247
+
248
+ b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
249
+
250
+ y_pred = values.reshape(-1)
251
+ b_values = torch.tensor(y_pred).to(self.device)
252
+ y_true = returns.reshape(-1)
253
+ b_returns = torch.tensor(y_true).to(self.device)
254
+
255
+ step_stats = []
256
+ for _ in range(self.n_epochs):
257
+ b_idxs = torch.randperm(len(b_obs))
258
+ # Only record last epoch's stats
259
+ step_stats.clear()
260
+ for i in range(0, len(b_obs), self.batch_size):
261
+ self.policy.reset_noise(self.batch_size)
262
+
263
+ mb_idxs = b_idxs[i : i + self.batch_size]
264
+
265
+ mb_obs = b_obs[mb_idxs]
266
+ mb_actions = b_actions[mb_idxs]
267
+ mb_values = b_values[mb_idxs]
268
+ mb_logprobs = b_logprobs[mb_idxs]
269
+ mb_action_masks = (
270
+ b_action_masks[mb_idxs] if b_action_masks is not None else None
271
+ )
272
 
273
+ mb_adv = b_advantages[mb_idxs]
274
+ if self.normalize_advantage:
275
+ mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8)
276
+ mb_returns = b_returns[mb_idxs]
277
 
278
+ new_logprobs, entropy, new_values = self.policy(
279
+ mb_obs, mb_actions, action_masks=mb_action_masks
280
+ )
281
+
282
+ logratio = new_logprobs - mb_logprobs
283
+ ratio = torch.exp(logratio)
284
+ clipped_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
285
+ pi_loss = torch.max(
286
+ -ratio * mb_adv, -clipped_ratio * mb_adv
287
+ ).mean()
288
+
289
+ v_loss_unclipped = (new_values - mb_returns) ** 2
290
+ if v_clip:
291
+ v_loss_clipped = (
292
+ mb_values
293
+ + torch.clamp(new_values - mb_values, -v_clip, v_clip)
294
+ - mb_returns
295
+ ) ** 2
296
+ v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
297
+ else:
298
+ v_loss = v_loss_unclipped.mean()
299
+
300
+ if self.ppo2_vf_coef_halving:
301
+ v_loss *= 0.5
302
+
303
+ entropy_loss = -entropy.mean()
304
+
305
+ loss = pi_loss + ent_coef * entropy_loss + self.vf_coef * v_loss
306
+
307
+ self.optimizer.zero_grad()
308
+ loss.backward()
309
+ nn.utils.clip_grad_norm_(
310
+ self.policy.parameters(), self.max_grad_norm
311
+ )
312
+ self.optimizer.step()
313
+
314
+ with torch.no_grad():
315
+ approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
316
+ clipped_frac = (
317
+ ((ratio - 1).abs() > pi_clip)
318
+ .float()
319
+ .mean()
320
+ .cpu()
321
+ .numpy()
322
+ .item()
323
+ )
324
+ val_clipped_frac = (
325
+ ((new_values - mb_values).abs() > v_clip)
326
+ .float()
327
+ .mean()
328
+ .cpu()
329
+ .numpy()
330
+ .item()
331
+ if v_clip
332
+ else 0
333
+ )
334
+
335
+ step_stats.append(
336
+ TrainStepStats(
337
+ loss.item(),
338
+ pi_loss.item(),
339
+ v_loss.item(),
340
+ entropy_loss.item(),
341
+ approx_kl,
342
+ clipped_frac,
343
+ val_clipped_frac,
344
+ )
345
+ )
346
+
347
+ var_y = np.var(y_true).item()
348
+ explained_var = (
349
+ np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
350
  )
351
+ TrainStats(step_stats, explained_var).write_to_tensorboard(
352
+ self.tb_writer, timesteps_elapsed
 
 
353
  )
354
 
355
+ end_time = perf_counter()
356
+ rollout_steps = self.n_steps * self.env.num_envs
357
+ self.tb_writer.add_scalar(
358
+ "train/steps_per_second",
359
+ rollout_steps / (end_time - start_time),
360
+ timesteps_elapsed,
361
+ )
362
+
363
+ if callback:
364
+ if not callback.on_step(timesteps_elapsed=rollout_steps):
365
+ logging.info(
366
+ f"Callback terminated training at {timesteps_elapsed} timesteps"
367
+ )
368
+ break
369
+
370
+ return self
rl_algo_impls/runner/config.py CHANGED
@@ -36,7 +36,7 @@ class RunArgs:
36
 
37
  @dataclass
38
  class EnvHyperparams:
39
- env_type: str = "sb3vec"
40
  n_envs: int = 1
41
  frame_stack: int = 1
42
  make_kwargs: Optional[Dict[str, Any]] = None
@@ -50,7 +50,8 @@ class EnvHyperparams:
50
  video_step_interval: Union[int, float] = 1_000_000
51
  initial_steps_to_truncate: Optional[int] = None
52
  clip_atari_rewards: bool = True
53
- normalize_type: Optional[str] = "gymlike"
 
54
 
55
 
56
  HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
 
36
 
37
  @dataclass
38
  class EnvHyperparams:
39
+ env_type: str = "gymvec"
40
  n_envs: int = 1
41
  frame_stack: int = 1
42
  make_kwargs: Optional[Dict[str, Any]] = None
 
50
  video_step_interval: Union[int, float] = 1_000_000
51
  initial_steps_to_truncate: Optional[int] = None
52
  clip_atari_rewards: bool = True
53
+ normalize_type: Optional[str] = None
54
+ mask_actions: bool = False
55
 
56
 
57
  HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
rl_algo_impls/runner/env.py CHANGED
@@ -20,6 +20,7 @@ from typing import Callable, Optional
20
 
21
  from rl_algo_impls.runner.config import Config, EnvHyperparams
22
  from rl_algo_impls.shared.policy.policy import VEC_NORMALIZE_FILENAME
 
23
  from rl_algo_impls.wrappers.atari_wrappers import (
24
  EpisodicLifeEnv,
25
  FireOnLifeStarttEnv,
@@ -113,21 +114,20 @@ def _make_vec_env(
113
  initial_steps_to_truncate,
114
  clip_atari_rewards,
115
  normalize_type,
 
116
  ) = astuple(hparams)
117
 
118
  import_for_env_id(config.env_id)
119
 
120
- spec = gym.spec(config.env_id)
121
  seed = config.seed(training=training)
122
 
123
  make_kwargs = make_kwargs.copy() if make_kwargs is not None else {}
124
- if "BulletEnv" in config.env_id and render:
125
  make_kwargs["render"] = True
126
- if "CarRacing" in config.env_id:
127
  make_kwargs["verbose"] = 0
128
- if "procgen" in config.env_id:
129
- if not render:
130
- make_kwargs["render_mode"] = "rgb_array"
131
 
132
  def make(idx: int) -> Callable[[], gym.Env]:
133
  def _make() -> gym.Env:
@@ -145,7 +145,7 @@ def _make_vec_env(
145
  env = InitialStepTruncateWrapper(
146
  env, idx * initial_steps_to_truncate // n_envs
147
  )
148
- if "AtariEnv" in spec.entry_point: # type: ignore
149
  env = NoopResetEnv(env, noop_max=30)
150
  env = MaxAndSkipEnv(env, skip=4)
151
  env = EpisodicLifeEnv(env, training=training)
@@ -157,17 +157,17 @@ def _make_vec_env(
157
  env = ResizeObservation(env, (84, 84))
158
  env = GrayScaleObservation(env, keep_dim=False)
159
  env = FrameStack(env, frame_stack)
160
- elif "CarRacing" in config.env_id:
161
  env = ResizeObservation(env, (64, 64))
162
  env = GrayScaleObservation(env, keep_dim=False)
163
  env = FrameStack(env, frame_stack)
164
- elif "procgen" in config.env_id:
165
  # env = GrayScaleObservation(env, keep_dim=False)
166
  env = NoopEnvSeed(env)
167
  env = HwcToChwObservation(env)
168
  if frame_stack > 1:
169
  env = FrameStack(env, frame_stack)
170
- elif "Microrts" in config.env_id:
171
  env = HwcToChwObservation(env)
172
 
173
  if no_reward_timeout_steps:
@@ -195,6 +195,8 @@ def _make_vec_env(
195
  envs = SyncVectorEnvRenderCompat(envs)
196
  if env_type == "sb3vec":
197
  envs = IsVectorEnv(envs)
 
 
198
  if training:
199
  assert tb_writer
200
  envs = EpisodeStatsWriter(
@@ -262,6 +264,8 @@ def _make_procgen_env(
262
  _, # video_step_interval
263
  _, # initial_steps_to_truncate
264
  _, # clip_atari_rewards
 
 
265
  ) = astuple(hparams)
266
 
267
  seed = config.seed(training=training)
@@ -307,3 +311,24 @@ def import_for_env_id(env_id: str) -> None:
307
  import pybullet_envs
308
  if "Microrts" in env_id:
309
  import gym_microrts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  from rl_algo_impls.runner.config import Config, EnvHyperparams
22
  from rl_algo_impls.shared.policy.policy import VEC_NORMALIZE_FILENAME
23
+ from rl_algo_impls.wrappers.action_mask_wrapper import ActionMaskWrapper
24
  from rl_algo_impls.wrappers.atari_wrappers import (
25
  EpisodicLifeEnv,
26
  FireOnLifeStarttEnv,
 
114
  initial_steps_to_truncate,
115
  clip_atari_rewards,
116
  normalize_type,
117
+ mask_actions,
118
  ) = astuple(hparams)
119
 
120
  import_for_env_id(config.env_id)
121
 
 
122
  seed = config.seed(training=training)
123
 
124
  make_kwargs = make_kwargs.copy() if make_kwargs is not None else {}
125
+ if is_bullet_env(config) and render:
126
  make_kwargs["render"] = True
127
+ if is_car_racing(config):
128
  make_kwargs["verbose"] = 0
129
+ if is_gym_procgen(config) and not render:
130
+ make_kwargs["render_mode"] = "rgb_array"
 
131
 
132
  def make(idx: int) -> Callable[[], gym.Env]:
133
  def _make() -> gym.Env:
 
145
  env = InitialStepTruncateWrapper(
146
  env, idx * initial_steps_to_truncate // n_envs
147
  )
148
+ if is_atari(config): # type: ignore
149
  env = NoopResetEnv(env, noop_max=30)
150
  env = MaxAndSkipEnv(env, skip=4)
151
  env = EpisodicLifeEnv(env, training=training)
 
157
  env = ResizeObservation(env, (84, 84))
158
  env = GrayScaleObservation(env, keep_dim=False)
159
  env = FrameStack(env, frame_stack)
160
+ elif is_car_racing(config):
161
  env = ResizeObservation(env, (64, 64))
162
  env = GrayScaleObservation(env, keep_dim=False)
163
  env = FrameStack(env, frame_stack)
164
+ elif is_gym_procgen(config):
165
  # env = GrayScaleObservation(env, keep_dim=False)
166
  env = NoopEnvSeed(env)
167
  env = HwcToChwObservation(env)
168
  if frame_stack > 1:
169
  env = FrameStack(env, frame_stack)
170
+ elif is_microrts(config):
171
  env = HwcToChwObservation(env)
172
 
173
  if no_reward_timeout_steps:
 
195
  envs = SyncVectorEnvRenderCompat(envs)
196
  if env_type == "sb3vec":
197
  envs = IsVectorEnv(envs)
198
+ if mask_actions:
199
+ envs = ActionMaskWrapper(envs)
200
  if training:
201
  assert tb_writer
202
  envs = EpisodeStatsWriter(
 
264
  _, # video_step_interval
265
  _, # initial_steps_to_truncate
266
  _, # clip_atari_rewards
267
+ _, # normalize_type
268
+ _, # mask_actions
269
  ) = astuple(hparams)
270
 
271
  seed = config.seed(training=training)
 
311
  import pybullet_envs
312
  if "Microrts" in env_id:
313
  import gym_microrts
314
+
315
+
316
+ def is_atari(config: Config) -> bool:
317
+ spec = gym.spec(config.env_id)
318
+ return "AtariEnv" in str(spec.entry_point)
319
+
320
+
321
+ def is_bullet_env(config: Config) -> bool:
322
+ return "BulletEnv" in config.env_id
323
+
324
+
325
+ def is_car_racing(config: Config) -> bool:
326
+ return "CarRacing" in config.env_id
327
+
328
+
329
+ def is_gym_procgen(config: Config) -> bool:
330
+ return "procgen" in config.env_id
331
+
332
+
333
+ def is_microrts(config: Config) -> bool:
334
+ return "Microrts" in config.env_id
rl_algo_impls/runner/evaluate.py CHANGED
@@ -75,7 +75,7 @@ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
75
  render=args.render,
76
  normalize_load_path=model_path,
77
  )
78
- device = get_device(config.device, env)
79
  policy = make_policy(
80
  args.algo,
81
  env,
 
75
  render=args.render,
76
  normalize_load_path=model_path,
77
  )
78
+ device = get_device(config, env)
79
  policy = make_policy(
80
  args.algo,
81
  env,
rl_algo_impls/runner/running_utils.py CHANGED
@@ -15,8 +15,8 @@ from pathlib import Path
15
  from torch.utils.tensorboard.writer import SummaryWriter
16
  from typing import Dict, Optional, Type, Union
17
 
18
- from rl_algo_impls.runner.config import Hyperparams
19
- from rl_algo_impls.runner.env import import_for_env_id
20
  from rl_algo_impls.shared.algorithm import Algorithm
21
  from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
22
  from rl_algo_impls.shared.policy.on_policy import ActorCritic
@@ -93,7 +93,8 @@ def load_hyperparams(algo: str, env_id: str) -> Hyperparams:
93
  raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
94
 
95
 
96
- def get_device(device: str, env: VecEnv) -> torch.device:
 
97
  # cuda by default
98
  if device == "auto":
99
  device = "cuda"
@@ -111,6 +112,8 @@ def get_device(device: str, env: VecEnv) -> torch.device:
111
  device = "cpu"
112
  elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
113
  device = "cpu"
 
 
114
  print(f"Device: {device}")
115
  return torch.device(device)
116
 
 
15
  from torch.utils.tensorboard.writer import SummaryWriter
16
  from typing import Dict, Optional, Type, Union
17
 
18
+ from rl_algo_impls.runner.config import Config, Hyperparams
19
+ from rl_algo_impls.runner.env import import_for_env_id, is_microrts
20
  from rl_algo_impls.shared.algorithm import Algorithm
21
  from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
22
  from rl_algo_impls.shared.policy.on_policy import ActorCritic
 
93
  raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
94
 
95
 
96
+ def get_device(config: Config, env: VecEnv) -> torch.device:
97
+ device = config.device
98
  # cuda by default
99
  if device == "auto":
100
  device = "cuda"
 
112
  device = "cpu"
113
  elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
114
  device = "cpu"
115
+ if is_microrts(config):
116
+ device = "cpu"
117
  print(f"Device: {device}")
118
  return torch.device(device)
119
 
rl_algo_impls/runner/train.py CHANGED
@@ -65,7 +65,7 @@ def train(args: TrainArgs):
65
  env = make_env(
66
  config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
67
  )
68
- device = get_device(config.device, env)
69
  policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
70
  algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
71
 
 
65
  env = make_env(
66
  config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
67
  )
68
+ device = get_device(config, env)
69
  policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
70
  algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
71
 
rl_algo_impls/shared/callbacks/eval_callback.py CHANGED
@@ -9,8 +9,9 @@ from typing import List, Optional, Union
9
  from rl_algo_impls.shared.callbacks.callback import Callback
10
  from rl_algo_impls.shared.policy.policy import Policy
11
  from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
 
12
  from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
13
- from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
14
 
15
 
16
  class EvaluateAccumulator(EpisodeAccumulator):
@@ -83,8 +84,13 @@ def evaluate(
83
  )
84
 
85
  obs = env.reset()
 
86
  while not episodes.is_done():
87
- act = policy.act(obs, deterministic=deterministic)
 
 
 
 
88
  obs, rew, done, _ = env.step(act)
89
  episodes.step(rew, done)
90
  if render:
 
9
  from rl_algo_impls.shared.callbacks.callback import Callback
10
  from rl_algo_impls.shared.policy.policy import Policy
11
  from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
12
+ from rl_algo_impls.wrappers.action_mask_wrapper import ActionMaskWrapper
13
  from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
14
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, find_wrapper
15
 
16
 
17
  class EvaluateAccumulator(EpisodeAccumulator):
 
84
  )
85
 
86
  obs = env.reset()
87
+ action_masker = find_wrapper(env, ActionMaskWrapper)
88
  while not episodes.is_done():
89
+ act = policy.act(
90
+ obs,
91
+ deterministic=deterministic,
92
+ action_masks=action_masker.action_masks() if action_masker else None,
93
+ )
94
  obs, rew, done, _ = env.step(act)
95
  episodes.step(rew, done)
96
  if render:
rl_algo_impls/shared/gae.py CHANGED
@@ -5,6 +5,7 @@ from typing import NamedTuple, Sequence
5
 
6
  from rl_algo_impls.shared.policy.on_policy import OnPolicy
7
  from rl_algo_impls.shared.trajectory import Trajectory
 
8
 
9
 
10
  class RtgAdvantage(NamedTuple):
@@ -19,7 +20,7 @@ def discounted_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
19
  return dc
20
 
21
 
22
- def compute_advantage(
23
  trajectories: Sequence[Trajectory],
24
  policy: OnPolicy,
25
  gamma: float,
@@ -40,7 +41,7 @@ def compute_advantage(
40
  )
41
 
42
 
43
- def compute_rtg_and_advantage(
44
  trajectories: Sequence[Trajectory],
45
  policy: OnPolicy,
46
  gamma: float,
@@ -65,3 +66,29 @@ def compute_rtg_and_advantage(
65
  ),
66
  torch.as_tensor(np.concatenate(advantages), dtype=torch.float32, device=device),
67
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  from rl_algo_impls.shared.policy.on_policy import OnPolicy
7
  from rl_algo_impls.shared.trajectory import Trajectory
8
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvObs
9
 
10
 
11
  class RtgAdvantage(NamedTuple):
 
20
  return dc
21
 
22
 
23
+ def compute_advantage_from_trajectories(
24
  trajectories: Sequence[Trajectory],
25
  policy: OnPolicy,
26
  gamma: float,
 
41
  )
42
 
43
 
44
+ def compute_rtg_and_advantage_from_trajectories(
45
  trajectories: Sequence[Trajectory],
46
  policy: OnPolicy,
47
  gamma: float,
 
66
  ),
67
  torch.as_tensor(np.concatenate(advantages), dtype=torch.float32, device=device),
68
  )
69
+
70
+
71
+ def compute_advantages(
72
+ rewards: np.ndarray,
73
+ values: np.ndarray,
74
+ episode_starts: np.ndarray,
75
+ next_episode_starts: np.ndarray,
76
+ next_obs: VecEnvObs,
77
+ policy: OnPolicy,
78
+ gamma: float,
79
+ gae_lambda: float,
80
+ ) -> np.ndarray:
81
+ advantages = np.zeros_like(rewards)
82
+ last_gae_lam = 0
83
+ n_steps = advantages.shape[0]
84
+ for t in reversed(range(n_steps)):
85
+ if t == n_steps - 1:
86
+ next_nonterminal = 1.0 - next_episode_starts
87
+ next_value = policy.value(next_obs)
88
+ else:
89
+ next_nonterminal = 1.0 - episode_starts[t + 1]
90
+ next_value = values[t + 1]
91
+ delta = rewards[t] + gamma * next_value * next_nonterminal - values[t]
92
+ last_gae_lam = delta + gamma * gae_lambda * next_nonterminal * last_gae_lam
93
+ advantages[t] = last_gae_lam
94
+ return advantages
rl_algo_impls/shared/policy/actor.py CHANGED
@@ -6,8 +6,8 @@ import torch.nn as nn
6
  from abc import ABC, abstractmethod
7
  from gym.spaces import Box, Discrete, MultiDiscrete
8
  from numpy.typing import NDArray
9
- from torch.distributions import Categorical, Distribution, Normal
10
- from typing import NamedTuple, Optional, Sequence, Type, TypeVar, Union
11
 
12
  from rl_algo_impls.shared.module.module import mlp
13
 
@@ -20,7 +20,12 @@ class PiForward(NamedTuple):
20
 
21
  class Actor(nn.Module, ABC):
22
  @abstractmethod
23
- def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
 
 
 
 
 
24
  ...
25
 
26
 
@@ -41,34 +46,53 @@ class CategoricalActorHead(Actor):
41
  final_layer_gain=0.01,
42
  )
43
 
44
- def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
 
 
 
 
 
45
  logits = self._fc(obs)
46
- pi = Categorical(logits=logits)
47
  logp_a = None
48
  entropy = None
49
- if a is not None:
50
- logp_a = pi.log_prob(a)
51
  entropy = pi.entropy()
52
  return PiForward(pi, logp_a, entropy)
53
 
54
 
55
- class MultiCategorical(Categorical):
56
  def __init__(
57
- self, nvec: NDArray[np.int64], probs=None, logits=None, validate_args=None
 
 
 
 
 
58
  ):
59
  # Either probs or logits should be set
60
- assert (probs is not None) != (logits is not None)
 
 
 
 
 
61
  if probs:
62
  self.dists = [
63
- Categorical(probs=p, validate_args=validate_args)
64
- for p in torch.split(probs, nvec.tolist(), dim=1)
65
  ]
 
66
  else:
67
  assert logits is not None
68
  self.dists = [
69
- Categorical(logits=lg, validate_args=validate_args)
70
- for lg in torch.split(logits, nvec.tolist(), dim=1)
71
  ]
 
 
 
72
 
73
  def log_prob(self, action: torch.Tensor) -> torch.Tensor:
74
  prob_stack = torch.stack(
@@ -82,6 +106,34 @@ class MultiCategorical(Categorical):
82
  def sample(self, sample_shape: torch.Size = torch.Size()):
83
  return torch.stack([c.sample(sample_shape) for c in self.dists], dim=-1)
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  class MultiDiscreteActorHead(Actor):
87
  def __init__(
@@ -101,13 +153,18 @@ class MultiDiscreteActorHead(Actor):
101
  final_layer_gain=0.01,
102
  )
103
 
104
- def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
 
 
 
 
 
105
  logits = self._fc(obs)
106
- pi = MultiCategorical(self.nvec, logits=logits)
107
  logp_a = None
108
  entropy = None
109
- if a is not None:
110
- logp_a = pi.log_prob(a)
111
  entropy = pi.entropy()
112
  return PiForward(pi, logp_a, entropy)
113
 
@@ -146,12 +203,20 @@ class GaussianActorHead(Actor):
146
  std = torch.exp(self.log_std)
147
  return GaussianDistribution(mu, std)
148
 
149
- def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
 
 
 
 
 
 
 
 
150
  pi = self._distribution(obs)
151
  logp_a = None
152
  entropy = None
153
- if a is not None:
154
- logp_a = pi.log_prob(a)
155
  entropy = pi.entropy()
156
  return PiForward(pi, logp_a, entropy)
157
 
@@ -311,12 +376,20 @@ class StateDependentNoiseActorHead(Actor):
311
  ones = ones.to(self.device)
312
  return ones * std
313
 
314
- def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
 
 
 
 
 
 
 
 
315
  pi = self._distribution(obs)
316
  logp_a = None
317
  entropy = None
318
- if a is not None:
319
- logp_a = pi.log_prob(a)
320
  entropy = -logp_a if self.bijector else sum_independent_dims(pi.entropy())
321
  return PiForward(pi, logp_a, entropy)
322
 
 
6
  from abc import ABC, abstractmethod
7
  from gym.spaces import Box, Discrete, MultiDiscrete
8
  from numpy.typing import NDArray
9
+ from torch.distributions import Categorical, Distribution, Normal, constraints
10
+ from typing import Dict, NamedTuple, Optional, Sequence, Type, TypeVar, Union
11
 
12
  from rl_algo_impls.shared.module.module import mlp
13
 
 
20
 
21
  class Actor(nn.Module, ABC):
22
  @abstractmethod
23
+ def forward(
24
+ self,
25
+ obs: torch.Tensor,
26
+ actions: Optional[torch.Tensor] = None,
27
+ action_masks: Optional[torch.Tensor] = None,
28
+ ) -> PiForward:
29
  ...
30
 
31
 
 
46
  final_layer_gain=0.01,
47
  )
48
 
49
+ def forward(
50
+ self,
51
+ obs: torch.Tensor,
52
+ actions: Optional[torch.Tensor] = None,
53
+ action_masks: Optional[torch.Tensor] = None,
54
+ ) -> PiForward:
55
  logits = self._fc(obs)
56
+ pi = MaskedCategorical(logits=logits, mask=action_masks)
57
  logp_a = None
58
  entropy = None
59
+ if actions is not None:
60
+ logp_a = pi.log_prob(actions)
61
  entropy = pi.entropy()
62
  return PiForward(pi, logp_a, entropy)
63
 
64
 
65
+ class MultiCategorical(Distribution):
66
  def __init__(
67
+ self,
68
+ nvec: NDArray[np.int64],
69
+ probs=None,
70
+ logits=None,
71
+ validate_args=None,
72
+ masks: Optional[torch.Tensor] = None,
73
  ):
74
  # Either probs or logits should be set
75
+ assert (probs is None) != (logits is None)
76
+ masks_split = (
77
+ torch.split(masks, nvec.tolist(), dim=1)
78
+ if masks is not None
79
+ else [None] * len(nvec)
80
+ )
81
  if probs:
82
  self.dists = [
83
+ MaskedCategorical(probs=p, validate_args=validate_args, mask=m)
84
+ for p, m in zip(torch.split(probs, nvec.tolist(), dim=1), masks_split)
85
  ]
86
+ param = probs
87
  else:
88
  assert logits is not None
89
  self.dists = [
90
+ MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
91
+ for lg, m in zip(torch.split(logits, nvec.tolist(), dim=1), masks_split)
92
  ]
93
+ param = logits
94
+ batch_shape = param.size()[:-1] if param.ndimension() > 1 else torch.Size()
95
+ super().__init__(batch_shape=batch_shape, validate_args=validate_args)
96
 
97
  def log_prob(self, action: torch.Tensor) -> torch.Tensor:
98
  prob_stack = torch.stack(
 
106
  def sample(self, sample_shape: torch.Size = torch.Size()):
107
  return torch.stack([c.sample(sample_shape) for c in self.dists], dim=-1)
108
 
109
+ @property
110
+ def arg_constraints(self) -> Dict[str, constraints.Constraint]:
111
+ # Constraints handled by child distributions in dist
112
+ return {}
113
+
114
+
115
+ class MaskedCategorical(Categorical):
116
+ def __init__(
117
+ self,
118
+ probs=None,
119
+ logits=None,
120
+ validate_args=None,
121
+ mask: Optional[torch.Tensor] = None,
122
+ ):
123
+ if mask is not None:
124
+ assert logits is not None, "mask requires logits and not probs"
125
+ logits = torch.where(mask, logits, -1e8)
126
+ self.mask = mask
127
+ super().__init__(probs, logits, validate_args)
128
+
129
+ def entropy(self) -> torch.Tensor:
130
+ if self.mask is None:
131
+ return super().entropy()
132
+ # If mask set, then use approximation for entropy
133
+ p_log_p = self.logits * self.probs
134
+ masked = torch.where(self.mask, p_log_p, 0)
135
+ return -masked.sum(-1)
136
+
137
 
138
  class MultiDiscreteActorHead(Actor):
139
  def __init__(
 
153
  final_layer_gain=0.01,
154
  )
155
 
156
+ def forward(
157
+ self,
158
+ obs: torch.Tensor,
159
+ actions: Optional[torch.Tensor] = None,
160
+ action_masks: Optional[torch.Tensor] = None,
161
+ ) -> PiForward:
162
  logits = self._fc(obs)
163
+ pi = MultiCategorical(self.nvec, logits=logits, masks=action_masks)
164
  logp_a = None
165
  entropy = None
166
+ if actions is not None:
167
+ logp_a = pi.log_prob(actions)
168
  entropy = pi.entropy()
169
  return PiForward(pi, logp_a, entropy)
170
 
 
203
  std = torch.exp(self.log_std)
204
  return GaussianDistribution(mu, std)
205
 
206
+ def forward(
207
+ self,
208
+ obs: torch.Tensor,
209
+ actions: Optional[torch.Tensor] = None,
210
+ action_masks: Optional[torch.Tensor] = None,
211
+ ) -> PiForward:
212
+ assert (
213
+ not action_masks
214
+ ), f"{self.__class__.__name__} does not support action_masks"
215
  pi = self._distribution(obs)
216
  logp_a = None
217
  entropy = None
218
+ if actions is not None:
219
+ logp_a = pi.log_prob(actions)
220
  entropy = pi.entropy()
221
  return PiForward(pi, logp_a, entropy)
222
 
 
376
  ones = ones.to(self.device)
377
  return ones * std
378
 
379
+ def forward(
380
+ self,
381
+ obs: torch.Tensor,
382
+ actions: Optional[torch.Tensor] = None,
383
+ action_masks: Optional[torch.Tensor] = None,
384
+ ) -> PiForward:
385
+ assert (
386
+ not action_masks
387
+ ), f"{self.__class__.__name__} does not support action_masks"
388
  pi = self._distribution(obs)
389
  logp_a = None
390
  entropy = None
391
+ if actions is not None:
392
+ logp_a = pi.log_prob(actions)
393
  entropy = -logp_a if self.bijector else sum_independent_dims(pi.entropy())
394
  return PiForward(pi, logp_a, entropy)
395
 
rl_algo_impls/shared/policy/on_policy.py CHANGED
@@ -77,7 +77,7 @@ class OnPolicy(Policy):
77
  ...
78
 
79
  @abstractmethod
80
- def step(self, obs: VecEnvObs) -> Step:
81
  ...
82
 
83
 
@@ -162,10 +162,13 @@ class ActorCritic(OnPolicy):
162
  )
163
 
164
  def _pi_forward(
165
- self, obs: torch.Tensor, action: Optional[torch.Tensor] = None
 
 
 
166
  ) -> Tuple[PiForward, torch.Tensor]:
167
  p_fe = self._feature_extractor(obs)
168
- pi_forward = self._pi(p_fe, action)
169
 
170
  return pi_forward, p_fe
171
 
@@ -173,8 +176,13 @@ class ActorCritic(OnPolicy):
173
  v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
174
  return self._v(v_fe)
175
 
176
- def forward(self, obs: torch.Tensor, action: torch.Tensor) -> ACForward:
177
- (_, logp_a, entropy), p_fc = self._pi_forward(obs, action)
 
 
 
 
 
178
  v = self._v_forward(obs, p_fc)
179
 
180
  assert logp_a is not None
@@ -192,10 +200,11 @@ class ActorCritic(OnPolicy):
192
  v = self._v(fe)
193
  return v.cpu().numpy()
194
 
195
- def step(self, obs: VecEnvObs) -> Step:
196
  o = self._as_tensor(obs)
 
197
  with torch.no_grad():
198
- (pi, _, _), p_fc = self._pi_forward(o)
199
  a = pi.sample()
200
  logp_a = pi.log_prob(a)
201
 
@@ -205,13 +214,21 @@ class ActorCritic(OnPolicy):
205
  clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
206
  return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
207
 
208
- def act(self, obs: np.ndarray, deterministic: bool = True) -> np.ndarray:
 
 
 
 
 
209
  if not deterministic:
210
- return self.step(obs).clamped_a
211
  else:
212
  o = self._as_tensor(obs)
 
 
 
213
  with torch.no_grad():
214
- (pi, _, _), _ = self._pi_forward(o)
215
  a = pi.mode
216
  return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
217
 
 
77
  ...
78
 
79
  @abstractmethod
80
+ def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
81
  ...
82
 
83
 
 
162
  )
163
 
164
  def _pi_forward(
165
+ self,
166
+ obs: torch.Tensor,
167
+ action_masks: Optional[torch.Tensor],
168
+ action: Optional[torch.Tensor] = None,
169
  ) -> Tuple[PiForward, torch.Tensor]:
170
  p_fe = self._feature_extractor(obs)
171
+ pi_forward = self._pi(p_fe, actions=action, action_masks=action_masks)
172
 
173
  return pi_forward, p_fe
174
 
 
176
  v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
177
  return self._v(v_fe)
178
 
179
+ def forward(
180
+ self,
181
+ obs: torch.Tensor,
182
+ action: torch.Tensor,
183
+ action_masks: Optional[torch.Tensor] = None,
184
+ ) -> ACForward:
185
+ (_, logp_a, entropy), p_fc = self._pi_forward(obs, action_masks, action=action)
186
  v = self._v_forward(obs, p_fc)
187
 
188
  assert logp_a is not None
 
200
  v = self._v(fe)
201
  return v.cpu().numpy()
202
 
203
+ def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
204
  o = self._as_tensor(obs)
205
+ a_masks = self._as_tensor(action_masks) if action_masks is not None else None
206
  with torch.no_grad():
207
+ (pi, _, _), p_fc = self._pi_forward(o, action_masks=a_masks)
208
  a = pi.sample()
209
  logp_a = pi.log_prob(a)
210
 
 
214
  clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
215
  return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
216
 
217
+ def act(
218
+ self,
219
+ obs: np.ndarray,
220
+ deterministic: bool = True,
221
+ action_masks: Optional[np.ndarray] = None,
222
+ ) -> np.ndarray:
223
  if not deterministic:
224
+ return self.step(obs, action_masks=action_masks).clamped_a
225
  else:
226
  o = self._as_tensor(obs)
227
+ a_masks = (
228
+ self._as_tensor(action_masks) if action_masks is not None else None
229
+ )
230
  with torch.no_grad():
231
+ (pi, _, _), _ = self._pi_forward(o, action_masks=a_masks)
232
  a = pi.mode
233
  return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
234
 
rl_algo_impls/shared/policy/policy.py CHANGED
@@ -46,7 +46,12 @@ class Policy(nn.Module, ABC):
46
  return self
47
 
48
  @abstractmethod
49
- def act(self, obs: VecEnvObs, deterministic: bool = True) -> np.ndarray:
 
 
 
 
 
50
  ...
51
 
52
  def save(self, path: str) -> None:
 
46
  return self
47
 
48
  @abstractmethod
49
+ def act(
50
+ self,
51
+ obs: VecEnvObs,
52
+ deterministic: bool = True,
53
+ action_masks: Optional[np.ndarray] = None,
54
+ ) -> np.ndarray:
55
  ...
56
 
57
  def save(self, path: str) -> None:
rl_algo_impls/vpg/vpg.py CHANGED
@@ -10,7 +10,7 @@ from typing import Optional, Sequence, TypeVar
10
 
11
  from rl_algo_impls.shared.algorithm import Algorithm
12
  from rl_algo_impls.shared.callbacks.callback import Callback
13
- from rl_algo_impls.shared.gae import compute_rtg_and_advantage, compute_advantage
14
  from rl_algo_impls.shared.trajectory import Trajectory, TrajectoryAccumulator
15
  from rl_algo_impls.vpg.policy import VPGActorCritic
16
  from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
@@ -58,7 +58,6 @@ class VanillaPolicyGradient(Algorithm):
58
  max_grad_norm: float = 10.0,
59
  n_steps: int = 4_000,
60
  sde_sample_freq: int = -1,
61
- update_rtg_between_v_iters: bool = False,
62
  ent_coef: float = 0.0,
63
  ) -> None:
64
  super().__init__(policy, env, device, tb_writer)
@@ -73,7 +72,6 @@ class VanillaPolicyGradient(Algorithm):
73
  self.n_steps = n_steps
74
  self.train_v_iters = train_v_iters
75
  self.sde_sample_freq = sde_sample_freq
76
- self.update_rtg_between_v_iters = update_rtg_between_v_iters
77
 
78
  self.ent_coef = ent_coef
79
 
@@ -118,7 +116,7 @@ class VanillaPolicyGradient(Algorithm):
118
  act = torch.as_tensor(
119
  np.concatenate([np.array(t.act) for t in trajectories]), device=self.device
120
  )
121
- rtg, adv = compute_rtg_and_advantage(
122
  trajectories, self.policy, self.gamma, self.gae_lambda, self.device
123
  )
124
 
@@ -135,10 +133,6 @@ class VanillaPolicyGradient(Algorithm):
135
 
136
  v_loss = 0
137
  for _ in range(self.train_v_iters):
138
- if self.update_rtg_between_v_iters:
139
- rtg = compute_advantage(
140
- trajectories, self.policy, self.gamma, self.gae_lambda, self.device
141
- )
142
  v = self.policy.v(obs)
143
  v_loss = ((v - rtg) ** 2).mean()
144
 
 
10
 
11
  from rl_algo_impls.shared.algorithm import Algorithm
12
  from rl_algo_impls.shared.callbacks.callback import Callback
13
+ from rl_algo_impls.shared.gae import compute_rtg_and_advantage_from_trajectories
14
  from rl_algo_impls.shared.trajectory import Trajectory, TrajectoryAccumulator
15
  from rl_algo_impls.vpg.policy import VPGActorCritic
16
  from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
 
58
  max_grad_norm: float = 10.0,
59
  n_steps: int = 4_000,
60
  sde_sample_freq: int = -1,
 
61
  ent_coef: float = 0.0,
62
  ) -> None:
63
  super().__init__(policy, env, device, tb_writer)
 
72
  self.n_steps = n_steps
73
  self.train_v_iters = train_v_iters
74
  self.sde_sample_freq = sde_sample_freq
 
75
 
76
  self.ent_coef = ent_coef
77
 
 
116
  act = torch.as_tensor(
117
  np.concatenate([np.array(t.act) for t in trajectories]), device=self.device
118
  )
119
+ rtg, adv = compute_rtg_and_advantage_from_trajectories(
120
  trajectories, self.policy, self.gamma, self.gae_lambda, self.device
121
  )
122
 
 
133
 
134
  v_loss = 0
135
  for _ in range(self.train_v_iters):
 
 
 
 
136
  v = self.policy.v(obs)
137
  v_loss = ((v - rtg) ** 2).mean()
138
 
rl_algo_impls/wrappers/action_mask_wrapper.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from gym.vector.vector_env import VectorEnv
4
+ from stable_baselines3.common.vec_env import VecEnv as SBVecEnv
5
+ from typing import Optional
6
+
7
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecotarableWrapper
8
+
9
+
10
+ class IncompleteArrayError(Exception):
11
+ pass
12
+
13
+
14
+ class ActionMaskWrapper(VecotarableWrapper):
15
+ def action_masks(self) -> Optional[np.ndarray]:
16
+ envs = getattr(self.env.unwrapped, "envs")
17
+ assert (
18
+ envs
19
+ ), f"{self.__class__.__name__} expects to wrap synchronous vectorized env"
20
+ masks = [getattr(e.unwrapped, "action_mask") for e in envs]
21
+ assert all(m is not None for m in masks)
22
+ return np.array(masks, dtype=np.bool8)
saved_models/ppo-MicrortsAttackShapedReward-v1-S2-best/model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:114140423904ae116b52975d9de49633eb55ff1c790fe97dded5c8a64f63484b
3
  size 294255
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0eff8cfad9e5927093804e190813c112c9484c2ddbf8c97abc7713bf4a01d27c
3
  size 294255
scripts/benchmark.sh CHANGED
@@ -8,6 +8,7 @@ do
8
  -e) envs=$2 ;;
9
  --procgen) procgen=t ;;
10
  --microrts) microrts=t ;;
 
11
  esac
12
  shift
13
  done
@@ -58,6 +59,13 @@ for algo in $(echo $algos); do
58
  "MicrortsRandomEnemyShapedReward3-v1"
59
  )
60
  algo_envs=${MICRORTS_ENVS[*]}
 
 
 
 
 
 
 
61
  elif [ -z "$envs" ]; then
62
  if [ "$algo" = "dqn" ]; then
63
  BENCHMARK_ENVS="${DISCRETE_ENVS[*]}"
 
8
  -e) envs=$2 ;;
9
  --procgen) procgen=t ;;
10
  --microrts) microrts=t ;;
11
+ --no-mask-microrts) no_mask_microrts=t ;;
12
  esac
13
  shift
14
  done
 
59
  "MicrortsRandomEnemyShapedReward3-v1"
60
  )
61
  algo_envs=${MICRORTS_ENVS[*]}
62
+ elif [ "$no_mask_microrts" = "t" ]; then
63
+ NO_MASK_MICRORTS_ENVS=(
64
+ "MicrortsMining-v1-NoMask"
65
+ "MicrortsAttackShapedReward-v1-NoMask"
66
+ "MicrortsRandomEnemyShapedReward3-v1-NoMask"
67
+ )
68
+ algo_envs=${NO_MASK_MICRORTS_ENVS[*]}
69
  elif [ -z "$envs" ]; then
70
  if [ "$algo" = "dqn" ]; then
71
  BENCHMARK_ENVS="${DISCRETE_ENVS[*]}"