diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e63e24b004558acfa8c9327981ecb060b4ff0d00 --- /dev/null +++ b/.gitignore @@ -0,0 +1,161 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# .idea folder +.idea/ + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +venv/ +/ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# customize +log/ +MUJOCO_LOG.TXT +*.pth +.vscode/ +.DS_Store +*.zip +*.pstats +*.swp +*.pkl +*.hdf5 +wandb/ +videos/ + +# might be needed for IDE plugins that can't read ruff config +.flake8 + +docs/notebooks/_build/ +docs/conf.py + +# temporary scripts (for ad-hoc testing), temp folder +/temp +/temp*.py \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aa00e7474b88d50b06b53e48219a367faee29203 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,52 @@ +default_install_hook_types: [commit-msg, pre-commit] +default_stages: [commit, manual] +fail_fast: false +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-added-large-files + - repo: local + hooks: + - id: ruff + name: ruff + entry: poetry run ruff + require_serial: true + language: system + types: [python] + - id: ruff-nb + name: ruff-nb + entry: poetry run nbqa ruff . + require_serial: true + language: system + pass_filenames: false + types: [python] + - id: black + name: black + entry: poetry run black + require_serial: true + language: system + types: [python] + - id: poetry-check + name: poetry check + entry: poetry check + language: system + files: pyproject.toml + pass_filenames: false + - id: poetry-lock-check + name: poetry lock check + entry: poetry check + args: [--lock] + language: system + pass_filenames: false + - id: mypy + name: mypy + entry: poetry run mypy tianshou examples test + # filenames should not be passed as they would collide with the config in pyproject.toml + pass_filenames: false + files: '^tianshou(/[^/]*)*/[^/]*\.py$' + language: system + - id: mypy-nb + name: mypy-nb + entry: poetry run nbqa mypy + language: system diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a93e455c1a687b558065ee4d638272a6d1f7326c --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,23 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.11" + commands: + - mkdir -p $READTHEDOCS_OUTPUT/html + - curl -sSL https://install.python-poetry.org | python - +# - ~/.local/bin/poetry config virtualenvs.create false + - ~/.local/bin/poetry install --with dev +## Same as poe tasks, but unfortunately poe doesn't work with poetry not creating virtualenvs + - ~/.local/bin/poetry run python docs/autogen_rst.py + - ~/.local/bin/poetry run which jupyter-book + - ~/.local/bin/poetry run python docs/create_toc.py + - ~/.local/bin/poetry run jupyter-book config sphinx docs/ + - ~/.local/bin/poetry run sphinx-build -W -b html docs $READTHEDOCS_OUTPUT/html diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..b3fb71f67a9da7c8e63ff8780ecb849299aa2e20 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,3 @@ +# Contributing to Tianshou + +Please refer to [tianshou.readthedocs.io/en/latest/contributing.html](https://tianshou.readthedocs.io/en/latest/contributing.html). diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..322a77c33dd5800d19ad169626e446de6ad58863 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Tianshou contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..1aba38f67a2211cf5b09466d7b411206cb7223bf --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include LICENSE diff --git a/README.md b/README.md index b9d45d6fc145755d3585e2b349940e4ffe70e2de..c1a6c395215ee09664b0d8eba98c5df286ee66c8 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ sdk_version: 4.37.2 app_file: app.py pinned: false license: mit +python_version: 3.11 --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..488e7688e5d954dbe10360b07f87c9f14289351a --- /dev/null +++ b/app.py @@ -0,0 +1,440 @@ +import gradio as gr +import datetime +import os +import pprint +import sys + +import numpy as np +import torch +from examples.atari.atari_network import C51 +from examples.atari.atari_wrapper import wrap_deepmind + +from tianshou.data import Collector +from examples.atari.tianshou.policy import C51Policy + +import gymnasium as gym + +from examples.atari.tianshou.env.venvs import DummyVectorEnv + + + +from examples.atari.tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction, FullQuantileFunctionRainbow +from examples.atari.atari_network import DQN +from examples.atari.tianshou.policy import FQFPolicy,FQF_RainbowPolicy + + +def seed(self, seed): + np.random.seed(seed) + +# Define configuration parameters +config_c51 = { + "task": "PongNoFrameskip-v4", + "seed": 3128, + "scale_obs": 0, + "eps_test": 0.005, + "eps_train": 1.0, + "eps_train_final": 0.05, + "buffer_size": 100000, + "lr": 0.0001, + "gamma": 0.99, + "num_atoms": 51, + "v_min": -10.0, + "v_max": 10.0, + "n_step": 3, + "target_update_freq": 500, + "epoch": 100, + "step_per_epoch": 100000, + "step_per_collect": 10, + "update_per_step": 0.1, + "batch_size": 32, + "training_num": 1, + "test_num": 1, + "logdir": "log", + "render": 0.0, + "device": "cuda" if torch.cuda.is_available() else "cpu", + "frames_stack": 4, + "resume_path": "examples/atari/c51_pong.pth", + "resume_id": "", + "logger": "tensorboard", + "wandb_project": "atari.benchmark", + "watch": True, + "save_buffer_name": None +} + + +config_fqf = { + "task": "SpaceInvadersNoFrameskip-v4", + "seed": 3128, + "scale_obs": 0, + "eps_test": 0.005, + "eps_train": 1.0, + "eps_train_final": 0.05, + "buffer_size": 100000, + "lr": 5e-5, + "fraction_lr": 2.5e-9, + "gamma": 0.99, + "num_fractions": 32, + "num_cosines": 64, + "ent_coef": 10.0, + "hidden_sizes": [512], + "n_step": 3, + "target_update_freq": 500, + "epoch": 100, + "step_per_epoch": 100000, + "step_per_collect": 10, + "update_per_step": 0.1, + "batch_size": 32, + "training_num": 1, + "test_num": 1, + "logdir": "log", + "render": 0.0, + "device": "cuda" if torch.cuda.is_available() else "cpu", + "frames_stack": 4, + "resume_path": "fqf_pong.pth", + "resume_id": None, + "logger": "tensorboard", + "wandb_project": "atari.benchmark", + "watch": True, + "save_buffer_name": None, +} + + +config_fqf_r = { + "task": "PongNoFrameskip-v4", + "algo_name": "RainbowFQF", + "seed": 3128, + "scale_obs": 0, + "eps_test": 0.005, + "eps_train": 1.0, + "eps_train_final": 0.05, + "buffer_size": 100000, + "lr": 5e-5, + "fraction_lr": 2.5e-9, + "gamma": 0.99, + "num_fractions": 32, + "num_cosines": 64, + "ent_coef": 10.0, + "hidden_sizes": [512], + "n_step": 3, + "target_update_freq": 500, + "epoch": 100, + "step_per_epoch": 100000, + "step_per_collect": 10, + "update_per_step": 0.1, + "batch_size": 32, + "training_num": 1, + "test_num": 1, + "logdir": "log", + "no_dueling": False, + "no_noisy": False, + "no_priority": False, + "noisy_std": 0.1, + "alpha": 0.5, + "beta": 0.4, + "beta_final": 1.0, + "beta_anneal_step": 5000000, + "no_weight_norm": False, + "render": 0.0, + "device": "cuda" if torch.cuda.is_available() else "cpu", + "frames_stack": 4, + "resume_path": None, + "resume_id": None, + "logger": "tensorboard", + "wandb_project": "atari.benchmark", + "watch": False, + "save_buffer_name": None, + "per": False, +} + + +def test_c51(config : dict) -> None: + # _, _, test_envs,_ = make_atari_watch_env( + # config["task"], + # config["seed"], + # config["training_num"], + # config["test_num"], + # scale=config["scale_obs"], + # frame_stack=config["frames_stack"], + # ) + env_wrap = gym.make(config["task"],render_mode = 'rgb_array') + env_deep = wrap_deepmind(env_wrap) + rec_env = DummyVectorEnv( + [ + lambda: gym.wrappers.RecordVideo( + env_deep, + video_folder='video-app/' + ) + ] + ) + state_shape = env_deep.observation_space.shape or env_deep.observation_space.n + action_shape = env_deep.action_space.shape or env_deep.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", state_shape) + print("Actions shape:", action_shape) + # seed + np.random.seed(config["seed"]) + torch.manual_seed(config["seed"]) + rec_env.seed(config["seed"]) + # test_envs.seed(config["seed"]) + + + net = C51(*state_shape, action_shape, config["num_atoms"], config["device"]) + optim = torch.optim.Adam(net.parameters(), lr=config["lr"]) + # define policy + policy = C51Policy( + model=net, + optim=optim, + discount_factor=config["gamma"], + action_space=env_deep.action_space, + num_atoms=config["num_atoms"], + v_min=config["v_min"], + v_max=config["v_max"], + estimation_step=config["n_step"], + target_update_freq=config["target_update_freq"], + ).to(config["device"]) + # load a previous policy + if config["resume_path"]: + policy.load_state_dict(torch.load(config["resume_path"], map_location=config["device"])) + print("Loaded agent from:", config["resume_path"]) + + + + collector = Collector(policy, rec_env, exploration_noise=True) + # result = collector.collect(n_episode=config["test_num"], render=config["render"]) + result = collector.collect(n_episode=config["test_num"]) + # Collector(policy, rec_env, exploration_noise=True).collect(n_episode=config["test_num"]) + rec_env.close() + result.pprint_asdict() + return result + +def test_FQF(config : dict) -> None: + + # _, _, test_envs,_ = make_atari_watch_env( + # config["task"], + # config["seed"], + # config["training_num"], + # config["test_num"], + # scale=config["scale_obs"], + # frame_stack=config["frames_stack"], + # ) + + # env_wrap = gym.make(config["task"],render_mode = 'rgb_array') + env_deep = wrap_deepmind(gym.make(config["task"],render_mode = 'rgb_array')) + rec_env = DummyVectorEnv( + [ + lambda: gym.wrappers.RecordVideo( + env_deep, + video_folder='video-app/' + ) + ] + ) + + state_shape = env_deep.observation_space.shape or env_deep.observation_space.n + action_shape = env_deep.action_space.shape or env_deep.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", state_shape) + print("Actions shape:", action_shape) + # seed + np.random.seed(config["seed"]) + torch.manual_seed(config["seed"]) + rec_env.seed(config["seed"]) + feature_net = DQN(*state_shape, action_shape, config["device"], features_only=True) + + # Create FullQuantileFunction net + net = FullQuantileFunction( + feature_net, + action_shape, + config["hidden_sizes"], + config["num_cosines"], + ).to(config["device"]) + + # Create Adam optimizer + optim = torch.optim.Adam(net.parameters(), lr=config["lr"]) + + # Create FractionProposalNetwork + fraction_net = FractionProposalNetwork(config["num_fractions"], net.input_dim) + + # Create RMSprop optimizer for fraction_net + fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=config["fraction_lr"]) + + # Define policy using FQFPolicy + policy: FQFPolicy = FQFPolicy( + model=net, + optim=optim, + fraction_model=fraction_net, + fraction_optim=fraction_optim, + action_space=env_deep.action_space, + discount_factor=config["gamma"], + num_fractions=config["num_fractions"], + ent_coef=config["ent_coef"], + estimation_step=config["n_step"], + target_update_freq=config["target_update_freq"], + ).to(config["device"]) + + # load a previous policy + if config["resume_path"]: + policy.load_state_dict(torch.load(config["resume_path"], map_location=config["device"])) + print("Loaded agent from:", config["resume_path"]) + + + collector = Collector(policy, rec_env, exploration_noise=True) + + # result = collector.collect(n_episode=config["test_num"], render=config["render"]) + result = collector.collect(n_episode=config["test_num"]) + # Collector(policy, rec_env, exploration_noise=True).collect(n_episode=config["test_num"]) + rec_env.close() + result.pprint_asdict() + return result + + + +def test_fqf_rainbow(config: dict) -> None: + # _, _, test_envs,_ = make_atari_watch_env( + # config['task'], + # config['seed'], + # config['training_num'], + # config['test_num'], + # scale=config['scale_obs'], + # frame_stack=config['frames_stack'], + # ) + + env_deep = wrap_deepmind(gym.make(config["task"],render_mode = 'rgb_array')) + rec_env = DummyVectorEnv( + [ + lambda: gym.wrappers.RecordVideo( + env_deep, + video_folder='video-app/' + ) + ] + ) + + + config['state_shape'] = env_deep.observation_space.shape or env_deep.observation_space.n + config['action_shape'] = env_deep.action_space.shape or env_deep.action_space.n + + # print(env_deep.action_space) + # print(test_envs.action_space) + + # should be N_FRAMES x H x W + # print("Observations shape:", config['state_shape']) + # print("Actions shape:", config['action_shape']) + # seed + np.random.seed(config['seed']) + torch.manual_seed(config['seed']) + # test_envs.seed(config['seed']) + rec_env.seed(config['seed']) + # define model + feature_net = DQN(*config['state_shape'], config['action_shape'], config['device'], features_only=True) + preprocess_net_output_dim = feature_net.output_dim # Ensure this is correctly set + # print(preprocess_net_output_dim) + net = FullQuantileFunctionRainbow( + preprocess_net=feature_net, + action_shape=config['action_shape'], + hidden_sizes=config['hidden_sizes'], + num_cosines=config['num_cosines'], + preprocess_net_output_dim=preprocess_net_output_dim, + device=config['device'], + noisy_std=config['noisy_std'], + is_noisy=not config['no_noisy'], # Set to True to use noisy layers + is_dueling=not config['no_dueling'], # Set to True to use dueling layers + ).to(config['device']) + # print(net) + optim = torch.optim.Adam(net.parameters(), lr=config['lr']) + fraction_net = FractionProposalNetwork(config['num_fractions'], net.input_dim) + fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=config['fraction_lr']) + # define policy + policy: FQF_RainbowPolicy = FQF_RainbowPolicy( + model=net, + optim=optim, + fraction_model=fraction_net, + fraction_optim=fraction_optim, + action_space=env_deep.action_space, + discount_factor=config['gamma'], + num_fractions=config['num_fractions'], + ent_coef=config['ent_coef'], + estimation_step=config['n_step'], + target_update_freq=config['target_update_freq'], + is_noisy=not config['no_noisy'] + ).to(config['device']) + # load a previous policy + if config['resume_path']: + policy.load_state_dict(torch.load(config['resume_path'], map_location=config['device'])) + print("Loaded agent from:", config['resume_path']) + # policy.eval() + test_collector = Collector(policy, rec_env, exploration_noise=True) + result = test_collector.collect(n_episode=config["test_num"]) + + #replay + # Collector(policy, rec_env, exploration_noise=True).collect(n_episode=1) + + rec_env.close() + result.pprint_asdict() + return result + + +# Define the function to display choices and mean scores +def display_choice(algo, game,slider): + # Dictionary to store mean scores for each algorithm and game + match algo: + case "C51": + match game: + case "Freeway": + config_c51["resume_path"] = "models/c51_freeway.pth" + config_c51["task"] = "FreewayNoFrameskip-v4" + mean_scores = test_c51(config_c51) + + case "Pong" : + return 19 + + case "FQF": + match game: + case "Freeway": + config_fqf["resume_path"] = "models/fqf_freeway.pth" + config_fqf["task"] = "FreewayNoFrameskip-v4" + mean_scores = test_FQF(config_fqf) + + case "Pong" : + return 20 + + case "FQF-Rainbow": + match game: + case "Freeway": + config_fqf_r["resume_path"] = "models/fqf-rainbow_freeway.pth" + config_fqf_r["task"] = "FreewayNoFrameskip-v4" + mean_scores = test_fqf_rainbow(config_fqf_r) + + case "Pong" : + return 21 + + + + # Calculate or fetch the mean score for the selected combination + mean_score = mean_scores.returns_stat.mean + + # Return the selected options and the mean score + # return f"Your {algo} agent finished {game} with a \nMean Score of ##{mean_score}" + return [mean_score,"video-app/rl-video-episode-0.mp4"] + +# Define the choices for the radio buttons +algos = ["C51", "FQF", "FQF-Rainbow"] +# games = ["Pong", "Space Invaders","Freeway","MsPacman"] +games = ["Freeway"] + + +# Create a Gradio Interface +demo = gr.Interface( + fn=display_choice, # Function to call when an option is selected + inputs=[gr.Radio(algos,label="Algorithm"), gr.Radio(games, label="Game"),gr.Slider(maximum=100,label="Seed")], # Radio buttons with the defined choices + outputs=[gr.Textbox(label="Score"),gr.Video(autoplay=True,height=480,width=480,label="Replay")], + title="Distributional RL Algorithms Benchmark", + description="Select the DRL agent and the game of your choice", + theme="soft", + # examples=[["FQF","Pong",31], + # ["C51","Space Invaders",31], + # ["FQF-Rainbow","Freeway",31] + # ] +) + +# Launch the Gradio app +if __name__ == "__main__": + demo.launch(share=False) diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/atari/README.md b/examples/atari/README.md new file mode 100644 index 0000000000000000000000000000000000000000..62e58487b5d8cb90c7c84bd9ba3a76357c2d25bb --- /dev/null +++ b/examples/atari/README.md @@ -0,0 +1,137 @@ +# Atari Environment + +## EnvPool + +We highly recommend using envpool to run the following experiments. To install, in a linux machine, type: + +```bash +pip install envpool +``` + +After that, `atari_wrapper` will automatically switch to envpool's Atari env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation, and it's behavior is consistent to that approach (OpenAI wrapper), which will describe below. + +For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool). + +## ALE-py + +The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). + +The env wrapper is a crucial thing. Without wrappers, the agent cannot perform well enough on Atari games. Many existing RL codebases use [OpenAI wrapper](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py), but it is not the original DeepMind version ([related issue](https://github.com/openai/baselines/issues/240)). Dopamine has a different [wrapper](https://github.com/google/dopamine/blob/master/dopamine/discrete_domains/atari_lib.py) but unfortunately it cannot work very well in our codebase. + +# DQN (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | time cost | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | ------------------- | +| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch-size 64` | ~30 min (~15 epoch) | +| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | +| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test-num 100` | 3~4h (100 epoch) | +| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | +| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | +| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | +| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | + +Note: The `eps_train_final` and `eps_test` in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed. + +We haven't tuned this result to the best, so have fun with playing these hyperparameters! + +# C51 (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20 | ![](results/c51/Pong_rew.png) | `python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64` | +| BreakoutNoFrameskip-v4 | 536.6 | ![](results/c51/Breakout_rew.png) | `python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1` | +| EnduroNoFrameskip-v4 | 1032 | ![](results/c51/Enduro_rew.png) | `python3 atari_c51.py --task "EnduroNoFrameskip-v4 " ` | +| QbertNoFrameskip-v4 | 16245 | ![](results/c51/Qbert_rew.png) | `python3 atari_c51.py --task "QbertNoFrameskip-v4"` | +| MsPacmanNoFrameskip-v4 | 3133 | ![](results/c51/MsPacman_rew.png) | `python3 atari_c51.py --task "MsPacmanNoFrameskip-v4"` | +| SeaquestNoFrameskip-v4 | 6226 | ![](results/c51/Seaquest_rew.png) | `python3 atari_c51.py --task "SeaquestNoFrameskip-v4"` | +| SpaceInvadersNoFrameskip-v4 | 988.5 | ![](results/c51/SpaceInvader_rew.png) | `python3 atari_c51.py --task "SpaceInvadersNoFrameskip-v4"` | + +Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper. + +# QRDQN (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20 | ![](results/qrdqn/Pong_rew.png) | `python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch-size 64` | +| BreakoutNoFrameskip-v4 | 409.2 | ![](results/qrdqn/Breakout_rew.png) | `python3 atari_qrdqn.py --task "BreakoutNoFrameskip-v4" --n-step 1` | +| EnduroNoFrameskip-v4 | 1055.9 | ![](results/qrdqn/Enduro_rew.png) | `python3 atari_qrdqn.py --task "EnduroNoFrameskip-v4"` | +| QbertNoFrameskip-v4 | 14990 | ![](results/qrdqn/Qbert_rew.png) | `python3 atari_qrdqn.py --task "QbertNoFrameskip-v4"` | +| MsPacmanNoFrameskip-v4 | 2886 | ![](results/qrdqn/MsPacman_rew.png) | `python3 atari_qrdqn.py --task "MsPacmanNoFrameskip-v4"` | +| SeaquestNoFrameskip-v4 | 5676 | ![](results/qrdqn/Seaquest_rew.png) | `python3 atari_qrdqn.py --task "SeaquestNoFrameskip-v4"` | +| SpaceInvadersNoFrameskip-v4 | 938 | ![](results/qrdqn/SpaceInvader_rew.png) | `python3 atari_qrdqn.py --task "SpaceInvadersNoFrameskip-v4"` | + +# IQN (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.3 | ![](results/iqn/Pong_rew.png) | `python3 atari_iqn.py --task "PongNoFrameskip-v4" --batch-size 64` | +| BreakoutNoFrameskip-v4 | 496.7 | ![](results/iqn/Breakout_rew.png) | `python3 atari_iqn.py --task "BreakoutNoFrameskip-v4" --n-step 1` | +| EnduroNoFrameskip-v4 | 1545 | ![](results/iqn/Enduro_rew.png) | `python3 atari_iqn.py --task "EnduroNoFrameskip-v4"` | +| QbertNoFrameskip-v4 | 15342.5 | ![](results/iqn/Qbert_rew.png) | `python3 atari_iqn.py --task "QbertNoFrameskip-v4"` | +| MsPacmanNoFrameskip-v4 | 2915 | ![](results/iqn/MsPacman_rew.png) | `python3 atari_iqn.py --task "MsPacmanNoFrameskip-v4"` | +| SeaquestNoFrameskip-v4 | 4874 | ![](results/iqn/Seaquest_rew.png) | `python3 atari_iqn.py --task "SeaquestNoFrameskip-v4"` | +| SpaceInvadersNoFrameskip-v4 | 1498.5 | ![](results/iqn/SpaceInvaders_rew.png) | `python3 atari_iqn.py --task "SpaceInvadersNoFrameskip-v4"` | + +# FQF (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.7 | ![](results/fqf/Pong_rew.png) | `python3 atari_fqf.py --task "PongNoFrameskip-v4" --batch-size 64` | +| BreakoutNoFrameskip-v4 | 517.3 | ![](results/fqf/Breakout_rew.png) | `python3 atari_fqf.py --task "BreakoutNoFrameskip-v4" --n-step 1` | +| EnduroNoFrameskip-v4 | 2240.5 | ![](results/fqf/Enduro_rew.png) | `python3 atari_fqf.py --task "EnduroNoFrameskip-v4"` | +| QbertNoFrameskip-v4 | 16172.5 | ![](results/fqf/Qbert_rew.png) | `python3 atari_fqf.py --task "QbertNoFrameskip-v4"` | +| MsPacmanNoFrameskip-v4 | 2429 | ![](results/fqf/MsPacman_rew.png) | `python3 atari_fqf.py --task "MsPacmanNoFrameskip-v4"` | +| SeaquestNoFrameskip-v4 | 10775 | ![](results/fqf/Seaquest_rew.png) | `python3 atari_fqf.py --task "SeaquestNoFrameskip-v4"` | +| SpaceInvadersNoFrameskip-v4 | 2482 | ![](results/fqf/SpaceInvaders_rew.png) | `python3 atari_fqf.py --task "SpaceInvadersNoFrameskip-v4"` | + +# Rainbow (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 21 | ![](results/rainbow/Pong_rew.png) | `python3 atari_rainbow.py --task "PongNoFrameskip-v4" --batch-size 64` | +| BreakoutNoFrameskip-v4 | 684.6 | ![](results/rainbow/Breakout_rew.png) | `python3 atari_rainbow.py --task "BreakoutNoFrameskip-v4" --n-step 1` | +| EnduroNoFrameskip-v4 | 1625.9 | ![](results/rainbow/Enduro_rew.png) | `python3 atari_rainbow.py --task "EnduroNoFrameskip-v4"` | +| QbertNoFrameskip-v4 | 16192.5 | ![](results/rainbow/Qbert_rew.png) | `python3 atari_rainbow.py --task "QbertNoFrameskip-v4"` | +| MsPacmanNoFrameskip-v4 | 3101 | ![](results/rainbow/MsPacman_rew.png) | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` | +| SeaquestNoFrameskip-v4 | 2126 | ![](results/rainbow/Seaquest_rew.png) | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` | +| SpaceInvadersNoFrameskip-v4 | 1794.5 | ![](results/rainbow/SpaceInvaders_rew.png) | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` | + +# PPO (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.2 | ![](results/ppo/Pong_rew.png) | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` | +| BreakoutNoFrameskip-v4 | 441.8 | ![](results/ppo/Breakout_rew.png) | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` | +| EnduroNoFrameskip-v4 | 1245.4 | ![](results/ppo/Enduro_rew.png) | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` | +| QbertNoFrameskip-v4 | 17395 | ![](results/ppo/Qbert_rew.png) | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` | +| MsPacmanNoFrameskip-v4 | 2098 | ![](results/ppo/MsPacman_rew.png) | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` | +| SeaquestNoFrameskip-v4 | 882 | ![](results/ppo/Seaquest_rew.png) | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 1e-4` | +| SpaceInvadersNoFrameskip-v4 | 1340.5 | ![](results/ppo/SpaceInvaders_rew.png) | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` | + +# SAC (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.1 | ![](results/discrete_sac/Pong_rew.png) | `python3 atari_sac.py --task "PongNoFrameskip-v4"` | +| BreakoutNoFrameskip-v4 | 211.2 | ![](results/discrete_sac/Breakout_rew.png) | `python3 atari_sac.py --task "BreakoutNoFrameskip-v4" --n-step 1 --actor-lr 1e-4 --critic-lr 1e-4` | +| EnduroNoFrameskip-v4 | 1290.7 | ![](results/discrete_sac/Enduro_rew.png) | `python3 atari_sac.py --task "EnduroNoFrameskip-v4"` | +| QbertNoFrameskip-v4 | 13157.5 | ![](results/discrete_sac/Qbert_rew.png) | `python3 atari_sac.py --task "QbertNoFrameskip-v4"` | +| MsPacmanNoFrameskip-v4 | 3836 | ![](results/discrete_sac/MsPacman_rew.png) | `python3 atari_sac.py --task "MsPacmanNoFrameskip-v4"` | +| SeaquestNoFrameskip-v4 | 1772 | ![](results/discrete_sac/Seaquest_rew.png) | `python3 atari_sac.py --task "SeaquestNoFrameskip-v4"` | +| SpaceInvadersNoFrameskip-v4 | 649 | ![](results/discrete_sac/SpaceInvaders_rew.png) | `python3 atari_sac.py --task "SpaceInvadersNoFrameskip-v4"` | diff --git a/examples/atari/__init__.py b/examples/atari/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py new file mode 100644 index 0000000000000000000000000000000000000000..d611ab196293f8f03729aa258ea747bbe55b7a07 --- /dev/null +++ b/examples/atari/atari_c51.py @@ -0,0 +1,218 @@ +import argparse +import datetime +import os +import pprint +import sys + +import numpy as np +import torch +from atari_network import C51 +from atari_wrapper import make_atari_env + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault +from tianshou.policy import C51Policy +from tianshou.policy.base import BasePolicy +from tianshou.trainer import OffpolicyTrainer + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.005) + parser.add_argument("--eps-train", type=float, default=1.0) + parser.add_argument("--eps-train-final", type=float, default=0.05) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=0.0001) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--num-atoms", type=int, default=51) + parser.add_argument("--v-min", type=float, default=-10.0) + parser.add_argument("--v-max", type=float, default=10.0) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.0) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only", + ) + parser.add_argument("--save-buffer-name", type=str, default=None) + return parser.parse_args() + + +def test_c51(args: argparse.Namespace = get_args()) -> None: + env, train_envs, test_envs = make_atari_env( + args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, + frame_stack=args.frames_stack, + ) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + # define model + net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + # define policy + policy: C51Policy = C51Policy( + model=net, + optim=optim, + discount_factor=args.gamma, + action_space=env.action_space, + num_atoms=args.num_atoms, + v_min=args.v_min, + v_max=args.v_max, + estimation_step=args.n_step, + target_update_freq=args.target_update_freq, + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "c51" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + logger_factory = LoggerFactoryDefault() + if args.logger == "wandb": + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + def stop_fn(mean_rewards: float) -> bool: + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + if "Pong" in args.task: + return mean_rewards >= 20 + return False + + def train_fn(epoch: int, env_step: int) -> None: + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) + else: + eps = args.eps_train_final + policy.set_eps(eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) + + def test_fn(epoch: int, env_step: int | None) -> None: + policy.set_eps(args.eps_test) + + # watch agent's performance + def watch() -> None: + print("Setup test envs ...") + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + result.pprint_asdict() + + if args.watch: + watch() + sys.exit(0) + + # test train_collector and start filling replay buffer + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = OffpolicyTrainer( + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ).run() + + pprint.pprint(result) + watch() + + +if __name__ == "__main__": + test_c51(get_args()) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb9bccce9652201cb6e00aedec630f91828f214 --- /dev/null +++ b/examples/atari/atari_dqn.py @@ -0,0 +1,262 @@ +import argparse +import datetime +import os +import pprint +import sys + +import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import make_atari_env + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault +from tianshou.policy import DQNPolicy +from tianshou.policy.base import BasePolicy +from tianshou.policy.modelbased.icm import ICMPolicy +from tianshou.trainer import OffpolicyTrainer +from tianshou.utils.net.discrete import IntrinsicCuriosityModule + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.005) + parser.add_argument("--eps-train", type=float, default=1.0) + parser.add_argument("--eps-train-final", type=float, default=0.05) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=0.0001) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.0) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only", + ) + parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument( + "--icm-lr-scale", + type=float, + default=0.0, + help="use intrinsic curiosity module with this lr scale", + ) + parser.add_argument( + "--icm-reward-scale", + type=float, + default=0.01, + help="scaling factor for intrinsic curiosity reward", + ) + parser.add_argument( + "--icm-forward-loss-weight", + type=float, + default=0.2, + help="weight for the forward model loss in ICM", + ) + return parser.parse_args() + + +def main(args: argparse.Namespace = get_args()) -> None: + env, train_envs, test_envs = make_atari_env( + args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, + frame_stack=args.frames_stack, + ) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + # define model + net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + # define policy + policy: DQNPolicy | ICMPolicy + policy = DQNPolicy( + model=net, + optim=optim, + action_space=env.action_space, + discount_factor=args.gamma, + estimation_step=args.n_step, + target_update_freq=args.target_update_freq, + ) + if args.icm_lr_scale > 0: + feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + action_dim = np.prod(args.action_shape) + feature_dim = feature_net.output_dim + icm_net = IntrinsicCuriosityModule( + feature_net.net, + feature_dim, + action_dim, + hidden_sizes=[512], + device=args.device, + ) + icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) + policy = ICMPolicy( + policy=policy, + model=icm_net, + optim=icm_optim, + action_space=env.action_space, + lr_scale=args.icm_lr_scale, + reward_scale=args.icm_reward_scale, + forward_loss_weight=args.icm_forward_loss_weight, + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "dqn_icm" if args.icm_lr_scale > 0 else "dqn" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + logger_factory = LoggerFactoryDefault() + if args.logger == "wandb": + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + def stop_fn(mean_rewards: float) -> bool: + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + if "Pong" in args.task: + return mean_rewards >= 20 + return False + + def train_fn(epoch: int, env_step: int) -> None: + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) + else: + eps = args.eps_train_final + policy.set_eps(eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) + + def test_fn(epoch: int, env_step: int | None) -> None: + policy.set_eps(args.eps_test) + + def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: + # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") + torch.save({"model": policy.state_dict()}, ckpt_path) + return ckpt_path + + # watch agent's performance + def watch() -> None: + print("Setup test envs ...") + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + result.pprint_asdict() + + if args.watch: + watch() + sys.exit(0) + + # test train_collector and start filling replay buffer + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = OffpolicyTrainer( + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + resume_from_log=args.resume_id is not None, + save_checkpoint_fn=save_checkpoint_fn, + ).run() + + pprint.pprint(result) + watch() + + +if __name__ == "__main__": + main(get_args()) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py new file mode 100644 index 0000000000000000000000000000000000000000..aa76983bea6ad590063f21e433938ff5139b9df5 --- /dev/null +++ b/examples/atari/atari_dqn_hl.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 + +import os + +from examples.atari.atari_network import ( + IntermediateModuleFactoryAtariDQN, + IntermediateModuleFactoryAtariDQNFeatures, +) +from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback +from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.experiment import ( + DQNExperimentBuilder, + ExperimentConfig, +) +from tianshou.highlevel.params.policy_params import DQNParams +from tianshou.highlevel.params.policy_wrapper import ( + PolicyWrapperFactoryIntrinsicCuriosity, +) +from tianshou.highlevel.trainer import ( + EpochTestCallbackDQNSetEps, + EpochTrainCallbackDQNEpsLinearDecay, +) +from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag + + +def main( + experiment_config: ExperimentConfig, + task: str = "PongNoFrameskip-v4", + scale_obs: bool = False, + eps_test: float = 0.005, + eps_train: float = 1.0, + eps_train_final: float = 0.05, + buffer_size: int = 100000, + lr: float = 0.0001, + gamma: float = 0.99, + n_step: int = 3, + target_update_freq: int = 500, + epoch: int = 100, + step_per_epoch: int = 100000, + step_per_collect: int = 10, + update_per_step: float = 0.1, + batch_size: int = 32, + training_num: int = 10, + test_num: int = 10, + frames_stack: int = 4, + save_buffer_name: str | None = None, # TODO support? + icm_lr_scale: float = 0.0, + icm_reward_scale: float = 0.01, + icm_forward_loss_weight: float = 0.2, +) -> None: + log_name = os.path.join(task, "dqn", str(experiment_config.seed), datetime_tag()) + + sampling_config = SamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + batch_size=batch_size, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + step_per_collect=step_per_collect, + update_per_step=update_per_step, + repeat_per_collect=None, + replay_buffer_stack_num=frames_stack, + replay_buffer_ignore_obs_next=True, + replay_buffer_save_only_last_obs=True, + ) + + env_factory = AtariEnvFactory( + task, + sampling_config.train_seed, + sampling_config.test_seed, + frames_stack, + scale=scale_obs, + ) + + builder = ( + DQNExperimentBuilder(env_factory, experiment_config, sampling_config) + .with_dqn_params( + DQNParams( + discount_factor=gamma, + estimation_step=n_step, + lr=lr, + target_update_freq=target_update_freq, + ), + ) + .with_model_factory(IntermediateModuleFactoryAtariDQN()) + .with_epoch_train_callback( + EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final), + ) + .with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test)) + .with_epoch_stop_callback(AtariEpochStopCallback(task)) + ) + if icm_lr_scale > 0: + builder.with_policy_wrapper_factory( + PolicyWrapperFactoryIntrinsicCuriosity( + feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), + hidden_sizes=[512], + lr=lr, + lr_scale=icm_lr_scale, + reward_scale=icm_reward_scale, + forward_loss_weight=icm_forward_loss_weight, + ), + ) + + experiment = builder.build() + experiment.run(run_name=log_name) + + +if __name__ == "__main__": + logging.run_cli(main) diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py new file mode 100644 index 0000000000000000000000000000000000000000..58aff46ac3d7521bc25773eebf095fb7085965db --- /dev/null +++ b/examples/atari/atari_fqf.py @@ -0,0 +1,231 @@ +import argparse +import datetime +import os +import pprint +import sys + +import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import make_atari_env + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault +from tianshou.policy import FQFPolicy +from tianshou.policy.base import BasePolicy +from tianshou.trainer import OffpolicyTrainer +from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=3128) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.005) + parser.add_argument("--eps-train", type=float, default=1.0) + parser.add_argument("--eps-train-final", type=float, default=0.05) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=5e-5) + parser.add_argument("--fraction-lr", type=float, default=2.5e-9) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--num-fractions", type=int, default=32) + parser.add_argument("--num-cosines", type=int, default=64) + parser.add_argument("--ent-coef", type=float, default=10.0) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.0) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only", + ) + parser.add_argument("--save-buffer-name", type=str, default=None) + return parser.parse_args() + + +def test_fqf(args: argparse.Namespace = get_args()) -> None: + env, train_envs, test_envs = make_atari_env( + args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, + frame_stack=args.frames_stack, + ) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + # define model + feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + net = FullQuantileFunction( + feature_net, + args.action_shape, + args.hidden_sizes, + args.num_cosines, + device=args.device, + ).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) + fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr) + # define policy + policy: FQFPolicy = FQFPolicy( + model=net, + optim=optim, + fraction_model=fraction_net, + fraction_optim=fraction_optim, + action_space=env.action_space, + discount_factor=args.gamma, + num_fractions=args.num_fractions, + ent_coef=args.ent_coef, + estimation_step=args.n_step, + target_update_freq=args.target_update_freq, + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "fqf" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + logger_factory = LoggerFactoryDefault() + if args.logger == "wandb": + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + def stop_fn(mean_rewards: float) -> bool: + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + if "Pong" in args.task: + return mean_rewards >= 20 + return False + + def train_fn(epoch: int, env_step: int) -> None: + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) + else: + eps = args.eps_train_final + policy.set_eps(eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) + + def test_fn(epoch: int, env_step: int | None) -> None: + policy.set_eps(args.eps_test) + + # watch agent's performance + def watch() -> None: + print("Setup test envs ...") + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + result.pprint_asdict() + + if args.watch: + watch() + sys.exit(0) + + # test train_collector and start filling replay buffer + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = OffpolicyTrainer( + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ).run() + + pprint.pprint(result) + watch() + + +if __name__ == "__main__": + test_fqf(get_args()) diff --git a/examples/atari/atari_fqf_rainbow.py b/examples/atari/atari_fqf_rainbow.py new file mode 100644 index 0000000000000000000000000000000000000000..c06da8f3a4fd605815af3b47feb22d067be0d95a --- /dev/null +++ b/examples/atari/atari_fqf_rainbow.py @@ -0,0 +1,288 @@ +import argparse +import datetime +import os +import pprint +import sys + +# import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import make_atari_env + +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault +from tianshou.policy import FQFPolicy,FQF_RainbowPolicy +from tianshou.policy.base import BasePolicy +from tianshou.trainer import OffpolicyTrainer +from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction, FullQuantileFunctionRainbow + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--algo-name", type=str, default="RainbowFQF") + parser.add_argument("--seed", type=int, default=3128) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.005) + parser.add_argument("--eps-train", type=float, default=1.0) + parser.add_argument("--eps-train-final", type=float, default=0.05) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=5e-5) + parser.add_argument("--fraction-lr", type=float, default=2.5e-9) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--num-fractions", type=int, default=32) + parser.add_argument("--num-cosines", type=int, default=64) + parser.add_argument("--ent-coef", type=float, default=10.0) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + #rainbow elements + parser.add_argument("--no-dueling", action="store_true", default=False) + parser.add_argument("--no-noisy", action="store_true", default=False) + parser.add_argument("--no-priority", action="store_true", default=False) + parser.add_argument("--noisy-std", type=float, default=0.1) + parser.add_argument("--alpha", type=float, default=0.5) + parser.add_argument("--beta", type=float, default=0.4) + parser.add_argument("--beta-final", type=float, default=1.0) + parser.add_argument("--beta-anneal-step", type=int, default=5000000) + parser.add_argument("--no-weight-norm", action="store_true", default=False) + + + parser.add_argument("--render", type=float, default=0.0) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only", + ) + parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--per", type=bool, default=False) + return parser.parse_args() + + +def test_fqf(args: argparse.Namespace = get_args()) -> None: + env, train_envs, test_envs = make_atari_env( + args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, + frame_stack=args.frames_stack, + ) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # seed + # np.random.seed(args.seed) + torch.manual_seed(args.seed) + # define model + feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + preprocess_net_output_dim = feature_net.output_dim # Ensure this is correctly set + print(preprocess_net_output_dim) + net = FullQuantileFunctionRainbow( + preprocess_net=feature_net, + action_shape=args.action_shape, + hidden_sizes=args.hidden_sizes, + num_cosines=args.num_cosines, + preprocess_net_output_dim=preprocess_net_output_dim, + device=args.device, + noisy_std = args.noisy_std, + is_noisy=not args.no_noisy, # Set to True to use noisy layers + is_dueling = not args.no_dueling, # Set to True to use noisy layers + ).to(args.device) + print(net) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) + fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr) + # define policy + policy: FQF_RainbowPolicy = FQF_RainbowPolicy( + model=net, + optim=optim, + fraction_model=fraction_net, + fraction_optim=fraction_optim, + action_space=env.action_space, + discount_factor=args.gamma, + num_fractions=args.num_fractions, + ent_coef=args.ent_coef, + estimation_step=args.n_step, + target_update_freq=args.target_update_freq, + is_noisy=not args.no_noisy + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer + if args.no_priority: + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + else: + print("Using PER") + buffer = PrioritizedVectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + alpha=args.alpha, + beta=args.beta, + weight_norm=not args.no_weight_norm, + ) + print("PER as buffer") + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + # args.algo_name = "fqf_per_noisy" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + logger_factory = LoggerFactoryDefault() + if args.logger == "wandb": + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + def stop_fn(mean_rewards: float) -> bool: + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + if "Pong" in args.task: + return mean_rewards >= 20 + return False + + def train_fn(epoch: int, env_step: int) -> None: + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) + else: + eps = args.eps_train_final + policy.set_eps(eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) + if not args.no_priority: + if env_step <= args.beta_anneal_step: + beta = args.beta - env_step / args.beta_anneal_step * (args.beta - args.beta_final) + # print("beta updated - anneal") + else: + beta = args.beta_final + # print("beta updated - final") + buffer.set_beta(beta) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/beta": beta}) + + def test_fn(epoch: int, env_step: int | None) -> None: + policy.set_eps(args.eps_test) + + # watch agent's performance + def watch() -> None: + print("Setup test envs ...") + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + # buffer = VectorReplayBuffer( + # args.buffer_size, + # buffer_num=len(test_envs), + # ignore_obs_next=True, + # save_only_last_obs=True, + # stack_num=args.frames_stack, + # ) + buffer = PrioritizedVectorReplayBuffer( + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + alpha=args.alpha, + beta=args.beta, + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + result.pprint_asdict() + + if args.watch: + watch() + sys.exit(0) + + # test train_collector and start filling replay buffer + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = OffpolicyTrainer( + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ).run() + + pprint.pprint(result) + watch() + + +if __name__ == "__main__": + test_fqf(get_args()) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py new file mode 100644 index 0000000000000000000000000000000000000000..c6090523d62bcca207fa6b576c5843b10b18cebb --- /dev/null +++ b/examples/atari/atari_iqn.py @@ -0,0 +1,229 @@ +import argparse +import datetime +import os +import pprint +import sys + +import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import make_atari_env + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault +from tianshou.policy import IQNPolicy +from tianshou.policy.base import BasePolicy +from tianshou.trainer import OffpolicyTrainer +from tianshou.utils.net.discrete import ImplicitQuantileNetwork + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.005) + parser.add_argument("--eps-train", type=float, default=1.0) + parser.add_argument("--eps-train-final", type=float, default=0.05) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=0.0001) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--sample-size", type=int, default=32) + parser.add_argument("--online-sample-size", type=int, default=8) + parser.add_argument("--target-sample-size", type=int, default=8) + parser.add_argument("--num-cosines", type=int, default=64) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.0) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only", + ) + parser.add_argument("--save-buffer-name", type=str, default=None) + return parser.parse_args() + + +def test_iqn(args: argparse.Namespace = get_args()) -> None: + env, train_envs, test_envs = make_atari_env( + args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, + frame_stack=args.frames_stack, + ) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + # define model + feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + net = ImplicitQuantileNetwork( + feature_net, + args.action_shape, + args.hidden_sizes, + num_cosines=args.num_cosines, + device=args.device, + ).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + # define policy + policy: IQNPolicy = IQNPolicy( + model=net, + optim=optim, + action_space=env.action_space, + discount_factor=args.gamma, + sample_size=args.sample_size, + online_sample_size=args.online_sample_size, + target_sample_size=args.target_sample_size, + estimation_step=args.n_step, + target_update_freq=args.target_update_freq, + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "iqn" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + logger_factory = LoggerFactoryDefault() + if args.logger == "wandb": + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + def stop_fn(mean_rewards: float) -> bool: + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + if "Pong" in args.task: + return mean_rewards >= 20 + return False + + def train_fn(epoch: int, env_step: int) -> None: + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) + else: + eps = args.eps_train_final + policy.set_eps(eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) + + def test_fn(epoch: int, env_step: int | None) -> None: + policy.set_eps(args.eps_test) + + # watch agent's performance + def watch() -> None: + print("Setup test envs ...") + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + result.pprint_asdict() + + if args.watch: + watch() + sys.exit(0) + + # test train_collector and start filling replay buffer + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + + result = OffpolicyTrainer( + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ).run() + + pprint.pprint(result) + watch() + + +if __name__ == "__main__": + test_iqn(get_args()) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py new file mode 100644 index 0000000000000000000000000000000000000000..23df1cd256e620ae7584578fd384f41b041d7fe3 --- /dev/null +++ b/examples/atari/atari_iqn_hl.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +import os +from collections.abc import Sequence + +from examples.atari.atari_network import ( + IntermediateModuleFactoryAtariDQN, +) +from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback +from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.experiment import ( + ExperimentConfig, + IQNExperimentBuilder, +) +from tianshou.highlevel.params.policy_params import IQNParams +from tianshou.highlevel.trainer import ( + EpochTestCallbackDQNSetEps, + EpochTrainCallbackDQNEpsLinearDecay, +) +from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag + + +def main( + experiment_config: ExperimentConfig, + task: str = "PongNoFrameskip-v4", + scale_obs: bool = False, + eps_test: float = 0.005, + eps_train: float = 1.0, + eps_train_final: float = 0.05, + buffer_size: int = 100000, + lr: float = 0.0001, + gamma: float = 0.99, + sample_size: int = 32, + online_sample_size: int = 8, + target_sample_size: int = 8, + num_cosines: int = 64, + hidden_sizes: Sequence[int] = (512,), + n_step: int = 3, + target_update_freq: int = 500, + epoch: int = 100, + step_per_epoch: int = 100000, + step_per_collect: int = 10, + update_per_step: float = 0.1, + batch_size: int = 32, + training_num: int = 10, + test_num: int = 10, + frames_stack: int = 4, + save_buffer_name: str | None = None, # TODO support? +) -> None: + log_name = os.path.join(task, "iqn", str(experiment_config.seed), datetime_tag()) + + sampling_config = SamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + batch_size=batch_size, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + step_per_collect=step_per_collect, + update_per_step=update_per_step, + repeat_per_collect=None, + replay_buffer_stack_num=frames_stack, + replay_buffer_ignore_obs_next=True, + replay_buffer_save_only_last_obs=True, + ) + + env_factory = AtariEnvFactory( + task, + sampling_config.train_seed, + sampling_config.test_seed, + frames_stack, + scale=scale_obs, + ) + + experiment = ( + IQNExperimentBuilder(env_factory, experiment_config, sampling_config) + .with_iqn_params( + IQNParams( + discount_factor=gamma, + estimation_step=n_step, + lr=lr, + sample_size=sample_size, + online_sample_size=online_sample_size, + target_update_freq=target_update_freq, + target_sample_size=target_sample_size, + hidden_sizes=hidden_sizes, + num_cosines=num_cosines, + ), + ) + .with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True)) + .with_epoch_train_callback( + EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final), + ) + .with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test)) + .with_epoch_stop_callback(AtariEpochStopCallback(task)) + .build() + ) + experiment.run(run_name=log_name) + + +if __name__ == "__main__": + logging.run_cli(main) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py new file mode 100644 index 0000000000000000000000000000000000000000..5432e792e6000ddbd880b722751170ee6e846dde --- /dev/null +++ b/examples/atari/atari_network.py @@ -0,0 +1,308 @@ +from collections.abc import Callable, Sequence +from typing import Any + +import numpy as np +import torch +from torch import nn + +from examples.atari.tianshou.highlevel.env import Environments +from examples.atari.tianshou.highlevel.module.actor import ActorFactory +from examples.atari.tianshou.highlevel.module.core import ( + TDevice, +) +from examples.atari.tianshou.highlevel.module.intermediate import ( + IntermediateModule, + IntermediateModuleFactory, +) +from examples.atari.tianshou.utils.net.common import NetBase +from examples.atari.tianshou.utils.net.discrete import Actor, NoisyLinear + + +def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module: + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + +class ScaledObsInputModule(torch.nn.Module): + def __init__(self, module: NetBase, denom: float = 255.0) -> None: + super().__init__() + self.module = module + self.denom = denom + # This is required such that the value can be retrieved by downstream modules (see usages of get_output_dim) + self.output_dim = module.output_dim + + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any | None = None, + info: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, Any]: + if info is None: + info = {} + return self.module.forward(obs / self.denom, state, info) + + +def scale_obs(module: NetBase, denom: float = 255.0) -> ScaledObsInputModule: + return ScaledObsInputModule(module, denom=denom) + + +class DQN(NetBase[Any]): + """Reference: Human-level control through deep reinforcement learning. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__( + self, + c: int, + h: int, + w: int, + action_shape: Sequence[int] | int, + device: str | int | torch.device = "cpu", + features_only: bool = False, + output_dim_added_layer: int | None = None, + layer_init: Callable[[nn.Module], nn.Module] = lambda x: x, + ) -> None: + # TODO: Add docstring + if not features_only and output_dim_added_layer is not None: + raise ValueError( + "Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.", + ) + super().__init__() + self.device = device + self.net = nn.Sequential( + layer_init(nn.Conv2d(c, 32, kernel_size=8, stride=4)), + nn.ReLU(inplace=True), + layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)), + nn.ReLU(inplace=True), + layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)), + nn.ReLU(inplace=True), + nn.Flatten(), + ) + with torch.no_grad(): + base_cnn_output_dim = int(np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])) + if not features_only: + action_dim = int(np.prod(action_shape)) + self.net = nn.Sequential( + self.net, + layer_init(nn.Linear(base_cnn_output_dim, 512)), + nn.ReLU(inplace=True), + layer_init(nn.Linear(512, action_dim)), + ) + self.output_dim = action_dim + elif output_dim_added_layer is not None: + self.net = nn.Sequential( + self.net, + layer_init(nn.Linear(base_cnn_output_dim, output_dim_added_layer)), + nn.ReLU(inplace=True), + ) + self.output_dim = output_dim_added_layer + else: + self.output_dim = base_cnn_output_dim + + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any | None = None, + info: dict[str, Any] | None = None, + **kwargs: Any, + ) -> tuple[torch.Tensor, Any]: + r"""Mapping: s -> Q(s, \*).""" + obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) + return self.net(obs), state + + +class C51(DQN): + """Reference: A distributional perspective on reinforcement learning. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__( + self, + c: int, + h: int, + w: int, + action_shape: Sequence[int], + num_atoms: int = 51, + device: str | int | torch.device = "cpu", + ) -> None: + self.action_num = int(np.prod(action_shape)) + super().__init__(c, h, w, [self.action_num * num_atoms], device) + self.num_atoms = num_atoms + + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any | None = None, + info: dict[str, Any] | None = None, + **kwargs: Any, + ) -> tuple[torch.Tensor, Any]: + r"""Mapping: x -> Z(x, \*).""" + obs, state = super().forward(obs) + obs = obs.view(-1, self.num_atoms).softmax(dim=-1) + obs = obs.view(-1, self.action_num, self.num_atoms) + return obs, state + + +class Rainbow(DQN): + """Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__( + self, + c: int, + h: int, + w: int, + action_shape: Sequence[int], + num_atoms: int = 51, + noisy_std: float = 0.5, + device: str | int | torch.device = "cpu", + is_dueling: bool = True, + is_noisy: bool = True, + ) -> None: + super().__init__(c, h, w, action_shape, device, features_only=True) + self.action_num = int(np.prod(action_shape)) + self.num_atoms = num_atoms + + def linear(x: int, y: int) -> NoisyLinear | nn.Linear: + if is_noisy: + return NoisyLinear(x, y, noisy_std) + return nn.Linear(x, y) + + self.Q = nn.Sequential( + linear(self.output_dim, 512), + nn.ReLU(inplace=True), + linear(512, self.action_num * self.num_atoms), + ) + self._is_dueling = is_dueling + if self._is_dueling: + self.V = nn.Sequential( + linear(self.output_dim, 512), + nn.ReLU(inplace=True), + linear(512, self.num_atoms), + ) + self.output_dim = self.action_num * self.num_atoms + + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any | None = None, + info: dict[str, Any] | None = None, + **kwargs: Any, + ) -> tuple[torch.Tensor, Any]: + r"""Mapping: x -> Z(x, \*).""" + obs, state = super().forward(obs) + q = self.Q(obs) + q = q.view(-1, self.action_num, self.num_atoms) + if self._is_dueling: + v = self.V(obs) + v = v.view(-1, 1, self.num_atoms) + logits = q - q.mean(dim=1, keepdim=True) + v + else: + logits = q + probs = logits.softmax(dim=2) + return probs, state + + +class QRDQN(DQN): + """Reference: Distributional Reinforcement Learning with Quantile Regression. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__( + self, + *, + c: int, + h: int, + w: int, + action_shape: Sequence[int] | int, + num_quantiles: int = 200, + device: str | int | torch.device = "cpu", + ) -> None: + self.action_num = int(np.prod(action_shape)) + super().__init__(c, h, w, [self.action_num * num_quantiles], device) + self.num_quantiles = num_quantiles + + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any | None = None, + info: dict[str, Any] | None = None, + **kwargs: Any, + ) -> tuple[torch.Tensor, Any]: + r"""Mapping: x -> Z(x, \*).""" + obs, state = super().forward(obs) + obs = obs.view(-1, self.action_num, self.num_quantiles) + return obs, state + + +class ActorFactoryAtariDQN(ActorFactory): + def __init__( + self, + scale_obs: bool = True, + features_only: bool = False, + output_dim_added_layer: int | None = None, + ) -> None: + self.output_dim_added_layer = output_dim_added_layer + self.scale_obs = scale_obs + self.features_only = features_only + + def create_module(self, envs: Environments, device: TDevice) -> Actor: + c, h, w = envs.get_observation_shape() # type: ignore # only right shape is a sequence of length 3 + action_shape = envs.get_action_shape() + if isinstance(action_shape, np.int64): + action_shape = int(action_shape) + net: DQN | ScaledObsInputModule + net = DQN( + c=c, + h=h, + w=w, + action_shape=action_shape, + device=device, + features_only=self.features_only, + output_dim_added_layer=self.output_dim_added_layer, + layer_init=layer_init, + ) + if self.scale_obs: + net = scale_obs(net) + return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device) + + +class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory): + def __init__(self, features_only: bool = False, net_only: bool = False) -> None: + self.features_only = features_only + self.net_only = net_only + + def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: + obs_shape = envs.get_observation_shape() + if isinstance(obs_shape, int): + obs_shape = [obs_shape] + assert len(obs_shape) == 3 + c, h, w = obs_shape + action_shape = envs.get_action_shape() + if isinstance(action_shape, np.int64): + action_shape = int(action_shape) + dqn = DQN( + c=c, + h=h, + w=w, + action_shape=action_shape, + device=device, + features_only=self.features_only, + ).to(device) + module = dqn.net if self.net_only else dqn + return IntermediateModule(module, dqn.output_dim) + + +class IntermediateModuleFactoryAtariDQNFeatures(IntermediateModuleFactoryAtariDQN): + def __init__(self) -> None: + super().__init__(features_only=True, net_only=True) diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..dd75de7fb043e2e7fe73e7701f6682fe83598683 --- /dev/null +++ b/examples/atari/atari_ppo.py @@ -0,0 +1,284 @@ +import argparse +import datetime +import os +import pprint +import sys + +import numpy as np +import torch +from atari_network import DQN, layer_init, scale_obs +from atari_wrapper import make_atari_env +from torch.distributions import Categorical +from torch.optim.lr_scheduler import LambdaLR + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault +from tianshou.policy import ICMPolicy, PPOPolicy +from tianshou.policy.base import BasePolicy +from tianshou.trainer import OnpolicyTrainer +from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=4213) + parser.add_argument("--scale-obs", type=int, default=1) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=2.5e-4) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=1000) + parser.add_argument("--repeat-per-collect", type=int, default=4) + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--hidden-size", type=int, default=512) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--rew-norm", type=int, default=False) + parser.add_argument("--vf-coef", type=float, default=0.25) + parser.add_argument("--ent-coef", type=float, default=0.01) + parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--lr-decay", type=int, default=True) + parser.add_argument("--max-grad-norm", type=float, default=0.5) + parser.add_argument("--eps-clip", type=float, default=0.1) + parser.add_argument("--dual-clip", type=float, default=None) + parser.add_argument("--value-clip", type=int, default=1) + parser.add_argument("--norm-adv", type=int, default=1) + parser.add_argument("--recompute-adv", type=int, default=0) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.0) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only", + ) + parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument( + "--icm-lr-scale", + type=float, + default=0.0, + help="use intrinsic curiosity module with this lr scale", + ) + parser.add_argument( + "--icm-reward-scale", + type=float, + default=0.01, + help="scaling factor for intrinsic curiosity reward", + ) + parser.add_argument( + "--icm-forward-loss-weight", + type=float, + default=0.2, + help="weight for the forward model loss in ICM", + ) + return parser.parse_args() + + +def test_ppo(args: argparse.Namespace = get_args()) -> None: + env, train_envs, test_envs = make_atari_env( + args.task, + args.seed, + args.training_num, + args.test_num, + scale=0, + frame_stack=args.frames_stack, + ) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + # define model + net = DQN( + *args.state_shape, + args.action_shape, + device=args.device, + features_only=True, + output_dim_added_layer=args.hidden_size, + layer_init=layer_init, + ) + if args.scale_obs: + net = scale_obs(net) + actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) + critic = Critic(net, device=args.device) + optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5) + + lr_scheduler = None + if args.lr_decay: + # decay learning rate to 0 linearly + max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch + + lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + + policy: PPOPolicy = PPOPolicy( + actor=actor, + critic=critic, + optim=optim, + dist_fn=Categorical, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + max_grad_norm=args.max_grad_norm, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + reward_normalization=args.rew_norm, + action_scaling=False, + lr_scheduler=lr_scheduler, + action_space=env.action_space, + eps_clip=args.eps_clip, + value_clip=args.value_clip, + dual_clip=args.dual_clip, + advantage_normalization=args.norm_adv, + recompute_advantage=args.recompute_adv, + ).to(args.device) + if args.icm_lr_scale > 0: + feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + action_dim = np.prod(args.action_shape) + feature_dim = feature_net.output_dim + icm_net = IntrinsicCuriosityModule( + feature_net.net, + feature_dim, + action_dim, + hidden_sizes=[args.hidden_size], + device=args.device, + ) + icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) + policy: ICMPolicy = ICMPolicy( # type: ignore[no-redef] + policy=policy, + model=icm_net, + optim=icm_optim, + action_space=env.action_space, + lr_scale=args.icm_lr_scale, + reward_scale=args.icm_reward_scale, + forward_loss_weight=args.icm_forward_loss_weight, + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "ppo_icm" if args.icm_lr_scale > 0 else "ppo" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + logger_factory = LoggerFactoryDefault() + if args.logger == "wandb": + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + def stop_fn(mean_rewards: float) -> bool: + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + if "Pong" in args.task: + return mean_rewards >= 20 + return False + + def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: + # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") + torch.save({"model": policy.state_dict()}, ckpt_path) + return ckpt_path + + # watch agent's performance + def watch() -> None: + print("Setup test envs ...") + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + result.pprint_asdict() + + if args.watch: + watch() + sys.exit(0) + + # test train_collector and start filling replay buffer + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = OnpolicyTrainer( + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + resume_from_log=args.resume_id is not None, + save_checkpoint_fn=save_checkpoint_fn, + ).run() + + pprint.pprint(result) + watch() + + +if __name__ == "__main__": + test_ppo(get_args()) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py new file mode 100644 index 0000000000000000000000000000000000000000..10dcd0a7e7c6439036c96f118942fd0c941a80b6 --- /dev/null +++ b/examples/atari/atari_ppo_hl.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 + +import os +from collections.abc import Sequence + +from examples.atari.atari_network import ( + ActorFactoryAtariDQN, + IntermediateModuleFactoryAtariDQNFeatures, +) +from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback +from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.experiment import ( + ExperimentConfig, + PPOExperimentBuilder, +) +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear +from tianshou.highlevel.params.policy_params import PPOParams +from tianshou.highlevel.params.policy_wrapper import ( + PolicyWrapperFactoryIntrinsicCuriosity, +) +from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag + + +def main( + experiment_config: ExperimentConfig, + task: str = "PongNoFrameskip-v4", + scale_obs: bool = True, + buffer_size: int = 100000, + lr: float = 2.5e-4, + gamma: float = 0.99, + epoch: int = 100, + step_per_epoch: int = 100000, + step_per_collect: int = 1000, + repeat_per_collect: int = 4, + batch_size: int = 256, + hidden_sizes: Sequence[int] = (512,), + training_num: int = 10, + test_num: int = 10, + rew_norm: bool = False, + vf_coef: float = 0.25, + ent_coef: float = 0.01, + gae_lambda: float = 0.95, + lr_decay: bool = True, + max_grad_norm: float = 0.5, + eps_clip: float = 0.1, + dual_clip: float | None = None, + value_clip: bool = True, + norm_adv: bool = True, + recompute_adv: bool = False, + frames_stack: int = 4, + save_buffer_name: str | None = None, # TODO add support in high-level API? + icm_lr_scale: float = 0.0, + icm_reward_scale: float = 0.01, + icm_forward_loss_weight: float = 0.2, +) -> None: + log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) + + sampling_config = SamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + batch_size=batch_size, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + step_per_collect=step_per_collect, + repeat_per_collect=repeat_per_collect, + replay_buffer_stack_num=frames_stack, + replay_buffer_ignore_obs_next=True, + replay_buffer_save_only_last_obs=True, + ) + + env_factory = AtariEnvFactory( + task, + sampling_config.train_seed, + sampling_config.test_seed, + frames_stack, + scale=scale_obs, + ) + + builder = ( + PPOExperimentBuilder(env_factory, experiment_config, sampling_config) + .with_ppo_params( + PPOParams( + discount_factor=gamma, + gae_lambda=gae_lambda, + reward_normalization=rew_norm, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + value_clip=value_clip, + advantage_normalization=norm_adv, + eps_clip=eps_clip, + dual_clip=dual_clip, + recompute_advantage=recompute_adv, + lr=lr, + lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) + if lr_decay + else None, + ), + ) + .with_actor_factory(ActorFactoryAtariDQN(scale_obs=scale_obs, features_only=True)) + .with_critic_factory_use_actor() + .with_epoch_stop_callback(AtariEpochStopCallback(task)) + ) + if icm_lr_scale > 0: + builder.with_policy_wrapper_factory( + PolicyWrapperFactoryIntrinsicCuriosity( + feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), + hidden_sizes=hidden_sizes, + lr=lr, + lr_scale=icm_lr_scale, + reward_scale=icm_reward_scale, + forward_loss_weight=icm_forward_loss_weight, + ), + ) + experiment = builder.build() + experiment.run(run_name=log_name) + + +if __name__ == "__main__": + logging.run_cli(main) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py new file mode 100644 index 0000000000000000000000000000000000000000..b9731316e19c8edd646e70bc878576cf68744268 --- /dev/null +++ b/examples/atari/atari_qrdqn.py @@ -0,0 +1,222 @@ +import argparse +import datetime +import os +import pprint +import sys + +import numpy as np +import torch +from atari_network import QRDQN +from atari_wrapper import make_atari_env + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault +from tianshou.policy import QRDQNPolicy +from tianshou.policy.base import BasePolicy +from tianshou.trainer import OffpolicyTrainer + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.005) + parser.add_argument("--eps-train", type=float, default=1.0) + parser.add_argument("--eps-train-final", type=float, default=0.05) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=0.0001) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--num-quantiles", type=int, default=200) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.0) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only", + ) + parser.add_argument("--save-buffer-name", type=str, default=None) + return parser.parse_args() + + +def test_qrdqn(args: argparse.Namespace = get_args()) -> None: + env, train_envs, test_envs = make_atari_env( + args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, + frame_stack=args.frames_stack, + ) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + # define model + c, h, w = args.state_shape + net = QRDQN( + c=c, + h=h, + w=w, + action_shape=args.action_shape, + num_quantiles=args.num_quantiles, + device=args.device, + ) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + # define policy + policy: QRDQNPolicy = QRDQNPolicy( + model=net, + optim=optim, + action_space=env.action_space, + discount_factor=args.gamma, + num_quantiles=args.num_quantiles, + estimation_step=args.n_step, + target_update_freq=args.target_update_freq, + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "qrdqn" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + logger_factory = LoggerFactoryDefault() + if args.logger == "wandb": + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + def stop_fn(mean_rewards: float) -> bool: + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + if "Pong" in args.task: + return mean_rewards >= 20 + return False + + def train_fn(epoch: int, env_step: int) -> None: + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) + else: + eps = args.eps_train_final + policy.set_eps(eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) + + def test_fn(epoch: int, env_step: int | None) -> None: + policy.set_eps(args.eps_test) + + # watch agent's performance + def watch() -> None: + print("Setup test envs ...") + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + result.pprint_asdict() + + if args.watch: + watch() + sys.exit(0) + + # test train_collector and start filling replay buffer + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = OffpolicyTrainer( + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ).run() + + pprint.pprint(result) + watch() + + +if __name__ == "__main__": + test_qrdqn(get_args()) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py new file mode 100644 index 0000000000000000000000000000000000000000..952d35f07146a3f92e6ed5b82d93e291ed5081b6 --- /dev/null +++ b/examples/atari/atari_rainbow.py @@ -0,0 +1,258 @@ +import argparse +import datetime +import os +import pprint +import sys + +import numpy as np +import torch +from atari_network import Rainbow +from atari_wrapper import make_atari_env + +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault +from tianshou.policy import C51Policy, RainbowPolicy +from tianshou.policy.base import BasePolicy +from tianshou.trainer import OffpolicyTrainer + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--eps-test", type=float, default=0.005) + parser.add_argument("--eps-train", type=float, default=1.0) + parser.add_argument("--eps-train-final", type=float, default=0.05) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--lr", type=float, default=0.0000625) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--num-atoms", type=int, default=51) + parser.add_argument("--v-min", type=float, default=-10.0) + parser.add_argument("--v-max", type=float, default=10.0) + parser.add_argument("--noisy-std", type=float, default=0.1) + parser.add_argument("--no-dueling", action="store_true", default=False) + parser.add_argument("--no-noisy", action="store_true", default=False) + parser.add_argument("--no-priority", action="store_true", default=False) + parser.add_argument("--alpha", type=float, default=0.5) + parser.add_argument("--beta", type=float, default=0.4) + parser.add_argument("--beta-final", type=float, default=1.0) + parser.add_argument("--beta-anneal-step", type=int, default=5000000) + parser.add_argument("--no-weight-norm", action="store_true", default=False) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.0) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only", + ) + parser.add_argument("--save-buffer-name", type=str, default=None) + return parser.parse_args() + + +def test_rainbow(args: argparse.Namespace = get_args()) -> None: + env, train_envs, test_envs = make_atari_env( + args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, + frame_stack=args.frames_stack, + ) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + # define model + net = Rainbow( + *args.state_shape, + args.action_shape, + args.num_atoms, + args.noisy_std, + args.device, + is_dueling=not args.no_dueling, + is_noisy=not args.no_noisy, + ) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + # define policy + policy: C51Policy = RainbowPolicy( + model=net, + optim=optim, + discount_factor=args.gamma, + action_space=env.action_space, + num_atoms=args.num_atoms, + v_min=args.v_min, + v_max=args.v_max, + estimation_step=args.n_step, + target_update_freq=args.target_update_freq, + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer + if args.no_priority: + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + else: + buffer = PrioritizedVectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + alpha=args.alpha, + beta=args.beta, + weight_norm=not args.no_weight_norm, + ) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "rainbow" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + logger_factory = LoggerFactoryDefault() + if args.logger == "wandb": + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + def stop_fn(mean_rewards: float) -> bool: + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + if "Pong" in args.task: + return mean_rewards >= 20 + return False + + def train_fn(epoch: int, env_step: int) -> None: + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) + else: + eps = args.eps_train_final + policy.set_eps(eps) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/eps": eps}) + if not args.no_priority: + if env_step <= args.beta_anneal_step: + beta = args.beta - env_step / args.beta_anneal_step * (args.beta - args.beta_final) + else: + beta = args.beta_final + buffer.set_beta(beta) + if env_step % 1000 == 0: + logger.write("train/env_step", env_step, {"train/beta": beta}) + + def test_fn(epoch: int, env_step: int | None) -> None: + policy.set_eps(args.eps_test) + + # watch agent's performance + def watch() -> None: + print("Setup test envs ...") + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = PrioritizedVectorReplayBuffer( + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + alpha=args.alpha, + beta=args.beta, + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + result.pprint_asdict() + + if args.watch: + watch() + sys.exit(0) + + # test train_collector and start filling replay buffer + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = OffpolicyTrainer( + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ).run() + + pprint.pprint(result) + watch() + + +if __name__ == "__main__": + test_rainbow(get_args()) diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py new file mode 100644 index 0000000000000000000000000000000000000000..4d01a88aa5ead67ae7ecbd670433e98d241a2e8c --- /dev/null +++ b/examples/atari/atari_sac.py @@ -0,0 +1,271 @@ +import argparse +import datetime +import os +import pprint +import sys + +import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import make_atari_env + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault +from tianshou.policy import DiscreteSACPolicy, ICMPolicy +from tianshou.policy.base import BasePolicy +from tianshou.trainer import OffpolicyTrainer +from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=4213) + parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--actor-lr", type=float, default=1e-5) + parser.add_argument("--critic-lr", type=float, default=1e-5) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--tau", type=float, default=0.005) + parser.add_argument("--alpha", type=float, default=0.05) + parser.add_argument("--auto-alpha", action="store_true", default=False) + parser.add_argument("--alpha-lr", type=float, default=3e-4) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=100000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--hidden-size", type=int, default=512) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--rew-norm", type=int, default=False) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.0) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only", + ) + parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument( + "--icm-lr-scale", + type=float, + default=0.0, + help="use intrinsic curiosity module with this lr scale", + ) + parser.add_argument( + "--icm-reward-scale", + type=float, + default=0.01, + help="scaling factor for intrinsic curiosity reward", + ) + parser.add_argument( + "--icm-forward-loss-weight", + type=float, + default=0.2, + help="weight for the forward model loss in ICM", + ) + return parser.parse_args() + + +def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: + env, train_envs, test_envs = make_atari_env( + args.task, + args.seed, + args.training_num, + args.test_num, + scale=args.scale_obs, + frame_stack=args.frames_stack, + ) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + # define model + net = DQN( + *args.state_shape, + args.action_shape, + device=args.device, + features_only=True, + output_dim_added_layer=args.hidden_size, + ) + actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + critic1 = Critic(net, last_size=args.action_shape, device=args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2 = Critic(net, last_size=args.action_shape, device=args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + # define policy + if args.auto_alpha: + target_entropy = 0.98 * np.log(np.prod(args.action_shape)) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + args.alpha = (target_entropy, log_alpha, alpha_optim) + + policy: DiscreteSACPolicy | ICMPolicy + policy = DiscreteSACPolicy( + actor=actor, + actor_optim=actor_optim, + critic=critic1, + critic_optim=critic1_optim, + critic2=critic2, + critic2_optim=critic2_optim, + action_space=env.action_space, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + estimation_step=args.n_step, + ).to(args.device) + if args.icm_lr_scale > 0: + feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + action_dim = np.prod(args.action_shape) + feature_dim = feature_net.output_dim + icm_net = IntrinsicCuriosityModule( + feature_net.net, + feature_dim, + action_dim, + hidden_sizes=[args.hidden_size], + device=args.device, + ) + icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.actor_lr) + policy = ICMPolicy( + policy=policy, + model=icm_net, + optim=icm_optim, + action_space=env.action_space, + lr_scale=args.icm_lr_scale, + reward_scale=args.icm_reward_scale, + forward_loss_weight=args.icm_forward_loss_weight, + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "discrete_sac_icm" if args.icm_lr_scale > 0 else "discrete_sac" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + logger_factory = LoggerFactoryDefault() + if args.logger == "wandb": + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + def stop_fn(mean_rewards: float) -> bool: + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + if "Pong" in args.task: + return mean_rewards >= 20 + return False + + def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: + # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + ckpt_path = os.path.join(log_path, "checkpoint.pth") + torch.save({"model": policy.state_dict()}, ckpt_path) + return ckpt_path + + # watch agent's performance + def watch() -> None: + print("Setup test envs ...") + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + result.pprint_asdict() + + if args.watch: + watch() + sys.exit(0) + + # test train_collector and start filling replay buffer + train_collector.reset() + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = OffpolicyTrainer( + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + resume_from_log=args.resume_id is not None, + save_checkpoint_fn=save_checkpoint_fn, + ).run() + + pprint.pprint(result) + watch() + + +if __name__ == "__main__": + test_discrete_sac(get_args()) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py new file mode 100644 index 0000000000000000000000000000000000000000..cf09b40eaf1a9b664a6fb96b5bbfaeba0fcad71d --- /dev/null +++ b/examples/atari/atari_sac_hl.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 + +import os +from collections.abc import Sequence + +from examples.atari.atari_network import ( + ActorFactoryAtariDQN, + IntermediateModuleFactoryAtariDQNFeatures, +) +from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback +from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.experiment import ( + DiscreteSACExperimentBuilder, + ExperimentConfig, +) +from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault +from tianshou.highlevel.params.policy_params import DiscreteSACParams +from tianshou.highlevel.params.policy_wrapper import ( + PolicyWrapperFactoryIntrinsicCuriosity, +) +from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag + + +def main( + experiment_config: ExperimentConfig, + task: str = "PongNoFrameskip-v4", + scale_obs: bool = False, + buffer_size: int = 100000, + actor_lr: float = 1e-5, + critic_lr: float = 1e-5, + gamma: float = 0.99, + n_step: int = 3, + tau: float = 0.005, + alpha: float = 0.05, + auto_alpha: bool = False, + alpha_lr: float = 3e-4, + epoch: int = 100, + step_per_epoch: int = 100000, + step_per_collect: int = 10, + update_per_step: float = 0.1, + batch_size: int = 64, + hidden_sizes: Sequence[int] = (512,), + training_num: int = 10, + test_num: int = 10, + frames_stack: int = 4, + save_buffer_name: str | None = None, # TODO add support in high-level API? + icm_lr_scale: float = 0.0, + icm_reward_scale: float = 0.01, + icm_forward_loss_weight: float = 0.2, +) -> None: + log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag()) + + sampling_config = SamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + update_per_step=update_per_step, + batch_size=batch_size, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + step_per_collect=step_per_collect, + repeat_per_collect=None, + replay_buffer_stack_num=frames_stack, + replay_buffer_ignore_obs_next=True, + replay_buffer_save_only_last_obs=True, + ) + + env_factory = AtariEnvFactory( + task, + sampling_config.train_seed, + sampling_config.test_seed, + frames_stack, + scale=scale_obs, + ) + + builder = ( + DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config) + .with_sac_params( + DiscreteSACParams( + actor_lr=actor_lr, + critic1_lr=critic_lr, + critic2_lr=critic_lr, + gamma=gamma, + tau=tau, + alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha, + estimation_step=n_step, + ), + ) + .with_actor_factory(ActorFactoryAtariDQN(scale_obs=False, features_only=True)) + .with_common_critic_factory_use_actor() + .with_epoch_stop_callback(AtariEpochStopCallback(task)) + ) + if icm_lr_scale > 0: + builder.with_policy_wrapper_factory( + PolicyWrapperFactoryIntrinsicCuriosity( + feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), + hidden_sizes=hidden_sizes, + lr=actor_lr, + lr_scale=icm_lr_scale, + reward_scale=icm_reward_scale, + forward_loss_weight=icm_forward_loss_weight, + ), + ) + experiment = builder.build() + experiment.run(run_name=log_name) + + +if __name__ == "__main__": + logging.run_cli(main) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f274aefc64fb27c675ebb58f81edd8f0e08a0dbe --- /dev/null +++ b/examples/atari/atari_wrapper.py @@ -0,0 +1,469 @@ +# Borrow a lot from openai baselines: +# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py +import logging +import warnings +from collections import deque +from typing import Any, SupportsFloat + +import cv2 +import gymnasium as gym +import numpy as np +from gymnasium import Env + +from examples.atari.tianshou.env import BaseVectorEnv +from examples.atari.tianshou.highlevel.env import ( + EnvFactoryRegistered, + EnvMode, + EnvPoolFactory, + VectorEnvType, +) +from examples.atari.tianshou.highlevel.trainer import EpochStopCallback, TrainingContext + +envpool_is_available = True +try: + import envpool +except ImportError: + envpool_is_available = False + envpool = None +log = logging.getLogger(__name__) + + +def _parse_reset_result(reset_result: tuple) -> tuple[tuple, dict, bool]: + contains_info = ( + isinstance(reset_result, tuple) + and len(reset_result) == 2 + and isinstance(reset_result[1], dict) + ) + if contains_info: + return reset_result[0], reset_result[1], contains_info + return reset_result, {}, contains_info + + +def get_space_dtype(obs_space: gym.spaces.Box) -> type[np.floating] | type[np.integer]: + obs_space_dtype: type[np.integer] | type[np.floating] + if np.issubdtype(obs_space.dtype, np.integer): + obs_space_dtype = np.integer + elif np.issubdtype(obs_space.dtype, np.floating): + obs_space_dtype = np.floating + else: + raise TypeError( + f"Unsupported observation space dtype: {obs_space.dtype}. " + f"This might be a bug in tianshou or gymnasium, please report it!", + ) + return obs_space_dtype + + +class NoopResetEnv(gym.Wrapper): + """Sample initial states by taking random number of no-ops on reset. + + No-op is assumed to be action 0. + + :param gym.Env env: the environment to wrap. + :param int noop_max: the maximum value of no-ops to run. + """ + + def __init__(self, env: gym.Env, noop_max: int = 30) -> None: + super().__init__(env) + self.noop_max = noop_max + self.noop_action = 0 + assert hasattr(env.unwrapped, "get_action_meanings") + assert env.unwrapped.get_action_meanings()[0] == "NOOP" + + def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: + _, info, return_info = _parse_reset_result(self.env.reset(**kwargs)) + noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) + for _ in range(noops): + step_result = self.env.step(self.noop_action) + if len(step_result) == 4: + obs, rew, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) + else: + obs, rew, term, trunc, info = step_result + done = term or trunc + if done: + obs, info, _ = _parse_reset_result(self.env.reset()) + if return_info: + return obs, info + return obs, {} + + +class MaxAndSkipEnv(gym.Wrapper): + """Return only every `skip`-th frame (frameskipping) using most recent raw observations (for max pooling across time steps). + + :param gym.Env env: the environment to wrap. + :param int skip: number of `skip`-th frame. + """ + + def __init__(self, env: gym.Env, skip: int = 4) -> None: + super().__init__(env) + self._skip = skip + + def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: + """Step the environment with the given action. + + Repeat action, sum reward, and max over last observations. + """ + obs_list = [] + total_reward = 0.0 + new_step_api = False + for _ in range(self._skip): + step_result = self.env.step(action) + if len(step_result) == 4: + obs, reward, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) + else: + obs, reward, term, trunc, info = step_result + done = term or trunc + new_step_api = True + obs_list.append(obs) + total_reward += float(reward) + if done: + break + max_frame = np.max(obs_list[-2:], axis=0) + if new_step_api: + return max_frame, total_reward, term, trunc, info + + return max_frame, total_reward, done, info.get("TimeLimit.truncated", False), info + + +class EpisodicLifeEnv(gym.Wrapper): + """Make end-of-life == end-of-episode, but only reset on true game over. + + It helps the value estimation. + + :param gym.Env env: the environment to wrap. + """ + + def __init__(self, env: gym.Env) -> None: + super().__init__(env) + self.lives = 0 + self.was_real_done = True + self._return_info = False + + def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: + step_result = self.env.step(action) + if len(step_result) == 4: + obs, reward, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) + new_step_api = False + else: + obs, reward, term, trunc, info = step_result + done = term or trunc + new_step_api = True + reward = float(reward) + self.was_real_done = done + # check current lives, make loss of life terminal, then update lives to + # handle bonus lives + assert hasattr(self.env.unwrapped, "ale") + lives = self.env.unwrapped.ale.lives() + if 0 < lives < self.lives: + # for Qbert sometimes we stay in lives == 0 condition for a few + # frames, so its important to keep lives > 0, so that we only reset + # once the environment is actually done. + done = True + term = True + self.lives = lives + if new_step_api: + return obs, reward, term, trunc, info + return obs, reward, done, info.get("TimeLimit.truncated", False), info + + def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: + """Calls the Gym environment reset, only when lives are exhausted. + + This way all states are still reachable even though lives are episodic, and + the learner need not know about any of this behind-the-scenes. + """ + if self.was_real_done: + obs, info, self._return_info = _parse_reset_result(self.env.reset(**kwargs)) + else: + # no-op step to advance from terminal/lost life state + step_result = self.env.step(0) + obs, info = step_result[0], step_result[-1] + assert hasattr(self.env.unwrapped, "ale") + self.lives = self.env.unwrapped.ale.lives() + if self._return_info: + return obs, info + return obs, {} + + +class FireResetEnv(gym.Wrapper): + """Take action on reset for environments that are fixed until firing. + + Related discussion: https://github.com/openai/baselines/issues/240. + + :param gym.Env env: the environment to wrap. + """ + + def __init__(self, env: gym.Env) -> None: + super().__init__(env) + assert hasattr(env.unwrapped, "get_action_meanings") + assert env.unwrapped.get_action_meanings()[1] == "FIRE" + assert len(env.unwrapped.get_action_meanings()) >= 3 + + def reset(self, **kwargs: Any) -> tuple[Any, dict]: + _, _, return_info = _parse_reset_result(self.env.reset(**kwargs)) + obs = self.env.step(1)[0] + return obs, {} + + +class WarpFrame(gym.ObservationWrapper): + """Warp frames to 84x84 as done in the Nature paper and later work. + + :param gym.Env env: the environment to wrap. + """ + + def __init__(self, env: gym.Env) -> None: + super().__init__(env) + self.size = 84 + obs_space = env.observation_space + assert isinstance(obs_space, gym.spaces.Box) + obs_space_dtype = get_space_dtype(obs_space) + self.observation_space = gym.spaces.Box( + low=np.min(obs_space.low), + high=np.max(obs_space.high), + shape=(self.size, self.size), + dtype=obs_space_dtype, + ) + + def observation(self, frame: np.ndarray) -> np.ndarray: + """Returns the current observation from a frame.""" + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA) + + +class ScaledFloatFrame(gym.ObservationWrapper): + """Normalize observations to 0~1. + + :param gym.Env env: the environment to wrap. + """ + + def __init__(self, env: gym.Env) -> None: + super().__init__(env) + obs_space = env.observation_space + assert isinstance(obs_space, gym.spaces.Box) + low = np.min(obs_space.low) + high = np.max(obs_space.high) + self.bias = low + self.scale = high - low + self.observation_space = gym.spaces.Box( + low=0.0, + high=1.0, + shape=obs_space.shape, + dtype=np.float32, + ) + + def observation(self, observation: np.ndarray) -> np.ndarray: + return (observation - self.bias) / self.scale + + +class ClipRewardEnv(gym.RewardWrapper): + """clips the reward to {+1, 0, -1} by its sign. + + :param gym.Env env: the environment to wrap. + """ + + def __init__(self, env: gym.Env) -> None: + super().__init__(env) + self.reward_range = (-1, 1) + + def reward(self, reward: SupportsFloat) -> int: + """Bin reward to {+1, 0, -1} by its sign. Note: np.sign(0) == 0.""" + return np.sign(float(reward)) + + +class FrameStack(gym.Wrapper): + """Stack n_frames last frames. + + :param gym.Env env: the environment to wrap. + :param int n_frames: the number of frames to stack. + """ + + def __init__(self, env: gym.Env, n_frames: int) -> None: + super().__init__(env) + self.n_frames: int = n_frames + self.frames: deque[tuple[Any, ...]] = deque([], maxlen=n_frames) + obs_space = env.observation_space + obs_space_shape = env.observation_space.shape + assert obs_space_shape is not None + shape = (n_frames, *obs_space_shape) + assert isinstance(obs_space, gym.spaces.Box) + obs_space_dtype = get_space_dtype(obs_space) + self.observation_space = gym.spaces.Box( + low=np.min(obs_space.low), + high=np.max(obs_space.high), + shape=shape, + dtype=obs_space_dtype, + ) + + def reset(self, **kwargs: Any) -> tuple[np.ndarray, dict]: + obs, info, return_info = _parse_reset_result(self.env.reset(**kwargs)) + for _ in range(self.n_frames): + self.frames.append(obs) + return (self._get_ob(), info) if return_info else (self._get_ob(), {}) + + def step(self, action: Any) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: + step_result = self.env.step(action) + done: bool + if len(step_result) == 4: + obs, reward, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) + new_step_api = False + else: + obs, reward, term, trunc, info = step_result + new_step_api = True + self.frames.append(obs) + reward = float(reward) + if new_step_api: + return self._get_ob(), reward, term, trunc, info + return self._get_ob(), reward, done, info.get("TimeLimit.truncated", False), info + + def _get_ob(self) -> np.ndarray: + # the original wrapper use `LazyFrames` but since we use np buffer, + # it has no effect + return np.stack(self.frames, axis=0) + + +def wrap_deepmind( + env: gym.Env, + episode_life: bool = True, + clip_rewards: bool = True, + frame_stack: int = 4, + scale: bool = False, + warp_frame: bool = True, +) -> ( + MaxAndSkipEnv + | EpisodicLifeEnv + | FireResetEnv + | WarpFrame + | ScaledFloatFrame + | ClipRewardEnv + | FrameStack +): + """Configure environment for DeepMind-style Atari. + + The observation is channel-first: (c, h, w) instead of (h, w, c). + + :param env: the Atari environment to wrap. + :param bool episode_life: wrap the episode life wrapper. + :param bool clip_rewards: wrap the reward clipping wrapper. + :param int frame_stack: wrap the frame stacking wrapper. + :param bool scale: wrap the scaling observation wrapper. + :param bool warp_frame: wrap the grayscale + resize observation wrapper. + :return: the wrapped atari environment. + """ + env = NoopResetEnv(env, noop_max=30) + env = MaxAndSkipEnv(env, skip=4) + assert hasattr(env.unwrapped, "get_action_meanings") # for mypy + + wrapped_env: MaxAndSkipEnv | EpisodicLifeEnv | FireResetEnv | WarpFrame | ScaledFloatFrame | ClipRewardEnv | FrameStack = ( + env + ) + if episode_life: + wrapped_env = EpisodicLifeEnv(wrapped_env) + if "FIRE" in env.unwrapped.get_action_meanings(): + wrapped_env = FireResetEnv(wrapped_env) + if warp_frame: + wrapped_env = WarpFrame(wrapped_env) + if scale: + wrapped_env = ScaledFloatFrame(wrapped_env) + if clip_rewards: + wrapped_env = ClipRewardEnv(wrapped_env) + if frame_stack: + wrapped_env = FrameStack(wrapped_env, frame_stack) + return wrapped_env + + +def make_atari_env( + task: str, + seed: int, + training_num: int, + test_num: int, + scale: int | bool = False, + frame_stack: int = 4, +) -> tuple[Env, BaseVectorEnv, BaseVectorEnv]: + """Wrapper function for Atari env. + + If EnvPool is installed, it will automatically switch to EnvPool's Atari env. + + :return: a tuple of (single env, training envs, test envs). + """ + env_factory = AtariEnvFactory(task, seed, seed + training_num, frame_stack, scale=bool(scale)) + envs = env_factory.create_envs(training_num, test_num) + return envs.env, envs.train_envs, envs.test_envs + + +class AtariEnvFactory(EnvFactoryRegistered): + def __init__( + self, + task: str, + train_seed: int, + test_seed: int, + frame_stack: int, + scale: bool = False, + use_envpool_if_available: bool = True, + venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO, + ) -> None: + assert "NoFrameskip" in task + self.frame_stack = frame_stack + self.scale = scale + envpool_factory = None + if use_envpool_if_available: + if envpool_is_available: + envpool_factory = self.EnvPoolFactoryAtari(self) + log.info("Using envpool, because it available") + else: + log.info("Not using envpool, because it is not available") + super().__init__( + task=task, + train_seed=train_seed, + test_seed=test_seed, + venv_type=venv_type, + envpool_factory=envpool_factory, + ) + + def create_env(self, mode: EnvMode) -> gym.Env: + env = super().create_env(mode) + is_train = mode == EnvMode.TRAIN + return wrap_deepmind( + env, + episode_life=is_train, + clip_rewards=is_train, + frame_stack=self.frame_stack, + scale=self.scale, + ) + + class EnvPoolFactoryAtari(EnvPoolFactory): + """Atari-specific envpool creation. + Since envpool internally handles the functions that are implemented through the wrappers in `wrap_deepmind`, + it sets the creation keyword arguments accordingly. + """ + + def __init__(self, parent: "AtariEnvFactory") -> None: + self.parent = parent + if self.parent.scale: + warnings.warn( + "EnvPool does not include ScaledFloatFrame wrapper, " + "please compensate by scaling inside your network's forward function (e.g. `x = x / 255.0` for Atari)", + ) + + def _transform_task(self, task: str) -> str: + task = super()._transform_task(task) + # TODO: Maybe warn user, explain why this is needed + return task.replace("NoFrameskip-v4", "-v5") + + def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict: + kwargs = super()._transform_kwargs(kwargs, mode) + is_train = mode == EnvMode.TRAIN + kwargs["reward_clip"] = is_train + kwargs["episodic_life"] = is_train + kwargs["stack_num"] = self.parent.frame_stack + return kwargs + + +class AtariEpochStopCallback(EpochStopCallback): + def __init__(self, task: str) -> None: + self.task = task + + def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: + env = context.envs.env + if env.spec and env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + if "Pong" in self.task: + return mean_rewards >= 20 + return False diff --git a/examples/atari/benchmark/BreakoutNoFrameskip-v4/result.json b/examples/atari/benchmark/BreakoutNoFrameskip-v4/result.json new file mode 100644 index 0000000000000000000000000000000000000000..2b326e9a74620f0e6cdc1b2c2ca1ee1ac508af79 --- /dev/null +++ b/examples/atari/benchmark/BreakoutNoFrameskip-v4/result.json @@ -0,0 +1 @@ +[{"env_step": 0, "rew": 1.489999993145466, "rew_std": 1.4842169361904596, "Agent": "c51"}, {"env_step": 100000, "rew": 7.160000157356262, "rew_std": 3.2809146348741707, "Agent": "c51"}, {"env_step": 200000, "rew": 17.44999990463257, "rew_std": 3.923327599934239, "Agent": "c51"}, {"env_step": 300000, "rew": 17.149999713897706, "rew_std": 4.067001119498604, "Agent": "c51"}, {"env_step": 400000, "rew": 25.739999961853027, "rew_std": 3.0522777522079583, "Agent": "c51"}, {"env_step": 500000, "rew": 31.940000152587892, "rew_std": 6.872292668284413, "Agent": "c51"}, {"env_step": 600000, "rew": 36.10000019073486, "rew_std": 4.652311404432131, "Agent": "c51"}, {"env_step": 700000, "rew": 44.069999885559085, "rew_std": 8.63215520480072, "Agent": "c51"}, {"env_step": 800000, "rew": 54.52999992370606, "rew_std": 6.331357911285241, "Agent": "c51"}, {"env_step": 900000, "rew": 75.43000068664551, "rew_std": 22.26365004429836, "Agent": "c51"}, {"env_step": 1000000, "rew": 73.68000030517578, "rew_std": 23.26932736029952, "Agent": "c51"}, {"env_step": 1100000, "rew": 116.59999771118164, "rew_std": 33.29786741329967, "Agent": "c51"}, {"env_step": 1200000, "rew": 118.80000152587891, "rew_std": 24.21891832224786, "Agent": "c51"}, {"env_step": 1300000, "rew": 149.72999954223633, "rew_std": 34.80965616483245, "Agent": "c51"}, {"env_step": 1400000, "rew": 155.19000167846679, "rew_std": 59.97181996860956, "Agent": "c51"}, {"env_step": 1500000, "rew": 227.85000076293946, "rew_std": 74.72283777482421, "Agent": "c51"}, {"env_step": 1600000, "rew": 224.7099994659424, "rew_std": 81.33929570370834, "Agent": "c51"}, {"env_step": 1700000, "rew": 269.3300064086914, "rew_std": 39.881227624530716, "Agent": "c51"}, {"env_step": 1800000, "rew": 294.8, "rew_std": 48.58571966090773, "Agent": "c51"}, {"env_step": 1900000, "rew": 299.05999450683595, "rew_std": 40.756084489018896, "Agent": "c51"}, {"env_step": 2000000, "rew": 306.51000213623047, "rew_std": 39.86726833847238, "Agent": "c51"}, {"env_step": 2100000, "rew": 277.07000122070315, "rew_std": 52.86110130603287, "Agent": "c51"}, {"env_step": 2200000, "rew": 329.5400024414063, "rew_std": 32.50145093140574, "Agent": "c51"}, {"env_step": 2300000, "rew": 337.4499984741211, "rew_std": 44.79243428344727, "Agent": "c51"}, {"env_step": 2400000, "rew": 364.02000122070314, "rew_std": 18.57825612929554, "Agent": "c51"}, {"env_step": 2500000, "rew": 321.4100006103516, "rew_std": 43.22805389464291, "Agent": "c51"}, {"env_step": 2600000, "rew": 361.55, "rew_std": 27.54516925199673, "Agent": "c51"}, {"env_step": 2700000, "rew": 333.14000244140624, "rew_std": 47.81977169887054, "Agent": "c51"}, {"env_step": 2800000, "rew": 322.52000579833987, "rew_std": 76.32684991482328, "Agent": "c51"}, {"env_step": 2900000, "rew": 330.7200012207031, "rew_std": 64.05682960403955, "Agent": "c51"}, {"env_step": 3000000, "rew": 365.2100006103516, "rew_std": 15.082475188932436, "Agent": "c51"}, {"env_step": 3100000, "rew": 355.52999725341795, "rew_std": 49.67731822769064, "Agent": "c51"}, {"env_step": 3200000, "rew": 365.6499969482422, "rew_std": 50.31155528215873, "Agent": "c51"}, {"env_step": 3300000, "rew": 346.89000091552737, "rew_std": 37.82696142293252, "Agent": "c51"}, {"env_step": 3400000, "rew": 337.4800048828125, "rew_std": 62.96080895743974, "Agent": "c51"}, {"env_step": 3500000, "rew": 362.2200012207031, "rew_std": 35.653072714280874, "Agent": "c51"}, {"env_step": 3600000, "rew": 333.90000305175784, "rew_std": 81.13762513523231, "Agent": "c51"}, {"env_step": 3700000, "rew": 376.25, "rew_std": 31.655335448058853, "Agent": "c51"}, {"env_step": 3800000, "rew": 362.44000244140625, "rew_std": 18.283176431620095, "Agent": "c51"}, {"env_step": 3900000, "rew": 366.27000427246094, "rew_std": 35.04446188070728, "Agent": "c51"}, {"env_step": 4000000, "rew": 382.57000122070315, "rew_std": 14.872595676225728, "Agent": "c51"}, {"env_step": 4100000, "rew": 359.1599945068359, "rew_std": 31.658272094999887, "Agent": "c51"}, {"env_step": 4200000, "rew": 360.13999938964844, "rew_std": 25.854099915610277, "Agent": "c51"}, {"env_step": 4300000, "rew": 358.4100006103516, "rew_std": 45.640715079962746, "Agent": "c51"}, {"env_step": 4400000, "rew": 375.90000305175784, "rew_std": 31.782793945259574, "Agent": "c51"}, {"env_step": 4500000, "rew": 357.9, "rew_std": 39.46456467194559, "Agent": "c51"}, {"env_step": 4600000, "rew": 403.4100036621094, "rew_std": 37.36884707172226, "Agent": "c51"}, {"env_step": 4700000, "rew": 375.38999938964844, "rew_std": 24.63223304182013, "Agent": "c51"}, {"env_step": 4800000, "rew": 345.60999908447263, "rew_std": 64.29045939613836, "Agent": "c51"}, {"env_step": 4900000, "rew": 369.19000244140625, "rew_std": 40.53828808279498, "Agent": "c51"}, {"env_step": 5000000, "rew": 328.8799976348877, "rew_std": 98.52224821427203, "Agent": "c51"}, {"env_step": 5100000, "rew": 385.2899993896484, "rew_std": 33.74096429736962, "Agent": "c51"}, {"env_step": 5200000, "rew": 378.63999938964844, "rew_std": 23.966530681765192, "Agent": "c51"}, {"env_step": 5300000, "rew": 358.99000244140626, "rew_std": 35.26937456390392, "Agent": "c51"}, {"env_step": 5400000, "rew": 367.1299987792969, "rew_std": 37.816190250907496, "Agent": "c51"}, {"env_step": 5500000, "rew": 374.85, "rew_std": 48.78137438800481, "Agent": "c51"}, {"env_step": 5600000, "rew": 396.31000366210935, "rew_std": 16.84918405865486, "Agent": "c51"}, {"env_step": 5700000, "rew": 392.25, "rew_std": 17.89766911339113, "Agent": "c51"}, {"env_step": 5800000, "rew": 371.0399993896484, "rew_std": 39.75309508083556, "Agent": "c51"}, {"env_step": 5900000, "rew": 361.20999908447266, "rew_std": 66.4414985994371, "Agent": "c51"}, {"env_step": 6000000, "rew": 379.6600006103516, "rew_std": 33.6160148677817, "Agent": "c51"}, {"env_step": 6100000, "rew": 376.65, "rew_std": 38.38656729834399, "Agent": "c51"}, {"env_step": 6200000, "rew": 401.14000244140624, "rew_std": 25.221350146221273, "Agent": "c51"}, {"env_step": 6300000, "rew": 373.94000244140625, "rew_std": 43.70732540398474, "Agent": "c51"}, {"env_step": 6400000, "rew": 383.6600006103516, "rew_std": 21.5665598368943, "Agent": "c51"}, {"env_step": 6500000, "rew": 396.2099945068359, "rew_std": 23.735648538855735, "Agent": "c51"}, {"env_step": 6600000, "rew": 379.5899993896484, "rew_std": 37.21383152538017, "Agent": "c51"}, {"env_step": 6700000, "rew": 369.11000366210936, "rew_std": 33.04576961471482, "Agent": "c51"}, {"env_step": 6800000, "rew": 384.3999969482422, "rew_std": 43.737378807281985, "Agent": "c51"}, {"env_step": 6900000, "rew": 392.86000366210936, "rew_std": 25.008686655391916, "Agent": "c51"}, {"env_step": 7000000, "rew": 388.18999938964845, "rew_std": 41.67885269592273, "Agent": "c51"}, {"env_step": 7100000, "rew": 400.7399963378906, "rew_std": 31.629359514939143, "Agent": "c51"}, {"env_step": 7200000, "rew": 392.60999755859376, "rew_std": 31.733938484043435, "Agent": "c51"}, {"env_step": 7300000, "rew": 381.43999633789065, "rew_std": 28.448867547263603, "Agent": "c51"}, {"env_step": 7400000, "rew": 383.7099975585937, "rew_std": 27.319865345659487, "Agent": "c51"}, {"env_step": 7500000, "rew": 400.85, "rew_std": 24.355008972204978, "Agent": "c51"}, {"env_step": 7600000, "rew": 370.6299987792969, "rew_std": 61.18848259333839, "Agent": "c51"}, {"env_step": 7700000, "rew": 377.75999755859374, "rew_std": 29.30734371334747, "Agent": "c51"}, {"env_step": 7800000, "rew": 366.9, "rew_std": 61.976462247217256, "Agent": "c51"}, {"env_step": 7900000, "rew": 381.92999877929685, "rew_std": 42.68440563855683, "Agent": "c51"}, {"env_step": 8000000, "rew": 389.1, "rew_std": 30.273154962392713, "Agent": "c51"}, {"env_step": 8100000, "rew": 377.89000244140624, "rew_std": 64.2143809894212, "Agent": "c51"}, {"env_step": 8200000, "rew": 369.98999786376953, "rew_std": 51.384270079498556, "Agent": "c51"}, {"env_step": 8300000, "rew": 363.1800018310547, "rew_std": 53.85413212986565, "Agent": "c51"}, {"env_step": 8400000, "rew": 402.32000122070315, "rew_std": 47.67640870874529, "Agent": "c51"}, {"env_step": 8500000, "rew": 412.0899993896484, "rew_std": 49.33683150503912, "Agent": "c51"}, {"env_step": 8600000, "rew": 362.84000396728516, "rew_std": 95.5870313016572, "Agent": "c51"}, {"env_step": 8700000, "rew": 374.8699951171875, "rew_std": 35.3107884710375, "Agent": "c51"}, {"env_step": 8800000, "rew": 387.94000244140625, "rew_std": 45.46937780145519, "Agent": "c51"}, {"env_step": 8900000, "rew": 390.63999938964844, "rew_std": 27.526135712106345, "Agent": "c51"}, {"env_step": 9000000, "rew": 389.9, "rew_std": 36.79899479363974, "Agent": "c51"}, {"env_step": 9100000, "rew": 401.82000122070315, "rew_std": 28.145932045294533, "Agent": "c51"}, {"env_step": 9200000, "rew": 378.2200042724609, "rew_std": 44.36847920124076, "Agent": "c51"}, {"env_step": 9300000, "rew": 378.5300048828125, "rew_std": 20.175735548110744, "Agent": "c51"}, {"env_step": 9400000, "rew": 411.37000427246096, "rew_std": 26.821631968299695, "Agent": "c51"}, {"env_step": 9500000, "rew": 398.6599945068359, "rew_std": 36.35720540569888, "Agent": "c51"}, {"env_step": 9600000, "rew": 386.3699981689453, "rew_std": 28.54824021627022, "Agent": "c51"}, {"env_step": 9700000, "rew": 389.52000122070314, "rew_std": 61.006069637114855, "Agent": "c51"}, {"env_step": 9800000, "rew": 412.8700012207031, "rew_std": 35.838668839394245, "Agent": "c51"}, {"env_step": 9900000, "rew": 399.7899993896484, "rew_std": 36.742630957118074, "Agent": "c51"}, {"env_step": 10000000, "rew": 393.1800048828125, "rew_std": 20.823774222066966, "Agent": "c51"}, {"env_step": 0, "rew": 2.3299999952316286, "rew_std": 1.0705605633034625, "Agent": "dqn"}, {"env_step": 100000, "rew": 7.849999904632568, "rew_std": 1.8200273963417377, "Agent": "dqn"}, {"env_step": 200000, "rew": 17.30000009536743, "rew_std": 9.143740794241733, "Agent": "dqn"}, {"env_step": 300000, "rew": 20.85, "rew_std": 4.622391089254499, "Agent": "dqn"}, {"env_step": 400000, "rew": 20.7, "rew_std": 5.0768098447216925, "Agent": "dqn"}, {"env_step": 500000, "rew": 25.080000114440917, "rew_std": 3.1577840761532077, "Agent": "dqn"}, {"env_step": 600000, "rew": 28.130000114440918, "rew_std": 3.998262178859403, "Agent": "dqn"}, {"env_step": 700000, "rew": 35.89000053405762, "rew_std": 6.575477441314717, "Agent": "dqn"}, {"env_step": 800000, "rew": 40.44000015258789, "rew_std": 3.872260440437905, "Agent": "dqn"}, {"env_step": 900000, "rew": 39.329999923706055, "rew_std": 8.570536398642876, "Agent": "dqn"}, {"env_step": 1000000, "rew": 41.73000030517578, "rew_std": 9.01854193615441, "Agent": "dqn"}, {"env_step": 1100000, "rew": 46.03999996185303, "rew_std": 13.705852687158178, "Agent": "dqn"}, {"env_step": 1200000, "rew": 46.45000019073486, "rew_std": 8.782511298663483, "Agent": "dqn"}, {"env_step": 1300000, "rew": 47.709999084472656, "rew_std": 8.820708473452349, "Agent": "dqn"}, {"env_step": 1400000, "rew": 53.69000015258789, "rew_std": 8.472597260377047, "Agent": "dqn"}, {"env_step": 1500000, "rew": 56.78999996185303, "rew_std": 19.832672991347415, "Agent": "dqn"}, {"env_step": 1600000, "rew": 47.09000072479248, "rew_std": 14.559289729405599, "Agent": "dqn"}, {"env_step": 1700000, "rew": 51.7899995803833, "rew_std": 15.683395606454836, "Agent": "dqn"}, {"env_step": 1800000, "rew": 49.360000419616696, "rew_std": 14.638525024332282, "Agent": "dqn"}, {"env_step": 1900000, "rew": 52.28000049591064, "rew_std": 13.878459272042756, "Agent": "dqn"}, {"env_step": 2000000, "rew": 50.32000026702881, "rew_std": 12.656760859312767, "Agent": "dqn"}, {"env_step": 2100000, "rew": 50.69000015258789, "rew_std": 7.424345019289965, "Agent": "dqn"}, {"env_step": 2200000, "rew": 57.699999237060545, "rew_std": 26.776631520034286, "Agent": "dqn"}, {"env_step": 2300000, "rew": 51.079999923706055, "rew_std": 16.490165999706548, "Agent": "dqn"}, {"env_step": 2400000, "rew": 56.36000061035156, "rew_std": 10.570827964178985, "Agent": "dqn"}, {"env_step": 2500000, "rew": 61.59000091552734, "rew_std": 20.03838606299226, "Agent": "dqn"}, {"env_step": 2600000, "rew": 49.81999969482422, "rew_std": 11.181932063887603, "Agent": "dqn"}, {"env_step": 2700000, "rew": 50.91000061035156, "rew_std": 6.640550189836024, "Agent": "dqn"}, {"env_step": 2800000, "rew": 51.87000007629395, "rew_std": 16.15952039351572, "Agent": "dqn"}, {"env_step": 2900000, "rew": 55.97000026702881, "rew_std": 12.618800305474606, "Agent": "dqn"}, {"env_step": 3000000, "rew": 62.70000076293945, "rew_std": 15.349006905504705, "Agent": "dqn"}, {"env_step": 3100000, "rew": 62.3, "rew_std": 7.967809614396711, "Agent": "dqn"}, {"env_step": 3200000, "rew": 57.64999961853027, "rew_std": 14.362815684542023, "Agent": "dqn"}, {"env_step": 3300000, "rew": 63.35, "rew_std": 17.445471857828533, "Agent": "dqn"}, {"env_step": 3400000, "rew": 61.999999618530275, "rew_std": 9.367175853731641, "Agent": "dqn"}, {"env_step": 3500000, "rew": 58.58999938964844, "rew_std": 7.702006216494004, "Agent": "dqn"}, {"env_step": 3600000, "rew": 59.83000030517578, "rew_std": 11.262597633065775, "Agent": "dqn"}, {"env_step": 3700000, "rew": 66.31999969482422, "rew_std": 12.887497638750386, "Agent": "dqn"}, {"env_step": 3800000, "rew": 60.06000022888183, "rew_std": 8.369014666496854, "Agent": "dqn"}, {"env_step": 3900000, "rew": 68.73999938964843, "rew_std": 16.99924625009274, "Agent": "dqn"}, {"env_step": 4000000, "rew": 65.21999893188476, "rew_std": 13.805780604536238, "Agent": "dqn"}, {"env_step": 4100000, "rew": 65.03999938964844, "rew_std": 24.91546420245774, "Agent": "dqn"}, {"env_step": 4200000, "rew": 57.69999961853027, "rew_std": 19.0262977849217, "Agent": "dqn"}, {"env_step": 4300000, "rew": 72.17000045776368, "rew_std": 17.75004520954237, "Agent": "dqn"}, {"env_step": 4400000, "rew": 69.8600009918213, "rew_std": 6.844588638105477, "Agent": "dqn"}, {"env_step": 4500000, "rew": 61.260000228881836, "rew_std": 11.948741149739522, "Agent": "dqn"}, {"env_step": 4600000, "rew": 67.55, "rew_std": 8.807752424830234, "Agent": "dqn"}, {"env_step": 4700000, "rew": 70.23999977111816, "rew_std": 8.630087095950666, "Agent": "dqn"}, {"env_step": 4800000, "rew": 72.41999969482421, "rew_std": 12.884859861619276, "Agent": "dqn"}, {"env_step": 4900000, "rew": 73.97000045776367, "rew_std": 13.834816048918906, "Agent": "dqn"}, {"env_step": 5000000, "rew": 74.23000030517578, "rew_std": 13.696208341395574, "Agent": "dqn"}, {"env_step": 5100000, "rew": 67.39999885559082, "rew_std": 15.599165755242288, "Agent": "dqn"}, {"env_step": 5200000, "rew": 85.09999961853028, "rew_std": 31.04148759266113, "Agent": "dqn"}, {"env_step": 5300000, "rew": 69.8799991607666, "rew_std": 17.921539788764512, "Agent": "dqn"}, {"env_step": 5400000, "rew": 73.52999992370606, "rew_std": 21.51474045621979, "Agent": "dqn"}, {"env_step": 5500000, "rew": 77.47999992370606, "rew_std": 13.142054243069717, "Agent": "dqn"}, {"env_step": 5600000, "rew": 81.1799991607666, "rew_std": 20.257827947382573, "Agent": "dqn"}, {"env_step": 5700000, "rew": 91.23000030517578, "rew_std": 32.801493962015556, "Agent": "dqn"}, {"env_step": 5800000, "rew": 82.73000068664551, "rew_std": 19.65411187086091, "Agent": "dqn"}, {"env_step": 5900000, "rew": 81.49000053405761, "rew_std": 18.301501103980517, "Agent": "dqn"}, {"env_step": 6000000, "rew": 81.15999946594238, "rew_std": 17.353743105571674, "Agent": "dqn"}, {"env_step": 6100000, "rew": 78.85, "rew_std": 23.848828094928958, "Agent": "dqn"}, {"env_step": 6200000, "rew": 73.11000022888183, "rew_std": 16.69673268597132, "Agent": "dqn"}, {"env_step": 6300000, "rew": 91.60000190734863, "rew_std": 27.26793083658477, "Agent": "dqn"}, {"env_step": 6400000, "rew": 91.17999992370605, "rew_std": 17.885123036862367, "Agent": "dqn"}, {"env_step": 6500000, "rew": 85.3, "rew_std": 27.91558722418007, "Agent": "dqn"}, {"env_step": 6600000, "rew": 78.09999961853028, "rew_std": 18.799307603869114, "Agent": "dqn"}, {"env_step": 6700000, "rew": 73.22999992370606, "rew_std": 15.534029981853278, "Agent": "dqn"}, {"env_step": 6800000, "rew": 98.13000106811523, "rew_std": 20.634877163905355, "Agent": "dqn"}, {"env_step": 6900000, "rew": 89.0, "rew_std": 24.171967172300295, "Agent": "dqn"}, {"env_step": 7000000, "rew": 94.03000106811524, "rew_std": 22.807151580126355, "Agent": "dqn"}, {"env_step": 7100000, "rew": 91.46000061035156, "rew_std": 23.31502638921496, "Agent": "dqn"}, {"env_step": 7200000, "rew": 83.65999984741211, "rew_std": 16.474659802642254, "Agent": "dqn"}, {"env_step": 7300000, "rew": 80.12000045776367, "rew_std": 11.930952073355787, "Agent": "dqn"}, {"env_step": 7400000, "rew": 77.42999877929688, "rew_std": 19.180929987122642, "Agent": "dqn"}, {"env_step": 7500000, "rew": 86.84999923706054, "rew_std": 17.592342521566607, "Agent": "dqn"}, {"env_step": 7600000, "rew": 82.52999992370606, "rew_std": 20.004302258249783, "Agent": "dqn"}, {"env_step": 7700000, "rew": 107.37000045776367, "rew_std": 34.55555067760397, "Agent": "dqn"}, {"env_step": 7800000, "rew": 96.06000061035157, "rew_std": 22.622474586994528, "Agent": "dqn"}, {"env_step": 7900000, "rew": 94.13999977111817, "rew_std": 28.41982322868775, "Agent": "dqn"}, {"env_step": 8000000, "rew": 101.21000061035156, "rew_std": 26.587306822196613, "Agent": "dqn"}, {"env_step": 8100000, "rew": 85.82999992370605, "rew_std": 25.36777673896879, "Agent": "dqn"}, {"env_step": 8200000, "rew": 87.62999954223633, "rew_std": 39.89198041691922, "Agent": "dqn"}, {"env_step": 8300000, "rew": 101.57999839782715, "rew_std": 41.77749993495092, "Agent": "dqn"}, {"env_step": 8400000, "rew": 93.03000068664551, "rew_std": 27.115569981755712, "Agent": "dqn"}, {"env_step": 8500000, "rew": 87.7099998474121, "rew_std": 36.714669130069105, "Agent": "dqn"}, {"env_step": 8600000, "rew": 90.95000076293945, "rew_std": 12.810327317331756, "Agent": "dqn"}, {"env_step": 8700000, "rew": 104.0099998474121, "rew_std": 32.22671128294261, "Agent": "dqn"}, {"env_step": 8800000, "rew": 103.99000091552735, "rew_std": 27.8962892258018, "Agent": "dqn"}, {"env_step": 8900000, "rew": 114.25000076293945, "rew_std": 32.67464640396138, "Agent": "dqn"}, {"env_step": 9000000, "rew": 106.80000038146973, "rew_std": 32.973262202331064, "Agent": "dqn"}, {"env_step": 9100000, "rew": 98.43000030517578, "rew_std": 20.662141740244365, "Agent": "dqn"}, {"env_step": 9200000, "rew": 88.96000099182129, "rew_std": 33.323092362700145, "Agent": "dqn"}, {"env_step": 9300000, "rew": 121.35000076293946, "rew_std": 25.42248087718444, "Agent": "dqn"}, {"env_step": 9400000, "rew": 114.03000106811524, "rew_std": 31.284407594353972, "Agent": "dqn"}, {"env_step": 9500000, "rew": 125.76000061035157, "rew_std": 36.260922820967124, "Agent": "dqn"}, {"env_step": 9600000, "rew": 100.2, "rew_std": 41.19708629230755, "Agent": "dqn"}, {"env_step": 9700000, "rew": 122.76999969482422, "rew_std": 33.774755897887985, "Agent": "dqn"}, {"env_step": 9800000, "rew": 133.53999938964844, "rew_std": 44.59944213785166, "Agent": "dqn"}, {"env_step": 9900000, "rew": 111.62000198364258, "rew_std": 32.61686090929321, "Agent": "dqn"}, {"env_step": 10000000, "rew": 119.93999862670898, "rew_std": 32.37928701114132, "Agent": "dqn"}, {"env_step": 0, "rew": 1.8899999886751175, "rew_std": 1.2739309645395656, "Agent": "fqf"}, {"env_step": 100000, "rew": 6.579999995231629, "rew_std": 2.0439177558947796, "Agent": "fqf"}, {"env_step": 200000, "rew": 15.449999904632568, "rew_std": 1.902235310873203, "Agent": "fqf"}, {"env_step": 300000, "rew": 15.630000019073487, "rew_std": 1.6285271252770628, "Agent": "fqf"}, {"env_step": 400000, "rew": 20.410000228881835, "rew_std": 2.727068191039453, "Agent": "fqf"}, {"env_step": 500000, "rew": 27.5, "rew_std": 5.404812646912962, "Agent": "fqf"}, {"env_step": 600000, "rew": 34.10999984741211, "rew_std": 5.110469271319042, "Agent": "fqf"}, {"env_step": 700000, "rew": 44.479999732971194, "rew_std": 9.641866944004914, "Agent": "fqf"}, {"env_step": 800000, "rew": 58.929999923706056, "rew_std": 10.773119278904277, "Agent": "fqf"}, {"env_step": 900000, "rew": 74.28000030517578, "rew_std": 18.03883545662945, "Agent": "fqf"}, {"env_step": 1000000, "rew": 91.27000007629394, "rew_std": 19.16142200039165, "Agent": "fqf"}, {"env_step": 1100000, "rew": 89.1100009918213, "rew_std": 13.826457342888988, "Agent": "fqf"}, {"env_step": 1200000, "rew": 94.35999984741211, "rew_std": 42.61833485770056, "Agent": "fqf"}, {"env_step": 1300000, "rew": 111.60999984741211, "rew_std": 28.842728288197964, "Agent": "fqf"}, {"env_step": 1400000, "rew": 137.2, "rew_std": 25.806123286033692, "Agent": "fqf"}, {"env_step": 1500000, "rew": 157.3999984741211, "rew_std": 31.331581823972737, "Agent": "fqf"}, {"env_step": 1600000, "rew": 166.1099998474121, "rew_std": 31.462182753954036, "Agent": "fqf"}, {"env_step": 1700000, "rew": 189.7300018310547, "rew_std": 38.91149203014107, "Agent": "fqf"}, {"env_step": 1800000, "rew": 206.3300003051758, "rew_std": 38.15114780212248, "Agent": "fqf"}, {"env_step": 1900000, "rew": 224.5, "rew_std": 38.52728870047049, "Agent": "fqf"}, {"env_step": 2000000, "rew": 222.76000061035157, "rew_std": 46.865021102153406, "Agent": "fqf"}, {"env_step": 2100000, "rew": 245.28999938964844, "rew_std": 49.371034931773835, "Agent": "fqf"}, {"env_step": 2200000, "rew": 267.70999908447266, "rew_std": 28.133131370013544, "Agent": "fqf"}, {"env_step": 2300000, "rew": 268.58999786376955, "rew_std": 49.59553152130905, "Agent": "fqf"}, {"env_step": 2400000, "rew": 247.30999908447265, "rew_std": 63.79005209586264, "Agent": "fqf"}, {"env_step": 2500000, "rew": 281.91000213623045, "rew_std": 36.61752006054675, "Agent": "fqf"}, {"env_step": 2600000, "rew": 273.11999969482423, "rew_std": 53.78060736276747, "Agent": "fqf"}, {"env_step": 2700000, "rew": 299.9100006103516, "rew_std": 30.088116602832, "Agent": "fqf"}, {"env_step": 2800000, "rew": 297.63999938964844, "rew_std": 55.8203918208197, "Agent": "fqf"}, {"env_step": 2900000, "rew": 309.13000335693357, "rew_std": 49.09570374572412, "Agent": "fqf"}, {"env_step": 3000000, "rew": 315.32999725341796, "rew_std": 43.39456033965058, "Agent": "fqf"}, {"env_step": 3100000, "rew": 279.22000274658205, "rew_std": 85.23476924161254, "Agent": "fqf"}, {"env_step": 3200000, "rew": 297.58000030517576, "rew_std": 63.087017417518, "Agent": "fqf"}, {"env_step": 3300000, "rew": 301.73999938964846, "rew_std": 49.87208136848178, "Agent": "fqf"}, {"env_step": 3400000, "rew": 275.0300010681152, "rew_std": 68.53272323238663, "Agent": "fqf"}, {"env_step": 3500000, "rew": 309.43999938964845, "rew_std": 32.276252757870395, "Agent": "fqf"}, {"env_step": 3600000, "rew": 319.67999572753905, "rew_std": 42.78061813129124, "Agent": "fqf"}, {"env_step": 3700000, "rew": 332.6600006103516, "rew_std": 33.96071915762684, "Agent": "fqf"}, {"env_step": 3800000, "rew": 355.36000061035156, "rew_std": 28.662765200015865, "Agent": "fqf"}, {"env_step": 3900000, "rew": 316.99999542236327, "rew_std": 53.55531395626551, "Agent": "fqf"}, {"env_step": 4000000, "rew": 319.78999786376954, "rew_std": 80.64893371138278, "Agent": "fqf"}, {"env_step": 4100000, "rew": 345.6599975585938, "rew_std": 25.589495375720983, "Agent": "fqf"}, {"env_step": 4200000, "rew": 338.1300048828125, "rew_std": 25.35192511248875, "Agent": "fqf"}, {"env_step": 4300000, "rew": 331.38999633789064, "rew_std": 38.176389861042445, "Agent": "fqf"}, {"env_step": 4400000, "rew": 328.5199996948242, "rew_std": 88.22566393672108, "Agent": "fqf"}, {"env_step": 4500000, "rew": 356.0900024414062, "rew_std": 26.407973243051238, "Agent": "fqf"}, {"env_step": 4600000, "rew": 330.85999755859376, "rew_std": 52.67645135396039, "Agent": "fqf"}, {"env_step": 4700000, "rew": 363.3799987792969, "rew_std": 35.14901906852587, "Agent": "fqf"}, {"env_step": 4800000, "rew": 364.2799987792969, "rew_std": 22.464898788730974, "Agent": "fqf"}, {"env_step": 4900000, "rew": 342.62999572753904, "rew_std": 31.664115177564884, "Agent": "fqf"}, {"env_step": 5000000, "rew": 326.9600006103516, "rew_std": 44.123470427909474, "Agent": "fqf"}, {"env_step": 5100000, "rew": 342.51000061035154, "rew_std": 42.72719409460211, "Agent": "fqf"}, {"env_step": 5200000, "rew": 368.5100036621094, "rew_std": 36.07728802793028, "Agent": "fqf"}, {"env_step": 5300000, "rew": 339.4, "rew_std": 37.085738197603874, "Agent": "fqf"}, {"env_step": 5400000, "rew": 339.8399963378906, "rew_std": 29.956508084411002, "Agent": "fqf"}, {"env_step": 5500000, "rew": 329.1000015258789, "rew_std": 52.50864780548441, "Agent": "fqf"}, {"env_step": 5600000, "rew": 345.49000244140626, "rew_std": 25.85855984648126, "Agent": "fqf"}, {"env_step": 5700000, "rew": 348.35999908447263, "rew_std": 53.21148820407242, "Agent": "fqf"}, {"env_step": 5800000, "rew": 344.5499984741211, "rew_std": 46.83556422879634, "Agent": "fqf"}, {"env_step": 5900000, "rew": 350.65999908447264, "rew_std": 54.664820479976136, "Agent": "fqf"}, {"env_step": 6000000, "rew": 346.4800033569336, "rew_std": 62.894860468597145, "Agent": "fqf"}, {"env_step": 6100000, "rew": 350.48999938964846, "rew_std": 26.15692816124583, "Agent": "fqf"}, {"env_step": 6200000, "rew": 362.8300048828125, "rew_std": 36.7557307025397, "Agent": "fqf"}, {"env_step": 6300000, "rew": 328.7800018310547, "rew_std": 47.65588972046255, "Agent": "fqf"}, {"env_step": 6400000, "rew": 364.5799987792969, "rew_std": 26.335103106220885, "Agent": "fqf"}, {"env_step": 6500000, "rew": 318.6900009155273, "rew_std": 56.799413057921996, "Agent": "fqf"}, {"env_step": 6600000, "rew": 371.51999816894534, "rew_std": 26.211210500163315, "Agent": "fqf"}, {"env_step": 6700000, "rew": 347.8500030517578, "rew_std": 32.235581327971936, "Agent": "fqf"}, {"env_step": 6800000, "rew": 372.7200042724609, "rew_std": 28.695146937057064, "Agent": "fqf"}, {"env_step": 6900000, "rew": 361.5699981689453, "rew_std": 43.78193476682266, "Agent": "fqf"}, {"env_step": 7000000, "rew": 317.13000030517577, "rew_std": 49.348357840104256, "Agent": "fqf"}, {"env_step": 7100000, "rew": 340.74000091552733, "rew_std": 77.68987471874023, "Agent": "fqf"}, {"env_step": 7200000, "rew": 344.6399978637695, "rew_std": 59.30083036122553, "Agent": "fqf"}, {"env_step": 7300000, "rew": 341.8500015258789, "rew_std": 40.51899094909766, "Agent": "fqf"}, {"env_step": 7400000, "rew": 359.8500030517578, "rew_std": 20.46920835454293, "Agent": "fqf"}, {"env_step": 7500000, "rew": 342.1299987792969, "rew_std": 52.97516510033074, "Agent": "fqf"}, {"env_step": 7600000, "rew": 348.57000122070315, "rew_std": 27.585034799201082, "Agent": "fqf"}, {"env_step": 7700000, "rew": 347.2500030517578, "rew_std": 27.34297686518728, "Agent": "fqf"}, {"env_step": 7800000, "rew": 345.8099990844727, "rew_std": 71.28503787650662, "Agent": "fqf"}, {"env_step": 7900000, "rew": 355.4199981689453, "rew_std": 48.30798976082001, "Agent": "fqf"}, {"env_step": 8000000, "rew": 373.85, "rew_std": 25.45906797131143, "Agent": "fqf"}, {"env_step": 8100000, "rew": 346.7200012207031, "rew_std": 34.270798637584605, "Agent": "fqf"}, {"env_step": 8200000, "rew": 365.01000213623047, "rew_std": 59.017802669324645, "Agent": "fqf"}, {"env_step": 8300000, "rew": 334.4899932861328, "rew_std": 60.103334841706335, "Agent": "fqf"}, {"env_step": 8400000, "rew": 355.2200012207031, "rew_std": 30.91270427542967, "Agent": "fqf"}, {"env_step": 8500000, "rew": 359.00999755859374, "rew_std": 34.8320641579841, "Agent": "fqf"}, {"env_step": 8600000, "rew": 336.2000045776367, "rew_std": 45.92332986368791, "Agent": "fqf"}, {"env_step": 8700000, "rew": 371.00999755859374, "rew_std": 35.20410344429364, "Agent": "fqf"}, {"env_step": 8800000, "rew": 345.88999938964844, "rew_std": 27.24296297805744, "Agent": "fqf"}, {"env_step": 8900000, "rew": 327.00999603271487, "rew_std": 62.46383398874547, "Agent": "fqf"}, {"env_step": 9000000, "rew": 367.0299957275391, "rew_std": 24.968425234757614, "Agent": "fqf"}, {"env_step": 9100000, "rew": 357.7699951171875, "rew_std": 31.722642179225197, "Agent": "fqf"}, {"env_step": 9200000, "rew": 351.0800018310547, "rew_std": 40.11283828329421, "Agent": "fqf"}, {"env_step": 9300000, "rew": 363.8500030517578, "rew_std": 21.079715798651307, "Agent": "fqf"}, {"env_step": 9400000, "rew": 342.8400054931641, "rew_std": 63.49006749721425, "Agent": "fqf"}, {"env_step": 9500000, "rew": 373.99000091552733, "rew_std": 45.288134172835484, "Agent": "fqf"}, {"env_step": 9600000, "rew": 356.8399993896484, "rew_std": 47.543644666957135, "Agent": "fqf"}, {"env_step": 9700000, "rew": 364.5300018310547, "rew_std": 23.580712727396694, "Agent": "fqf"}, {"env_step": 9800000, "rew": 382.63999633789064, "rew_std": 29.47945522804052, "Agent": "fqf"}, {"env_step": 9900000, "rew": 353.19000244140625, "rew_std": 41.2007373443112, "Agent": "fqf"}, {"env_step": 10000000, "rew": 335.73000183105466, "rew_std": 36.178617059142844, "Agent": "fqf"}, {"env_step": 0, "rew": 2.0200000025331972, "rew_std": 1.062826409813908, "Agent": "qrdqn"}, {"env_step": 100000, "rew": 11.1, "rew_std": 1.8363006019761408, "Agent": "qrdqn"}, {"env_step": 200000, "rew": 15.950000190734864, "rew_std": 2.2743131208480416, "Agent": "qrdqn"}, {"env_step": 300000, "rew": 21.6100004196167, "rew_std": 1.6908279939018895, "Agent": "qrdqn"}, {"env_step": 400000, "rew": 26.039999961853027, "rew_std": 3.166764513719337, "Agent": "qrdqn"}, {"env_step": 500000, "rew": 31.70999984741211, "rew_std": 4.500988535661596, "Agent": "qrdqn"}, {"env_step": 600000, "rew": 34.890000343322754, "rew_std": 8.884306606312624, "Agent": "qrdqn"}, {"env_step": 700000, "rew": 39.35000057220459, "rew_std": 11.433481420905531, "Agent": "qrdqn"}, {"env_step": 800000, "rew": 45.01999969482422, "rew_std": 7.258622758189306, "Agent": "qrdqn"}, {"env_step": 900000, "rew": 51.72000007629394, "rew_std": 12.04224138466468, "Agent": "qrdqn"}, {"env_step": 1000000, "rew": 48.96999988555908, "rew_std": 9.053513212898372, "Agent": "qrdqn"}, {"env_step": 1100000, "rew": 66.27000007629394, "rew_std": 17.327149763903382, "Agent": "qrdqn"}, {"env_step": 1200000, "rew": 67.5099998474121, "rew_std": 15.202332343099394, "Agent": "qrdqn"}, {"env_step": 1300000, "rew": 47.89999961853027, "rew_std": 13.158266496317525, "Agent": "qrdqn"}, {"env_step": 1400000, "rew": 54.50000057220459, "rew_std": 12.41821272111614, "Agent": "qrdqn"}, {"env_step": 1500000, "rew": 71.68000030517578, "rew_std": 12.78567937628033, "Agent": "qrdqn"}, {"env_step": 1600000, "rew": 65.53000030517578, "rew_std": 15.931104640556512, "Agent": "qrdqn"}, {"env_step": 1700000, "rew": 65.3899990081787, "rew_std": 14.402253328610442, "Agent": "qrdqn"}, {"env_step": 1800000, "rew": 69.67999954223633, "rew_std": 13.667245828132831, "Agent": "qrdqn"}, {"env_step": 1900000, "rew": 104.74000091552735, "rew_std": 41.49964546688277, "Agent": "qrdqn"}, {"env_step": 2000000, "rew": 64.65, "rew_std": 15.356839508305148, "Agent": "qrdqn"}, {"env_step": 2100000, "rew": 67.09999961853028, "rew_std": 20.951228029225863, "Agent": "qrdqn"}, {"env_step": 2200000, "rew": 81.0700008392334, "rew_std": 28.43526185117888, "Agent": "qrdqn"}, {"env_step": 2300000, "rew": 69.75000076293945, "rew_std": 15.399562151245634, "Agent": "qrdqn"}, {"env_step": 2400000, "rew": 93.77000045776367, "rew_std": 48.462461876453574, "Agent": "qrdqn"}, {"env_step": 2500000, "rew": 90.25000076293945, "rew_std": 37.74602613741299, "Agent": "qrdqn"}, {"env_step": 2600000, "rew": 90.18000106811523, "rew_std": 29.855881757158723, "Agent": "qrdqn"}, {"env_step": 2700000, "rew": 99.36000061035156, "rew_std": 40.36459324093467, "Agent": "qrdqn"}, {"env_step": 2800000, "rew": 81.06999969482422, "rew_std": 39.89363303080445, "Agent": "qrdqn"}, {"env_step": 2900000, "rew": 95.2400001525879, "rew_std": 43.06056686630658, "Agent": "qrdqn"}, {"env_step": 3000000, "rew": 104.05000076293945, "rew_std": 28.336240349879567, "Agent": "qrdqn"}, {"env_step": 3100000, "rew": 93.90999908447266, "rew_std": 27.109681802005888, "Agent": "qrdqn"}, {"env_step": 3200000, "rew": 107.66000022888184, "rew_std": 40.67018922208847, "Agent": "qrdqn"}, {"env_step": 3300000, "rew": 93.79000053405761, "rew_std": 30.64237221598378, "Agent": "qrdqn"}, {"env_step": 3400000, "rew": 108.4900001525879, "rew_std": 41.321239097052164, "Agent": "qrdqn"}, {"env_step": 3500000, "rew": 124.85, "rew_std": 42.0317792485268, "Agent": "qrdqn"}, {"env_step": 3600000, "rew": 116.13000106811523, "rew_std": 43.9727670450532, "Agent": "qrdqn"}, {"env_step": 3700000, "rew": 111.63999977111817, "rew_std": 43.274708584109035, "Agent": "qrdqn"}, {"env_step": 3800000, "rew": 106.15, "rew_std": 45.770651300224905, "Agent": "qrdqn"}, {"env_step": 3900000, "rew": 129.54000015258788, "rew_std": 46.98308583368676, "Agent": "qrdqn"}, {"env_step": 4000000, "rew": 115.51000137329102, "rew_std": 36.304806385435626, "Agent": "qrdqn"}, {"env_step": 4100000, "rew": 129.35999908447266, "rew_std": 31.217724397411512, "Agent": "qrdqn"}, {"env_step": 4200000, "rew": 135.28999938964844, "rew_std": 26.868697494551604, "Agent": "qrdqn"}, {"env_step": 4300000, "rew": 126.12999877929687, "rew_std": 41.017608258673214, "Agent": "qrdqn"}, {"env_step": 4400000, "rew": 136.70000076293945, "rew_std": 50.228299617735615, "Agent": "qrdqn"}, {"env_step": 4500000, "rew": 120.93000183105468, "rew_std": 41.47946862067106, "Agent": "qrdqn"}, {"env_step": 4600000, "rew": 138.7300003051758, "rew_std": 50.09087919140305, "Agent": "qrdqn"}, {"env_step": 4700000, "rew": 111.62000122070313, "rew_std": 28.04349582228607, "Agent": "qrdqn"}, {"env_step": 4800000, "rew": 139.99000091552733, "rew_std": 46.49149126588381, "Agent": "qrdqn"}, {"env_step": 4900000, "rew": 141.19000167846679, "rew_std": 55.482997939665566, "Agent": "qrdqn"}, {"env_step": 5000000, "rew": 141.16000061035157, "rew_std": 29.674102239383583, "Agent": "qrdqn"}, {"env_step": 5100000, "rew": 149.54000015258788, "rew_std": 40.483557547800565, "Agent": "qrdqn"}, {"env_step": 5200000, "rew": 138.81000061035155, "rew_std": 25.27221522764106, "Agent": "qrdqn"}, {"env_step": 5300000, "rew": 154.86000289916993, "rew_std": 47.20214481743621, "Agent": "qrdqn"}, {"env_step": 5400000, "rew": 152.83000259399415, "rew_std": 35.65692378910702, "Agent": "qrdqn"}, {"env_step": 5500000, "rew": 165.04999923706055, "rew_std": 52.75731800579599, "Agent": "qrdqn"}, {"env_step": 5600000, "rew": 159.33999938964843, "rew_std": 39.30949528911902, "Agent": "qrdqn"}, {"env_step": 5700000, "rew": 157.33999938964843, "rew_std": 31.022483741005583, "Agent": "qrdqn"}, {"env_step": 5800000, "rew": 133.2099983215332, "rew_std": 30.196568926146423, "Agent": "qrdqn"}, {"env_step": 5900000, "rew": 154.22000122070312, "rew_std": 54.375359375832986, "Agent": "qrdqn"}, {"env_step": 6000000, "rew": 184.93000030517578, "rew_std": 30.234021002954563, "Agent": "qrdqn"}, {"env_step": 6100000, "rew": 163.85000305175782, "rew_std": 37.09256658071057, "Agent": "qrdqn"}, {"env_step": 6200000, "rew": 174.16000061035157, "rew_std": 45.05503855203985, "Agent": "qrdqn"}, {"env_step": 6300000, "rew": 134.93999862670898, "rew_std": 28.63627883914366, "Agent": "qrdqn"}, {"env_step": 6400000, "rew": 158.54000244140624, "rew_std": 58.80484980319214, "Agent": "qrdqn"}, {"env_step": 6500000, "rew": 163.03000259399414, "rew_std": 43.68508069770034, "Agent": "qrdqn"}, {"env_step": 6600000, "rew": 171.29000091552734, "rew_std": 25.585327282680602, "Agent": "qrdqn"}, {"env_step": 6700000, "rew": 175.55999755859375, "rew_std": 40.4673989589574, "Agent": "qrdqn"}, {"env_step": 6800000, "rew": 170.24999923706054, "rew_std": 40.89010527420232, "Agent": "qrdqn"}, {"env_step": 6900000, "rew": 174.58000106811522, "rew_std": 34.85977017255225, "Agent": "qrdqn"}, {"env_step": 7000000, "rew": 197.0900016784668, "rew_std": 40.22520027724406, "Agent": "qrdqn"}, {"env_step": 7100000, "rew": 185.90000076293944, "rew_std": 40.68073380527696, "Agent": "qrdqn"}, {"env_step": 7200000, "rew": 163.37999954223633, "rew_std": 29.362010642178937, "Agent": "qrdqn"}, {"env_step": 7300000, "rew": 175.2400001525879, "rew_std": 33.284537997701975, "Agent": "qrdqn"}, {"env_step": 7400000, "rew": 173.48999938964843, "rew_std": 38.9055120846502, "Agent": "qrdqn"}, {"env_step": 7500000, "rew": 180.91999893188478, "rew_std": 32.606096613717376, "Agent": "qrdqn"}, {"env_step": 7600000, "rew": 182.53000030517578, "rew_std": 39.06146447307148, "Agent": "qrdqn"}, {"env_step": 7700000, "rew": 202.53999938964844, "rew_std": 34.642085489481985, "Agent": "qrdqn"}, {"env_step": 7800000, "rew": 198.09000091552736, "rew_std": 37.14933188731783, "Agent": "qrdqn"}, {"env_step": 7900000, "rew": 173.42000045776368, "rew_std": 46.51472327478004, "Agent": "qrdqn"}, {"env_step": 8000000, "rew": 187.71000137329102, "rew_std": 42.45332867238581, "Agent": "qrdqn"}, {"env_step": 8100000, "rew": 178.81999816894532, "rew_std": 25.60713985608123, "Agent": "qrdqn"}, {"env_step": 8200000, "rew": 186.19999923706055, "rew_std": 58.60175754971055, "Agent": "qrdqn"}, {"env_step": 8300000, "rew": 183.31999893188475, "rew_std": 40.961731898305466, "Agent": "qrdqn"}, {"env_step": 8400000, "rew": 189.95, "rew_std": 19.00927386779873, "Agent": "qrdqn"}, {"env_step": 8500000, "rew": 192.83999938964843, "rew_std": 33.04204192587639, "Agent": "qrdqn"}, {"env_step": 8600000, "rew": 202.9000015258789, "rew_std": 23.599364986605206, "Agent": "qrdqn"}, {"env_step": 8700000, "rew": 186.10999908447266, "rew_std": 46.4904581707004, "Agent": "qrdqn"}, {"env_step": 8800000, "rew": 213.19000015258788, "rew_std": 44.943285014544436, "Agent": "qrdqn"}, {"env_step": 8900000, "rew": 202.45999603271486, "rew_std": 54.62318213575417, "Agent": "qrdqn"}, {"env_step": 9000000, "rew": 163.0799997329712, "rew_std": 66.0252791544666, "Agent": "qrdqn"}, {"env_step": 9100000, "rew": 189.61999969482423, "rew_std": 31.533055330348382, "Agent": "qrdqn"}, {"env_step": 9200000, "rew": 190.22999954223633, "rew_std": 44.21936226238718, "Agent": "qrdqn"}, {"env_step": 9300000, "rew": 206.0199996948242, "rew_std": 34.16480026287263, "Agent": "qrdqn"}, {"env_step": 9400000, "rew": 194.15, "rew_std": 31.170764052493578, "Agent": "qrdqn"}, {"env_step": 9500000, "rew": 198.09000091552736, "rew_std": 34.949348176785655, "Agent": "qrdqn"}, {"env_step": 9600000, "rew": 184.69000167846679, "rew_std": 36.43675343515362, "Agent": "qrdqn"}, {"env_step": 9700000, "rew": 177.53000259399414, "rew_std": 49.71995694653478, "Agent": "qrdqn"}, {"env_step": 9800000, "rew": 188.76999893188477, "rew_std": 54.15346713496224, "Agent": "qrdqn"}, {"env_step": 9900000, "rew": 228.30999908447265, "rew_std": 27.33559443989839, "Agent": "qrdqn"}, {"env_step": 10000000, "rew": 208.64000244140624, "rew_std": 42.412903043388575, "Agent": "qrdqn"}, {"env_step": 0, "rew": 2.0299999952316283, "rew_std": 0.6356886257408865, "Agent": "iqn"}, {"env_step": 100000, "rew": 10.880000019073487, "rew_std": 1.0235234666900987, "Agent": "iqn"}, {"env_step": 200000, "rew": 16.200000095367432, "rew_std": 2.41453951471406, "Agent": "iqn"}, {"env_step": 300000, "rew": 19.000000286102296, "rew_std": 3.212164382918219, "Agent": "iqn"}, {"env_step": 400000, "rew": 24.710000038146973, "rew_std": 4.045108260414139, "Agent": "iqn"}, {"env_step": 500000, "rew": 36.970000457763675, "rew_std": 6.761220184785156, "Agent": "iqn"}, {"env_step": 600000, "rew": 57.630000305175784, "rew_std": 17.247496577908297, "Agent": "iqn"}, {"env_step": 700000, "rew": 59.85, "rew_std": 15.685677915496768, "Agent": "iqn"}, {"env_step": 800000, "rew": 79.85, "rew_std": 15.545884075479117, "Agent": "iqn"}, {"env_step": 900000, "rew": 85.95, "rew_std": 18.034815034946483, "Agent": "iqn"}, {"env_step": 1000000, "rew": 105.31999969482422, "rew_std": 23.10228556713175, "Agent": "iqn"}, {"env_step": 1100000, "rew": 114.6, "rew_std": 42.868006897695565, "Agent": "iqn"}, {"env_step": 1200000, "rew": 100.68000030517578, "rew_std": 21.099801867092722, "Agent": "iqn"}, {"env_step": 1300000, "rew": 115.17000045776368, "rew_std": 24.481870429507214, "Agent": "iqn"}, {"env_step": 1400000, "rew": 148.81999969482422, "rew_std": 46.95697798485056, "Agent": "iqn"}, {"env_step": 1500000, "rew": 174.43000259399415, "rew_std": 65.91861982472467, "Agent": "iqn"}, {"env_step": 1600000, "rew": 177.42999954223632, "rew_std": 37.79439115810842, "Agent": "iqn"}, {"env_step": 1700000, "rew": 212.91999893188478, "rew_std": 38.73584462706583, "Agent": "iqn"}, {"env_step": 1800000, "rew": 239.47000274658203, "rew_std": 44.23055960487321, "Agent": "iqn"}, {"env_step": 1900000, "rew": 224.89999542236328, "rew_std": 35.47142441734377, "Agent": "iqn"}, {"env_step": 2000000, "rew": 255.11000366210936, "rew_std": 70.83536414074327, "Agent": "iqn"}, {"env_step": 2100000, "rew": 268.1800018310547, "rew_std": 46.30323741929793, "Agent": "iqn"}, {"env_step": 2200000, "rew": 251.55999755859375, "rew_std": 84.66489362401488, "Agent": "iqn"}, {"env_step": 2300000, "rew": 272.47999420166013, "rew_std": 42.157886622987164, "Agent": "iqn"}, {"env_step": 2400000, "rew": 315.0400024414063, "rew_std": 35.556779061887944, "Agent": "iqn"}, {"env_step": 2500000, "rew": 288.16000213623045, "rew_std": 75.40411670056966, "Agent": "iqn"}, {"env_step": 2600000, "rew": 282.24999847412107, "rew_std": 67.52896065218202, "Agent": "iqn"}, {"env_step": 2700000, "rew": 314.6600006103516, "rew_std": 33.30751400267159, "Agent": "iqn"}, {"env_step": 2800000, "rew": 292.0699981689453, "rew_std": 70.06501237480857, "Agent": "iqn"}, {"env_step": 2900000, "rew": 280.75, "rew_std": 53.51295587493006, "Agent": "iqn"}, {"env_step": 3000000, "rew": 320.45, "rew_std": 25.30736054112851, "Agent": "iqn"}, {"env_step": 3100000, "rew": 314.0599990844727, "rew_std": 44.75654636439621, "Agent": "iqn"}, {"env_step": 3200000, "rew": 331.3699981689453, "rew_std": 23.177711488651255, "Agent": "iqn"}, {"env_step": 3300000, "rew": 330.7899993896484, "rew_std": 30.686655193589775, "Agent": "iqn"}, {"env_step": 3400000, "rew": 302.63000030517577, "rew_std": 62.505184496442915, "Agent": "iqn"}, {"env_step": 3500000, "rew": 312.2699996948242, "rew_std": 60.86172947653031, "Agent": "iqn"}, {"env_step": 3600000, "rew": 337.6700012207031, "rew_std": 35.29467556366819, "Agent": "iqn"}, {"env_step": 3700000, "rew": 299.9600006103516, "rew_std": 71.5931756482553, "Agent": "iqn"}, {"env_step": 3800000, "rew": 246.19000167846679, "rew_std": 87.48656258153838, "Agent": "iqn"}, {"env_step": 3900000, "rew": 326.85999755859376, "rew_std": 51.01209852522548, "Agent": "iqn"}, {"env_step": 4000000, "rew": 307.7100006103516, "rew_std": 58.77869107819418, "Agent": "iqn"}, {"env_step": 4100000, "rew": 311.3300033569336, "rew_std": 66.51625456766375, "Agent": "iqn"}, {"env_step": 4200000, "rew": 300.1900009155273, "rew_std": 35.21126152649786, "Agent": "iqn"}, {"env_step": 4300000, "rew": 307.98000183105466, "rew_std": 31.851334098626857, "Agent": "iqn"}, {"env_step": 4400000, "rew": 318.60999755859376, "rew_std": 43.45476757896472, "Agent": "iqn"}, {"env_step": 4500000, "rew": 327.53999786376954, "rew_std": 62.97847529940396, "Agent": "iqn"}, {"env_step": 4600000, "rew": 282.2199996948242, "rew_std": 65.37763576262309, "Agent": "iqn"}, {"env_step": 4700000, "rew": 298.72000274658205, "rew_std": 53.23556836698092, "Agent": "iqn"}, {"env_step": 4800000, "rew": 338.88000030517577, "rew_std": 64.96912888006219, "Agent": "iqn"}, {"env_step": 4900000, "rew": 320.7, "rew_std": 66.17025120314409, "Agent": "iqn"}, {"env_step": 5000000, "rew": 318.7500061035156, "rew_std": 40.19410702152574, "Agent": "iqn"}, {"env_step": 5100000, "rew": 295.86999664306643, "rew_std": 70.69885222375017, "Agent": "iqn"}, {"env_step": 5200000, "rew": 344.55999755859375, "rew_std": 31.723278369793565, "Agent": "iqn"}, {"env_step": 5300000, "rew": 328.7099975585937, "rew_std": 42.69288989467297, "Agent": "iqn"}, {"env_step": 5400000, "rew": 296.60999755859376, "rew_std": 66.42924294236374, "Agent": "iqn"}, {"env_step": 5500000, "rew": 283.00999755859374, "rew_std": 70.532708487894, "Agent": "iqn"}, {"env_step": 5600000, "rew": 355.9200042724609, "rew_std": 22.674958525127415, "Agent": "iqn"}, {"env_step": 5700000, "rew": 301.33000030517576, "rew_std": 88.94783620229133, "Agent": "iqn"}, {"env_step": 5800000, "rew": 293.93999633789065, "rew_std": 52.81626973902586, "Agent": "iqn"}, {"env_step": 5900000, "rew": 353.3300018310547, "rew_std": 36.446288269528395, "Agent": "iqn"}, {"env_step": 6000000, "rew": 305.2799987792969, "rew_std": 54.018343190981945, "Agent": "iqn"}, {"env_step": 6100000, "rew": 349.9500030517578, "rew_std": 24.089635436458245, "Agent": "iqn"}, {"env_step": 6200000, "rew": 328.2199966430664, "rew_std": 48.67830877960062, "Agent": "iqn"}, {"env_step": 6300000, "rew": 343.49000244140626, "rew_std": 31.285020475109025, "Agent": "iqn"}, {"env_step": 6400000, "rew": 303.22000274658205, "rew_std": 55.57735070139985, "Agent": "iqn"}, {"env_step": 6500000, "rew": 321.5300003051758, "rew_std": 54.123212058813415, "Agent": "iqn"}, {"env_step": 6600000, "rew": 348.51000061035154, "rew_std": 38.049976345782774, "Agent": "iqn"}, {"env_step": 6700000, "rew": 352.5299987792969, "rew_std": 29.52311476032704, "Agent": "iqn"}, {"env_step": 6800000, "rew": 352.8199981689453, "rew_std": 27.302040489543003, "Agent": "iqn"}, {"env_step": 6900000, "rew": 325.95, "rew_std": 32.27061521262411, "Agent": "iqn"}, {"env_step": 7000000, "rew": 337.38000335693357, "rew_std": 57.94110508353129, "Agent": "iqn"}, {"env_step": 7100000, "rew": 337.1800064086914, "rew_std": 57.54656752151562, "Agent": "iqn"}, {"env_step": 7200000, "rew": 347.05, "rew_std": 29.28273921056375, "Agent": "iqn"}, {"env_step": 7300000, "rew": 336.5800018310547, "rew_std": 25.209833569339764, "Agent": "iqn"}, {"env_step": 7400000, "rew": 314.8700004577637, "rew_std": 83.15007163125841, "Agent": "iqn"}, {"env_step": 7500000, "rew": 346.7200012207031, "rew_std": 32.34303330425867, "Agent": "iqn"}, {"env_step": 7600000, "rew": 314.9699996948242, "rew_std": 51.35753429441995, "Agent": "iqn"}, {"env_step": 7700000, "rew": 312.4199966430664, "rew_std": 58.537693420573355, "Agent": "iqn"}, {"env_step": 7800000, "rew": 328.15, "rew_std": 45.2648705332799, "Agent": "iqn"}, {"env_step": 7900000, "rew": 346.00999755859374, "rew_std": 24.54332676008246, "Agent": "iqn"}, {"env_step": 8000000, "rew": 321.4400039672852, "rew_std": 50.142358701938754, "Agent": "iqn"}, {"env_step": 8100000, "rew": 337.76999816894534, "rew_std": 42.3190257331327, "Agent": "iqn"}, {"env_step": 8200000, "rew": 333.6700012207031, "rew_std": 29.37312182285308, "Agent": "iqn"}, {"env_step": 8300000, "rew": 324.7100006103516, "rew_std": 52.949946115785885, "Agent": "iqn"}, {"env_step": 8400000, "rew": 333.55999755859375, "rew_std": 33.45080520039474, "Agent": "iqn"}, {"env_step": 8500000, "rew": 331.4600036621094, "rew_std": 48.38483477275051, "Agent": "iqn"}, {"env_step": 8600000, "rew": 324.1599975585938, "rew_std": 49.630536861003044, "Agent": "iqn"}, {"env_step": 8700000, "rew": 351.7799987792969, "rew_std": 27.03962312243994, "Agent": "iqn"}, {"env_step": 8800000, "rew": 339.1300018310547, "rew_std": 53.111113054337935, "Agent": "iqn"}, {"env_step": 8900000, "rew": 329.3400024414062, "rew_std": 39.404931065916315, "Agent": "iqn"}, {"env_step": 9000000, "rew": 337.5, "rew_std": 31.20205217899507, "Agent": "iqn"}, {"env_step": 9100000, "rew": 312.82000122070315, "rew_std": 81.47756719359116, "Agent": "iqn"}, {"env_step": 9200000, "rew": 318.81000366210935, "rew_std": 61.65786917648312, "Agent": "iqn"}, {"env_step": 9300000, "rew": 355.5400024414063, "rew_std": 16.46543217809965, "Agent": "iqn"}, {"env_step": 9400000, "rew": 316.15000305175784, "rew_std": 69.99261936144228, "Agent": "iqn"}, {"env_step": 9500000, "rew": 346.1699981689453, "rew_std": 31.56960876454528, "Agent": "iqn"}, {"env_step": 9600000, "rew": 338.1199951171875, "rew_std": 45.98025387262979, "Agent": "iqn"}, {"env_step": 9700000, "rew": 329.63000335693357, "rew_std": 42.38771307072964, "Agent": "iqn"}, {"env_step": 9800000, "rew": 331.4, "rew_std": 39.38928853624626, "Agent": "iqn"}, {"env_step": 9900000, "rew": 318.9999969482422, "rew_std": 43.779264475855875, "Agent": "iqn"}, {"env_step": 10000000, "rew": 345.3199981689453, "rew_std": 30.032006070982476, "Agent": "iqn"}, {"env_step": 0, "rew": 1.7999999798834323, "rew_std": 1.461506037350442, "Agent": "rainbow"}, {"env_step": 100000, "rew": 2.789999971538782, "rew_std": 2.4865437409115354, "Agent": "rainbow"}, {"env_step": 200000, "rew": 12.420000076293945, "rew_std": 1.9046260138234519, "Agent": "rainbow"}, {"env_step": 300000, "rew": 17.679999923706056, "rew_std": 3.491647063886927, "Agent": "rainbow"}, {"env_step": 400000, "rew": 22.370000076293945, "rew_std": 3.6133226874735795, "Agent": "rainbow"}, {"env_step": 500000, "rew": 25.05, "rew_std": 3.6381997573897396, "Agent": "rainbow"}, {"env_step": 600000, "rew": 35.28999977111816, "rew_std": 4.300104297171605, "Agent": "rainbow"}, {"env_step": 700000, "rew": 38.70999984741211, "rew_std": 5.35918838448342, "Agent": "rainbow"}, {"env_step": 800000, "rew": 50.54999961853027, "rew_std": 8.483189257414512, "Agent": "rainbow"}, {"env_step": 900000, "rew": 60.01000061035156, "rew_std": 10.841259439227493, "Agent": "rainbow"}, {"env_step": 1000000, "rew": 72.13000030517578, "rew_std": 14.579921126080473, "Agent": "rainbow"}, {"env_step": 1100000, "rew": 105.4900016784668, "rew_std": 41.347393284288735, "Agent": "rainbow"}, {"env_step": 1200000, "rew": 99.45000076293945, "rew_std": 25.770341675180582, "Agent": "rainbow"}, {"env_step": 1300000, "rew": 104.1, "rew_std": 24.05194318506914, "Agent": "rainbow"}, {"env_step": 1400000, "rew": 122.69000091552735, "rew_std": 33.97927298843038, "Agent": "rainbow"}, {"env_step": 1500000, "rew": 188.78999786376954, "rew_std": 47.31898888179236, "Agent": "rainbow"}, {"env_step": 1600000, "rew": 230.06000213623048, "rew_std": 45.668113003632385, "Agent": "rainbow"}, {"env_step": 1700000, "rew": 222.32000122070312, "rew_std": 69.96254538894237, "Agent": "rainbow"}, {"env_step": 1800000, "rew": 264.9599975585937, "rew_std": 44.13764946857046, "Agent": "rainbow"}, {"env_step": 1900000, "rew": 304.4, "rew_std": 44.836877219962034, "Agent": "rainbow"}, {"env_step": 2000000, "rew": 327.15000305175784, "rew_std": 42.61066349224782, "Agent": "rainbow"}, {"env_step": 2100000, "rew": 329.89000091552737, "rew_std": 47.86103795081456, "Agent": "rainbow"}, {"env_step": 2200000, "rew": 332.67000274658204, "rew_std": 49.45685362258107, "Agent": "rainbow"}, {"env_step": 2300000, "rew": 291.48999786376953, "rew_std": 72.64021500933792, "Agent": "rainbow"}, {"env_step": 2400000, "rew": 334.86000061035156, "rew_std": 78.4941559533148, "Agent": "rainbow"}, {"env_step": 2500000, "rew": 333.99000091552733, "rew_std": 75.35539320141508, "Agent": "rainbow"}, {"env_step": 2600000, "rew": 353.07000122070315, "rew_std": 51.661630785667796, "Agent": "rainbow"}, {"env_step": 2700000, "rew": 335.8000061035156, "rew_std": 65.27762192648794, "Agent": "rainbow"}, {"env_step": 2800000, "rew": 381.63999938964844, "rew_std": 33.01331023409123, "Agent": "rainbow"}, {"env_step": 2900000, "rew": 370.3, "rew_std": 29.175224311445085, "Agent": "rainbow"}, {"env_step": 3000000, "rew": 396.6999969482422, "rew_std": 40.15956167187139, "Agent": "rainbow"}, {"env_step": 3100000, "rew": 378.6699981689453, "rew_std": 44.56986018654734, "Agent": "rainbow"}, {"env_step": 3200000, "rew": 367.72999877929686, "rew_std": 48.201724085448, "Agent": "rainbow"}, {"env_step": 3300000, "rew": 373.5100036621094, "rew_std": 50.91597587464675, "Agent": "rainbow"}, {"env_step": 3400000, "rew": 373.94000549316405, "rew_std": 45.228182826120296, "Agent": "rainbow"}, {"env_step": 3500000, "rew": 386.4099945068359, "rew_std": 24.17674431981728, "Agent": "rainbow"}, {"env_step": 3600000, "rew": 405.77000732421874, "rew_std": 42.80668625101679, "Agent": "rainbow"}, {"env_step": 3700000, "rew": 407.6, "rew_std": 48.41737864076781, "Agent": "rainbow"}, {"env_step": 3800000, "rew": 408.2899993896484, "rew_std": 31.628448022764232, "Agent": "rainbow"}, {"env_step": 3900000, "rew": 390.32000122070315, "rew_std": 28.6771582193927, "Agent": "rainbow"}, {"env_step": 4000000, "rew": 371.5400054931641, "rew_std": 32.82566550832653, "Agent": "rainbow"}, {"env_step": 4100000, "rew": 374.39000244140624, "rew_std": 44.559746097958346, "Agent": "rainbow"}, {"env_step": 4200000, "rew": 373.72999877929686, "rew_std": 42.524114470487525, "Agent": "rainbow"}, {"env_step": 4300000, "rew": 405.72999877929686, "rew_std": 21.966842238579567, "Agent": "rainbow"}, {"env_step": 4400000, "rew": 382.3500061035156, "rew_std": 30.38684579249488, "Agent": "rainbow"}, {"env_step": 4500000, "rew": 411.3499984741211, "rew_std": 69.50523878087984, "Agent": "rainbow"}, {"env_step": 4600000, "rew": 417.1599975585938, "rew_std": 63.06370162906102, "Agent": "rainbow"}, {"env_step": 4700000, "rew": 386.0800048828125, "rew_std": 28.676079172356957, "Agent": "rainbow"}, {"env_step": 4800000, "rew": 398.8399993896484, "rew_std": 30.87983841532746, "Agent": "rainbow"}, {"env_step": 4900000, "rew": 396.72999877929686, "rew_std": 33.71180556312966, "Agent": "rainbow"}, {"env_step": 5000000, "rew": 384.89000244140624, "rew_std": 33.94638292724067, "Agent": "rainbow"}, {"env_step": 5100000, "rew": 400.55, "rew_std": 46.40784802725254, "Agent": "rainbow"}, {"env_step": 5200000, "rew": 419.52000122070314, "rew_std": 28.008704651064, "Agent": "rainbow"}, {"env_step": 5300000, "rew": 407.02000122070314, "rew_std": 16.72254065818173, "Agent": "rainbow"}, {"env_step": 5400000, "rew": 414.0499969482422, "rew_std": 41.30225599948548, "Agent": "rainbow"}, {"env_step": 5500000, "rew": 398.7799987792969, "rew_std": 47.40503931200679, "Agent": "rainbow"}, {"env_step": 5600000, "rew": 403.820002746582, "rew_std": 64.9688057624937, "Agent": "rainbow"}, {"env_step": 5700000, "rew": 389.6999984741211, "rew_std": 72.28918253740537, "Agent": "rainbow"}, {"env_step": 5800000, "rew": 403.6700012207031, "rew_std": 38.4149988299815, "Agent": "rainbow"}, {"env_step": 5900000, "rew": 409.89000244140624, "rew_std": 24.034117466481515, "Agent": "rainbow"}, {"env_step": 6000000, "rew": 399.0400024414063, "rew_std": 39.877192287649464, "Agent": "rainbow"}, {"env_step": 6100000, "rew": 384.40000305175784, "rew_std": 56.328160747865354, "Agent": "rainbow"}, {"env_step": 6200000, "rew": 403.8999969482422, "rew_std": 33.781714504496996, "Agent": "rainbow"}, {"env_step": 6300000, "rew": 397.63999938964844, "rew_std": 26.051031347673337, "Agent": "rainbow"}, {"env_step": 6400000, "rew": 409.3799987792969, "rew_std": 15.132536949771454, "Agent": "rainbow"}, {"env_step": 6500000, "rew": 408.31000061035155, "rew_std": 22.57203611858315, "Agent": "rainbow"}, {"env_step": 6600000, "rew": 415.3700012207031, "rew_std": 38.56120137444495, "Agent": "rainbow"}, {"env_step": 6700000, "rew": 431.77000122070314, "rew_std": 47.137184174849416, "Agent": "rainbow"}, {"env_step": 6800000, "rew": 404.22999877929686, "rew_std": 34.66502740832278, "Agent": "rainbow"}, {"env_step": 6900000, "rew": 401.00999755859374, "rew_std": 22.66051127275429, "Agent": "rainbow"}, {"env_step": 7000000, "rew": 393.8199981689453, "rew_std": 48.31937448033659, "Agent": "rainbow"}, {"env_step": 7100000, "rew": 417.7, "rew_std": 50.47389404778245, "Agent": "rainbow"}, {"env_step": 7200000, "rew": 395.7499969482422, "rew_std": 40.349476442782, "Agent": "rainbow"}, {"env_step": 7300000, "rew": 417.0899963378906, "rew_std": 18.112780812448758, "Agent": "rainbow"}, {"env_step": 7400000, "rew": 383.18999633789065, "rew_std": 47.911194905935396, "Agent": "rainbow"}, {"env_step": 7500000, "rew": 410.9200073242188, "rew_std": 70.87141742592017, "Agent": "rainbow"}, {"env_step": 7600000, "rew": 420.8400024414062, "rew_std": 34.12468989742961, "Agent": "rainbow"}, {"env_step": 7700000, "rew": 422.22999572753906, "rew_std": 21.98477505605195, "Agent": "rainbow"}, {"env_step": 7800000, "rew": 423.4200042724609, "rew_std": 43.2258487041201, "Agent": "rainbow"}, {"env_step": 7900000, "rew": 418.1700012207031, "rew_std": 35.33822585932713, "Agent": "rainbow"}, {"env_step": 8000000, "rew": 396.6499938964844, "rew_std": 39.98475122704287, "Agent": "rainbow"}, {"env_step": 8100000, "rew": 431.2100006103516, "rew_std": 43.29839137758222, "Agent": "rainbow"}, {"env_step": 8200000, "rew": 388.6700042724609, "rew_std": 42.25809445722602, "Agent": "rainbow"}, {"env_step": 8300000, "rew": 371.5899963378906, "rew_std": 76.53331656114038, "Agent": "rainbow"}, {"env_step": 8400000, "rew": 408.69000244140625, "rew_std": 24.624358338524335, "Agent": "rainbow"}, {"env_step": 8500000, "rew": 409.0399993896484, "rew_std": 41.573793217544875, "Agent": "rainbow"}, {"env_step": 8600000, "rew": 401.95, "rew_std": 37.11682649365457, "Agent": "rainbow"}, {"env_step": 8700000, "rew": 420.77000122070314, "rew_std": 44.74215708867204, "Agent": "rainbow"}, {"env_step": 8800000, "rew": 405.2499969482422, "rew_std": 31.691043996566787, "Agent": "rainbow"}, {"env_step": 8900000, "rew": 398.4700012207031, "rew_std": 40.46342211972141, "Agent": "rainbow"}, {"env_step": 9000000, "rew": 440.4, "rew_std": 50.116695811843044, "Agent": "rainbow"}, {"env_step": 9100000, "rew": 422.5, "rew_std": 35.31455535189779, "Agent": "rainbow"}, {"env_step": 9200000, "rew": 428.8199981689453, "rew_std": 33.58162360678032, "Agent": "rainbow"}, {"env_step": 9300000, "rew": 425.3400024414062, "rew_std": 46.15141555774316, "Agent": "rainbow"}, {"env_step": 9400000, "rew": 440.2799987792969, "rew_std": 42.420672496482084, "Agent": "rainbow"}, {"env_step": 9500000, "rew": 410.85, "rew_std": 6.367617048401184, "Agent": "rainbow"}, {"env_step": 9600000, "rew": 399.57000122070315, "rew_std": 53.69724380529421, "Agent": "rainbow"}, {"env_step": 9700000, "rew": 431.38999938964844, "rew_std": 40.52948488877201, "Agent": "rainbow"}, {"env_step": 9800000, "rew": 422.23999938964846, "rew_std": 37.25512766393269, "Agent": "rainbow"}, {"env_step": 9900000, "rew": 413.8299987792969, "rew_std": 36.612840411302244, "Agent": "rainbow"}, {"env_step": 10000000, "rew": 392.0399993896484, "rew_std": 48.50192720302268, "Agent": "rainbow"}, {"env_step": 0, "rew": 2.0299999833106996, "rew_std": 0.8764131095183902, "Agent": "ppo"}, {"env_step": 100000, "rew": 2.680000030994415, "rew_std": 1.5315352171604617, "Agent": "ppo"}, {"env_step": 200000, "rew": 6.869999933242798, "rew_std": 2.2427883234590684, "Agent": "ppo"}, {"env_step": 300000, "rew": 8.350000023841858, "rew_std": 2.62154539898698, "Agent": "ppo"}, {"env_step": 400000, "rew": 10.77000002861023, "rew_std": 3.155328811734828, "Agent": "ppo"}, {"env_step": 500000, "rew": 11.309999942779541, "rew_std": 2.5762180106126547, "Agent": "ppo"}, {"env_step": 600000, "rew": 13.859999895095825, "rew_std": 4.651494279703967, "Agent": "ppo"}, {"env_step": 700000, "rew": 14.749999761581421, "rew_std": 4.220485482312834, "Agent": "ppo"}, {"env_step": 800000, "rew": 15.349999904632568, "rew_std": 4.1898090655449005, "Agent": "ppo"}, {"env_step": 900000, "rew": 18.020000171661376, "rew_std": 4.536915203735637, "Agent": "ppo"}, {"env_step": 1000000, "rew": 17.020000171661376, "rew_std": 4.455962464888489, "Agent": "ppo"}, {"env_step": 1100000, "rew": 20.690000343322755, "rew_std": 8.05238462382638, "Agent": "ppo"}, {"env_step": 1200000, "rew": 22.689999961853026, "rew_std": 12.304588638832776, "Agent": "ppo"}, {"env_step": 1300000, "rew": 20.55999975204468, "rew_std": 4.387299621179701, "Agent": "ppo"}, {"env_step": 1400000, "rew": 19.94000005722046, "rew_std": 4.290501381090037, "Agent": "ppo"}, {"env_step": 1500000, "rew": 26.379999923706055, "rew_std": 10.657748376402598, "Agent": "ppo"}, {"env_step": 1600000, "rew": 23.750000190734863, "rew_std": 4.385715448611054, "Agent": "ppo"}, {"env_step": 1700000, "rew": 23.659999656677247, "rew_std": 4.599826003557162, "Agent": "ppo"}, {"env_step": 1800000, "rew": 24.960000133514406, "rew_std": 4.907586145252241, "Agent": "ppo"}, {"env_step": 1900000, "rew": 27.629999732971193, "rew_std": 5.699657718554115, "Agent": "ppo"}, {"env_step": 2000000, "rew": 25.119999885559082, "rew_std": 2.981207745714502, "Agent": "ppo"}, {"env_step": 2100000, "rew": 24.629999923706055, "rew_std": 3.647204546502502, "Agent": "ppo"}, {"env_step": 2200000, "rew": 27.63999996185303, "rew_std": 3.726714114078771, "Agent": "ppo"}, {"env_step": 2300000, "rew": 28.329999923706055, "rew_std": 4.171102659408543, "Agent": "ppo"}, {"env_step": 2400000, "rew": 28.530000305175783, "rew_std": 3.5877709920260523, "Agent": "ppo"}, {"env_step": 2500000, "rew": 32.079999923706055, "rew_std": 4.9793168787475475, "Agent": "ppo"}, {"env_step": 2600000, "rew": 35.150000190734865, "rew_std": 6.757550860011463, "Agent": "ppo"}, {"env_step": 2700000, "rew": 35.24000015258789, "rew_std": 13.677953506506757, "Agent": "ppo"}, {"env_step": 2800000, "rew": 37.55, "rew_std": 9.215123739752018, "Agent": "ppo"}, {"env_step": 2900000, "rew": 36.089999771118165, "rew_std": 4.728942780362954, "Agent": "ppo"}, {"env_step": 3000000, "rew": 42.98000030517578, "rew_std": 19.57303269533935, "Agent": "ppo"}, {"env_step": 3100000, "rew": 37.360000419616696, "rew_std": 12.54338152629255, "Agent": "ppo"}, {"env_step": 3200000, "rew": 40.380000305175784, "rew_std": 10.776437119519981, "Agent": "ppo"}, {"env_step": 3300000, "rew": 38.110000419616696, "rew_std": 12.645825732034167, "Agent": "ppo"}, {"env_step": 3400000, "rew": 42.21999931335449, "rew_std": 9.1218199543614, "Agent": "ppo"}, {"env_step": 3500000, "rew": 50.119999694824216, "rew_std": 14.80836362588282, "Agent": "ppo"}, {"env_step": 3600000, "rew": 47.400000190734865, "rew_std": 10.068664307564164, "Agent": "ppo"}, {"env_step": 3700000, "rew": 45.4399995803833, "rew_std": 9.41564705130973, "Agent": "ppo"}, {"env_step": 3800000, "rew": 53.88000068664551, "rew_std": 16.57653842254833, "Agent": "ppo"}, {"env_step": 3900000, "rew": 61.64000015258789, "rew_std": 24.900129246081427, "Agent": "ppo"}, {"env_step": 4000000, "rew": 56.04999923706055, "rew_std": 14.419240647637906, "Agent": "ppo"}, {"env_step": 4100000, "rew": 55.7, "rew_std": 20.825513362350105, "Agent": "ppo"}, {"env_step": 4200000, "rew": 54.29000072479248, "rew_std": 21.086083955047158, "Agent": "ppo"}, {"env_step": 4300000, "rew": 60.97000007629394, "rew_std": 23.419352948854787, "Agent": "ppo"}, {"env_step": 4400000, "rew": 61.32000160217285, "rew_std": 24.429810367430697, "Agent": "ppo"}, {"env_step": 4500000, "rew": 62.31000022888183, "rew_std": 21.309597897936026, "Agent": "ppo"}, {"env_step": 4600000, "rew": 67.86999893188477, "rew_std": 26.248467597744845, "Agent": "ppo"}, {"env_step": 4700000, "rew": 66.7, "rew_std": 24.40237644035704, "Agent": "ppo"}, {"env_step": 4800000, "rew": 82.43000106811523, "rew_std": 39.014153045648264, "Agent": "ppo"}, {"env_step": 4900000, "rew": 82.50999946594239, "rew_std": 29.639110602215194, "Agent": "ppo"}, {"env_step": 5000000, "rew": 89.46999969482422, "rew_std": 36.290717341380685, "Agent": "ppo"}, {"env_step": 5100000, "rew": 75.40999946594238, "rew_std": 19.553385764445366, "Agent": "ppo"}, {"env_step": 5200000, "rew": 91.71999969482422, "rew_std": 33.941354984622286, "Agent": "ppo"}, {"env_step": 5300000, "rew": 95.90999984741211, "rew_std": 34.3326784841907, "Agent": "ppo"}, {"env_step": 5400000, "rew": 101.17000045776368, "rew_std": 48.307889695724214, "Agent": "ppo"}, {"env_step": 5500000, "rew": 93.73999977111816, "rew_std": 32.80829688751657, "Agent": "ppo"}, {"env_step": 5600000, "rew": 101.56000099182128, "rew_std": 38.65323253336432, "Agent": "ppo"}, {"env_step": 5700000, "rew": 119.71000213623047, "rew_std": 58.63711387604412, "Agent": "ppo"}, {"env_step": 5800000, "rew": 121.56000022888183, "rew_std": 52.67867263607598, "Agent": "ppo"}, {"env_step": 5900000, "rew": 128.3799991607666, "rew_std": 58.80482705685627, "Agent": "ppo"}, {"env_step": 6000000, "rew": 120.53999824523926, "rew_std": 48.01154252771422, "Agent": "ppo"}, {"env_step": 6100000, "rew": 122.3899990081787, "rew_std": 49.0965679103651, "Agent": "ppo"}, {"env_step": 6200000, "rew": 133.99999961853027, "rew_std": 60.90436501831146, "Agent": "ppo"}, {"env_step": 6300000, "rew": 140.52000198364257, "rew_std": 69.07889716111299, "Agent": "ppo"}, {"env_step": 6400000, "rew": 137.21999893188476, "rew_std": 62.32881643967556, "Agent": "ppo"}, {"env_step": 6500000, "rew": 146.8299991607666, "rew_std": 55.37363915857669, "Agent": "ppo"}, {"env_step": 6600000, "rew": 139.81000213623048, "rew_std": 74.04937024552305, "Agent": "ppo"}, {"env_step": 6700000, "rew": 149.6099994659424, "rew_std": 71.83785305550174, "Agent": "ppo"}, {"env_step": 6800000, "rew": 135.49000091552733, "rew_std": 55.006754667730945, "Agent": "ppo"}, {"env_step": 6900000, "rew": 147.45999984741212, "rew_std": 59.91933138185291, "Agent": "ppo"}, {"env_step": 7000000, "rew": 165.1099994659424, "rew_std": 73.1342120883567, "Agent": "ppo"}, {"env_step": 7100000, "rew": 174.0099983215332, "rew_std": 66.24739028991273, "Agent": "ppo"}, {"env_step": 7200000, "rew": 177.29999847412108, "rew_std": 73.95540152011542, "Agent": "ppo"}, {"env_step": 7300000, "rew": 175.92999992370605, "rew_std": 76.0113898211075, "Agent": "ppo"}, {"env_step": 7400000, "rew": 164.42999877929688, "rew_std": 77.33506176445988, "Agent": "ppo"}, {"env_step": 7500000, "rew": 173.9599983215332, "rew_std": 84.44470542302061, "Agent": "ppo"}, {"env_step": 7600000, "rew": 171.43999900817872, "rew_std": 83.88048894373154, "Agent": "ppo"}, {"env_step": 7700000, "rew": 199.14999923706054, "rew_std": 94.75587743328266, "Agent": "ppo"}, {"env_step": 7800000, "rew": 186.2400001525879, "rew_std": 77.16537397952645, "Agent": "ppo"}, {"env_step": 7900000, "rew": 199.35, "rew_std": 79.58159406938299, "Agent": "ppo"}, {"env_step": 8000000, "rew": 210.05000190734864, "rew_std": 90.49879749595148, "Agent": "ppo"}, {"env_step": 8100000, "rew": 212.7500015258789, "rew_std": 88.85424210708028, "Agent": "ppo"}, {"env_step": 8200000, "rew": 210.8900001525879, "rew_std": 73.44615975377866, "Agent": "ppo"}, {"env_step": 8300000, "rew": 219.910001373291, "rew_std": 100.7312432120665, "Agent": "ppo"}, {"env_step": 8400000, "rew": 237.1999984741211, "rew_std": 88.27977955944947, "Agent": "ppo"}, {"env_step": 8500000, "rew": 235.44000015258788, "rew_std": 82.60257787044871, "Agent": "ppo"}, {"env_step": 8600000, "rew": 217.410001373291, "rew_std": 82.26707245141274, "Agent": "ppo"}, {"env_step": 8700000, "rew": 239.08999938964843, "rew_std": 80.68150053430362, "Agent": "ppo"}, {"env_step": 8800000, "rew": 228.78000259399414, "rew_std": 78.50268738950463, "Agent": "ppo"}, {"env_step": 8900000, "rew": 242.50999908447267, "rew_std": 88.02333190391393, "Agent": "ppo"}, {"env_step": 9000000, "rew": 230.4200012207031, "rew_std": 98.88399222725009, "Agent": "ppo"}, {"env_step": 9100000, "rew": 254.71000213623046, "rew_std": 74.41610668817124, "Agent": "ppo"}, {"env_step": 9200000, "rew": 258.9599975585937, "rew_std": 88.94478371137306, "Agent": "ppo"}, {"env_step": 9300000, "rew": 248.64999694824218, "rew_std": 86.56067064571964, "Agent": "ppo"}, {"env_step": 9400000, "rew": 247.660001373291, "rew_std": 92.0610693574401, "Agent": "ppo"}, {"env_step": 9500000, "rew": 276.51000289916993, "rew_std": 81.18026107244279, "Agent": "ppo"}, {"env_step": 9600000, "rew": 258.3400054931641, "rew_std": 82.19079470003253, "Agent": "ppo"}, {"env_step": 9700000, "rew": 265.7800018310547, "rew_std": 71.35941141595737, "Agent": "ppo"}, {"env_step": 9800000, "rew": 279.3699981689453, "rew_std": 81.73140051743914, "Agent": "ppo"}, {"env_step": 9900000, "rew": 283.01000061035154, "rew_std": 74.29673322328118, "Agent": "ppo"}, {"env_step": 10000000, "rew": 272.84999923706056, "rew_std": 87.10014073751415, "Agent": "ppo"}] \ No newline at end of file diff --git a/examples/atari/benchmark/EnduroNoFrameskip-v4/result.json b/examples/atari/benchmark/EnduroNoFrameskip-v4/result.json new file mode 100644 index 0000000000000000000000000000000000000000..0694a7694f05b42bb095f80bf755e287e58ddfc2 --- /dev/null +++ b/examples/atari/benchmark/EnduroNoFrameskip-v4/result.json @@ -0,0 +1 @@ +[{"env_step": 0, "rew": 0.0800000011920929, "rew_std": 0.24000000357627865, "Agent": "c51"}, {"env_step": 100000, "rew": 0.0, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 200000, "rew": 0.0, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 300000, "rew": 0.1, "rew_std": 0.30000000000000004, "Agent": "c51"}, {"env_step": 400000, "rew": 0.21000000238418579, "rew_std": 0.35623026856774925, "Agent": "c51"}, {"env_step": 500000, "rew": 0.05000000074505806, "rew_std": 0.12041594758226036, "Agent": "c51"}, {"env_step": 600000, "rew": 0.31999999061226847, "rew_std": 0.927146130289149, "Agent": "c51"}, {"env_step": 700000, "rew": 0.7299999989569187, "rew_std": 1.0354226172474177, "Agent": "c51"}, {"env_step": 800000, "rew": 29.999999809265137, "rew_std": 22.8517390602112, "Agent": "c51"}, {"env_step": 900000, "rew": 84.33000020980835, "rew_std": 38.11215200504169, "Agent": "c51"}, {"env_step": 1000000, "rew": 134.27000045776367, "rew_std": 33.76175580023914, "Agent": "c51"}, {"env_step": 1100000, "rew": 195.3500015258789, "rew_std": 55.070539310018276, "Agent": "c51"}, {"env_step": 1200000, "rew": 305.8000015258789, "rew_std": 63.623845163113806, "Agent": "c51"}, {"env_step": 1300000, "rew": 341.6900039672852, "rew_std": 75.15398169731095, "Agent": "c51"}, {"env_step": 1400000, "rew": 439.62000427246096, "rew_std": 35.53360923105612, "Agent": "c51"}, {"env_step": 1500000, "rew": 448.17999877929685, "rew_std": 63.31653151381171, "Agent": "c51"}, {"env_step": 1600000, "rew": 483.0900024414062, "rew_std": 63.598508175034475, "Agent": "c51"}, {"env_step": 1700000, "rew": 479.4499938964844, "rew_std": 60.719817637295435, "Agent": "c51"}, {"env_step": 1800000, "rew": 477.8299987792969, "rew_std": 88.50112258719895, "Agent": "c51"}, {"env_step": 1900000, "rew": 516.6200012207031, "rew_std": 90.65022983110897, "Agent": "c51"}, {"env_step": 2000000, "rew": 500.8199981689453, "rew_std": 56.8060386594065, "Agent": "c51"}, {"env_step": 2100000, "rew": 539.5200012207031, "rew_std": 80.30648599564611, "Agent": "c51"}, {"env_step": 2200000, "rew": 592.730014038086, "rew_std": 96.93099005249192, "Agent": "c51"}, {"env_step": 2300000, "rew": 573.3399963378906, "rew_std": 91.49156422648889, "Agent": "c51"}, {"env_step": 2400000, "rew": 606.1099975585937, "rew_std": 68.63315724977679, "Agent": "c51"}, {"env_step": 2500000, "rew": 652.9599975585937, "rew_std": 85.58606896755197, "Agent": "c51"}, {"env_step": 2600000, "rew": 655.4, "rew_std": 81.90140506345105, "Agent": "c51"}, {"env_step": 2700000, "rew": 625.2999969482422, "rew_std": 85.18681205945472, "Agent": "c51"}, {"env_step": 2800000, "rew": 637.7900024414063, "rew_std": 122.54666760609061, "Agent": "c51"}, {"env_step": 2900000, "rew": 644.1, "rew_std": 85.42630507564513, "Agent": "c51"}, {"env_step": 3000000, "rew": 710.5900024414062, "rew_std": 75.28909772977295, "Agent": "c51"}, {"env_step": 3100000, "rew": 671.3799987792969, "rew_std": 56.7536538058374, "Agent": "c51"}, {"env_step": 3200000, "rew": 668.2599975585938, "rew_std": 65.81738352088684, "Agent": "c51"}, {"env_step": 3300000, "rew": 690.7899993896484, "rew_std": 84.55681861003309, "Agent": "c51"}, {"env_step": 3400000, "rew": 738.3400024414062, "rew_std": 101.42283074745313, "Agent": "c51"}, {"env_step": 3500000, "rew": 730.4300048828125, "rew_std": 90.93792521834833, "Agent": "c51"}, {"env_step": 3600000, "rew": 742.3700012207031, "rew_std": 98.71111614905327, "Agent": "c51"}, {"env_step": 3700000, "rew": 708.8199981689453, "rew_std": 130.55409140026177, "Agent": "c51"}, {"env_step": 3800000, "rew": 705.9700012207031, "rew_std": 75.86588955018891, "Agent": "c51"}, {"env_step": 3900000, "rew": 755.3899963378906, "rew_std": 111.82682749974384, "Agent": "c51"}, {"env_step": 4000000, "rew": 792.8599975585937, "rew_std": 100.65659388596904, "Agent": "c51"}, {"env_step": 4100000, "rew": 780.4700012207031, "rew_std": 89.4066144108773, "Agent": "c51"}, {"env_step": 4200000, "rew": 749.0600036621094, "rew_std": 89.14610115546856, "Agent": "c51"}, {"env_step": 4300000, "rew": 735.3100067138672, "rew_std": 105.2854315881353, "Agent": "c51"}, {"env_step": 4400000, "rew": 794.9, "rew_std": 94.90966229691644, "Agent": "c51"}, {"env_step": 4500000, "rew": 775.8700012207031, "rew_std": 86.57670898287867, "Agent": "c51"}, {"env_step": 4600000, "rew": 764.4599975585937, "rew_std": 121.75587907672723, "Agent": "c51"}, {"env_step": 4700000, "rew": 761.55, "rew_std": 62.056058351897406, "Agent": "c51"}, {"env_step": 4800000, "rew": 746.5799987792968, "rew_std": 130.65757158616265, "Agent": "c51"}, {"env_step": 4900000, "rew": 792.0400085449219, "rew_std": 92.14410963183664, "Agent": "c51"}, {"env_step": 5000000, "rew": 784.8700012207031, "rew_std": 76.30846325078628, "Agent": "c51"}, {"env_step": 5100000, "rew": 806.9400024414062, "rew_std": 89.7122211065443, "Agent": "c51"}, {"env_step": 5200000, "rew": 809.7999938964844, "rew_std": 96.43909531778648, "Agent": "c51"}, {"env_step": 5300000, "rew": 801.8500061035156, "rew_std": 71.23311713783924, "Agent": "c51"}, {"env_step": 5400000, "rew": 865.3399963378906, "rew_std": 87.65904211044649, "Agent": "c51"}, {"env_step": 5500000, "rew": 800.1799987792969, "rew_std": 120.17872453215695, "Agent": "c51"}, {"env_step": 5600000, "rew": 808.1899963378906, "rew_std": 114.39077140355394, "Agent": "c51"}, {"env_step": 5700000, "rew": 787.6900024414062, "rew_std": 137.37429836391115, "Agent": "c51"}, {"env_step": 5800000, "rew": 817.6800048828125, "rew_std": 75.91765398683945, "Agent": "c51"}, {"env_step": 5900000, "rew": 788.95, "rew_std": 112.15058353811531, "Agent": "c51"}, {"env_step": 6000000, "rew": 824.4900024414062, "rew_std": 83.1460697372063, "Agent": "c51"}, {"env_step": 6100000, "rew": 791.3400024414062, "rew_std": 71.07509644467352, "Agent": "c51"}, {"env_step": 6200000, "rew": 852.5599975585938, "rew_std": 102.26750645543113, "Agent": "c51"}, {"env_step": 6300000, "rew": 791.0799957275391, "rew_std": 148.64477463928537, "Agent": "c51"}, {"env_step": 6400000, "rew": 799.7700012207031, "rew_std": 94.79903258127105, "Agent": "c51"}, {"env_step": 6500000, "rew": 856.9799987792969, "rew_std": 98.21994199980757, "Agent": "c51"}, {"env_step": 6600000, "rew": 830.3799987792969, "rew_std": 92.69931723386792, "Agent": "c51"}, {"env_step": 6700000, "rew": 828.5700012207031, "rew_std": 117.4840294040638, "Agent": "c51"}, {"env_step": 6800000, "rew": 836.3399963378906, "rew_std": 107.67462291584121, "Agent": "c51"}, {"env_step": 6900000, "rew": 809.05, "rew_std": 86.39843789136866, "Agent": "c51"}, {"env_step": 7000000, "rew": 802.610009765625, "rew_std": 85.19054082178417, "Agent": "c51"}, {"env_step": 7100000, "rew": 821.9500061035156, "rew_std": 84.06559770155239, "Agent": "c51"}, {"env_step": 7200000, "rew": 846.0100036621094, "rew_std": 109.6511176842256, "Agent": "c51"}, {"env_step": 7300000, "rew": 753.2699981689453, "rew_std": 118.18144461063531, "Agent": "c51"}, {"env_step": 7400000, "rew": 862.4699951171875, "rew_std": 94.61919535944561, "Agent": "c51"}, {"env_step": 7500000, "rew": 855.6, "rew_std": 83.0109175169287, "Agent": "c51"}, {"env_step": 7600000, "rew": 804.8099975585938, "rew_std": 87.84749020213859, "Agent": "c51"}, {"env_step": 7700000, "rew": 879.1099975585937, "rew_std": 124.99047443353965, "Agent": "c51"}, {"env_step": 7800000, "rew": 861.0999938964844, "rew_std": 149.99051826966058, "Agent": "c51"}, {"env_step": 7900000, "rew": 840.1100158691406, "rew_std": 74.45142853212236, "Agent": "c51"}, {"env_step": 8000000, "rew": 838.1, "rew_std": 81.488879742754, "Agent": "c51"}, {"env_step": 8100000, "rew": 853.1600036621094, "rew_std": 125.11934723061991, "Agent": "c51"}, {"env_step": 8200000, "rew": 883.8700012207031, "rew_std": 110.8634912939346, "Agent": "c51"}, {"env_step": 8300000, "rew": 843.6199951171875, "rew_std": 110.07341220660334, "Agent": "c51"}, {"env_step": 8400000, "rew": 833.9500122070312, "rew_std": 166.85491541704317, "Agent": "c51"}, {"env_step": 8500000, "rew": 883.5900085449218, "rew_std": 120.01639086371644, "Agent": "c51"}, {"env_step": 8600000, "rew": 816.3799987792969, "rew_std": 93.30947260374052, "Agent": "c51"}, {"env_step": 8700000, "rew": 816.5599975585938, "rew_std": 90.72039308816316, "Agent": "c51"}, {"env_step": 8800000, "rew": 885.2699951171875, "rew_std": 84.88027385527451, "Agent": "c51"}, {"env_step": 8900000, "rew": 940.8200012207031, "rew_std": 133.91121568644473, "Agent": "c51"}, {"env_step": 9000000, "rew": 858.2299926757812, "rew_std": 106.44362694927028, "Agent": "c51"}, {"env_step": 9100000, "rew": 843.5200012207031, "rew_std": 68.82380650364031, "Agent": "c51"}, {"env_step": 9200000, "rew": 847.75, "rew_std": 129.47580625363784, "Agent": "c51"}, {"env_step": 9300000, "rew": 898.089990234375, "rew_std": 74.57461814213535, "Agent": "c51"}, {"env_step": 9400000, "rew": 837.9999938964844, "rew_std": 75.13513732170806, "Agent": "c51"}, {"env_step": 9500000, "rew": 900.9, "rew_std": 78.32506786062443, "Agent": "c51"}, {"env_step": 9600000, "rew": 832.4200134277344, "rew_std": 69.69465721794408, "Agent": "c51"}, {"env_step": 9700000, "rew": 889.0700012207031, "rew_std": 126.08315874519296, "Agent": "c51"}, {"env_step": 9800000, "rew": 873.6100036621094, "rew_std": 72.68277945780702, "Agent": "c51"}, {"env_step": 9900000, "rew": 796.860009765625, "rew_std": 83.10054074269745, "Agent": "c51"}, {"env_step": 10000000, "rew": 821.8199951171875, "rew_std": 135.73986691667577, "Agent": "c51"}, {"env_step": 0, "rew": 0.010000000149011612, "rew_std": 0.03000000044703483, "Agent": "dqn"}, {"env_step": 100000, "rew": 0.28999999687075617, "rew_std": 0.6774215679654139, "Agent": "dqn"}, {"env_step": 200000, "rew": 0.44000001028180125, "rew_std": 0.8392854380207859, "Agent": "dqn"}, {"env_step": 300000, "rew": 0.3599999964237213, "rew_std": 0.8284925931025058, "Agent": "dqn"}, {"env_step": 400000, "rew": 0.33999999538064, "rew_std": 0.6873135967539007, "Agent": "dqn"}, {"env_step": 500000, "rew": 0.11000000089406967, "rew_std": 0.18138357257301754, "Agent": "dqn"}, {"env_step": 600000, "rew": 0.2100000001490116, "rew_std": 0.4548626165857307, "Agent": "dqn"}, {"env_step": 700000, "rew": 8.389999697357416, "rew_std": 23.108285476400564, "Agent": "dqn"}, {"env_step": 800000, "rew": 98.94999904632569, "rew_std": 35.660067729885725, "Agent": "dqn"}, {"env_step": 900000, "rew": 164.40999908447264, "rew_std": 59.0228520750357, "Agent": "dqn"}, {"env_step": 1000000, "rew": 210.4000030517578, "rew_std": 57.88502705780608, "Agent": "dqn"}, {"env_step": 1100000, "rew": 300.38000030517577, "rew_std": 82.23668084181749, "Agent": "dqn"}, {"env_step": 1200000, "rew": 302.7600036621094, "rew_std": 103.46174434569636, "Agent": "dqn"}, {"env_step": 1300000, "rew": 433.91000213623045, "rew_std": 126.5024948103054, "Agent": "dqn"}, {"env_step": 1400000, "rew": 416.14000244140624, "rew_std": 145.09603673682304, "Agent": "dqn"}, {"env_step": 1500000, "rew": 472.1600006103516, "rew_std": 86.80930065353702, "Agent": "dqn"}, {"env_step": 1600000, "rew": 536.65, "rew_std": 132.84745000480228, "Agent": "dqn"}, {"env_step": 1700000, "rew": 511.6300048828125, "rew_std": 113.56259400514695, "Agent": "dqn"}, {"env_step": 1800000, "rew": 559.5200012207031, "rew_std": 95.26979029881092, "Agent": "dqn"}, {"env_step": 1900000, "rew": 504.64000244140624, "rew_std": 183.2878214612098, "Agent": "dqn"}, {"env_step": 2000000, "rew": 574.7700103759765, "rew_std": 78.04244399982899, "Agent": "dqn"}, {"env_step": 2100000, "rew": 531.539998626709, "rew_std": 230.03735295552605, "Agent": "dqn"}, {"env_step": 2200000, "rew": 584.3100006103516, "rew_std": 116.97556069856971, "Agent": "dqn"}, {"env_step": 2300000, "rew": 609.7900024414063, "rew_std": 76.25380786006397, "Agent": "dqn"}, {"env_step": 2400000, "rew": 601.6499938964844, "rew_std": 156.2863846625289, "Agent": "dqn"}, {"env_step": 2500000, "rew": 614.8899978637695, "rew_std": 188.88530877823078, "Agent": "dqn"}, {"env_step": 2600000, "rew": 585.7100036621093, "rew_std": 183.29359497607433, "Agent": "dqn"}, {"env_step": 2700000, "rew": 681.8800048828125, "rew_std": 179.32764210720728, "Agent": "dqn"}, {"env_step": 2800000, "rew": 593.8799987792969, "rew_std": 178.6578949602182, "Agent": "dqn"}, {"env_step": 2900000, "rew": 685.2, "rew_std": 118.27089032922618, "Agent": "dqn"}, {"env_step": 3000000, "rew": 683.2299926757812, "rew_std": 131.34903655399373, "Agent": "dqn"}, {"env_step": 3100000, "rew": 662.1999938964843, "rew_std": 109.7032239820281, "Agent": "dqn"}, {"env_step": 3200000, "rew": 701.2099975585937, "rew_std": 88.04874320833147, "Agent": "dqn"}, {"env_step": 3300000, "rew": 688.9199981689453, "rew_std": 126.02541709175058, "Agent": "dqn"}, {"env_step": 3400000, "rew": 636.0799987792968, "rew_std": 172.2193654656943, "Agent": "dqn"}, {"env_step": 3500000, "rew": 653.7300018310547, "rew_std": 166.93894277813533, "Agent": "dqn"}, {"env_step": 3600000, "rew": 684.3899963378906, "rew_std": 172.87680391908185, "Agent": "dqn"}, {"env_step": 3700000, "rew": 643.3400039672852, "rew_std": 180.69860945189737, "Agent": "dqn"}, {"env_step": 3800000, "rew": 601.6399993896484, "rew_std": 220.07335285006258, "Agent": "dqn"}, {"env_step": 3900000, "rew": 787.8099914550781, "rew_std": 150.5371983171373, "Agent": "dqn"}, {"env_step": 4000000, "rew": 709.3800048828125, "rew_std": 144.69985159836313, "Agent": "dqn"}, {"env_step": 4100000, "rew": 764.1300018310546, "rew_std": 195.3602195608862, "Agent": "dqn"}, {"env_step": 4200000, "rew": 680.2700012207031, "rew_std": 210.1577440975007, "Agent": "dqn"}, {"env_step": 4300000, "rew": 705.5600036621094, "rew_std": 216.64282568222822, "Agent": "dqn"}, {"env_step": 4400000, "rew": 808.1700073242188, "rew_std": 171.82676781938977, "Agent": "dqn"}, {"env_step": 4500000, "rew": 715.6900006294251, "rew_std": 325.7469276625226, "Agent": "dqn"}, {"env_step": 4600000, "rew": 732.2800018310547, "rew_std": 201.57531656345432, "Agent": "dqn"}, {"env_step": 4700000, "rew": 786.0399963378907, "rew_std": 166.10157244455863, "Agent": "dqn"}, {"env_step": 4800000, "rew": 786.3699920654296, "rew_std": 178.73539138167627, "Agent": "dqn"}, {"env_step": 4900000, "rew": 775.7399993896485, "rew_std": 224.3126871078825, "Agent": "dqn"}, {"env_step": 5000000, "rew": 837.7899932861328, "rew_std": 153.09056362847045, "Agent": "dqn"}, {"env_step": 5100000, "rew": 830.0400085449219, "rew_std": 160.10900182067942, "Agent": "dqn"}, {"env_step": 5200000, "rew": 823.5699981689453, "rew_std": 194.98995547995364, "Agent": "dqn"}, {"env_step": 5300000, "rew": 855.3900024414063, "rew_std": 140.50763881053481, "Agent": "dqn"}, {"env_step": 5400000, "rew": 894.0799865722656, "rew_std": 140.23513752874317, "Agent": "dqn"}, {"env_step": 5500000, "rew": 833.1599899291992, "rew_std": 252.01737545461143, "Agent": "dqn"}, {"env_step": 5600000, "rew": 810.3499969482422, "rew_std": 203.2235554729935, "Agent": "dqn"}, {"env_step": 5700000, "rew": 725.3200134277344, "rew_std": 319.200193375417, "Agent": "dqn"}, {"env_step": 5800000, "rew": 766.3399963378906, "rew_std": 245.59766336856708, "Agent": "dqn"}, {"env_step": 5900000, "rew": 824.45, "rew_std": 157.8227175018482, "Agent": "dqn"}, {"env_step": 6000000, "rew": 839.1499938964844, "rew_std": 267.6931744720739, "Agent": "dqn"}, {"env_step": 6100000, "rew": 911.7200012207031, "rew_std": 149.13539534554963, "Agent": "dqn"}, {"env_step": 6200000, "rew": 865.1299987792969, "rew_std": 151.98633776803706, "Agent": "dqn"}, {"env_step": 6300000, "rew": 701.3800109863281, "rew_std": 231.76907968297863, "Agent": "dqn"}, {"env_step": 6400000, "rew": 848.5299987792969, "rew_std": 142.20331551456655, "Agent": "dqn"}, {"env_step": 6500000, "rew": 857.1699829101562, "rew_std": 245.273054002329, "Agent": "dqn"}, {"env_step": 6600000, "rew": 872.0099975585938, "rew_std": 275.1938175427823, "Agent": "dqn"}, {"env_step": 6700000, "rew": 780.0400024414063, "rew_std": 218.67946501129768, "Agent": "dqn"}, {"env_step": 6800000, "rew": 972.5299926757813, "rew_std": 84.68238747454869, "Agent": "dqn"}, {"env_step": 6900000, "rew": 839.6300048828125, "rew_std": 172.06233073457074, "Agent": "dqn"}, {"env_step": 7000000, "rew": 765.7000030517578, "rew_std": 228.73327046919016, "Agent": "dqn"}, {"env_step": 7100000, "rew": 803.6699981689453, "rew_std": 223.20593976154984, "Agent": "dqn"}, {"env_step": 7200000, "rew": 869.5399948120117, "rew_std": 296.05559318233026, "Agent": "dqn"}, {"env_step": 7300000, "rew": 899.0700073242188, "rew_std": 223.90129458785196, "Agent": "dqn"}, {"env_step": 7400000, "rew": 894.3500122070312, "rew_std": 124.86718242625085, "Agent": "dqn"}, {"env_step": 7500000, "rew": 844.4800048828125, "rew_std": 178.28963708114927, "Agent": "dqn"}, {"env_step": 7600000, "rew": 832.0400024414063, "rew_std": 169.4274077668312, "Agent": "dqn"}, {"env_step": 7700000, "rew": 792.3100036621094, "rew_std": 230.68153832883345, "Agent": "dqn"}, {"env_step": 7800000, "rew": 803.5799963474274, "rew_std": 309.8821414828576, "Agent": "dqn"}, {"env_step": 7900000, "rew": 673.1000030517578, "rew_std": 227.6657908497509, "Agent": "dqn"}, {"env_step": 8000000, "rew": 902.8799987792969, "rew_std": 131.25283923863736, "Agent": "dqn"}, {"env_step": 8100000, "rew": 724.4699996948242, "rew_std": 295.4643426314042, "Agent": "dqn"}, {"env_step": 8200000, "rew": 927.2100036621093, "rew_std": 143.6894016940865, "Agent": "dqn"}, {"env_step": 8300000, "rew": 942.3600006103516, "rew_std": 285.8560329981964, "Agent": "dqn"}, {"env_step": 8400000, "rew": 851.1799987792969, "rew_std": 200.8229014095921, "Agent": "dqn"}, {"env_step": 8500000, "rew": 901.9700012207031, "rew_std": 160.8610738450723, "Agent": "dqn"}, {"env_step": 8600000, "rew": 871.3400024414062, "rew_std": 190.1946167222594, "Agent": "dqn"}, {"env_step": 8700000, "rew": 833.9000122070313, "rew_std": 240.40292224599622, "Agent": "dqn"}, {"env_step": 8800000, "rew": 869.95, "rew_std": 184.55665785734556, "Agent": "dqn"}, {"env_step": 8900000, "rew": 875.4000030517578, "rew_std": 285.9085977204755, "Agent": "dqn"}, {"env_step": 9000000, "rew": 867.3599975585937, "rew_std": 361.56061577070335, "Agent": "dqn"}, {"env_step": 9100000, "rew": 856.3100036621094, "rew_std": 315.8472236982946, "Agent": "dqn"}, {"env_step": 9200000, "rew": 856.9499938964843, "rew_std": 170.14471559957494, "Agent": "dqn"}, {"env_step": 9300000, "rew": 888.9199951171875, "rew_std": 168.72139042881585, "Agent": "dqn"}, {"env_step": 9400000, "rew": 866.0400131225585, "rew_std": 299.502088920062, "Agent": "dqn"}, {"env_step": 9500000, "rew": 840.8400024414062, "rew_std": 331.5126113915152, "Agent": "dqn"}, {"env_step": 9600000, "rew": 807.6199981689454, "rew_std": 295.85304518327024, "Agent": "dqn"}, {"env_step": 9700000, "rew": 997.8700073242187, "rew_std": 180.62872163215707, "Agent": "dqn"}, {"env_step": 9800000, "rew": 902.3699951171875, "rew_std": 191.78846507427704, "Agent": "dqn"}, {"env_step": 9900000, "rew": 832.2099884033203, "rew_std": 279.22019714027266, "Agent": "dqn"}, {"env_step": 10000000, "rew": 833.1200012207031, "rew_std": 180.39020450524953, "Agent": "dqn"}, {"env_step": 0, "rew": 0.010000000149011612, "rew_std": 0.03000000044703483, "Agent": "fqf"}, {"env_step": 100000, "rew": 0.4000000074505806, "rew_std": 0.5949790043492437, "Agent": "fqf"}, {"env_step": 200000, "rew": 0.5199999868869781, "rew_std": 0.9031057265445962, "Agent": "fqf"}, {"env_step": 300000, "rew": 0.020000000298023225, "rew_std": 0.04000000059604644, "Agent": "fqf"}, {"env_step": 400000, "rew": 0.1600000001490116, "rew_std": 0.4476605856920158, "Agent": "fqf"}, {"env_step": 500000, "rew": 0.06999999880790711, "rew_std": 0.2099999964237213, "Agent": "fqf"}, {"env_step": 600000, "rew": 3.8100000239908693, "rew_std": 5.602936729313963, "Agent": "fqf"}, {"env_step": 700000, "rew": 12.430000038444996, "rew_std": 19.787372405641634, "Agent": "fqf"}, {"env_step": 800000, "rew": 85.2499984741211, "rew_std": 63.42670047163025, "Agent": "fqf"}, {"env_step": 900000, "rew": 133.19000046253205, "rew_std": 70.73937261722692, "Agent": "fqf"}, {"env_step": 1000000, "rew": 305.75999908447267, "rew_std": 103.32499176000272, "Agent": "fqf"}, {"env_step": 1100000, "rew": 413.0800033569336, "rew_std": 109.4271524578977, "Agent": "fqf"}, {"env_step": 1200000, "rew": 447.47999572753906, "rew_std": 130.462502495816, "Agent": "fqf"}, {"env_step": 1300000, "rew": 518.3400054931641, "rew_std": 135.58270493310644, "Agent": "fqf"}, {"env_step": 1400000, "rew": 657.4699951171875, "rew_std": 146.84401743315414, "Agent": "fqf"}, {"env_step": 1500000, "rew": 748.0899963378906, "rew_std": 115.51613589716597, "Agent": "fqf"}, {"env_step": 1600000, "rew": 675.2199920654297, "rew_std": 152.8333391660496, "Agent": "fqf"}, {"env_step": 1700000, "rew": 711.7800109863281, "rew_std": 140.070698003054, "Agent": "fqf"}, {"env_step": 1800000, "rew": 768.7699981689453, "rew_std": 205.88203217918607, "Agent": "fqf"}, {"env_step": 1900000, "rew": 830.1599975585938, "rew_std": 164.96536198083228, "Agent": "fqf"}, {"env_step": 2000000, "rew": 838.0, "rew_std": 134.02208935127933, "Agent": "fqf"}, {"env_step": 2100000, "rew": 860.7899963378907, "rew_std": 130.38507918675546, "Agent": "fqf"}, {"env_step": 2200000, "rew": 977.0700073242188, "rew_std": 126.20982209871863, "Agent": "fqf"}, {"env_step": 2300000, "rew": 950.0299987792969, "rew_std": 156.14618566531715, "Agent": "fqf"}, {"env_step": 2400000, "rew": 846.3499877929687, "rew_std": 248.7892251417476, "Agent": "fqf"}, {"env_step": 2500000, "rew": 992.1199951171875, "rew_std": 115.97767285528869, "Agent": "fqf"}, {"env_step": 2600000, "rew": 987.1100036621094, "rew_std": 102.14894278448398, "Agent": "fqf"}, {"env_step": 2700000, "rew": 1004.62001953125, "rew_std": 195.0578605917069, "Agent": "fqf"}, {"env_step": 2800000, "rew": 972.389990234375, "rew_std": 155.29393383794988, "Agent": "fqf"}, {"env_step": 2900000, "rew": 948.7399932861329, "rew_std": 274.68796913449233, "Agent": "fqf"}, {"env_step": 3000000, "rew": 1058.3299926757813, "rew_std": 170.51883401478082, "Agent": "fqf"}, {"env_step": 3100000, "rew": 1085.8400024414063, "rew_std": 143.12279720739653, "Agent": "fqf"}, {"env_step": 3200000, "rew": 1149.7299865722657, "rew_std": 218.88763506679157, "Agent": "fqf"}, {"env_step": 3300000, "rew": 1151.3599914550782, "rew_std": 190.10523364130052, "Agent": "fqf"}, {"env_step": 3400000, "rew": 1080.9000122070313, "rew_std": 332.21089689623733, "Agent": "fqf"}, {"env_step": 3500000, "rew": 1113.7200073242188, "rew_std": 122.5017752266977, "Agent": "fqf"}, {"env_step": 3600000, "rew": 1153.9199951171875, "rew_std": 110.683608376239, "Agent": "fqf"}, {"env_step": 3700000, "rew": 1151.8500061035156, "rew_std": 253.38377103265688, "Agent": "fqf"}, {"env_step": 3800000, "rew": 1277.9900146484374, "rew_std": 112.66564272737654, "Agent": "fqf"}, {"env_step": 3900000, "rew": 1282.3099731445313, "rew_std": 217.32348002992126, "Agent": "fqf"}, {"env_step": 4000000, "rew": 1248.0900024414063, "rew_std": 170.67968176451825, "Agent": "fqf"}, {"env_step": 4100000, "rew": 1322.8300170898438, "rew_std": 153.41429245269322, "Agent": "fqf"}, {"env_step": 4200000, "rew": 1236.4299896240234, "rew_std": 308.0954582465107, "Agent": "fqf"}, {"env_step": 4300000, "rew": 1342.439990234375, "rew_std": 193.51426337682267, "Agent": "fqf"}, {"env_step": 4400000, "rew": 1149.9899963378907, "rew_std": 244.50325562855912, "Agent": "fqf"}, {"env_step": 4500000, "rew": 1353.7800048828126, "rew_std": 228.50535375228858, "Agent": "fqf"}, {"env_step": 4600000, "rew": 1196.4300048828125, "rew_std": 221.77920221529982, "Agent": "fqf"}, {"env_step": 4700000, "rew": 1338.169989013672, "rew_std": 258.7844697757098, "Agent": "fqf"}, {"env_step": 4800000, "rew": 1455.7900146484376, "rew_std": 222.74215426809056, "Agent": "fqf"}, {"env_step": 4900000, "rew": 1465.7900024414062, "rew_std": 211.17771615384086, "Agent": "fqf"}, {"env_step": 5000000, "rew": 1360.9300048828125, "rew_std": 158.9558026017405, "Agent": "fqf"}, {"env_step": 5100000, "rew": 1326.210009765625, "rew_std": 153.008085562372, "Agent": "fqf"}, {"env_step": 5200000, "rew": 1324.9800048828124, "rew_std": 192.30021302076855, "Agent": "fqf"}, {"env_step": 5300000, "rew": 1373.3199951171875, "rew_std": 190.05878674516413, "Agent": "fqf"}, {"env_step": 5400000, "rew": 1444.8299926757813, "rew_std": 176.63967297772865, "Agent": "fqf"}, {"env_step": 5500000, "rew": 1380.8700012207032, "rew_std": 275.76877411799353, "Agent": "fqf"}, {"env_step": 5600000, "rew": 1449.3299926757813, "rew_std": 189.39536052457436, "Agent": "fqf"}, {"env_step": 5700000, "rew": 1399.7899780273438, "rew_std": 186.23228867562835, "Agent": "fqf"}, {"env_step": 5800000, "rew": 1526.0900024414063, "rew_std": 226.097034699169, "Agent": "fqf"}, {"env_step": 5900000, "rew": 1323.17001953125, "rew_std": 214.75273220122753, "Agent": "fqf"}, {"env_step": 6000000, "rew": 1335.5700073242188, "rew_std": 199.16455799287937, "Agent": "fqf"}, {"env_step": 6100000, "rew": 1455.260009765625, "rew_std": 188.3336061850183, "Agent": "fqf"}, {"env_step": 6200000, "rew": 1460.6299926757813, "rew_std": 168.61568012736888, "Agent": "fqf"}, {"env_step": 6300000, "rew": 1531.5299926757812, "rew_std": 223.6121634917823, "Agent": "fqf"}, {"env_step": 6400000, "rew": 1473.3, "rew_std": 157.68532141529403, "Agent": "fqf"}, {"env_step": 6500000, "rew": 1348.3100036621095, "rew_std": 336.5419653083397, "Agent": "fqf"}, {"env_step": 6600000, "rew": 1360.3900024414063, "rew_std": 385.11999323091857, "Agent": "fqf"}, {"env_step": 6700000, "rew": 1525.3900024414063, "rew_std": 223.73438160213453, "Agent": "fqf"}, {"env_step": 6800000, "rew": 1424.3700134277344, "rew_std": 227.2740165319224, "Agent": "fqf"}, {"env_step": 6900000, "rew": 1444.5199951171876, "rew_std": 206.8632128076105, "Agent": "fqf"}, {"env_step": 7000000, "rew": 1550.7000244140625, "rew_std": 243.50839787358856, "Agent": "fqf"}, {"env_step": 7100000, "rew": 1510.7899963378907, "rew_std": 258.7483723817995, "Agent": "fqf"}, {"env_step": 7200000, "rew": 1483.280010986328, "rew_std": 264.7287475532349, "Agent": "fqf"}, {"env_step": 7300000, "rew": 1499.6600219726563, "rew_std": 341.719842324931, "Agent": "fqf"}, {"env_step": 7400000, "rew": 1687.2500244140624, "rew_std": 256.1505463645012, "Agent": "fqf"}, {"env_step": 7500000, "rew": 1454.0300048828126, "rew_std": 280.16069877794627, "Agent": "fqf"}, {"env_step": 7600000, "rew": 1593.1700012207032, "rew_std": 356.6653348908056, "Agent": "fqf"}, {"env_step": 7700000, "rew": 1752.35, "rew_std": 272.67251651423953, "Agent": "fqf"}, {"env_step": 7800000, "rew": 1424.1700073242187, "rew_std": 240.21965983357364, "Agent": "fqf"}, {"env_step": 7900000, "rew": 1545.0499877929688, "rew_std": 274.2309684007817, "Agent": "fqf"}, {"env_step": 8000000, "rew": 1491.2900024414062, "rew_std": 221.5089150627939, "Agent": "fqf"}, {"env_step": 8100000, "rew": 1686.5300170898438, "rew_std": 233.44947282580907, "Agent": "fqf"}, {"env_step": 8200000, "rew": 1654.559991455078, "rew_std": 257.67832128859897, "Agent": "fqf"}, {"env_step": 8300000, "rew": 1608.6599975585937, "rew_std": 238.0655354689683, "Agent": "fqf"}, {"env_step": 8400000, "rew": 1575.3399780273437, "rew_std": 203.41885145877984, "Agent": "fqf"}, {"env_step": 8500000, "rew": 1501.899984741211, "rew_std": 443.3008682851743, "Agent": "fqf"}, {"env_step": 8600000, "rew": 1344.8300048828125, "rew_std": 226.3636211692771, "Agent": "fqf"}, {"env_step": 8700000, "rew": 1358.25, "rew_std": 207.60759860147468, "Agent": "fqf"}, {"env_step": 8800000, "rew": 1577.8999877929687, "rew_std": 218.78633687041156, "Agent": "fqf"}, {"env_step": 8900000, "rew": 1816.8199951171875, "rew_std": 314.3398063009122, "Agent": "fqf"}, {"env_step": 9000000, "rew": 1508.2200073242188, "rew_std": 202.8977505176873, "Agent": "fqf"}, {"env_step": 9100000, "rew": 1388.3, "rew_std": 358.01253786434296, "Agent": "fqf"}, {"env_step": 9200000, "rew": 1657.6999877929688, "rew_std": 162.50743736286006, "Agent": "fqf"}, {"env_step": 9300000, "rew": 1769.5699829101563, "rew_std": 430.9407204209698, "Agent": "fqf"}, {"env_step": 9400000, "rew": 1644.0599975585938, "rew_std": 363.31351881469755, "Agent": "fqf"}, {"env_step": 9500000, "rew": 1774.7999755859375, "rew_std": 458.10963716513834, "Agent": "fqf"}, {"env_step": 9600000, "rew": 1574.6399780273437, "rew_std": 286.1478292951951, "Agent": "fqf"}, {"env_step": 9700000, "rew": 1621.3900024414063, "rew_std": 203.88568091812692, "Agent": "fqf"}, {"env_step": 9800000, "rew": 1800.6699829101562, "rew_std": 246.23150505646822, "Agent": "fqf"}, {"env_step": 9900000, "rew": 1717.560009765625, "rew_std": 272.92596987574973, "Agent": "fqf"}, {"env_step": 10000000, "rew": 1663.030029296875, "rew_std": 215.58594858353038, "Agent": "fqf"}, {"env_step": 0, "rew": 0.2, "rew_std": 0.43817804165122526, "Agent": "qrdqn"}, {"env_step": 100000, "rew": 2.950000001490116, "rew_std": 8.683806768461812, "Agent": "qrdqn"}, {"env_step": 200000, "rew": 2.750000011920929, "rew_std": 7.12281545355883, "Agent": "qrdqn"}, {"env_step": 300000, "rew": 2.139999923855066, "rew_std": 6.060560801476709, "Agent": "qrdqn"}, {"env_step": 400000, "rew": 8.719999969005585, "rew_std": 12.750513633943633, "Agent": "qrdqn"}, {"env_step": 500000, "rew": 0.020000000298023225, "rew_std": 0.06000000089406966, "Agent": "qrdqn"}, {"env_step": 600000, "rew": 4.539999961853027, "rew_std": 8.518591324669725, "Agent": "qrdqn"}, {"env_step": 700000, "rew": 20.8, "rew_std": 19.17268910037082, "Agent": "qrdqn"}, {"env_step": 800000, "rew": 64.31000022888183, "rew_std": 55.11154952769602, "Agent": "qrdqn"}, {"env_step": 900000, "rew": 117.36000137329101, "rew_std": 81.42853758823472, "Agent": "qrdqn"}, {"env_step": 1000000, "rew": 212.5300022125244, "rew_std": 134.6447147733004, "Agent": "qrdqn"}, {"env_step": 1100000, "rew": 311.71999740600586, "rew_std": 143.0127312672216, "Agent": "qrdqn"}, {"env_step": 1200000, "rew": 427.21999702453616, "rew_std": 166.6607249218663, "Agent": "qrdqn"}, {"env_step": 1300000, "rew": 450.81999626159666, "rew_std": 170.6163965376507, "Agent": "qrdqn"}, {"env_step": 1400000, "rew": 446.6700017929077, "rew_std": 171.28357388934182, "Agent": "qrdqn"}, {"env_step": 1500000, "rew": 501.95000114440916, "rew_std": 214.17296736360868, "Agent": "qrdqn"}, {"env_step": 1600000, "rew": 515.3500011444091, "rew_std": 228.24397876121898, "Agent": "qrdqn"}, {"env_step": 1700000, "rew": 532.7900005340576, "rew_std": 260.70330697040697, "Agent": "qrdqn"}, {"env_step": 1800000, "rew": 535.5499984741211, "rew_std": 226.46023524425055, "Agent": "qrdqn"}, {"env_step": 1900000, "rew": 614.3999969482422, "rew_std": 229.86442537632706, "Agent": "qrdqn"}, {"env_step": 2000000, "rew": 495.1, "rew_std": 277.79406604885946, "Agent": "qrdqn"}, {"env_step": 2100000, "rew": 585.6600044250488, "rew_std": 230.43919543285637, "Agent": "qrdqn"}, {"env_step": 2200000, "rew": 716.0299999237061, "rew_std": 264.0099569711277, "Agent": "qrdqn"}, {"env_step": 2300000, "rew": 542.9899975776673, "rew_std": 303.6674555033, "Agent": "qrdqn"}, {"env_step": 2400000, "rew": 715.6400062561036, "rew_std": 286.50699130558786, "Agent": "qrdqn"}, {"env_step": 2500000, "rew": 648.0000051498413, "rew_std": 301.31887025376795, "Agent": "qrdqn"}, {"env_step": 2600000, "rew": 593.8800014495849, "rew_std": 302.4725858413253, "Agent": "qrdqn"}, {"env_step": 2700000, "rew": 662.1199962615967, "rew_std": 292.98960962786776, "Agent": "qrdqn"}, {"env_step": 2800000, "rew": 695.9800054550171, "rew_std": 284.20259753123815, "Agent": "qrdqn"}, {"env_step": 2900000, "rew": 729.7400060653687, "rew_std": 261.0564239167806, "Agent": "qrdqn"}, {"env_step": 3000000, "rew": 765.3200073242188, "rew_std": 267.8584904904683, "Agent": "qrdqn"}, {"env_step": 3100000, "rew": 763.0400096893311, "rew_std": 257.2424218386392, "Agent": "qrdqn"}, {"env_step": 3200000, "rew": 775.7299938201904, "rew_std": 273.62065502831194, "Agent": "qrdqn"}, {"env_step": 3300000, "rew": 768.5900030136108, "rew_std": 293.28401432770056, "Agent": "qrdqn"}, {"env_step": 3400000, "rew": 736.4700023651124, "rew_std": 251.00099161324306, "Agent": "qrdqn"}, {"env_step": 3500000, "rew": 704.0100011825562, "rew_std": 279.16315078624314, "Agent": "qrdqn"}, {"env_step": 3600000, "rew": 711.050004196167, "rew_std": 264.62492121177553, "Agent": "qrdqn"}, {"env_step": 3700000, "rew": 801.9700037002564, "rew_std": 277.7941106570264, "Agent": "qrdqn"}, {"env_step": 3800000, "rew": 844.069990158081, "rew_std": 287.56283848408725, "Agent": "qrdqn"}, {"env_step": 3900000, "rew": 759.6999963760376, "rew_std": 279.6283181481931, "Agent": "qrdqn"}, {"env_step": 4000000, "rew": 583.4499931335449, "rew_std": 287.06731842027597, "Agent": "qrdqn"}, {"env_step": 4100000, "rew": 749.8800132751464, "rew_std": 264.1112287832134, "Agent": "qrdqn"}, {"env_step": 4200000, "rew": 793.8099964141845, "rew_std": 302.25798478353556, "Agent": "qrdqn"}, {"env_step": 4300000, "rew": 721.4700023651124, "rew_std": 280.748223525854, "Agent": "qrdqn"}, {"env_step": 4400000, "rew": 782.9300025939941, "rew_std": 341.9947531259377, "Agent": "qrdqn"}, {"env_step": 4500000, "rew": 861.8200061798095, "rew_std": 300.90433581636256, "Agent": "qrdqn"}, {"env_step": 4600000, "rew": 699.4100095748902, "rew_std": 358.53734373918525, "Agent": "qrdqn"}, {"env_step": 4700000, "rew": 763.3199853897095, "rew_std": 321.5669115617899, "Agent": "qrdqn"}, {"env_step": 4800000, "rew": 875.8500049591064, "rew_std": 322.19139996909126, "Agent": "qrdqn"}, {"env_step": 4900000, "rew": 798.370009803772, "rew_std": 320.01471740842925, "Agent": "qrdqn"}, {"env_step": 5000000, "rew": 916.0799865722656, "rew_std": 321.69201590196326, "Agent": "qrdqn"}, {"env_step": 5100000, "rew": 854.7900035858154, "rew_std": 282.8008511766018, "Agent": "qrdqn"}, {"env_step": 5200000, "rew": 778.5300037384034, "rew_std": 300.45106950066236, "Agent": "qrdqn"}, {"env_step": 5300000, "rew": 824.4299976348877, "rew_std": 302.3329753994756, "Agent": "qrdqn"}, {"env_step": 5400000, "rew": 888.6600048065186, "rew_std": 338.1730635095935, "Agent": "qrdqn"}, {"env_step": 5500000, "rew": 839.840009689331, "rew_std": 347.04786183662765, "Agent": "qrdqn"}, {"env_step": 5600000, "rew": 743.8000047683715, "rew_std": 372.26147487134074, "Agent": "qrdqn"}, {"env_step": 5700000, "rew": 867.8499877929687, "rew_std": 316.0468936380672, "Agent": "qrdqn"}, {"env_step": 5800000, "rew": 823.3300037384033, "rew_std": 323.5871449731984, "Agent": "qrdqn"}, {"env_step": 5900000, "rew": 840.5399921417236, "rew_std": 379.45472835781385, "Agent": "qrdqn"}, {"env_step": 6000000, "rew": 795.1999963760376, "rew_std": 305.4413360252164, "Agent": "qrdqn"}, {"env_step": 6100000, "rew": 837.2100086212158, "rew_std": 294.2474416713455, "Agent": "qrdqn"}, {"env_step": 6200000, "rew": 832.8199975967407, "rew_std": 315.05369700324695, "Agent": "qrdqn"}, {"env_step": 6300000, "rew": 758.1000026702881, "rew_std": 356.9235426531571, "Agent": "qrdqn"}, {"env_step": 6400000, "rew": 869.2500061035156, "rew_std": 302.04547033603006, "Agent": "qrdqn"}, {"env_step": 6500000, "rew": 785.2299983978271, "rew_std": 393.64854338638054, "Agent": "qrdqn"}, {"env_step": 6600000, "rew": 790.1000085830689, "rew_std": 365.117793324928, "Agent": "qrdqn"}, {"env_step": 6700000, "rew": 871.6399927139282, "rew_std": 307.94262149553043, "Agent": "qrdqn"}, {"env_step": 6800000, "rew": 769.5600109100342, "rew_std": 397.7108398542242, "Agent": "qrdqn"}, {"env_step": 6900000, "rew": 897.1599975585938, "rew_std": 326.12778476504053, "Agent": "qrdqn"}, {"env_step": 7000000, "rew": 826.01999874115, "rew_std": 303.25010067918225, "Agent": "qrdqn"}, {"env_step": 7100000, "rew": 899.7099914550781, "rew_std": 354.635686696347, "Agent": "qrdqn"}, {"env_step": 7200000, "rew": 839.5300022125244, "rew_std": 366.1003253391808, "Agent": "qrdqn"}, {"env_step": 7300000, "rew": 789.2700035095215, "rew_std": 325.517405634325, "Agent": "qrdqn"}, {"env_step": 7400000, "rew": 791.4800006866456, "rew_std": 308.0200619677834, "Agent": "qrdqn"}, {"env_step": 7500000, "rew": 753.9000038146972, "rew_std": 353.6681404822802, "Agent": "qrdqn"}, {"env_step": 7600000, "rew": 760.5899974822999, "rew_std": 373.76654972757194, "Agent": "qrdqn"}, {"env_step": 7700000, "rew": 895.6800060272217, "rew_std": 332.0611307126876, "Agent": "qrdqn"}, {"env_step": 7800000, "rew": 797.7900001525879, "rew_std": 337.71737028991475, "Agent": "qrdqn"}, {"env_step": 7900000, "rew": 863.1199945449829, "rew_std": 383.8510796441181, "Agent": "qrdqn"}, {"env_step": 8000000, "rew": 936.8600036621094, "rew_std": 335.0957512819982, "Agent": "qrdqn"}, {"env_step": 8100000, "rew": 873.0900012969971, "rew_std": 323.48754155145383, "Agent": "qrdqn"}, {"env_step": 8200000, "rew": 897.2599962234497, "rew_std": 430.1072599184845, "Agent": "qrdqn"}, {"env_step": 8300000, "rew": 925.4600048065186, "rew_std": 328.0253425434283, "Agent": "qrdqn"}, {"env_step": 8400000, "rew": 800.1099956512451, "rew_std": 369.4526365613206, "Agent": "qrdqn"}, {"env_step": 8500000, "rew": 764.4199901580811, "rew_std": 386.03387542422104, "Agent": "qrdqn"}, {"env_step": 8600000, "rew": 951.6600109100342, "rew_std": 333.5321850552157, "Agent": "qrdqn"}, {"env_step": 8700000, "rew": 746.5100072860718, "rew_std": 286.61728504504305, "Agent": "qrdqn"}, {"env_step": 8800000, "rew": 849.2099956512451, "rew_std": 376.9793778487278, "Agent": "qrdqn"}, {"env_step": 8900000, "rew": 804.9700115203857, "rew_std": 377.45671631618694, "Agent": "qrdqn"}, {"env_step": 9000000, "rew": 855.3400024414062, "rew_std": 335.099863589258, "Agent": "qrdqn"}, {"env_step": 9100000, "rew": 660.5900043487549, "rew_std": 356.5631885015981, "Agent": "qrdqn"}, {"env_step": 9200000, "rew": 878.2999950408936, "rew_std": 312.5758680927236, "Agent": "qrdqn"}, {"env_step": 9300000, "rew": 826.950011062622, "rew_std": 350.83283915590056, "Agent": "qrdqn"}, {"env_step": 9400000, "rew": 790.7599872589111, "rew_std": 401.66813659775227, "Agent": "qrdqn"}, {"env_step": 9500000, "rew": 849.3099939346314, "rew_std": 313.7882372406271, "Agent": "qrdqn"}, {"env_step": 9600000, "rew": 854.3600103378296, "rew_std": 367.7412676045596, "Agent": "qrdqn"}, {"env_step": 9700000, "rew": 803.1500019073486, "rew_std": 384.89482159028836, "Agent": "qrdqn"}, {"env_step": 9800000, "rew": 655.2900049209595, "rew_std": 378.5388990350636, "Agent": "qrdqn"}, {"env_step": 9900000, "rew": 778.3899938583374, "rew_std": 332.5220751161268, "Agent": "qrdqn"}, {"env_step": 10000000, "rew": 805.2999959945679, "rew_std": 376.93931456836924, "Agent": "qrdqn"}, {"env_step": 0, "rew": 0.020000000298023225, "rew_std": 0.06000000089406966, "Agent": "iqn"}, {"env_step": 100000, "rew": 1.2300000190734863, "rew_std": 3.52648555976251, "Agent": "iqn"}, {"env_step": 200000, "rew": 0.17000000029802323, "rew_std": 0.2193171198611889, "Agent": "iqn"}, {"env_step": 300000, "rew": 2.0799999237060547, "rew_std": 5.910634254899296, "Agent": "iqn"}, {"env_step": 400000, "rew": 0.04000000059604645, "rew_std": 0.08000000119209288, "Agent": "iqn"}, {"env_step": 500000, "rew": 3.5900000773370264, "rew_std": 10.53740503238059, "Agent": "iqn"}, {"env_step": 600000, "rew": 4.530000066757202, "rew_std": 9.554166769060705, "Agent": "iqn"}, {"env_step": 700000, "rew": 3.480000114440918, "rew_std": 6.673799742209458, "Agent": "iqn"}, {"env_step": 800000, "rew": 37.689999313652514, "rew_std": 42.65694315548521, "Agent": "iqn"}, {"env_step": 900000, "rew": 123.90999913215637, "rew_std": 116.8702046746665, "Agent": "iqn"}, {"env_step": 1000000, "rew": 214.27999999523163, "rew_std": 131.19766379013453, "Agent": "iqn"}, {"env_step": 1100000, "rew": 314.01000213623047, "rew_std": 130.61631856241644, "Agent": "iqn"}, {"env_step": 1200000, "rew": 447.6299987792969, "rew_std": 132.88649299306232, "Agent": "iqn"}, {"env_step": 1300000, "rew": 488.24000244140626, "rew_std": 144.3410893300805, "Agent": "iqn"}, {"env_step": 1400000, "rew": 562.2800018310547, "rew_std": 126.35490427347918, "Agent": "iqn"}, {"env_step": 1500000, "rew": 503.0600067138672, "rew_std": 126.1369021393955, "Agent": "iqn"}, {"env_step": 1600000, "rew": 590.9699981689453, "rew_std": 106.22815802626981, "Agent": "iqn"}, {"env_step": 1700000, "rew": 656.8100067138672, "rew_std": 176.85540361563625, "Agent": "iqn"}, {"env_step": 1800000, "rew": 639.1500061035156, "rew_std": 159.90553612644817, "Agent": "iqn"}, {"env_step": 1900000, "rew": 654.3699951171875, "rew_std": 166.7387170420838, "Agent": "iqn"}, {"env_step": 2000000, "rew": 633.2700103759765, "rew_std": 194.57762929404822, "Agent": "iqn"}, {"env_step": 2100000, "rew": 700.1900115966797, "rew_std": 157.40320450762798, "Agent": "iqn"}, {"env_step": 2200000, "rew": 628.6800003051758, "rew_std": 207.51145605282667, "Agent": "iqn"}, {"env_step": 2300000, "rew": 684.490007019043, "rew_std": 224.21249061996116, "Agent": "iqn"}, {"env_step": 2400000, "rew": 756.5900001525879, "rew_std": 288.4319216808326, "Agent": "iqn"}, {"env_step": 2500000, "rew": 675.5500030517578, "rew_std": 244.21607763740568, "Agent": "iqn"}, {"env_step": 2600000, "rew": 779.7999938964844, "rew_std": 256.5563681757168, "Agent": "iqn"}, {"env_step": 2700000, "rew": 727.7399963378906, "rew_std": 269.84431788518737, "Agent": "iqn"}, {"env_step": 2800000, "rew": 792.0200012207031, "rew_std": 116.56821154467826, "Agent": "iqn"}, {"env_step": 2900000, "rew": 859.7300109863281, "rew_std": 185.18119041013455, "Agent": "iqn"}, {"env_step": 3000000, "rew": 899.7199981689453, "rew_std": 216.57974461966018, "Agent": "iqn"}, {"env_step": 3100000, "rew": 915.3699890136719, "rew_std": 114.51627047111124, "Agent": "iqn"}, {"env_step": 3200000, "rew": 795.7599945068359, "rew_std": 240.49648753795014, "Agent": "iqn"}, {"env_step": 3300000, "rew": 880.1699981689453, "rew_std": 189.51477550814192, "Agent": "iqn"}, {"env_step": 3400000, "rew": 945.1100036621094, "rew_std": 128.7564537028809, "Agent": "iqn"}, {"env_step": 3500000, "rew": 919.9100036621094, "rew_std": 244.82328680147785, "Agent": "iqn"}, {"env_step": 3600000, "rew": 982.8000061035157, "rew_std": 159.50998642265364, "Agent": "iqn"}, {"env_step": 3700000, "rew": 837.7000122070312, "rew_std": 188.50218141170882, "Agent": "iqn"}, {"env_step": 3800000, "rew": 1006.9300048828125, "rew_std": 142.2613367175323, "Agent": "iqn"}, {"env_step": 3900000, "rew": 913.0099868774414, "rew_std": 267.489982408041, "Agent": "iqn"}, {"env_step": 4000000, "rew": 874.3000061035157, "rew_std": 173.92007373390783, "Agent": "iqn"}, {"env_step": 4100000, "rew": 910.5500030517578, "rew_std": 194.92890230647552, "Agent": "iqn"}, {"env_step": 4200000, "rew": 983.5, "rew_std": 116.83092046232777, "Agent": "iqn"}, {"env_step": 4300000, "rew": 901.1400039672851, "rew_std": 305.8871257170003, "Agent": "iqn"}, {"env_step": 4400000, "rew": 813.9199890136719, "rew_std": 259.7093781051844, "Agent": "iqn"}, {"env_step": 4500000, "rew": 975.1299987792969, "rew_std": 249.7706832098956, "Agent": "iqn"}, {"env_step": 4600000, "rew": 964.7699890136719, "rew_std": 288.6829312458577, "Agent": "iqn"}, {"env_step": 4700000, "rew": 990.8800170898437, "rew_std": 227.1040665924821, "Agent": "iqn"}, {"env_step": 4800000, "rew": 1069.3499877929687, "rew_std": 184.13221489797237, "Agent": "iqn"}, {"env_step": 4900000, "rew": 985.4000122070313, "rew_std": 185.19558967958181, "Agent": "iqn"}, {"env_step": 5000000, "rew": 888.0499984741211, "rew_std": 383.0892119253023, "Agent": "iqn"}, {"env_step": 5100000, "rew": 1122.0600036621095, "rew_std": 252.77394487644332, "Agent": "iqn"}, {"env_step": 5200000, "rew": 972.6900054931641, "rew_std": 222.1775487183736, "Agent": "iqn"}, {"env_step": 5300000, "rew": 966.9400115966797, "rew_std": 369.08832261651287, "Agent": "iqn"}, {"env_step": 5400000, "rew": 789.2899993896484, "rew_std": 320.53568647830184, "Agent": "iqn"}, {"env_step": 5500000, "rew": 1027.3899841308594, "rew_std": 133.49564343614747, "Agent": "iqn"}, {"env_step": 5600000, "rew": 872.7399963378906, "rew_std": 283.1106543105209, "Agent": "iqn"}, {"env_step": 5700000, "rew": 1003.5799987792968, "rew_std": 303.12510600006334, "Agent": "iqn"}, {"env_step": 5800000, "rew": 898.3699935913086, "rew_std": 299.9163407129428, "Agent": "iqn"}, {"env_step": 5900000, "rew": 928.5400024414063, "rew_std": 183.30899650150636, "Agent": "iqn"}, {"env_step": 6000000, "rew": 1099.45, "rew_std": 215.64728660357196, "Agent": "iqn"}, {"env_step": 6100000, "rew": 1008.9999969482421, "rew_std": 270.1578310856458, "Agent": "iqn"}, {"env_step": 6200000, "rew": 1065.940008544922, "rew_std": 255.40183052553036, "Agent": "iqn"}, {"env_step": 6300000, "rew": 811.1000024795533, "rew_std": 373.20585603601734, "Agent": "iqn"}, {"env_step": 6400000, "rew": 940.3700012207031, "rew_std": 246.35399546539406, "Agent": "iqn"}, {"env_step": 6500000, "rew": 1068.6700012207032, "rew_std": 97.1969648010114, "Agent": "iqn"}, {"env_step": 6600000, "rew": 1245.320001220703, "rew_std": 287.68207875342046, "Agent": "iqn"}, {"env_step": 6700000, "rew": 1029.4099975585937, "rew_std": 181.585287367347, "Agent": "iqn"}, {"env_step": 6800000, "rew": 1042.259991455078, "rew_std": 164.0986750263718, "Agent": "iqn"}, {"env_step": 6900000, "rew": 838.8700035095214, "rew_std": 355.00632184818426, "Agent": "iqn"}, {"env_step": 7000000, "rew": 1098.1199951171875, "rew_std": 197.77309679595174, "Agent": "iqn"}, {"env_step": 7100000, "rew": 929.949984741211, "rew_std": 290.7787575853067, "Agent": "iqn"}, {"env_step": 7200000, "rew": 1002.5799926757812, "rew_std": 238.4512457320423, "Agent": "iqn"}, {"env_step": 7300000, "rew": 936.2500061035156, "rew_std": 200.61891005074025, "Agent": "iqn"}, {"env_step": 7400000, "rew": 1090.2499938964843, "rew_std": 137.09873398122215, "Agent": "iqn"}, {"env_step": 7500000, "rew": 1079.7300170898438, "rew_std": 129.4222508666326, "Agent": "iqn"}, {"env_step": 7600000, "rew": 968.8100051879883, "rew_std": 469.84580201774713, "Agent": "iqn"}, {"env_step": 7700000, "rew": 1022.8900024414063, "rew_std": 251.64726234338931, "Agent": "iqn"}, {"env_step": 7800000, "rew": 1021.4299987792969, "rew_std": 243.66798220474894, "Agent": "iqn"}, {"env_step": 7900000, "rew": 1113.2900024414062, "rew_std": 199.72603151675196, "Agent": "iqn"}, {"env_step": 8000000, "rew": 1132.0199890136719, "rew_std": 263.47352587686873, "Agent": "iqn"}, {"env_step": 8100000, "rew": 1050.8499877929687, "rew_std": 191.7771377236277, "Agent": "iqn"}, {"env_step": 8200000, "rew": 1099.139990234375, "rew_std": 223.08476706246242, "Agent": "iqn"}, {"env_step": 8300000, "rew": 1095.5199951171876, "rew_std": 152.96869354522246, "Agent": "iqn"}, {"env_step": 8400000, "rew": 1059.9700012207031, "rew_std": 121.60402177574996, "Agent": "iqn"}, {"env_step": 8500000, "rew": 1119.9500122070312, "rew_std": 173.45015933529174, "Agent": "iqn"}, {"env_step": 8600000, "rew": 940.7099975585937, "rew_std": 184.9244564548404, "Agent": "iqn"}, {"env_step": 8700000, "rew": 930.7999961853027, "rew_std": 366.8234875758687, "Agent": "iqn"}, {"env_step": 8800000, "rew": 1097.2800170898438, "rew_std": 296.206762846177, "Agent": "iqn"}, {"env_step": 8900000, "rew": 1139.8899780273437, "rew_std": 255.91239554199356, "Agent": "iqn"}, {"env_step": 9000000, "rew": 1043.5500061035157, "rew_std": 173.89118497858877, "Agent": "iqn"}, {"env_step": 9100000, "rew": 929.5200164794921, "rew_std": 373.30865054928415, "Agent": "iqn"}, {"env_step": 9200000, "rew": 1205.3200134277345, "rew_std": 275.42816935716706, "Agent": "iqn"}, {"env_step": 9300000, "rew": 1150.1200012207032, "rew_std": 260.06818848705825, "Agent": "iqn"}, {"env_step": 9400000, "rew": 1100.200018310547, "rew_std": 185.52076935234098, "Agent": "iqn"}, {"env_step": 9500000, "rew": 1058.1600158691406, "rew_std": 311.87799312292907, "Agent": "iqn"}, {"env_step": 9600000, "rew": 1252.6800048828125, "rew_std": 118.09878836211058, "Agent": "iqn"}, {"env_step": 9700000, "rew": 1132.0099853515626, "rew_std": 200.64719895414822, "Agent": "iqn"}, {"env_step": 9800000, "rew": 1039.539990234375, "rew_std": 270.93414588943654, "Agent": "iqn"}, {"env_step": 9900000, "rew": 1111.9599914550781, "rew_std": 303.33757722581527, "Agent": "iqn"}, {"env_step": 10000000, "rew": 1095.0599853515625, "rew_std": 200.86304116683058, "Agent": "iqn"}, {"env_step": 0, "rew": 0.2100000001490116, "rew_std": 0.5974110812223167, "Agent": "rainbow"}, {"env_step": 100000, "rew": 0.12999999523162842, "rew_std": 0.38999998569488525, "Agent": "rainbow"}, {"env_step": 200000, "rew": 2.7599999859929083, "rew_std": 6.376864428026797, "Agent": "rainbow"}, {"env_step": 300000, "rew": 0.7399999916553497, "rew_std": 1.967841428665537, "Agent": "rainbow"}, {"env_step": 400000, "rew": 0.7299999989569187, "rew_std": 1.9344508264320834, "Agent": "rainbow"}, {"env_step": 500000, "rew": 2.250000037252903, "rew_std": 6.453410065907903, "Agent": "rainbow"}, {"env_step": 600000, "rew": 0.3300000071525574, "rew_std": 0.7043436813838512, "Agent": "rainbow"}, {"env_step": 700000, "rew": 11.450000222027302, "rew_std": 19.350775633693363, "Agent": "rainbow"}, {"env_step": 800000, "rew": 51.87000031471253, "rew_std": 41.72313644425024, "Agent": "rainbow"}, {"env_step": 900000, "rew": 134.95999908447266, "rew_std": 26.553200308986938, "Agent": "rainbow"}, {"env_step": 1000000, "rew": 273.26000061035154, "rew_std": 40.46751980619983, "Agent": "rainbow"}, {"env_step": 1100000, "rew": 394.9699981689453, "rew_std": 35.29900920102157, "Agent": "rainbow"}, {"env_step": 1200000, "rew": 475.4699951171875, "rew_std": 51.0648386902766, "Agent": "rainbow"}, {"env_step": 1300000, "rew": 537.4499969482422, "rew_std": 87.05237149623139, "Agent": "rainbow"}, {"env_step": 1400000, "rew": 528.3800109863281, "rew_std": 74.70653568610189, "Agent": "rainbow"}, {"env_step": 1500000, "rew": 602.5700042724609, "rew_std": 63.9815013284615, "Agent": "rainbow"}, {"env_step": 1600000, "rew": 672.6400024414063, "rew_std": 75.62920855186832, "Agent": "rainbow"}, {"env_step": 1700000, "rew": 670.960009765625, "rew_std": 59.331076612532364, "Agent": "rainbow"}, {"env_step": 1800000, "rew": 704.7300048828125, "rew_std": 67.48957648360094, "Agent": "rainbow"}, {"env_step": 1900000, "rew": 787.0799987792968, "rew_std": 112.42707564022125, "Agent": "rainbow"}, {"env_step": 2000000, "rew": 823.6899963378906, "rew_std": 77.87041479137376, "Agent": "rainbow"}, {"env_step": 2100000, "rew": 840.9600036621093, "rew_std": 68.54743565383826, "Agent": "rainbow"}, {"env_step": 2200000, "rew": 822.8200012207031, "rew_std": 101.75918406873306, "Agent": "rainbow"}, {"env_step": 2300000, "rew": 846.6400024414063, "rew_std": 56.10774301137517, "Agent": "rainbow"}, {"env_step": 2400000, "rew": 935.4899963378906, "rew_std": 81.81529716155883, "Agent": "rainbow"}, {"env_step": 2500000, "rew": 871.6499938964844, "rew_std": 105.67288003470732, "Agent": "rainbow"}, {"env_step": 2600000, "rew": 935.3400085449218, "rew_std": 93.00937834754181, "Agent": "rainbow"}, {"env_step": 2700000, "rew": 962.9700134277343, "rew_std": 60.47081018959421, "Agent": "rainbow"}, {"env_step": 2800000, "rew": 939.1, "rew_std": 76.38223487303658, "Agent": "rainbow"}, {"env_step": 2900000, "rew": 983.4500122070312, "rew_std": 67.66671891975996, "Agent": "rainbow"}, {"env_step": 3000000, "rew": 1005.25, "rew_std": 46.68377918260524, "Agent": "rainbow"}, {"env_step": 3100000, "rew": 1027.85, "rew_std": 100.20803948410037, "Agent": "rainbow"}, {"env_step": 3200000, "rew": 1019.7700134277344, "rew_std": 74.38652294354544, "Agent": "rainbow"}, {"env_step": 3300000, "rew": 1025.939990234375, "rew_std": 67.48715328162088, "Agent": "rainbow"}, {"env_step": 3400000, "rew": 1048.9500061035155, "rew_std": 63.75370223721345, "Agent": "rainbow"}, {"env_step": 3500000, "rew": 1024.8799865722656, "rew_std": 81.16544192265837, "Agent": "rainbow"}, {"env_step": 3600000, "rew": 1057.9299865722655, "rew_std": 79.86718486180474, "Agent": "rainbow"}, {"env_step": 3700000, "rew": 1059.6200012207032, "rew_std": 100.3108670590954, "Agent": "rainbow"}, {"env_step": 3800000, "rew": 1100.7599975585938, "rew_std": 81.71528182411633, "Agent": "rainbow"}, {"env_step": 3900000, "rew": 1076.8199890136718, "rew_std": 65.37815728364572, "Agent": "rainbow"}, {"env_step": 4000000, "rew": 1191.0, "rew_std": 100.77746153732592, "Agent": "rainbow"}, {"env_step": 4100000, "rew": 1085.950018310547, "rew_std": 93.44744071480993, "Agent": "rainbow"}, {"env_step": 4200000, "rew": 1124.4300048828125, "rew_std": 98.69067996332086, "Agent": "rainbow"}, {"env_step": 4300000, "rew": 1204.660009765625, "rew_std": 81.11235149073514, "Agent": "rainbow"}, {"env_step": 4400000, "rew": 1136.6400024414063, "rew_std": 78.94447133335663, "Agent": "rainbow"}, {"env_step": 4500000, "rew": 1154.5200073242188, "rew_std": 74.92851522386881, "Agent": "rainbow"}, {"env_step": 4600000, "rew": 1206.0400085449219, "rew_std": 103.3929078135147, "Agent": "rainbow"}, {"env_step": 4700000, "rew": 1204.1800170898437, "rew_std": 82.69720477051608, "Agent": "rainbow"}, {"env_step": 4800000, "rew": 1142.9900146484374, "rew_std": 100.03756445774546, "Agent": "rainbow"}, {"env_step": 4900000, "rew": 1199.2599853515626, "rew_std": 43.66213881744983, "Agent": "rainbow"}, {"env_step": 5000000, "rew": 1174.4100036621094, "rew_std": 135.42985480511723, "Agent": "rainbow"}, {"env_step": 5100000, "rew": 1206.4500244140625, "rew_std": 65.40353311217433, "Agent": "rainbow"}, {"env_step": 5200000, "rew": 1213.0400024414062, "rew_std": 56.73084174140543, "Agent": "rainbow"}, {"env_step": 5300000, "rew": 1279.0799926757813, "rew_std": 122.6794261050074, "Agent": "rainbow"}, {"env_step": 5400000, "rew": 1260.5200073242188, "rew_std": 78.9505194850195, "Agent": "rainbow"}, {"env_step": 5500000, "rew": 1181.0700073242188, "rew_std": 114.83173344170228, "Agent": "rainbow"}, {"env_step": 5600000, "rew": 1176.05, "rew_std": 83.41206986065441, "Agent": "rainbow"}, {"env_step": 5700000, "rew": 1270.2599853515626, "rew_std": 124.96829424226486, "Agent": "rainbow"}, {"env_step": 5800000, "rew": 1261.5499755859375, "rew_std": 105.05167725943326, "Agent": "rainbow"}, {"env_step": 5900000, "rew": 1254.1099853515625, "rew_std": 103.52183258934855, "Agent": "rainbow"}, {"env_step": 6000000, "rew": 1285.210009765625, "rew_std": 135.08428673916382, "Agent": "rainbow"}, {"env_step": 6100000, "rew": 1321.8599975585937, "rew_std": 98.54867262115465, "Agent": "rainbow"}, {"env_step": 6200000, "rew": 1270.4499877929688, "rew_std": 134.84840828335157, "Agent": "rainbow"}, {"env_step": 6300000, "rew": 1291.8700073242187, "rew_std": 128.573159656795, "Agent": "rainbow"}, {"env_step": 6400000, "rew": 1372.1099975585937, "rew_std": 88.26145690808981, "Agent": "rainbow"}, {"env_step": 6500000, "rew": 1354.3300170898438, "rew_std": 76.23653593249794, "Agent": "rainbow"}, {"env_step": 6600000, "rew": 1337.8300048828125, "rew_std": 111.20302612183444, "Agent": "rainbow"}, {"env_step": 6700000, "rew": 1287.5800048828125, "rew_std": 156.0572077287139, "Agent": "rainbow"}, {"env_step": 6800000, "rew": 1319.3700073242187, "rew_std": 129.23841474784112, "Agent": "rainbow"}, {"env_step": 6900000, "rew": 1279.7999877929688, "rew_std": 117.75878546918071, "Agent": "rainbow"}, {"env_step": 7000000, "rew": 1328.610009765625, "rew_std": 100.9171629081728, "Agent": "rainbow"}, {"env_step": 7100000, "rew": 1364.7, "rew_std": 163.70892187079815, "Agent": "rainbow"}, {"env_step": 7200000, "rew": 1308.8900024414063, "rew_std": 88.94055366414823, "Agent": "rainbow"}, {"env_step": 7300000, "rew": 1322.25, "rew_std": 94.6752858690983, "Agent": "rainbow"}, {"env_step": 7400000, "rew": 1309.5300170898438, "rew_std": 130.62605278751548, "Agent": "rainbow"}, {"env_step": 7500000, "rew": 1346.460009765625, "rew_std": 117.62017984362635, "Agent": "rainbow"}, {"env_step": 7600000, "rew": 1307.6800170898437, "rew_std": 135.51715895844544, "Agent": "rainbow"}, {"env_step": 7700000, "rew": 1370.15, "rew_std": 121.40495533214569, "Agent": "rainbow"}, {"env_step": 7800000, "rew": 1366.02001953125, "rew_std": 155.8914434634503, "Agent": "rainbow"}, {"env_step": 7900000, "rew": 1383.0500122070312, "rew_std": 119.51592283983616, "Agent": "rainbow"}, {"env_step": 8000000, "rew": 1347.489990234375, "rew_std": 107.93389821410152, "Agent": "rainbow"}, {"env_step": 8100000, "rew": 1382.2799926757812, "rew_std": 81.35011185204777, "Agent": "rainbow"}, {"env_step": 8200000, "rew": 1357.7900024414062, "rew_std": 102.5295829744403, "Agent": "rainbow"}, {"env_step": 8300000, "rew": 1308.2999877929688, "rew_std": 108.36144846385606, "Agent": "rainbow"}, {"env_step": 8400000, "rew": 1368.2300048828124, "rew_std": 109.5325932250981, "Agent": "rainbow"}, {"env_step": 8500000, "rew": 1311.7599975585938, "rew_std": 113.36432112272027, "Agent": "rainbow"}, {"env_step": 8600000, "rew": 1384.7000122070312, "rew_std": 129.2364554734506, "Agent": "rainbow"}, {"env_step": 8700000, "rew": 1377.9300170898437, "rew_std": 130.36785628181076, "Agent": "rainbow"}, {"env_step": 8800000, "rew": 1420.160009765625, "rew_std": 126.74235879373302, "Agent": "rainbow"}, {"env_step": 8900000, "rew": 1345.9299926757812, "rew_std": 110.19251868290497, "Agent": "rainbow"}, {"env_step": 9000000, "rew": 1361.6299926757813, "rew_std": 146.04748540296958, "Agent": "rainbow"}, {"env_step": 9100000, "rew": 1334.089990234375, "rew_std": 85.54697028191892, "Agent": "rainbow"}, {"env_step": 9200000, "rew": 1292.2299926757812, "rew_std": 143.56093407787503, "Agent": "rainbow"}, {"env_step": 9300000, "rew": 1363.2300048828124, "rew_std": 162.6994478889228, "Agent": "rainbow"}, {"env_step": 9400000, "rew": 1438.1799926757812, "rew_std": 130.79961009153894, "Agent": "rainbow"}, {"env_step": 9500000, "rew": 1496.1199951171875, "rew_std": 112.32410367854669, "Agent": "rainbow"}, {"env_step": 9600000, "rew": 1472.02001953125, "rew_std": 126.8561263598282, "Agent": "rainbow"}, {"env_step": 9700000, "rew": 1391.2999877929688, "rew_std": 85.84510612212074, "Agent": "rainbow"}, {"env_step": 9800000, "rew": 1311.4199951171875, "rew_std": 129.5012392428379, "Agent": "rainbow"}, {"env_step": 9900000, "rew": 1416.0599975585938, "rew_std": 92.06283588597819, "Agent": "rainbow"}, {"env_step": 10000000, "rew": 1416.15, "rew_std": 73.5937659368437, "Agent": "rainbow"}, {"env_step": 0, "rew": 0.010000000149011612, "rew_std": 0.03000000044703483, "Agent": "ppo"}, {"env_step": 100000, "rew": 5.800000095367432, "rew_std": 9.745665904076512, "Agent": "ppo"}, {"env_step": 200000, "rew": 17.870000410079957, "rew_std": 24.34313324422257, "Agent": "ppo"}, {"env_step": 300000, "rew": 33.790000438690186, "rew_std": 34.83846926788018, "Agent": "ppo"}, {"env_step": 400000, "rew": 49.810000157356264, "rew_std": 45.02722450198439, "Agent": "ppo"}, {"env_step": 500000, "rew": 63.84000015258789, "rew_std": 55.360387449231766, "Agent": "ppo"}, {"env_step": 600000, "rew": 70.23000016212464, "rew_std": 58.73603829366222, "Agent": "ppo"}, {"env_step": 700000, "rew": 75.51999950408936, "rew_std": 67.29662320233204, "Agent": "ppo"}, {"env_step": 800000, "rew": 81.71000061035156, "rew_std": 56.81521740785272, "Agent": "ppo"}, {"env_step": 900000, "rew": 113.76000213623047, "rew_std": 79.55546964125922, "Agent": "ppo"}, {"env_step": 1000000, "rew": 116.16000061035156, "rew_std": 84.3558915187514, "Agent": "ppo"}, {"env_step": 1100000, "rew": 122.9199995458126, "rew_std": 82.33155847187241, "Agent": "ppo"}, {"env_step": 1200000, "rew": 150.4199990928173, "rew_std": 104.54179755010438, "Agent": "ppo"}, {"env_step": 1300000, "rew": 168.3199987411499, "rew_std": 108.28887045300536, "Agent": "ppo"}, {"env_step": 1400000, "rew": 176.67999801635742, "rew_std": 94.50735937346175, "Agent": "ppo"}, {"env_step": 1500000, "rew": 210.8900005340576, "rew_std": 108.90844271319088, "Agent": "ppo"}, {"env_step": 1600000, "rew": 211.0199996948242, "rew_std": 101.37994700391509, "Agent": "ppo"}, {"env_step": 1700000, "rew": 214.6699966430664, "rew_std": 100.4095196324463, "Agent": "ppo"}, {"env_step": 1800000, "rew": 247.6599998474121, "rew_std": 104.51677601247573, "Agent": "ppo"}, {"env_step": 1900000, "rew": 279.729997253418, "rew_std": 113.7720723989544, "Agent": "ppo"}, {"env_step": 2000000, "rew": 280.9099998474121, "rew_std": 106.76945620737145, "Agent": "ppo"}, {"env_step": 2100000, "rew": 288.3299987792969, "rew_std": 102.17095623218948, "Agent": "ppo"}, {"env_step": 2200000, "rew": 271.02000427246094, "rew_std": 118.16367469140917, "Agent": "ppo"}, {"env_step": 2300000, "rew": 269.90000305175784, "rew_std": 85.66159506214764, "Agent": "ppo"}, {"env_step": 2400000, "rew": 296.75999755859374, "rew_std": 95.57390780369239, "Agent": "ppo"}, {"env_step": 2500000, "rew": 300.6899978637695, "rew_std": 87.79065128890048, "Agent": "ppo"}, {"env_step": 2600000, "rew": 320.3000015258789, "rew_std": 91.139115860422, "Agent": "ppo"}, {"env_step": 2700000, "rew": 333.9300018310547, "rew_std": 89.31125875672396, "Agent": "ppo"}, {"env_step": 2800000, "rew": 327.020002746582, "rew_std": 120.26429493395794, "Agent": "ppo"}, {"env_step": 2900000, "rew": 361.1499954223633, "rew_std": 109.55767628353465, "Agent": "ppo"}, {"env_step": 3000000, "rew": 302.6900039672852, "rew_std": 98.16769631029796, "Agent": "ppo"}, {"env_step": 3100000, "rew": 315.95, "rew_std": 70.34762657993309, "Agent": "ppo"}, {"env_step": 3200000, "rew": 318.9499984741211, "rew_std": 115.132872934207, "Agent": "ppo"}, {"env_step": 3300000, "rew": 363.1899978637695, "rew_std": 85.86564602287181, "Agent": "ppo"}, {"env_step": 3400000, "rew": 368.1300018310547, "rew_std": 95.21212558703517, "Agent": "ppo"}, {"env_step": 3500000, "rew": 350.01000518798827, "rew_std": 93.16288451590842, "Agent": "ppo"}, {"env_step": 3600000, "rew": 388.5899993896484, "rew_std": 123.07756326920718, "Agent": "ppo"}, {"env_step": 3700000, "rew": 417.2999969482422, "rew_std": 100.34292919570994, "Agent": "ppo"}, {"env_step": 3800000, "rew": 461.9800048828125, "rew_std": 110.07789422926527, "Agent": "ppo"}, {"env_step": 3900000, "rew": 426.2000030517578, "rew_std": 88.50425199374192, "Agent": "ppo"}, {"env_step": 4000000, "rew": 449.4999969482422, "rew_std": 82.22106733939164, "Agent": "ppo"}, {"env_step": 4100000, "rew": 459.7000030517578, "rew_std": 107.17836845612624, "Agent": "ppo"}, {"env_step": 4200000, "rew": 465.42999572753905, "rew_std": 70.88704020848354, "Agent": "ppo"}, {"env_step": 4300000, "rew": 477.6600067138672, "rew_std": 132.27509463627436, "Agent": "ppo"}, {"env_step": 4400000, "rew": 410.0800048828125, "rew_std": 95.43147426617367, "Agent": "ppo"}, {"env_step": 4500000, "rew": 447.7100006103516, "rew_std": 70.30009378544847, "Agent": "ppo"}, {"env_step": 4600000, "rew": 462.73999786376953, "rew_std": 135.12483044104556, "Agent": "ppo"}, {"env_step": 4700000, "rew": 506.6499908447266, "rew_std": 102.90543516850502, "Agent": "ppo"}, {"env_step": 4800000, "rew": 504.0899993896484, "rew_std": 128.18980369342842, "Agent": "ppo"}, {"env_step": 4900000, "rew": 534.0000030517579, "rew_std": 116.97765592191337, "Agent": "ppo"}, {"env_step": 5000000, "rew": 513.1700057983398, "rew_std": 134.31567368090884, "Agent": "ppo"}, {"env_step": 5100000, "rew": 599.9900024414062, "rew_std": 114.45719219959994, "Agent": "ppo"}, {"env_step": 5200000, "rew": 602.4800048828125, "rew_std": 115.42681306636841, "Agent": "ppo"}, {"env_step": 5300000, "rew": 560.45, "rew_std": 147.17203595366996, "Agent": "ppo"}, {"env_step": 5400000, "rew": 542.2500030517579, "rew_std": 112.07122589244243, "Agent": "ppo"}, {"env_step": 5500000, "rew": 658.7099975585937, "rew_std": 119.60214061911367, "Agent": "ppo"}, {"env_step": 5600000, "rew": 624.2800048828125, "rew_std": 68.26668122502235, "Agent": "ppo"}, {"env_step": 5700000, "rew": 587.1499908447265, "rew_std": 79.0157166299757, "Agent": "ppo"}, {"env_step": 5800000, "rew": 649.9000061035156, "rew_std": 150.5787790102997, "Agent": "ppo"}, {"env_step": 5900000, "rew": 665.9100036621094, "rew_std": 119.23486176362321, "Agent": "ppo"}, {"env_step": 6000000, "rew": 706.7900024414063, "rew_std": 129.37492925213294, "Agent": "ppo"}, {"env_step": 6100000, "rew": 643.3500030517578, "rew_std": 116.35478540380245, "Agent": "ppo"}, {"env_step": 6200000, "rew": 721.2799926757813, "rew_std": 105.70935856689015, "Agent": "ppo"}, {"env_step": 6300000, "rew": 650.6100036621094, "rew_std": 148.55013326415474, "Agent": "ppo"}, {"env_step": 6400000, "rew": 768.1000030517578, "rew_std": 139.68046925775155, "Agent": "ppo"}, {"env_step": 6500000, "rew": 764.2999877929688, "rew_std": 118.32171828460929, "Agent": "ppo"}, {"env_step": 6600000, "rew": 782.5200073242188, "rew_std": 120.42663578622158, "Agent": "ppo"}, {"env_step": 6700000, "rew": 727.1499938964844, "rew_std": 130.45287574227336, "Agent": "ppo"}, {"env_step": 6800000, "rew": 783.6400024414063, "rew_std": 90.19995201850675, "Agent": "ppo"}, {"env_step": 6900000, "rew": 819.2, "rew_std": 132.53455239250277, "Agent": "ppo"}, {"env_step": 7000000, "rew": 794.2299987792969, "rew_std": 115.94847255070289, "Agent": "ppo"}, {"env_step": 7100000, "rew": 844.1199890136719, "rew_std": 122.43119457526758, "Agent": "ppo"}, {"env_step": 7200000, "rew": 889.6299987792969, "rew_std": 118.17887362331, "Agent": "ppo"}, {"env_step": 7300000, "rew": 861.8200012207031, "rew_std": 80.9622334545407, "Agent": "ppo"}, {"env_step": 7400000, "rew": 857.1600036621094, "rew_std": 95.34967659902838, "Agent": "ppo"}, {"env_step": 7500000, "rew": 892.6200073242187, "rew_std": 123.67628608131649, "Agent": "ppo"}, {"env_step": 7600000, "rew": 841.9999877929688, "rew_std": 86.62423578644787, "Agent": "ppo"}, {"env_step": 7700000, "rew": 890.85, "rew_std": 97.07490413687609, "Agent": "ppo"}, {"env_step": 7800000, "rew": 909.2900024414063, "rew_std": 86.85939542942558, "Agent": "ppo"}, {"env_step": 7900000, "rew": 906.0800048828125, "rew_std": 88.32103888799158, "Agent": "ppo"}, {"env_step": 8000000, "rew": 931.8700073242187, "rew_std": 80.47352350435204, "Agent": "ppo"}, {"env_step": 8100000, "rew": 938.1000061035156, "rew_std": 122.04440568291714, "Agent": "ppo"}, {"env_step": 8200000, "rew": 945.4200073242188, "rew_std": 66.3337244036705, "Agent": "ppo"}, {"env_step": 8300000, "rew": 971.860009765625, "rew_std": 91.97996093457455, "Agent": "ppo"}, {"env_step": 8400000, "rew": 1005.7700073242188, "rew_std": 129.41220377608226, "Agent": "ppo"}, {"env_step": 8500000, "rew": 972.6299926757813, "rew_std": 93.78562067872805, "Agent": "ppo"}, {"env_step": 8600000, "rew": 1000.3700012207031, "rew_std": 41.05399464776146, "Agent": "ppo"}, {"env_step": 8700000, "rew": 992.8999938964844, "rew_std": 132.0640325185449, "Agent": "ppo"}, {"env_step": 8800000, "rew": 1002.5400024414063, "rew_std": 93.22239772526292, "Agent": "ppo"}, {"env_step": 8900000, "rew": 979.8400085449218, "rew_std": 81.5561815195113, "Agent": "ppo"}, {"env_step": 9000000, "rew": 999.1100036621094, "rew_std": 134.88361081259976, "Agent": "ppo"}, {"env_step": 9100000, "rew": 1019.3699890136719, "rew_std": 110.85209895729604, "Agent": "ppo"}, {"env_step": 9200000, "rew": 1059.209979248047, "rew_std": 97.60941217671626, "Agent": "ppo"}, {"env_step": 9300000, "rew": 1088.9700012207031, "rew_std": 108.06601150231093, "Agent": "ppo"}, {"env_step": 9400000, "rew": 1053.210009765625, "rew_std": 92.88492116582314, "Agent": "ppo"}, {"env_step": 9500000, "rew": 1086.1199890136718, "rew_std": 105.10582835713497, "Agent": "ppo"}, {"env_step": 9600000, "rew": 1088.6500183105468, "rew_std": 111.11191848457115, "Agent": "ppo"}, {"env_step": 9700000, "rew": 1098.8700073242187, "rew_std": 110.51410372623518, "Agent": "ppo"}, {"env_step": 9800000, "rew": 1096.7400024414062, "rew_std": 105.4631692708768, "Agent": "ppo"}, {"env_step": 9900000, "rew": 1075.4200134277344, "rew_std": 89.31879559577925, "Agent": "ppo"}, {"env_step": 10000000, "rew": 1081.520001220703, "rew_std": 84.75678447189097, "Agent": "ppo"}] \ No newline at end of file diff --git a/examples/atari/benchmark/MsPacmanNoFrameskip-v4/result.json b/examples/atari/benchmark/MsPacmanNoFrameskip-v4/result.json new file mode 100644 index 0000000000000000000000000000000000000000..924d28d063006de8741bfefce0689566c8753950 --- /dev/null +++ b/examples/atari/benchmark/MsPacmanNoFrameskip-v4/result.json @@ -0,0 +1 @@ +[{"env_step": 0, "rew": 232.0, "rew_std": 98.97373388935065, "Agent": "c51"}, {"env_step": 100000, "rew": 471.9, "rew_std": 188.48206811259263, "Agent": "c51"}, {"env_step": 200000, "rew": 674.3, "rew_std": 146.9837065800152, "Agent": "c51"}, {"env_step": 300000, "rew": 971.9, "rew_std": 232.55900326583787, "Agent": "c51"}, {"env_step": 400000, "rew": 1213.2, "rew_std": 339.9396417012879, "Agent": "c51"}, {"env_step": 500000, "rew": 1105.6, "rew_std": 205.68480741172888, "Agent": "c51"}, {"env_step": 600000, "rew": 1321.9, "rew_std": 232.04717192846803, "Agent": "c51"}, {"env_step": 700000, "rew": 1380.7, "rew_std": 526.1245194818429, "Agent": "c51"}, {"env_step": 800000, "rew": 1383.7, "rew_std": 241.8123445980374, "Agent": "c51"}, {"env_step": 900000, "rew": 1527.4, "rew_std": 273.0213178489914, "Agent": "c51"}, {"env_step": 1000000, "rew": 1433.2, "rew_std": 211.11788176277253, "Agent": "c51"}, {"env_step": 1100000, "rew": 1540.9, "rew_std": 213.74632160577642, "Agent": "c51"}, {"env_step": 1200000, "rew": 1537.3, "rew_std": 244.24375119949335, "Agent": "c51"}, {"env_step": 1300000, "rew": 1540.3, "rew_std": 281.00250888559697, "Agent": "c51"}, {"env_step": 1400000, "rew": 1688.0, "rew_std": 280.55480747975076, "Agent": "c51"}, {"env_step": 1500000, "rew": 1538.1, "rew_std": 241.28466590316094, "Agent": "c51"}, {"env_step": 1600000, "rew": 1563.0, "rew_std": 274.25572008619986, "Agent": "c51"}, {"env_step": 1700000, "rew": 1590.1, "rew_std": 294.4925975300568, "Agent": "c51"}, {"env_step": 1800000, "rew": 1855.9, "rew_std": 304.23755520974066, "Agent": "c51"}, {"env_step": 1900000, "rew": 1742.9, "rew_std": 189.30052826128087, "Agent": "c51"}, {"env_step": 2000000, "rew": 1842.5, "rew_std": 387.4669663339057, "Agent": "c51"}, {"env_step": 2100000, "rew": 1696.0, "rew_std": 306.7631007797385, "Agent": "c51"}, {"env_step": 2200000, "rew": 1946.0, "rew_std": 377.3793847045702, "Agent": "c51"}, {"env_step": 2300000, "rew": 1633.2, "rew_std": 193.38914137045026, "Agent": "c51"}, {"env_step": 2400000, "rew": 1732.7, "rew_std": 415.1855127530343, "Agent": "c51"}, {"env_step": 2500000, "rew": 2071.9, "rew_std": 396.92983007075696, "Agent": "c51"}, {"env_step": 2600000, "rew": 1844.4, "rew_std": 247.90207744188027, "Agent": "c51"}, {"env_step": 2700000, "rew": 1785.0, "rew_std": 315.5059428917306, "Agent": "c51"}, {"env_step": 2800000, "rew": 2009.1, "rew_std": 355.064627920045, "Agent": "c51"}, {"env_step": 2900000, "rew": 1977.9, "rew_std": 321.0471772185515, "Agent": "c51"}, {"env_step": 3000000, "rew": 1903.5, "rew_std": 249.7784017884653, "Agent": "c51"}, {"env_step": 3100000, "rew": 1831.5, "rew_std": 293.79082695005985, "Agent": "c51"}, {"env_step": 3200000, "rew": 2088.6, "rew_std": 283.78238141223636, "Agent": "c51"}, {"env_step": 3300000, "rew": 2027.6, "rew_std": 295.14003455986784, "Agent": "c51"}, {"env_step": 3400000, "rew": 2003.4, "rew_std": 174.97668416106185, "Agent": "c51"}, {"env_step": 3500000, "rew": 2107.5, "rew_std": 332.1888769962053, "Agent": "c51"}, {"env_step": 3600000, "rew": 1979.6, "rew_std": 317.34372532003846, "Agent": "c51"}, {"env_step": 3700000, "rew": 1993.1, "rew_std": 282.6320753205481, "Agent": "c51"}, {"env_step": 3800000, "rew": 1860.6, "rew_std": 249.90246097227612, "Agent": "c51"}, {"env_step": 3900000, "rew": 2034.8, "rew_std": 301.3057583253264, "Agent": "c51"}, {"env_step": 4000000, "rew": 2045.1, "rew_std": 451.0275933909144, "Agent": "c51"}, {"env_step": 4100000, "rew": 2082.3, "rew_std": 434.06291018699125, "Agent": "c51"}, {"env_step": 4200000, "rew": 2143.4, "rew_std": 325.0452891521426, "Agent": "c51"}, {"env_step": 4300000, "rew": 2178.5, "rew_std": 244.46441458829955, "Agent": "c51"}, {"env_step": 4400000, "rew": 2125.9, "rew_std": 343.1540324693854, "Agent": "c51"}, {"env_step": 4500000, "rew": 1984.6, "rew_std": 215.0196270111173, "Agent": "c51"}, {"env_step": 4600000, "rew": 2114.5, "rew_std": 265.95535339601645, "Agent": "c51"}, {"env_step": 4700000, "rew": 1941.8, "rew_std": 327.6485312037886, "Agent": "c51"}, {"env_step": 4800000, "rew": 2072.2, "rew_std": 302.6386624342633, "Agent": "c51"}, {"env_step": 4900000, "rew": 1997.1, "rew_std": 483.2809638295306, "Agent": "c51"}, {"env_step": 5000000, "rew": 2116.6, "rew_std": 315.6698275096941, "Agent": "c51"}, {"env_step": 5100000, "rew": 2136.5, "rew_std": 296.60318609212544, "Agent": "c51"}, {"env_step": 5200000, "rew": 1947.6, "rew_std": 303.33222710420995, "Agent": "c51"}, {"env_step": 5300000, "rew": 2043.2, "rew_std": 348.6327007037636, "Agent": "c51"}, {"env_step": 5400000, "rew": 2003.5, "rew_std": 181.4779600943321, "Agent": "c51"}, {"env_step": 5500000, "rew": 2042.7, "rew_std": 349.1080205323275, "Agent": "c51"}, {"env_step": 5600000, "rew": 2124.6, "rew_std": 246.09843559031415, "Agent": "c51"}, {"env_step": 5700000, "rew": 1958.8, "rew_std": 204.2428946132521, "Agent": "c51"}, {"env_step": 5800000, "rew": 2104.0, "rew_std": 256.8260111437313, "Agent": "c51"}, {"env_step": 5900000, "rew": 2023.0, "rew_std": 217.54447821077878, "Agent": "c51"}, {"env_step": 6000000, "rew": 1912.1, "rew_std": 324.066490091154, "Agent": "c51"}, {"env_step": 6100000, "rew": 2112.9, "rew_std": 196.80114328936202, "Agent": "c51"}, {"env_step": 6200000, "rew": 2013.2, "rew_std": 245.47211654279596, "Agent": "c51"}, {"env_step": 6300000, "rew": 2207.9, "rew_std": 279.84583255785674, "Agent": "c51"}, {"env_step": 6400000, "rew": 2062.1, "rew_std": 293.391700632448, "Agent": "c51"}, {"env_step": 6500000, "rew": 2121.7, "rew_std": 276.0326248833641, "Agent": "c51"}, {"env_step": 6600000, "rew": 2086.0, "rew_std": 265.9657872734762, "Agent": "c51"}, {"env_step": 6700000, "rew": 1945.4, "rew_std": 268.4459722178748, "Agent": "c51"}, {"env_step": 6800000, "rew": 2025.5, "rew_std": 188.10648579993196, "Agent": "c51"}, {"env_step": 6900000, "rew": 2222.5, "rew_std": 192.25828980826807, "Agent": "c51"}, {"env_step": 7000000, "rew": 1962.0, "rew_std": 251.00756960697422, "Agent": "c51"}, {"env_step": 7100000, "rew": 2028.0, "rew_std": 182.15323219751002, "Agent": "c51"}, {"env_step": 7200000, "rew": 2155.4, "rew_std": 353.8415464582982, "Agent": "c51"}, {"env_step": 7300000, "rew": 2094.3, "rew_std": 366.438821633298, "Agent": "c51"}, {"env_step": 7400000, "rew": 2234.5, "rew_std": 312.953111503944, "Agent": "c51"}, {"env_step": 7500000, "rew": 2193.4, "rew_std": 310.7629965102023, "Agent": "c51"}, {"env_step": 7600000, "rew": 2107.7, "rew_std": 326.7916920608601, "Agent": "c51"}, {"env_step": 7700000, "rew": 2056.6, "rew_std": 275.43790588806036, "Agent": "c51"}, {"env_step": 7800000, "rew": 2155.4, "rew_std": 317.36546756066576, "Agent": "c51"}, {"env_step": 7900000, "rew": 2004.5, "rew_std": 340.1814956754703, "Agent": "c51"}, {"env_step": 8000000, "rew": 2161.7, "rew_std": 239.97793648583612, "Agent": "c51"}, {"env_step": 8100000, "rew": 1823.4, "rew_std": 288.0486764420208, "Agent": "c51"}, {"env_step": 8200000, "rew": 2090.5, "rew_std": 331.0233375458594, "Agent": "c51"}, {"env_step": 8300000, "rew": 1968.3, "rew_std": 320.45063582399086, "Agent": "c51"}, {"env_step": 8400000, "rew": 2106.9, "rew_std": 403.024428540008, "Agent": "c51"}, {"env_step": 8500000, "rew": 2226.4, "rew_std": 361.9058441086577, "Agent": "c51"}, {"env_step": 8600000, "rew": 1958.6, "rew_std": 128.12353413795609, "Agent": "c51"}, {"env_step": 8700000, "rew": 1775.2, "rew_std": 258.49982591870344, "Agent": "c51"}, {"env_step": 8800000, "rew": 2074.7, "rew_std": 342.712138682014, "Agent": "c51"}, {"env_step": 8900000, "rew": 1892.9, "rew_std": 167.9615729862042, "Agent": "c51"}, {"env_step": 9000000, "rew": 2107.6, "rew_std": 225.7517220310844, "Agent": "c51"}, {"env_step": 9100000, "rew": 1941.5, "rew_std": 392.5196173441526, "Agent": "c51"}, {"env_step": 9200000, "rew": 2188.9, "rew_std": 298.97974847805324, "Agent": "c51"}, {"env_step": 9300000, "rew": 2027.5, "rew_std": 442.07267502074814, "Agent": "c51"}, {"env_step": 9400000, "rew": 2137.9, "rew_std": 365.06340545171054, "Agent": "c51"}, {"env_step": 9500000, "rew": 2254.9, "rew_std": 201.1842190630269, "Agent": "c51"}, {"env_step": 9600000, "rew": 1965.5, "rew_std": 197.30040547348096, "Agent": "c51"}, {"env_step": 9700000, "rew": 2237.9, "rew_std": 387.8250765486935, "Agent": "c51"}, {"env_step": 9800000, "rew": 2177.9, "rew_std": 237.57838706414353, "Agent": "c51"}, {"env_step": 9900000, "rew": 2144.3, "rew_std": 298.6901571863392, "Agent": "c51"}, {"env_step": 10000000, "rew": 2196.7, "rew_std": 370.68991084193266, "Agent": "c51"}, {"env_step": 0, "rew": 131.5, "rew_std": 68.65020029104068, "Agent": "dqn"}, {"env_step": 100000, "rew": 614.7, "rew_std": 204.22930739734684, "Agent": "dqn"}, {"env_step": 200000, "rew": 701.0, "rew_std": 207.84417239845817, "Agent": "dqn"}, {"env_step": 300000, "rew": 810.8, "rew_std": 271.50977882941896, "Agent": "dqn"}, {"env_step": 400000, "rew": 834.8, "rew_std": 215.7822050123689, "Agent": "dqn"}, {"env_step": 500000, "rew": 909.7, "rew_std": 287.6004346311041, "Agent": "dqn"}, {"env_step": 600000, "rew": 1064.1, "rew_std": 212.43608450543425, "Agent": "dqn"}, {"env_step": 700000, "rew": 1294.5, "rew_std": 225.1289630411867, "Agent": "dqn"}, {"env_step": 800000, "rew": 1285.4, "rew_std": 259.38820327840665, "Agent": "dqn"}, {"env_step": 900000, "rew": 1304.9, "rew_std": 298.1504485993607, "Agent": "dqn"}, {"env_step": 1000000, "rew": 1251.2, "rew_std": 232.5570037646684, "Agent": "dqn"}, {"env_step": 1100000, "rew": 1391.9, "rew_std": 193.24205028926804, "Agent": "dqn"}, {"env_step": 1200000, "rew": 1458.9, "rew_std": 220.16468835851038, "Agent": "dqn"}, {"env_step": 1300000, "rew": 1373.2, "rew_std": 253.95267275616533, "Agent": "dqn"}, {"env_step": 1400000, "rew": 1561.5, "rew_std": 382.30334814123717, "Agent": "dqn"}, {"env_step": 1500000, "rew": 1512.3, "rew_std": 150.26180486071635, "Agent": "dqn"}, {"env_step": 1600000, "rew": 1592.5, "rew_std": 247.11060276726292, "Agent": "dqn"}, {"env_step": 1700000, "rew": 1433.2, "rew_std": 348.8689725384016, "Agent": "dqn"}, {"env_step": 1800000, "rew": 1603.8, "rew_std": 354.88640436060666, "Agent": "dqn"}, {"env_step": 1900000, "rew": 1709.6, "rew_std": 354.1762837909958, "Agent": "dqn"}, {"env_step": 2000000, "rew": 1364.8, "rew_std": 449.32745297833736, "Agent": "dqn"}, {"env_step": 2100000, "rew": 1460.0, "rew_std": 327.8566760034024, "Agent": "dqn"}, {"env_step": 2200000, "rew": 1597.0, "rew_std": 265.5627985995026, "Agent": "dqn"}, {"env_step": 2300000, "rew": 1751.1, "rew_std": 489.0773865146496, "Agent": "dqn"}, {"env_step": 2400000, "rew": 1606.6, "rew_std": 194.81129330713864, "Agent": "dqn"}, {"env_step": 2500000, "rew": 1659.0, "rew_std": 233.647169895122, "Agent": "dqn"}, {"env_step": 2600000, "rew": 1689.7, "rew_std": 161.25510844621326, "Agent": "dqn"}, {"env_step": 2700000, "rew": 1622.3, "rew_std": 451.706552974384, "Agent": "dqn"}, {"env_step": 2800000, "rew": 1908.7, "rew_std": 397.6817949064302, "Agent": "dqn"}, {"env_step": 2900000, "rew": 1810.7, "rew_std": 342.4213924391991, "Agent": "dqn"}, {"env_step": 3000000, "rew": 1618.5, "rew_std": 344.36412414768176, "Agent": "dqn"}, {"env_step": 3100000, "rew": 1750.4, "rew_std": 284.15425388334415, "Agent": "dqn"}, {"env_step": 3200000, "rew": 1895.3, "rew_std": 299.36167089325244, "Agent": "dqn"}, {"env_step": 3300000, "rew": 1750.6, "rew_std": 258.5123594724244, "Agent": "dqn"}, {"env_step": 3400000, "rew": 1768.8, "rew_std": 519.2761885547998, "Agent": "dqn"}, {"env_step": 3500000, "rew": 1923.8, "rew_std": 338.9373983496067, "Agent": "dqn"}, {"env_step": 3600000, "rew": 1848.1, "rew_std": 409.1123195407344, "Agent": "dqn"}, {"env_step": 3700000, "rew": 1954.5, "rew_std": 319.1367261848752, "Agent": "dqn"}, {"env_step": 3800000, "rew": 1761.1, "rew_std": 322.0672134819066, "Agent": "dqn"}, {"env_step": 3900000, "rew": 1843.7, "rew_std": 310.8404896405872, "Agent": "dqn"}, {"env_step": 4000000, "rew": 2049.1, "rew_std": 425.3754694384715, "Agent": "dqn"}, {"env_step": 4100000, "rew": 1596.5, "rew_std": 433.0591760949074, "Agent": "dqn"}, {"env_step": 4200000, "rew": 1982.2, "rew_std": 302.4684446351388, "Agent": "dqn"}, {"env_step": 4300000, "rew": 1967.3, "rew_std": 453.393217858406, "Agent": "dqn"}, {"env_step": 4400000, "rew": 1863.3, "rew_std": 420.0681016216299, "Agent": "dqn"}, {"env_step": 4500000, "rew": 2167.2, "rew_std": 428.3346355362825, "Agent": "dqn"}, {"env_step": 4600000, "rew": 1978.7, "rew_std": 467.56562106296906, "Agent": "dqn"}, {"env_step": 4700000, "rew": 2055.6, "rew_std": 222.23060095315407, "Agent": "dqn"}, {"env_step": 4800000, "rew": 1964.6, "rew_std": 371.6714140205028, "Agent": "dqn"}, {"env_step": 4900000, "rew": 1900.1, "rew_std": 293.1492623221147, "Agent": "dqn"}, {"env_step": 5000000, "rew": 1980.9, "rew_std": 444.9308822727413, "Agent": "dqn"}, {"env_step": 5100000, "rew": 1902.8, "rew_std": 498.980921478968, "Agent": "dqn"}, {"env_step": 5200000, "rew": 2109.0, "rew_std": 436.7092854520041, "Agent": "dqn"}, {"env_step": 5300000, "rew": 1968.1, "rew_std": 310.17428971466995, "Agent": "dqn"}, {"env_step": 5400000, "rew": 1976.0, "rew_std": 482.8349614516331, "Agent": "dqn"}, {"env_step": 5500000, "rew": 1849.4, "rew_std": 398.51930944434804, "Agent": "dqn"}, {"env_step": 5600000, "rew": 1880.3, "rew_std": 403.730863323576, "Agent": "dqn"}, {"env_step": 5700000, "rew": 2198.5, "rew_std": 510.2962374934779, "Agent": "dqn"}, {"env_step": 5800000, "rew": 2004.4, "rew_std": 301.12960664803455, "Agent": "dqn"}, {"env_step": 5900000, "rew": 2048.1, "rew_std": 450.4685227626898, "Agent": "dqn"}, {"env_step": 6000000, "rew": 2285.8, "rew_std": 433.295003432996, "Agent": "dqn"}, {"env_step": 6100000, "rew": 2123.8, "rew_std": 344.2341644869085, "Agent": "dqn"}, {"env_step": 6200000, "rew": 2220.9, "rew_std": 532.900825670218, "Agent": "dqn"}, {"env_step": 6300000, "rew": 2083.9, "rew_std": 434.93619072227136, "Agent": "dqn"}, {"env_step": 6400000, "rew": 2324.8, "rew_std": 359.76959293414444, "Agent": "dqn"}, {"env_step": 6500000, "rew": 1959.3, "rew_std": 218.1110038489576, "Agent": "dqn"}, {"env_step": 6600000, "rew": 2100.8, "rew_std": 346.50073592995443, "Agent": "dqn"}, {"env_step": 6700000, "rew": 1936.2, "rew_std": 392.79404272468287, "Agent": "dqn"}, {"env_step": 6800000, "rew": 2225.4, "rew_std": 382.2230239009681, "Agent": "dqn"}, {"env_step": 6900000, "rew": 2039.2, "rew_std": 316.61895079101, "Agent": "dqn"}, {"env_step": 7000000, "rew": 2102.3, "rew_std": 456.89212074624356, "Agent": "dqn"}, {"env_step": 7100000, "rew": 2108.4, "rew_std": 328.34195589354704, "Agent": "dqn"}, {"env_step": 7200000, "rew": 1930.3, "rew_std": 574.398824859522, "Agent": "dqn"}, {"env_step": 7300000, "rew": 1915.3, "rew_std": 206.36184240309544, "Agent": "dqn"}, {"env_step": 7400000, "rew": 2049.0, "rew_std": 228.25950144517535, "Agent": "dqn"}, {"env_step": 7500000, "rew": 1851.8, "rew_std": 267.78043244419484, "Agent": "dqn"}, {"env_step": 7600000, "rew": 1897.6, "rew_std": 284.4177209668905, "Agent": "dqn"}, {"env_step": 7700000, "rew": 2028.1, "rew_std": 348.8350469777943, "Agent": "dqn"}, {"env_step": 7800000, "rew": 1792.0, "rew_std": 491.49120032814426, "Agent": "dqn"}, {"env_step": 7900000, "rew": 1943.1, "rew_std": 397.7744209976303, "Agent": "dqn"}, {"env_step": 8000000, "rew": 1958.2, "rew_std": 320.44587686534527, "Agent": "dqn"}, {"env_step": 8100000, "rew": 1928.8, "rew_std": 440.1206198305187, "Agent": "dqn"}, {"env_step": 8200000, "rew": 1939.0, "rew_std": 313.2452713130719, "Agent": "dqn"}, {"env_step": 8300000, "rew": 1952.0, "rew_std": 200.15993605114886, "Agent": "dqn"}, {"env_step": 8400000, "rew": 2045.5, "rew_std": 214.5685205243304, "Agent": "dqn"}, {"env_step": 8500000, "rew": 1957.4, "rew_std": 295.4894921989613, "Agent": "dqn"}, {"env_step": 8600000, "rew": 1757.6, "rew_std": 440.1148032048002, "Agent": "dqn"}, {"env_step": 8700000, "rew": 1913.0, "rew_std": 399.30990471061443, "Agent": "dqn"}, {"env_step": 8800000, "rew": 1883.1, "rew_std": 390.40759470071794, "Agent": "dqn"}, {"env_step": 8900000, "rew": 2099.2, "rew_std": 247.28234874329385, "Agent": "dqn"}, {"env_step": 9000000, "rew": 1986.4, "rew_std": 552.3334500100459, "Agent": "dqn"}, {"env_step": 9100000, "rew": 1933.1, "rew_std": 483.2095715111612, "Agent": "dqn"}, {"env_step": 9200000, "rew": 1978.8, "rew_std": 328.87955242003113, "Agent": "dqn"}, {"env_step": 9300000, "rew": 2072.5, "rew_std": 331.83045369585955, "Agent": "dqn"}, {"env_step": 9400000, "rew": 2054.7, "rew_std": 349.56088167871417, "Agent": "dqn"}, {"env_step": 9500000, "rew": 1718.4, "rew_std": 509.64207047691815, "Agent": "dqn"}, {"env_step": 9600000, "rew": 1724.2, "rew_std": 495.79507863632534, "Agent": "dqn"}, {"env_step": 9700000, "rew": 1770.2, "rew_std": 291.362591970898, "Agent": "dqn"}, {"env_step": 9800000, "rew": 1836.1, "rew_std": 500.204048364265, "Agent": "dqn"}, {"env_step": 9900000, "rew": 2061.6, "rew_std": 311.6386368857366, "Agent": "dqn"}, {"env_step": 10000000, "rew": 1889.7, "rew_std": 482.98779487684783, "Agent": "dqn"}, {"env_step": 0, "rew": 209.6, "rew_std": 147.36770338171112, "Agent": "fqf"}, {"env_step": 100000, "rew": 564.9, "rew_std": 161.8928349248354, "Agent": "fqf"}, {"env_step": 200000, "rew": 767.5, "rew_std": 204.61732575713134, "Agent": "fqf"}, {"env_step": 300000, "rew": 1075.9, "rew_std": 288.21535351191824, "Agent": "fqf"}, {"env_step": 400000, "rew": 1022.0, "rew_std": 265.35787156216037, "Agent": "fqf"}, {"env_step": 500000, "rew": 1071.1, "rew_std": 377.931329741264, "Agent": "fqf"}, {"env_step": 600000, "rew": 1151.7, "rew_std": 300.1049982922644, "Agent": "fqf"}, {"env_step": 700000, "rew": 1196.7, "rew_std": 103.24538730616491, "Agent": "fqf"}, {"env_step": 800000, "rew": 1514.1, "rew_std": 486.713149606624, "Agent": "fqf"}, {"env_step": 900000, "rew": 1626.9, "rew_std": 316.3809254680187, "Agent": "fqf"}, {"env_step": 1000000, "rew": 1582.2, "rew_std": 369.04926500400995, "Agent": "fqf"}, {"env_step": 1100000, "rew": 1562.1, "rew_std": 180.65572230073423, "Agent": "fqf"}, {"env_step": 1200000, "rew": 1479.4, "rew_std": 133.2525421896333, "Agent": "fqf"}, {"env_step": 1300000, "rew": 1791.1, "rew_std": 167.81623878516643, "Agent": "fqf"}, {"env_step": 1400000, "rew": 1891.5, "rew_std": 251.22788459882395, "Agent": "fqf"}, {"env_step": 1500000, "rew": 1736.2, "rew_std": 316.84216891064233, "Agent": "fqf"}, {"env_step": 1600000, "rew": 2012.1, "rew_std": 421.7516923498944, "Agent": "fqf"}, {"env_step": 1700000, "rew": 1790.0, "rew_std": 352.4060158396845, "Agent": "fqf"}, {"env_step": 1800000, "rew": 1927.9, "rew_std": 216.36147993577785, "Agent": "fqf"}, {"env_step": 1900000, "rew": 2056.4, "rew_std": 326.7314493586438, "Agent": "fqf"}, {"env_step": 2000000, "rew": 2039.8, "rew_std": 309.31789472967773, "Agent": "fqf"}, {"env_step": 2100000, "rew": 2133.9, "rew_std": 337.53442787366146, "Agent": "fqf"}, {"env_step": 2200000, "rew": 2087.1, "rew_std": 346.3381151418365, "Agent": "fqf"}, {"env_step": 2300000, "rew": 2096.6, "rew_std": 242.30361119884284, "Agent": "fqf"}, {"env_step": 2400000, "rew": 2233.8, "rew_std": 185.708804314712, "Agent": "fqf"}, {"env_step": 2500000, "rew": 1932.5, "rew_std": 284.03248053699775, "Agent": "fqf"}, {"env_step": 2600000, "rew": 1983.8, "rew_std": 354.9137359979182, "Agent": "fqf"}, {"env_step": 2700000, "rew": 2270.3, "rew_std": 424.8583410973592, "Agent": "fqf"}, {"env_step": 2800000, "rew": 2289.4, "rew_std": 383.8, "Agent": "fqf"}, {"env_step": 2900000, "rew": 2172.8, "rew_std": 362.5851072506978, "Agent": "fqf"}, {"env_step": 3000000, "rew": 2144.2, "rew_std": 308.21057736554076, "Agent": "fqf"}, {"env_step": 3100000, "rew": 2074.5, "rew_std": 176.32711079127907, "Agent": "fqf"}, {"env_step": 3200000, "rew": 2291.0, "rew_std": 323.1361941968123, "Agent": "fqf"}, {"env_step": 3300000, "rew": 2121.9, "rew_std": 152.47062012073016, "Agent": "fqf"}, {"env_step": 3400000, "rew": 2160.7, "rew_std": 297.12726229681454, "Agent": "fqf"}, {"env_step": 3500000, "rew": 1965.6, "rew_std": 434.0376020577019, "Agent": "fqf"}, {"env_step": 3600000, "rew": 2088.7, "rew_std": 263.4023728063208, "Agent": "fqf"}, {"env_step": 3700000, "rew": 2288.5, "rew_std": 494.5837138442793, "Agent": "fqf"}, {"env_step": 3800000, "rew": 2077.7, "rew_std": 488.92699864090133, "Agent": "fqf"}, {"env_step": 3900000, "rew": 2320.0, "rew_std": 402.8208038321755, "Agent": "fqf"}, {"env_step": 4000000, "rew": 2162.4, "rew_std": 318.08275652729117, "Agent": "fqf"}, {"env_step": 4100000, "rew": 2152.5, "rew_std": 415.441271421124, "Agent": "fqf"}, {"env_step": 4200000, "rew": 2323.0, "rew_std": 441.1031625368379, "Agent": "fqf"}, {"env_step": 4300000, "rew": 2254.9, "rew_std": 369.04347982317745, "Agent": "fqf"}, {"env_step": 4400000, "rew": 2266.5, "rew_std": 468.3191753494619, "Agent": "fqf"}, {"env_step": 4500000, "rew": 2189.8, "rew_std": 463.9204242108769, "Agent": "fqf"}, {"env_step": 4600000, "rew": 2164.3, "rew_std": 272.80764285481445, "Agent": "fqf"}, {"env_step": 4700000, "rew": 2167.5, "rew_std": 151.73611962878186, "Agent": "fqf"}, {"env_step": 4800000, "rew": 2223.9, "rew_std": 334.81470995163875, "Agent": "fqf"}, {"env_step": 4900000, "rew": 2216.2, "rew_std": 302.2647184174825, "Agent": "fqf"}, {"env_step": 5000000, "rew": 2307.7, "rew_std": 463.2478926017905, "Agent": "fqf"}, {"env_step": 5100000, "rew": 2229.6, "rew_std": 334.5005829591333, "Agent": "fqf"}, {"env_step": 5200000, "rew": 2097.1, "rew_std": 257.26540770185176, "Agent": "fqf"}, {"env_step": 5300000, "rew": 2246.4, "rew_std": 323.79691165914477, "Agent": "fqf"}, {"env_step": 5400000, "rew": 2412.3, "rew_std": 271.80399187649914, "Agent": "fqf"}, {"env_step": 5500000, "rew": 2289.4, "rew_std": 326.94929270454156, "Agent": "fqf"}, {"env_step": 5600000, "rew": 2213.5, "rew_std": 387.96320701839755, "Agent": "fqf"}, {"env_step": 5700000, "rew": 2212.0, "rew_std": 371.89353315162657, "Agent": "fqf"}, {"env_step": 5800000, "rew": 2119.3, "rew_std": 379.4562030063549, "Agent": "fqf"}, {"env_step": 5900000, "rew": 2086.9, "rew_std": 480.4985848054081, "Agent": "fqf"}, {"env_step": 6000000, "rew": 2082.4, "rew_std": 193.79483997258544, "Agent": "fqf"}, {"env_step": 6100000, "rew": 2167.7, "rew_std": 307.71611917480044, "Agent": "fqf"}, {"env_step": 6200000, "rew": 2257.0, "rew_std": 330.79963724284823, "Agent": "fqf"}, {"env_step": 6300000, "rew": 2369.5, "rew_std": 511.182990718588, "Agent": "fqf"}, {"env_step": 6400000, "rew": 2233.4, "rew_std": 333.7532621563421, "Agent": "fqf"}, {"env_step": 6500000, "rew": 2179.4, "rew_std": 211.5184152739425, "Agent": "fqf"}, {"env_step": 6600000, "rew": 2365.5, "rew_std": 349.8263140474141, "Agent": "fqf"}, {"env_step": 6700000, "rew": 2255.2, "rew_std": 228.93309066187874, "Agent": "fqf"}, {"env_step": 6800000, "rew": 2354.2, "rew_std": 476.53138406614937, "Agent": "fqf"}, {"env_step": 6900000, "rew": 2295.9, "rew_std": 328.7827398146077, "Agent": "fqf"}, {"env_step": 7000000, "rew": 2294.0, "rew_std": 350.9675198647305, "Agent": "fqf"}, {"env_step": 7100000, "rew": 2279.9, "rew_std": 249.64392642321585, "Agent": "fqf"}, {"env_step": 7200000, "rew": 2247.8, "rew_std": 342.9891543474808, "Agent": "fqf"}, {"env_step": 7300000, "rew": 2430.6, "rew_std": 378.38821334708615, "Agent": "fqf"}, {"env_step": 7400000, "rew": 2458.8, "rew_std": 619.2716366829665, "Agent": "fqf"}, {"env_step": 7500000, "rew": 2126.3, "rew_std": 379.9394820231243, "Agent": "fqf"}, {"env_step": 7600000, "rew": 2294.9, "rew_std": 231.97821018362907, "Agent": "fqf"}, {"env_step": 7700000, "rew": 2384.8, "rew_std": 322.39317610644304, "Agent": "fqf"}, {"env_step": 7800000, "rew": 2327.4, "rew_std": 316.6234988120749, "Agent": "fqf"}, {"env_step": 7900000, "rew": 2369.9, "rew_std": 376.6129179940593, "Agent": "fqf"}, {"env_step": 8000000, "rew": 2459.9, "rew_std": 387.55888584833144, "Agent": "fqf"}, {"env_step": 8100000, "rew": 2432.6, "rew_std": 335.16748052279775, "Agent": "fqf"}, {"env_step": 8200000, "rew": 2405.6, "rew_std": 405.6466935647325, "Agent": "fqf"}, {"env_step": 8300000, "rew": 2475.6, "rew_std": 527.2349002105228, "Agent": "fqf"}, {"env_step": 8400000, "rew": 2371.0, "rew_std": 221.4050586594624, "Agent": "fqf"}, {"env_step": 8500000, "rew": 2384.4, "rew_std": 362.2496928915192, "Agent": "fqf"}, {"env_step": 8600000, "rew": 2201.5, "rew_std": 263.163162315701, "Agent": "fqf"}, {"env_step": 8700000, "rew": 2120.5, "rew_std": 442.54948875803706, "Agent": "fqf"}, {"env_step": 8800000, "rew": 2236.1, "rew_std": 213.28546598397182, "Agent": "fqf"}, {"env_step": 8900000, "rew": 2335.3, "rew_std": 308.52230065264325, "Agent": "fqf"}, {"env_step": 9000000, "rew": 2441.1, "rew_std": 316.18394962426544, "Agent": "fqf"}, {"env_step": 9100000, "rew": 2425.4, "rew_std": 412.7859493732799, "Agent": "fqf"}, {"env_step": 9200000, "rew": 2400.9, "rew_std": 393.80869721223786, "Agent": "fqf"}, {"env_step": 9300000, "rew": 2478.2, "rew_std": 338.00852060266175, "Agent": "fqf"}, {"env_step": 9400000, "rew": 2325.9, "rew_std": 250.8499352202428, "Agent": "fqf"}, {"env_step": 9500000, "rew": 2419.5, "rew_std": 283.8465254323188, "Agent": "fqf"}, {"env_step": 9600000, "rew": 2506.6, "rew_std": 402.5186206872919, "Agent": "fqf"}, {"env_step": 9700000, "rew": 2327.1, "rew_std": 375.1155688584519, "Agent": "fqf"}, {"env_step": 9800000, "rew": 2354.1, "rew_std": 263.84557983790444, "Agent": "fqf"}, {"env_step": 9900000, "rew": 2433.2, "rew_std": 550.9763697292291, "Agent": "fqf"}, {"env_step": 10000000, "rew": 2478.0, "rew_std": 184.68892765945662, "Agent": "fqf"}, {"env_step": 0, "rew": 147.9, "rew_std": 65.50030534279973, "Agent": "qrdqn"}, {"env_step": 100000, "rew": 658.8, "rew_std": 173.94355406280508, "Agent": "qrdqn"}, {"env_step": 200000, "rew": 901.9, "rew_std": 225.8364231030947, "Agent": "qrdqn"}, {"env_step": 300000, "rew": 947.9, "rew_std": 176.00366473457305, "Agent": "qrdqn"}, {"env_step": 400000, "rew": 984.4, "rew_std": 298.15405413980204, "Agent": "qrdqn"}, {"env_step": 500000, "rew": 1047.1, "rew_std": 262.9248752020242, "Agent": "qrdqn"}, {"env_step": 600000, "rew": 1181.9, "rew_std": 363.3566980255077, "Agent": "qrdqn"}, {"env_step": 700000, "rew": 1294.7, "rew_std": 404.93630363305283, "Agent": "qrdqn"}, {"env_step": 800000, "rew": 1202.4, "rew_std": 548.8827197134193, "Agent": "qrdqn"}, {"env_step": 900000, "rew": 1376.5, "rew_std": 180.5803145417573, "Agent": "qrdqn"}, {"env_step": 1000000, "rew": 1537.5, "rew_std": 396.10762426391136, "Agent": "qrdqn"}, {"env_step": 1100000, "rew": 1516.6, "rew_std": 246.3169502896624, "Agent": "qrdqn"}, {"env_step": 1200000, "rew": 1514.0, "rew_std": 284.75779181613274, "Agent": "qrdqn"}, {"env_step": 1300000, "rew": 1366.4, "rew_std": 237.443551186382, "Agent": "qrdqn"}, {"env_step": 1400000, "rew": 1548.5, "rew_std": 338.3723540716647, "Agent": "qrdqn"}, {"env_step": 1500000, "rew": 1512.0, "rew_std": 236.26468208346333, "Agent": "qrdqn"}, {"env_step": 1600000, "rew": 1628.0, "rew_std": 213.3972820820359, "Agent": "qrdqn"}, {"env_step": 1700000, "rew": 1573.0, "rew_std": 335.270338682085, "Agent": "qrdqn"}, {"env_step": 1800000, "rew": 1587.9, "rew_std": 292.99059711874713, "Agent": "qrdqn"}, {"env_step": 1900000, "rew": 1870.7, "rew_std": 401.1979685890745, "Agent": "qrdqn"}, {"env_step": 2000000, "rew": 1629.4, "rew_std": 364.31859683524254, "Agent": "qrdqn"}, {"env_step": 2100000, "rew": 1756.7, "rew_std": 314.8085291093619, "Agent": "qrdqn"}, {"env_step": 2200000, "rew": 1695.5, "rew_std": 151.9527887207076, "Agent": "qrdqn"}, {"env_step": 2300000, "rew": 1751.0, "rew_std": 158.11894257172352, "Agent": "qrdqn"}, {"env_step": 2400000, "rew": 1726.3, "rew_std": 512.4833753401177, "Agent": "qrdqn"}, {"env_step": 2500000, "rew": 1908.5, "rew_std": 349.27532120091166, "Agent": "qrdqn"}, {"env_step": 2600000, "rew": 1737.8, "rew_std": 288.164813952016, "Agent": "qrdqn"}, {"env_step": 2700000, "rew": 1928.7, "rew_std": 390.23686396853896, "Agent": "qrdqn"}, {"env_step": 2800000, "rew": 1859.6, "rew_std": 433.55949995358196, "Agent": "qrdqn"}, {"env_step": 2900000, "rew": 1784.5, "rew_std": 345.3847854205509, "Agent": "qrdqn"}, {"env_step": 3000000, "rew": 1677.1, "rew_std": 354.96688577950476, "Agent": "qrdqn"}, {"env_step": 3100000, "rew": 1964.8, "rew_std": 299.83221974964596, "Agent": "qrdqn"}, {"env_step": 3200000, "rew": 1798.1, "rew_std": 198.37159574898823, "Agent": "qrdqn"}, {"env_step": 3300000, "rew": 1783.1, "rew_std": 387.8730333498322, "Agent": "qrdqn"}, {"env_step": 3400000, "rew": 1856.4, "rew_std": 299.4779457656273, "Agent": "qrdqn"}, {"env_step": 3500000, "rew": 2008.8, "rew_std": 300.43894554468136, "Agent": "qrdqn"}, {"env_step": 3600000, "rew": 2021.3, "rew_std": 332.70619170673694, "Agent": "qrdqn"}, {"env_step": 3700000, "rew": 1957.0, "rew_std": 290.85013323015687, "Agent": "qrdqn"}, {"env_step": 3800000, "rew": 1850.0, "rew_std": 178.16116299575506, "Agent": "qrdqn"}, {"env_step": 3900000, "rew": 1878.4, "rew_std": 399.1170755555317, "Agent": "qrdqn"}, {"env_step": 4000000, "rew": 2028.5, "rew_std": 457.8126800340943, "Agent": "qrdqn"}, {"env_step": 4100000, "rew": 1910.1, "rew_std": 372.1944249985483, "Agent": "qrdqn"}, {"env_step": 4200000, "rew": 1774.8, "rew_std": 338.8928444213598, "Agent": "qrdqn"}, {"env_step": 4300000, "rew": 1837.0, "rew_std": 452.25545878408144, "Agent": "qrdqn"}, {"env_step": 4400000, "rew": 2068.4, "rew_std": 295.75131445185497, "Agent": "qrdqn"}, {"env_step": 4500000, "rew": 1876.6, "rew_std": 202.19851631503136, "Agent": "qrdqn"}, {"env_step": 4600000, "rew": 1891.1, "rew_std": 480.94541270293865, "Agent": "qrdqn"}, {"env_step": 4700000, "rew": 1814.4, "rew_std": 474.1063593751934, "Agent": "qrdqn"}, {"env_step": 4800000, "rew": 2115.4, "rew_std": 403.2094244930294, "Agent": "qrdqn"}, {"env_step": 4900000, "rew": 2100.4, "rew_std": 376.1237030552581, "Agent": "qrdqn"}, {"env_step": 5000000, "rew": 1879.5, "rew_std": 195.04525115982702, "Agent": "qrdqn"}, {"env_step": 5100000, "rew": 1979.9, "rew_std": 212.22226556136846, "Agent": "qrdqn"}, {"env_step": 5200000, "rew": 1924.1, "rew_std": 592.8174170855643, "Agent": "qrdqn"}, {"env_step": 5300000, "rew": 2100.4, "rew_std": 348.4285292567186, "Agent": "qrdqn"}, {"env_step": 5400000, "rew": 1820.9, "rew_std": 254.82796157407844, "Agent": "qrdqn"}, {"env_step": 5500000, "rew": 1936.0, "rew_std": 332.0704744478196, "Agent": "qrdqn"}, {"env_step": 5600000, "rew": 1994.6, "rew_std": 341.4827082005764, "Agent": "qrdqn"}, {"env_step": 5700000, "rew": 1998.0, "rew_std": 395.96893817571095, "Agent": "qrdqn"}, {"env_step": 5800000, "rew": 2020.0, "rew_std": 326.8748996175754, "Agent": "qrdqn"}, {"env_step": 5900000, "rew": 1926.1, "rew_std": 247.97235733040887, "Agent": "qrdqn"}, {"env_step": 6000000, "rew": 2020.8, "rew_std": 313.74059348449, "Agent": "qrdqn"}, {"env_step": 6100000, "rew": 1995.5, "rew_std": 454.48965884825145, "Agent": "qrdqn"}, {"env_step": 6200000, "rew": 2165.4, "rew_std": 601.1125019495103, "Agent": "qrdqn"}, {"env_step": 6300000, "rew": 2024.4, "rew_std": 388.68424202686685, "Agent": "qrdqn"}, {"env_step": 6400000, "rew": 1701.7, "rew_std": 395.4142258442405, "Agent": "qrdqn"}, {"env_step": 6500000, "rew": 1837.4, "rew_std": 457.4256660923171, "Agent": "qrdqn"}, {"env_step": 6600000, "rew": 2029.2, "rew_std": 292.2097876526384, "Agent": "qrdqn"}, {"env_step": 6700000, "rew": 2231.2, "rew_std": 258.9786863817175, "Agent": "qrdqn"}, {"env_step": 6800000, "rew": 1961.0, "rew_std": 310.98295773241335, "Agent": "qrdqn"}, {"env_step": 6900000, "rew": 2151.2, "rew_std": 517.7655453967558, "Agent": "qrdqn"}, {"env_step": 7000000, "rew": 2058.0, "rew_std": 143.03985458605584, "Agent": "qrdqn"}, {"env_step": 7100000, "rew": 1998.7, "rew_std": 416.9923380591063, "Agent": "qrdqn"}, {"env_step": 7200000, "rew": 2028.3, "rew_std": 338.8569167067422, "Agent": "qrdqn"}, {"env_step": 7300000, "rew": 1894.4, "rew_std": 426.75219976000125, "Agent": "qrdqn"}, {"env_step": 7400000, "rew": 2052.4, "rew_std": 381.0210492873064, "Agent": "qrdqn"}, {"env_step": 7500000, "rew": 2102.5, "rew_std": 231.69905049438592, "Agent": "qrdqn"}, {"env_step": 7600000, "rew": 2112.5, "rew_std": 287.3075877870266, "Agent": "qrdqn"}, {"env_step": 7700000, "rew": 2069.4, "rew_std": 393.8327563827062, "Agent": "qrdqn"}, {"env_step": 7800000, "rew": 2143.4, "rew_std": 293.58174330158886, "Agent": "qrdqn"}, {"env_step": 7900000, "rew": 2049.1, "rew_std": 565.8735636164672, "Agent": "qrdqn"}, {"env_step": 8000000, "rew": 1993.5, "rew_std": 267.2022642119636, "Agent": "qrdqn"}, {"env_step": 8100000, "rew": 1891.7, "rew_std": 279.16018698947744, "Agent": "qrdqn"}, {"env_step": 8200000, "rew": 2026.0, "rew_std": 382.7017115195593, "Agent": "qrdqn"}, {"env_step": 8300000, "rew": 1835.2, "rew_std": 455.81439205009747, "Agent": "qrdqn"}, {"env_step": 8400000, "rew": 1992.0, "rew_std": 213.05257567088927, "Agent": "qrdqn"}, {"env_step": 8500000, "rew": 2126.1, "rew_std": 186.0109942987242, "Agent": "qrdqn"}, {"env_step": 8600000, "rew": 2112.9, "rew_std": 253.74532508008892, "Agent": "qrdqn"}, {"env_step": 8700000, "rew": 2143.2, "rew_std": 242.88507570453973, "Agent": "qrdqn"}, {"env_step": 8800000, "rew": 1859.2, "rew_std": 253.84436176523596, "Agent": "qrdqn"}, {"env_step": 8900000, "rew": 1909.5, "rew_std": 577.5562743144602, "Agent": "qrdqn"}, {"env_step": 9000000, "rew": 2166.2, "rew_std": 251.24044260429093, "Agent": "qrdqn"}, {"env_step": 9100000, "rew": 2221.9, "rew_std": 475.83493986885827, "Agent": "qrdqn"}, {"env_step": 9200000, "rew": 2060.8, "rew_std": 244.7904409898393, "Agent": "qrdqn"}, {"env_step": 9300000, "rew": 2114.2, "rew_std": 240.22980664355538, "Agent": "qrdqn"}, {"env_step": 9400000, "rew": 2168.0, "rew_std": 323.3663557020118, "Agent": "qrdqn"}, {"env_step": 9500000, "rew": 2012.4, "rew_std": 346.8542056830218, "Agent": "qrdqn"}, {"env_step": 9600000, "rew": 1901.9, "rew_std": 406.43263894525006, "Agent": "qrdqn"}, {"env_step": 9700000, "rew": 2197.4, "rew_std": 377.0950012927777, "Agent": "qrdqn"}, {"env_step": 9800000, "rew": 2095.6, "rew_std": 300.4899998336051, "Agent": "qrdqn"}, {"env_step": 9900000, "rew": 2027.9, "rew_std": 359.36790340819255, "Agent": "qrdqn"}, {"env_step": 10000000, "rew": 2259.3, "rew_std": 269.21108818174633, "Agent": "qrdqn"}, {"env_step": 0, "rew": 203.5, "rew_std": 125.49442218680478, "Agent": "iqn"}, {"env_step": 100000, "rew": 497.4, "rew_std": 171.87914358641655, "Agent": "iqn"}, {"env_step": 200000, "rew": 719.6, "rew_std": 175.82332041000703, "Agent": "iqn"}, {"env_step": 300000, "rew": 808.8, "rew_std": 230.12640005005943, "Agent": "iqn"}, {"env_step": 400000, "rew": 841.3, "rew_std": 215.85182417575254, "Agent": "iqn"}, {"env_step": 500000, "rew": 917.6, "rew_std": 177.8264322309819, "Agent": "iqn"}, {"env_step": 600000, "rew": 896.3, "rew_std": 348.6686249148323, "Agent": "iqn"}, {"env_step": 700000, "rew": 1141.9, "rew_std": 364.5055417960062, "Agent": "iqn"}, {"env_step": 800000, "rew": 1323.0, "rew_std": 291.3722704719857, "Agent": "iqn"}, {"env_step": 900000, "rew": 1274.1, "rew_std": 234.52439105559998, "Agent": "iqn"}, {"env_step": 1000000, "rew": 1553.1, "rew_std": 408.0204529187232, "Agent": "iqn"}, {"env_step": 1100000, "rew": 1436.4, "rew_std": 345.38245467886753, "Agent": "iqn"}, {"env_step": 1200000, "rew": 1649.8, "rew_std": 456.99295399382254, "Agent": "iqn"}, {"env_step": 1300000, "rew": 1489.3, "rew_std": 167.85592036028996, "Agent": "iqn"}, {"env_step": 1400000, "rew": 1645.2, "rew_std": 115.38006760268429, "Agent": "iqn"}, {"env_step": 1500000, "rew": 1641.9, "rew_std": 186.58534240395198, "Agent": "iqn"}, {"env_step": 1600000, "rew": 1599.5, "rew_std": 387.8569968429086, "Agent": "iqn"}, {"env_step": 1700000, "rew": 1690.2, "rew_std": 161.43902873840636, "Agent": "iqn"}, {"env_step": 1800000, "rew": 1613.6, "rew_std": 344.5162405460735, "Agent": "iqn"}, {"env_step": 1900000, "rew": 1773.5, "rew_std": 319.91131583612355, "Agent": "iqn"}, {"env_step": 2000000, "rew": 1738.7, "rew_std": 435.27900248001856, "Agent": "iqn"}, {"env_step": 2100000, "rew": 1719.0, "rew_std": 253.0829903411132, "Agent": "iqn"}, {"env_step": 2200000, "rew": 1831.4, "rew_std": 298.3257280222408, "Agent": "iqn"}, {"env_step": 2300000, "rew": 1982.3, "rew_std": 322.387980545181, "Agent": "iqn"}, {"env_step": 2400000, "rew": 1801.0, "rew_std": 123.33045041675636, "Agent": "iqn"}, {"env_step": 2500000, "rew": 1800.4, "rew_std": 263.40888367706964, "Agent": "iqn"}, {"env_step": 2600000, "rew": 1744.0, "rew_std": 312.34083946868043, "Agent": "iqn"}, {"env_step": 2700000, "rew": 2024.6, "rew_std": 493.24663202093933, "Agent": "iqn"}, {"env_step": 2800000, "rew": 1913.7, "rew_std": 184.7598711841941, "Agent": "iqn"}, {"env_step": 2900000, "rew": 1956.7, "rew_std": 347.0703242860156, "Agent": "iqn"}, {"env_step": 3000000, "rew": 1950.4, "rew_std": 288.921165718263, "Agent": "iqn"}, {"env_step": 3100000, "rew": 1983.4, "rew_std": 243.26002548713177, "Agent": "iqn"}, {"env_step": 3200000, "rew": 2040.2, "rew_std": 305.4157166879269, "Agent": "iqn"}, {"env_step": 3300000, "rew": 2148.3, "rew_std": 352.78890288669794, "Agent": "iqn"}, {"env_step": 3400000, "rew": 1893.3, "rew_std": 572.8401260386705, "Agent": "iqn"}, {"env_step": 3500000, "rew": 2011.7, "rew_std": 243.61323855652836, "Agent": "iqn"}, {"env_step": 3600000, "rew": 1999.7, "rew_std": 199.95301948207737, "Agent": "iqn"}, {"env_step": 3700000, "rew": 2145.6, "rew_std": 185.3392565000734, "Agent": "iqn"}, {"env_step": 3800000, "rew": 2101.3, "rew_std": 386.0235873622233, "Agent": "iqn"}, {"env_step": 3900000, "rew": 1885.6, "rew_std": 300.85850494875496, "Agent": "iqn"}, {"env_step": 4000000, "rew": 2040.6, "rew_std": 263.6210158541993, "Agent": "iqn"}, {"env_step": 4100000, "rew": 2034.7, "rew_std": 204.19894710796137, "Agent": "iqn"}, {"env_step": 4200000, "rew": 2011.4, "rew_std": 203.3834801551001, "Agent": "iqn"}, {"env_step": 4300000, "rew": 1990.5, "rew_std": 242.6170851362286, "Agent": "iqn"}, {"env_step": 4400000, "rew": 1978.7, "rew_std": 212.20275681526854, "Agent": "iqn"}, {"env_step": 4500000, "rew": 1977.7, "rew_std": 178.95030036297788, "Agent": "iqn"}, {"env_step": 4600000, "rew": 1849.0, "rew_std": 308.38774294708924, "Agent": "iqn"}, {"env_step": 4700000, "rew": 1953.0, "rew_std": 273.4209209259599, "Agent": "iqn"}, {"env_step": 4800000, "rew": 2019.8, "rew_std": 216.3579441573616, "Agent": "iqn"}, {"env_step": 4900000, "rew": 1956.9, "rew_std": 152.99048990051637, "Agent": "iqn"}, {"env_step": 5000000, "rew": 2045.9, "rew_std": 282.905090092066, "Agent": "iqn"}, {"env_step": 5100000, "rew": 1971.6, "rew_std": 380.651336527274, "Agent": "iqn"}, {"env_step": 5200000, "rew": 2039.4, "rew_std": 223.29630538815462, "Agent": "iqn"}, {"env_step": 5300000, "rew": 1975.8, "rew_std": 201.08694636897744, "Agent": "iqn"}, {"env_step": 5400000, "rew": 2064.5, "rew_std": 254.50432216369134, "Agent": "iqn"}, {"env_step": 5500000, "rew": 2134.7, "rew_std": 428.5522255221643, "Agent": "iqn"}, {"env_step": 5600000, "rew": 1948.5, "rew_std": 272.5880591662078, "Agent": "iqn"}, {"env_step": 5700000, "rew": 2002.2, "rew_std": 340.7244634598461, "Agent": "iqn"}, {"env_step": 5800000, "rew": 2045.1, "rew_std": 164.70364294696094, "Agent": "iqn"}, {"env_step": 5900000, "rew": 1886.4, "rew_std": 163.4730558838367, "Agent": "iqn"}, {"env_step": 6000000, "rew": 1919.7, "rew_std": 219.83268637761765, "Agent": "iqn"}, {"env_step": 6100000, "rew": 2004.6, "rew_std": 165.23207920981932, "Agent": "iqn"}, {"env_step": 6200000, "rew": 1947.4, "rew_std": 389.1902362598528, "Agent": "iqn"}, {"env_step": 6300000, "rew": 2121.2, "rew_std": 371.044148316612, "Agent": "iqn"}, {"env_step": 6400000, "rew": 2047.5, "rew_std": 190.7633350515764, "Agent": "iqn"}, {"env_step": 6500000, "rew": 2032.7, "rew_std": 139.8721201669582, "Agent": "iqn"}, {"env_step": 6600000, "rew": 2159.6, "rew_std": 173.13416762730577, "Agent": "iqn"}, {"env_step": 6700000, "rew": 1899.0, "rew_std": 313.5557366721266, "Agent": "iqn"}, {"env_step": 6800000, "rew": 2104.5, "rew_std": 332.75554090052356, "Agent": "iqn"}, {"env_step": 6900000, "rew": 2212.5, "rew_std": 400.02956140765394, "Agent": "iqn"}, {"env_step": 7000000, "rew": 1910.5, "rew_std": 262.7969748684334, "Agent": "iqn"}, {"env_step": 7100000, "rew": 2110.5, "rew_std": 244.04640952081226, "Agent": "iqn"}, {"env_step": 7200000, "rew": 2069.4, "rew_std": 252.60095011697797, "Agent": "iqn"}, {"env_step": 7300000, "rew": 1997.3, "rew_std": 178.55478150976523, "Agent": "iqn"}, {"env_step": 7400000, "rew": 2102.1, "rew_std": 270.0775629333173, "Agent": "iqn"}, {"env_step": 7500000, "rew": 1930.6, "rew_std": 381.8453089930528, "Agent": "iqn"}, {"env_step": 7600000, "rew": 2114.2, "rew_std": 166.0757658419795, "Agent": "iqn"}, {"env_step": 7700000, "rew": 2000.9, "rew_std": 236.48909065747623, "Agent": "iqn"}, {"env_step": 7800000, "rew": 2138.6, "rew_std": 264.63189528097325, "Agent": "iqn"}, {"env_step": 7900000, "rew": 2128.6, "rew_std": 213.55570701809867, "Agent": "iqn"}, {"env_step": 8000000, "rew": 2109.3, "rew_std": 174.64652873733274, "Agent": "iqn"}, {"env_step": 8100000, "rew": 2009.0, "rew_std": 247.61986996200446, "Agent": "iqn"}, {"env_step": 8200000, "rew": 1983.7, "rew_std": 389.60288756630126, "Agent": "iqn"}, {"env_step": 8300000, "rew": 1994.3, "rew_std": 114.65692303563705, "Agent": "iqn"}, {"env_step": 8400000, "rew": 2095.6, "rew_std": 306.54630971518804, "Agent": "iqn"}, {"env_step": 8500000, "rew": 2008.5, "rew_std": 301.76555469436863, "Agent": "iqn"}, {"env_step": 8600000, "rew": 2129.8, "rew_std": 119.71365836862559, "Agent": "iqn"}, {"env_step": 8700000, "rew": 1975.8, "rew_std": 117.61700557317381, "Agent": "iqn"}, {"env_step": 8800000, "rew": 2123.2, "rew_std": 291.63051966486637, "Agent": "iqn"}, {"env_step": 8900000, "rew": 2044.2, "rew_std": 255.55109078225433, "Agent": "iqn"}, {"env_step": 9000000, "rew": 2228.6, "rew_std": 253.11902338623227, "Agent": "iqn"}, {"env_step": 9100000, "rew": 2149.0, "rew_std": 178.49201662819544, "Agent": "iqn"}, {"env_step": 9200000, "rew": 2148.9, "rew_std": 300.1541104166325, "Agent": "iqn"}, {"env_step": 9300000, "rew": 2022.7, "rew_std": 154.7081445819838, "Agent": "iqn"}, {"env_step": 9400000, "rew": 2217.1, "rew_std": 328.33350423007397, "Agent": "iqn"}, {"env_step": 9500000, "rew": 1985.3, "rew_std": 223.17842637674457, "Agent": "iqn"}, {"env_step": 9600000, "rew": 2110.1, "rew_std": 211.15892119444067, "Agent": "iqn"}, {"env_step": 9700000, "rew": 2162.6, "rew_std": 227.4173256372522, "Agent": "iqn"}, {"env_step": 9800000, "rew": 2212.4, "rew_std": 328.63329107076174, "Agent": "iqn"}, {"env_step": 9900000, "rew": 2094.4, "rew_std": 378.75142243957316, "Agent": "iqn"}, {"env_step": 10000000, "rew": 2151.1, "rew_std": 407.2558041329798, "Agent": "iqn"}, {"env_step": 0, "rew": 218.6, "rew_std": 99.78997945685728, "Agent": "rainbow"}, {"env_step": 100000, "rew": 395.4, "rew_std": 217.94045058226342, "Agent": "rainbow"}, {"env_step": 200000, "rew": 716.4, "rew_std": 209.39923591073583, "Agent": "rainbow"}, {"env_step": 300000, "rew": 943.7, "rew_std": 255.72096120576424, "Agent": "rainbow"}, {"env_step": 400000, "rew": 1031.6, "rew_std": 220.23632761195415, "Agent": "rainbow"}, {"env_step": 500000, "rew": 1255.4, "rew_std": 227.50701088098361, "Agent": "rainbow"}, {"env_step": 600000, "rew": 1306.0, "rew_std": 232.00991358129505, "Agent": "rainbow"}, {"env_step": 700000, "rew": 1406.3, "rew_std": 257.8658759898254, "Agent": "rainbow"}, {"env_step": 800000, "rew": 1297.9, "rew_std": 324.35488280585514, "Agent": "rainbow"}, {"env_step": 900000, "rew": 1442.4, "rew_std": 252.78734145522398, "Agent": "rainbow"}, {"env_step": 1000000, "rew": 1444.5, "rew_std": 303.269269791715, "Agent": "rainbow"}, {"env_step": 1100000, "rew": 1614.9, "rew_std": 246.82117008068815, "Agent": "rainbow"}, {"env_step": 1200000, "rew": 1609.4, "rew_std": 298.84952735448655, "Agent": "rainbow"}, {"env_step": 1300000, "rew": 1685.1, "rew_std": 399.2817175879707, "Agent": "rainbow"}, {"env_step": 1400000, "rew": 1548.6, "rew_std": 186.0033333034653, "Agent": "rainbow"}, {"env_step": 1500000, "rew": 1715.5, "rew_std": 250.8785562777337, "Agent": "rainbow"}, {"env_step": 1600000, "rew": 1737.4, "rew_std": 276.36541028138817, "Agent": "rainbow"}, {"env_step": 1700000, "rew": 2035.6, "rew_std": 429.08791639942507, "Agent": "rainbow"}, {"env_step": 1800000, "rew": 1743.6, "rew_std": 354.16470744556125, "Agent": "rainbow"}, {"env_step": 1900000, "rew": 1857.3, "rew_std": 287.29046277243526, "Agent": "rainbow"}, {"env_step": 2000000, "rew": 1836.8, "rew_std": 371.6750731485769, "Agent": "rainbow"}, {"env_step": 2100000, "rew": 1950.2, "rew_std": 312.3401351091467, "Agent": "rainbow"}, {"env_step": 2200000, "rew": 2048.7, "rew_std": 436.6781537929279, "Agent": "rainbow"}, {"env_step": 2300000, "rew": 1939.5, "rew_std": 278.96782968650706, "Agent": "rainbow"}, {"env_step": 2400000, "rew": 1835.0, "rew_std": 308.29239367846884, "Agent": "rainbow"}, {"env_step": 2500000, "rew": 1861.0, "rew_std": 219.7257381373425, "Agent": "rainbow"}, {"env_step": 2600000, "rew": 1996.7, "rew_std": 346.1849361251873, "Agent": "rainbow"}, {"env_step": 2700000, "rew": 2101.1, "rew_std": 340.95056826466794, "Agent": "rainbow"}, {"env_step": 2800000, "rew": 2038.1, "rew_std": 255.98728484047797, "Agent": "rainbow"}, {"env_step": 2900000, "rew": 1941.0, "rew_std": 302.7953103996163, "Agent": "rainbow"}, {"env_step": 3000000, "rew": 2099.1, "rew_std": 384.3590638972886, "Agent": "rainbow"}, {"env_step": 3100000, "rew": 2072.2, "rew_std": 272.6069698301934, "Agent": "rainbow"}, {"env_step": 3200000, "rew": 1995.5, "rew_std": 265.344021979015, "Agent": "rainbow"}, {"env_step": 3300000, "rew": 2059.7, "rew_std": 355.7518938811148, "Agent": "rainbow"}, {"env_step": 3400000, "rew": 1939.6, "rew_std": 301.1342557730688, "Agent": "rainbow"}, {"env_step": 3500000, "rew": 1921.4, "rew_std": 263.6744963017849, "Agent": "rainbow"}, {"env_step": 3600000, "rew": 2222.2, "rew_std": 170.95309298167143, "Agent": "rainbow"}, {"env_step": 3700000, "rew": 2048.7, "rew_std": 211.11515814834328, "Agent": "rainbow"}, {"env_step": 3800000, "rew": 2072.6, "rew_std": 327.48288504897477, "Agent": "rainbow"}, {"env_step": 3900000, "rew": 2167.4, "rew_std": 428.5650942389032, "Agent": "rainbow"}, {"env_step": 4000000, "rew": 2107.1, "rew_std": 285.2712568766787, "Agent": "rainbow"}, {"env_step": 4100000, "rew": 1802.2, "rew_std": 228.4192636359727, "Agent": "rainbow"}, {"env_step": 4200000, "rew": 1961.4, "rew_std": 254.83100282343983, "Agent": "rainbow"}, {"env_step": 4300000, "rew": 2048.1, "rew_std": 245.7040699703609, "Agent": "rainbow"}, {"env_step": 4400000, "rew": 2136.8, "rew_std": 292.6403253141986, "Agent": "rainbow"}, {"env_step": 4500000, "rew": 2099.8, "rew_std": 350.09107386507304, "Agent": "rainbow"}, {"env_step": 4600000, "rew": 2179.6, "rew_std": 253.28489887871328, "Agent": "rainbow"}, {"env_step": 4700000, "rew": 2250.3, "rew_std": 184.54974939023896, "Agent": "rainbow"}, {"env_step": 4800000, "rew": 1950.7, "rew_std": 262.8326653975871, "Agent": "rainbow"}, {"env_step": 4900000, "rew": 2161.1, "rew_std": 393.27940449507395, "Agent": "rainbow"}, {"env_step": 5000000, "rew": 2120.8, "rew_std": 218.70198901701832, "Agent": "rainbow"}, {"env_step": 5100000, "rew": 2207.4, "rew_std": 232.03973797606307, "Agent": "rainbow"}, {"env_step": 5200000, "rew": 2217.3, "rew_std": 359.2347561136032, "Agent": "rainbow"}, {"env_step": 5300000, "rew": 2141.2, "rew_std": 243.542521954586, "Agent": "rainbow"}, {"env_step": 5400000, "rew": 2160.4, "rew_std": 287.35072646506393, "Agent": "rainbow"}, {"env_step": 5500000, "rew": 2235.1, "rew_std": 212.19493396403223, "Agent": "rainbow"}, {"env_step": 5600000, "rew": 2280.4, "rew_std": 318.04597151984177, "Agent": "rainbow"}, {"env_step": 5700000, "rew": 2358.9, "rew_std": 310.13430961439917, "Agent": "rainbow"}, {"env_step": 5800000, "rew": 2267.6, "rew_std": 273.4484229246898, "Agent": "rainbow"}, {"env_step": 5900000, "rew": 2193.4, "rew_std": 181.35997353330202, "Agent": "rainbow"}, {"env_step": 6000000, "rew": 2366.9, "rew_std": 578.7907134707674, "Agent": "rainbow"}, {"env_step": 6100000, "rew": 2292.2, "rew_std": 293.46372859350095, "Agent": "rainbow"}, {"env_step": 6200000, "rew": 2048.0, "rew_std": 355.46139030842716, "Agent": "rainbow"}, {"env_step": 6300000, "rew": 2311.8, "rew_std": 276.04304012236935, "Agent": "rainbow"}, {"env_step": 6400000, "rew": 2211.3, "rew_std": 304.2528718023874, "Agent": "rainbow"}, {"env_step": 6500000, "rew": 2256.9, "rew_std": 187.56622830349818, "Agent": "rainbow"}, {"env_step": 6600000, "rew": 2262.1, "rew_std": 290.55342021735004, "Agent": "rainbow"}, {"env_step": 6700000, "rew": 2175.7, "rew_std": 346.9455432773276, "Agent": "rainbow"}, {"env_step": 6800000, "rew": 2179.1, "rew_std": 243.099341833745, "Agent": "rainbow"}, {"env_step": 6900000, "rew": 2338.0, "rew_std": 367.66288907095316, "Agent": "rainbow"}, {"env_step": 7000000, "rew": 2354.4, "rew_std": 258.2797707912875, "Agent": "rainbow"}, {"env_step": 7100000, "rew": 2320.4, "rew_std": 294.2781677257081, "Agent": "rainbow"}, {"env_step": 7200000, "rew": 2389.3, "rew_std": 247.6655204100886, "Agent": "rainbow"}, {"env_step": 7300000, "rew": 2187.6, "rew_std": 325.17201601613874, "Agent": "rainbow"}, {"env_step": 7400000, "rew": 2160.5, "rew_std": 205.99866504421817, "Agent": "rainbow"}, {"env_step": 7500000, "rew": 2400.5, "rew_std": 389.60268222896, "Agent": "rainbow"}, {"env_step": 7600000, "rew": 2228.4, "rew_std": 339.70051516004503, "Agent": "rainbow"}, {"env_step": 7700000, "rew": 2230.1, "rew_std": 383.68019234774164, "Agent": "rainbow"}, {"env_step": 7800000, "rew": 2358.0, "rew_std": 292.9624549323684, "Agent": "rainbow"}, {"env_step": 7900000, "rew": 2243.1, "rew_std": 245.06833740816052, "Agent": "rainbow"}, {"env_step": 8000000, "rew": 2271.8, "rew_std": 182.58466529257052, "Agent": "rainbow"}, {"env_step": 8100000, "rew": 2178.2, "rew_std": 284.79389038390553, "Agent": "rainbow"}, {"env_step": 8200000, "rew": 2151.4, "rew_std": 386.34937556569184, "Agent": "rainbow"}, {"env_step": 8300000, "rew": 2225.5, "rew_std": 272.64161457855255, "Agent": "rainbow"}, {"env_step": 8400000, "rew": 2378.0, "rew_std": 335.8109587252923, "Agent": "rainbow"}, {"env_step": 8500000, "rew": 2290.5, "rew_std": 365.62802135503784, "Agent": "rainbow"}, {"env_step": 8600000, "rew": 2313.5, "rew_std": 400.5274647262033, "Agent": "rainbow"}, {"env_step": 8700000, "rew": 2258.8, "rew_std": 245.23898548150945, "Agent": "rainbow"}, {"env_step": 8800000, "rew": 2345.8, "rew_std": 273.3780532522682, "Agent": "rainbow"}, {"env_step": 8900000, "rew": 2222.2, "rew_std": 320.44181999233496, "Agent": "rainbow"}, {"env_step": 9000000, "rew": 2361.9, "rew_std": 291.36212863033523, "Agent": "rainbow"}, {"env_step": 9100000, "rew": 2401.0, "rew_std": 308.2463300673667, "Agent": "rainbow"}, {"env_step": 9200000, "rew": 2296.9, "rew_std": 327.37759544599265, "Agent": "rainbow"}, {"env_step": 9300000, "rew": 2196.2, "rew_std": 414.1250535768151, "Agent": "rainbow"}, {"env_step": 9400000, "rew": 2343.8, "rew_std": 295.55195820701306, "Agent": "rainbow"}, {"env_step": 9500000, "rew": 2109.8, "rew_std": 314.5847421602008, "Agent": "rainbow"}, {"env_step": 9600000, "rew": 2524.2, "rew_std": 338.81463958925974, "Agent": "rainbow"}, {"env_step": 9700000, "rew": 2397.1, "rew_std": 202.06011481734836, "Agent": "rainbow"}, {"env_step": 9800000, "rew": 2485.8, "rew_std": 377.7405988241137, "Agent": "rainbow"}, {"env_step": 9900000, "rew": 2244.9, "rew_std": 120.4271148869722, "Agent": "rainbow"}, {"env_step": 10000000, "rew": 2214.0, "rew_std": 176.29690865128634, "Agent": "rainbow"}, {"env_step": 0, "rew": 370.7, "rew_std": 113.11502994739469, "Agent": "ppo"}, {"env_step": 100000, "rew": 505.9, "rew_std": 129.42986517801833, "Agent": "ppo"}, {"env_step": 200000, "rew": 421.8, "rew_std": 102.49858535609162, "Agent": "ppo"}, {"env_step": 300000, "rew": 479.5, "rew_std": 92.63719555340609, "Agent": "ppo"}, {"env_step": 400000, "rew": 508.4, "rew_std": 132.38595091625092, "Agent": "ppo"}, {"env_step": 500000, "rew": 560.6, "rew_std": 100.25088528287418, "Agent": "ppo"}, {"env_step": 600000, "rew": 664.6, "rew_std": 175.08866325379265, "Agent": "ppo"}, {"env_step": 700000, "rew": 588.6, "rew_std": 162.83746497658333, "Agent": "ppo"}, {"env_step": 800000, "rew": 610.4, "rew_std": 181.44982777616517, "Agent": "ppo"}, {"env_step": 900000, "rew": 633.7, "rew_std": 107.41233634922946, "Agent": "ppo"}, {"env_step": 1000000, "rew": 697.1, "rew_std": 94.16204118433288, "Agent": "ppo"}, {"env_step": 1100000, "rew": 631.9, "rew_std": 98.8275771229873, "Agent": "ppo"}, {"env_step": 1200000, "rew": 712.6, "rew_std": 130.972668904623, "Agent": "ppo"}, {"env_step": 1300000, "rew": 727.2, "rew_std": 129.8936488054747, "Agent": "ppo"}, {"env_step": 1400000, "rew": 664.1, "rew_std": 156.49054284524672, "Agent": "ppo"}, {"env_step": 1500000, "rew": 628.1, "rew_std": 184.3379776389011, "Agent": "ppo"}, {"env_step": 1600000, "rew": 641.9, "rew_std": 127.15065866915515, "Agent": "ppo"}, {"env_step": 1700000, "rew": 647.3, "rew_std": 92.44355034289846, "Agent": "ppo"}, {"env_step": 1800000, "rew": 647.3, "rew_std": 125.52294610946637, "Agent": "ppo"}, {"env_step": 1900000, "rew": 613.0, "rew_std": 117.30387887874808, "Agent": "ppo"}, {"env_step": 2000000, "rew": 757.2, "rew_std": 211.36262678155757, "Agent": "ppo"}, {"env_step": 2100000, "rew": 698.0, "rew_std": 88.34591105421914, "Agent": "ppo"}, {"env_step": 2200000, "rew": 756.7, "rew_std": 118.22609694986974, "Agent": "ppo"}, {"env_step": 2300000, "rew": 694.6, "rew_std": 142.86441124366837, "Agent": "ppo"}, {"env_step": 2400000, "rew": 795.3, "rew_std": 180.00836091693074, "Agent": "ppo"}, {"env_step": 2500000, "rew": 637.0, "rew_std": 111.93748255164576, "Agent": "ppo"}, {"env_step": 2600000, "rew": 731.4, "rew_std": 201.773239058107, "Agent": "ppo"}, {"env_step": 2700000, "rew": 709.3, "rew_std": 171.81679196167062, "Agent": "ppo"}, {"env_step": 2800000, "rew": 643.3, "rew_std": 124.15880959480886, "Agent": "ppo"}, {"env_step": 2900000, "rew": 841.8, "rew_std": 230.30275725661647, "Agent": "ppo"}, {"env_step": 3000000, "rew": 771.9, "rew_std": 201.02659028098745, "Agent": "ppo"}, {"env_step": 3100000, "rew": 803.4, "rew_std": 195.58128744846732, "Agent": "ppo"}, {"env_step": 3200000, "rew": 756.8, "rew_std": 186.79657384438292, "Agent": "ppo"}, {"env_step": 3300000, "rew": 761.7, "rew_std": 183.00986312218257, "Agent": "ppo"}, {"env_step": 3400000, "rew": 884.0, "rew_std": 177.51788642274894, "Agent": "ppo"}, {"env_step": 3500000, "rew": 882.3, "rew_std": 235.03235947417963, "Agent": "ppo"}, {"env_step": 3600000, "rew": 886.8, "rew_std": 165.33166665826604, "Agent": "ppo"}, {"env_step": 3700000, "rew": 887.6, "rew_std": 155.86545479996522, "Agent": "ppo"}, {"env_step": 3800000, "rew": 870.0, "rew_std": 140.03142504452347, "Agent": "ppo"}, {"env_step": 3900000, "rew": 963.2, "rew_std": 163.08267841803433, "Agent": "ppo"}, {"env_step": 4000000, "rew": 915.2, "rew_std": 198.6211469103932, "Agent": "ppo"}, {"env_step": 4100000, "rew": 954.3, "rew_std": 224.29135070260733, "Agent": "ppo"}, {"env_step": 4200000, "rew": 1005.9, "rew_std": 185.8673989703412, "Agent": "ppo"}, {"env_step": 4300000, "rew": 1021.8, "rew_std": 173.70768549491413, "Agent": "ppo"}, {"env_step": 4400000, "rew": 969.5, "rew_std": 176.3333490863257, "Agent": "ppo"}, {"env_step": 4500000, "rew": 1041.1, "rew_std": 177.89291722831462, "Agent": "ppo"}, {"env_step": 4600000, "rew": 977.5, "rew_std": 200.08660624839436, "Agent": "ppo"}, {"env_step": 4700000, "rew": 1033.2, "rew_std": 133.5520872169357, "Agent": "ppo"}, {"env_step": 4800000, "rew": 1085.6, "rew_std": 141.09018392503427, "Agent": "ppo"}, {"env_step": 4900000, "rew": 1077.5, "rew_std": 248.93543339589084, "Agent": "ppo"}, {"env_step": 5000000, "rew": 1067.3, "rew_std": 158.23656341061002, "Agent": "ppo"}, {"env_step": 5100000, "rew": 1198.8, "rew_std": 166.84831434569546, "Agent": "ppo"}, {"env_step": 5200000, "rew": 1088.0, "rew_std": 144.770853420155, "Agent": "ppo"}, {"env_step": 5300000, "rew": 1108.4, "rew_std": 154.99238690980923, "Agent": "ppo"}, {"env_step": 5400000, "rew": 1203.5, "rew_std": 257.2929264476581, "Agent": "ppo"}, {"env_step": 5500000, "rew": 1092.1, "rew_std": 100.34286222746488, "Agent": "ppo"}, {"env_step": 5600000, "rew": 1198.8, "rew_std": 151.49838282965268, "Agent": "ppo"}, {"env_step": 5700000, "rew": 1137.5, "rew_std": 123.52024125624108, "Agent": "ppo"}, {"env_step": 5800000, "rew": 1118.2, "rew_std": 153.89463928285483, "Agent": "ppo"}, {"env_step": 5900000, "rew": 1187.0, "rew_std": 157.57855184002676, "Agent": "ppo"}, {"env_step": 6000000, "rew": 1200.2, "rew_std": 167.1201962660408, "Agent": "ppo"}, {"env_step": 6100000, "rew": 1207.1, "rew_std": 205.2556698364262, "Agent": "ppo"}, {"env_step": 6200000, "rew": 1304.3, "rew_std": 198.32904477156137, "Agent": "ppo"}, {"env_step": 6300000, "rew": 1280.2, "rew_std": 114.50310039470546, "Agent": "ppo"}, {"env_step": 6400000, "rew": 1224.8, "rew_std": 189.02105702804647, "Agent": "ppo"}, {"env_step": 6500000, "rew": 1325.9, "rew_std": 179.55859767774976, "Agent": "ppo"}, {"env_step": 6600000, "rew": 1417.9, "rew_std": 262.43606840524035, "Agent": "ppo"}, {"env_step": 6700000, "rew": 1329.9, "rew_std": 153.3286992053347, "Agent": "ppo"}, {"env_step": 6800000, "rew": 1324.8, "rew_std": 237.16230729186287, "Agent": "ppo"}, {"env_step": 6900000, "rew": 1362.0, "rew_std": 162.35947770302786, "Agent": "ppo"}, {"env_step": 7000000, "rew": 1291.7, "rew_std": 179.75597347515324, "Agent": "ppo"}, {"env_step": 7100000, "rew": 1315.4, "rew_std": 236.61200307676702, "Agent": "ppo"}, {"env_step": 7200000, "rew": 1400.3, "rew_std": 257.7530019223831, "Agent": "ppo"}, {"env_step": 7300000, "rew": 1361.2, "rew_std": 186.70286553773084, "Agent": "ppo"}, {"env_step": 7400000, "rew": 1465.6, "rew_std": 229.4812410634037, "Agent": "ppo"}, {"env_step": 7500000, "rew": 1450.6, "rew_std": 163.0295678703713, "Agent": "ppo"}, {"env_step": 7600000, "rew": 1490.6, "rew_std": 267.5194198558303, "Agent": "ppo"}, {"env_step": 7700000, "rew": 1461.2, "rew_std": 199.32877363792716, "Agent": "ppo"}, {"env_step": 7800000, "rew": 1510.4, "rew_std": 212.96769708103622, "Agent": "ppo"}, {"env_step": 7900000, "rew": 1515.6, "rew_std": 344.78027785823247, "Agent": "ppo"}, {"env_step": 8000000, "rew": 1401.0, "rew_std": 341.4229049141255, "Agent": "ppo"}, {"env_step": 8100000, "rew": 1480.9, "rew_std": 253.02982037696665, "Agent": "ppo"}, {"env_step": 8200000, "rew": 1490.2, "rew_std": 273.54590108426044, "Agent": "ppo"}, {"env_step": 8300000, "rew": 1565.9, "rew_std": 238.09512804759362, "Agent": "ppo"}, {"env_step": 8400000, "rew": 1507.5, "rew_std": 310.9798224965729, "Agent": "ppo"}, {"env_step": 8500000, "rew": 1463.0, "rew_std": 203.8013738913455, "Agent": "ppo"}, {"env_step": 8600000, "rew": 1554.6, "rew_std": 261.8802016189846, "Agent": "ppo"}, {"env_step": 8700000, "rew": 1525.2, "rew_std": 198.0645349374794, "Agent": "ppo"}, {"env_step": 8800000, "rew": 1599.1, "rew_std": 190.6459808126046, "Agent": "ppo"}, {"env_step": 8900000, "rew": 1544.1, "rew_std": 207.58297136325993, "Agent": "ppo"}, {"env_step": 9000000, "rew": 1524.5, "rew_std": 192.14382633850093, "Agent": "ppo"}, {"env_step": 9100000, "rew": 1563.0, "rew_std": 273.20761336390314, "Agent": "ppo"}, {"env_step": 9200000, "rew": 1699.4, "rew_std": 248.01701554530484, "Agent": "ppo"}, {"env_step": 9300000, "rew": 1534.9, "rew_std": 245.79888120168488, "Agent": "ppo"}, {"env_step": 9400000, "rew": 1526.9, "rew_std": 157.2097007184989, "Agent": "ppo"}, {"env_step": 9500000, "rew": 1573.5, "rew_std": 227.85620465548, "Agent": "ppo"}, {"env_step": 9600000, "rew": 1482.6, "rew_std": 161.7721854955295, "Agent": "ppo"}, {"env_step": 9700000, "rew": 1633.9, "rew_std": 182.42667019928857, "Agent": "ppo"}, {"env_step": 9800000, "rew": 1514.0, "rew_std": 231.70110055845657, "Agent": "ppo"}, {"env_step": 9900000, "rew": 1624.4, "rew_std": 227.23696882329688, "Agent": "ppo"}, {"env_step": 10000000, "rew": 1531.6, "rew_std": 227.96455864892684, "Agent": "ppo"}] \ No newline at end of file diff --git a/examples/atari/benchmark/PongNoFrameskip-v4/result.json b/examples/atari/benchmark/PongNoFrameskip-v4/result.json new file mode 100644 index 0000000000000000000000000000000000000000..ee5aa54780f2f13f4b8d0717d310e5a8c88308d4 --- /dev/null +++ b/examples/atari/benchmark/PongNoFrameskip-v4/result.json @@ -0,0 +1 @@ +[{"env_step": 0, "rew": -20.979999923706053, "rew_std": 0.04000015258789063, "Agent": "c51"}, {"env_step": 100000, "rew": -20.869999885559082, "rew_std": 0.15524167570244413, "Agent": "c51"}, {"env_step": 200000, "rew": -20.560000038146974, "rew_std": 0.40298883737879937, "Agent": "c51"}, {"env_step": 300000, "rew": -18.95999994277954, "rew_std": 2.2632720366854833, "Agent": "c51"}, {"env_step": 400000, "rew": -16.210000228881835, "rew_std": 2.7750494802377017, "Agent": "c51"}, {"env_step": 500000, "rew": -15.040000057220459, "rew_std": 3.022648038181074, "Agent": "c51"}, {"env_step": 600000, "rew": -12.759999966621399, "rew_std": 5.666603835995492, "Agent": "c51"}, {"env_step": 700000, "rew": -8.17000013589859, "rew_std": 6.876634371998414, "Agent": "c51"}, {"env_step": 800000, "rew": -5.910000105202198, "rew_std": 5.658347880641881, "Agent": "c51"}, {"env_step": 900000, "rew": -2.0299999713897705, "rew_std": 7.5090678214603175, "Agent": "c51"}, {"env_step": 1000000, "rew": -1.05, "rew_std": 8.06576073888153, "Agent": "c51"}, {"env_step": 1100000, "rew": 5.750000011920929, "rew_std": 8.470448577145289, "Agent": "c51"}, {"env_step": 1200000, "rew": 11.85, "rew_std": 6.486486065226738, "Agent": "c51"}, {"env_step": 1300000, "rew": 11.839999842643739, "rew_std": 9.283232047765221, "Agent": "c51"}, {"env_step": 1400000, "rew": 10.289999675750732, "rew_std": 13.408239764024396, "Agent": "c51"}, {"env_step": 1500000, "rew": 15.300000054495674, "rew_std": 7.003264581462973, "Agent": "c51"}, {"env_step": 1600000, "rew": 15.419999885559083, "rew_std": 5.896914245313163, "Agent": "c51"}, {"env_step": 1700000, "rew": 16.47999973297119, "rew_std": 6.142116861407374, "Agent": "c51"}, {"env_step": 1800000, "rew": 18.700000381469728, "rew_std": 1.2743627474105064, "Agent": "c51"}, {"env_step": 1900000, "rew": 13.000000283122063, "rew_std": 8.335466556491935, "Agent": "c51"}, {"env_step": 2000000, "rew": 17.47499966621399, "rew_std": 3.660174086617874, "Agent": "c51"}, {"env_step": 2100000, "rew": 13.566666801770529, "rew_std": 8.47833851940157, "Agent": "c51"}, {"env_step": 2200000, "rew": 19.75, "rew_std": 0.5499992370605469, "Agent": "c51"}, {"env_step": 2300000, "rew": 14.0, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 2400000, "rew": 19.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 2500000, "rew": 15.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 2600000, "rew": 18.200000762939453, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 2700000, "rew": 19.0, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 2800000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 2900000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 3000000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 3100000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 3200000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 3300000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 3400000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 3500000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 3600000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 3700000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 3800000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 3900000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 4000000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 4100000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 4200000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 4300000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 4400000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 4500000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 4600000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 4700000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 4800000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 4900000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 5000000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 5100000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 5200000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 5300000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 5400000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 5500000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 5600000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 5700000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 5800000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 5900000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 6000000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 6100000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 6200000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 6300000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 6400000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 6500000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 6600000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 6700000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 6800000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 6900000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 7000000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 7100000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 7200000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 7300000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 7400000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 7500000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 7600000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 7700000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 7800000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 7900000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 8000000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 8100000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 8200000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 8300000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 8400000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 8500000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 8600000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 8700000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 8800000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 8900000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 9000000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 9100000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 9200000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 9300000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 9400000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 9500000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 9600000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 9700000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 9800000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 9900000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 10000000, "rew": 20.600000381469727, "rew_std": 0.0, "Agent": "c51"}, {"env_step": 0, "rew": -20.979999923706053, "rew_std": 0.04000015258789063, "Agent": "dqn"}, {"env_step": 100000, "rew": -20.689999961853026, "rew_std": 0.5281095979692643, "Agent": "dqn"}, {"env_step": 200000, "rew": -18.38000020980835, "rew_std": 2.5906757312772744, "Agent": "dqn"}, {"env_step": 300000, "rew": -18.030000019073487, "rew_std": 1.7245580854624265, "Agent": "dqn"}, {"env_step": 400000, "rew": -13.899999952316284, "rew_std": 3.808936970212056, "Agent": "dqn"}, {"env_step": 500000, "rew": -5.709999942779541, "rew_std": 9.006936246078585, "Agent": "dqn"}, {"env_step": 600000, "rew": -1.0700000286102296, "rew_std": 8.908540906843577, "Agent": "dqn"}, {"env_step": 700000, "rew": 6.160000026226044, "rew_std": 7.3988107178341656, "Agent": "dqn"}, {"env_step": 800000, "rew": 15.04000015258789, "rew_std": 5.2547504042740645, "Agent": "dqn"}, {"env_step": 900000, "rew": 19.755555470784504, "rew_std": 0.4374447957623508, "Agent": "dqn"}, {"env_step": 1000000, "rew": 19.983332951863606, "rew_std": 0.5814256919850659, "Agent": "dqn"}, {"env_step": 1100000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 1200000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 1300000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 1400000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 1500000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 1600000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 1700000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 1800000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 1900000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 2000000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 2100000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 2200000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 2300000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 2400000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 2500000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 2600000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 2700000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 2800000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 2900000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 3000000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 3100000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 3200000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 3300000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 3400000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 3500000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 3600000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 3700000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 3800000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 3900000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 4000000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 4100000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 4200000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 4300000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 4400000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 4500000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 4600000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 4700000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 4800000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 4900000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 5000000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 5100000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 5200000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 5300000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 5400000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 5500000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 5600000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 5700000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 5800000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 5900000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 6000000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 6100000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 6200000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 6300000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 6400000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 6500000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 6600000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 6700000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 6800000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 6900000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 7000000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 7100000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 7200000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 7300000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 7400000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 7500000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 7600000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 7700000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 7800000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 7900000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 8000000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 8100000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 8200000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 8300000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 8400000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 8500000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 8600000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 8700000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 8800000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 8900000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 9000000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 9100000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 9200000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 9300000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 9400000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 9500000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 9600000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 9700000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 9800000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 9900000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 10000000, "rew": 20.25, "rew_std": 0.25, "Agent": "dqn"}, {"env_step": 0, "rew": -20.979999923706053, "rew_std": 0.04000015258789063, "Agent": "fqf"}, {"env_step": 100000, "rew": -20.879999923706055, "rew_std": 0.15999994277996732, "Agent": "fqf"}, {"env_step": 200000, "rew": -19.329999923706055, "rew_std": 1.1883183401767072, "Agent": "fqf"}, {"env_step": 300000, "rew": -18.410000228881835, "rew_std": 2.3947650818103883, "Agent": "fqf"}, {"env_step": 400000, "rew": -15.789999866485596, "rew_std": 1.8124292335112842, "Agent": "fqf"}, {"env_step": 500000, "rew": -12.899999952316284, "rew_std": 3.9191835397861126, "Agent": "fqf"}, {"env_step": 600000, "rew": -7.259999930858612, "rew_std": 6.181294202818166, "Agent": "fqf"}, {"env_step": 700000, "rew": -0.2800000667572021, "rew_std": 5.578135949422739, "Agent": "fqf"}, {"env_step": 800000, "rew": 5.889999827742576, "rew_std": 6.357428520171511, "Agent": "fqf"}, {"env_step": 900000, "rew": 12.8555555873447, "rew_std": 5.842585252592779, "Agent": "fqf"}, {"env_step": 1000000, "rew": 18.875, "rew_std": 2.6085196695608577, "Agent": "fqf"}, {"env_step": 1100000, "rew": 18.749999682108562, "rew_std": 2.6196374340932294, "Agent": "fqf"}, {"env_step": 1200000, "rew": 19.65999984741211, "rew_std": 1.051855243717263, "Agent": "fqf"}, {"env_step": 1300000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 1400000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 1500000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 1600000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 1700000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 1800000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 1900000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 2000000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 2100000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 2200000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 2300000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 2400000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 2500000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 2600000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 2700000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 2800000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 2900000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 3000000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 3100000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 3200000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 3300000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 3400000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 3500000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 3600000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 3700000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 3800000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 3900000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 4000000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 4100000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 4200000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 4300000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 4400000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 4500000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 4600000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 4700000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 4800000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 4900000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 5000000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 5100000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 5200000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 5300000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 5400000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 5500000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 5600000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 5700000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 5800000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 5900000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 6000000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 6100000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 6200000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 6300000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 6400000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 6500000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 6600000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 6700000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 6800000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 6900000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 7000000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 7100000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 7200000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 7300000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 7400000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 7500000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 7600000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 7700000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 7800000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 7900000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 8000000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 8100000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 8200000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 8300000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 8400000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 8500000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 8600000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 8700000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 8800000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 8900000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 9000000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 9100000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 9200000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 9300000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 9400000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 9500000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 9600000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 9700000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 9800000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 9900000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 10000000, "rew": 20.399999618530273, "rew_std": 0.39999961853027344, "Agent": "fqf"}, {"env_step": 0, "rew": -21.0, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 100000, "rew": -20.870000076293945, "rew_std": 0.27221299202092436, "Agent": "qrdqn"}, {"env_step": 200000, "rew": -19.48000011444092, "rew_std": 1.2023307411527404, "Agent": "qrdqn"}, {"env_step": 300000, "rew": -16.780000019073487, "rew_std": 2.1613883342785347, "Agent": "qrdqn"}, {"env_step": 400000, "rew": -12.920000219345093, "rew_std": 3.473845164617662, "Agent": "qrdqn"}, {"env_step": 500000, "rew": -7.060000002384186, "rew_std": 6.094456461922503, "Agent": "qrdqn"}, {"env_step": 600000, "rew": -3.779999941587448, "rew_std": 6.045295632144355, "Agent": "qrdqn"}, {"env_step": 700000, "rew": 9.749999952316283, "rew_std": 6.640368991575429, "Agent": "qrdqn"}, {"env_step": 800000, "rew": 15.269999933242797, "rew_std": 4.12966090763813, "Agent": "qrdqn"}, {"env_step": 900000, "rew": 19.622222052680122, "rew_std": 1.2916639583656102, "Agent": "qrdqn"}, {"env_step": 1000000, "rew": 20.09999990463257, "rew_std": 0.3162274644388989, "Agent": "qrdqn"}, {"env_step": 1100000, "rew": 19.899999618530273, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 1200000, "rew": 19.899999618530273, "rew_std": 0.10000038146972656, "Agent": "qrdqn"}, {"env_step": 1300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 1400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 1500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 1600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 1700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 1800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 1900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 2000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 2100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 2200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 2300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 2400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 2500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 2600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 2700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 2800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 2900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 3000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 3100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 3200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 3300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 3400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 3500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 3600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 3700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 3800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 3900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 4000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 4100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 4200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 4300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 4400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 4500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 4600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 4700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 4800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 4900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 5000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 5100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 5200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 5300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 5400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 5500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 5600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 5700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 5800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 5900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 6000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 6100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 6200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 6300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 6400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 6500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 6600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 6700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 6800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 6900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 7000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 7100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 7200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 7300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 7400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 7500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 7600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 7700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 7800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 7900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 8000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 8100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 8200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 8300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 8400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 8500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 8600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 8700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 8800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 8900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 9000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 9100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 9200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 9300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 9400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 9500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 9600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 9700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 9800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 9900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 10000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "qrdqn"}, {"env_step": 0, "rew": -20.969999885559083, "rew_std": 0.06403148663413369, "Agent": "iqn"}, {"env_step": 100000, "rew": -19.1300000667572, "rew_std": 5.345100609195572, "Agent": "iqn"}, {"env_step": 200000, "rew": -19.34000015258789, "rew_std": 1.09380072496787, "Agent": "iqn"}, {"env_step": 300000, "rew": -18.3, "rew_std": 1.1471704685545094, "Agent": "iqn"}, {"env_step": 400000, "rew": -14.660000038146972, "rew_std": 2.7383207958984883, "Agent": "iqn"}, {"env_step": 500000, "rew": -9.659999978542327, "rew_std": 5.29871682181189, "Agent": "iqn"}, {"env_step": 600000, "rew": -8.680000057816505, "rew_std": 4.040495106986447, "Agent": "iqn"}, {"env_step": 700000, "rew": 2.8499999545514583, "rew_std": 6.374519581488704, "Agent": "iqn"}, {"env_step": 800000, "rew": 7.970000147819519, "rew_std": 8.160275826601659, "Agent": "iqn"}, {"env_step": 900000, "rew": 17.166666507720947, "rew_std": 4.651164654639624, "Agent": "iqn"}, {"env_step": 1000000, "rew": 17.849999984105427, "rew_std": 4.5853935091484725, "Agent": "iqn"}, {"env_step": 1100000, "rew": 18.260000038146973, "rew_std": 1.9652988864635694, "Agent": "iqn"}, {"env_step": 1200000, "rew": 18.68000030517578, "rew_std": 2.9047550585330666, "Agent": "iqn"}, {"env_step": 1300000, "rew": 19.600000381469727, "rew_std": 0.39999961853027344, "Agent": "iqn"}, {"env_step": 1400000, "rew": 18.799999237060547, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 1500000, "rew": 17.0, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 1600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 1700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 1800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 1900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 2000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 2100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 2200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 2300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 2400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 2500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 2600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 2700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 2800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 2900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 3000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 3100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 3200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 3300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 3400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 3500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 3600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 3700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 3800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 3900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 4000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 4100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 4200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 4300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 4400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 4500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 4600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 4700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 4800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 4900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 5000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 5100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 5200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 5300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 5400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 5500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 5600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 5700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 5800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 5900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 6000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 6100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 6200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 6300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 6400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 6500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 6600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 6700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 6800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 6900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 7000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 7100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 7200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 7300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 7400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 7500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 7600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 7700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 7800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 7900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 8000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 8100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 8200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 8300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 8400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 8500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 8600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 8700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 8800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 8900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 9000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 9100000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 9200000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 9300000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 9400000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 9500000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 9600000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 9700000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 9800000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 9900000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 10000000, "rew": 20.700000762939453, "rew_std": 0.0, "Agent": "iqn"}, {"env_step": 0, "rew": -21.0, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 100000, "rew": -20.979999923706053, "rew_std": 0.04000015258789063, "Agent": "rainbow"}, {"env_step": 200000, "rew": -20.439999961853026, "rew_std": 0.611882543624649, "Agent": "rainbow"}, {"env_step": 300000, "rew": -20.05, "rew_std": 1.0984076053899456, "Agent": "rainbow"}, {"env_step": 400000, "rew": -18.579999923706055, "rew_std": 1.1417527594487265, "Agent": "rainbow"}, {"env_step": 500000, "rew": -16.669999980926512, "rew_std": 2.142918529337897, "Agent": "rainbow"}, {"env_step": 600000, "rew": -14.539999961853027, "rew_std": 3.4153184124021854, "Agent": "rainbow"}, {"env_step": 700000, "rew": -11.319999885559081, "rew_std": 2.876734170162213, "Agent": "rainbow"}, {"env_step": 800000, "rew": -10.470000064373016, "rew_std": 4.46520999148195, "Agent": "rainbow"}, {"env_step": 900000, "rew": -2.170000058412552, "rew_std": 4.194055360234164, "Agent": "rainbow"}, {"env_step": 1000000, "rew": -1.1700000524520875, "rew_std": 7.9131599288409395, "Agent": "rainbow"}, {"env_step": 1100000, "rew": 4.420000007003546, "rew_std": 8.925671038750298, "Agent": "rainbow"}, {"env_step": 1200000, "rew": 4.199999978972806, "rew_std": 7.79358685682886, "Agent": "rainbow"}, {"env_step": 1300000, "rew": 4.3666667805777655, "rew_std": 9.006787650672438, "Agent": "rainbow"}, {"env_step": 1400000, "rew": 8.224999904632568, "rew_std": 5.813507857159169, "Agent": "rainbow"}, {"env_step": 1500000, "rew": 10.48749989271164, "rew_std": 5.611915177475383, "Agent": "rainbow"}, {"env_step": 1600000, "rew": 10.325000084936619, "rew_std": 7.195441264466608, "Agent": "rainbow"}, {"env_step": 1700000, "rew": 5.216666638851166, "rew_std": 8.010496691447514, "Agent": "rainbow"}, {"env_step": 1800000, "rew": 7.8833333651224775, "rew_std": 8.73506665798113, "Agent": "rainbow"}, {"env_step": 1900000, "rew": 10.416666527589163, "rew_std": 7.064799091845216, "Agent": "rainbow"}, {"env_step": 2000000, "rew": 14.739999961853027, "rew_std": 3.559550394534507, "Agent": "rainbow"}, {"env_step": 2100000, "rew": 16.82000026702881, "rew_std": 2.790985522657408, "Agent": "rainbow"}, {"env_step": 2200000, "rew": 14.699999809265137, "rew_std": 3.199374665910315, "Agent": "rainbow"}, {"env_step": 2300000, "rew": 16.800000190734863, "rew_std": 1.987460835388907, "Agent": "rainbow"}, {"env_step": 2400000, "rew": 16.649999856948853, "rew_std": 3.9150350231246187, "Agent": "rainbow"}, {"env_step": 2500000, "rew": 18.700000127156574, "rew_std": 1.8384774512584698, "Agent": "rainbow"}, {"env_step": 2600000, "rew": 9.5, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 2700000, "rew": 16.100000381469727, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 2800000, "rew": 15.5, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 2900000, "rew": 15.600000381469727, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 3000000, "rew": -4.300000190734863, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 3100000, "rew": 17.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 3200000, "rew": 17.799999237060547, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 3300000, "rew": 19.100000381469727, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 3400000, "rew": 19.5, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 3500000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 3600000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 3700000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 3800000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 3900000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 4000000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 4100000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 4200000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 4300000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 4400000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 4500000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 4600000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 4700000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 4800000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 4900000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 5000000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 5100000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 5200000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 5300000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 5400000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 5500000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 5600000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 5700000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 5800000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 5900000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 6000000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 6100000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 6200000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 6300000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 6400000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 6500000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 6600000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 6700000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 6800000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 6900000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 7000000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 7100000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 7200000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 7300000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 7400000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 7500000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 7600000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 7700000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 7800000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 7900000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 8000000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 8100000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 8200000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 8300000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 8400000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 8500000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 8600000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 8700000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 8800000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 8900000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 9000000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 9100000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 9200000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 9300000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 9400000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 9500000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 9600000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 9700000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 9800000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 9900000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 10000000, "rew": 20.200000762939453, "rew_std": 0.0, "Agent": "rainbow"}, {"env_step": 0, "rew": -20.75, "rew_std": 0.34132100802270626, "Agent": "ppo"}, {"env_step": 100000, "rew": -20.61000003814697, "rew_std": 0.32695557164785227, "Agent": "ppo"}, {"env_step": 200000, "rew": -19.98000011444092, "rew_std": 0.5793098790055098, "Agent": "ppo"}, {"env_step": 300000, "rew": -17.55, "rew_std": 2.648112510163754, "Agent": "ppo"}, {"env_step": 400000, "rew": -13.260000038146973, "rew_std": 4.553284472142394, "Agent": "ppo"}, {"env_step": 500000, "rew": -9.600000047683716, "rew_std": 5.175519361004068, "Agent": "ppo"}, {"env_step": 600000, "rew": -8.650000131130218, "rew_std": 5.414471552623239, "Agent": "ppo"}, {"env_step": 700000, "rew": -4.4400000154972075, "rew_std": 5.544041881025165, "Agent": "ppo"}, {"env_step": 800000, "rew": 0.6799999058246613, "rew_std": 7.877410761852243, "Agent": "ppo"}, {"env_step": 900000, "rew": 5.640000033378601, "rew_std": 6.771735266186935, "Agent": "ppo"}, {"env_step": 1000000, "rew": 5.6600001603364944, "rew_std": 7.132489235026172, "Agent": "ppo"}, {"env_step": 1100000, "rew": 6.7699999690055845, "rew_std": 6.8543488111854005, "Agent": "ppo"}, {"env_step": 1200000, "rew": 11.959999895095825, "rew_std": 4.759663803505452, "Agent": "ppo"}, {"env_step": 1300000, "rew": 13.499999952316283, "rew_std": 4.599999950243093, "Agent": "ppo"}, {"env_step": 1400000, "rew": 13.460000038146973, "rew_std": 4.538986609617991, "Agent": "ppo"}, {"env_step": 1500000, "rew": 13.359999942779542, "rew_std": 4.740295478528203, "Agent": "ppo"}, {"env_step": 1600000, "rew": 15.47999997138977, "rew_std": 3.309924420995019, "Agent": "ppo"}, {"env_step": 1700000, "rew": 13.88000020980835, "rew_std": 3.9776378782724717, "Agent": "ppo"}, {"env_step": 1800000, "rew": 16.680000019073486, "rew_std": 2.098475869175282, "Agent": "ppo"}, {"env_step": 1900000, "rew": 15.039999866485596, "rew_std": 3.2720634721834996, "Agent": "ppo"}, {"env_step": 2000000, "rew": 16.200000190734862, "rew_std": 2.0079841339730997, "Agent": "ppo"}, {"env_step": 2100000, "rew": 16.709999752044677, "rew_std": 2.592083999305904, "Agent": "ppo"}, {"env_step": 2200000, "rew": 17.93000011444092, "rew_std": 1.5020321173047337, "Agent": "ppo"}, {"env_step": 2300000, "rew": 16.13333299424913, "rew_std": 4.134945780702597, "Agent": "ppo"}, {"env_step": 2400000, "rew": 16.46666653951009, "rew_std": 2.9988884554980983, "Agent": "ppo"}, {"env_step": 2500000, "rew": 17.266666624281143, "rew_std": 1.7549928259554646, "Agent": "ppo"}, {"env_step": 2600000, "rew": 17.63333363003201, "rew_std": 1.300427226972741, "Agent": "ppo"}, {"env_step": 2700000, "rew": 16.8111113442315, "rew_std": 2.6534861585374485, "Agent": "ppo"}, {"env_step": 2800000, "rew": 17.000000211927627, "rew_std": 2.82999802642146, "Agent": "ppo"}, {"env_step": 2900000, "rew": 16.97777779897054, "rew_std": 2.5494127247858547, "Agent": "ppo"}, {"env_step": 3000000, "rew": 17.81250011920929, "rew_std": 1.6593954848338575, "Agent": "ppo"}, {"env_step": 3100000, "rew": 17.06250011920929, "rew_std": 2.284697851779331, "Agent": "ppo"}, {"env_step": 3200000, "rew": 16.975000381469727, "rew_std": 2.1057958062253594, "Agent": "ppo"}, {"env_step": 3300000, "rew": 16.824999809265137, "rew_std": 2.9625790227338165, "Agent": "ppo"}, {"env_step": 3400000, "rew": 18.1875, "rew_std": 2.531519990815377, "Agent": "ppo"}, {"env_step": 3500000, "rew": 16.71666669845581, "rew_std": 2.412755415928967, "Agent": "ppo"}, {"env_step": 3600000, "rew": 16.46666669845581, "rew_std": 3.901566607031345, "Agent": "ppo"}, {"env_step": 3700000, "rew": 16.166666666666668, "rew_std": 2.739626889190835, "Agent": "ppo"}, {"env_step": 3800000, "rew": 17.300000190734863, "rew_std": 2.78926494928614, "Agent": "ppo"}, {"env_step": 3900000, "rew": 18.09999990463257, "rew_std": 2.077658824212887, "Agent": "ppo"}, {"env_step": 4000000, "rew": 17.019999504089355, "rew_std": 1.4483091764261082, "Agent": "ppo"}, {"env_step": 4100000, "rew": 18.620000076293945, "rew_std": 1.2253983415940415, "Agent": "ppo"}, {"env_step": 4200000, "rew": 18.35999984741211, "rew_std": 2.2526428387365733, "Agent": "ppo"}, {"env_step": 4300000, "rew": 19.0, "rew_std": 0.6519197285159748, "Agent": "ppo"}, {"env_step": 4400000, "rew": 18.975000381469727, "rew_std": 1.3141059206073868, "Agent": "ppo"}, {"env_step": 4500000, "rew": 19.625, "rew_std": 0.3897113582892007, "Agent": "ppo"}, {"env_step": 4600000, "rew": 19.566666920979817, "rew_std": 0.684754539003982, "Agent": "ppo"}, {"env_step": 4700000, "rew": 15.199999809265137, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 4800000, "rew": 17.200000762939453, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 4900000, "rew": 16.799999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 5000000, "rew": 18.700000762939453, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 5100000, "rew": 16.100000381469727, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 5200000, "rew": 17.700000762939453, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 5300000, "rew": 17.700000762939453, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 5400000, "rew": 17.600000381469727, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 5500000, "rew": 16.700000762939453, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 5600000, "rew": 19.399999618530273, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 5700000, "rew": 19.100000381469727, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 5800000, "rew": 18.5, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 5900000, "rew": 18.700000762939453, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 6000000, "rew": 19.600000381469727, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 6100000, "rew": 19.0, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 6200000, "rew": 19.0, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 6300000, "rew": 19.100000381469727, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 6400000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 6500000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 6600000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 6700000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 6800000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 6900000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 7000000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 7100000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 7200000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 7300000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 7400000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 7500000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 7600000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 7700000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 7800000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 7900000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 8000000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 8100000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 8200000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 8300000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 8400000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 8500000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 8600000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 8700000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 8800000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 8900000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 9000000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 9100000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 9200000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 9300000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 9400000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 9500000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 9600000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 9700000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 9800000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 9900000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}, {"env_step": 10000000, "rew": 20.299999237060547, "rew_std": 0.0, "Agent": "ppo"}] \ No newline at end of file diff --git a/examples/atari/benchmark/QbertNoFrameskip-v4/result.json b/examples/atari/benchmark/QbertNoFrameskip-v4/result.json new file mode 100644 index 0000000000000000000000000000000000000000..f4add1d0946e7846ecbdda90ebf90c0a06ec6d5e --- /dev/null +++ b/examples/atari/benchmark/QbertNoFrameskip-v4/result.json @@ -0,0 +1 @@ +[{"env_step": 0, "rew": 95.0, "rew_std": 75.92759709091287, "Agent": "c51"}, {"env_step": 100000, "rew": 251.5, "rew_std": 82.12186067059123, "Agent": "c51"}, {"env_step": 200000, "rew": 255.5, "rew_std": 117.85478352616833, "Agent": "c51"}, {"env_step": 300000, "rew": 320.0, "rew_std": 42.175229697062704, "Agent": "c51"}, {"env_step": 400000, "rew": 332.5, "rew_std": 83.60621986431393, "Agent": "c51"}, {"env_step": 500000, "rew": 430.25, "rew_std": 135.99839153460601, "Agent": "c51"}, {"env_step": 600000, "rew": 523.75, "rew_std": 108.9681719586045, "Agent": "c51"}, {"env_step": 700000, "rew": 1121.0, "rew_std": 516.2099863427673, "Agent": "c51"}, {"env_step": 800000, "rew": 1754.75, "rew_std": 1004.4398003364861, "Agent": "c51"}, {"env_step": 900000, "rew": 2517.0, "rew_std": 1108.9954914245593, "Agent": "c51"}, {"env_step": 1000000, "rew": 2285.0, "rew_std": 1031.815753901829, "Agent": "c51"}, {"env_step": 1100000, "rew": 2877.75, "rew_std": 1017.4357780715203, "Agent": "c51"}, {"env_step": 1200000, "rew": 3304.5, "rew_std": 804.6146593245738, "Agent": "c51"}, {"env_step": 1300000, "rew": 3511.25, "rew_std": 1172.84443235239, "Agent": "c51"}, {"env_step": 1400000, "rew": 3495.0, "rew_std": 713.4371030441296, "Agent": "c51"}, {"env_step": 1500000, "rew": 3199.0, "rew_std": 1272.923897960911, "Agent": "c51"}, {"env_step": 1600000, "rew": 3992.0, "rew_std": 1011.4944389367645, "Agent": "c51"}, {"env_step": 1700000, "rew": 4453.75, "rew_std": 1278.8536126156114, "Agent": "c51"}, {"env_step": 1800000, "rew": 3931.5, "rew_std": 1001.7822867270114, "Agent": "c51"}, {"env_step": 1900000, "rew": 4928.25, "rew_std": 1331.802842953866, "Agent": "c51"}, {"env_step": 2000000, "rew": 4457.0, "rew_std": 1296.9019816470325, "Agent": "c51"}, {"env_step": 2100000, "rew": 5236.75, "rew_std": 1800.0354336790151, "Agent": "c51"}, {"env_step": 2200000, "rew": 4757.25, "rew_std": 1431.1256277839482, "Agent": "c51"}, {"env_step": 2300000, "rew": 4738.25, "rew_std": 1369.4538373015719, "Agent": "c51"}, {"env_step": 2400000, "rew": 6592.0, "rew_std": 1420.6903251588644, "Agent": "c51"}, {"env_step": 2500000, "rew": 5894.25, "rew_std": 1735.2507203571477, "Agent": "c51"}, {"env_step": 2600000, "rew": 7282.25, "rew_std": 2613.5424833929906, "Agent": "c51"}, {"env_step": 2700000, "rew": 7078.25, "rew_std": 2062.308188050467, "Agent": "c51"}, {"env_step": 2800000, "rew": 6066.25, "rew_std": 1512.1017037554054, "Agent": "c51"}, {"env_step": 2900000, "rew": 7065.5, "rew_std": 2015.3544353289324, "Agent": "c51"}, {"env_step": 3000000, "rew": 6861.0, "rew_std": 1669.891538394036, "Agent": "c51"}, {"env_step": 3100000, "rew": 7762.75, "rew_std": 2067.515553145852, "Agent": "c51"}, {"env_step": 3200000, "rew": 7553.5, "rew_std": 2434.32644688423, "Agent": "c51"}, {"env_step": 3300000, "rew": 6468.25, "rew_std": 1466.2916873869265, "Agent": "c51"}, {"env_step": 3400000, "rew": 7396.25, "rew_std": 2111.3411762431956, "Agent": "c51"}, {"env_step": 3500000, "rew": 7398.75, "rew_std": 2466.653413534216, "Agent": "c51"}, {"env_step": 3600000, "rew": 7548.75, "rew_std": 2775.422546658436, "Agent": "c51"}, {"env_step": 3700000, "rew": 8335.5, "rew_std": 2109.992239322221, "Agent": "c51"}, {"env_step": 3800000, "rew": 6925.0, "rew_std": 1951.7191140120547, "Agent": "c51"}, {"env_step": 3900000, "rew": 7580.5, "rew_std": 2267.1120947143304, "Agent": "c51"}, {"env_step": 4000000, "rew": 8586.75, "rew_std": 2490.3042108344916, "Agent": "c51"}, {"env_step": 4100000, "rew": 8712.75, "rew_std": 2641.0264221510547, "Agent": "c51"}, {"env_step": 4200000, "rew": 9052.75, "rew_std": 1451.2083112013934, "Agent": "c51"}, {"env_step": 4300000, "rew": 7919.75, "rew_std": 1486.0133116833106, "Agent": "c51"}, {"env_step": 4400000, "rew": 9568.25, "rew_std": 2890.1516071825713, "Agent": "c51"}, {"env_step": 4500000, "rew": 8489.0, "rew_std": 1878.7070687044322, "Agent": "c51"}, {"env_step": 4600000, "rew": 8453.75, "rew_std": 2539.73577611924, "Agent": "c51"}, {"env_step": 4700000, "rew": 8407.0, "rew_std": 2617.2267765709566, "Agent": "c51"}, {"env_step": 4800000, "rew": 8893.25, "rew_std": 2978.127568204559, "Agent": "c51"}, {"env_step": 4900000, "rew": 10263.75, "rew_std": 2290.572322040935, "Agent": "c51"}, {"env_step": 5000000, "rew": 8514.5, "rew_std": 1787.0897431298743, "Agent": "c51"}, {"env_step": 5100000, "rew": 8638.75, "rew_std": 2969.4349736102995, "Agent": "c51"}, {"env_step": 5200000, "rew": 10585.75, "rew_std": 3481.069708078251, "Agent": "c51"}, {"env_step": 5300000, "rew": 9607.5, "rew_std": 2606.770032051159, "Agent": "c51"}, {"env_step": 5400000, "rew": 9306.5, "rew_std": 2684.2033734424817, "Agent": "c51"}, {"env_step": 5500000, "rew": 9660.75, "rew_std": 2237.9474552589477, "Agent": "c51"}, {"env_step": 5600000, "rew": 9766.25, "rew_std": 2911.0542012302003, "Agent": "c51"}, {"env_step": 5700000, "rew": 10415.5, "rew_std": 1625.4448621838885, "Agent": "c51"}, {"env_step": 5800000, "rew": 9485.5, "rew_std": 3670.640407340387, "Agent": "c51"}, {"env_step": 5900000, "rew": 10269.0, "rew_std": 2380.1436931412354, "Agent": "c51"}, {"env_step": 6000000, "rew": 10933.5, "rew_std": 2768.2542332668795, "Agent": "c51"}, {"env_step": 6100000, "rew": 10309.25, "rew_std": 2190.8112224698866, "Agent": "c51"}, {"env_step": 6200000, "rew": 10257.0, "rew_std": 3413.4135773445328, "Agent": "c51"}, {"env_step": 6300000, "rew": 9958.0, "rew_std": 2849.388925717232, "Agent": "c51"}, {"env_step": 6400000, "rew": 11790.0, "rew_std": 1323.6403401226482, "Agent": "c51"}, {"env_step": 6500000, "rew": 10310.75, "rew_std": 2311.68581613073, "Agent": "c51"}, {"env_step": 6600000, "rew": 9120.75, "rew_std": 2925.0596254606503, "Agent": "c51"}, {"env_step": 6700000, "rew": 10305.5, "rew_std": 2839.6374768621436, "Agent": "c51"}, {"env_step": 6800000, "rew": 10348.75, "rew_std": 3006.7963121069574, "Agent": "c51"}, {"env_step": 6900000, "rew": 10654.25, "rew_std": 1407.1265268269233, "Agent": "c51"}, {"env_step": 7000000, "rew": 11493.75, "rew_std": 1194.8479244238574, "Agent": "c51"}, {"env_step": 7100000, "rew": 11250.5, "rew_std": 1915.5099843122719, "Agent": "c51"}, {"env_step": 7200000, "rew": 10615.75, "rew_std": 2852.864141963301, "Agent": "c51"}, {"env_step": 7300000, "rew": 10428.75, "rew_std": 1486.1473892249046, "Agent": "c51"}, {"env_step": 7400000, "rew": 11293.0, "rew_std": 2100.1969550496924, "Agent": "c51"}, {"env_step": 7500000, "rew": 10405.0, "rew_std": 2845.91066268778, "Agent": "c51"}, {"env_step": 7600000, "rew": 11912.75, "rew_std": 1889.1385106709354, "Agent": "c51"}, {"env_step": 7700000, "rew": 10792.75, "rew_std": 2319.9715650197095, "Agent": "c51"}, {"env_step": 7800000, "rew": 11481.75, "rew_std": 2059.718442530435, "Agent": "c51"}, {"env_step": 7900000, "rew": 11188.0, "rew_std": 1572.3460973971348, "Agent": "c51"}, {"env_step": 8000000, "rew": 11333.25, "rew_std": 2443.5376634093445, "Agent": "c51"}, {"env_step": 8100000, "rew": 11388.75, "rew_std": 1806.7637677626813, "Agent": "c51"}, {"env_step": 8200000, "rew": 11084.25, "rew_std": 2011.5637729139983, "Agent": "c51"}, {"env_step": 8300000, "rew": 11189.25, "rew_std": 1837.155767075835, "Agent": "c51"}, {"env_step": 8400000, "rew": 12201.5, "rew_std": 1443.038547648676, "Agent": "c51"}, {"env_step": 8500000, "rew": 12172.0, "rew_std": 2153.40886735427, "Agent": "c51"}, {"env_step": 8600000, "rew": 10667.0, "rew_std": 2920.304093754621, "Agent": "c51"}, {"env_step": 8700000, "rew": 12087.25, "rew_std": 1455.5503469478479, "Agent": "c51"}, {"env_step": 8800000, "rew": 11311.0, "rew_std": 2612.836868233453, "Agent": "c51"}, {"env_step": 8900000, "rew": 12494.75, "rew_std": 2119.100767424711, "Agent": "c51"}, {"env_step": 9000000, "rew": 12513.25, "rew_std": 1274.6416408151745, "Agent": "c51"}, {"env_step": 9100000, "rew": 12241.0, "rew_std": 1972.8945106112492, "Agent": "c51"}, {"env_step": 9200000, "rew": 10962.25, "rew_std": 1657.8398784261403, "Agent": "c51"}, {"env_step": 9300000, "rew": 11570.25, "rew_std": 2591.904813549294, "Agent": "c51"}, {"env_step": 9400000, "rew": 11239.25, "rew_std": 2040.6086867648094, "Agent": "c51"}, {"env_step": 9500000, "rew": 11834.25, "rew_std": 1834.925083620582, "Agent": "c51"}, {"env_step": 9600000, "rew": 11510.5, "rew_std": 1754.8346788230508, "Agent": "c51"}, {"env_step": 9700000, "rew": 10276.75, "rew_std": 2304.5601668214263, "Agent": "c51"}, {"env_step": 9800000, "rew": 12446.75, "rew_std": 1572.9002074194027, "Agent": "c51"}, {"env_step": 9900000, "rew": 10765.0, "rew_std": 2277.32930205537, "Agent": "c51"}, {"env_step": 10000000, "rew": 11854.5, "rew_std": 2126.8074078298673, "Agent": "c51"}, {"env_step": 0, "rew": 79.5, "rew_std": 76.44278906476399, "Agent": "dqn"}, {"env_step": 100000, "rew": 306.5, "rew_std": 140.31749712705113, "Agent": "dqn"}, {"env_step": 200000, "rew": 409.5, "rew_std": 96.2925230742242, "Agent": "dqn"}, {"env_step": 300000, "rew": 537.25, "rew_std": 147.0180686174322, "Agent": "dqn"}, {"env_step": 400000, "rew": 534.25, "rew_std": 124.05165254844451, "Agent": "dqn"}, {"env_step": 500000, "rew": 725.25, "rew_std": 251.90883767744236, "Agent": "dqn"}, {"env_step": 600000, "rew": 669.5, "rew_std": 160.39326669159152, "Agent": "dqn"}, {"env_step": 700000, "rew": 958.5, "rew_std": 439.7985334218385, "Agent": "dqn"}, {"env_step": 800000, "rew": 818.5, "rew_std": 111.63668751803773, "Agent": "dqn"}, {"env_step": 900000, "rew": 778.75, "rew_std": 199.5408792703891, "Agent": "dqn"}, {"env_step": 1000000, "rew": 850.0, "rew_std": 283.47618947629445, "Agent": "dqn"}, {"env_step": 1100000, "rew": 1346.0, "rew_std": 645.7797612189469, "Agent": "dqn"}, {"env_step": 1200000, "rew": 1157.5, "rew_std": 768.8619837135922, "Agent": "dqn"}, {"env_step": 1300000, "rew": 1414.5, "rew_std": 999.8636156996613, "Agent": "dqn"}, {"env_step": 1400000, "rew": 1861.25, "rew_std": 1166.1422779832656, "Agent": "dqn"}, {"env_step": 1500000, "rew": 2099.75, "rew_std": 986.7018609995625, "Agent": "dqn"}, {"env_step": 1600000, "rew": 2019.0, "rew_std": 728.7679671884598, "Agent": "dqn"}, {"env_step": 1700000, "rew": 3189.0, "rew_std": 1119.5803901462368, "Agent": "dqn"}, {"env_step": 1800000, "rew": 3215.5, "rew_std": 1019.3391241387726, "Agent": "dqn"}, {"env_step": 1900000, "rew": 4062.5, "rew_std": 644.8352502771542, "Agent": "dqn"}, {"env_step": 2000000, "rew": 3697.75, "rew_std": 775.0285881823974, "Agent": "dqn"}, {"env_step": 2100000, "rew": 4084.75, "rew_std": 369.5460898183067, "Agent": "dqn"}, {"env_step": 2200000, "rew": 4364.5, "rew_std": 82.35441700358275, "Agent": "dqn"}, {"env_step": 2300000, "rew": 3960.5, "rew_std": 493.58357954859076, "Agent": "dqn"}, {"env_step": 2400000, "rew": 4298.5, "rew_std": 337.0908631215032, "Agent": "dqn"}, {"env_step": 2500000, "rew": 3868.5, "rew_std": 810.0564795123856, "Agent": "dqn"}, {"env_step": 2600000, "rew": 3593.0, "rew_std": 1069.2274079913964, "Agent": "dqn"}, {"env_step": 2700000, "rew": 3861.5, "rew_std": 863.5603626846244, "Agent": "dqn"}, {"env_step": 2800000, "rew": 4479.75, "rew_std": 226.15108334916283, "Agent": "dqn"}, {"env_step": 2900000, "rew": 4399.25, "rew_std": 278.67106505699513, "Agent": "dqn"}, {"env_step": 3000000, "rew": 4731.0, "rew_std": 975.6428649869787, "Agent": "dqn"}, {"env_step": 3100000, "rew": 4451.0, "rew_std": 1066.7041529871344, "Agent": "dqn"}, {"env_step": 3200000, "rew": 4260.0, "rew_std": 1112.3870729202133, "Agent": "dqn"}, {"env_step": 3300000, "rew": 4400.75, "rew_std": 758.1804287239285, "Agent": "dqn"}, {"env_step": 3400000, "rew": 4580.5, "rew_std": 901.3668786903588, "Agent": "dqn"}, {"env_step": 3500000, "rew": 4537.0, "rew_std": 1127.5176273566635, "Agent": "dqn"}, {"env_step": 3600000, "rew": 5060.75, "rew_std": 1816.7983686969778, "Agent": "dqn"}, {"env_step": 3700000, "rew": 5504.0, "rew_std": 1962.111808740776, "Agent": "dqn"}, {"env_step": 3800000, "rew": 5938.25, "rew_std": 1861.7293875587827, "Agent": "dqn"}, {"env_step": 3900000, "rew": 5781.75, "rew_std": 1370.2176150159507, "Agent": "dqn"}, {"env_step": 4000000, "rew": 5990.25, "rew_std": 3394.9163189245182, "Agent": "dqn"}, {"env_step": 4100000, "rew": 6092.75, "rew_std": 2065.6846473990167, "Agent": "dqn"}, {"env_step": 4200000, "rew": 6176.0, "rew_std": 1842.3508080710362, "Agent": "dqn"}, {"env_step": 4300000, "rew": 6576.5, "rew_std": 2726.7487966440913, "Agent": "dqn"}, {"env_step": 4400000, "rew": 6971.25, "rew_std": 3082.8676281183402, "Agent": "dqn"}, {"env_step": 4500000, "rew": 6908.25, "rew_std": 2762.2427595162594, "Agent": "dqn"}, {"env_step": 4600000, "rew": 7546.0, "rew_std": 2864.2300885229174, "Agent": "dqn"}, {"env_step": 4700000, "rew": 7737.75, "rew_std": 3928.65680767613, "Agent": "dqn"}, {"env_step": 4800000, "rew": 8261.75, "rew_std": 3556.5873829416873, "Agent": "dqn"}, {"env_step": 4900000, "rew": 8120.5, "rew_std": 2792.5308413695275, "Agent": "dqn"}, {"env_step": 5000000, "rew": 7459.25, "rew_std": 3016.322481516192, "Agent": "dqn"}, {"env_step": 5100000, "rew": 8186.25, "rew_std": 3262.4464076058016, "Agent": "dqn"}, {"env_step": 5200000, "rew": 8457.75, "rew_std": 3065.806062441002, "Agent": "dqn"}, {"env_step": 5300000, "rew": 7461.25, "rew_std": 2633.543062586978, "Agent": "dqn"}, {"env_step": 5400000, "rew": 8212.25, "rew_std": 2857.8948655435174, "Agent": "dqn"}, {"env_step": 5500000, "rew": 8331.0, "rew_std": 2962.497088606164, "Agent": "dqn"}, {"env_step": 5600000, "rew": 8116.0, "rew_std": 3106.8304186099376, "Agent": "dqn"}, {"env_step": 5700000, "rew": 8354.0, "rew_std": 2939.679446810485, "Agent": "dqn"}, {"env_step": 5800000, "rew": 8698.25, "rew_std": 2624.4728161099324, "Agent": "dqn"}, {"env_step": 5900000, "rew": 9697.25, "rew_std": 2572.896337301602, "Agent": "dqn"}, {"env_step": 6000000, "rew": 8455.0, "rew_std": 1774.5978417658464, "Agent": "dqn"}, {"env_step": 6100000, "rew": 9885.75, "rew_std": 3028.3760190075473, "Agent": "dqn"}, {"env_step": 6200000, "rew": 8983.5, "rew_std": 2107.2515274641514, "Agent": "dqn"}, {"env_step": 6300000, "rew": 9419.75, "rew_std": 2727.142838668338, "Agent": "dqn"}, {"env_step": 6400000, "rew": 8409.25, "rew_std": 3007.3811385489535, "Agent": "dqn"}, {"env_step": 6500000, "rew": 9823.75, "rew_std": 2742.98269453163, "Agent": "dqn"}, {"env_step": 6600000, "rew": 9702.25, "rew_std": 2529.285336315379, "Agent": "dqn"}, {"env_step": 6700000, "rew": 10412.5, "rew_std": 2968.082925054487, "Agent": "dqn"}, {"env_step": 6800000, "rew": 9085.25, "rew_std": 2521.6067422379724, "Agent": "dqn"}, {"env_step": 6900000, "rew": 9624.25, "rew_std": 2870.277654252285, "Agent": "dqn"}, {"env_step": 7000000, "rew": 10178.25, "rew_std": 2328.1741907554942, "Agent": "dqn"}, {"env_step": 7100000, "rew": 9411.75, "rew_std": 3466.6296762850225, "Agent": "dqn"}, {"env_step": 7200000, "rew": 10059.0, "rew_std": 2418.3835510522313, "Agent": "dqn"}, {"env_step": 7300000, "rew": 9972.25, "rew_std": 3165.8356815381308, "Agent": "dqn"}, {"env_step": 7400000, "rew": 9769.25, "rew_std": 3534.1402861940837, "Agent": "dqn"}, {"env_step": 7500000, "rew": 9630.75, "rew_std": 3561.6785105480812, "Agent": "dqn"}, {"env_step": 7600000, "rew": 10130.5, "rew_std": 2504.094846446516, "Agent": "dqn"}, {"env_step": 7700000, "rew": 9689.75, "rew_std": 2412.3324941848296, "Agent": "dqn"}, {"env_step": 7800000, "rew": 9682.5, "rew_std": 2696.419848614084, "Agent": "dqn"}, {"env_step": 7900000, "rew": 8600.25, "rew_std": 4069.30498519096, "Agent": "dqn"}, {"env_step": 8000000, "rew": 10808.25, "rew_std": 1838.3657994262187, "Agent": "dqn"}, {"env_step": 8100000, "rew": 10105.5, "rew_std": 3078.21819402069, "Agent": "dqn"}, {"env_step": 8200000, "rew": 9794.25, "rew_std": 3020.5171432223324, "Agent": "dqn"}, {"env_step": 8300000, "rew": 10248.5, "rew_std": 2272.298450908243, "Agent": "dqn"}, {"env_step": 8400000, "rew": 9916.5, "rew_std": 3159.7433044473723, "Agent": "dqn"}, {"env_step": 8500000, "rew": 10325.5, "rew_std": 2780.830316650047, "Agent": "dqn"}, {"env_step": 8600000, "rew": 10778.0, "rew_std": 1940.7523669958514, "Agent": "dqn"}, {"env_step": 8700000, "rew": 10993.0, "rew_std": 2580.0946688057784, "Agent": "dqn"}, {"env_step": 8800000, "rew": 10329.75, "rew_std": 2510.3706026202585, "Agent": "dqn"}, {"env_step": 8900000, "rew": 9983.0, "rew_std": 3615.9431342320636, "Agent": "dqn"}, {"env_step": 9000000, "rew": 11148.0, "rew_std": 1932.5183698997535, "Agent": "dqn"}, {"env_step": 9100000, "rew": 10034.75, "rew_std": 2345.046494741629, "Agent": "dqn"}, {"env_step": 9200000, "rew": 10810.75, "rew_std": 2402.0418527785896, "Agent": "dqn"}, {"env_step": 9300000, "rew": 10502.5, "rew_std": 2058.038811587381, "Agent": "dqn"}, {"env_step": 9400000, "rew": 10956.0, "rew_std": 1991.7147762669233, "Agent": "dqn"}, {"env_step": 9500000, "rew": 11620.25, "rew_std": 786.060947064539, "Agent": "dqn"}, {"env_step": 9600000, "rew": 10733.5, "rew_std": 2011.6753589980665, "Agent": "dqn"}, {"env_step": 9700000, "rew": 11486.25, "rew_std": 2341.8905957580514, "Agent": "dqn"}, {"env_step": 9800000, "rew": 11012.5, "rew_std": 2049.413025722243, "Agent": "dqn"}, {"env_step": 9900000, "rew": 10990.5, "rew_std": 1687.970601047305, "Agent": "dqn"}, {"env_step": 10000000, "rew": 11396.5, "rew_std": 1123.2326117060527, "Agent": "dqn"}, {"env_step": 0, "rew": 62.25, "rew_std": 64.61859252568102, "Agent": "fqf"}, {"env_step": 100000, "rew": 282.5, "rew_std": 133.41195598596101, "Agent": "fqf"}, {"env_step": 200000, "rew": 334.25, "rew_std": 97.66684442532174, "Agent": "fqf"}, {"env_step": 300000, "rew": 478.0, "rew_std": 103.5, "Agent": "fqf"}, {"env_step": 400000, "rew": 497.75, "rew_std": 127.49730389306278, "Agent": "fqf"}, {"env_step": 500000, "rew": 761.75, "rew_std": 323.0790344482291, "Agent": "fqf"}, {"env_step": 600000, "rew": 723.25, "rew_std": 85.77623505377233, "Agent": "fqf"}, {"env_step": 700000, "rew": 1184.75, "rew_std": 753.0441969101149, "Agent": "fqf"}, {"env_step": 800000, "rew": 1227.25, "rew_std": 684.0965301037567, "Agent": "fqf"}, {"env_step": 900000, "rew": 1899.75, "rew_std": 957.4160864013096, "Agent": "fqf"}, {"env_step": 1000000, "rew": 1912.5, "rew_std": 1270.665180132044, "Agent": "fqf"}, {"env_step": 1100000, "rew": 2567.5, "rew_std": 1188.7546004117082, "Agent": "fqf"}, {"env_step": 1200000, "rew": 3371.0, "rew_std": 1017.2175283586103, "Agent": "fqf"}, {"env_step": 1300000, "rew": 3156.25, "rew_std": 890.8782534667686, "Agent": "fqf"}, {"env_step": 1400000, "rew": 3885.0, "rew_std": 888.4551198569346, "Agent": "fqf"}, {"env_step": 1500000, "rew": 3952.75, "rew_std": 590.0110698114062, "Agent": "fqf"}, {"env_step": 1600000, "rew": 3700.0, "rew_std": 1213.4516883667022, "Agent": "fqf"}, {"env_step": 1700000, "rew": 4309.75, "rew_std": 1129.6019486969735, "Agent": "fqf"}, {"env_step": 1800000, "rew": 4612.75, "rew_std": 1088.452714866383, "Agent": "fqf"}, {"env_step": 1900000, "rew": 5602.25, "rew_std": 1122.1271374046703, "Agent": "fqf"}, {"env_step": 2000000, "rew": 6148.5, "rew_std": 2185.0435350354005, "Agent": "fqf"}, {"env_step": 2100000, "rew": 6673.75, "rew_std": 1807.6529153850304, "Agent": "fqf"}, {"env_step": 2200000, "rew": 6371.75, "rew_std": 2170.01153510759, "Agent": "fqf"}, {"env_step": 2300000, "rew": 6601.0, "rew_std": 2183.5877014674725, "Agent": "fqf"}, {"env_step": 2400000, "rew": 7732.0, "rew_std": 1939.5839760113508, "Agent": "fqf"}, {"env_step": 2500000, "rew": 8078.25, "rew_std": 2086.995762453772, "Agent": "fqf"}, {"env_step": 2600000, "rew": 9642.5, "rew_std": 2714.1858816227013, "Agent": "fqf"}, {"env_step": 2700000, "rew": 10048.5, "rew_std": 2313.9531110201865, "Agent": "fqf"}, {"env_step": 2800000, "rew": 9025.75, "rew_std": 3670.8754150066165, "Agent": "fqf"}, {"env_step": 2900000, "rew": 9993.5, "rew_std": 3190.4126300527337, "Agent": "fqf"}, {"env_step": 3000000, "rew": 10725.75, "rew_std": 1486.178846067996, "Agent": "fqf"}, {"env_step": 3100000, "rew": 12443.0, "rew_std": 1860.8062096843937, "Agent": "fqf"}, {"env_step": 3200000, "rew": 11651.5, "rew_std": 1916.6462245286687, "Agent": "fqf"}, {"env_step": 3300000, "rew": 11780.25, "rew_std": 2378.499645259591, "Agent": "fqf"}, {"env_step": 3400000, "rew": 12591.25, "rew_std": 1730.6852869600527, "Agent": "fqf"}, {"env_step": 3500000, "rew": 13177.25, "rew_std": 1040.0303180676995, "Agent": "fqf"}, {"env_step": 3600000, "rew": 12289.75, "rew_std": 3415.4978498748906, "Agent": "fqf"}, {"env_step": 3700000, "rew": 12660.0, "rew_std": 1981.2193215290426, "Agent": "fqf"}, {"env_step": 3800000, "rew": 12749.0, "rew_std": 2114.099453668157, "Agent": "fqf"}, {"env_step": 3900000, "rew": 13807.25, "rew_std": 1109.9293051811903, "Agent": "fqf"}, {"env_step": 4000000, "rew": 14015.25, "rew_std": 1171.8481396921702, "Agent": "fqf"}, {"env_step": 4100000, "rew": 13752.25, "rew_std": 1630.1025466209173, "Agent": "fqf"}, {"env_step": 4200000, "rew": 14020.5, "rew_std": 1309.8782386160938, "Agent": "fqf"}, {"env_step": 4300000, "rew": 13418.75, "rew_std": 1649.8266007371806, "Agent": "fqf"}, {"env_step": 4400000, "rew": 14221.5, "rew_std": 1284.3087634988715, "Agent": "fqf"}, {"env_step": 4500000, "rew": 14305.75, "rew_std": 859.7587234218679, "Agent": "fqf"}, {"env_step": 4600000, "rew": 14158.0, "rew_std": 1344.8414404679831, "Agent": "fqf"}, {"env_step": 4700000, "rew": 12771.5, "rew_std": 1663.6489263062685, "Agent": "fqf"}, {"env_step": 4800000, "rew": 14314.0, "rew_std": 1097.285970018755, "Agent": "fqf"}, {"env_step": 4900000, "rew": 14935.25, "rew_std": 337.25074128902963, "Agent": "fqf"}, {"env_step": 5000000, "rew": 14672.0, "rew_std": 807.711117912834, "Agent": "fqf"}, {"env_step": 5100000, "rew": 14673.0, "rew_std": 571.9405563517943, "Agent": "fqf"}, {"env_step": 5200000, "rew": 14309.75, "rew_std": 1108.4434187183394, "Agent": "fqf"}, {"env_step": 5300000, "rew": 14757.25, "rew_std": 947.0417427442151, "Agent": "fqf"}, {"env_step": 5400000, "rew": 14685.0, "rew_std": 655.1602857316674, "Agent": "fqf"}, {"env_step": 5500000, "rew": 14524.25, "rew_std": 979.248468214273, "Agent": "fqf"}, {"env_step": 5600000, "rew": 14862.5, "rew_std": 499.7686964986903, "Agent": "fqf"}, {"env_step": 5700000, "rew": 14338.0, "rew_std": 1270.7752555035056, "Agent": "fqf"}, {"env_step": 5800000, "rew": 14777.75, "rew_std": 538.4253081904676, "Agent": "fqf"}, {"env_step": 5900000, "rew": 14932.0, "rew_std": 720.0848908288522, "Agent": "fqf"}, {"env_step": 6000000, "rew": 15026.25, "rew_std": 556.9619039934419, "Agent": "fqf"}, {"env_step": 6100000, "rew": 15113.75, "rew_std": 255.20151351432068, "Agent": "fqf"}, {"env_step": 6200000, "rew": 14408.5, "rew_std": 1393.7912325739462, "Agent": "fqf"}, {"env_step": 6300000, "rew": 15156.5, "rew_std": 590.047879413188, "Agent": "fqf"}, {"env_step": 6400000, "rew": 14545.5, "rew_std": 1392.182100157878, "Agent": "fqf"}, {"env_step": 6500000, "rew": 14554.75, "rew_std": 1060.3109508535692, "Agent": "fqf"}, {"env_step": 6600000, "rew": 13926.25, "rew_std": 1543.7536437203962, "Agent": "fqf"}, {"env_step": 6700000, "rew": 14911.25, "rew_std": 508.5976430342555, "Agent": "fqf"}, {"env_step": 6800000, "rew": 14964.0, "rew_std": 1249.9880999433556, "Agent": "fqf"}, {"env_step": 6900000, "rew": 15271.75, "rew_std": 499.26827708157066, "Agent": "fqf"}, {"env_step": 7000000, "rew": 14915.25, "rew_std": 710.6022533738548, "Agent": "fqf"}, {"env_step": 7100000, "rew": 14988.5, "rew_std": 568.0396112948463, "Agent": "fqf"}, {"env_step": 7200000, "rew": 14881.25, "rew_std": 963.4282861220133, "Agent": "fqf"}, {"env_step": 7300000, "rew": 15227.75, "rew_std": 746.1756244343553, "Agent": "fqf"}, {"env_step": 7400000, "rew": 15052.0, "rew_std": 1012.3807337163228, "Agent": "fqf"}, {"env_step": 7500000, "rew": 15262.75, "rew_std": 626.2052878250071, "Agent": "fqf"}, {"env_step": 7600000, "rew": 14771.75, "rew_std": 516.1831675868557, "Agent": "fqf"}, {"env_step": 7700000, "rew": 14902.25, "rew_std": 1191.0822022429854, "Agent": "fqf"}, {"env_step": 7800000, "rew": 15195.0, "rew_std": 983.0596370515881, "Agent": "fqf"}, {"env_step": 7900000, "rew": 15172.75, "rew_std": 897.3812247311619, "Agent": "fqf"}, {"env_step": 8000000, "rew": 14729.5, "rew_std": 1125.9345007592583, "Agent": "fqf"}, {"env_step": 8100000, "rew": 14950.75, "rew_std": 407.5706227146407, "Agent": "fqf"}, {"env_step": 8200000, "rew": 14679.25, "rew_std": 1469.804004791115, "Agent": "fqf"}, {"env_step": 8300000, "rew": 14879.75, "rew_std": 1249.1259193932372, "Agent": "fqf"}, {"env_step": 8400000, "rew": 14759.25, "rew_std": 824.2845761628687, "Agent": "fqf"}, {"env_step": 8500000, "rew": 14181.25, "rew_std": 1934.2803086678, "Agent": "fqf"}, {"env_step": 8600000, "rew": 15150.75, "rew_std": 606.5559022052296, "Agent": "fqf"}, {"env_step": 8700000, "rew": 15301.25, "rew_std": 684.131977399098, "Agent": "fqf"}, {"env_step": 8800000, "rew": 15258.75, "rew_std": 178.02826320559328, "Agent": "fqf"}, {"env_step": 8900000, "rew": 14306.75, "rew_std": 2652.5966169962594, "Agent": "fqf"}, {"env_step": 9000000, "rew": 14469.5, "rew_std": 1781.5501676910476, "Agent": "fqf"}, {"env_step": 9100000, "rew": 14648.25, "rew_std": 983.8413553515628, "Agent": "fqf"}, {"env_step": 9200000, "rew": 15119.25, "rew_std": 669.5624037384417, "Agent": "fqf"}, {"env_step": 9300000, "rew": 14687.75, "rew_std": 914.5568940749395, "Agent": "fqf"}, {"env_step": 9400000, "rew": 14220.0, "rew_std": 3311.433790671346, "Agent": "fqf"}, {"env_step": 9500000, "rew": 15234.75, "rew_std": 382.4288332487497, "Agent": "fqf"}, {"env_step": 9600000, "rew": 14718.75, "rew_std": 632.6375838503432, "Agent": "fqf"}, {"env_step": 9700000, "rew": 14343.5, "rew_std": 1404.7336046382602, "Agent": "fqf"}, {"env_step": 9800000, "rew": 15267.5, "rew_std": 387.3209263646879, "Agent": "fqf"}, {"env_step": 9900000, "rew": 15137.75, "rew_std": 331.75, "Agent": "fqf"}, {"env_step": 10000000, "rew": 14602.75, "rew_std": 1270.1847552620052, "Agent": "fqf"}, {"env_step": 0, "rew": 63.5, "rew_std": 62.13091018164791, "Agent": "qrdqn"}, {"env_step": 100000, "rew": 270.75, "rew_std": 151.2119786921658, "Agent": "qrdqn"}, {"env_step": 200000, "rew": 330.0, "rew_std": 140.30324301312496, "Agent": "qrdqn"}, {"env_step": 300000, "rew": 482.5, "rew_std": 145.48625364617786, "Agent": "qrdqn"}, {"env_step": 400000, "rew": 655.25, "rew_std": 164.68701375639793, "Agent": "qrdqn"}, {"env_step": 500000, "rew": 624.5, "rew_std": 130.56033088193365, "Agent": "qrdqn"}, {"env_step": 600000, "rew": 676.5, "rew_std": 131.37351331223505, "Agent": "qrdqn"}, {"env_step": 700000, "rew": 628.25, "rew_std": 158.0587311729409, "Agent": "qrdqn"}, {"env_step": 800000, "rew": 1161.25, "rew_std": 710.1436914456116, "Agent": "qrdqn"}, {"env_step": 900000, "rew": 1550.25, "rew_std": 826.7983808039297, "Agent": "qrdqn"}, {"env_step": 1000000, "rew": 1962.5, "rew_std": 961.0228925473108, "Agent": "qrdqn"}, {"env_step": 1100000, "rew": 2176.0, "rew_std": 1403.9837071704214, "Agent": "qrdqn"}, {"env_step": 1200000, "rew": 2638.5, "rew_std": 1025.882668729714, "Agent": "qrdqn"}, {"env_step": 1300000, "rew": 3701.0, "rew_std": 630.7289433663243, "Agent": "qrdqn"}, {"env_step": 1400000, "rew": 3190.25, "rew_std": 947.5115104841735, "Agent": "qrdqn"}, {"env_step": 1500000, "rew": 3946.75, "rew_std": 637.7578400145309, "Agent": "qrdqn"}, {"env_step": 1600000, "rew": 4426.5, "rew_std": 815.5735711755256, "Agent": "qrdqn"}, {"env_step": 1700000, "rew": 4326.25, "rew_std": 986.4046139896143, "Agent": "qrdqn"}, {"env_step": 1800000, "rew": 4494.5, "rew_std": 949.5484453149297, "Agent": "qrdqn"}, {"env_step": 1900000, "rew": 4857.5, "rew_std": 1134.8067016016428, "Agent": "qrdqn"}, {"env_step": 2000000, "rew": 4661.0, "rew_std": 2612.279225121235, "Agent": "qrdqn"}, {"env_step": 2100000, "rew": 6238.5, "rew_std": 2523.3789846156683, "Agent": "qrdqn"}, {"env_step": 2200000, "rew": 6793.5, "rew_std": 2207.1540499022717, "Agent": "qrdqn"}, {"env_step": 2300000, "rew": 8352.75, "rew_std": 2463.5217296585797, "Agent": "qrdqn"}, {"env_step": 2400000, "rew": 10017.0, "rew_std": 1099.753836092423, "Agent": "qrdqn"}, {"env_step": 2500000, "rew": 9378.25, "rew_std": 2206.291869291096, "Agent": "qrdqn"}, {"env_step": 2600000, "rew": 9277.75, "rew_std": 2164.6920826066694, "Agent": "qrdqn"}, {"env_step": 2700000, "rew": 9680.25, "rew_std": 1852.4255889238843, "Agent": "qrdqn"}, {"env_step": 2800000, "rew": 9750.0, "rew_std": 3101.0985956592867, "Agent": "qrdqn"}, {"env_step": 2900000, "rew": 11197.0, "rew_std": 2089.198650200598, "Agent": "qrdqn"}, {"env_step": 3000000, "rew": 10168.5, "rew_std": 1820.62976741566, "Agent": "qrdqn"}, {"env_step": 3100000, "rew": 10809.0, "rew_std": 1863.6564195151423, "Agent": "qrdqn"}, {"env_step": 3200000, "rew": 11434.75, "rew_std": 1928.14951767232, "Agent": "qrdqn"}, {"env_step": 3300000, "rew": 12635.0, "rew_std": 2041.877812211103, "Agent": "qrdqn"}, {"env_step": 3400000, "rew": 11676.0, "rew_std": 3368.622715591641, "Agent": "qrdqn"}, {"env_step": 3500000, "rew": 11960.0, "rew_std": 1950.4877595104256, "Agent": "qrdqn"}, {"env_step": 3600000, "rew": 11736.0, "rew_std": 2031.8129835198908, "Agent": "qrdqn"}, {"env_step": 3700000, "rew": 12507.25, "rew_std": 1577.2018141315968, "Agent": "qrdqn"}, {"env_step": 3800000, "rew": 12923.5, "rew_std": 4095.112208474879, "Agent": "qrdqn"}, {"env_step": 3900000, "rew": 13316.75, "rew_std": 1166.7872824555468, "Agent": "qrdqn"}, {"env_step": 4000000, "rew": 13060.0, "rew_std": 2080.1246957814815, "Agent": "qrdqn"}, {"env_step": 4100000, "rew": 12532.75, "rew_std": 1183.963919424912, "Agent": "qrdqn"}, {"env_step": 4200000, "rew": 12320.25, "rew_std": 2122.9921249265153, "Agent": "qrdqn"}, {"env_step": 4300000, "rew": 12833.5, "rew_std": 1463.2879074194525, "Agent": "qrdqn"}, {"env_step": 4400000, "rew": 12643.5, "rew_std": 1230.7717091321201, "Agent": "qrdqn"}, {"env_step": 4500000, "rew": 12753.5, "rew_std": 2467.2244526998347, "Agent": "qrdqn"}, {"env_step": 4600000, "rew": 14206.0, "rew_std": 934.261874422798, "Agent": "qrdqn"}, {"env_step": 4700000, "rew": 13566.0, "rew_std": 1616.8879058240248, "Agent": "qrdqn"}, {"env_step": 4800000, "rew": 13339.0, "rew_std": 2508.4644705476694, "Agent": "qrdqn"}, {"env_step": 4900000, "rew": 13325.5, "rew_std": 1697.8286868821601, "Agent": "qrdqn"}, {"env_step": 5000000, "rew": 13318.25, "rew_std": 1575.3479972691748, "Agent": "qrdqn"}, {"env_step": 5100000, "rew": 12695.25, "rew_std": 1818.0165875205869, "Agent": "qrdqn"}, {"env_step": 5200000, "rew": 13957.5, "rew_std": 1218.8980679285696, "Agent": "qrdqn"}, {"env_step": 5300000, "rew": 13959.75, "rew_std": 1010.0305997839868, "Agent": "qrdqn"}, {"env_step": 5400000, "rew": 13414.0, "rew_std": 1498.4079884997943, "Agent": "qrdqn"}, {"env_step": 5500000, "rew": 12775.5, "rew_std": 1296.2314608124584, "Agent": "qrdqn"}, {"env_step": 5600000, "rew": 14213.75, "rew_std": 1282.7033220897185, "Agent": "qrdqn"}, {"env_step": 5700000, "rew": 12620.5, "rew_std": 2257.2158735929534, "Agent": "qrdqn"}, {"env_step": 5800000, "rew": 12587.5, "rew_std": 1430.5497195134462, "Agent": "qrdqn"}, {"env_step": 5900000, "rew": 13289.5, "rew_std": 1792.5658286378216, "Agent": "qrdqn"}, {"env_step": 6000000, "rew": 13572.75, "rew_std": 2379.9851496385436, "Agent": "qrdqn"}, {"env_step": 6100000, "rew": 12327.75, "rew_std": 2985.5888133666363, "Agent": "qrdqn"}, {"env_step": 6200000, "rew": 13057.75, "rew_std": 2234.581182794664, "Agent": "qrdqn"}, {"env_step": 6300000, "rew": 13167.75, "rew_std": 2580.7533904850343, "Agent": "qrdqn"}, {"env_step": 6400000, "rew": 14265.0, "rew_std": 1022.666856801373, "Agent": "qrdqn"}, {"env_step": 6500000, "rew": 13314.5, "rew_std": 1621.0269892879637, "Agent": "qrdqn"}, {"env_step": 6600000, "rew": 14761.5, "rew_std": 862.8928091020344, "Agent": "qrdqn"}, {"env_step": 6700000, "rew": 12912.5, "rew_std": 2490.064005201473, "Agent": "qrdqn"}, {"env_step": 6800000, "rew": 13582.25, "rew_std": 1415.470085342675, "Agent": "qrdqn"}, {"env_step": 6900000, "rew": 14093.0, "rew_std": 1151.0564712471755, "Agent": "qrdqn"}, {"env_step": 7000000, "rew": 13608.75, "rew_std": 1454.381419195116, "Agent": "qrdqn"}, {"env_step": 7100000, "rew": 14457.25, "rew_std": 1426.2934699773396, "Agent": "qrdqn"}, {"env_step": 7200000, "rew": 14363.5, "rew_std": 1147.3579432766394, "Agent": "qrdqn"}, {"env_step": 7300000, "rew": 14335.75, "rew_std": 1048.029132467223, "Agent": "qrdqn"}, {"env_step": 7400000, "rew": 14255.0, "rew_std": 996.4361494847525, "Agent": "qrdqn"}, {"env_step": 7500000, "rew": 13165.0, "rew_std": 2007.4355531373853, "Agent": "qrdqn"}, {"env_step": 7600000, "rew": 13882.25, "rew_std": 1050.5673764685444, "Agent": "qrdqn"}, {"env_step": 7700000, "rew": 14029.25, "rew_std": 1288.9787866757156, "Agent": "qrdqn"}, {"env_step": 7800000, "rew": 13062.75, "rew_std": 2194.7472662017376, "Agent": "qrdqn"}, {"env_step": 7900000, "rew": 13878.75, "rew_std": 1196.7911524154915, "Agent": "qrdqn"}, {"env_step": 8000000, "rew": 14246.25, "rew_std": 1554.0568884374857, "Agent": "qrdqn"}, {"env_step": 8100000, "rew": 14211.5, "rew_std": 1194.981276003938, "Agent": "qrdqn"}, {"env_step": 8200000, "rew": 14197.0, "rew_std": 1123.807034147767, "Agent": "qrdqn"}, {"env_step": 8300000, "rew": 13508.0, "rew_std": 1345.116073058381, "Agent": "qrdqn"}, {"env_step": 8400000, "rew": 11739.5, "rew_std": 2172.110954808709, "Agent": "qrdqn"}, {"env_step": 8500000, "rew": 13295.5, "rew_std": 1875.5738455203516, "Agent": "qrdqn"}, {"env_step": 8600000, "rew": 14682.0, "rew_std": 657.4094994750228, "Agent": "qrdqn"}, {"env_step": 8700000, "rew": 13262.75, "rew_std": 2101.4068055709727, "Agent": "qrdqn"}, {"env_step": 8800000, "rew": 13034.25, "rew_std": 2962.464567298654, "Agent": "qrdqn"}, {"env_step": 8900000, "rew": 13833.25, "rew_std": 1593.2596500570771, "Agent": "qrdqn"}, {"env_step": 9000000, "rew": 13900.75, "rew_std": 1380.3591244672525, "Agent": "qrdqn"}, {"env_step": 9100000, "rew": 13849.5, "rew_std": 1837.9743605393412, "Agent": "qrdqn"}, {"env_step": 9200000, "rew": 12643.25, "rew_std": 2829.352631345199, "Agent": "qrdqn"}, {"env_step": 9300000, "rew": 13530.75, "rew_std": 1416.5393790855233, "Agent": "qrdqn"}, {"env_step": 9400000, "rew": 13982.5, "rew_std": 1845.6198145880423, "Agent": "qrdqn"}, {"env_step": 9500000, "rew": 13809.25, "rew_std": 1238.4738844642627, "Agent": "qrdqn"}, {"env_step": 9600000, "rew": 12931.5, "rew_std": 1797.4736437567033, "Agent": "qrdqn"}, {"env_step": 9700000, "rew": 14342.75, "rew_std": 649.1700951984773, "Agent": "qrdqn"}, {"env_step": 9800000, "rew": 14729.75, "rew_std": 626.4367984880837, "Agent": "qrdqn"}, {"env_step": 9900000, "rew": 13490.75, "rew_std": 1119.1950511416676, "Agent": "qrdqn"}, {"env_step": 10000000, "rew": 14191.5, "rew_std": 1683.4767595663445, "Agent": "qrdqn"}, {"env_step": 0, "rew": 74.75, "rew_std": 68.97871048374273, "Agent": "iqn"}, {"env_step": 100000, "rew": 305.25, "rew_std": 107.32223674523374, "Agent": "iqn"}, {"env_step": 200000, "rew": 278.5, "rew_std": 60.28266749240614, "Agent": "iqn"}, {"env_step": 300000, "rew": 480.75, "rew_std": 128.17980535170116, "Agent": "iqn"}, {"env_step": 400000, "rew": 580.5, "rew_std": 164.20566372692508, "Agent": "iqn"}, {"env_step": 500000, "rew": 603.5, "rew_std": 163.69254717304634, "Agent": "iqn"}, {"env_step": 600000, "rew": 681.5, "rew_std": 165.10299815569672, "Agent": "iqn"}, {"env_step": 700000, "rew": 779.5, "rew_std": 202.94642150084834, "Agent": "iqn"}, {"env_step": 800000, "rew": 1212.0, "rew_std": 518.2033867122059, "Agent": "iqn"}, {"env_step": 900000, "rew": 1937.0, "rew_std": 1077.2446333122296, "Agent": "iqn"}, {"env_step": 1000000, "rew": 2055.75, "rew_std": 1114.4051384034444, "Agent": "iqn"}, {"env_step": 1100000, "rew": 2164.0, "rew_std": 763.3292212407435, "Agent": "iqn"}, {"env_step": 1200000, "rew": 2717.0, "rew_std": 926.5607103692666, "Agent": "iqn"}, {"env_step": 1300000, "rew": 3349.25, "rew_std": 801.4120740917247, "Agent": "iqn"}, {"env_step": 1400000, "rew": 3172.25, "rew_std": 848.1453663730057, "Agent": "iqn"}, {"env_step": 1500000, "rew": 3463.5, "rew_std": 827.8875225536377, "Agent": "iqn"}, {"env_step": 1600000, "rew": 4035.75, "rew_std": 911.0859248720726, "Agent": "iqn"}, {"env_step": 1700000, "rew": 4497.0, "rew_std": 543.007596631944, "Agent": "iqn"}, {"env_step": 1800000, "rew": 4461.25, "rew_std": 499.12705045909905, "Agent": "iqn"}, {"env_step": 1900000, "rew": 4384.25, "rew_std": 471.4711682595236, "Agent": "iqn"}, {"env_step": 2000000, "rew": 5132.0, "rew_std": 1111.9947167140679, "Agent": "iqn"}, {"env_step": 2100000, "rew": 4575.75, "rew_std": 2275.0469912729277, "Agent": "iqn"}, {"env_step": 2200000, "rew": 5614.5, "rew_std": 1350.1304566596518, "Agent": "iqn"}, {"env_step": 2300000, "rew": 5378.75, "rew_std": 2386.0001178751018, "Agent": "iqn"}, {"env_step": 2400000, "rew": 6720.5, "rew_std": 2223.6897265580915, "Agent": "iqn"}, {"env_step": 2500000, "rew": 7193.75, "rew_std": 1491.2818521325873, "Agent": "iqn"}, {"env_step": 2600000, "rew": 8060.25, "rew_std": 2501.7125259509735, "Agent": "iqn"}, {"env_step": 2700000, "rew": 8047.0, "rew_std": 1672.755511125281, "Agent": "iqn"}, {"env_step": 2800000, "rew": 8176.0, "rew_std": 3218.092447397992, "Agent": "iqn"}, {"env_step": 2900000, "rew": 9079.25, "rew_std": 2817.5170917848927, "Agent": "iqn"}, {"env_step": 3000000, "rew": 9333.5, "rew_std": 1586.5446731813133, "Agent": "iqn"}, {"env_step": 3100000, "rew": 11244.75, "rew_std": 1804.940944324772, "Agent": "iqn"}, {"env_step": 3200000, "rew": 9774.75, "rew_std": 2385.623988079429, "Agent": "iqn"}, {"env_step": 3300000, "rew": 10427.5, "rew_std": 2821.736167681167, "Agent": "iqn"}, {"env_step": 3400000, "rew": 9773.25, "rew_std": 2530.4006723244443, "Agent": "iqn"}, {"env_step": 3500000, "rew": 10958.5, "rew_std": 1914.0373559572968, "Agent": "iqn"}, {"env_step": 3600000, "rew": 11481.25, "rew_std": 2320.765027420915, "Agent": "iqn"}, {"env_step": 3700000, "rew": 10402.0, "rew_std": 2840.605525235773, "Agent": "iqn"}, {"env_step": 3800000, "rew": 11571.25, "rew_std": 1838.5531845720427, "Agent": "iqn"}, {"env_step": 3900000, "rew": 12558.75, "rew_std": 1597.0246749815733, "Agent": "iqn"}, {"env_step": 4000000, "rew": 12249.5, "rew_std": 1836.1981102266716, "Agent": "iqn"}, {"env_step": 4100000, "rew": 12411.5, "rew_std": 1798.764228574718, "Agent": "iqn"}, {"env_step": 4200000, "rew": 12926.75, "rew_std": 1323.884459648953, "Agent": "iqn"}, {"env_step": 4300000, "rew": 11794.75, "rew_std": 2639.6958750015124, "Agent": "iqn"}, {"env_step": 4400000, "rew": 12201.0, "rew_std": 1702.2159087495334, "Agent": "iqn"}, {"env_step": 4500000, "rew": 12271.25, "rew_std": 1584.632548100663, "Agent": "iqn"}, {"env_step": 4600000, "rew": 12395.25, "rew_std": 1911.1424757196937, "Agent": "iqn"}, {"env_step": 4700000, "rew": 12780.0, "rew_std": 1188.934396844502, "Agent": "iqn"}, {"env_step": 4800000, "rew": 12680.5, "rew_std": 1798.4388368804762, "Agent": "iqn"}, {"env_step": 4900000, "rew": 11659.0, "rew_std": 1524.3105818697186, "Agent": "iqn"}, {"env_step": 5000000, "rew": 12834.25, "rew_std": 1934.9157119885094, "Agent": "iqn"}, {"env_step": 5100000, "rew": 13496.0, "rew_std": 1634.7783488901484, "Agent": "iqn"}, {"env_step": 5200000, "rew": 13142.75, "rew_std": 1530.6499640675527, "Agent": "iqn"}, {"env_step": 5300000, "rew": 12664.75, "rew_std": 2404.7719356521106, "Agent": "iqn"}, {"env_step": 5400000, "rew": 12944.25, "rew_std": 2205.103186361128, "Agent": "iqn"}, {"env_step": 5500000, "rew": 13810.25, "rew_std": 2059.245750875791, "Agent": "iqn"}, {"env_step": 5600000, "rew": 13504.0, "rew_std": 849.1313208214616, "Agent": "iqn"}, {"env_step": 5700000, "rew": 13502.25, "rew_std": 1435.742599667503, "Agent": "iqn"}, {"env_step": 5800000, "rew": 14175.25, "rew_std": 1070.7231493247916, "Agent": "iqn"}, {"env_step": 5900000, "rew": 13746.0, "rew_std": 1353.6211619208677, "Agent": "iqn"}, {"env_step": 6000000, "rew": 14359.75, "rew_std": 987.1046360442241, "Agent": "iqn"}, {"env_step": 6100000, "rew": 13638.25, "rew_std": 2135.9354069119227, "Agent": "iqn"}, {"env_step": 6200000, "rew": 14398.0, "rew_std": 724.5531381479208, "Agent": "iqn"}, {"env_step": 6300000, "rew": 13681.25, "rew_std": 1508.860600751441, "Agent": "iqn"}, {"env_step": 6400000, "rew": 12862.0, "rew_std": 2345.081501781974, "Agent": "iqn"}, {"env_step": 6500000, "rew": 12578.5, "rew_std": 3452.268855405094, "Agent": "iqn"}, {"env_step": 6600000, "rew": 13525.25, "rew_std": 1754.864115109771, "Agent": "iqn"}, {"env_step": 6700000, "rew": 14026.75, "rew_std": 1140.708688710663, "Agent": "iqn"}, {"env_step": 6800000, "rew": 14103.75, "rew_std": 1377.3363106010092, "Agent": "iqn"}, {"env_step": 6900000, "rew": 13723.5, "rew_std": 1402.6114928945933, "Agent": "iqn"}, {"env_step": 7000000, "rew": 13494.25, "rew_std": 997.4141880382492, "Agent": "iqn"}, {"env_step": 7100000, "rew": 14152.25, "rew_std": 709.2394253705867, "Agent": "iqn"}, {"env_step": 7200000, "rew": 13685.25, "rew_std": 1417.7761856160514, "Agent": "iqn"}, {"env_step": 7300000, "rew": 13408.25, "rew_std": 2077.5096419752185, "Agent": "iqn"}, {"env_step": 7400000, "rew": 14233.0, "rew_std": 909.7477672410084, "Agent": "iqn"}, {"env_step": 7500000, "rew": 14091.5, "rew_std": 743.8003764451857, "Agent": "iqn"}, {"env_step": 7600000, "rew": 13211.75, "rew_std": 1589.0996074821742, "Agent": "iqn"}, {"env_step": 7700000, "rew": 13444.5, "rew_std": 1892.1039215645635, "Agent": "iqn"}, {"env_step": 7800000, "rew": 13603.25, "rew_std": 2529.7435764322045, "Agent": "iqn"}, {"env_step": 7900000, "rew": 13292.25, "rew_std": 3160.3117824828614, "Agent": "iqn"}, {"env_step": 8000000, "rew": 14121.75, "rew_std": 818.3054518332382, "Agent": "iqn"}, {"env_step": 8100000, "rew": 14027.0, "rew_std": 721.3241296393737, "Agent": "iqn"}, {"env_step": 8200000, "rew": 14095.25, "rew_std": 599.5054732861078, "Agent": "iqn"}, {"env_step": 8300000, "rew": 14409.25, "rew_std": 808.6462839214684, "Agent": "iqn"}, {"env_step": 8400000, "rew": 13536.75, "rew_std": 753.1708388539747, "Agent": "iqn"}, {"env_step": 8500000, "rew": 13976.5, "rew_std": 988.6829623291786, "Agent": "iqn"}, {"env_step": 8600000, "rew": 13914.5, "rew_std": 1239.5683724587361, "Agent": "iqn"}, {"env_step": 8700000, "rew": 14257.0, "rew_std": 1150.6013645046662, "Agent": "iqn"}, {"env_step": 8800000, "rew": 13446.5, "rew_std": 1551.8111193054392, "Agent": "iqn"}, {"env_step": 8900000, "rew": 14032.5, "rew_std": 1186.8413963120768, "Agent": "iqn"}, {"env_step": 9000000, "rew": 14378.5, "rew_std": 943.7049326987753, "Agent": "iqn"}, {"env_step": 9100000, "rew": 14320.75, "rew_std": 647.3224177332344, "Agent": "iqn"}, {"env_step": 9200000, "rew": 13960.25, "rew_std": 1017.343630490701, "Agent": "iqn"}, {"env_step": 9300000, "rew": 13514.25, "rew_std": 1402.1367845185434, "Agent": "iqn"}, {"env_step": 9400000, "rew": 13712.25, "rew_std": 1607.3042065831844, "Agent": "iqn"}, {"env_step": 9500000, "rew": 14267.75, "rew_std": 724.1317645981289, "Agent": "iqn"}, {"env_step": 9600000, "rew": 14351.75, "rew_std": 780.0296869350551, "Agent": "iqn"}, {"env_step": 9700000, "rew": 13220.25, "rew_std": 1425.2001833075942, "Agent": "iqn"}, {"env_step": 9800000, "rew": 14156.5, "rew_std": 853.8107225843443, "Agent": "iqn"}, {"env_step": 9900000, "rew": 14273.75, "rew_std": 895.3443820675931, "Agent": "iqn"}, {"env_step": 10000000, "rew": 13774.75, "rew_std": 1513.8219223211163, "Agent": "iqn"}, {"env_step": 0, "rew": 45.5, "rew_std": 47.75981574503821, "Agent": "rainbow"}, {"env_step": 100000, "rew": 284.5, "rew_std": 61.47967143698801, "Agent": "rainbow"}, {"env_step": 200000, "rew": 285.0, "rew_std": 74.47314952383846, "Agent": "rainbow"}, {"env_step": 300000, "rew": 377.75, "rew_std": 92.13746523537534, "Agent": "rainbow"}, {"env_step": 400000, "rew": 395.75, "rew_std": 96.40442157909564, "Agent": "rainbow"}, {"env_step": 500000, "rew": 446.5, "rew_std": 135.95587519485872, "Agent": "rainbow"}, {"env_step": 600000, "rew": 509.0, "rew_std": 112.18400064180275, "Agent": "rainbow"}, {"env_step": 700000, "rew": 842.0, "rew_std": 379.03957577012983, "Agent": "rainbow"}, {"env_step": 800000, "rew": 841.25, "rew_std": 334.61031439571616, "Agent": "rainbow"}, {"env_step": 900000, "rew": 1965.0, "rew_std": 1128.5698914998575, "Agent": "rainbow"}, {"env_step": 1000000, "rew": 2198.25, "rew_std": 836.4859906178943, "Agent": "rainbow"}, {"env_step": 1100000, "rew": 3015.75, "rew_std": 848.7866707836546, "Agent": "rainbow"}, {"env_step": 1200000, "rew": 2877.0, "rew_std": 996.3312702108672, "Agent": "rainbow"}, {"env_step": 1300000, "rew": 3242.0, "rew_std": 876.3666470148211, "Agent": "rainbow"}, {"env_step": 1400000, "rew": 3739.5, "rew_std": 779.8091112573641, "Agent": "rainbow"}, {"env_step": 1500000, "rew": 3878.5, "rew_std": 610.4621200369438, "Agent": "rainbow"}, {"env_step": 1600000, "rew": 3686.75, "rew_std": 1020.1911891895558, "Agent": "rainbow"}, {"env_step": 1700000, "rew": 3802.5, "rew_std": 775.4450335130144, "Agent": "rainbow"}, {"env_step": 1800000, "rew": 4826.75, "rew_std": 1208.2508276430024, "Agent": "rainbow"}, {"env_step": 1900000, "rew": 5678.25, "rew_std": 1521.4220527191, "Agent": "rainbow"}, {"env_step": 2000000, "rew": 5642.5, "rew_std": 2018.9904655545058, "Agent": "rainbow"}, {"env_step": 2100000, "rew": 7018.0, "rew_std": 2637.9750283124363, "Agent": "rainbow"}, {"env_step": 2200000, "rew": 6920.25, "rew_std": 1881.316178237991, "Agent": "rainbow"}, {"env_step": 2300000, "rew": 7435.0, "rew_std": 1537.6528379318916, "Agent": "rainbow"}, {"env_step": 2400000, "rew": 7692.5, "rew_std": 1343.7070923382075, "Agent": "rainbow"}, {"env_step": 2500000, "rew": 8006.25, "rew_std": 1876.4488568836616, "Agent": "rainbow"}, {"env_step": 2600000, "rew": 9979.75, "rew_std": 2021.7954427933603, "Agent": "rainbow"}, {"env_step": 2700000, "rew": 9089.75, "rew_std": 1605.1473647301048, "Agent": "rainbow"}, {"env_step": 2800000, "rew": 8764.75, "rew_std": 1827.6663569973596, "Agent": "rainbow"}, {"env_step": 2900000, "rew": 9663.0, "rew_std": 2015.3541252097607, "Agent": "rainbow"}, {"env_step": 3000000, "rew": 9934.5, "rew_std": 2286.617261371041, "Agent": "rainbow"}, {"env_step": 3100000, "rew": 10924.25, "rew_std": 2628.7715881186787, "Agent": "rainbow"}, {"env_step": 3200000, "rew": 9174.75, "rew_std": 1997.3590219337134, "Agent": "rainbow"}, {"env_step": 3300000, "rew": 10324.25, "rew_std": 1182.9740328933683, "Agent": "rainbow"}, {"env_step": 3400000, "rew": 10506.5, "rew_std": 1664.2221155843351, "Agent": "rainbow"}, {"env_step": 3500000, "rew": 10675.0, "rew_std": 2079.8194032175006, "Agent": "rainbow"}, {"env_step": 3600000, "rew": 10794.25, "rew_std": 2335.775848085599, "Agent": "rainbow"}, {"env_step": 3700000, "rew": 10830.25, "rew_std": 2143.70282746933, "Agent": "rainbow"}, {"env_step": 3800000, "rew": 11664.75, "rew_std": 1526.7708775386043, "Agent": "rainbow"}, {"env_step": 3900000, "rew": 10242.5, "rew_std": 2334.839662589275, "Agent": "rainbow"}, {"env_step": 4000000, "rew": 11877.25, "rew_std": 1986.8088263594966, "Agent": "rainbow"}, {"env_step": 4100000, "rew": 11280.25, "rew_std": 2765.102721871287, "Agent": "rainbow"}, {"env_step": 4200000, "rew": 12994.5, "rew_std": 1754.1610530393154, "Agent": "rainbow"}, {"env_step": 4300000, "rew": 10860.25, "rew_std": 1974.6045331913933, "Agent": "rainbow"}, {"env_step": 4400000, "rew": 10636.25, "rew_std": 2674.058537597859, "Agent": "rainbow"}, {"env_step": 4500000, "rew": 12535.5, "rew_std": 1929.7117012652434, "Agent": "rainbow"}, {"env_step": 4600000, "rew": 12290.5, "rew_std": 1829.934015203827, "Agent": "rainbow"}, {"env_step": 4700000, "rew": 12177.5, "rew_std": 946.7675269040442, "Agent": "rainbow"}, {"env_step": 4800000, "rew": 13175.75, "rew_std": 1413.9178945398492, "Agent": "rainbow"}, {"env_step": 4900000, "rew": 12883.5, "rew_std": 1610.216677966043, "Agent": "rainbow"}, {"env_step": 5000000, "rew": 12284.5, "rew_std": 1809.4221453270654, "Agent": "rainbow"}, {"env_step": 5100000, "rew": 12318.0, "rew_std": 2168.0633062712905, "Agent": "rainbow"}, {"env_step": 5200000, "rew": 12730.25, "rew_std": 1575.0005753967203, "Agent": "rainbow"}, {"env_step": 5300000, "rew": 11980.25, "rew_std": 1916.1492798057254, "Agent": "rainbow"}, {"env_step": 5400000, "rew": 12032.75, "rew_std": 2195.3403637021756, "Agent": "rainbow"}, {"env_step": 5500000, "rew": 12618.0, "rew_std": 2118.3926099757805, "Agent": "rainbow"}, {"env_step": 5600000, "rew": 13014.25, "rew_std": 1486.7145195026515, "Agent": "rainbow"}, {"env_step": 5700000, "rew": 12690.0, "rew_std": 1458.9743829142444, "Agent": "rainbow"}, {"env_step": 5800000, "rew": 12033.5, "rew_std": 1977.8407418192194, "Agent": "rainbow"}, {"env_step": 5900000, "rew": 12640.25, "rew_std": 2624.1961745456456, "Agent": "rainbow"}, {"env_step": 6000000, "rew": 13131.25, "rew_std": 1906.204097807997, "Agent": "rainbow"}, {"env_step": 6100000, "rew": 13501.75, "rew_std": 1226.5200008560805, "Agent": "rainbow"}, {"env_step": 6200000, "rew": 13880.0, "rew_std": 1096.272662251504, "Agent": "rainbow"}, {"env_step": 6300000, "rew": 12978.75, "rew_std": 1734.9788363262533, "Agent": "rainbow"}, {"env_step": 6400000, "rew": 12417.0, "rew_std": 1276.8250663266288, "Agent": "rainbow"}, {"env_step": 6500000, "rew": 13424.5, "rew_std": 1740.133543725883, "Agent": "rainbow"}, {"env_step": 6600000, "rew": 13237.0, "rew_std": 1644.9296337533713, "Agent": "rainbow"}, {"env_step": 6700000, "rew": 13351.75, "rew_std": 1120.8969901378093, "Agent": "rainbow"}, {"env_step": 6800000, "rew": 12263.0, "rew_std": 2282.7893573433357, "Agent": "rainbow"}, {"env_step": 6900000, "rew": 12439.0, "rew_std": 2598.2990108915487, "Agent": "rainbow"}, {"env_step": 7000000, "rew": 14034.5, "rew_std": 744.4837137238128, "Agent": "rainbow"}, {"env_step": 7100000, "rew": 13683.25, "rew_std": 901.8238810876545, "Agent": "rainbow"}, {"env_step": 7200000, "rew": 14111.25, "rew_std": 1060.766379793402, "Agent": "rainbow"}, {"env_step": 7300000, "rew": 13421.75, "rew_std": 1568.2095881928537, "Agent": "rainbow"}, {"env_step": 7400000, "rew": 14206.5, "rew_std": 607.6617068731582, "Agent": "rainbow"}, {"env_step": 7500000, "rew": 13354.25, "rew_std": 1601.6690708445362, "Agent": "rainbow"}, {"env_step": 7600000, "rew": 13701.75, "rew_std": 1030.2663793893305, "Agent": "rainbow"}, {"env_step": 7700000, "rew": 13039.0, "rew_std": 2426.61317477673, "Agent": "rainbow"}, {"env_step": 7800000, "rew": 13988.25, "rew_std": 1832.0964937742772, "Agent": "rainbow"}, {"env_step": 7900000, "rew": 13303.0, "rew_std": 1248.9764809635128, "Agent": "rainbow"}, {"env_step": 8000000, "rew": 13551.25, "rew_std": 1319.6809510256637, "Agent": "rainbow"}, {"env_step": 8100000, "rew": 13257.25, "rew_std": 1674.4097654098891, "Agent": "rainbow"}, {"env_step": 8200000, "rew": 13652.5, "rew_std": 1983.858613913804, "Agent": "rainbow"}, {"env_step": 8300000, "rew": 13802.5, "rew_std": 1365.1304882684292, "Agent": "rainbow"}, {"env_step": 8400000, "rew": 13834.5, "rew_std": 1055.2753195256678, "Agent": "rainbow"}, {"env_step": 8500000, "rew": 14132.75, "rew_std": 759.6795788357089, "Agent": "rainbow"}, {"env_step": 8600000, "rew": 13816.5, "rew_std": 838.3877384599563, "Agent": "rainbow"}, {"env_step": 8700000, "rew": 13764.0, "rew_std": 1449.42367857021, "Agent": "rainbow"}, {"env_step": 8800000, "rew": 13053.75, "rew_std": 1003.1601629351118, "Agent": "rainbow"}, {"env_step": 8900000, "rew": 13302.75, "rew_std": 1787.7211226866455, "Agent": "rainbow"}, {"env_step": 9000000, "rew": 13252.75, "rew_std": 1108.8256456720326, "Agent": "rainbow"}, {"env_step": 9100000, "rew": 13711.75, "rew_std": 1272.3845969281458, "Agent": "rainbow"}, {"env_step": 9200000, "rew": 13983.5, "rew_std": 1598.7660085203213, "Agent": "rainbow"}, {"env_step": 9300000, "rew": 13033.25, "rew_std": 1330.3514808124958, "Agent": "rainbow"}, {"env_step": 9400000, "rew": 14224.75, "rew_std": 1230.1089636694792, "Agent": "rainbow"}, {"env_step": 9500000, "rew": 13983.25, "rew_std": 1469.7389436563217, "Agent": "rainbow"}, {"env_step": 9600000, "rew": 12979.5, "rew_std": 1610.91891478125, "Agent": "rainbow"}, {"env_step": 9700000, "rew": 13711.25, "rew_std": 1179.091413970944, "Agent": "rainbow"}, {"env_step": 9800000, "rew": 13414.0, "rew_std": 2159.0412223947924, "Agent": "rainbow"}, {"env_step": 9900000, "rew": 13838.75, "rew_std": 1349.4764030912138, "Agent": "rainbow"}, {"env_step": 10000000, "rew": 14035.0, "rew_std": 1246.866572653225, "Agent": "rainbow"}, {"env_step": 0, "rew": 120.5, "rew_std": 101.5, "Agent": "ppo"}, {"env_step": 100000, "rew": 273.0, "rew_std": 28.956864471140516, "Agent": "ppo"}, {"env_step": 200000, "rew": 355.0, "rew_std": 89.81230427953622, "Agent": "ppo"}, {"env_step": 300000, "rew": 391.5, "rew_std": 92.91931984253867, "Agent": "ppo"}, {"env_step": 400000, "rew": 474.0, "rew_std": 108.15035829806575, "Agent": "ppo"}, {"env_step": 500000, "rew": 542.75, "rew_std": 105.84452040611266, "Agent": "ppo"}, {"env_step": 600000, "rew": 621.75, "rew_std": 77.6695725493581, "Agent": "ppo"}, {"env_step": 700000, "rew": 641.75, "rew_std": 85.77623505377233, "Agent": "ppo"}, {"env_step": 800000, "rew": 672.25, "rew_std": 64.53148456373835, "Agent": "ppo"}, {"env_step": 900000, "rew": 744.75, "rew_std": 134.25465541276398, "Agent": "ppo"}, {"env_step": 1000000, "rew": 791.25, "rew_std": 143.7891251103504, "Agent": "ppo"}, {"env_step": 1100000, "rew": 995.0, "rew_std": 389.2460661329797, "Agent": "ppo"}, {"env_step": 1200000, "rew": 817.25, "rew_std": 168.63588734311568, "Agent": "ppo"}, {"env_step": 1300000, "rew": 1099.0, "rew_std": 655.2797112684018, "Agent": "ppo"}, {"env_step": 1400000, "rew": 1188.0, "rew_std": 663.5, "Agent": "ppo"}, {"env_step": 1500000, "rew": 1322.0, "rew_std": 450.0927682156202, "Agent": "ppo"}, {"env_step": 1600000, "rew": 1452.75, "rew_std": 704.6368302182337, "Agent": "ppo"}, {"env_step": 1700000, "rew": 1558.5, "rew_std": 423.40170051618827, "Agent": "ppo"}, {"env_step": 1800000, "rew": 1552.75, "rew_std": 663.306537356598, "Agent": "ppo"}, {"env_step": 1900000, "rew": 1814.25, "rew_std": 756.1093257591788, "Agent": "ppo"}, {"env_step": 2000000, "rew": 1824.0, "rew_std": 703.5785315087435, "Agent": "ppo"}, {"env_step": 2100000, "rew": 1752.25, "rew_std": 750.6567541160207, "Agent": "ppo"}, {"env_step": 2200000, "rew": 2510.75, "rew_std": 872.3481601402045, "Agent": "ppo"}, {"env_step": 2300000, "rew": 2298.25, "rew_std": 906.4326022931876, "Agent": "ppo"}, {"env_step": 2400000, "rew": 2231.0, "rew_std": 897.2521942018309, "Agent": "ppo"}, {"env_step": 2500000, "rew": 2028.0, "rew_std": 938.4780231843471, "Agent": "ppo"}, {"env_step": 2600000, "rew": 2503.25, "rew_std": 949.5203065232465, "Agent": "ppo"}, {"env_step": 2700000, "rew": 2804.5, "rew_std": 959.1681291619317, "Agent": "ppo"}, {"env_step": 2800000, "rew": 2946.25, "rew_std": 708.6265324555665, "Agent": "ppo"}, {"env_step": 2900000, "rew": 3231.75, "rew_std": 616.26298160769, "Agent": "ppo"}, {"env_step": 3000000, "rew": 2883.25, "rew_std": 727.0738012746712, "Agent": "ppo"}, {"env_step": 3100000, "rew": 3300.5, "rew_std": 795.6183130119617, "Agent": "ppo"}, {"env_step": 3200000, "rew": 3390.5, "rew_std": 828.5211222413102, "Agent": "ppo"}, {"env_step": 3300000, "rew": 3235.5, "rew_std": 996.1192197724126, "Agent": "ppo"}, {"env_step": 3400000, "rew": 3114.0, "rew_std": 1074.6076028020648, "Agent": "ppo"}, {"env_step": 3500000, "rew": 3412.75, "rew_std": 1089.8081539885816, "Agent": "ppo"}, {"env_step": 3600000, "rew": 3153.75, "rew_std": 1106.7566636347847, "Agent": "ppo"}, {"env_step": 3700000, "rew": 3294.75, "rew_std": 694.6846856668138, "Agent": "ppo"}, {"env_step": 3800000, "rew": 3217.0, "rew_std": 1153.753548206895, "Agent": "ppo"}, {"env_step": 3900000, "rew": 3735.5, "rew_std": 868.4992803681532, "Agent": "ppo"}, {"env_step": 4000000, "rew": 3744.0, "rew_std": 798.573885122723, "Agent": "ppo"}, {"env_step": 4100000, "rew": 3626.75, "rew_std": 879.9460565852886, "Agent": "ppo"}, {"env_step": 4200000, "rew": 3621.5, "rew_std": 977.4035758068414, "Agent": "ppo"}, {"env_step": 4300000, "rew": 3884.5, "rew_std": 623.5030072100695, "Agent": "ppo"}, {"env_step": 4400000, "rew": 3692.25, "rew_std": 711.3521016346265, "Agent": "ppo"}, {"env_step": 4500000, "rew": 3992.75, "rew_std": 715.286175247362, "Agent": "ppo"}, {"env_step": 4600000, "rew": 4163.0, "rew_std": 830.2919667201412, "Agent": "ppo"}, {"env_step": 4700000, "rew": 4100.75, "rew_std": 683.023654422012, "Agent": "ppo"}, {"env_step": 4800000, "rew": 4077.5, "rew_std": 490.6844709179209, "Agent": "ppo"}, {"env_step": 4900000, "rew": 4007.25, "rew_std": 496.95126773155533, "Agent": "ppo"}, {"env_step": 5000000, "rew": 4787.5, "rew_std": 1021.3936557468918, "Agent": "ppo"}, {"env_step": 5100000, "rew": 4553.0, "rew_std": 615.2263810988602, "Agent": "ppo"}, {"env_step": 5200000, "rew": 4548.75, "rew_std": 416.08029573629176, "Agent": "ppo"}, {"env_step": 5300000, "rew": 4595.0, "rew_std": 509.9178855462907, "Agent": "ppo"}, {"env_step": 5400000, "rew": 5037.5, "rew_std": 584.2281232532374, "Agent": "ppo"}, {"env_step": 5500000, "rew": 5001.75, "rew_std": 1064.1552107188124, "Agent": "ppo"}, {"env_step": 5600000, "rew": 5132.75, "rew_std": 1378.285370487549, "Agent": "ppo"}, {"env_step": 5700000, "rew": 5175.5, "rew_std": 1010.7384676561984, "Agent": "ppo"}, {"env_step": 5800000, "rew": 4833.5, "rew_std": 789.5474969879899, "Agent": "ppo"}, {"env_step": 5900000, "rew": 5724.0, "rew_std": 707.8031152799484, "Agent": "ppo"}, {"env_step": 6000000, "rew": 6142.5, "rew_std": 1675.208569104158, "Agent": "ppo"}, {"env_step": 6100000, "rew": 6317.0, "rew_std": 1503.244324785562, "Agent": "ppo"}, {"env_step": 6200000, "rew": 6381.75, "rew_std": 1400.6998473977214, "Agent": "ppo"}, {"env_step": 6300000, "rew": 6283.0, "rew_std": 1507.1785726980065, "Agent": "ppo"}, {"env_step": 6400000, "rew": 6748.0, "rew_std": 1430.778983630945, "Agent": "ppo"}, {"env_step": 6500000, "rew": 7201.75, "rew_std": 1294.4265380854952, "Agent": "ppo"}, {"env_step": 6600000, "rew": 6559.0, "rew_std": 1157.9767916499882, "Agent": "ppo"}, {"env_step": 6700000, "rew": 7433.5, "rew_std": 1716.4509896877335, "Agent": "ppo"}, {"env_step": 6800000, "rew": 7610.5, "rew_std": 1812.574412265604, "Agent": "ppo"}, {"env_step": 6900000, "rew": 8195.0, "rew_std": 1976.1841386874858, "Agent": "ppo"}, {"env_step": 7000000, "rew": 8271.5, "rew_std": 1789.7011622055788, "Agent": "ppo"}, {"env_step": 7100000, "rew": 7825.5, "rew_std": 1272.6767067877058, "Agent": "ppo"}, {"env_step": 7200000, "rew": 8352.75, "rew_std": 1310.4419340436264, "Agent": "ppo"}, {"env_step": 7300000, "rew": 8443.0, "rew_std": 1754.0131841009634, "Agent": "ppo"}, {"env_step": 7400000, "rew": 8361.25, "rew_std": 1613.0232678111001, "Agent": "ppo"}, {"env_step": 7500000, "rew": 8785.5, "rew_std": 1928.0082987373264, "Agent": "ppo"}, {"env_step": 7600000, "rew": 9088.0, "rew_std": 1135.7738551313814, "Agent": "ppo"}, {"env_step": 7700000, "rew": 8585.25, "rew_std": 1348.3320483100592, "Agent": "ppo"}, {"env_step": 7800000, "rew": 8759.25, "rew_std": 1379.0055520192802, "Agent": "ppo"}, {"env_step": 7900000, "rew": 9218.5, "rew_std": 1970.6262329523577, "Agent": "ppo"}, {"env_step": 8000000, "rew": 9573.25, "rew_std": 1635.5530601298144, "Agent": "ppo"}, {"env_step": 8100000, "rew": 10431.25, "rew_std": 1564.9469839262927, "Agent": "ppo"}, {"env_step": 8200000, "rew": 9307.5, "rew_std": 1389.4486316521384, "Agent": "ppo"}, {"env_step": 8300000, "rew": 9908.75, "rew_std": 1632.5357002222033, "Agent": "ppo"}, {"env_step": 8400000, "rew": 10750.5, "rew_std": 2245.378531562106, "Agent": "ppo"}, {"env_step": 8500000, "rew": 10358.5, "rew_std": 2260.2992611599025, "Agent": "ppo"}, {"env_step": 8600000, "rew": 10700.25, "rew_std": 1594.5439982954374, "Agent": "ppo"}, {"env_step": 8700000, "rew": 10038.25, "rew_std": 1889.7635599460584, "Agent": "ppo"}, {"env_step": 8800000, "rew": 9823.0, "rew_std": 1878.6184418343178, "Agent": "ppo"}, {"env_step": 8900000, "rew": 10836.5, "rew_std": 1715.2179890614486, "Agent": "ppo"}, {"env_step": 9000000, "rew": 10589.0, "rew_std": 1656.9747282321478, "Agent": "ppo"}, {"env_step": 9100000, "rew": 10209.75, "rew_std": 1596.2845336906576, "Agent": "ppo"}, {"env_step": 9200000, "rew": 11638.75, "rew_std": 2334.8370526655603, "Agent": "ppo"}, {"env_step": 9300000, "rew": 11236.5, "rew_std": 1308.257046608196, "Agent": "ppo"}, {"env_step": 9400000, "rew": 12341.75, "rew_std": 1760.6944830094742, "Agent": "ppo"}, {"env_step": 9500000, "rew": 11866.0, "rew_std": 1635.246617486182, "Agent": "ppo"}, {"env_step": 9600000, "rew": 11265.5, "rew_std": 1304.1528859761804, "Agent": "ppo"}, {"env_step": 9700000, "rew": 11678.5, "rew_std": 1495.5292541438298, "Agent": "ppo"}, {"env_step": 9800000, "rew": 11504.25, "rew_std": 1666.65422703691, "Agent": "ppo"}, {"env_step": 9900000, "rew": 11494.0, "rew_std": 1494.6768881601133, "Agent": "ppo"}, {"env_step": 10000000, "rew": 12188.5, "rew_std": 1292.4967117946567, "Agent": "ppo"}] \ No newline at end of file diff --git a/examples/atari/benchmark/SeaquestNoFrameskip-v4/result.json b/examples/atari/benchmark/SeaquestNoFrameskip-v4/result.json new file mode 100644 index 0000000000000000000000000000000000000000..9d225cf2bed38901b09ef95e764b511ed60244fb --- /dev/null +++ b/examples/atari/benchmark/SeaquestNoFrameskip-v4/result.json @@ -0,0 +1 @@ +[{"env_step": 0, "rew": 32.2, "rew_std": 46.315872009495834, "Agent": "c51"}, {"env_step": 100000, "rew": 150.4, "rew_std": 44.16152171291203, "Agent": "c51"}, {"env_step": 200000, "rew": 128.6, "rew_std": 49.11252386102755, "Agent": "c51"}, {"env_step": 300000, "rew": 247.0, "rew_std": 99.90095094642493, "Agent": "c51"}, {"env_step": 400000, "rew": 316.8, "rew_std": 101.15809409038903, "Agent": "c51"}, {"env_step": 500000, "rew": 294.4, "rew_std": 154.54138604270378, "Agent": "c51"}, {"env_step": 600000, "rew": 319.4, "rew_std": 168.48513287527777, "Agent": "c51"}, {"env_step": 700000, "rew": 447.6, "rew_std": 228.1434636363707, "Agent": "c51"}, {"env_step": 800000, "rew": 584.0, "rew_std": 225.6138293633615, "Agent": "c51"}, {"env_step": 900000, "rew": 728.2, "rew_std": 275.2329195427029, "Agent": "c51"}, {"env_step": 1000000, "rew": 972.4, "rew_std": 346.5323072961596, "Agent": "c51"}, {"env_step": 1100000, "rew": 1153.0, "rew_std": 393.3764100705582, "Agent": "c51"}, {"env_step": 1200000, "rew": 1589.2, "rew_std": 267.4624459620453, "Agent": "c51"}, {"env_step": 1300000, "rew": 1583.4, "rew_std": 262.64508371564847, "Agent": "c51"}, {"env_step": 1400000, "rew": 1678.6, "rew_std": 221.38482332806828, "Agent": "c51"}, {"env_step": 1500000, "rew": 1636.2, "rew_std": 262.5969535238366, "Agent": "c51"}, {"env_step": 1600000, "rew": 1672.2, "rew_std": 191.08835652650322, "Agent": "c51"}, {"env_step": 1700000, "rew": 1610.6, "rew_std": 330.96591969566896, "Agent": "c51"}, {"env_step": 1800000, "rew": 1730.3, "rew_std": 421.8265164733009, "Agent": "c51"}, {"env_step": 1900000, "rew": 1915.1, "rew_std": 466.08463823644735, "Agent": "c51"}, {"env_step": 2000000, "rew": 1765.0, "rew_std": 223.29218526406157, "Agent": "c51"}, {"env_step": 2100000, "rew": 1774.2, "rew_std": 280.23483009790203, "Agent": "c51"}, {"env_step": 2200000, "rew": 1940.0, "rew_std": 288.7337874236405, "Agent": "c51"}, {"env_step": 2300000, "rew": 1953.0, "rew_std": 183.59793027155834, "Agent": "c51"}, {"env_step": 2400000, "rew": 1960.0, "rew_std": 160.45186193996005, "Agent": "c51"}, {"env_step": 2500000, "rew": 1876.8, "rew_std": 178.058866670548, "Agent": "c51"}, {"env_step": 2600000, "rew": 1938.3, "rew_std": 339.1395140646398, "Agent": "c51"}, {"env_step": 2700000, "rew": 2020.2, "rew_std": 316.60505365518094, "Agent": "c51"}, {"env_step": 2800000, "rew": 1987.6, "rew_std": 300.36351309704713, "Agent": "c51"}, {"env_step": 2900000, "rew": 1866.7, "rew_std": 246.95750646619348, "Agent": "c51"}, {"env_step": 3000000, "rew": 1934.9, "rew_std": 380.61383316952634, "Agent": "c51"}, {"env_step": 3100000, "rew": 2063.6, "rew_std": 397.0740988782824, "Agent": "c51"}, {"env_step": 3200000, "rew": 2049.0, "rew_std": 507.7331976540435, "Agent": "c51"}, {"env_step": 3300000, "rew": 2166.0, "rew_std": 523.4004203284518, "Agent": "c51"}, {"env_step": 3400000, "rew": 2154.2, "rew_std": 581.4135877325194, "Agent": "c51"}, {"env_step": 3500000, "rew": 2041.9, "rew_std": 658.6993927430024, "Agent": "c51"}, {"env_step": 3600000, "rew": 2267.9, "rew_std": 511.7625328216203, "Agent": "c51"}, {"env_step": 3700000, "rew": 2240.1, "rew_std": 415.064440779983, "Agent": "c51"}, {"env_step": 3800000, "rew": 2300.7, "rew_std": 460.17693336367915, "Agent": "c51"}, {"env_step": 3900000, "rew": 2148.5, "rew_std": 605.8113980439787, "Agent": "c51"}, {"env_step": 4000000, "rew": 2083.7, "rew_std": 491.9182960614496, "Agent": "c51"}, {"env_step": 4100000, "rew": 2218.3, "rew_std": 504.38795584351533, "Agent": "c51"}, {"env_step": 4200000, "rew": 2268.6, "rew_std": 484.98725756456736, "Agent": "c51"}, {"env_step": 4300000, "rew": 2227.8, "rew_std": 526.6742446712199, "Agent": "c51"}, {"env_step": 4400000, "rew": 2411.4, "rew_std": 649.1550200067777, "Agent": "c51"}, {"env_step": 4500000, "rew": 2175.5, "rew_std": 458.5758933917046, "Agent": "c51"}, {"env_step": 4600000, "rew": 2318.9, "rew_std": 604.9889998999981, "Agent": "c51"}, {"env_step": 4700000, "rew": 2327.4, "rew_std": 395.89195495741006, "Agent": "c51"}, {"env_step": 4800000, "rew": 2369.6, "rew_std": 508.4897639087733, "Agent": "c51"}, {"env_step": 4900000, "rew": 2172.9, "rew_std": 587.6421445063313, "Agent": "c51"}, {"env_step": 5000000, "rew": 2279.5, "rew_std": 422.98540163934734, "Agent": "c51"}, {"env_step": 5100000, "rew": 2513.7, "rew_std": 757.7894232568834, "Agent": "c51"}, {"env_step": 5200000, "rew": 2347.0, "rew_std": 447.44988546204814, "Agent": "c51"}, {"env_step": 5300000, "rew": 2241.3, "rew_std": 527.0457380531599, "Agent": "c51"}, {"env_step": 5400000, "rew": 2434.3, "rew_std": 574.2675421787305, "Agent": "c51"}, {"env_step": 5500000, "rew": 2543.6, "rew_std": 672.8527624971157, "Agent": "c51"}, {"env_step": 5600000, "rew": 2479.0, "rew_std": 647.793794351258, "Agent": "c51"}, {"env_step": 5700000, "rew": 2461.7, "rew_std": 558.4800891705988, "Agent": "c51"}, {"env_step": 5800000, "rew": 2491.6, "rew_std": 587.2570476375741, "Agent": "c51"}, {"env_step": 5900000, "rew": 2470.1, "rew_std": 801.6780463502789, "Agent": "c51"}, {"env_step": 6000000, "rew": 2324.7, "rew_std": 429.707353904957, "Agent": "c51"}, {"env_step": 6100000, "rew": 2391.7, "rew_std": 472.14575080159307, "Agent": "c51"}, {"env_step": 6200000, "rew": 2537.3, "rew_std": 751.1694948545235, "Agent": "c51"}, {"env_step": 6300000, "rew": 2379.9, "rew_std": 653.9957874482068, "Agent": "c51"}, {"env_step": 6400000, "rew": 2658.8, "rew_std": 892.9185629160143, "Agent": "c51"}, {"env_step": 6500000, "rew": 2667.7, "rew_std": 955.2983879396008, "Agent": "c51"}, {"env_step": 6600000, "rew": 2439.5, "rew_std": 665.1733984458489, "Agent": "c51"}, {"env_step": 6700000, "rew": 2414.2, "rew_std": 540.237318222279, "Agent": "c51"}, {"env_step": 6800000, "rew": 2644.7, "rew_std": 830.3852178356741, "Agent": "c51"}, {"env_step": 6900000, "rew": 2430.4, "rew_std": 685.8056867655736, "Agent": "c51"}, {"env_step": 7000000, "rew": 2793.1, "rew_std": 1111.4479250059358, "Agent": "c51"}, {"env_step": 7100000, "rew": 2576.3, "rew_std": 875.817338261809, "Agent": "c51"}, {"env_step": 7200000, "rew": 2954.8, "rew_std": 1122.580758787536, "Agent": "c51"}, {"env_step": 7300000, "rew": 2826.6, "rew_std": 955.3017533742938, "Agent": "c51"}, {"env_step": 7400000, "rew": 2654.4, "rew_std": 863.3571914335341, "Agent": "c51"}, {"env_step": 7500000, "rew": 2642.0, "rew_std": 993.2044099781273, "Agent": "c51"}, {"env_step": 7600000, "rew": 2686.2, "rew_std": 1022.8722109823885, "Agent": "c51"}, {"env_step": 7700000, "rew": 2804.2, "rew_std": 1017.5520428951043, "Agent": "c51"}, {"env_step": 7800000, "rew": 2618.4, "rew_std": 754.956846448855, "Agent": "c51"}, {"env_step": 7900000, "rew": 2661.4, "rew_std": 895.5007761023996, "Agent": "c51"}, {"env_step": 8000000, "rew": 2549.6, "rew_std": 666.4779366190602, "Agent": "c51"}, {"env_step": 8100000, "rew": 2469.4, "rew_std": 947.9606742898145, "Agent": "c51"}, {"env_step": 8200000, "rew": 2505.2, "rew_std": 674.0418087923033, "Agent": "c51"}, {"env_step": 8300000, "rew": 2662.4, "rew_std": 742.0776509234057, "Agent": "c51"}, {"env_step": 8400000, "rew": 2667.8, "rew_std": 823.5222887086907, "Agent": "c51"}, {"env_step": 8500000, "rew": 2946.4, "rew_std": 1133.8231960936414, "Agent": "c51"}, {"env_step": 8600000, "rew": 2712.2, "rew_std": 895.4735953672783, "Agent": "c51"}, {"env_step": 8700000, "rew": 2573.9, "rew_std": 830.3705739005928, "Agent": "c51"}, {"env_step": 8800000, "rew": 2695.3, "rew_std": 1011.963146562166, "Agent": "c51"}, {"env_step": 8900000, "rew": 2988.3, "rew_std": 1189.3779929021725, "Agent": "c51"}, {"env_step": 9000000, "rew": 3090.7, "rew_std": 1242.8095630465675, "Agent": "c51"}, {"env_step": 9100000, "rew": 2933.6, "rew_std": 1181.3922464617751, "Agent": "c51"}, {"env_step": 9200000, "rew": 2749.2, "rew_std": 1097.8828534957636, "Agent": "c51"}, {"env_step": 9300000, "rew": 2900.2, "rew_std": 1171.506961140223, "Agent": "c51"}, {"env_step": 9400000, "rew": 2628.8, "rew_std": 762.4866949658859, "Agent": "c51"}, {"env_step": 9500000, "rew": 2926.8, "rew_std": 1080.2360667928099, "Agent": "c51"}, {"env_step": 9600000, "rew": 2832.0, "rew_std": 1037.878316567024, "Agent": "c51"}, {"env_step": 9700000, "rew": 3305.4, "rew_std": 1524.3043790529503, "Agent": "c51"}, {"env_step": 9800000, "rew": 2810.2, "rew_std": 1373.4536613952434, "Agent": "c51"}, {"env_step": 9900000, "rew": 2678.6, "rew_std": 794.8096879127733, "Agent": "c51"}, {"env_step": 10000000, "rew": 2879.0, "rew_std": 1374.2044243852513, "Agent": "c51"}, {"env_step": 0, "rew": 67.6, "rew_std": 52.75452587219413, "Agent": "dqn"}, {"env_step": 100000, "rew": 221.0, "rew_std": 43.148580509676094, "Agent": "dqn"}, {"env_step": 200000, "rew": 284.2, "rew_std": 53.54960317313286, "Agent": "dqn"}, {"env_step": 300000, "rew": 255.8, "rew_std": 84.65671857566888, "Agent": "dqn"}, {"env_step": 400000, "rew": 283.4, "rew_std": 66.36294146585125, "Agent": "dqn"}, {"env_step": 500000, "rew": 266.6, "rew_std": 59.49151199961218, "Agent": "dqn"}, {"env_step": 600000, "rew": 290.4, "rew_std": 82.25715774326268, "Agent": "dqn"}, {"env_step": 700000, "rew": 346.2, "rew_std": 104.86353036208536, "Agent": "dqn"}, {"env_step": 800000, "rew": 407.4, "rew_std": 125.18801859603019, "Agent": "dqn"}, {"env_step": 900000, "rew": 519.8, "rew_std": 125.2436026310326, "Agent": "dqn"}, {"env_step": 1000000, "rew": 500.8, "rew_std": 113.11304080432106, "Agent": "dqn"}, {"env_step": 1100000, "rew": 857.6, "rew_std": 217.36200219909645, "Agent": "dqn"}, {"env_step": 1200000, "rew": 909.4, "rew_std": 323.0765234429763, "Agent": "dqn"}, {"env_step": 1300000, "rew": 1074.6, "rew_std": 300.64337677720425, "Agent": "dqn"}, {"env_step": 1400000, "rew": 1264.8, "rew_std": 315.2956707600027, "Agent": "dqn"}, {"env_step": 1500000, "rew": 1273.2, "rew_std": 365.40629441759756, "Agent": "dqn"}, {"env_step": 1600000, "rew": 1206.4, "rew_std": 444.22498804097006, "Agent": "dqn"}, {"env_step": 1700000, "rew": 1501.1, "rew_std": 372.9220964222957, "Agent": "dqn"}, {"env_step": 1800000, "rew": 1625.4, "rew_std": 438.0347474801514, "Agent": "dqn"}, {"env_step": 1900000, "rew": 1565.2, "rew_std": 472.6666478608365, "Agent": "dqn"}, {"env_step": 2000000, "rew": 1754.0, "rew_std": 312.49191989553907, "Agent": "dqn"}, {"env_step": 2100000, "rew": 1821.2, "rew_std": 352.37389233596747, "Agent": "dqn"}, {"env_step": 2200000, "rew": 1993.0, "rew_std": 548.2636227217706, "Agent": "dqn"}, {"env_step": 2300000, "rew": 1839.2, "rew_std": 397.0201002468263, "Agent": "dqn"}, {"env_step": 2400000, "rew": 2161.7, "rew_std": 328.52276937831874, "Agent": "dqn"}, {"env_step": 2500000, "rew": 2045.4, "rew_std": 817.1477467386177, "Agent": "dqn"}, {"env_step": 2600000, "rew": 1969.4, "rew_std": 666.9255130822332, "Agent": "dqn"}, {"env_step": 2700000, "rew": 2051.1, "rew_std": 532.0207608731073, "Agent": "dqn"}, {"env_step": 2800000, "rew": 2071.5, "rew_std": 529.4195406291686, "Agent": "dqn"}, {"env_step": 2900000, "rew": 1928.8, "rew_std": 357.8560604488905, "Agent": "dqn"}, {"env_step": 3000000, "rew": 2327.6, "rew_std": 610.7462975737143, "Agent": "dqn"}, {"env_step": 3100000, "rew": 2295.2, "rew_std": 519.696988638572, "Agent": "dqn"}, {"env_step": 3200000, "rew": 1959.3, "rew_std": 419.5507239893646, "Agent": "dqn"}, {"env_step": 3300000, "rew": 2432.6, "rew_std": 510.47276127135325, "Agent": "dqn"}, {"env_step": 3400000, "rew": 2435.4, "rew_std": 451.0752043728407, "Agent": "dqn"}, {"env_step": 3500000, "rew": 2519.0, "rew_std": 417.8040210433595, "Agent": "dqn"}, {"env_step": 3600000, "rew": 2485.6, "rew_std": 568.3300449562736, "Agent": "dqn"}, {"env_step": 3700000, "rew": 2359.6, "rew_std": 628.5970410366247, "Agent": "dqn"}, {"env_step": 3800000, "rew": 2478.4, "rew_std": 378.3150010242787, "Agent": "dqn"}, {"env_step": 3900000, "rew": 2657.6, "rew_std": 525.0781275200862, "Agent": "dqn"}, {"env_step": 4000000, "rew": 2616.8, "rew_std": 352.63715062369704, "Agent": "dqn"}, {"env_step": 4100000, "rew": 2332.2, "rew_std": 373.4396336759129, "Agent": "dqn"}, {"env_step": 4200000, "rew": 2553.5, "rew_std": 363.4631343066309, "Agent": "dqn"}, {"env_step": 4300000, "rew": 2390.0, "rew_std": 644.888207366207, "Agent": "dqn"}, {"env_step": 4400000, "rew": 2727.2, "rew_std": 635.8051273778783, "Agent": "dqn"}, {"env_step": 4500000, "rew": 2780.6, "rew_std": 470.3411953040048, "Agent": "dqn"}, {"env_step": 4600000, "rew": 2597.2, "rew_std": 523.4158576122813, "Agent": "dqn"}, {"env_step": 4700000, "rew": 2602.5, "rew_std": 744.5147748701835, "Agent": "dqn"}, {"env_step": 4800000, "rew": 2417.0, "rew_std": 749.8470510710835, "Agent": "dqn"}, {"env_step": 4900000, "rew": 2945.2, "rew_std": 587.9389083909995, "Agent": "dqn"}, {"env_step": 5000000, "rew": 2675.3, "rew_std": 784.2958689168265, "Agent": "dqn"}, {"env_step": 5100000, "rew": 2855.4, "rew_std": 1029.4697858606633, "Agent": "dqn"}, {"env_step": 5200000, "rew": 2557.8, "rew_std": 957.7696800379516, "Agent": "dqn"}, {"env_step": 5300000, "rew": 2583.4, "rew_std": 686.686566054703, "Agent": "dqn"}, {"env_step": 5400000, "rew": 2643.4, "rew_std": 625.6280364561678, "Agent": "dqn"}, {"env_step": 5500000, "rew": 2624.8, "rew_std": 485.9606568437408, "Agent": "dqn"}, {"env_step": 5600000, "rew": 2627.2, "rew_std": 482.2772646517769, "Agent": "dqn"}, {"env_step": 5700000, "rew": 2659.2, "rew_std": 828.4439389602655, "Agent": "dqn"}, {"env_step": 5800000, "rew": 2599.4, "rew_std": 550.9120074930297, "Agent": "dqn"}, {"env_step": 5900000, "rew": 2938.3, "rew_std": 744.541207724596, "Agent": "dqn"}, {"env_step": 6000000, "rew": 2851.0, "rew_std": 557.9026796852655, "Agent": "dqn"}, {"env_step": 6100000, "rew": 2454.4, "rew_std": 921.5052034579078, "Agent": "dqn"}, {"env_step": 6200000, "rew": 2610.6, "rew_std": 869.3878536073529, "Agent": "dqn"}, {"env_step": 6300000, "rew": 2773.0, "rew_std": 432.34268815373764, "Agent": "dqn"}, {"env_step": 6400000, "rew": 2506.0, "rew_std": 803.5655542642429, "Agent": "dqn"}, {"env_step": 6500000, "rew": 2808.7, "rew_std": 689.5932206743335, "Agent": "dqn"}, {"env_step": 6600000, "rew": 2985.2, "rew_std": 595.4958941923949, "Agent": "dqn"}, {"env_step": 6700000, "rew": 2698.0, "rew_std": 634.7957151714243, "Agent": "dqn"}, {"env_step": 6800000, "rew": 2821.2, "rew_std": 647.6642339978331, "Agent": "dqn"}, {"env_step": 6900000, "rew": 2988.2, "rew_std": 699.9722565930738, "Agent": "dqn"}, {"env_step": 7000000, "rew": 2854.4, "rew_std": 386.69864235603416, "Agent": "dqn"}, {"env_step": 7100000, "rew": 2749.0, "rew_std": 739.4631836677199, "Agent": "dqn"}, {"env_step": 7200000, "rew": 2854.4, "rew_std": 721.7993072870048, "Agent": "dqn"}, {"env_step": 7300000, "rew": 2570.2, "rew_std": 562.2785430727372, "Agent": "dqn"}, {"env_step": 7400000, "rew": 2909.4, "rew_std": 663.843385144418, "Agent": "dqn"}, {"env_step": 7500000, "rew": 2631.1, "rew_std": 731.3366461486802, "Agent": "dqn"}, {"env_step": 7600000, "rew": 2852.2, "rew_std": 665.8404914091662, "Agent": "dqn"}, {"env_step": 7700000, "rew": 2876.5, "rew_std": 423.60199480172423, "Agent": "dqn"}, {"env_step": 7800000, "rew": 2636.4, "rew_std": 778.3766697428694, "Agent": "dqn"}, {"env_step": 7900000, "rew": 2651.3, "rew_std": 599.0412423197588, "Agent": "dqn"}, {"env_step": 8000000, "rew": 2770.2, "rew_std": 600.3961692082987, "Agent": "dqn"}, {"env_step": 8100000, "rew": 2965.0, "rew_std": 660.43697655416, "Agent": "dqn"}, {"env_step": 8200000, "rew": 2998.4, "rew_std": 484.83795230984134, "Agent": "dqn"}, {"env_step": 8300000, "rew": 2604.2, "rew_std": 553.038479673883, "Agent": "dqn"}, {"env_step": 8400000, "rew": 2286.0, "rew_std": 568.7192629056976, "Agent": "dqn"}, {"env_step": 8500000, "rew": 2715.2, "rew_std": 530.8809282692307, "Agent": "dqn"}, {"env_step": 8600000, "rew": 2736.2, "rew_std": 531.8183524475251, "Agent": "dqn"}, {"env_step": 8700000, "rew": 2767.8, "rew_std": 546.1792379796215, "Agent": "dqn"}, {"env_step": 8800000, "rew": 2634.8, "rew_std": 725.9069912874514, "Agent": "dqn"}, {"env_step": 8900000, "rew": 2286.2, "rew_std": 622.9314247973047, "Agent": "dqn"}, {"env_step": 9000000, "rew": 2815.0, "rew_std": 796.1378021423176, "Agent": "dqn"}, {"env_step": 9100000, "rew": 2723.2, "rew_std": 613.141549725673, "Agent": "dqn"}, {"env_step": 9200000, "rew": 2820.4, "rew_std": 687.1970896329524, "Agent": "dqn"}, {"env_step": 9300000, "rew": 2704.2, "rew_std": 625.2215287400139, "Agent": "dqn"}, {"env_step": 9400000, "rew": 2331.2, "rew_std": 761.4608066079304, "Agent": "dqn"}, {"env_step": 9500000, "rew": 2712.7, "rew_std": 589.839308625663, "Agent": "dqn"}, {"env_step": 9600000, "rew": 2890.0, "rew_std": 690.3222435935264, "Agent": "dqn"}, {"env_step": 9700000, "rew": 2330.1, "rew_std": 573.4458038908298, "Agent": "dqn"}, {"env_step": 9800000, "rew": 2720.6, "rew_std": 1005.5040725924486, "Agent": "dqn"}, {"env_step": 9900000, "rew": 3213.9, "rew_std": 381.56741213054346, "Agent": "dqn"}, {"env_step": 10000000, "rew": 2365.6, "rew_std": 703.0867940731073, "Agent": "dqn"}, {"env_step": 0, "rew": 84.0, "rew_std": 39.97999499749844, "Agent": "fqf"}, {"env_step": 100000, "rew": 235.4, "rew_std": 48.70359329659363, "Agent": "fqf"}, {"env_step": 200000, "rew": 270.2, "rew_std": 64.16509954796298, "Agent": "fqf"}, {"env_step": 300000, "rew": 268.0, "rew_std": 43.174066289845804, "Agent": "fqf"}, {"env_step": 400000, "rew": 273.4, "rew_std": 98.09403651598807, "Agent": "fqf"}, {"env_step": 500000, "rew": 311.6, "rew_std": 46.21514903145937, "Agent": "fqf"}, {"env_step": 600000, "rew": 390.6, "rew_std": 107.00299061241232, "Agent": "fqf"}, {"env_step": 700000, "rew": 513.8, "rew_std": 207.3131930196436, "Agent": "fqf"}, {"env_step": 800000, "rew": 677.2, "rew_std": 171.06536762302298, "Agent": "fqf"}, {"env_step": 900000, "rew": 902.2, "rew_std": 367.83088505453156, "Agent": "fqf"}, {"env_step": 1000000, "rew": 1180.2, "rew_std": 368.2884195844338, "Agent": "fqf"}, {"env_step": 1100000, "rew": 1722.4, "rew_std": 516.4639774466366, "Agent": "fqf"}, {"env_step": 1200000, "rew": 2106.6, "rew_std": 570.7236108660653, "Agent": "fqf"}, {"env_step": 1300000, "rew": 2475.0, "rew_std": 793.9652385337787, "Agent": "fqf"}, {"env_step": 1400000, "rew": 2825.5, "rew_std": 731.4583036646724, "Agent": "fqf"}, {"env_step": 1500000, "rew": 3100.2, "rew_std": 422.5548011796813, "Agent": "fqf"}, {"env_step": 1600000, "rew": 3458.9, "rew_std": 911.015197458308, "Agent": "fqf"}, {"env_step": 1700000, "rew": 3497.4, "rew_std": 772.6115712309776, "Agent": "fqf"}, {"env_step": 1800000, "rew": 3650.5, "rew_std": 925.8434262876202, "Agent": "fqf"}, {"env_step": 1900000, "rew": 3701.9, "rew_std": 668.0494667313193, "Agent": "fqf"}, {"env_step": 2000000, "rew": 3597.7, "rew_std": 658.9843776600474, "Agent": "fqf"}, {"env_step": 2100000, "rew": 3653.4, "rew_std": 609.672239814148, "Agent": "fqf"}, {"env_step": 2200000, "rew": 4249.8, "rew_std": 837.1662678345324, "Agent": "fqf"}, {"env_step": 2300000, "rew": 4032.9, "rew_std": 788.8272878140056, "Agent": "fqf"}, {"env_step": 2400000, "rew": 4410.0, "rew_std": 802.1447500295691, "Agent": "fqf"}, {"env_step": 2500000, "rew": 4966.7, "rew_std": 1177.294359962707, "Agent": "fqf"}, {"env_step": 2600000, "rew": 4576.1, "rew_std": 841.7940900243955, "Agent": "fqf"}, {"env_step": 2700000, "rew": 5155.4, "rew_std": 1126.5631984047766, "Agent": "fqf"}, {"env_step": 2800000, "rew": 5071.3, "rew_std": 472.3333674429534, "Agent": "fqf"}, {"env_step": 2900000, "rew": 4688.0, "rew_std": 717.3926400514574, "Agent": "fqf"}, {"env_step": 3000000, "rew": 4985.2, "rew_std": 726.9461878296082, "Agent": "fqf"}, {"env_step": 3100000, "rew": 4975.1, "rew_std": 585.9778920744366, "Agent": "fqf"}, {"env_step": 3200000, "rew": 4920.8, "rew_std": 1034.2605861193783, "Agent": "fqf"}, {"env_step": 3300000, "rew": 5047.4, "rew_std": 724.611233697077, "Agent": "fqf"}, {"env_step": 3400000, "rew": 5616.9, "rew_std": 1700.5169478720288, "Agent": "fqf"}, {"env_step": 3500000, "rew": 5794.7, "rew_std": 1492.3058031114133, "Agent": "fqf"}, {"env_step": 3600000, "rew": 5340.5, "rew_std": 1678.8342532841054, "Agent": "fqf"}, {"env_step": 3700000, "rew": 5262.5, "rew_std": 1011.5026692994933, "Agent": "fqf"}, {"env_step": 3800000, "rew": 5265.4, "rew_std": 708.2388297742507, "Agent": "fqf"}, {"env_step": 3900000, "rew": 5469.6, "rew_std": 858.3154664807107, "Agent": "fqf"}, {"env_step": 4000000, "rew": 6005.2, "rew_std": 1882.1700667049192, "Agent": "fqf"}, {"env_step": 4100000, "rew": 5602.9, "rew_std": 1134.739304862575, "Agent": "fqf"}, {"env_step": 4200000, "rew": 5792.3, "rew_std": 707.0452672919888, "Agent": "fqf"}, {"env_step": 4300000, "rew": 5279.0, "rew_std": 1276.4659807452763, "Agent": "fqf"}, {"env_step": 4400000, "rew": 5126.0, "rew_std": 1668.219589862198, "Agent": "fqf"}, {"env_step": 4500000, "rew": 5870.5, "rew_std": 1084.202771625308, "Agent": "fqf"}, {"env_step": 4600000, "rew": 5440.8, "rew_std": 1613.5675876764506, "Agent": "fqf"}, {"env_step": 4700000, "rew": 5901.3, "rew_std": 586.3447876463131, "Agent": "fqf"}, {"env_step": 4800000, "rew": 5909.6, "rew_std": 1153.2979840440196, "Agent": "fqf"}, {"env_step": 4900000, "rew": 6558.0, "rew_std": 1928.9374277046936, "Agent": "fqf"}, {"env_step": 5000000, "rew": 6140.0, "rew_std": 1449.9375848635693, "Agent": "fqf"}, {"env_step": 5100000, "rew": 6061.0, "rew_std": 844.278626994667, "Agent": "fqf"}, {"env_step": 5200000, "rew": 5817.9, "rew_std": 983.7778661872811, "Agent": "fqf"}, {"env_step": 5300000, "rew": 6269.0, "rew_std": 660.6495288729116, "Agent": "fqf"}, {"env_step": 5400000, "rew": 5512.1, "rew_std": 1459.216464408211, "Agent": "fqf"}, {"env_step": 5500000, "rew": 5616.9, "rew_std": 1634.4490478445637, "Agent": "fqf"}, {"env_step": 5600000, "rew": 6840.6, "rew_std": 1181.8537303744486, "Agent": "fqf"}, {"env_step": 5700000, "rew": 6313.4, "rew_std": 1765.277666544275, "Agent": "fqf"}, {"env_step": 5800000, "rew": 6400.5, "rew_std": 2038.3985012749592, "Agent": "fqf"}, {"env_step": 5900000, "rew": 6898.0, "rew_std": 1592.9492145074807, "Agent": "fqf"}, {"env_step": 6000000, "rew": 6413.2, "rew_std": 2133.357813401212, "Agent": "fqf"}, {"env_step": 6100000, "rew": 6410.2, "rew_std": 1778.3248184738356, "Agent": "fqf"}, {"env_step": 6200000, "rew": 6357.5, "rew_std": 1512.9386140884897, "Agent": "fqf"}, {"env_step": 6300000, "rew": 6276.0, "rew_std": 815.767613968586, "Agent": "fqf"}, {"env_step": 6400000, "rew": 6026.5, "rew_std": 1442.6516038184686, "Agent": "fqf"}, {"env_step": 6500000, "rew": 6285.0, "rew_std": 1306.420376448561, "Agent": "fqf"}, {"env_step": 6600000, "rew": 6946.2, "rew_std": 1895.469588255111, "Agent": "fqf"}, {"env_step": 6700000, "rew": 6952.1, "rew_std": 1505.6863185936172, "Agent": "fqf"}, {"env_step": 6800000, "rew": 6325.1, "rew_std": 1797.747893893913, "Agent": "fqf"}, {"env_step": 6900000, "rew": 6713.2, "rew_std": 1581.9958154179803, "Agent": "fqf"}, {"env_step": 7000000, "rew": 6725.5, "rew_std": 1307.073238192872, "Agent": "fqf"}, {"env_step": 7100000, "rew": 6847.9, "rew_std": 1273.460596170922, "Agent": "fqf"}, {"env_step": 7200000, "rew": 7050.2, "rew_std": 1556.5933187573432, "Agent": "fqf"}, {"env_step": 7300000, "rew": 6831.8, "rew_std": 1364.3794047111676, "Agent": "fqf"}, {"env_step": 7400000, "rew": 6303.4, "rew_std": 1708.135252256097, "Agent": "fqf"}, {"env_step": 7500000, "rew": 7570.5, "rew_std": 2275.7164695980914, "Agent": "fqf"}, {"env_step": 7600000, "rew": 7652.3, "rew_std": 2182.971646632177, "Agent": "fqf"}, {"env_step": 7700000, "rew": 7493.9, "rew_std": 2103.570604947692, "Agent": "fqf"}, {"env_step": 7800000, "rew": 7694.6, "rew_std": 2340.8724954597587, "Agent": "fqf"}, {"env_step": 7900000, "rew": 6932.5, "rew_std": 1200.7625285625797, "Agent": "fqf"}, {"env_step": 8000000, "rew": 7276.4, "rew_std": 1941.6771204296558, "Agent": "fqf"}, {"env_step": 8100000, "rew": 6880.9, "rew_std": 1708.650546483979, "Agent": "fqf"}, {"env_step": 8200000, "rew": 6877.8, "rew_std": 1889.6905460947833, "Agent": "fqf"}, {"env_step": 8300000, "rew": 6632.9, "rew_std": 1580.6722905143872, "Agent": "fqf"}, {"env_step": 8400000, "rew": 7083.5, "rew_std": 1896.2999894531456, "Agent": "fqf"}, {"env_step": 8500000, "rew": 6696.8, "rew_std": 2655.081648462058, "Agent": "fqf"}, {"env_step": 8600000, "rew": 7298.9, "rew_std": 2318.328382693013, "Agent": "fqf"}, {"env_step": 8700000, "rew": 6763.7, "rew_std": 1158.4843589794382, "Agent": "fqf"}, {"env_step": 8800000, "rew": 7196.0, "rew_std": 1865.8348265588784, "Agent": "fqf"}, {"env_step": 8900000, "rew": 6880.7, "rew_std": 1600.4205728495244, "Agent": "fqf"}, {"env_step": 9000000, "rew": 7794.2, "rew_std": 2350.0790965412207, "Agent": "fqf"}, {"env_step": 9100000, "rew": 7289.3, "rew_std": 1832.8727206219203, "Agent": "fqf"}, {"env_step": 9200000, "rew": 6713.8, "rew_std": 1709.8212070272143, "Agent": "fqf"}, {"env_step": 9300000, "rew": 7391.5, "rew_std": 2495.9832631650397, "Agent": "fqf"}, {"env_step": 9400000, "rew": 7061.4, "rew_std": 1013.7232561207226, "Agent": "fqf"}, {"env_step": 9500000, "rew": 7424.5, "rew_std": 2155.5881911905158, "Agent": "fqf"}, {"env_step": 9600000, "rew": 7426.3, "rew_std": 1927.6538615633253, "Agent": "fqf"}, {"env_step": 9700000, "rew": 7352.0, "rew_std": 1948.6867372669215, "Agent": "fqf"}, {"env_step": 9800000, "rew": 7327.9, "rew_std": 1429.7993880261663, "Agent": "fqf"}, {"env_step": 9900000, "rew": 8051.5, "rew_std": 3155.5843912023647, "Agent": "fqf"}, {"env_step": 10000000, "rew": 6903.5, "rew_std": 1400.5262046816547, "Agent": "fqf"}, {"env_step": 0, "rew": 45.4, "rew_std": 52.91540418441496, "Agent": "qrdqn"}, {"env_step": 100000, "rew": 200.0, "rew_std": 31.41973901864877, "Agent": "qrdqn"}, {"env_step": 200000, "rew": 289.4, "rew_std": 83.21562353308423, "Agent": "qrdqn"}, {"env_step": 300000, "rew": 258.4, "rew_std": 86.97493891920823, "Agent": "qrdqn"}, {"env_step": 400000, "rew": 267.0, "rew_std": 81.28591513909406, "Agent": "qrdqn"}, {"env_step": 500000, "rew": 300.6, "rew_std": 89.50553055537965, "Agent": "qrdqn"}, {"env_step": 600000, "rew": 325.2, "rew_std": 81.07379354637354, "Agent": "qrdqn"}, {"env_step": 700000, "rew": 408.6, "rew_std": 74.08670595997638, "Agent": "qrdqn"}, {"env_step": 800000, "rew": 465.0, "rew_std": 195.36273953853123, "Agent": "qrdqn"}, {"env_step": 900000, "rew": 629.4, "rew_std": 227.03488718696954, "Agent": "qrdqn"}, {"env_step": 1000000, "rew": 899.2, "rew_std": 221.7118851121879, "Agent": "qrdqn"}, {"env_step": 1100000, "rew": 1039.5, "rew_std": 408.2810918962572, "Agent": "qrdqn"}, {"env_step": 1200000, "rew": 1266.2, "rew_std": 453.6681165786284, "Agent": "qrdqn"}, {"env_step": 1300000, "rew": 1240.1, "rew_std": 317.7037771258, "Agent": "qrdqn"}, {"env_step": 1400000, "rew": 1547.9, "rew_std": 501.95188016382605, "Agent": "qrdqn"}, {"env_step": 1500000, "rew": 1760.2, "rew_std": 333.3532060742779, "Agent": "qrdqn"}, {"env_step": 1600000, "rew": 1911.5, "rew_std": 621.3912213734596, "Agent": "qrdqn"}, {"env_step": 1700000, "rew": 1998.1, "rew_std": 404.5299618075279, "Agent": "qrdqn"}, {"env_step": 1800000, "rew": 2403.4, "rew_std": 561.4866338569423, "Agent": "qrdqn"}, {"env_step": 1900000, "rew": 2352.4, "rew_std": 371.63858787806197, "Agent": "qrdqn"}, {"env_step": 2000000, "rew": 2128.1, "rew_std": 730.9558741811984, "Agent": "qrdqn"}, {"env_step": 2100000, "rew": 2500.0, "rew_std": 809.2574374078993, "Agent": "qrdqn"}, {"env_step": 2200000, "rew": 2503.2, "rew_std": 550.9215552145332, "Agent": "qrdqn"}, {"env_step": 2300000, "rew": 2622.2, "rew_std": 507.82079516301815, "Agent": "qrdqn"}, {"env_step": 2400000, "rew": 2551.8, "rew_std": 607.7318158530126, "Agent": "qrdqn"}, {"env_step": 2500000, "rew": 2391.6, "rew_std": 668.416516851581, "Agent": "qrdqn"}, {"env_step": 2600000, "rew": 2284.1, "rew_std": 935.7713876797045, "Agent": "qrdqn"}, {"env_step": 2700000, "rew": 2470.0, "rew_std": 762.4539330346457, "Agent": "qrdqn"}, {"env_step": 2800000, "rew": 2389.0, "rew_std": 905.0358003968684, "Agent": "qrdqn"}, {"env_step": 2900000, "rew": 2890.4, "rew_std": 717.7201683107421, "Agent": "qrdqn"}, {"env_step": 3000000, "rew": 2774.2, "rew_std": 525.8086724275284, "Agent": "qrdqn"}, {"env_step": 3100000, "rew": 2885.2, "rew_std": 427.62993347051844, "Agent": "qrdqn"}, {"env_step": 3200000, "rew": 2853.4, "rew_std": 634.3021677402656, "Agent": "qrdqn"}, {"env_step": 3300000, "rew": 2818.2, "rew_std": 437.66238129407463, "Agent": "qrdqn"}, {"env_step": 3400000, "rew": 3153.4, "rew_std": 560.2157084552342, "Agent": "qrdqn"}, {"env_step": 3500000, "rew": 2667.6, "rew_std": 998.5805125276579, "Agent": "qrdqn"}, {"env_step": 3600000, "rew": 3060.6, "rew_std": 483.220901865803, "Agent": "qrdqn"}, {"env_step": 3700000, "rew": 2940.4, "rew_std": 498.1427907738905, "Agent": "qrdqn"}, {"env_step": 3800000, "rew": 3141.6, "rew_std": 600.0958590092087, "Agent": "qrdqn"}, {"env_step": 3900000, "rew": 3165.2, "rew_std": 600.9986356057724, "Agent": "qrdqn"}, {"env_step": 4000000, "rew": 2781.6, "rew_std": 783.5125014956686, "Agent": "qrdqn"}, {"env_step": 4100000, "rew": 3374.4, "rew_std": 895.0686230675277, "Agent": "qrdqn"}, {"env_step": 4200000, "rew": 2629.0, "rew_std": 847.0873626728237, "Agent": "qrdqn"}, {"env_step": 4300000, "rew": 3079.4, "rew_std": 804.6664153548352, "Agent": "qrdqn"}, {"env_step": 4400000, "rew": 3388.8, "rew_std": 672.935182614195, "Agent": "qrdqn"}, {"env_step": 4500000, "rew": 3347.8, "rew_std": 759.8049486545872, "Agent": "qrdqn"}, {"env_step": 4600000, "rew": 3110.0, "rew_std": 800.3929035167666, "Agent": "qrdqn"}, {"env_step": 4700000, "rew": 3388.4, "rew_std": 840.9579299822316, "Agent": "qrdqn"}, {"env_step": 4800000, "rew": 3641.4, "rew_std": 761.8887320337531, "Agent": "qrdqn"}, {"env_step": 4900000, "rew": 3562.0, "rew_std": 694.3111694334176, "Agent": "qrdqn"}, {"env_step": 5000000, "rew": 3529.8, "rew_std": 537.6299470825635, "Agent": "qrdqn"}, {"env_step": 5100000, "rew": 3322.8, "rew_std": 854.5160969812096, "Agent": "qrdqn"}, {"env_step": 5200000, "rew": 3274.0, "rew_std": 1038.5505283807813, "Agent": "qrdqn"}, {"env_step": 5300000, "rew": 3571.0, "rew_std": 705.0641105601675, "Agent": "qrdqn"}, {"env_step": 5400000, "rew": 3157.2, "rew_std": 1001.435449742019, "Agent": "qrdqn"}, {"env_step": 5500000, "rew": 3315.6, "rew_std": 1095.1945215348735, "Agent": "qrdqn"}, {"env_step": 5600000, "rew": 3545.9, "rew_std": 659.3630942053096, "Agent": "qrdqn"}, {"env_step": 5700000, "rew": 3607.2, "rew_std": 404.7084876797125, "Agent": "qrdqn"}, {"env_step": 5800000, "rew": 3753.6, "rew_std": 658.8455357669201, "Agent": "qrdqn"}, {"env_step": 5900000, "rew": 3261.0, "rew_std": 661.858746259351, "Agent": "qrdqn"}, {"env_step": 6000000, "rew": 3644.2, "rew_std": 767.9252307353887, "Agent": "qrdqn"}, {"env_step": 6100000, "rew": 3731.4, "rew_std": 678.6693156464347, "Agent": "qrdqn"}, {"env_step": 6200000, "rew": 4187.6, "rew_std": 725.6907330261287, "Agent": "qrdqn"}, {"env_step": 6300000, "rew": 3814.6, "rew_std": 838.6646767331983, "Agent": "qrdqn"}, {"env_step": 6400000, "rew": 3318.2, "rew_std": 769.7981293819828, "Agent": "qrdqn"}, {"env_step": 6500000, "rew": 3726.2, "rew_std": 650.0095076227732, "Agent": "qrdqn"}, {"env_step": 6600000, "rew": 3536.0, "rew_std": 672.9190144437888, "Agent": "qrdqn"}, {"env_step": 6700000, "rew": 3278.2, "rew_std": 785.3837024028446, "Agent": "qrdqn"}, {"env_step": 6800000, "rew": 3081.8, "rew_std": 919.7605992865751, "Agent": "qrdqn"}, {"env_step": 6900000, "rew": 3286.2, "rew_std": 544.497897149291, "Agent": "qrdqn"}, {"env_step": 7000000, "rew": 3537.8, "rew_std": 511.0807763944952, "Agent": "qrdqn"}, {"env_step": 7100000, "rew": 3516.8, "rew_std": 652.3656643325122, "Agent": "qrdqn"}, {"env_step": 7200000, "rew": 3074.8, "rew_std": 885.5579935837065, "Agent": "qrdqn"}, {"env_step": 7300000, "rew": 3015.6, "rew_std": 826.9422228910554, "Agent": "qrdqn"}, {"env_step": 7400000, "rew": 3113.0, "rew_std": 853.1176941079115, "Agent": "qrdqn"}, {"env_step": 7500000, "rew": 3382.6, "rew_std": 425.42407078114417, "Agent": "qrdqn"}, {"env_step": 7600000, "rew": 3832.2, "rew_std": 956.4151609003278, "Agent": "qrdqn"}, {"env_step": 7700000, "rew": 3565.6, "rew_std": 824.5613621799168, "Agent": "qrdqn"}, {"env_step": 7800000, "rew": 3260.4, "rew_std": 1026.301242326053, "Agent": "qrdqn"}, {"env_step": 7900000, "rew": 3165.4, "rew_std": 1141.938019333799, "Agent": "qrdqn"}, {"env_step": 8000000, "rew": 3833.2, "rew_std": 872.2041962751612, "Agent": "qrdqn"}, {"env_step": 8100000, "rew": 3275.8, "rew_std": 543.5265954854463, "Agent": "qrdqn"}, {"env_step": 8200000, "rew": 3510.4, "rew_std": 1006.9696321140971, "Agent": "qrdqn"}, {"env_step": 8300000, "rew": 3475.6, "rew_std": 1033.5646278777153, "Agent": "qrdqn"}, {"env_step": 8400000, "rew": 3522.6, "rew_std": 554.0274722430288, "Agent": "qrdqn"}, {"env_step": 8500000, "rew": 3770.6, "rew_std": 639.6787005989804, "Agent": "qrdqn"}, {"env_step": 8600000, "rew": 3319.4, "rew_std": 514.5678186595038, "Agent": "qrdqn"}, {"env_step": 8700000, "rew": 3270.6, "rew_std": 850.6858644646684, "Agent": "qrdqn"}, {"env_step": 8800000, "rew": 3856.0, "rew_std": 599.3169445293533, "Agent": "qrdqn"}, {"env_step": 8900000, "rew": 3440.7, "rew_std": 711.9547808674369, "Agent": "qrdqn"}, {"env_step": 9000000, "rew": 3568.7, "rew_std": 857.8783188774502, "Agent": "qrdqn"}, {"env_step": 9100000, "rew": 3740.8, "rew_std": 602.7020491088444, "Agent": "qrdqn"}, {"env_step": 9200000, "rew": 3701.8, "rew_std": 647.5816241988341, "Agent": "qrdqn"}, {"env_step": 9300000, "rew": 3148.5, "rew_std": 721.4049140392655, "Agent": "qrdqn"}, {"env_step": 9400000, "rew": 3532.6, "rew_std": 894.7174079003939, "Agent": "qrdqn"}, {"env_step": 9500000, "rew": 3562.2, "rew_std": 658.007568345532, "Agent": "qrdqn"}, {"env_step": 9600000, "rew": 3524.8, "rew_std": 867.9291215300937, "Agent": "qrdqn"}, {"env_step": 9700000, "rew": 3570.2, "rew_std": 838.2216651936408, "Agent": "qrdqn"}, {"env_step": 9800000, "rew": 3432.2, "rew_std": 583.9133154844133, "Agent": "qrdqn"}, {"env_step": 9900000, "rew": 3285.6, "rew_std": 924.0051082109882, "Agent": "qrdqn"}, {"env_step": 10000000, "rew": 3202.8, "rew_std": 982.7004426578835, "Agent": "qrdqn"}, {"env_step": 0, "rew": 106.6, "rew_std": 87.09557968117556, "Agent": "iqn"}, {"env_step": 100000, "rew": 228.6, "rew_std": 33.87093148999596, "Agent": "iqn"}, {"env_step": 200000, "rew": 229.6, "rew_std": 75.1308192421725, "Agent": "iqn"}, {"env_step": 300000, "rew": 251.2, "rew_std": 79.04530346579739, "Agent": "iqn"}, {"env_step": 400000, "rew": 247.6, "rew_std": 64.6083585923679, "Agent": "iqn"}, {"env_step": 500000, "rew": 382.2, "rew_std": 127.84506247798544, "Agent": "iqn"}, {"env_step": 600000, "rew": 441.6, "rew_std": 243.14736272474764, "Agent": "iqn"}, {"env_step": 700000, "rew": 692.8, "rew_std": 305.80542833638515, "Agent": "iqn"}, {"env_step": 800000, "rew": 990.8, "rew_std": 394.5460175949062, "Agent": "iqn"}, {"env_step": 900000, "rew": 901.2, "rew_std": 464.5894531734443, "Agent": "iqn"}, {"env_step": 1000000, "rew": 1541.4, "rew_std": 385.48883252307064, "Agent": "iqn"}, {"env_step": 1100000, "rew": 2180.2, "rew_std": 539.3918427266026, "Agent": "iqn"}, {"env_step": 1200000, "rew": 2100.2, "rew_std": 646.6129908995024, "Agent": "iqn"}, {"env_step": 1300000, "rew": 2565.6, "rew_std": 635.4090336153555, "Agent": "iqn"}, {"env_step": 1400000, "rew": 2368.0, "rew_std": 440.5024404018666, "Agent": "iqn"}, {"env_step": 1500000, "rew": 2376.0, "rew_std": 1055.6806335251206, "Agent": "iqn"}, {"env_step": 1600000, "rew": 3192.5, "rew_std": 493.9407353114339, "Agent": "iqn"}, {"env_step": 1700000, "rew": 2622.4, "rew_std": 1216.6698155210393, "Agent": "iqn"}, {"env_step": 1800000, "rew": 3191.4, "rew_std": 527.0340026981181, "Agent": "iqn"}, {"env_step": 1900000, "rew": 2762.6, "rew_std": 706.9370834805597, "Agent": "iqn"}, {"env_step": 2000000, "rew": 3111.6, "rew_std": 1020.695762703069, "Agent": "iqn"}, {"env_step": 2100000, "rew": 3645.5, "rew_std": 998.2598108708975, "Agent": "iqn"}, {"env_step": 2200000, "rew": 3387.8, "rew_std": 938.3493805614196, "Agent": "iqn"}, {"env_step": 2300000, "rew": 3795.2, "rew_std": 690.2713669275294, "Agent": "iqn"}, {"env_step": 2400000, "rew": 3738.6, "rew_std": 1253.164011612207, "Agent": "iqn"}, {"env_step": 2500000, "rew": 3917.4, "rew_std": 391.9112654670697, "Agent": "iqn"}, {"env_step": 2600000, "rew": 3715.0, "rew_std": 903.7580428411135, "Agent": "iqn"}, {"env_step": 2700000, "rew": 4198.8, "rew_std": 916.1332654150268, "Agent": "iqn"}, {"env_step": 2800000, "rew": 3842.8, "rew_std": 1014.2438365600256, "Agent": "iqn"}, {"env_step": 2900000, "rew": 3685.0, "rew_std": 1446.485741374591, "Agent": "iqn"}, {"env_step": 3000000, "rew": 3950.0, "rew_std": 1456.4194450775506, "Agent": "iqn"}, {"env_step": 3100000, "rew": 4272.0, "rew_std": 806.635977377652, "Agent": "iqn"}, {"env_step": 3200000, "rew": 4197.4, "rew_std": 668.5554875999449, "Agent": "iqn"}, {"env_step": 3300000, "rew": 4473.6, "rew_std": 668.1130443270808, "Agent": "iqn"}, {"env_step": 3400000, "rew": 4128.8, "rew_std": 1420.9650804998691, "Agent": "iqn"}, {"env_step": 3500000, "rew": 4091.2, "rew_std": 799.4804312802158, "Agent": "iqn"}, {"env_step": 3600000, "rew": 3836.0, "rew_std": 495.31727205903087, "Agent": "iqn"}, {"env_step": 3700000, "rew": 3937.6, "rew_std": 955.3729324195865, "Agent": "iqn"}, {"env_step": 3800000, "rew": 4366.0, "rew_std": 646.7463181186267, "Agent": "iqn"}, {"env_step": 3900000, "rew": 4184.6, "rew_std": 648.5627494699337, "Agent": "iqn"}, {"env_step": 4000000, "rew": 4264.2, "rew_std": 1133.6307864556254, "Agent": "iqn"}, {"env_step": 4100000, "rew": 3667.2, "rew_std": 1318.5249940748186, "Agent": "iqn"}, {"env_step": 4200000, "rew": 4149.6, "rew_std": 734.6094472575206, "Agent": "iqn"}, {"env_step": 4300000, "rew": 4311.5, "rew_std": 1162.0347025799185, "Agent": "iqn"}, {"env_step": 4400000, "rew": 4001.8, "rew_std": 1118.5291949698942, "Agent": "iqn"}, {"env_step": 4500000, "rew": 4658.6, "rew_std": 651.7754521305632, "Agent": "iqn"}, {"env_step": 4600000, "rew": 4676.1, "rew_std": 562.302134088072, "Agent": "iqn"}, {"env_step": 4700000, "rew": 4486.8, "rew_std": 643.8162470767571, "Agent": "iqn"}, {"env_step": 4800000, "rew": 4090.2, "rew_std": 1062.8171808923678, "Agent": "iqn"}, {"env_step": 4900000, "rew": 4424.2, "rew_std": 889.2979028424614, "Agent": "iqn"}, {"env_step": 5000000, "rew": 4119.8, "rew_std": 986.707433842474, "Agent": "iqn"}, {"env_step": 5100000, "rew": 4387.0, "rew_std": 1373.3178073556026, "Agent": "iqn"}, {"env_step": 5200000, "rew": 4230.2, "rew_std": 906.2165083466533, "Agent": "iqn"}, {"env_step": 5300000, "rew": 4634.0, "rew_std": 1000.1627867502369, "Agent": "iqn"}, {"env_step": 5400000, "rew": 4360.6, "rew_std": 547.0027787863604, "Agent": "iqn"}, {"env_step": 5500000, "rew": 4132.2, "rew_std": 1382.0892735275822, "Agent": "iqn"}, {"env_step": 5600000, "rew": 4627.2, "rew_std": 630.4678897453858, "Agent": "iqn"}, {"env_step": 5700000, "rew": 4543.6, "rew_std": 817.3174658601149, "Agent": "iqn"}, {"env_step": 5800000, "rew": 4541.4, "rew_std": 589.8366214469902, "Agent": "iqn"}, {"env_step": 5900000, "rew": 4541.8, "rew_std": 957.1254672194237, "Agent": "iqn"}, {"env_step": 6000000, "rew": 4616.6, "rew_std": 624.175648355493, "Agent": "iqn"}, {"env_step": 6100000, "rew": 4831.4, "rew_std": 685.3571623613486, "Agent": "iqn"}, {"env_step": 6200000, "rew": 4185.4, "rew_std": 1318.965063980089, "Agent": "iqn"}, {"env_step": 6300000, "rew": 4762.2, "rew_std": 578.2092700744256, "Agent": "iqn"}, {"env_step": 6400000, "rew": 4953.0, "rew_std": 491.9441025157228, "Agent": "iqn"}, {"env_step": 6500000, "rew": 4542.0, "rew_std": 540.8837213301949, "Agent": "iqn"}, {"env_step": 6600000, "rew": 4407.3, "rew_std": 992.3926692595023, "Agent": "iqn"}, {"env_step": 6700000, "rew": 4558.4, "rew_std": 883.3956305076453, "Agent": "iqn"}, {"env_step": 6800000, "rew": 4337.2, "rew_std": 886.7474048453709, "Agent": "iqn"}, {"env_step": 6900000, "rew": 4499.8, "rew_std": 1165.5764067619075, "Agent": "iqn"}, {"env_step": 7000000, "rew": 4851.0, "rew_std": 666.8494582737546, "Agent": "iqn"}, {"env_step": 7100000, "rew": 4711.8, "rew_std": 1179.0499395699912, "Agent": "iqn"}, {"env_step": 7200000, "rew": 5200.4, "rew_std": 528.0430285497575, "Agent": "iqn"}, {"env_step": 7300000, "rew": 4526.0, "rew_std": 615.5309902839987, "Agent": "iqn"}, {"env_step": 7400000, "rew": 4689.4, "rew_std": 1031.8333392559093, "Agent": "iqn"}, {"env_step": 7500000, "rew": 4679.8, "rew_std": 1083.7780030984204, "Agent": "iqn"}, {"env_step": 7600000, "rew": 4287.0, "rew_std": 1172.4614279369705, "Agent": "iqn"}, {"env_step": 7700000, "rew": 4314.4, "rew_std": 984.6696095645484, "Agent": "iqn"}, {"env_step": 7800000, "rew": 5033.0, "rew_std": 813.2641637254159, "Agent": "iqn"}, {"env_step": 7900000, "rew": 5103.8, "rew_std": 708.6434646562402, "Agent": "iqn"}, {"env_step": 8000000, "rew": 4809.2, "rew_std": 815.4278386221555, "Agent": "iqn"}, {"env_step": 8100000, "rew": 4326.3, "rew_std": 854.7598551640103, "Agent": "iqn"}, {"env_step": 8200000, "rew": 4424.6, "rew_std": 656.1347727410886, "Agent": "iqn"}, {"env_step": 8300000, "rew": 4463.8, "rew_std": 1188.5400960842676, "Agent": "iqn"}, {"env_step": 8400000, "rew": 4601.0, "rew_std": 1020.5477940792387, "Agent": "iqn"}, {"env_step": 8500000, "rew": 4801.4, "rew_std": 724.9215405821516, "Agent": "iqn"}, {"env_step": 8600000, "rew": 4811.2, "rew_std": 703.3446950109171, "Agent": "iqn"}, {"env_step": 8700000, "rew": 4873.2, "rew_std": 966.0274116193598, "Agent": "iqn"}, {"env_step": 8800000, "rew": 4744.0, "rew_std": 747.7903449497059, "Agent": "iqn"}, {"env_step": 8900000, "rew": 4795.2, "rew_std": 1258.5916573694583, "Agent": "iqn"}, {"env_step": 9000000, "rew": 4230.6, "rew_std": 1360.8685608830854, "Agent": "iqn"}, {"env_step": 9100000, "rew": 4927.6, "rew_std": 1000.1939012011621, "Agent": "iqn"}, {"env_step": 9200000, "rew": 4662.6, "rew_std": 837.6820637927017, "Agent": "iqn"}, {"env_step": 9300000, "rew": 4471.6, "rew_std": 785.4536523563946, "Agent": "iqn"}, {"env_step": 9400000, "rew": 5254.6, "rew_std": 424.59020243053186, "Agent": "iqn"}, {"env_step": 9500000, "rew": 5147.4, "rew_std": 802.4213606329283, "Agent": "iqn"}, {"env_step": 9600000, "rew": 4296.0, "rew_std": 1377.9889694768967, "Agent": "iqn"}, {"env_step": 9700000, "rew": 4708.8, "rew_std": 957.8067445993477, "Agent": "iqn"}, {"env_step": 9800000, "rew": 5341.2, "rew_std": 670.1965084958291, "Agent": "iqn"}, {"env_step": 9900000, "rew": 4807.4, "rew_std": 688.9702751207776, "Agent": "iqn"}, {"env_step": 10000000, "rew": 5173.0, "rew_std": 639.4342812205176, "Agent": "iqn"}, {"env_step": 0, "rew": 55.2, "rew_std": 71.61675781547221, "Agent": "rainbow"}, {"env_step": 100000, "rew": 197.4, "rew_std": 98.0124481889928, "Agent": "rainbow"}, {"env_step": 200000, "rew": 183.8, "rew_std": 83.80190928612545, "Agent": "rainbow"}, {"env_step": 300000, "rew": 341.6, "rew_std": 115.49129837351384, "Agent": "rainbow"}, {"env_step": 400000, "rew": 478.6, "rew_std": 112.00374993722309, "Agent": "rainbow"}, {"env_step": 500000, "rew": 327.8, "rew_std": 113.28000706214668, "Agent": "rainbow"}, {"env_step": 600000, "rew": 556.2, "rew_std": 241.52424308959132, "Agent": "rainbow"}, {"env_step": 700000, "rew": 778.0, "rew_std": 244.70390270692457, "Agent": "rainbow"}, {"env_step": 800000, "rew": 952.2, "rew_std": 287.64902224759953, "Agent": "rainbow"}, {"env_step": 900000, "rew": 1213.4, "rew_std": 264.05310072029073, "Agent": "rainbow"}, {"env_step": 1000000, "rew": 1398.2, "rew_std": 236.1549491329792, "Agent": "rainbow"}, {"env_step": 1100000, "rew": 1322.4, "rew_std": 223.74056404684424, "Agent": "rainbow"}, {"env_step": 1200000, "rew": 1377.0, "rew_std": 333.29416436535456, "Agent": "rainbow"}, {"env_step": 1300000, "rew": 1495.2, "rew_std": 277.13996463880846, "Agent": "rainbow"}, {"env_step": 1400000, "rew": 1431.2, "rew_std": 408.9402890398548, "Agent": "rainbow"}, {"env_step": 1500000, "rew": 1460.6, "rew_std": 341.7953188678862, "Agent": "rainbow"}, {"env_step": 1600000, "rew": 1478.6, "rew_std": 316.41750899721086, "Agent": "rainbow"}, {"env_step": 1700000, "rew": 1522.4, "rew_std": 258.7659946747254, "Agent": "rainbow"}, {"env_step": 1800000, "rew": 1574.8, "rew_std": 316.74620755424996, "Agent": "rainbow"}, {"env_step": 1900000, "rew": 1628.8, "rew_std": 394.08445795286065, "Agent": "rainbow"}, {"env_step": 2000000, "rew": 1717.2, "rew_std": 368.0708627424887, "Agent": "rainbow"}, {"env_step": 2100000, "rew": 1660.2, "rew_std": 333.7057985711366, "Agent": "rainbow"}, {"env_step": 2200000, "rew": 1685.8, "rew_std": 311.28051657628686, "Agent": "rainbow"}, {"env_step": 2300000, "rew": 1637.6, "rew_std": 437.96237281300773, "Agent": "rainbow"}, {"env_step": 2400000, "rew": 1646.0, "rew_std": 412.8253868162664, "Agent": "rainbow"}, {"env_step": 2500000, "rew": 1597.1, "rew_std": 315.6765591550947, "Agent": "rainbow"}, {"env_step": 2600000, "rew": 1661.2, "rew_std": 427.9338266601508, "Agent": "rainbow"}, {"env_step": 2700000, "rew": 1649.2, "rew_std": 378.05629210476053, "Agent": "rainbow"}, {"env_step": 2800000, "rew": 1687.8, "rew_std": 342.1618915075143, "Agent": "rainbow"}, {"env_step": 2900000, "rew": 1625.6, "rew_std": 356.8526866929826, "Agent": "rainbow"}, {"env_step": 3000000, "rew": 1646.2, "rew_std": 357.99156414641953, "Agent": "rainbow"}, {"env_step": 3100000, "rew": 1651.4, "rew_std": 435.95141931183116, "Agent": "rainbow"}, {"env_step": 3200000, "rew": 1694.2, "rew_std": 385.55253857289017, "Agent": "rainbow"}, {"env_step": 3300000, "rew": 1634.8, "rew_std": 408.7338498338497, "Agent": "rainbow"}, {"env_step": 3400000, "rew": 1658.6, "rew_std": 404.9168309665578, "Agent": "rainbow"}, {"env_step": 3500000, "rew": 1594.8, "rew_std": 372.0668757091929, "Agent": "rainbow"}, {"env_step": 3600000, "rew": 1646.4, "rew_std": 462.23093795201555, "Agent": "rainbow"}, {"env_step": 3700000, "rew": 1722.8, "rew_std": 375.3560443099325, "Agent": "rainbow"}, {"env_step": 3800000, "rew": 1726.6, "rew_std": 325.04775033831567, "Agent": "rainbow"}, {"env_step": 3900000, "rew": 1754.4, "rew_std": 369.6255402430952, "Agent": "rainbow"}, {"env_step": 4000000, "rew": 1707.2, "rew_std": 340.19782480198194, "Agent": "rainbow"}, {"env_step": 4100000, "rew": 1701.8, "rew_std": 354.3878666094538, "Agent": "rainbow"}, {"env_step": 4200000, "rew": 1657.6, "rew_std": 428.869024295297, "Agent": "rainbow"}, {"env_step": 4300000, "rew": 1686.4, "rew_std": 453.4417713444583, "Agent": "rainbow"}, {"env_step": 4400000, "rew": 1743.8, "rew_std": 245.49940936792498, "Agent": "rainbow"}, {"env_step": 4500000, "rew": 1714.6, "rew_std": 414.4838235685441, "Agent": "rainbow"}, {"env_step": 4600000, "rew": 1804.6, "rew_std": 416.3316466472372, "Agent": "rainbow"}, {"env_step": 4700000, "rew": 1770.2, "rew_std": 335.57884319485936, "Agent": "rainbow"}, {"env_step": 4800000, "rew": 1752.0, "rew_std": 379.45750750248703, "Agent": "rainbow"}, {"env_step": 4900000, "rew": 1681.8, "rew_std": 369.5407420028271, "Agent": "rainbow"}, {"env_step": 5000000, "rew": 1798.8, "rew_std": 338.7603282558334, "Agent": "rainbow"}, {"env_step": 5100000, "rew": 1814.6, "rew_std": 419.58651074599624, "Agent": "rainbow"}, {"env_step": 5200000, "rew": 1811.0, "rew_std": 331.25005660376877, "Agent": "rainbow"}, {"env_step": 5300000, "rew": 1854.4, "rew_std": 320.87106444801157, "Agent": "rainbow"}, {"env_step": 5400000, "rew": 1821.8, "rew_std": 351.93914246642134, "Agent": "rainbow"}, {"env_step": 5500000, "rew": 1869.2, "rew_std": 345.7851355972376, "Agent": "rainbow"}, {"env_step": 5600000, "rew": 1842.0, "rew_std": 353.2319351361086, "Agent": "rainbow"}, {"env_step": 5700000, "rew": 1894.3, "rew_std": 423.4864932911084, "Agent": "rainbow"}, {"env_step": 5800000, "rew": 1768.2, "rew_std": 373.75334112218985, "Agent": "rainbow"}, {"env_step": 5900000, "rew": 1788.2, "rew_std": 346.13690932924214, "Agent": "rainbow"}, {"env_step": 6000000, "rew": 1827.0, "rew_std": 370.17320270381543, "Agent": "rainbow"}, {"env_step": 6100000, "rew": 1777.0, "rew_std": 365.00767115226495, "Agent": "rainbow"}, {"env_step": 6200000, "rew": 1762.0, "rew_std": 395.1910930170365, "Agent": "rainbow"}, {"env_step": 6300000, "rew": 1882.0, "rew_std": 356.34477686644993, "Agent": "rainbow"}, {"env_step": 6400000, "rew": 1807.8, "rew_std": 391.83103501381817, "Agent": "rainbow"}, {"env_step": 6500000, "rew": 1864.4, "rew_std": 366.81635732338873, "Agent": "rainbow"}, {"env_step": 6600000, "rew": 1839.8, "rew_std": 329.80776218882414, "Agent": "rainbow"}, {"env_step": 6700000, "rew": 1803.2, "rew_std": 371.11852554136937, "Agent": "rainbow"}, {"env_step": 6800000, "rew": 1861.4, "rew_std": 354.3445216170274, "Agent": "rainbow"}, {"env_step": 6900000, "rew": 1861.8, "rew_std": 366.5399841763515, "Agent": "rainbow"}, {"env_step": 7000000, "rew": 1877.6, "rew_std": 345.57175810531743, "Agent": "rainbow"}, {"env_step": 7100000, "rew": 1860.6, "rew_std": 372.0484377067051, "Agent": "rainbow"}, {"env_step": 7200000, "rew": 1861.4, "rew_std": 343.88259624470675, "Agent": "rainbow"}, {"env_step": 7300000, "rew": 1931.6, "rew_std": 363.43890820879375, "Agent": "rainbow"}, {"env_step": 7400000, "rew": 1901.2, "rew_std": 362.95145680930943, "Agent": "rainbow"}, {"env_step": 7500000, "rew": 1897.0, "rew_std": 373.6364543242536, "Agent": "rainbow"}, {"env_step": 7600000, "rew": 1901.0, "rew_std": 359.4754511785193, "Agent": "rainbow"}, {"env_step": 7700000, "rew": 1892.6, "rew_std": 381.24957704894575, "Agent": "rainbow"}, {"env_step": 7800000, "rew": 1892.4, "rew_std": 351.7485465499467, "Agent": "rainbow"}, {"env_step": 7900000, "rew": 1915.4, "rew_std": 373.93480715226286, "Agent": "rainbow"}, {"env_step": 8000000, "rew": 1863.4, "rew_std": 332.4106496488944, "Agent": "rainbow"}, {"env_step": 8100000, "rew": 1909.0, "rew_std": 413.3543274238217, "Agent": "rainbow"}, {"env_step": 8200000, "rew": 1866.6, "rew_std": 333.5890285965652, "Agent": "rainbow"}, {"env_step": 8300000, "rew": 1887.8, "rew_std": 355.3504748836, "Agent": "rainbow"}, {"env_step": 8400000, "rew": 1912.4, "rew_std": 447.0863898621831, "Agent": "rainbow"}, {"env_step": 8500000, "rew": 1909.8, "rew_std": 365.22152181929255, "Agent": "rainbow"}, {"env_step": 8600000, "rew": 1915.4, "rew_std": 369.93734604659744, "Agent": "rainbow"}, {"env_step": 8700000, "rew": 1866.4, "rew_std": 375.5697538407479, "Agent": "rainbow"}, {"env_step": 8800000, "rew": 1877.2, "rew_std": 429.923900242822, "Agent": "rainbow"}, {"env_step": 8900000, "rew": 1829.0, "rew_std": 345.25845391532414, "Agent": "rainbow"}, {"env_step": 9000000, "rew": 1908.6, "rew_std": 339.2050117554279, "Agent": "rainbow"}, {"env_step": 9100000, "rew": 1874.0, "rew_std": 355.7178657306939, "Agent": "rainbow"}, {"env_step": 9200000, "rew": 1831.8, "rew_std": 345.1578769201132, "Agent": "rainbow"}, {"env_step": 9300000, "rew": 1870.8, "rew_std": 397.23262705875504, "Agent": "rainbow"}, {"env_step": 9400000, "rew": 1910.0, "rew_std": 358.90277234928124, "Agent": "rainbow"}, {"env_step": 9500000, "rew": 1902.4, "rew_std": 376.3017937772819, "Agent": "rainbow"}, {"env_step": 9600000, "rew": 1889.8, "rew_std": 407.2625197584477, "Agent": "rainbow"}, {"env_step": 9700000, "rew": 1912.6, "rew_std": 381.65329816470864, "Agent": "rainbow"}, {"env_step": 9800000, "rew": 1894.6, "rew_std": 346.8568004234601, "Agent": "rainbow"}, {"env_step": 9900000, "rew": 1934.6, "rew_std": 376.412592775534, "Agent": "rainbow"}, {"env_step": 10000000, "rew": 1896.0, "rew_std": 353.1526582088828, "Agent": "rainbow"}, {"env_step": 0, "rew": 149.2, "rew_std": 108.08959246847033, "Agent": "ppo"}, {"env_step": 100000, "rew": 451.8, "rew_std": 93.66087763842489, "Agent": "ppo"}, {"env_step": 200000, "rew": 548.8, "rew_std": 87.63195764103413, "Agent": "ppo"}, {"env_step": 300000, "rew": 628.6, "rew_std": 55.785661240143064, "Agent": "ppo"}, {"env_step": 400000, "rew": 712.4, "rew_std": 68.37426416423068, "Agent": "ppo"}, {"env_step": 500000, "rew": 747.4, "rew_std": 46.5536249931195, "Agent": "ppo"}, {"env_step": 600000, "rew": 758.2, "rew_std": 58.05824661492974, "Agent": "ppo"}, {"env_step": 700000, "rew": 748.8, "rew_std": 55.246357345982545, "Agent": "ppo"}, {"env_step": 800000, "rew": 781.4, "rew_std": 39.2127530275547, "Agent": "ppo"}, {"env_step": 900000, "rew": 792.4, "rew_std": 79.78370761001271, "Agent": "ppo"}, {"env_step": 1000000, "rew": 785.8, "rew_std": 37.83596172955037, "Agent": "ppo"}, {"env_step": 1100000, "rew": 824.4, "rew_std": 24.83223711227001, "Agent": "ppo"}, {"env_step": 1200000, "rew": 814.2, "rew_std": 35.104985400937004, "Agent": "ppo"}, {"env_step": 1300000, "rew": 823.0, "rew_std": 42.22084793085046, "Agent": "ppo"}, {"env_step": 1400000, "rew": 822.6, "rew_std": 22.628300864183327, "Agent": "ppo"}, {"env_step": 1500000, "rew": 824.2, "rew_std": 23.142169301947472, "Agent": "ppo"}, {"env_step": 1600000, "rew": 841.0, "rew_std": 110.87740978215535, "Agent": "ppo"}, {"env_step": 1700000, "rew": 851.8, "rew_std": 50.91522365658429, "Agent": "ppo"}, {"env_step": 1800000, "rew": 865.0, "rew_std": 90.81960140850653, "Agent": "ppo"}, {"env_step": 1900000, "rew": 872.2, "rew_std": 84.7417252597562, "Agent": "ppo"}, {"env_step": 2000000, "rew": 843.8, "rew_std": 130.17357642778353, "Agent": "ppo"}, {"env_step": 2100000, "rew": 868.0, "rew_std": 104.34557968596465, "Agent": "ppo"}, {"env_step": 2200000, "rew": 860.4, "rew_std": 135.80809990571254, "Agent": "ppo"}, {"env_step": 2300000, "rew": 884.0, "rew_std": 147.23043163694115, "Agent": "ppo"}, {"env_step": 2400000, "rew": 922.0, "rew_std": 163.5506037897751, "Agent": "ppo"}, {"env_step": 2500000, "rew": 906.8, "rew_std": 148.42425677765746, "Agent": "ppo"}, {"env_step": 2600000, "rew": 895.0, "rew_std": 125.72111994410486, "Agent": "ppo"}, {"env_step": 2700000, "rew": 920.4, "rew_std": 155.8660963776279, "Agent": "ppo"}, {"env_step": 2800000, "rew": 934.2, "rew_std": 182.06251673532364, "Agent": "ppo"}, {"env_step": 2900000, "rew": 894.0, "rew_std": 235.24625395529682, "Agent": "ppo"}, {"env_step": 3000000, "rew": 922.6, "rew_std": 170.39964788696014, "Agent": "ppo"}, {"env_step": 3100000, "rew": 933.2, "rew_std": 172.90968740935253, "Agent": "ppo"}, {"env_step": 3200000, "rew": 931.8, "rew_std": 181.70404508430735, "Agent": "ppo"}, {"env_step": 3300000, "rew": 928.4, "rew_std": 214.5931965370757, "Agent": "ppo"}, {"env_step": 3400000, "rew": 933.0, "rew_std": 203.49201458533943, "Agent": "ppo"}, {"env_step": 3500000, "rew": 958.2, "rew_std": 223.07209596899386, "Agent": "ppo"}, {"env_step": 3600000, "rew": 951.8, "rew_std": 234.61704967883298, "Agent": "ppo"}, {"env_step": 3700000, "rew": 944.0, "rew_std": 233.30666514268296, "Agent": "ppo"}, {"env_step": 3800000, "rew": 934.8, "rew_std": 249.3434579049549, "Agent": "ppo"}, {"env_step": 3900000, "rew": 925.6, "rew_std": 264.6655247666382, "Agent": "ppo"}, {"env_step": 4000000, "rew": 910.4, "rew_std": 329.20364518030476, "Agent": "ppo"}, {"env_step": 4100000, "rew": 939.6, "rew_std": 328.4677153085216, "Agent": "ppo"}, {"env_step": 4200000, "rew": 925.8, "rew_std": 297.59227140502156, "Agent": "ppo"}, {"env_step": 4300000, "rew": 947.8, "rew_std": 306.41599174977796, "Agent": "ppo"}, {"env_step": 4400000, "rew": 938.0, "rew_std": 307.99350642505436, "Agent": "ppo"}, {"env_step": 4500000, "rew": 885.0, "rew_std": 320.93145685644464, "Agent": "ppo"}, {"env_step": 4600000, "rew": 937.4, "rew_std": 259.07226790994054, "Agent": "ppo"}, {"env_step": 4700000, "rew": 932.6, "rew_std": 310.10198322487395, "Agent": "ppo"}, {"env_step": 4800000, "rew": 906.4, "rew_std": 359.8041689586156, "Agent": "ppo"}, {"env_step": 4900000, "rew": 887.0, "rew_std": 377.4125064170503, "Agent": "ppo"}, {"env_step": 5000000, "rew": 901.6, "rew_std": 373.16462854884844, "Agent": "ppo"}, {"env_step": 5100000, "rew": 899.6, "rew_std": 394.11145631661105, "Agent": "ppo"}, {"env_step": 5200000, "rew": 871.2, "rew_std": 410.2703498913856, "Agent": "ppo"}, {"env_step": 5300000, "rew": 859.8, "rew_std": 379.5180628112449, "Agent": "ppo"}, {"env_step": 5400000, "rew": 908.4, "rew_std": 370.1597492975161, "Agent": "ppo"}, {"env_step": 5500000, "rew": 858.8, "rew_std": 425.68175906420987, "Agent": "ppo"}, {"env_step": 5600000, "rew": 917.0, "rew_std": 371.99059127886557, "Agent": "ppo"}, {"env_step": 5700000, "rew": 926.0, "rew_std": 387.98144285519635, "Agent": "ppo"}, {"env_step": 5800000, "rew": 924.4, "rew_std": 380.3491027989944, "Agent": "ppo"}, {"env_step": 5900000, "rew": 942.8, "rew_std": 393.7331075741536, "Agent": "ppo"}, {"env_step": 6000000, "rew": 953.0, "rew_std": 385.58864091152896, "Agent": "ppo"}, {"env_step": 6100000, "rew": 931.0, "rew_std": 386.86871158055675, "Agent": "ppo"}, {"env_step": 6200000, "rew": 947.2, "rew_std": 389.32474876380513, "Agent": "ppo"}, {"env_step": 6300000, "rew": 954.4, "rew_std": 382.7979101301365, "Agent": "ppo"}, {"env_step": 6400000, "rew": 966.4, "rew_std": 395.44081731657394, "Agent": "ppo"}, {"env_step": 6500000, "rew": 952.8, "rew_std": 410.60073063744056, "Agent": "ppo"}, {"env_step": 6600000, "rew": 952.2, "rew_std": 401.42491203212586, "Agent": "ppo"}, {"env_step": 6700000, "rew": 974.0, "rew_std": 328.4058464765815, "Agent": "ppo"}, {"env_step": 6800000, "rew": 992.8, "rew_std": 326.9534523445195, "Agent": "ppo"}, {"env_step": 6900000, "rew": 986.8, "rew_std": 301.4514222888988, "Agent": "ppo"}, {"env_step": 7000000, "rew": 1018.6, "rew_std": 286.2880367741551, "Agent": "ppo"}, {"env_step": 7100000, "rew": 995.0, "rew_std": 302.6037012331475, "Agent": "ppo"}, {"env_step": 7200000, "rew": 986.4, "rew_std": 286.39804468606275, "Agent": "ppo"}, {"env_step": 7300000, "rew": 997.6, "rew_std": 314.9346598899524, "Agent": "ppo"}, {"env_step": 7400000, "rew": 1001.6, "rew_std": 297.1273127802289, "Agent": "ppo"}, {"env_step": 7500000, "rew": 986.2, "rew_std": 317.547413782572, "Agent": "ppo"}, {"env_step": 7600000, "rew": 1022.6, "rew_std": 319.39574198789813, "Agent": "ppo"}, {"env_step": 7700000, "rew": 1019.6, "rew_std": 320.0022499920899, "Agent": "ppo"}, {"env_step": 7800000, "rew": 1003.2, "rew_std": 321.1394712582058, "Agent": "ppo"}, {"env_step": 7900000, "rew": 1024.8, "rew_std": 323.7668296783968, "Agent": "ppo"}, {"env_step": 8000000, "rew": 1018.2, "rew_std": 316.27260393527604, "Agent": "ppo"}, {"env_step": 8100000, "rew": 1014.0, "rew_std": 307.1468704056742, "Agent": "ppo"}, {"env_step": 8200000, "rew": 1004.0, "rew_std": 307.0244289954791, "Agent": "ppo"}, {"env_step": 8300000, "rew": 1017.0, "rew_std": 338.5513255032389, "Agent": "ppo"}, {"env_step": 8400000, "rew": 1018.2, "rew_std": 328.83363574914296, "Agent": "ppo"}, {"env_step": 8500000, "rew": 1000.2, "rew_std": 347.273321751038, "Agent": "ppo"}, {"env_step": 8600000, "rew": 1010.0, "rew_std": 339.7281265953704, "Agent": "ppo"}, {"env_step": 8700000, "rew": 1023.2, "rew_std": 338.9633608518773, "Agent": "ppo"}, {"env_step": 8800000, "rew": 1024.2, "rew_std": 358.4092074710135, "Agent": "ppo"}, {"env_step": 8900000, "rew": 953.8, "rew_std": 392.8872102779626, "Agent": "ppo"}, {"env_step": 9000000, "rew": 951.4, "rew_std": 427.01105372109515, "Agent": "ppo"}, {"env_step": 9100000, "rew": 979.0, "rew_std": 400.3560914985558, "Agent": "ppo"}, {"env_step": 9200000, "rew": 991.6, "rew_std": 416.51199262446204, "Agent": "ppo"}, {"env_step": 9300000, "rew": 989.2, "rew_std": 423.34780027773854, "Agent": "ppo"}, {"env_step": 9400000, "rew": 993.2, "rew_std": 427.4821165850099, "Agent": "ppo"}, {"env_step": 9500000, "rew": 978.2, "rew_std": 419.3723405280801, "Agent": "ppo"}, {"env_step": 9600000, "rew": 992.2, "rew_std": 368.224333796668, "Agent": "ppo"}, {"env_step": 9700000, "rew": 1026.4, "rew_std": 379.41723735223206, "Agent": "ppo"}, {"env_step": 9800000, "rew": 1025.2, "rew_std": 368.4651408206752, "Agent": "ppo"}, {"env_step": 9900000, "rew": 1035.2, "rew_std": 353.61696791867894, "Agent": "ppo"}, {"env_step": 10000000, "rew": 1025.8, "rew_std": 358.64461518333155, "Agent": "ppo"}] \ No newline at end of file diff --git a/examples/atari/benchmark/SpaceInvadersNoFrameskip-v4/result.json b/examples/atari/benchmark/SpaceInvadersNoFrameskip-v4/result.json new file mode 100644 index 0000000000000000000000000000000000000000..c5fde60c4cb11d6ebd3b7a7d4ad57a5a52825439 --- /dev/null +++ b/examples/atari/benchmark/SpaceInvadersNoFrameskip-v4/result.json @@ -0,0 +1 @@ +[{"env_step": 0, "rew": 131.5, "rew_std": 72.94964016360875, "Agent": "c51"}, {"env_step": 100000, "rew": 175.3, "rew_std": 78.69695038564073, "Agent": "c51"}, {"env_step": 200000, "rew": 162.45, "rew_std": 70.4969680199085, "Agent": "c51"}, {"env_step": 300000, "rew": 216.1, "rew_std": 43.270544253568154, "Agent": "c51"}, {"env_step": 400000, "rew": 310.5, "rew_std": 67.2647753285477, "Agent": "c51"}, {"env_step": 500000, "rew": 371.6, "rew_std": 66.80149698921424, "Agent": "c51"}, {"env_step": 600000, "rew": 429.15, "rew_std": 59.84816204362503, "Agent": "c51"}, {"env_step": 700000, "rew": 411.25, "rew_std": 67.41819116529307, "Agent": "c51"}, {"env_step": 800000, "rew": 433.9, "rew_std": 56.295115241022465, "Agent": "c51"}, {"env_step": 900000, "rew": 468.25, "rew_std": 55.36661900459518, "Agent": "c51"}, {"env_step": 1000000, "rew": 479.2, "rew_std": 77.91508198032008, "Agent": "c51"}, {"env_step": 1100000, "rew": 517.25, "rew_std": 102.26613564616588, "Agent": "c51"}, {"env_step": 1200000, "rew": 511.4, "rew_std": 65.065659145205, "Agent": "c51"}, {"env_step": 1300000, "rew": 573.65, "rew_std": 58.454277003483675, "Agent": "c51"}, {"env_step": 1400000, "rew": 556.35, "rew_std": 71.15758919468814, "Agent": "c51"}, {"env_step": 1500000, "rew": 553.5, "rew_std": 38.90629769073382, "Agent": "c51"}, {"env_step": 1600000, "rew": 608.25, "rew_std": 35.925791570959156, "Agent": "c51"}, {"env_step": 1700000, "rew": 571.35, "rew_std": 70.93273221863092, "Agent": "c51"}, {"env_step": 1800000, "rew": 624.95, "rew_std": 94.19406828457936, "Agent": "c51"}, {"env_step": 1900000, "rew": 640.2, "rew_std": 96.80113635696638, "Agent": "c51"}, {"env_step": 2000000, "rew": 617.15, "rew_std": 67.16139143883188, "Agent": "c51"}, {"env_step": 2100000, "rew": 617.65, "rew_std": 94.80454894149331, "Agent": "c51"}, {"env_step": 2200000, "rew": 613.0, "rew_std": 115.9631406956538, "Agent": "c51"}, {"env_step": 2300000, "rew": 659.9, "rew_std": 81.69969400187493, "Agent": "c51"}, {"env_step": 2400000, "rew": 698.0, "rew_std": 80.07558928912107, "Agent": "c51"}, {"env_step": 2500000, "rew": 651.8, "rew_std": 74.98106427625578, "Agent": "c51"}, {"env_step": 2600000, "rew": 619.8, "rew_std": 103.08035700365032, "Agent": "c51"}, {"env_step": 2700000, "rew": 629.75, "rew_std": 97.50211536166792, "Agent": "c51"}, {"env_step": 2800000, "rew": 655.95, "rew_std": 62.26975590123989, "Agent": "c51"}, {"env_step": 2900000, "rew": 722.85, "rew_std": 111.88946554524246, "Agent": "c51"}, {"env_step": 3000000, "rew": 689.7, "rew_std": 116.6568043450531, "Agent": "c51"}, {"env_step": 3100000, "rew": 738.45, "rew_std": 90.73572890543173, "Agent": "c51"}, {"env_step": 3200000, "rew": 742.85, "rew_std": 114.76150269145138, "Agent": "c51"}, {"env_step": 3300000, "rew": 722.95, "rew_std": 115.68069199308933, "Agent": "c51"}, {"env_step": 3400000, "rew": 775.55, "rew_std": 107.4728919309423, "Agent": "c51"}, {"env_step": 3500000, "rew": 796.25, "rew_std": 82.22902468106989, "Agent": "c51"}, {"env_step": 3600000, "rew": 742.3, "rew_std": 119.33360800713268, "Agent": "c51"}, {"env_step": 3700000, "rew": 758.55, "rew_std": 96.7464340428111, "Agent": "c51"}, {"env_step": 3800000, "rew": 664.4, "rew_std": 103.92925478420403, "Agent": "c51"}, {"env_step": 3900000, "rew": 738.85, "rew_std": 97.65578579889673, "Agent": "c51"}, {"env_step": 4000000, "rew": 689.45, "rew_std": 113.2767518072442, "Agent": "c51"}, {"env_step": 4100000, "rew": 832.35, "rew_std": 157.338973239309, "Agent": "c51"}, {"env_step": 4200000, "rew": 672.15, "rew_std": 78.5875467233836, "Agent": "c51"}, {"env_step": 4300000, "rew": 722.05, "rew_std": 77.35387837723458, "Agent": "c51"}, {"env_step": 4400000, "rew": 897.7, "rew_std": 116.28654264359226, "Agent": "c51"}, {"env_step": 4500000, "rew": 823.1, "rew_std": 92.00076086641893, "Agent": "c51"}, {"env_step": 4600000, "rew": 690.8, "rew_std": 81.19273366502695, "Agent": "c51"}, {"env_step": 4700000, "rew": 811.15, "rew_std": 99.55050225890375, "Agent": "c51"}, {"env_step": 4800000, "rew": 755.0, "rew_std": 107.04998832321282, "Agent": "c51"}, {"env_step": 4900000, "rew": 805.75, "rew_std": 115.61017472523774, "Agent": "c51"}, {"env_step": 5000000, "rew": 760.75, "rew_std": 112.35040053333142, "Agent": "c51"}, {"env_step": 5100000, "rew": 820.95, "rew_std": 141.30984572916356, "Agent": "c51"}, {"env_step": 5200000, "rew": 797.1, "rew_std": 154.61869227231227, "Agent": "c51"}, {"env_step": 5300000, "rew": 825.8, "rew_std": 151.04969380968635, "Agent": "c51"}, {"env_step": 5400000, "rew": 787.95, "rew_std": 107.38399554868501, "Agent": "c51"}, {"env_step": 5500000, "rew": 825.0, "rew_std": 65.9052349969257, "Agent": "c51"}, {"env_step": 5600000, "rew": 822.5, "rew_std": 149.53043168532616, "Agent": "c51"}, {"env_step": 5700000, "rew": 865.65, "rew_std": 141.1168044564502, "Agent": "c51"}, {"env_step": 5800000, "rew": 756.65, "rew_std": 111.32251569201982, "Agent": "c51"}, {"env_step": 5900000, "rew": 833.3, "rew_std": 179.73886613640354, "Agent": "c51"}, {"env_step": 6000000, "rew": 838.7, "rew_std": 96.8695514596821, "Agent": "c51"}, {"env_step": 6100000, "rew": 797.05, "rew_std": 143.24549731143384, "Agent": "c51"}, {"env_step": 6200000, "rew": 787.65, "rew_std": 143.62260441866385, "Agent": "c51"}, {"env_step": 6300000, "rew": 836.35, "rew_std": 147.87225060842215, "Agent": "c51"}, {"env_step": 6400000, "rew": 936.1, "rew_std": 174.2318570181699, "Agent": "c51"}, {"env_step": 6500000, "rew": 878.6, "rew_std": 130.01822949109868, "Agent": "c51"}, {"env_step": 6600000, "rew": 869.35, "rew_std": 131.44923925226803, "Agent": "c51"}, {"env_step": 6700000, "rew": 831.1, "rew_std": 128.2900619689616, "Agent": "c51"}, {"env_step": 6800000, "rew": 848.35, "rew_std": 173.9251347563083, "Agent": "c51"}, {"env_step": 6900000, "rew": 833.4, "rew_std": 157.15466903658958, "Agent": "c51"}, {"env_step": 7000000, "rew": 832.3, "rew_std": 154.00165583525393, "Agent": "c51"}, {"env_step": 7100000, "rew": 832.85, "rew_std": 96.22423031648525, "Agent": "c51"}, {"env_step": 7200000, "rew": 867.75, "rew_std": 155.06728378352412, "Agent": "c51"}, {"env_step": 7300000, "rew": 881.55, "rew_std": 176.80051611915616, "Agent": "c51"}, {"env_step": 7400000, "rew": 848.7, "rew_std": 122.09365257866602, "Agent": "c51"}, {"env_step": 7500000, "rew": 891.75, "rew_std": 136.24688069823839, "Agent": "c51"}, {"env_step": 7600000, "rew": 947.85, "rew_std": 155.2700953178042, "Agent": "c51"}, {"env_step": 7700000, "rew": 810.6, "rew_std": 61.26001958863546, "Agent": "c51"}, {"env_step": 7800000, "rew": 809.45, "rew_std": 132.21695239264895, "Agent": "c51"}, {"env_step": 7900000, "rew": 933.9, "rew_std": 128.37421859547965, "Agent": "c51"}, {"env_step": 8000000, "rew": 859.35, "rew_std": 175.26181129955265, "Agent": "c51"}, {"env_step": 8100000, "rew": 922.05, "rew_std": 125.99334307811664, "Agent": "c51"}, {"env_step": 8200000, "rew": 878.3, "rew_std": 114.7343017584541, "Agent": "c51"}, {"env_step": 8300000, "rew": 895.25, "rew_std": 212.38682280216915, "Agent": "c51"}, {"env_step": 8400000, "rew": 877.2, "rew_std": 165.07804214976625, "Agent": "c51"}, {"env_step": 8500000, "rew": 872.55, "rew_std": 171.04918152391141, "Agent": "c51"}, {"env_step": 8600000, "rew": 921.95, "rew_std": 176.95118677194566, "Agent": "c51"}, {"env_step": 8700000, "rew": 881.1, "rew_std": 133.7792584820233, "Agent": "c51"}, {"env_step": 8800000, "rew": 875.65, "rew_std": 134.4557641010604, "Agent": "c51"}, {"env_step": 8900000, "rew": 865.25, "rew_std": 158.70353650753975, "Agent": "c51"}, {"env_step": 9000000, "rew": 873.9, "rew_std": 141.55260506257028, "Agent": "c51"}, {"env_step": 9100000, "rew": 923.35, "rew_std": 146.47082473994607, "Agent": "c51"}, {"env_step": 9200000, "rew": 894.5, "rew_std": 181.60740623664003, "Agent": "c51"}, {"env_step": 9300000, "rew": 873.55, "rew_std": 141.41206631684582, "Agent": "c51"}, {"env_step": 9400000, "rew": 919.55, "rew_std": 149.74168591277447, "Agent": "c51"}, {"env_step": 9500000, "rew": 886.55, "rew_std": 105.35142381572258, "Agent": "c51"}, {"env_step": 9600000, "rew": 860.15, "rew_std": 190.00829587152242, "Agent": "c51"}, {"env_step": 9700000, "rew": 919.65, "rew_std": 205.74645197426855, "Agent": "c51"}, {"env_step": 9800000, "rew": 877.75, "rew_std": 124.24014045388068, "Agent": "c51"}, {"env_step": 9900000, "rew": 880.1, "rew_std": 104.3469692899607, "Agent": "c51"}, {"env_step": 10000000, "rew": 880.55, "rew_std": 143.47516335589236, "Agent": "c51"}, {"env_step": 0, "rew": 189.2, "rew_std": 81.34807926435633, "Agent": "dqn"}, {"env_step": 100000, "rew": 245.5, "rew_std": 101.77032966439678, "Agent": "dqn"}, {"env_step": 200000, "rew": 230.55, "rew_std": 65.04092942140356, "Agent": "dqn"}, {"env_step": 300000, "rew": 289.9, "rew_std": 61.85903329344874, "Agent": "dqn"}, {"env_step": 400000, "rew": 272.75, "rew_std": 70.29838191594456, "Agent": "dqn"}, {"env_step": 500000, "rew": 295.65, "rew_std": 66.4168841485356, "Agent": "dqn"}, {"env_step": 600000, "rew": 313.0, "rew_std": 92.10130292237999, "Agent": "dqn"}, {"env_step": 700000, "rew": 321.55, "rew_std": 54.48323136525586, "Agent": "dqn"}, {"env_step": 800000, "rew": 398.9, "rew_std": 85.19647880047626, "Agent": "dqn"}, {"env_step": 900000, "rew": 368.95, "rew_std": 76.96442359947874, "Agent": "dqn"}, {"env_step": 1000000, "rew": 365.3, "rew_std": 76.99974025930217, "Agent": "dqn"}, {"env_step": 1100000, "rew": 436.35, "rew_std": 96.35638276730815, "Agent": "dqn"}, {"env_step": 1200000, "rew": 403.95, "rew_std": 62.60089855585142, "Agent": "dqn"}, {"env_step": 1300000, "rew": 449.65, "rew_std": 59.11558593129226, "Agent": "dqn"}, {"env_step": 1400000, "rew": 420.15, "rew_std": 76.2200924953519, "Agent": "dqn"}, {"env_step": 1500000, "rew": 467.6, "rew_std": 49.78092003970999, "Agent": "dqn"}, {"env_step": 1600000, "rew": 407.05, "rew_std": 62.21030863128715, "Agent": "dqn"}, {"env_step": 1700000, "rew": 471.3, "rew_std": 63.85264285838136, "Agent": "dqn"}, {"env_step": 1800000, "rew": 440.8, "rew_std": 76.10525606027484, "Agent": "dqn"}, {"env_step": 1900000, "rew": 511.2, "rew_std": 51.454931736423475, "Agent": "dqn"}, {"env_step": 2000000, "rew": 446.15, "rew_std": 89.50224857510564, "Agent": "dqn"}, {"env_step": 2100000, "rew": 512.9, "rew_std": 72.38224920517462, "Agent": "dqn"}, {"env_step": 2200000, "rew": 458.2, "rew_std": 57.65682960413277, "Agent": "dqn"}, {"env_step": 2300000, "rew": 443.95, "rew_std": 83.05884961880943, "Agent": "dqn"}, {"env_step": 2400000, "rew": 429.95, "rew_std": 97.29631287977978, "Agent": "dqn"}, {"env_step": 2500000, "rew": 518.2, "rew_std": 83.67950764673512, "Agent": "dqn"}, {"env_step": 2600000, "rew": 500.35, "rew_std": 86.74043174898313, "Agent": "dqn"}, {"env_step": 2700000, "rew": 472.85, "rew_std": 108.66325275823469, "Agent": "dqn"}, {"env_step": 2800000, "rew": 483.6, "rew_std": 62.72750592842027, "Agent": "dqn"}, {"env_step": 2900000, "rew": 442.75, "rew_std": 133.11221769619797, "Agent": "dqn"}, {"env_step": 3000000, "rew": 496.85, "rew_std": 91.74694817812743, "Agent": "dqn"}, {"env_step": 3100000, "rew": 488.3, "rew_std": 120.88097451625711, "Agent": "dqn"}, {"env_step": 3200000, "rew": 496.5, "rew_std": 77.99935897172489, "Agent": "dqn"}, {"env_step": 3300000, "rew": 489.05, "rew_std": 103.76065005578945, "Agent": "dqn"}, {"env_step": 3400000, "rew": 499.35, "rew_std": 79.70321511708295, "Agent": "dqn"}, {"env_step": 3500000, "rew": 518.35, "rew_std": 78.03109956933838, "Agent": "dqn"}, {"env_step": 3600000, "rew": 521.85, "rew_std": 87.429986274733, "Agent": "dqn"}, {"env_step": 3700000, "rew": 560.4, "rew_std": 98.244796299855, "Agent": "dqn"}, {"env_step": 3800000, "rew": 551.25, "rew_std": 61.10615762752556, "Agent": "dqn"}, {"env_step": 3900000, "rew": 520.0, "rew_std": 123.95079668965424, "Agent": "dqn"}, {"env_step": 4000000, "rew": 568.7, "rew_std": 94.99899999473679, "Agent": "dqn"}, {"env_step": 4100000, "rew": 540.9, "rew_std": 120.96668962983156, "Agent": "dqn"}, {"env_step": 4200000, "rew": 542.15, "rew_std": 113.86111056897346, "Agent": "dqn"}, {"env_step": 4300000, "rew": 564.95, "rew_std": 118.0547436573389, "Agent": "dqn"}, {"env_step": 4400000, "rew": 560.8, "rew_std": 101.43401796241733, "Agent": "dqn"}, {"env_step": 4500000, "rew": 572.3, "rew_std": 100.23228022947497, "Agent": "dqn"}, {"env_step": 4600000, "rew": 577.25, "rew_std": 122.57757747646998, "Agent": "dqn"}, {"env_step": 4700000, "rew": 597.7, "rew_std": 134.22410364759378, "Agent": "dqn"}, {"env_step": 4800000, "rew": 561.1, "rew_std": 95.81330805269172, "Agent": "dqn"}, {"env_step": 4900000, "rew": 556.8, "rew_std": 126.28364898117255, "Agent": "dqn"}, {"env_step": 5000000, "rew": 604.65, "rew_std": 157.45920900347494, "Agent": "dqn"}, {"env_step": 5100000, "rew": 551.65, "rew_std": 108.79270425906326, "Agent": "dqn"}, {"env_step": 5200000, "rew": 511.25, "rew_std": 115.59849696254706, "Agent": "dqn"}, {"env_step": 5300000, "rew": 551.55, "rew_std": 108.7601604448982, "Agent": "dqn"}, {"env_step": 5400000, "rew": 569.1, "rew_std": 100.45765276971187, "Agent": "dqn"}, {"env_step": 5500000, "rew": 502.8, "rew_std": 87.80495430213492, "Agent": "dqn"}, {"env_step": 5600000, "rew": 569.9, "rew_std": 101.9560689709053, "Agent": "dqn"}, {"env_step": 5700000, "rew": 518.95, "rew_std": 81.64326365353115, "Agent": "dqn"}, {"env_step": 5800000, "rew": 488.6, "rew_std": 77.79293027004447, "Agent": "dqn"}, {"env_step": 5900000, "rew": 546.3, "rew_std": 118.97693053697428, "Agent": "dqn"}, {"env_step": 6000000, "rew": 554.15, "rew_std": 153.27508766919692, "Agent": "dqn"}, {"env_step": 6100000, "rew": 515.95, "rew_std": 145.70954841739095, "Agent": "dqn"}, {"env_step": 6200000, "rew": 475.1, "rew_std": 145.77187657432418, "Agent": "dqn"}, {"env_step": 6300000, "rew": 559.6, "rew_std": 122.7236326059492, "Agent": "dqn"}, {"env_step": 6400000, "rew": 511.9, "rew_std": 133.77664220632838, "Agent": "dqn"}, {"env_step": 6500000, "rew": 543.75, "rew_std": 91.17270699063398, "Agent": "dqn"}, {"env_step": 6600000, "rew": 521.35, "rew_std": 109.17922192432037, "Agent": "dqn"}, {"env_step": 6700000, "rew": 516.6, "rew_std": 75.09121120344244, "Agent": "dqn"}, {"env_step": 6800000, "rew": 583.9, "rew_std": 113.4287882329702, "Agent": "dqn"}, {"env_step": 6900000, "rew": 526.35, "rew_std": 107.32288898459639, "Agent": "dqn"}, {"env_step": 7000000, "rew": 540.0, "rew_std": 147.050841548085, "Agent": "dqn"}, {"env_step": 7100000, "rew": 521.75, "rew_std": 136.49674904553586, "Agent": "dqn"}, {"env_step": 7200000, "rew": 561.3, "rew_std": 84.21229126439917, "Agent": "dqn"}, {"env_step": 7300000, "rew": 491.65, "rew_std": 89.11146110349667, "Agent": "dqn"}, {"env_step": 7400000, "rew": 521.25, "rew_std": 41.02392594572099, "Agent": "dqn"}, {"env_step": 7500000, "rew": 501.35, "rew_std": 102.60459297711775, "Agent": "dqn"}, {"env_step": 7600000, "rew": 491.5, "rew_std": 119.28767748598344, "Agent": "dqn"}, {"env_step": 7700000, "rew": 503.5, "rew_std": 155.50707379408823, "Agent": "dqn"}, {"env_step": 7800000, "rew": 504.8, "rew_std": 110.91780740710664, "Agent": "dqn"}, {"env_step": 7900000, "rew": 551.55, "rew_std": 103.75245780221306, "Agent": "dqn"}, {"env_step": 8000000, "rew": 528.65, "rew_std": 35.64621298258764, "Agent": "dqn"}, {"env_step": 8100000, "rew": 521.05, "rew_std": 94.89137210516033, "Agent": "dqn"}, {"env_step": 8200000, "rew": 519.65, "rew_std": 79.50504700960813, "Agent": "dqn"}, {"env_step": 8300000, "rew": 544.2, "rew_std": 153.77226668030877, "Agent": "dqn"}, {"env_step": 8400000, "rew": 540.2, "rew_std": 130.77427116982912, "Agent": "dqn"}, {"env_step": 8500000, "rew": 524.4, "rew_std": 132.1672803684785, "Agent": "dqn"}, {"env_step": 8600000, "rew": 572.25, "rew_std": 186.15480788848834, "Agent": "dqn"}, {"env_step": 8700000, "rew": 564.35, "rew_std": 140.56280624688736, "Agent": "dqn"}, {"env_step": 8800000, "rew": 486.55, "rew_std": 130.3735881994509, "Agent": "dqn"}, {"env_step": 8900000, "rew": 576.0, "rew_std": 83.01776918226604, "Agent": "dqn"}, {"env_step": 9000000, "rew": 553.25, "rew_std": 74.99208291546515, "Agent": "dqn"}, {"env_step": 9100000, "rew": 488.35, "rew_std": 62.5244152311719, "Agent": "dqn"}, {"env_step": 9200000, "rew": 557.05, "rew_std": 100.2916870931983, "Agent": "dqn"}, {"env_step": 9300000, "rew": 541.2, "rew_std": 124.49461835758201, "Agent": "dqn"}, {"env_step": 9400000, "rew": 445.8, "rew_std": 82.59031420209031, "Agent": "dqn"}, {"env_step": 9500000, "rew": 530.65, "rew_std": 123.81034892124325, "Agent": "dqn"}, {"env_step": 9600000, "rew": 553.6, "rew_std": 85.22347094550891, "Agent": "dqn"}, {"env_step": 9700000, "rew": 590.65, "rew_std": 110.74724601542019, "Agent": "dqn"}, {"env_step": 9800000, "rew": 561.0, "rew_std": 100.72661018817222, "Agent": "dqn"}, {"env_step": 9900000, "rew": 525.75, "rew_std": 88.93038007340348, "Agent": "dqn"}, {"env_step": 10000000, "rew": 522.1, "rew_std": 134.1502515838118, "Agent": "dqn"}, {"env_step": 0, "rew": 230.2, "rew_std": 115.47709729639034, "Agent": "fqf"}, {"env_step": 100000, "rew": 197.8, "rew_std": 50.49366296873302, "Agent": "fqf"}, {"env_step": 200000, "rew": 274.45, "rew_std": 80.67882311982494, "Agent": "fqf"}, {"env_step": 300000, "rew": 331.75, "rew_std": 106.72845215780092, "Agent": "fqf"}, {"env_step": 400000, "rew": 342.3, "rew_std": 107.6478518132155, "Agent": "fqf"}, {"env_step": 500000, "rew": 344.15, "rew_std": 89.13586539659555, "Agent": "fqf"}, {"env_step": 600000, "rew": 418.7, "rew_std": 101.0124249783164, "Agent": "fqf"}, {"env_step": 700000, "rew": 455.75, "rew_std": 68.63499471843791, "Agent": "fqf"}, {"env_step": 800000, "rew": 513.6, "rew_std": 56.154608003261856, "Agent": "fqf"}, {"env_step": 900000, "rew": 530.0, "rew_std": 71.92808908903392, "Agent": "fqf"}, {"env_step": 1000000, "rew": 524.25, "rew_std": 79.38332633494265, "Agent": "fqf"}, {"env_step": 1100000, "rew": 552.4, "rew_std": 55.314916613875496, "Agent": "fqf"}, {"env_step": 1200000, "rew": 592.0, "rew_std": 123.01138158723363, "Agent": "fqf"}, {"env_step": 1300000, "rew": 626.2, "rew_std": 132.57247074713513, "Agent": "fqf"}, {"env_step": 1400000, "rew": 666.45, "rew_std": 91.45120283517325, "Agent": "fqf"}, {"env_step": 1500000, "rew": 633.95, "rew_std": 123.25632843793458, "Agent": "fqf"}, {"env_step": 1600000, "rew": 672.8, "rew_std": 103.08981520984506, "Agent": "fqf"}, {"env_step": 1700000, "rew": 617.5, "rew_std": 140.43023178788818, "Agent": "fqf"}, {"env_step": 1800000, "rew": 673.4, "rew_std": 67.20520813151315, "Agent": "fqf"}, {"env_step": 1900000, "rew": 668.5, "rew_std": 96.4898958440727, "Agent": "fqf"}, {"env_step": 2000000, "rew": 667.75, "rew_std": 174.6682927723289, "Agent": "fqf"}, {"env_step": 2100000, "rew": 699.8, "rew_std": 121.638028593035, "Agent": "fqf"}, {"env_step": 2200000, "rew": 714.25, "rew_std": 161.04211405716208, "Agent": "fqf"}, {"env_step": 2300000, "rew": 747.05, "rew_std": 150.87817105201134, "Agent": "fqf"}, {"env_step": 2400000, "rew": 735.6, "rew_std": 94.85483646077304, "Agent": "fqf"}, {"env_step": 2500000, "rew": 686.05, "rew_std": 107.51405722043978, "Agent": "fqf"}, {"env_step": 2600000, "rew": 727.95, "rew_std": 72.91243035312978, "Agent": "fqf"}, {"env_step": 2700000, "rew": 804.25, "rew_std": 120.79305650574456, "Agent": "fqf"}, {"env_step": 2800000, "rew": 799.6, "rew_std": 149.8772497746072, "Agent": "fqf"}, {"env_step": 2900000, "rew": 837.3, "rew_std": 153.77031573096286, "Agent": "fqf"}, {"env_step": 3000000, "rew": 825.05, "rew_std": 126.28864754996785, "Agent": "fqf"}, {"env_step": 3100000, "rew": 897.25, "rew_std": 165.72601636435965, "Agent": "fqf"}, {"env_step": 3200000, "rew": 835.4, "rew_std": 150.30997970860085, "Agent": "fqf"}, {"env_step": 3300000, "rew": 886.3, "rew_std": 185.1247147195641, "Agent": "fqf"}, {"env_step": 3400000, "rew": 787.0, "rew_std": 143.65931922433714, "Agent": "fqf"}, {"env_step": 3500000, "rew": 887.85, "rew_std": 202.29138019203884, "Agent": "fqf"}, {"env_step": 3600000, "rew": 860.05, "rew_std": 139.24967683984045, "Agent": "fqf"}, {"env_step": 3700000, "rew": 864.55, "rew_std": 175.68529392069217, "Agent": "fqf"}, {"env_step": 3800000, "rew": 982.3, "rew_std": 242.51723650083102, "Agent": "fqf"}, {"env_step": 3900000, "rew": 976.7, "rew_std": 136.36333818149217, "Agent": "fqf"}, {"env_step": 4000000, "rew": 940.7, "rew_std": 151.9544668642551, "Agent": "fqf"}, {"env_step": 4100000, "rew": 923.85, "rew_std": 171.70586041250894, "Agent": "fqf"}, {"env_step": 4200000, "rew": 1001.85, "rew_std": 166.15249772422925, "Agent": "fqf"}, {"env_step": 4300000, "rew": 1156.5, "rew_std": 218.98983994697107, "Agent": "fqf"}, {"env_step": 4400000, "rew": 1059.3, "rew_std": 177.61998761400702, "Agent": "fqf"}, {"env_step": 4500000, "rew": 1082.8, "rew_std": 197.79777551833087, "Agent": "fqf"}, {"env_step": 4600000, "rew": 1097.7, "rew_std": 135.50944616520283, "Agent": "fqf"}, {"env_step": 4700000, "rew": 1051.95, "rew_std": 234.61686320467248, "Agent": "fqf"}, {"env_step": 4800000, "rew": 967.5, "rew_std": 162.95152653473363, "Agent": "fqf"}, {"env_step": 4900000, "rew": 987.2, "rew_std": 236.65536123232027, "Agent": "fqf"}, {"env_step": 5000000, "rew": 1005.5, "rew_std": 246.8736518950534, "Agent": "fqf"}, {"env_step": 5100000, "rew": 1098.95, "rew_std": 251.61572387273415, "Agent": "fqf"}, {"env_step": 5200000, "rew": 1028.55, "rew_std": 254.31441661848427, "Agent": "fqf"}, {"env_step": 5300000, "rew": 1025.2, "rew_std": 199.16051315459094, "Agent": "fqf"}, {"env_step": 5400000, "rew": 1034.65, "rew_std": 224.61445300781517, "Agent": "fqf"}, {"env_step": 5500000, "rew": 1263.2, "rew_std": 179.57897983895555, "Agent": "fqf"}, {"env_step": 5600000, "rew": 1016.95, "rew_std": 171.7589080659283, "Agent": "fqf"}, {"env_step": 5700000, "rew": 1224.7, "rew_std": 191.47979005628764, "Agent": "fqf"}, {"env_step": 5800000, "rew": 1192.7, "rew_std": 211.82247283987599, "Agent": "fqf"}, {"env_step": 5900000, "rew": 1256.45, "rew_std": 423.3461024977081, "Agent": "fqf"}, {"env_step": 6000000, "rew": 1206.3, "rew_std": 349.3533454827648, "Agent": "fqf"}, {"env_step": 6100000, "rew": 1323.8, "rew_std": 327.38587324440255, "Agent": "fqf"}, {"env_step": 6200000, "rew": 1459.2, "rew_std": 292.0249304425908, "Agent": "fqf"}, {"env_step": 6300000, "rew": 1187.2, "rew_std": 343.017506841852, "Agent": "fqf"}, {"env_step": 6400000, "rew": 1257.5, "rew_std": 311.28724676735476, "Agent": "fqf"}, {"env_step": 6500000, "rew": 1111.55, "rew_std": 231.75153181802273, "Agent": "fqf"}, {"env_step": 6600000, "rew": 1306.1, "rew_std": 182.76703750950279, "Agent": "fqf"}, {"env_step": 6700000, "rew": 1163.85, "rew_std": 369.6850044294467, "Agent": "fqf"}, {"env_step": 6800000, "rew": 1146.1, "rew_std": 217.4428430645626, "Agent": "fqf"}, {"env_step": 6900000, "rew": 1197.9, "rew_std": 226.09840777856002, "Agent": "fqf"}, {"env_step": 7000000, "rew": 1633.55, "rew_std": 109.7614344840664, "Agent": "fqf"}, {"env_step": 7100000, "rew": 1409.85, "rew_std": 353.492153378261, "Agent": "fqf"}, {"env_step": 7200000, "rew": 1372.45, "rew_std": 253.83640893299764, "Agent": "fqf"}, {"env_step": 7300000, "rew": 1275.55, "rew_std": 352.75008504605637, "Agent": "fqf"}, {"env_step": 7400000, "rew": 1356.95, "rew_std": 431.4569184750663, "Agent": "fqf"}, {"env_step": 7500000, "rew": 1394.0, "rew_std": 367.9727571437864, "Agent": "fqf"}, {"env_step": 7600000, "rew": 1537.6, "rew_std": 332.01941208308887, "Agent": "fqf"}, {"env_step": 7700000, "rew": 1574.95, "rew_std": 366.9757110491102, "Agent": "fqf"}, {"env_step": 7800000, "rew": 1337.1, "rew_std": 339.9577767900008, "Agent": "fqf"}, {"env_step": 7900000, "rew": 1460.65, "rew_std": 323.72148909208977, "Agent": "fqf"}, {"env_step": 8000000, "rew": 1490.3, "rew_std": 428.0282817758658, "Agent": "fqf"}, {"env_step": 8100000, "rew": 1340.1, "rew_std": 215.97648020097003, "Agent": "fqf"}, {"env_step": 8200000, "rew": 1639.3, "rew_std": 353.53735304773664, "Agent": "fqf"}, {"env_step": 8300000, "rew": 1621.2, "rew_std": 447.873486600848, "Agent": "fqf"}, {"env_step": 8400000, "rew": 1636.0, "rew_std": 445.03949262958673, "Agent": "fqf"}, {"env_step": 8500000, "rew": 1507.95, "rew_std": 247.5409511575812, "Agent": "fqf"}, {"env_step": 8600000, "rew": 1481.65, "rew_std": 305.7656005831918, "Agent": "fqf"}, {"env_step": 8700000, "rew": 1612.35, "rew_std": 260.7944640900186, "Agent": "fqf"}, {"env_step": 8800000, "rew": 1461.6, "rew_std": 150.4408189289064, "Agent": "fqf"}, {"env_step": 8900000, "rew": 1593.3, "rew_std": 260.38521463401105, "Agent": "fqf"}, {"env_step": 9000000, "rew": 1542.55, "rew_std": 377.05214825007954, "Agent": "fqf"}, {"env_step": 9100000, "rew": 1562.2, "rew_std": 291.7097187273678, "Agent": "fqf"}, {"env_step": 9200000, "rew": 1645.4, "rew_std": 301.2403359445743, "Agent": "fqf"}, {"env_step": 9300000, "rew": 1787.55, "rew_std": 340.77921371468653, "Agent": "fqf"}, {"env_step": 9400000, "rew": 1669.3, "rew_std": 312.634627000913, "Agent": "fqf"}, {"env_step": 9500000, "rew": 1691.45, "rew_std": 373.35073657353354, "Agent": "fqf"}, {"env_step": 9600000, "rew": 1444.45, "rew_std": 174.3102765186264, "Agent": "fqf"}, {"env_step": 9700000, "rew": 1547.25, "rew_std": 277.2487375985687, "Agent": "fqf"}, {"env_step": 9800000, "rew": 1697.55, "rew_std": 281.5422215228117, "Agent": "fqf"}, {"env_step": 9900000, "rew": 1566.15, "rew_std": 436.8016168697181, "Agent": "fqf"}, {"env_step": 10000000, "rew": 1580.2, "rew_std": 413.62206179071256, "Agent": "fqf"}, {"env_step": 0, "rew": 104.2, "rew_std": 85.97418217116113, "Agent": "qrdqn"}, {"env_step": 100000, "rew": 222.7, "rew_std": 70.10249638921569, "Agent": "qrdqn"}, {"env_step": 200000, "rew": 284.55, "rew_std": 65.67931561762806, "Agent": "qrdqn"}, {"env_step": 300000, "rew": 298.65, "rew_std": 112.2911505863218, "Agent": "qrdqn"}, {"env_step": 400000, "rew": 401.8, "rew_std": 97.01886414507233, "Agent": "qrdqn"}, {"env_step": 500000, "rew": 307.5, "rew_std": 84.26268450506429, "Agent": "qrdqn"}, {"env_step": 600000, "rew": 300.85, "rew_std": 93.35230313173854, "Agent": "qrdqn"}, {"env_step": 700000, "rew": 326.1, "rew_std": 88.10839914559793, "Agent": "qrdqn"}, {"env_step": 800000, "rew": 373.1, "rew_std": 67.90500717914696, "Agent": "qrdqn"}, {"env_step": 900000, "rew": 435.5, "rew_std": 72.41926539257355, "Agent": "qrdqn"}, {"env_step": 1000000, "rew": 410.55, "rew_std": 76.28939965683306, "Agent": "qrdqn"}, {"env_step": 1100000, "rew": 413.0, "rew_std": 106.63043655542258, "Agent": "qrdqn"}, {"env_step": 1200000, "rew": 435.95, "rew_std": 79.69894917751425, "Agent": "qrdqn"}, {"env_step": 1300000, "rew": 429.0, "rew_std": 77.78110310351737, "Agent": "qrdqn"}, {"env_step": 1400000, "rew": 486.8, "rew_std": 72.63580384355913, "Agent": "qrdqn"}, {"env_step": 1500000, "rew": 430.25, "rew_std": 113.51965688813546, "Agent": "qrdqn"}, {"env_step": 1600000, "rew": 468.6, "rew_std": 107.82086996495623, "Agent": "qrdqn"}, {"env_step": 1700000, "rew": 475.6, "rew_std": 44.46223116308942, "Agent": "qrdqn"}, {"env_step": 1800000, "rew": 501.05, "rew_std": 96.80765723846436, "Agent": "qrdqn"}, {"env_step": 1900000, "rew": 462.0, "rew_std": 61.099099829702894, "Agent": "qrdqn"}, {"env_step": 2000000, "rew": 496.1, "rew_std": 102.8520296348108, "Agent": "qrdqn"}, {"env_step": 2100000, "rew": 519.55, "rew_std": 87.50984230359464, "Agent": "qrdqn"}, {"env_step": 2200000, "rew": 485.35, "rew_std": 75.20473721781096, "Agent": "qrdqn"}, {"env_step": 2300000, "rew": 512.45, "rew_std": 120.2733241413074, "Agent": "qrdqn"}, {"env_step": 2400000, "rew": 489.4, "rew_std": 75.57109235680004, "Agent": "qrdqn"}, {"env_step": 2500000, "rew": 511.45, "rew_std": 56.399224285445634, "Agent": "qrdqn"}, {"env_step": 2600000, "rew": 513.45, "rew_std": 94.62939553859572, "Agent": "qrdqn"}, {"env_step": 2700000, "rew": 497.9, "rew_std": 62.18231259771544, "Agent": "qrdqn"}, {"env_step": 2800000, "rew": 509.8, "rew_std": 98.18966340710207, "Agent": "qrdqn"}, {"env_step": 2900000, "rew": 481.3, "rew_std": 49.9585828461937, "Agent": "qrdqn"}, {"env_step": 3000000, "rew": 519.35, "rew_std": 65.75106463016398, "Agent": "qrdqn"}, {"env_step": 3100000, "rew": 485.7, "rew_std": 51.05692901066417, "Agent": "qrdqn"}, {"env_step": 3200000, "rew": 518.6, "rew_std": 81.14117327226664, "Agent": "qrdqn"}, {"env_step": 3300000, "rew": 559.25, "rew_std": 62.88889011582252, "Agent": "qrdqn"}, {"env_step": 3400000, "rew": 512.15, "rew_std": 106.45211364740486, "Agent": "qrdqn"}, {"env_step": 3500000, "rew": 522.7, "rew_std": 51.48456079253275, "Agent": "qrdqn"}, {"env_step": 3600000, "rew": 565.4, "rew_std": 101.5477227711188, "Agent": "qrdqn"}, {"env_step": 3700000, "rew": 577.85, "rew_std": 100.99580436830037, "Agent": "qrdqn"}, {"env_step": 3800000, "rew": 509.1, "rew_std": 87.09069984791716, "Agent": "qrdqn"}, {"env_step": 3900000, "rew": 546.35, "rew_std": 75.80997625642684, "Agent": "qrdqn"}, {"env_step": 4000000, "rew": 516.45, "rew_std": 87.6491443198392, "Agent": "qrdqn"}, {"env_step": 4100000, "rew": 520.0, "rew_std": 117.6093108559012, "Agent": "qrdqn"}, {"env_step": 4200000, "rew": 546.85, "rew_std": 85.65396955191278, "Agent": "qrdqn"}, {"env_step": 4300000, "rew": 545.15, "rew_std": 87.88431316224755, "Agent": "qrdqn"}, {"env_step": 4400000, "rew": 489.25, "rew_std": 74.9827480157936, "Agent": "qrdqn"}, {"env_step": 4500000, "rew": 593.25, "rew_std": 62.53169196495486, "Agent": "qrdqn"}, {"env_step": 4600000, "rew": 527.6, "rew_std": 122.62091991173449, "Agent": "qrdqn"}, {"env_step": 4700000, "rew": 520.9, "rew_std": 74.23099083267041, "Agent": "qrdqn"}, {"env_step": 4800000, "rew": 598.05, "rew_std": 122.61779030793205, "Agent": "qrdqn"}, {"env_step": 4900000, "rew": 545.5, "rew_std": 72.01284607623838, "Agent": "qrdqn"}, {"env_step": 5000000, "rew": 603.55, "rew_std": 73.38203117930165, "Agent": "qrdqn"}, {"env_step": 5100000, "rew": 559.75, "rew_std": 47.47433517175359, "Agent": "qrdqn"}, {"env_step": 5200000, "rew": 558.3, "rew_std": 102.89951409020357, "Agent": "qrdqn"}, {"env_step": 5300000, "rew": 614.8, "rew_std": 85.48280528854912, "Agent": "qrdqn"}, {"env_step": 5400000, "rew": 604.75, "rew_std": 84.51220326083092, "Agent": "qrdqn"}, {"env_step": 5500000, "rew": 611.15, "rew_std": 93.74754663456532, "Agent": "qrdqn"}, {"env_step": 5600000, "rew": 520.85, "rew_std": 106.81691111429876, "Agent": "qrdqn"}, {"env_step": 5700000, "rew": 660.4, "rew_std": 123.0060567614457, "Agent": "qrdqn"}, {"env_step": 5800000, "rew": 585.9, "rew_std": 127.98120955827851, "Agent": "qrdqn"}, {"env_step": 5900000, "rew": 570.45, "rew_std": 98.77789479433139, "Agent": "qrdqn"}, {"env_step": 6000000, "rew": 641.5, "rew_std": 90.51408730136983, "Agent": "qrdqn"}, {"env_step": 6100000, "rew": 592.95, "rew_std": 103.75438544948354, "Agent": "qrdqn"}, {"env_step": 6200000, "rew": 612.3, "rew_std": 82.66323245554823, "Agent": "qrdqn"}, {"env_step": 6300000, "rew": 642.25, "rew_std": 79.97319863554289, "Agent": "qrdqn"}, {"env_step": 6400000, "rew": 652.8, "rew_std": 100.57017450516828, "Agent": "qrdqn"}, {"env_step": 6500000, "rew": 617.95, "rew_std": 120.11233283888878, "Agent": "qrdqn"}, {"env_step": 6600000, "rew": 579.0, "rew_std": 103.84363244802255, "Agent": "qrdqn"}, {"env_step": 6700000, "rew": 566.85, "rew_std": 83.09273433941141, "Agent": "qrdqn"}, {"env_step": 6800000, "rew": 572.8, "rew_std": 120.3420541622919, "Agent": "qrdqn"}, {"env_step": 6900000, "rew": 600.65, "rew_std": 63.689893232757115, "Agent": "qrdqn"}, {"env_step": 7000000, "rew": 576.3, "rew_std": 126.27830375800905, "Agent": "qrdqn"}, {"env_step": 7100000, "rew": 573.25, "rew_std": 80.12154828758615, "Agent": "qrdqn"}, {"env_step": 7200000, "rew": 580.7, "rew_std": 85.19865022404991, "Agent": "qrdqn"}, {"env_step": 7300000, "rew": 549.95, "rew_std": 106.52029149415617, "Agent": "qrdqn"}, {"env_step": 7400000, "rew": 559.15, "rew_std": 76.05657433779146, "Agent": "qrdqn"}, {"env_step": 7500000, "rew": 558.75, "rew_std": 96.65175890794745, "Agent": "qrdqn"}, {"env_step": 7600000, "rew": 628.15, "rew_std": 106.18169569186584, "Agent": "qrdqn"}, {"env_step": 7700000, "rew": 630.65, "rew_std": 103.51402078945634, "Agent": "qrdqn"}, {"env_step": 7800000, "rew": 617.2, "rew_std": 153.99938311564756, "Agent": "qrdqn"}, {"env_step": 7900000, "rew": 596.9, "rew_std": 75.67886098508619, "Agent": "qrdqn"}, {"env_step": 8000000, "rew": 528.0, "rew_std": 121.1131702169504, "Agent": "qrdqn"}, {"env_step": 8100000, "rew": 606.8, "rew_std": 102.84313297444804, "Agent": "qrdqn"}, {"env_step": 8200000, "rew": 591.8, "rew_std": 142.44142655842785, "Agent": "qrdqn"}, {"env_step": 8300000, "rew": 550.85, "rew_std": 137.40925187191726, "Agent": "qrdqn"}, {"env_step": 8400000, "rew": 569.7, "rew_std": 142.38858802586674, "Agent": "qrdqn"}, {"env_step": 8500000, "rew": 603.85, "rew_std": 53.07732566736949, "Agent": "qrdqn"}, {"env_step": 8600000, "rew": 570.5, "rew_std": 92.73537620563148, "Agent": "qrdqn"}, {"env_step": 8700000, "rew": 667.8, "rew_std": 81.47244933104687, "Agent": "qrdqn"}, {"env_step": 8800000, "rew": 550.0, "rew_std": 169.62207993065053, "Agent": "qrdqn"}, {"env_step": 8900000, "rew": 636.75, "rew_std": 63.59490938746591, "Agent": "qrdqn"}, {"env_step": 9000000, "rew": 586.6, "rew_std": 137.99434771033194, "Agent": "qrdqn"}, {"env_step": 9100000, "rew": 609.55, "rew_std": 123.41463649016674, "Agent": "qrdqn"}, {"env_step": 9200000, "rew": 626.9, "rew_std": 111.03260782310753, "Agent": "qrdqn"}, {"env_step": 9300000, "rew": 624.4, "rew_std": 161.2231062844281, "Agent": "qrdqn"}, {"env_step": 9400000, "rew": 633.05, "rew_std": 107.944997568206, "Agent": "qrdqn"}, {"env_step": 9500000, "rew": 552.85, "rew_std": 160.7512130591866, "Agent": "qrdqn"}, {"env_step": 9600000, "rew": 555.75, "rew_std": 80.15804700714708, "Agent": "qrdqn"}, {"env_step": 9700000, "rew": 582.25, "rew_std": 108.67411145254421, "Agent": "qrdqn"}, {"env_step": 9800000, "rew": 635.15, "rew_std": 106.95303876000905, "Agent": "qrdqn"}, {"env_step": 9900000, "rew": 597.35, "rew_std": 98.45660211484042, "Agent": "qrdqn"}, {"env_step": 10000000, "rew": 550.15, "rew_std": 221.34204864869213, "Agent": "qrdqn"}, {"env_step": 0, "rew": 193.85, "rew_std": 59.73192195133185, "Agent": "iqn"}, {"env_step": 100000, "rew": 178.3, "rew_std": 102.48638934024362, "Agent": "iqn"}, {"env_step": 200000, "rew": 275.5, "rew_std": 63.457466069801434, "Agent": "iqn"}, {"env_step": 300000, "rew": 309.5, "rew_std": 43.517812445020716, "Agent": "iqn"}, {"env_step": 400000, "rew": 321.25, "rew_std": 78.36525058978629, "Agent": "iqn"}, {"env_step": 500000, "rew": 374.9, "rew_std": 70.88398972969847, "Agent": "iqn"}, {"env_step": 600000, "rew": 344.3, "rew_std": 138.2237678548809, "Agent": "iqn"}, {"env_step": 700000, "rew": 402.65, "rew_std": 115.49524882002723, "Agent": "iqn"}, {"env_step": 800000, "rew": 502.4, "rew_std": 117.13129385437523, "Agent": "iqn"}, {"env_step": 900000, "rew": 550.95, "rew_std": 73.10145347392212, "Agent": "iqn"}, {"env_step": 1000000, "rew": 542.7, "rew_std": 35.80237422294784, "Agent": "iqn"}, {"env_step": 1100000, "rew": 579.8, "rew_std": 73.14273716508016, "Agent": "iqn"}, {"env_step": 1200000, "rew": 617.6, "rew_std": 98.75621499429795, "Agent": "iqn"}, {"env_step": 1300000, "rew": 650.15, "rew_std": 124.93219160808795, "Agent": "iqn"}, {"env_step": 1400000, "rew": 666.45, "rew_std": 72.53014890374898, "Agent": "iqn"}, {"env_step": 1500000, "rew": 619.95, "rew_std": 84.52113640977622, "Agent": "iqn"}, {"env_step": 1600000, "rew": 633.65, "rew_std": 143.48554805275688, "Agent": "iqn"}, {"env_step": 1700000, "rew": 659.7, "rew_std": 71.16923492633597, "Agent": "iqn"}, {"env_step": 1800000, "rew": 746.45, "rew_std": 159.69368334408222, "Agent": "iqn"}, {"env_step": 1900000, "rew": 713.35, "rew_std": 60.149418118548745, "Agent": "iqn"}, {"env_step": 2000000, "rew": 708.95, "rew_std": 140.94013800191908, "Agent": "iqn"}, {"env_step": 2100000, "rew": 723.65, "rew_std": 82.80701962998064, "Agent": "iqn"}, {"env_step": 2200000, "rew": 680.25, "rew_std": 95.15467671113176, "Agent": "iqn"}, {"env_step": 2300000, "rew": 799.4, "rew_std": 105.6581279410155, "Agent": "iqn"}, {"env_step": 2400000, "rew": 761.5, "rew_std": 83.66719787347967, "Agent": "iqn"}, {"env_step": 2500000, "rew": 796.4, "rew_std": 124.04732967702287, "Agent": "iqn"}, {"env_step": 2600000, "rew": 689.55, "rew_std": 71.40919058496603, "Agent": "iqn"}, {"env_step": 2700000, "rew": 688.0, "rew_std": 80.3962063781619, "Agent": "iqn"}, {"env_step": 2800000, "rew": 757.45, "rew_std": 125.27738223637978, "Agent": "iqn"}, {"env_step": 2900000, "rew": 756.0, "rew_std": 107.79216112500946, "Agent": "iqn"}, {"env_step": 3000000, "rew": 744.8, "rew_std": 124.25441642050394, "Agent": "iqn"}, {"env_step": 3100000, "rew": 812.6, "rew_std": 155.1967783170772, "Agent": "iqn"}, {"env_step": 3200000, "rew": 757.55, "rew_std": 168.36262203945387, "Agent": "iqn"}, {"env_step": 3300000, "rew": 770.5, "rew_std": 111.00945905642456, "Agent": "iqn"}, {"env_step": 3400000, "rew": 759.45, "rew_std": 102.85267376203694, "Agent": "iqn"}, {"env_step": 3500000, "rew": 850.85, "rew_std": 200.12159428707338, "Agent": "iqn"}, {"env_step": 3600000, "rew": 785.85, "rew_std": 128.4011390136396, "Agent": "iqn"}, {"env_step": 3700000, "rew": 787.85, "rew_std": 137.1677895863311, "Agent": "iqn"}, {"env_step": 3800000, "rew": 791.75, "rew_std": 188.92634146672083, "Agent": "iqn"}, {"env_step": 3900000, "rew": 774.5, "rew_std": 60.62507731953833, "Agent": "iqn"}, {"env_step": 4000000, "rew": 872.55, "rew_std": 194.64755970728223, "Agent": "iqn"}, {"env_step": 4100000, "rew": 782.7, "rew_std": 128.8107526567561, "Agent": "iqn"}, {"env_step": 4200000, "rew": 826.2, "rew_std": 168.36124851045741, "Agent": "iqn"}, {"env_step": 4300000, "rew": 795.05, "rew_std": 154.91133754506157, "Agent": "iqn"}, {"env_step": 4400000, "rew": 824.45, "rew_std": 175.53866383221674, "Agent": "iqn"}, {"env_step": 4500000, "rew": 912.3, "rew_std": 182.14118150489747, "Agent": "iqn"}, {"env_step": 4600000, "rew": 857.55, "rew_std": 158.4999132491876, "Agent": "iqn"}, {"env_step": 4700000, "rew": 815.05, "rew_std": 86.60411364363705, "Agent": "iqn"}, {"env_step": 4800000, "rew": 806.95, "rew_std": 147.00245746245196, "Agent": "iqn"}, {"env_step": 4900000, "rew": 912.15, "rew_std": 95.28878475455545, "Agent": "iqn"}, {"env_step": 5000000, "rew": 883.0, "rew_std": 149.37553347185073, "Agent": "iqn"}, {"env_step": 5100000, "rew": 886.6, "rew_std": 178.2464305392958, "Agent": "iqn"}, {"env_step": 5200000, "rew": 933.15, "rew_std": 92.64341584807849, "Agent": "iqn"}, {"env_step": 5300000, "rew": 874.25, "rew_std": 130.7188299366239, "Agent": "iqn"}, {"env_step": 5400000, "rew": 882.05, "rew_std": 148.7500336134416, "Agent": "iqn"}, {"env_step": 5500000, "rew": 801.65, "rew_std": 157.15773127657448, "Agent": "iqn"}, {"env_step": 5600000, "rew": 927.7, "rew_std": 267.2525210358174, "Agent": "iqn"}, {"env_step": 5700000, "rew": 952.6, "rew_std": 184.16715233721783, "Agent": "iqn"}, {"env_step": 5800000, "rew": 857.65, "rew_std": 235.34772677890902, "Agent": "iqn"}, {"env_step": 5900000, "rew": 836.4, "rew_std": 238.53121389034183, "Agent": "iqn"}, {"env_step": 6000000, "rew": 890.35, "rew_std": 114.51791344588845, "Agent": "iqn"}, {"env_step": 6100000, "rew": 935.55, "rew_std": 166.9722207434518, "Agent": "iqn"}, {"env_step": 6200000, "rew": 941.05, "rew_std": 148.40981268096797, "Agent": "iqn"}, {"env_step": 6300000, "rew": 965.9, "rew_std": 189.78208556130897, "Agent": "iqn"}, {"env_step": 6400000, "rew": 875.15, "rew_std": 215.37328641221964, "Agent": "iqn"}, {"env_step": 6500000, "rew": 939.0, "rew_std": 153.9548635152524, "Agent": "iqn"}, {"env_step": 6600000, "rew": 928.2, "rew_std": 232.79499565067974, "Agent": "iqn"}, {"env_step": 6700000, "rew": 847.65, "rew_std": 67.1215501906802, "Agent": "iqn"}, {"env_step": 6800000, "rew": 961.7, "rew_std": 153.62831770217363, "Agent": "iqn"}, {"env_step": 6900000, "rew": 917.75, "rew_std": 210.0694230486674, "Agent": "iqn"}, {"env_step": 7000000, "rew": 887.35, "rew_std": 105.70029564764708, "Agent": "iqn"}, {"env_step": 7100000, "rew": 958.85, "rew_std": 133.01297869005117, "Agent": "iqn"}, {"env_step": 7200000, "rew": 886.9, "rew_std": 131.38260919923914, "Agent": "iqn"}, {"env_step": 7300000, "rew": 919.7, "rew_std": 193.17805258362037, "Agent": "iqn"}, {"env_step": 7400000, "rew": 843.9, "rew_std": 261.7283706440706, "Agent": "iqn"}, {"env_step": 7500000, "rew": 865.8, "rew_std": 213.59625464881168, "Agent": "iqn"}, {"env_step": 7600000, "rew": 951.8, "rew_std": 178.61973575168003, "Agent": "iqn"}, {"env_step": 7700000, "rew": 889.7, "rew_std": 187.84557487468265, "Agent": "iqn"}, {"env_step": 7800000, "rew": 977.9, "rew_std": 223.82055312236184, "Agent": "iqn"}, {"env_step": 7900000, "rew": 909.35, "rew_std": 181.14414840121114, "Agent": "iqn"}, {"env_step": 8000000, "rew": 985.15, "rew_std": 156.5410569147915, "Agent": "iqn"}, {"env_step": 8100000, "rew": 1051.1, "rew_std": 157.7045972697055, "Agent": "iqn"}, {"env_step": 8200000, "rew": 953.55, "rew_std": 201.69462684960152, "Agent": "iqn"}, {"env_step": 8300000, "rew": 951.6, "rew_std": 148.2414921673416, "Agent": "iqn"}, {"env_step": 8400000, "rew": 930.85, "rew_std": 137.40233804415413, "Agent": "iqn"}, {"env_step": 8500000, "rew": 1027.05, "rew_std": 82.30353880605621, "Agent": "iqn"}, {"env_step": 8600000, "rew": 999.15, "rew_std": 210.70109752917756, "Agent": "iqn"}, {"env_step": 8700000, "rew": 1005.65, "rew_std": 197.8955595762573, "Agent": "iqn"}, {"env_step": 8800000, "rew": 1114.7, "rew_std": 116.91389994350544, "Agent": "iqn"}, {"env_step": 8900000, "rew": 955.6, "rew_std": 172.78857022384324, "Agent": "iqn"}, {"env_step": 9000000, "rew": 858.45, "rew_std": 104.03327592650344, "Agent": "iqn"}, {"env_step": 9100000, "rew": 887.4, "rew_std": 217.57881330681073, "Agent": "iqn"}, {"env_step": 9200000, "rew": 965.85, "rew_std": 178.4683515360637, "Agent": "iqn"}, {"env_step": 9300000, "rew": 970.6, "rew_std": 139.28151348976647, "Agent": "iqn"}, {"env_step": 9400000, "rew": 964.0, "rew_std": 120.89272103811709, "Agent": "iqn"}, {"env_step": 9500000, "rew": 993.35, "rew_std": 285.66554307441424, "Agent": "iqn"}, {"env_step": 9600000, "rew": 965.5, "rew_std": 75.48609143411785, "Agent": "iqn"}, {"env_step": 9700000, "rew": 984.5, "rew_std": 142.40224717328024, "Agent": "iqn"}, {"env_step": 9800000, "rew": 959.0, "rew_std": 233.88319734431545, "Agent": "iqn"}, {"env_step": 9900000, "rew": 1060.1, "rew_std": 261.16113799721427, "Agent": "iqn"}, {"env_step": 10000000, "rew": 966.65, "rew_std": 156.9350263644162, "Agent": "iqn"}, {"env_step": 0, "rew": 129.2, "rew_std": 69.53567429744247, "Agent": "rainbow"}, {"env_step": 100000, "rew": 177.85, "rew_std": 75.14188246244566, "Agent": "rainbow"}, {"env_step": 200000, "rew": 198.55, "rew_std": 53.13823952672877, "Agent": "rainbow"}, {"env_step": 300000, "rew": 274.85, "rew_std": 47.645592660811765, "Agent": "rainbow"}, {"env_step": 400000, "rew": 297.05, "rew_std": 50.52496907470602, "Agent": "rainbow"}, {"env_step": 500000, "rew": 363.35, "rew_std": 90.23110605550616, "Agent": "rainbow"}, {"env_step": 600000, "rew": 377.1, "rew_std": 62.0112086642407, "Agent": "rainbow"}, {"env_step": 700000, "rew": 449.6, "rew_std": 20.95924616965028, "Agent": "rainbow"}, {"env_step": 800000, "rew": 478.55, "rew_std": 50.6070400240915, "Agent": "rainbow"}, {"env_step": 900000, "rew": 467.45, "rew_std": 44.51092562506424, "Agent": "rainbow"}, {"env_step": 1000000, "rew": 518.3, "rew_std": 61.56102663211522, "Agent": "rainbow"}, {"env_step": 1100000, "rew": 546.95, "rew_std": 26.70809802288437, "Agent": "rainbow"}, {"env_step": 1200000, "rew": 539.1, "rew_std": 56.16929766340327, "Agent": "rainbow"}, {"env_step": 1300000, "rew": 585.25, "rew_std": 74.84425495654293, "Agent": "rainbow"}, {"env_step": 1400000, "rew": 547.5, "rew_std": 76.41760268419836, "Agent": "rainbow"}, {"env_step": 1500000, "rew": 622.95, "rew_std": 78.23376828454577, "Agent": "rainbow"}, {"env_step": 1600000, "rew": 608.95, "rew_std": 82.29776728441665, "Agent": "rainbow"}, {"env_step": 1700000, "rew": 593.7, "rew_std": 66.87682707784514, "Agent": "rainbow"}, {"env_step": 1800000, "rew": 589.45, "rew_std": 61.96388060798, "Agent": "rainbow"}, {"env_step": 1900000, "rew": 616.65, "rew_std": 62.149839098745865, "Agent": "rainbow"}, {"env_step": 2000000, "rew": 625.85, "rew_std": 97.02269064502386, "Agent": "rainbow"}, {"env_step": 2100000, "rew": 625.05, "rew_std": 74.01602866947131, "Agent": "rainbow"}, {"env_step": 2200000, "rew": 604.05, "rew_std": 120.92465629473585, "Agent": "rainbow"}, {"env_step": 2300000, "rew": 645.65, "rew_std": 99.99626243015285, "Agent": "rainbow"}, {"env_step": 2400000, "rew": 700.1, "rew_std": 88.73691452828412, "Agent": "rainbow"}, {"env_step": 2500000, "rew": 651.45, "rew_std": 63.471430581010225, "Agent": "rainbow"}, {"env_step": 2600000, "rew": 664.25, "rew_std": 97.90409848417991, "Agent": "rainbow"}, {"env_step": 2700000, "rew": 724.25, "rew_std": 75.68726775356606, "Agent": "rainbow"}, {"env_step": 2800000, "rew": 670.4, "rew_std": 85.67461701110778, "Agent": "rainbow"}, {"env_step": 2900000, "rew": 741.0, "rew_std": 110.0034090380839, "Agent": "rainbow"}, {"env_step": 3000000, "rew": 760.65, "rew_std": 146.2280838279706, "Agent": "rainbow"}, {"env_step": 3100000, "rew": 758.85, "rew_std": 83.78127774150977, "Agent": "rainbow"}, {"env_step": 3200000, "rew": 781.65, "rew_std": 80.49038762485866, "Agent": "rainbow"}, {"env_step": 3300000, "rew": 759.45, "rew_std": 155.77362581643916, "Agent": "rainbow"}, {"env_step": 3400000, "rew": 764.45, "rew_std": 146.94462392343587, "Agent": "rainbow"}, {"env_step": 3500000, "rew": 801.25, "rew_std": 66.33259002933626, "Agent": "rainbow"}, {"env_step": 3600000, "rew": 816.85, "rew_std": 83.54371610121254, "Agent": "rainbow"}, {"env_step": 3700000, "rew": 808.7, "rew_std": 91.75380101118428, "Agent": "rainbow"}, {"env_step": 3800000, "rew": 858.65, "rew_std": 116.32176279613373, "Agent": "rainbow"}, {"env_step": 3900000, "rew": 785.0, "rew_std": 113.13421233207929, "Agent": "rainbow"}, {"env_step": 4000000, "rew": 784.1, "rew_std": 148.00249997888548, "Agent": "rainbow"}, {"env_step": 4100000, "rew": 847.8, "rew_std": 141.88837161656343, "Agent": "rainbow"}, {"env_step": 4200000, "rew": 864.45, "rew_std": 98.91623981935423, "Agent": "rainbow"}, {"env_step": 4300000, "rew": 874.8, "rew_std": 130.1499519784775, "Agent": "rainbow"}, {"env_step": 4400000, "rew": 896.05, "rew_std": 163.69735031453627, "Agent": "rainbow"}, {"env_step": 4500000, "rew": 803.35, "rew_std": 121.67580901724055, "Agent": "rainbow"}, {"env_step": 4600000, "rew": 883.9, "rew_std": 127.14004089978893, "Agent": "rainbow"}, {"env_step": 4700000, "rew": 884.0, "rew_std": 62.155450283945335, "Agent": "rainbow"}, {"env_step": 4800000, "rew": 880.95, "rew_std": 157.68995687741182, "Agent": "rainbow"}, {"env_step": 4900000, "rew": 889.8, "rew_std": 118.07925304641795, "Agent": "rainbow"}, {"env_step": 5000000, "rew": 926.05, "rew_std": 152.34573344862665, "Agent": "rainbow"}, {"env_step": 5100000, "rew": 909.15, "rew_std": 130.32882451706527, "Agent": "rainbow"}, {"env_step": 5200000, "rew": 860.3, "rew_std": 148.12835650205534, "Agent": "rainbow"}, {"env_step": 5300000, "rew": 921.85, "rew_std": 225.55631780112037, "Agent": "rainbow"}, {"env_step": 5400000, "rew": 906.55, "rew_std": 165.13622406970558, "Agent": "rainbow"}, {"env_step": 5500000, "rew": 830.05, "rew_std": 163.59407843806574, "Agent": "rainbow"}, {"env_step": 5600000, "rew": 936.0, "rew_std": 105.19054139988063, "Agent": "rainbow"}, {"env_step": 5700000, "rew": 953.05, "rew_std": 209.43226231887004, "Agent": "rainbow"}, {"env_step": 5800000, "rew": 1002.05, "rew_std": 93.83001918362801, "Agent": "rainbow"}, {"env_step": 5900000, "rew": 925.0, "rew_std": 173.92857729539446, "Agent": "rainbow"}, {"env_step": 6000000, "rew": 959.65, "rew_std": 155.2884493450817, "Agent": "rainbow"}, {"env_step": 6100000, "rew": 968.05, "rew_std": 146.51168042173293, "Agent": "rainbow"}, {"env_step": 6200000, "rew": 1050.5, "rew_std": 181.52906103431485, "Agent": "rainbow"}, {"env_step": 6300000, "rew": 949.45, "rew_std": 170.18320275514853, "Agent": "rainbow"}, {"env_step": 6400000, "rew": 989.1, "rew_std": 128.6547317435313, "Agent": "rainbow"}, {"env_step": 6500000, "rew": 1003.1, "rew_std": 173.66603006921073, "Agent": "rainbow"}, {"env_step": 6600000, "rew": 1086.4, "rew_std": 179.66382496206631, "Agent": "rainbow"}, {"env_step": 6700000, "rew": 878.0, "rew_std": 83.01204731844649, "Agent": "rainbow"}, {"env_step": 6800000, "rew": 1107.55, "rew_std": 113.496354567008, "Agent": "rainbow"}, {"env_step": 6900000, "rew": 1062.7, "rew_std": 188.38619906988941, "Agent": "rainbow"}, {"env_step": 7000000, "rew": 1025.8, "rew_std": 146.94577231074055, "Agent": "rainbow"}, {"env_step": 7100000, "rew": 969.65, "rew_std": 143.79204602480627, "Agent": "rainbow"}, {"env_step": 7200000, "rew": 1074.2, "rew_std": 236.38752082121428, "Agent": "rainbow"}, {"env_step": 7300000, "rew": 1129.1, "rew_std": 145.2378738483871, "Agent": "rainbow"}, {"env_step": 7400000, "rew": 1020.55, "rew_std": 165.74822020160576, "Agent": "rainbow"}, {"env_step": 7500000, "rew": 1026.55, "rew_std": 126.17734543094492, "Agent": "rainbow"}, {"env_step": 7600000, "rew": 1062.0, "rew_std": 134.57005610461786, "Agent": "rainbow"}, {"env_step": 7700000, "rew": 1086.0, "rew_std": 98.97221832413376, "Agent": "rainbow"}, {"env_step": 7800000, "rew": 1066.7, "rew_std": 159.72964659073153, "Agent": "rainbow"}, {"env_step": 7900000, "rew": 1040.4, "rew_std": 127.58032763714004, "Agent": "rainbow"}, {"env_step": 8000000, "rew": 1074.7, "rew_std": 214.01894775930472, "Agent": "rainbow"}, {"env_step": 8100000, "rew": 1095.35, "rew_std": 139.572033373452, "Agent": "rainbow"}, {"env_step": 8200000, "rew": 1175.95, "rew_std": 163.02261346205933, "Agent": "rainbow"}, {"env_step": 8300000, "rew": 1147.0, "rew_std": 173.63841740813004, "Agent": "rainbow"}, {"env_step": 8400000, "rew": 1167.4, "rew_std": 160.64289589023224, "Agent": "rainbow"}, {"env_step": 8500000, "rew": 1162.35, "rew_std": 261.15216349860094, "Agent": "rainbow"}, {"env_step": 8600000, "rew": 1090.85, "rew_std": 134.61947295989538, "Agent": "rainbow"}, {"env_step": 8700000, "rew": 1165.2, "rew_std": 295.2810694914254, "Agent": "rainbow"}, {"env_step": 8800000, "rew": 1233.65, "rew_std": 176.00952389004408, "Agent": "rainbow"}, {"env_step": 8900000, "rew": 1189.25, "rew_std": 256.7154309736756, "Agent": "rainbow"}, {"env_step": 9000000, "rew": 1097.2, "rew_std": 220.08159850382765, "Agent": "rainbow"}, {"env_step": 9100000, "rew": 1151.05, "rew_std": 172.71746437462542, "Agent": "rainbow"}, {"env_step": 9200000, "rew": 1204.9, "rew_std": 126.59498410284667, "Agent": "rainbow"}, {"env_step": 9300000, "rew": 1064.65, "rew_std": 216.47644791062143, "Agent": "rainbow"}, {"env_step": 9400000, "rew": 1358.15, "rew_std": 267.57840439766437, "Agent": "rainbow"}, {"env_step": 9500000, "rew": 1092.1, "rew_std": 237.28230022485874, "Agent": "rainbow"}, {"env_step": 9600000, "rew": 1312.5, "rew_std": 333.92401830356556, "Agent": "rainbow"}, {"env_step": 9700000, "rew": 1284.1, "rew_std": 214.49811187980188, "Agent": "rainbow"}, {"env_step": 9800000, "rew": 1226.6, "rew_std": 304.27987117126236, "Agent": "rainbow"}, {"env_step": 9900000, "rew": 1122.35, "rew_std": 192.80275542636832, "Agent": "rainbow"}, {"env_step": 10000000, "rew": 1184.0, "rew_std": 231.1005192551501, "Agent": "rainbow"}, {"env_step": 0, "rew": 171.85, "rew_std": 31.587220517164848, "Agent": "ppo"}, {"env_step": 100000, "rew": 226.2, "rew_std": 53.99453676067608, "Agent": "ppo"}, {"env_step": 200000, "rew": 240.45, "rew_std": 20.0517455599257, "Agent": "ppo"}, {"env_step": 300000, "rew": 282.7, "rew_std": 25.0421644431946, "Agent": "ppo"}, {"env_step": 400000, "rew": 291.8, "rew_std": 47.00276587606308, "Agent": "ppo"}, {"env_step": 500000, "rew": 320.5, "rew_std": 31.345653606201928, "Agent": "ppo"}, {"env_step": 600000, "rew": 314.0, "rew_std": 56.18807702707043, "Agent": "ppo"}, {"env_step": 700000, "rew": 320.9, "rew_std": 56.30133213344068, "Agent": "ppo"}, {"env_step": 800000, "rew": 331.25, "rew_std": 52.53439349607074, "Agent": "ppo"}, {"env_step": 900000, "rew": 385.95, "rew_std": 86.00071220635328, "Agent": "ppo"}, {"env_step": 1000000, "rew": 396.05, "rew_std": 54.37713214210547, "Agent": "ppo"}, {"env_step": 1100000, "rew": 366.65, "rew_std": 46.54301773628349, "Agent": "ppo"}, {"env_step": 1200000, "rew": 377.25, "rew_std": 51.198266572219026, "Agent": "ppo"}, {"env_step": 1300000, "rew": 386.5, "rew_std": 81.74686538332831, "Agent": "ppo"}, {"env_step": 1400000, "rew": 400.45, "rew_std": 102.20431742348265, "Agent": "ppo"}, {"env_step": 1500000, "rew": 417.4, "rew_std": 66.27586287631418, "Agent": "ppo"}, {"env_step": 1600000, "rew": 428.3, "rew_std": 44.33801980242239, "Agent": "ppo"}, {"env_step": 1700000, "rew": 392.8, "rew_std": 61.047604375601836, "Agent": "ppo"}, {"env_step": 1800000, "rew": 443.6, "rew_std": 57.05953031702943, "Agent": "ppo"}, {"env_step": 1900000, "rew": 424.75, "rew_std": 63.06078416892704, "Agent": "ppo"}, {"env_step": 2000000, "rew": 438.4, "rew_std": 42.03617489734288, "Agent": "ppo"}, {"env_step": 2100000, "rew": 468.85, "rew_std": 68.96160163453283, "Agent": "ppo"}, {"env_step": 2200000, "rew": 474.1, "rew_std": 64.23659393211942, "Agent": "ppo"}, {"env_step": 2300000, "rew": 467.0, "rew_std": 42.47646407129482, "Agent": "ppo"}, {"env_step": 2400000, "rew": 488.1, "rew_std": 49.019791921222996, "Agent": "ppo"}, {"env_step": 2500000, "rew": 528.75, "rew_std": 90.16602741609503, "Agent": "ppo"}, {"env_step": 2600000, "rew": 522.45, "rew_std": 87.87617709026718, "Agent": "ppo"}, {"env_step": 2700000, "rew": 504.35, "rew_std": 51.12047045949401, "Agent": "ppo"}, {"env_step": 2800000, "rew": 528.55, "rew_std": 70.213050781176, "Agent": "ppo"}, {"env_step": 2900000, "rew": 521.1, "rew_std": 44.84852282963175, "Agent": "ppo"}, {"env_step": 3000000, "rew": 565.5, "rew_std": 67.83251432757008, "Agent": "ppo"}, {"env_step": 3100000, "rew": 526.7, "rew_std": 55.79569875895453, "Agent": "ppo"}, {"env_step": 3200000, "rew": 610.9, "rew_std": 55.02990096302191, "Agent": "ppo"}, {"env_step": 3300000, "rew": 552.1, "rew_std": 85.30791288034189, "Agent": "ppo"}, {"env_step": 3400000, "rew": 594.4, "rew_std": 72.59814047205342, "Agent": "ppo"}, {"env_step": 3500000, "rew": 560.7, "rew_std": 89.29171294134747, "Agent": "ppo"}, {"env_step": 3600000, "rew": 580.25, "rew_std": 79.42205298278306, "Agent": "ppo"}, {"env_step": 3700000, "rew": 629.95, "rew_std": 90.99023299233825, "Agent": "ppo"}, {"env_step": 3800000, "rew": 593.6, "rew_std": 96.24598692932604, "Agent": "ppo"}, {"env_step": 3900000, "rew": 633.6, "rew_std": 63.77060451336494, "Agent": "ppo"}, {"env_step": 4000000, "rew": 623.85, "rew_std": 79.62538853908345, "Agent": "ppo"}, {"env_step": 4100000, "rew": 625.55, "rew_std": 71.11520582828963, "Agent": "ppo"}, {"env_step": 4200000, "rew": 631.1, "rew_std": 60.92405108001273, "Agent": "ppo"}, {"env_step": 4300000, "rew": 652.7, "rew_std": 78.92122654900898, "Agent": "ppo"}, {"env_step": 4400000, "rew": 645.25, "rew_std": 61.85194014741979, "Agent": "ppo"}, {"env_step": 4500000, "rew": 684.05, "rew_std": 84.49658277113933, "Agent": "ppo"}, {"env_step": 4600000, "rew": 696.05, "rew_std": 90.35857734603837, "Agent": "ppo"}, {"env_step": 4700000, "rew": 651.7, "rew_std": 98.24744271481065, "Agent": "ppo"}, {"env_step": 4800000, "rew": 710.2, "rew_std": 113.41190413708783, "Agent": "ppo"}, {"env_step": 4900000, "rew": 719.95, "rew_std": 103.00544888499832, "Agent": "ppo"}, {"env_step": 5000000, "rew": 702.85, "rew_std": 71.93714270111094, "Agent": "ppo"}, {"env_step": 5100000, "rew": 657.1, "rew_std": 91.01615241263497, "Agent": "ppo"}, {"env_step": 5200000, "rew": 669.75, "rew_std": 95.95891047734962, "Agent": "ppo"}, {"env_step": 5300000, "rew": 730.45, "rew_std": 102.41861403084891, "Agent": "ppo"}, {"env_step": 5400000, "rew": 707.9, "rew_std": 79.9180204960058, "Agent": "ppo"}, {"env_step": 5500000, "rew": 711.65, "rew_std": 116.25189245771442, "Agent": "ppo"}, {"env_step": 5600000, "rew": 742.6, "rew_std": 103.81541311385318, "Agent": "ppo"}, {"env_step": 5700000, "rew": 752.15, "rew_std": 98.74513912087015, "Agent": "ppo"}, {"env_step": 5800000, "rew": 791.7, "rew_std": 111.0621897857232, "Agent": "ppo"}, {"env_step": 5900000, "rew": 806.95, "rew_std": 144.94213500566354, "Agent": "ppo"}, {"env_step": 6000000, "rew": 827.45, "rew_std": 113.05871262313224, "Agent": "ppo"}, {"env_step": 6100000, "rew": 779.5, "rew_std": 100.94874937313487, "Agent": "ppo"}, {"env_step": 6200000, "rew": 812.75, "rew_std": 158.00810896912856, "Agent": "ppo"}, {"env_step": 6300000, "rew": 839.65, "rew_std": 123.6092735194249, "Agent": "ppo"}, {"env_step": 6400000, "rew": 852.95, "rew_std": 132.6488691998541, "Agent": "ppo"}, {"env_step": 6500000, "rew": 833.05, "rew_std": 148.62073374869334, "Agent": "ppo"}, {"env_step": 6600000, "rew": 887.55, "rew_std": 111.08947969992478, "Agent": "ppo"}, {"env_step": 6700000, "rew": 793.6, "rew_std": 104.03119724390372, "Agent": "ppo"}, {"env_step": 6800000, "rew": 832.25, "rew_std": 154.1725088983117, "Agent": "ppo"}, {"env_step": 6900000, "rew": 871.05, "rew_std": 154.96926308142525, "Agent": "ppo"}, {"env_step": 7000000, "rew": 833.1, "rew_std": 101.21086898154763, "Agent": "ppo"}, {"env_step": 7100000, "rew": 885.15, "rew_std": 144.50104670901175, "Agent": "ppo"}, {"env_step": 7200000, "rew": 850.1, "rew_std": 142.2687246024227, "Agent": "ppo"}, {"env_step": 7300000, "rew": 861.5, "rew_std": 87.94373201087159, "Agent": "ppo"}, {"env_step": 7400000, "rew": 834.6, "rew_std": 195.45917732355264, "Agent": "ppo"}, {"env_step": 7500000, "rew": 880.95, "rew_std": 143.8566039499056, "Agent": "ppo"}, {"env_step": 7600000, "rew": 921.95, "rew_std": 171.26462711254766, "Agent": "ppo"}, {"env_step": 7700000, "rew": 906.05, "rew_std": 214.73034834415, "Agent": "ppo"}, {"env_step": 7800000, "rew": 934.75, "rew_std": 217.31075099957664, "Agent": "ppo"}, {"env_step": 7900000, "rew": 927.8, "rew_std": 146.93998775010158, "Agent": "ppo"}, {"env_step": 8000000, "rew": 904.5, "rew_std": 154.3149377085705, "Agent": "ppo"}, {"env_step": 8100000, "rew": 902.9, "rew_std": 179.20083705161647, "Agent": "ppo"}, {"env_step": 8200000, "rew": 941.1, "rew_std": 163.1423917931817, "Agent": "ppo"}, {"env_step": 8300000, "rew": 956.8, "rew_std": 210.935440360315, "Agent": "ppo"}, {"env_step": 8400000, "rew": 913.4, "rew_std": 155.79261856711955, "Agent": "ppo"}, {"env_step": 8500000, "rew": 907.55, "rew_std": 156.9779363477556, "Agent": "ppo"}, {"env_step": 8600000, "rew": 883.95, "rew_std": 164.77324570451358, "Agent": "ppo"}, {"env_step": 8700000, "rew": 963.85, "rew_std": 182.24695470706774, "Agent": "ppo"}, {"env_step": 8800000, "rew": 993.0, "rew_std": 205.16420253055844, "Agent": "ppo"}, {"env_step": 8900000, "rew": 961.75, "rew_std": 131.86114856165935, "Agent": "ppo"}, {"env_step": 9000000, "rew": 969.8, "rew_std": 228.0311601514144, "Agent": "ppo"}, {"env_step": 9100000, "rew": 1003.2, "rew_std": 189.2723170461016, "Agent": "ppo"}, {"env_step": 9200000, "rew": 953.9, "rew_std": 193.50074935255418, "Agent": "ppo"}, {"env_step": 9300000, "rew": 955.5, "rew_std": 164.37198666439485, "Agent": "ppo"}, {"env_step": 9400000, "rew": 989.6, "rew_std": 161.20899478627115, "Agent": "ppo"}, {"env_step": 9500000, "rew": 1055.5, "rew_std": 215.21524109597814, "Agent": "ppo"}, {"env_step": 9600000, "rew": 1071.5, "rew_std": 225.1992451141877, "Agent": "ppo"}, {"env_step": 9700000, "rew": 954.35, "rew_std": 160.38080464943425, "Agent": "ppo"}, {"env_step": 9800000, "rew": 965.05, "rew_std": 193.5224082632293, "Agent": "ppo"}, {"env_step": 9900000, "rew": 1018.95, "rew_std": 173.74959712183508, "Agent": "ppo"}, {"env_step": 10000000, "rew": 1129.5, "rew_std": 145.34132241038677, "Agent": "ppo"}] \ No newline at end of file diff --git a/examples/atari/results/c51/Breakout_rew.png b/examples/atari/results/c51/Breakout_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..6b510a04f5df1856552d4a396fdb794ecdf488cb Binary files /dev/null and b/examples/atari/results/c51/Breakout_rew.png differ diff --git a/examples/atari/results/c51/Enduro_rew.png b/examples/atari/results/c51/Enduro_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..0e08380a135bae2b10690c5d781977a32fe1d3d5 Binary files /dev/null and b/examples/atari/results/c51/Enduro_rew.png differ diff --git a/examples/atari/results/c51/MsPacman_rew.png b/examples/atari/results/c51/MsPacman_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..001af14eef37bf483c3fa5f334cffbbef312716b Binary files /dev/null and b/examples/atari/results/c51/MsPacman_rew.png differ diff --git a/examples/atari/results/c51/Pong_rew.png b/examples/atari/results/c51/Pong_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..c835399d349d7f860dd6fd255426fc7be9158f50 Binary files /dev/null and b/examples/atari/results/c51/Pong_rew.png differ diff --git a/examples/atari/results/c51/Qbert_rew.png b/examples/atari/results/c51/Qbert_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..47ee25ed9b24d61e2bae3e47c14a2ce243c9346c Binary files /dev/null and b/examples/atari/results/c51/Qbert_rew.png differ diff --git a/examples/atari/results/c51/Seaquest_rew.png b/examples/atari/results/c51/Seaquest_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..6cc069fd85156bee35350af5c709124d465d0761 Binary files /dev/null and b/examples/atari/results/c51/Seaquest_rew.png differ diff --git a/examples/atari/results/c51/SpaceInvader_rew.png b/examples/atari/results/c51/SpaceInvader_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..108cf04be2ee16ca5044ce1b5d51d2ae6ba99a11 Binary files /dev/null and b/examples/atari/results/c51/SpaceInvader_rew.png differ diff --git a/examples/atari/results/discrete_sac/Breakout_rew.png b/examples/atari/results/discrete_sac/Breakout_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..23f0d8fd2b2cbae82e7dfe3ca8368ea9bb1a7a17 Binary files /dev/null and b/examples/atari/results/discrete_sac/Breakout_rew.png differ diff --git a/examples/atari/results/discrete_sac/Enduro_rew.png b/examples/atari/results/discrete_sac/Enduro_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..05466a69e67e9d064917c6fcd6814e9597aee0d8 Binary files /dev/null and b/examples/atari/results/discrete_sac/Enduro_rew.png differ diff --git a/examples/atari/results/discrete_sac/MsPacman_rew.png b/examples/atari/results/discrete_sac/MsPacman_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..0f8d8bc3ef93c5f14351b4c06bd65cada0d81d84 Binary files /dev/null and b/examples/atari/results/discrete_sac/MsPacman_rew.png differ diff --git a/examples/atari/results/discrete_sac/Pong_rew.png b/examples/atari/results/discrete_sac/Pong_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..3fcdad0b0aaf94dea304ba6c48f4d9418a6ea199 Binary files /dev/null and b/examples/atari/results/discrete_sac/Pong_rew.png differ diff --git a/examples/atari/results/discrete_sac/Qbert_rew.png b/examples/atari/results/discrete_sac/Qbert_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..5ac7efa91a3da77448fd1ebacbd43c54c03a2315 Binary files /dev/null and b/examples/atari/results/discrete_sac/Qbert_rew.png differ diff --git a/examples/atari/results/discrete_sac/Seaquest_rew.png b/examples/atari/results/discrete_sac/Seaquest_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..1d562744a117466665f91830840950709aa84115 Binary files /dev/null and b/examples/atari/results/discrete_sac/Seaquest_rew.png differ diff --git a/examples/atari/results/discrete_sac/SpaceInvaders_rew.png b/examples/atari/results/discrete_sac/SpaceInvaders_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..e1ee290fd95f065f556304a383fceef882aa20ba Binary files /dev/null and b/examples/atari/results/discrete_sac/SpaceInvaders_rew.png differ diff --git a/examples/atari/results/dqn/Breakout_rew.png b/examples/atari/results/dqn/Breakout_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..2deed236a27cf579206489774d887815d69cde26 Binary files /dev/null and b/examples/atari/results/dqn/Breakout_rew.png differ diff --git a/examples/atari/results/dqn/Enduro_rew.png b/examples/atari/results/dqn/Enduro_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..27e0c6f39b9e94e4605a9d630599d8dd99d37b5f Binary files /dev/null and b/examples/atari/results/dqn/Enduro_rew.png differ diff --git a/examples/atari/results/dqn/MsPacman_rew.png b/examples/atari/results/dqn/MsPacman_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..6a3a88ab41794a89468504e237efdd1b592d8b84 Binary files /dev/null and b/examples/atari/results/dqn/MsPacman_rew.png differ diff --git a/examples/atari/results/dqn/Pong_rew.png b/examples/atari/results/dqn/Pong_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..75b289873ef12c61dbbbc2ca0e1ccc8643569f7c Binary files /dev/null and b/examples/atari/results/dqn/Pong_rew.png differ diff --git a/examples/atari/results/dqn/Qbert_rew.png b/examples/atari/results/dqn/Qbert_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..8f674881a16c869798e0a2c133cdf8db51748926 Binary files /dev/null and b/examples/atari/results/dqn/Qbert_rew.png differ diff --git a/examples/atari/results/dqn/Seaquest_rew.png b/examples/atari/results/dqn/Seaquest_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..5ed0821034e657d422b3e4443dbfdc015f9b50eb Binary files /dev/null and b/examples/atari/results/dqn/Seaquest_rew.png differ diff --git a/examples/atari/results/dqn/SpaceInvader_rew.png b/examples/atari/results/dqn/SpaceInvader_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..57c51fe9a2a59e7db17fba637afee878bfbcb481 Binary files /dev/null and b/examples/atari/results/dqn/SpaceInvader_rew.png differ diff --git a/examples/atari/results/fqf/Breakout_rew.png b/examples/atari/results/fqf/Breakout_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..409a2adf522b3f653f152e5b28a83b4d113e6d43 Binary files /dev/null and b/examples/atari/results/fqf/Breakout_rew.png differ diff --git a/examples/atari/results/fqf/Enduro_rew.png b/examples/atari/results/fqf/Enduro_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..3d51259835917e9615fdd43c28a597f5e0ba45ea Binary files /dev/null and b/examples/atari/results/fqf/Enduro_rew.png differ diff --git a/examples/atari/results/fqf/MsPacman_rew.png b/examples/atari/results/fqf/MsPacman_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..6832017fb66afb26627984b779697a5e2b123072 Binary files /dev/null and b/examples/atari/results/fqf/MsPacman_rew.png differ diff --git a/examples/atari/results/fqf/Pong_rew.png b/examples/atari/results/fqf/Pong_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..e3a4e44fdfcb6d343e43c277b063a0b669055c17 Binary files /dev/null and b/examples/atari/results/fqf/Pong_rew.png differ diff --git a/examples/atari/results/fqf/Qbert_rew.png b/examples/atari/results/fqf/Qbert_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..c0a2fec58023e99dec31ca16709c8369c95ae383 Binary files /dev/null and b/examples/atari/results/fqf/Qbert_rew.png differ diff --git a/examples/atari/results/fqf/Seaquest_rew.png b/examples/atari/results/fqf/Seaquest_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..2accab38eaedf210005115aa5abe9a676652131b Binary files /dev/null and b/examples/atari/results/fqf/Seaquest_rew.png differ diff --git a/examples/atari/results/fqf/SpaceInvaders_rew.png b/examples/atari/results/fqf/SpaceInvaders_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..fe271ecc1658199cb1b6813ba639bc87bd1a17d0 Binary files /dev/null and b/examples/atari/results/fqf/SpaceInvaders_rew.png differ diff --git a/examples/atari/results/iqn/Breakout_rew.png b/examples/atari/results/iqn/Breakout_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..ab9b9485851c6edd3938d7f43db35232f86320b2 Binary files /dev/null and b/examples/atari/results/iqn/Breakout_rew.png differ diff --git a/examples/atari/results/iqn/Enduro_rew.png b/examples/atari/results/iqn/Enduro_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..2b9129943d5bc6f2334537754fe46b0a4175c50a Binary files /dev/null and b/examples/atari/results/iqn/Enduro_rew.png differ diff --git a/examples/atari/results/iqn/MsPacman_rew.png b/examples/atari/results/iqn/MsPacman_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..af6f3aebb79b742dc3875da34cc1df414f142f02 Binary files /dev/null and b/examples/atari/results/iqn/MsPacman_rew.png differ diff --git a/examples/atari/results/iqn/Pong_rew.png b/examples/atari/results/iqn/Pong_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..bb8d31f9a02d7ea1b741be4d8387195772084829 Binary files /dev/null and b/examples/atari/results/iqn/Pong_rew.png differ diff --git a/examples/atari/results/iqn/Qbert_rew.png b/examples/atari/results/iqn/Qbert_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..085a64d9946c6d3dac34c70a1d2a41d470fe7deb Binary files /dev/null and b/examples/atari/results/iqn/Qbert_rew.png differ diff --git a/examples/atari/results/iqn/Seaquest_rew.png b/examples/atari/results/iqn/Seaquest_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..8f343cdba4ea7c0fdd29e23cbef642c7967f44e2 Binary files /dev/null and b/examples/atari/results/iqn/Seaquest_rew.png differ diff --git a/examples/atari/results/iqn/SpaceInvaders_rew.png b/examples/atari/results/iqn/SpaceInvaders_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..d8e7b338efa907318c562b88b4bb46e2a43e6d55 Binary files /dev/null and b/examples/atari/results/iqn/SpaceInvaders_rew.png differ diff --git a/examples/atari/results/ppo/Breakout_rew.png b/examples/atari/results/ppo/Breakout_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..79a6cca1b53d3851ea2e1464d185f8419c3aa69d Binary files /dev/null and b/examples/atari/results/ppo/Breakout_rew.png differ diff --git a/examples/atari/results/ppo/Enduro_rew.png b/examples/atari/results/ppo/Enduro_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..621527a719642ccaf0f69e422b4ecaf3ea38e31d Binary files /dev/null and b/examples/atari/results/ppo/Enduro_rew.png differ diff --git a/examples/atari/results/ppo/MsPacman_rew.png b/examples/atari/results/ppo/MsPacman_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..a5cc0519b6eaa70ca4c61802eb2478c9241dc028 Binary files /dev/null and b/examples/atari/results/ppo/MsPacman_rew.png differ diff --git a/examples/atari/results/ppo/Pong_rew.png b/examples/atari/results/ppo/Pong_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..3d6523b081cecb68254c03a24b3f8ec22a307780 Binary files /dev/null and b/examples/atari/results/ppo/Pong_rew.png differ diff --git a/examples/atari/results/ppo/Qbert_rew.png b/examples/atari/results/ppo/Qbert_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..db0bfd22da2f1de444ff50fd9b391d574f201d3d Binary files /dev/null and b/examples/atari/results/ppo/Qbert_rew.png differ diff --git a/examples/atari/results/ppo/Seaquest_rew.png b/examples/atari/results/ppo/Seaquest_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..4896a9c0c584812e49465098ecedb574b8f1d6b7 Binary files /dev/null and b/examples/atari/results/ppo/Seaquest_rew.png differ diff --git a/examples/atari/results/ppo/SpaceInvaders_rew.png b/examples/atari/results/ppo/SpaceInvaders_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..fb6722393905ca7e025ae2c8235cb6a0e194499a Binary files /dev/null and b/examples/atari/results/ppo/SpaceInvaders_rew.png differ diff --git a/examples/atari/results/qrdqn/Breakout_rew.png b/examples/atari/results/qrdqn/Breakout_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..5cc1916e6eb4aa90634d3c31d9f13d023d429114 Binary files /dev/null and b/examples/atari/results/qrdqn/Breakout_rew.png differ diff --git a/examples/atari/results/qrdqn/Enduro_rew.png b/examples/atari/results/qrdqn/Enduro_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..640e60d2d7cf0f0fe39b9329174e3aeacea9e48d Binary files /dev/null and b/examples/atari/results/qrdqn/Enduro_rew.png differ diff --git a/examples/atari/results/qrdqn/MsPacman_rew.png b/examples/atari/results/qrdqn/MsPacman_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..0afd25787ca5feee219d86a4be0caf068bf70167 Binary files /dev/null and b/examples/atari/results/qrdqn/MsPacman_rew.png differ diff --git a/examples/atari/results/qrdqn/Pong_rew.png b/examples/atari/results/qrdqn/Pong_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..30a02375eaab9c4041349ac98b28a09851dba39b Binary files /dev/null and b/examples/atari/results/qrdqn/Pong_rew.png differ diff --git a/examples/atari/results/qrdqn/Qbert_rew.png b/examples/atari/results/qrdqn/Qbert_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..fbd25c7320c2a615913e24c78e6e0b89a281457c Binary files /dev/null and b/examples/atari/results/qrdqn/Qbert_rew.png differ diff --git a/examples/atari/results/qrdqn/Seaquest_rew.png b/examples/atari/results/qrdqn/Seaquest_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..7e9d47af2e1a8e14cadab38b90dc7259b75f22dd Binary files /dev/null and b/examples/atari/results/qrdqn/Seaquest_rew.png differ diff --git a/examples/atari/results/qrdqn/SpaceInvader_rew.png b/examples/atari/results/qrdqn/SpaceInvader_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..4751768fba2a87997c0acf6953432d6ea8593e1e Binary files /dev/null and b/examples/atari/results/qrdqn/SpaceInvader_rew.png differ diff --git a/examples/atari/results/rainbow/Breakout_rew.png b/examples/atari/results/rainbow/Breakout_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..b2071ccd84851c50eb82e39495fc8e517d809e2c Binary files /dev/null and b/examples/atari/results/rainbow/Breakout_rew.png differ diff --git a/examples/atari/results/rainbow/Enduro_rew.png b/examples/atari/results/rainbow/Enduro_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..f6b913fc4d8dc07fdf969118365e7bf3692335e1 Binary files /dev/null and b/examples/atari/results/rainbow/Enduro_rew.png differ diff --git a/examples/atari/results/rainbow/MsPacman_rew.png b/examples/atari/results/rainbow/MsPacman_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..2b51f2d0e2ab2d39703bb47bd32c4fbbfcccbfa8 Binary files /dev/null and b/examples/atari/results/rainbow/MsPacman_rew.png differ diff --git a/examples/atari/results/rainbow/Pong_rew.png b/examples/atari/results/rainbow/Pong_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..3566cfb3b76677abb28a05aea9147acdeb84b54d Binary files /dev/null and b/examples/atari/results/rainbow/Pong_rew.png differ diff --git a/examples/atari/results/rainbow/Qbert_rew.png b/examples/atari/results/rainbow/Qbert_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..1644ab8730267b934f896183e9ba69579e4d582f Binary files /dev/null and b/examples/atari/results/rainbow/Qbert_rew.png differ diff --git a/examples/atari/results/rainbow/Seaquest_rew.png b/examples/atari/results/rainbow/Seaquest_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..9c5898a1afa8c908e0c55f85ca2d46bff9555219 Binary files /dev/null and b/examples/atari/results/rainbow/Seaquest_rew.png differ diff --git a/examples/atari/results/rainbow/SpaceInvaders_rew.png b/examples/atari/results/rainbow/SpaceInvaders_rew.png new file mode 100644 index 0000000000000000000000000000000000000000..2182ee80e82fa8469436fe0ef1049321436a1492 Binary files /dev/null and b/examples/atari/results/rainbow/SpaceInvaders_rew.png differ diff --git a/examples/atari/tianshou/__init__.py b/examples/atari/tianshou/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca302f23fb5a0fbefdd08f43a4aae0ba053fcaea --- /dev/null +++ b/examples/atari/tianshou/__init__.py @@ -0,0 +1,12 @@ +from tianshou import data, env, exploration, policy, trainer, utils + +__version__ = "1.0.0" + +__all__ = [ + "env", + "data", + "utils", + "policy", + "trainer", + "exploration", +] diff --git a/examples/atari/tianshou/data/__init__.py b/examples/atari/tianshou/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c84c2ec7d753b12a34a903882861880d1fea5e9d --- /dev/null +++ b/examples/atari/tianshou/data/__init__.py @@ -0,0 +1,60 @@ +"""Data package.""" +# isort:skip_file + +from tianshou.data.batch import Batch +from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as +from tianshou.data.utils.segtree import SegmentTree +from tianshou.data.buffer.base import ReplayBuffer +from tianshou.data.buffer.prio import PrioritizedReplayBuffer +from tianshou.data.buffer.her import HERReplayBuffer +from tianshou.data.buffer.manager import ( + ReplayBufferManager, + PrioritizedReplayBufferManager, + HERReplayBufferManager, +) +from tianshou.data.buffer.vecbuf import ( + HERVectorReplayBuffer, + PrioritizedVectorReplayBuffer, + VectorReplayBuffer, +) +from tianshou.data.buffer.cached import CachedReplayBuffer +from tianshou.data.stats import ( + EpochStats, + InfoStats, + SequenceSummaryStats, + TimingStats, +) +from tianshou.data.collector import ( + Collector, + AsyncCollector, + CollectStats, + CollectStatsBase, + BaseCollector, +) + +__all__ = [ + "Batch", + "to_numpy", + "to_torch", + "to_torch_as", + "SegmentTree", + "ReplayBuffer", + "PrioritizedReplayBuffer", + "HERReplayBuffer", + "ReplayBufferManager", + "PrioritizedReplayBufferManager", + "HERReplayBufferManager", + "VectorReplayBuffer", + "PrioritizedVectorReplayBuffer", + "HERVectorReplayBuffer", + "CachedReplayBuffer", + "Collector", + "CollectStats", + "CollectStatsBase", + "AsyncCollector", + "EpochStats", + "InfoStats", + "SequenceSummaryStats", + "TimingStats", + "BaseCollector", +] diff --git a/examples/atari/tianshou/data/batch.py b/examples/atari/tianshou/data/batch.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ea4a8b1dab2c4e0aa122ff6de09fa0e2923c47 --- /dev/null +++ b/examples/atari/tianshou/data/batch.py @@ -0,0 +1,1009 @@ +import pprint +import warnings +from collections.abc import Collection, Iterable, Iterator, KeysView, Sequence +from copy import deepcopy +from numbers import Number +from types import EllipsisType +from typing import ( + Any, + Protocol, + Self, + TypeVar, + Union, + cast, + overload, + runtime_checkable, +) + +import numpy as np +import torch +from deepdiff import DeepDiff + +_SingleIndexType = slice | int | EllipsisType +IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...] +TBatch = TypeVar("TBatch", bound="BatchProtocol") +arr_type = torch.Tensor | np.ndarray + + +def _is_batch_set(obj: Any) -> bool: + # Batch set is a list/tuple of dict/Batch objects, + # or 1-D np.ndarray with object type, + # where each element is a dict/Batch object + if isinstance(obj, np.ndarray): # most often case + # "for element in obj" will just unpack the first dimension, + # but obj.tolist() will flatten ndarray of objects + # so do not use obj.tolist() + if obj.shape == (): + return False + return obj.dtype == object and all(isinstance(element, dict | Batch) for element in obj) + if ( + isinstance(obj, list | tuple) + and len(obj) > 0 + and all(isinstance(element, dict | Batch) for element in obj) + ): + return True + return False + + +def _is_scalar(value: Any) -> bool: + # check if the value is a scalar + # 1. python bool object, number object: isinstance(value, Number) + # 2. numpy scalar: isinstance(value, np.generic) + # 3. python object rather than dict / Batch / tensor + # the check of dict / Batch is omitted because this only checks a value. + # a dict / Batch will eventually check their values + if isinstance(value, torch.Tensor): + return value.numel() == 1 and not value.shape + # np.asanyarray will cause dead loop in some cases + return np.isscalar(value) + + +def _is_number(value: Any) -> bool: + # isinstance(value, Number) checks 1, 1.0, np.int(1), np.float(1.0), etc. + # isinstance(value, np.nummber) checks np.int32(1), np.float64(1.0), etc. + # isinstance(value, np.bool_) checks np.bool_(True), etc. + # similar to np.isscalar but np.isscalar('st') returns True + return isinstance(value, Number | np.number | np.bool_) + + +def _to_array_with_correct_type(obj: Any) -> np.ndarray: + if isinstance(obj, np.ndarray) and issubclass(obj.dtype.type, np.bool_ | np.number): + return obj # most often case + # convert the value to np.ndarray + # convert to object obj type if neither bool nor number + # raises an exception if array's elements are tensors themselves + try: + obj_array = np.asanyarray(obj) + except ValueError: + obj_array = np.asanyarray(obj, dtype=object) + if not issubclass(obj_array.dtype.type, np.bool_ | np.number): + obj_array = obj_array.astype(object) + if obj_array.dtype == object: + # scalar ndarray with object obj type is very annoying + # a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)]) + # a is not array([{}, {}], dtype=object), and a[0]={} results in + # something very strange: + # array([{}, array({}, dtype=object)], dtype=object) + if not obj_array.shape: + obj_array = obj_array.item(0) + elif all(isinstance(arr, np.ndarray) for arr in obj_array.reshape(-1)): + return obj_array # various length, np.array([[1], [2, 3], [4, 5, 6]]) + elif any(isinstance(arr, torch.Tensor) for arr in obj_array.reshape(-1)): + raise ValueError("Numpy arrays of tensors are not supported yet.") + return obj_array + + +def create_value( + inst: Any, + size: int, + stack: bool = True, +) -> Union["Batch", np.ndarray, torch.Tensor]: + """Create empty place-holders according to inst's shape. + + :param stack: whether to stack or to concatenate. E.g. if inst has shape of + (3, 5), size = 10, stack=True returns an np.array with shape of (10, 3, 5), + otherwise (10, 5) + """ + has_shape = isinstance(inst, np.ndarray | torch.Tensor) + is_scalar = _is_scalar(inst) + if not stack and is_scalar: + # should never hit since it has already checked in Batch.cat_ , here we do not + # consider scalar types, following the behavior of numpy which does not support + # concatenation of zero-dimensional arrays (scalars) + raise TypeError(f"cannot concatenate with {inst} which is scalar") + if has_shape: + shape = (size, *inst.shape) if stack else (size, *inst.shape[1:]) + if isinstance(inst, np.ndarray): + target_type = ( + inst.dtype.type if issubclass(inst.dtype.type, np.bool_ | np.number) else object + ) + return np.full(shape, fill_value=None if target_type == object else 0, dtype=target_type) + if isinstance(inst, torch.Tensor): + return torch.full(shape, fill_value=0, device=inst.device, dtype=inst.dtype) + if isinstance(inst, dict | Batch): + zero_batch = Batch() + for key, val in inst.items(): + zero_batch.__dict__[key] = create_value(val, size, stack=stack) + return zero_batch + if is_scalar: + return create_value(np.asarray(inst), size, stack=stack) + # fall back to object + return np.array([None for _ in range(size)], object) + + +def _assert_type_keys(keys: Iterable[str]) -> None: + assert all(isinstance(key, str) for key in keys), f"keys should all be string, but got {keys}" + + +def _parse_value(obj: Any) -> Union["Batch", np.ndarray, torch.Tensor] | None: + if isinstance(obj, Batch): # most often case + return obj + if ( + (isinstance(obj, np.ndarray) and issubclass(obj.dtype.type, np.bool_ | np.number)) + or isinstance(obj, torch.Tensor) + or obj is None + ): # third often case + return obj + if _is_number(obj): # second often case, but it is more time-consuming + return np.asanyarray(obj) + if isinstance(obj, dict): + return Batch(obj) + if ( + not isinstance(obj, np.ndarray) + and isinstance(obj, Collection) + and len(obj) > 0 + and all(isinstance(element, torch.Tensor) for element in obj) + ): + try: + obj = cast(list[torch.Tensor], obj) + return torch.stack(obj) + except RuntimeError as exception: + raise TypeError( + "Batch does not support non-stackable iterable" + " of torch.Tensor as unique value yet.", + ) from exception + if _is_batch_set(obj): + obj = Batch(obj) # list of dict / Batch + else: + # None, scalar, normal obj list (main case) + # or an actual list of objects + try: + obj = _to_array_with_correct_type(obj) + except ValueError as exception: + raise TypeError( + "Batch does not support heterogeneous list/tuple of tensors as unique value yet.", + ) from exception + return obj + + +def alloc_by_keys_diff( + meta: "BatchProtocol", + batch: "BatchProtocol", + size: int, + stack: bool = True, +) -> None: + """Creates place-holders inside meta for keys that are in batch but not in meta. + + This mainly is an internal method, use it only if you know what you are doing. + """ + for key in batch.get_keys(): + if key in meta.get_keys(): + if isinstance(meta[key], Batch) and isinstance(batch[key], Batch): + alloc_by_keys_diff(meta[key], batch[key], size, stack) + elif isinstance(meta[key], Batch) and meta[key].is_empty(): + meta[key] = create_value(batch[key], size, stack) + else: + meta[key] = create_value(batch[key], size, stack) + + +# Note: This is implemented as a protocol because the interface +# of Batch is always extended by adding new fields. Having a hierarchy of +# protocols building off this one allows for type safety and IDE support despite +# the dynamic nature of Batch +@runtime_checkable +class BatchProtocol(Protocol): + """The internal data structure in Tianshou. + + Batch is a kind of supercharged array (of temporal data) stored individually in a + (recursive) dictionary of objects that can be either numpy arrays, torch tensors, or + batches themselves. It is designed to make it extremely easily to access, manipulate + and set partial view of the heterogeneous data conveniently. + + For a detailed description, please refer to :ref:`batch_concept`. + """ + + @property + def shape(self) -> list[int]: + ... + + def __setattr__(self, key: str, value: Any) -> None: + ... + + def __getattr__(self, key: str) -> Any: + ... + + def __contains__(self, key: str) -> bool: + ... + + def __getstate__(self) -> dict: + ... + + def __setstate__(self, state: dict) -> None: + ... + + @overload + def __getitem__(self, index: str) -> Any: + ... + + @overload + def __getitem__(self, index: IndexType) -> Self: + ... + + def __getitem__(self, index: str | IndexType) -> Any: + ... + + def __setitem__(self, index: str | IndexType, value: Any) -> None: + ... + + def __iadd__(self, other: Self | Number | np.number) -> Self: + ... + + def __add__(self, other: Self | Number | np.number) -> Self: + ... + + def __imul__(self, value: Number | np.number) -> Self: + ... + + def __mul__(self, value: Number | np.number) -> Self: + ... + + def __itruediv__(self, value: Number | np.number) -> Self: + ... + + def __truediv__(self, value: Number | np.number) -> Self: + ... + + def __repr__(self) -> str: + ... + + def __iter__(self) -> Iterator[Self]: + ... + + def __eq__(self, other: Any) -> bool: + ... + + @staticmethod + def to_numpy(batch: TBatch) -> TBatch: + """Change all torch.Tensor to numpy.ndarray and return a new Batch.""" + ... + + def to_numpy_(self) -> None: + """Change all torch.Tensor to numpy.ndarray in-place.""" + ... + + @staticmethod + def to_torch( + batch: TBatch, + dtype: torch.dtype | None = None, + device: str | int | torch.device = "cpu", + ) -> TBatch: + """Change all numpy.ndarray to torch.Tensor and return a new Batch.""" + ... + + def to_torch_( + self, + dtype: torch.dtype | None = None, + device: str | int | torch.device = "cpu", + ) -> None: + """Change all numpy.ndarray to torch.Tensor in-place.""" + ... + + def cat_(self, batches: Self | Sequence[dict | Self]) -> None: + """Concatenate a list of (or one) Batch objects into current batch.""" + ... + + @staticmethod + def cat(batches: Sequence[dict | TBatch]) -> TBatch: + """Concatenate a list of Batch object into a single new batch. + + For keys that are not shared across all batches, batches that do not + have these keys will be padded by zeros with appropriate shapes. E.g. + :: + + >>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) + >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5]))) + >>> c = Batch.cat([a, b]) + >>> c.a.shape + (7, 4) + >>> c.b.shape + (7, 3) + >>> c.common.c.shape + (7, 5) + """ + ... + + def stack_(self, batches: Sequence[dict | Self], axis: int = 0) -> None: + """Stack a list of Batch object into current batch.""" + ... + + @staticmethod + def stack(batches: Sequence[dict | TBatch], axis: int = 0) -> TBatch: + """Stack a list of Batch object into a single new batch. + + For keys that are not shared across all batches, batches that do not + have these keys will be padded by zeros. E.g. + :: + + >>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) + >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) + >>> c = Batch.stack([a, b]) + >>> c.a.shape + (2, 4, 4) + >>> c.b.shape + (2, 4, 6) + >>> c.common.c.shape + (2, 4, 5) + + .. note:: + + If there are keys that are not shared across all batches, ``stack`` + with ``axis != 0`` is undefined, and will cause an exception. + """ + ... + + def empty_(self, index: slice | IndexType | None = None) -> Self: + """Return an empty Batch object with 0 or None filled. + + If "index" is specified, it will only reset the specific indexed-data. + :: + + >>> data.empty_() + >>> print(data) + Batch( + a: array([[0., 0.], + [0., 0.]]), + b: array([None, None], dtype=object), + ) + >>> b={'c': [2., 'st'], 'd': [1., 0.]} + >>> data = Batch(a=[False, True], b=b) + >>> data[0] = Batch.empty(data[1]) + >>> data + Batch( + a: array([False, True]), + b: Batch( + c: array([None, 'st']), + d: array([0., 0.]), + ), + ) + """ + ... + + @staticmethod + def empty(batch: TBatch, index: IndexType | None = None) -> TBatch: + """Return an empty Batch object with 0 or None filled. + + The shape is the same as the given Batch. + """ + ... + + def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None: + """Update this batch from another dict/Batch.""" + ... + + def __len__(self) -> int: + ... + + def is_empty(self, recurse: bool = False) -> bool: + ... + + def split( + self, + size: int, + shuffle: bool = True, + merge_last: bool = False, + ) -> Iterator[Self]: + """Split whole data into multiple small batches. + + :param size: divide the data batch with the given size, but one + batch if the length of the batch is smaller than "size". Size of -1 means + the whole batch. + :param shuffle: randomly shuffle the entire data batch if it is + True, otherwise remain in the same. Default to True. + :param merge_last: merge the last batch into the previous one. + Default to False. + """ + ... + + def to_dict(self, recurse: bool = True) -> dict[str, Any]: + ... + + def to_list_of_dicts(self) -> list[dict[str, Any]]: + ... + + +class Batch(BatchProtocol): + """See :class:`~tianshou.data.batch.BatchProtocol`.""" + + __doc__ = BatchProtocol.__doc__ + + def __init__( + self, + batch_dict: dict + | BatchProtocol + | Sequence[dict | BatchProtocol] + | np.ndarray + | None = None, + copy: bool = False, + **kwargs: Any, + ) -> None: + if copy: + batch_dict = deepcopy(batch_dict) + if batch_dict is not None: + if isinstance(batch_dict, dict | BatchProtocol): + _assert_type_keys(batch_dict.keys()) + for batch_key, obj in batch_dict.items(): + self.__dict__[batch_key] = _parse_value(obj) + elif _is_batch_set(batch_dict): + batch_dict = cast(Sequence[dict | BatchProtocol], batch_dict) + self.stack_(batch_dict) + if len(kwargs) > 0: + # TODO: that's a rather weird pattern, is it really needed? + # Feels like kwargs could be just merged into batch_dict in the beginning + self.__init__(kwargs, copy=copy) # type: ignore + + def to_dict(self, recursive: bool = True) -> dict[str, Any]: + result = {} + for k, v in self.__dict__.items(): + if recursive and isinstance(v, Batch): + v = v.to_dict(recursive=recursive) + result[k] = v + return result + + def get_keys(self) -> KeysView: + return self.__dict__.keys() + + def to_list_of_dicts(self) -> list[dict[str, Any]]: + return [entry.to_dict() for entry in self] + + def __setattr__(self, key: str, value: Any) -> None: + """Set self.key = value.""" + self.__dict__[key] = _parse_value(value) + + def __getattr__(self, key: str) -> Any: + """Return self.key. The "Any" return type is needed for mypy.""" + return getattr(self.__dict__, key) + + def __contains__(self, key: str) -> bool: + """Return key in self.""" + return key in self.__dict__ + + def __getstate__(self) -> dict[str, Any]: + """Pickling interface. + + Only the actual data are serialized for both efficiency and simplicity. + """ + state = {} + for batch_key, obj in self.items(): + if isinstance(obj, Batch): + state[batch_key] = obj.__getstate__() + else: + state[batch_key] = obj + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + """Unpickling interface. + + At this point, self is an empty Batch instance that has not been + initialized, so it can safely be initialized by the pickle state. + """ + self.__init__(**state) # type: ignore + + @overload + def __getitem__(self, index: str) -> Any: + ... + + @overload + def __getitem__(self, index: IndexType) -> Self: + ... + + def __getitem__(self, index: str | IndexType) -> Any: + """Return self[index].""" + if isinstance(index, str): + return self.__dict__[index] + batch_items = self.items() + if len(batch_items) > 0: + new_batch = Batch() + for batch_key, obj in batch_items: + if isinstance(obj, Batch) and obj.is_empty(): + new_batch.__dict__[batch_key] = Batch() + else: + new_batch.__dict__[batch_key] = obj[index] + return new_batch + raise IndexError("Cannot access item from empty Batch object.") + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, self.__class__): + return False + + this_batch_no_torch_tensor: Batch = Batch.to_numpy(self) + other_batch_no_torch_tensor: Batch = Batch.to_numpy(other) + this_dict = this_batch_no_torch_tensor.to_dict(recursive=True) + other_dict = other_batch_no_torch_tensor.to_dict(recursive=True) + + return not DeepDiff(this_dict, other_dict) + + def __iter__(self) -> Iterator[Self]: + # TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea + if len(self.__dict__) == 0: + yield from [] + else: + for i in range(len(self)): + yield self[i] + + def __setitem__(self, index: str | IndexType, value: Any) -> None: + """Assign value to self[index].""" + value = _parse_value(value) + if isinstance(index, str): + self.__dict__[index] = value + return + if not isinstance(value, Batch): + raise ValueError( + "Batch does not supported tensor assignment. " + "Use a compatible Batch or dict instead.", + ) + if not set(value.keys()).issubset(self.__dict__.keys()): + raise ValueError("Creating keys is not supported by item assignment.") + for key, val in self.items(): + try: + self.__dict__[key][index] = value[key] + except KeyError: + if isinstance(val, Batch): + self.__dict__[key][index] = Batch() + elif isinstance(val, torch.Tensor) or ( + isinstance(val, np.ndarray) and issubclass(val.dtype.type, np.bool_ | np.number) + ): + self.__dict__[key][index] = 0 + else: + self.__dict__[key][index] = None + + def __iadd__(self, other: Self | Number | np.number) -> Self: + """Algebraic addition with another Batch instance in-place.""" + if isinstance(other, Batch): + for (batch_key, obj), value in zip( + self.__dict__.items(), + other.__dict__.values(), + strict=True, + ): # TODO are keys consistent? + if isinstance(obj, Batch) and obj.is_empty(): + continue + self.__dict__[batch_key] += value + return self + if _is_number(other): + for batch_key, obj in self.items(): + if isinstance(obj, Batch) and obj.is_empty(): + continue + self.__dict__[batch_key] += other + return self + raise TypeError("Only addition of Batch or number is supported.") + + def __add__(self, other: Self | Number | np.number) -> Self: + """Algebraic addition with another Batch instance out-of-place.""" + return deepcopy(self).__iadd__(other) + + def __imul__(self, value: Number | np.number) -> Self: + """Algebraic multiplication with a scalar value in-place.""" + assert _is_number(value), "Only multiplication by a number is supported." + for batch_key, obj in self.__dict__.items(): + if isinstance(obj, Batch) and obj.is_empty(): + continue + self.__dict__[batch_key] *= value + return self + + def __mul__(self, value: Number | np.number) -> Self: + """Algebraic multiplication with a scalar value out-of-place.""" + return deepcopy(self).__imul__(value) + + def __itruediv__(self, value: Number | np.number) -> Self: + """Algebraic division with a scalar value in-place.""" + assert _is_number(value), "Only division by a number is supported." + for batch_key, obj in self.__dict__.items(): + if isinstance(obj, Batch) and obj.is_empty(): + continue + self.__dict__[batch_key] /= value + return self + + def __truediv__(self, value: Number | np.number) -> Self: + """Algebraic division with a scalar value out-of-place.""" + return deepcopy(self).__itruediv__(value) + + def __repr__(self) -> str: + """Return str(self).""" + self_str = self.__class__.__name__ + "(\n" + flag = False + for batch_key, obj in self.__dict__.items(): + rpl = "\n" + " " * (6 + len(batch_key)) + obj_name = pprint.pformat(obj).replace("\n", rpl) + self_str += f" {batch_key}: {obj_name},\n" + flag = True + if flag: + self_str += ")" + else: + self_str = self.__class__.__name__ + "()" + return self_str + + @staticmethod + def to_numpy(batch: TBatch) -> TBatch: + batch_dict = deepcopy(batch) + for batch_key, obj in batch_dict.items(): + if isinstance(obj, torch.Tensor): + batch_dict.__dict__[batch_key] = obj.detach().cpu().numpy() + elif isinstance(obj, Batch): + obj = Batch.to_numpy(obj) + batch_dict.__dict__[batch_key] = obj + + return batch_dict + + def to_numpy_(self) -> None: + for batch_key, obj in self.items(): + if isinstance(obj, torch.Tensor): + self.__dict__[batch_key] = obj.detach().cpu().numpy() + elif isinstance(obj, Batch): + obj.to_numpy_() + + @staticmethod + def to_torch( + batch: TBatch, + dtype: torch.dtype | None = None, + device: str | int | torch.device = "cpu", + ) -> TBatch: + new_batch = Batch(batch, copy=True) + new_batch.to_torch_(dtype=dtype, device=device) + + return new_batch # type: ignore[return-value] + + def to_torch_( + self, + dtype: torch.dtype | None = None, + device: str | int | torch.device = "cpu", + ) -> None: + if not isinstance(device, torch.device): + device = torch.device(device) + + for batch_key, obj in self.items(): + if isinstance(obj, torch.Tensor): + if ( + dtype is not None + and obj.dtype != dtype + or obj.device.type != device.type + or device.index != obj.device.index + ): + if dtype is not None: + self.__dict__[batch_key] = obj.type(dtype).to(device) + else: + self.__dict__[batch_key] = obj.to(device) + elif isinstance(obj, Batch): + obj.to_torch_(dtype, device) + else: + # ndarray or scalar + if not isinstance(obj, np.ndarray): + obj = np.asanyarray(obj) + obj = torch.from_numpy(obj).to(device) + if dtype is not None: + obj = obj.type(dtype) + self.__dict__[batch_key] = obj + + def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: + """Private method for Batch.cat_. + + :: + + >>> a = Batch(a=np.random.randn(3, 4)) + >>> x = Batch(a=a, b=np.random.randn(4, 4)) + >>> y = Batch(a=Batch(a=Batch()), b=np.random.randn(4, 4)) + + If we want to concatenate x and y, we want to pad y.a.a with zeros. + Without ``lens`` as a hint, when we concatenate x.a and y.a, we would + not be able to know how to pad y.a. So ``Batch.cat_`` should compute + the ``lens`` to give ``Batch.__cat`` a hint. + :: + + >>> ans = Batch.cat([x, y]) + >>> # this is equivalent to the following line + >>> ans = Batch(); ans.__cat([x, y], lens=[3, 4]) + >>> # this lens is equal to [len(a), len(b)] + """ + # partial keys will be padded by zeros + # with the shape of [len, rest_shape] + sum_lens = [0] + for len_ in lens: + sum_lens.append(sum_lens[-1] + len_) + # collect non-empty keys + keys_map = [ + { + batch_key + for batch_key, obj in batch.items() + if not (isinstance(obj, Batch) and obj.is_empty()) + } + for batch in batches + ] + keys_shared = set.intersection(*keys_map) + values_shared = [[batch[key] for batch in batches] for key in keys_shared] + for key, shared_value in zip(keys_shared, values_shared, strict=True): + if all(isinstance(element, dict | Batch) for element in shared_value): + batch_holder = Batch() + batch_holder.__cat(shared_value, lens=lens) + self.__dict__[key] = batch_holder + elif all(isinstance(element, torch.Tensor) for element in shared_value): + self.__dict__[key] = torch.cat(shared_value) + else: + # cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch())) + # will fail here + self.__dict__[key] = _to_array_with_correct_type(np.concatenate(shared_value)) + keys_total = set.union(*[set(batch.keys()) for batch in batches]) + keys_reserve_or_partial = set.difference(keys_total, keys_shared) + # keys that are reserved in all batches + keys_reserve = set.difference(keys_total, set.union(*keys_map)) + # keys that occur only in some batches, but not all + keys_partial = keys_reserve_or_partial.difference(keys_reserve) + for key in keys_reserve: + # reserved keys + self.__dict__[key] = Batch() + for key in keys_partial: + for i, batch in enumerate(batches): + if key not in batch.__dict__: + continue + value = batch.get(key) + if isinstance(value, Batch) and value.is_empty(): + continue + try: + self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value + except KeyError: + self.__dict__[key] = create_value(value, sum_lens[-1], stack=False) + self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value + + def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: + if isinstance(batches, BatchProtocol | dict): + batches = [batches] + # check input format + batch_list = [] + for batch in batches: + if isinstance(batch, dict): + if len(batch) > 0: + batch_list.append(Batch(batch)) + elif isinstance(batch, Batch): + # x.is_empty() means that x is Batch() and should be ignored + if not batch.is_empty(): + batch_list.append(batch) + else: + raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_") + if len(batch_list) == 0: + return + batches = batch_list + try: + # x.is_empty(recurse=True) here means x is a nested empty batch + # like Batch(a=Batch), and we have to treat it as length zero and + # keep it. + lens = [0 if batch.is_empty(recurse=True) else len(batch) for batch in batches] + except TypeError as exception: + raise ValueError( + "Batch.cat_ meets an exception. Maybe because there is any " + f"scalar in {batches} but Batch.cat_ does not support the " + "concatenation of scalar.", + ) from exception + if not self.is_empty(): + batches = [self, *list(batches)] + lens = [0 if self.is_empty(recurse=True) else len(self), *lens] + self.__cat(batches, lens) + + @staticmethod + def cat(batches: Sequence[dict | TBatch]) -> TBatch: + batch = Batch() + batch.cat_(batches) + return batch # type: ignore + + def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None: + # check input format + batch_list = [] + for batch in batches: + if isinstance(batch, dict): + if len(batch) > 0: + batch_list.append(Batch(batch)) + elif isinstance(batch, Batch): + # x.is_empty() means that x is Batch() and should be ignored + if not batch.is_empty(): + batch_list.append(batch) + else: + raise ValueError(f"Cannot concatenate {type(batch)} in Batch.stack_") + if len(batch_list) == 0: + return + batches = batch_list + if not self.is_empty(): + batches = [self, *batches] + # collect non-empty keys + keys_map = [ + { + batch_key + for batch_key, obj in batch.items() + if not (isinstance(obj, BatchProtocol) and obj.is_empty()) + } + for batch in batches + ] + keys_shared = set.intersection(*keys_map) + values_shared = [[batch[key] for batch in batches] for key in keys_shared] + for shared_key, value in zip(keys_shared, values_shared, strict=True): + # second often + if all(isinstance(element, torch.Tensor) for element in value): + self.__dict__[shared_key] = torch.stack(value, axis) + # third often + elif all(isinstance(element, BatchProtocol | dict) for element in value): + self.__dict__[shared_key] = Batch.stack(value, axis) + else: # most often case is np.ndarray + try: + self.__dict__[shared_key] = _to_array_with_correct_type(np.stack(value, axis)) + except ValueError: + warnings.warn( + "You are using tensors with different shape," + " fallback to dtype=object by default.", + ) + self.__dict__[shared_key] = np.array(value, dtype=object) + # all the keys + keys_total = set.union(*[set(batch.keys()) for batch in batches]) + # keys that are reserved in all batches + keys_reserve = set.difference(keys_total, set.union(*keys_map)) + # keys that are either partial or reserved + keys_reserve_or_partial = set.difference(keys_total, keys_shared) + # keys that occur only in some batches, but not all + keys_partial = keys_reserve_or_partial.difference(keys_reserve) + if keys_partial and axis != 0: + raise ValueError( + f"Stack of Batch with non-shared keys {keys_partial} is only " + f"supported with axis=0, but got axis={axis}!", + ) + for key in keys_reserve: + # reserved keys + self.__dict__[key] = Batch() + for key in keys_partial: + for i, batch in enumerate(batches): + if key not in batch.__dict__: + continue + value = batch.get(key) + # TODO: fix code/annotations s.t. the ignores can be removed + if ( + isinstance(value, BatchProtocol) # type: ignore + and value.is_empty() # type: ignore + ): + continue # type: ignore + try: + self.__dict__[key][i] = value + except KeyError: + self.__dict__[key] = create_value(value, len(batches)) + self.__dict__[key][i] = value + + @staticmethod + def stack(batches: Sequence[dict | TBatch], axis: int = 0) -> TBatch: + batch = Batch() + batch.stack_(batches, axis) + # can't cast to a generic type, so we have to ignore the type here + return batch # type: ignore + + def empty_(self, index: slice | IndexType | None = None) -> Self: + for batch_key, obj in self.items(): + if isinstance(obj, torch.Tensor): # most often case + self.__dict__[batch_key][index] = 0 + elif obj is None: + continue + elif isinstance(obj, np.ndarray): + if obj.dtype == object: + self.__dict__[batch_key][index] = None + else: + self.__dict__[batch_key][index] = 0 + elif isinstance(obj, Batch): + self.__dict__[batch_key].empty_(index=index) + else: # scalar value + warnings.warn( + "You are calling Batch.empty on a NumPy scalar, " + "which may cause undefined behaviors.", + ) + if _is_number(obj): + self.__dict__[batch_key] = obj.__class__(0) + else: + self.__dict__[batch_key] = None + return self + + @staticmethod + def empty(batch: TBatch, index: IndexType | None = None) -> TBatch: + return deepcopy(batch).empty_(index) + + def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None: + if batch is None: + self.update(kwargs) + return + for batch_key, obj in batch.items(): + self.__dict__[batch_key] = _parse_value(obj) + if kwargs: + self.update(kwargs) + + def __len__(self) -> int: + """Return len(self).""" + lens = [] + for obj in self.__dict__.values(): + # TODO: causes inconsistent behavior to batch with empty batches + # and batch with empty sequences of other type. Remove, but only after + # Buffer and Collectors have been improved to no longer rely on this + if isinstance(obj, Batch) and obj.is_empty(recurse=True): + continue + if hasattr(obj, "__len__") and (isinstance(obj, Batch) or obj.ndim > 0): + lens.append(len(obj)) + else: + raise TypeError(f"Object {obj} in {self} has no len()") + if not lens: + return 0 + return min(lens) + + def is_empty(self, recurse: bool = False) -> bool: + """Test if a Batch is empty. + + If ``recurse=True``, it further tests the values of the object; else + it only tests the existence of any key. + + ``b.is_empty(recurse=True)`` is mainly used to distinguish + ``Batch(a=Batch(a=Batch()))`` and ``Batch(a=1)``. They both raise + exceptions when applied to ``len()``, but the former can be used in + ``cat``, while the latter is a scalar and cannot be used in ``cat``. + + Another usage is in ``__len__``, where we have to skip checking the + length of recursively empty Batch. + :: + + >>> Batch().is_empty() + True + >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty() + False + >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) + True + >>> Batch(d=1).is_empty() + False + >>> Batch(a=np.float64(1.0)).is_empty() + False + """ + if len(self.__dict__) == 0: + return True + if not recurse: + return False + return all( + False if not isinstance(obj, Batch) else obj.is_empty(recurse=True) + for obj in self.values() + ) + + @property + def shape(self) -> list[int]: + """Return self.shape.""" + if self.is_empty(): + return [] + data_shape = [] + for obj in self.__dict__.values(): + try: + data_shape.append(list(obj.shape)) + except AttributeError: + data_shape.append([]) + return ( + list(map(min, zip(*data_shape, strict=False))) if len(data_shape) > 1 else data_shape[0] + ) + + def split( + self, + size: int, + shuffle: bool = True, + merge_last: bool = False, + ) -> Iterator[Self]: + length = len(self) + if size == -1: + size = length + assert size >= 1 # size can be greater than length, return whole batch + indices = np.random.permutation(length) if shuffle else np.arange(length) + merge_last = merge_last and length % size > 0 + for idx in range(0, length, size): + if merge_last and idx + size + size >= length: + yield self[indices[idx:]] + break + yield self[indices[idx : idx + size]] diff --git a/examples/atari/tianshou/data/buffer/__init__.py b/examples/atari/tianshou/data/buffer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/atari/tianshou/data/buffer/base.py b/examples/atari/tianshou/data/buffer/base.py new file mode 100644 index 0000000000000000000000000000000000000000..b1719be560a39371af009a5cb85e0adfce74f51f --- /dev/null +++ b/examples/atari/tianshou/data/buffer/base.py @@ -0,0 +1,421 @@ +from typing import Any, Self, TypeVar, cast + +import h5py +import numpy as np + +from tianshou.data import Batch +from tianshou.data.batch import alloc_by_keys_diff, create_value +from tianshou.data.types import RolloutBatchProtocol +from tianshou.data.utils.converter import from_hdf5, to_hdf5 + +TBuffer = TypeVar("TBuffer", bound="ReplayBuffer") + + +class ReplayBuffer: + """:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. + + ReplayBuffer can be considered as a specialized form (or management) of Batch. It + stores all the data in a batch with circular-queue style. + + For the example usage of ReplayBuffer, please check out Section Buffer in + :doc:`/01_tutorials/01_concepts`. + + :param size: the maximum size of replay buffer. + :param stack_num: the frame-stack sampling argument, should be greater than or + equal to 1. Default to 1 (no stacking). + :param ignore_obs_next: whether to not store obs_next. Default to False. + :param save_only_last_obs: only save the last obs/obs_next when it has a shape + of (timestep, ...) because of temporal stacking. Default to False. + :param sample_avail: the parameter indicating sampling only available index + when using frame-stack sampling method. Default to False. + """ + + _reserved_keys = ( + "obs", + "act", + "rew", + "terminated", + "truncated", + "done", + "obs_next", + "info", + "policy", + ) + _input_keys = ( + "obs", + "act", + "rew", + "terminated", + "truncated", + "obs_next", + "info", + "policy", + ) + + def __init__( + self, + size: int, + stack_num: int = 1, + ignore_obs_next: bool = False, + save_only_last_obs: bool = False, + sample_avail: bool = False, + **kwargs: Any, # otherwise PrioritizedVectorReplayBuffer will cause TypeError + ) -> None: + self.options: dict[str, Any] = { + "stack_num": stack_num, + "ignore_obs_next": ignore_obs_next, + "save_only_last_obs": save_only_last_obs, + "sample_avail": sample_avail, + } + super().__init__() + self.maxsize = int(size) + assert stack_num > 0, "stack_num should be greater than 0" + self.stack_num = stack_num + self._indices = np.arange(size) + self._save_obs_next = not ignore_obs_next + self._save_only_last_obs = save_only_last_obs + self._sample_avail = sample_avail + self._meta = cast(RolloutBatchProtocol, Batch()) + self._ep_rew: float | np.ndarray + self.reset() + + def __len__(self) -> int: + """Return len(self).""" + return self._size + + def __repr__(self) -> str: + """Return str(self).""" + return self.__class__.__name__ + self._meta.__repr__()[5:] + + def __getattr__(self, key: str) -> Any: + """Return self.key.""" + try: + return self._meta[key] + except KeyError as exception: + raise AttributeError from exception + + def __setstate__(self, state: dict[str, Any]) -> None: + """Unpickling interface. + + We need it because pickling buffer does not work out-of-the-box + ("buffer.__getattr__" is customized). + """ + self.__dict__.update(state) + + def __setattr__(self, key: str, value: Any) -> None: + """Set self.key = value.""" + assert key not in self._reserved_keys, f"key '{key}' is reserved and cannot be assigned" + super().__setattr__(key, value) + + def save_hdf5(self, path: str, compression: str | None = None) -> None: + """Save replay buffer to HDF5 file.""" + with h5py.File(path, "w") as f: + to_hdf5(self.__dict__, f, compression=compression) + + @classmethod + def load_hdf5(cls, path: str, device: str | None = None) -> Self: + """Load replay buffer from HDF5 file.""" + with h5py.File(path, "r") as f: + buf = cls.__new__(cls) + buf.__setstate__(from_hdf5(f, device=device)) # type: ignore + return buf + + @classmethod + def from_data( + cls, + obs: h5py.Dataset, + act: h5py.Dataset, + rew: h5py.Dataset, + terminated: h5py.Dataset, + truncated: h5py.Dataset, + done: h5py.Dataset, + obs_next: h5py.Dataset, + ) -> Self: + size = len(obs) + assert all( + len(dset) == size for dset in [obs, act, rew, terminated, truncated, done, obs_next] + ), "Lengths of all hdf5 datasets need to be equal." + buf = cls(size) + if size == 0: + return buf + batch = Batch( + obs=obs, + act=act, + rew=rew, + terminated=terminated, + truncated=truncated, + done=done, + obs_next=obs_next, + ) + batch = cast(RolloutBatchProtocol, batch) + buf.set_batch(batch) + buf._size = size + return buf + + def reset(self, keep_statistics: bool = False) -> None: + """Clear all the data in replay buffer and episode statistics.""" + self.last_index = np.array([0]) + self._index = self._size = 0 + if not keep_statistics: + self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0 + + def set_batch(self, batch: RolloutBatchProtocol) -> None: + """Manually choose the batch you want the ReplayBuffer to manage.""" + assert len(batch) == self.maxsize and set(batch.keys()).issubset( + self._reserved_keys, + ), "Input batch doesn't meet ReplayBuffer's data form requirement." + self._meta = batch + + def unfinished_index(self) -> np.ndarray: + """Return the index of unfinished episode.""" + last = (self._index - 1) % self._size if self._size else 0 + return np.array([last] if not self.done[last] and self._size else [], int) + + def prev(self, index: int | np.ndarray) -> np.ndarray: + """Return the index of preceding step within the same episode if it exists. + If it does not exist (because it is the first index within the episode), + the index remains unmodified. + """ + index = (index - 1) % self._size # compute preceding index with wrap-around + # end_flag will be 1 if the previous index is the last step of an episode or + # if it is the very last index of the buffer (wrap-around case), and 0 otherwise + end_flag = self.done[index] | (index == self.last_index[0]) + return (index + end_flag) % self._size + + def next(self, index: int | np.ndarray) -> np.ndarray: + """Return the index of next step if there is a next step within the episode. + If there isn't a next step, the index remains unmodified. + """ + end_flag = self.done[index] | (index == self.last_index[0]) + return (index + (1 - end_flag)) % self._size + + def update(self, buffer: "ReplayBuffer") -> np.ndarray: + """Move the data from the given buffer to current buffer. + + Return the updated indices. If update fails, return an empty array. + """ + if len(buffer) == 0 or self.maxsize == 0: + return np.array([], int) + stack_num, buffer.stack_num = buffer.stack_num, 1 + from_indices = buffer.sample_indices(0) # get all available indices + buffer.stack_num = stack_num + if len(from_indices) == 0: + return np.array([], int) + to_indices = [] + for _ in range(len(from_indices)): + to_indices.append(self._index) + self.last_index[0] = self._index + self._index = (self._index + 1) % self.maxsize + self._size = min(self._size + 1, self.maxsize) + to_indices = np.array(to_indices) + if self._meta.is_empty(): + self._meta = create_value(buffer._meta, self.maxsize, stack=False) # type: ignore + self._meta[to_indices] = buffer._meta[from_indices] + return to_indices + + def _add_index( + self, + rew: float | np.ndarray, + done: bool, + ) -> tuple[int, float | np.ndarray, int, int]: + """Maintain the buffer's state after adding one data batch. + + Return (index_to_be_modified, episode_reward, episode_length, + episode_start_index). + """ + self.last_index[0] = ptr = self._index + self._size = min(self._size + 1, self.maxsize) + self._index = (self._index + 1) % self.maxsize + + self._ep_rew += rew + self._ep_len += 1 + + if done: + result = ptr, self._ep_rew, self._ep_len, self._ep_idx + self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, self._index + return result + return ptr, self._ep_rew * 0.0, 0, self._ep_idx + + def add( + self, + batch: RolloutBatchProtocol, + buffer_ids: np.ndarray | list[int] | None = None, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Add a batch of data into replay buffer. + + :param batch: the input data batch. "obs", "act", "rew", + "terminated", "truncated" are required keys. + :param buffer_ids: to make consistent with other buffer's add function; if it + is not None, we assume the input batch's first dimension is always 1. + + Return (current_index, episode_reward, episode_length, episode_start_index). If + the episode is not finished, the return value of episode_length and + episode_reward is 0. + """ + # preprocess batch + new_batch = Batch() + for key in batch.get_keys(): + new_batch.__dict__[key] = batch[key] + batch = new_batch + batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) + assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset( + batch.get_keys(), + ) # important to do after preprocess batch + stacked_batch = buffer_ids is not None + if stacked_batch: + assert len(batch) == 1 + if self._save_only_last_obs: + batch.obs = batch.obs[:, -1] if stacked_batch else batch.obs[-1] + if not self._save_obs_next: + batch.pop("obs_next", None) + elif self._save_only_last_obs: + batch.obs_next = batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1] + # get ptr + if stacked_batch: + rew, done = batch.rew[0], batch.done[0] + else: + rew, done = batch.rew, batch.done + ptr, ep_rew, ep_len, ep_idx = (np.array([x]) for x in self._add_index(rew, done)) + try: + self._meta[ptr] = batch + except ValueError: + stack = not stacked_batch + batch.rew = batch.rew.astype(float) + batch.done = batch.done.astype(bool) + batch.terminated = batch.terminated.astype(bool) + batch.truncated = batch.truncated.astype(bool) + if self._meta.is_empty(): + self._meta = create_value(batch, self.maxsize, stack) # type: ignore + else: # dynamic key pops up in batch + alloc_by_keys_diff(self._meta, batch, self.maxsize, stack) + self._meta[ptr] = batch + return ptr, ep_rew, ep_len, ep_idx + + def sample_indices(self, batch_size: int | None) -> np.ndarray: + """Get a random sample of index with size = batch_size. + + Return all available indices in the buffer if batch_size is 0; return an empty + numpy array if batch_size < 0 or no available index can be sampled. + + :param batch_size: the number of indices to be sampled. If None, it will be set + to the length of the buffer (i.e. return all available indices in a + random order). + """ + if batch_size is None: + batch_size = len(self) + if self.stack_num == 1 or not self._sample_avail: # most often case + if batch_size > 0: + return np.random.choice(self._size, batch_size) + # TODO: is this behavior really desired? + if batch_size == 0: # construct current available indices + return np.concatenate([np.arange(self._index, self._size), np.arange(self._index)]) + return np.array([], int) + # TODO: raise error on negative batch_size instead? + if batch_size < 0: + return np.array([], int) + # TODO: simplify this code - shouldn't have such a large if-else + # with many returns for handling different stack nums. + # It is also not clear whether this is really necessary - frame stacking usually is handled + # by environment wrappers (e.g. FrameStack) and not by the replay buffer. + all_indices = prev_indices = np.concatenate( + [np.arange(self._index, self._size), np.arange(self._index)], + ) + for _ in range(self.stack_num - 2): + prev_indices = self.prev(prev_indices) + all_indices = all_indices[prev_indices != self.prev(prev_indices)] + if batch_size > 0: + return np.random.choice(all_indices, batch_size) + return all_indices + + def sample(self, batch_size: int | None) -> tuple[RolloutBatchProtocol, np.ndarray]: + """Get a random sample from buffer with size = batch_size. + + Return all the data in the buffer if batch_size is 0. + + :return: Sample data and its corresponding index inside the buffer. + """ + indices = self.sample_indices(batch_size) + return self[indices], indices + + def get( + self, + index: int | list[int] | np.ndarray, + key: str, + default_value: Any = None, + stack_num: int | None = None, + ) -> Batch | np.ndarray: + """Return the stacked result. + + E.g., if you set ``key = "obs", stack_num = 4, index = t``, it returns the + stacked result as ``[obs[t-3], obs[t-2], obs[t-1], obs[t]]``. + + :param index: the index for getting stacked data. + :param str key: the key to get, should be one of the reserved_keys. + :param default_value: if the given key's data is not found and default_value is + set, return this default_value. + :param stack_num: Default to self.stack_num. + """ + if key not in self._meta and default_value is not None: + return default_value + val = self._meta[key] + if stack_num is None: + stack_num = self.stack_num + try: + if stack_num == 1: # the most common case + return val[index] + + stack = list[Any]() + indices = np.array(index) if isinstance(index, list) else index + # NOTE: stack_num > 1, so the range is not empty and indices is turned into + # np.ndarray by self.prev + for _ in range(stack_num): + stack = [val[indices], *stack] + indices = self.prev(indices) + indices = cast(np.ndarray, indices) + if isinstance(val, Batch): + return Batch.stack(stack, axis=indices.ndim) + return np.stack(stack, axis=indices.ndim) + + except IndexError as exception: + if not (isinstance(val, Batch) and val.is_empty()): + raise exception # val != Batch() + return Batch() + + def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> RolloutBatchProtocol: + """Return a data batch: self[index]. + + If stack_num is larger than 1, return the stacked obs and obs_next with shape + (batch, len, ...). + """ + if isinstance(index, slice): # change slice to np array + # buffer[:] will get all available data + indices = ( + self.sample_indices(0) + if index == slice(None) + else self._indices[: len(self)][index] + ) + else: + indices = index # type: ignore + # raise KeyError first instead of AttributeError, + # to support np.array([ReplayBuffer()]) + obs = self.get(indices, "obs") + if self._save_obs_next: + obs_next = self.get(indices, "obs_next", Batch()) + else: + obs_next = self.get(self.next(indices), "obs", Batch()) + batch_dict = { + "obs": obs, + "act": self.act[indices], + "rew": self.rew[indices], + "terminated": self.terminated[indices], + "truncated": self.truncated[indices], + "done": self.done[indices], + "obs_next": obs_next, + "info": self.get(indices, "info", Batch()), + # TODO: what's the use of this key? + "policy": self.get(indices, "policy", Batch()), + } + for key in self._meta.__dict__: + if key not in self._input_keys: + batch_dict[key] = self._meta[key][indices] + return cast(RolloutBatchProtocol, Batch(batch_dict)) diff --git a/examples/atari/tianshou/data/buffer/cached.py b/examples/atari/tianshou/data/buffer/cached.py new file mode 100644 index 0000000000000000000000000000000000000000..97e0a8054cfe5e6cf92e6dabe47d8d968d402109 --- /dev/null +++ b/examples/atari/tianshou/data/buffer/cached.py @@ -0,0 +1,82 @@ +import numpy as np + +from tianshou.data import ReplayBuffer, ReplayBufferManager +from tianshou.data.types import RolloutBatchProtocol + + +class CachedReplayBuffer(ReplayBufferManager): + """CachedReplayBuffer contains a given main buffer and n cached buffers, ``cached_buffer_num * ReplayBuffer(size=max_episode_length)``. + + The memory layout is: ``| main_buffer | cached_buffers[0] | cached_buffers[1] | ... + | cached_buffers[cached_buffer_num - 1] |``. + + The data is first stored in cached buffers. When an episode is terminated, the data + will move to the main buffer and the corresponding cached buffer will be reset. + + :param main_buffer: the main buffer whose ``.update()`` function + behaves normally. + :param cached_buffer_num: number of ReplayBuffer needs to be created for cached + buffer. + :param max_episode_length: the maximum length of one episode, used in each + cached buffer's maxsize. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__( + self, + main_buffer: ReplayBuffer, + cached_buffer_num: int, + max_episode_length: int, + ) -> None: + assert cached_buffer_num > 0 + assert max_episode_length > 0 + assert isinstance(main_buffer, ReplayBuffer) + kwargs = main_buffer.options + buffers = [main_buffer] + [ + ReplayBuffer(max_episode_length, **kwargs) for _ in range(cached_buffer_num) + ] + super().__init__(buffer_list=buffers) + self.main_buffer = self.buffers[0] + self.cached_buffers = self.buffers[1:] + self.cached_buffer_num = cached_buffer_num + + def add( + self, + batch: RolloutBatchProtocol, + buffer_ids: np.ndarray | list[int] | None = None, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Add a batch of data into CachedReplayBuffer. + + Each of the data's length (first dimension) must equal to the length of + buffer_ids. By default the buffer_ids is [0, 1, ..., cached_buffer_num - 1]. + + Return (current_index, episode_reward, episode_length, episode_start_index) + with each of the shape (len(buffer_ids), ...), where (current_index[i], + episode_reward[i], episode_length[i], episode_start_index[i]) refers to the + cached_buffer_ids[i]th cached buffer's corresponding episode result. + """ + if buffer_ids is None: + buf_arr = np.arange(1, 1 + self.cached_buffer_num) + else: # make sure it is np.ndarray + buf_arr = np.asarray(buffer_ids) + 1 + ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids=buf_arr) + # find the terminated episode, move data from cached buf to main buf + updated_ptr, updated_ep_idx = [], [] + done = np.logical_or(batch.terminated, batch.truncated) + for buffer_idx in buf_arr[done]: + index = self.main_buffer.update(self.buffers[buffer_idx]) + if len(index) == 0: # unsuccessful move, replace with -1 + index = [-1] + updated_ep_idx.append(index[0]) + updated_ptr.append(index[-1]) + self.buffers[buffer_idx].reset() + self._lengths[0] = len(self.main_buffer) + self._lengths[buffer_idx] = 0 + self.last_index[0] = index[-1] + self.last_index[buffer_idx] = self._offset[buffer_idx] + ptr[done] = updated_ptr + ep_idx[done] = updated_ep_idx + return ptr, ep_rew, ep_len, ep_idx diff --git a/examples/atari/tianshou/data/buffer/her.py b/examples/atari/tianshou/data/buffer/her.py new file mode 100644 index 0000000000000000000000000000000000000000..1ae1e8f23d1033d2c0a05cffb8ad26a19e37d5e6 --- /dev/null +++ b/examples/atari/tianshou/data/buffer/her.py @@ -0,0 +1,195 @@ +from collections.abc import Callable +from typing import Any, Union, cast + +import numpy as np + +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import RolloutBatchProtocol + + +class HERReplayBuffer(ReplayBuffer): + """Implementation of Hindsight Experience Replay. arXiv:1707.01495. + + HERReplayBuffer is to be used with goal-based environment where the + observation is a dictionary with keys ``observation``, ``achieved_goal`` and + ``desired_goal``. Currently support only HER's future strategy, online sampling. + + :param size: the size of the replay buffer. + :param compute_reward_fn: a function that takes 2 ``np.array`` arguments, + ``acheived_goal`` and ``desired_goal``, and returns rewards as ``np.array``. + The two arguments are of shape (batch_size, ...original_shape) and the returned + rewards must be of shape (batch_size,). + :param horizon: the maximum number of steps in an episode. + :param future_k: the 'k' parameter introduced in the paper. In short, there + will be at most k episodes that are re-written for every 1 unaltered episode + during the sampling. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__( + self, + size: int, + compute_reward_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], + horizon: int, + future_k: float = 8.0, + **kwargs: Any, + ) -> None: + super().__init__(size, **kwargs) + self.horizon = horizon + self.future_p = 1 - 1 / future_k + self.compute_reward_fn = compute_reward_fn + self._original_meta = Batch() + self._altered_indices = np.array([]) + + def _restore_cache(self) -> None: + """Write cached original meta back to `self._meta`. + + It's called everytime before 'writing', 'sampling' or 'saving' the buffer. + """ + if not hasattr(self, "_altered_indices"): + return + + if self._altered_indices.size == 0: + return + self._meta[self._altered_indices] = self._original_meta + # Clean + self._original_meta = Batch() + self._altered_indices = np.array([]) + + def reset(self, keep_statistics: bool = False) -> None: + self._restore_cache() + return super().reset(keep_statistics) + + def save_hdf5(self, path: str, compression: str | None = None) -> None: + self._restore_cache() + return super().save_hdf5(path, compression) + + def set_batch(self, batch: RolloutBatchProtocol) -> None: + self._restore_cache() + return super().set_batch(batch) + + def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray: + self._restore_cache() + return super().update(buffer) + + def add( + self, + batch: RolloutBatchProtocol, + buffer_ids: np.ndarray | list[int] | None = None, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + self._restore_cache() + return super().add(batch, buffer_ids) + + def sample_indices(self, batch_size: int | None) -> np.ndarray: + """Get a random sample of index with size = batch_size. + + Return all available indices in the buffer if batch_size is 0; return an \ + empty numpy array if batch_size < 0 or no available index can be sampled. \ + Additionally, some episodes of the sampled transitions will be re-written \ + according to HER. + """ + self._restore_cache() + indices = super().sample_indices(batch_size=batch_size) + self.rewrite_transitions(indices.copy()) + return indices + + def rewrite_transitions(self, indices: np.ndarray) -> None: + """Re-write the goal of some sampled transitions' episodes according to HER. + + Currently applies only HER's 'future' strategy. The new goals will be written \ + directly to the internal batch data temporarily and will be restored right \ + before the next sampling or when using some of the buffer's method (e.g. \ + `add`, `save_hdf5`, etc.). This is to make sure that n-step returns \ + calculation etc., performs correctly without additional alteration. + """ + if indices.size == 0: + return + + # Sort indices keeping chronological order + indices[indices < self._index] += self.maxsize + indices = np.sort(indices) + indices[indices >= self.maxsize] -= self.maxsize + + # Construct episode trajectories + indices = [indices] + for _ in range(self.horizon - 1): + indices.append(self.next(indices[-1])) + indices = np.stack(indices) + + # Calculate future timestep to use + current = indices[0] + terminal = indices[-1] + episodes_len = (terminal - current + self.maxsize) % self.maxsize + future_offset = np.random.uniform(size=len(indices[0])) * episodes_len + future_offset = np.round(future_offset).astype(int) + future_t = (current + future_offset) % self.maxsize + + # Compute indices + # open indices are used to find longest, unique trajectories among + # presented episodes + unique_ep_open_indices = np.sort(np.unique(terminal, return_index=True)[1]) + unique_ep_indices = indices[:, unique_ep_open_indices] + # close indices are used to find max future_t among presented episodes + unique_ep_close_indices = np.hstack([(unique_ep_open_indices - 1)[1:], len(terminal) - 1]) + # episode indices that will be altered + her_ep_indices = np.random.choice( + len(unique_ep_open_indices), + size=int(len(unique_ep_open_indices) * self.future_p), + replace=False, + ) + + # Cache original meta + self._altered_indices = unique_ep_indices.copy() + self._original_meta = self._meta[self._altered_indices].copy() + + # Copy original obs, ep_rew (and obs_next), and obs of future time step + ep_obs = self[unique_ep_indices].obs + # to satisfy mypy + # TODO: add protocol covering these batches + assert isinstance(ep_obs, BatchProtocol) + ep_rew = self[unique_ep_indices].rew + if self._save_obs_next: + ep_obs_next = self[unique_ep_indices].obs_next + # to satisfy mypy + assert isinstance(ep_obs_next, BatchProtocol) + future_obs = self[future_t[unique_ep_close_indices]].obs_next + else: + future_obs = self[self.next(future_t[unique_ep_close_indices])].obs + future_obs = cast(BatchProtocol, future_obs) + + # Re-assign goals and rewards via broadcast assignment + ep_obs.desired_goal[:, her_ep_indices] = future_obs.achieved_goal[None, her_ep_indices] + if self._save_obs_next: + ep_obs_next = cast(BatchProtocol, ep_obs_next) + ep_obs_next.desired_goal[:, her_ep_indices] = future_obs.achieved_goal[ + None, + her_ep_indices, + ] + ep_rew[:, her_ep_indices] = self._compute_reward(ep_obs_next)[:, her_ep_indices] + else: + tmp_ep_obs_next = self[self.next(unique_ep_indices)].obs + assert isinstance(tmp_ep_obs_next, BatchProtocol) + ep_rew[:, her_ep_indices] = self._compute_reward(tmp_ep_obs_next)[:, her_ep_indices] + + # Sanity check + assert ep_obs.desired_goal.shape[:2] == unique_ep_indices.shape + assert ep_obs.achieved_goal.shape[:2] == unique_ep_indices.shape + assert ep_rew.shape == unique_ep_indices.shape + + # Re-write meta + assert isinstance(self._meta.obs, BatchProtocol) + self._meta.obs[unique_ep_indices] = ep_obs + if self._save_obs_next: + self._meta.obs_next[unique_ep_indices] = ep_obs_next # type: ignore + self._meta.rew[unique_ep_indices] = ep_rew.astype(np.float32) + + def _compute_reward(self, obs: BatchProtocol, lead_dims: int = 2) -> np.ndarray: + lead_shape = obs.observation.shape[:lead_dims] + g = obs.desired_goal.reshape(-1, *obs.desired_goal.shape[lead_dims:]) + ag = obs.achieved_goal.reshape(-1, *obs.achieved_goal.shape[lead_dims:]) + rewards = self.compute_reward_fn(ag, g) + return rewards.reshape(*lead_shape, *rewards.shape[1:]) diff --git a/examples/atari/tianshou/data/buffer/manager.py b/examples/atari/tianshou/data/buffer/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..90480257ac21324a0d3cd70b8694d8d781c18d64 --- /dev/null +++ b/examples/atari/tianshou/data/buffer/manager.py @@ -0,0 +1,323 @@ +from collections.abc import Sequence +from typing import Union + +import numpy as np +from numba import njit + +from tianshou.data import Batch, HERReplayBuffer, PrioritizedReplayBuffer, ReplayBuffer +from tianshou.data.batch import alloc_by_keys_diff, create_value +from tianshou.data.types import RolloutBatchProtocol + + +class ReplayBufferManager(ReplayBuffer): + """ReplayBufferManager contains a list of ReplayBuffer with exactly the same configuration. + + These replay buffers have contiguous memory layout, and the storage space each + buffer has is a shallow copy of the topmost memory. + + :param buffer_list: a list of ReplayBuffer needed to be handled. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, buffer_list: list[ReplayBuffer] | list[HERReplayBuffer]) -> None: + self.buffer_num = len(buffer_list) + self.buffers = np.array(buffer_list, dtype=object) + offset, size = [], 0 + buffer_type = type(self.buffers[0]) + kwargs = self.buffers[0].options + for buf in self.buffers: + assert buf._meta.is_empty() + assert isinstance(buf, buffer_type) + assert buf.options == kwargs + offset.append(size) + size += buf.maxsize + self._offset = np.array(offset) + self._extend_offset = np.array([*offset, size]) + self._lengths = np.zeros_like(offset) + super().__init__(size=size, **kwargs) + self._compile() + self._meta: RolloutBatchProtocol + + def _compile(self) -> None: + lens = last = index = np.array([0]) + offset = np.array([0, 1]) + done = np.array([False, False]) + _prev_index(index, offset, done, last, lens) + _next_index(index, offset, done, last, lens) + + def __len__(self) -> int: + return int(self._lengths.sum()) + + def reset(self, keep_statistics: bool = False) -> None: + self.last_index = self._offset.copy() + self._lengths = np.zeros_like(self._offset) + for buf in self.buffers: + buf.reset(keep_statistics=keep_statistics) + + def _set_batch_for_children(self) -> None: + for offset, buf in zip(self._offset, self.buffers, strict=True): + buf.set_batch(self._meta[offset : offset + buf.maxsize]) + + def set_batch(self, batch: RolloutBatchProtocol) -> None: + super().set_batch(batch) + self._set_batch_for_children() + + def unfinished_index(self) -> np.ndarray: + return np.concatenate( + [ + buf.unfinished_index() + offset + for offset, buf in zip(self._offset, self.buffers, strict=True) + ], + ) + + def prev(self, index: int | np.ndarray) -> np.ndarray: + if isinstance(index, list | np.ndarray): + return _prev_index( + np.asarray(index), + self._extend_offset, + self.done, + self.last_index, + self._lengths, + ) + return _prev_index( + np.array([index]), + self._extend_offset, + self.done, + self.last_index, + self._lengths, + )[0] + + def next(self, index: int | np.ndarray) -> np.ndarray: + if isinstance(index, list | np.ndarray): + return _next_index( + np.asarray(index), + self._extend_offset, + self.done, + self.last_index, + self._lengths, + ) + return _next_index( + np.array([index]), + self._extend_offset, + self.done, + self.last_index, + self._lengths, + )[0] + + def update(self, buffer: ReplayBuffer) -> np.ndarray: + """The ReplayBufferManager cannot be updated by any buffer.""" + raise NotImplementedError + + def add( + self, + batch: RolloutBatchProtocol, + buffer_ids: np.ndarray | list[int] | None = None, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Add a batch of data into ReplayBufferManager. + + Each of the data's length (first dimension) must equal to the length of + buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. + + Return (current_index, episode_reward, episode_length, episode_start_index). If + the episode is not finished, the return value of episode_length and + episode_reward is 0. + """ + # preprocess batch + new_batch = Batch() + for key in set(self._reserved_keys).intersection(batch.get_keys()): + new_batch.__dict__[key] = batch[key] + batch = new_batch + batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) + assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.get_keys()) + if self._save_only_last_obs: + batch.obs = batch.obs[:, -1] + if not self._save_obs_next: + batch.pop("obs_next", None) + elif self._save_only_last_obs: + batch.obs_next = batch.obs_next[:, -1] + # get index + if buffer_ids is None: + buffer_ids = np.arange(self.buffer_num) + ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], [] + for batch_idx, buffer_id in enumerate(buffer_ids): + ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index( + batch.rew[batch_idx], + batch.done[batch_idx], + ) + ptrs.append(ptr + self._offset[buffer_id]) + ep_lens.append(ep_len) + ep_rews.append(ep_rew) + ep_idxs.append(ep_idx + self._offset[buffer_id]) + self.last_index[buffer_id] = ptr + self._offset[buffer_id] + self._lengths[buffer_id] = len(self.buffers[buffer_id]) + ptrs = np.array(ptrs) + try: + self._meta[ptrs] = batch + except ValueError: + batch.rew = batch.rew.astype(float) + batch.done = batch.done.astype(bool) + batch.terminated = batch.terminated.astype(bool) + batch.truncated = batch.truncated.astype(bool) + if self._meta.is_empty(): + self._meta = create_value(batch, self.maxsize, stack=False) # type: ignore + else: # dynamic key pops up in batch + alloc_by_keys_diff(self._meta, batch, self.maxsize, False) + self._set_batch_for_children() + self._meta[ptrs] = batch + return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs) + + def sample_indices(self, batch_size: int | None) -> np.ndarray: + # TODO: simplify this code + if batch_size is not None and batch_size < 0: + # TODO: raise error instead? + return np.array([], int) + if self._sample_avail and self.stack_num > 1: + all_indices = np.concatenate( + [ + buf.sample_indices(0) + offset + for offset, buf in zip(self._offset, self.buffers, strict=True) + ], + ) + if batch_size == 0: + return all_indices + if batch_size is None: + batch_size = len(all_indices) + return np.random.choice(all_indices, batch_size) + if batch_size == 0 or batch_size is None: # get all available indices + sample_num = np.zeros(self.buffer_num, int) + else: + buffer_idx = np.random.choice( + self.buffer_num, + batch_size, + p=self._lengths / self._lengths.sum(), + ) + sample_num = np.bincount(buffer_idx, minlength=self.buffer_num) + # avoid batch_size > 0 and sample_num == 0 -> get child's all data + sample_num[sample_num == 0] = -1 + + return np.concatenate( + [ + buf.sample_indices(int(bsz)) + offset + for offset, buf, bsz in zip(self._offset, self.buffers, sample_num, strict=True) + ], + ) + + +class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManager): + """PrioritizedReplayBufferManager contains a list of PrioritizedReplayBuffer with exactly the same configuration. + + These replay buffers have contiguous memory layout, and the storage space each + buffer has is a shallow copy of the topmost memory. + + :param buffer_list: a list of PrioritizedReplayBuffer needed to be handled. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None: + ReplayBufferManager.__init__(self, buffer_list) # type: ignore + kwargs = buffer_list[0].options + for buf in buffer_list: + del buf.weight + PrioritizedReplayBuffer.__init__(self, self.maxsize, **kwargs) + + +class HERReplayBufferManager(ReplayBufferManager): + """HERReplayBufferManager contains a list of HERReplayBuffer with exactly the same configuration. + + These replay buffers have contiguous memory layout, and the storage space each + buffer has is a shallow copy of the topmost memory. + + :param buffer_list: a list of HERReplayBuffer needed to be handled. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, buffer_list: list[HERReplayBuffer]) -> None: + super().__init__(buffer_list) + + def _restore_cache(self) -> None: + for buf in self.buffers: + buf._restore_cache() + + def save_hdf5(self, path: str, compression: str | None = None) -> None: + self._restore_cache() + return super().save_hdf5(path, compression) + + def set_batch(self, batch: RolloutBatchProtocol) -> None: + self._restore_cache() + return super().set_batch(batch) + + def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray: + self._restore_cache() + return super().update(buffer) + + def add( + self, + batch: RolloutBatchProtocol, + buffer_ids: np.ndarray | list[int] | None = None, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + self._restore_cache() + return super().add(batch, buffer_ids) + + +@njit +def _prev_index( + index: np.ndarray, + offset: np.ndarray, + done: np.ndarray, + last_index: np.ndarray, + lengths: np.ndarray, +) -> np.ndarray: + index = index % offset[-1] + prev_index = np.zeros_like(index) + # disable B905 until strict=True in zip is implemented in numba + # https://github.com/numba/numba/issues/8943 + for start, end, cur_len, last in zip( # noqa: B905 + offset[:-1], + offset[1:], + lengths, + last_index, + ): + mask = (start <= index) & (index < end) + correct_cur_len = max(1, cur_len) + if np.sum(mask) > 0: + subind = index[mask] + subind = (subind - start - 1) % correct_cur_len + end_flag = done[subind + start] | (subind + start == last) + prev_index[mask] = (subind + end_flag) % correct_cur_len + start + return prev_index + + +@njit +def _next_index( + index: np.ndarray, + offset: np.ndarray, + done: np.ndarray, + last_index: np.ndarray, + lengths: np.ndarray, +) -> np.ndarray: + index = index % offset[-1] + next_index = np.zeros_like(index) + # disable B905 until strict=True in zip is implemented in numba + # https://github.com/numba/numba/issues/8943 + for start, end, cur_len, last in zip( # noqa: B905 + offset[:-1], + offset[1:], + lengths, + last_index, + ): + mask = (start <= index) & (index < end) + correct_cur_len = max(1, cur_len) + if np.sum(mask) > 0: + subind = index[mask] + end_flag = done[subind] | (subind == last) + next_index[mask] = (subind - start + 1 - end_flag) % correct_cur_len + start + return next_index diff --git a/examples/atari/tianshou/data/buffer/prio.py b/examples/atari/tianshou/data/buffer/prio.py new file mode 100644 index 0000000000000000000000000000000000000000..bef6a06a06b479567e252bb1ce11b7aa3fb99cea --- /dev/null +++ b/examples/atari/tianshou/data/buffer/prio.py @@ -0,0 +1,107 @@ +from typing import Any, cast + +import numpy as np +import torch + +from tianshou.data import ReplayBuffer, SegmentTree, to_numpy +from tianshou.data.types import PrioBatchProtocol, RolloutBatchProtocol + + +class PrioritizedReplayBuffer(ReplayBuffer): + """Implementation of Prioritized Experience Replay. arXiv:1511.05952. + + :param alpha: the prioritization exponent. + :param beta: the importance sample soft coefficient. + :param weight_norm: whether to normalize returned weights with the maximum + weight value within the batch. Default to True. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__( + self, + size: int, + alpha: float, + beta: float, + weight_norm: bool = True, + **kwargs: Any, + ) -> None: + # will raise KeyError in PrioritizedVectorReplayBuffer + # super().__init__(size, **kwargs) + ReplayBuffer.__init__(self, size, **kwargs) + assert alpha > 0.0 + assert beta >= 0.0 + self._alpha, self._beta = alpha, beta + self._max_prio = self._min_prio = 1.0 + # save weight directly in this class instead of self._meta + self.weight = SegmentTree(size) + self.__eps = np.finfo(np.float32).eps.item() + self.options.update(alpha=alpha, beta=beta) + self._weight_norm = weight_norm + + def init_weight(self, index: int | np.ndarray) -> None: + self.weight[index] = self._max_prio**self._alpha + + def update(self, buffer: ReplayBuffer) -> np.ndarray: + indices = super().update(buffer) + self.init_weight(indices) + return indices + + def add( + self, + batch: RolloutBatchProtocol, + buffer_ids: np.ndarray | list[int] | None = None, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids) + self.init_weight(ptr) + return ptr, ep_rew, ep_len, ep_idx + + def sample_indices(self, batch_size: int | None) -> np.ndarray: + if batch_size is not None and batch_size > 0 and len(self) > 0: + scalar = np.random.rand(batch_size) * self.weight.reduce() + return self.weight.get_prefix_sum_idx(scalar) # type: ignore + return super().sample_indices(batch_size) + + def get_weight(self, index: int | np.ndarray) -> float | np.ndarray: + """Get the importance sampling weight. + + The "weight" in the returned Batch is the weight on loss function to debias + the sampling process (some transition tuples are sampled more often so their + losses are weighted less). + """ + # important sampling weight calculation + # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) + # simplified formula: (p_j/p_min)**(-beta) + return (self.weight[index] / self._min_prio) ** (-self._beta) + + def update_weight(self, index: np.ndarray, new_weight: np.ndarray | torch.Tensor) -> None: + """Update priority weight by index in this buffer. + + :param np.ndarray index: index you want to update weight. + :param np.ndarray new_weight: new priority weight you want to update. + """ + weight = np.abs(to_numpy(new_weight)) + self.__eps + self.weight[index] = weight**self._alpha + self._max_prio = max(self._max_prio, weight.max()) + self._min_prio = min(self._min_prio, weight.min()) + + def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> PrioBatchProtocol: + if isinstance(index, slice): # change slice to np array + # buffer[:] will get all available data + indices = ( + self.sample_indices(0) + if index == slice(None) + else self._indices[: len(self)][index] + ) + else: + indices = index # type: ignore + batch = super().__getitem__(indices) + weight = self.get_weight(indices) + # ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154 + batch.weight = weight / np.max(weight) if self._weight_norm else weight + return cast(PrioBatchProtocol, batch) + + def set_beta(self, beta: float) -> None: + self._beta = beta diff --git a/examples/atari/tianshou/data/buffer/vecbuf.py b/examples/atari/tianshou/data/buffer/vecbuf.py new file mode 100644 index 0000000000000000000000000000000000000000..00b2560f23ee4a6080475f68b15e8915024928b7 --- /dev/null +++ b/examples/atari/tianshou/data/buffer/vecbuf.py @@ -0,0 +1,89 @@ +from typing import Any + +import numpy as np + +from tianshou.data import ( + HERReplayBuffer, + HERReplayBufferManager, + PrioritizedReplayBuffer, + PrioritizedReplayBufferManager, + ReplayBuffer, + ReplayBufferManager, +) + + +class VectorReplayBuffer(ReplayBufferManager): + """VectorReplayBuffer contains n ReplayBuffer with the same size. + + It is used for storing transition from different environments yet keeping the order + of time. + + :param total_size: the total size of VectorReplayBuffer. + :param buffer_num: the number of ReplayBuffer it uses, which are under the same + configuration. + + Other input arguments (stack_num/ignore_obs_next/save_only_last_obs/sample_avail) + are the same as :class:`~tianshou.data.ReplayBuffer`. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: + assert buffer_num > 0 + size = int(np.ceil(total_size / buffer_num)) + buffer_list = [ReplayBuffer(size, **kwargs) for _ in range(buffer_num)] + super().__init__(buffer_list) + + +class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): + """PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size. + + It is used for storing transition from different environments yet keeping the order + of time. + + :param total_size: the total size of PrioritizedVectorReplayBuffer. + :param buffer_num: the number of PrioritizedReplayBuffer it uses, which are + under the same configuration. + + Other input arguments (alpha/beta/stack_num/ignore_obs_next/save_only_last_obs/ + sample_avail) are the same as :class:`~tianshou.data.PrioritizedReplayBuffer`. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: + assert buffer_num > 0 + size = int(np.ceil(total_size / buffer_num)) + buffer_list = [PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num)] + super().__init__(buffer_list) + + def set_beta(self, beta: float) -> None: + for buffer in self.buffers: + buffer.set_beta(beta) + + +class HERVectorReplayBuffer(HERReplayBufferManager): + """HERVectorReplayBuffer contains n HERReplayBuffer with same size. + + It is used for storing transition from different environments yet keeping the order + of time. + + :param total_size: the total size of HERVectorReplayBuffer. + :param buffer_num: the number of HERReplayBuffer it uses, which are + under the same configuration. + + Other input arguments are the same as :class:`~tianshou.data.HERReplayBuffer`. + + .. seealso:: + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: + assert buffer_num > 0 + size = int(np.ceil(total_size / buffer_num)) + buffer_list = [HERReplayBuffer(size, **kwargs) for _ in range(buffer_num)] + super().__init__(buffer_list) diff --git a/examples/atari/tianshou/data/collector.py b/examples/atari/tianshou/data/collector.py new file mode 100644 index 0000000000000000000000000000000000000000..6773a6383e835b676d9e106b5890013be4315471 --- /dev/null +++ b/examples/atari/tianshou/data/collector.py @@ -0,0 +1,943 @@ +import logging +import time +import warnings +from abc import ABC, abstractmethod +from copy import copy +from dataclasses import dataclass +from typing import Any, Self, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +from overrides import override + +from tianshou.data import ( + Batch, + CachedReplayBuffer, + ReplayBuffer, + ReplayBufferManager, + SequenceSummaryStats, + VectorReplayBuffer, + to_numpy, +) +from tianshou.data.types import ( + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.env import BaseVectorEnv, DummyVectorEnv +from tianshou.policy import BasePolicy +from tianshou.utils.print import DataclassPPrintMixin +from tianshou.utils.torch_utils import torch_train_mode + +log = logging.getLogger(__name__) + + +@dataclass(kw_only=True) +class CollectStatsBase(DataclassPPrintMixin): + """The most basic stats, often used for offline learning.""" + + n_collected_episodes: int = 0 + """The number of collected episodes.""" + n_collected_steps: int = 0 + """The number of collected steps.""" + + +@dataclass(kw_only=True) +class CollectStats(CollectStatsBase): + """A data structure for storing the statistics of rollouts.""" + + collect_time: float = 0.0 + """The time for collecting transitions.""" + collect_speed: float = 0.0 + """The speed of collecting (env_step per second).""" + returns: np.ndarray + """The collected episode returns.""" + returns_stat: SequenceSummaryStats | None # can be None if no episode ends during the collect step + """Stats of the collected returns.""" + lens: np.ndarray + """The collected episode lengths.""" + lens_stat: SequenceSummaryStats | None # can be None if no episode ends during the collect step + """Stats of the collected episode lengths.""" + + @classmethod + def with_autogenerated_stats( + cls, + returns: np.ndarray, + lens: np.ndarray, + n_collected_episodes: int = 0, + n_collected_steps: int = 0, + collect_time: float = 0.0, + collect_speed: float = 0.0, + ) -> Self: + """Return a new instance with the stats autogenerated from the given lists.""" + returns_stat = SequenceSummaryStats.from_sequence(returns) if returns.size > 0 else None + lens_stat = SequenceSummaryStats.from_sequence(lens) if lens.size > 0 else None + return cls( + n_collected_episodes=n_collected_episodes, + n_collected_steps=n_collected_steps, + collect_time=collect_time, + collect_speed=collect_speed, + returns=returns, + returns_stat=returns_stat, + lens=np.array(lens, int), + lens_stat=lens_stat, + ) + + +_TArrLike = TypeVar("_TArrLike", bound="np.ndarray | torch.Tensor | Batch | None") + + +def _nullable_slice(obj: _TArrLike, indices: np.ndarray) -> _TArrLike: + """Return None, or the values at the given indices if the object is not None.""" + if obj is not None: + return obj[indices] # type: ignore[index, return-value] + return None # type: ignore[unreachable] + + +def _dict_of_arr_to_arr_of_dicts(dict_of_arr: dict[str, np.ndarray | dict]) -> np.ndarray: + return np.array(Batch(dict_of_arr).to_list_of_dicts()) + + +def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch: + """TODO: this exists because of multiple bugs in Batch and to restore backwards compatibility. + Batch should be fixed and this function should be removed asap!. + """ + if info_array.dtype != np.dtype("O"): + raise ValueError( + f"Expected info_array to have dtype=object, but got {info_array.dtype}.", + ) + + truthy_info_indices = info_array.nonzero()[0] + falsy_info_indices = set(range(len(info_array))) - set(truthy_info_indices) + falsy_info_indices = np.array(list(falsy_info_indices), dtype=int) + + if len(falsy_info_indices) == len(info_array): + return Batch() + + some_nonempty_info = None + for info in info_array: + if info: + some_nonempty_info = info + break + + info_array = copy(info_array) + info_array[falsy_info_indices] = some_nonempty_info + result_batch_parent = Batch(info=info_array) + result_batch_parent.info[falsy_info_indices] = {} + return result_batch_parent.info + + +class BaseCollector(ABC): + """Used to collect data from a vector environment into a buffer using a given policy. + + .. note:: + + Please make sure the given environment has a time limitation if using `n_episode` + collect option. + + .. note:: + + In past versions of Tianshou, the replay buffer passed to `__init__` + was automatically reset. This is not done in the current implementation. + """ + + def __init__( + self, + policy: BasePolicy, + env: BaseVectorEnv | gym.Env, + buffer: ReplayBuffer | None = None, + exploration_noise: bool = False, + ) -> None: + if isinstance(env, gym.Env) and not hasattr(env, "__len__"): + warnings.warn("Single environment detected, wrap to DummyVectorEnv.") + # Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy + env = DummyVectorEnv([lambda: env]) # type: ignore + + if buffer is None: + buffer = VectorReplayBuffer(len(env), len(env)) + + self.buffer: ReplayBuffer = buffer + self.policy = policy + self.env = cast(BaseVectorEnv, env) + self.exploration_noise = exploration_noise + self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + + self._action_space = self.env.action_space + self._is_closed = False + + self._validate_buffer() + + def _validate_buffer(self) -> None: + buf = self.buffer + # TODO: a bit weird but true - all VectorReplayBuffers inherit from ReplayBufferManager. + # We should probably rename the manager + if isinstance(buf, ReplayBufferManager) and buf.buffer_num < self.env_num: + raise ValueError( + f"Buffer has only {buf.buffer_num} buffers, but at least {self.env_num=} are needed.", + ) + if isinstance(buf, CachedReplayBuffer) and buf.cached_buffer_num < self.env_num: + raise ValueError( + f"Buffer has only {buf.cached_buffer_num} cached buffers, but at least {self.env_num=} are needed.", + ) + # Non-VectorReplayBuffer. TODO: probably shouldn't rely on isinstance + if not isinstance(buf, ReplayBufferManager): + if buf.maxsize == 0: + raise ValueError("Buffer maxsize should be greater than 0.") + if self.env_num > 1: + raise ValueError( + f"Cannot use {type(buf).__name__} to collect from multiple envs ({self.env_num=}). " + f"Please use the corresponding VectorReplayBuffer instead.", + ) + + @property + def env_num(self) -> int: + return len(self.env) + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space + + def close(self) -> None: + """Close the collector and the environment.""" + self.env.close() + self._is_closed = True + + def reset( + self, + reset_buffer: bool = True, + reset_stats: bool = True, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Reset the environment, statistics, and data needed to start the collection. + + :param reset_buffer: if true, reset the replay buffer attached + to the collector. + :param reset_stats: if true, reset the statistics attached to the collector. + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Defaults to None (extra keyword arguments) + :return: The initial observation and info from the environment. + """ + obs_NO, info_N = self.reset_env(gym_reset_kwargs=gym_reset_kwargs) + if reset_buffer: + self.reset_buffer() + if reset_stats: + self.reset_stat() + self._is_closed = False + return obs_NO, info_N + + def reset_stat(self) -> None: + """Reset the statistic variables.""" + self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + + def reset_buffer(self, keep_statistics: bool = False) -> None: + """Reset the data buffer.""" + self.buffer.reset(keep_statistics=keep_statistics) + + def reset_env( + self, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Reset the environments and the initial obs, info, and hidden state of the collector. + + :return: The initial observation and info from the (vectorized) environment. + """ + gym_reset_kwargs = gym_reset_kwargs or {} + obs_NO, info_N = self.env.reset(**gym_reset_kwargs) + # TODO: hack, wrap envpool envs such that they don't return a dict + if isinstance(info_N, dict): # type: ignore[unreachable] + # this can happen if the env is an envpool env. Then the thing returned by reset is a dict + # with array entries instead of an array of dicts + # We use Batch to turn it into an array of dicts + info_N = _dict_of_arr_to_arr_of_dicts(info_N) # type: ignore[unreachable] + return obs_NO, info_N + + @abstractmethod + def _collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: + pass + + @torch.no_grad() + def collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + reset_before_collect: bool = False, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: + """Collect a specified number of steps or episodes. + + To ensure an unbiased sampling result with the n_episode option, this function will + first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` + episodes, they will be collected evenly from each env. + + :param n_step: how many steps you want to collect. + :param n_episode: how many episodes you want to collect. + :param random: whether to use random policy for collecting data. + :param render: the sleep time between rendering consecutive frames. + :param reset_before_collect: whether to reset the environment before collecting data. + (The collector needs the initial obs and info to function properly.) + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Only used if reset_before_collect is True. + + .. note:: + + One and only one collection number specification is permitted, either + ``n_step`` or ``n_episode``. + + :return: The collected stats + """ + # check that exactly one of n_step or n_episode is set and that the other is larger than 0 + self._validate_n_step_n_episode(n_episode, n_step) + + if reset_before_collect: + self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) + + with torch_train_mode(self.policy, False): + return self._collect( + n_step=n_step, + n_episode=n_episode, + random=random, + render=render, + gym_reset_kwargs=gym_reset_kwargs, + ) + + def _validate_n_step_n_episode(self, n_episode: int | None, n_step: int | None) -> None: + if not n_step and not n_episode: + raise ValueError( + f"Only one of n_step and n_episode should be set to a value larger than zero " + f"but got {n_step=}, {n_episode=}.", + ) + if n_step is None and n_episode is None: + raise ValueError( + "Exactly one of n_step and n_episode should be set but got None for both.", + ) + if n_step and n_step % self.env_num != 0: + warnings.warn( + f"{n_step=} is not a multiple of ({self.env_num=}), " + "which may cause extra transitions being collected into the buffer.", + ) + if n_episode and self.env_num > n_episode: + warnings.warn( + f"{n_episode=} should be larger than {self.env_num=} to " + f"collect at least one trajectory in each environment.", + ) + + +class Collector(BaseCollector): + # NAMING CONVENTION (mostly suffixes): + # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed, + # the corresponding env is either reset or removed from the ready envs. + # N - number of envs, always fixed and >= R. + # R - number ready env ids. Note that this might change when envs get idle. + # This can only happen in n_episode case, see explanation in the corresponding block. + # For n_step, we always use all envs to collect the data, while for n_episode, + # R will be at most n_episode at the beginning, but can decrease during the collection. + # O - dimension(s) of observations + # A - dimension(s) of actions + # H - dimension(s) of hidden state + # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. + # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. + # Only used in n_episode case. Then, R becomes R-S. + def __init__( + self, + policy: BasePolicy, + env: gym.Env | BaseVectorEnv, + buffer: ReplayBuffer | None = None, + exploration_noise: bool = False, + ) -> None: + """:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param env: a ``gym.Env`` environment or an instance of the + :class:`~tianshou.env.BaseVectorEnv` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + as the default buffer. + :param exploration_noise: determine whether the action needs to be modified + with the corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. Default to False. + """ + super().__init__(policy, env, buffer, exploration_noise=exploration_noise) + self._pre_collect_obs_RO: np.ndarray | None = None + self._pre_collect_info_R: np.ndarray | None = None + self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None + + self._is_closed = False + self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + + @override + def close(self) -> None: + super().close() + self._pre_collect_obs_RO = None + self._pre_collect_info_R = None + + @override + def reset_env( + self, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + obs_NO, info_N = super().reset_env(gym_reset_kwargs=gym_reset_kwargs) + # We assume that R = N when reset is called. + # TODO: there is currently no mechanism that ensures this and it's a public method! + self._pre_collect_obs_RO = obs_NO + self._pre_collect_info_R = info_N + self._pre_collect_hidden_state_RH = None + return obs_NO, info_N + + def _compute_action_policy_hidden( + self, + random: bool, + ready_env_ids_R: np.ndarray, + last_obs_RO: np.ndarray, + last_info_R: np.ndarray, + last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None, + ) -> tuple[np.ndarray, np.ndarray, Batch, np.ndarray | torch.Tensor | Batch | None]: + """Returns the action, the normalized action, a "policy" entry, and the hidden state.""" + if random: + try: + act_normalized_RA = np.array( + [self._action_space[i].sample() for i in ready_env_ids_R], + ) + # TODO: test whether envpool env explicitly + except TypeError: # envpool's action space is not for per-env + act_normalized_RA = np.array([self._action_space.sample() for _ in ready_env_ids_R]) + act_RA = self.policy.map_action_inverse(np.array(act_normalized_RA)) + policy_R = Batch() + hidden_state_RH = None + + else: + info_batch = _HACKY_create_info_batch(last_info_R) + obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch)) + + act_batch_RA = self.policy( + obs_batch_R, + last_hidden_state_RH, + ) + + act_RA = to_numpy(act_batch_RA.act) + if self.exploration_noise: + act_RA = self.policy.exploration_noise(act_RA, obs_batch_R) + act_normalized_RA = self.policy.map_action(act_RA) + + # TODO: cleanup the whole policy in batch thing + # todo policy_R can also be none, check + policy_R = act_batch_RA.get("policy", Batch()) + if not isinstance(policy_R, Batch): + raise RuntimeError( + f"The policy result should be a {Batch}, but got {type(policy_R)}", + ) + + hidden_state_RH = act_batch_RA.get("state", None) + # TODO: do we need the conditional? Would be better to just add hidden_state which could be None + if hidden_state_RH is not None: + policy_R.hidden_state = ( + hidden_state_RH # save state into buffer through policy attr + ) + return act_RA, act_normalized_RA, policy_R, hidden_state_RH + + # TODO: reduce complexity, remove the noqa + def _collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: + # TODO: can't do it init since AsyncCollector is currently a subclass of Collector + if self.env.is_async: + raise ValueError( + f"Please use {AsyncCollector.__name__} for asynchronous environments. " + f"Env class: {self.env.__class__.__name__}.", + ) + + if n_step is not None: + ready_env_ids_R = np.arange(self.env_num) + elif n_episode is not None: + ready_env_ids_R = np.arange(min(self.env_num, n_episode)) + else: + raise ValueError("Either n_step or n_episode should be set.") + + start_time = time.time() + if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: + raise ValueError( + "Initial obs and info should not be None. " + "Either reset the collector (using reset or reset_env) or pass reset_before_collect=True to collect.", + ) + + # get the first obs to be the current obs in the n_step case as + # episodes as a new call to collect does not restart trajectories + # (which we also really don't want) + step_count = 0 + num_collected_episodes = 0 + episode_returns: list[float] = [] + episode_lens: list[int] = [] + episode_start_indices: list[int] = [] + + # in case we select fewer episodes than envs, we run only some of them + last_obs_RO = _nullable_slice(self._pre_collect_obs_RO, ready_env_ids_R) + last_info_R = _nullable_slice(self._pre_collect_info_R, ready_env_ids_R) + last_hidden_state_RH = _nullable_slice( + self._pre_collect_hidden_state_RH, + ready_env_ids_R, + ) + + while True: + # todo check if we need this when using cur_rollout_batch + # if len(cur_rollout_batch) != len(ready_env_ids): + # raise RuntimeError( + # f"The length of the collected_rollout_batch {len(cur_rollout_batch)}) is not equal to the length of ready_env_ids" + # f"{len(ready_env_ids)}. This should not happen and could be a bug!", + # ) + # restore the state: if the last state is None, it won't store + + # get the next action + ( + act_RA, + act_normalized_RA, + policy_R, + hidden_state_RH, + ) = self._compute_action_policy_hidden( + random=random, + ready_env_ids_R=ready_env_ids_R, + last_obs_RO=last_obs_RO, + last_info_R=last_info_R, + last_hidden_state_RH=last_hidden_state_RH, + ) + + obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( + act_normalized_RA, + ready_env_ids_R, + ) + if isinstance(info_R, dict): # type: ignore[unreachable] + # This can happen if the env is an envpool env. Then the info returned by step is a dict + info_R = _dict_of_arr_to_arr_of_dicts(info_R) # type: ignore[unreachable] + done_R = np.logical_or(terminated_R, truncated_R) + + current_iteration_batch = cast( + RolloutBatchProtocol, + Batch( + obs=last_obs_RO, + act=act_RA, + policy=policy_R, + obs_next=obs_next_RO, + rew=rew_R, + terminated=terminated_R, + truncated=truncated_R, + done=done_R, + info=info_R, + ), + ) + + # TODO: only makes sense if render_mode is human. + # Also, doubtful whether it makes sense at all for true vectorized envs + if render: + self.env.render() + if not np.isclose(render, 0): + time.sleep(render) + + # add data into the buffer + ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add( + current_iteration_batch, + buffer_ids=ready_env_ids_R, + ) + + # collect statistics + num_episodes_done_this_iter = np.sum(done_R) + num_collected_episodes += num_episodes_done_this_iter + step_count += len(ready_env_ids_R) + + # preparing for the next iteration + # obs_next, info and hidden_state will be modified inplace in the code below, + # so we copy to not affect the data in the buffer + last_obs_RO = copy(obs_next_RO) + last_info_R = copy(info_R) + last_hidden_state_RH = copy(hidden_state_RH) + + # Preparing last_obs_RO, last_info_R, last_hidden_state_RH for the next while-loop iteration + # Resetting envs that reached done, or removing some of them from the collection if needed (see below) + if num_episodes_done_this_iter > 0: + # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays + # D - number of envs that reached done in the rollout above + env_ind_local_D = np.where(done_R)[0] + env_ind_global_D = ready_env_ids_R[env_ind_local_D] + episode_lens.extend(ep_len_R[env_ind_local_D]) + episode_returns.extend(ep_rew_R[env_ind_local_D]) + episode_start_indices.extend(ep_idx_R[env_ind_local_D]) + # now we copy obs_next to obs, but since there might be + # finished episodes, we have to reset finished envs first. + + gym_reset_kwargs = gym_reset_kwargs or {} + obs_reset_DO, info_reset_D = self.env.reset( + env_id=env_ind_global_D, + **gym_reset_kwargs, + ) + + # Set the hidden state to zero or None for the envs that reached done + # TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of + # this complex logic + self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH) + + # preparing for the next iteration + last_obs_RO[env_ind_local_D] = obs_reset_DO + last_info_R[env_ind_local_D] = info_reset_D + + # Handling the case when we have more ready envs than desired and are not done yet + # + # This can only happen if we are collecting a fixed number of episodes + # If we have more ready envs than there are remaining episodes to collect, + # we will remove some of them for the next rollout + # One effect of this is the following: only envs that have completed an episode + # in the last step can ever be removed from the ready envs. + # Thus, this guarantees that each env will contribute at least one episode to the + # collected data (the buffer). This effect was previous called "avoiding bias in selecting environments" + # However, it is not at all clear whether this is actually useful or necessary. + # Additional naming convention: + # S - number of surplus envs + # TODO: can the whole block be removed? If we have too many episodes, we could just strip the last ones. + # Changing R to R-S highly increases the complexity of the code. + if n_episode: + remaining_episodes_to_collect = n_episode - num_collected_episodes + surplus_env_num = len(ready_env_ids_R) - remaining_episodes_to_collect + if surplus_env_num > 0: + # R becomes R-S here, preparing for the next iteration in while loop + # Everything that was of length R needs to be filtered and become of length R-S. + # Note that this won't be the last iteration, as one iteration equals one + # step and we still need to collect the remaining episodes to reach the breaking condition. + + # creating the mask + env_to_be_ignored_ind_local_S = env_ind_local_D[:surplus_env_num] + env_should_remain_R = np.ones_like(ready_env_ids_R, dtype=bool) + env_should_remain_R[env_to_be_ignored_ind_local_S] = False + # stripping the "idle" indices, shortening the relevant quantities from R to R-S + ready_env_ids_R = ready_env_ids_R[env_should_remain_R] + last_obs_RO = last_obs_RO[env_should_remain_R] + last_info_R = last_info_R[env_should_remain_R] + if hidden_state_RH is not None: + last_hidden_state_RH = last_hidden_state_RH[env_should_remain_R] # type: ignore[index] + + if (n_step and step_count >= n_step) or ( + n_episode and num_collected_episodes >= n_episode + ): + break + + # generate statistics + self.collect_step += step_count + self.collect_episode += num_collected_episodes + collect_time = max(time.time() - start_time, 1e-9) + self.collect_time += collect_time + + if n_step: + # persist for future collect iterations + self._pre_collect_obs_RO = last_obs_RO + self._pre_collect_info_R = last_info_R + self._pre_collect_hidden_state_RH = last_hidden_state_RH + elif n_episode: + # reset envs and the _pre_collect fields + self.reset_env(gym_reset_kwargs) # todo still necessary? + + return CollectStats.with_autogenerated_stats( + returns=np.array(episode_returns), + lens=np.array(episode_lens), + n_collected_episodes=num_collected_episodes, + n_collected_steps=step_count, + collect_time=collect_time, + collect_speed=step_count / collect_time, + ) + + @staticmethod + def _reset_hidden_state_based_on_type( + env_ind_local_D: np.ndarray, + last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None, + ) -> None: + if isinstance(last_hidden_state_RH, torch.Tensor): + last_hidden_state_RH[env_ind_local_D].zero_() # type: ignore[index] + elif isinstance(last_hidden_state_RH, np.ndarray): + last_hidden_state_RH[env_ind_local_D] = ( + None if last_hidden_state_RH.dtype == object else 0 + ) + elif isinstance(last_hidden_state_RH, Batch): + last_hidden_state_RH.empty_(env_ind_local_D) + # todo is this inplace magic and just working? + + +class AsyncCollector(Collector): + """Async Collector handles async vector environment. + + Please refer to :class:`~tianshou.data.Collector` for a more detailed explanation. + """ + + def __init__( + self, + policy: BasePolicy, + env: BaseVectorEnv, + buffer: ReplayBuffer | None = None, + exploration_noise: bool = False, + ) -> None: + if not env.is_async: + # TODO: raise an exception? + log.error( + f"Please use {Collector.__name__} if not using async venv. " + f"Env class: {env.__class__.__name__}", + ) + # assert env.is_async + warnings.warn("Using async setting may collect extra transitions into buffer.") + super().__init__( + policy, + env, + buffer, + exploration_noise, + ) + # E denotes the number of parallel environments: self.env_num + # At init, E=R but during collection R <= E + # Keep in sync with reset! + self._ready_env_ids_R: np.ndarray = np.arange(self.env_num) + self._current_obs_in_all_envs_EO: np.ndarray | None = copy(self._pre_collect_obs_RO) + self._current_info_in_all_envs_E: np.ndarray | None = copy(self._pre_collect_info_R) + self._current_hidden_state_in_all_envs_EH: np.ndarray | torch.Tensor | Batch | None = copy( + self._pre_collect_hidden_state_RH, + ) + self._current_action_in_all_envs_EA: np.ndarray = np.empty(self.env_num) + self._current_policy_in_all_envs_E: Batch | None = None + + @override + def reset( + self, + reset_buffer: bool = True, + reset_stats: bool = True, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + # This sets the _pre_collect attrs + result = super().reset( + reset_buffer=reset_buffer, + reset_stats=reset_stats, + gym_reset_kwargs=gym_reset_kwargs, + ) + # Keep in sync with init! + self._ready_env_ids_R = np.arange(self.env_num) + # E denotes the number of parallel environments self.env_num + self._current_obs_in_all_envs_EO = copy(self._pre_collect_obs_RO) + self._current_info_in_all_envs_E = copy(self._pre_collect_info_R) + self._current_hidden_state_in_all_envs_EH = copy(self._pre_collect_hidden_state_RH) + self._current_action_in_all_envs_EA = np.empty(self.env_num) + self._current_policy_in_all_envs_E = None + return result + + @override + def reset_env( + self, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + # we need to step through the envs and wait until they are ready to be able to interact with them + if self.env.waiting_id: + self.env.step(None, id=self.env.waiting_id) + return super().reset_env(gym_reset_kwargs=gym_reset_kwargs) + + @override + def _collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: + start_time = time.time() + + step_count = 0 + num_collected_episodes = 0 + episode_returns: list[float] = [] + episode_lens: list[int] = [] + episode_start_indices: list[int] = [] + + ready_env_ids_R = self._ready_env_ids_R + # last_obs_RO= self._current_obs_in_all_envs_EO[ready_env_ids_R] # type: ignore[index] + # last_info_R = self._current_info_in_all_envs_E[ready_env_ids_R] # type: ignore[index] + # last_hidden_state_RH = self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] # type: ignore[index] + # last_obs_RO = self._pre_collect_obs_RO + # last_info_R = self._pre_collect_info_R + # last_hidden_state_RH = self._pre_collect_hidden_state_RH + if self._current_obs_in_all_envs_EO is None or self._current_info_in_all_envs_E is None: + raise RuntimeError( + "Current obs or info array is None, did you call reset or pass reset_at_collect=True?", + ) + + last_obs_RO = self._current_obs_in_all_envs_EO[ready_env_ids_R] + last_info_R = self._current_info_in_all_envs_E[ready_env_ids_R] + last_hidden_state_RH = _nullable_slice( + self._current_hidden_state_in_all_envs_EH, + ready_env_ids_R, + ) + # Each iteration of the AsyncCollector is only stepping a subset of the + # envs. The last observation/ hidden state of the ones not included in + # the current iteration has to be retained. + while True: + # todo do we need this? + # todo extend to all current attributes but some could be None at init + if self._current_obs_in_all_envs_EO is None: + raise RuntimeError( + "Current obs is None, did you call reset or pass reset_at_collect=True?", + ) + if ( + not len(self._current_obs_in_all_envs_EO) + == len(self._current_action_in_all_envs_EA) + == self.env_num + ): # major difference + raise RuntimeError( + f"{len(self._current_obs_in_all_envs_EO)=} and" + f"{len(self._current_action_in_all_envs_EA)=} have to equal" + f" {self.env_num=} as it tracks the current transition" + f"in all envs", + ) + + # get the next action + ( + act_RA, + act_normalized_RA, + policy_R, + hidden_state_RH, + ) = self._compute_action_policy_hidden( + random=random, + ready_env_ids_R=ready_env_ids_R, + last_obs_RO=last_obs_RO, + last_info_R=last_info_R, + last_hidden_state_RH=last_hidden_state_RH, + ) + + # save act_RA/policy_R/ hidden_state_RH before env.step + self._current_action_in_all_envs_EA[ready_env_ids_R] = act_RA + if self._current_policy_in_all_envs_E: + self._current_policy_in_all_envs_E[ready_env_ids_R] = policy_R + else: + self._current_policy_in_all_envs_E = policy_R # first iteration + if hidden_state_RH is not None: + if self._current_hidden_state_in_all_envs_EH is not None: + # Need to cast since if it's a Tensor, the assignment might in fact fail if hidden_state_RH is not + # a tensor as well. This is hard to express with proper typing, even using @overload, so we cheat + # and hope that if one of the two is a tensor, the other one is as well. + self._current_hidden_state_in_all_envs_EH = cast( + np.ndarray | Batch, + self._current_hidden_state_in_all_envs_EH, + ) + self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] = hidden_state_RH + else: + self._current_hidden_state_in_all_envs_EH = hidden_state_RH + + # step in env + obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( + act_normalized_RA, + ready_env_ids_R, + ) + done_R = np.logical_or(terminated_R, truncated_R) + # Not all environments of the AsyncCollector might have performed a step in this iteration. + # Change batch_of_envs_with_step_in_this_iteration here to reflect that ready_env_ids_R has changed. + # This means especially that R is potentially changing every iteration + try: + ready_env_ids_R = cast(np.ndarray, info_R["env_id"]) + # TODO: don't use bare Exception! + except Exception: + ready_env_ids_R = np.array([i["env_id"] for i in info_R]) + + current_iteration_batch = cast( + RolloutBatchProtocol, + Batch( + obs=self._current_obs_in_all_envs_EO[ready_env_ids_R], + act=self._current_action_in_all_envs_EA[ready_env_ids_R], + policy=self._current_policy_in_all_envs_E[ready_env_ids_R], + obs_next=obs_next_RO, + rew=rew_R, + terminated=terminated_R, + truncated=truncated_R, + done=done_R, + info=info_R, + ), + ) + + if render: + self.env.render() + if render > 0 and not np.isclose(render, 0): + time.sleep(render) + + # add data into the buffer + ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add( + current_iteration_batch, + buffer_ids=ready_env_ids_R, + ) + + # collect statistics + num_episodes_done_this_iter = np.sum(done_R) + step_count += len(ready_env_ids_R) + num_collected_episodes += num_episodes_done_this_iter + + # preparing for the next iteration + # todo seem we can get rid of this last_sth stuff altogether + last_obs_RO = copy(obs_next_RO) + last_info_R = copy(info_R) + last_hidden_state_RH = copy( + self._current_hidden_state_in_all_envs_EH[ready_env_ids_R], # type: ignore[index] + ) + if num_episodes_done_this_iter: + env_ind_local_D = np.where(done_R)[0] + env_ind_global_D = ready_env_ids_R[env_ind_local_D] + episode_lens.extend(ep_len_R[env_ind_local_D]) + episode_returns.extend(ep_rew_R[env_ind_local_D]) + episode_start_indices.extend(ep_idx_R[env_ind_local_D]) + + # now we copy obs_next_RO to obs, but since there might be + # finished episodes, we have to reset finished envs first. + gym_reset_kwargs = gym_reset_kwargs or {} + obs_reset_DO, info_reset_D = self.env.reset( + env_id=env_ind_global_D, + **gym_reset_kwargs, + ) + last_obs_RO[env_ind_local_D] = obs_reset_DO + last_info_R[env_ind_local_D] = info_reset_D + + self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH) + + # update based on the current transition in all envs + self._current_obs_in_all_envs_EO[ready_env_ids_R] = last_obs_RO + # this is a list, so loop over + for idx, ready_env_id in enumerate(ready_env_ids_R): + self._current_info_in_all_envs_E[ready_env_id] = last_info_R[idx] + if self._current_hidden_state_in_all_envs_EH is not None: + # Need to cast since if it's a Tensor, the assignment might in fact fail if hidden_state_RH is not + # a tensor as well. This is hard to express with proper typing, even using @overload, so we cheat + # and hope that if one of the two is a tensor, the other one is as well. + self._current_hidden_state_in_all_envs_EH = cast( + np.ndarray | Batch, + self._current_hidden_state_in_all_envs_EH, + ) + self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] = last_hidden_state_RH + else: + self._current_hidden_state_in_all_envs_EH = last_hidden_state_RH + + if (n_step and step_count >= n_step) or ( + n_episode and num_collected_episodes >= n_episode + ): + break + + # generate statistics + self.collect_step += step_count + self.collect_episode += num_collected_episodes + collect_time = max(time.time() - start_time, 1e-9) + self.collect_time += collect_time + + # persist for future collect iterations + self._ready_env_ids_R = ready_env_ids_R + + return CollectStats.with_autogenerated_stats( + returns=np.array(episode_returns), + lens=np.array(episode_lens), + n_collected_episodes=num_collected_episodes, + n_collected_steps=step_count, + collect_time=collect_time, + collect_speed=step_count / collect_time, + ) diff --git a/examples/atari/tianshou/data/stats.py b/examples/atari/tianshou/data/stats.py new file mode 100644 index 0000000000000000000000000000000000000000..b7731860238dc06b01857480cf14beba7b7e85ae --- /dev/null +++ b/examples/atari/tianshou/data/stats.py @@ -0,0 +1,89 @@ +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import numpy as np + +from tianshou.utils.print import DataclassPPrintMixin + +if TYPE_CHECKING: + from tianshou.data import CollectStats, CollectStatsBase + from tianshou.policy.base import TrainingStats + + +@dataclass(kw_only=True) +class SequenceSummaryStats(DataclassPPrintMixin): + """A data structure for storing the statistics of a sequence.""" + + mean: float + std: float + max: float + min: float + + @classmethod + def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "SequenceSummaryStats": + return cls( + mean=float(np.mean(sequence)), + std=float(np.std(sequence)), + max=float(np.max(sequence)), + min=float(np.min(sequence)), + ) + + +@dataclass(kw_only=True) +class TimingStats(DataclassPPrintMixin): + """A data structure for storing timing statistics.""" + + total_time: float = 0.0 + """The total time elapsed.""" + train_time: float = 0.0 + """The total time elapsed for training (collecting samples plus model update).""" + train_time_collect: float = 0.0 + """The total time elapsed for collecting training transitions.""" + train_time_update: float = 0.0 + """The total time elapsed for updating models.""" + test_time: float = 0.0 + """The total time elapsed for testing models.""" + update_speed: float = 0.0 + """The speed of updating (env_step per second).""" + + +@dataclass(kw_only=True) +class InfoStats(DataclassPPrintMixin): + """A data structure for storing information about the learning process.""" + + gradient_step: int + """The total gradient step.""" + best_reward: float + """The best reward over the test results.""" + best_reward_std: float + """Standard deviation of the best reward over the test results.""" + train_step: int + """The total collected step of training collector.""" + train_episode: int + """The total collected episode of training collector.""" + test_step: int + """The total collected step of test collector.""" + test_episode: int + """The total collected episode of test collector.""" + + timing: TimingStats + """The timing statistics.""" + + +@dataclass(kw_only=True) +class EpochStats(DataclassPPrintMixin): + """A data structure for storing epoch statistics.""" + + epoch: int + """The current epoch.""" + + train_collect_stat: "CollectStatsBase" + """The statistics of the last call to the training collector.""" + test_collect_stat: Optional["CollectStats"] + """The statistics of the last call to the test collector.""" + training_stat: Optional["TrainingStats"] + """The statistics of the last model update step. + Can be None if no model update is performed, typically in the last training iteration.""" + info_stat: InfoStats + """The information of the collector.""" diff --git a/examples/atari/tianshou/data/types.py b/examples/atari/tianshou/data/types.py new file mode 100644 index 0000000000000000000000000000000000000000..3572e548476a41f43ef35cf66ebcced24cd07c96 --- /dev/null +++ b/examples/atari/tianshou/data/types.py @@ -0,0 +1,128 @@ +from typing import Protocol + +import numpy as np +import torch + +from tianshou.data import Batch +from tianshou.data.batch import BatchProtocol, arr_type + +TNestedDictValue = np.ndarray | dict[str, "TNestedDictValue"] + + +d: dict[str, TNestedDictValue] = {"a": {"b": np.array([1, 2, 3])}} +d["c"] = np.array([1, 2, 3]) + + +class ObsBatchProtocol(BatchProtocol, Protocol): + """Observations of an environment that a policy can turn into actions. + + Typically used inside a policy's forward + """ + + obs: arr_type | BatchProtocol + info: arr_type + + +class RolloutBatchProtocol(ObsBatchProtocol, Protocol): + """Typically, the outcome of sampling from a replay buffer.""" + + obs_next: arr_type | BatchProtocol + act: arr_type + rew: np.ndarray + terminated: arr_type + truncated: arr_type + + +class BatchWithReturnsProtocol(RolloutBatchProtocol, Protocol): + """With added returns, usually computed with GAE.""" + + returns: arr_type + + +class PrioBatchProtocol(RolloutBatchProtocol, Protocol): + """Contains weights that can be used for prioritized replay.""" + + weight: np.ndarray | torch.Tensor + + +class RecurrentStateBatch(BatchProtocol, Protocol): + """Used by RNNs in policies, contains `hidden` and `cell` fields.""" + + hidden: torch.Tensor + cell: torch.Tensor + + +class ActBatchProtocol(BatchProtocol, Protocol): + """Simplest batch, just containing the action. Useful e.g., for random policy.""" + + act: arr_type + + +class ActStateBatchProtocol(ActBatchProtocol, Protocol): + """Contains action and state (which can be None), useful for policies that can support RNNs.""" + + state: dict | BatchProtocol | np.ndarray | None + + +class ModelOutputBatchProtocol(ActStateBatchProtocol, Protocol): + """In addition to state and action, contains model output: (logits).""" + + logits: torch.Tensor + + +class FQFBatchProtocol(ModelOutputBatchProtocol, Protocol): + """Model outputs, fractions and quantiles_tau - specific to the FQF model.""" + + fractions: torch.Tensor + quantiles_tau: torch.Tensor + + +class BatchWithAdvantagesProtocol(BatchWithReturnsProtocol, Protocol): + """Contains estimated advantages and values. + + Returns are usually computed from GAE of advantages by adding the value. + """ + + adv: torch.Tensor + v_s: torch.Tensor + + +class DistBatchProtocol(ModelOutputBatchProtocol, Protocol): + """Contains dist instances for actions (created by dist_fn). + + Usually categorical or normal. + """ + + dist: torch.distributions.Distribution + + +class DistLogProbBatchProtocol(DistBatchProtocol, Protocol): + """Contains dist objects that can be sampled from and log_prob of taken action.""" + + log_prob: torch.Tensor + + +class LogpOldProtocol(BatchWithAdvantagesProtocol, Protocol): + """Contains logp_old, often needed for importance weights, in particular in PPO. + + Builds on batches that contain advantages and values. + """ + + logp_old: torch.Tensor + + +class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol, Protocol): + """Contains taus for algorithms using quantile regression. + + See e.g. https://arxiv.org/abs/1806.06923 + """ + + taus: torch.Tensor + + +class ImitationBatchProtocol(ActBatchProtocol, Protocol): + """Similar to other batches, but contains imitation_logits and q_value fields.""" + + state: dict | Batch | np.ndarray | None + q_value: torch.Tensor + imitation_logits: torch.Tensor diff --git a/examples/atari/tianshou/data/utils/__init__.py b/examples/atari/tianshou/data/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/atari/tianshou/data/utils/converter.py b/examples/atari/tianshou/data/utils/converter.py new file mode 100644 index 0000000000000000000000000000000000000000..8f07e049449a5740fc54ea896f118d1a916ac9d6 --- /dev/null +++ b/examples/atari/tianshou/data/utils/converter.py @@ -0,0 +1,165 @@ +import pickle +from copy import deepcopy +from numbers import Number +from typing import Any, Union, no_type_check + +import h5py +import numpy as np +import torch + +from tianshou.data.batch import Batch, _parse_value + + +# TODO: confusing name, could actually return a batch... +# Overrides and generic types should be added +# todo check for ActBatchProtocol +@no_type_check +def to_numpy(x: Any) -> Batch | np.ndarray: + """Return an object without torch.Tensor.""" + if isinstance(x, torch.Tensor): # most often case + return x.detach().cpu().numpy() + if isinstance(x, np.ndarray): # second often case + return x + if isinstance(x, np.number | np.bool_ | Number): + return np.asanyarray(x) + if x is None: + return np.array(None, dtype=object) + if isinstance(x, dict | Batch): + x = Batch(x) if isinstance(x, dict) else deepcopy(x) + x.to_numpy_() + return x + if isinstance(x, list | tuple): + return to_numpy(_parse_value(x)) + # fallback + return np.asanyarray(x) + + +@no_type_check +def to_torch( + x: Any, + dtype: torch.dtype | None = None, + device: str | int | torch.device = "cpu", +) -> Batch | torch.Tensor: + """Return an object without np.ndarray.""" + if isinstance(x, np.ndarray) and issubclass( + x.dtype.type, + np.bool_ | np.number, + ): # most often case + x = torch.from_numpy(x).to(device) + if dtype is not None: + x = x.type(dtype) + return x + if isinstance(x, torch.Tensor): # second often case + if dtype is not None: + x = x.type(dtype) + return x.to(device) + if isinstance(x, np.number | np.bool_ | Number): + return to_torch(np.asanyarray(x), dtype, device) + if isinstance(x, dict | Batch): + x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x) + x.to_torch_(dtype, device) + return x + if isinstance(x, list | tuple): + return to_torch(_parse_value(x), dtype, device) + # fallback + raise TypeError(f"object {x} cannot be converted to torch.") + + +@no_type_check +def to_torch_as(x: Any, y: torch.Tensor) -> Batch | torch.Tensor: + """Return an object without np.ndarray. + + Same as ``to_torch(x, dtype=y.dtype, device=y.device)``. + """ + assert isinstance(y, torch.Tensor) + return to_torch(x, dtype=y.dtype, device=y.device) + + +# Note: object is used as a proxy for objects that can be pickled +# Note: mypy does not support cyclic definition currently +Hdf5ConvertibleValues = Union[ + int, + float, + Batch, + np.ndarray, + torch.Tensor, + object, + "Hdf5ConvertibleType", +] + +Hdf5ConvertibleType = dict[str, Hdf5ConvertibleValues] + + +def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group, compression: str | None = None) -> None: + """Copy object into HDF5 group.""" + + def to_hdf5_via_pickle( + x: object, + y: h5py.Group, + key: str, + compression: str | None = None, + ) -> None: + """Pickle, convert to numpy array and write to HDF5 dataset.""" + data = np.frombuffer(pickle.dumps(x), dtype=np.byte) + y.create_dataset(key, data=data, compression=compression) + + for k, v in x.items(): + if isinstance(v, Batch | dict): + # dicts and batches are both represented by groups + subgrp = y.create_group(k) + if isinstance(v, Batch): + subgrp_data = v.__getstate__() + subgrp.attrs["__data_type__"] = "Batch" + else: + subgrp_data = v + to_hdf5(subgrp_data, subgrp, compression=compression) + elif isinstance(v, torch.Tensor): + # PyTorch tensors are written to datasets + y.create_dataset(k, data=to_numpy(v), compression=compression) + y[k].attrs["__data_type__"] = "Tensor" + elif isinstance(v, np.ndarray): + try: + # NumPy arrays are written to datasets + y.create_dataset(k, data=v, compression=compression) + y[k].attrs["__data_type__"] = "ndarray" + except TypeError: + # If data type is not supported by HDF5 fall back to pickle. + # This happens if dtype=object (e.g. due to entries being None) + # and possibly in other cases like structured arrays. + try: + to_hdf5_via_pickle(v, y, k, compression=compression) + except Exception as exception: + raise RuntimeError( + f"Attempted to pickle {v.__class__.__name__} due to " + "data type not supported by HDF5 and failed.", + ) from exception + y[k].attrs["__data_type__"] = "pickled_ndarray" + elif isinstance(v, int | float): + # ints and floats are stored as attributes of groups + y.attrs[k] = v + else: # resort to pickle for any other type of object + try: + to_hdf5_via_pickle(v, y, k, compression=compression) + except Exception as exception: + raise NotImplementedError( + f"No conversion to HDF5 for object of type '{type(v)}' " + "implemented and fallback to pickle failed.", + ) from exception + y[k].attrs["__data_type__"] = v.__class__.__name__ + + +def from_hdf5(x: h5py.Group, device: str | None = None) -> Hdf5ConvertibleValues: + """Restore object from HDF5 group.""" + if isinstance(x, h5py.Dataset): + # handle datasets + if x.attrs["__data_type__"] == "ndarray": + return np.array(x) + if x.attrs["__data_type__"] == "Tensor": + return torch.tensor(x, device=device) + return pickle.loads(x[()]) + # handle groups representing a dict or a Batch + y = dict(x.attrs.items()) + data_type = y.pop("__data_type__", None) + for k, v in x.items(): + y[k] = from_hdf5(v, device) + return Batch(y) if data_type == "Batch" else y diff --git a/examples/atari/tianshou/data/utils/segtree.py b/examples/atari/tianshou/data/utils/segtree.py new file mode 100644 index 0000000000000000000000000000000000000000..3e4582c7bcdfcedef83bc8a30fc3adc78e7a7673 --- /dev/null +++ b/examples/atari/tianshou/data/utils/segtree.py @@ -0,0 +1,134 @@ +import numpy as np +from numba import njit + + +class SegmentTree: + """Implementation of Segment Tree. + + The segment tree stores an array ``arr`` with size ``n``. It supports value + update and fast query of the sum for the interval ``[left, right)`` in + O(log n) time. The detailed procedure is as follows: + + 1. Pad the array to have length of power of 2, so that leaf nodes in the \ + segment tree have the same depth. + 2. Store the segment tree in a binary heap. + + :param size: the size of segment tree. + """ + + def __init__(self, size: int) -> None: + bound = 1 + while bound < size: + bound *= 2 + self._size = size + self._bound = bound + self._value = np.zeros([bound * 2]) + self._compile() + + def __len__(self) -> int: + return self._size + + def __getitem__(self, index: int | np.ndarray) -> float | np.ndarray: + """Return self[index].""" + return self._value[index + self._bound] + + def __setitem__(self, index: int | np.ndarray, value: float | np.ndarray) -> None: + """Update values in segment tree. + + Duplicate values in ``index`` are handled by numpy: later index + overwrites previous ones. + :: + + >>> a = np.array([1, 2, 3, 4]) + >>> a[[0, 1, 0, 1]] = [4, 5, 6, 7] + >>> print(a) + [6 7 3 4] + """ + if isinstance(index, int): + index, value = np.array([index]), np.array([value]) + assert np.all(index >= 0) + assert np.all(index < self._size) + _setitem(self._value, index + self._bound, value) + + def reduce(self, start: int = 0, end: int | None = None) -> float: + """Return operation(value[start:end]).""" + if start == 0 and end is None: + return self._value[1] + if end is None: + end = self._size + if end < 0: + end += self._size + return _reduce(self._value, start + self._bound - 1, end + self._bound) + + def get_prefix_sum_idx(self, value: float | np.ndarray) -> int | np.ndarray: + r"""Find the index with given value. + + Return the minimum index for each ``v`` in ``value`` so that + :math:`v \le \mathrm{sums}_i`, where + :math:`\mathrm{sums}_i = \sum_{j = 0}^{i} \mathrm{arr}_j`. + + .. warning:: + + Please make sure all of the values inside the segment tree are + non-negative when using this function. + """ + assert np.all(value >= 0.0) + assert np.all(value < self._value[1]) + single = False + if not isinstance(value, np.ndarray): + value = np.array([value]) + single = True + index = _get_prefix_sum_idx(value, self._bound, self._value) + return index.item() if single else index + + def _compile(self) -> None: + f64 = np.array([0, 1], dtype=np.float64) + f32 = np.array([0, 1], dtype=np.float32) + i64 = np.array([0, 1], dtype=np.int64) + _setitem(f64, i64, f64) + _setitem(f64, i64, f32) + _reduce(f64, 0, 1) + _get_prefix_sum_idx(f64, 1, f64) + _get_prefix_sum_idx(f32, 1, f64) + + +@njit +def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None: + """Numba version, 4x faster: 0.1 -> 0.024.""" + tree[index] = value + while index[0] > 1: + index //= 2 + tree[index] = tree[index * 2] + tree[index * 2 + 1] + + +@njit +def _reduce(tree: np.ndarray, start: int, end: int) -> float: + """Numba version, 2x faster: 0.009 -> 0.005.""" + # nodes in (start, end) should be aggregated + result = 0.0 + while end - start > 1: # (start, end) interval is not empty + if start % 2 == 0: + result += tree[start + 1] + start //= 2 + if end % 2 == 1: + result += tree[end - 1] + end //= 2 + return result + + +@njit +def _get_prefix_sum_idx(value: np.ndarray, bound: int, sums: np.ndarray) -> np.ndarray: + """Numba version (v0.51), 5x speed up with size=100000 and bsz=64. + + vectorized np: 0.0923 (numpy best) -> 0.024 (now) + for-loop: 0.2914 -> 0.019 (but not so stable) + """ + index = np.ones(value.shape, dtype=np.int64) + while index[0] < bound: + index *= 2 + lsons = sums[index] + direct = lsons < value + value -= lsons * direct + index += direct + index -= bound + return index diff --git a/examples/atari/tianshou/env/__init__.py b/examples/atari/tianshou/env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..049ccf439755830975e94d45ecbbfba95b1b0295 --- /dev/null +++ b/examples/atari/tianshou/env/__init__.py @@ -0,0 +1,30 @@ +"""Env package.""" + +from tianshou.env.gym_wrappers import ( + ContinuousToDiscrete, + MultiDiscreteToDiscrete, + TruncatedAsTerminated, +) +from tianshou.env.pettingzoo_env import PettingZooEnv +from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper +from tianshou.env.venvs import ( + BaseVectorEnv, + DummyVectorEnv, + RayVectorEnv, + ShmemVectorEnv, + SubprocVectorEnv, +) + +__all__ = [ + "BaseVectorEnv", + "DummyVectorEnv", + "SubprocVectorEnv", + "ShmemVectorEnv", + "RayVectorEnv", + "VectorEnvWrapper", + "VectorEnvNormObs", + "PettingZooEnv", + "ContinuousToDiscrete", + "MultiDiscreteToDiscrete", + "TruncatedAsTerminated", +] diff --git a/examples/atari/tianshou/env/gym_wrappers.py b/examples/atari/tianshou/env/gym_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..1db4a286c9e0d218aa8e74730580284fdcbaa9eb --- /dev/null +++ b/examples/atari/tianshou/env/gym_wrappers.py @@ -0,0 +1,80 @@ +from typing import Any, SupportsFloat + +import gymnasium as gym +import numpy as np +from packaging import version + + +class ContinuousToDiscrete(gym.ActionWrapper): + """Gym environment wrapper to take discrete action in a continuous environment. + + :param gym.Env env: gym environment with continuous action space. + :param action_per_dim: number of discrete actions in each dimension + of the action space. + """ + + def __init__(self, env: gym.Env, action_per_dim: int | list[int]) -> None: + super().__init__(env) + assert isinstance(env.action_space, gym.spaces.Box) + low, high = env.action_space.low, env.action_space.high + if isinstance(action_per_dim, int): + action_per_dim = [action_per_dim] * env.action_space.shape[0] + assert len(action_per_dim) == env.action_space.shape[0] + self.action_space = gym.spaces.MultiDiscrete(action_per_dim) + self.mesh = np.array( + [np.linspace(lo, hi, a) for lo, hi, a in zip(low, high, action_per_dim, strict=True)], + dtype=object, + ) + + def action(self, act: np.ndarray) -> np.ndarray: # type: ignore + # modify act + assert len(act.shape) <= 2, f"Unknown action format with shape {act.shape}." + if len(act.shape) == 1: + return np.array([self.mesh[i][a] for i, a in enumerate(act)]) + return np.array([[self.mesh[i][a] for i, a in enumerate(a_)] for a_ in act]) + + +class MultiDiscreteToDiscrete(gym.ActionWrapper): + """Gym environment wrapper to take discrete action in multidiscrete environment. + + :param gym.Env env: gym environment with multidiscrete action space. + """ + + def __init__(self, env: gym.Env) -> None: + super().__init__(env) + assert isinstance(env.action_space, gym.spaces.MultiDiscrete) + nvec = env.action_space.nvec + assert nvec.ndim == 1 + self.bases = np.ones_like(nvec) + for i in range(1, len(self.bases)): + self.bases[i] = self.bases[i - 1] * nvec[-i] + self.action_space = gym.spaces.Discrete(np.prod(nvec)) + + def action(self, act: np.ndarray) -> np.ndarray: # type: ignore + converted_act = [] + for b in np.flip(self.bases): + converted_act.append(act // b) + act = act % b + return np.array(converted_act).transpose() + + +class TruncatedAsTerminated(gym.Wrapper): + """A wrapper that set ``terminated = terminated or truncated`` for ``step()``. + + It's intended to use with ``gym.wrappers.TimeLimit``. + + :param gym.Env env: gym environment. + """ + + def __init__(self, env: gym.Env): + super().__init__(env) + if not version.parse(gym.__version__) >= version.parse("0.26.0"): + raise OSError( + f"TruncatedAsTerminated is not applicable with gym version \ + {gym.__version__}", + ) + + def step(self, act: np.ndarray) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]: + observation, reward, terminated, truncated, info = super().step(act) + terminated = terminated or truncated + return observation, reward, terminated, truncated, info diff --git a/examples/atari/tianshou/env/pettingzoo_env.py b/examples/atari/tianshou/env/pettingzoo_env.py new file mode 100644 index 0000000000000000000000000000000000000000..56b052c7b438a31d2f242350df63ea71a791364c --- /dev/null +++ b/examples/atari/tianshou/env/pettingzoo_env.py @@ -0,0 +1,132 @@ +import warnings +from abc import ABC +from typing import Any + +import pettingzoo +from gymnasium import spaces +from packaging import version +from pettingzoo.utils.env import AECEnv +from pettingzoo.utils.wrappers import BaseWrapper + +if version.parse(pettingzoo.__version__) < version.parse("1.21.0"): + warnings.warn( + f"You are using PettingZoo {pettingzoo.__version__}. " + f"Future tianshou versions may not support PettingZoo<1.21.0. " + f"Consider upgrading your PettingZoo version.", + DeprecationWarning, + ) + + +class PettingZooEnv(AECEnv, ABC): + """The interface for petting zoo environments. + + Multi-agent environments must be wrapped as + :class:`~tianshou.env.PettingZooEnv`. Here is the usage: + :: + + env = PettingZooEnv(...) + # obs is a dict containing obs, agent_id, and mask + obs = env.reset() + action = policy(obs) + obs, rew, trunc, term, info = env.step(action) + env.close() + + The available action's mask is set to True, otherwise it is set to False. + Further usage can be found at :ref:`marl_example`. + """ + + def __init__(self, env: BaseWrapper): + super().__init__() + self.env = env + # agent idx list + self.agents = self.env.possible_agents + self.agent_idx = {} + for i, agent_id in enumerate(self.agents): + self.agent_idx[agent_id] = i + + self.rewards = [0] * len(self.agents) + + # Get first observation space, assuming all agents have equal space + self.observation_space: Any = self.env.observation_space(self.agents[0]) + + # Get first action space, assuming all agents have equal space + self.action_space: Any = self.env.action_space(self.agents[0]) + + assert all( + self.env.observation_space(agent) == self.observation_space for agent in self.agents + ), ( + "Observation spaces for all agents must be identical. Perhaps " + "SuperSuit's pad_observations wrapper can help (useage: " + "`supersuit.aec_wrappers.pad_observations(env)`" + ) + + assert all(self.env.action_space(agent) == self.action_space for agent in self.agents), ( + "Action spaces for all agents must be identical. Perhaps " + "SuperSuit's pad_action_space wrapper can help (useage: " + "`supersuit.aec_wrappers.pad_action_space(env)`" + ) + + self.reset() + + def reset(self, *args: Any, **kwargs: Any) -> tuple[dict, dict]: + self.env.reset(*args, **kwargs) + + observation, reward, terminated, truncated, info = self.env.last(self) + + if isinstance(observation, dict) and "action_mask" in observation: + observation_dict = { + "agent_id": self.env.agent_selection, + "obs": observation["observation"], + "mask": [obm == 1 for obm in observation["action_mask"]], + } + else: + if isinstance(self.action_space, spaces.Discrete): + observation_dict = { + "agent_id": self.env.agent_selection, + "obs": observation, + "mask": [True] * self.env.action_space(self.env.agent_selection).n, + } + else: + observation_dict = { + "agent_id": self.env.agent_selection, + "obs": observation, + } + + return observation_dict, info + + def step(self, action: Any) -> tuple[dict, list[int], bool, bool, dict]: + self.env.step(action) + + observation, rew, term, trunc, info = self.env.last() + + if isinstance(observation, dict) and "action_mask" in observation: + obs = { + "agent_id": self.env.agent_selection, + "obs": observation["observation"], + "mask": [obm == 1 for obm in observation["action_mask"]], + } + else: + if isinstance(self.action_space, spaces.Discrete): + obs = { + "agent_id": self.env.agent_selection, + "obs": observation, + "mask": [True] * self.env.action_space(self.env.agent_selection).n, + } + else: + obs = {"agent_id": self.env.agent_selection, "obs": observation} + + for agent_id, reward in self.env.rewards.items(): + self.rewards[self.agent_idx[agent_id]] = reward + return obs, self.rewards, term, trunc, info + + def close(self) -> None: + self.env.close() + + def seed(self, seed: Any = None) -> None: + try: + self.env.seed(seed) + except (NotImplementedError, AttributeError): + self.env.reset(seed=seed) + + def render(self) -> Any: + return self.env.render() diff --git a/examples/atari/tianshou/env/utils.py b/examples/atari/tianshou/env/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b5be12f438c4d33f289555c76fa8983f70ebea52 --- /dev/null +++ b/examples/atari/tianshou/env/utils.py @@ -0,0 +1,24 @@ +from typing import Any + +import cloudpickle +import gymnasium +import numpy as np + +from tianshou.env.pettingzoo_env import PettingZooEnv + +ENV_TYPE = gymnasium.Env | PettingZooEnv + +gym_new_venv_step_type = tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] + + +class CloudpickleWrapper: + """A cloudpickle wrapper used in SubprocVectorEnv.""" + + def __init__(self, data: Any) -> None: + self.data = data + + def __getstate__(self) -> str: + return cloudpickle.dumps(self.data) + + def __setstate__(self, data: str) -> None: + self.data = cloudpickle.loads(data) diff --git a/examples/atari/tianshou/env/venv_wrappers.py b/examples/atari/tianshou/env/venv_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..9297ddebbaaa67e8cbe7f63a94b0b36a85ec3e2f --- /dev/null +++ b/examples/atari/tianshou/env/venv_wrappers.py @@ -0,0 +1,120 @@ +from typing import Any + +import numpy as np +import torch + +from tianshou.env.utils import gym_new_venv_step_type +from tianshou.env.venvs import GYM_RESERVED_KEYS, BaseVectorEnv +from tianshou.utils import RunningMeanStd + + +class VectorEnvWrapper(BaseVectorEnv): + """Base class for vectorized environments wrapper.""" + + # Note: No super call because this is a wrapper with overridden __getattribute__ + # It's not a "true" subclass of BaseVectorEnv but it does extend its interface, so + # it can be used as a drop-in replacement + # noinspection PyMissingConstructor + def __init__(self, venv: BaseVectorEnv) -> None: + self.venv = venv + self.is_async = venv.is_async + + def __len__(self) -> int: + return len(self.venv) + + def __getattribute__(self, key: str) -> Any: + if key in GYM_RESERVED_KEYS: # reserved keys in gym.Env + return getattr(self.venv, key) + return super().__getattribute__(key) + + def get_env_attr( + self, + key: str, + id: int | list[int] | np.ndarray | None = None, + ) -> list[Any]: + return self.venv.get_env_attr(key, id) + + def set_env_attr( + self, + key: str, + value: Any, + id: int | list[int] | np.ndarray | None = None, + ) -> None: + return self.venv.set_env_attr(key, value, id) + + def reset( + self, + env_id: int | list[int] | np.ndarray | None = None, + **kwargs: Any, + ) -> tuple[np.ndarray, np.ndarray]: + return self.venv.reset(env_id, **kwargs) + + def step( + self, + action: np.ndarray | torch.Tensor | None, + id: int | list[int] | np.ndarray | None = None, + ) -> gym_new_venv_step_type: + return self.venv.step(action, id) + + def seed(self, seed: int | list[int] | None = None) -> list[list[int] | None]: + return self.venv.seed(seed) + + def render(self, **kwargs: Any) -> list[Any]: + return self.venv.render(**kwargs) + + def close(self) -> None: + self.venv.close() + + +class VectorEnvNormObs(VectorEnvWrapper): + """An observation normalization wrapper for vectorized environments. + + :param update_obs_rms: whether to update obs_rms. Default to True. + """ + + def __init__(self, venv: BaseVectorEnv, update_obs_rms: bool = True) -> None: + super().__init__(venv) + # initialize observation running mean/std + self.update_obs_rms = update_obs_rms + self.obs_rms = RunningMeanStd() + + def reset( + self, + env_id: int | list[int] | np.ndarray | None = None, + **kwargs: Any, + ) -> tuple[np.ndarray, np.ndarray]: + obs, info = self.venv.reset(env_id, **kwargs) + + if isinstance(obs, tuple): # type: ignore + raise TypeError( + "Tuple observation space is not supported. ", + "Please change it to array or dict space", + ) + + if self.obs_rms and self.update_obs_rms: + self.obs_rms.update(obs) + obs = self._norm_obs(obs) + return obs, info + + def step( + self, + action: np.ndarray | torch.Tensor | None, + id: int | list[int] | np.ndarray | None = None, + ) -> gym_new_venv_step_type: + step_results = self.venv.step(action, id) + if self.obs_rms and self.update_obs_rms: + self.obs_rms.update(step_results[0]) + return (self._norm_obs(step_results[0]), *step_results[1:]) + + def _norm_obs(self, obs: np.ndarray) -> np.ndarray: + if self.obs_rms: + return self.obs_rms.norm(obs) # type: ignore + return obs + + def set_obs_rms(self, obs_rms: RunningMeanStd) -> None: + """Set with given observation running mean/std.""" + self.obs_rms = obs_rms + + def get_obs_rms(self) -> RunningMeanStd: + """Return observation running mean/std.""" + return self.obs_rms diff --git a/examples/atari/tianshou/env/venvs.py b/examples/atari/tianshou/env/venvs.py new file mode 100644 index 0000000000000000000000000000000000000000..e9309f9ec79a0f0224a648a4920ddbefad6a7e8f --- /dev/null +++ b/examples/atari/tianshou/env/venvs.py @@ -0,0 +1,473 @@ +from collections.abc import Callable, Sequence +from typing import Any, Literal + +import gymnasium as gym +import numpy as np +import torch + +from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type +from tianshou.env.worker import ( + DummyEnvWorker, + EnvWorker, + RayEnvWorker, + SubprocEnvWorker, +) + +GYM_RESERVED_KEYS = [ + "metadata", + "reward_range", + "spec", + "action_space", + "observation_space", +] + + +class BaseVectorEnv: + """Base class for vectorized environments. + + Usage: + :: + + env_num = 8 + envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)]) + assert len(envs) == env_num + + It accepts a list of environment generators. In other words, an environment + generator ``efn`` of a specific task means that ``efn()`` returns the + environment of the given task, for example, ``gym.make(task)``. + + All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`. + Here are some other usages: + :: + + envs.seed(2) # which is equal to the next line + envs.seed([2, 3, 4, 5, 6, 7, 8, 9]) # set specific seed for each env + obs = envs.reset() # reset all environments + obs = envs.reset([0, 5, 7]) # reset 3 specific environments + obs, rew, done, info = envs.step([1] * 8) # step synchronously + envs.render() # render all environments + envs.close() # close all environments + + .. warning:: + + If you use your own environment, please make sure the ``seed`` method + is set up properly, e.g., + :: + + def seed(self, seed): + np.random.seed(seed) + + Otherwise, the outputs of these envs may be the same with each other. + + :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the i-th env. + :param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a + worker which contains the i-th env. + :param wait_num: use in asynchronous simulation if the time cost of + ``env.step`` varies with time and synchronously waiting for all + environments to finish a step is time-wasting. In that case, we can + return when ``wait_num`` environments finish a step and keep on + simulation in these environments. If ``None``, asynchronous simulation + is disabled; else, ``1 <= wait_num <= env_num``. + :param timeout: use in asynchronous simulation same as above, in each + vectorized step it only deal with those environments spending time + within ``timeout`` seconds. + """ + + def __init__( + self, + env_fns: Sequence[Callable[[], ENV_TYPE]], + worker_fn: Callable[[Callable[[], ENV_TYPE]], EnvWorker], + wait_num: int | None = None, + timeout: float | None = None, + ) -> None: + self._env_fns = env_fns + # A VectorEnv contains a pool of EnvWorkers, which corresponds to + # interact with the given envs (one worker <-> one env). + self.workers = [worker_fn(fn) for fn in env_fns] + self.worker_class = type(self.workers[0]) + assert issubclass(self.worker_class, EnvWorker) + assert all(isinstance(w, self.worker_class) for w in self.workers) + + self.env_num = len(env_fns) + self.wait_num = wait_num or len(env_fns) + assert ( + 1 <= self.wait_num <= len(env_fns) + ), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}" + self.timeout = timeout + assert ( + self.timeout is None or self.timeout > 0 + ), f"timeout is {timeout}, it should be positive if provided!" + self.is_async = self.wait_num != len(env_fns) or timeout is not None + self.waiting_conn: list[EnvWorker] = [] + # environments in self.ready_id is actually ready + # but environments in self.waiting_id are just waiting when checked, + # and they may be ready now, but this is not known until we check it + # in the step() function + self.waiting_id: list[int] = [] + # all environments are ready in the beginning + self.ready_id = list(range(self.env_num)) + self.is_closed = False + + def _assert_is_not_closed(self) -> None: + assert ( + not self.is_closed + ), f"Methods of {self.__class__.__name__} cannot be called after close." + + def __len__(self) -> int: + """Return len(self), which is the number of environments.""" + return self.env_num + + def __getattribute__(self, key: str) -> Any: + """Switch the attribute getter depending on the key. + + Any class who inherits ``gym.Env`` will inherit some attributes, like + ``action_space``. However, we would like the attribute lookup to go straight + into the worker (in fact, this vector env's action_space is always None). + """ + if key in GYM_RESERVED_KEYS: # reserved keys in gym.Env + return self.get_env_attr(key) + return super().__getattribute__(key) + + def get_env_attr( + self, + key: str, + id: int | list[int] | np.ndarray | None = None, + ) -> list[Any]: + """Get an attribute from the underlying environments. + + If id is an int, retrieve the attribute denoted by key from the environment + underlying the worker at index id. The result is returned as a list with one + element. Otherwise, retrieve the attribute for all workers at indices id and + return a list that is ordered correspondingly to id. + + :param str key: The key of the desired attribute. + :param id: Indice(s) of the desired worker(s). Default to None for all env_id. + + :return list: The list of environment attributes. + """ + self._assert_is_not_closed() + id = self._wrap_id(id) + if self.is_async: + self._assert_id(id) + + return [self.workers[j].get_env_attr(key) for j in id] + + def set_env_attr( + self, + key: str, + value: Any, + id: int | list[int] | np.ndarray | None = None, + ) -> None: + """Set an attribute in the underlying environments. + + If id is an int, set the attribute denoted by key from the environment + underlying the worker at index id to value. + Otherwise, set the attribute for all workers at indices id. + + :param str key: The key of the desired attribute. + :param Any value: The new value of the attribute. + :param id: Indice(s) of the desired worker(s). Default to None for all env_id. + """ + self._assert_is_not_closed() + id = self._wrap_id(id) + if self.is_async: + self._assert_id(id) + for j in id: + self.workers[j].set_env_attr(key, value) + + def _wrap_id( + self, + id: int | list[int] | np.ndarray | None = None, + ) -> list[int] | np.ndarray: + if id is None: + return list(range(self.env_num)) + return [id] if np.isscalar(id) else id # type: ignore + + def _assert_id(self, id: list[int] | np.ndarray) -> None: + for i in id: + assert ( + i not in self.waiting_id + ), f"Cannot interact with environment {i} which is stepping now." + assert i in self.ready_id, f"Can only interact with ready environments {self.ready_id}." + + # TODO: for now, has to be kept in sync with reset in EnvPoolMixin + # In particular, can't rename env_id to env_ids + def reset( + self, + env_id: int | list[int] | np.ndarray | None = None, + **kwargs: Any, + ) -> tuple[np.ndarray, np.ndarray]: + """Reset the state of some envs and return initial observations. + + If id is None, reset the state of all the environments and return + initial observations, otherwise reset the specific environments with + the given id, either an int or a list. + """ + self._assert_is_not_closed() + env_id = self._wrap_id(env_id) + if self.is_async: + self._assert_id(env_id) + + # send(None) == reset() in worker + for id in env_id: + self.workers[id].send(None, **kwargs) + ret_list = [self.workers[id].recv() for id in env_id] + + assert ( + isinstance(ret_list[0], tuple | list) + and len(ret_list[0]) == 2 + and isinstance(ret_list[0][1], dict) + ), "The environment does not adhere to the Gymnasium's API." + + obs_list = [r[0] for r in ret_list] + + if isinstance(obs_list[0], tuple): # type: ignore + raise TypeError( + "Tuple observation space is not supported. ", + "Please change it to array or dict space", + ) + try: + obs = np.stack(obs_list) + except ValueError: # different len(obs) + obs = np.array(obs_list, dtype=object) + + infos = np.array([r[1] for r in ret_list]) + return obs, infos + + def step( + self, + action: np.ndarray | torch.Tensor | None, + id: int | list[int] | np.ndarray | None = None, + ) -> gym_new_venv_step_type: + """Run one timestep of some environments' dynamics. + + If id is None, run one timestep of all the environments` dynamics; + otherwise run one timestep for some environments with given id, either + an int or a list. When the end of episode is reached, you are + responsible for calling reset(id) to reset this environment`s state. + + Accept a batch of action and return a tuple (batch_obs, batch_rew, + batch_done, batch_info) in numpy format. + + :param numpy.ndarray action: a batch of action provided by the agent. + If the venv is async, the action can be None, which will result + in all arrays in the returned tuple being empty. + + :return: A tuple consisting of either: + + * ``obs`` a numpy.ndarray, the agent's observation of current environments + * ``rew`` a numpy.ndarray, the amount of rewards returned after \ + previous actions + * ``terminated`` a numpy.ndarray, whether these episodes have been \ + terminated + * ``truncated`` a numpy.ndarray, whether these episodes have been truncated + * ``info`` a numpy.ndarray, contains auxiliary diagnostic \ + information (helpful for debugging, and sometimes learning) + + For the async simulation: + + Provide the given action to the environments. The action sequence + should correspond to the ``id`` argument, and the ``id`` argument + should be a subset of the ``env_id`` in the last returned ``info`` + (initially they are env_ids of all the environments). If action is + None, fetch unfinished step() calls instead. + """ + self._assert_is_not_closed() + id = self._wrap_id(id) + if not self.is_async: + if action is None: + raise ValueError("action must be not-None for non-async") + assert len(action) == len(id) + for i, j in enumerate(id): + self.workers[j].send(action[i]) + result = [] + for j in id: + env_return = self.workers[j].recv() + env_return[-1]["env_id"] = j + result.append(env_return) + else: + if action is not None: + self._assert_id(id) + assert len(action) == len(id) + for act, env_id in zip(action, id, strict=True): + self.workers[env_id].send(act) + self.waiting_conn.append(self.workers[env_id]) + self.waiting_id.append(env_id) + self.ready_id = [x for x in self.ready_id if x not in id] + ready_conns: list[EnvWorker] = [] + while not ready_conns: + ready_conns = self.worker_class.wait(self.waiting_conn, self.wait_num, self.timeout) + result = [] + for conn in ready_conns: + waiting_index = self.waiting_conn.index(conn) + self.waiting_conn.pop(waiting_index) + env_id = self.waiting_id.pop(waiting_index) + # env_return can be (obs, reward, done, info) or + # (obs, reward, terminated, truncated, info) + env_return = conn.recv() + env_return[-1]["env_id"] = env_id # Add `env_id` to info + result.append(env_return) + self.ready_id.append(env_id) + obs_list, rew_list, term_list, trunc_list, info_list = tuple(zip(*result, strict=True)) + try: + obs_stack = np.stack(obs_list) + except ValueError: # different len(obs) + obs_stack = np.array(obs_list, dtype=object) + return ( + obs_stack, + np.stack(rew_list), + np.stack(term_list), + np.stack(trunc_list), + np.stack(info_list), + ) + + def seed(self, seed: int | list[int] | None = None) -> list[list[int] | None]: + """Set the seed for all environments. + + Accept ``None``, an int (which will extend ``i`` to + ``[i, i + 1, i + 2, ...]``) or a list. + + :return: The list of seeds used in this env's random number generators. + The first value in the list should be the "main" seed, or the value + which a reproducer pass to "seed". + """ + self._assert_is_not_closed() + seed_list: list[None] | list[int] + if seed is None: + seed_list = [seed] * self.env_num + elif isinstance(seed, int): + seed_list = [seed + i for i in range(self.env_num)] + else: + seed_list = seed + return [w.seed(s) for w, s in zip(self.workers, seed_list, strict=True)] + + def render(self, **kwargs: Any) -> list[Any]: + """Render all of the environments.""" + self._assert_is_not_closed() + if self.is_async and len(self.waiting_id) > 0: + raise RuntimeError( + f"Environments {self.waiting_id} are still stepping, cannot render them now.", + ) + return [w.render(**kwargs) for w in self.workers] + + def close(self) -> None: + """Close all of the environments. + + This function will be called only once (if not, it will be called during + garbage collected). This way, ``close`` of all workers can be assured. + """ + self._assert_is_not_closed() + for w in self.workers: + w.close() + self.is_closed = True + + +class DummyVectorEnv(BaseVectorEnv): + """Dummy vectorized environment wrapper, implemented in for-loop. + + This has the same interface as true vectorized environment, but the rollout does not happen in parallel. + So, all workers just wait for each other and the environment is as efficient as using a single environment. + This can be useful for testing or for demonstration purposes. + + A rare use-case would be using vector based interface, but parallelization is not desired + (e.g. because of too much overhead). However, in such cases one should consider using a single environment. + + .. seealso:: + + Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. + """ + + def __init__( + self, + env_fns: Sequence[Callable[[], ENV_TYPE]], + wait_num: int | None = None, + timeout: float | None = None, + ) -> None: + super().__init__(env_fns, DummyEnvWorker, wait_num, timeout) + + +class SubprocVectorEnv(BaseVectorEnv): + """Vectorized environment wrapper based on subprocess. + + .. seealso:: + + Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. + + Additional arguments are: + + :param share_memory: whether to share memory between the main process and the worker process. Allows for + shared buffers to exchange observations + :param context: the context to use for multiprocessing. Usually it's fine to use the default context, but + `spawn` as well as `fork` can have non-obvious side effects, see for example + https://github.com/google-deepmind/mujoco/issues/742, or + https://github.com/Farama-Foundation/Gymnasium/issues/222. + Consider using 'fork' when using macOS and additional parallelization, for example via joblib. + Defaults to None, which will use the default system context. + """ + + def __init__( + self, + env_fns: Sequence[Callable[[], ENV_TYPE]], + wait_num: int | None = None, + timeout: float | None = None, + share_memory: bool = False, + context: Literal["fork", "spawn"] | None = None, + ) -> None: + def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: + return SubprocEnvWorker(fn, share_memory=share_memory, context=context) + + super().__init__( + env_fns, + worker_fn, + wait_num, + timeout, + ) + + +class ShmemVectorEnv(BaseVectorEnv): + """Optimized SubprocVectorEnv with shared buffers to exchange observations. + + ShmemVectorEnv has exactly the same API as SubprocVectorEnv. + + .. seealso:: + + Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. + """ + + def __init__( + self, + env_fns: Sequence[Callable[[], ENV_TYPE]], + wait_num: int | None = None, + timeout: float | None = None, + ) -> None: + def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: + return SubprocEnvWorker(fn, share_memory=True) + + super().__init__(env_fns, worker_fn, wait_num, timeout) + + +class RayVectorEnv(BaseVectorEnv): + """Vectorized environment wrapper based on ray. + + This is a choice to run distributed environments in a cluster. + + .. seealso:: + + Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. + """ + + def __init__( + self, + env_fns: Sequence[Callable[[], ENV_TYPE]], + wait_num: int | None = None, + timeout: float | None = None, + ) -> None: + try: + import ray + except ImportError as exception: + raise ImportError( + "Please install ray to support RayVectorEnv: pip install ray", + ) from exception + if not ray.is_initialized(): + ray.init() + super().__init__(env_fns, lambda env_fn: RayEnvWorker(env_fn), wait_num, timeout) diff --git a/examples/atari/tianshou/env/worker/__init__.py b/examples/atari/tianshou/env/worker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b1f37510fb22fa55f9593315226f0acb9753a4c --- /dev/null +++ b/examples/atari/tianshou/env/worker/__init__.py @@ -0,0 +1,11 @@ +from tianshou.env.worker.base import EnvWorker +from tianshou.env.worker.dummy import DummyEnvWorker +from tianshou.env.worker.ray import RayEnvWorker +from tianshou.env.worker.subproc import SubprocEnvWorker + +__all__ = [ + "EnvWorker", + "DummyEnvWorker", + "SubprocEnvWorker", + "RayEnvWorker", +] diff --git a/examples/atari/tianshou/env/worker/base.py b/examples/atari/tianshou/env/worker/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ac35ccf3f4bf07c12dafb8f7e27390a14305dacd --- /dev/null +++ b/examples/atari/tianshou/env/worker/base.py @@ -0,0 +1,86 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any + +import gymnasium as gym +import numpy as np + +from tianshou.env.utils import gym_new_venv_step_type + + +class EnvWorker(ABC): + """An abstract worker for an environment.""" + + def __init__(self, env_fn: Callable[[], gym.Env]) -> None: + self._env_fn = env_fn + self.is_closed = False + self.result: gym_new_venv_step_type | tuple[np.ndarray, dict] + self.action_space = self.get_env_attr("action_space") + self.is_reset = False + + @abstractmethod + def get_env_attr(self, key: str) -> Any: + pass + + @abstractmethod + def set_env_attr(self, key: str, value: Any) -> None: + pass + + @abstractmethod + def send(self, action: np.ndarray | None) -> None: + """Send action signal to low-level worker. + + When action is None, it indicates sending "reset" signal; otherwise + it indicates "step" signal. The paired return value from "recv" + function is determined by such kind of different signal. + """ + + def recv(self) -> gym_new_venv_step_type | tuple[np.ndarray, dict]: + """Receive result from low-level worker. + + If the last "send" function sends a NULL action, it only returns a + single observation; otherwise it returns a tuple of (obs, rew, done, + info) or (obs, rew, terminated, truncated, info), based on whether + the environment is using the old step API or the new one. + """ + return self.result + + @abstractmethod + def reset(self, **kwargs: Any) -> tuple[np.ndarray, dict]: + pass + + def step(self, action: np.ndarray) -> gym_new_venv_step_type: + """Perform one timestep of the environment's dynamic. + + "send" and "recv" are coupled in sync simulation, so users only call + "step" function. But they can be called separately in async + simulation, i.e. someone calls "send" first, and calls "recv" later. + """ + self.send(action) + return self.recv() # type: ignore + + @staticmethod + def wait( + workers: list["EnvWorker"], + wait_num: int, + timeout: float | None = None, + ) -> list["EnvWorker"]: + """Given a list of workers, return those ready ones.""" + raise NotImplementedError + + def seed(self, seed: int | None = None) -> list[int] | None: + return self.action_space.seed(seed) # issue 299 + + @abstractmethod + def render(self, **kwargs: Any) -> Any: + """Render the environment.""" + + @abstractmethod + def close_env(self) -> None: + pass + + def close(self) -> None: + if self.is_closed: + return + self.is_closed = True + self.close_env() diff --git a/examples/atari/tianshou/env/worker/dummy.py b/examples/atari/tianshou/env/worker/dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..0ed0b03193701776070abe914fc6ab90ff104962 --- /dev/null +++ b/examples/atari/tianshou/env/worker/dummy.py @@ -0,0 +1,55 @@ +from collections.abc import Callable +from typing import Any + +import gymnasium as gym +import numpy as np + +from tianshou.env.worker import EnvWorker + + +class DummyEnvWorker(EnvWorker): + """Dummy worker used in sequential vector environments.""" + + def __init__(self, env_fn: Callable[[], gym.Env]) -> None: + self.env = env_fn() + super().__init__(env_fn) + + def get_env_attr(self, key: str) -> Any: + return getattr(self.env, key) + + def set_env_attr(self, key: str, value: Any) -> None: + setattr(self.env.unwrapped, key, value) + + def reset(self, **kwargs: Any) -> tuple[np.ndarray, dict]: + if "seed" in kwargs: + super().seed(kwargs["seed"]) + return self.env.reset(**kwargs) + + @staticmethod + def wait( # type: ignore + workers: list["DummyEnvWorker"], + wait_num: int, + timeout: float | None = None, + ) -> list["DummyEnvWorker"]: + # Sequential EnvWorker objects are always ready + return workers + + def send(self, action: np.ndarray | None, **kwargs: Any) -> None: + if action is None: + self.result = self.env.reset(**kwargs) + else: + self.result = self.env.step(action) # type: ignore + + def seed(self, seed: int | None = None) -> list[int] | None: + super().seed(seed) + try: + return self.env.seed(seed) # type: ignore + except (AttributeError, NotImplementedError): + self.env.reset(seed=seed) + return [seed] # type: ignore + + def render(self, **kwargs: Any) -> Any: + return self.env.render(**kwargs) + + def close_env(self) -> None: + self.env.close() diff --git a/examples/atari/tianshou/env/worker/ray.py b/examples/atari/tianshou/env/worker/ray.py new file mode 100644 index 0000000000000000000000000000000000000000..76b842220ccaa8b82bce34a5434ee56bd6e89a30 --- /dev/null +++ b/examples/atari/tianshou/env/worker/ray.py @@ -0,0 +1,79 @@ +import contextlib +from collections.abc import Callable +from typing import Any + +import gymnasium as gym +import numpy as np + +from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type +from tianshou.env.worker import EnvWorker + +with contextlib.suppress(ImportError): + import ray + + +# mypy: disable-error-code="unused-ignore" + + +class _SetAttrWrapper(gym.Wrapper): + def set_env_attr(self, key: str, value: Any) -> None: + setattr(self.env.unwrapped, key, value) + + def get_env_attr(self, key: str) -> Any: + return getattr(self.env, key) + + +class RayEnvWorker(EnvWorker): + """Ray worker used in RayVectorEnv.""" + + def __init__( + self, + env_fn: Callable[[], ENV_TYPE], + ) -> None: # TODO: is ENV_TYPE actually correct? + self.env = ray.remote(_SetAttrWrapper).options(num_cpus=0).remote(env_fn()) # type: ignore + super().__init__(env_fn) + + def get_env_attr(self, key: str) -> Any: + return ray.get(self.env.get_env_attr.remote(key)) + + def set_env_attr(self, key: str, value: Any) -> None: + ray.get(self.env.set_env_attr.remote(key, value)) + + def reset(self, **kwargs: Any) -> Any: + if "seed" in kwargs: + super().seed(kwargs["seed"]) + return ray.get(self.env.reset.remote(**kwargs)) + + @staticmethod + def wait( # type: ignore + workers: list["RayEnvWorker"], + wait_num: int, + timeout: float | None = None, + ) -> list["RayEnvWorker"]: + results = [x.result for x in workers] + ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout) + return [workers[results.index(result)] for result in ready_results] + + def send(self, action: np.ndarray | None, **kwargs: Any) -> None: + # self.result is actually a handle + if action is None: + self.result = self.env.reset.remote(**kwargs) + else: + self.result = self.env.step.remote(action) + + def recv(self) -> gym_new_venv_step_type: + return ray.get(self.result) # type: ignore + + def seed(self, seed: int | None = None) -> list[int] | None: + super().seed(seed) + try: + return ray.get(self.env.seed.remote(seed)) + except (AttributeError, NotImplementedError): + self.env.reset.remote(seed=seed) + return None + + def render(self, **kwargs: Any) -> Any: + return ray.get(self.env.render.remote(**kwargs)) + + def close_env(self) -> None: + ray.get(self.env.close.remote()) diff --git a/examples/atari/tianshou/env/worker/subproc.py b/examples/atari/tianshou/env/worker/subproc.py new file mode 100644 index 0000000000000000000000000000000000000000..74193e920c0804e1f929588fee3ee01b42408226 --- /dev/null +++ b/examples/atari/tianshou/env/worker/subproc.py @@ -0,0 +1,276 @@ +import ctypes +import multiprocessing +import time +from collections import OrderedDict +from collections.abc import Callable +from multiprocessing import connection +from multiprocessing.context import BaseContext +from typing import Any, Literal + +import gymnasium as gym +import numpy as np + +from tianshou.env.utils import CloudpickleWrapper, gym_new_venv_step_type +from tianshou.env.worker import EnvWorker + +# mypy: disable-error-code="unused-ignore" + + +_NP_TO_CT = { + np.bool_: ctypes.c_bool, + np.uint8: ctypes.c_uint8, + np.uint16: ctypes.c_uint16, + np.uint32: ctypes.c_uint32, + np.uint64: ctypes.c_uint64, + np.int8: ctypes.c_int8, + np.int16: ctypes.c_int16, + np.int32: ctypes.c_int32, + np.int64: ctypes.c_int64, + np.float32: ctypes.c_float, + np.float64: ctypes.c_double, +} + + +class ShArray: + """Wrapper of multiprocessing Array. + + Example usage: + + :: + + import numpy as np + import multiprocessing as mp + from tianshou.env.worker.subproc import ShArray + ctx = mp.get_context('fork') # set an explicit context + arr = ShArray(np.dtype(np.float32), (2, 3), ctx) + arr.save(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)) + print(arr.get()) + + """ + + def __init__(self, dtype: np.generic, shape: tuple[int], ctx: BaseContext | None) -> None: + if ctx is None: + ctx = multiprocessing.get_context() + self.arr = ctx.Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore + self.dtype = dtype + self.shape = shape + + def save(self, ndarray: np.ndarray) -> None: + assert isinstance(ndarray, np.ndarray) + dst = self.arr.get_obj() + dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape) # type: ignore + np.copyto(dst_np, ndarray) + + def get(self) -> np.ndarray: + obj = self.arr.get_obj() + return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) # type: ignore + + +def _setup_buf(space: gym.Space, ctx: BaseContext) -> dict | tuple | ShArray: + if isinstance(space, gym.spaces.Dict): + assert isinstance(space.spaces, OrderedDict) + return {k: _setup_buf(v, ctx) for k, v in space.spaces.items()} + if isinstance(space, gym.spaces.Tuple): + assert isinstance(space.spaces, tuple) + return tuple([_setup_buf(t, ctx) for t in space.spaces]) + return ShArray(space.dtype, space.shape, ctx) # type: ignore + + +def _worker( + parent: connection.Connection, + p: connection.Connection, + env_fn_wrapper: CloudpickleWrapper, + obs_bufs: dict | tuple | ShArray | None = None, +) -> None: + def _encode_obs( + obs: dict | tuple | np.ndarray, + buffer: dict | tuple | ShArray, + ) -> None: + if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray): + buffer.save(obs) + elif isinstance(obs, tuple) and isinstance(buffer, tuple): + for o, b in zip(obs, buffer, strict=True): + _encode_obs(o, b) + elif isinstance(obs, dict) and isinstance(buffer, dict): + for k in obs: + _encode_obs(obs[k], buffer[k]) + + parent.close() + env = env_fn_wrapper.data() + try: + while True: + try: + cmd, data = p.recv() + except EOFError: # the pipe has been closed + p.close() + break + if cmd == "step": + env_return = env.step(data) + if obs_bufs is not None: + _encode_obs(env_return[0], obs_bufs) + env_return = (None, *env_return[1:]) + p.send(env_return) + elif cmd == "reset": + obs, info = env.reset(**data) + if obs_bufs is not None: + _encode_obs(obs, obs_bufs) + obs = None + p.send((obs, info)) + elif cmd == "close": + p.send(env.close()) + p.close() + break + elif cmd == "render": + p.send(env.render(**data) if hasattr(env, "render") else None) + elif cmd == "seed": + if hasattr(env, "seed"): + p.send(env.seed(data)) + else: + env.action_space.seed(seed=data) + env.reset(seed=data) + p.send(None) + elif cmd == "getattr": + p.send(getattr(env, data) if hasattr(env, data) else None) + elif cmd == "setattr": + setattr(env.unwrapped, data["key"], data["value"]) + else: + p.close() + raise NotImplementedError + except KeyboardInterrupt: + p.close() + + +class SubprocEnvWorker(EnvWorker): + """Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv.""" + + def __init__( + self, + env_fn: Callable[[], gym.Env], + share_memory: bool = False, + context: BaseContext | Literal["fork", "spawn"] | None = None, + ) -> None: + if not isinstance(context, BaseContext): + context = multiprocessing.get_context(context) + self.parent_remote, self.child_remote = context.Pipe() + self.share_memory = share_memory + self.buffer: dict | tuple | ShArray | None = None + assert hasattr(context, "Process") # for mypy + if self.share_memory: + dummy = env_fn() + obs_space = dummy.observation_space + dummy.close() + del dummy + self.buffer = _setup_buf(obs_space, context) + args = ( + self.parent_remote, + self.child_remote, + CloudpickleWrapper(env_fn), + self.buffer, + ) + self.process = context.Process(target=_worker, args=args, daemon=True) + self.process.start() + self.child_remote.close() + super().__init__(env_fn) + + def get_env_attr(self, key: str) -> Any: + self.parent_remote.send(["getattr", key]) + return self.parent_remote.recv() + + def set_env_attr(self, key: str, value: Any) -> None: + self.parent_remote.send(["setattr", {"key": key, "value": value}]) + + def _decode_obs(self) -> dict | tuple | np.ndarray: + def decode_obs( + buffer: dict | tuple | ShArray | None, + ) -> dict | tuple | np.ndarray: + if isinstance(buffer, ShArray): + return buffer.get() + if isinstance(buffer, tuple): + return tuple([decode_obs(b) for b in buffer]) + if isinstance(buffer, dict): + return {k: decode_obs(v) for k, v in buffer.items()} + raise NotImplementedError + + return decode_obs(self.buffer) + + @staticmethod + def wait( # type: ignore + workers: list["SubprocEnvWorker"], + wait_num: int, + timeout: float | None = None, + ) -> list["SubprocEnvWorker"]: + remain_conns = conns = [x.parent_remote for x in workers] + ready_conns: list[connection.Connection] = [] + remain_time, t1 = timeout, time.time() + while len(remain_conns) > 0 and len(ready_conns) < wait_num: + if timeout: + remain_time = timeout - (time.time() - t1) + if remain_time <= 0: + break + # connection.wait hangs if the list is empty + new_ready_conns = connection.wait(remain_conns, timeout=remain_time) # type: ignore + ready_conns.extend(new_ready_conns) # type: ignore + remain_conns = [conn for conn in remain_conns if conn not in ready_conns] # type: ignore + return [workers[conns.index(con)] for con in ready_conns] # type: ignore + + def send(self, action: np.ndarray | None, **kwargs: Any) -> None: + if action is None: + if "seed" in kwargs: + super().seed(kwargs["seed"]) + self.parent_remote.send(["reset", kwargs]) + else: + self.parent_remote.send(["step", action]) + + def recv(self) -> gym_new_venv_step_type | tuple[np.ndarray, dict]: + result = self.parent_remote.recv() + if isinstance(result, tuple): + if len(result) == 2: + obs, info = result + if self.share_memory: + obs = self._decode_obs() + return obs, info + obs = result[0] + if self.share_memory: + obs = self._decode_obs() + # TODO: figure out the typing issue, simplify and document this method + return (obs, *result[1:]) + obs = result + if self.share_memory: + obs = self._decode_obs() + return obs + + def reset(self, **kwargs: Any) -> tuple[np.ndarray, dict]: + if "seed" in kwargs: + super().seed(kwargs["seed"]) + self.parent_remote.send(["reset", kwargs]) + + result = self.parent_remote.recv() + if isinstance(result, tuple): + obs, info = result + if self.share_memory: + obs = self._decode_obs() + return obs, info + obs = result + if self.share_memory: + obs = self._decode_obs() + return obs + + def seed(self, seed: int | None = None) -> list[int] | None: + super().seed(seed) + self.parent_remote.send(["seed", seed]) + return self.parent_remote.recv() + + def render(self, **kwargs: Any) -> Any: + self.parent_remote.send(["render", kwargs]) + return self.parent_remote.recv() + + def close_env(self) -> None: + try: + self.parent_remote.send(["close", None]) + # mp may be deleted so it may raise AttributeError + self.parent_remote.recv() + self.process.join() + except (BrokenPipeError, EOFError, AttributeError): + pass + # ensure the subproc is terminated + self.process.terminate() diff --git a/examples/atari/tianshou/evaluation/__init__.py b/examples/atari/tianshou/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/atari/tianshou/evaluation/launcher.py b/examples/atari/tianshou/evaluation/launcher.py new file mode 100644 index 0000000000000000000000000000000000000000..534e5f835c4148e16e220066978061941279f9ee --- /dev/null +++ b/examples/atari/tianshou/evaluation/launcher.py @@ -0,0 +1,149 @@ +"""Provides a basic interface for launching experiments. The API is experimental and subject to change!.""" + +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from copy import copy +from dataclasses import asdict, dataclass +from enum import Enum +from typing import Literal + +from joblib import Parallel, delayed + +from tianshou.data import InfoStats +from tianshou.highlevel.experiment import Experiment + +log = logging.getLogger(__name__) + + +@dataclass +class JoblibConfig: + n_jobs: int = -1 + """The maximum number of concurrently running jobs. If -1, all CPUs are used.""" + backend: Literal["loky", "multiprocessing", "threading"] | None = "loky" + """Allows to hard-code backend, otherwise inferred based on prefer and require.""" + verbose: int = 10 + """If greater than zero, prints progress messages.""" + + +class ExpLauncher(ABC): + def __init__( + self, + experiment_runner: Callable[ + [Experiment], + InfoStats | None, + ] = lambda exp: exp.run().trainer_result, + ): + """:param experiment_runner: can be used to override the default way in which an experiment is executed. + Can be useful e.g., if one wants to use the high-level interfaces to setup an experiment (or an experiment + collection) and tinker with it prior to execution. This need often arises when prototyping with mechanisms + that are not yet supported by the high-level interfaces. + Passing this allows arbitrary things to happen during experiment execution, so use it with caution! + """ + self.experiment_runner = experiment_runner + + @abstractmethod + def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]: + """Should call `self.experiment_runner` for each experiment in experiments and aggregate the results.""" + + def _safe_execute(self, exp: Experiment) -> InfoStats | None | Literal["failed"]: + try: + return self.experiment_runner(exp) + except BaseException as e: + log.error(f"Failed to run experiment {exp}.", exc_info=e) + return "failed" + + @staticmethod + def _return_from_successful_and_failed_exps( + successful_exp_stats: list[InfoStats | None], + failed_exps: list[Experiment], + ) -> list[InfoStats | None]: + if not successful_exp_stats: + raise RuntimeError("All experiments failed, see error logs for more details.") + if failed_exps: + log.error( + f"Failed to run the following " + f"{len(failed_exps)}/{len(successful_exp_stats) + len(failed_exps)} experiments: {failed_exps}. " + f"See the logs for more details. " + f"Returning the results of {len(successful_exp_stats)} successful experiments.", + ) + return successful_exp_stats + + def launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]: + """Will return the results of successfully executed experiments. + + If a single experiment is passed, will not use parallelism and run it in the main process. + Failed experiments will be logged, and a RuntimeError is only raised if all experiments have failed. + """ + if len(experiments) == 1: + log.info( + "A single experiment is being run, will not use parallelism and run it in the main process.", + ) + return [self.experiment_runner(experiments[0])] + return self._launch(experiments) + + +class SequentialExpLauncher(ExpLauncher): + """Convenience wrapper around a simple for loop to run experiments sequentially.""" + + def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]: + successful_exp_stats = [] + failed_exps = [] + for exp in experiments: + for exp in experiments: + exp_stats = self._safe_execute(exp) + if exp_stats == "failed": + failed_exps.append(exp) + else: + successful_exp_stats.append(exp_stats) + # noinspection PyTypeChecker + return self._return_from_successful_and_failed_exps(successful_exp_stats, failed_exps) + + +class JoblibExpLauncher(ExpLauncher): + def __init__( + self, + joblib_cfg: JoblibConfig | None = None, + experiment_runner: Callable[ + [Experiment], + InfoStats | None, + ] = lambda exp: exp.run().trainer_result, + ) -> None: + super().__init__(experiment_runner=experiment_runner) + self.joblib_cfg = copy(joblib_cfg) if joblib_cfg is not None else JoblibConfig() + # Joblib's backend is hard-coded to loky since the threading backend produces different results + if self.joblib_cfg.backend != "loky": + log.warning( + f"Ignoring the user provided joblib backend {self.joblib_cfg.backend} and using loky instead. " + f"The current implementation requires loky to work and will be relaxed soon", + ) + self.joblib_cfg.backend = "loky" + + def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]: + results = Parallel(**asdict(self.joblib_cfg))( + delayed(self._safe_execute)(exp) for exp in experiments + ) + successful_exps = [] + failed_exps = [] + for exp, result in zip(experiments, results, strict=True): + if result == "failed": + failed_exps.append(exp) + else: + successful_exps.append(result) + return self._return_from_successful_and_failed_exps(successful_exps, failed_exps) + + +class RegisteredExpLauncher(Enum): + joblib = "joblib" + sequential = "sequential" + + def create_launcher(self) -> ExpLauncher: + match self: + case RegisteredExpLauncher.joblib: + return JoblibExpLauncher() + case RegisteredExpLauncher.sequential: + return SequentialExpLauncher() + case _: + raise NotImplementedError( + f"Launcher {self} is not yet implemented.", + ) diff --git a/examples/atari/tianshou/evaluation/rliable_evaluation_hl.py b/examples/atari/tianshou/evaluation/rliable_evaluation_hl.py new file mode 100644 index 0000000000000000000000000000000000000000..884176bd1952a3086992f2e5b8bf19b9fb2ed382 --- /dev/null +++ b/examples/atari/tianshou/evaluation/rliable_evaluation_hl.py @@ -0,0 +1,218 @@ +"""The rliable-evaluation module provides a high-level interface to evaluate the results of an experiment with multiple runs +on different seeds using the rliable library. The API is experimental and subject to change!. +""" + +import os +from dataclasses import asdict, dataclass, fields + +import matplotlib.pyplot as plt +import numpy as np +import scipy.stats as sst +from rliable import library as rly +from rliable import plot_utils + +from tianshou.highlevel.experiment import Experiment +from tianshou.utils import logging +from tianshou.utils.logger.base import DataScope + +log = logging.getLogger(__name__) + + +@dataclass +class LoggedSummaryData: + mean: np.ndarray + std: np.ndarray + max: np.ndarray + min: np.ndarray + + +@dataclass +class LoggedCollectStats: + env_step: np.ndarray | None = None + n_collected_episodes: np.ndarray | None = None + n_collected_steps: np.ndarray | None = None + collect_time: np.ndarray | None = None + collect_speed: np.ndarray | None = None + returns_stat: LoggedSummaryData | None = None + lens_stat: LoggedSummaryData | None = None + + @classmethod + def from_data_dict(cls, data: dict) -> "LoggedCollectStats": + """Create a LoggedCollectStats object from a dictionary. + + Converts SequenceSummaryStats from dict format to dataclass format and ignores fields that are not present. + """ + field_names = [f.name for f in fields(cls)] + for k, v in data.items(): + if k not in field_names: + data.pop(k) + if isinstance(v, dict): + data[k] = LoggedSummaryData(**v) + return cls(**data) + + +@dataclass +class RLiableExperimentResult: + """The result of an experiment that can be used with the rliable library.""" + + exp_dir: str + """The base directory where each sub-directory contains the results of one experiment run.""" + + test_episode_returns_RE: np.ndarray + """The test episodes for each run of the experiment where each row corresponds to one run.""" + + env_steps_E: np.ndarray + """The number of environment steps at which the test episodes were evaluated.""" + + @classmethod + def load_from_disk(cls, exp_dir: str) -> "RLiableExperimentResult": + """Load the experiment result from disk. + + :param exp_dir: The directory from where the experiment results are restored. + """ + test_episode_returns = [] + env_step_at_test = None + + # TODO: env_step_at_test should not be defined in a loop and overwritten at each iteration + # just for retrieving them. We might need a cleaner directory structure. + for entry in os.scandir(exp_dir): + if entry.name.startswith(".") or not entry.is_dir(): + continue + + exp = Experiment.from_directory(entry.path) + logger = exp.logger_factory.create_logger( + entry.path, + entry.name, + None, + asdict(exp.config), + ) + data = logger.restore_logged_data(entry.path) + + if DataScope.TEST.value not in data or not data[DataScope.TEST.value]: + continue + restored_test_data = data[DataScope.TEST.value] + if not isinstance(restored_test_data, dict): + raise RuntimeError( + f"Expected entry with key {DataScope.TEST.value} data to be a dictionary, " + f"but got {restored_test_data=}.", + ) + test_data = LoggedCollectStats.from_data_dict(restored_test_data) + + if test_data.returns_stat is None: + continue + test_episode_returns.append(test_data.returns_stat.mean) + env_step_at_test = test_data.env_step + + if not test_episode_returns or env_step_at_test is None: + raise ValueError(f"No experiment data found in {exp_dir}.") + + return cls( + test_episode_returns_RE=np.array(test_episode_returns), + env_steps_E=np.array(env_step_at_test), + exp_dir=exp_dir, + ) + + def _get_rliable_data( + self, + algo_name: str | None = None, + score_thresholds: np.ndarray | None = None, + ) -> tuple[dict, np.ndarray, np.ndarray]: + """Return the data in the format expected by the rliable library. + + :param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm + is set to the experiment dir. + :param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred + from the minimum and maximum test episode returns. + + :return: A tuple score_dict, env_steps, and score_thresholds. + """ + if score_thresholds is None: + score_thresholds = np.linspace( + np.min(self.test_episode_returns_RE), + np.max(self.test_episode_returns_RE), + 101, + ) + + if algo_name is None: + algo_name = os.path.basename(self.exp_dir) + + score_dict = {algo_name: self.test_episode_returns_RE} + + return score_dict, self.env_steps_E, score_thresholds + + def eval_results( + self, + algo_name: str | None = None, + score_thresholds: np.ndarray | None = None, + save_plots: bool = False, + show_plots: bool = True, + ) -> tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]: + """Evaluate the results of an experiment and create a sample efficiency curve and a performance profile. + + :param algo_name: The name of the algorithm to be shown in the figure legend. If None, the name of the algorithm + is set to the experiment dir. + :param score_thresholds: The score thresholds for the performance profile. If None, the thresholds are inferred + from the minimum and maximum test episode returns. + :param save_plots: If True, the figures are saved to the experiment directory. + :param show_plots: If True, the figures are shown. + + :return: The created figures and axes. + """ + score_dict, env_steps, score_thresholds = self._get_rliable_data( + algo_name, + score_thresholds, + ) + + iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0) + iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm) + + # Plot IQM sample efficiency curve + fig_iqm, ax_iqm = plt.subplots(ncols=1, figsize=(7, 5), constrained_layout=True) + plot_utils.plot_sample_efficiency_curve( + env_steps, + iqm_scores, + iqm_cis, + algorithms=None, + xlabel="env step", + ylabel="IQM episode return", + ax=ax_iqm, + ) + if show_plots: + plt.show(block=False) + + if save_plots: + iqm_sample_efficiency_curve_path = os.path.abspath( + os.path.join( + self.exp_dir, + "iqm_sample_efficiency_curve.png", + ), + ) + log.info(f"Saving iqm sample efficiency curve to {iqm_sample_efficiency_curve_path}.") + fig_iqm.savefig(iqm_sample_efficiency_curve_path) + + final_score_dict = {algo: returns[:, [-1]] for algo, returns in score_dict.items()} + score_distributions, score_distributions_cis = rly.create_performance_profile( + final_score_dict, + score_thresholds, + ) + + # Plot score distributions + fig_profile, ax_profile = plt.subplots(ncols=1, figsize=(7, 5), constrained_layout=True) + plot_utils.plot_performance_profiles( + score_distributions, + score_thresholds, + performance_profile_cis=score_distributions_cis, + xlabel=r"Episode return $(\tau)$", + ax=ax_profile, + ) + + if save_plots: + profile_curve_path = os.path.abspath( + os.path.join(self.exp_dir, "performance_profile.png"), + ) + log.info(f"Saving performance profile curve to {profile_curve_path}.") + fig_profile.savefig(profile_curve_path) + if show_plots: + plt.show(block=False) + + return fig_iqm, ax_iqm, fig_profile, ax_profile diff --git a/examples/atari/tianshou/exploration/__init__.py b/examples/atari/tianshou/exploration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0878d23840a76aeda9cec1319d1be598730d5a20 --- /dev/null +++ b/examples/atari/tianshou/exploration/__init__.py @@ -0,0 +1,7 @@ +from tianshou.exploration.random import BaseNoise, GaussianNoise, OUNoise + +__all__ = [ + "BaseNoise", + "GaussianNoise", + "OUNoise", +] diff --git a/examples/atari/tianshou/exploration/random.py b/examples/atari/tianshou/exploration/random.py new file mode 100644 index 0000000000000000000000000000000000000000..a797a9049bfd8a7345f59439ab041d818ef2791b --- /dev/null +++ b/examples/atari/tianshou/exploration/random.py @@ -0,0 +1,82 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence + +import numpy as np + + +class BaseNoise(ABC): + """The action noise base class.""" + + @abstractmethod + def reset(self) -> None: + """Reset to the initial state.""" + + @abstractmethod + def __call__(self, size: Sequence[int]) -> np.ndarray: + """Generate new noise.""" + raise NotImplementedError + + +class GaussianNoise(BaseNoise): + """The vanilla Gaussian process, for exploration in DDPG by default.""" + + def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None: + self._mu = mu + assert sigma >= 0, "Noise std should not be negative." + self._sigma = sigma + + def __call__(self, size: Sequence[int]) -> np.ndarray: + return np.random.normal(self._mu, self._sigma, size) + + def reset(self) -> None: + pass + + +class OUNoise(BaseNoise): + """Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG. + + Usage: + :: + + # init + self.noise = OUNoise() + # generate noise + noise = self.noise(logits.shape, eps) + + For required parameters, you can refer to the stackoverflow page. However, + our experiment result shows that (similar to OpenAI SpinningUp) using + vanilla Gaussian process has little difference from using the + Ornstein-Uhlenbeck process. + """ + + def __init__( + self, + mu: float = 0.0, + sigma: float = 0.3, + theta: float = 0.15, + dt: float = 1e-2, + x0: float | np.ndarray | None = None, + ) -> None: + super().__init__() + self._mu = mu + self._alpha = theta * dt + self._beta = sigma * np.sqrt(dt) + self._x0 = x0 + self.reset() + + def reset(self) -> None: + """Reset to the initial state.""" + self._x = self._x0 + + def __call__(self, size: Sequence[int], mu: float | None = None) -> np.ndarray: + """Generate new noise. + + Return an numpy array which size is equal to ``size``. + """ + if self._x is None or isinstance(self._x, np.ndarray) and self._x.shape != size: + self._x = 0.0 + if mu is None: + mu = self._mu + r = self._beta * np.random.normal(size=size) + self._x = self._x + self._alpha * (mu - self._x) + r + return self._x # type: ignore diff --git a/examples/atari/tianshou/highlevel/__init__.py b/examples/atari/tianshou/highlevel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/examples/atari/tianshou/highlevel/__init__.py @@ -0,0 +1 @@ + diff --git a/examples/atari/tianshou/highlevel/agent.py b/examples/atari/tianshou/highlevel/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..c1313262eeb8ffca9f910bc3df25dfdd6f4cc790 --- /dev/null +++ b/examples/atari/tianshou/highlevel/agent.py @@ -0,0 +1,620 @@ +import logging +import typing +from abc import ABC, abstractmethod +from typing import Any, Generic, TypeVar, cast + +import gymnasium + +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data.collector import BaseCollector +from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.actor import ( + ActorFactory, +) +from tianshou.highlevel.module.core import ( + ModuleFactory, + TDevice, +) +from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory +from tianshou.highlevel.module.module_opt import ( + ActorCriticOpt, +) +from tianshou.highlevel.optim import OptimizerFactory +from tianshou.highlevel.params.policy_params import ( + A2CParams, + DDPGParams, + DiscreteSACParams, + DQNParams, + IQNParams, + NPGParams, + Params, + ParamsMixinActorAndDualCritics, + ParamsMixinLearningRateWithScheduler, + ParamTransformerData, + PGParams, + PPOParams, + REDQParams, + SACParams, + TD3Params, + TRPOParams, +) +from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory +from tianshou.highlevel.persistence import PolicyPersistence +from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext +from tianshou.highlevel.world import World +from tianshou.policy import ( + A2CPolicy, + BasePolicy, + DDPGPolicy, + DiscreteSACPolicy, + DQNPolicy, + IQNPolicy, + NPGPolicy, + PGPolicy, + PPOPolicy, + REDQPolicy, + SACPolicy, + TD3Policy, + TRPOPolicy, +) +from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer +from tianshou.utils.net.common import ActorCritic +from tianshou.utils.string import ToStringMixin + +CHECKPOINT_DICT_KEY_MODEL = "model" +CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" +TParams = TypeVar("TParams", bound=Params) +TActorCriticParams = TypeVar( + "TActorCriticParams", + bound=Params | ParamsMixinLearningRateWithScheduler, +) +TActorDualCriticsParams = TypeVar( + "TActorDualCriticsParams", + bound=Params | ParamsMixinActorAndDualCritics, +) +TDiscreteCriticOnlyParams = TypeVar( + "TDiscreteCriticOnlyParams", + bound=Params | ParamsMixinLearningRateWithScheduler, +) +TPolicy = TypeVar("TPolicy", bound=BasePolicy) +log = logging.getLogger(__name__) + + +class AgentFactory(ABC, ToStringMixin): + """Factory for the creation of an agent's policy, its trainer as well as collectors.""" + + def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory): + self.sampling_config = sampling_config + self.optim_factory = optim_factory + self.policy_wrapper_factory: PolicyWrapperFactory | None = None + self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks() + + def create_train_test_collector( + self, + policy: BasePolicy, + envs: Environments, + reset_collectors: bool = True, + ) -> tuple[BaseCollector, BaseCollector]: + """:param policy: + :param envs: + :param reset_collectors: Whether to reset the collectors before returning them. + Setting to True means that the envs will be reset as well. + :return: + """ + buffer_size = self.sampling_config.buffer_size + train_envs = envs.train_envs + buffer: ReplayBuffer + if len(train_envs) > 1: + buffer = VectorReplayBuffer( + buffer_size, + len(train_envs), + stack_num=self.sampling_config.replay_buffer_stack_num, + save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs, + ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next, + ) + else: + buffer = ReplayBuffer( + buffer_size, + stack_num=self.sampling_config.replay_buffer_stack_num, + save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs, + ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next, + ) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, envs.test_envs) + if reset_collectors: + train_collector.reset() + test_collector.reset() + + if self.sampling_config.start_timesteps > 0: + log.info( + f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})", + ) + train_collector.collect( + n_step=self.sampling_config.start_timesteps, + random=self.sampling_config.start_timesteps_random, + ) + return train_collector, test_collector + + def set_policy_wrapper_factory( + self, + policy_wrapper_factory: PolicyWrapperFactory | None, + ) -> None: + self.policy_wrapper_factory = policy_wrapper_factory + + def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None: + self.trainer_callbacks = callbacks + + @abstractmethod + def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + pass + + def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + policy = self._create_policy(envs, device) + if self.policy_wrapper_factory is not None: + policy = self.policy_wrapper_factory.create_wrapped_policy( + policy, + envs, + self.optim_factory, + device, + ) + return policy + + @abstractmethod + def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> BaseTrainer: + pass + + +class OnPolicyAgentFactory(AgentFactory, ABC): + def create_trainer( + self, + world: World, + policy_persistence: PolicyPersistence, + ) -> OnpolicyTrainer: + sampling_config = self.sampling_config + callbacks = self.trainer_callbacks + context = TrainingContext(world.policy, world.envs, world.logger) + train_fn = ( + callbacks.epoch_train_callback.get_trainer_fn(context) + if callbacks.epoch_train_callback + else None + ) + test_fn = ( + callbacks.epoch_test_callback.get_trainer_fn(context) + if callbacks.epoch_test_callback + else None + ) + stop_fn = ( + callbacks.epoch_stop_callback.get_trainer_fn(context) + if callbacks.epoch_stop_callback + else None + ) + return OnpolicyTrainer( + policy=world.policy, + train_collector=world.train_collector, + test_collector=world.test_collector, + max_epoch=sampling_config.num_epochs, + step_per_epoch=sampling_config.step_per_epoch, + repeat_per_collect=sampling_config.repeat_per_collect, + episode_per_test=sampling_config.num_test_episodes, + batch_size=sampling_config.batch_size, + step_per_collect=sampling_config.step_per_collect, + save_best_fn=policy_persistence.get_save_best_fn(world), + logger=world.logger, + test_in_train=False, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + verbose=False, + ) + + +class OffPolicyAgentFactory(AgentFactory, ABC): + def create_trainer( + self, + world: World, + policy_persistence: PolicyPersistence, + ) -> OffpolicyTrainer: + sampling_config = self.sampling_config + callbacks = self.trainer_callbacks + context = TrainingContext(world.policy, world.envs, world.logger) + train_fn = ( + callbacks.epoch_train_callback.get_trainer_fn(context) + if callbacks.epoch_train_callback + else None + ) + test_fn = ( + callbacks.epoch_test_callback.get_trainer_fn(context) + if callbacks.epoch_test_callback + else None + ) + stop_fn = ( + callbacks.epoch_stop_callback.get_trainer_fn(context) + if callbacks.epoch_stop_callback + else None + ) + return OffpolicyTrainer( + policy=world.policy, + train_collector=world.train_collector, + test_collector=world.test_collector, + max_epoch=sampling_config.num_epochs, + step_per_epoch=sampling_config.step_per_epoch, + step_per_collect=sampling_config.step_per_collect, + episode_per_test=sampling_config.num_test_episodes, + batch_size=sampling_config.batch_size, + save_best_fn=policy_persistence.get_save_best_fn(world), + logger=world.logger, + update_per_step=sampling_config.update_per_step, + test_in_train=False, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + verbose=False, + ) + + +class PGAgentFactory(OnPolicyAgentFactory): + def __init__( + self, + params: PGParams, + sampling_config: SamplingConfig, + actor_factory: ActorFactory, + optim_factory: OptimizerFactory, + ): + super().__init__(sampling_config, optim_factory) + self.params = params + self.actor_factory = actor_factory + self.optim_factory = optim_factory + + def _create_policy(self, envs: Environments, device: TDevice) -> PGPolicy: + actor = self.actor_factory.create_module_opt( + envs, + device, + self.optim_factory, + self.params.lr, + ) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim=actor.optim, + optim_factory=self.optim_factory, + ), + ) + return PGPolicy( + actor=actor.module, + optim=actor.optim, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + **kwargs, + ) + + +class ActorCriticAgentFactory( + Generic[TActorCriticParams, TPolicy], + OnPolicyAgentFactory, + ABC, +): + def __init__( + self, + params: TActorCriticParams, + sampling_config: SamplingConfig, + actor_factory: ActorFactory, + critic_factory: CriticFactory, + optimizer_factory: OptimizerFactory, + ): + super().__init__(sampling_config, optim_factory=optimizer_factory) + self.params = params + self.actor_factory = actor_factory + self.critic_factory = critic_factory + self.optim_factory = optimizer_factory + self.critic_use_action = False + + @abstractmethod + def _get_policy_class(self) -> type[TPolicy]: + pass + + def create_actor_critic_module_opt( + self, + envs: Environments, + device: TDevice, + lr: float, + ) -> ActorCriticOpt: + actor = self.actor_factory.create_module(envs, device) + critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) + actor_critic = ActorCritic(actor, critic) + optim = self.optim_factory.create_optimizer(actor_critic, lr) + return ActorCriticOpt(actor_critic, optim) + + @typing.no_type_check + def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: + actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim_factory=self.optim_factory, + optim=actor_critic.optim, + ), + ) + kwargs["actor"] = actor_critic.actor + kwargs["critic"] = actor_critic.critic + kwargs["optim"] = actor_critic.optim + kwargs["action_space"] = envs.get_action_space() + return kwargs + + def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: + policy_class = self._get_policy_class() + return policy_class(**self._create_kwargs(envs, device)) + + +class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]): + def _get_policy_class(self) -> type[A2CPolicy]: + return A2CPolicy + + +class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): + def _get_policy_class(self) -> type[PPOPolicy]: + return PPOPolicy + + +class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]): + def _get_policy_class(self) -> type[NPGPolicy]: + return NPGPolicy + + +class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]): + def _get_policy_class(self) -> type[TRPOPolicy]: + return TRPOPolicy + + +class DiscreteCriticOnlyAgentFactory( + OffPolicyAgentFactory, + Generic[TDiscreteCriticOnlyParams, TPolicy], +): + def __init__( + self, + params: TDiscreteCriticOnlyParams, + sampling_config: SamplingConfig, + model_factory: ModuleFactory, + optim_factory: OptimizerFactory, + ): + super().__init__(sampling_config, optim_factory) + self.params = params + self.model_factory = model_factory + self.optim_factory = optim_factory + + @abstractmethod + def _get_policy_class(self) -> type[TPolicy]: + pass + + @typing.no_type_check + def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: + model = self.model_factory.create_module(envs, device) + optim = self.optim_factory.create_optimizer(model, self.params.lr) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim=optim, + optim_factory=self.optim_factory, + ), + ) + envs.get_type().assert_discrete(self) + action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space()) + policy_class = self._get_policy_class() + return policy_class( + model=model, + optim=optim, + action_space=action_space, + observation_space=envs.get_observation_space(), + **kwargs, + ) + + +class DQNAgentFactory(DiscreteCriticOnlyAgentFactory[DQNParams, DQNPolicy]): + def _get_policy_class(self) -> type[DQNPolicy]: + return DQNPolicy + + +class IQNAgentFactory(DiscreteCriticOnlyAgentFactory[IQNParams, IQNPolicy]): + def _get_policy_class(self) -> type[IQNPolicy]: + return IQNPolicy + + +class DDPGAgentFactory(OffPolicyAgentFactory): + def __init__( + self, + params: DDPGParams, + sampling_config: SamplingConfig, + actor_factory: ActorFactory, + critic_factory: CriticFactory, + optim_factory: OptimizerFactory, + ): + super().__init__(sampling_config, optim_factory) + self.critic_factory = critic_factory + self.actor_factory = actor_factory + self.params = params + self.optim_factory = optim_factory + + def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + actor = self.actor_factory.create_module_opt( + envs, + device, + self.optim_factory, + self.params.actor_lr, + ) + critic = self.critic_factory.create_module_opt( + envs, + device, + True, + self.optim_factory, + self.params.critic_lr, + ) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim_factory=self.optim_factory, + actor=actor, + critic1=critic, + ), + ) + return DDPGPolicy( + actor=actor.module, + actor_optim=actor.optim, + critic=critic.module, + critic_optim=critic.optim, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + **kwargs, + ) + + +class REDQAgentFactory(OffPolicyAgentFactory): + def __init__( + self, + params: REDQParams, + sampling_config: SamplingConfig, + actor_factory: ActorFactory, + critic_ensemble_factory: CriticEnsembleFactory, + optim_factory: OptimizerFactory, + ): + super().__init__(sampling_config, optim_factory) + self.critic_ensemble_factory = critic_ensemble_factory + self.actor_factory = actor_factory + self.params = params + self.optim_factory = optim_factory + + def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + envs.get_type().assert_continuous(self) + actor = self.actor_factory.create_module_opt( + envs, + device, + self.optim_factory, + self.params.actor_lr, + ) + critic_ensemble = self.critic_ensemble_factory.create_module_opt( + envs, + device, + self.params.ensemble_size, + True, + self.optim_factory, + self.params.critic_lr, + ) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim_factory=self.optim_factory, + actor=actor, + critic1=critic_ensemble, + ), + ) + action_space = cast(gymnasium.spaces.Box, envs.get_action_space()) + return REDQPolicy( + actor=actor.module, + actor_optim=actor.optim, + critic=critic_ensemble.module, + critic_optim=critic_ensemble.optim, + action_space=action_space, + observation_space=envs.get_observation_space(), + **kwargs, + ) + + +class ActorDualCriticsAgentFactory( + OffPolicyAgentFactory, + Generic[TActorDualCriticsParams, TPolicy], + ABC, +): + def __init__( + self, + params: TActorDualCriticsParams, + sampling_config: SamplingConfig, + actor_factory: ActorFactory, + critic1_factory: CriticFactory, + critic2_factory: CriticFactory, + optim_factory: OptimizerFactory, + ): + super().__init__(sampling_config, optim_factory) + self.params = params + self.actor_factory = actor_factory + self.critic1_factory = critic1_factory + self.critic2_factory = critic2_factory + self.optim_factory = optim_factory + + @abstractmethod + def _get_policy_class(self) -> type[TPolicy]: + pass + + def _get_discrete_last_size_use_action_shape(self) -> bool: + return True + + @staticmethod + def _get_critic_use_action(envs: Environments) -> bool: + return envs.get_type().is_continuous() + + @typing.no_type_check + def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: + actor = self.actor_factory.create_module_opt( + envs, + device, + self.optim_factory, + self.params.actor_lr, + ) + use_action_shape = self._get_discrete_last_size_use_action_shape() + critic_use_action = self._get_critic_use_action(envs) + critic1 = self.critic1_factory.create_module_opt( + envs, + device, + critic_use_action, + self.optim_factory, + self.params.critic1_lr, + discrete_last_size_use_action_shape=use_action_shape, + ) + critic2 = self.critic2_factory.create_module_opt( + envs, + device, + critic_use_action, + self.optim_factory, + self.params.critic2_lr, + discrete_last_size_use_action_shape=use_action_shape, + ) + kwargs = self.params.create_kwargs( + ParamTransformerData( + envs=envs, + device=device, + optim_factory=self.optim_factory, + actor=actor, + critic1=critic1, + critic2=critic2, + ), + ) + policy_class = self._get_policy_class() + return policy_class( + actor=actor.module, + actor_optim=actor.optim, + critic=critic1.module, + critic_optim=critic1.optim, + critic2=critic2.module, + critic2_optim=critic2.optim, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + **kwargs, + ) + + +class SACAgentFactory(ActorDualCriticsAgentFactory[SACParams, SACPolicy]): + def _get_policy_class(self) -> type[SACPolicy]: + return SACPolicy + + +class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]): + def _get_policy_class(self) -> type[DiscreteSACPolicy]: + return DiscreteSACPolicy + + +class TD3AgentFactory(ActorDualCriticsAgentFactory[TD3Params, TD3Policy]): + def _get_policy_class(self) -> type[TD3Policy]: + return TD3Policy diff --git a/examples/atari/tianshou/highlevel/config.py b/examples/atari/tianshou/highlevel/config.py new file mode 100644 index 0000000000000000000000000000000000000000..951f2f3af2d145829933baf921aa4ea3cedb65eb --- /dev/null +++ b/examples/atari/tianshou/highlevel/config.py @@ -0,0 +1,145 @@ +import multiprocessing +from dataclasses import dataclass + +from tianshou.utils.string import ToStringMixin + + +@dataclass +class SamplingConfig(ToStringMixin): + """Configuration of sampling, epochs, parallelization, buffers, collectors, and batching.""" + + num_epochs: int = 100 + """ + the number of epochs to run training for. An epoch is the outermost iteration level and each + epoch consists of a number of training steps and a test step, where each training step + + * collects environment steps/transitions (collection step), adding them to the (replay) + buffer (see :attr:`step_per_collect`) + * performs one or more gradient updates (see :attr:`update_per_step`), + + and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate + agent performance. + + The number of training steps in each epoch is indirectly determined by + :attr:`step_per_epoch`: As many training steps will be performed as are required in + order to reach :attr:`step_per_epoch` total steps in the training environments. + Specifically, if the number of transitions collected per step is `c` (see + :attr:`step_per_collect`) and :attr:`step_per_epoch` is set to `s`, then the number + of training steps per epoch is `ceil(s / c)`. + + Therefore, if `num_epochs = e`, the total number of environment steps taken during training + can be computed as `e * ceil(s / c) * c`. + """ + + step_per_epoch: int = 30000 + """ + the total number of environment steps to be made per epoch. See :attr:`num_epochs` for + an explanation of epoch semantics. + """ + + batch_size: int | None = 64 + """for off-policy algorithms, this is the number of environment steps/transitions to sample + from the buffer for a gradient update; for on-policy algorithms, its use is algorithm-specific. + On-policy algorithms use the full buffer that was collected in the preceding collection step + but they may use this parameter to perform the gradient update using mini-batches of this size + (causing the gradient to be less accurate, a form of regularization). + + ``batch_size=None`` means that the full buffer is used for the gradient update. This doesn't + make much sense for off-policy algorithms and is not recommended then. For on-policy or offline algorithms, + this means that the full buffer is used for the gradient update (no mini-batching), and + may make sense in some cases. + """ + + num_train_envs: int = -1 + """the number of training environments to use. If set to -1, use number of CPUs/threads.""" + + train_seed: int = 42 + """the seed to use for the training environments.""" + + num_test_envs: int = 1 + """the number of test environments to use""" + + num_test_episodes: int = 1 + """the total number of episodes to collect in each test step (across all test environments). + """ + + buffer_size: int = 4096 + """the total size of the sample/replay buffer, in which environment steps (transitions) are + stored""" + + step_per_collect: int = 2048 + """ + the number of environment steps/transitions to collect in each collection step before the + network update within each training step. + Note that the exact number can be reached only if this is a multiple of the number of + training environments being used, as each training environment will produce the same + (non-zero) number of transitions. + Specifically, if this is set to `n` and `m` training environments are used, then the total + number of transitions collected per collection step is `ceil(n / m) * m =: c`. + + See :attr:`num_epochs` for information on the total number of environment steps being + collected during training. + """ + + repeat_per_collect: int | None = 1 + """ + controls, within one gradient update step of an on-policy algorithm, the number of times an + actual gradient update is applied using the full collected dataset, i.e. if the parameter is + 5, then the collected data shall be used five times to update the policy within the same + training step. + + The parameter is ignored and may be set to None for off-policy and offline algorithms. + """ + + update_per_step: float = 1.0 + """ + for off-policy algorithms only: the number of gradient steps to perform per sample + collected (see :attr:`step_per_collect`). + Specifically, if this is set to `u` and the number of samples collected in the preceding + collection step is `n`, then `round(u * n)` gradient steps will be performed. + + Note that for on-policy algorithms, only a single gradient update is usually performed, + because thereafter, the samples no longer reflect the behavior of the updated policy. + To change the number of gradient updates for an on-policy algorithm, use parameter + :attr:`repeat_per_collect` instead. + """ + + start_timesteps: int = 0 + """ + the number of environment steps to collect before the actual training loop begins + """ + + start_timesteps_random: bool = False + """ + whether to use a random policy (instead of the initial or restored policy to be trained) + when collecting the initial :attr:`start_timesteps` environment steps before training + """ + + replay_buffer_ignore_obs_next: bool = False + + replay_buffer_save_only_last_obs: bool = False + """if True, for the case where the environment outputs stacked frames (e.g. because it + is using a `FrameStack` wrapper), save only the most recent frame so as not to duplicate + observations in buffer memory. Specifically, if the environment outputs observations `obs` with + shape (N, ...), only obs[-1] of shape (...) will be stored. + Frame stacking with a fixed number of frames can then be recreated at the buffer level by setting + :attr:`replay_buffer_stack_num`. + """ + + replay_buffer_stack_num: int = 1 + """ + the number of consecutive environment observations to stack and use as the observation input + to the agent for each time step. Setting this to a value greater than 1 can help agents learn + temporal aspects (e.g. velocities of moving objects for which only positions are observed). + + If the environment already stacks frames (e.g. using a `FrameStack` wrapper), this should either not + be used or should be used in conjunction with :attr:`replay_buffer_save_only_last_obs`. + """ + + @property + def test_seed(self) -> int: + return self.train_seed + self.num_train_envs + + def __post_init__(self) -> None: + if self.num_train_envs == -1: + self.num_train_envs = multiprocessing.cpu_count() diff --git a/examples/atari/tianshou/highlevel/env.py b/examples/atari/tianshou/highlevel/env.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c692f64bd6574e92526133cb519c87956f52e0 --- /dev/null +++ b/examples/atari/tianshou/highlevel/env.py @@ -0,0 +1,488 @@ +import logging +import platform +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from enum import Enum +from typing import Any, TypeAlias, cast + +import gymnasium as gym +import gymnasium.spaces +from gymnasium import Env + +from tianshou.env import ( + BaseVectorEnv, + DummyVectorEnv, + RayVectorEnv, + SubprocVectorEnv, +) +from tianshou.highlevel.persistence import Persistence +from tianshou.utils.net.common import TActionShape +from tianshou.utils.string import ToStringMixin + +TObservationShape: TypeAlias = int | Sequence[int] + +log = logging.getLogger(__name__) + + +class EnvType(Enum): + """Enumeration of environment types.""" + + CONTINUOUS = "continuous" + DISCRETE = "discrete" + + def is_discrete(self) -> bool: + return self == EnvType.DISCRETE + + def is_continuous(self) -> bool: + return self == EnvType.CONTINUOUS + + def assert_continuous(self, requiring_entity: Any) -> None: + if not self.is_continuous(): + raise AssertionError(f"{requiring_entity} requires continuous environments") + + def assert_discrete(self, requiring_entity: Any) -> None: + if not self.is_discrete(): + raise AssertionError(f"{requiring_entity} requires discrete environments") + + @staticmethod + def from_env(env: Env) -> "EnvType": + if isinstance(env.action_space, gymnasium.spaces.Discrete): + return EnvType.DISCRETE + elif isinstance(env.action_space, gymnasium.spaces.Box): + return EnvType.CONTINUOUS + else: + raise Exception(f"Unsupported environment type with action space {env.action_space}") + + +class EnvMode(Enum): + """Indicates the purpose for which an environment is created.""" + + TRAIN = "train" + TEST = "test" + WATCH = "watch" + + +class VectorEnvType(Enum): + DUMMY = "dummy" + """Vectorized environment without parallelization; environments are processed sequentially""" + SUBPROC = "subproc" + """Parallelization based on `subprocess`""" + SUBPROC_SHARED_MEM_DEFAULT_CONTEXT = "shmem" + """Parallelization based on `subprocess` with shared memory""" + SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork" + """Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn` + by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)""" + RAY = "ray" + """Parallelization based on the `ray` library""" + SUBPROC_SHARED_MEM_AUTO = "subproc_shared_mem_auto" + """Parallelization based on `subprocess` with shared memory, using default context on windows and fork context otherwise""" + + def create_venv( + self, + factories: Sequence[Callable[[], gym.Env]], + ) -> BaseVectorEnv: + match self: + case VectorEnvType.DUMMY: + return DummyVectorEnv(factories) + case VectorEnvType.SUBPROC: + return SubprocVectorEnv(factories) + case VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT: + return SubprocVectorEnv(factories, share_memory=True) + case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT: + return SubprocVectorEnv(factories, share_memory=True, context="fork") + case VectorEnvType.SUBPROC_SHARED_MEM_AUTO: + if platform.system().lower() == "windows": + selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT + else: + selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT + return selected_venv_type.create_venv(factories) + case VectorEnvType.RAY: + return RayVectorEnv(factories) + case _: + raise NotImplementedError(self) + + +class Environments(ToStringMixin, ABC): + """Represents (vectorized) environments for a learning process.""" + + def __init__( + self, + env: gym.Env, + train_envs: BaseVectorEnv, + test_envs: BaseVectorEnv, + watch_env: BaseVectorEnv | None = None, + ): + self.env = env + self.train_envs = train_envs + self.test_envs = test_envs + self.watch_env = watch_env + self.persistence: Sequence[Persistence] = [] + + @staticmethod + def from_factory_and_type( + factory_fn: Callable[[EnvMode], gym.Env], + env_type: EnvType, + venv_type: VectorEnvType, + num_training_envs: int, + num_test_envs: int, + create_watch_env: bool = False, + ) -> "Environments": + """Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete). + + :param factory_fn: the factory for a single environment instance + :param env_type: the type of environments created by `factory_fn` + :param venv_type: the vector environment type to use for parallelization + :param num_training_envs: the number of training environments to create + :param num_test_envs: the number of test environments to create + :param create_watch_env: whether to create an environment for watching the agent + :return: the instance + """ + train_envs = venv_type.create_venv( + [lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs, + ) + test_envs = venv_type.create_venv( + [lambda: factory_fn(EnvMode.TEST)] * num_test_envs, + ) + if create_watch_env: + watch_env = VectorEnvType.DUMMY.create_venv([lambda: factory_fn(EnvMode.WATCH)]) + else: + watch_env = None + env = factory_fn(EnvMode.TRAIN) + match env_type: + case EnvType.CONTINUOUS: + return ContinuousEnvironments(env, train_envs, test_envs, watch_env) + case EnvType.DISCRETE: + return DiscreteEnvironments(env, train_envs, test_envs, watch_env) + case _: + raise ValueError(f"Environment type {env_type} not handled") + + def _tostring_includes(self) -> list[str]: + return [] + + def _tostring_additional_entries(self) -> dict[str, Any]: + return self.info() + + def info(self) -> dict[str, Any]: + return { + "action_shape": self.get_action_shape(), + "state_shape": self.get_observation_shape(), + } + + def set_persistence(self, *p: Persistence) -> None: + """Associates the given persistence handlers which may persist and restore environment-specific information. + + :param p: persistence handlers + """ + self.persistence = p + + @abstractmethod + def get_action_shape(self) -> TActionShape: + pass + + @abstractmethod + def get_observation_shape(self) -> TObservationShape: + pass + + def get_action_space(self) -> gym.Space: + return self.env.action_space + + def get_observation_space(self) -> gym.Space: + return self.env.observation_space + + @abstractmethod + def get_type(self) -> EnvType: + pass + + +class ContinuousEnvironments(Environments): + """Represents (vectorized) continuous environments.""" + + def __init__( + self, + env: gym.Env, + train_envs: BaseVectorEnv, + test_envs: BaseVectorEnv, + watch_env: BaseVectorEnv | None = None, + ): + super().__init__(env, train_envs, test_envs, watch_env) + self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) + + @staticmethod + def from_factory( + factory_fn: Callable[[EnvMode], gym.Env], + venv_type: VectorEnvType, + num_training_envs: int, + num_test_envs: int, + create_watch_env: bool = False, + ) -> "ContinuousEnvironments": + """Creates an instance from a factory function that creates a single instance. + + :param factory_fn: the factory for a single environment instance + :param venv_type: the vector environment type to use for parallelization + :param num_training_envs: the number of training environments to create + :param num_test_envs: the number of test environments to create + :param create_watch_env: whether to create an environment for watching the agent + :return: the instance + """ + return cast( + ContinuousEnvironments, + Environments.from_factory_and_type( + factory_fn, + EnvType.CONTINUOUS, + venv_type, + num_training_envs, + num_test_envs, + create_watch_env, + ), + ) + + def info(self) -> dict[str, Any]: + d = super().info() + d["max_action"] = self.max_action + return d + + @staticmethod + def _get_continuous_env_info( + env: gym.Env, + ) -> tuple[tuple[int, ...], tuple[int, ...], float]: + if not isinstance(env.action_space, gym.spaces.Box): + raise ValueError( + "Only environments with continuous action space are supported here. " + f"But got env with action space: {env.action_space.__class__}.", + ) + state_shape = env.observation_space.shape or env.observation_space.n # type: ignore + if not state_shape: + raise ValueError("Observation space shape is not defined") + action_shape = env.action_space.shape + max_action = env.action_space.high[0] + return state_shape, action_shape, max_action + + def get_action_shape(self) -> TActionShape: + return self.action_shape + + def get_observation_shape(self) -> TObservationShape: + return self.state_shape + + def get_type(self) -> EnvType: + return EnvType.CONTINUOUS + + +class DiscreteEnvironments(Environments): + """Represents (vectorized) discrete environments.""" + + def __init__( + self, + env: gym.Env, + train_envs: BaseVectorEnv, + test_envs: BaseVectorEnv, + watch_env: BaseVectorEnv | None = None, + ): + super().__init__(env, train_envs, test_envs, watch_env) + self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore + self.action_shape = env.action_space.shape or env.action_space.n # type: ignore + + @staticmethod + def from_factory( + factory_fn: Callable[[EnvMode], gym.Env], + venv_type: VectorEnvType, + num_training_envs: int, + num_test_envs: int, + create_watch_env: bool = False, + ) -> "DiscreteEnvironments": + """Creates an instance from a factory function that creates a single instance. + + :param factory_fn: the factory for a single environment instance + :param venv_type: the vector environment type to use for parallelization + :param num_training_envs: the number of training environments to create + :param num_test_envs: the number of test environments to create + :param create_watch_env: whether to create an environment for watching the agent + :return: the instance + """ + return cast( + DiscreteEnvironments, + Environments.from_factory_and_type( + factory_fn, + EnvType.DISCRETE, + venv_type, + num_training_envs, + num_test_envs, + create_watch_env, + ), + ) + + def get_action_shape(self) -> TActionShape: + return self.action_shape + + def get_observation_shape(self) -> TObservationShape: + return self.observation_shape + + def get_type(self) -> EnvType: + return EnvType.DISCRETE + + +class EnvPoolFactory: + """A factory for the creation of envpool-based vectorized environments, which can be used in conjunction + with :class:`EnvFactoryRegistered`. + """ + + def _transform_task(self, task: str) -> str: + return task + + def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict: + """Transforms gymnasium keyword arguments to be envpool-compatible. + + :param kwargs: keyword arguments that would normally be passed to `gymnasium.make`. + :param mode: the environment mode + :return: the transformed keyword arguments + """ + kwargs = dict(kwargs) + if "render_mode" in kwargs: + del kwargs["render_mode"] + return kwargs + + def create_venv( + self, + task: str, + num_envs: int, + mode: EnvMode, + seed: int, + kwargs: dict, + ) -> BaseVectorEnv: + import envpool + + envpool_task = self._transform_task(task) + envpool_kwargs = self._transform_kwargs(kwargs, mode) + return envpool.make_gymnasium( + envpool_task, + num_envs=num_envs, + seed=seed, + **envpool_kwargs, + ) + + +class EnvFactory(ToStringMixin, ABC): + """Main interface for the creation of environments (in various forms).""" + + def __init__(self, venv_type: VectorEnvType): + """:param venv_type: the type of vectorized environment to use for train and test environments. + watch environments are always created as dummy environments. + """ + self.venv_type = venv_type + + @abstractmethod + def create_env(self, mode: EnvMode) -> Env: + pass + + def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: + """Create vectorized environments. + + :param num_envs: the number of environments + :param mode: the mode for which to create. In `WATCH` mode the resulting venv will always be of type `DUMMY` with a single env. + + :return: the vectorized environments + """ + if mode == EnvMode.WATCH: + return VectorEnvType.DUMMY.create_venv([lambda: self.create_env(mode)]) + else: + return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs) + + def create_envs( + self, + num_training_envs: int, + num_test_envs: int, + create_watch_env: bool = False, + ) -> Environments: + """Create environments for learning. + + :param num_training_envs: the number of training environments + :param num_test_envs: the number of test environments + :param create_watch_env: whether to create an environment for watching the agent + :return: the environments + """ + env = self.create_env(EnvMode.TRAIN) + train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN) + test_envs = self.create_venv(num_test_envs, EnvMode.TEST) + watch_env = self.create_venv(1, EnvMode.WATCH) if create_watch_env else None + match EnvType.from_env(env): + case EnvType.DISCRETE: + return DiscreteEnvironments(env, train_envs, test_envs, watch_env) + case EnvType.CONTINUOUS: + return ContinuousEnvironments(env, train_envs, test_envs, watch_env) + case _: + raise ValueError + + +class EnvFactoryRegistered(EnvFactory): + """Factory for environments that are registered with gymnasium and thus can be created via `gymnasium.make` + (or via `envpool.make_gymnasium`). + """ + + def __init__( + self, + *, + task: str, + train_seed: int, + test_seed: int, + venv_type: VectorEnvType, + envpool_factory: EnvPoolFactory | None = None, + render_mode_train: str | None = None, + render_mode_test: str | None = None, + render_mode_watch: str = "human", + **make_kwargs: Any, + ): + """:param task: the gymnasium task/environment identifier + :param seed: the random seed + :param venv_type: the type of vectorized environment to use (if `envpool_factory` is not specified) + :param envpool_factory: the factory to use for vectorized environment creation based on envpool; envpool must be installed. + :param render_mode_train: the render mode to use for training environments + :param render_mode_test: the render mode to use for test environments + :param render_mode_watch: the render mode to use for environments that are used to watch agent performance + :param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`. + If envpool is used, the gymnasium parameters will be appropriately translated for use with + `envpool.make_gymnasium`. + """ + super().__init__(venv_type) + self.task = task + self.envpool_factory = envpool_factory + self.train_seed = train_seed + self.test_seed = test_seed + self.render_modes = { + EnvMode.TRAIN: render_mode_train, + EnvMode.TEST: render_mode_test, + EnvMode.WATCH: render_mode_watch, + } + self.make_kwargs = make_kwargs + + def _create_kwargs(self, mode: EnvMode) -> dict: + """Adapts the keyword arguments for the given mode. + + :param mode: the mode + :return: adapted keyword arguments + """ + kwargs = dict(self.make_kwargs) + kwargs["render_mode"] = self.render_modes.get(mode) + return kwargs + + def create_env(self, mode: EnvMode) -> Env: + """Creates a single environment for the given mode. + + :param mode: the mode + :return: an environment + """ + kwargs = self._create_kwargs(mode) + return gymnasium.make(self.task, **kwargs) + + def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: + seed = self.train_seed if mode == EnvMode.TRAIN else self.test_seed + if self.envpool_factory is not None: + return self.envpool_factory.create_venv( + self.task, + num_envs, + mode, + seed, + self._create_kwargs(mode), + ) + else: + venv = super().create_venv(num_envs, mode) + venv.seed(seed) + return venv diff --git a/examples/atari/tianshou/highlevel/experiment.py b/examples/atari/tianshou/highlevel/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..dbcd3f1565a38f78833eeec17543fc1b34c9cd50 --- /dev/null +++ b/examples/atari/tianshou/highlevel/experiment.py @@ -0,0 +1,1246 @@ +import os +import pickle +from abc import abstractmethod +from collections.abc import Sequence +from copy import deepcopy +from dataclasses import dataclass +from pprint import pformat +from typing import TYPE_CHECKING, Any, Self, Union, cast + +import numpy as np +import torch + +from tianshou.data import Collector, InfoStats +from tianshou.env import BaseVectorEnv +from tianshou.highlevel.agent import ( + A2CAgentFactory, + AgentFactory, + DDPGAgentFactory, + DiscreteSACAgentFactory, + DQNAgentFactory, + IQNAgentFactory, + NPGAgentFactory, + PGAgentFactory, + PPOAgentFactory, + REDQAgentFactory, + SACAgentFactory, + TD3AgentFactory, + TRPOAgentFactory, +) +from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.env import EnvFactory +from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger +from tianshou.highlevel.module.actor import ( + ActorFactory, + ActorFactoryDefault, + ActorFactoryTransientStorageDecorator, + ActorFuture, + ActorFutureProviderProtocol, + ContinuousActorType, + IntermediateModuleFactoryFromActorFactory, +) +from tianshou.highlevel.module.core import ( + TDevice, +) +from tianshou.highlevel.module.critic import ( + CriticEnsembleFactory, + CriticEnsembleFactoryDefault, + CriticFactory, + CriticFactoryDefault, + CriticFactoryReuseActor, +) +from tianshou.highlevel.module.intermediate import IntermediateModuleFactory +from tianshou.highlevel.module.special import ImplicitQuantileNetworkFactory +from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam +from tianshou.highlevel.params.policy_params import ( + A2CParams, + DDPGParams, + DiscreteSACParams, + DQNParams, + IQNParams, + NPGParams, + PGParams, + PPOParams, + REDQParams, + SACParams, + TD3Params, + TRPOParams, +) +from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory +from tianshou.highlevel.persistence import ( + PersistenceGroup, + PolicyPersistence, +) +from tianshou.highlevel.trainer import ( + EpochStopCallback, + EpochTestCallback, + EpochTrainCallback, + TrainerCallbacks, +) +from tianshou.highlevel.world import World +from tianshou.policy import BasePolicy +from tianshou.utils import LazyLogger, deprecation, logging +from tianshou.utils.logging import datetime_tag +from tianshou.utils.net.common import ModuleType +from tianshou.utils.string import ToStringMixin + +if TYPE_CHECKING: + from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher + +log = logging.getLogger(__name__) + + +@dataclass +class ExperimentConfig: + """Generic config for setting up the experiment, not RL or training specific.""" + + seed: int = 42 + """The random seed with which to initialize random number generators.""" + device: TDevice = "cuda" if torch.cuda.is_available() else "cpu" + """The torch device to use""" + policy_restore_directory: str | None = None + """Directory from which to load the policy neural network parameters (persistence directory of a previous run)""" + train: bool = True + """Whether to perform training""" + watch: bool = True + """Whether to watch agent performance (after training)""" + watch_num_episodes: int = 10 + """Number of episodes for which to watch performance (if `watch` is enabled)""" + watch_render: float = 0.0 + """Milliseconds between rendered frames when watching agent performance (if `watch` is enabled)""" + persistence_base_dir: str = "log" + """Base directory in which experiment data is to be stored. Every experiment run will create a subdirectory + in this directory based on the run's experiment name""" + persistence_enabled: bool = True + """Whether persistence is enabled, allowing files to be stored""" + log_file_enabled: bool = True + """Whether to write to a log file; has no effect if `persistence_enabled` is False. + Disable this if you have externally configured log file generation.""" + policy_persistence_mode: PolicyPersistence.Mode = PolicyPersistence.Mode.POLICY + """Controls the way in which the policy is persisted""" + + +@dataclass +class ExperimentResult: + """Contains the results of an experiment.""" + + world: World + """contains all the essential instances of the experiment""" + trainer_result: InfoStats | None + """dataclass of results as returned by the trainer (if any)""" + + +class Experiment(ToStringMixin): + """Represents a reinforcement learning experiment. + + An experiment is composed only of configuration and factory objects, which themselves + should be designed to contain only configuration. Therefore, experiments can easily + be stored/pickled and later restored without any problems. + """ + + LOG_FILENAME = "log.txt" + EXPERIMENT_PICKLE_FILENAME = "experiment.pkl" + + def __init__( + self, + config: ExperimentConfig, + env_factory: EnvFactory, + agent_factory: AgentFactory, + sampling_config: SamplingConfig, + name: str, + logger_factory: LoggerFactory | None = None, + ): + if logger_factory is None: + logger_factory = LoggerFactoryDefault() + self.config = config + self.sampling_config = sampling_config + self.env_factory = env_factory + self.agent_factory = agent_factory + self.logger_factory = logger_factory + self.name = name + + @classmethod + def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment": + """Restores an experiment from a previously stored pickle. + + :param directory: persistence directory of a previous run, in which a pickled experiment is found + :param restore_policy: whether the experiment shall be configured to restore the policy that was + persisted in the given directory + """ + with open(os.path.join(directory, cls.EXPERIMENT_PICKLE_FILENAME), "rb") as f: + experiment: Experiment = pickle.load(f) + if restore_policy: + experiment.config.policy_restore_directory = directory + return experiment + + def get_seeding_info_as_str(self) -> str: + """Returns information on the seeds used in the experiment as a string. + + This can be useful for creating unique experiment names based on seeds, e.g. + A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`. + """ + return "_".join( + [ + f"exp_seed={self.config.seed}", + f"train_seed={self.sampling_config.train_seed}", + f"test_seed={self.sampling_config.test_seed}", + ], + ) + + def _set_seed(self) -> None: + seed = self.config.seed + log.info(f"Setting random seed {seed}") + np.random.seed(seed) + torch.manual_seed(seed) + + def _build_config_dict(self) -> dict: + return {"experiment": self.pprints()} + + def save(self, directory: str) -> None: + path = os.path.join(directory, self.EXPERIMENT_PICKLE_FILENAME) + log.info( + f"Saving serialized experiment in {path}; can be restored via Experiment.from_directory('{directory}')", + ) + with open(path, "wb") as f: + pickle.dump(self, f) + + def run( + self, + run_name: str | None = None, + logger_run_id: str | None = None, + raise_error_on_dirname_collision: bool = True, + **kwargs: dict[str, Any], + ) -> ExperimentResult: + """Run the experiment and return the results. + + :param run_name: Defines a name for this run of the experiment, which determines + the subdirectory (within the persistence base directory) where all results will be saved. + If None, the experiment's name will be used. + The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case + a nested directory structure will be created. + :param logger_run_id: Run identifier to use for logger initialization/resumption (applies when + using wandb, in particular). + :param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed + experiment with the same name. + :param kwargs: for backward compatibility with old parameter names only + :return: + """ + # backward compatibility + _experiment_name = kwargs.pop("experiment_name", None) + if _experiment_name is not None: + run_name = cast(str, _experiment_name) + deprecation( + "Parameter run_name should now be used instead of experiment_name. " + "Support for experiment_name will be removed in the future.", + ) + assert len(kwargs) == 0, f"Received unexpected arguments: {kwargs}" + + if run_name is None: + run_name = self.name + + # initialize persistence directory + use_persistence = self.config.persistence_enabled + persistence_dir = os.path.join(self.config.persistence_base_dir, run_name) + if use_persistence: + os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision) + + with logging.FileLoggerContext( + os.path.join(persistence_dir, self.LOG_FILENAME), + enabled=use_persistence and self.config.log_file_enabled, + ): + # log initial information + log.info(f"Running experiment (name='{run_name}'):\n{self.pprints()}") + log.info(f"Working directory: {os.getcwd()}") + + self._set_seed() + + # create environments + envs = self.env_factory.create_envs( + self.sampling_config.num_train_envs, + self.sampling_config.num_test_envs, + create_watch_env=self.config.watch, + ) + log.info(f"Created {envs}") + + # initialize persistence + additional_persistence = PersistenceGroup(*envs.persistence, enabled=use_persistence) + policy_persistence = PolicyPersistence( + additional_persistence, + enabled=use_persistence, + mode=self.config.policy_persistence_mode, + ) + if use_persistence: + log.info(f"Persistence directory: {os.path.abspath(persistence_dir)}") + self.save(persistence_dir) + + # initialize logger + full_config = self._build_config_dict() + full_config.update(envs.info()) + logger: TLogger + if use_persistence: + logger = self.logger_factory.create_logger( + log_dir=persistence_dir, + experiment_name=run_name, + run_id=logger_run_id, + config_dict=full_config, + ) + else: + logger = LazyLogger() + + # create policy and collectors + log.info("Creating policy") + policy = self.agent_factory.create_policy(envs, self.config.device) + log.info("Creating collectors") + train_collector, test_collector = self.agent_factory.create_train_test_collector( + policy, + envs, + ) + + # create context object with all relevant instances (except trainer; added later) + world = World( + envs=envs, + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + logger=logger, + persist_directory=persistence_dir, + restore_directory=self.config.policy_restore_directory, + ) + + # restore policy parameters if applicable + if self.config.policy_restore_directory: + policy_persistence.restore( + policy, + world, + self.config.device, + ) + + # train policy + log.info("Starting training") + trainer_result: InfoStats | None = None + if self.config.train: + trainer = self.agent_factory.create_trainer(world, policy_persistence) + world.trainer = trainer + trainer_result = trainer.run() + log.info(f"Training result:\n{pformat(trainer_result)}") + + # watch agent performance + if self.config.watch: + assert envs.watch_env is not None + log.info("Watching agent performance") + self._watch_agent( + self.config.watch_num_episodes, + policy, + envs.watch_env, + self.config.watch_render, + ) + + return ExperimentResult(world=world, trainer_result=trainer_result) + + @staticmethod + def _watch_agent( + num_episodes: int, + policy: BasePolicy, + env: BaseVectorEnv, + render: float, + ) -> None: + collector = Collector(policy, env) + collector.reset() + result = collector.collect(n_episode=num_episodes, render=render) + assert result.returns_stat is not None # for mypy + assert result.lens_stat is not None # for mypy + log.info( + f"Watched episodes: mean reward={result.returns_stat.mean}, mean episode length={result.lens_stat.mean}", + ) + + +class ExperimentCollection: + """Shallow wrapper around a list of experiments providing a simple interface for running them with a launcher.""" + + def __init__(self, experiments: list[Experiment]): + self.experiments = experiments + + def run( + self, + launcher: Union["ExpLauncher", "RegisteredExpLauncher"], + ) -> list[InfoStats | None]: + from tianshou.evaluation.launcher import RegisteredExpLauncher + + if isinstance(launcher, RegisteredExpLauncher): + launcher = launcher.create_launcher() + return launcher.launch(experiments=self.experiments) + + +class ExperimentBuilder: + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + if experiment_config is None: + experiment_config = ExperimentConfig() + if sampling_config is None: + sampling_config = SamplingConfig() + self._config = experiment_config + self._env_factory = env_factory + self._sampling_config = sampling_config + self._logger_factory: LoggerFactory | None = None + self._optim_factory: OptimizerFactory | None = None + self._policy_wrapper_factory: PolicyWrapperFactory | None = None + self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() + self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag() + + def copy(self) -> Self: + return deepcopy(self) + + @property + def experiment_config(self) -> ExperimentConfig: + return self._config + + @experiment_config.setter + def experiment_config(self, experiment_config: ExperimentConfig) -> None: + self._config = experiment_config + + @property + def sampling_config(self) -> SamplingConfig: + return self._sampling_config + + @sampling_config.setter + def sampling_config(self, sampling_config: SamplingConfig) -> None: + self._sampling_config = sampling_config + + def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: + """Allows to customize the logger factory to use. + + If this method is not called, the default logger factory :class:`LoggerFactoryDefault` will be used. + + :param logger_factory: the factory to use + :return: the builder + """ + self._logger_factory = logger_factory + return self + + def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self: + """Allows to define a wrapper around the policy that is created, extending the original policy. + + :param policy_wrapper_factory: the factory for the wrapper + :return: the builder + """ + self._policy_wrapper_factory = policy_wrapper_factory + return self + + def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self: + """Allows to customize the gradient-based optimizer to use. + + By default, :class:`OptimizerFactoryAdam` will be used with default parameters. + + :param optim_factory: the optimizer factory + :return: the builder + """ + self._optim_factory = optim_factory + return self + + def with_optim_factory_default( + self, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-08, + weight_decay: float = 0, + ) -> Self: + """Configures the use of the default optimizer, Adam, with the given parameters. + + :param betas: coefficients used for computing running averages of gradient and its square + :param eps: term added to the denominator to improve numerical stability + :param weight_decay: weight decay (L2 penalty) + :return: the builder + """ + self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay) + return self + + def with_epoch_train_callback(self, callback: EpochTrainCallback) -> Self: + """Allows to define a callback function which is called at the beginning of every epoch during training. + + :param callback: the callback + :return: the builder + """ + self._trainer_callbacks.epoch_train_callback = callback + return self + + def with_epoch_test_callback(self, callback: EpochTestCallback) -> Self: + """Allows to define a callback function which is called at the beginning of testing in each epoch. + + :param callback: the callback + :return: the builder + """ + self._trainer_callbacks.epoch_test_callback = callback + return self + + def with_epoch_stop_callback(self, callback: EpochStopCallback) -> Self: + """Allows to define a callback that decides whether training shall stop early. + + The callback receives the undiscounted returns of the testing result. + + :param callback: the callback + :return: the builder + """ + self._trainer_callbacks.epoch_stop_callback = callback + return self + + def with_name( + self, + name: str, + ) -> Self: + """Sets the name of the experiment. + + :param name: the name to use for this experiment, which, when the experiment is run, + will determine the storage sub-folder by default + :return: the builder + """ + self._name = name + return self + + @abstractmethod + def _create_agent_factory(self) -> AgentFactory: + pass + + def _get_optim_factory(self) -> OptimizerFactory: + if self._optim_factory is None: + return OptimizerFactoryAdam() + else: + return self._optim_factory + + def build(self) -> Experiment: + """Creates the experiment based on the options specified via this builder. + + :return: the experiment + """ + agent_factory = self._create_agent_factory() + agent_factory.set_trainer_callbacks(self._trainer_callbacks) + if self._policy_wrapper_factory: + agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) + experiment: Experiment = Experiment( + config=self._config, + env_factory=self._env_factory, + agent_factory=agent_factory, + sampling_config=self._sampling_config, + name=self._name, + logger_factory=self._logger_factory, + ) + return experiment + + def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: + """Creates a collection of experiments with non-overlapping random seeds, starting from the configured seed. + + Each experiment in the collection will have a unique name that is created from the original experiment name and the seeds used. + """ + num_train_envs = self.sampling_config.num_train_envs + + seeded_experiments = [] + for i in range(num_experiments): + builder = self.copy() + builder.experiment_config.seed += i + builder.sampling_config.train_seed += i * num_train_envs + experiment = builder.build() + experiment.name += f"_{experiment.get_seeding_info_as_str()}" + seeded_experiments.append(experiment) + return ExperimentCollection(seeded_experiments) + + +class _BuilderMixinActorFactory(ActorFutureProviderProtocol): + def __init__(self, continuous_actor_type: ContinuousActorType): + self._continuous_actor_type = continuous_actor_type + self._actor_future = ActorFuture() + self._actor_factory: ActorFactory | None = None + + def with_actor_factory(self, actor_factory: ActorFactory) -> Self: + """Allows to customize the actor component via the specification of a factory. + + If this function is not called, a default actor factory (with default parameters) will be used. + + :param actor_factory: the factory to use for the creation of the actor network + :return: the builder + """ + self._actor_factory = actor_factory + return self + + def _with_actor_factory_default( + self, + hidden_sizes: Sequence[int], + hidden_activation: ModuleType = torch.nn.ReLU, + continuous_unbounded: bool = False, + continuous_conditioned_sigma: bool = False, + ) -> Self: + """Adds a default actor factory with the given parameters. + + :param hidden_sizes: the sequence of hidden dimensions to use in the network structure + :param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits + :param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma) + shall be computed from the input; if False, sigma is an independent parameter. + :return: the builder + """ + self._actor_factory = ActorFactoryDefault( + self._continuous_actor_type, + hidden_sizes, + hidden_activation=hidden_activation, + continuous_unbounded=continuous_unbounded, + continuous_conditioned_sigma=continuous_conditioned_sigma, + ) + return self + + def get_actor_future(self) -> ActorFuture: + """:return: an object, which, in the future, will contain the actor instance that is created for the experiment.""" + return self._actor_future + + def _get_actor_factory(self) -> ActorFactory: + actor_factory: ActorFactory + if self._actor_factory is None: + actor_factory = ActorFactoryDefault(self._continuous_actor_type) + else: + actor_factory = self._actor_factory + return ActorFactoryTransientStorageDecorator(actor_factory, self._actor_future) + + +class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory): + """Specialization of the actor mixin where, in the continuous case, the actor component outputs Gaussian distribution parameters.""" + + def __init__(self) -> None: + super().__init__(ContinuousActorType.GAUSSIAN) + + def with_actor_factory_default( + self, + hidden_sizes: Sequence[int], + hidden_activation: ModuleType = torch.nn.ReLU, + continuous_unbounded: bool = False, + continuous_conditioned_sigma: bool = False, + ) -> Self: + """Defines use of the default actor factory, allowing its parameters it to be customized. + + The default actor factory uses an MLP-style architecture. + + :param hidden_sizes: dimensions of hidden layers used by the network + :param hidden_activation: the activation function to use for hidden layers + :param continuous_unbounded: whether, for continuous action spaces, to apply tanh activation on final logits + :param continuous_conditioned_sigma: whether, for continuous action spaces, the standard deviation of continuous actions (sigma) + shall be computed from the input; if False, sigma is an independent parameter. + :return: the builder + """ + return super()._with_actor_factory_default( + hidden_sizes, + hidden_activation=hidden_activation, + continuous_unbounded=continuous_unbounded, + continuous_conditioned_sigma=continuous_conditioned_sigma, + ) + + +class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory): + """Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy.""" + + def __init__(self) -> None: + super().__init__(ContinuousActorType.DETERMINISTIC) + + def with_actor_factory_default( + self, + hidden_sizes: Sequence[int], + hidden_activation: ModuleType = torch.nn.ReLU, + ) -> Self: + """Defines use of the default actor factory, allowing its parameters it to be customized. + + The default actor factory uses an MLP-style architecture. + + :param hidden_sizes: dimensions of hidden layers used by the network + :param hidden_activation: the activation function to use for hidden layers + :return: the builder + """ + return super()._with_actor_factory_default(hidden_sizes, hidden_activation) + + +class _BuilderMixinCriticsFactory: + def __init__(self, num_critics: int, actor_future_provider: ActorFutureProviderProtocol): + self._actor_future_provider = actor_future_provider + self._critic_factories: list[CriticFactory | None] = [None] * num_critics + + def _with_critic_factory(self, idx: int, critic_factory: CriticFactory) -> Self: + self._critic_factories[idx] = critic_factory + return self + + def _with_critic_factory_default( + self, + idx: int, + hidden_sizes: Sequence[int], + hidden_activation: ModuleType = torch.nn.ReLU, + ) -> Self: + self._critic_factories[idx] = CriticFactoryDefault( + hidden_sizes, + hidden_activation=hidden_activation, + ) + return self + + def _with_critic_factory_use_actor(self, idx: int) -> Self: + self._critic_factories[idx] = CriticFactoryReuseActor( + self._actor_future_provider.get_actor_future(), + ) + return self + + def _get_critic_factory(self, idx: int) -> CriticFactory: + factory = self._critic_factories[idx] + if factory is None: + return CriticFactoryDefault() + else: + return factory + + +class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): + def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None: + super().__init__(1, actor_future_provider) + + def with_critic_factory(self, critic_factory: CriticFactory) -> Self: + """Specifies that the given factory shall be used for the critic. + + :param critic_factory: the critic factory + :return: the builder + """ + self._with_critic_factory(0, critic_factory) + return self + + def with_critic_factory_default( + self, + hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, + hidden_activation: ModuleType = torch.nn.ReLU, + ) -> Self: + """Makes the critic use the default, MLP-style architecture with the given parameters. + + :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network + :param hidden_activation: the activation function to use for hidden layers + :return: the builder + """ + self._with_critic_factory_default(0, hidden_sizes, hidden_activation) + return self + + +class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory): + def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None: + super().__init__(actor_future_provider) + + def with_critic_factory_use_actor(self) -> Self: + """Makes the first critic reuse the actor's preprocessing network (parameter sharing).""" + return self._with_critic_factory_use_actor(0) + + +class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory): + def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None: + super().__init__(2, actor_future_provider) + + def with_common_critic_factory(self, critic_factory: CriticFactory) -> Self: + """Specifies that the given factory shall be used for both critics. + + :param critic_factory: the critic factory + :return: the builder + """ + for i in range(len(self._critic_factories)): + self._with_critic_factory(i, critic_factory) + return self + + def with_common_critic_factory_default( + self, + hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, + hidden_activation: ModuleType = torch.nn.ReLU, + ) -> Self: + """Makes both critics use the default, MLP-style architecture with the given parameters. + + :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network + :param hidden_activation: the activation function to use for hidden layers + :return: the builder + """ + for i in range(len(self._critic_factories)): + self._with_critic_factory_default(i, hidden_sizes, hidden_activation) + return self + + def with_common_critic_factory_use_actor(self) -> Self: + """Makes both critics reuse the actor's preprocessing network (parameter sharing).""" + for i in range(len(self._critic_factories)): + self._with_critic_factory_use_actor(i) + return self + + def with_critic1_factory(self, critic_factory: CriticFactory) -> Self: + """Specifies that the given factory shall be used for the first critic. + + :param critic_factory: the critic factory + :return: the builder + """ + self._with_critic_factory(0, critic_factory) + return self + + def with_critic1_factory_default( + self, + hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, + hidden_activation: ModuleType = torch.nn.ReLU, + ) -> Self: + """Makes the first critic use the default, MLP-style architecture with the given parameters. + + :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network + :param hidden_activation: the activation function to use for hidden layers + :return: the builder + """ + self._with_critic_factory_default(0, hidden_sizes, hidden_activation) + return self + + def with_critic1_factory_use_actor(self) -> Self: + """Makes the first critic reuse the actor's preprocessing network (parameter sharing).""" + return self._with_critic_factory_use_actor(0) + + def with_critic2_factory(self, critic_factory: CriticFactory) -> Self: + """Specifies that the given factory shall be used for the second critic. + + :param critic_factory: the critic factory + :return: the builder + """ + self._with_critic_factory(1, critic_factory) + return self + + def with_critic2_factory_default( + self, + hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, + hidden_activation: ModuleType = torch.nn.ReLU, + ) -> Self: + """Makes the second critic use the default, MLP-style architecture with the given parameters. + + :param hidden_sizes: the sequence of dimensions to use in hidden layers of the network + :param hidden_activation: the activation function to use for hidden layers + :return: the builder + """ + self._with_critic_factory_default(1, hidden_sizes, hidden_activation) + return self + + def with_critic2_factory_use_actor(self) -> Self: + """Makes the first critic reuse the actor's preprocessing network (parameter sharing).""" + return self._with_critic_factory_use_actor(1) + + +class _BuilderMixinCriticEnsembleFactory: + def __init__(self) -> None: + self.critic_ensemble_factory: CriticEnsembleFactory | None = None + + def with_critic_ensemble_factory(self, factory: CriticEnsembleFactory) -> Self: + """Specifies that the given factory shall be used for the critic ensemble. + + If unspecified, the default factory (:class:`CriticEnsembleFactoryDefault`) is used. + + :param factory: the critic ensemble factory + :return: the builder + """ + self.critic_ensemble_factory = factory + return self + + def with_critic_ensemble_factory_default( + self, + hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES, + ) -> Self: + """Allows to customize the parameters of the default critic ensemble factory. + + :param hidden_sizes: the sequence of sizes of hidden layers in the network architecture + :return: the builder + """ + self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes) + return self + + def _get_critic_ensemble_factory(self) -> CriticEnsembleFactory: + if self.critic_ensemble_factory is None: + return CriticEnsembleFactoryDefault() + else: + return self.critic_ensemble_factory + + +class PGExperimentBuilder( + ExperimentBuilder, + _BuilderMixinActorFactory_ContinuousGaussian, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) + self._params: PGParams = PGParams() + self._env_config = None + + def with_pg_params(self, params: PGParams) -> Self: + self._params = params + return self + + def _create_agent_factory(self) -> AgentFactory: + return PGAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_optim_factory(), + ) + + +class A2CExperimentBuilder( + ExperimentBuilder, + _BuilderMixinActorFactory_ContinuousGaussian, + _BuilderMixinSingleCriticCanUseActorFactory, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) + self._params: A2CParams = A2CParams() + self._env_config = None + + def with_a2c_params(self, params: A2CParams) -> Self: + self._params = params + return self + + def _create_agent_factory(self) -> AgentFactory: + return A2CAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_optim_factory(), + ) + + +class PPOExperimentBuilder( + ExperimentBuilder, + _BuilderMixinActorFactory_ContinuousGaussian, + _BuilderMixinSingleCriticCanUseActorFactory, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) + self._params: PPOParams = PPOParams() + + def with_ppo_params(self, params: PPOParams) -> Self: + self._params = params + return self + + def _create_agent_factory(self) -> AgentFactory: + return PPOAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_optim_factory(), + ) + + +class NPGExperimentBuilder( + ExperimentBuilder, + _BuilderMixinActorFactory_ContinuousGaussian, + _BuilderMixinSingleCriticCanUseActorFactory, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) + self._params: NPGParams = NPGParams() + + def with_npg_params(self, params: NPGParams) -> Self: + self._params = params + return self + + def _create_agent_factory(self) -> AgentFactory: + return NPGAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_optim_factory(), + ) + + +class TRPOExperimentBuilder( + ExperimentBuilder, + _BuilderMixinActorFactory_ContinuousGaussian, + _BuilderMixinSingleCriticCanUseActorFactory, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) + self._params: TRPOParams = TRPOParams() + + def with_trpo_params(self, params: TRPOParams) -> Self: + self._params = params + return self + + def _create_agent_factory(self) -> AgentFactory: + return TRPOAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_optim_factory(), + ) + + +class DQNExperimentBuilder( + ExperimentBuilder, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + self._params: DQNParams = DQNParams() + self._model_factory: IntermediateModuleFactory = IntermediateModuleFactoryFromActorFactory( + ActorFactoryDefault(ContinuousActorType.UNSUPPORTED, discrete_softmax=False), + ) + + def with_dqn_params(self, params: DQNParams) -> Self: + self._params = params + return self + + def with_model_factory(self, module_factory: IntermediateModuleFactory) -> Self: + """:param module_factory: factory for a module which maps environment observations to a vector of Q-values (one for each action) + :return: the builder + """ + self._model_factory = module_factory + return self + + def with_model_factory_default( + self, + hidden_sizes: Sequence[int], + hidden_activation: ModuleType = torch.nn.ReLU, + ) -> Self: + """Allows to configure the default factory for the model of the Q function, which maps environment observations to a vector of + Q-values (one for each action). The default model is a multi-layer perceptron. + + :param hidden_sizes: the sequence of dimensions used for hidden layers + :param hidden_activation: the activation function to use for hidden layers (not used for the output layer) + :return: the builder + """ + self._model_factory = IntermediateModuleFactoryFromActorFactory( + ActorFactoryDefault( + ContinuousActorType.UNSUPPORTED, + hidden_sizes=hidden_sizes, + hidden_activation=hidden_activation, + discrete_softmax=False, + ), + ) + return self + + def _create_agent_factory(self) -> AgentFactory: + return DQNAgentFactory( + self._params, + self._sampling_config, + self._model_factory, + self._get_optim_factory(), + ) + + +class IQNExperimentBuilder(ExperimentBuilder): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + self._params: IQNParams = IQNParams() + self._preprocess_network_factory: IntermediateModuleFactory = ( + IntermediateModuleFactoryFromActorFactory( + ActorFactoryDefault(ContinuousActorType.UNSUPPORTED, discrete_softmax=False), + ) + ) + + def with_iqn_params(self, params: IQNParams) -> Self: + self._params = params + return self + + def with_preprocess_network_factory(self, module_factory: IntermediateModuleFactory) -> Self: + self._preprocess_network_factory = module_factory + return self + + def _create_agent_factory(self) -> AgentFactory: + model_factory = ImplicitQuantileNetworkFactory( + self._preprocess_network_factory, + hidden_sizes=self._params.hidden_sizes, + num_cosines=self._params.num_cosines, + ) + return IQNAgentFactory( + self._params, + self._sampling_config, + model_factory, + self._get_optim_factory(), + ) + + +class DDPGExperimentBuilder( + ExperimentBuilder, + _BuilderMixinActorFactory_ContinuousDeterministic, + _BuilderMixinSingleCriticCanUseActorFactory, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) + _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) + self._params: DDPGParams = DDPGParams() + + def with_ddpg_params(self, params: DDPGParams) -> Self: + self._params = params + return self + + def _create_agent_factory(self) -> AgentFactory: + return DDPGAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_optim_factory(), + ) + + +class REDQExperimentBuilder( + ExperimentBuilder, + _BuilderMixinActorFactory_ContinuousGaussian, + _BuilderMixinCriticEnsembleFactory, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) + _BuilderMixinCriticEnsembleFactory.__init__(self) + self._params: REDQParams = REDQParams() + + def with_redq_params(self, params: REDQParams) -> Self: + self._params = params + return self + + def _create_agent_factory(self) -> AgentFactory: + return REDQAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_ensemble_factory(), + self._get_optim_factory(), + ) + + +class SACExperimentBuilder( + ExperimentBuilder, + _BuilderMixinActorFactory_ContinuousGaussian, + _BuilderMixinDualCriticFactory, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) + _BuilderMixinDualCriticFactory.__init__(self, self) + self._params: SACParams = SACParams() + + def with_sac_params(self, params: SACParams) -> Self: + self._params = params + return self + + def _create_agent_factory(self) -> AgentFactory: + return SACAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_critic_factory(1), + self._get_optim_factory(), + ) + + +class DiscreteSACExperimentBuilder( + ExperimentBuilder, + _BuilderMixinActorFactory, + _BuilderMixinDualCriticFactory, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + _BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED) + _BuilderMixinDualCriticFactory.__init__(self, self) + self._params: DiscreteSACParams = DiscreteSACParams() + + def with_sac_params(self, params: DiscreteSACParams) -> Self: + self._params = params + return self + + def _create_agent_factory(self) -> AgentFactory: + return DiscreteSACAgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_critic_factory(1), + self._get_optim_factory(), + ) + + +class TD3ExperimentBuilder( + ExperimentBuilder, + _BuilderMixinActorFactory_ContinuousDeterministic, + _BuilderMixinDualCriticFactory, +): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + sampling_config: SamplingConfig | None = None, + ): + super().__init__(env_factory, experiment_config, sampling_config) + _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) + _BuilderMixinDualCriticFactory.__init__(self, self) + self._params: TD3Params = TD3Params() + + def with_td3_params(self, params: TD3Params) -> Self: + self._params = params + return self + + def _create_agent_factory(self) -> AgentFactory: + return TD3AgentFactory( + self._params, + self._sampling_config, + self._get_actor_factory(), + self._get_critic_factory(0), + self._get_critic_factory(1), + self._get_optim_factory(), + ) diff --git a/examples/atari/tianshou/highlevel/logger.py b/examples/atari/tianshou/highlevel/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..2223d81f7893dcfc0d080051ee0d8ddfcf6c6e95 --- /dev/null +++ b/examples/atari/tianshou/highlevel/logger.py @@ -0,0 +1,76 @@ +import os +from abc import ABC, abstractmethod +from typing import Literal, TypeAlias + +from torch.utils.tensorboard import SummaryWriter + +from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger +from tianshou.utils.string import ToStringMixin + +TLogger: TypeAlias = BaseLogger + + +class LoggerFactory(ToStringMixin, ABC): + @abstractmethod + def create_logger( + self, + log_dir: str, + experiment_name: str, + run_id: str | None, + config_dict: dict, + ) -> TLogger: + """Creates the logger. + + :param log_dir: path to the directory in which log data is to be stored + :param experiment_name: the name of the job, which may contain `os.path.sep` + :param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger + :param config_dict: a dictionary with data that is to be logged + :return: the logger + """ + + +class LoggerFactoryDefault(LoggerFactory): + def __init__( + self, + logger_type: Literal["tensorboard", "wandb", "pandas"] = "tensorboard", + wandb_project: str | None = None, + ): + if logger_type == "wandb" and wandb_project is None: + raise ValueError("Must provide 'wandb_project'") + self.logger_type = logger_type + self.wandb_project = wandb_project + + def create_logger( + self, + log_dir: str, + experiment_name: str, + run_id: str | None, + config_dict: dict, + ) -> TLogger: + if self.logger_type in ["wandb", "tensorboard"]: + writer = SummaryWriter(log_dir) + writer.add_text( + "args", + str( + dict( + log_dir=log_dir, + logger_type=self.logger_type, + wandb_project=self.wandb_project, + ), + ), + ) + match self.logger_type: + case "wandb": + wandb_logger = WandbLogger( + save_interval=1, + name=experiment_name.replace(os.path.sep, "__"), + run_id=run_id, + config=config_dict, + project=self.wandb_project, + ) + wandb_logger.load(writer) + return wandb_logger + case "tensorboard": + return TensorboardLogger(writer) + case _: + raise ValueError(f"Unknown logger type '{self.logger_type}'") diff --git a/examples/atari/tianshou/highlevel/module/__init__.py b/examples/atari/tianshou/highlevel/module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/atari/tianshou/highlevel/module/actor.py b/examples/atari/tianshou/highlevel/module/actor.py new file mode 100644 index 0000000000000000000000000000000000000000..867ece17a09ca159f8a5abb4afef9013df8c7b6e --- /dev/null +++ b/examples/atari/tianshou/highlevel/module/actor.py @@ -0,0 +1,265 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum +from typing import Protocol + +import torch +from torch import nn + +from tianshou.highlevel.env import Environments, EnvType +from tianshou.highlevel.module.core import ( + ModuleFactory, + TDevice, + init_linear_orthogonal, +) +from tianshou.highlevel.module.intermediate import ( + IntermediateModule, + IntermediateModuleFactory, +) +from tianshou.highlevel.module.module_opt import ModuleOpt +from tianshou.highlevel.optim import OptimizerFactory +from tianshou.utils.net import continuous, discrete +from tianshou.utils.net.common import BaseActor, ModuleType, Net +from tianshou.utils.string import ToStringMixin + + +class ContinuousActorType(Enum): + GAUSSIAN = "gaussian" + DETERMINISTIC = "deterministic" + UNSUPPORTED = "unsupported" + + +@dataclass +class ActorFuture: + """Container, which, in the future, will hold an actor instance.""" + + actor: BaseActor | nn.Module | None = None + + +class ActorFutureProviderProtocol(Protocol): + def get_actor_future(self) -> ActorFuture: + pass + + +class ActorFactory(ModuleFactory, ToStringMixin, ABC): + @abstractmethod + def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: + pass + + def create_module_opt( + self, + envs: Environments, + device: TDevice, + optim_factory: OptimizerFactory, + lr: float, + ) -> ModuleOpt: + """Creates the actor module along with its optimizer for the given learning rate. + + :param envs: the environments + :param device: the torch device + :param optim_factory: the optimizer factory + :param lr: the learning rate + :return: a container with the actor module and its optimizer + """ + module = self.create_module(envs, device) + optim = optim_factory.create_optimizer(module, lr) + return ModuleOpt(module, optim) + + @staticmethod + def _init_linear(actor: torch.nn.Module) -> None: + """Initializes linear layers of an actor module using default mechanisms. + + :param module: the actor module. + """ + init_linear_orthogonal(actor) + if hasattr(actor, "mu"): + # For continuous action spaces with Gaussian policies + # do last policy layer scaling, this will make initial actions have (close to) + # 0 mean and std, and will help boost performances, + # see https://arxiv.org/abs/2006.05990, Fig.24 for details + for m in actor.mu.modules(): + if isinstance(m, torch.nn.Linear): + m.weight.data.copy_(0.01 * m.weight.data) + + +class ActorFactoryDefault(ActorFactory): + """An actor factory which, depending on the type of environment, creates a suitable MLP-based policy.""" + + DEFAULT_HIDDEN_SIZES = (64, 64) + + def __init__( + self, + continuous_actor_type: ContinuousActorType, + hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, + hidden_activation: ModuleType = nn.ReLU, + continuous_unbounded: bool = False, + continuous_conditioned_sigma: bool = False, + discrete_softmax: bool = True, + ): + self.continuous_actor_type = continuous_actor_type + self.continuous_unbounded = continuous_unbounded + self.continuous_conditioned_sigma = continuous_conditioned_sigma + self.hidden_sizes = hidden_sizes + self.hidden_activation = hidden_activation + self.discrete_softmax = discrete_softmax + + def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + env_type = envs.get_type() + factory: ActorFactoryContinuousDeterministicNet | ActorFactoryContinuousGaussianNet | ActorFactoryDiscreteNet + if env_type == EnvType.CONTINUOUS: + match self.continuous_actor_type: + case ContinuousActorType.GAUSSIAN: + factory = ActorFactoryContinuousGaussianNet( + self.hidden_sizes, + activation=self.hidden_activation, + unbounded=self.continuous_unbounded, + conditioned_sigma=self.continuous_conditioned_sigma, + ) + case ContinuousActorType.DETERMINISTIC: + factory = ActorFactoryContinuousDeterministicNet( + self.hidden_sizes, + activation=self.hidden_activation, + ) + case ContinuousActorType.UNSUPPORTED: + raise ValueError("Continuous action spaces are not supported by the algorithm") + case _: + raise ValueError(self.continuous_actor_type) + return factory.create_module(envs, device) + elif env_type == EnvType.DISCRETE: + factory = ActorFactoryDiscreteNet( + self.DEFAULT_HIDDEN_SIZES, + softmax_output=self.discrete_softmax, + ) + return factory.create_module(envs, device) + else: + raise ValueError(f"{env_type} not supported") + + +class ActorFactoryContinuous(ActorFactory, ABC): + """Serves as a type bound for actor factories that are suitable for continuous action spaces.""" + + +class ActorFactoryContinuousDeterministicNet(ActorFactoryContinuous): + def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU): + self.hidden_sizes = hidden_sizes + self.activation = activation + + def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + net_a = Net( + state_shape=envs.get_observation_shape(), + hidden_sizes=self.hidden_sizes, + activation=self.activation, + device=device, + ) + return continuous.Actor( + preprocess_net=net_a, + action_shape=envs.get_action_shape(), + hidden_sizes=(), + device=device, + ).to(device) + + +class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous): + def __init__( + self, + hidden_sizes: Sequence[int], + unbounded: bool = True, + conditioned_sigma: bool = False, + activation: ModuleType = nn.ReLU, + ): + """For actors with Gaussian policies. + + :param hidden_sizes: the sequence of hidden dimensions to use in the network structure + :param unbounded: whether to apply tanh activation on final logits + :param conditioned_sigma: if True, the standard deviation of continuous actions (sigma) is computed from the + input; if False, sigma is an independent parameter + """ + self.hidden_sizes = hidden_sizes + self.unbounded = unbounded + self.conditioned_sigma = conditioned_sigma + self.activation = activation + + def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + net_a = Net( + state_shape=envs.get_observation_shape(), + hidden_sizes=self.hidden_sizes, + activation=self.activation, + device=device, + ) + actor = continuous.ActorProb( + preprocess_net=net_a, + action_shape=envs.get_action_shape(), + unbounded=self.unbounded, + device=device, + conditioned_sigma=self.conditioned_sigma, + ).to(device) + + # init params + if not self.conditioned_sigma: + torch.nn.init.constant_(actor.sigma_param, -0.5) + self._init_linear(actor) + + return actor + + +class ActorFactoryDiscreteNet(ActorFactory): + def __init__( + self, + hidden_sizes: Sequence[int], + softmax_output: bool = True, + activation: ModuleType = nn.ReLU, + ): + self.hidden_sizes = hidden_sizes + self.softmax_output = softmax_output + self.activation = activation + + def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + net_a = Net( + state_shape=envs.get_observation_shape(), + hidden_sizes=self.hidden_sizes, + activation=self.activation, + device=device, + ) + return discrete.Actor( + net_a, + envs.get_action_shape(), + hidden_sizes=(), + device=device, + softmax_output=self.softmax_output, + ).to(device) + + +class ActorFactoryTransientStorageDecorator(ActorFactory): + """Wraps an actor factory, storing the most recently created actor instance such that it can be retrieved.""" + + def __init__(self, actor_factory: ActorFactory, actor_future: ActorFuture): + self.actor_factory = actor_factory + self._actor_future = actor_future + + def __getstate__(self) -> dict: + d = dict(self.__dict__) + del d["_actor_future"] + return d + + def __setstate__(self, state: dict) -> None: + self.__dict__ = state + self._actor_future = ActorFuture() + + def _tostring_excludes(self) -> list[str]: + return [*super()._tostring_excludes(), "_actor_future"] + + def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: + module = self.actor_factory.create_module(envs, device) + self._actor_future.actor = module + return module + + +class IntermediateModuleFactoryFromActorFactory(IntermediateModuleFactory): + def __init__(self, actor_factory: ActorFactory): + self.actor_factory = actor_factory + + def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: + actor = self.actor_factory.create_module(envs, device) + assert isinstance(actor, BaseActor) + return IntermediateModule(actor, actor.get_output_dim()) diff --git a/examples/atari/tianshou/highlevel/module/core.py b/examples/atari/tianshou/highlevel/module/core.py new file mode 100644 index 0000000000000000000000000000000000000000..61f4a232bdc44749bb93ba9b0c04ef6299c75c73 --- /dev/null +++ b/examples/atari/tianshou/highlevel/module/core.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod +from typing import TypeAlias + +import numpy as np +import torch + +from tianshou.highlevel.env import Environments + +TDevice: TypeAlias = str | torch.device + + +def init_linear_orthogonal(module: torch.nn.Module) -> None: + """Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0. + + :param module: the module whose submodules are to be processed + """ + for m in module.modules(): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) + torch.nn.init.zeros_(m.bias) + + +class ModuleFactory(ABC): + """Represents a factory for the creation of a torch module given an environment and target device.""" + + @abstractmethod + def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: + pass diff --git a/examples/atari/tianshou/highlevel/module/critic.py b/examples/atari/tianshou/highlevel/module/critic.py new file mode 100644 index 0000000000000000000000000000000000000000..4eacef1157fe19d1945dc0a9c94dfb2a9320d60a --- /dev/null +++ b/examples/atari/tianshou/highlevel/module/critic.py @@ -0,0 +1,297 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence + +import numpy as np +from torch import nn + +from tianshou.highlevel.env import Environments, EnvType +from tianshou.highlevel.module.actor import ActorFuture +from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal +from tianshou.highlevel.module.module_opt import ModuleOpt +from tianshou.highlevel.optim import OptimizerFactory +from tianshou.utils.net import continuous, discrete +from tianshou.utils.net.common import BaseActor, EnsembleLinear, ModuleType, Net +from tianshou.utils.string import ToStringMixin + + +class CriticFactory(ToStringMixin, ABC): + """Represents a factory for the generation of a critic module.""" + + @abstractmethod + def create_module( + self, + envs: Environments, + device: TDevice, + use_action: bool, + discrete_last_size_use_action_shape: bool = False, + ) -> nn.Module: + """Creates the critic module. + + :param envs: the environments + :param device: the torch device + :param use_action: whether to expect the action as an additional input (in addition to the observations) + :param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape + :return: the module + """ + + def create_module_opt( + self, + envs: Environments, + device: TDevice, + use_action: bool, + optim_factory: OptimizerFactory, + lr: float, + discrete_last_size_use_action_shape: bool = False, + ) -> ModuleOpt: + """Creates the critic module along with its optimizer for the given learning rate. + + :param envs: the environments + :param device: the torch device + :param use_action: whether to expect the action as an additional input (in addition to the observations) + :param optim_factory: the optimizer factory + :param lr: the learning rate + :param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape + :return: + """ + module = self.create_module( + envs, + device, + use_action, + discrete_last_size_use_action_shape=discrete_last_size_use_action_shape, + ) + opt = optim_factory.create_optimizer(module, lr) + return ModuleOpt(module, opt) + + +class CriticFactoryDefault(CriticFactory): + """A critic factory which, depending on the type of environment, creates a suitable MLP-based critic.""" + + DEFAULT_HIDDEN_SIZES = (64, 64) + + def __init__( + self, + hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES, + hidden_activation: ModuleType = nn.ReLU, + ): + self.hidden_sizes = hidden_sizes + self.hidden_activation = hidden_activation + + def create_module( + self, + envs: Environments, + device: TDevice, + use_action: bool, + discrete_last_size_use_action_shape: bool = False, + ) -> nn.Module: + factory: CriticFactory + env_type = envs.get_type() + match env_type: + case EnvType.CONTINUOUS: + factory = CriticFactoryContinuousNet( + self.hidden_sizes, + activation=self.hidden_activation, + ) + case EnvType.DISCRETE: + factory = CriticFactoryDiscreteNet( + self.hidden_sizes, + activation=self.hidden_activation, + ) + case _: + raise ValueError(f"{env_type} not supported") + return factory.create_module( + envs, + device, + use_action, + discrete_last_size_use_action_shape=discrete_last_size_use_action_shape, + ) + + +class CriticFactoryContinuousNet(CriticFactory): + def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU): + self.hidden_sizes = hidden_sizes + self.activation = activation + + def create_module( + self, + envs: Environments, + device: TDevice, + use_action: bool, + discrete_last_size_use_action_shape: bool = False, + ) -> nn.Module: + action_shape = envs.get_action_shape() if use_action else 0 + net_c = Net( + state_shape=envs.get_observation_shape(), + action_shape=action_shape, + hidden_sizes=self.hidden_sizes, + concat=use_action, + activation=self.activation, + device=device, + ) + critic = continuous.Critic(net_c, device=device).to(device) + init_linear_orthogonal(critic) + return critic + + +class CriticFactoryDiscreteNet(CriticFactory): + def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU): + self.hidden_sizes = hidden_sizes + self.activation = activation + + def create_module( + self, + envs: Environments, + device: TDevice, + use_action: bool, + discrete_last_size_use_action_shape: bool = False, + ) -> nn.Module: + action_shape = envs.get_action_shape() if use_action else 0 + net_c = Net( + state_shape=envs.get_observation_shape(), + action_shape=action_shape, + hidden_sizes=self.hidden_sizes, + concat=use_action, + activation=self.activation, + device=device, + ) + last_size = ( + int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1 + ) + critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device) + init_linear_orthogonal(critic) + return critic + + +class CriticFactoryReuseActor(CriticFactory): + """A critic factory which reuses the actor's preprocessing component. + + This class is for internal use in experiment builders only. + """ + + def __init__(self, actor_future: ActorFuture): + """:param actor_future: the object, which will hold the actor instance later when the critic is to be created""" + self.actor_future = actor_future + + def _tostring_excludes(self) -> list[str]: + return ["actor_future"] + + def create_module( + self, + envs: Environments, + device: TDevice, + use_action: bool, + discrete_last_size_use_action_shape: bool = False, + ) -> nn.Module: + actor = self.actor_future.actor + if not isinstance(actor, BaseActor): + raise ValueError( + f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}", + ) + if envs.get_type().is_discrete(): + # TODO get rid of this prod pattern here and elsewhere + last_size = ( + int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1 + ) + return discrete.Critic( + actor.get_preprocess_net(), + device=device, + last_size=last_size, + ).to(device) + elif envs.get_type().is_continuous(): + return continuous.Critic( + actor.get_preprocess_net(), + device=device, + apply_preprocess_net_to_obs_only=True, + ).to(device) + else: + raise ValueError + + +class CriticEnsembleFactory: + @abstractmethod + def create_module( + self, + envs: Environments, + device: TDevice, + ensemble_size: int, + use_action: bool, + ) -> nn.Module: + pass + + def create_module_opt( + self, + envs: Environments, + device: TDevice, + ensemble_size: int, + use_action: bool, + optim_factory: OptimizerFactory, + lr: float, + ) -> ModuleOpt: + module = self.create_module(envs, device, ensemble_size, use_action) + opt = optim_factory.create_optimizer(module, lr) + return ModuleOpt(module, opt) + + +class CriticEnsembleFactoryDefault(CriticEnsembleFactory): + """A critic ensemble factory which, depending on the type of environment, creates a suitable MLP-based critic.""" + + DEFAULT_HIDDEN_SIZES = (64, 64) + + def __init__(self, hidden_sizes: Sequence[int] = DEFAULT_HIDDEN_SIZES): + self.hidden_sizes = hidden_sizes + + def create_module( + self, + envs: Environments, + device: TDevice, + ensemble_size: int, + use_action: bool, + ) -> nn.Module: + env_type = envs.get_type() + factory: CriticEnsembleFactory + match env_type: + case EnvType.CONTINUOUS: + factory = CriticEnsembleFactoryContinuousNet(self.hidden_sizes) + case EnvType.DISCRETE: + raise NotImplementedError("No default is implemented for the discrete case") + case _: + raise ValueError(f"{env_type} not supported") + return factory.create_module( + envs, + device, + ensemble_size, + use_action, + ) + + +class CriticEnsembleFactoryContinuousNet(CriticEnsembleFactory): + def __init__(self, hidden_sizes: Sequence[int]): + self.hidden_sizes = hidden_sizes + + def create_module( + self, + envs: Environments, + device: TDevice, + ensemble_size: int, + use_action: bool, + ) -> nn.Module: + def linear_layer(x: int, y: int) -> EnsembleLinear: + return EnsembleLinear(ensemble_size, x, y) + + action_shape = envs.get_action_shape() if use_action else 0 + net_c = Net( + state_shape=envs.get_observation_shape(), + action_shape=action_shape, + hidden_sizes=self.hidden_sizes, + concat=use_action, + activation=nn.Tanh, + device=device, + linear_layer=linear_layer, + ) + critic = continuous.Critic( + net_c, + device=device, + linear_layer=linear_layer, + flatten_input=False, + ).to(device) + init_linear_orthogonal(critic) + return critic diff --git a/examples/atari/tianshou/highlevel/module/intermediate.py b/examples/atari/tianshou/highlevel/module/intermediate.py new file mode 100644 index 0000000000000000000000000000000000000000..a008935af38d23c0871ee05eeba4b78b8a951ace --- /dev/null +++ b/examples/atari/tianshou/highlevel/module/intermediate.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import torch + +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.core import ModuleFactory, TDevice +from tianshou.utils.string import ToStringMixin + + +@dataclass +class IntermediateModule: + """Container for a module which computes an intermediate representation (with a known dimension).""" + + module: torch.nn.Module + output_dim: int + + +class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC): + """Factory for the generation of a module which computes an intermediate representation.""" + + @abstractmethod + def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: + pass + + def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: + return self.create_intermediate_module(envs, device).module diff --git a/examples/atari/tianshou/highlevel/module/module_opt.py b/examples/atari/tianshou/highlevel/module/module_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..558686aa9819b874173a4a8d4a24ae9bb68ed5e1 --- /dev/null +++ b/examples/atari/tianshou/highlevel/module/module_opt.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass + +import torch + +from tianshou.utils.net.common import ActorCritic + + +@dataclass +class ModuleOpt: + """Container for a torch module along with its optimizer.""" + + module: torch.nn.Module + optim: torch.optim.Optimizer + + +@dataclass +class ActorCriticOpt: + """Container for an :class:`ActorCritic` instance along with its optimizer.""" + + actor_critic_module: ActorCritic + optim: torch.optim.Optimizer + + @property + def actor(self) -> torch.nn.Module: + return self.actor_critic_module.actor + + @property + def critic(self) -> torch.nn.Module: + return self.actor_critic_module.critic diff --git a/examples/atari/tianshou/highlevel/module/special.py b/examples/atari/tianshou/highlevel/module/special.py new file mode 100644 index 0000000000000000000000000000000000000000..8c3c568f09227812d880cd2e580f9bd17784ebec --- /dev/null +++ b/examples/atari/tianshou/highlevel/module/special.py @@ -0,0 +1,30 @@ +from collections.abc import Sequence + +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.core import ModuleFactory, TDevice +from tianshou.highlevel.module.intermediate import IntermediateModuleFactory +from tianshou.utils.net.discrete import ImplicitQuantileNetwork +from tianshou.utils.string import ToStringMixin + + +class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin): + def __init__( + self, + preprocess_net_factory: IntermediateModuleFactory, + hidden_sizes: Sequence[int] = (), + num_cosines: int = 64, + ): + self.preprocess_net_factory = preprocess_net_factory + self.hidden_sizes = hidden_sizes + self.num_cosines = num_cosines + + def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork: + preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device) + return ImplicitQuantileNetwork( + preprocess_net=preprocess_net.module, + action_shape=envs.get_action_shape(), + hidden_sizes=self.hidden_sizes, + num_cosines=self.num_cosines, + preprocess_net_output_dim=preprocess_net.output_dim, + device=device, + ).to(device) diff --git a/examples/atari/tianshou/highlevel/optim.py b/examples/atari/tianshou/highlevel/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..db5fd90fff8d6108914657f8f1b8602c48a60927 --- /dev/null +++ b/examples/atari/tianshou/highlevel/optim.py @@ -0,0 +1,92 @@ +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import Any, Protocol, TypeAlias + +import torch +from torch.optim import Adam, RMSprop + +from tianshou.utils.string import ToStringMixin + +TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]] + + +class OptimizerWithLearningRateProtocol(Protocol): + def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer: + pass + + +class OptimizerFactory(ABC, ToStringMixin): + def create_optimizer( + self, + module: torch.nn.Module, + lr: float, + ) -> torch.optim.Optimizer: + return self.create_optimizer_for_params(module.parameters(), lr) + + @abstractmethod + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: + pass + + +class OptimizerFactoryTorch(OptimizerFactory): + def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any): + """Factory for torch optimizers. + + :param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`), + which will be passed the module parameters, the learning rate as `lr` and the + kwargs provided. + :param kwargs: keyword arguments to provide at optimizer construction + """ + self.optim_class = optim_class + self.kwargs = kwargs + + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: + return self.optim_class(params, lr=lr, **self.kwargs) + + +class OptimizerFactoryAdam(OptimizerFactory): + def __init__( + self, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-08, + weight_decay: float = 0, + ): + self.weight_decay = weight_decay + self.eps = eps + self.betas = betas + + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: + return Adam( + params, + lr=lr, + betas=self.betas, + eps=self.eps, + weight_decay=self.weight_decay, + ) + + +class OptimizerFactoryRMSprop(OptimizerFactory): + def __init__( + self, + alpha: float = 0.99, + eps: float = 1e-08, + weight_decay: float = 0, + momentum: float = 0, + centered: bool = False, + ): + self.alpha = alpha + self.momentum = momentum + self.centered = centered + self.weight_decay = weight_decay + self.eps = eps + + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: + return RMSprop( + params, + lr=lr, + alpha=self.alpha, + eps=self.eps, + weight_decay=self.weight_decay, + momentum=self.momentum, + centered=self.centered, + ) diff --git a/examples/atari/tianshou/highlevel/params/__init__.py b/examples/atari/tianshou/highlevel/params/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/atari/tianshou/highlevel/params/alpha.py b/examples/atari/tianshou/highlevel/params/alpha.py new file mode 100644 index 0000000000000000000000000000000000000000..4e8490de8646f3dfbad621452a493cd3ec9e973d --- /dev/null +++ b/examples/atari/tianshou/highlevel/params/alpha.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch + +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.core import TDevice +from tianshou.highlevel.optim import OptimizerFactory +from tianshou.utils.string import ToStringMixin + + +class AutoAlphaFactory(ToStringMixin, ABC): + @abstractmethod + def create_auto_alpha( + self, + envs: Environments, + optim_factory: OptimizerFactory, + device: TDevice, + ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: + pass + + +class AutoAlphaFactoryDefault(AutoAlphaFactory): + def __init__(self, lr: float = 3e-4): + self.lr = lr + + def create_auto_alpha( + self, + envs: Environments, + optim_factory: OptimizerFactory, + device: TDevice, + ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: + target_entropy = float(-np.prod(envs.get_action_shape())) + log_alpha = torch.zeros(1, requires_grad=True, device=device) + alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr) + return target_entropy, log_alpha, alpha_optim diff --git a/examples/atari/tianshou/highlevel/params/dist_fn.py b/examples/atari/tianshou/highlevel/params/dist_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..c8d2aca9e619e1fbeeb630de9ab7a950a6b40f30 --- /dev/null +++ b/examples/atari/tianshou/highlevel/params/dist_fn.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any + +import torch + +from tianshou.highlevel.env import Environments, EnvType +from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont +from tianshou.utils.string import ToStringMixin + + +class DistributionFunctionFactory(ToStringMixin, ABC): + # True return type defined in subclasses + @abstractmethod + def create_dist_fn( + self, + envs: Environments, + ) -> Callable[[Any], torch.distributions.Distribution]: + pass + + +class DistributionFunctionFactoryCategorical(DistributionFunctionFactory): + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete: + envs.get_type().assert_discrete(self) + return self._dist_fn + + @staticmethod + def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical: + return torch.distributions.Categorical(logits=p) + + +class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory): + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont: + envs.get_type().assert_continuous(self) + return self._dist_fn + + @staticmethod + def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution: + loc, scale = loc_scale + return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1) + + +class DistributionFunctionFactoryDefault(DistributionFunctionFactory): + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont: + match envs.get_type(): + case EnvType.DISCRETE: + return DistributionFunctionFactoryCategorical().create_dist_fn(envs) + case EnvType.CONTINUOUS: + return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs) + case _: + raise ValueError(envs.get_type()) diff --git a/examples/atari/tianshou/highlevel/params/env_param.py b/examples/atari/tianshou/highlevel/params/env_param.py new file mode 100644 index 0000000000000000000000000000000000000000..8696518dc6acd704efc42ea72998fb774c240b02 --- /dev/null +++ b/examples/atari/tianshou/highlevel/params/env_param.py @@ -0,0 +1,33 @@ +"""Factories for the generation of environment-dependent parameters.""" +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +from tianshou.highlevel.env import ContinuousEnvironments, Environments +from tianshou.utils.string import ToStringMixin + +TValue = TypeVar("TValue") +TEnvs = TypeVar("TEnvs", bound=Environments) + + +class EnvValueFactory(Generic[TValue, TEnvs], ToStringMixin, ABC): + @abstractmethod + def create_value(self, envs: TEnvs) -> TValue: + pass + + +class FloatEnvValueFactory(EnvValueFactory[float, TEnvs], Generic[TEnvs], ABC): + """Serves as a type bound for float value factories.""" + + +class FloatEnvValueFactoryMaxActionScaled(FloatEnvValueFactory[ContinuousEnvironments]): + def __init__(self, value: float): + """:param value: value with which to scale the max action value""" + self.value = value + + def create_value(self, envs: ContinuousEnvironments) -> float: + envs.get_type().assert_continuous(self) + return envs.max_action * self.value + + +class MaxActionScaled(FloatEnvValueFactoryMaxActionScaled): + pass diff --git a/examples/atari/tianshou/highlevel/params/lr_scheduler.py b/examples/atari/tianshou/highlevel/params/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..0b0cf359a4fbf9ffb4975c9fd72ffff248199f76 --- /dev/null +++ b/examples/atari/tianshou/highlevel/params/lr_scheduler.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch +from torch.optim.lr_scheduler import LambdaLR, LRScheduler + +from tianshou.highlevel.config import SamplingConfig +from tianshou.utils.string import ToStringMixin + + +class LRSchedulerFactory(ToStringMixin, ABC): + """Factory for the creation of a learning rate scheduler.""" + + @abstractmethod + def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: + pass + + +class LRSchedulerFactoryLinear(LRSchedulerFactory): + def __init__(self, sampling_config: SamplingConfig): + self.sampling_config = sampling_config + + def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: + return LambdaLR(optim, lr_lambda=self._LRLambda(self.sampling_config).compute) + + class _LRLambda: + def __init__(self, sampling_config: SamplingConfig): + self.max_update_num = ( + np.ceil(sampling_config.step_per_epoch / sampling_config.step_per_collect) + * sampling_config.num_epochs + ) + + def compute(self, epoch: int) -> float: + return 1.0 - epoch / self.max_update_num diff --git a/examples/atari/tianshou/highlevel/params/noise.py b/examples/atari/tianshou/highlevel/params/noise.py new file mode 100644 index 0000000000000000000000000000000000000000..66e0c53c48ff37dde2c2a513df0356b764f0c92a --- /dev/null +++ b/examples/atari/tianshou/highlevel/params/noise.py @@ -0,0 +1,32 @@ +from abc import ABC, abstractmethod + +from tianshou.exploration import BaseNoise, GaussianNoise +from tianshou.highlevel.env import ContinuousEnvironments, Environments +from tianshou.utils.string import ToStringMixin + + +class NoiseFactory(ToStringMixin, ABC): + @abstractmethod + def create_noise(self, envs: Environments) -> BaseNoise: + pass + + +class NoiseFactoryMaxActionScaledGaussian(NoiseFactory): + def __init__(self, std_fraction: float): + """Factory for Gaussian noise where the standard deviation is a fraction of the maximum action value. + + This factory can only be applied to continuous action spaces. + + :param std_fraction: fraction (between 0 and 1) of the maximum action value that shall + be used as the standard deviation + """ + self.std_fraction = std_fraction + + def create_noise(self, envs: Environments) -> GaussianNoise: + envs.get_type().assert_continuous(self) + envs: ContinuousEnvironments + return GaussianNoise(sigma=envs.max_action * self.std_fraction) + + +class MaxActionScaledGaussian(NoiseFactoryMaxActionScaledGaussian): + pass diff --git a/examples/atari/tianshou/highlevel/params/policy_params.py b/examples/atari/tianshou/highlevel/params/policy_params.py new file mode 100644 index 0000000000000000000000000000000000000000..24674bc8c038362aec983f34546e289921ba9717 --- /dev/null +++ b/examples/atari/tianshou/highlevel/params/policy_params.py @@ -0,0 +1,656 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import asdict, dataclass +from typing import Any, Literal, Protocol + +import torch +from torch.optim.lr_scheduler import LRScheduler + +from tianshou.exploration import BaseNoise +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.core import TDevice +from tianshou.highlevel.module.module_opt import ModuleOpt +from tianshou.highlevel.optim import OptimizerFactory +from tianshou.highlevel.params.alpha import AutoAlphaFactory +from tianshou.highlevel.params.dist_fn import ( + DistributionFunctionFactory, + DistributionFunctionFactoryDefault, +) +from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory +from tianshou.highlevel.params.noise import NoiseFactory +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.utils import MultipleLRSchedulers +from tianshou.utils.string import ToStringMixin + + +@dataclass(kw_only=True) +class ParamTransformerData: + """Holds data that can be used by `ParamTransformer` instances to perform their transformation. + + The representation contains the superset of all data items that are required by different types of agent factories. + An agent factory is expected to set only the attributes that are relevant to its parameters. + """ + + envs: Environments + device: TDevice + optim_factory: OptimizerFactory + optim: torch.optim.Optimizer | None = None + """the single optimizer for the case where there is just one""" + actor: ModuleOpt | None = None + critic1: ModuleOpt | None = None + critic2: ModuleOpt | None = None + + +class ParamTransformer(ABC): + """Base class for parameter transformations from high to low-level API. + + Transforms one or more parameters from the representation used by the high-level API + to the representation required by the (low-level) policy implementation. + It operates directly on a dictionary of keyword arguments, which is initially + generated from the parameter dataclass (subclass of `Params`). + """ + + @abstractmethod + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + pass + + @staticmethod + def get(d: dict[str, Any], key: str, drop: bool = False) -> Any: + value = d[key] + if drop: + del d[key] + return value + + +class ParamTransformerDrop(ParamTransformer): + def __init__(self, *keys: str): + self.keys = keys + + def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: + for k in self.keys: + del kwargs[k] + + +class ParamTransformerChangeValue(ParamTransformer): + def __init__(self, key: str): + self.key = key + + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + params[self.key] = self.change_value(params[self.key], data) + + @abstractmethod + def change_value(self, value: Any, data: ParamTransformerData) -> Any: + pass + + +class ParamTransformerLRScheduler(ParamTransformer): + """Transformer for learning rate scheduler params. + + Transforms a key containing a learning rate scheduler factory (removed) into a key containing + a learning rate scheduler (added) for the data member `optim`. + """ + + def __init__(self, key_scheduler_factory: str, key_scheduler: str): + self.key_scheduler_factory = key_scheduler_factory + self.key_scheduler = key_scheduler + + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + assert data.optim is not None + factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True) + params[self.key_scheduler] = ( + factory.create_scheduler(data.optim) if factory is not None else None + ) + + +class ParamTransformerMultiLRScheduler(ParamTransformer): + def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str): + """Transforms several scheduler factories into a single scheduler. + + The result may be a `MultipleLRSchedulers` instance if more than one factory is indeed given. + + :param optim_key_list: a list of tuples (optimizer, key of learning rate factory) + :param key_scheduler: the key under which to store the resulting learning rate scheduler + """ + self.optim_key_list = optim_key_list + self.key_scheduler = key_scheduler + + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + lr_schedulers = [] + for optim, lr_scheduler_factory_key in self.optim_key_list: + lr_scheduler_factory: LRSchedulerFactory | None = self.get( + params, + lr_scheduler_factory_key, + drop=True, + ) + if lr_scheduler_factory is not None: + lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim)) + lr_scheduler: LRScheduler | MultipleLRSchedulers | None + match len(lr_schedulers): + case 0: + lr_scheduler = None + case 1: + lr_scheduler = lr_schedulers[0] + case _: + lr_scheduler = MultipleLRSchedulers(*lr_schedulers) + params[self.key_scheduler] = lr_scheduler + + +class ParamTransformerActorAndCriticLRScheduler(ParamTransformer): + def __init__( + self, + key_scheduler_factory_actor: str, + key_scheduler_factory_critic: str, + key_scheduler: str, + ): + self.key_factory_actor = key_scheduler_factory_actor + self.key_factory_critic = key_scheduler_factory_critic + self.key_scheduler = key_scheduler + + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + assert data.actor is not None and data.critic1 is not None + transformer = ParamTransformerMultiLRScheduler( + [ + (data.actor.optim, self.key_factory_actor), + (data.critic1.optim, self.key_factory_critic), + ], + self.key_scheduler, + ) + transformer.transform(params, data) + + +class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer): + def __init__( + self, + key_scheduler_factory_actor: str, + key_scheduler_factory_critic1: str, + key_scheduler_factory_critic2: str, + key_scheduler: str, + ): + self.key_factory_actor = key_scheduler_factory_actor + self.key_factory_critic1 = key_scheduler_factory_critic1 + self.key_factory_critic2 = key_scheduler_factory_critic2 + self.key_scheduler = key_scheduler + + def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: + assert data.actor is not None and data.critic1 is not None and data.critic2 is not None + transformer = ParamTransformerMultiLRScheduler( + [ + (data.actor.optim, self.key_factory_actor), + (data.critic1.optim, self.key_factory_critic1), + (data.critic2.optim, self.key_factory_critic2), + ], + self.key_scheduler, + ) + transformer.transform(params, data) + + +class ParamTransformerAutoAlpha(ParamTransformer): + def __init__(self, key: str): + self.key = key + + def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: + alpha = self.get(kwargs, self.key) + if isinstance(alpha, AutoAlphaFactory): + kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device) + + +class ParamTransformerNoiseFactory(ParamTransformerChangeValue): + def change_value(self, value: Any, data: ParamTransformerData) -> Any: + if isinstance(value, NoiseFactory): + value = value.create_noise(data.envs) + return value + + +class ParamTransformerFloatEnvParamFactory(ParamTransformerChangeValue): + def change_value(self, value: Any, data: ParamTransformerData) -> Any: + if isinstance(value, EnvValueFactory): + value = value.create_value(data.envs) + return value + + +class ParamTransformerDistributionFunction(ParamTransformerChangeValue): + def change_value(self, value: Any, data: ParamTransformerData) -> Any: + if value == "default": + value = DistributionFunctionFactoryDefault().create_dist_fn(data.envs) + elif isinstance(value, DistributionFunctionFactory): + value = value.create_dist_fn(data.envs) + return value + + +class ParamTransformerActionScaling(ParamTransformerChangeValue): + def change_value(self, value: Any, data: ParamTransformerData) -> Any: + if value == "default": + return data.envs.get_type().is_continuous() + else: + return value + + +class GetParamTransformersProtocol(Protocol): + def _get_param_transformers(self) -> list[ParamTransformer]: + pass + + +@dataclass +class Params(GetParamTransformersProtocol, ToStringMixin): + def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]: + params = asdict(self) + for transformer in self._get_param_transformers(): + transformer.transform(params, data) + return params + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [] + + +@dataclass +class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol): + lr: float = 1e-3 + """the learning rate to use in the gradient-based optimizer""" + lr_scheduler_factory: LRSchedulerFactory | None = None + """factory for the creation of a learning rate scheduler""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [ + ParamTransformerDrop("lr"), + ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"), + ] + + +@dataclass +class ParamsMixinActorAndCritic(GetParamTransformersProtocol): + actor_lr: float = 1e-3 + """the learning rate to use for the actor network""" + critic_lr: float = 1e-3 + """the learning rate to use for the critic network""" + actor_lr_scheduler_factory: LRSchedulerFactory | None = None + """factory for the creation of a learning rate scheduler to use for the actor network (if any)""" + critic_lr_scheduler_factory: LRSchedulerFactory | None = None + """factory for the creation of a learning rate scheduler to use for the critic network (if any)""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [ + ParamTransformerDrop("actor_lr", "critic_lr"), + ParamTransformerActorAndCriticLRScheduler( + "actor_lr_scheduler_factory", + "critic_lr_scheduler_factory", + "lr_scheduler", + ), + ] + + +@dataclass +class ParamsMixinActionScaling(GetParamTransformersProtocol): + action_scaling: bool | Literal["default"] = "default" + """whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces""" + action_bound_method: Literal["clip", "tanh"] | None = "clip" + """ + method to bound action to range [-1, 1]. Only used if the action_space is continuous. + """ + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [] + + +@dataclass +class ParamsMixinExplorationNoise(GetParamTransformersProtocol): + exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None + """ + If not None, add noise to actions for exploration. + This is useful when solving "hard exploration" problems. + It can either be a distribution, a factory for the creation of a distribution or "default". + When set to "default", use Gaussian noise with standard deviation 0.1. + """ + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [ParamTransformerNoiseFactory("exploration_noise")] + + +@dataclass +class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithScheduler): + discount_factor: float = 0.99 + """ + discount factor (gamma) for future rewards; must be in [0, 1] + """ + reward_normalization: bool = False + """ + if True, will normalize the returns by subtracting the running mean and dividing by the running + standard deviation. + """ + deterministic_eval: bool = False + """ + whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation. + Does not affect training. + """ + dist_fn: TDistFnDiscrOrCont | DistributionFunctionFactory | Literal["default"] = "default" + """ + This can either be a function which maps the model output to a torch distribution or a + factory for the creation of such a function. + When set to "default", a factory which creates Gaussian distributions from mean and standard + deviation will be used for the continuous case and which creates categorical distributions + for the discrete case (see :class:`DistributionFunctionFactoryDefault`) + """ + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) + transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self)) + transformers.append(ParamTransformerActionScaling("action_scaling")) + transformers.append(ParamTransformerDistributionFunction("dist_fn")) + return transformers + + +@dataclass +class ParamsMixinGeneralAdvantageEstimation(GetParamTransformersProtocol): + gae_lambda: float = 0.95 + """ + determines the blend between Monte Carlo and one-step temporal difference (TD) estimates of the advantage + function in general advantage estimation (GAE). + A value of 0 gives a fully TD-based estimate; lambda=1 gives a fully Monte Carlo estimate. + """ + max_batchsize: int = 256 + """the maximum size of the batch when computing general advantage estimation (GAE)""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [] + + +@dataclass +class A2CParams(PGParams, ParamsMixinGeneralAdvantageEstimation): + vf_coef: float = 0.5 + """weight (coefficient) of the value loss in the loss function""" + ent_coef: float = 0.01 + """weight (coefficient) of the entropy loss in the loss function""" + max_grad_norm: float | None = None + """maximum norm for clipping gradients in backpropagation""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinGeneralAdvantageEstimation._get_param_transformers(self)) + return transformers + + +@dataclass +class PPOParams(A2CParams): + eps_clip: float = 0.2 + """ + determines the range of allowed change in the policy during a policy update: + The ratio between the probabilities indicated by the new and old policy is + constrained to stay in the interval [1 - eps_clip, 1 + eps_clip]. + Small values thus force the new policy to stay close to the old policy. + Typical values range between 0.1 and 0.3. + The optimal epsilon depends on the environment; more stochastic environments may need larger epsilons. + """ + dual_clip: float | None = None + """ + determines the lower bound clipping for the probability ratio + (corresponds to parameter c in arXiv:1912.09729, Equation 5). + If set to None, dual clipping is not used and the bounds described in parameter eps_clip apply. + If set to a float value c, the lower bound is changed from 1 - eps_clip to c, + where c < 1 - eps_clip. + Setting c > 0 reduces policy oscillation and further stabilizes training. + Typical values are between 0 and 0.5. Smaller values provide more stability. + Setting c = 0 yields PPO with only the upper bound. + """ + value_clip: bool = False + """ + whether to apply clipping of the predicted value function during policy learning. + Value clipping discourages large changes in value predictions between updates. + Inaccurate value predictions can lead to bad policy updates, which can cause training instability. + Clipping values prevents sporadic large errors from skewing policy updates too much. + """ + advantage_normalization: bool = True + """whether to apply per mini-batch advantage normalization.""" + recompute_advantage: bool = False + """ + whether to recompute advantage every update repeat as described in + https://arxiv.org/pdf/2006.05990.pdf, Sec. 3.5. + The original PPO implementation splits the data in each policy iteration + step into individual transitions and then randomly assigns them to minibatches. + This makes it impossible to compute advantages as the temporal structure is broken. + Therefore, the advantages are computed once at the beginning of each policy iteration step and + then used in minibatch policy and value function optimization. + This results in higher diversity of data in each minibatch at the cost of + using slightly stale advantage estimations. + Enabling this option will, as a remedy to this problem, recompute the advantages at the beginning + of each pass over the data instead of just once per iteration. + """ + + +@dataclass +class NPGParams(PGParams, ParamsMixinGeneralAdvantageEstimation): + optim_critic_iters: int = 5 + """number of times to optimize critic network per update.""" + actor_step_size: float = 0.5 + """step size for actor update in natural gradient direction""" + advantage_normalization: bool = True + """whether to do per mini-batch advantage normalization.""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinGeneralAdvantageEstimation._get_param_transformers(self)) + return transformers + + +@dataclass +class TRPOParams(NPGParams): + max_kl: float = 0.01 + """ + maximum KL divergence, used to constrain each actor network update. + """ + backtrack_coeff: float = 0.8 + """ + coefficient with which to reduce the step size when constraints are not met. + """ + max_backtracks: int = 10 + """maximum number of times to backtrack in line search when the constraints are not met.""" + + +@dataclass +class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol): + actor_lr: float = 1e-3 + """the learning rate to use for the actor network""" + critic1_lr: float = 1e-3 + """the learning rate to use for the first critic network""" + critic2_lr: float = 1e-3 + """the learning rate to use for the second critic network""" + actor_lr_scheduler_factory: LRSchedulerFactory | None = None + """factory for the creation of a learning rate scheduler to use for the actor network (if any)""" + critic1_lr_scheduler_factory: LRSchedulerFactory | None = None + """factory for the creation of a learning rate scheduler to use for the first critic network (if any)""" + critic2_lr_scheduler_factory: LRSchedulerFactory | None = None + """factory for the creation of a learning rate scheduler to use for the second critic network (if any)""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [ + ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"), + ParamTransformerActorDualCriticsLRScheduler( + "actor_lr_scheduler_factory", + "critic1_lr_scheduler_factory", + "critic2_lr_scheduler_factory", + "lr_scheduler", + ), + ] + + +@dataclass +class _SACParams(Params, ParamsMixinActorAndDualCritics): + tau: float = 0.005 + """controls the contribution of the entropy term in the overall optimization objective, + i.e. the desired amount of randomness in the optimal policy. + Higher values mean greater target entropy and therefore more randomness in the policy. + Lower values mean lower target entropy and therefore a more deterministic policy. + """ + gamma: float = 0.99 + """discount factor (gamma) for future rewards; must be in [0, 1]""" + alpha: float | AutoAlphaFactory = 0.2 + """ + controls the relative importance (coefficient) of the entropy term in the loss function. + This can be a constant or a factory for the creation of a representation that allows the + parameter to be automatically tuned; + use :class:`tianshou.highlevel.params.alpha.AutoAlphaFactoryDefault` for the standard + auto-adjusted alpha. + """ + estimation_step: int = 1 + """the number of steps to look ahead""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) + transformers.append(ParamTransformerAutoAlpha("alpha")) + return transformers + + +@dataclass +class SACParams(_SACParams, ParamsMixinExplorationNoise, ParamsMixinActionScaling): + deterministic_eval: bool = True + """ + whether to use deterministic action (mean of Gaussian policy) in evaluation mode instead of stochastic + action sampled by the policy. Does not affect training.""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) + transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) + return transformers + + +@dataclass +class DiscreteSACParams(_SACParams): + pass + + +@dataclass +class DQNParams(Params, ParamsMixinLearningRateWithScheduler): + discount_factor: float = 0.99 + """ + discount factor (gamma) for future rewards; must be in [0, 1] + """ + estimation_step: int = 1 + """the number of steps to look ahead""" + target_update_freq: int = 0 + """the target network update frequency (0 if no target network is to be used)""" + reward_normalization: bool = False + """whether to normalize the returns to Normal(0, 1)""" + is_double: bool = True + """whether to use double Q learning""" + clip_loss_grad: bool = False + """whether to clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber + loss instead of the MSE loss.""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self)) + return transformers + + +@dataclass +class IQNParams(DQNParams): + sample_size: int = 32 + """the number of samples for policy evaluation""" + online_sample_size: int = 8 + """the number of samples for online model in training""" + target_sample_size: int = 8 + """the number of samples for target model in training.""" + num_quantiles: int = 200 + """the number of quantile midpoints in the inverse cumulative distribution function of the value""" + hidden_sizes: Sequence[int] = () + """hidden dimensions to use in the IQN network""" + num_cosines: int = 64 + """number of cosines to use in the IQN network""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.append(ParamTransformerDrop("hidden_sizes", "num_cosines")) + return transformers + + +@dataclass +class DDPGParams( + Params, + ParamsMixinActorAndCritic, + ParamsMixinExplorationNoise, + ParamsMixinActionScaling, +): + tau: float = 0.005 + """ + controls the soft update of the target network. + It determines how slowly the target networks track the main networks. + Smaller tau means slower tracking and more stable learning. + """ + gamma: float = 0.99 + """discount factor (gamma) for future rewards; must be in [0, 1]""" + estimation_step: int = 1 + """the number of steps to look ahead.""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self)) + transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) + transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) + return transformers + + +@dataclass +class REDQParams(DDPGParams): + ensemble_size: int = 10 + """the number of sub-networks in the critic ensemble""" + subset_size: int = 2 + """the number of networks in the subset""" + alpha: float | AutoAlphaFactory = 0.2 + """ + controls the relative importance (coefficient) of the entropy term in the loss function. + This can be a constant or a factory for the creation of a representation that allows the + parameter to be automatically tuned; + use :class:`tianshou.highlevel.params.alpha.AutoAlphaFactoryDefault` for the standard + auto-adjusted alpha. + """ + estimation_step: int = 1 + """the number of steps to look ahead""" + actor_delay: int = 20 + """the number of critic updates before an actor update""" + deterministic_eval: bool = True + """ + whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation. + Does not affect training. + """ + target_mode: Literal["mean", "min"] = "min" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.append(ParamTransformerAutoAlpha("alpha")) + return transformers + + +@dataclass +class TD3Params( + Params, + ParamsMixinActorAndDualCritics, + ParamsMixinExplorationNoise, + ParamsMixinActionScaling, +): + tau: float = 0.005 + """ + controls the soft update of the target network. + It determines how slowly the target networks track the main networks. + Smaller tau means slower tracking and more stable learning. + """ + gamma: float = 0.99 + """discount factor (gamma) for future rewards; must be in [0, 1]""" + policy_noise: float | FloatEnvValueFactory = 0.2 + """the scale of the the noise used in updating policy network""" + noise_clip: float | FloatEnvValueFactory = 0.5 + """determines the clipping range of the noise used in updating the policy network as [-noise_clip, noise_clip]""" + update_actor_freq: int = 2 + """the update frequency of actor network""" + estimation_step: int = 1 + """the number of steps to look ahead.""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) + transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) + transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) + transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise")) + transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip")) + return transformers diff --git a/examples/atari/tianshou/highlevel/params/policy_wrapper.py b/examples/atari/tianshou/highlevel/params/policy_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..e7958224db107ac68f4bb8f96c7316715a6e0078 --- /dev/null +++ b/examples/atari/tianshou/highlevel/params/policy_wrapper.py @@ -0,0 +1,76 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Generic, TypeVar + +from tianshou.highlevel.env import Environments +from tianshou.highlevel.module.core import TDevice +from tianshou.highlevel.module.intermediate import IntermediateModuleFactory +from tianshou.highlevel.optim import OptimizerFactory +from tianshou.policy import BasePolicy, ICMPolicy +from tianshou.utils.net.discrete import IntrinsicCuriosityModule +from tianshou.utils.string import ToStringMixin + +TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy) + + +class PolicyWrapperFactory(Generic[TPolicyOut], ToStringMixin, ABC): + @abstractmethod + def create_wrapped_policy( + self, + policy: BasePolicy, + envs: Environments, + optim_factory: OptimizerFactory, + device: TDevice, + ) -> TPolicyOut: + pass + + +class PolicyWrapperFactoryIntrinsicCuriosity( + PolicyWrapperFactory[ICMPolicy], +): + def __init__( + self, + *, + feature_net_factory: IntermediateModuleFactory, + hidden_sizes: Sequence[int], + lr: float, + lr_scale: float, + reward_scale: float, + forward_loss_weight: float, + ): + self.feature_net_factory = feature_net_factory + self.hidden_sizes = hidden_sizes + self.lr = lr + self.lr_scale = lr_scale + self.reward_scale = reward_scale + self.forward_loss_weight = forward_loss_weight + + def create_wrapped_policy( + self, + policy: BasePolicy, + envs: Environments, + optim_factory: OptimizerFactory, + device: TDevice, + ) -> ICMPolicy: + feature_net = self.feature_net_factory.create_intermediate_module(envs, device) + action_dim = envs.get_action_shape() + if not isinstance(action_dim, int): + raise ValueError(f"Environment action shape must be an integer, got {action_dim}") + feature_dim = feature_net.output_dim + icm_net = IntrinsicCuriosityModule( + feature_net.module, + feature_dim, + action_dim, + hidden_sizes=self.hidden_sizes, + device=device, + ) + icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr) + return ICMPolicy( + policy=policy, + model=icm_net, + optim=icm_optim, + action_space=envs.get_action_space(), + lr_scale=self.lr_scale, + reward_scale=self.reward_scale, + forward_loss_weight=self.forward_loss_weight, + ).to(device) diff --git a/examples/atari/tianshou/highlevel/persistence.py b/examples/atari/tianshou/highlevel/persistence.py new file mode 100644 index 0000000000000000000000000000000000000000..52b18d1b635ef6811f3b40b6a62d5bef746e1c6d --- /dev/null +++ b/examples/atari/tianshou/highlevel/persistence.py @@ -0,0 +1,130 @@ +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable +from enum import Enum +from typing import TYPE_CHECKING + +import torch + +from tianshou.highlevel.world import World + +if TYPE_CHECKING: + from tianshou.highlevel.module.core import TDevice + +log = logging.getLogger(__name__) + + +class PersistEvent(Enum): + """Enumeration of persistence events that Persistence objects can react to.""" + + PERSIST_POLICY = "persist_policy" + """Policy neural network is persisted (new best found)""" + + +class RestoreEvent(Enum): + """Enumeration of restoration events that Persistence objects can react to.""" + + RESTORE_POLICY = "restore_policy" + """Policy neural network parameters are restored""" + + +class Persistence(ABC): + @abstractmethod + def persist(self, event: PersistEvent, world: World) -> None: + pass + + @abstractmethod + def restore(self, event: RestoreEvent, world: World) -> None: + pass + + +class PersistenceGroup(Persistence): + """Groups persistence handler such that they can be applied collectively.""" + + def __init__(self, *p: Persistence, enabled: bool = True): + self.items = p + self.enabled = enabled + + def persist(self, event: PersistEvent, world: World) -> None: + if not self.enabled: + return + for item in self.items: + item.persist(event, world) + + def restore(self, event: RestoreEvent, world: World) -> None: + for item in self.items: + item.restore(event, world) + + +class PolicyPersistence: + class Mode(Enum): + """Mode of persistence.""" + + POLICY_STATE_DICT = "policy_state_dict" + """Persist only the policy's state dictionary. Note that for a policy to be restored from + such a dictionary, it is necessary to first create a structurally equivalent object which can + accept the respective state.""" + POLICY = "policy" + """Persist the entire policy. This is larger but has the advantage of the policy being loadable + without requiring an environment to be instantiated. + It has the potential disadvantage that upon breaking code changes in the policy implementation + (e.g. renamed/moved class), it will no longer be loadable. + Note that a precondition is that the policy be picklable in its entirety. + """ + + def get_filename(self) -> str: + return self.value + ".pt" + + def __init__( + self, + additional_persistence: Persistence | None = None, + enabled: bool = True, + mode: Mode = Mode.POLICY, + ): + """Handles persistence of the policy. + + :param additional_persistence: a persistence instance which is to be invoked whenever + this object is used to persist/restore data + :param enabled: whether persistence is enabled (restoration is always enabled) + :param mode: the persistence mode + """ + self.additional_persistence = additional_persistence + self.enabled = enabled + self.mode = mode + + def persist(self, policy: torch.nn.Module, world: World) -> None: + if not self.enabled: + return + path = world.persist_path(self.mode.get_filename()) + match self.mode: + case self.Mode.POLICY_STATE_DICT: + log.info(f"Saving policy state dictionary in {path}") + torch.save(policy.state_dict(), path) + case self.Mode.POLICY: + log.info(f"Saving policy object in {path}") + torch.save(policy, path) + case _: + raise NotImplementedError + if self.additional_persistence is not None: + self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world) + + def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None: + path = world.restore_path(self.mode.get_filename()) + log.info(f"Restoring policy from {path}") + match self.mode: + case self.Mode.POLICY_STATE_DICT: + state_dict = torch.load(path, map_location=device) + case self.Mode.POLICY: + loaded_policy: torch.nn.Module = torch.load(path, map_location=device) + state_dict = loaded_policy.state_dict() + case _: + raise NotImplementedError + policy.load_state_dict(state_dict) + if self.additional_persistence is not None: + self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world) + + def get_save_best_fn(self, world: World) -> Callable[[torch.nn.Module], None]: + def save_best_fn(pol: torch.nn.Module) -> None: + self.persist(pol, world) + + return save_best_fn diff --git a/examples/atari/tianshou/highlevel/trainer.py b/examples/atari/tianshou/highlevel/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..4eccc6a183ceebb50e120681566821edff15df5a --- /dev/null +++ b/examples/atari/tianshou/highlevel/trainer.py @@ -0,0 +1,151 @@ +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from typing import TypeVar, cast + +from tianshou.highlevel.env import Environments +from tianshou.highlevel.logger import TLogger +from tianshou.policy import BasePolicy, DQNPolicy +from tianshou.utils.string import ToStringMixin + +TPolicy = TypeVar("TPolicy", bound=BasePolicy) +log = logging.getLogger(__name__) + + +class TrainingContext: + def __init__(self, policy: TPolicy, envs: Environments, logger: TLogger): + self.policy = policy + self.envs = envs + self.logger = logger + + +class EpochTrainCallback(ToStringMixin, ABC): + """Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase + of each epoch. + """ + + @abstractmethod + def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: + pass + + def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int], None]: + def fn(epoch: int, env_step: int) -> None: + return self.callback(epoch, env_step, context) + + return fn + + +class EpochTestCallback(ToStringMixin, ABC): + """Callback which is called at the beginning of the test phase of each epoch.""" + + @abstractmethod + def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: + pass + + def get_trainer_fn(self, context: TrainingContext) -> Callable[[int, int | None], None]: + def fn(epoch: int, env_step: int | None) -> None: + return self.callback(epoch, env_step, context) + + return fn + + +class EpochStopCallback(ToStringMixin, ABC): + """Callback which is called after the test phase of each epoch in order to determine + whether training should stop early. + """ + + @abstractmethod + def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: + """Determines whether training should stop. + + :param mean_rewards: the average undiscounted returns of the testing result + :param context: the training context + :return: True if the goal has been reached and training should stop, False otherwise + """ + + def get_trainer_fn(self, context: TrainingContext) -> Callable[[float], bool]: + def fn(mean_rewards: float) -> bool: + return self.should_stop(mean_rewards, context) + + return fn + + +@dataclass +class TrainerCallbacks: + """Container for callbacks used during training.""" + + epoch_train_callback: EpochTrainCallback | None = None + epoch_test_callback: EpochTestCallback | None = None + epoch_stop_callback: EpochStopCallback | None = None + + +class EpochTrainCallbackDQNSetEps(EpochTrainCallback): + """Sets the epsilon value for DQN-based policies at the beginning of the training + stage in each epoch. + """ + + def __init__(self, eps_test: float): + self.eps_test = eps_test + + def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: + policy = cast(DQNPolicy, context.policy) + policy.set_eps(self.eps_test) + + +class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback): + """Sets the epsilon value for DQN-based policies at the beginning of the training + stage in each epoch, using a linear decay in the first `decay_steps` steps. + """ + + def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = 1000000): + self.eps_train = eps_train + self.eps_train_final = eps_train_final + self.decay_steps = decay_steps + + def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: + policy = cast(DQNPolicy, context.policy) + logger = context.logger + if env_step <= self.decay_steps: + eps = self.eps_train - env_step / self.decay_steps * ( + self.eps_train - self.eps_train_final + ) + else: + eps = self.eps_train_final + policy.set_eps(eps) + logger.write("train/env_step", env_step, {"train/eps": eps}) + + +class EpochTestCallbackDQNSetEps(EpochTestCallback): + """Sets the epsilon value for DQN-based policies at the beginning of the test + stage in each epoch. + """ + + def __init__(self, eps_test: float): + self.eps_test = eps_test + + def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: + policy = cast(DQNPolicy, context.policy) + policy.set_eps(self.eps_test) + + +class EpochStopCallbackRewardThreshold(EpochStopCallback): + """Stops training once the mean rewards exceed the given reward threshold or the threshold that + is specified in the gymnasium environment (i.e. `env.spec.reward_threshold`). + """ + + def __init__(self, threshold: float | None = None): + """:param threshold: the reward threshold beyond which to stop training. + If it is None, use threshold given by the environment, i.e. `env.spec.reward_threshold`. + """ + self.threshold = threshold + + def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: + threshold = self.threshold + if threshold is None: + threshold = context.envs.env.spec.reward_threshold # type: ignore + assert threshold is not None + is_reached = mean_rewards >= threshold + if is_reached: + log.info(f"Reward threshold ({threshold}) exceeded") + return is_reached diff --git a/examples/atari/tianshou/highlevel/world.py b/examples/atari/tianshou/highlevel/world.py new file mode 100644 index 0000000000000000000000000000000000000000..c32ef9cbc4068292c2dd95f157a4f891c191bbb8 --- /dev/null +++ b/examples/atari/tianshou/highlevel/world.py @@ -0,0 +1,34 @@ +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from tianshou.data import BaseCollector + from tianshou.highlevel.env import Environments + from tianshou.highlevel.logger import TLogger + from tianshou.policy import BasePolicy + from tianshou.trainer import BaseTrainer + + +@dataclass +class World: + """Container for instances and configuration items that are relevant to an experiment.""" + + envs: "Environments" + policy: "BasePolicy" + train_collector: "BaseCollector" + test_collector: "BaseCollector" + logger: "TLogger" + persist_directory: str + restore_directory: str | None + trainer: Optional["BaseTrainer"] = None + + def persist_path(self, filename: str) -> str: + return os.path.abspath(os.path.join(self.persist_directory, filename)) + + def restore_path(self, filename: str) -> str: + if self.restore_directory is None: + raise ValueError( + "Path cannot be formed because no directory for restoration was provided", + ) + return os.path.join(self.restore_directory, filename) diff --git a/examples/atari/tianshou/policy/__init__.py b/examples/atari/tianshou/policy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df82637b849adf8990f1a0b5ca2e651fb815ff01 --- /dev/null +++ b/examples/atari/tianshou/policy/__init__.py @@ -0,0 +1,69 @@ +"""Policy package.""" +# isort:skip_file + +from tianshou.policy.base import BasePolicy, TrainingStats +from tianshou.policy.random import RandomPolicy +from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.bdq import BranchingDQNPolicy +from tianshou.policy.modelfree.c51 import C51Policy +from tianshou.policy.modelfree.rainbow import RainbowPolicy +from tianshou.policy.modelfree.qrdqn import QRDQNPolicy +from tianshou.policy.modelfree.iqn import IQNPolicy +from tianshou.policy.modelfree.fqf import FQFPolicy +from examples.atari.tianshou.policy.modelfree.fqf_rainbow import FQF_RainbowPolicy +from tianshou.policy.modelfree.pg import PGPolicy +from tianshou.policy.modelfree.a2c import A2CPolicy +from tianshou.policy.modelfree.npg import NPGPolicy +from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.modelfree.ppo import PPOPolicy +from tianshou.policy.modelfree.trpo import TRPOPolicy +from tianshou.policy.modelfree.td3 import TD3Policy +from tianshou.policy.modelfree.sac import SACPolicy +from tianshou.policy.modelfree.redq import REDQPolicy +from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy +from tianshou.policy.imitation.base import ImitationPolicy +from tianshou.policy.imitation.bcq import BCQPolicy +from tianshou.policy.imitation.cql import CQLPolicy +from tianshou.policy.imitation.td3_bc import TD3BCPolicy +from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy +from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy +from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy +from tianshou.policy.imitation.gail import GAILPolicy +from tianshou.policy.modelbased.psrl import PSRLPolicy +from tianshou.policy.modelbased.icm import ICMPolicy +from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager + +__all__ = [ + "BasePolicy", + "RandomPolicy", + "DQNPolicy", + "BranchingDQNPolicy", + "C51Policy", + "RainbowPolicy", + "QRDQNPolicy", + "IQNPolicy", + "FQFPolicy", + "FQF_RainbowPolicy", + "PGPolicy", + "A2CPolicy", + "NPGPolicy", + "DDPGPolicy", + "PPOPolicy", + "TRPOPolicy", + "TD3Policy", + "SACPolicy", + "REDQPolicy", + "DiscreteSACPolicy", + "ImitationPolicy", + "BCQPolicy", + "CQLPolicy", + "TD3BCPolicy", + "DiscreteBCQPolicy", + "DiscreteCQLPolicy", + "DiscreteCRRPolicy", + "GAILPolicy", + "PSRLPolicy", + "ICMPolicy", + "MultiAgentPolicyManager", + "TrainingStats", +] diff --git a/examples/atari/tianshou/policy/base.py b/examples/atari/tianshou/policy/base.py new file mode 100644 index 0000000000000000000000000000000000000000..b7ae5f23d516570f08185e0978586e91a96bed74 --- /dev/null +++ b/examples/atari/tianshou/policy/base.py @@ -0,0 +1,769 @@ +import logging +import time +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete +from numba import njit +from overrides import override +from torch import nn + +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch_as +from tianshou.data.batch import Batch, BatchProtocol, arr_type +from tianshou.data.buffer.base import TBuffer +from tianshou.data.types import ( + ActBatchProtocol, + ActStateBatchProtocol, + BatchWithReturnsProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.utils import MultipleLRSchedulers +from tianshou.utils.print import DataclassPPrintMixin +from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode + +logger = logging.getLogger(__name__) + +TLearningRateScheduler: TypeAlias = torch.optim.lr_scheduler.LRScheduler | MultipleLRSchedulers + + +@dataclass(kw_only=True) +class TrainingStats(DataclassPPrintMixin): + _non_loss_fields = ("train_time", "smoothed_loss") + + train_time: float = 0.0 + """The time for learning models.""" + + # TODO: modified in the trainer but not used anywhere else. Should be refactored. + smoothed_loss: dict = field(default_factory=dict) + """The smoothed loss statistics of the policy learn step.""" + + # Mainly so that we can override this in the TrainingStatsWrapper + def _get_self_dict(self) -> dict[str, Any]: + return self.__dict__ + + def get_loss_stats_dict(self) -> dict[str, float]: + """Return loss statistics as a dict for logging. + + Returns a dict with all fields except train_time and smoothed_loss. Moreover, fields with value None excluded, + and instances of SequenceSummaryStats are replaced by their mean. + """ + result = {} + for k, v in self._get_self_dict().items(): + if k.startswith("_"): + logger.debug(f"Skipping {k=} as it starts with an underscore.") + continue + if k in self._non_loss_fields or v is None: + continue + if isinstance(v, SequenceSummaryStats): + result[k] = v.mean + else: + result[k] = v + + return result + + +class TrainingStatsWrapper(TrainingStats): + _setattr_frozen = False + _training_stats_public_fields = TrainingStats.__dataclass_fields__.keys() + + def __init__(self, wrapped_stats: TrainingStats) -> None: + """In this particular case, super().__init__() should be called LAST in the subclass init.""" + self._wrapped_stats = wrapped_stats + + # HACK: special sauce for the existing attributes of the base TrainingStats class + # for some reason, delattr doesn't work here, so we need to delegate their handling + # to the wrapped stats object by always keeping the value there and in self in sync + # see also __setattr__ + for k in self._training_stats_public_fields: + super().__setattr__(k, getattr(self._wrapped_stats, k)) + + self._setattr_frozen = True + + @override + def _get_self_dict(self) -> dict[str, Any]: + return {**self._wrapped_stats._get_self_dict(), **self.__dict__} + + @property + def wrapped_stats(self) -> TrainingStats: + return self._wrapped_stats + + def __getattr__(self, name: str) -> Any: + return getattr(self._wrapped_stats, name) + + def __setattr__(self, name: str, value: Any) -> None: + """Setattr logic for wrapper of a dataclass with default values. + + 1. If name exists directly in self, set it there. + 2. If it exists in self._wrapped_stats, set it there instead. + 3. Special case: if name is in the base TrainingStats class, keep it in sync between self and the _wrapped_stats. + 4. If name doesn't exist in either and attribute setting is frozen, raise an AttributeError. + """ + # HACK: special sauce for the existing attributes of the base TrainingStats class, see init + # Need to keep them in sync with the wrapped stats object + if name in self._training_stats_public_fields: + setattr(self._wrapped_stats, name, value) + super().__setattr__(name, value) + return + + if not self._setattr_frozen: + super().__setattr__(name, value) + return + + if not hasattr(self, name): + raise AttributeError( + f"Setting new attributes on StatsWrappers outside of init is not allowed. " + f"Tried to set {name=}, {value=} on {self.__class__.__name__}. \n" + f"NOTE: you may get this error if you call super().__init__() in your subclass init too early! " + f"The call to super().__init__() should be the last call in your subclass init.", + ) + if hasattr(self._wrapped_stats, name): + setattr(self._wrapped_stats, name, value) + else: + super().__setattr__(name, value) + + +TTrainingStats = TypeVar("TTrainingStats", bound=TrainingStats) + + +class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): + """The base class for any RL policy. + + Tianshou aims to modularize RL algorithms. It comes into several classes of + policies in Tianshou. All policy classes must inherit from + :class:`~tianshou.policy.BasePolicy`. + + A policy class typically has the following parts: + + * :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including \ + coping the target network and so on; + * :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \ + observation; + * :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the \ + replay buffer (this function can interact with replay buffer); + * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of \ + data. + * :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the replay buffer \ + from the learning process (e.g., prioritized replay buffer needs to update \ + the weight); + * :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training, \ + i.e., `process_fn -> learn -> post_process_fn`. + + Most of the policy needs a neural network to predict the action and an + optimizer to optimize the policy. The rules of self-defined networks are: + + 1. Input: observation "obs" (may be a ``numpy.ndarray``, a ``torch.Tensor``, a \ + dict or any others), hidden state "state" (for RNN usage), and other information \ + "info" provided by the environment. + 2. Output: some "logits", the next hidden state "state", and the intermediate \ + result during policy forwarding procedure "policy". The "logits" could be a tuple \ + instead of a ``torch.Tensor``. It depends on how the policy process the network \ + output. For example, in PPO, the return of the network might be \ + ``(mu, sigma), state`` for Gaussian policy. The "policy" can be a Batch of \ + torch.Tensor or other things, which will be stored in the replay buffer, and can \ + be accessed in the policy update process (e.g. in "policy.learn()", the \ + "batch.policy" is what you need). + + Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``, you can + use :class:`~tianshou.policy.BasePolicy` almost the same as ``torch.nn.Module``, + for instance, loading and saving the model: + :: + + torch.save(policy.state_dict(), "policy.pth") + policy.load_state_dict(torch.load("policy.pth")) + + :param action_space: Env's action_space. + :param observation_space: Env's observation space. TODO: appears unused... + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + + def __init__( + self, + *, + action_space: gym.Space, + # TODO: does the policy actually need the observation space? + observation_space: gym.Space | None = None, + action_scaling: bool = False, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + allowed_action_bound_methods = ("clip", "tanh") + if ( + action_bound_method is not None + and action_bound_method not in allowed_action_bound_methods + ): + raise ValueError( + f"Got invalid {action_bound_method=}. " + f"Valid values are: {allowed_action_bound_methods}.", + ) + if action_scaling and not isinstance(action_space, Box): + raise ValueError( + f"action_scaling can only be True when action_space is Box but " + f"got: {action_space}", + ) + + super().__init__() + self.observation_space = observation_space + self.action_space = action_space + if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary): + action_type = "discrete" + elif isinstance(action_space, Box): + action_type = "continuous" + else: + raise ValueError(f"Unsupported action space: {action_space}.") + self._action_type = cast(Literal["discrete", "continuous"], action_type) + self.agent_id = 0 + self.updating = False + self.action_scaling = action_scaling + self.action_bound_method = action_bound_method + self.lr_scheduler = lr_scheduler + self.is_within_training_step = False + """ + flag indicating whether we are currently within a training step, + which encompasses data collection for training (in online RL algorithms) + and the policy update (gradient steps). + + It can be used, for example, to control whether a flag controlling deterministic evaluation should + indeed be applied, because within a training step, we typically always want to apply stochastic evaluation + (even if such a flag is enabled), as well as stochastic action computation for q-targets (e.g. in SAC + based algorithms). + + This flag should normally remain False and should be set to True only by the algorithm which performs + training steps. This is done automatically by the Trainer classes. If a policy is used outside of a Trainer, + the user should ensure that this flag is set correctly before calling update or learn. + """ + self._compile() + + def __setstate__(self, state: dict[str, Any]) -> None: + # TODO Use setstate function once merged + if "is_within_training_step" not in state: + state["is_within_training_step"] = False + self.__dict__ = state + + @property + def action_type(self) -> Literal["discrete", "continuous"]: + return self._action_type + + def set_agent_id(self, agent_id: int) -> None: + """Set self.agent_id = agent_id, for MARL.""" + self.agent_id = agent_id + + # TODO: needed, since for most of offline algorithm, the algorithm itself doesn't + # have a method to add noise to action. + # So we add the default behavior here. It's a little messy, maybe one can + # find a better way to do this. + + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + + def exploration_noise( + self, + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: + """Modify the action from policy.forward with exploration noise. + + NOTE: currently does not add any noise! Needs to be overridden by subclasses + to actually do something. + + :param act: a data batch or numpy.ndarray which is the action taken by + policy.forward. + :param batch: the input batch for policy.forward, kept for advanced usage. + :return: action in the same form of input "act" but with added exploration + noise. + """ + return act + + def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None: + """Softly update the parameters of target module towards the parameters of source module.""" + for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True): + tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data) + + def compute_action( + self, + obs: arr_type, + info: dict[str, Any] | None = None, + state: dict | BatchProtocol | np.ndarray | None = None, + ) -> np.ndarray | int: + """Get action as int (for discrete env's) or array (for continuous ones) from an env's observation and info. + + :param obs: observation from the gym's env. + :param info: information given by the gym's env. + :param state: the hidden state of RNN policy, used for recurrent policy. + :return: action as int (for discrete env's) or array (for continuous ones). + """ + # need to add empty batch dimension + obs = obs[None, :] + obs_batch = cast(ObsBatchProtocol, Batch(obs=obs, info=info)) + act = self.forward(obs_batch, state=state).act.squeeze() + if isinstance(act, torch.Tensor): + act = act.detach().cpu().numpy() + act = self.map_action(act) + if isinstance(self.action_space, Discrete): + # could be an array of shape (), easier to just convert to int + act = int(act) # type: ignore + return act + + @abstractmethod + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> ActBatchProtocol | ActStateBatchProtocol: # TODO: make consistent typing + """Compute action over the given batch data. + + :return: A :class:`~tianshou.data.Batch` which MUST have the following keys: + + * ``act`` a numpy.ndarray or a torch.Tensor, the action over \ + given batch data. + * ``state`` a dict, a numpy.ndarray or a torch.Tensor, the \ + internal state of the policy, ``None`` as default. + + Other keys are user-defined. It depends on the algorithm. For example, + :: + + # some code + return Batch(logits=..., act=..., state=None, dist=...) + + The keyword ``policy`` is reserved and the corresponding data will be + stored into the replay buffer. For instance, + :: + + # some code + return Batch(..., policy=Batch(log_prob=dist.log_prob(act))) + # and in the sampled data batch, you can directly use + # batch.policy.log_prob to get your data. + + .. note:: + + In continuous action space, you should do another step "map_action" to get + the real action: + :: + + act = policy(batch).act # doesn't map to the target action range + act = policy.map_action(act, batch) + """ + + @staticmethod + def _action_to_numpy(act: arr_type) -> np.ndarray: + act = to_numpy(act) # NOTE: to_numpy could confusingly also return a Batch + if not isinstance(act, np.ndarray): + raise ValueError( + f"act should have been be a numpy.ndarray, but got {type(act)}.", + ) + return act + + def map_action( + self, + act: arr_type, + ) -> np.ndarray: + """Map raw network output to action range in gym's env.action_space. + + This function is called in :meth:`~tianshou.data.Collector.collect` and only + affects action sending to env. Remapped action will not be stored in buffer + and thus can be viewed as a part of env (a black box action transformation). + + Action mapping includes 2 standard procedures: bounding and scaling. Bounding + procedure expects original action range is (-inf, inf) and maps it to [-1, 1], + while scaling procedure expects original action range is (-1, 1) and maps it + to [action_space.low, action_space.high]. Bounding procedure is applied first. + + :param act: a data batch or numpy.ndarray which is the action taken by + policy.forward. + + :return: action in the same form of input "act" but remap to the target action + space. + """ + act = self._action_to_numpy(act) + if isinstance(self.action_space, gym.spaces.Box): + if self.action_bound_method == "clip": + act = np.clip(act, -1.0, 1.0) + elif self.action_bound_method == "tanh": + act = np.tanh(act) + if self.action_scaling: + assert ( + np.min(act) >= -1.0 and np.max(act) <= 1.0 + ), f"action scaling only accepts raw action range = [-1, 1], but got: {act}" + low, high = self.action_space.low, self.action_space.high + act = low + (high - low) * (act + 1.0) / 2.0 + return act + + def map_action_inverse( + self, + act: arr_type, + ) -> np.ndarray: + """Inverse operation to :meth:`~tianshou.policy.BasePolicy.map_action`. + + This function is called in :meth:`~tianshou.data.Collector.collect` for + random initial steps. It scales [action_space.low, action_space.high] to + the value ranges of policy.forward. + + :param act: a data batch, list or numpy.ndarray which is the action taken + by gym.spaces.Box.sample(). + + :return: action remapped. + """ + act = self._action_to_numpy(act) + if isinstance(self.action_space, gym.spaces.Box): + if self.action_scaling: + low, high = self.action_space.low, self.action_space.high + scale = high - low + eps = np.finfo(np.float32).eps.item() + scale[scale < eps] += eps + act = (act - low) * 2.0 / scale - 1.0 + if self.action_bound_method == "tanh": + act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0 + + return act + + def process_buffer(self, buffer: TBuffer) -> TBuffer: + """Pre-process the replay buffer, e.g., to add new keys. + + Used in BaseTrainer initialization method, usually used by offline trainers. + + Note: this will only be called once, when the trainer is initialized! + If the buffer is empty by then, there will be nothing to process. + This method is meant to be overridden by policies which will be trained + offline at some stage, e.g., in a pre-training step. + """ + return buffer + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol: + """Pre-process the data from the provided replay buffer. + + Meant to be overridden by subclasses. Typical usage is to add new keys to the + batch, e.g., to add the value function of the next state. Used in :meth:`update`, + which is usually called repeatedly during training. + + For modifying the replay buffer only once at the beginning + (e.g., for offline learning) see :meth:`process_buffer`. + """ + return batch + + @abstractmethod + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTrainingStats: + """Update policy with a given batch of data. + + :return: A dataclass object, including the data needed to be logged (e.g., loss). + + .. note:: + + In order to distinguish the collecting state, updating state and + testing state, you can check the policy state by ``self.training`` + and ``self.updating``. Please refer to :ref:`policy_state` for more + detailed explanation. + + .. warning:: + + If you use ``torch.distributions.Normal`` and + ``torch.distributions.Categorical`` to calculate the log_prob, + please be careful about the shape: Categorical distribution gives + "[batch_size]" shape while Normal distribution gives "[batch_size, + 1]" shape. The auto-broadcasting of numerical operation with torch + tensors will amplify this error. + """ + + def post_process_fn( + self, + batch: BatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> None: + """Post-process the data from the provided replay buffer. + + This will only have an effect if the buffer has the + method `update_weight` and the batch has the attribute `weight`. + + Typical usage is to update the sampling weight in prioritized + experience replay. Used in :meth:`update`. + """ + if hasattr(buffer, "update_weight"): + if hasattr(batch, "weight"): + buffer.update_weight(indices, batch.weight) + else: + logger.warning( + "batch has no attribute 'weight', but buffer has an " + "update_weight method. This is probably a mistake." + "Prioritized replay is disabled for this batch.", + ) + + def update( + self, + sample_size: int | None, + buffer: ReplayBuffer | None, + **kwargs: Any, + ) -> TTrainingStats: + """Update the policy network and replay buffer. + + It includes 3 function steps: process_fn, learn, and post_process_fn. In + addition, this function will change the value of ``self.updating``: it will be + False before this function and will be True when executing :meth:`update`. + Please refer to :ref:`policy_state` for more detailed explanation. The return + value of learn is augmented with the training time within update, while smoothed + loss values are computed in the trainer. + + :param sample_size: 0 means it will extract all the data from the buffer, + otherwise it will sample a batch with given sample_size. None also + means it will extract all the data from the buffer, but it will be shuffled + first. TODO: remove the option for 0? + :param buffer: the corresponding replay buffer. + + :return: A dataclass object containing the data needed to be logged (e.g., loss) from + ``policy.learn()``. + """ + # TODO: when does this happen? + # -> this happens never in practice as update is either called with a collector buffer or an assert before + + if not self.is_within_training_step: + raise RuntimeError( + f"update() was called outside of a training step as signalled by {self.is_within_training_step=} " + f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned " + f"flag yourself. You can to this e.g., by using the contextmanager {policy_within_training_step.__name__}.", + ) + + if buffer is None: + return TrainingStats() # type: ignore[return-value] + start_time = time.time() + batch, indices = buffer.sample(sample_size) + self.updating = True + batch = self.process_fn(batch, buffer, indices) + with torch_train_mode(self): + training_stat = self.learn(batch, **kwargs) + self.post_process_fn(batch, buffer, indices) + if self.lr_scheduler is not None: + self.lr_scheduler.step() + self.updating = False + training_stat.train_time = time.time() - start_time + return training_stat + + @staticmethod + def value_mask(buffer: ReplayBuffer, indices: np.ndarray) -> np.ndarray: + """Value mask determines whether the obs_next of buffer[indices] is valid. + + For instance, usually "obs_next" after "done" flag is considered to be invalid, + and its q/advantage value can provide meaningless (even misleading) + information, and should be set to 0 by hand. But if "done" flag is generated + because timelimit of game length (info["TimeLimit.truncated"] is set to True in + gym's settings), "obs_next" will instead be valid. Value mask is typically used + for assisting in calculating the correct q/advantage value. + + :param buffer: the corresponding replay buffer. + :param numpy.ndarray indices: indices of replay buffer whose "obs_next" will be + judged. + + :return: A bool type numpy.ndarray in the same shape with indices. "True" means + "obs_next" of that buffer[indices] is valid. + """ + return ~buffer.terminated[indices] + + @staticmethod + def compute_episodic_return( + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + v_s_: np.ndarray | torch.Tensor | None = None, + v_s: np.ndarray | torch.Tensor | None = None, + gamma: float = 0.99, + gae_lambda: float = 0.95, + ) -> tuple[np.ndarray, np.ndarray]: + r"""Compute returns over given batch. + + Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) + to calculate q/advantage value of given batch. Returns are calculated as + advantage + value, which is exactly equivalent to using :math:`TD(\lambda)` + for estimating returns. + + Setting `v_s_` and `v_s` to None (or all zeros) and `gae_lambda` to 1.0 calculates the + discounted return-to-go/ Monte-Carlo return. + + :param batch: a data batch which contains several episodes of data in + sequential order. Mind that the end of each finished episode of batch + should be marked by done flag, unfinished (or collecting) episodes will be + recognized by buffer.unfinished_index(). + :param buffer: the corresponding replay buffer. + :param indices: tells the batch's location in buffer, batch is equal + to buffer[indices]. + :param v_s_: the value function of all next states :math:`V(s')`. + If None, it will be set to an array of 0. + :param v_s: the value function of all current states :math:`V(s)`. If None, + it is set based upon `v_s_` rolled by 1. + :param gamma: the discount factor, should be in [0, 1]. + :param gae_lambda: the parameter for Generalized Advantage Estimation, + should be in [0, 1]. + + :return: two numpy arrays (returns, advantage) with each shape (bsz, ). + """ + rew = batch.rew + if v_s_ is None: + assert np.isclose(gae_lambda, 1.0) + v_s_ = np.zeros_like(rew) + else: + v_s_ = to_numpy(v_s_.flatten()) + v_s_ = v_s_ * BasePolicy.value_mask(buffer, indices) + v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten()) + + end_flag = np.logical_or(batch.terminated, batch.truncated) + end_flag[np.isin(indices, buffer.unfinished_index())] = True + advantage = _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + returns = advantage + v_s + # normalization varies from each policy, so we don't do it here + return returns, advantage + + @staticmethod + def compute_nstep_return( + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], + gamma: float = 0.99, + n_step: int = 1, + rew_norm: bool = False, + ) -> BatchWithReturnsProtocol: + r"""Compute n-step return for Q-learning targets. + + .. math:: + G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n}) + + where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, + :math:`d_t` is the done flag of step :math:`t`. + + :param batch: a data batch, which is equal to buffer[indices]. + :param buffer: the data buffer. + :param indices: tell batch's location in buffer + :param function target_q_fn: a function which compute target Q value + of "obs_next" given data buffer and wanted indices. + :param gamma: the discount factor, should be in [0, 1]. + :param n_step: the number of estimation step, should be an int greater + than 0. + :param rew_norm: normalize the reward to Normal(0, 1). + TODO: passing True is not supported and will cause an error! + :return: a Batch. The result will be stored in batch.returns as a + torch.Tensor with the same shape as target_q_fn's return tensor. + """ + assert not rew_norm, "Reward normalization in computing n-step returns is unsupported now." + if len(indices) != len(batch): + raise ValueError(f"Batch size {len(batch)} and indices size {len(indices)} mismatch.") + + rew = buffer.rew + bsz = len(indices) + indices = [indices] + for _ in range(n_step - 1): + indices.append(buffer.next(indices[-1])) + indices = np.stack(indices) + # terminal indicates buffer indexes nstep after 'indices', + # and are truncated at the end of each episode + terminal = indices[-1] + with torch.no_grad(): + target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) + target_q = to_numpy(target_q_torch.reshape(bsz, -1)) + target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape(-1, 1) + end_flag = buffer.done.copy() + end_flag[buffer.unfinished_index()] = True + target_q = _nstep_return(rew, end_flag, target_q, indices, gamma, n_step) + + batch.returns = to_torch_as(target_q, target_q_torch) + if hasattr(batch, "weight"): # prio buffer update + batch.weight = to_torch_as(batch.weight, target_q_torch) + return cast(BatchWithReturnsProtocol, batch) + + @staticmethod + def _compile() -> None: + f64 = np.array([0, 1], dtype=np.float64) + f32 = np.array([0, 1], dtype=np.float32) + b = np.array([False, True], dtype=np.bool_) + i64 = np.array([[0, 1]], dtype=np.int64) + _gae_return(f64, f64, f64, b, 0.1, 0.1) + _gae_return(f32, f32, f64, b, 0.1, 0.1) + _nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1) + + +# TODO: rename? See docstring +@njit +def _gae_return( + v_s: np.ndarray, + v_s_: np.ndarray, + rew: np.ndarray, + end_flag: np.ndarray, + gamma: float, + gae_lambda: float, +) -> np.ndarray: + r"""Computes advantages with GAE. + + Note: doesn't compute returns but rather advantages. The return + is given by the output of this + v_s. Note that the advantages plus v_s + is exactly the same as the TD-lambda target, which is computed by the recursive + formula: + + .. math:: + G_t^\lambda = r_t + \gamma ( \lambda G_{t+1}^\lambda + (1 - \lambda) V_{t+1} ) + + The GAE is computed recursively as: + + .. math:: + \delta_t = r_t + \gamma V_{t+1} - V_t \n + A_t^\lambda= \delta_t + \gamma \lambda A_{t+1}^\lambda + + And the following equality holds: + + .. math:: + G_t^\lambda = A_t^\lambda+ V_t + + :param v_s: values in an episode, i.e. $V_t$ + :param v_s_: next values in an episode, i.e. v_s shifted by 1, equivalent to + $V_{t+1}$ + :param rew: rewards in an episode, i.e. $r_t$ + :param end_flag: boolean array indicating whether the episode is done + :param gamma: discount factor + :param gae_lambda: lambda parameter for GAE, controlling the bias-variance tradeoff + :return: + """ + returns = np.zeros(rew.shape) + delta = rew + v_s_ * gamma - v_s + discount = (1.0 - end_flag) * (gamma * gae_lambda) + gae = 0.0 + for i in range(len(rew) - 1, -1, -1): + gae = delta[i] + discount[i] * gae + returns[i] = gae + return returns + + +@njit +def _nstep_return( + rew: np.ndarray, + end_flag: np.ndarray, + target_q: np.ndarray, + indices: np.ndarray, + gamma: float, + n_step: int, +) -> np.ndarray: + gamma_buffer = np.ones(n_step + 1) + for i in range(1, n_step + 1): + gamma_buffer[i] = gamma_buffer[i - 1] * gamma + target_shape = target_q.shape + bsz = target_shape[0] + # change target_q to 2d array + target_q = target_q.reshape(bsz, -1) + returns = np.zeros(target_q.shape) + gammas = np.full(indices[0].shape, n_step) + for n in range(n_step - 1, -1, -1): + now = indices[n] + gammas[end_flag[now] > 0] = n + 1 + returns[end_flag[now] > 0] = 0.0 + returns = rew[now].reshape(bsz, 1) + gamma * returns + target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns + return target_q.reshape(target_shape) diff --git a/examples/atari/tianshou/policy/imitation/__init__.py b/examples/atari/tianshou/policy/imitation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/atari/tianshou/policy/imitation/base.py b/examples/atari/tianshou/policy/imitation/base.py new file mode 100644 index 0000000000000000000000000000000000000000..6e21016d97aefe5dae24fc2c40b9a15f6676338c --- /dev/null +++ b/examples/atari/tianshou/policy/imitation/base.py @@ -0,0 +1,115 @@ +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.data import Batch, to_torch +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + ModelOutputBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.policy import BasePolicy +from tianshou.policy.base import TLearningRateScheduler, TrainingStats + +# Dimension Naming Convention +# B - Batch Size +# A - Action +# D - Dist input (usually 2, loc and scale) +# H - Dimension of hidden, can be None + + +@dataclass(kw_only=True) +class ImitationTrainingStats(TrainingStats): + loss: float = 0.0 + + +TImitationTrainingStats = TypeVar("TImitationTrainingStats", bound=ImitationTrainingStats) + + +class ImitationPolicy(BasePolicy[TImitationTrainingStats], Generic[TImitationTrainingStats]): + """Implementation of vanilla imitation learning. + + :param actor: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> a) + :param optim: for optimizing the model. + :param action_space: Env's action_space. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + actor: torch.nn.Module, + optim: torch.optim.Optimizer, + action_space: gym.Space, + observation_space: gym.Space | None = None, + action_scaling: bool = False, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + lr_scheduler=lr_scheduler, + ) + self.actor = actor + self.optim = optim + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> ModelOutputBatchProtocol: + # TODO - ALGO-REFACTORING: marked for refactoring when Algorithm abstraction is introduced + if self.action_type == "discrete": + # If it's discrete, the "actor" is usually a critic that maps obs to action_values + # which then could be turned into logits or a Categorigal + action_values_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + act_B = action_values_BA.argmax(dim=1) + result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) + elif self.action_type == "continuous": + # If it's continuous, the actor would usually deliver something like loc, scale determining a + # Gaussian dist + dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + result = Batch(logits=dist_input_BD, act=dist_input_BD, state=hidden_BH) + else: + raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!") + return cast(ModelOutputBatchProtocol, result) + + def learn( + self, + batch: RolloutBatchProtocol, + *ags: Any, + **kwargs: Any, + ) -> TImitationTrainingStats: + self.optim.zero_grad() + if self.action_type == "continuous": # regression + act = self(batch).act + act_target = to_torch(batch.act, dtype=torch.float32, device=act.device) + loss = F.mse_loss(act, act_target) + elif self.action_type == "discrete": # classification + act = F.log_softmax(self(batch).logits, dim=-1) + act_target = to_torch(batch.act, dtype=torch.long, device=act.device) + loss = F.nll_loss(act, act_target) + loss.backward() + self.optim.step() + + return ImitationTrainingStats(loss=loss.item()) # type: ignore diff --git a/examples/atari/tianshou/policy/imitation/bcq.py b/examples/atari/tianshou/policy/imitation/bcq.py new file mode 100644 index 0000000000000000000000000000000000000000..dee1a80a35fe9a7a3398dd28654c6881cdccfdc6 --- /dev/null +++ b/examples/atari/tianshou/policy/imitation/bcq.py @@ -0,0 +1,233 @@ +import copy +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.data import Batch, to_torch +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol +from tianshou.policy import BasePolicy +from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.utils.net.continuous import VAE +from tianshou.utils.optim import clone_optimizer + + +@dataclass(kw_only=True) +class BCQTrainingStats(TrainingStats): + actor_loss: float + critic1_loss: float + critic2_loss: float + vae_loss: float + + +TBCQTrainingStats = TypeVar("TBCQTrainingStats", bound=BCQTrainingStats) + + +class BCQPolicy(BasePolicy[TBCQTrainingStats], Generic[TBCQTrainingStats]): + """Implementation of BCQ algorithm. arXiv:1812.02900. + + :param actor_perturbation: the actor perturbation. `(s, a -> perturbed a)` + :param actor_perturbation_optim: the optimizer for actor network. + :param critic: the first critic network. + :param critic_optim: the optimizer for the first critic network. + :param critic2: the second critic network. + :param critic2_optim: the optimizer for the second critic network. + :param vae: the VAE network, generating actions similar to those in batch. + :param vae_optim: the optimizer for the VAE network. + :param device: which device to create this model on. + :param gamma: discount factor, in [0, 1]. + :param tau: param for soft update of the target network. + :param lmbda: param for Clipped Double Q-learning. + :param forward_sampled_times: the number of sampled actions in forward function. + The policy samples many actions and takes the action with the max value. + :param num_sampled_action: the number of sampled actions in calculating target Q. + The algorithm samples several actions using VAE, and perturbs each action to get the target Q. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. + """ + + def __init__( + self, + *, + actor_perturbation: torch.nn.Module, + actor_perturbation_optim: torch.optim.Optimizer, + critic: torch.nn.Module, + critic_optim: torch.optim.Optimizer, + action_space: gym.Space, + vae: VAE, + vae_optim: torch.optim.Optimizer, + critic2: torch.nn.Module | None = None, + critic2_optim: torch.optim.Optimizer | None = None, + # TODO: remove? Many policies don't use this + device: str | torch.device = "cpu", + gamma: float = 0.99, + tau: float = 0.005, + lmbda: float = 0.75, + forward_sampled_times: int = 100, + num_sampled_action: int = 10, + observation_space: gym.Space | None = None, + action_scaling: bool = False, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + # actor is Perturbation! + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + lr_scheduler=lr_scheduler, + ) + self.actor_perturbation = actor_perturbation + self.actor_perturbation_target = copy.deepcopy(self.actor_perturbation) + self.actor_perturbation_optim = actor_perturbation_optim + + self.critic = critic + self.critic_target = copy.deepcopy(self.critic) + self.critic_optim = critic_optim + + critic2 = critic2 or copy.deepcopy(critic) + critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) + self.critic2 = critic2 + self.critic2_target = copy.deepcopy(self.critic2) + self.critic2_optim = critic2_optim + + self.vae = vae + self.vae_optim = vae_optim + + self.gamma = gamma + self.tau = tau + self.lmbda = lmbda + self.device = device + self.forward_sampled_times = forward_sampled_times + self.num_sampled_action = num_sampled_action + + def train(self, mode: bool = True) -> Self: + """Set the module in training mode, except for the target network.""" + self.training = mode + self.actor_perturbation.train(mode) + self.critic.train(mode) + self.critic2.train(mode) + return self + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> ActBatchProtocol: + """Compute action over the given batch data.""" + # There is "obs" in the Batch + # obs_group: several groups. Each group has a state. + obs_group: torch.Tensor = to_torch(batch.obs, device=self.device) + act_group = [] + for obs_orig in obs_group: + # now obs is (state_dim) + obs = (obs_orig.reshape(1, -1)).repeat(self.forward_sampled_times, 1) + # now obs is (forward_sampled_times, state_dim) + + # decode(obs) generates action and actor perturbs it + act = self.actor_perturbation(obs, self.vae.decode(obs)) + # now action is (forward_sampled_times, action_dim) + q1 = self.critic(obs, act) + # q1 is (forward_sampled_times, 1) + max_indice = q1.argmax(0) + act_group.append(act[max_indice].cpu().data.numpy().flatten()) + act_group = np.array(act_group) + return cast(ActBatchProtocol, Batch(act=act_group)) + + def sync_weight(self) -> None: + """Soft-update the weight for the target network.""" + self.soft_update(self.critic_target, self.critic, self.tau) + self.soft_update(self.critic2_target, self.critic2, self.tau) + self.soft_update(self.actor_perturbation_target, self.actor_perturbation, self.tau) + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBCQTrainingStats: + # batch: obs, act, rew, done, obs_next. (numpy array) + # (batch_size, state_dim) + batch: Batch = to_torch(batch, dtype=torch.float, device=self.device) + obs, act = batch.obs, batch.act + batch_size = obs.shape[0] + + # mean, std: (state.shape[0], latent_dim) + recon, mean, std = self.vae(obs, act) + recon_loss = F.mse_loss(act, recon) + # (....) is D_KL( N(mu, sigma) || N(0,1) ) + KL_loss = (-torch.log(std) + (std.pow(2) + mean.pow(2) - 1) / 2).mean() + vae_loss = recon_loss + KL_loss / 2 + + self.vae_optim.zero_grad() + vae_loss.backward() + self.vae_optim.step() + + # critic training: + with torch.no_grad(): + # repeat num_sampled_action times + obs_next = batch.obs_next.repeat_interleave(self.num_sampled_action, dim=0) + # now obs_next: (num_sampled_action * batch_size, state_dim) + + # perturbed action generated by VAE + act_next = self.vae.decode(obs_next) + # now obs_next: (num_sampled_action * batch_size, action_dim) + target_Q1 = self.critic_target(obs_next, act_next) + target_Q2 = self.critic2_target(obs_next, act_next) + + # Clipped Double Q-learning + target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (1 - self.lmbda) * torch.max( + target_Q1, + target_Q2, + ) + # now target_Q: (num_sampled_action * batch_size, 1) + + # the max value of Q + target_Q = target_Q.reshape(batch_size, -1).max(dim=1)[0].reshape(-1, 1) + # now target_Q: (batch_size, 1) + + target_Q = ( + batch.rew.reshape(-1, 1) + (1 - batch.done).reshape(-1, 1) * self.gamma * target_Q + ) + + current_Q1 = self.critic(obs, act) + current_Q2 = self.critic2(obs, act) + + critic1_loss = F.mse_loss(current_Q1, target_Q) + critic2_loss = F.mse_loss(current_Q2, target_Q) + + self.critic_optim.zero_grad() + self.critic2_optim.zero_grad() + critic1_loss.backward() + critic2_loss.backward() + self.critic_optim.step() + self.critic2_optim.step() + + sampled_act = self.vae.decode(obs) + perturbed_act = self.actor_perturbation(obs, sampled_act) + + # max + actor_loss = -self.critic(obs, perturbed_act).mean() + + self.actor_perturbation_optim.zero_grad() + actor_loss.backward() + self.actor_perturbation_optim.step() + + # update target network + self.sync_weight() + + return BCQTrainingStats( # type: ignore + actor_loss=actor_loss.item(), + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + vae_loss=vae_loss.item(), + ) diff --git a/examples/atari/tianshou/policy/imitation/cql.py b/examples/atari/tianshou/policy/imitation/cql.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce6d83d4938a19f3a304dd857bbb5e09858f3fe --- /dev/null +++ b/examples/atari/tianshou/policy/imitation/cql.py @@ -0,0 +1,401 @@ +from dataclasses import dataclass +from typing import Any, Literal, Self, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F +from overrides import override +from torch.nn.utils import clip_grad_norm_ + +from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.data.buffer.base import TBuffer +from tianshou.data.types import RolloutBatchProtocol +from tianshou.exploration import BaseNoise +from tianshou.policy import SACPolicy +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.sac import SACTrainingStats +from tianshou.utils.conversion import to_optional_float +from tianshou.utils.net.continuous import ActorProb + + +@dataclass(kw_only=True) +class CQLTrainingStats(SACTrainingStats): + """A data structure for storing loss statistics of the CQL learn step.""" + + cql_alpha: float | None = None + cql_alpha_loss: float | None = None + + +TCQLTrainingStats = TypeVar("TCQLTrainingStats", bound=CQLTrainingStats) + + +class CQLPolicy(SACPolicy[TCQLTrainingStats]): + """Implementation of CQL algorithm. arXiv:2006.04779. + + :param actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> a) + :param actor_optim: The optimizer for actor network. + :param critic: The first critic network. + :param critic_optim: The optimizer for the first critic network. + :param action_space: Env's action space. + :param critic2: the second critic network. (s, a -> Q(s, a)). + If None, use the same network as critic (via deepcopy). + :param critic2_optim: the optimizer for the second critic network. + If None, clone critic_optim to use for critic2.parameters(). + :param cql_alpha_lr: The learning rate of cql_log_alpha. + :param cql_weight: + :param tau: Parameter for soft update of the target network. + :param gamma: Discount factor, in [0, 1]. + :param alpha: Entropy regularization coefficient or a tuple + (target_entropy, log_alpha, alpha_optim) for automatic tuning. + :param temperature: + :param with_lagrange: Whether to use Lagrange. + TODO: extend documentation - what does this mean? + :param lagrange_threshold: The value of tau in CQL(Lagrange). + :param min_action: The minimum value of each dimension of action. + :param max_action: The maximum value of each dimension of action. + :param num_repeat_actions: The number of times the action is repeated when calculating log-sum-exp. + :param alpha_min: Lower bound for clipping cql_alpha. + :param alpha_max: Upper bound for clipping cql_alpha. + :param clip_grad: Clip_grad for updating critic network. + :param calibrated: calibrate Q-values as in CalQL paper `arXiv:2303.05479`. + Useful for offline pre-training followed by online training, + and also was observed to achieve better results than vanilla cql. + :param device: Which device to create this model on. + :param estimation_step: Estimation steps. + :param exploration_noise: Type of exploration noise. + :param deterministic_eval: Flag for deterministic evaluation. + :param action_scaling: Flag for action scaling. + :param action_bound_method: Method for action bounding. Only used if the + action_space is continuous. + :param observation_space: Env's Observation space. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + actor: ActorProb, + actor_optim: torch.optim.Optimizer, + critic: torch.nn.Module, + critic_optim: torch.optim.Optimizer, + action_space: gym.spaces.Box, + critic2: torch.nn.Module | None = None, + critic2_optim: torch.optim.Optimizer | None = None, + cql_alpha_lr: float = 1e-4, + cql_weight: float = 1.0, + tau: float = 0.005, + gamma: float = 0.99, + alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2, + temperature: float = 1.0, + with_lagrange: bool = True, + lagrange_threshold: float = 10.0, + min_action: float = -1.0, + max_action: float = 1.0, + num_repeat_actions: int = 10, + alpha_min: float = 0.0, + alpha_max: float = 1e6, + clip_grad: float = 1.0, + calibrated: bool = True, + # TODO: why does this one have device? Almost no other policies have it + device: str | torch.device = "cpu", + estimation_step: int = 1, + exploration_noise: BaseNoise | Literal["default"] | None = None, + deterministic_eval: bool = True, + action_scaling: bool = True, + action_bound_method: Literal["clip"] | None = "clip", + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + actor=actor, + actor_optim=actor_optim, + critic=critic, + critic_optim=critic_optim, + action_space=action_space, + critic2=critic2, + critic2_optim=critic2_optim, + tau=tau, + gamma=gamma, + deterministic_eval=deterministic_eval, + alpha=alpha, + exploration_noise=exploration_noise, + estimation_step=estimation_step, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + observation_space=observation_space, + lr_scheduler=lr_scheduler, + ) + # There are _target_entropy, _log_alpha, _alpha_optim in SACPolicy. + self.device = device + self.temperature = temperature + self.with_lagrange = with_lagrange + self.lagrange_threshold = lagrange_threshold + + self.cql_weight = cql_weight + + self.cql_log_alpha = torch.tensor([0.0], requires_grad=True) + self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr) + self.cql_log_alpha = self.cql_log_alpha.to(device) + + self.min_action = min_action + self.max_action = max_action + + self.num_repeat_actions = num_repeat_actions + + self.alpha_min = alpha_min + self.alpha_max = alpha_max + self.clip_grad = clip_grad + + self.calibrated = calibrated + + def train(self, mode: bool = True) -> Self: + """Set the module in training mode, except for the target network.""" + self.training = mode + self.actor.train(mode) + self.critic.train(mode) + self.critic2.train(mode) + return self + + def sync_weight(self) -> None: + """Soft-update the weight for the target network.""" + self.soft_update(self.critic_old, self.critic, self.tau) + self.soft_update(self.critic2_old, self.critic2, self.tau) + + def actor_pred(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch = Batch(obs=obs, info=[None] * len(obs)) + obs_result = self(batch) + return obs_result.act, obs_result.log_prob + + def calc_actor_loss(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + act_pred, log_pi = self.actor_pred(obs) + q1 = self.critic(obs, act_pred) + q2 = self.critic2(obs, act_pred) + min_Q = torch.min(q1, q2) + # self.alpha: float | torch.Tensor + actor_loss = (self.alpha * log_pi - min_Q).mean() + # actor_loss.shape: (), log_pi.shape: (batch_size, 1) + return actor_loss, log_pi + + def calc_pi_values( + self, + obs_pi: torch.Tensor, + obs_to_pred: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + act_pred, log_pi = self.actor_pred(obs_pi) + + q1 = self.critic(obs_to_pred, act_pred) + q2 = self.critic2(obs_to_pred, act_pred) + + return q1 - log_pi.detach(), q2 - log_pi.detach() + + def calc_random_values( + self, + obs: torch.Tensor, + act: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + random_value1 = self.critic(obs, act) + random_log_prob1 = np.log(0.5 ** act.shape[-1]) + + random_value2 = self.critic2(obs, act) + random_log_prob2 = np.log(0.5 ** act.shape[-1]) + + return random_value1 - random_log_prob1, random_value2 - random_log_prob2 + + @override + def process_buffer(self, buffer: TBuffer) -> TBuffer: + """If `self.calibrated = True`, adds `calibration_returns` to buffer._meta. + + :param buffer: + :return: + """ + if self.calibrated: + # otherwise _meta hack cannot work + assert isinstance(buffer, ReplayBuffer) + batch, indices = buffer.sample(0) + returns, _ = self.compute_episodic_return( + batch=batch, + buffer=buffer, + indices=indices, + gamma=self.gamma, + gae_lambda=1.0, + ) + # TODO: don't access _meta directly + buffer._meta = cast( + RolloutBatchProtocol, + Batch(**buffer._meta.__dict__, calibration_returns=returns), + ) + return buffer + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol: + # TODO: mypy rightly complains here b/c the design violates + # Liskov Substitution Principle + # DDPGPolicy.process_fn() results in a batch with returns but + # CQLPolicy.process_fn() doesn't add the returns. + # Should probably be fixed! + return batch + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TCQLTrainingStats: # type: ignore + batch: Batch = to_torch(batch, dtype=torch.float, device=self.device) + obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next + batch_size = obs.shape[0] + + # compute actor loss and update actor + actor_loss, log_pi = self.calc_actor_loss(obs) + self.actor_optim.zero_grad() + actor_loss.backward() + self.actor_optim.step() + + alpha_loss = None + # compute alpha loss + if self.is_auto_alpha: + log_pi = log_pi + self.target_entropy + alpha_loss = -(self.log_alpha * log_pi.detach()).mean() + self.alpha_optim.zero_grad() + # update log_alpha + alpha_loss.backward() + self.alpha_optim.step() + # update alpha + # TODO: it's probably a bad idea to track both alpha and log_alpha in different fields + self.alpha = self.log_alpha.detach().exp() + + # compute target_Q + with torch.no_grad(): + act_next, new_log_pi = self.actor_pred(obs_next) + + target_Q1 = self.critic_old(obs_next, act_next) + target_Q2 = self.critic2_old(obs_next, act_next) + + target_Q = torch.min(target_Q1, target_Q2) - self.alpha * new_log_pi + + target_Q = rew + self.gamma * (1 - batch.done) * target_Q.flatten() + # shape: (batch_size) + + # compute critic loss + current_Q1 = self.critic(obs, act).flatten() + current_Q2 = self.critic2(obs, act).flatten() + # shape: (batch_size) + + critic1_loss = F.mse_loss(current_Q1, target_Q) + critic2_loss = F.mse_loss(current_Q2, target_Q) + + # CQL + random_actions = ( + torch.FloatTensor(batch_size * self.num_repeat_actions, act.shape[-1]) + .uniform_(-self.min_action, self.max_action) + .to(self.device) + ) + + obs_len = len(obs.shape) + repeat_size = [1, self.num_repeat_actions] + [1] * (obs_len - 1) + view_size = [batch_size * self.num_repeat_actions, *list(obs.shape[1:])] + tmp_obs = obs.unsqueeze(1).repeat(*repeat_size).view(*view_size) + tmp_obs_next = obs_next.unsqueeze(1).repeat(*repeat_size).view(*view_size) + # tmp_obs & tmp_obs_next: (batch_size * num_repeat, state_dim) + + current_pi_value1, current_pi_value2 = self.calc_pi_values(tmp_obs, tmp_obs) + next_pi_value1, next_pi_value2 = self.calc_pi_values(tmp_obs_next, tmp_obs) + + random_value1, random_value2 = self.calc_random_values(tmp_obs, random_actions) + + for value in [ + current_pi_value1, + current_pi_value2, + next_pi_value1, + next_pi_value2, + random_value1, + random_value2, + ]: + value.reshape(batch_size, self.num_repeat_actions, 1) + + if self.calibrated: + returns = ( + batch.calibration_returns.unsqueeze(1) + .repeat( + (1, self.num_repeat_actions), + ) + .view(-1, 1) + ) + random_value1 = torch.max(random_value1, returns) + random_value2 = torch.max(random_value2, returns) + + current_pi_value1 = torch.max(current_pi_value1, returns) + current_pi_value2 = torch.max(current_pi_value2, returns) + + next_pi_value1 = torch.max(next_pi_value1, returns) + next_pi_value2 = torch.max(next_pi_value2, returns) + + # cat q values + cat_q1 = torch.cat([random_value1, current_pi_value1, next_pi_value1], 1) + cat_q2 = torch.cat([random_value2, current_pi_value2, next_pi_value2], 1) + # shape: (batch_size, 3 * num_repeat, 1) + + cql1_scaled_loss = ( + torch.logsumexp(cat_q1 / self.temperature, dim=1).mean() + * self.cql_weight + * self.temperature + - current_Q1.mean() * self.cql_weight + ) + cql2_scaled_loss = ( + torch.logsumexp(cat_q2 / self.temperature, dim=1).mean() + * self.cql_weight + * self.temperature + - current_Q2.mean() * self.cql_weight + ) + # shape: (1) + + cql_alpha_loss = None + cql_alpha = None + if self.with_lagrange: + cql_alpha = torch.clamp( + self.cql_log_alpha.exp(), + self.alpha_min, + self.alpha_max, + ) + cql1_scaled_loss = cql_alpha * (cql1_scaled_loss - self.lagrange_threshold) + cql2_scaled_loss = cql_alpha * (cql2_scaled_loss - self.lagrange_threshold) + + self.cql_alpha_optim.zero_grad() + cql_alpha_loss = -(cql1_scaled_loss + cql2_scaled_loss) * 0.5 + cql_alpha_loss.backward(retain_graph=True) + self.cql_alpha_optim.step() + + critic1_loss = critic1_loss + cql1_scaled_loss + critic2_loss = critic2_loss + cql2_scaled_loss + + # update critic + self.critic_optim.zero_grad() + critic1_loss.backward(retain_graph=True) + # clip grad, prevent the vanishing gradient problem + # It doesn't seem necessary + clip_grad_norm_(self.critic.parameters(), self.clip_grad) + self.critic_optim.step() + + self.critic2_optim.zero_grad() + critic2_loss.backward() + clip_grad_norm_(self.critic2.parameters(), self.clip_grad) + self.critic2_optim.step() + + self.sync_weight() + + return CQLTrainingStats( # type: ignore[return-value] + actor_loss=to_optional_float(actor_loss), + critic1_loss=to_optional_float(critic1_loss), + critic2_loss=to_optional_float(critic2_loss), + alpha=to_optional_float(self.alpha), + alpha_loss=to_optional_float(alpha_loss), + cql_alpha_loss=to_optional_float(cql_alpha_loss), + cql_alpha=to_optional_float(cql_alpha), + ) diff --git a/examples/atari/tianshou/policy/imitation/discrete_bcq.py b/examples/atari/tianshou/policy/imitation/discrete_bcq.py new file mode 100644 index 0000000000000000000000000000000000000000..b5258c141c0556406798e7c179106beeac50d813 --- /dev/null +++ b/examples/atari/tianshou/policy/imitation/discrete_bcq.py @@ -0,0 +1,179 @@ +import math +from dataclasses import dataclass +from typing import Any, Self, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.data.types import ( + ImitationBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.policy import DQNPolicy +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.dqn import DQNTrainingStats + +float_info = torch.finfo(torch.float32) +INF = float_info.max + + +@dataclass(kw_only=True) +class DiscreteBCQTrainingStats(DQNTrainingStats): + q_loss: float + i_loss: float + reg_loss: float + + +TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteBCQTrainingStats) + + +class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]): + """Implementation of discrete BCQ algorithm. arXiv:1910.01708. + + :param model: a model following the rules (s_B -> action_values_BA) + :param imitator: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits) + :param optim: a torch.optim for optimizing the model. + :param discount_factor: in [0, 1]. + :param estimation_step: the number of steps to look ahead + :param target_update_freq: the target network update frequency. + :param eval_eps: the epsilon-greedy noise added in evaluation. + :param unlikely_action_threshold: the threshold (tau) for unlikely + actions, as shown in Equ. (17) in the paper. + :param imitation_logits_penalty: regularization weight for imitation + logits. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param observation_space: Env's observation space. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + model: torch.nn.Module, + imitator: torch.nn.Module, + optim: torch.optim.Optimizer, + action_space: gym.spaces.Discrete, + discount_factor: float = 0.99, + estimation_step: int = 1, + target_update_freq: int = 8000, + eval_eps: float = 1e-3, + unlikely_action_threshold: float = 0.3, + imitation_logits_penalty: float = 1e-2, + reward_normalization: bool = False, + is_double: bool = True, + clip_loss_grad: bool = False, + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + model=model, + optim=optim, + action_space=action_space, + discount_factor=discount_factor, + estimation_step=estimation_step, + target_update_freq=target_update_freq, + reward_normalization=reward_normalization, + is_double=is_double, + clip_loss_grad=clip_loss_grad, + observation_space=observation_space, + lr_scheduler=lr_scheduler, + ) + assert ( + target_update_freq > 0 + ), f"BCQ needs target_update_freq>0 but got: {target_update_freq}." + self.imitator = imitator + assert ( + 0.0 <= unlikely_action_threshold < 1.0 + ), f"unlikely_action_threshold should be in [0, 1) but got: {unlikely_action_threshold}" + if unlikely_action_threshold > 0: + self._log_tau = math.log(unlikely_action_threshold) + else: + self._log_tau = -np.inf + assert 0.0 <= eval_eps < 1.0 + self.eps = eval_eps + self._weight_reg = imitation_logits_penalty + + def train(self, mode: bool = True) -> Self: + self.training = mode + self.model.train(mode) + self.imitator.train(mode) + return self + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + batch = buffer[indices] # batch.obs_next: s_{t+n} + next_obs_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + act = self(next_obs_batch).act + target_q, _ = self.model_old(batch.obs_next) + return target_q[np.arange(len(act)), act] + + def forward( # type: ignore + self, + batch: ObsBatchProtocol, + state: dict | Batch | np.ndarray | None = None, + **kwargs: Any, + ) -> ImitationBatchProtocol: + # TODO: Liskov substitution principle is violated here, the superclass + # produces a batch with the field logits, but this one doesn't. + # Should be fixed in the future! + q_value, state = self.model(batch.obs, state=state, info=batch.info) + if self.max_action_num is None: + self.max_action_num = q_value.shape[1] + imitation_logits, _ = self.imitator(batch.obs, state=state, info=batch.info) + + # mask actions for argmax + ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values + mask = (ratio < self._log_tau).float() + act = (q_value - INF * mask).argmax(dim=-1) + + result = Batch(act=act, state=state, q_value=q_value, imitation_logits=imitation_logits) + return cast(ImitationBatchProtocol, result) + + def learn( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TDiscreteBCQTrainingStats: + if self._iter % self.freq == 0: + self.sync_weight() + self._iter += 1 + + target_q = batch.returns.flatten() + result = self(batch) + imitation_logits = result.imitation_logits + current_q = result.q_value[np.arange(len(target_q)), batch.act] + act = to_torch(batch.act, dtype=torch.long, device=target_q.device) + q_loss = F.smooth_l1_loss(current_q, target_q) + i_loss = F.nll_loss(F.log_softmax(imitation_logits, dim=-1), act) + reg_loss = imitation_logits.pow(2).mean() + loss = q_loss + i_loss + self._weight_reg * reg_loss + + self.optim.zero_grad() + loss.backward() + self.optim.step() + + return DiscreteBCQTrainingStats( # type: ignore[return-value] + loss=loss.item(), + q_loss=q_loss.item(), + i_loss=i_loss.item(), + reg_loss=reg_loss.item(), + ) diff --git a/examples/atari/tianshou/policy/imitation/discrete_cql.py b/examples/atari/tianshou/policy/imitation/discrete_cql.py new file mode 100644 index 0000000000000000000000000000000000000000..b63f83e1159143ff0471d3981e4d9d042caa2c52 --- /dev/null +++ b/examples/atari/tianshou/policy/imitation/discrete_cql.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass +from typing import Any, TypeVar + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.data import to_torch +from tianshou.data.types import RolloutBatchProtocol +from tianshou.policy import QRDQNPolicy +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats + + +@dataclass(kw_only=True) +class DiscreteCQLTrainingStats(QRDQNTrainingStats): + cql_loss: float + qr_loss: float + + +TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteCQLTrainingStats) + + +class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]): + """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779. + + :param model: a model following the rules (s_B -> action_values_BA) + :param optim: a torch.optim for optimizing the model. + :param action_space: Env's action space. + :param min_q_weight: the weight for the cql loss. + :param discount_factor: in [0, 1]. + :param num_quantiles: the number of quantile midpoints in the inverse + cumulative distribution function of the value. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param observation_space: Env's observation space. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + model: torch.nn.Module, + optim: torch.optim.Optimizer, + action_space: gym.spaces.Discrete, + min_q_weight: float = 10.0, + discount_factor: float = 0.99, + num_quantiles: int = 200, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + is_double: bool = True, + clip_loss_grad: bool = False, + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + model=model, + optim=optim, + action_space=action_space, + discount_factor=discount_factor, + num_quantiles=num_quantiles, + estimation_step=estimation_step, + target_update_freq=target_update_freq, + reward_normalization=reward_normalization, + is_double=is_double, + clip_loss_grad=clip_loss_grad, + observation_space=observation_space, + lr_scheduler=lr_scheduler, + ) + self.min_q_weight = min_q_weight + + def learn( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TDiscreteCQLTrainingStats: + if self._target and self._iter % self.freq == 0: + self.sync_weight() + self.optim.zero_grad() + weight = batch.pop("weight", 1.0) + all_dist = self(batch).logits + act = to_torch(batch.act, dtype=torch.long, device=all_dist.device) + curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2) + target_dist = batch.returns.unsqueeze(1) + # calculate each element's difference between curr_dist and target_dist + dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + huber_loss = ( + (dist_diff * (self.tau_hat - (target_dist - curr_dist).detach().le(0.0).float()).abs()) + .sum(-1) + .mean(1) + ) + qr_loss = (huber_loss * weight).mean() + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 + batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer + # add CQL loss + q = self.compute_q_value(all_dist, None) + dataset_expec = q.gather(1, act.unsqueeze(1)).mean() + negative_sampling = q.logsumexp(1).mean() + min_q_loss = negative_sampling - dataset_expec + loss = qr_loss + min_q_loss * self.min_q_weight + loss.backward() + self.optim.step() + self._iter += 1 + + return DiscreteCQLTrainingStats( # type: ignore[return-value] + loss=loss.item(), + qr_loss=qr_loss.item(), + cql_loss=min_q_loss.item(), + ) diff --git a/examples/atari/tianshou/policy/imitation/discrete_crr.py b/examples/atari/tianshou/policy/imitation/discrete_crr.py new file mode 100644 index 0000000000000000000000000000000000000000..9c54129da53aaef6cb434108cc54f9a1acc4bdc4 --- /dev/null +++ b/examples/atari/tianshou/policy/imitation/discrete_crr.py @@ -0,0 +1,153 @@ +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Literal, TypeVar + +import gymnasium as gym +import torch +import torch.nn.functional as F +from torch.distributions import Categorical + +from tianshou.data import to_torch, to_torch_as +from tianshou.data.types import RolloutBatchProtocol +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats +from tianshou.utils.net.discrete import Actor, Critic + + +@dataclass +class DiscreteCRRTrainingStats(PGTrainingStats): + actor_loss: float + critic_loss: float + cql_loss: float + + +TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteCRRTrainingStats) + + +class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]): + r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134. + + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). + :param critic: the action-value critic (i.e., Q function) + network. (s -> Q(s, \*)) + :param optim: a torch.optim for optimizing the model. + :param discount_factor: in [0, 1]. + :param str policy_improvement_mode: type of the weight function f. Possible + values: "binary"/"exp"/"all". + :param ratio_upper_bound: when policy_improvement_mode is "exp", the value + of the exp function is upper-bounded by this parameter. + :param beta: when policy_improvement_mode is "exp", this is the denominator + of the exp function. + :param min_q_weight: weight for CQL loss/regularizer. Default to 10. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: if True, will normalize the *returns* + by subtracting the running mean and dividing by the running standard deviation. + Can be detrimental to performance! See TODO in process_fn. + :param observation_space: Env's observation space. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + actor: torch.nn.Module | Actor, + critic: torch.nn.Module | Critic, + optim: torch.optim.Optimizer, + action_space: gym.spaces.Discrete, + discount_factor: float = 0.99, + policy_improvement_mode: Literal["exp", "binary", "all"] = "exp", + ratio_upper_bound: float = 20.0, + beta: float = 1.0, + min_q_weight: float = 10.0, + target_update_freq: int = 0, + reward_normalization: bool = False, + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + actor=actor, + optim=optim, + action_space=action_space, + dist_fn=lambda x: Categorical(logits=x), + discount_factor=discount_factor, + reward_normalization=reward_normalization, + observation_space=observation_space, + action_scaling=False, + action_bound_method=None, + lr_scheduler=lr_scheduler, + ) + self.critic = critic + self._target = target_update_freq > 0 + self._freq = target_update_freq + self._iter = 0 + if self._target: + self.actor_old = deepcopy(self.actor) + self.actor_old.eval() + self.critic_old = deepcopy(self.critic) + self.critic_old.eval() + else: + self.actor_old = self.actor + self.critic_old = self.critic + self._policy_improvement_mode = policy_improvement_mode + self._ratio_upper_bound = ratio_upper_bound + self._beta = beta + self._min_q_weight = min_q_weight + + def sync_weight(self) -> None: + self.actor_old.load_state_dict(self.actor.state_dict()) + self.critic_old.load_state_dict(self.critic.state_dict()) + + def learn( # type: ignore + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TDiscreteCRRTrainingStats: + if self._target and self._iter % self._freq == 0: + self.sync_weight() + self.optim.zero_grad() + q_t = self.critic(batch.obs) + act = to_torch(batch.act, dtype=torch.long, device=q_t.device) + qa_t = q_t.gather(1, act.unsqueeze(1)) + # Critic loss + with torch.no_grad(): + target_a_t, _ = self.actor_old(batch.obs_next) + target_m = Categorical(logits=target_a_t) + q_t_target = self.critic_old(batch.obs_next) + rew = to_torch_as(batch.rew, q_t_target) + expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True) + expected_target_q[batch.done > 0] = 0.0 + target = rew.unsqueeze(1) + self.gamma * expected_target_q + critic_loss = 0.5 * F.mse_loss(qa_t, target) + # Actor loss + act_target, _ = self.actor(batch.obs) + dist = Categorical(logits=act_target) + expected_policy_q = (q_t * dist.probs).sum(-1, keepdim=True) + advantage = qa_t - expected_policy_q + if self._policy_improvement_mode == "binary": + actor_loss_coef = (advantage > 0).float() + elif self._policy_improvement_mode == "exp": + actor_loss_coef = (advantage / self._beta).exp().clamp(0, self._ratio_upper_bound) + else: + actor_loss_coef = 1.0 # effectively behavior cloning + actor_loss = (-dist.log_prob(act) * actor_loss_coef).mean() + # CQL loss/regularizer + min_q_loss = (q_t.logsumexp(1) - qa_t).mean() + loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss + loss.backward() + self.optim.step() + self._iter += 1 + + return DiscreteCRRTrainingStats( # type: ignore[return-value] + loss=loss.item(), + actor_loss=actor_loss.item(), + critic_loss=critic_loss.item(), + cql_loss=min_q_loss.item(), + ) diff --git a/examples/atari/tianshou/policy/imitation/gail.py b/examples/atari/tianshou/policy/imitation/gail.py new file mode 100644 index 0000000000000000000000000000000000000000..524f04001a528ed780ec6d81ccac89de9cc46895 --- /dev/null +++ b/examples/atari/tianshou/policy/imitation/gail.py @@ -0,0 +1,198 @@ +from dataclasses import dataclass +from typing import Any, Literal, TypeVar + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.data import ( + ReplayBuffer, + SequenceSummaryStats, + to_numpy, + to_torch, +) +from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol +from tianshou.policy import PPOPolicy +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.policy.modelfree.ppo import PPOTrainingStats +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic + + +@dataclass(kw_only=True) +class GailTrainingStats(PPOTrainingStats): + disc_loss: SequenceSummaryStats + acc_pi: SequenceSummaryStats + acc_exp: SequenceSummaryStats + + +TGailTrainingStats = TypeVar("TGailTrainingStats", bound=GailTrainingStats) + + +class GAILPolicy(PPOPolicy[TGailTrainingStats]): + r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476. + + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer for actor and critic network. + :param dist_fn: distribution class for computing the action. + :param action_space: env's action space + :param expert_buffer: the replay buffer containing expert experience. + :param disc_net: the discriminator network with input dim equals + state dim plus action dim and output dim equals 1. + :param disc_optim: the optimizer for the discriminator network. + :param disc_update_num: the number of discriminator grad steps per model grad step. + :param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original + paper. + :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, + where c > 1 is a constant indicating the lower bound. Set to None + to disable dual-clip PPO. + :param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. + :param advantage_normalization: whether to do per mini-batch advantage + normalization. + :param recompute_advantage: whether to recompute advantage every update + repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. + :param vf_coef: weight for value loss. + :param ent_coef: weight for entropy loss. + :param max_grad_norm: clipping gradients in back propagation. + :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param max_batchsize: the maximum size of the batch when computing GAE. + :param discount_factor: in [0, 1]. + :param reward_normalization: normalize estimated values to have std close to 1. + :param deterministic_eval: if True, use deterministic evaluation. + :param observation_space: the space of the observation. + :param action_scaling: if True, scale the action from [-1, 1] to the range of + action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.PPOPolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, + optim: torch.optim.Optimizer, + dist_fn: TDistFnDiscrOrCont, + action_space: gym.Space, + expert_buffer: ReplayBuffer, + disc_net: torch.nn.Module, + disc_optim: torch.optim.Optimizer, + disc_update_num: int = 4, + eps_clip: float = 0.2, + dual_clip: float | None = None, + value_clip: bool = False, + advantage_normalization: bool = True, + recompute_advantage: bool = False, + vf_coef: float = 0.5, + ent_coef: float = 0.01, + max_grad_norm: float | None = None, + gae_lambda: float = 0.95, + max_batchsize: int = 256, + discount_factor: float = 0.99, + # TODO: rename to return_normalization? + reward_normalization: bool = False, + deterministic_eval: bool = False, + observation_space: gym.Space | None = None, + action_scaling: bool = True, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + actor=actor, + critic=critic, + optim=optim, + dist_fn=dist_fn, + action_space=action_space, + eps_clip=eps_clip, + dual_clip=dual_clip, + value_clip=value_clip, + advantage_normalization=advantage_normalization, + recompute_advantage=recompute_advantage, + vf_coef=vf_coef, + ent_coef=ent_coef, + max_grad_norm=max_grad_norm, + gae_lambda=gae_lambda, + max_batchsize=max_batchsize, + discount_factor=discount_factor, + reward_normalization=reward_normalization, + deterministic_eval=deterministic_eval, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + lr_scheduler=lr_scheduler, + ) + self.disc_net = disc_net + self.disc_optim = disc_optim + self.disc_update_num = disc_update_num + self.expert_buffer = expert_buffer + self.action_dim = actor.output_dim + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> LogpOldProtocol: + """Pre-process the data from the provided replay buffer. + + Used in :meth:`update`. Check out :ref:`process_fn` for more information. + """ + # update reward + with torch.no_grad(): + batch.rew = to_numpy(-F.logsigmoid(-self.disc(batch)).flatten()) + return super().process_fn(batch, buffer, indices) + + def disc(self, batch: RolloutBatchProtocol) -> torch.Tensor: + obs = to_torch(batch.obs, device=self.disc_net.device) + act = to_torch(batch.act, device=self.disc_net.device) + return self.disc_net(torch.cat([obs, act], dim=1)) + + def learn( # type: ignore + self, + batch: RolloutBatchProtocol, + batch_size: int | None, + repeat: int, + **kwargs: Any, + ) -> TGailTrainingStats: + # update discriminator + losses = [] + acc_pis = [] + acc_exps = [] + bsz = len(batch) // self.disc_update_num + for b in batch.split(bsz, merge_last=True): + logits_pi = self.disc(b) + exp_b = self.expert_buffer.sample(bsz)[0] + logits_exp = self.disc(exp_b) + loss_pi = -F.logsigmoid(-logits_pi).mean() + loss_exp = -F.logsigmoid(logits_exp).mean() + loss_disc = loss_pi + loss_exp + self.disc_optim.zero_grad() + loss_disc.backward() + self.disc_optim.step() + losses.append(loss_disc.item()) + acc_pis.append((logits_pi < 0).float().mean().item()) + acc_exps.append((logits_exp > 0).float().mean().item()) + # update policy + ppo_loss_stat = super().learn(batch, batch_size, repeat, **kwargs) + + disc_losses_summary = SequenceSummaryStats.from_sequence(losses) + acc_pi_summary = SequenceSummaryStats.from_sequence(acc_pis) + acc_exps_summary = SequenceSummaryStats.from_sequence(acc_exps) + + return GailTrainingStats( # type: ignore[return-value] + **ppo_loss_stat.__dict__, + disc_loss=disc_losses_summary, + acc_pi=acc_pi_summary, + acc_exp=acc_exps_summary, + ) diff --git a/examples/atari/tianshou/policy/imitation/td3_bc.py b/examples/atari/tianshou/policy/imitation/td3_bc.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b2bfe91028f208d00e3c95fe14f00ad140f544 --- /dev/null +++ b/examples/atari/tianshou/policy/imitation/td3_bc.py @@ -0,0 +1,130 @@ +from dataclasses import dataclass +from typing import Any, Literal, TypeVar + +import gymnasium as gym +import torch +import torch.nn.functional as F + +from tianshou.data import to_torch_as +from tianshou.data.types import RolloutBatchProtocol +from tianshou.exploration import BaseNoise, GaussianNoise +from tianshou.policy import TD3Policy +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.td3 import TD3TrainingStats + + +@dataclass(kw_only=True) +class TD3BCTrainingStats(TD3TrainingStats): + pass + + +TTD3BCTrainingStats = TypeVar("TTD3BCTrainingStats", bound=TD3BCTrainingStats) + + +class TD3BCPolicy(TD3Policy[TTD3BCTrainingStats]): + """Implementation of TD3+BC. arXiv:2106.06860. + + :param actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> actions) + :param actor_optim: the optimizer for actor network. + :param critic: the first critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer for the first critic network. + :param action_space: Env's action space. Should be gym.spaces.Box. + :param critic2: the second critic network. (s, a -> Q(s, a)). + If None, use the same network as critic (via deepcopy). + :param critic2_optim: the optimizer for the second critic network. + If None, clone critic_optim to use for critic2.parameters(). + :param tau: param for soft update of the target network. + :param gamma: discount factor, in [0, 1]. + :param exploration_noise: add noise to action for exploration. + This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). + :param policy_noise: the noise used in updating policy network. + :param update_actor_freq: the update frequency of actor network. + :param noise_clip: the clipping range used in updating policy network. + :param alpha: the value of alpha, which controls the weight for TD3 learning + relative to behavior cloning. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate + in optimizer in each policy.update() + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + actor: torch.nn.Module, + actor_optim: torch.optim.Optimizer, + critic: torch.nn.Module, + critic_optim: torch.optim.Optimizer, + action_space: gym.Space, + critic2: torch.nn.Module | None = None, + critic2_optim: torch.optim.Optimizer | None = None, + tau: float = 0.005, + gamma: float = 0.99, + exploration_noise: BaseNoise | None = GaussianNoise(sigma=0.1), + policy_noise: float = 0.2, + update_actor_freq: int = 2, + noise_clip: float = 0.5, + # TODO: same name as alpha in SAC and REDQ, which also inherit from DDPGPolicy. Rename? + alpha: float = 2.5, + estimation_step: int = 1, + observation_space: gym.Space | None = None, + action_scaling: bool = True, + action_bound_method: Literal["clip"] | None = "clip", + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + actor=actor, + actor_optim=actor_optim, + critic=critic, + critic_optim=critic_optim, + action_space=action_space, + critic2=critic2, + critic2_optim=critic2_optim, + tau=tau, + gamma=gamma, + exploration_noise=exploration_noise, + policy_noise=policy_noise, + noise_clip=noise_clip, + update_actor_freq=update_actor_freq, + estimation_step=estimation_step, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + observation_space=observation_space, + lr_scheduler=lr_scheduler, + ) + self.alpha = alpha + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3BCTrainingStats: # type: ignore + # critic 1&2 + td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) + td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) + batch.weight = (td1 + td2) / 2.0 # prio-buffer + + # actor + if self._cnt % self.update_actor_freq == 0: + act = self(batch, eps=0.0).act + q_value = self.critic(batch.obs, act) + lmbda = self.alpha / q_value.abs().mean().detach() + actor_loss = -lmbda * q_value.mean() + F.mse_loss(act, to_torch_as(batch.act, act)) + self.actor_optim.zero_grad() + actor_loss.backward() + self._last = actor_loss.item() + self.actor_optim.step() + self.sync_weight() + self._cnt += 1 + + return TD3BCTrainingStats( # type: ignore[return-value] + actor_loss=self._last, + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + ) diff --git a/examples/atari/tianshou/policy/modelbased/__init__.py b/examples/atari/tianshou/policy/modelbased/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/atari/tianshou/policy/modelbased/icm.py b/examples/atari/tianshou/policy/modelbased/icm.py new file mode 100644 index 0000000000000000000000000000000000000000..9a603b7deb8dcbe866668413b3c544a8027a88c4 --- /dev/null +++ b/examples/atari/tianshou/policy/modelbased/icm.py @@ -0,0 +1,176 @@ +from typing import Any, Literal, Self, TypeVar + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol +from tianshou.policy import BasePolicy +from tianshou.policy.base import ( + TLearningRateScheduler, + TrainingStats, + TrainingStatsWrapper, + TTrainingStats, +) +from tianshou.utils.net.discrete import IntrinsicCuriosityModule + + +class ICMTrainingStats(TrainingStatsWrapper): + def __init__( + self, + wrapped_stats: TrainingStats, + *, + icm_loss: float, + icm_forward_loss: float, + icm_inverse_loss: float, + ) -> None: + self.icm_loss = icm_loss + self.icm_forward_loss = icm_forward_loss + self.icm_inverse_loss = icm_inverse_loss + super().__init__(wrapped_stats) + + +class ICMPolicy(BasePolicy[ICMTrainingStats]): + """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. + + :param policy: a base policy to add ICM to. + :param model: the ICM model. + :param optim: a torch.optim for optimizing the model. + :param lr_scale: the scaling factor for ICM learning. + :param forward_loss_weight: the weight for forward model loss. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + policy: BasePolicy[TTrainingStats], + model: IntrinsicCuriosityModule, + optim: torch.optim.Optimizer, + lr_scale: float, + reward_scale: float, + forward_loss_weight: float, + action_space: gym.Space, + observation_space: gym.Space | None = None, + action_scaling: bool = False, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + lr_scheduler=lr_scheduler, + ) + self.policy = policy + self.model = model + self.optim = optim + self.lr_scale = lr_scale + self.reward_scale = reward_scale + self.forward_loss_weight = forward_loss_weight + + def train(self, mode: bool = True) -> Self: + """Set the module in training mode.""" + self.policy.train(mode) + self.training = mode + self.model.train(mode) + return self + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> ActBatchProtocol: + """Compute action over the given batch data by inner policy. + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + return self.policy.forward(batch, state, **kwargs) + + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + + def exploration_noise( + self, + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: + return self.policy.exploration_noise(act, batch) + + def set_eps(self, eps: float) -> None: + """Set the eps for epsilon-greedy exploration.""" + if hasattr(self.policy, "set_eps"): + self.policy.set_eps(eps) + else: + raise NotImplementedError + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol: + """Pre-process the data from the provided replay buffer. + + Used in :meth:`update`. Check out :ref:`process_fn` for more information. + """ + mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next) + batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss) + batch.rew += to_numpy(mse_loss * self.reward_scale) + return self.policy.process_fn(batch, buffer, indices) + + def post_process_fn( + self, + batch: BatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> None: + """Post-process the data from the provided replay buffer. + + Typical usage is to update the sampling weight in prioritized + experience replay. Used in :meth:`update`. + """ + self.policy.post_process_fn(batch, buffer, indices) + batch.rew = batch.policy.orig_rew # restore original reward + + def learn( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> ICMTrainingStats: + training_stat = self.policy.learn(batch, **kwargs) + self.optim.zero_grad() + act_hat = batch.policy.act_hat + act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) + inverse_loss = F.cross_entropy(act_hat, act).mean() + forward_loss = batch.policy.mse_loss.mean() + loss = ( + (1 - self.forward_loss_weight) * inverse_loss + self.forward_loss_weight * forward_loss + ) * self.lr_scale + loss.backward() + self.optim.step() + + return ICMTrainingStats( + training_stat, + icm_loss=loss.item(), + icm_forward_loss=forward_loss.item(), + icm_inverse_loss=inverse_loss.item(), + ) diff --git a/examples/atari/tianshou/policy/modelbased/psrl.py b/examples/atari/tianshou/policy/modelbased/psrl.py new file mode 100644 index 0000000000000000000000000000000000000000..8c137470908b4b9d4ad2ed65c251a55113d80a5f --- /dev/null +++ b/examples/atari/tianshou/policy/modelbased/psrl.py @@ -0,0 +1,253 @@ +from dataclasses import dataclass +from typing import Any, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch + +from tianshou.data import Batch +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol +from tianshou.policy import BasePolicy +from tianshou.policy.base import TLearningRateScheduler, TrainingStats + + +@dataclass(kw_only=True) +class PSRLTrainingStats(TrainingStats): + psrl_rew_mean: float = 0.0 + psrl_rew_std: float = 0.0 + + +TPSRLTrainingStats = TypeVar("TPSRLTrainingStats", bound=PSRLTrainingStats) + + +class PSRLModel: + """Implementation of Posterior Sampling Reinforcement Learning Model. + + :param trans_count_prior: dirichlet prior (alphas), with shape + (n_state, n_action, n_state). + :param rew_mean_prior: means of the normal priors of rewards, + with shape (n_state, n_action). + :param rew_std_prior: standard deviations of the normal priors + of rewards, with shape (n_state, n_action). + :param discount_factor: in [0, 1]. + :param epsilon: for precision control in value iteration. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). + """ + + def __init__( + self, + trans_count_prior: np.ndarray, + rew_mean_prior: np.ndarray, + rew_std_prior: np.ndarray, + discount_factor: float, + epsilon: float, + ) -> None: + self.trans_count = trans_count_prior + self.n_state, self.n_action = rew_mean_prior.shape + self.rew_mean = rew_mean_prior + self.rew_std = rew_std_prior + self.rew_square_sum = np.zeros_like(rew_mean_prior) + self.rew_std_prior = rew_std_prior + self.discount_factor = discount_factor + self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight + self.eps = epsilon + self.policy: np.ndarray + self.value = np.zeros(self.n_state) + self.updated = False + self.__eps = np.finfo(np.float32).eps.item() + + def observe( + self, + trans_count: np.ndarray, + rew_sum: np.ndarray, + rew_square_sum: np.ndarray, + rew_count: np.ndarray, + ) -> None: + """Add data into memory pool. + + For rewards, we have a normal prior at first. After we observed a + reward for a given state-action pair, we use the mean value of our + observations instead of the prior mean as the posterior mean. The + standard deviations are in inverse proportion to the number of the + corresponding observations. + + :param trans_count: the number of observations, with shape + (n_state, n_action, n_state). + :param rew_sum: total rewards, with shape + (n_state, n_action). + :param rew_square_sum: total rewards' squares, with shape + (n_state, n_action). + :param rew_count: the number of rewards, with shape + (n_state, n_action). + """ + self.updated = False + self.trans_count += trans_count + sum_count = self.rew_count + rew_count + self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count + self.rew_square_sum += rew_square_sum + raw_std2 = self.rew_square_sum / sum_count - self.rew_mean**2 + self.rew_std = np.sqrt( + 1 / (sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior**2), + ) + self.rew_count = sum_count + + def sample_trans_prob(self) -> np.ndarray: + return torch.distributions.Dirichlet(torch.from_numpy(self.trans_count)).sample().numpy() + + def sample_reward(self) -> np.ndarray: + return np.random.normal(self.rew_mean, self.rew_std) + + def solve_policy(self) -> None: + self.updated = True + self.policy, self.value = self.value_iteration( + self.sample_trans_prob(), + self.sample_reward(), + self.discount_factor, + self.eps, + self.value, + ) + + @staticmethod + def value_iteration( + trans_prob: np.ndarray, + rew: np.ndarray, + discount_factor: float, + eps: float, + value: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray]: + """Value iteration solver for MDPs. + + :param trans_prob: transition probabilities, with shape + (n_state, n_action, n_state). + :param rew: rewards, with shape (n_state, n_action). + :param eps: for precision control. + :param discount_factor: in [0, 1]. + :param value: the initialize value of value array, with + shape (n_state, ). + + :return: the optimal policy with shape (n_state, ). + """ + Q = rew + discount_factor * trans_prob.dot(value) + new_value = Q.max(axis=1) + while not np.allclose(new_value, value, eps): + value = new_value + Q = rew + discount_factor * trans_prob.dot(value) + new_value = Q.max(axis=1) + # this is to make sure if Q(s, a1) == Q(s, a2) -> choose a1/a2 randomly + Q += eps * np.random.randn(*Q.shape) + return Q.argmax(axis=1), new_value + + def __call__( + self, + obs: np.ndarray, + state: Any = None, + info: Any = None, + ) -> np.ndarray: + if not self.updated: + self.solve_policy() + return self.policy[obs] + + +class PSRLPolicy(BasePolicy[TPSRLTrainingStats]): + """Implementation of Posterior Sampling Reinforcement Learning. + + Reference: Strens M. A Bayesian framework for reinforcement learning [C] + //ICML. 2000, 2000: 943-950. + + :param trans_count_prior: dirichlet prior (alphas), with shape + (n_state, n_action, n_state). + :param rew_mean_prior: means of the normal priors of rewards, + with shape (n_state, n_action). + :param rew_std_prior: standard deviations of the normal priors + of rewards, with shape (n_state, n_action). + :param action_space: Env's action_space. + :param discount_factor: in [0, 1]. + :param epsilon: for precision control in value iteration. + :param add_done_loop: whether to add an extra self-loop for the + terminal state in MDP. Default to False. + :param observation_space: Env's observation space. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + trans_count_prior: np.ndarray, + rew_mean_prior: np.ndarray, + rew_std_prior: np.ndarray, + action_space: gym.spaces.Discrete, + discount_factor: float = 0.99, + epsilon: float = 0.01, + add_done_loop: bool = False, + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=False, + action_bound_method=None, + lr_scheduler=lr_scheduler, + ) + assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" + self.model = PSRLModel( + trans_count_prior, + rew_mean_prior, + rew_std_prior, + discount_factor, + epsilon, + ) + self._add_done_loop = add_done_loop + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> ActBatchProtocol: + """Compute action over the given batch data with PSRL model. + + :return: A :class:`~tianshou.data.Batch` with "act" key containing + the action. + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + assert isinstance(batch.obs, np.ndarray), "only support np.ndarray observation" + # TODO: shouldn't the model output a state as well if state is passed (i.e. RNNs are involved)? + act = self.model(batch.obs, state=state, info=batch.info) + return cast(ActBatchProtocol, Batch(act=act)) + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TPSRLTrainingStats: + n_s, n_a = self.model.n_state, self.model.n_action + trans_count = np.zeros((n_s, n_a, n_s)) + rew_sum = np.zeros((n_s, n_a)) + rew_square_sum = np.zeros((n_s, n_a)) + rew_count = np.zeros((n_s, n_a)) + for minibatch in batch.split(size=1): + obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next + obs_next = cast(np.ndarray, obs_next) + assert not isinstance(obs, BatchProtocol), "Observations cannot be Batches here" + trans_count[obs, act, obs_next] += 1 + rew_sum[obs, act] += minibatch.rew + rew_square_sum[obs, act] += minibatch.rew**2 + rew_count[obs, act] += 1 + if self._add_done_loop and minibatch.done: + # special operation for terminal states: add a self-loop + trans_count[obs_next, :, obs_next] += 1 + rew_count[obs_next, :] += 1 + self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count) + + return PSRLTrainingStats( # type: ignore[return-value] + psrl_rew_mean=float(self.model.rew_mean.mean()), + psrl_rew_std=float(self.model.rew_std.mean()), + ) diff --git a/examples/atari/tianshou/policy/modelfree/__init__.py b/examples/atari/tianshou/policy/modelfree/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/atari/tianshou/policy/modelfree/a2c.py b/examples/atari/tianshou/policy/modelfree/a2c.py new file mode 100644 index 0000000000000000000000000000000000000000..d41ccb463731f3359277097d310ff9c54f29dc1e --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/a2c.py @@ -0,0 +1,206 @@ +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as +from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol +from tianshou.policy import PGPolicy +from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic + + +@dataclass(kw_only=True) +class A2CTrainingStats(TrainingStats): + loss: SequenceSummaryStats + actor_loss: SequenceSummaryStats + vf_loss: SequenceSummaryStats + ent_loss: SequenceSummaryStats + + +TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var] + """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. + + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer for actor and critic network. + :param dist_fn: distribution class for computing the action. + :param action_space: env's action space + :param vf_coef: weight for value loss. + :param ent_coef: weight for entropy loss. + :param max_grad_norm: clipping gradients in back propagation. + :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param max_batchsize: the maximum size of the batch when computing GAE. + :param discount_factor: in [0, 1]. + :param reward_normalization: normalize estimated values to have std close to 1. + :param deterministic_eval: if True, use deterministic evaluation. + :param observation_space: the space of the observation. + :param action_scaling: if True, scale the action from [-1, 1] to the range of + action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, + optim: torch.optim.Optimizer, + dist_fn: TDistFnDiscrOrCont, + action_space: gym.Space, + vf_coef: float = 0.5, + ent_coef: float = 0.01, + max_grad_norm: float | None = None, + gae_lambda: float = 0.95, + max_batchsize: int = 256, + discount_factor: float = 0.99, + # TODO: rename to return_normalization? + reward_normalization: bool = False, + deterministic_eval: bool = False, + observation_space: gym.Space | None = None, + action_scaling: bool = True, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + actor=actor, + optim=optim, + dist_fn=dist_fn, + action_space=action_space, + discount_factor=discount_factor, + reward_normalization=reward_normalization, + deterministic_eval=deterministic_eval, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + lr_scheduler=lr_scheduler, + ) + self.critic = critic + assert 0.0 <= gae_lambda <= 1.0, f"GAE lambda should be in [0, 1] but got: {gae_lambda}" + self.gae_lambda = gae_lambda + self.vf_coef = vf_coef + self.ent_coef = ent_coef + self.max_grad_norm = max_grad_norm + self.max_batchsize = max_batchsize + self._actor_critic = ActorCritic(self.actor, self.critic) + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithAdvantagesProtocol: + batch = self._compute_returns(batch, buffer, indices) + batch.act = to_torch_as(batch.act, batch.v_s) + return batch + + def _compute_returns( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithAdvantagesProtocol: + v_s, v_s_ = [], [] + with torch.no_grad(): + for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): + v_s.append(self.critic(minibatch.obs)) + v_s_.append(self.critic(minibatch.obs_next)) + batch.v_s = torch.cat(v_s, dim=0).flatten() # old value + v_s = batch.v_s.cpu().numpy() + v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy() + # when normalizing values, we do not minus self.ret_rms.mean to be numerically + # consistent with OPENAI baselines' value normalization pipeline. Empirical + # study also shows that "minus mean" will harm performances a tiny little bit + # due to unknown reasons (on Mujoco envs, not confident, though). + # TODO: see todo in PGPolicy.process_fn + if self.rew_norm: # unnormalize v_s & v_s_ + v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + unnormalized_returns, advantages = self.compute_episodic_return( + batch, + buffer, + indices, + v_s_, + v_s, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + ) + if self.rew_norm: + batch.returns = unnormalized_returns / np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.update(unnormalized_returns) + else: + batch.returns = unnormalized_returns + batch.returns = to_torch_as(batch.returns, batch.v_s) + batch.adv = to_torch_as(advantages, batch.v_s) + return cast(BatchWithAdvantagesProtocol, batch) + + # TODO: mypy complains b/c signature is different from superclass, although + # it's compatible. Can this be fixed? + def learn( # type: ignore + self, + batch: RolloutBatchProtocol, + batch_size: int | None, + repeat: int, + *args: Any, + **kwargs: Any, + ) -> TA2CTrainingStats: + losses, actor_losses, vf_losses, ent_losses = [], [], [], [] + split_batch_size = batch_size or -1 + for _ in range(repeat): + for minibatch in batch.split(split_batch_size, merge_last=True): + # calculate loss for actor + dist = self(minibatch).dist + log_prob = dist.log_prob(minibatch.act) + log_prob = log_prob.reshape(len(minibatch.adv), -1).transpose(0, 1) + actor_loss = -(log_prob * minibatch.adv).mean() + # calculate loss for critic + value = self.critic(minibatch.obs).flatten() + vf_loss = F.mse_loss(minibatch.returns, value) + # calculate regularization and overall loss + ent_loss = dist.entropy().mean() + loss = actor_loss + self.vf_coef * vf_loss - self.ent_coef * ent_loss + self.optim.zero_grad() + loss.backward() + if self.max_grad_norm: # clip large gradient + nn.utils.clip_grad_norm_( + self._actor_critic.parameters(), + max_norm=self.max_grad_norm, + ) + self.optim.step() + actor_losses.append(actor_loss.item()) + vf_losses.append(vf_loss.item()) + ent_losses.append(ent_loss.item()) + losses.append(loss.item()) + + loss_summary_stat = SequenceSummaryStats.from_sequence(losses) + actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) + vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) + ent_loss_summary_stat = SequenceSummaryStats.from_sequence(ent_losses) + + return A2CTrainingStats( # type: ignore[return-value] + loss=loss_summary_stat, + actor_loss=actor_loss_summary_stat, + vf_loss=vf_loss_summary_stat, + ent_loss=ent_loss_summary_stat, + ) diff --git a/examples/atari/tianshou/policy/modelfree/bdq.py b/examples/atari/tianshou/policy/modelfree/bdq.py new file mode 100644 index 0000000000000000000000000000000000000000..d7196a92b6bdc7c6891da2bb453b442b0ab14835 --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/bdq.py @@ -0,0 +1,204 @@ +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch + +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + ActBatchProtocol, + BatchWithReturnsProtocol, + ModelOutputBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.policy import DQNPolicy +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.dqn import DQNTrainingStats +from tianshou.utils.net.common import BranchingNet + + +@dataclass(kw_only=True) +class BDQNTrainingStats(DQNTrainingStats): + pass + + +TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats) + + +class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]): + """Implementation of the Branching dual Q network arXiv:1711.08946. + + :param model: BranchingNet mapping (obs, state, info) -> action_values_BA. + :param optim: a torch.optim for optimizing the model. + :param discount_factor: in [0, 1]. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param observation_space: Env's observation space. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + model: BranchingNet, + optim: torch.optim.Optimizer, + action_space: gym.spaces.Discrete, + discount_factor: float = 0.99, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + is_double: bool = True, + clip_loss_grad: bool = False, + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + assert ( + estimation_step == 1 + ), f"N-step bigger than one is not supported by BDQ but got: {estimation_step}" + super().__init__( + model=model, + optim=optim, + action_space=action_space, + discount_factor=discount_factor, + estimation_step=estimation_step, + target_update_freq=target_update_freq, + reward_normalization=reward_normalization, + is_double=is_double, + clip_loss_grad=clip_loss_grad, + observation_space=observation_space, + lr_scheduler=lr_scheduler, + ) + self.model = cast(BranchingNet, self.model) + + # TODO: this used to be a public property called max_action_num, + # but it collides with an attr of the same name in base class + @property + def _action_per_branch(self) -> int: + return self.model.action_per_branch + + @property + def num_branches(self) -> int: + return self.model.num_branches + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + result = self(obs_next_batch) + if self._target: + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + target_q = self(obs_next_batch, model="model_old").logits + else: + target_q = result.logits + if self.is_double: + act = np.expand_dims(self(obs_next_batch).act, -1) + act = to_torch(act, dtype=torch.long, device=target_q.device) + else: + act = target_q.max(-1).indices.unsqueeze(-1) + return torch.gather(target_q, -1, act).squeeze() + + def _compute_return( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indice: np.ndarray, + gamma: float = 0.99, + ) -> BatchWithReturnsProtocol: + rew = batch.rew + with torch.no_grad(): + target_q_torch = self._target_q(buffer, indice) # (bsz, ?) + target_q = to_numpy(target_q_torch) + end_flag = buffer.done.copy() + end_flag[buffer.unfinished_index()] = True + end_flag = end_flag[indice] + mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q + _target_q = rew + gamma * mean_target_q * (1 - end_flag) + target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1) + target_q = np.repeat(target_q[..., None], self._action_per_branch, axis=-1) + + batch.returns = to_torch_as(target_q, target_q_torch) + if hasattr(batch, "weight"): # prio buffer update + batch.weight = to_torch_as(batch.weight, target_q_torch) + return cast(BatchWithReturnsProtocol, batch) + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithReturnsProtocol: + """Compute the 1-step return for BDQ targets.""" + return self._compute_return(batch, buffer, indices) + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + model: Literal["model", "model_old"] = "model", + **kwargs: Any, + ) -> ModelOutputBatchProtocol: + model = getattr(self, model) + obs = batch.obs + # TODO: this is very contrived, see also iqn.py + obs_next_BO = obs.obs if hasattr(obs, "obs") else obs + action_values_BA, hidden_BH = model(obs_next_BO, state=state, info=batch.info) + act_B = to_numpy(action_values_BA.argmax(dim=-1)) + result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) + return cast(ModelOutputBatchProtocol, result) + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats: + if self._target and self._iter % self.freq == 0: + self.sync_weight() + self.optim.zero_grad() + weight = batch.pop("weight", 1.0) + act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device) + q = self(batch).logits + act_mask = torch.zeros_like(q) + act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1) + act_q = q * act_mask + returns = batch.returns + returns = returns * act_mask + td_error = returns - act_q + loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean() + batch.weight = td_error.sum(-1).sum(-1) # prio-buffer + loss.backward() + self.optim.step() + self._iter += 1 + + return BDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] + + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + + def exploration_noise( + self, + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: + if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): + bsz = len(act) + rand_mask = np.random.rand(bsz) < self.eps + rand_act = np.random.randint( + low=0, + high=self._action_per_branch, + size=(bsz, act.shape[-1]), + ) + if hasattr(batch.obs, "mask"): + rand_act += batch.obs.mask + act[rand_mask] = rand_act[rand_mask] + return act diff --git a/examples/atari/tianshou/policy/modelfree/c51.py b/examples/atari/tianshou/policy/modelfree/c51.py new file mode 100644 index 0000000000000000000000000000000000000000..5bfdba0c1d97e7ee95900c442eb9041748dd320f --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/c51.py @@ -0,0 +1,137 @@ +from dataclasses import dataclass +from typing import Any, Generic, TypeVar + +import gymnasium as gym +import numpy as np +import torch + +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.types import RolloutBatchProtocol +from tianshou.policy import DQNPolicy +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.dqn import DQNTrainingStats + + +@dataclass(kw_only=True) +class C51TrainingStats(DQNTrainingStats): + pass + + +TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats) + + +class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]): + """Implementation of Categorical Deep Q-Network. arXiv:1707.06887. + + :param model: a model following the rules (s_B -> action_values_BA) + :param optim: a torch.optim for optimizing the model. + :param discount_factor: in [0, 1]. + :param num_atoms: the number of atoms in the support set of the + value distribution. Default to 51. + :param v_min: the value of the smallest atom in the support set. + Default to -10.0. + :param v_max: the value of the largest atom in the support set. + Default to 10.0. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param observation_space: Env's observation space. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + model: torch.nn.Module, + optim: torch.optim.Optimizer, + action_space: gym.spaces.Discrete, + discount_factor: float = 0.99, + num_atoms: int = 51, + v_min: float = -10.0, + v_max: float = 10.0, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + is_double: bool = True, + clip_loss_grad: bool = False, + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + assert num_atoms > 1, f"num_atoms should be greater than 1 but got: {num_atoms}" + assert v_min < v_max, f"v_max should be larger than v_min, but got {v_min=} and {v_max=}" + super().__init__( + model=model, + optim=optim, + action_space=action_space, + discount_factor=discount_factor, + estimation_step=estimation_step, + target_update_freq=target_update_freq, + reward_normalization=reward_normalization, + is_double=is_double, + clip_loss_grad=clip_loss_grad, + observation_space=observation_space, + lr_scheduler=lr_scheduler, + ) + self._num_atoms = num_atoms + self._v_min = v_min + self._v_max = v_max + self.support = torch.nn.Parameter( + torch.linspace(self._v_min, self._v_max, self._num_atoms), + requires_grad=False, + ) + self.delta_z = (v_max - v_min) / (num_atoms - 1) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + return self.support.repeat(len(indices), 1) # shape: [bsz, num_atoms] + + def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: + return super().compute_q_value((logits * self.support).sum(2), mask) + + def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor: + obs_next_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) + if self._target: + act = self(obs_next_batch).act + next_dist = self(obs_next_batch, model="model_old").logits + else: + next_batch = self(obs_next_batch) + act = next_batch.act + next_dist = next_batch.logits + next_dist = next_dist[np.arange(len(act)), act, :] + target_support = batch.returns.clamp(self._v_min, self._v_max) + # An amazing trick for calculating the projection gracefully. + # ref: https://github.com/ShangtongZhang/DeepRL + target_dist = ( + 1 - (target_support.unsqueeze(1) - self.support.view(1, -1, 1)).abs() / self.delta_z + ).clamp(0, 1) * next_dist.unsqueeze(1) + return target_dist.sum(-1) + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats: + if self._target and self._iter % self.freq == 0: + self.sync_weight() + self.optim.zero_grad() + with torch.no_grad(): + target_dist = self._target_dist(batch) + weight = batch.pop("weight", 1.0) + curr_dist = self(batch).logits + act = batch.act + curr_dist = curr_dist[np.arange(len(act)), act, :] + cross_entropy = -(target_dist * torch.log(curr_dist + 1e-8)).sum(1) + loss = (cross_entropy * weight).mean() + # ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100 + batch.weight = cross_entropy.detach() # prio-buffer + loss.backward() + self.optim.step() + self._iter += 1 + + return C51TrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/examples/atari/tianshou/policy/modelfree/ddpg.py b/examples/atari/tianshou/policy/modelfree/ddpg.py new file mode 100644 index 0000000000000000000000000000000000000000..f21744f72ff691f5535b7bbf5cbf8f3fea1c353b --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/ddpg.py @@ -0,0 +1,224 @@ +import warnings +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch + +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + ActBatchProtocol, + ActStateBatchProtocol, + BatchWithReturnsProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.exploration import BaseNoise, GaussianNoise +from tianshou.policy import BasePolicy +from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.utils.net.continuous import Actor, Critic + + +@dataclass(kw_only=True) +class DDPGTrainingStats(TrainingStats): + actor_loss: float + critic_loss: float + + +TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats) + + +class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): + """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971. + + :param actor: The actor network following the rules (s -> actions) + :param actor_optim: The optimizer for actor network. + :param critic: The critic network. (s, a -> Q(s, a)) + :param critic_optim: The optimizer for critic network. + :param action_space: Env's action space. + :param tau: Param for soft update of the target network. + :param gamma: Discount factor, in [0, 1]. + :param exploration_noise: The exploration noise, added to the action. Defaults + to ``GaussianNoise(sigma=0.1)``. + :param estimation_step: The number of steps to look ahead. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + actor: torch.nn.Module | Actor, + actor_optim: torch.optim.Optimizer, + critic: torch.nn.Module | Critic, + critic_optim: torch.optim.Optimizer, + action_space: gym.Space, + tau: float = 0.005, + gamma: float = 0.99, + exploration_noise: BaseNoise | Literal["default"] | None = "default", + estimation_step: int = 1, + observation_space: gym.Space | None = None, + action_scaling: bool = True, + # tanh not supported, see assert below + action_bound_method: Literal["clip"] | None = "clip", + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + assert 0.0 <= tau <= 1.0, f"tau should be in [0, 1] but got: {tau}" + assert 0.0 <= gamma <= 1.0, f"gamma should be in [0, 1] but got: {gamma}" + assert action_bound_method != "tanh", ( # type: ignore[comparison-overlap] + "tanh mapping is not supported" + "in policies where action is used as input of critic , because" + "raw action in range (-inf, inf) will cause instability in training" + ) + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + lr_scheduler=lr_scheduler, + ) + if action_scaling and not np.isclose(actor.max_action, 1.0): + warnings.warn( + "action_scaling and action_bound_method are only intended to deal" + "with unbounded model action space, but find actor model bound" + f"action space with max_action={actor.max_action}." + "Consider using unbounded=True option of the actor model," + "or set action_scaling to False and action_bound_method to None.", + ) + self.actor = actor + self.actor_old = deepcopy(actor) + self.actor_old.eval() + self.actor_optim = actor_optim + self.critic = critic + self.critic_old = deepcopy(critic) + self.critic_old.eval() + self.critic_optim = critic_optim + self.tau = tau + self.gamma = gamma + if exploration_noise == "default": + exploration_noise = GaussianNoise(sigma=0.1) + # TODO: IMPORTANT - can't call this "exploration_noise" because confusingly, + # there is already a method called exploration_noise() in the base class + # Now this method doesn't apply any noise and is also not overridden. See TODO there + self._exploration_noise = exploration_noise + # it is only a little difference to use GaussianNoise + # self.noise = OUNoise() + self.estimation_step = estimation_step + + def set_exp_noise(self, noise: BaseNoise | None) -> None: + """Set the exploration noise.""" + self._exploration_noise = noise + + def train(self, mode: bool = True) -> Self: + """Set the module in training mode, except for the target network.""" + self.training = mode + self.actor.train(mode) + self.critic.train(mode) + return self + + def sync_weight(self) -> None: + """Soft-update the weight for the target network.""" + self.soft_update(self.actor_old, self.actor, self.tau) + self.soft_update(self.critic_old, self.critic, self.tau) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + return self.critic_old(obs_next_batch.obs, self(obs_next_batch, model="actor_old").act) + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol | BatchWithReturnsProtocol: + return self.compute_nstep_return( + batch=batch, + buffer=buffer, + indices=indices, + target_q_fn=self._target_q, + gamma=self.gamma, + n_step=self.estimation_step, + ) + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + model: Literal["actor", "actor_old"] = "actor", + **kwargs: Any, + ) -> ActStateBatchProtocol: + """Compute action over the given batch data. + + :return: A :class:`~tianshou.data.Batch` which has 2 keys: + + * ``act`` the action. + * ``state`` the hidden state. + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + model = getattr(self, model) + actions, hidden = model(batch.obs, state=state, info=batch.info) + return cast(ActStateBatchProtocol, Batch(act=actions, state=hidden)) + + @staticmethod + def _mse_optimizer( + batch: RolloutBatchProtocol, + critic: torch.nn.Module, + optimizer: torch.optim.Optimizer, + ) -> tuple[torch.Tensor, torch.Tensor]: + """A simple wrapper script for updating critic network.""" + weight = getattr(batch, "weight", 1.0) + current_q = critic(batch.obs, batch.act).flatten() + target_q = batch.returns.flatten() + td = current_q - target_q + # critic_loss = F.mse_loss(current_q1, target_q) + critic_loss = (td.pow(2) * weight).mean() + optimizer.zero_grad() + critic_loss.backward() + optimizer.step() + return td, critic_loss + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDDPGTrainingStats: # type: ignore + # critic + td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) + batch.weight = td # prio-buffer + # actor + actor_loss = -self.critic(batch.obs, self(batch).act).mean() + self.actor_optim.zero_grad() + actor_loss.backward() + self.actor_optim.step() + self.sync_weight() + + return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) # type: ignore[return-value] + + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + + def exploration_noise( + self, + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: + if self._exploration_noise is None: + return act + if isinstance(act, np.ndarray): + return act + self._exploration_noise(act.shape) + warnings.warn("Cannot add exploration noise to non-numpy_array action.") + return act diff --git a/examples/atari/tianshou/policy/modelfree/discrete_sac.py b/examples/atari/tianshou/policy/modelfree/discrete_sac.py new file mode 100644 index 0000000000000000000000000000000000000000..d1ce28da9700425d636b8260e234c753a882afc0 --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/discrete_sac.py @@ -0,0 +1,194 @@ +from dataclasses import dataclass +from typing import Any, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +from overrides import override +from torch.distributions import Categorical + +from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol +from tianshou.policy import SACPolicy +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.sac import SACTrainingStats +from tianshou.utils.net.discrete import Actor, Critic + + +@dataclass +class DiscreteSACTrainingStats(SACTrainingStats): + pass + + +TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteSACTrainingStats) + + +class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): + """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207. + + :param actor: the actor network following the rules (s_B -> dist_input_BD) + :param actor_optim: the optimizer for actor network. + :param critic: the first critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer for the first critic network. + :param action_space: Env's action space. Should be gym.spaces.Box. + :param critic2: the second critic network. (s, a -> Q(s, a)). + If None, use the same network as critic (via deepcopy). + :param critic2_optim: the optimizer for the second critic network. + If None, clone critic_optim to use for critic2.parameters(). + :param tau: param for soft update of the target network. + :param gamma: discount factor, in [0, 1]. + :param alpha: entropy regularization coefficient. + If a tuple (target_entropy, log_alpha, alpha_optim) is provided, + then alpha is automatically tuned. + :param estimation_step: the number of steps to look ahead for calculating + :param observation_space: Env's observation space. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate + in optimizer in each policy.update() + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + actor: torch.nn.Module | Actor, + actor_optim: torch.optim.Optimizer, + critic: torch.nn.Module | Critic, + critic_optim: torch.optim.Optimizer, + action_space: gym.spaces.Discrete, + critic2: torch.nn.Module | Critic | None = None, + critic2_optim: torch.optim.Optimizer | None = None, + tau: float = 0.005, + gamma: float = 0.99, + alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2, + estimation_step: int = 1, + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + actor=actor, + actor_optim=actor_optim, + critic=critic, + critic_optim=critic_optim, + action_space=action_space, + critic2=critic2, + critic2_optim=critic2_optim, + tau=tau, + gamma=gamma, + alpha=alpha, + estimation_step=estimation_step, + # Note: inheriting from continuous sac reduces code duplication, + # but continuous stuff has to be disabled + exploration_noise=None, + action_scaling=False, + action_bound_method=None, + observation_space=observation_space, + lr_scheduler=lr_scheduler, + ) + + # TODO: violates Liskov substitution principle, incompatible action space with SAC + # Not too urgent, but still.. + @override + def _check_field_validity(self) -> None: + if not isinstance(self.action_space, gym.spaces.Discrete): + raise ValueError( + f"DiscreteSACPolicy only supports gym.spaces.Discrete, but got {self.action_space=}." + f"Please use SACPolicy for continuous action spaces.", + ) + + def forward( # type: ignore + self, + batch: ObsBatchProtocol, + state: dict | Batch | np.ndarray | None = None, + **kwargs: Any, + ) -> Batch: + logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Categorical(logits=logits_BA) + act_B = ( + dist.mode + if self.deterministic_eval and not self.is_within_training_step + else dist.sample() + ) + return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + obs_next_result = self(obs_next_batch) + dist = obs_next_result.dist + target_q = dist.probs * torch.min( + self.critic_old(obs_next_batch.obs), + self.critic2_old(obs_next_batch.obs), + ) + return target_q.sum(dim=-1) + self.alpha * dist.entropy() + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore + weight = batch.pop("weight", 1.0) + target_q = batch.returns.flatten() + act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) + + # critic 1 + current_q1 = self.critic(batch.obs).gather(1, act).flatten() + td1 = current_q1 - target_q + critic1_loss = (td1.pow(2) * weight).mean() + + self.critic_optim.zero_grad() + critic1_loss.backward() + self.critic_optim.step() + + # critic 2 + current_q2 = self.critic2(batch.obs).gather(1, act).flatten() + td2 = current_q2 - target_q + critic2_loss = (td2.pow(2) * weight).mean() + + self.critic2_optim.zero_grad() + critic2_loss.backward() + self.critic2_optim.step() + batch.weight = (td1 + td2) / 2.0 # prio-buffer + + # actor + dist = self(batch).dist + entropy = dist.entropy() + with torch.no_grad(): + current_q1a = self.critic(batch.obs) + current_q2a = self.critic2(batch.obs) + q = torch.min(current_q1a, current_q2a) + actor_loss = -(self.alpha * entropy + (dist.probs * q).sum(dim=-1)).mean() + self.actor_optim.zero_grad() + actor_loss.backward() + self.actor_optim.step() + + if self.is_auto_alpha: + log_prob = -entropy.detach() + self.target_entropy + alpha_loss = -(self.log_alpha * log_prob).mean() + self.alpha_optim.zero_grad() + alpha_loss.backward() + self.alpha_optim.step() + self.alpha = self.log_alpha.detach().exp() + + self.sync_weight() + + if self.is_auto_alpha: + self.alpha = cast(torch.Tensor, self.alpha) + + return DiscreteSACTrainingStats( # type: ignore[return-value] + actor_loss=actor_loss.item(), + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + alpha=self.alpha.item() if isinstance(self.alpha, torch.Tensor) else self.alpha, + alpha_loss=None if not self.is_auto_alpha else alpha_loss.item(), + ) + + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + + def exploration_noise( + self, + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: + return act diff --git a/examples/atari/tianshou/policy/modelfree/dqn.py b/examples/atari/tianshou/policy/modelfree/dqn.py new file mode 100644 index 0000000000000000000000000000000000000000..e0ada073334fff0fcc8433fcc7e8d7aee7c806eb --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/dqn.py @@ -0,0 +1,254 @@ +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch + +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + ActBatchProtocol, + BatchWithReturnsProtocol, + ModelOutputBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.policy import BasePolicy +from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.utils.net.common import Net + + +@dataclass(kw_only=True) +class DQNTrainingStats(TrainingStats): + loss: float + + +TDQNTrainingStats = TypeVar("TDQNTrainingStats", bound=DQNTrainingStats) + + +class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]): + """Implementation of Deep Q Network. arXiv:1312.5602. + + Implementation of Double Q-Learning. arXiv:1509.06461. + + Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is + implemented in the network side, not here). + + :param model: a model following the rules (s -> action_values_BA) + :param optim: a torch.optim for optimizing the model. + :param discount_factor: in [0, 1]. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param observation_space: Env's observation space. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + model: torch.nn.Module | Net, + optim: torch.optim.Optimizer, + # TODO: type violates Liskov substitution principle + action_space: gym.spaces.Discrete, + discount_factor: float = 0.99, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + is_double: bool = True, + clip_loss_grad: bool = False, + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=False, + action_bound_method=None, + lr_scheduler=lr_scheduler, + ) + self.model = model + self.optim = optim + self.eps = 0.0 + assert ( + 0.0 <= discount_factor <= 1.0 + ), f"discount factor should be in [0, 1] but got: {discount_factor}" + self.gamma = discount_factor + assert ( + estimation_step > 0 + ), f"estimation_step should be greater than 0 but got: {estimation_step}" + self.n_step = estimation_step + self._target = target_update_freq > 0 + self.freq = target_update_freq + self._iter = 0 + if self._target: + self.model_old = deepcopy(self.model) + self.model_old.eval() + self.rew_norm = reward_normalization + self.is_double = is_double + self.clip_loss_grad = clip_loss_grad + + # TODO: set in forward, fix this! + self.max_action_num: int | None = None + + def set_eps(self, eps: float) -> None: + """Set the eps for epsilon-greedy exploration.""" + self.eps = eps + + def train(self, mode: bool = True) -> Self: + """Set the module in training mode, except for the target network.""" + self.training = mode + self.model.train(mode) + return self + + def sync_weight(self) -> None: + """Synchronize the weight for the target network.""" + self.model_old.load_state_dict(self.model.state_dict()) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + result = self(obs_next_batch) + if self._target: + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + target_q = self(obs_next_batch, model="model_old").logits + else: + target_q = result.logits + if self.is_double: + return target_q[np.arange(len(result.act)), result.act] + # Nature DQN, over estimate + return target_q.max(dim=1)[0] + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithReturnsProtocol: + """Compute the n-step return for Q-learning targets. + + More details can be found at + :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`. + """ + return self.compute_nstep_return( + batch=batch, + buffer=buffer, + indices=indices, + target_q_fn=self._target_q, + gamma=self.gamma, + n_step=self.n_step, + rew_norm=self.rew_norm, + ) + + def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: + """Compute the q value based on the network's raw output and action mask.""" + if mask is not None: + # the masked q value should be smaller than logits.min() + min_value = logits.min() - logits.max() - 1.0 + logits = logits + to_torch_as(1 - mask, logits) * min_value + return logits + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + model: Literal["model", "model_old"] = "model", + **kwargs: Any, + ) -> ModelOutputBatchProtocol: + """Compute action over the given batch data. + + If you need to mask the action, please add a "mask" into batch.obs, for + example, if we have an environment that has "0/1/2" three actions: + :: + + batch == Batch( + obs=Batch( + obs="original obs, with batch_size=1 for demonstration", + mask=np.array([[False, True, False]]), + # action 1 is available + # action 0 and 2 are unavailable + ), + ... + ) + + :return: A :class:`~tianshou.data.Batch` which has 3 keys: + + * ``act`` the action. + * ``logits`` the network's raw output. + * ``state`` the hidden state. + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + model = getattr(self, model) + obs = batch.obs + # TODO: this is convoluted! See also other places where this is done. + obs_next = obs.obs if hasattr(obs, "obs") else obs + action_values_BA, hidden_BH = model(obs_next, state=state, info=batch.info) + q = self.compute_q_value(action_values_BA, getattr(obs, "mask", None)) + if self.max_action_num is None: + self.max_action_num = q.shape[1] + act_B = to_numpy(q.argmax(dim=1)) + result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) + return cast(ModelOutputBatchProtocol, result) + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats: + if self._target and self._iter % self.freq == 0: + self.sync_weight() + self.optim.zero_grad() + weight = batch.pop("weight", 1.0) + q = self(batch).logits + q = q[np.arange(len(q)), batch.act] + returns = to_torch_as(batch.returns.flatten(), q) + td_error = returns - q + + if self.clip_loss_grad: + y = q.reshape(-1, 1) + t = returns.reshape(-1, 1) + loss = torch.nn.functional.huber_loss(y, t, reduction="mean") + else: + loss = (td_error.pow(2) * weight).mean() + + batch.weight = td_error # prio-buffer + loss.backward() + self.optim.step() + self._iter += 1 + + return DQNTrainingStats(loss=loss.item()) # type: ignore[return-value] + + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + + def exploration_noise( + self, + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: + if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): + bsz = len(act) + rand_mask = np.random.rand(bsz) < self.eps + assert ( + self.max_action_num is not None + ), "Can't call this method before max_action_num was set in first forward" + q = np.random.rand(bsz, self.max_action_num) # [0, 1] + if hasattr(batch.obs, "mask"): + q += batch.obs.mask + rand_act = q.argmax(axis=1) + act[rand_mask] = rand_act[rand_mask] + return act diff --git a/examples/atari/tianshou/policy/modelfree/fqf.py b/examples/atari/tianshou/policy/modelfree/fqf.py new file mode 100644 index 0000000000000000000000000000000000000000..cb68e97024d75ee14ad35fe16acdf85520fcd826 --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/fqf.py @@ -0,0 +1,221 @@ +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.data import Batch, ReplayBuffer, to_numpy +from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol +from tianshou.policy import DQNPolicy, QRDQNPolicy +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats +from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction + + + + +@dataclass(kw_only=True) +class FQFTrainingStats(QRDQNTrainingStats): + quantile_loss: float + fraction_loss: float + entropy_loss: float + + +TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats) + + +class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]): + """Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140. + + :param model: a model following the rules (s_B -> action_values_BA) + :param optim: a torch.optim for optimizing the model. + :param fraction_model: a FractionProposalNetwork for + proposing fractions/quantiles given state. + :param fraction_optim: a torch.optim for optimizing + the fraction model above. + :param action_space: Env's action space. + :param discount_factor: in [0, 1]. + :param num_fractions: the number of fractions to use. + :param ent_coef: the coefficient for entropy loss. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param observation_space: Env's observation space. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + model: FullQuantileFunction, + optim: torch.optim.Optimizer, + fraction_model: FractionProposalNetwork, + fraction_optim: torch.optim.Optimizer, + action_space: gym.spaces.Discrete, + discount_factor: float = 0.99, + # TODO: used as num_quantiles in QRDQNPolicy, but num_fractions in FQFPolicy. + # Rename? Or at least explain what happens here. + num_fractions: int = 32, + ent_coef: float = 0.0, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + is_double: bool = True, + clip_loss_grad: bool = False, + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + model=model, + optim=optim, + action_space=action_space, + discount_factor=discount_factor, + num_quantiles=num_fractions, + estimation_step=estimation_step, + target_update_freq=target_update_freq, + reward_normalization=reward_normalization, + is_double=is_double, + clip_loss_grad=clip_loss_grad, + observation_space=observation_space, + lr_scheduler=lr_scheduler, + ) + self.fraction_model = fraction_model + self.ent_coef = ent_coef + self.fraction_optim = fraction_optim + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + if self._target: + result = self(obs_next_batch) + act, fractions = result.act, result.fractions + next_dist = self(obs_next_batch, model="model_old", fractions=fractions).logits + else: + next_batch = self(obs_next_batch) + act = next_batch.act + next_dist = next_batch.logits + return next_dist[np.arange(len(act)), act, :] + + # TODO: fix Liskov substitution principle violation + def forward( # type: ignore + self, + batch: ObsBatchProtocol, + state: dict | Batch | np.ndarray | None = None, + model: Literal["model", "model_old"] = "model", + fractions: Batch | None = None, + **kwargs: Any, + ) -> FQFBatchProtocol: + model = getattr(self, model) + obs = batch.obs + # TODO: this is convoluted! See also other places where this is done + obs_next = obs.obs if hasattr(obs, "obs") else obs + if fractions is None: + (logits, fractions, quantiles_tau), hidden = model( + obs_next, + propose_model=self.fraction_model, + state=state, + info=batch.info, + ) + else: + (logits, _, quantiles_tau), hidden = model( + obs_next, + propose_model=self.fraction_model, + fractions=fractions, + state=state, + info=batch.info, + ) + weighted_logits = (fractions.taus[:, 1:] - fractions.taus[:, :-1]).unsqueeze(1) * logits + q = DQNPolicy.compute_q_value(self, weighted_logits.sum(2), getattr(obs, "mask", None)) + if self.max_action_num is None: # type: ignore + # TODO: see same thing in DQNPolicy! Also reduce code duplication. + self.max_action_num = q.shape[1] + act = to_numpy(q.max(dim=1)[1]) + result = Batch( + logits=logits, + act=act, + state=hidden, + fractions=fractions, + quantiles_tau=quantiles_tau, + ) + return cast(FQFBatchProtocol, result) + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats: + if self._target and self._iter % self.freq == 0: + self.sync_weight() + weight = batch.pop("weight", 1.0) + out = self(batch) + curr_dist_orig = out.logits + taus, tau_hats = out.fractions.taus, out.fractions.tau_hats + act = batch.act + curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2) + target_dist = batch.returns.unsqueeze(1) + # calculate each element's difference between curr_dist and target_dist + dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + huber_loss = ( + ( + dist_diff + * (tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.0).float()).abs() + ) + .sum(-1) + .mean(1) + ) + quantile_loss = (huber_loss * weight).mean() + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 + batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer + # calculate fraction loss + with torch.no_grad(): + sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :] + sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :] + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169 + values_1 = sa_quantiles - sa_quantile_hats[:, :-1] + signs_1 = sa_quantiles > torch.cat( + [sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], + dim=1, + ) + + values_2 = sa_quantiles - sa_quantile_hats[:, 1:] + signs_2 = sa_quantiles < torch.cat( + [sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], + dim=1, + ) + + gradient_of_taus = torch.where(signs_1, values_1, -values_1) + torch.where( + signs_2, + values_2, + -values_2, + ) + fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean() + # calculate entropy loss + entropy_loss = out.fractions.entropies.mean() + fraction_entropy_loss = fraction_loss - self.ent_coef * entropy_loss + self.fraction_optim.zero_grad() + fraction_entropy_loss.backward(retain_graph=True) + self.fraction_optim.step() + self.optim.zero_grad() + quantile_loss.backward() + self.optim.step() + self._iter += 1 + + return FQFTrainingStats( # type: ignore[return-value] + loss=quantile_loss.item() + fraction_entropy_loss.item(), + quantile_loss=quantile_loss.item(), + fraction_loss=fraction_loss.item(), + entropy_loss=entropy_loss.item(), + ) diff --git a/examples/atari/tianshou/policy/modelfree/fqf_rainbow.py b/examples/atari/tianshou/policy/modelfree/fqf_rainbow.py new file mode 100644 index 0000000000000000000000000000000000000000..bf391e535f0ff2434359b5620489e062bfc0e48a --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/fqf_rainbow.py @@ -0,0 +1,248 @@ +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from tianshou.data import Batch, ReplayBuffer, to_numpy +from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol +from tianshou.policy import DQNPolicy, QRDQNPolicy +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats +from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction +from tianshou.utils.net.discrete import NoisyLinear + + + +# TODO: this is a hacky thing interviewing side-effects and a return. Should improve. +def _sample_noise(model: nn.Module) -> bool: + """Sample the random noises of NoisyLinear modules in the model. + + Returns True if at least one NoisyLinear submodule was found. + + :param model: a PyTorch module which may have NoisyLinear submodules. + :returns: True if model has at least one NoisyLinear submodule; + otherwise, False. + """ + sampled_any_noise = False + for m in model.modules(): + if isinstance(m, NoisyLinear): + m.sample() + sampled_any_noise = True + return sampled_any_noise + + +@dataclass(kw_only=True) +class FQFTrainingStats(QRDQNTrainingStats): + quantile_loss: float + fraction_loss: float + entropy_loss: float + + +TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats) + + +class FQF_RainbowPolicy(QRDQNPolicy[TFQFTrainingStats]): + """Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140. + + :param model: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param optim: a torch.optim for optimizing the model. + :param fraction_model: a FractionProposalNetwork for + proposing fractions/quantiles given state. + :param fraction_optim: a torch.optim for optimizing + the fraction model above. + :param action_space: Env's action space. + :param discount_factor: in [0, 1]. + :param num_fractions: the number of fractions to use. + :param ent_coef: the coefficient for entropy loss. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param observation_space: Env's observation space. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + model: FullQuantileFunction, + optim: torch.optim.Optimizer, + fraction_model: FractionProposalNetwork, + fraction_optim: torch.optim.Optimizer, + action_space: gym.spaces.Discrete, + discount_factor: float = 0.99, + # TODO: used as num_quantiles in QRDQNPolicy, but num_fractions in FQFPolicy. + # Rename? Or at least explain what happens here. + num_fractions: int = 32, + ent_coef: float = 0.0, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + is_double: bool = True, + clip_loss_grad: bool = False, + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + is_noisy: bool = False + ) -> None: + super().__init__( + model=model, + optim=optim, + action_space=action_space, + discount_factor=discount_factor, + num_quantiles=num_fractions, + estimation_step=estimation_step, + target_update_freq=target_update_freq, + reward_normalization=reward_normalization, + is_double=is_double, + clip_loss_grad=clip_loss_grad, + observation_space=observation_space, + lr_scheduler=lr_scheduler, + ) + self.fraction_model = fraction_model + self.ent_coef = ent_coef + self.fraction_optim = fraction_optim + self.is_noisy = is_noisy + print("Noisy is", self.is_noisy) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + if self._target: + result = self(obs_next_batch) + act, fractions = result.act, result.fractions + next_dist = self(obs_next_batch, model="model_old", fractions=fractions).logits + else: + next_batch = self(obs_next_batch) + act = next_batch.act + next_dist = next_batch.logits + return next_dist[np.arange(len(act)), act, :] + + # TODO: fix Liskov substitution principle violation + def forward( # type: ignore + self, + batch: ObsBatchProtocol, + state: dict | Batch | np.ndarray | None = None, + model: Literal["model", "model_old"] = "model", + fractions: Batch | None = None, + **kwargs: Any, + ) -> FQFBatchProtocol: + model = getattr(self, model) + obs = batch.obs + # TODO: this is convoluted! See also other places where this is done + obs_next = obs.obs if hasattr(obs, "obs") else obs + if fractions is None: + (logits, fractions, quantiles_tau), hidden = model( + obs_next, + propose_model=self.fraction_model, + state=state, + info=batch.info, + ) + else: + (logits, _, quantiles_tau), hidden = model( + obs_next, + propose_model=self.fraction_model, + fractions=fractions, + state=state, + info=batch.info, + ) + weighted_logits = (fractions.taus[:, 1:] - fractions.taus[:, :-1]).unsqueeze(1) * logits + q = DQNPolicy.compute_q_value(self, weighted_logits.sum(2), getattr(obs, "mask", None)) + if self.max_action_num is None: # type: ignore + # TODO: see same thing in DQNPolicy! Also reduce code duplication. + self.max_action_num = q.shape[1] + act = to_numpy(q.max(dim=1)[1]) + result = Batch( + logits=logits, + act=act, + state=hidden, + fractions=fractions, + quantiles_tau=quantiles_tau, + ) + return cast(FQFBatchProtocol, result) + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats: + if self._target and self._iter % self.freq == 0: + self.sync_weight() + if self.is_noisy: + _sample_noise(self.model) + if self._target and _sample_noise(self.model_old): + self.model_old.train() # so that NoisyLinear takes effect + weight = batch.pop("weight", 1.0) + out = self(batch) + curr_dist_orig = out.logits + taus, tau_hats = out.fractions.taus, out.fractions.tau_hats + act = batch.act + curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2) + target_dist = batch.returns.unsqueeze(1) + # calculate each element's difference between curr_dist and target_dist + dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + huber_loss = ( + ( + dist_diff + * (tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.0).float()).abs() + ) + .sum(-1) + .mean(1) + ) + quantile_loss = (huber_loss * weight).mean() + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 + batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer + # calculate fraction loss + with torch.no_grad(): + sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :] + sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :] + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169 + values_1 = sa_quantiles - sa_quantile_hats[:, :-1] + signs_1 = sa_quantiles > torch.cat( + [sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], + dim=1, + ) + + values_2 = sa_quantiles - sa_quantile_hats[:, 1:] + signs_2 = sa_quantiles < torch.cat( + [sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], + dim=1, + ) + + gradient_of_taus = torch.where(signs_1, values_1, -values_1) + torch.where( + signs_2, + values_2, + -values_2, + ) + fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean() + # calculate entropy loss + entropy_loss = out.fractions.entropies.mean() + fraction_entropy_loss = fraction_loss - self.ent_coef * entropy_loss + self.fraction_optim.zero_grad() + fraction_entropy_loss.backward(retain_graph=True) + self.fraction_optim.step() + self.optim.zero_grad() + quantile_loss.backward() + self.optim.step() + self._iter += 1 + + return FQFTrainingStats( # type: ignore[return-value] + loss=quantile_loss.item() + fraction_entropy_loss.item(), + quantile_loss=quantile_loss.item(), + fraction_loss=fraction_loss.item(), + entropy_loss=entropy_loss.item(), + ) diff --git a/examples/atari/tianshou/policy/modelfree/iqn.py b/examples/atari/tianshou/policy/modelfree/iqn.py new file mode 100644 index 0000000000000000000000000000000000000000..75d76a2dd8070fccc5d2640e016d090f0dcf623d --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/iqn.py @@ -0,0 +1,161 @@ +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.data import Batch, to_numpy +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + ObsBatchProtocol, + QuantileRegressionBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.policy import QRDQNPolicy +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats + + +@dataclass(kw_only=True) +class IQNTrainingStats(QRDQNTrainingStats): + pass + + +TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats) + + +class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]): + """Implementation of Implicit Quantile Network. arXiv:1806.06923. + + :param model: a model following the rules (s_B -> action_values_BA) + :param optim: a torch.optim for optimizing the model. + :param discount_factor: in [0, 1]. + :param sample_size: the number of samples for policy evaluation. + :param online_sample_size: the number of samples for online model + in training. + :param target_sample_size: the number of samples for target model + in training. + :param num_quantiles: the number of quantile midpoints in the inverse + cumulative distribution function of the value. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param observation_space: Env's observation space. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + model: torch.nn.Module, + optim: torch.optim.Optimizer, + action_space: gym.spaces.Discrete, + discount_factor: float = 0.99, + sample_size: int = 32, + online_sample_size: int = 8, + target_sample_size: int = 8, + num_quantiles: int = 200, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + is_double: bool = True, + clip_loss_grad: bool = False, + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + assert sample_size > 1, f"sample_size should be greater than 1 but got: {sample_size}" + assert ( + online_sample_size > 1 + ), f"online_sample_size should be greater than 1 but got: {online_sample_size}" + assert ( + target_sample_size > 1 + ), f"target_sample_size should be greater than 1 but got: {target_sample_size}" + super().__init__( + model=model, + optim=optim, + action_space=action_space, + discount_factor=discount_factor, + num_quantiles=num_quantiles, + estimation_step=estimation_step, + target_update_freq=target_update_freq, + reward_normalization=reward_normalization, + is_double=is_double, + clip_loss_grad=clip_loss_grad, + observation_space=observation_space, + lr_scheduler=lr_scheduler, + ) + self.sample_size = sample_size # for policy eval + self.online_sample_size = online_sample_size + self.target_sample_size = target_sample_size + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + model: Literal["model", "model_old"] = "model", + **kwargs: Any, + ) -> QuantileRegressionBatchProtocol: + if model == "model_old": + sample_size = self.target_sample_size + elif self.training: + sample_size = self.online_sample_size + else: + sample_size = self.sample_size + model = getattr(self, model) + obs = batch.obs + # TODO: this seems very contrived! + obs_next = obs.obs if hasattr(obs, "obs") else obs + (logits, taus), hidden = model( + obs_next, + sample_size=sample_size, + state=state, + info=batch.info, + ) + q = self.compute_q_value(logits, getattr(obs, "mask", None)) + if self.max_action_num is None: # type: ignore + # TODO: see same thing in DQNPolicy! + self.max_action_num = q.shape[1] + act = to_numpy(q.max(dim=1)[1]) + result = Batch(logits=logits, act=act, state=hidden, taus=taus) + return cast(QuantileRegressionBatchProtocol, result) + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats: + if self._target and self._iter % self.freq == 0: + self.sync_weight() + self.optim.zero_grad() + weight = batch.pop("weight", 1.0) + action_batch = self(batch) + curr_dist, taus = action_batch.logits, action_batch.taus + act = batch.act + curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) + target_dist = batch.returns.unsqueeze(1) + # calculate each element's difference between curr_dist and target_dist + dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + huber_loss = ( + ( + dist_diff + * (taus.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.0).float()).abs() + ) + .sum(-1) + .mean(1) + ) + loss = (huber_loss * weight).mean() + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 + batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer + loss.backward() + self.optim.step() + self._iter += 1 + + return IQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/examples/atari/tianshou/policy/modelfree/npg.py b/examples/atari/tianshou/policy/modelfree/npg.py new file mode 100644 index 0000000000000000000000000000000000000000..9e04d3feba90269f6f90a207c212f3c2921dd0ec --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/npg.py @@ -0,0 +1,227 @@ +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.distributions import kl_divergence + +from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats +from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol +from tianshou.policy import A2CPolicy +from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic + + +@dataclass(kw_only=True) +class NPGTrainingStats(TrainingStats): + actor_loss: SequenceSummaryStats + vf_loss: SequenceSummaryStats + kl: SequenceSummaryStats + + +TNPGTrainingStats = TypeVar("TNPGTrainingStats", bound=NPGTrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # type: ignore[type-var] + """Implementation of Natural Policy Gradient. + + https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf + + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`). + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer for actor and critic network. + :param dist_fn: distribution class for computing the action. + :param action_space: env's action space + :param optim_critic_iters: Number of times to optimize critic network per update. + :param actor_step_size: step size for actor update in natural gradient direction. + :param advantage_normalization: whether to do per mini-batch advantage + normalization. + :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param max_batchsize: the maximum size of the batch when computing GAE. + :param discount_factor: in [0, 1]. + :param reward_normalization: normalize estimated values to have std close to 1. + :param deterministic_eval: if True, use deterministic evaluation. + :param observation_space: the space of the observation. + :param action_scaling: if True, scale the action from [-1, 1] to the range of + action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + + def __init__( + self, + *, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, + optim: torch.optim.Optimizer, + dist_fn: TDistFnDiscrOrCont, + action_space: gym.Space, + optim_critic_iters: int = 5, + actor_step_size: float = 0.5, + advantage_normalization: bool = True, + gae_lambda: float = 0.95, + max_batchsize: int = 256, + discount_factor: float = 0.99, + # TODO: rename to return_normalization? + reward_normalization: bool = False, + deterministic_eval: bool = False, + observation_space: gym.Space | None = None, + action_scaling: bool = True, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + actor=actor, + critic=critic, + optim=optim, + dist_fn=dist_fn, + action_space=action_space, + # TODO: violates Liskov substitution principle, see the del statement below + vf_coef=None, # type: ignore + ent_coef=None, # type: ignore + max_grad_norm=None, + gae_lambda=gae_lambda, + max_batchsize=max_batchsize, + discount_factor=discount_factor, + reward_normalization=reward_normalization, + deterministic_eval=deterministic_eval, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + lr_scheduler=lr_scheduler, + ) + # TODO: see above, it ain't pretty... + del self.vf_coef, self.ent_coef, self.max_grad_norm + self.norm_adv = advantage_normalization + self.optim_critic_iters = optim_critic_iters + self.actor_step_size = actor_step_size + # adjusts Hessian-vector product calculation for numerical stability + self._damping = 0.1 + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithAdvantagesProtocol: + batch = super().process_fn(batch, buffer, indices) + old_log_prob = [] + with torch.no_grad(): + for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): + old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act)) + batch.logp_old = torch.cat(old_log_prob, dim=0) + if self.norm_adv: + batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() + return batch + + def learn( # type: ignore + self, + batch: Batch, + batch_size: int | None, + repeat: int, + **kwargs: Any, + ) -> TNPGTrainingStats: + actor_losses, vf_losses, kls = [], [], [] + split_batch_size = batch_size or -1 + for _ in range(repeat): + for minibatch in batch.split(split_batch_size, merge_last=True): + # optimize actor + # direction: calculate villia gradient + dist = self(minibatch).dist + log_prob = dist.log_prob(minibatch.act) + log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1) + actor_loss = -(log_prob * minibatch.adv).mean() + flat_grads = self._get_flat_grad(actor_loss, self.actor, retain_graph=True).detach() + + # direction: calculate natural gradient + with torch.no_grad(): + old_dist = self(minibatch).dist + + kl = kl_divergence(old_dist, dist).mean() + # calculate first order gradient of kl with respect to theta + flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) + search_direction = -self._conjugate_gradients(flat_grads, flat_kl_grad, nsteps=10) + + # step + with torch.no_grad(): + flat_params = torch.cat( + [param.data.view(-1) for param in self.actor.parameters()], + ) + new_flat_params = flat_params + self.actor_step_size * search_direction + self._set_from_flat_params(self.actor, new_flat_params) + new_dist = self(minibatch).dist + kl = kl_divergence(old_dist, new_dist).mean() + + # optimize critic + for _ in range(self.optim_critic_iters): + value = self.critic(minibatch.obs).flatten() + vf_loss = F.mse_loss(minibatch.returns, value) + self.optim.zero_grad() + vf_loss.backward() + self.optim.step() + + actor_losses.append(actor_loss.item()) + vf_losses.append(vf_loss.item()) + kls.append(kl.item()) + + actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) + vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) + kl_summary_stat = SequenceSummaryStats.from_sequence(kls) + + return NPGTrainingStats( # type: ignore[return-value] + actor_loss=actor_loss_summary_stat, + vf_loss=vf_loss_summary_stat, + kl=kl_summary_stat, + ) + + def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor: + """Matrix vector product.""" + # caculate second order gradient of kl with respect to theta + kl_v = (flat_kl_grad * v).sum() + flat_kl_grad_grad = self._get_flat_grad(kl_v, self.actor, retain_graph=True).detach() + return flat_kl_grad_grad + v * self._damping + + def _conjugate_gradients( + self, + minibatch: torch.Tensor, + flat_kl_grad: torch.Tensor, + nsteps: int = 10, + residual_tol: float = 1e-10, + ) -> torch.Tensor: + x = torch.zeros_like(minibatch) + r, p = minibatch.clone(), minibatch.clone() + # Note: should be 'r, p = minibatch - MVP(x)', but for x=0, MVP(x)=0. + # Change if doing warm start. + rdotr = r.dot(r) + for _ in range(nsteps): + z = self._MVP(p, flat_kl_grad) + alpha = rdotr / p.dot(z) + x += alpha * p + r -= alpha * z + new_rdotr = r.dot(r) + if new_rdotr < residual_tol: + break + p = r + new_rdotr / rdotr * p + rdotr = new_rdotr + return x + + def _get_flat_grad(self, y: torch.Tensor, model: nn.Module, **kwargs: Any) -> torch.Tensor: + grads = torch.autograd.grad(y, model.parameters(), **kwargs) # type: ignore + return torch.cat([grad.reshape(-1) for grad in grads]) + + def _set_from_flat_params(self, model: nn.Module, flat_params: torch.Tensor) -> nn.Module: + prev_ind = 0 + for param in model.parameters(): + flat_size = int(np.prod(list(param.size()))) + param.data.copy_(flat_params[prev_ind : prev_ind + flat_size].view(param.size())) + prev_ind += flat_size + return model diff --git a/examples/atari/tianshou/policy/modelfree/pg.py b/examples/atari/tianshou/policy/modelfree/pg.py new file mode 100644 index 0000000000000000000000000000000000000000..80bcff672d0c8146e4c9bf1440857b451c00e949 --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/pg.py @@ -0,0 +1,235 @@ +import warnings +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch + +from tianshou.data import ( + Batch, + ReplayBuffer, + SequenceSummaryStats, + to_torch, + to_torch_as, +) +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + BatchWithReturnsProtocol, + DistBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.policy import BasePolicy +from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.utils import RunningMeanStd +from tianshou.utils.net.continuous import ActorProb +from tianshou.utils.net.discrete import Actor + +# Dimension Naming Convention +# B - Batch Size +# A - Action +# D - Dist input (usually 2, loc and scale) +# H - Dimension of hidden, can be None + +TDistFnContinuous = Callable[ + [tuple[torch.Tensor, torch.Tensor]], + torch.distributions.Distribution, +] +TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Categorical] + +TDistFnDiscrOrCont = TDistFnContinuous | TDistFnDiscrete + + +@dataclass(kw_only=True) +class PGTrainingStats(TrainingStats): + loss: SequenceSummaryStats + + +TPGTrainingStats = TypeVar("TPGTrainingStats", bound=PGTrainingStats) + + +class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): + """Implementation of REINFORCE algorithm. + + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). + :param optim: optimizer for actor network. + :param dist_fn: distribution class for computing the action. + Maps model_output -> distribution. Typically a Gaussian distribution + taking `model_output=mean,std` as input for continuous action spaces, + or a categorical distribution taking `model_output=logits` + for discrete action spaces. Note that as user, you are responsible + for ensuring that the distribution is compatible with the action space. + :param action_space: env's action space. + :param discount_factor: in [0, 1]. + :param reward_normalization: if True, will normalize the *returns* + by subtracting the running mean and dividing by the running standard deviation. + Can be detrimental to performance! See TODO in process_fn. + :param deterministic_eval: if True, will use deterministic action (the dist's mode) + instead of stochastic one during evaluation. Does not affect training. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. + """ + + def __init__( + self, + *, + actor: torch.nn.Module | ActorProb | Actor, + optim: torch.optim.Optimizer, + dist_fn: TDistFnDiscrOrCont, + action_space: gym.Space, + discount_factor: float = 0.99, + # TODO: rename to return_normalization? + reward_normalization: bool = False, + deterministic_eval: bool = False, + observation_space: gym.Space | None = None, + # TODO: why change the default from the base? + action_scaling: bool = True, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + lr_scheduler=lr_scheduler, + ) + if action_scaling and not np.isclose(actor.max_action, 1.0): + warnings.warn( + "action_scaling and action_bound_method are only intended" + "to deal with unbounded model action space, but find actor model" + f"bound action space with max_action={actor.max_action}." + "Consider using unbounded=True option of the actor model," + "or set action_scaling to False and action_bound_method to None.", + ) + self.actor = actor + self.optim = optim + self.dist_fn = dist_fn + assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" + self.gamma = discount_factor + self.rew_norm = reward_normalization + self.ret_rms = RunningMeanStd() + self._eps = 1e-8 + self.deterministic_eval = deterministic_eval + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithReturnsProtocol: + r"""Compute the discounted returns (Monte Carlo estimates) for each transition. + + They are added to the batch under the field `returns`. + Note: this function will modify the input batch! + + .. math:: + G_t = \sum_{i=t}^T \gamma^{i-t}r_i + + where :math:`T` is the terminal time step, :math:`\gamma` is the + discount factor, :math:`\gamma \in [0, 1]`. + + :param batch: a data batch which contains several episodes of data in + sequential order. Mind that the end of each finished episode of batch + should be marked by done flag, unfinished (or collecting) episodes will be + recognized by buffer.unfinished_index(). + :param buffer: the corresponding replay buffer. + :param numpy.ndarray indices: tell batch's location in buffer, batch is equal + to buffer[indices]. + """ + v_s_ = np.full(indices.shape, self.ret_rms.mean) + # gae_lambda = 1.0 means we use Monte Carlo estimate + unnormalized_returns, _ = self.compute_episodic_return( + batch, + buffer, + indices, + v_s_=v_s_, + gamma=self.gamma, + gae_lambda=1.0, + ) + # TODO: overridden in A2C, where mean is not subtracted. Subtracting mean + # can be very detrimental! It also has no theoretical grounding. + # This should be addressed soon! + if self.rew_norm: + batch.returns = (unnormalized_returns - self.ret_rms.mean) / np.sqrt( + self.ret_rms.var + self._eps, + ) + self.ret_rms.update(unnormalized_returns) + else: + batch.returns = unnormalized_returns + batch: BatchWithReturnsProtocol + return batch + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> DistBatchProtocol: + """Compute action over the given batch data by applying the actor. + + Will sample from the dist_fn, if appropriate. + Returns a new object representing the processed batch data + (contrary to other methods that modify the input batch inplace). + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + # TODO - ALGO: marked for algorithm refactoring + action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + # in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A + # therefore action_dist_input_BD is equivalent to logits_BA + # If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian) + # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked + dist = self.dist_fn(action_dist_input_BD) + + act_B = ( + dist.mode + if self.deterministic_eval and not self.is_within_training_step + else dist.sample() + ) + # act is of dimension BA in continuous case and of dimension B in discrete + result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) + return cast(DistBatchProtocol, result) + + # TODO: why does mypy complain? + def learn( # type: ignore + self, + batch: BatchWithReturnsProtocol, + batch_size: int | None, + repeat: int, + *args: Any, + **kwargs: Any, + ) -> TPGTrainingStats: + losses = [] + split_batch_size = batch_size or -1 + for _ in range(repeat): + for minibatch in batch.split(split_batch_size, merge_last=True): + self.optim.zero_grad() + result = self(minibatch) + dist = result.dist + act = to_torch_as(minibatch.act, result.act) + ret = to_torch(minibatch.returns, torch.float, result.act.device) + log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1) + loss = -(log_prob * ret).mean() + loss.backward() + self.optim.step() + losses.append(loss.item()) + + loss_summary_stat = SequenceSummaryStats.from_sequence(losses) + + return PGTrainingStats(loss=loss_summary_stat) # type: ignore[return-value] diff --git a/examples/atari/tianshou/policy/modelfree/ppo.py b/examples/atari/tianshou/policy/modelfree/ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..196cd72e49f5f0bb9c3408de32760e9d3e022dd5 --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/ppo.py @@ -0,0 +1,213 @@ +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar + +import gymnasium as gym +import numpy as np +import torch +from torch import nn + +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as +from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol +from tianshou.policy import A2CPolicy +from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic + + +@dataclass(kw_only=True) +class PPOTrainingStats(TrainingStats): + loss: SequenceSummaryStats + clip_loss: SequenceSummaryStats + vf_loss: SequenceSummaryStats + ent_loss: SequenceSummaryStats + + +TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var] + r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. + + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`). + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer for actor and critic network. + :param dist_fn: distribution class for computing the action. + :param action_space: env's action space + :param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original + paper. + :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, + where c > 1 is a constant indicating the lower bound. Set to None + to disable dual-clip PPO. + :param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. + :param advantage_normalization: whether to do per mini-batch advantage + normalization. + :param recompute_advantage: whether to recompute advantage every update + repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. + :param vf_coef: weight for value loss. + :param ent_coef: weight for entropy loss. + :param max_grad_norm: clipping gradients in back propagation. + :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param max_batchsize: the maximum size of the batch when computing GAE. + :param discount_factor: in [0, 1]. + :param reward_normalization: normalize estimated values to have std close to 1. + :param deterministic_eval: if True, use deterministic evaluation. + :param observation_space: the space of the observation. + :param action_scaling: if True, scale the action from [-1, 1] to the range of + action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, + optim: torch.optim.Optimizer, + dist_fn: TDistFnDiscrOrCont, + action_space: gym.Space, + eps_clip: float = 0.2, + dual_clip: float | None = None, + value_clip: bool = False, + advantage_normalization: bool = True, + recompute_advantage: bool = False, + vf_coef: float = 0.5, + ent_coef: float = 0.01, + max_grad_norm: float | None = None, + gae_lambda: float = 0.95, + max_batchsize: int = 256, + discount_factor: float = 0.99, + # TODO: rename to return_normalization? + reward_normalization: bool = False, + deterministic_eval: bool = False, + observation_space: gym.Space | None = None, + action_scaling: bool = True, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + assert ( + dual_clip is None or dual_clip > 1.0 + ), f"Dual-clip PPO parameter should greater than 1.0 but got {dual_clip}" + + super().__init__( + actor=actor, + critic=critic, + optim=optim, + dist_fn=dist_fn, + action_space=action_space, + vf_coef=vf_coef, + ent_coef=ent_coef, + max_grad_norm=max_grad_norm, + gae_lambda=gae_lambda, + max_batchsize=max_batchsize, + discount_factor=discount_factor, + reward_normalization=reward_normalization, + deterministic_eval=deterministic_eval, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + lr_scheduler=lr_scheduler, + ) + self.eps_clip = eps_clip + self.dual_clip = dual_clip + self.value_clip = value_clip + self.norm_adv = advantage_normalization + self.recompute_adv = recompute_advantage + self._actor_critic: ActorCritic + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> LogpOldProtocol: + if self.recompute_adv: + # buffer input `buffer` and `indices` to be used in `learn()`. + self._buffer, self._indices = buffer, indices + batch = self._compute_returns(batch, buffer, indices) + batch.act = to_torch_as(batch.act, batch.v_s) + with torch.no_grad(): + batch.logp_old = self(batch).dist.log_prob(batch.act) + batch: LogpOldProtocol + return batch + + # TODO: why does mypy complain? + def learn( # type: ignore + self, + batch: RolloutBatchProtocol, + batch_size: int | None, + repeat: int, + *args: Any, + **kwargs: Any, + ) -> TPPOTrainingStats: + losses, clip_losses, vf_losses, ent_losses = [], [], [], [] + split_batch_size = batch_size or -1 + for step in range(repeat): + if self.recompute_adv and step > 0: + batch = self._compute_returns(batch, self._buffer, self._indices) + for minibatch in batch.split(split_batch_size, merge_last=True): + # calculate loss for actor + dist = self(minibatch).dist + if self.norm_adv: + mean, std = minibatch.adv.mean(), minibatch.adv.std() + minibatch.adv = (minibatch.adv - mean) / (std + self._eps) # per-batch norm + ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() + ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) + surr1 = ratio * minibatch.adv + surr2 = ratio.clamp(1.0 - self.eps_clip, 1.0 + self.eps_clip) * minibatch.adv + if self.dual_clip: + clip1 = torch.min(surr1, surr2) + clip2 = torch.max(clip1, self.dual_clip * minibatch.adv) + clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean() + else: + clip_loss = -torch.min(surr1, surr2).mean() + # calculate loss for critic + value = self.critic(minibatch.obs).flatten() + if self.value_clip: + v_clip = minibatch.v_s + (value - minibatch.v_s).clamp( + -self.eps_clip, + self.eps_clip, + ) + vf1 = (minibatch.returns - value).pow(2) + vf2 = (minibatch.returns - v_clip).pow(2) + vf_loss = torch.max(vf1, vf2).mean() + else: + vf_loss = (minibatch.returns - value).pow(2).mean() + # calculate regularization and overall loss + ent_loss = dist.entropy().mean() + loss = clip_loss + self.vf_coef * vf_loss - self.ent_coef * ent_loss + self.optim.zero_grad() + loss.backward() + if self.max_grad_norm: # clip large gradient + nn.utils.clip_grad_norm_( + self._actor_critic.parameters(), + max_norm=self.max_grad_norm, + ) + self.optim.step() + clip_losses.append(clip_loss.item()) + vf_losses.append(vf_loss.item()) + ent_losses.append(ent_loss.item()) + losses.append(loss.item()) + + losses_summary = SequenceSummaryStats.from_sequence(losses) + clip_losses_summary = SequenceSummaryStats.from_sequence(clip_losses) + vf_losses_summary = SequenceSummaryStats.from_sequence(vf_losses) + ent_losses_summary = SequenceSummaryStats.from_sequence(ent_losses) + + return PPOTrainingStats( # type: ignore[return-value] + loss=losses_summary, + clip_loss=clip_losses_summary, + vf_loss=vf_losses_summary, + ent_loss=ent_losses_summary, + ) diff --git a/examples/atari/tianshou/policy/modelfree/qrdqn.py b/examples/atari/tianshou/policy/modelfree/qrdqn.py new file mode 100644 index 0000000000000000000000000000000000000000..71c36de0c90082dded3c3e3b216ddeadd483b835 --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/qrdqn.py @@ -0,0 +1,131 @@ +import warnings +from dataclasses import dataclass +from typing import Any, Generic, TypeVar + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.types import RolloutBatchProtocol +from tianshou.policy import DQNPolicy +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.dqn import DQNTrainingStats + + +@dataclass(kw_only=True) +class QRDQNTrainingStats(DQNTrainingStats): + pass + + +TQRDQNTrainingStats = TypeVar("TQRDQNTrainingStats", bound=QRDQNTrainingStats) + + +class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): + """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044. + + :param model: a model following the rules (s -> action_values_BA) + :param optim: a torch.optim for optimizing the model. + :param action_space: Env's action space. + :param discount_factor: in [0, 1]. + :param num_quantiles: the number of quantile midpoints in the inverse + cumulative distribution function of the value. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param observation_space: Env's observation space. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + model: torch.nn.Module, + optim: torch.optim.Optimizer, + action_space: gym.spaces.Discrete, + discount_factor: float = 0.99, + num_quantiles: int = 200, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + is_double: bool = True, + clip_loss_grad: bool = False, + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + assert num_quantiles > 1, f"num_quantiles should be greater than 1 but got: {num_quantiles}" + super().__init__( + model=model, + optim=optim, + action_space=action_space, + discount_factor=discount_factor, + estimation_step=estimation_step, + target_update_freq=target_update_freq, + reward_normalization=reward_normalization, + is_double=is_double, + clip_loss_grad=clip_loss_grad, + observation_space=observation_space, + lr_scheduler=lr_scheduler, + ) + self.num_quantiles = num_quantiles + tau = torch.linspace(0, 1, self.num_quantiles + 1) + self.tau_hat = torch.nn.Parameter( + ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), + requires_grad=False, + ) + warnings.filterwarnings("ignore", message="Using a target size") + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + if self._target: + act = self(obs_next_batch).act + next_dist = self(obs_next_batch, model="model_old").logits + else: + next_batch = self(obs_next_batch) + act = next_batch.act + next_dist = next_batch.logits + return next_dist[np.arange(len(act)), act, :] + + def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: + return super().compute_q_value(logits.mean(2), mask) + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats: + if self._target and self._iter % self.freq == 0: + self.sync_weight() + self.optim.zero_grad() + weight = batch.pop("weight", 1.0) + curr_dist = self(batch).logits + act = batch.act + curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) + target_dist = batch.returns.unsqueeze(1) + # calculate each element's difference between curr_dist and target_dist + dist_diff = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + huber_loss = ( + (dist_diff * (self.tau_hat - (target_dist - curr_dist).detach().le(0.0).float()).abs()) + .sum(-1) + .mean(1) + ) + loss = (huber_loss * weight).mean() + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 + batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer + loss.backward() + self.optim.step() + self._iter += 1 + + return QRDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/examples/atari/tianshou/policy/modelfree/rainbow.py b/examples/atari/tianshou/policy/modelfree/rainbow.py new file mode 100644 index 0000000000000000000000000000000000000000..fad567cd2f56078d1d8ad87c944dcd0356badf46 --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/rainbow.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass +from typing import Any, TypeVar + +from torch import nn + +from tianshou.data.types import RolloutBatchProtocol +from tianshou.policy import C51Policy +from tianshou.policy.modelfree.c51 import C51TrainingStats +from tianshou.utils.net.discrete import NoisyLinear + + +# TODO: this is a hacky thing interviewing side-effects and a return. Should improve. +def _sample_noise(model: nn.Module) -> bool: + """Sample the random noises of NoisyLinear modules in the model. + + Returns True if at least one NoisyLinear submodule was found. + + :param model: a PyTorch module which may have NoisyLinear submodules. + :returns: True if model has at least one NoisyLinear submodule; + otherwise, False. + """ + sampled_any_noise = False + for m in model.modules(): + if isinstance(m, NoisyLinear): + m.sample() + sampled_any_noise = True + return sampled_any_noise + + +@dataclass(kw_only=True) +class RainbowTrainingStats(C51TrainingStats): + loss: float + + +TRainbowTrainingStats = TypeVar("TRainbowTrainingStats", bound=RainbowTrainingStats) + + +# TODO: is this class worth keeping? It barely does anything +class RainbowPolicy(C51Policy[TRainbowTrainingStats]): + """Implementation of Rainbow DQN. arXiv:1710.02298. + + Same parameters as :class:`~tianshou.policy.C51Policy`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.C51Policy` for more detailed + explanation. + """ + + def learn( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TRainbowTrainingStats: + _sample_noise(self.model) + if self._target and _sample_noise(self.model_old): + self.model_old.train() # so that NoisyLinear takes effect + return super().learn(batch, **kwargs) diff --git a/examples/atari/tianshou/policy/modelfree/redq.py b/examples/atari/tianshou/policy/modelfree/redq.py new file mode 100644 index 0000000000000000000000000000000000000000..25f299733ee1ff7ce84834a9822f2dce049984ec --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/redq.py @@ -0,0 +1,239 @@ +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +from torch.distributions import Independent, Normal + +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol +from tianshou.exploration import BaseNoise +from tianshou.policy import DDPGPolicy +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.ddpg import DDPGTrainingStats +from tianshou.utils.net.continuous import ActorProb + + +@dataclass +class REDQTrainingStats(DDPGTrainingStats): + """A data structure for storing loss statistics of the REDQ learn step.""" + + alpha: float | None = None + alpha_loss: float | None = None + + +TREDQTrainingStats = TypeVar("TREDQTrainingStats", bound=REDQTrainingStats) + + +class REDQPolicy(DDPGPolicy[TREDQTrainingStats]): + """Implementation of REDQ. arXiv:2101.05982. + + :param actor: The actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> model_output) + :param actor_optim: The optimizer for actor network. + :param critic: The critic network. (s, a -> Q(s, a)) + :param critic_optim: The optimizer for critic network. + :param action_space: Env's action space. + :param ensemble_size: Number of sub-networks in the critic ensemble. + :param subset_size: Number of networks in the subset. + :param tau: Param for soft update of the target network. + :param gamma: Discount factor, in [0, 1]. + :param alpha: entropy regularization coefficient. + If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then + alpha is automatically tuned. + :param exploration_noise: The exploration noise, added to the action. Defaults + to ``GaussianNoise(sigma=0.1)``. + :param estimation_step: The number of steps to look ahead. + :param actor_delay: Number of critic updates before an actor update. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + :param lr_scheduler: if not None, will be called in `policy.update()`. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + actor: torch.nn.Module | ActorProb, + actor_optim: torch.optim.Optimizer, + critic: torch.nn.Module, + critic_optim: torch.optim.Optimizer, + action_space: gym.spaces.Box, + ensemble_size: int = 10, + subset_size: int = 2, + tau: float = 0.005, + gamma: float = 0.99, + alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2, + estimation_step: int = 1, + actor_delay: int = 20, + exploration_noise: BaseNoise | Literal["default"] | None = None, + deterministic_eval: bool = True, + target_mode: Literal["mean", "min"] = "min", + action_scaling: bool = True, + action_bound_method: Literal["clip"] | None = "clip", + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + if target_mode not in ("min", "mean"): + raise ValueError(f"Unsupported target_mode: {target_mode}") + if not 0 < subset_size <= ensemble_size: + raise ValueError( + f"Invalid choice of ensemble size or subset size. " + f"Should be 0 < {subset_size=} <= {ensemble_size=}", + ) + super().__init__( + actor=actor, + actor_optim=actor_optim, + critic=critic, + critic_optim=critic_optim, + action_space=action_space, + tau=tau, + gamma=gamma, + exploration_noise=exploration_noise, + estimation_step=estimation_step, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + observation_space=observation_space, + lr_scheduler=lr_scheduler, + ) + self.ensemble_size = ensemble_size + self.subset_size = subset_size + + self.target_mode = target_mode + self.critic_gradient_step = 0 + self.actor_delay = actor_delay + self.deterministic_eval = deterministic_eval + self.__eps = np.finfo(np.float32).eps.item() + + self._last_actor_loss = 0.0 # only for logging purposes + + # TODO: reduce duplication with SACPolicy + self.alpha: float | torch.Tensor + self._is_auto_alpha = not isinstance(alpha, float) + if self._is_auto_alpha: + # TODO: why doesn't mypy understand that this must be a tuple? + alpha = cast(tuple[float, torch.Tensor, torch.optim.Optimizer], alpha) + if alpha[1].shape != torch.Size([1]): + raise ValueError( + f"Expected log_alpha to have shape torch.Size([1]), " + f"but got {alpha[1].shape} instead.", + ) + if not alpha[1].requires_grad: + raise ValueError("Expected log_alpha to require gradient, but it doesn't.") + + self.target_entropy, self.log_alpha, self.alpha_optim = alpha + self.alpha = self.log_alpha.detach().exp() + else: + # TODO: make mypy undestand this, or switch to something like pyright... + alpha = cast(float, alpha) + self.alpha = alpha + + @property + def is_auto_alpha(self) -> bool: + return self._is_auto_alpha + + # TODO: why override from the base class? + def sync_weight(self) -> None: + for o, n in zip(self.critic_old.parameters(), self.critic.parameters(), strict=True): + o.data.copy_(o.data * (1.0 - self.tau) + n.data * self.tau) + + def forward( # type: ignore + self, + batch: ObsBatchProtocol, + state: dict | Batch | np.ndarray | None = None, + **kwargs: Any, + ) -> Batch: + (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Independent(Normal(loc_B, scale_B), 1) + if self.deterministic_eval and not self.is_within_training_step: + act_B = dist.mode + else: + act_B = dist.rsample() + log_prob = dist.log_prob(act_B).unsqueeze(-1) + # apply correction for Tanh squashing when computing logprob from Gaussian + # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. + # in appendix C to get some understanding of this equation. + squashed_action = torch.tanh(act_B) + log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum( + -1, + keepdim=True, + ) + return Batch( + logits=(loc_B, scale_B), + act=squashed_action, + state=h_BH, + dist=dist, + log_prob=log_prob, + ) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + obs_next_result = self(obs_next_batch) + a_ = obs_next_result.act + sample_ensemble_idx = np.random.choice(self.ensemble_size, self.subset_size, replace=False) + qs = self.critic_old(obs_next_batch.obs, a_)[sample_ensemble_idx, ...] + if self.target_mode == "min": + target_q, _ = torch.min(qs, dim=0) + elif self.target_mode == "mean": + target_q = torch.mean(qs, dim=0) + + target_q -= self.alpha * obs_next_result.log_prob + + return target_q + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TREDQTrainingStats: # type: ignore + # critic ensemble + weight = getattr(batch, "weight", 1.0) + current_qs = self.critic(batch.obs, batch.act).flatten(1) + target_q = batch.returns.flatten() + td = current_qs - target_q + critic_loss = (td.pow(2) * weight).mean() + self.critic_optim.zero_grad() + critic_loss.backward() + self.critic_optim.step() + batch.weight = torch.mean(td, dim=0) # prio-buffer + self.critic_gradient_step += 1 + + alpha_loss = None + # actor + if self.critic_gradient_step % self.actor_delay == 0: + obs_result = self(batch) + a = obs_result.act + current_qa = self.critic(batch.obs, a).mean(dim=0).flatten() + actor_loss = (self.alpha * obs_result.log_prob.flatten() - current_qa).mean() + self.actor_optim.zero_grad() + actor_loss.backward() + self.actor_optim.step() + + if self.is_auto_alpha: + log_prob = obs_result.log_prob.detach() + self._target_entropy + alpha_loss = -(self._log_alpha * log_prob).mean() + self.alpha_optim.zero_grad() + alpha_loss.backward() + self.alpha_optim.step() + self.alpha = self.log_alpha.detach().exp() + + self.sync_weight() + + if self.critic_gradient_step % self.actor_delay == 0: + self._last_actor_loss = actor_loss.item() + if self.is_auto_alpha: + self.alpha = cast(torch.Tensor, self.alpha) + + return REDQTrainingStats( # type: ignore[return-value] + actor_loss=self._last_actor_loss, + critic_loss=critic_loss.item(), + alpha=self.alpha.item() if isinstance(self.alpha, torch.Tensor) else self.alpha, + alpha_loss=alpha_loss, + ) diff --git a/examples/atari/tianshou/policy/modelfree/sac.py b/examples/atari/tianshou/policy/modelfree/sac.py new file mode 100644 index 0000000000000000000000000000000000000000..a5a05c0fdc0011ae972cc6efc5bbe738e60e10c1 --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/sac.py @@ -0,0 +1,251 @@ +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar, cast + +import gymnasium as gym +import numpy as np +import torch +from torch.distributions import Independent, Normal + +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.types import ( + DistLogProbBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.exploration import BaseNoise +from tianshou.policy import DDPGPolicy +from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.utils.conversion import to_optional_float +from tianshou.utils.net.continuous import ActorProb +from tianshou.utils.optim import clone_optimizer + + +@dataclass(kw_only=True) +class SACTrainingStats(TrainingStats): + actor_loss: float + critic1_loss: float + critic2_loss: float + alpha: float | None = None + alpha_loss: float | None = None + + +TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var] + """Implementation of Soft Actor-Critic. arXiv:1812.05905. + + :param actor: the actor network following the rules (s -> dist_input_BD) + :param actor_optim: the optimizer for actor network. + :param critic: the first critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer for the first critic network. + :param action_space: Env's action space. Should be gym.spaces.Box. + :param critic2: the second critic network. (s, a -> Q(s, a)). + If None, use the same network as critic (via deepcopy). + :param critic2_optim: the optimizer for the second critic network. + If None, clone critic_optim to use for critic2.parameters(). + :param tau: param for soft update of the target network. + :param gamma: discount factor, in [0, 1]. + :param alpha: entropy regularization coefficient. + If a tuple (target_entropy, log_alpha, alpha_optim) is provided, + then alpha is automatically tuned. + :param estimation_step: The number of steps to look ahead. + :param exploration_noise: add noise to action for exploration. + This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). + :param deterministic_eval: whether to use deterministic action + (mode of Gaussian policy) in evaluation mode instead of stochastic + action sampled by the policy. Does not affect training. + :param action_scaling: whether to map actions from range [-1, 1] + to range[action_spaces.low, action_spaces.high]. + :param action_bound_method: method to bound action to range [-1, 1], + can be either "clip" (for simply clipping the action) + or empty string for no bounding. Only used if the action_space is continuous. + :param observation_space: Env's observation space. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate + in optimizer in each policy.update() + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + actor: torch.nn.Module | ActorProb, + actor_optim: torch.optim.Optimizer, + critic: torch.nn.Module, + critic_optim: torch.optim.Optimizer, + action_space: gym.Space, + critic2: torch.nn.Module | None = None, + critic2_optim: torch.optim.Optimizer | None = None, + tau: float = 0.005, + gamma: float = 0.99, + alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2, + estimation_step: int = 1, + exploration_noise: BaseNoise | Literal["default"] | None = None, + deterministic_eval: bool = True, + action_scaling: bool = True, + # TODO: some papers claim that tanh is crucial for SAC, yet DDPG will raise an + # error if tanh is used. Should be investigated. + action_bound_method: Literal["clip"] | None = "clip", + observation_space: gym.Space | None = None, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + actor=actor, + actor_optim=actor_optim, + critic=critic, + critic_optim=critic_optim, + action_space=action_space, + tau=tau, + gamma=gamma, + exploration_noise=exploration_noise, + estimation_step=estimation_step, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + observation_space=observation_space, + lr_scheduler=lr_scheduler, + ) + critic2 = critic2 or deepcopy(critic) + critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) + self.critic2, self.critic2_old = critic2, deepcopy(critic2) + self.critic2_old.eval() + self.critic2_optim = critic2_optim + self.deterministic_eval = deterministic_eval + self.__eps = np.finfo(np.float32).eps.item() + + self.alpha: float | torch.Tensor + self._is_auto_alpha = not isinstance(alpha, float) + if self._is_auto_alpha: + # TODO: why doesn't mypy understand that this must be a tuple? + alpha = cast(tuple[float, torch.Tensor, torch.optim.Optimizer], alpha) + if alpha[1].shape != torch.Size([1]): + raise ValueError( + f"Expected log_alpha to have shape torch.Size([1]), " + f"but got {alpha[1].shape} instead.", + ) + if not alpha[1].requires_grad: + raise ValueError("Expected log_alpha to require gradient, but it doesn't.") + + self.target_entropy, self.log_alpha, self.alpha_optim = alpha + self.alpha = self.log_alpha.detach().exp() + else: + alpha = cast( + float, + alpha, + ) # can we convert alpha to a constant tensor here? then mypy wouldn't complain + self.alpha = alpha + + # TODO or not TODO: add to BasePolicy? + self._check_field_validity() + + def _check_field_validity(self) -> None: + if not isinstance(self.action_space, gym.spaces.Box): + raise ValueError( + f"SACPolicy only supports gym.spaces.Box, but got {self.action_space=}." + f"Please use DiscreteSACPolicy for discrete action spaces.", + ) + + @property + def is_auto_alpha(self) -> bool: + return self._is_auto_alpha + + def train(self, mode: bool = True) -> Self: + self.training = mode + self.actor.train(mode) + self.critic.train(mode) + self.critic2.train(mode) + return self + + def sync_weight(self) -> None: + self.soft_update(self.critic_old, self.critic, self.tau) + self.soft_update(self.critic2_old, self.critic2, self.tau) + + # TODO: violates Liskov substitution principle + def forward( # type: ignore + self, + batch: ObsBatchProtocol, + state: dict | Batch | np.ndarray | None = None, + **kwargs: Any, + ) -> DistLogProbBatchProtocol: + (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) + if self.deterministic_eval and not self.is_within_training_step: + act_B = dist.mode + else: + act_B = dist.rsample() + log_prob = dist.log_prob(act_B).unsqueeze(-1) + # apply correction for Tanh squashing when computing logprob from Gaussian + # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. + # in appendix C to get some understanding of this equation. + squashed_action = torch.tanh(act_B) + log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum( + -1, + keepdim=True, + ) + result = Batch( + logits=(loc_B, scale_B), + act=squashed_action, + state=hidden_BH, + dist=dist, + log_prob=log_prob, + ) + return cast(DistLogProbBatchProtocol, result) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + obs_next_result = self(obs_next_batch) + act_ = obs_next_result.act + return ( + torch.min( + self.critic_old(obs_next_batch.obs, act_), + self.critic2_old(obs_next_batch.obs, act_), + ) + - self.alpha * obs_next_result.log_prob + ) + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore + # critic 1&2 + td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) + td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) + batch.weight = (td1 + td2) / 2.0 # prio-buffer + + # actor + obs_result = self(batch) + act = obs_result.act + current_q1a = self.critic(batch.obs, act).flatten() + current_q2a = self.critic2(batch.obs, act).flatten() + actor_loss = ( + self.alpha * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a) + ).mean() + self.actor_optim.zero_grad() + actor_loss.backward() + self.actor_optim.step() + alpha_loss = None + + if self.is_auto_alpha: + log_prob = obs_result.log_prob.detach() + self.target_entropy + # please take a look at issue #258 if you'd like to change this line + alpha_loss = -(self.log_alpha * log_prob).mean() + self.alpha_optim.zero_grad() + alpha_loss.backward() + self.alpha_optim.step() + self.alpha = self.log_alpha.detach().exp() + + self.sync_weight() + + return SACTrainingStats( # type: ignore[return-value] + actor_loss=actor_loss.item(), + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + alpha=to_optional_float(self.alpha), + alpha_loss=to_optional_float(alpha_loss), + ) diff --git a/examples/atari/tianshou/policy/modelfree/td3.py b/examples/atari/tianshou/policy/modelfree/td3.py new file mode 100644 index 0000000000000000000000000000000000000000..8c2ae8c98452fb58ae680fd93b5129d708255fcb --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/td3.py @@ -0,0 +1,163 @@ +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar + +import gymnasium as gym +import numpy as np +import torch + +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.types import RolloutBatchProtocol +from tianshou.exploration import BaseNoise +from tianshou.policy import DDPGPolicy +from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.utils.optim import clone_optimizer + + +@dataclass(kw_only=True) +class TD3TrainingStats(TrainingStats): + actor_loss: float + critic1_loss: float + critic2_loss: float + + +TTD3TrainingStats = TypeVar("TTD3TrainingStats", bound=TD3TrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class TD3Policy(DDPGPolicy[TTD3TrainingStats], Generic[TTD3TrainingStats]): # type: ignore[type-var] + """Implementation of TD3, arXiv:1802.09477. + + :param actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> actions) + :param actor_optim: the optimizer for actor network. + :param critic: the first critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer for the first critic network. + :param action_space: Env's action space. Should be gym.spaces.Box. + :param critic2: the second critic network. (s, a -> Q(s, a)). + If None, use the same network as critic (via deepcopy). + :param critic2_optim: the optimizer for the second critic network. + If None, clone critic_optim to use for critic2.parameters(). + :param tau: param for soft update of the target network. + :param gamma: discount factor, in [0, 1]. + :param exploration_noise: add noise to action for exploration. + This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). + :param policy_noise: the noise used in updating policy network. + :param update_actor_freq: the update frequency of actor network. + :param noise_clip: the clipping range used in updating policy network. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate + in optimizer in each policy.update() + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + actor: torch.nn.Module, + actor_optim: torch.optim.Optimizer, + critic: torch.nn.Module, + critic_optim: torch.optim.Optimizer, + action_space: gym.Space, + critic2: torch.nn.Module | None = None, + critic2_optim: torch.optim.Optimizer | None = None, + tau: float = 0.005, + gamma: float = 0.99, + exploration_noise: BaseNoise | Literal["default"] | None = "default", + policy_noise: float = 0.2, + update_actor_freq: int = 2, + noise_clip: float = 0.5, + estimation_step: int = 1, + observation_space: gym.Space | None = None, + action_scaling: bool = True, + action_bound_method: Literal["clip"] | None = "clip", + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + # TODO: reduce duplication with SAC. + # Some intermediate class, like TwoCriticPolicy? + super().__init__( + actor=actor, + actor_optim=actor_optim, + critic=critic, + critic_optim=critic_optim, + action_space=action_space, + tau=tau, + gamma=gamma, + exploration_noise=exploration_noise, + estimation_step=estimation_step, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + observation_space=observation_space, + lr_scheduler=lr_scheduler, + ) + if critic2 and not critic2_optim: + raise ValueError("critic2_optim must be provided if critic2 is provided") + critic2 = critic2 or deepcopy(critic) + critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) + self.critic2, self.critic2_old = critic2, deepcopy(critic2) + self.critic2_old.eval() + self.critic2_optim = critic2_optim + + self.policy_noise = policy_noise + self.update_actor_freq = update_actor_freq + self.noise_clip = noise_clip + self._cnt = 0 + self._last = 0 + + def train(self, mode: bool = True) -> Self: + self.training = mode + self.actor.train(mode) + self.critic.train(mode) + self.critic2.train(mode) + return self + + def sync_weight(self) -> None: + self.soft_update(self.critic_old, self.critic, self.tau) + self.soft_update(self.critic2_old, self.critic2, self.tau) + self.soft_update(self.actor_old, self.actor, self.tau) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + act_ = self(obs_next_batch, model="actor_old").act + noise = torch.randn(size=act_.shape, device=act_.device) * self.policy_noise + if self.noise_clip > 0.0: + noise = noise.clamp(-self.noise_clip, self.noise_clip) + act_ += noise + return torch.min( + self.critic_old(obs_next_batch.obs, act_), + self.critic2_old(obs_next_batch.obs, act_), + ) + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3TrainingStats: # type: ignore + # critic 1&2 + td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) + td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) + batch.weight = (td1 + td2) / 2.0 # prio-buffer + + # actor + if self._cnt % self.update_actor_freq == 0: + actor_loss = -self.critic(batch.obs, self(batch, eps=0.0).act).mean() + self.actor_optim.zero_grad() + actor_loss.backward() + self._last = actor_loss.item() + self.actor_optim.step() + self.sync_weight() + self._cnt += 1 + + return TD3TrainingStats( # type: ignore[return-value] + actor_loss=self._last, + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + ) diff --git a/examples/atari/tianshou/policy/modelfree/trpo.py b/examples/atari/tianshou/policy/modelfree/trpo.py new file mode 100644 index 0000000000000000000000000000000000000000..e7aa5cfd598efb6cb56f9e98fcc015cfc73d79a6 --- /dev/null +++ b/examples/atari/tianshou/policy/modelfree/trpo.py @@ -0,0 +1,199 @@ +import warnings +from dataclasses import dataclass +from typing import Any, Literal, TypeVar + +import gymnasium as gym +import torch +import torch.nn.functional as F +from torch.distributions import kl_divergence + +from tianshou.data import Batch, SequenceSummaryStats +from tianshou.policy import NPGPolicy +from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.npg import NPGTrainingStats +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic + + +@dataclass(kw_only=True) +class TRPOTrainingStats(NPGTrainingStats): + step_size: SequenceSummaryStats + + +TTRPOTrainingStats = TypeVar("TTRPOTrainingStats", bound=TRPOTrainingStats) + + +class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]): + """Implementation of Trust Region Policy Optimization. arXiv:1502.05477. + + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer for actor and critic network. + :param dist_fn: distribution class for computing the action. + :param action_space: env's action space + :param max_kl: max kl-divergence used to constrain each actor network update. + :param backtrack_coeff: Coefficient to be multiplied by step size when + constraints are not met. + :param max_backtracks: Max number of backtracking times in linesearch. + :param optim_critic_iters: Number of times to optimize critic network per update. + :param actor_step_size: step size for actor update in natural gradient direction. + :param advantage_normalization: whether to do per mini-batch advantage + normalization. + :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param max_batchsize: the maximum size of the batch when computing GAE. + :param discount_factor: in [0, 1]. + :param reward_normalization: normalize estimated values to have std close to 1. + :param deterministic_eval: if True, use deterministic evaluation. + :param observation_space: the space of the observation. + :param action_scaling: if True, scale the action from [-1, 1] to the range of + action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + + def __init__( + self, + *, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, + optim: torch.optim.Optimizer, + dist_fn: TDistFnDiscrOrCont, + action_space: gym.Space, + max_kl: float = 0.01, + backtrack_coeff: float = 0.8, + max_backtracks: int = 10, + optim_critic_iters: int = 5, + actor_step_size: float = 0.5, + advantage_normalization: bool = True, + gae_lambda: float = 0.95, + max_batchsize: int = 256, + discount_factor: float = 0.99, + # TODO: rename to return_normalization? + reward_normalization: bool = False, + deterministic_eval: bool = False, + observation_space: gym.Space | None = None, + action_scaling: bool = True, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + actor=actor, + critic=critic, + optim=optim, + dist_fn=dist_fn, + action_space=action_space, + optim_critic_iters=optim_critic_iters, + actor_step_size=actor_step_size, + advantage_normalization=advantage_normalization, + gae_lambda=gae_lambda, + max_batchsize=max_batchsize, + discount_factor=discount_factor, + reward_normalization=reward_normalization, + deterministic_eval=deterministic_eval, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + lr_scheduler=lr_scheduler, + ) + self.max_backtracks = max_backtracks + self.max_kl = max_kl + self.backtrack_coeff = backtrack_coeff + + def learn( # type: ignore + self, + batch: Batch, + batch_size: int | None, + repeat: int, + **kwargs: Any, + ) -> TTRPOTrainingStats: + actor_losses, vf_losses, step_sizes, kls = [], [], [], [] + split_batch_size = batch_size or -1 + for _ in range(repeat): + for minibatch in batch.split(split_batch_size, merge_last=True): + # optimize actor + # direction: calculate villia gradient + dist = self(minibatch).dist # TODO could come from batch + ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() + ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) + actor_loss = -(ratio * minibatch.adv).mean() + flat_grads = self._get_flat_grad(actor_loss, self.actor, retain_graph=True).detach() + + # direction: calculate natural gradient + with torch.no_grad(): + old_dist = self(minibatch).dist + + kl = kl_divergence(old_dist, dist).mean() + # calculate first order gradient of kl with respect to theta + flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) + search_direction = -self._conjugate_gradients(flat_grads, flat_kl_grad, nsteps=10) + + # stepsize: calculate max stepsize constrained by kl bound + step_size = torch.sqrt( + 2 + * self.max_kl + / (search_direction * self._MVP(search_direction, flat_kl_grad)).sum( + 0, + keepdim=True, + ), + ) + + # stepsize: linesearch stepsize + with torch.no_grad(): + flat_params = torch.cat( + [param.data.view(-1) for param in self.actor.parameters()], + ) + for i in range(self.max_backtracks): + new_flat_params = flat_params + step_size * search_direction + self._set_from_flat_params(self.actor, new_flat_params) + # calculate kl and if in bound, loss actually down + new_dist = self(minibatch).dist + new_dratio = ( + (new_dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() + ) + new_dratio = new_dratio.reshape(new_dratio.size(0), -1).transpose(0, 1) + new_actor_loss = -(new_dratio * minibatch.adv).mean() + kl = kl_divergence(old_dist, new_dist).mean() + + if kl < self.max_kl and new_actor_loss < actor_loss: + if i > 0: + warnings.warn(f"Backtracking to step {i}.") + break + if i < self.max_backtracks - 1: + step_size = step_size * self.backtrack_coeff + else: + self._set_from_flat_params(self.actor, new_flat_params) + step_size = torch.tensor([0.0]) + warnings.warn( + "Line search failed! It seems hyperparamters" + " are poor and need to be changed.", + ) + + # optimize critic + # TODO: remove type-ignore once the top-level type-ignore is removed + for _ in range(self.optim_critic_iters): # type: ignore + value = self.critic(minibatch.obs).flatten() + vf_loss = F.mse_loss(minibatch.returns, value) + self.optim.zero_grad() + vf_loss.backward() + self.optim.step() + + actor_losses.append(actor_loss.item()) + vf_losses.append(vf_loss.item()) + step_sizes.append(step_size.item()) + kls.append(kl.item()) + + actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) + vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) + kl_summary_stat = SequenceSummaryStats.from_sequence(kls) + step_size_stat = SequenceSummaryStats.from_sequence(step_sizes) + + return TRPOTrainingStats( # type: ignore[return-value] + actor_loss=actor_loss_summary_stat, + vf_loss=vf_loss_summary_stat, + kl=kl_summary_stat, + step_size=step_size_stat, + ) diff --git a/examples/atari/tianshou/policy/multiagent/__init__.py b/examples/atari/tianshou/policy/multiagent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/atari/tianshou/policy/multiagent/mapolicy.py b/examples/atari/tianshou/policy/multiagent/mapolicy.py new file mode 100644 index 0000000000000000000000000000000000000000..81cfe0a6d9476bc1695608b656b8ec29d75b5530 --- /dev/null +++ b/examples/atari/tianshou/policy/multiagent/mapolicy.py @@ -0,0 +1,286 @@ +from typing import Any, Literal, Protocol, Self, TypeVar, cast, overload + +import numpy as np +from overrides import override + +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.batch import BatchProtocol, IndexType +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol +from tianshou.policy import BasePolicy +from tianshou.policy.base import TLearningRateScheduler, TrainingStats + +try: + from tianshou.env.pettingzoo_env import PettingZooEnv +except ImportError: + PettingZooEnv = None # type: ignore + + +class MapTrainingStats(TrainingStats): + def __init__( + self, + agent_id_to_stats: dict[str | int, TrainingStats], + train_time_aggregator: Literal["min", "max", "mean"] = "max", + ) -> None: + self._agent_id_to_stats = agent_id_to_stats + train_times = [agent_stats.train_time for agent_stats in agent_id_to_stats.values()] + match train_time_aggregator: + case "max": + aggr_function = max + case "min": + aggr_function = min + case "mean": + aggr_function = np.mean # type: ignore + case _: + raise ValueError( + f"Unknown {train_time_aggregator=}", + ) + self.train_time = aggr_function(train_times) + self.smoothed_loss = {} + + @override + def get_loss_stats_dict(self) -> dict[str, float]: + """Collects loss_stats_dicts from all agents, prepends agent_id to all keys, and joins results.""" + result_dict = {} + for agent_id, stats in self._agent_id_to_stats.items(): + agent_loss_stats_dict = stats.get_loss_stats_dict() + for k, v in agent_loss_stats_dict.items(): + result_dict[f"{agent_id}/" + k] = v + return result_dict + + +class MAPRolloutBatchProtocol(RolloutBatchProtocol, Protocol): + # TODO: this might not be entirely correct. + # The whole MAP data processing pipeline needs more documentation and possibly some refactoring + @overload + def __getitem__(self, index: str) -> RolloutBatchProtocol: + ... + + @overload + def __getitem__(self, index: IndexType) -> Self: + ... + + def __getitem__(self, index: str | IndexType) -> Any: + ... + + +class MultiAgentPolicyManager(BasePolicy): + """Multi-agent policy manager for MARL. + + This multi-agent policy manager accepts a list of + :class:`~tianshou.policy.BasePolicy`. It dispatches the batch data to each + of these policies when the "forward" is called. The same as "process_fn" + and "learn": it splits the data and feeds them to each policy. A figure in + :ref:`marl_example` can help you better understand this procedure. + + :param policies: a list of policies. + :param env: a PettingZooEnv. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + + def __init__( + self, + *, + policies: list[BasePolicy], + # TODO: 1 why restrict to PettingZooEnv? + # TODO: 2 This is the only policy that takes an env in init, is it really needed? + env: PettingZooEnv, + action_scaling: bool = False, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + super().__init__( + action_space=env.action_space, + observation_space=env.observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + lr_scheduler=lr_scheduler, + ) + assert len(policies) == len(env.agents), "One policy must be assigned for each agent." + + self.agent_idx = env.agent_idx + for i, policy in enumerate(policies): + # agent_id 0 is reserved for the environment proxy + # (this MultiAgentPolicyManager) + policy.set_agent_id(env.agents[i]) + + self.policies: dict[str | int, BasePolicy] = dict(zip(env.agents, policies, strict=True)) + """Maps agent_id to policy.""" + + # TODO: unused - remove it? + def replace_policy(self, policy: BasePolicy, agent_id: int) -> None: + """Replace the "agent_id"th policy in this manager.""" + policy.set_agent_id(agent_id) + self.policies[agent_id] = policy + + # TODO: violates Liskov substitution principle + def process_fn( # type: ignore + self, + batch: MAPRolloutBatchProtocol, + buffer: ReplayBuffer, + indice: np.ndarray, + ) -> MAPRolloutBatchProtocol: + """Dispatch batch data from `obs.agent_id` to every policy's process_fn. + + Save original multi-dimensional rew in "save_rew", set rew to the + reward of each agent during their "process_fn", and restore the + original reward afterwards. + """ + # TODO: maybe only str is actually allowed as agent_id? See MAPRolloutBatchProtocol + results: dict[str | int, RolloutBatchProtocol] = {} + assert isinstance( + batch.obs, + BatchProtocol, + ), f"here only observations of type Batch are permitted, but got {type(batch.obs)}" + # reward can be empty Batch (after initial reset) or nparray. + has_rew = isinstance(buffer.rew, np.ndarray) + if has_rew: # save the original reward in save_rew + # Since we do not override buffer.__setattr__, here we use _meta to + # change buffer.rew, otherwise buffer.rew = Batch() has no effect. + save_rew, buffer._meta.rew = buffer.rew, Batch() # type: ignore + for agent, policy in self.policies.items(): + agent_index = np.nonzero(batch.obs.agent_id == agent)[0] + if len(agent_index) == 0: + results[agent] = cast(RolloutBatchProtocol, Batch()) + continue + tmp_batch, tmp_indice = batch[agent_index], indice[agent_index] + if has_rew: + tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent]] + buffer._meta.rew = save_rew[:, self.agent_idx[agent]] + if not hasattr(tmp_batch.obs, "mask"): + if hasattr(tmp_batch.obs, "obs"): + tmp_batch.obs = tmp_batch.obs.obs + if hasattr(tmp_batch.obs_next, "obs"): + tmp_batch.obs_next = tmp_batch.obs_next.obs + results[agent] = policy.process_fn(tmp_batch, buffer, tmp_indice) + if has_rew: # restore from save_rew + buffer._meta.rew = save_rew + return Batch(results) + + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + + def exploration_noise( + self, + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: + """Add exploration noise from sub-policy onto act.""" + if not isinstance(batch.obs, Batch): + raise TypeError( + f"here only observations of type Batch are permitted, but got {type(batch.obs)}", + ) + for agent_id, policy in self.policies.items(): + agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] + if len(agent_index) == 0: + continue + act[agent_index] = policy.exploration_noise(act[agent_index], batch[agent_index]) + return act + + def forward( # type: ignore + self, + batch: Batch, + state: dict | Batch | None = None, + **kwargs: Any, + ) -> Batch: + """Dispatch batch data from obs.agent_id to every policy's forward. + + :param batch: TODO: document what is expected at input and make a BatchProtocol for it + :param state: if None, it means all agents have no state. If not + None, it should contain keys of "agent_1", "agent_2", ... + + :return: a Batch with the following contents: + TODO: establish a BatcProtocol for this + + :: + + { + "act": actions corresponding to the input + "state": { + "agent_1": output state of agent_1's policy for the state + "agent_2": xxx + ... + "agent_n": xxx} + "out": { + "agent_1": output of agent_1's policy for the input + "agent_2": xxx + ... + "agent_n": xxx} + } + """ + results: list[tuple[bool, np.ndarray, Batch, np.ndarray | Batch, Batch]] = [] + for agent_id, policy in self.policies.items(): + # This part of code is difficult to understand. + # Let's follow an example with two agents + # batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6) + # each agent plays for three transitions + # agent_index for agent 1 is [0, 2, 4] + # agent_index for agent 2 is [1, 3, 5] + # we separate the transition of each agent according to agent_id + agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] + if len(agent_index) == 0: + # (has_data, agent_index, out, act, state) + results.append((False, np.array([-1]), Batch(), Batch(), Batch())) + continue + tmp_batch = batch[agent_index] + if "rew" in tmp_batch.get_keys() and isinstance(tmp_batch.rew, np.ndarray): + # reward can be empty Batch (after initial reset) or nparray. + tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]] + if not hasattr(tmp_batch.obs, "mask"): + if hasattr(tmp_batch.obs, "obs"): + tmp_batch.obs = tmp_batch.obs.obs + if hasattr(tmp_batch.obs_next, "obs"): + tmp_batch.obs_next = tmp_batch.obs_next.obs + out = policy( + batch=tmp_batch, + state=None if state is None else state[agent_id], + **kwargs, + ) + act = out.act + each_state = out.state if (hasattr(out, "state") and out.state is not None) else Batch() + results.append((True, agent_index, out, act, each_state)) + holder: Batch = Batch.cat( + [{"act": act} for (has_data, agent_index, out, act, each_state) in results if has_data], + ) + state_dict, out_dict = {}, {} + for (agent_id, _), (has_data, agent_index, out, act, state) in zip( + self.policies.items(), + results, + strict=True, + ): + if has_data: + holder.act[agent_index] = act + state_dict[agent_id] = state + out_dict[agent_id] = out + holder["out"] = out_dict + holder["state"] = state_dict + return holder + + # Violates Liskov substitution principle + def learn( # type: ignore + self, + batch: MAPRolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> MapTrainingStats: + """Dispatch the data to all policies for learning. + + :param batch: must map agent_ids to rollout batches + """ + agent_id_to_stats = {} + for agent_id, policy in self.policies.items(): + data = batch[agent_id] + if not data.is_empty(): + train_stats = policy.learn(batch=data, **kwargs) + agent_id_to_stats[agent_id] = train_stats + return MapTrainingStats(agent_id_to_stats) + + # Need a train method that set all sub-policies to train mode. + # No need for a similar eval function, as eval internally uses the train function. + def train(self, mode: bool = True) -> Self: + """Set each internal policy in training mode.""" + for policy in self.policies.values(): + policy.train(mode) + return self diff --git a/examples/atari/tianshou/policy/random.py b/examples/atari/tianshou/policy/random.py new file mode 100644 index 0000000000000000000000000000000000000000..943ae99f2bb2c0af3a6eb30330a5247e1363d6f9 --- /dev/null +++ b/examples/atari/tianshou/policy/random.py @@ -0,0 +1,54 @@ +from typing import Any, TypeVar, cast + +import numpy as np + +from tianshou.data import Batch +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol +from tianshou.policy import BasePolicy +from tianshou.policy.base import TrainingStats + + +class RandomTrainingStats(TrainingStats): + pass + + +TRandomTrainingStats = TypeVar("TRandomTrainingStats", bound=RandomTrainingStats) + + +class RandomPolicy(BasePolicy[TRandomTrainingStats]): + """A random agent used in multi-agent learning. + + It randomly chooses an action from the legal action. + """ + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> ActBatchProtocol: + """Compute the random action over the given batch data. + + The input should contain a mask in batch.obs, with "True" to be + available and "False" to be unavailable. For example, + ``batch.obs.mask == np.array([[False, True, False]])`` means with batch + size 1, action "1" is available but action "0" and "2" are unavailable. + + :return: A :class:`~tianshou.data.Batch` with "act" key, containing + the random action. + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + mask = batch.obs.mask # type: ignore + logits = np.random.rand(*mask.shape) + logits[~mask] = -np.inf + result = Batch(act=logits.argmax(axis=-1)) + return cast(ActBatchProtocol, result) + + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TRandomTrainingStats: # type: ignore + """Since a random agent learns nothing, it returns an empty dict.""" + return RandomTrainingStats() # type: ignore[return-value] diff --git a/examples/atari/tianshou/py.typed b/examples/atari/tianshou/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/atari/tianshou/trainer/__init__.py b/examples/atari/tianshou/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5946555a23295ed386650c9bf9e5e07fc1e261a0 --- /dev/null +++ b/examples/atari/tianshou/trainer/__init__.py @@ -0,0 +1,18 @@ +"""Trainer package.""" + +from tianshou.trainer.base import ( + BaseTrainer, + OfflineTrainer, + OffpolicyTrainer, + OnpolicyTrainer, +) +from tianshou.trainer.utils import gather_info, test_episode + +__all__ = [ + "BaseTrainer", + "OffpolicyTrainer", + "OnpolicyTrainer", + "OfflineTrainer", + "test_episode", + "gather_info", +] diff --git a/examples/atari/tianshou/trainer/base.py b/examples/atari/tianshou/trainer/base.py new file mode 100644 index 0000000000000000000000000000000000000000..242f2b028685a432e5a79b25ed3e37dae24e588a --- /dev/null +++ b/examples/atari/tianshou/trainer/base.py @@ -0,0 +1,694 @@ +import logging +import time +from abc import ABC, abstractmethod +from collections import defaultdict, deque +from collections.abc import Callable +from dataclasses import asdict + +import numpy as np +import tqdm + +from tianshou.data import ( + AsyncCollector, + CollectStats, + EpochStats, + InfoStats, + ReplayBuffer, + SequenceSummaryStats, +) +from tianshou.data.collector import BaseCollector, CollectStatsBase +from tianshou.policy import BasePolicy +from tianshou.policy.base import TrainingStats +from tianshou.trainer.utils import gather_info, test_episode +from tianshou.utils import ( + BaseLogger, + DummyTqdm, + LazyLogger, + MovAvg, + tqdm_config, +) +from tianshou.utils.logging import set_numerical_fields_to_precision +from tianshou.utils.torch_utils import policy_within_training_step + +log = logging.getLogger(__name__) + + +class BaseTrainer(ABC): + """An iterator base class for trainers. + + Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results + on every epoch. + + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param batch_size: the batch size of sample data, which is going to feed in + the policy network. If None, will use the whole buffer in each gradient step. + :param train_collector: the collector used for training. + :param test_collector: the collector used for testing. If it's None, + then no testing will be performed. + :param buffer: the replay buffer used for off-policy algorithms or for pre-training. + If a policy overrides the ``process_buffer`` method, the replay buffer will + be pre-processed before training. + :param max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` + is set. + :param step_per_epoch: the number of transitions collected per epoch. + :param repeat_per_collect: the number of repeat time for policy learning, + for example, set it to 2 means the policy needs to learn each given batch + data twice. Only used in on-policy algorithms + :param episode_per_test: the number of episodes for one policy evaluation. + :param update_per_step: only used in off-policy algorithms. + How many gradient steps to perform per step in the environment + (i.e., per sample added to the buffer). + :param step_per_collect: the number of transitions the collector would + collect before the network update, i.e., trainer will collect + "step_per_collect" transitions and do some policy network update repeatedly + in each epoch. + :param episode_per_collect: the number of episodes the collector would + collect before the network update, i.e., trainer will collect + "episode_per_collect" episodes and do some policy network update repeatedly + in each epoch. + :param train_fn: a hook called at the beginning of training in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param save_best_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature + ``f(policy: BasePolicy) -> None``. + :param save_checkpoint_fn: a function to save training process and + return the saved checkpoint path, with the signature ``f(epoch: int, + env_step: int, gradient_step: int) -> str``; you can save whatever you want. + :param resume_from_log: resume env_step/gradient_step and other metadata + from existing tensorboard log. + :param stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param reward_metric: a function with signature + ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray + with shape (num_episode,)``, used in multi-agent RL. We need to return a + single scalar for each episode's result to monitor training in the + multi-agent RL setting. This function specifies what is the desired metric, + e.g., the reward of agent 1 or the average reward over all agents. + :param logger: A logger that logs statistics during + training/testing/updating. To not log anything, keep the default logger. + :param verbose: whether to print status information to stdout. + If set to False, status information will still be logged (provided that + logging is enabled via the `logging` module). + :param show_progress: whether to display a progress bar when training. + :param test_in_train: whether to test in the training phase. + """ + + __doc__: str + + @staticmethod + def gen_doc(learning_type: str) -> str: + """Document string for subclass trainer.""" + step_means = f'The "step" in {learning_type} trainer means ' + if learning_type != "offline": + step_means += "an environment step (a.k.a. transition)." + else: # offline + step_means += "a gradient step." + + trainer_name = learning_type.capitalize() + "Trainer" + + return f"""An iterator class for {learning_type} trainer procedure. + + Returns an iterator that yields a 3-tuple (epoch, stats, info) of + train results on every epoch. + + {step_means} + + Example usage: + + :: + + trainer = {trainer_name}(...) + for epoch, epoch_stat, info in trainer: + print("Epoch:", epoch) + print(epoch_stat) + print(info) + do_something_with_policy() + query_something_about_policy() + make_a_plot_with(epoch_stat) + display(info) + + - epoch int: the epoch number + - epoch_stat dict: a large collection of metrics of the current epoch + - info dict: result returned from :func:`~tianshou.trainer.gather_info` + + You can even iterate on several trainers at the same time: + + :: + + trainer1 = {trainer_name}(...) + trainer2 = {trainer_name}(...) + for result1, result2, ... in zip(trainer1, trainer2, ...): + compare_results(result1, result2, ...) + """ + + def __init__( + self, + policy: BasePolicy, + max_epoch: int, + batch_size: int | None, + train_collector: BaseCollector | None = None, + test_collector: BaseCollector | None = None, + buffer: ReplayBuffer | None = None, + step_per_epoch: int | None = None, + repeat_per_collect: int | None = None, + episode_per_test: int | None = None, + update_per_step: float = 1.0, + step_per_collect: int | None = None, + episode_per_collect: int | None = None, + train_fn: Callable[[int, int], None] | None = None, + test_fn: Callable[[int, int | None], None] | None = None, + stop_fn: Callable[[float], bool] | None = None, + save_best_fn: Callable[[BasePolicy], None] | None = None, + save_checkpoint_fn: Callable[[int, int, int], str] | None = None, + resume_from_log: bool = False, + reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, + logger: BaseLogger = LazyLogger(), + verbose: bool = True, + show_progress: bool = True, + test_in_train: bool = True, + ): + self.policy = policy + + if buffer is not None: + buffer = policy.process_buffer(buffer) + self.buffer = buffer + + self.train_collector = train_collector + self.test_collector = test_collector + + self.logger = logger + self.start_time = time.time() + self.stat: defaultdict[str, MovAvg] = defaultdict(MovAvg) + self.best_reward = 0.0 + self.best_reward_std = 0.0 + self.start_epoch = 0 + # This is only used for logging but creeps into the implementations + # of the trainers. I believe it would be better to remove + self._gradient_step = 0 + self.env_step = 0 + self.policy_update_time = 0.0 + self.max_epoch = max_epoch + self.step_per_epoch = step_per_epoch + + # either on of these two + self.step_per_collect = step_per_collect + self.episode_per_collect = episode_per_collect + + self.update_per_step = update_per_step + self.repeat_per_collect = repeat_per_collect + + self.episode_per_test = episode_per_test + + self.batch_size = batch_size + + self.train_fn = train_fn + self.test_fn = test_fn + self.stop_fn = stop_fn + self.save_best_fn = save_best_fn + self.save_checkpoint_fn = save_checkpoint_fn + + self.reward_metric = reward_metric + self.verbose = verbose + self.show_progress = show_progress + self.test_in_train = test_in_train + self.resume_from_log = resume_from_log + + self.is_run = False + self.last_rew, self.last_len = 0.0, 0.0 + + self.epoch = self.start_epoch + self.best_epoch = self.start_epoch + self.stop_fn_flag = False + self.iter_num = 0 + + def _reset_collectors(self, reset_buffer: bool = False) -> None: + if self.train_collector is not None: + self.train_collector.reset(reset_buffer=reset_buffer) + if self.test_collector is not None: + self.test_collector.reset(reset_buffer=reset_buffer) + + def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> None: + """Initialize or reset the instance to yield a new iterator from zero.""" + self.is_run = False + self.env_step = 0 + if self.resume_from_log: + ( + self.start_epoch, + self.env_step, + self._gradient_step, + ) = self.logger.restore_data() + + self.last_rew, self.last_len = 0.0, 0.0 + self.start_time = time.time() + + if reset_collectors: + self._reset_collectors(reset_buffer=reset_buffer) + + if self.train_collector is not None and ( + self.train_collector.policy != self.policy or self.test_collector is None + ): + self.test_in_train = False + + if self.test_collector is not None: + assert self.episode_per_test is not None + assert not isinstance(self.test_collector, AsyncCollector) # Issue 700 + test_result = test_episode( + self.test_collector, + self.test_fn, + self.start_epoch, + self.episode_per_test, + self.logger, + self.env_step, + self.reward_metric, + ) + assert test_result.returns_stat is not None # for mypy + self.best_epoch = self.start_epoch + self.best_reward, self.best_reward_std = ( + test_result.returns_stat.mean, + test_result.returns_stat.std, + ) + if self.save_best_fn: + self.save_best_fn(self.policy) + + self.epoch = self.start_epoch + self.stop_fn_flag = False + self.iter_num = 0 + + def __iter__(self): # type: ignore + self.reset(reset_collectors=True, reset_buffer=False) + return self + + def __next__(self) -> EpochStats: + """Perform one epoch (both train and eval).""" + self.epoch += 1 + self.iter_num += 1 + + if self.iter_num > 1: + # iterator exhaustion check + if self.epoch > self.max_epoch: + raise StopIteration + + # exit flag 1, when stop_fn succeeds in train_step or test_step + if self.stop_fn_flag: + raise StopIteration + + progress = tqdm.tqdm if self.show_progress else DummyTqdm + + # perform n step_per_epoch + with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t: + train_stat: CollectStatsBase + while t.n < t.total and not self.stop_fn_flag: + train_stat, update_stat, self.stop_fn_flag = self.training_step() + + if isinstance(train_stat, CollectStats): + pbar_data_dict = { + "env_step": str(self.env_step), + "rew": f"{self.last_rew:.2f}", + "len": str(int(self.last_len)), + "n/ep": str(train_stat.n_collected_episodes), + "n/st": str(train_stat.n_collected_steps), + } + t.update(train_stat.n_collected_steps) + else: + pbar_data_dict = {} + t.update() + + pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) + pbar_data_dict["gradient_step"] = str(self._gradient_step) + t.set_postfix(**pbar_data_dict) + + if self.stop_fn_flag: + break + + if t.n <= t.total and not self.stop_fn_flag: + t.update() + + # for offline RL + if self.train_collector is None: + assert self.buffer is not None + batch_size = self.batch_size or len(self.buffer) + self.env_step = self._gradient_step * batch_size + + test_stat = None + if not self.stop_fn_flag: + self.logger.save_data( + self.epoch, + self.env_step, + self._gradient_step, + self.save_checkpoint_fn, + ) + # test + if self.test_collector is not None: + test_stat, self.stop_fn_flag = self.test_step() + + info_stat = gather_info( + start_time=self.start_time, + policy_update_time=self.policy_update_time, + gradient_step=self._gradient_step, + best_reward=self.best_reward, + best_reward_std=self.best_reward_std, + train_collector=self.train_collector, + test_collector=self.test_collector, + ) + + self.logger.log_info_data(asdict(info_stat), self.epoch) + + # in case trainer is used with run(), epoch_stat will not be returned + return EpochStats( + epoch=self.epoch, + train_collect_stat=train_stat, + test_collect_stat=test_stat, + training_stat=update_stat, + info_stat=info_stat, + ) + + def test_step(self) -> tuple[CollectStats, bool]: + """Perform one testing step.""" + assert self.episode_per_test is not None + assert self.test_collector is not None + stop_fn_flag = False + test_stat = test_episode( + self.test_collector, + self.test_fn, + self.epoch, + self.episode_per_test, + self.logger, + self.env_step, + self.reward_metric, + ) + assert test_stat.returns_stat is not None # for mypy + rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std + if self.best_epoch < 0 or self.best_reward < rew: + self.best_epoch = self.epoch + self.best_reward = float(rew) + self.best_reward_std = rew_std + if self.save_best_fn: + self.save_best_fn(self.policy) + log_msg = ( + f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," + f" best_reward: {self.best_reward:.6f} ± " + f"{self.best_reward_std:.6f} in #{self.best_epoch}" + ) + log.info(log_msg) + if self.verbose: + print(log_msg, flush=True) + + if self.stop_fn and self.stop_fn(self.best_reward): + stop_fn_flag = True + + return test_stat, stop_fn_flag + + def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: + """Perform one training iteration. + + A training iteration includes collecting data (for online RL), determining whether to stop training, + and performing a policy update if the training iteration should continue. + + :return: the iteration's collect stats, training stats, and a flag indicating whether to stop training. + If training is to be stopped, no gradient steps will be performed and the training stats will be `None`. + """ + with policy_within_training_step(self.policy): + should_stop_training = False + + collect_stats: CollectStatsBase | CollectStats + if self.train_collector is not None: + collect_stats = self._collect_training_data() + should_stop_training = self._update_best_reward_and_return_should_stop_training( + collect_stats, + ) + else: + assert self.buffer is not None, "Either train_collector or buffer must be provided." + collect_stats = CollectStatsBase( + n_collected_episodes=len(self.buffer), + ) + + if not should_stop_training: + training_stats = self.policy_update_fn(collect_stats) + else: + training_stats = None + + return collect_stats, training_stats, should_stop_training + + def _collect_training_data(self) -> CollectStats: + """Performs training data collection. + + :return: the data collection stats + """ + assert self.episode_per_test is not None + assert self.train_collector is not None + if self.train_fn: + self.train_fn(self.epoch, self.env_step) + collect_stats = self.train_collector.collect( + n_step=self.step_per_collect, + n_episode=self.episode_per_collect, + ) + + self.env_step += collect_stats.n_collected_steps + + if collect_stats.n_collected_episodes > 0: + assert collect_stats.returns_stat is not None # for mypy + assert collect_stats.lens_stat is not None # for mypy + self.last_rew = collect_stats.returns_stat.mean + self.last_len = collect_stats.lens_stat.mean + if self.reward_metric: # TODO: move inside collector + rew = self.reward_metric(collect_stats.returns) + collect_stats.returns = rew + collect_stats.returns_stat = SequenceSummaryStats.from_sequence(rew) + + self.logger.log_train_data(asdict(collect_stats), self.env_step) + + return collect_stats + + # TODO (maybe): separate out side effect, simplify name? + def _update_best_reward_and_return_should_stop_training( + self, + collect_stats: CollectStats, + ) -> bool: + """If `test_in_train` and `stop_fn` are set, will compute the `stop_fn` on the mean return of the training data. + Then, if the `stop_fn` is True there, will collect test data also compute the stop_fn of the mean return + on it. + Finally, if the latter is also True, will return True. + + **NOTE:** has a side effect of updating the best reward and corresponding std. + + + :param collect_stats: the data collection stats + :return: flag indicating whether to stop training + """ + should_stop_training = False + + # Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics + with policy_within_training_step(self.policy, enabled=False): + if ( + collect_stats.n_collected_episodes > 0 + and self.test_in_train + and self.stop_fn + and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore + ): + assert self.test_collector is not None + assert self.episode_per_test is not None and self.episode_per_test > 0 + test_result = test_episode( + self.test_collector, + self.test_fn, + self.epoch, + self.episode_per_test, + self.logger, + self.env_step, + ) + assert test_result.returns_stat is not None # for mypy + if self.stop_fn(test_result.returns_stat.mean): + should_stop_training = True + self.best_reward = test_result.returns_stat.mean + self.best_reward_std = test_result.returns_stat.std + + return should_stop_training + + # TODO: move moving average computation and logging into its own logger + # TODO: maybe think about a command line logger instead of always printing data dict + def _update_moving_avg_stats_and_log_update_data(self, update_stat: TrainingStats) -> None: + """Log losses, update moving average stats, and also modify the smoothed_loss in update_stat.""" + cur_losses_dict = update_stat.get_loss_stats_dict() + update_stat.smoothed_loss = self._update_moving_avg_stats_and_get_averaged_data( + cur_losses_dict, + ) + self.logger.log_update_data(asdict(update_stat), self._gradient_step) + + # TODO: seems convoluted, there should be a better way of dealing with the moving average stats + def _update_moving_avg_stats_and_get_averaged_data( + self, + data: dict[str, float], + ) -> dict[str, float]: + """Add entries to the moving average object in the trainer and retrieve the averaged results. + + :param data: any entries to be tracked in the moving average object. + :return: A dictionary containing the averaged values of the tracked entries. + + """ + smoothed_data = {} + for key, loss_item in data.items(): + self.stat[key].add(loss_item) + smoothed_data[key] = self.stat[key].get() + return smoothed_data + + @abstractmethod + def policy_update_fn( + self, + collect_stats: CollectStatsBase, + ) -> TrainingStats: + """Policy update function for different trainer implementation. + + :param collect_stats: provides info about the most recent collection. In the offline case, this will contain + stats of the whole dataset + """ + + def run(self, reset_prior_to_run: bool = True) -> InfoStats: + """Consume iterator. + + See itertools - recipes. Use functions that consume iterators at C speed + (feed the entire iterator into a zero-length deque). + """ + if reset_prior_to_run: + self.reset() + try: + self.is_run = True + deque(self, maxlen=0) # feed the entire iterator into a zero-length deque + info = gather_info( + start_time=self.start_time, + policy_update_time=self.policy_update_time, + gradient_step=self._gradient_step, + best_reward=self.best_reward, + best_reward_std=self.best_reward_std, + train_collector=self.train_collector, + test_collector=self.test_collector, + ) + finally: + self.is_run = False + + return info + + def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: + """Sample a mini-batch, perform one gradient step, and update the _gradient_step counter.""" + self._gradient_step += 1 + # Note: since sample_size=batch_size, this will perform + # exactly one gradient step. This is why we don't need to calculate the + # number of gradient steps, like in the on-policy case. + update_stat = self.policy.update(sample_size=self.batch_size, buffer=buffer) + self._update_moving_avg_stats_and_log_update_data(update_stat) + return update_stat + + +class OfflineTrainer(BaseTrainer): + """Offline trainer, samples mini-batches from buffer and passes them to update. + + Uses a buffer directly and usually does not have a collector. + """ + + # for mypy + assert isinstance(BaseTrainer.__doc__, str) + __doc__ += BaseTrainer.gen_doc("offline") + "\n".join(BaseTrainer.__doc__.split("\n")[1:]) + + def policy_update_fn( + self, + collect_stats: CollectStatsBase | None = None, + ) -> TrainingStats: + """Perform one off-line policy update.""" + assert self.buffer + update_stat = self._sample_and_update(self.buffer) + # logging + self.policy_update_time += update_stat.train_time + return update_stat + + +class OffpolicyTrainer(BaseTrainer): + """Offpolicy trainer, samples mini-batches from buffer and passes them to update. + + Note that with this trainer, it is expected that the policy's `learn` method + does not perform additional mini-batching but just updates params from the received + mini-batch. + """ + + # for mypy + assert isinstance(BaseTrainer.__doc__, str) + __doc__ += BaseTrainer.gen_doc("offpolicy") + "\n".join(BaseTrainer.__doc__.split("\n")[1:]) + + def policy_update_fn( + self, + # TODO: this is the only implementation where collect_stats is actually needed. Maybe change interface? + collect_stats: CollectStatsBase, + ) -> TrainingStats: + """Perform `update_per_step * n_collected_steps` gradient steps by sampling mini-batches from the buffer. + + :param collect_stats: the :class:`~TrainingStats` instance returned by the last gradient step. Some values + in it will be replaced by their moving averages. + """ + assert self.train_collector is not None + n_collected_steps = collect_stats.n_collected_steps + n_gradient_steps = round(self.update_per_step * n_collected_steps) + if n_gradient_steps == 0: + raise ValueError( + f"n_gradient_steps is 0, n_collected_steps={n_collected_steps}, " + f"update_per_step={self.update_per_step}", + ) + for _ in range(n_gradient_steps): + update_stat = self._sample_and_update(self.train_collector.buffer) + + # logging + self.policy_update_time += update_stat.train_time + # TODO: only the last update_stat is returned, should be improved + return update_stat + + +class OnpolicyTrainer(BaseTrainer): + """On-policy trainer, passes the entire buffer to .update and resets it after. + + Note that it is expected that the learn method of a policy will perform + batching when using this trainer. + """ + + # for mypy + assert isinstance(BaseTrainer.__doc__, str) + __doc__ = BaseTrainer.gen_doc("onpolicy") + "\n".join(BaseTrainer.__doc__.split("\n")[1:]) + + def policy_update_fn( + self, + result: CollectStatsBase | None = None, + ) -> TrainingStats: + """Perform one on-policy update by passing the entire buffer to the policy's update method.""" + assert self.train_collector is not None + training_stat = self.policy.update( + sample_size=0, + buffer=self.train_collector.buffer, + # Note: sample_size is None, so the whole buffer is used for the update. + # The kwargs are in the end passed to the .learn method, which uses + # batch_size to iterate through the buffer in mini-batches + # Off-policy algos typically don't use the batch_size kwarg at all + batch_size=self.batch_size, + repeat=self.repeat_per_collect, + ) + + # just for logging, no functional role + self.policy_update_time += training_stat.train_time + # TODO: remove the gradient step counting in trainers? Doesn't seem like + # it's important and it adds complexity + self._gradient_step += 1 + if self.batch_size is None: + self._gradient_step += 1 + elif self.batch_size > 0: + self._gradient_step += int((len(self.train_collector.buffer) - 0.1) // self.batch_size) + + # Note: this is the main difference to the off-policy trainer! + # The second difference is that batches of data are sampled without replacement + # during training, whereas in off-policy or offline training, the batches are + # sampled with replacement (and potentially custom prioritization). + self.train_collector.reset_buffer(keep_statistics=True) + + # The step is the number of mini-batches used for the update, so essentially + self._update_moving_avg_stats_and_log_update_data(training_stat) + + return training_stat diff --git a/examples/atari/tianshou/trainer/utils.py b/examples/atari/tianshou/trainer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..de730cee2ec665ae5b0d4557c296b2dc5246df39 --- /dev/null +++ b/examples/atari/tianshou/trainer/utils.py @@ -0,0 +1,85 @@ +import time +from collections.abc import Callable +from dataclasses import asdict + +import numpy as np + +from tianshou.data import ( + CollectStats, + InfoStats, + SequenceSummaryStats, + TimingStats, +) +from tianshou.data.collector import BaseCollector +from tianshou.utils import BaseLogger + + +def test_episode( + collector: BaseCollector, + test_fn: Callable[[int, int | None], None] | None, + epoch: int, + n_episode: int, + logger: BaseLogger | None = None, + global_step: int | None = None, + reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, +) -> CollectStats: + """A simple wrapper of testing policy in collector.""" + collector.reset(reset_stats=False) + if test_fn: + test_fn(epoch, global_step) + result = collector.collect(n_episode=n_episode) + if reward_metric: # TODO: move into collector + rew = reward_metric(result.returns) + result.returns = rew + result.returns_stat = SequenceSummaryStats.from_sequence(rew) + if logger and global_step is not None: + assert result.n_collected_episodes > 0 + logger.log_test_data(asdict(result), global_step) + return result + + +def gather_info( + start_time: float, + policy_update_time: float, + gradient_step: int, + best_reward: float, + best_reward_std: float, + train_collector: BaseCollector | None = None, + test_collector: BaseCollector | None = None, +) -> InfoStats: + """A simple wrapper of gathering information from collectors. + + :return: InfoStats object with times computed based on the `start_time` and + episode/step counts read off the collectors. No computation of + expensive statistics is done here. + """ + duration = max(0.0, time.time() - start_time) + test_time = 0.0 + update_speed = 0.0 + train_time_collect = 0.0 + if test_collector is not None: + test_time = test_collector.collect_time + + if train_collector is not None: + train_time_collect = train_collector.collect_time + update_speed = train_collector.collect_step / (duration - test_time) + + timing_stat = TimingStats( + total_time=duration, + train_time=duration - test_time, + train_time_collect=train_time_collect, + train_time_update=policy_update_time, + test_time=test_time, + update_speed=update_speed, + ) + + return InfoStats( + gradient_step=gradient_step, + best_reward=best_reward, + best_reward_std=best_reward_std, + train_step=train_collector.collect_step if train_collector is not None else 0, + train_episode=train_collector.collect_episode if train_collector is not None else 0, + test_step=test_collector.collect_step if test_collector is not None else 0, + test_episode=test_collector.collect_episode if test_collector is not None else 0, + timing=timing_stat, + ) diff --git a/examples/atari/tianshou/utils/__init__.py b/examples/atari/tianshou/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..47a3c4497a182c2930df066022a3e8b14bd0850d --- /dev/null +++ b/examples/atari/tianshou/utils/__init__.py @@ -0,0 +1,22 @@ +"""Utils package.""" + +from tianshou.utils.logger.base import BaseLogger, LazyLogger +from tianshou.utils.logger.tensorboard import TensorboardLogger +from tianshou.utils.logger.wandb import WandbLogger +from tianshou.utils.lr_scheduler import MultipleLRSchedulers +from tianshou.utils.progress_bar import DummyTqdm, tqdm_config +from tianshou.utils.statistics import MovAvg, RunningMeanStd +from tianshou.utils.warning import deprecation + +__all__ = [ + "MovAvg", + "RunningMeanStd", + "tqdm_config", + "deprecation", + "DummyTqdm", + "BaseLogger", + "TensorboardLogger", + "LazyLogger", + "WandbLogger", + "MultipleLRSchedulers", +] diff --git a/examples/atari/tianshou/utils/conversion.py b/examples/atari/tianshou/utils/conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..bae2db3318722521d8ae6027e77710af485e066f --- /dev/null +++ b/examples/atari/tianshou/utils/conversion.py @@ -0,0 +1,25 @@ +from typing import overload + +import torch + + +@overload +def to_optional_float(x: torch.Tensor) -> float: + ... + + +@overload +def to_optional_float(x: float) -> float: + ... + + +@overload +def to_optional_float(x: None) -> None: + ... + + +def to_optional_float(x: torch.Tensor | float | None) -> float | None: + """For the common case where one needs to extract a float from a scalar Tensor, which may be None.""" + if isinstance(x, torch.Tensor): + return x.item() + return x diff --git a/examples/atari/tianshou/utils/logger/__init__.py b/examples/atari/tianshou/utils/logger/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/atari/tianshou/utils/logger/base.py b/examples/atari/tianshou/utils/logger/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c1bd737951dab3797fc54d27fe875498de0cafca --- /dev/null +++ b/examples/atari/tianshou/utils/logger/base.py @@ -0,0 +1,185 @@ +import typing +from abc import ABC, abstractmethod +from collections.abc import Callable +from enum import Enum +from numbers import Number + +import numpy as np + +VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray | float +# It's unfortunate, but we can't use Union type in isinstance, hence we resort to this +VALID_LOG_VALS = typing.get_args(VALID_LOG_VALS_TYPE) + +TRestoredData = dict[str, np.ndarray | dict[str, "TRestoredData"]] + + +class DataScope(Enum): + TRAIN = "train" + TEST = "test" + UPDATE = "update" + INFO = "info" + + +class BaseLogger(ABC): + """The base class for any logger which is compatible with trainer.""" + + def __init__( + self, + train_interval: int = 1000, + test_interval: int = 1, + update_interval: int = 1000, + info_interval: int = 1, + exclude_arrays: bool = True, + ) -> None: + """:param train_interval: the log interval in log_train_data(). Default to 1000. + :param test_interval: the log interval in log_test_data(). Default to 1. + :param update_interval: the log interval in log_update_data(). Default to 1000. + :param info_interval: the log interval in log_info_data(). Default to 1. + :param exclude_arrays: whether to exclude numpy arrays from the logger's output + """ + super().__init__() + self.train_interval = train_interval + self.test_interval = test_interval + self.update_interval = update_interval + self.info_interval = info_interval + self.exclude_arrays = exclude_arrays + self.last_log_train_step = -1 + self.last_log_test_step = -1 + self.last_log_update_step = -1 + self.last_log_info_step = -1 + + @abstractmethod + def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: + """Specify how the writer is used to log data. + + :param str step_type: namespace which the data dict belongs to. + :param step: stands for the ordinate of the data dict. + :param data: the data to write with format ``{key: value}``. + """ + + @abstractmethod + def prepare_dict_for_logging(self, log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]: + """Prepare the dict for logging by filtering out invalid data types. + + If necessary, reformulate the dict to be compatible with the writer. + + :param log_data: the dict to be prepared for logging. + :return: the prepared dict. + """ + + def log_train_data(self, log_data: dict, step: int) -> None: + """Use writer to log statistics generated during training. + + :param log_data: a dict containing the information returned by the collector during the train step. + :param step: stands for the timestep the collector result is logged. + """ + # TODO: move interval check to calling method + if step - self.last_log_train_step >= self.train_interval: + log_data = self.prepare_dict_for_logging(log_data) + self.write(f"{DataScope.TRAIN.value}/env_step", step, log_data) + self.last_log_train_step = step + + def log_test_data(self, log_data: dict, step: int) -> None: + """Use writer to log statistics generated during evaluating. + + :param log_data:a dict containing the information returned by the collector during the evaluation step. + :param step: stands for the timestep the collector result is logged. + """ + # TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer) + if step - self.last_log_test_step >= self.test_interval: + log_data = self.prepare_dict_for_logging(log_data) + self.write(f"{DataScope.TEST.value}/env_step", step, log_data) + self.last_log_test_step = step + + def log_update_data(self, log_data: dict, step: int) -> None: + """Use writer to log statistics generated during updating. + + :param log_data:a dict containing the information returned during the policy update step. + :param step: stands for the timestep the policy training data is logged. + """ + # TODO: move interval check to calling method + if step - self.last_log_update_step >= self.update_interval: + log_data = self.prepare_dict_for_logging(log_data) + self.write(f"{DataScope.UPDATE.value}/gradient_step", step, log_data) + self.last_log_update_step = step + + def log_info_data(self, log_data: dict, step: int) -> None: + """Use writer to log global statistics. + + :param log_data: a dict containing information of data collected at the end of an epoch. + :param step: stands for the timestep the training info is logged. + """ + if ( + step - self.last_log_info_step >= self.info_interval + ): # TODO: move interval check to calling method + log_data = self.prepare_dict_for_logging(log_data) + self.write(f"{DataScope.INFO.value}/epoch", step, log_data) + self.last_log_info_step = step + + @abstractmethod + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Callable[[int, int, int], str] | None = None, + ) -> None: + """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. + + :param epoch: the epoch in trainer. + :param env_step: the env_step in trainer. + :param gradient_step: the gradient_step in trainer. + :param function save_checkpoint_fn: a hook defined by user, see trainer + documentation for detail. + """ + + @abstractmethod + def restore_data(self) -> tuple[int, int, int]: + """Restore internal data if present and return the metadata from existing log for continuation of training. + + If it finds nothing or an error occurs during the recover process, it will + return the default parameters. + + :return: epoch, env_step, gradient_step. + """ + + @abstractmethod + def restore_logged_data( + self, + log_path: str, + ) -> TRestoredData: + """Load the logged data from disk for post-processing. + + :return: a dict containing the logged data. + """ + + +class LazyLogger(BaseLogger): + """A logger that does nothing. Used as the placeholder in trainer.""" + + def __init__(self) -> None: + super().__init__() + + def prepare_dict_for_logging( + self, + data: dict[str, VALID_LOG_VALS_TYPE], + ) -> dict[str, VALID_LOG_VALS_TYPE]: + return data + + def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: + """The LazyLogger writes nothing.""" + + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Callable[[int, int, int], str] | None = None, + ) -> None: + pass + + def restore_data(self) -> tuple[int, int, int]: + return 0, 0, 0 + + def restore_logged_data(self, log_path: str) -> dict: + return {} diff --git a/examples/atari/tianshou/utils/logger/tensorboard.py b/examples/atari/tianshou/utils/logger/tensorboard.py new file mode 100644 index 0000000000000000000000000000000000000000..d824d862d1eae1ead77facea63fc394fc696fee7 --- /dev/null +++ b/examples/atari/tianshou/utils/logger/tensorboard.py @@ -0,0 +1,191 @@ +from collections.abc import Callable +from typing import Any + +import numpy as np +from matplotlib.figure import Figure +from tensorboard.backend.event_processing import event_accumulator +from torch.utils.tensorboard import SummaryWriter + +from tianshou.utils.logger.base import ( + VALID_LOG_VALS, + VALID_LOG_VALS_TYPE, + BaseLogger, + TRestoredData, +) + + +class TensorboardLogger(BaseLogger): + """A logger that relies on tensorboard SummaryWriter by default to visualize and log statistics. + + :param SummaryWriter writer: the writer to log data. + :param train_interval: the log interval in log_train_data(). Default to 1000. + :param test_interval: the log interval in log_test_data(). Default to 1. + :param update_interval: the log interval in log_update_data(). Default to 1000. + :param info_interval: the log interval in log_info_data(). Default to 1. + :param save_interval: the save interval in save_data(). Default to 1 (save at + the end of each epoch). + :param write_flush: whether to flush tensorboard result after each + add_scalar operation. Default to True. + """ + + def __init__( + self, + writer: SummaryWriter, + train_interval: int = 1000, + test_interval: int = 1, + update_interval: int = 1000, + info_interval: int = 1, + save_interval: int = 1, + write_flush: bool = True, + ) -> None: + super().__init__(train_interval, test_interval, update_interval, info_interval) + self.save_interval = save_interval + self.write_flush = write_flush + self.last_save_step = -1 + self.writer = writer + + def prepare_dict_for_logging( + self, + input_dict: dict[str, Any], + parent_key: str = "", + delimiter: str = "/", + exclude_arrays: bool = True, + ) -> dict[str, VALID_LOG_VALS_TYPE]: + """Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys. + + Filtering is performed with respect to valid logging data types. + + :param input_dict: The nested dictionary to be flattened and filtered. + :param parent_key: The parent key used as a prefix before the input_dict keys. + :param delimiter: The delimiter used to separate the keys. + :param exclude_arrays: Whether to exclude numpy arrays from the output. + :return: A flattened dictionary where the keys are compressed and values are filtered. + """ + result = {} + + def add_to_result( + cur_dict: dict, + prefix: str = "", + ) -> None: + for key, value in cur_dict.items(): + if exclude_arrays and isinstance(value, np.ndarray): + continue + + new_key = prefix + delimiter + key + new_key = new_key.lstrip(delimiter) + + if isinstance(value, dict): + add_to_result( + value, + new_key, + ) + elif isinstance(value, VALID_LOG_VALS): + result[new_key] = value + + add_to_result(input_dict, prefix=parent_key) + return result + + def write(self, step_type: str, step: int, data: dict[str, Any]) -> None: + scope, step_name = step_type.split("/") + self.writer.add_scalar(step_type, step, global_step=step) + for k, v in data.items(): + scope_key = f"{scope}/{k}" + if isinstance(v, np.ndarray): + self.writer.add_histogram(scope_key, v, global_step=step, bins="auto") + elif isinstance(v, Figure): + self.writer.add_figure(scope_key, v, global_step=step) + else: + self.writer.add_scalar(scope_key, v, global_step=step) + if self.write_flush: # issue 580 + self.writer.flush() # issue #482 + + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Callable[[int, int, int], str] | None = None, + ) -> None: + if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: + self.last_save_step = epoch + save_checkpoint_fn(epoch, env_step, gradient_step) + self.write("save/epoch", epoch, {"save/epoch": epoch}) + self.write("save/env_step", env_step, {"save/env_step": env_step}) + self.write( + "save/gradient_step", + gradient_step, + {"save/gradient_step": gradient_step}, + ) + + def restore_data(self) -> tuple[int, int, int]: + ea = event_accumulator.EventAccumulator(self.writer.log_dir) + ea.Reload() + + try: # epoch / gradient_step + epoch = ea.scalars.Items("save/epoch")[-1].step + self.last_save_step = self.last_log_test_step = epoch + gradient_step = ea.scalars.Items("save/gradient_step")[-1].step + self.last_log_update_step = gradient_step + except KeyError: + epoch, gradient_step = 0, 0 + try: # offline trainer doesn't have env_step + env_step = ea.scalars.Items("save/env_step")[-1].step + self.last_log_train_step = env_step + except KeyError: + env_step = 0 + + return epoch, env_step, gradient_step + + def restore_logged_data( + self, + log_path: str, + ) -> TRestoredData: + """Restores the logged data from the tensorboard log directory. + + The result is a nested dictionary where the keys are the tensorboard keys + and the values are the corresponding numpy arrays. The keys in each level + form a nested structure, where the hierarchy is represented by the slashes + in the tensorboard key-strings. + """ + ea = event_accumulator.EventAccumulator(log_path) + ea.Reload() + + def add_value_to_innermost_nested_dict( + data_dict: dict[str, Any], + key_string: str, + value: Any, + ) -> None: + """A particular logic, walking through the keys in the + `key_string` and adding the value to the `data_dict` in a nested manner, + creating nested dictionaries on the fly if necessary, or updating existing ones. + The value is added only to the innermost-nested dictionary. + + + Example: + ------- + >>> data_dict = {} + >>> add_value_to_innermost_nested_dict(data_dict, "a/b/c", 1) + >>> data_dict + {"a": {"b": {"c": 1}}} + """ + keys = key_string.split("/") + + cur_nested_dict = data_dict + # walk through the intermediate keys to reach the innermost-nested dict, + # creating nested dictionaries on the fly if necessary + for k in keys[:-1]: + cur_nested_dict = cur_nested_dict.setdefault(k, {}) + # After the loop above, + # this is the innermost-nested dict, where the value is finally set + # for the last key in the key_string + cur_nested_dict[keys[-1]] = value + + restored_data: dict[str, np.ndarray | dict] = {} + for key_string in ea.scalars.Keys(): + add_value_to_innermost_nested_dict( + restored_data, + key_string, + np.array([s.value for s in ea.scalars.Items(key_string)]), + ) + + return restored_data diff --git a/examples/atari/tianshou/utils/logger/wandb.py b/examples/atari/tianshou/utils/logger/wandb.py new file mode 100644 index 0000000000000000000000000000000000000000..74d844fa9f7d1223c03fb5b7607d55dd6cc08ba2 --- /dev/null +++ b/examples/atari/tianshou/utils/logger/wandb.py @@ -0,0 +1,177 @@ +import argparse +import contextlib +import os +from collections.abc import Callable + +from torch.utils.tensorboard import SummaryWriter + +from tianshou.utils import BaseLogger, TensorboardLogger +from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, TRestoredData + +with contextlib.suppress(ImportError): + import wandb + + +class WandbLogger(BaseLogger): + """Weights and Biases logger that sends data to https://wandb.ai/. + + This logger creates three panels with plots: train, test, and update. + Make sure to select the correct access for each panel in weights and biases: + + Example of usage: + :: + + logger = WandbLogger() + logger.load(SummaryWriter(log_path)) + result = OnpolicyTrainer(policy, train_collector, test_collector, + logger=logger).run() + + :param train_interval: the log interval in log_train_data(). Default to 1000. + :param test_interval: the log interval in log_test_data(). Default to 1. + :param update_interval: the log interval in log_update_data(). + Default to 1000. + :param info_interval: the log interval in log_info_data(). Default to 1. + :param save_interval: the save interval in save_data(). Default to 1 (save at + the end of each epoch). + :param write_flush: whether to flush tensorboard result after each + add_scalar operation. Default to True. + :param str project: W&B project name. Default to "tianshou". + :param str name: W&B run name. Default to None. If None, random name is assigned. + :param str entity: W&B team/organization name. Default to None. + :param str run_id: run id of W&B run to be resumed. Default to None. + :param argparse.Namespace config: experiment configurations. Default to None. + """ + + def __init__( + self, + train_interval: int = 1000, + test_interval: int = 1, + update_interval: int = 1000, + info_interval: int = 1, + save_interval: int = 1000, + write_flush: bool = True, + project: str | None = None, + name: str | None = None, + entity: str | None = None, + run_id: str | None = None, + config: argparse.Namespace | dict | None = None, + monitor_gym: bool = True, + ) -> None: + super().__init__(train_interval, test_interval, update_interval, info_interval) + self.last_save_step = -1 + self.save_interval = save_interval + self.write_flush = write_flush + self.restored = False + if project is None: + project = os.getenv("WANDB_PROJECT", "tianshou") + + self.wandb_run = ( + wandb.init( + project=project, + name=name, + id=run_id, + resume="allow", + entity=entity, + sync_tensorboard=True, + monitor_gym=monitor_gym, + config=config, # type: ignore + ) + if not wandb.run + else wandb.run + ) + # TODO: don't access private attribute! + self.wandb_run._label(repo="tianshou") # type: ignore + self.tensorboard_logger: TensorboardLogger | None = None + self.writer: SummaryWriter | None = None + + def prepare_dict_for_logging(self, log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]: + if self.tensorboard_logger is None: + raise Exception( + "`logger` needs to load the Tensorboard Writer before " + "preparing data for logging. Try `logger.load(SummaryWriter(log_path))`", + ) + return self.tensorboard_logger.prepare_dict_for_logging(log_data) + + def load(self, writer: SummaryWriter) -> None: + self.writer = writer + self.tensorboard_logger = TensorboardLogger( + writer, + self.train_interval, + self.test_interval, + self.update_interval, + self.save_interval, + self.write_flush, + ) + + def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: + if self.tensorboard_logger is None: + raise RuntimeError( + "`logger` needs to load the Tensorboard Writer before " + "writing data. Try `logger.load(SummaryWriter(log_path))`", + ) + self.tensorboard_logger.write(step_type, step, data) + + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Callable[[int, int, int], str] | None = None, + ) -> None: + """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. + + :param epoch: the epoch in trainer. + :param env_step: the env_step in trainer. + :param gradient_step: the gradient_step in trainer. + :param function save_checkpoint_fn: a hook defined by user, see trainer + documentation for detail. + """ + if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: + self.last_save_step = epoch + checkpoint_path = save_checkpoint_fn(epoch, env_step, gradient_step) + + checkpoint_artifact = wandb.Artifact( + "run_" + self.wandb_run.id + "_checkpoint", # type: ignore + type="model", + metadata={ + "save/epoch": epoch, + "save/env_step": env_step, + "save/gradient_step": gradient_step, + "checkpoint_path": str(checkpoint_path), + }, + ) + checkpoint_artifact.add_file(str(checkpoint_path)) + self.wandb_run.log_artifact(checkpoint_artifact) # type: ignore + + def restore_data(self) -> tuple[int, int, int]: + checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore + f"run_{self.wandb_run.id}_checkpoint:latest", # type: ignore + ) + assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist" + + checkpoint_artifact.download( + os.path.dirname(checkpoint_artifact.metadata["checkpoint_path"]), + ) + + try: # epoch / gradient_step + epoch = checkpoint_artifact.metadata["save/epoch"] + self.last_save_step = self.last_log_test_step = epoch + gradient_step = checkpoint_artifact.metadata["save/gradient_step"] + self.last_log_update_step = gradient_step + except KeyError: + epoch, gradient_step = 0, 0 + try: # offline trainer doesn't have env_step + env_step = checkpoint_artifact.metadata["save/env_step"] + self.last_log_train_step = env_step + except KeyError: + env_step = 0 + return epoch, env_step, gradient_step + + def restore_logged_data(self, log_path: str) -> TRestoredData: + if self.tensorboard_logger is None: + raise NotImplementedError( + "Restoring logged data directly from W&B is not yet implemented." + "Try instantiating the internal TensorboardLogger by calling something" + "like `logger.load(SummaryWriter(log_path))`", + ) + return self.tensorboard_logger.restore_logged_data(log_path) diff --git a/examples/atari/tianshou/utils/logging.py b/examples/atari/tianshou/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..b2eaf3ffc5ee6bc3dc35cf8fdc7a9c2461108a00 --- /dev/null +++ b/examples/atari/tianshou/utils/logging.py @@ -0,0 +1,183 @@ +""" +Partial copy of sensai.util.logging +""" +# ruff: noqa +import atexit +import logging as lg +import sys +from collections.abc import Callable +from datetime import datetime +from io import StringIO +from logging import * +from typing import Any, TypeVar, cast + +log = getLogger(__name__) # type: ignore + +LOG_DEFAULT_FORMAT = "%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s - %(message)s" + +# Holds the log format that is configured by the user (using function `configure`), such +# that it can be reused in other places +_logFormat = LOG_DEFAULT_FORMAT + + +def set_numerical_fields_to_precision(data: dict[str, Any], precision: int = 3) -> dict[str, Any]: + """Returns a copy of the given dictionary with all numerical values rounded to the given precision. + + Note: does not recurse into nested dictionaries. + + :param data: a dictionary + :param precision: the precision to be used + """ + result = {} + for k, v in data.items(): + if isinstance(v, float): + v = round(v, precision) + result[k] = v + return result + + +def remove_log_handlers() -> None: + """Removes all current log handlers.""" + logger = getLogger() + while logger.hasHandlers(): + logger.removeHandler(logger.handlers[0]) + + +def remove_log_handler(handler: Handler) -> None: + getLogger().removeHandler(handler) + + +def is_log_handler_active(handler: Handler) -> bool: + """Checks whether the given handler is active. + + :param handler: a log handler + :return: True if the handler is active, False otherwise + """ + return handler in getLogger().handlers + + +# noinspection PyShadowingBuiltins +def configure(format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG) -> None: + """Configures logging to stdout with the given format and log level, + also configuring the default log levels of some overly verbose libraries as well as some pandas output options. + + :param format: the log format + :param level: the minimum log level + """ + global _logFormat + _logFormat = format + remove_log_handlers() + basicConfig(level=level, format=format, stream=sys.stdout) + # set log levels of third-party libraries + getLogger("numba").setLevel(INFO) + + +T = TypeVar("T") + + +# noinspection PyShadowingBuiltins +def run_main( + main_fn: Callable[[], T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG +) -> T | None: + """Configures logging with the given parameters, ensuring that any exceptions that occur during + the execution of the given function are logged. + Logs two additional messages, one before the execution of the function, and one upon its completion. + + :param main_fn: the function to be executed + :param format: the log message format + :param level: the minimum log level + :return: the result of `main_fn` + """ + configure(format=format, level=level) + log.info("Starting") # type: ignore + try: + result = main_fn() + log.info("Done") # type: ignore + return result + except Exception as e: + log.error("Exception during script execution", exc_info=e) # type: ignore + return None + + +def run_cli( + main_fn: Callable[..., T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG +) -> T | None: + """ + Configures logging with the given parameters and runs the given main function as a + CLI using `jsonargparse` (which is configured to also parse attribute docstrings, such + that dataclasses can be used as function arguments). + Using this function requires that `jsonargparse` and `docstring_parser` be available. + Like `run_main`, two additional log messages will be logged (at the beginning and end + of the execution), and it is ensured that all exceptions will be logged. + + :param main_fn: the function to be executed + :param format: the log message format + :param level: the minimum log level + :return: the result of `main_fn` + """ + from jsonargparse import set_docstring_parse_options, CLI + + set_docstring_parse_options(attribute_docstrings=True) + return run_main(lambda: CLI(main_fn), format=format, level=level) + + +def datetime_tag() -> str: + """:return: a string tag for use in log file names which contains the current date and time (compact but readable)""" + return datetime.now().strftime("%Y%m%d-%H%M%S") + + +_fileLoggerPaths: list[str] = [] +_isAtExitReportFileLoggerRegistered = False +_memoryLogStream: StringIO | None = None + + +def _at_exit_report_file_logger() -> None: + for path in _fileLoggerPaths: + print(f"A log file was saved to {path}") + + +def add_file_logger(path: str, register_atexit: bool = True) -> FileHandler: + global _isAtExitReportFileLoggerRegistered + log.info(f"Logging to {path} ...") # type: ignore + handler = FileHandler(path) + handler.setFormatter(Formatter(_logFormat)) + Logger.root.addHandler(handler) + _fileLoggerPaths.append(path) + if not _isAtExitReportFileLoggerRegistered and register_atexit: + atexit.register(_at_exit_report_file_logger) + _isAtExitReportFileLoggerRegistered = True + return handler + + +def add_memory_logger() -> None: + """Enables in-memory logging (if it is not already enabled), i.e. all log statements are written to a memory buffer and can later be + read via function `get_memory_log()`. + """ + global _memoryLogStream + if _memoryLogStream is not None: + return + _memoryLogStream = StringIO() + handler = StreamHandler(_memoryLogStream) + handler.setFormatter(Formatter(_logFormat)) + Logger.root.addHandler(handler) + + +def get_memory_log() -> Any: + """:return: the in-memory log (provided that `add_memory_logger` was called beforehand)""" + assert _memoryLogStream is not None, "This should not have happened and might be a bug." + return _memoryLogStream.getvalue() + + +class FileLoggerContext: + def __init__(self, path: str, enabled: bool = True): + self.enabled = enabled + self.path = path + self._log_handler: Handler | None = None + + def __enter__(self) -> None: + if self.enabled: + self._log_handler = add_file_logger(self.path, register_atexit=False) + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + if self._log_handler is not None: + remove_log_handler(self._log_handler) diff --git a/examples/atari/tianshou/utils/lr_scheduler.py b/examples/atari/tianshou/utils/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..59890c75c270bf043c2dfe20a1d11e89fe5b8f1d --- /dev/null +++ b/examples/atari/tianshou/utils/lr_scheduler.py @@ -0,0 +1,40 @@ +import torch + + +class MultipleLRSchedulers: + """A wrapper for multiple learning rate schedulers. + + Every time :meth:`~tianshou.utils.MultipleLRSchedulers.step` is called, + it calls the step() method of each of the schedulers that it contains. + Example usage: + :: + + scheduler1 = ConstantLR(opt1, factor=0.1, total_iters=2) + scheduler2 = ExponentialLR(opt2, gamma=0.9) + scheduler = MultipleLRSchedulers(scheduler1, scheduler2) + policy = PPOPolicy(..., lr_scheduler=scheduler) + """ + + def __init__(self, *args: torch.optim.lr_scheduler.LRScheduler): + self.schedulers = args + + def step(self) -> None: + """Take a step in each of the learning rate schedulers.""" + for scheduler in self.schedulers: + scheduler.step() + + def state_dict(self) -> list[dict]: + """Get state_dict for each of the learning rate schedulers. + + :return: A list of state_dict of learning rate schedulers. + """ + return [s.state_dict() for s in self.schedulers] + + def load_state_dict(self, state_dict: list[dict]) -> None: + """Load states from state_dict. + + :param state_dict: A list of learning rate scheduler + state_dict, in the same order as the schedulers. + """ + for s, sd in zip(self.schedulers, state_dict, strict=True): + s.__dict__.update(sd) diff --git a/examples/atari/tianshou/utils/net/__init__.py b/examples/atari/tianshou/utils/net/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/atari/tianshou/utils/net/common.py b/examples/atari/tianshou/utils/net/common.py new file mode 100644 index 0000000000000000000000000000000000000000..eceee100f60abde632695dd0e1c11d2423d10f86 --- /dev/null +++ b/examples/atari/tianshou/utils/net/common.py @@ -0,0 +1,668 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from typing import Any, Generic, TypeAlias, TypeVar, cast, no_type_check + +import numpy as np +import torch +from torch import nn + +from tianshou.data.batch import Batch +from tianshou.data.types import RecurrentStateBatch + +ModuleType = type[nn.Module] +ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]] +TActionShape: TypeAlias = Sequence[int] | int | np.int64 +TLinearLayer: TypeAlias = Callable[[int, int], nn.Module] +T = TypeVar("T") + + +def miniblock( + input_size: int, + output_size: int = 0, + norm_layer: ModuleType | None = None, + norm_args: tuple[Any, ...] | dict[Any, Any] | None = None, + activation: ModuleType | None = None, + act_args: tuple[Any, ...] | dict[Any, Any] | None = None, + linear_layer: TLinearLayer = nn.Linear, +) -> list[nn.Module]: + """Construct a miniblock with given input/output-size, norm layer and activation.""" + layers: list[nn.Module] = [linear_layer(input_size, output_size)] + if norm_layer is not None: + if isinstance(norm_args, tuple): + layers += [norm_layer(output_size, *norm_args)] + elif isinstance(norm_args, dict): + layers += [norm_layer(output_size, **norm_args)] + else: + layers += [norm_layer(output_size)] + if activation is not None: + if isinstance(act_args, tuple): + layers += [activation(*act_args)] + elif isinstance(act_args, dict): + layers += [activation(**act_args)] + else: + layers += [activation()] + return layers + + +class MLP(nn.Module): + """Simple MLP backbone. + + Create a MLP of size input_dim * hidden_sizes[0] * hidden_sizes[1] * ... + * hidden_sizes[-1] * output_dim + + :param input_dim: dimension of the input vector. + :param output_dim: dimension of the output vector. If set to 0, there + is no final linear layer. + :param hidden_sizes: shape of MLP passed in as a list, not including + input_dim and output_dim. + :param norm_layer: use which normalization before activation, e.g., + ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. + You can also pass a list of normalization modules with the same length + of hidden_sizes, to use different normalization module in different + layers. Default to no normalization. + :param activation: which activation to use after each layer, can be both + the same activation for all layers if passed in nn.Module, or different + activation for different Modules if passed in a list. Default to + nn.ReLU. + :param device: which device to create this model on. Default to None. + :param linear_layer: use this module as linear layer. Default to nn.Linear. + :param flatten_input: whether to flatten input data. Default to True. + """ + + def __init__( + self, + input_dim: int, + output_dim: int = 0, + hidden_sizes: Sequence[int] = (), + norm_layer: ModuleType | Sequence[ModuleType] | None = None, + norm_args: ArgsType | None = None, + activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU, + act_args: ArgsType | None = None, + device: str | int | torch.device | None = None, + linear_layer: TLinearLayer = nn.Linear, + flatten_input: bool = True, + ) -> None: + super().__init__() + self.device = device + if norm_layer: + if isinstance(norm_layer, list): + assert len(norm_layer) == len(hidden_sizes) + norm_layer_list = norm_layer + if isinstance(norm_args, list): + assert len(norm_args) == len(hidden_sizes) + norm_args_list = norm_args + else: + norm_args_list = [norm_args for _ in range(len(hidden_sizes))] + else: + norm_layer_list = [norm_layer for _ in range(len(hidden_sizes))] + norm_args_list = [norm_args for _ in range(len(hidden_sizes))] + else: + norm_layer_list = [None] * len(hidden_sizes) + norm_args_list = [None] * len(hidden_sizes) + if activation: + if isinstance(activation, list): + assert len(activation) == len(hidden_sizes) + activation_list = activation + if isinstance(act_args, list): + assert len(act_args) == len(hidden_sizes) + act_args_list = act_args + else: + act_args_list = [act_args for _ in range(len(hidden_sizes))] + else: + activation_list = [activation for _ in range(len(hidden_sizes))] + act_args_list = [act_args for _ in range(len(hidden_sizes))] + else: + activation_list = [None] * len(hidden_sizes) + act_args_list = [None] * len(hidden_sizes) + hidden_sizes = [input_dim, *list(hidden_sizes)] + model = [] + for in_dim, out_dim, norm, norm_args, activ, act_args in zip( + hidden_sizes[:-1], + hidden_sizes[1:], + norm_layer_list, + norm_args_list, + activation_list, + act_args_list, + strict=True, + ): + model += miniblock(in_dim, out_dim, norm, norm_args, activ, act_args, linear_layer) + if output_dim > 0: + model += [linear_layer(hidden_sizes[-1], output_dim)] + self.output_dim = output_dim or hidden_sizes[-1] + self.model = nn.Sequential(*model) + self.flatten_input = flatten_input + + @no_type_check + def forward(self, obs: np.ndarray | torch.Tensor) -> torch.Tensor: + obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) + if self.flatten_input: + obs = obs.flatten(1) + return self.model(obs) + + +TRecurrentState = TypeVar("TRecurrentState", bound=Any) + + +class NetBase(nn.Module, Generic[TRecurrentState], ABC): + """Interface for NNs used in policies.""" + + @abstractmethod + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: TRecurrentState | None = None, + info: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, TRecurrentState | None]: + pass + + +class Net(NetBase[Any]): + """Wrapper of MLP to support more specific DRL usage. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + + :param state_shape: int or a sequence of int of the shape of state. + :param action_shape: int or a sequence of int of the shape of action. + :param hidden_sizes: shape of MLP passed in as a list. + :param norm_layer: use which normalization before activation, e.g., + ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. + You can also pass a list of normalization modules with the same length + of hidden_sizes, to use different normalization module in different + layers. Default to no normalization. + :param activation: which activation to use after each layer, can be both + the same activation for all layers if passed in nn.Module, or different + activation for different Modules if passed in a list. Default to + nn.ReLU. + :param device: specify the device when the network actually runs. Default + to "cpu". + :param softmax: whether to apply a softmax layer over the last layer's + output. + :param concat: whether the input shape is concatenated by state_shape + and action_shape. If it is True, ``action_shape`` is not the output + shape, but affects the input shape only. + :param num_atoms: in order to expand to the net of distributional RL. + Default to 1 (not use). + :param dueling_param: whether to use dueling network to calculate Q + values (for Dueling DQN). If you want to use dueling option, you should + pass a tuple of two dict (first for Q and second for V) stating + self-defined arguments as stated in + class:`~tianshou.utils.net.common.MLP`. Default to None. + :param linear_layer: use this module constructor, which takes the input + and output dimension as input, as linear layer. Default to nn.Linear. + + .. seealso:: + + Please refer to :class:`~tianshou.utils.net.common.MLP` for more + detailed explanation on the usage of activation, norm_layer, etc. + + You can also refer to :class:`~tianshou.utils.net.continuous.Actor`, + :class:`~tianshou.utils.net.continuous.Critic`, etc, to see how it's + suggested be used. + """ + + def __init__( + self, + state_shape: int | Sequence[int], + action_shape: TActionShape = 0, + hidden_sizes: Sequence[int] = (), + norm_layer: ModuleType | Sequence[ModuleType] | None = None, + norm_args: ArgsType | None = None, + activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU, + act_args: ArgsType | None = None, + device: str | int | torch.device = "cpu", + softmax: bool = False, + concat: bool = False, + num_atoms: int = 1, + dueling_param: tuple[dict[str, Any], dict[str, Any]] | None = None, + linear_layer: TLinearLayer = nn.Linear, + ) -> None: + super().__init__() + self.device = device + self.softmax = softmax + self.num_atoms = num_atoms + self.Q: MLP | None = None + self.V: MLP | None = None + + input_dim = int(np.prod(state_shape)) + action_dim = int(np.prod(action_shape)) * num_atoms + if concat: + input_dim += action_dim + self.use_dueling = dueling_param is not None + output_dim = action_dim if not self.use_dueling and not concat else 0 + self.model = MLP( + input_dim, + output_dim, + hidden_sizes, + norm_layer, + norm_args, + activation, + act_args, + device, + linear_layer, + ) + if self.use_dueling: # dueling DQN + assert dueling_param is not None + kwargs_update = { + "input_dim": self.model.output_dim, + "device": self.device, + } + # Important: don't change the original dict (e.g., don't use .update()) + q_kwargs = {**dueling_param[0], **kwargs_update} + v_kwargs = {**dueling_param[1], **kwargs_update} + + q_kwargs["output_dim"] = 0 if concat else action_dim + v_kwargs["output_dim"] = 0 if concat else num_atoms + self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs) + self.output_dim = self.Q.output_dim + else: + self.output_dim = self.model.output_dim + + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any = None, + info: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, Any]: + """Mapping: obs -> flatten (inside MLP)-> logits. + + :param obs: + :param state: unused and returned as is + :param info: unused + """ + logits = self.model(obs) + batch_size = logits.shape[0] + if self.use_dueling: # Dueling DQN + assert self.Q is not None + assert self.V is not None + q, v = self.Q(logits), self.V(logits) + if self.num_atoms > 1: + q = q.view(batch_size, -1, self.num_atoms) + v = v.view(batch_size, -1, self.num_atoms) + logits = q - q.mean(dim=1, keepdim=True) + v + elif self.num_atoms > 1: + logits = logits.view(batch_size, -1, self.num_atoms) + if self.softmax: + logits = torch.softmax(logits, dim=-1) + return logits, state + + +class Recurrent(NetBase[RecurrentStateBatch]): + """Simple Recurrent network based on LSTM. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__( + self, + layer_num: int, + state_shape: int | Sequence[int], + action_shape: TActionShape, + device: str | int | torch.device = "cpu", + hidden_layer_size: int = 128, + ) -> None: + super().__init__() + self.device = device + self.nn = nn.LSTM( + input_size=hidden_layer_size, + hidden_size=hidden_layer_size, + num_layers=layer_num, + batch_first=True, + ) + self.fc1 = nn.Linear(int(np.prod(state_shape)), hidden_layer_size) + self.fc2 = nn.Linear(hidden_layer_size, int(np.prod(action_shape))) + + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: RecurrentStateBatch | None = None, + info: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, RecurrentStateBatch]: + """Mapping: obs -> flatten -> logits. + + In the evaluation mode, `obs` should be with shape ``[bsz, dim]``; in the + training mode, `obs` should be with shape ``[bsz, len, dim]``. See the code + and comment for more detail. + + :param obs: + :param state: either None or a dict with keys 'hidden' and 'cell' + :param info: unused + :return: predicted action, next state as dict with keys 'hidden' and 'cell' + """ + # Note: the original type of state is Batch but it might also be a dict + # If it is a Batch, .issubset(state) will not work. However, + # issubset(state.keys()) always works + if state is not None and not {"hidden", "cell"}.issubset(state.keys()): + raise ValueError( + f"Expected to find keys 'hidden' and 'cell' but instead found {state.keys()}", + ) + + obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) + # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) + # In short, the tensor's shape in training phase is longer than which + # in evaluation phase. + if len(obs.shape) == 2: + obs = obs.unsqueeze(-2) + obs = self.fc1(obs) + self.nn.flatten_parameters() + if state is None: + obs, (hidden, cell) = self.nn(obs) + else: + # we store the stack data in [bsz, len, ...] format + # but pytorch rnn needs [len, bsz, ...] + obs, (hidden, cell) = self.nn( + obs, + ( + state["hidden"].transpose(0, 1).contiguous(), + state["cell"].transpose(0, 1).contiguous(), + ), + ) + obs = self.fc2(obs[:, -1]) + # please ensure the first dim is batch size: [bsz, len, ...] + rnn_state_batch = cast( + RecurrentStateBatch, + Batch( + { + "hidden": hidden.transpose(0, 1).detach(), + "cell": cell.transpose(0, 1).detach(), + }, + ), + ) + return obs, rnn_state_batch + + +class ActorCritic(nn.Module): + """An actor-critic network for parsing parameters. + + Using ``actor_critic.parameters()`` instead of set.union or list+list to avoid + issue #449. + + :param nn.Module actor: the actor network. + :param nn.Module critic: the critic network. + """ + + def __init__(self, actor: nn.Module, critic: nn.Module) -> None: + super().__init__() + self.actor = actor + self.critic = critic + + +class DataParallelNet(nn.Module): + """DataParallel wrapper for training agent with multi-GPU. + + This class does only the conversion of input data type, from numpy array to torch's + Tensor. If the input is a nested dictionary, the user should create a similar class + to do the same thing. + + :param nn.Module net: the network to be distributed in different GPUs. + """ + + def __init__(self, net: nn.Module) -> None: + super().__init__() + self.net = nn.DataParallel(net) + + def forward( + self, + obs: np.ndarray | torch.Tensor, + *args: Any, + **kwargs: Any, + ) -> tuple[Any, Any]: + if not isinstance(obs, torch.Tensor): + obs = torch.as_tensor(obs, dtype=torch.float32) + return self.net(obs=obs.cuda(), *args, **kwargs) # noqa: B026 + + +class EnsembleLinear(nn.Module): + """Linear Layer of Ensemble network. + + :param ensemble_size: Number of subnets in the ensemble. + :param in_feature: dimension of the input vector. + :param out_feature: dimension of the output vector. + :param bias: whether to include an additive bias, default to be True. + """ + + def __init__( + self, + ensemble_size: int, + in_feature: int, + out_feature: int, + bias: bool = True, + ) -> None: + super().__init__() + + # To be consistent with PyTorch default initializer + k = np.sqrt(1.0 / in_feature) + weight_data = torch.rand((ensemble_size, in_feature, out_feature)) * 2 * k - k + self.weight = nn.Parameter(weight_data, requires_grad=True) + + self.bias_weights: nn.Parameter | None = None + if bias: + bias_data = torch.rand((ensemble_size, 1, out_feature)) * 2 * k - k + self.bias_weights = nn.Parameter(bias_data, requires_grad=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.matmul(x, self.weight) + if self.bias_weights is not None: + x = x + self.bias_weights + return x + + +# TODO: fix docstring +class BranchingNet(NetBase[Any]): + """Branching dual Q network. + + Network for the BranchingDQNPolicy, it uses a common network module, a value module + and action "branches" one for each dimension.It allows for a linear scaling + of Q-value the output w.r.t. the number of dimensions in the action space. + For more info please refer to: arXiv:1711.08946. + :param state_shape: int or a sequence of int of the shape of state. + :param action_shape: int or a sequence of int of the shape of action. + :param action_peer_branch: int or a sequence of int of the number of actions in + each dimension. + :param common_hidden_sizes: shape of the common MLP network passed in as a list. + :param value_hidden_sizes: shape of the value MLP network passed in as a list. + :param action_hidden_sizes: shape of the action MLP network passed in as a list. + :param norm_layer: use which normalization before activation, e.g., + ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. + You can also pass a list of normalization modules with the same length + of hidden_sizes, to use different normalization module in different + layers. Default to no normalization. + :param activation: which activation to use after each layer, can be both + the same activation for all layers if passed in nn.Module, or different + activation for different Modules if passed in a list. Default to + nn.ReLU. + :param device: specify the device when the network actually runs. Default + to "cpu". + :param softmax: whether to apply a softmax layer over the last layer's + output. + """ + + def __init__( + self, + state_shape: int | Sequence[int], + num_branches: int = 0, + action_per_branch: int = 2, + common_hidden_sizes: list[int] | None = None, + value_hidden_sizes: list[int] | None = None, + action_hidden_sizes: list[int] | None = None, + norm_layer: ModuleType | None = None, + norm_args: ArgsType | None = None, + activation: ModuleType | None = nn.ReLU, + act_args: ArgsType | None = None, + device: str | int | torch.device = "cpu", + ) -> None: + super().__init__() + common_hidden_sizes = common_hidden_sizes or [] + value_hidden_sizes = value_hidden_sizes or [] + action_hidden_sizes = action_hidden_sizes or [] + + self.device = device + self.num_branches = num_branches + self.action_per_branch = action_per_branch + # common network + common_input_dim = int(np.prod(state_shape)) + common_output_dim = 0 + self.common = MLP( + common_input_dim, + common_output_dim, + common_hidden_sizes, + norm_layer, + norm_args, + activation, + act_args, + device, + ) + # value network + value_input_dim = common_hidden_sizes[-1] + value_output_dim = 1 + self.value = MLP( + value_input_dim, + value_output_dim, + value_hidden_sizes, + norm_layer, + norm_args, + activation, + act_args, + device, + ) + # action branching network + action_input_dim = common_hidden_sizes[-1] + action_output_dim = action_per_branch + self.branches = nn.ModuleList( + [ + MLP( + action_input_dim, + action_output_dim, + action_hidden_sizes, + norm_layer, + norm_args, + activation, + act_args, + device, + ) + for _ in range(self.num_branches) + ], + ) + + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any = None, + info: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, Any]: + """Mapping: obs -> model -> logits.""" + common_out = self.common(obs) + value_out = self.value(common_out) + value_out = torch.unsqueeze(value_out, 1) + action_out = [] + for b in self.branches: + action_out.append(b(common_out)) + action_scores = torch.stack(action_out, 1) + action_scores = action_scores - torch.mean(action_scores, 2, keepdim=True) + logits = value_out + action_scores + return logits, state + + +def get_dict_state_decorator( + state_shape: dict[str, int | Sequence[int]], + keys: Sequence[str], +) -> tuple[Callable, int]: + """A helper function to make Net or equivalent classes (e.g. Actor, Critic) applicable to dict state. + + The first return item, ``decorator_fn``, will alter the implementation of forward + function of the given class by preprocessing the observation. The preprocessing is + basically flatten the observation and concatenate them based on the ``keys`` order. + The batch dimension is preserved if presented. The result observation shape will + be equal to ``new_state_shape``, the second return item. + + :param state_shape: A dictionary indicating each state's shape + :param keys: A list of state's keys. The flatten observation will be according to + this list order. + :returns: a 2-items tuple ``decorator_fn`` and ``new_state_shape`` + """ + original_shape = state_shape + flat_state_shapes = [] + for k in keys: + flat_state_shapes.append(int(np.prod(state_shape[k]))) + new_state_shape = sum(flat_state_shapes) + + def preprocess_obs(obs: Batch | dict | torch.Tensor | np.ndarray) -> torch.Tensor: + if isinstance(obs, dict) or (isinstance(obs, Batch) and keys[0] in obs): + if original_shape[keys[0]] == obs[keys[0]].shape: + # No batch dim + new_obs = torch.Tensor([obs[k] for k in keys]).flatten() + # new_obs = torch.Tensor([obs[k] for k in keys]).reshape(1, -1) + else: + bsz = obs[keys[0]].shape[0] + new_obs = torch.cat([torch.Tensor(obs[k].reshape(bsz, -1)) for k in keys], dim=1) + else: + new_obs = torch.Tensor(obs) + return new_obs + + @no_type_check + def decorator_fn(net_class): + class new_net_class(net_class): + def forward(self, obs: np.ndarray | torch.Tensor, *args, **kwargs) -> Any: + return super().forward(preprocess_obs(obs), *args, **kwargs) + + return new_net_class + + return decorator_fn, new_state_shape + + +class BaseActor(nn.Module, ABC): + @abstractmethod + def get_preprocess_net(self) -> nn.Module: + pass + + @abstractmethod + def get_output_dim(self) -> int: + pass + + @abstractmethod + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any = None, + info: dict[str, Any] | None = None, + ) -> tuple[Any, Any]: + # TODO: ALGO-REFACTORING. Marked to be addressed as part of Algorithm abstraction. + # Return type needs to be more specific + pass + + +def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T: + """Gets the given attribute from the given object or takes the alternative value if it is not present. + If both are present, they are required to match. + + :param obj: the object from which to obtain the attribute value + :param attr_name: the attribute name + :param alt_value: the alternative value for the case where the attribute is not present, which cannot be None + if the attribute is not present + :return: the value + """ + v = getattr(obj, attr_name) + if v is not None: + if alt_value is not None and v != alt_value: + raise ValueError( + f"Attribute '{attr_name}' of {obj} is defined ({v}) but does not match alt. value ({alt_value})", + ) + return v + else: + if alt_value is None: + raise ValueError( + f"Attribute '{attr_name}' of {obj} is not defined and no fallback given", + ) + return alt_value + + +def get_output_dim(module: nn.Module, alt_value: int | None) -> int: + """Retrieves value the `output_dim` attribute of the given module or uses the given alternative value if the attribute is not present. + If both are present, they must match. + + :param module: the module + :param alt_value: the alternative value + :return: the value + """ + return getattr_with_matching_alt_value(module, "output_dim", alt_value) diff --git a/examples/atari/tianshou/utils/net/continuous.py b/examples/atari/tianshou/utils/net/continuous.py new file mode 100644 index 0000000000000000000000000000000000000000..0b28f98f904e450cfbc5feecf200eb85208f7ca9 --- /dev/null +++ b/examples/atari/tianshou/utils/net/continuous.py @@ -0,0 +1,529 @@ +import warnings +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Any + +import numpy as np +import torch +from torch import nn + +from tianshou.utils.net.common import ( + MLP, + BaseActor, + Net, + TActionShape, + TLinearLayer, + get_output_dim, +) +from tianshou.utils.pickle import setstate + +SIGMA_MIN = -20 +SIGMA_MAX = 2 + + +class Actor(BaseActor): + """Simple actor network that directly outputs actions for continuous action space. + Used primarily in DDPG and its variants. For probabilistic policies, see :class:`~ActorProb`. + + It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape. + + :param preprocess_net: a self-defined preprocess_net, see usage. + Typically, an instance of :class:`~tianshou.utils.net.common.Net`. + :param action_shape: a sequence of int for the shape of action. + :param hidden_sizes: a sequence of int for constructing the MLP after + preprocess_net. + :param max_action: the scale for the final action. + :param preprocess_net_output_dim: the output dimension of + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__( + self, + preprocess_net: nn.Module | Net, + action_shape: TActionShape, + hidden_sizes: Sequence[int] = (), + max_action: float = 1.0, + device: str | int | torch.device = "cpu", + preprocess_net_output_dim: int | None = None, + ) -> None: + super().__init__() + self.device = device + self.preprocess = preprocess_net + self.output_dim = int(np.prod(action_shape)) + input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + self.last = MLP( + input_dim, + self.output_dim, + hidden_sizes, + device=self.device, + ) + self.max_action = max_action + + def get_preprocess_net(self) -> nn.Module: + return self.preprocess + + def get_output_dim(self) -> int: + return self.output_dim + + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any = None, + info: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, Any]: + """Mapping: s_B -> action_values_BA, hidden_state_BH | None. + + Returns a tensor representing the actions directly, i.e, of shape + `(n_actions, )`, and a hidden state (which may be None). + The hidden state is only not None if a recurrent net is used as part of the + learning algorithm (support for RNNs is currently experimental). + """ + action_BA, hidden_BH = self.preprocess(obs, state) + action_BA = self.max_action * torch.tanh(self.last(action_BA)) + return action_BA, hidden_BH + + +class CriticBase(nn.Module, ABC): + @abstractmethod + def forward( + self, + obs: np.ndarray | torch.Tensor, + act: np.ndarray | torch.Tensor | None = None, + info: dict[str, Any] | None = None, + ) -> torch.Tensor: + """Mapping: (s_B, a_B) -> Q(s, a)_B.""" + + +class Critic(CriticBase): + """Simple critic network. + + It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value). + + :param preprocess_net: a self-defined preprocess_net, see usage. + Typically, an instance of :class:`~tianshou.utils.net.common.Net`. + :param hidden_sizes: a sequence of int for constructing the MLP after + preprocess_net. + :param preprocess_net_output_dim: the output dimension of + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. + :param linear_layer: use this module as linear layer. + :param flatten_input: whether to flatten input data for the last layer. + :param apply_preprocess_net_to_obs_only: whether to apply `preprocess_net` to the observations only (before + concatenating with the action) - and without the observations being modified in any way beforehand. + This allows the actor's preprocessing network to be reused for the critic. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__( + self, + preprocess_net: nn.Module | Net, + hidden_sizes: Sequence[int] = (), + device: str | int | torch.device = "cpu", + preprocess_net_output_dim: int | None = None, + linear_layer: TLinearLayer = nn.Linear, + flatten_input: bool = True, + apply_preprocess_net_to_obs_only: bool = False, + ) -> None: + super().__init__() + self.device = device + self.preprocess = preprocess_net + self.output_dim = 1 + self.apply_preprocess_net_to_obs_only = apply_preprocess_net_to_obs_only + input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + self.last = MLP( + input_dim, + 1, + hidden_sizes, + device=self.device, + linear_layer=linear_layer, + flatten_input=flatten_input, + ) + + def __setstate__(self, state: dict) -> None: + setstate( + Critic, + self, + state, + new_default_properties={"apply_preprocess_net_to_obs_only": False}, + ) + + def forward( + self, + obs: np.ndarray | torch.Tensor, + act: np.ndarray | torch.Tensor | None = None, + info: dict[str, Any] | None = None, + ) -> torch.Tensor: + """Mapping: (s_B, a_B) -> Q(s, a)_B.""" + obs = torch.as_tensor( + obs, + device=self.device, + dtype=torch.float32, + ) + if self.apply_preprocess_net_to_obs_only: + obs, _ = self.preprocess(obs) + obs = obs.flatten(1) + if act is not None: + act = torch.as_tensor( + act, + device=self.device, + dtype=torch.float32, + ).flatten(1) + obs = torch.cat([obs, act], dim=1) + if not self.apply_preprocess_net_to_obs_only: + obs, _ = self.preprocess(obs) + return self.last(obs) + + +class ActorProb(BaseActor): + """Simple actor network that outputs `mu` and `sigma` to be used as input for a `dist_fn` (typically, a Gaussian). + + Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`. + + :param preprocess_net: a self-defined preprocess_net, see usage. + Typically, an instance of :class:`~tianshou.utils.net.common.Net`. + :param action_shape: a sequence of int for the shape of action. + :param hidden_sizes: a sequence of int for constructing the MLP after + preprocess_net. + :param max_action: the scale for the final action logits. + :param unbounded: whether to apply tanh activation on final logits. + :param conditioned_sigma: True when sigma is calculated from the + input, False when sigma is an independent parameter. + :param preprocess_net_output_dim: the output dimension of + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + # TODO: force kwargs, adjust downstream code + def __init__( + self, + preprocess_net: nn.Module | Net, + action_shape: TActionShape, + hidden_sizes: Sequence[int] = (), + max_action: float = 1.0, + device: str | int | torch.device = "cpu", + unbounded: bool = False, + conditioned_sigma: bool = False, + preprocess_net_output_dim: int | None = None, + ) -> None: + super().__init__() + if unbounded and not np.isclose(max_action, 1.0): + warnings.warn("Note that max_action input will be discarded when unbounded is True.") + max_action = 1.0 + self.preprocess = preprocess_net + self.device = device + self.output_dim = int(np.prod(action_shape)) + input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) + self._c_sigma = conditioned_sigma + if conditioned_sigma: + self.sigma = MLP( + input_dim, + self.output_dim, + hidden_sizes, + device=self.device, + ) + else: + self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) + self.max_action = max_action + self._unbounded = unbounded + + def get_preprocess_net(self) -> nn.Module: + return self.preprocess + + def get_output_dim(self) -> int: + return self.output_dim + + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any = None, + info: dict[str, Any] | None = None, + ) -> tuple[tuple[torch.Tensor, torch.Tensor], Any]: + """Mapping: obs -> logits -> (mu, sigma).""" + if info is None: + info = {} + logits, hidden = self.preprocess(obs, state) + mu = self.mu(logits) + if not self._unbounded: + mu = self.max_action * torch.tanh(mu) + if self._c_sigma: + sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp() + else: + shape = [1] * len(mu.shape) + shape[1] = -1 + sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp() + return (mu, sigma), state + + +class RecurrentActorProb(nn.Module): + """Recurrent version of ActorProb. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__( + self, + layer_num: int, + state_shape: Sequence[int], + action_shape: Sequence[int], + hidden_layer_size: int = 128, + max_action: float = 1.0, + device: str | int | torch.device = "cpu", + unbounded: bool = False, + conditioned_sigma: bool = False, + ) -> None: + super().__init__() + if unbounded and not np.isclose(max_action, 1.0): + warnings.warn("Note that max_action input will be discarded when unbounded is True.") + max_action = 1.0 + self.device = device + self.nn = nn.LSTM( + input_size=int(np.prod(state_shape)), + hidden_size=hidden_layer_size, + num_layers=layer_num, + batch_first=True, + ) + output_dim = int(np.prod(action_shape)) + self.mu = nn.Linear(hidden_layer_size, output_dim) + self._c_sigma = conditioned_sigma + if conditioned_sigma: + self.sigma = nn.Linear(hidden_layer_size, output_dim) + else: + self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1)) + self.max_action = max_action + self._unbounded = unbounded + + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: dict[str, torch.Tensor] | None = None, + info: dict[str, Any] | None = None, + ) -> tuple[tuple[torch.Tensor, torch.Tensor], dict[str, torch.Tensor]]: + """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" + if info is None: + info = {} + obs = torch.as_tensor( + obs, + device=self.device, + dtype=torch.float32, + ) + # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) + # In short, the tensor's shape in training phase is longer than which + # in evaluation phase. + if len(obs.shape) == 2: + obs = obs.unsqueeze(-2) + self.nn.flatten_parameters() + if state is None: + obs, (hidden, cell) = self.nn(obs) + else: + # we store the stack data in [bsz, len, ...] format + # but pytorch rnn needs [len, bsz, ...] + obs, (hidden, cell) = self.nn( + obs, + ( + state["hidden"].transpose(0, 1).contiguous(), + state["cell"].transpose(0, 1).contiguous(), + ), + ) + logits = obs[:, -1] + mu = self.mu(logits) + if not self._unbounded: + mu = self.max_action * torch.tanh(mu) + if self._c_sigma: + sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp() + else: + shape = [1] * len(mu.shape) + shape[1] = -1 + sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp() + # please ensure the first dim is batch size: [bsz, len, ...] + return (mu, sigma), { + "hidden": hidden.transpose(0, 1).detach(), + "cell": cell.transpose(0, 1).detach(), + } + + +class RecurrentCritic(nn.Module): + """Recurrent version of Critic. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__( + self, + layer_num: int, + state_shape: Sequence[int], + action_shape: Sequence[int] = [0], + device: str | int | torch.device = "cpu", + hidden_layer_size: int = 128, + ) -> None: + super().__init__() + self.state_shape = state_shape + self.action_shape = action_shape + self.device = device + self.nn = nn.LSTM( + input_size=int(np.prod(state_shape)), + hidden_size=hidden_layer_size, + num_layers=layer_num, + batch_first=True, + ) + self.fc2 = nn.Linear(hidden_layer_size + int(np.prod(action_shape)), 1) + + def forward( + self, + obs: np.ndarray | torch.Tensor, + act: np.ndarray | torch.Tensor | None = None, + info: dict[str, Any] | None = None, + ) -> torch.Tensor: + """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" + if info is None: + info = {} + obs = torch.as_tensor( + obs, + device=self.device, + dtype=torch.float32, + ) + # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) + # In short, the tensor's shape in training phase is longer than which + # in evaluation phase. + assert len(obs.shape) == 3 + self.nn.flatten_parameters() + obs, (hidden, cell) = self.nn(obs) + obs = obs[:, -1] + if act is not None: + act = torch.as_tensor( + act, + device=self.device, + dtype=torch.float32, + ) + obs = torch.cat([obs, act], dim=1) + return self.fc2(obs) + + +class Perturbation(nn.Module): + """Implementation of perturbation network in BCQ algorithm. + + Given a state and action, it can generate perturbed action. + + :param preprocess_net: a self-defined preprocess_net which output a + flattened hidden state. + :param max_action: the maximum value of each dimension of action. + :param device: which device to create this model on. + :param phi: max perturbation parameter for BCQ. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + + .. seealso:: + + You can refer to `examples/offline/offline_bcq.py` to see how to use it. + """ + + def __init__( + self, + preprocess_net: nn.Module, + max_action: float, + device: str | int | torch.device = "cpu", + phi: float = 0.05, + ): + # preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim + super().__init__() + self.preprocess_net = preprocess_net + self.device = device + self.max_action = max_action + self.phi = phi + + def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: + # preprocess_net + logits = self.preprocess_net(torch.cat([state, action], -1))[0] + noise = self.phi * self.max_action * torch.tanh(logits) + # clip to [-max_action, max_action] + return (noise + action).clamp(-self.max_action, self.max_action) + + +class VAE(nn.Module): + """Implementation of VAE. + + It models the distribution of action. Given a state, it can generate actions similar to those in batch. + It is used in BCQ algorithm. + + :param encoder: the encoder in VAE. Its input_dim must be + state_dim + action_dim, and output_dim must be hidden_dim. + :param decoder: the decoder in VAE. Its input_dim must be + state_dim + latent_dim, and output_dim must be action_dim. + :param hidden_dim: the size of the last linear-layer in encoder. + :param latent_dim: the size of latent layer. + :param max_action: the maximum value of each dimension of action. + :param device: which device to create this model on. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + + .. seealso:: + + You can refer to `examples/offline/offline_bcq.py` to see how to use it. + """ + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + hidden_dim: int, + latent_dim: int, + max_action: float, + device: str | torch.device = "cpu", + ): + super().__init__() + self.encoder = encoder + + self.mean = nn.Linear(hidden_dim, latent_dim) + self.log_std = nn.Linear(hidden_dim, latent_dim) + + self.decoder = decoder + + self.max_action = max_action + self.latent_dim = latent_dim + self.device = device + + def forward( + self, + state: torch.Tensor, + action: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [state, action] -> z , [state, z] -> action + latent_z = self.encoder(torch.cat([state, action], -1)) + # shape of z: (state.shape[:-1], hidden_dim) + + mean = self.mean(latent_z) + # Clamped for numerical stability + log_std = self.log_std(latent_z).clamp(-4, 15) + std = torch.exp(log_std) + # shape of mean, std: (state.shape[:-1], latent_dim) + + latent_z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim) + + reconstruction = self.decode(state, latent_z) # (state.shape[:-1], action_dim) + return reconstruction, mean, std + + def decode( + self, + state: torch.Tensor, + latent_z: torch.Tensor | None = None, + ) -> torch.Tensor: + # decode(state) -> action + if latent_z is None: + # state.shape[0] may be batch_size + # latent vector clipped to [-0.5, 0.5] + latent_z = ( + torch.randn(state.shape[:-1] + (self.latent_dim,)).to(self.device).clamp(-0.5, 0.5) + ) + + # decode z with state! + return self.max_action * torch.tanh(self.decoder(torch.cat([state, latent_z], -1))) diff --git a/examples/atari/tianshou/utils/net/discrete.py b/examples/atari/tianshou/utils/net/discrete.py new file mode 100644 index 0000000000000000000000000000000000000000..e32db85acaf223aa1e8569d3205c90b41babc411 --- /dev/null +++ b/examples/atari/tianshou/utils/net/discrete.py @@ -0,0 +1,580 @@ +from collections.abc import Sequence +from typing import Any + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from tianshou.data import Batch, to_torch +from tianshou.utils.net.common import MLP, BaseActor, Net, TActionShape, get_output_dim + + +class Actor(BaseActor): + """Simple actor network for discrete action spaces. + + :param preprocess_net: a self-defined preprocess_net. Typically, an instance of + :class:`~tianshou.utils.net.common.Net`. + :param action_shape: a sequence of int for the shape of action. + :param hidden_sizes: a sequence of int for constructing the MLP after + preprocess_net. Default to empty sequence (where the MLP now contains + only a single linear layer). + :param softmax_output: whether to apply a softmax layer over the last + layer's output. + :param preprocess_net_output_dim: the output dimension of + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__( + self, + preprocess_net: nn.Module | Net, + action_shape: TActionShape, + hidden_sizes: Sequence[int] = (), + softmax_output: bool = True, + preprocess_net_output_dim: int | None = None, + device: str | int | torch.device = "cpu", + ) -> None: + super().__init__() + # TODO: reduce duplication with continuous.py. Probably introducing + # base classes is a good idea. + self.device = device + self.preprocess = preprocess_net + self.output_dim = int(np.prod(action_shape)) + input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + self.last = MLP( + input_dim, + self.output_dim, + hidden_sizes, + device=self.device, + ) + self.softmax_output = softmax_output + + def get_preprocess_net(self) -> nn.Module: + return self.preprocess + + def get_output_dim(self) -> int: + return self.output_dim + + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any = None, + info: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + r"""Mapping: s_B -> action_values_BA, hidden_state_BH | None. + + Returns a tensor representing the values of each action, i.e, of shape + `(n_actions, )`, and + a hidden state (which may be None). If `self.softmax_output` is True, they are the + probabilities for taking each action. Otherwise, they will be action values. + The hidden state is only + not None if a recurrent net is used as part of the learning algorithm. + """ + x, hidden_BH = self.preprocess(obs, state) + x = self.last(x) + if self.softmax_output: + x = F.softmax(x, dim=-1) + # If we computed softmax, output is probabilities, otherwise it's the non-normalized action values + output_BA = x + return output_BA, hidden_BH + + +class Critic(nn.Module): + """Simple critic network for discrete action spaces. + + :param preprocess_net: a self-defined preprocess_net. Typically, an instance of + :class:`~tianshou.utils.net.common.Net`. + :param hidden_sizes: a sequence of int for constructing the MLP after + preprocess_net. Default to empty sequence (where the MLP now contains + only a single linear layer). + :param last_size: the output dimension of Critic network. Default to 1. + :param preprocess_net_output_dim: the output dimension of + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`.. + """ + + def __init__( + self, + preprocess_net: nn.Module | Net, + hidden_sizes: Sequence[int] = (), + last_size: int = 1, + preprocess_net_output_dim: int | None = None, + device: str | int | torch.device = "cpu", + ) -> None: + super().__init__() + self.device = device + self.preprocess = preprocess_net + self.output_dim = last_size + input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) + + # TODO: make a proper interface! + def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor: + """Mapping: s_B -> V(s)_B.""" + # TODO: don't use this mechanism for passing state + logits, _ = self.preprocess(obs, state=kwargs.get("state", None)) + return self.last(logits) + + +class CosineEmbeddingNetwork(nn.Module): + """Cosine embedding network for IQN. Convert a scalar in [0, 1] to a list of n-dim vectors. + + :param num_cosines: the number of cosines used for the embedding. + :param embedding_dim: the dimension of the embedding/output. + + .. note:: + + From https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master + /fqf_iqn_qrdqn/network.py . + """ + + def __init__(self, num_cosines: int, embedding_dim: int) -> None: + super().__init__() + self.net = nn.Sequential(nn.Linear(num_cosines, embedding_dim), nn.ReLU()) + self.num_cosines = num_cosines + self.embedding_dim = embedding_dim + + def forward(self, taus: torch.Tensor) -> torch.Tensor: + batch_size = taus.shape[0] + N = taus.shape[1] + # Calculate i * \pi (i=1,...,N). + i_pi = np.pi * torch.arange( + start=1, + end=self.num_cosines + 1, + dtype=taus.dtype, + device=taus.device, + ).view(1, 1, self.num_cosines) + # Calculate cos(i * \pi * \tau). + cosines = torch.cos(taus.view(batch_size, N, 1) * i_pi).view( + batch_size * N, + self.num_cosines, + ) + # Calculate embeddings of taus. + return self.net(cosines).view(batch_size, N, self.embedding_dim) + + +class ImplicitQuantileNetwork(Critic): + """Implicit Quantile Network. + + :param preprocess_net: a self-defined preprocess_net which output a + flattened hidden state. + :param action_shape: a sequence of int for the shape of action. + :param hidden_sizes: a sequence of int for constructing the MLP after + preprocess_net. Default to empty sequence (where the MLP now contains + only a single linear layer). + :param num_cosines: the number of cosines to use for cosine embedding. + Default to 64. + :param preprocess_net_output_dim: the output dimension of + preprocess_net. + + .. note:: + + Although this class inherits Critic, it is actually a quantile Q-Network + with output shape (batch_size, action_dim, sample_size). + + The second item of the first return value is tau vector. + """ + + def __init__( + self, + preprocess_net: nn.Module, + action_shape: TActionShape, + hidden_sizes: Sequence[int] = (), + num_cosines: int = 64, + preprocess_net_output_dim: int | None = None, + device: str | int | torch.device = "cpu", + ) -> None: + last_size = int(np.prod(action_shape)) + super().__init__(preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device) + self.input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to( + device, + ) + + def forward( # type: ignore + self, + obs: np.ndarray | torch.Tensor, + sample_size: int, + **kwargs: Any, + ) -> tuple[Any, torch.Tensor]: + r"""Mapping: s -> Q(s, \*).""" + logits, hidden = self.preprocess(obs, state=kwargs.get("state", None)) + # Sample fractions. + batch_size = logits.size(0) + taus = torch.rand(batch_size, sample_size, dtype=logits.dtype, device=logits.device) + embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view( + batch_size * sample_size, + -1, + ) + out = self.last(embedding).view(batch_size, sample_size, -1).transpose(1, 2) + return (out, taus), hidden + + +class FractionProposalNetwork(nn.Module): + """Fraction proposal network for FQF. + + :param num_fractions: the number of factions to propose. + :param embedding_dim: the dimension of the embedding/input. + + .. note:: + + Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master + /fqf_iqn_qrdqn/network.py . + """ + + def __init__(self, num_fractions: int, embedding_dim: int) -> None: + super().__init__() + self.net = nn.Linear(embedding_dim, num_fractions) + torch.nn.init.xavier_uniform_(self.net.weight, gain=0.01) + torch.nn.init.constant_(self.net.bias, 0) + self.num_fractions = num_fractions + self.embedding_dim = embedding_dim + + def forward( + self, + obs_embeddings: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Calculate (log of) probabilities q_i in the paper. + dist = torch.distributions.Categorical(logits=self.net(obs_embeddings)) + taus_1_N = torch.cumsum(dist.probs, dim=1) + # Calculate \tau_i (i=0,...,N). + taus = F.pad(taus_1_N, (1, 0)) + # Calculate \hat \tau_i (i=0,...,N-1). + tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.0 + # Calculate entropies of value distributions. + entropies = dist.entropy() + return taus, tau_hats, entropies + + +class FullQuantileFunction(ImplicitQuantileNetwork): + """Full(y parameterized) Quantile Function. + + :param preprocess_net: a self-defined preprocess_net which output a + flattened hidden state. + :param action_shape: a sequence of int for the shape of action. + :param hidden_sizes: a sequence of int for constructing the MLP after + preprocess_net. Default to empty sequence (where the MLP now contains + only a single linear layer). + :param num_cosines: the number of cosines to use for cosine embedding. + Default to 64. + :param preprocess_net_output_dim: the output dimension of + preprocess_net. + + .. note:: + + The first return value is a tuple of (quantiles, fractions, quantiles_tau), + where fractions is a Batch(taus, tau_hats, entropies). + """ + + def __init__( + self, + preprocess_net: nn.Module, + action_shape: TActionShape, + hidden_sizes: Sequence[int] = (), + num_cosines: int = 64, + preprocess_net_output_dim: int | None = None, + device: str | int | torch.device = "cpu", + ) -> None: + super().__init__( + preprocess_net, + action_shape, + hidden_sizes, + num_cosines, + preprocess_net_output_dim, + device, + ) + + def _compute_quantiles(self, obs: torch.Tensor, taus: torch.Tensor) -> torch.Tensor: + batch_size, sample_size = taus.shape + embedding = (obs.unsqueeze(1) * self.embed_model(taus)).view(batch_size * sample_size, -1) + return self.last(embedding).view(batch_size, sample_size, -1).transpose(1, 2) + + def forward( # type: ignore + self, + obs: np.ndarray | torch.Tensor, + propose_model: FractionProposalNetwork, + fractions: Batch | None = None, + **kwargs: Any, + ) -> tuple[Any, torch.Tensor]: + r"""Mapping: s -> Q(s, \*).""" + logits, hidden = self.preprocess(obs, state=kwargs.get("state", None)) + # Propose fractions + if fractions is None: + taus, tau_hats, entropies = propose_model(logits.detach()) + fractions = Batch(taus=taus, tau_hats=tau_hats, entropies=entropies) + else: + taus, tau_hats = fractions.taus, fractions.tau_hats + quantiles = self._compute_quantiles(logits, tau_hats) + # Calculate quantiles_tau for computing fraction grad + quantiles_tau = None + if self.training: + with torch.no_grad(): + quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1]) + return (quantiles, fractions, quantiles_tau), hidden + + + + + +class FullQuantileFunctionRainbow(ImplicitQuantileNetwork): + """Full(y parameterized) Quantile Function with Noisy Networks and Dueling option. + + :param preprocess_net: a self-defined preprocess_net which output a + flattened hidden state. + :param action_shape: a sequence of int for the shape of action. + :param hidden_sizes: a sequence of int for constructing the MLP after + preprocess_net. Default to empty sequence (where the MLP now contains + only a single linear layer). + :param num_cosines: the number of cosines to use for cosine embedding. + Default to 64. + :param preprocess_net_output_dim: the output dimension of + preprocess_net. + :param noisy_std: standard deviation for NoisyLinear layers. Default to 0.5. + :param is_noisy: whether to use noisy layers. Default to True. + + .. note:: + + The first return value is a tuple of (quantiles, fractions, quantiles_tau), + where fractions is a Batch(taus, tau_hats, entropies). + """ + + def __init__( + self, + preprocess_net: nn.Module, + action_shape: TActionShape, + hidden_sizes: Sequence[int] = (), + num_cosines: int = 64, + preprocess_net_output_dim: int | None = None, + device: str | int | torch.device = "cpu", + noisy_std: float = 0.5, + is_noisy: bool = True, + is_dueling : bool = True + ) -> None: + super().__init__( + preprocess_net, + action_shape, + hidden_sizes, + num_cosines, + preprocess_net_output_dim, + device, + ) + + if preprocess_net_output_dim is None: + raise ValueError("preprocess_net_output_dim must be specified and not None.") + + # print(f"preprocess_net_output_dim: {preprocess_net_output_dim}") + # print(f"hidden_sizes: {hidden_sizes}") + + self.action_shape = action_shape + self.noisy_std = noisy_std + self.is_noisy = is_noisy + self.is_dueling = is_dueling + + print(action_shape,noisy_std) + print(preprocess_net_output_dim) + + def linear(x: int, y: int) -> nn.Module: + if self.is_noisy: + return NoisyLinear(x, y, self.noisy_std) + return nn.Linear(x, y) + + # Define the advantage network + + self.advantage_net = nn.Sequential( + linear(preprocess_net_output_dim, 512), + nn.ReLU(inplace=True), + linear(512, self.action_shape) + ) + + # print("Advantage net", self.advantage_net) + + + # Define the value network for dueling architecture + if self.is_dueling: + self.value_net = nn.Sequential( + linear(preprocess_net_output_dim, 512), + nn.ReLU(inplace=True), + linear(512, 1) # Output dimension is 1 for the value function + ) + print("Dueling is True") + + + # print("The value net", self.value_net) + + # if self.is_noisy: + # self.last = nn.Sequential( + # NoisyLinear(3136, 512), + # nn.ReLU(inplace=True), + # NoisyLinear(512, action_shape) + # ) + + # print(self.last) + + # self.embed_model = nn.Linear(num_cosines, preprocess_net_output_dim) + + def _compute_quantiles(self, obs: torch.Tensor, taus: torch.Tensor) -> torch.Tensor: + batch_size, sample_size = taus.shape + embedding = (obs.unsqueeze(1) * self.embed_model(taus)).view(batch_size * sample_size, -1) + + # Compute advantages + advantage = self.advantage_net(embedding).view(batch_size, sample_size, -1).transpose(1, 2) + + if self.is_dueling: + # Compute value + value = self.value_net(embedding).view(batch_size, sample_size, 1).transpose(1, 2) + # Combine value and advantage to compute quantiles + quantiles = value + (advantage - advantage.mean(dim=1, keepdim=True)) + else: + quantiles = advantage + + return quantiles + + # return self.last(embedding).view(batch_size, sample_size, -1).transpose(1, 2) + + + + + def forward( + self, + obs: np.ndarray | torch.Tensor, + propose_model: FractionProposalNetwork, + fractions: Batch | None = None, + **kwargs: Any, + ) -> tuple[Any, torch.Tensor]: + r"""Mapping: s -> Q(s, \*).""" + logits, hidden = self.preprocess(obs, state=kwargs.get("state", None)) + # Propose fractions + if fractions is None: + taus, tau_hats, entropies = propose_model(logits.detach()) + fractions = Batch(taus=taus, tau_hats=tau_hats, entropies=entropies) + else: + taus, tau_hats = fractions.taus, fractions.tau_hats + quantiles = self._compute_quantiles(logits, tau_hats) + # Calculate quantiles_tau for computing fraction grad + quantiles_tau = None + if self.training: + with torch.no_grad(): + quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1]) + return (quantiles, fractions, quantiles_tau), hidden + + +class NoisyLinear(nn.Module): + """Implementation of Noisy Networks. arXiv:1706.10295. + + :param in_features: the number of input features. + :param out_features: the number of output features. + :param noisy_std: initial standard deviation of noisy linear layers. + + .. note:: + + Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master + /fqf_iqn_qrdqn/network.py . + """ + + def __init__(self, in_features: int, out_features: int, noisy_std: float = 0.5) -> None: + super().__init__() + + # Learnable parameters. + self.mu_W = nn.Parameter(torch.FloatTensor(out_features, in_features)) + self.sigma_W = nn.Parameter(torch.FloatTensor(out_features, in_features)) + self.mu_bias = nn.Parameter(torch.FloatTensor(out_features)) + self.sigma_bias = nn.Parameter(torch.FloatTensor(out_features)) + + # Factorized noise parameters. + self.register_buffer("eps_p", torch.FloatTensor(in_features)) + self.register_buffer("eps_q", torch.FloatTensor(out_features)) + + self.in_features = in_features + self.out_features = out_features + self.sigma = noisy_std + + self.reset() + self.sample() + + def reset(self) -> None: + bound = 1 / np.sqrt(self.in_features) + self.mu_W.data.uniform_(-bound, bound) + self.mu_bias.data.uniform_(-bound, bound) + self.sigma_W.data.fill_(self.sigma / np.sqrt(self.in_features)) + self.sigma_bias.data.fill_(self.sigma / np.sqrt(self.in_features)) + + def f(self, x: torch.Tensor) -> torch.Tensor: + x = torch.randn(x.size(0), device=x.device) + return x.sign().mul_(x.abs().sqrt_()) + + # TODO: rename or change functionality? Usually sample is not an inplace operation... + def sample(self) -> None: + self.eps_p.copy_(self.f(self.eps_p)) + self.eps_q.copy_(self.f(self.eps_q)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training: + weight = self.mu_W + self.sigma_W * (self.eps_q.ger(self.eps_p)) + bias = self.mu_bias + self.sigma_bias * self.eps_q.clone() + else: + weight = self.mu_W + bias = self.mu_bias + + return F.linear(x, weight, bias) + + +class IntrinsicCuriosityModule(nn.Module): + """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. + + :param feature_net: a self-defined feature_net which output a + flattened hidden state. + :param feature_dim: input dimension of the feature net. + :param action_dim: dimension of the action space. + :param hidden_sizes: hidden layer sizes for forward and inverse models. + :param device: device for the module. + """ + + def __init__( + self, + feature_net: nn.Module, + feature_dim: int, + action_dim: int, + hidden_sizes: Sequence[int] = (), + device: str | torch.device = "cpu", + ) -> None: + super().__init__() + self.feature_net = feature_net + self.forward_model = MLP( + feature_dim + action_dim, + output_dim=feature_dim, + hidden_sizes=hidden_sizes, + device=device, + ) + self.inverse_model = MLP( + feature_dim * 2, + output_dim=action_dim, + hidden_sizes=hidden_sizes, + device=device, + ) + self.feature_dim = feature_dim + self.action_dim = action_dim + self.device = device + + def forward( + self, + s1: np.ndarray | torch.Tensor, + act: np.ndarray | torch.Tensor, + s2: np.ndarray | torch.Tensor, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor]: + r"""Mapping: s1, act, s2 -> mse_loss, act_hat.""" + s1 = to_torch(s1, dtype=torch.float32, device=self.device) + s2 = to_torch(s2, dtype=torch.float32, device=self.device) + phi1, phi2 = self.feature_net(s1), self.feature_net(s2) + act = to_torch(act, dtype=torch.long, device=self.device) + phi2_hat = self.forward_model( + torch.cat([phi1, F.one_hot(act, num_classes=self.action_dim)], dim=1), + ) + mse_loss = 0.5 * F.mse_loss(phi2_hat, phi2, reduction="none").sum(1) + act_hat = self.inverse_model(torch.cat([phi1, phi2], dim=1)) + return mse_loss, act_hat diff --git a/examples/atari/tianshou/utils/optim.py b/examples/atari/tianshou/utils/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..c69ef71db4eb09528439b9718d7763b02c0822f0 --- /dev/null +++ b/examples/atari/tianshou/utils/optim.py @@ -0,0 +1,69 @@ +from collections.abc import Iterator +from typing import TypeVar + +import torch +from torch import nn + + +def optim_step( + loss: torch.Tensor, + optim: torch.optim.Optimizer, + module: nn.Module | None = None, + max_grad_norm: float | None = None, +) -> None: + """Perform a single optimization step: zero_grad -> backward (-> clip_grad_norm) -> step. + + :param loss: + :param optim: + :param module: the module to optimize, required if max_grad_norm is passed + :param max_grad_norm: if passed, will clip gradients using this + """ + optim.zero_grad() + loss.backward() + if max_grad_norm: + if not module: + raise ValueError( + "module must be passed if max_grad_norm is passed. " + "Note: often the module will be the policy, i.e.`self`", + ) + nn.utils.clip_grad_norm_(module.parameters(), max_norm=max_grad_norm) + optim.step() + + +_STANDARD_TORCH_OPTIMIZERS = [ + torch.optim.Adam, + torch.optim.SGD, + torch.optim.RMSprop, + torch.optim.Adadelta, + torch.optim.AdamW, + torch.optim.Adamax, + torch.optim.NAdam, + torch.optim.SparseAdam, + torch.optim.LBFGS, +] + +TOptim = TypeVar("TOptim", bound=torch.optim.Optimizer) + + +def clone_optimizer( + optim: TOptim, + new_params: nn.Parameter | Iterator[nn.Parameter], +) -> TOptim: + """Clone an optimizer to get a new optim instance with new parameters. + + **WARNING**: This is a temporary measure, and should not be used in downstream code! + Once tianshou interfaces have moved to optimizer factories instead of optimizers, + this will be removed. + + :param optim: the optimizer to clone + :param new_params: the new parameters to use + :return: a new optimizer with the same configuration as the old one + """ + optim_class = type(optim) + # custom optimizers may not behave as expected + if optim_class not in _STANDARD_TORCH_OPTIMIZERS: + raise ValueError( + f"Cannot clone optimizer {optim} of type {optim_class}" + f"Currently, only standard torch optimizers are supported.", + ) + return optim_class(new_params, **optim.defaults) diff --git a/examples/atari/tianshou/utils/pickle.py b/examples/atari/tianshou/utils/pickle.py new file mode 100644 index 0000000000000000000000000000000000000000..92471622202bc8ec88189f62f98661b9569bf63f --- /dev/null +++ b/examples/atari/tianshou/utils/pickle.py @@ -0,0 +1,97 @@ +"""Helper functions for persistence/pickling, which have been copied from sensAI (specifically `sensai.util.pickle`).""" + +from collections.abc import Iterable +from copy import copy +from typing import Any + + +def setstate( + cls: type, + obj: Any, + state: dict[str, Any], + renamed_properties: dict[str, str] | None = None, + new_optional_properties: list[str] | None = None, + new_default_properties: dict[str, Any] | None = None, + removed_properties: list[str] | None = None, +) -> None: + """Helper function for safe implementations of `__setstate__` in classes, which appropriately handles the cases where + a parent class already implements `__setstate__` and where it does not. Call this function whenever you would actually + like to call the super-class' implementation. + Unfortunately, `__setstate__` is not implemented in `object`, rendering `super().__setstate__(state)` invalid in the general case. + + :param cls: the class in which you are implementing `__setstate__` + :param obj: the instance of `cls` + :param state: the state dictionary + :param renamed_properties: a mapping from old property names to new property names + :param new_optional_properties: a list of names of new property names, which, if not present, shall be initialized with None + :param new_default_properties: a dictionary mapping property names to their default values, which shall be added if they are not present + :param removed_properties: a list of names of properties that are no longer being used + """ + # handle new/changed properties + if renamed_properties is not None: + for mOld, mNew in renamed_properties.items(): + if mOld in state: + state[mNew] = state[mOld] + del state[mOld] + if new_optional_properties is not None: + for mNew in new_optional_properties: + if mNew not in state: + state[mNew] = None + if new_default_properties is not None: + for mNew, mValue in new_default_properties.items(): + if mNew not in state: + state[mNew] = mValue + if removed_properties is not None: + for p in removed_properties: + if p in state: + del state[p] + # call super implementation, if any + s = super(cls, obj) + if hasattr(s, "__setstate__"): + s.__setstate__(state) + else: + obj.__dict__ = state + + +def getstate( + cls: type, + obj: Any, + transient_properties: Iterable[str] | None = None, + excluded_properties: Iterable[str] | None = None, + override_properties: dict[str, Any] | None = None, + excluded_default_properties: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Helper function for safe implementations of `__getstate__` in classes, which appropriately handles the cases where + a parent class already implements `__getstate__` and where it does not. Call this function whenever you would actually + like to call the super-class' implementation. + Unfortunately, `__getstate__` is not implemented in `object`, rendering `super().__getstate__()` invalid in the general case. + + :param cls: the class in which you are implementing `__getstate__` + :param obj: the instance of `cls` + :param transient_properties: transient properties which shall be set to None in serializations + :param excluded_properties: properties which shall be completely removed from serializations + :param override_properties: a mapping from property names to values specifying (new or existing) properties which are to be set; + use this to set a fixed value for an existing property or to add a completely new property + :param excluded_default_properties: properties which shall be completely removed from serializations, if they are set + to the given default value + :return: the state dictionary, which may be modified by the receiver + """ + s = super(cls, obj) + d = s.__getstate__() if hasattr(s, "__getstate__") else obj.__dict__ + d = copy(d) + if transient_properties is not None: + for p in transient_properties: + if p in d: + d[p] = None + if excluded_properties is not None: + for p in excluded_properties: + if p in d: + del d[p] + if override_properties is not None: + for k, v in override_properties.items(): + d[k] = v + if excluded_default_properties is not None: + for p, v in excluded_default_properties.items(): + if p in d and d[p] == v: + del d[p] + return d diff --git a/examples/atari/tianshou/utils/print.py b/examples/atari/tianshou/utils/print.py new file mode 100644 index 0000000000000000000000000000000000000000..88035ba40b0724df674ff69c04b1d365208e8add --- /dev/null +++ b/examples/atari/tianshou/utils/print.py @@ -0,0 +1,29 @@ +import pprint +from collections.abc import Sequence +from dataclasses import asdict, dataclass + + +@dataclass +class DataclassPPrintMixin: + def pprint_asdict(self, exclude_fields: Sequence[str] | None = None, indent: int = 4) -> None: + """Pretty-print the object as a dict, excluding specified fields. + + :param exclude_fields: A sequence of field names to exclude from the output. + If None, no fields are excluded. + :param indent: The indentation to use when pretty-printing. + """ + print(self.pprints_asdict(exclude_fields=exclude_fields, indent=indent)) + + def pprints_asdict(self, exclude_fields: Sequence[str] | None = None, indent: int = 4) -> str: + """String corresponding to pretty-print of the object as a dict, excluding specified fields. + + :param exclude_fields: A sequence of field names to exclude from the output. + If None, no fields are excluded. + :param indent: The indentation to use when pretty-printing. + """ + prefix = f"{self.__class__.__name__}\n----------------------------------------\n" + print_dict = asdict(self) + exclude_fields = exclude_fields or [] + for field in exclude_fields: + print_dict.pop(field, None) + return prefix + pprint.pformat(print_dict, indent=indent) diff --git a/examples/atari/tianshou/utils/progress_bar.py b/examples/atari/tianshou/utils/progress_bar.py new file mode 100644 index 0000000000000000000000000000000000000000..dc3cd039d013de6d6a5e8c1b66dd14eb6cd5f18d --- /dev/null +++ b/examples/atari/tianshou/utils/progress_bar.py @@ -0,0 +1,35 @@ +from typing import Any + +tqdm_config = { + "dynamic_ncols": True, + "ascii": True, +} + + +class DummyTqdm: + """A dummy tqdm class that keeps stats but without progress bar. + + It supports ``__enter__`` and ``__exit__``, update and a dummy + ``set_postfix``, which is the interface that trainers use. + + .. note:: + + Using ``disable=True`` in tqdm config results in infinite loop, thus + this class is created. See the discussion at #641 for details. + """ + + def __init__(self, total: int, **kwargs: Any): + self.total = total + self.n = 0 + + def set_postfix(self, **kwargs: Any) -> None: + pass + + def update(self, n: int = 1) -> None: + self.n += n + + def __enter__(self) -> "DummyTqdm": + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + pass diff --git a/examples/atari/tianshou/utils/space_info.py b/examples/atari/tianshou/utils/space_info.py new file mode 100644 index 0000000000000000000000000000000000000000..f8b99053f3b4e91df25d69a53612aecbec528ac6 --- /dev/null +++ b/examples/atari/tianshou/utils/space_info.py @@ -0,0 +1,113 @@ +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Self + +import gymnasium as gym +import numpy as np +from gymnasium import spaces + +from tianshou.utils.string import ToStringMixin + + +@dataclass(kw_only=True) +class ActionSpaceInfo(ToStringMixin): + """A data structure for storing the different attributes of the action space.""" + + action_shape: int | Sequence[int] + """The shape of the action space.""" + min_action: float + """The smallest allowable action or in the continuous case the lower bound for allowable action value.""" + max_action: float + """The largest allowable action or in the continuous case the upper bound for allowable action value.""" + + @property + def action_dim(self) -> int: + """Return the number of distinct actions (must be greater than zero) an agent can take it its action space.""" + if isinstance(self.action_shape, int): + return self.action_shape + else: + return int(np.prod(self.action_shape)) + + @classmethod + def from_space(cls, space: spaces.Space) -> Self: + """Instantiate the `ActionSpaceInfo` object from a `Space`, supported spaces are Box and Discrete.""" + if isinstance(space, spaces.Box): + return cls( + action_shape=space.shape, + min_action=float(np.min(space.low)), + max_action=float(np.max(space.high)), + ) + elif isinstance(space, spaces.Discrete): + return cls( + action_shape=int(space.n), + min_action=float(space.start), + max_action=float(space.start + space.n - 1), + ) + else: + raise ValueError( + f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.", + ) + + def _tostring_additional_entries(self) -> dict[str, Any]: + return {"action_dim": self.action_dim} + + +@dataclass(kw_only=True) +class ObservationSpaceInfo(ToStringMixin): + """A data structure for storing the different attributes of the observation space.""" + + obs_shape: int | Sequence[int] + """The shape of the observation space.""" + + @property + def obs_dim(self) -> int: + """Return the number of distinct features (must be greater than zero) or dimensions in the observation space.""" + if isinstance(self.obs_shape, int): + return self.obs_shape + else: + return int(np.prod(self.obs_shape)) + + @classmethod + def from_space(cls, space: spaces.Space) -> Self: + """Instantiate the `ObservationSpaceInfo` object from a `Space`, supported spaces are Box and Discrete.""" + if isinstance(space, spaces.Box): + return cls( + obs_shape=space.shape, + ) + elif isinstance(space, spaces.Discrete): + return cls( + obs_shape=int(space.n), + ) + else: + raise ValueError( + f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.", + ) + + def _tostring_additional_entries(self) -> dict[str, Any]: + return {"obs_dim": self.obs_dim} + + +@dataclass(kw_only=True) +class SpaceInfo(ToStringMixin): + """A data structure for storing the attributes of both the action and observation space.""" + + action_info: ActionSpaceInfo + """Stores the attributes of the action space.""" + observation_info: ObservationSpaceInfo + """Stores the attributes of the observation space.""" + + @classmethod + def from_env(cls, env: gym.Env) -> Self: + """Instantiate the `SpaceInfo` object from `gym.Env.action_space` and `gym.Env.observation_space`.""" + return cls.from_spaces(env.action_space, env.observation_space) + + @classmethod + def from_spaces(cls, action_space: spaces.Space, observation_space: spaces.Space) -> Self: + """Instantiate the `SpaceInfo` object from `ActionSpaceInfo` and `ObservationSpaceInfo`.""" + action_info = ActionSpaceInfo.from_space(action_space) + observation_info = ObservationSpaceInfo.from_space(observation_space) + + return cls( + action_info=action_info, + observation_info=observation_info, + ) diff --git a/examples/atari/tianshou/utils/statistics.py b/examples/atari/tianshou/utils/statistics.py new file mode 100644 index 0000000000000000000000000000000000000000..779b0babec7dbe1a9fde27a8bb159441d04524cf --- /dev/null +++ b/examples/atari/tianshou/utils/statistics.py @@ -0,0 +1,114 @@ +from numbers import Number + +import numpy as np +import torch + + +class MovAvg: + """Class for moving average. + + It will automatically exclude the infinity and NaN. Usage: + :: + + >>> stat = MovAvg(size=66) + >>> stat.add(torch.tensor(5)) + 5.0 + >>> stat.add(float('inf')) # which will not add to stat + 5.0 + >>> stat.add([6, 7, 8]) + 6.5 + >>> stat.get() + 6.5 + >>> print(f'{stat.mean():.2f}±{stat.std():.2f}') + 6.50±1.12 + """ + + def __init__(self, size: int = 100) -> None: + super().__init__() + self.size = size + self.cache: list[np.number] = [] + self.banned = [np.inf, np.nan, -np.inf] + + def add( + self, + data_array: Number | float | np.number | list | np.ndarray | torch.Tensor, + ) -> float: + """Add a scalar into :class:`MovAvg`. + + You can add ``torch.Tensor`` with only one element, a python scalar, or + a list of python scalar. + """ + if isinstance(data_array, torch.Tensor): + data_array = data_array.flatten().cpu().numpy() + if np.isscalar(data_array): + data_array = [data_array] + for number in data_array: # type: ignore + if number not in self.banned: + self.cache.append(number) + if self.size > 0 and len(self.cache) > self.size: + self.cache = self.cache[-self.size :] + return self.get() + + def get(self) -> float: + """Get the average.""" + if len(self.cache) == 0: + return 0.0 + return float(np.mean(self.cache)) # type: ignore + + def mean(self) -> float: + """Get the average. Same as :meth:`get`.""" + return self.get() + + def std(self) -> float: + """Get the standard deviation.""" + if len(self.cache) == 0: + return 0.0 + return float(np.std(self.cache)) # type: ignore + + +class RunningMeanStd: + """Calculates the running mean and std of a data stream. + + https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + + :param mean: the initial mean estimation for data array. Default to 0. + :param std: the initial standard error estimation for data array. Default to 1. + :param clip_max: the maximum absolute value for data array. Default to + 10.0. + :param epsilon: To avoid division by zero. + """ + + def __init__( + self, + mean: float | np.ndarray = 0.0, + std: float | np.ndarray = 1.0, + clip_max: float | None = 10.0, + epsilon: float = np.finfo(np.float32).eps.item(), + ) -> None: + self.mean, self.var = mean, std + self.clip_max = clip_max + self.count = 0 + self.eps = epsilon + + def norm(self, data_array: float | np.ndarray) -> float | np.ndarray: + data_array = (data_array - self.mean) / np.sqrt(self.var + self.eps) + if self.clip_max: + data_array = np.clip(data_array, -self.clip_max, self.clip_max) + return data_array + + def update(self, data_array: np.ndarray) -> None: + """Add a batch of item into RMS with the same shape, modify mean/var/count.""" + batch_mean, batch_var = np.mean(data_array, axis=0), np.var(data_array, axis=0) + batch_count = len(data_array) + + delta = batch_mean - self.mean + total_count = self.count + batch_count + + new_mean = self.mean + delta * batch_count / total_count + m_a = self.var * self.count + m_b = batch_var * batch_count + m_2 = m_a + m_b + delta**2 * self.count * batch_count / total_count + new_var = m_2 / total_count + + self.mean, self.var = new_mean, new_var + self.count = total_count diff --git a/examples/atari/tianshou/utils/string.py b/examples/atari/tianshou/utils/string.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d479236ded315dc80a57e135226f1897f0e47f --- /dev/null +++ b/examples/atari/tianshou/utils/string.py @@ -0,0 +1,536 @@ +"""Copy of sensai.util.string from sensAI """ +# From commit commit d7b4afcc89b4d2e922a816cb07dffde27f297354 + + +import functools +import logging +import re +import sys +import types +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Mapping, Sequence +from typing import ( + Any, + Self, +) + +reCommaWhitespacePotentiallyBreaks = re.compile(r",\s+") + +log = logging.getLogger(__name__) + +# ruff: noqa + + +class StringConverter(ABC): + """Abstraction for a string conversion mechanism.""" + + @abstractmethod + def to_string(self, x: Any) -> str: + pass + + +def dict_string( + d: Mapping, brackets: str | None = None, converter: StringConverter | None = None +) -> str: + """Converts a dictionary to a string of the form "=, =, ...", optionally enclosed + by brackets. + + :param d: the dictionary + :param brackets: a two-character string containing the opening and closing bracket to use, e.g. ``"{}"``; + if None, do not use enclosing brackets + :param converter: the string converter to use for values + :return: the string representation + """ + s = ", ".join([f"{k}={to_string(v, converter=converter, context=k)}" for k, v in d.items()]) + if brackets is not None: + return brackets[:1] + s + brackets[-1:] + else: + return s + + +def list_string( + l: Iterable[Any], + brackets: str | None = "[]", + quote: str | None = None, + converter: StringConverter | None = None, +) -> str: + """Converts a list or any other iterable to a string of the form "[, , ...]", optionally enclosed + by different brackets or with the values quoted. + + :param l: the list + :param brackets: a two-character string containing the opening and closing bracket to use, e.g. ``"[]"``; + if None, do not use enclosing brackets + :param quote: a 1-character string defining the quote to use around each value, e.g. ``"'"``. + :param converter: the string converter to use for values + :return: the string representation + """ + + def item(x: Any) -> str: + x = to_string(x, converter=converter, context="list") + if quote is not None: + return quote + x + quote + else: + return x + + s = ", ".join(item(x) for x in l) + if brackets is not None: + return brackets[:1] + s + brackets[-1:] + else: + return s + + +def to_string( + x: Any, + converter: StringConverter | None = None, + apply_converter_to_non_complex_objects: bool = True, + context: Any = None, +) -> str: + """Converts the given object to a string, with proper handling of lists, tuples and dictionaries, optionally using a converter. + The conversion also removes unwanted line breaks (as present, in particular, in sklearn's string representations). + + :param x: the object to convert + :param converter: the converter with which to convert objects to strings + :param apply_converter_to_non_complex_objects: whether to apply/pass on the converter (if any) not only when converting complex objects + but also non-complex, primitive objects; use of this flag enables converters to implement their conversion functionality using this + function for complex objects without causing an infinite recursion. + :param context: context in which the object is being converted (e.g. dictionary key for case where x is the corresponding + dictionary value), only for debugging purposes (will be reported in log messages upon recursion exception) + :return: the string representation + """ + try: + if isinstance(x, list): + return list_string(x, converter=converter) + elif isinstance(x, tuple): + return list_string(x, brackets="()", converter=converter) + elif isinstance(x, dict): + return dict_string(x, brackets="{}", converter=converter) + elif isinstance(x, types.MethodType): + # could be bound method of a ToStringMixin instance (which would print the repr of the instance, which can potentially cause + # an infinite recursion) + return f"Method[{x.__name__}]" + else: + if converter and apply_converter_to_non_complex_objects: + s = converter.to_string(x) + else: + s = str(x) + + # remove any unwanted line breaks and indentation after commas (as generated, for example, by sklearn objects) + return reCommaWhitespacePotentiallyBreaks.sub(", ", s) + + except RecursionError: + log.error(f"Recursion in string conversion detected; context={context}") + raise + + +def object_repr(obj: Any, member_names_or_dict: list[str] | dict[str, Any]) -> str: + """Creates a string representation for the given object based on the given members. + + The string takes the form "ClassName[attr1=value1, attr2=value2, ...]" + """ + if isinstance(member_names_or_dict, dict): + members_dict = member_names_or_dict + else: + members_dict = {m: to_string(getattr(obj, m)) for m in member_names_or_dict} + return f"{obj.__class__.__name__}[{dict_string(members_dict)}]" + + +def or_regex_group(allowed_names: Sequence[str]) -> str: + """:param allowed_names: strings to include as literals in the regex + :return: a regular expression string of the form `(| ...|)`, which any of the given names + """ + allowed_names = [re.escape(name) for name in allowed_names] + return r"(%s)" % "|".join(allowed_names) + + +def function_name(x: Callable) -> str: + """Attempts to retrieve the name of the given function/callable object, taking the possibility + of the function being defined via functools.partial into account. + + :param x: a callable object + :return: name of the function or str(x) as a fallback + """ + if isinstance(x, functools.partial): + return function_name(x.func) + elif hasattr(x, "__name__"): + return x.__name__ + else: + return str(x) + + +class ToStringMixin: + """Provides implementations for ``__str__`` and ``__repr__`` which are based on the format ``"[]"`` and + ``"[id=, ]"`` respectively, where ```` is usually a list of entries of the + form ``"=, ..."``. + + By default, ```` will be the qualified name of the class, and ```` will include all properties + of the class, including private ones starting with an underscore (though the underscore will be dropped in the string + representation). + + * To exclude private properties, override :meth:`_toStringExcludePrivate` to return True. If there are exceptions + (and some private properties shall be retained), additionally override :meth:`_toStringExcludeExceptions`. + * To exclude a particular set of properties, override :meth:`_toStringExcludes`. + * To include only select properties (introducing inclusion semantics), override :meth:`_toStringIncludes`. + * To add values to the properties list that aren't actually properties of the object (i.e. derived properties), + override :meth:`_toStringAdditionalEntries`. + * To define a fully custom representation for ```` which is not based on the above principles, override + :meth:`_toStringObjectInfo`. + + For well-defined string conversions within a class hierarchy, it can be a good practice to define additional + inclusions/exclusions by overriding the respective method once more and basing the return value on an extended + version of the value returned by superclass. + In some cases, the requirements of a subclass can be at odds with the definitions in the superclass: The superclass + may make use of exclusion semantics, but the subclass may want to use inclusion semantics (and include + only some of the many properties it adds). In this case, if the subclass used :meth:`_toStringInclude`, the exclusion semantics + of the superclass would be void and none of its properties would actually be included. + In such cases, override :meth:`_toStringIncludesForced` to add inclusions regardless of the semantics otherwise used along + the class hierarchy. + + """ + + _TOSTRING_INCLUDE_ALL = "__all__" + + def _tostring_class_name(self) -> str: + """:return: the string use for in the string representation ``"[ str: + """Creates a string of the class attributes, with optional exclusions/inclusions/additions. + Exclusions take precedence over inclusions. + + :param exclude: attributes to be excluded + :param include: attributes to be included; if non-empty, only the specified attributes will be printed (bar the ones + excluded by ``exclude``) + :param include_forced: additional attributes to be included + :param additional_entries: additional key-value entries to be added + :param converter: the string converter to use; if None, use default (which avoids infinite recursions) + :return: a string containing entry/property names and values + """ + + def mklist(x: Any) -> list[str]: + if x is None: + return [] + if isinstance(x, str): + return [x] + return x + + exclude = mklist(exclude) + include = mklist(include) + include_forced = mklist(include_forced) + exclude_exceptions = mklist(exclude_exceptions) + + def is_excluded(k: Any) -> bool: + if k in include_forced or k in exclude_exceptions: + return False + if k in exclude: + return True + if self._tostring_exclude_private(): + return k.startswith("_") + else: + return False + + # determine relevant attribute dictionary + if ( + len(include) == 1 and include[0] == self._TOSTRING_INCLUDE_ALL + ): # exclude semantics (include everything by default) + attribute_dict = self.__dict__ + else: # include semantics (include only inclusions) + attribute_dict = { + k: getattr(self, k) + for k in set(include + include_forced) + if hasattr(self, k) and k != self._TOSTRING_INCLUDE_ALL + } + + # apply exclusions and remove underscores from attribute names + d = {k.strip("_"): v for k, v in attribute_dict.items() if not is_excluded(k)} + + if additional_entries is not None: + d.update(additional_entries) + + if converter is None: + converter = self._StringConverterAvoidToStringMixinRecursion(self) + return dict_string(d, converter=converter) + + def _tostring_object_info(self) -> str: + """Override this method to use a fully custom definition of the ```` part in the full string + representation ``"[]"`` to be generated. + As soon as this method is overridden, any property-based exclusions, inclusions, etc. will have no effect + (unless the implementation is specifically designed to make use of them - as is the default + implementation). + NOTE: Overrides must not internally use super() because of a technical limitation in the proxy + object that is used for nested object structures. + + :return: a string containing the string to use for ```` + """ + return self._tostring_properties( + exclude=self._tostring_excludes(), + include=self._tostring_includes(), + exclude_exceptions=self._tostring_exclude_exceptions(), + include_forced=self._tostring_includes_forced(), + additional_entries=self._tostring_additional_entries(), + ) + + def _tostring_excludes(self) -> list[str]: + """Makes the string representation exclude the returned attributes. + This method can be conveniently overridden by subclasses which can call super and extend the list returned. + + This method will only have no effect if :meth:`_toStringObjectInfo` is overridden to not use its result. + + :return: a list of attribute names + """ + return [] + + def _tostring_includes(self) -> list[str]: + """Makes the string representation include only the returned attributes (i.e. introduces inclusion semantics); + By default, the list contains only a marker element, which is interpreted as "all attributes included". + + This method can be conveniently overridden by sub-classes which can call super and extend the list returned. + Note that it is not a problem for a list containing the aforementioned marker element (which stands for all attributes) + to be extended; the marker element will be ignored and only the user-added elements will be considered as included. + + Note: To add an included attribute in a sub-class, regardless of any super-classes using exclusion or inclusion semantics, + use _toStringIncludesForced instead. + + This method will have no effect if :meth:`_toStringObjectInfo` is overridden to not use its result. + + :return: a list of attribute names to be included in the string representation + """ + return [self._TOSTRING_INCLUDE_ALL] + + # noinspection PyMethodMayBeStatic + def _tostring_includes_forced(self) -> list[str]: + """Defines a list of attribute names that are required to be present in the string representation, regardless of the + instance using include semantics or exclude semantics, thus facilitating added inclusions in sub-classes. + + This method will have no effect if :meth:`_toStringObjectInfo` is overridden to not use its result. + + :return: a list of attribute names + """ + return [] + + def _tostring_additional_entries(self) -> dict[str, Any]: + """:return: a dictionary of entries to be included in the ```` part of the string representation""" + return {} + + def _tostring_exclude_private(self) -> bool: + """:return: whether to exclude properties that are private (start with an underscore); explicitly included attributes + will still be considered - as will properties exempt from the rule via :meth:`toStringExcludeException`. + """ + return False + + def _tostring_exclude_exceptions(self) -> list[str]: + """Defines attribute names which should not be excluded even though other rules (particularly the exclusion of private members + via :meth:`_toStringExcludePrivate`) would otherwise exclude them. + + :return: a list of attribute names + """ + return [] + + def __str__(self) -> str: + return f"{self._tostring_class_name()}[{self._tostring_object_info()}]" + + def __repr__(self) -> str: + info = f"id={id(self)}" + property_info = self._tostring_object_info() + if len(property_info) > 0: + info += ", " + property_info + return f"{self._tostring_class_name()}[{info}]" + + def pprint(self, file: Any = sys.stdout) -> None: + """Prints a prettily formatted string representation of the object (with line breaks and indentations) + to ``stdout`` or the given file. + + :param file: the file to print to + """ + print(self.pprints(), file=file) + + def pprints(self) -> str: + """:return: a prettily formatted string representation with line breaks and indentations""" + return pretty_string_repr(self) + + class _StringConverterAvoidToStringMixinRecursion(StringConverter): + """Avoids recursions when converting objects implementing :class:`ToStringMixin` which may contain themselves to strings. + Use of this object prevents infinite recursions caused by a :class:`ToStringMixin` instance recursively containing itself in + either a property of another :class:`ToStringMixin`, a list or a tuple. + It handles all :class:`ToStringMixin` instances recursively encountered. + + A previously handled instance is converted to a string of the form "[<<]". + """ + + def __init__(self, *handled_objects: "ToStringMixin"): + """:param handled_objects: objects which are initially assumed to have been handled already""" + self._handled_to_string_mixin_ids = {id(o) for o in handled_objects} + + def to_string(self, x: Any) -> str: + if isinstance(x, ToStringMixin): + oid = id(x) + if oid in self._handled_to_string_mixin_ids: + return f"{x._tostring_class_name()}[<<]" + self._handled_to_string_mixin_ids.add(oid) + return str(self._ToStringMixinProxy(x, self)) + else: + return to_string( + x, + converter=self, + apply_converter_to_non_complex_objects=False, + context=x.__class__, + ) + + class _ToStringMixinProxy: + """A proxy object which wraps a ToStringMixin to ensure that the converter is applied when creating the properties string. + The proxy is to achieve that all ToStringMixin methods that aren't explicitly overwritten are bound to this proxy + (rather than the original object), such that the transitive call to _toStringProperties will call the new + implementation. + """ + + # methods where we assume that they could transitively call _toStringProperties (others are assumed not to) + TOSTRING_METHODS_TRANSITIVELY_CALLING_TOSTRINGPROPERTIES = {"_tostring_object_info"} + + def __init__(self, x: "ToStringMixin", converter: Any) -> None: + self.x = x + self.converter = converter + + def _tostring_properties(self, *args: Any, **kwargs: Any) -> str: + return self.x._tostring_properties(*args, **kwargs, converter=self.converter) # type: ignore[misc] + + def _tostring_class_name(self) -> str: + return self.x._tostring_class_name() + + def __getattr__(self, attr: str) -> Any: + if attr.startswith( + "_tostring", + ): # ToStringMixin method which we may bind to use this proxy to ensure correct transitive call + method = getattr(self.x.__class__, attr) + obj = ( + self + if attr in self.TOSTRING_METHODS_TRANSITIVELY_CALLING_TOSTRINGPROPERTIES + else self.x + ) + return lambda *args, **kwargs: method(obj, *args, **kwargs) + else: + return getattr(self.x, attr) + + def __str__(self) -> str: + return ToStringMixin.__str__(self) # type: ignore[arg-type] + + +def pretty_string_repr( + s: Any, initial_indentation_level: int = 0, indentation_string: str = " " +) -> str: + """Creates a pretty string representation (using indentations) from the given object/string representation (as generated, for example, via + ToStringMixin). An indentation level is added for every opening bracket. + + :param s: an object or object string representation + :param initial_indentation_level: the initial indentation level + :param indentation_string: the string which corresponds to a single indentation level + :return: a reformatted version of the input string with added indentations and line breaks + """ + if not isinstance(s, str): + s = str(s) + indent = initial_indentation_level + result = indentation_string * indent + i = 0 + + def nl() -> None: + nonlocal result + result += "\n" + (indentation_string * indent) + + def take(cnt: int = 1) -> None: + nonlocal result, i + result += s[i : i + cnt] + i += cnt + + def find_matching(j: int) -> int | None: + start = j + op = s[j] + cl = {"[": "]", "(": ")", "'": "'"}[s[j]] + is_bracket = cl != s[j] + stack = 0 + while j < len(s): + if s[j] == op and (is_bracket or j == start): + stack += 1 + elif s[j] == cl: + stack -= 1 + if stack == 0: + return j + j += 1 + return None + + brackets = "[(" + quotes = "'" + while i < len(s): + is_bracket = s[i] in brackets + is_quote = s[i] in quotes + if is_bracket or is_quote: + i_match = find_matching(i) + take_full_match_without_break = False + if i_match is not None: + k = i_match + 1 + full_match = s[i:k] + take_full_match_without_break = is_quote or not ( + "=" in full_match and "," in full_match + ) + if take_full_match_without_break: + take(k - i) + if not take_full_match_without_break: + take(1) + indent += 1 + nl() + elif s[i] in "])": + take(1) + indent -= 1 + elif s[i : i + 2] == ", ": + take(2) + nl() + else: + take(1) + + return result + + +class TagBuilder: + """Assists in building strings made up of components that are joined via a glue string.""" + + def __init__(self, *initial_components: str, glue: str = "_"): + """:param initial_components: initial components to always include at the beginning + :param glue: the glue string which joins components + """ + self.glue = glue + self.components = list(initial_components) + + def with_component(self, component: str) -> Self: + self.components.append(component) + return self + + def with_conditional(self, cond: bool, component: str) -> Self: + """Conditionally adds the given component. + + :param cond: the condition + :param component: the component to add if the condition holds + :return: the builder + """ + if cond: + self.components.append(component) + return self + + def with_alternative(self, cond: bool, true_component: str, false_component: str) -> Self: + """Adds a component depending on a condition. + + :param cond: the condition + :param true_component: the component to add if the condition holds + :param false_component: the component to add if the condition does not hold + :return: the builder + """ + self.components.append(true_component if cond else false_component) + return self + + def build(self) -> str: + """:return: the string (with all components joined)""" + return self.glue.join(self.components) diff --git a/examples/atari/tianshou/utils/torch_utils.py b/examples/atari/tianshou/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..430d174e70f6f0447f44f5a6293532888aaeb242 --- /dev/null +++ b/examples/atari/tianshou/utils/torch_utils.py @@ -0,0 +1,39 @@ +from collections.abc import Iterator +from contextlib import contextmanager +from typing import TYPE_CHECKING + +from torch import nn + +if TYPE_CHECKING: + from tianshou.policy import BasePolicy + + +@contextmanager +def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]: + """Temporarily switch to `module.training=enabled`, affecting things like `BatchNormalization`.""" + original_mode = module.training + try: + module.train(enabled) + yield + finally: + module.train(original_mode) + + +@contextmanager +def policy_within_training_step(policy: "BasePolicy", enabled: bool = True) -> Iterator[None]: + """Temporarily switch to `policy.is_within_training_step=enabled`. + + Enabling this ensures that the policy is able to adapt its behavior, + allowing it to differentiate between training and inference/evaluation, + e.g., to sample actions instead of using the most probable action (where applicable) + Note that for rollout, which also happens within a training step, one would usually want + the wrapped torch module to be in evaluation mode, which can be achieved using + `with torch_train_mode(policy, False)`. For subsequent gradient updates, the policy should be both + within training step and in torch train mode. + """ + original_mode = policy.is_within_training_step + try: + policy.is_within_training_step = enabled + yield + finally: + policy.is_within_training_step = original_mode diff --git a/examples/atari/tianshou/utils/warning.py b/examples/atari/tianshou/utils/warning.py new file mode 100644 index 0000000000000000000000000000000000000000..93c5ccec38f2f1dd83d616d77744e0bc0bdb4cb7 --- /dev/null +++ b/examples/atari/tianshou/utils/warning.py @@ -0,0 +1,8 @@ +import warnings + +warnings.simplefilter("once", DeprecationWarning) + + +def deprecation(msg: str) -> None: + """Deprecation warning wrapper.""" + warnings.warn(msg, category=DeprecationWarning, stacklevel=2) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..503cf45873fc8c77278515dd6c77f4dd6ec853eb Binary files /dev/null and b/requirements.txt differ diff --git a/video-app/rl-video-episode-0.meta.json b/video-app/rl-video-episode-0.meta.json new file mode 100644 index 0000000000000000000000000000000000000000..03b38651743b8a1da81ca17b45a00ccc4b8a7497 --- /dev/null +++ b/video-app/rl-video-episode-0.meta.json @@ -0,0 +1 @@ +{"step_id": 0, "episode_id": 0, "content_type": "video/mp4"} \ No newline at end of file diff --git a/video-app/rl-video-episode-0.mp4 b/video-app/rl-video-episode-0.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..34d7531f5f88f22b1d93129a8dc828a24d28e9b3 Binary files /dev/null and b/video-app/rl-video-episode-0.mp4 differ diff --git a/video-app/rl-video-episode-1.meta.json b/video-app/rl-video-episode-1.meta.json new file mode 100644 index 0000000000000000000000000000000000000000..2e8a6a7e196936f1715e0bbfe1c96acd3e4a4add --- /dev/null +++ b/video-app/rl-video-episode-1.meta.json @@ -0,0 +1 @@ +{"step_id": 2045, "episode_id": 1, "content_type": "video/mp4"} \ No newline at end of file diff --git a/video-app/rl-video-episode-1.mp4 b/video-app/rl-video-episode-1.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a81a999501c2c4902f81ca455e70aa5eebfbf051 Binary files /dev/null and b/video-app/rl-video-episode-1.mp4 differ