|
import gradio as gr |
|
import datetime |
|
import os |
|
import pprint |
|
import sys |
|
|
|
import numpy as np |
|
import torch |
|
from examples.atari.atari_network import C51 |
|
from examples.atari.atari_wrapper import wrap_deepmind |
|
|
|
from tianshou.data import Collector |
|
from examples.atari.tianshou.policy import C51Policy |
|
|
|
import gymnasium as gym |
|
|
|
from examples.atari.tianshou.env.venvs import DummyVectorEnv |
|
|
|
|
|
|
|
from examples.atari.tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction, FullQuantileFunctionRainbow |
|
from examples.atari.atari_network import DQN |
|
from examples.atari.tianshou.policy import FQFPolicy,FQF_RainbowPolicy |
|
|
|
|
|
def seed(self, seed): |
|
np.random.seed(seed) |
|
|
|
|
|
config_c51 = { |
|
"task": "PongNoFrameskip-v4", |
|
"seed": 3128, |
|
"scale_obs": 0, |
|
"eps_test": 0.005, |
|
"eps_train": 1.0, |
|
"eps_train_final": 0.05, |
|
"buffer_size": 100000, |
|
"lr": 0.0001, |
|
"gamma": 0.99, |
|
"num_atoms": 51, |
|
"v_min": -10.0, |
|
"v_max": 10.0, |
|
"n_step": 3, |
|
"target_update_freq": 500, |
|
"epoch": 100, |
|
"step_per_epoch": 100000, |
|
"step_per_collect": 10, |
|
"update_per_step": 0.1, |
|
"batch_size": 32, |
|
"training_num": 1, |
|
"test_num": 1, |
|
"logdir": "log", |
|
"render": 0.0, |
|
"device": "cuda" if torch.cuda.is_available() else "cpu", |
|
"frames_stack": 4, |
|
"resume_path": "examples/atari/c51_pong.pth", |
|
"resume_id": "", |
|
"logger": "tensorboard", |
|
"wandb_project": "atari.benchmark", |
|
"watch": True, |
|
"save_buffer_name": None |
|
} |
|
|
|
|
|
config_fqf = { |
|
"task": "SpaceInvadersNoFrameskip-v4", |
|
"seed": 3128, |
|
"scale_obs": 0, |
|
"eps_test": 0.005, |
|
"eps_train": 1.0, |
|
"eps_train_final": 0.05, |
|
"buffer_size": 100000, |
|
"lr": 5e-5, |
|
"fraction_lr": 2.5e-9, |
|
"gamma": 0.99, |
|
"num_fractions": 32, |
|
"num_cosines": 64, |
|
"ent_coef": 10.0, |
|
"hidden_sizes": [512], |
|
"n_step": 3, |
|
"target_update_freq": 500, |
|
"epoch": 100, |
|
"step_per_epoch": 100000, |
|
"step_per_collect": 10, |
|
"update_per_step": 0.1, |
|
"batch_size": 32, |
|
"training_num": 1, |
|
"test_num": 1, |
|
"logdir": "log", |
|
"render": 0.0, |
|
"device": "cuda" if torch.cuda.is_available() else "cpu", |
|
"frames_stack": 4, |
|
"resume_path": "fqf_pong.pth", |
|
"resume_id": None, |
|
"logger": "tensorboard", |
|
"wandb_project": "atari.benchmark", |
|
"watch": True, |
|
"save_buffer_name": None, |
|
} |
|
|
|
|
|
config_fqf_r = { |
|
"task": "PongNoFrameskip-v4", |
|
"algo_name": "RainbowFQF", |
|
"seed": 3128, |
|
"scale_obs": 0, |
|
"eps_test": 0.005, |
|
"eps_train": 1.0, |
|
"eps_train_final": 0.05, |
|
"buffer_size": 100000, |
|
"lr": 5e-5, |
|
"fraction_lr": 2.5e-9, |
|
"gamma": 0.99, |
|
"num_fractions": 32, |
|
"num_cosines": 64, |
|
"ent_coef": 10.0, |
|
"hidden_sizes": [512], |
|
"n_step": 3, |
|
"target_update_freq": 500, |
|
"epoch": 100, |
|
"step_per_epoch": 100000, |
|
"step_per_collect": 10, |
|
"update_per_step": 0.1, |
|
"batch_size": 32, |
|
"training_num": 1, |
|
"test_num": 1, |
|
"logdir": "log", |
|
"no_dueling": False, |
|
"no_noisy": False, |
|
"no_priority": False, |
|
"noisy_std": 0.1, |
|
"alpha": 0.5, |
|
"beta": 0.4, |
|
"beta_final": 1.0, |
|
"beta_anneal_step": 5000000, |
|
"no_weight_norm": False, |
|
"render": 0.0, |
|
"device": "cuda" if torch.cuda.is_available() else "cpu", |
|
"frames_stack": 4, |
|
"resume_path": None, |
|
"resume_id": None, |
|
"logger": "tensorboard", |
|
"wandb_project": "atari.benchmark", |
|
"watch": False, |
|
"save_buffer_name": None, |
|
"per": False, |
|
} |
|
|
|
|
|
def test_c51(config : dict) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
env_wrap = gym.make(config["task"],render_mode = 'rgb_array') |
|
env_wrap.action_space.seed(config["seed"]) |
|
env_deep = wrap_deepmind(env_wrap) |
|
rec_env = DummyVectorEnv( |
|
[ |
|
lambda: gym.wrappers.RecordVideo( |
|
env_deep, |
|
video_folder='video-app/' |
|
) |
|
] |
|
) |
|
state_shape = env_deep.observation_space.shape or env_deep.observation_space.n |
|
action_shape = env_deep.action_space.shape or env_deep.action_space.n |
|
|
|
print("Observations shape:", state_shape) |
|
print("Actions shape:", action_shape) |
|
|
|
np.random.seed(config["seed"]) |
|
torch.manual_seed(config["seed"]) |
|
|
|
|
|
print("seed is ",config["seed"]) |
|
|
|
|
|
net = C51(*state_shape, action_shape, config["num_atoms"], config["device"]) |
|
optim = torch.optim.Adam(net.parameters(), lr=config["lr"]) |
|
|
|
policy = C51Policy( |
|
model=net, |
|
optim=optim, |
|
discount_factor=config["gamma"], |
|
action_space=env_deep.action_space, |
|
num_atoms=config["num_atoms"], |
|
v_min=config["v_min"], |
|
v_max=config["v_max"], |
|
estimation_step=config["n_step"], |
|
target_update_freq=config["target_update_freq"], |
|
).to(config["device"]) |
|
|
|
if config["resume_path"]: |
|
policy.load_state_dict(torch.load(config["resume_path"], map_location=config["device"])) |
|
print("Loaded agent from:", config["resume_path"]) |
|
|
|
|
|
|
|
collector = Collector(policy, rec_env, exploration_noise=True) |
|
|
|
result = collector.collect(n_episode=config["test_num"]) |
|
|
|
rec_env.close() |
|
result.pprint_asdict() |
|
return result |
|
|
|
def test_FQF(config : dict) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
env_wrap = gym.make(config["task"],render_mode = 'rgb_array') |
|
env_wrap.action_space.seed(config["seed"]) |
|
env_deep = wrap_deepmind(env_wrap) |
|
|
|
rec_env = DummyVectorEnv( |
|
[ |
|
lambda: gym.wrappers.RecordVideo( |
|
env_deep, |
|
video_folder='video-app/' |
|
) |
|
] |
|
) |
|
|
|
state_shape = env_deep.observation_space.shape or env_deep.observation_space.n |
|
action_shape = env_deep.action_space.shape or env_deep.action_space.n |
|
|
|
print("Observations shape:", state_shape) |
|
print("Actions shape:", action_shape) |
|
|
|
print(config["seed"]) |
|
|
|
|
|
|
|
|
|
feature_net = DQN(*state_shape, action_shape, config["device"], features_only=True) |
|
|
|
|
|
net = FullQuantileFunction( |
|
feature_net, |
|
action_shape, |
|
config["hidden_sizes"], |
|
config["num_cosines"], |
|
).to(config["device"]) |
|
|
|
|
|
optim = torch.optim.Adam(net.parameters(), lr=config["lr"]) |
|
|
|
|
|
fraction_net = FractionProposalNetwork(config["num_fractions"], net.input_dim) |
|
|
|
|
|
fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=config["fraction_lr"]) |
|
|
|
|
|
policy: FQFPolicy = FQFPolicy( |
|
model=net, |
|
optim=optim, |
|
fraction_model=fraction_net, |
|
fraction_optim=fraction_optim, |
|
action_space=env_deep.action_space, |
|
discount_factor=config["gamma"], |
|
num_fractions=config["num_fractions"], |
|
ent_coef=config["ent_coef"], |
|
estimation_step=config["n_step"], |
|
target_update_freq=config["target_update_freq"], |
|
).to(config["device"]) |
|
|
|
|
|
if config["resume_path"]: |
|
policy.load_state_dict(torch.load(config["resume_path"], map_location=config["device"])) |
|
print("Loaded agent from:", config["resume_path"]) |
|
|
|
|
|
collector = Collector(policy, rec_env, exploration_noise=True) |
|
|
|
|
|
result = collector.collect(n_episode=config["test_num"]) |
|
|
|
rec_env.close() |
|
result.pprint_asdict() |
|
return result |
|
|
|
|
|
|
|
def test_fqf_rainbow(config: dict) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
env_wrap = gym.make(config["task"],render_mode = 'rgb_array') |
|
env_wrap.action_space.seed(config["seed"]) |
|
env_deep = wrap_deepmind(env_wrap) |
|
rec_env = DummyVectorEnv( |
|
[ |
|
lambda: gym.wrappers.RecordVideo( |
|
env_deep, |
|
video_folder='video-app/' |
|
) |
|
] |
|
) |
|
|
|
|
|
config['state_shape'] = env_deep.observation_space.shape or env_deep.observation_space.n |
|
config['action_shape'] = env_deep.action_space.shape or env_deep.action_space.n |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(config["seed"]) |
|
|
|
|
|
|
|
|
|
|
|
feature_net = DQN(*config['state_shape'], config['action_shape'], config['device'], features_only=True) |
|
preprocess_net_output_dim = feature_net.output_dim |
|
|
|
net = FullQuantileFunctionRainbow( |
|
preprocess_net=feature_net, |
|
action_shape=config['action_shape'], |
|
hidden_sizes=config['hidden_sizes'], |
|
num_cosines=config['num_cosines'], |
|
preprocess_net_output_dim=preprocess_net_output_dim, |
|
device=config['device'], |
|
noisy_std=config['noisy_std'], |
|
is_noisy=not config['no_noisy'], |
|
is_dueling=not config['no_dueling'], |
|
).to(config['device']) |
|
|
|
optim = torch.optim.Adam(net.parameters(), lr=config['lr']) |
|
fraction_net = FractionProposalNetwork(config['num_fractions'], net.input_dim) |
|
fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=config['fraction_lr']) |
|
|
|
policy: FQF_RainbowPolicy = FQF_RainbowPolicy( |
|
model=net, |
|
optim=optim, |
|
fraction_model=fraction_net, |
|
fraction_optim=fraction_optim, |
|
action_space=env_deep.action_space, |
|
discount_factor=config['gamma'], |
|
num_fractions=config['num_fractions'], |
|
ent_coef=config['ent_coef'], |
|
estimation_step=config['n_step'], |
|
target_update_freq=config['target_update_freq'], |
|
is_noisy=not config['no_noisy'] |
|
).to(config['device']) |
|
|
|
if config['resume_path']: |
|
policy.load_state_dict(torch.load(config['resume_path'], map_location=config['device'])) |
|
print("Loaded agent from:", config['resume_path']) |
|
|
|
test_collector = Collector(policy, rec_env, exploration_noise=True) |
|
result = test_collector.collect(n_episode=config["test_num"]) |
|
|
|
|
|
|
|
|
|
rec_env.close() |
|
result.pprint_asdict() |
|
return result |
|
|
|
|
|
|
|
def display_choice(algo, game,slider): |
|
|
|
match algo: |
|
case "C51": |
|
config_c51["seed"] = slider |
|
match game: |
|
case "Freeway": |
|
config_c51["resume_path"] = "models/c51_freeway.pth" |
|
config_c51["task"] = "FreewayNoFrameskip-v4" |
|
mean_scores = test_c51(config_c51) |
|
|
|
case "Pong" : |
|
config_c51["resume_path"] = "models/c51_pong.pth" |
|
config_c51["task"] = "PongNoFrameskip-v4" |
|
mean_scores = test_c51(config_c51) |
|
|
|
|
|
case "FQF": |
|
config_fqf["seed"] = slider |
|
match game: |
|
case "Freeway": |
|
config_fqf["resume_path"] = "models/fqf_freeway.pth" |
|
config_fqf["task"] = "FreewayNoFrameskip-v4" |
|
mean_scores = test_FQF(config_fqf) |
|
|
|
case "Pong" : |
|
config_fqf["resume_path"] = "models/fqf_pong.pth" |
|
config_fqf["task"] = "PongNoFrameskip-v4" |
|
mean_scores = test_FQF(config_fqf) |
|
|
|
case "FQF-Rainbow": |
|
config_fqf_r["seed"] = slider |
|
match game: |
|
case "Freeway": |
|
config_fqf_r["resume_path"] = "models/fqf-rainbow_freeway.pth" |
|
config_fqf_r["task"] = "FreewayNoFrameskip-v4" |
|
mean_scores = test_fqf_rainbow(config_fqf_r) |
|
|
|
case "Pong" : |
|
config_fqf_r["resume_path"] = "models/fqf-rainbow_pong.pth" |
|
config_fqf_r["task"] = "PongNoFrameskip-v4" |
|
mean_scores = test_fqf_rainbow(config_fqf_r) |
|
|
|
|
|
|
|
|
|
mean_score = mean_scores.returns_stat.mean |
|
|
|
|
|
|
|
return [mean_score,"video-app/rl-video-episode-0.mp4"] |
|
|
|
|
|
algos = ["C51", "FQF", "FQF-Rainbow"] |
|
|
|
games = ["Freeway","Pong"] |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=display_choice, |
|
inputs=[gr.Radio(algos,label="Algorithm"), gr.Radio(games, label="Game"),gr.Slider(maximum=100,label="Seed")], |
|
outputs=[gr.Textbox(label="Score"),gr.Video(autoplay=True,height=480,width=480,label="Replay")], |
|
title="Distributional RL Algorithms Benchmark", |
|
description="Select the DRL agent and the game of your choice", |
|
theme="soft", |
|
examples=[["FQF","Pong",31], |
|
["C51","Freeway",31], |
|
["FQF-Rainbow","Freeway",31] |
|
] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=False) |
|
|