Spaces:
Runtime error
Runtime error
import gradio as gr | |
import datetime | |
import os | |
import pprint | |
import sys | |
import numpy as np | |
import torch | |
from examples.atari.atari_network import C51 | |
from examples.atari.atari_wrapper import wrap_deepmind | |
from tianshou.data import Collector | |
from examples.atari.tianshou.policy import C51Policy | |
import gymnasium as gym | |
from examples.atari.tianshou.env.venvs import DummyVectorEnv | |
from examples.atari.tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction, FullQuantileFunctionRainbow | |
from examples.atari.atari_network import DQN | |
from examples.atari.tianshou.policy import FQFPolicy,FQF_RainbowPolicy | |
def seed(self, seed): | |
np.random.seed(seed) | |
# Define configuration parameters | |
config_c51 = { | |
"task": "PongNoFrameskip-v4", | |
"seed": 3128, | |
"scale_obs": 0, | |
"eps_test": 0.005, | |
"eps_train": 1.0, | |
"eps_train_final": 0.05, | |
"buffer_size": 100000, | |
"lr": 0.0001, | |
"gamma": 0.99, | |
"num_atoms": 51, | |
"v_min": -10.0, | |
"v_max": 10.0, | |
"n_step": 3, | |
"target_update_freq": 500, | |
"epoch": 100, | |
"step_per_epoch": 100000, | |
"step_per_collect": 10, | |
"update_per_step": 0.1, | |
"batch_size": 32, | |
"training_num": 1, | |
"test_num": 1, | |
"logdir": "log", | |
"render": 0.0, | |
"device": "cuda" if torch.cuda.is_available() else "cpu", | |
"frames_stack": 4, | |
"resume_path": "examples/atari/c51_pong.pth", | |
"resume_id": "", | |
"logger": "tensorboard", | |
"wandb_project": "atari.benchmark", | |
"watch": True, | |
"save_buffer_name": None | |
} | |
config_fqf = { | |
"task": "SpaceInvadersNoFrameskip-v4", | |
"seed": 3128, | |
"scale_obs": 0, | |
"eps_test": 0.005, | |
"eps_train": 1.0, | |
"eps_train_final": 0.05, | |
"buffer_size": 100000, | |
"lr": 5e-5, | |
"fraction_lr": 2.5e-9, | |
"gamma": 0.99, | |
"num_fractions": 32, | |
"num_cosines": 64, | |
"ent_coef": 10.0, | |
"hidden_sizes": [512], | |
"n_step": 3, | |
"target_update_freq": 500, | |
"epoch": 100, | |
"step_per_epoch": 100000, | |
"step_per_collect": 10, | |
"update_per_step": 0.1, | |
"batch_size": 32, | |
"training_num": 1, | |
"test_num": 1, | |
"logdir": "log", | |
"render": 0.0, | |
"device": "cuda" if torch.cuda.is_available() else "cpu", | |
"frames_stack": 4, | |
"resume_path": "fqf_pong.pth", | |
"resume_id": None, | |
"logger": "tensorboard", | |
"wandb_project": "atari.benchmark", | |
"watch": True, | |
"save_buffer_name": None, | |
} | |
config_fqf_r = { | |
"task": "PongNoFrameskip-v4", | |
"algo_name": "RainbowFQF", | |
"seed": 3128, | |
"scale_obs": 0, | |
"eps_test": 0.005, | |
"eps_train": 1.0, | |
"eps_train_final": 0.05, | |
"buffer_size": 100000, | |
"lr": 5e-5, | |
"fraction_lr": 2.5e-9, | |
"gamma": 0.99, | |
"num_fractions": 32, | |
"num_cosines": 64, | |
"ent_coef": 10.0, | |
"hidden_sizes": [512], | |
"n_step": 3, | |
"target_update_freq": 500, | |
"epoch": 100, | |
"step_per_epoch": 100000, | |
"step_per_collect": 10, | |
"update_per_step": 0.1, | |
"batch_size": 32, | |
"training_num": 1, | |
"test_num": 1, | |
"logdir": "log", | |
"no_dueling": False, | |
"no_noisy": False, | |
"no_priority": False, | |
"noisy_std": 0.1, | |
"alpha": 0.5, | |
"beta": 0.4, | |
"beta_final": 1.0, | |
"beta_anneal_step": 5000000, | |
"no_weight_norm": False, | |
"render": 0.0, | |
"device": "cuda" if torch.cuda.is_available() else "cpu", | |
"frames_stack": 4, | |
"resume_path": None, | |
"resume_id": None, | |
"logger": "tensorboard", | |
"wandb_project": "atari.benchmark", | |
"watch": False, | |
"save_buffer_name": None, | |
"per": False, | |
} | |
def test_c51(config : dict) -> None: | |
# _, _, test_envs,_ = make_atari_watch_env( | |
# config["task"], | |
# config["seed"], | |
# config["training_num"], | |
# config["test_num"], | |
# scale=config["scale_obs"], | |
# frame_stack=config["frames_stack"], | |
# ) | |
env_wrap = gym.make(config["task"],render_mode = 'rgb_array') | |
env_wrap.action_space.seed(config["seed"]) | |
env_deep = wrap_deepmind(env_wrap) | |
rec_env = DummyVectorEnv( | |
[ | |
lambda: gym.wrappers.RecordVideo( | |
env_deep, | |
video_folder='video-app/' | |
) | |
] | |
) | |
state_shape = env_deep.observation_space.shape or env_deep.observation_space.n | |
action_shape = env_deep.action_space.shape or env_deep.action_space.n | |
# should be N_FRAMES x H x W | |
print("Observations shape:", state_shape) | |
print("Actions shape:", action_shape) | |
# seed | |
np.random.seed(config["seed"]) | |
torch.manual_seed(config["seed"]) | |
# rec_env.seed(config["seed"]) | |
# test_envs.seed(config["seed"]) | |
print("seed is ",config["seed"]) | |
net = C51(*state_shape, action_shape, config["num_atoms"], config["device"]) | |
optim = torch.optim.Adam(net.parameters(), lr=config["lr"]) | |
# define policy | |
policy = C51Policy( | |
model=net, | |
optim=optim, | |
discount_factor=config["gamma"], | |
action_space=env_deep.action_space, | |
num_atoms=config["num_atoms"], | |
v_min=config["v_min"], | |
v_max=config["v_max"], | |
estimation_step=config["n_step"], | |
target_update_freq=config["target_update_freq"], | |
).to(config["device"]) | |
# load a previous policy | |
if config["resume_path"]: | |
policy.load_state_dict(torch.load(config["resume_path"], map_location=config["device"])) | |
print("Loaded agent from:", config["resume_path"]) | |
collector = Collector(policy, rec_env, exploration_noise=True) | |
# result = collector.collect(n_episode=config["test_num"], render=config["render"]) | |
result = collector.collect(n_episode=config["test_num"]) | |
# Collector(policy, rec_env, exploration_noise=True).collect(n_episode=config["test_num"]) | |
rec_env.close() | |
result.pprint_asdict() | |
return result | |
def test_FQF(config : dict) -> None: | |
# _, _, test_envs,_ = make_atari_watch_env( | |
# config["task"], | |
# config["seed"], | |
# config["training_num"], | |
# config["test_num"], | |
# scale=config["scale_obs"], | |
# frame_stack=config["frames_stack"], | |
# ) | |
env_wrap = gym.make(config["task"],render_mode = 'rgb_array') | |
env_wrap.action_space.seed(config["seed"]) | |
env_deep = wrap_deepmind(env_wrap) | |
rec_env = DummyVectorEnv( | |
[ | |
lambda: gym.wrappers.RecordVideo( | |
env_deep, | |
video_folder='video-app/' | |
) | |
] | |
) | |
state_shape = env_deep.observation_space.shape or env_deep.observation_space.n | |
action_shape = env_deep.action_space.shape or env_deep.action_space.n | |
# should be N_FRAMES x H x W | |
print("Observations shape:", state_shape) | |
print("Actions shape:", action_shape) | |
# seed | |
print(config["seed"]) | |
# np.random.seed(config["seed"]) | |
# torch.manual_seed(config["seed"]) | |
# rec_env.seed(config["seed"]) | |
feature_net = DQN(*state_shape, action_shape, config["device"], features_only=True) | |
# Create FullQuantileFunction net | |
net = FullQuantileFunction( | |
feature_net, | |
action_shape, | |
config["hidden_sizes"], | |
config["num_cosines"], | |
).to(config["device"]) | |
# Create Adam optimizer | |
optim = torch.optim.Adam(net.parameters(), lr=config["lr"]) | |
# Create FractionProposalNetwork | |
fraction_net = FractionProposalNetwork(config["num_fractions"], net.input_dim) | |
# Create RMSprop optimizer for fraction_net | |
fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=config["fraction_lr"]) | |
# Define policy using FQFPolicy | |
policy: FQFPolicy = FQFPolicy( | |
model=net, | |
optim=optim, | |
fraction_model=fraction_net, | |
fraction_optim=fraction_optim, | |
action_space=env_deep.action_space, | |
discount_factor=config["gamma"], | |
num_fractions=config["num_fractions"], | |
ent_coef=config["ent_coef"], | |
estimation_step=config["n_step"], | |
target_update_freq=config["target_update_freq"], | |
).to(config["device"]) | |
# load a previous policy | |
if config["resume_path"]: | |
policy.load_state_dict(torch.load(config["resume_path"], map_location=config["device"])) | |
print("Loaded agent from:", config["resume_path"]) | |
collector = Collector(policy, rec_env, exploration_noise=True) | |
# result = collector.collect(n_episode=config["test_num"], render=config["render"]) | |
result = collector.collect(n_episode=config["test_num"]) | |
# Collector(policy, rec_env, exploration_noise=True).collect(n_episode=config["test_num"]) | |
rec_env.close() | |
result.pprint_asdict() | |
return result | |
def test_fqf_rainbow(config: dict) -> None: | |
# _, _, test_envs,_ = make_atari_watch_env( | |
# config['task'], | |
# config['seed'], | |
# config['training_num'], | |
# config['test_num'], | |
# scale=config['scale_obs'], | |
# frame_stack=config['frames_stack'], | |
# ) | |
env_wrap = gym.make(config["task"],render_mode = 'rgb_array') | |
env_wrap.action_space.seed(config["seed"]) | |
env_deep = wrap_deepmind(env_wrap) | |
rec_env = DummyVectorEnv( | |
[ | |
lambda: gym.wrappers.RecordVideo( | |
env_deep, | |
video_folder='video-app/' | |
) | |
] | |
) | |
config['state_shape'] = env_deep.observation_space.shape or env_deep.observation_space.n | |
config['action_shape'] = env_deep.action_space.shape or env_deep.action_space.n | |
# print(env_deep.action_space) | |
# print(test_envs.action_space) | |
# should be N_FRAMES x H x W | |
# print("Observations shape:", config['state_shape']) | |
# print("Actions shape:", config['action_shape']) | |
# seed | |
print(config["seed"]) | |
# np.random.seed(config['seed']) | |
# torch.manual_seed(config['seed']) | |
# test_envs.seed(config['seed']) | |
# rec_env.seed(config['seed']) | |
# define model | |
feature_net = DQN(*config['state_shape'], config['action_shape'], config['device'], features_only=True) | |
preprocess_net_output_dim = feature_net.output_dim # Ensure this is correctly set | |
# print(preprocess_net_output_dim) | |
net = FullQuantileFunctionRainbow( | |
preprocess_net=feature_net, | |
action_shape=config['action_shape'], | |
hidden_sizes=config['hidden_sizes'], | |
num_cosines=config['num_cosines'], | |
preprocess_net_output_dim=preprocess_net_output_dim, | |
device=config['device'], | |
noisy_std=config['noisy_std'], | |
is_noisy=not config['no_noisy'], # Set to True to use noisy layers | |
is_dueling=not config['no_dueling'], # Set to True to use dueling layers | |
).to(config['device']) | |
# print(net) | |
optim = torch.optim.Adam(net.parameters(), lr=config['lr']) | |
fraction_net = FractionProposalNetwork(config['num_fractions'], net.input_dim) | |
fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=config['fraction_lr']) | |
# define policy | |
policy: FQF_RainbowPolicy = FQF_RainbowPolicy( | |
model=net, | |
optim=optim, | |
fraction_model=fraction_net, | |
fraction_optim=fraction_optim, | |
action_space=env_deep.action_space, | |
discount_factor=config['gamma'], | |
num_fractions=config['num_fractions'], | |
ent_coef=config['ent_coef'], | |
estimation_step=config['n_step'], | |
target_update_freq=config['target_update_freq'], | |
is_noisy=not config['no_noisy'] | |
).to(config['device']) | |
# load a previous policy | |
if config['resume_path']: | |
policy.load_state_dict(torch.load(config['resume_path'], map_location=config['device'])) | |
print("Loaded agent from:", config['resume_path']) | |
# policy.eval() | |
test_collector = Collector(policy, rec_env, exploration_noise=True) | |
result = test_collector.collect(n_episode=config["test_num"]) | |
#replay | |
# Collector(policy, rec_env, exploration_noise=True).collect(n_episode=1) | |
rec_env.close() | |
result.pprint_asdict() | |
return result | |
# Define the function to display choices and mean scores | |
def display_choice(algo, game,slider): | |
# Dictionary to store mean scores for each algorithm and game | |
match algo: | |
case "C51": | |
config_c51["seed"] = slider | |
match game: | |
case "Freeway": | |
config_c51["resume_path"] = "models/c51_freeway.pth" | |
config_c51["task"] = "FreewayNoFrameskip-v4" | |
mean_scores = test_c51(config_c51) | |
case "Pong" : | |
config_c51["resume_path"] = "models/c51_pong.pth" | |
config_c51["task"] = "PongNoFrameskip-v4" | |
mean_scores = test_c51(config_c51) | |
case "FQF": | |
config_fqf["seed"] = slider | |
match game: | |
case "Freeway": | |
config_fqf["resume_path"] = "models/fqf_freeway.pth" | |
config_fqf["task"] = "FreewayNoFrameskip-v4" | |
mean_scores = test_FQF(config_fqf) | |
case "Pong" : | |
config_fqf["resume_path"] = "models/fqf_pong.pth" | |
config_fqf["task"] = "PongNoFrameskip-v4" | |
mean_scores = test_FQF(config_fqf) | |
case "FQF-Rainbow": | |
config_fqf_r["seed"] = slider | |
match game: | |
case "Freeway": | |
config_fqf_r["resume_path"] = "models/fqf-rainbow_freeway.pth" | |
config_fqf_r["task"] = "FreewayNoFrameskip-v4" | |
mean_scores = test_fqf_rainbow(config_fqf_r) | |
case "Pong" : | |
config_fqf_r["resume_path"] = "models/fqf-rainbow_pong.pth" | |
config_fqf_r["task"] = "PongNoFrameskip-v4" | |
mean_scores = test_fqf_rainbow(config_fqf_r) | |
# Calculate or fetch the mean score for the selected combination | |
mean_score = mean_scores.returns_stat.mean | |
# Return the selected options and the mean score | |
# return f"Your {algo} agent finished {game} with a \nMean Score of ##{mean_score}" | |
return [mean_score,"video-app/rl-video-episode-0.mp4"] | |
# Define the choices for the radio buttons | |
algos = ["C51", "FQF", "FQF-Rainbow"] | |
# games = ["Pong", "Space Invaders","Freeway","MsPacman"] | |
games = ["Freeway","Pong"] | |
# Create a Gradio Interface | |
demo = gr.Interface( | |
fn=display_choice, # Function to call when an option is selected | |
inputs=[gr.Radio(algos,label="Algorithm"), gr.Radio(games, label="Game"),gr.Slider(maximum=100,label="Seed")], # Radio buttons with the defined choices | |
outputs=[gr.Textbox(label="Score"),gr.Video(autoplay=True,height=480,width=480,label="Replay")], | |
title="Distributional RL Algorithms Benchmark", | |
description="Select the DRL agent and the game of your choice", | |
theme="soft", | |
examples=[["FQF","Pong",31], | |
["C51","Freeway",31], | |
["FQF-Rainbow","Freeway",31] | |
] | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch(share=False) | |