sabretoothedhugs's picture
pong added
2befa79
import gradio as gr
import datetime
import os
import pprint
import sys
import numpy as np
import torch
from examples.atari.atari_network import C51
from examples.atari.atari_wrapper import wrap_deepmind
from tianshou.data import Collector
from examples.atari.tianshou.policy import C51Policy
import gymnasium as gym
from examples.atari.tianshou.env.venvs import DummyVectorEnv
from examples.atari.tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction, FullQuantileFunctionRainbow
from examples.atari.atari_network import DQN
from examples.atari.tianshou.policy import FQFPolicy,FQF_RainbowPolicy
def seed(self, seed):
np.random.seed(seed)
# Define configuration parameters
config_c51 = {
"task": "PongNoFrameskip-v4",
"seed": 3128,
"scale_obs": 0,
"eps_test": 0.005,
"eps_train": 1.0,
"eps_train_final": 0.05,
"buffer_size": 100000,
"lr": 0.0001,
"gamma": 0.99,
"num_atoms": 51,
"v_min": -10.0,
"v_max": 10.0,
"n_step": 3,
"target_update_freq": 500,
"epoch": 100,
"step_per_epoch": 100000,
"step_per_collect": 10,
"update_per_step": 0.1,
"batch_size": 32,
"training_num": 1,
"test_num": 1,
"logdir": "log",
"render": 0.0,
"device": "cuda" if torch.cuda.is_available() else "cpu",
"frames_stack": 4,
"resume_path": "examples/atari/c51_pong.pth",
"resume_id": "",
"logger": "tensorboard",
"wandb_project": "atari.benchmark",
"watch": True,
"save_buffer_name": None
}
config_fqf = {
"task": "SpaceInvadersNoFrameskip-v4",
"seed": 3128,
"scale_obs": 0,
"eps_test": 0.005,
"eps_train": 1.0,
"eps_train_final": 0.05,
"buffer_size": 100000,
"lr": 5e-5,
"fraction_lr": 2.5e-9,
"gamma": 0.99,
"num_fractions": 32,
"num_cosines": 64,
"ent_coef": 10.0,
"hidden_sizes": [512],
"n_step": 3,
"target_update_freq": 500,
"epoch": 100,
"step_per_epoch": 100000,
"step_per_collect": 10,
"update_per_step": 0.1,
"batch_size": 32,
"training_num": 1,
"test_num": 1,
"logdir": "log",
"render": 0.0,
"device": "cuda" if torch.cuda.is_available() else "cpu",
"frames_stack": 4,
"resume_path": "fqf_pong.pth",
"resume_id": None,
"logger": "tensorboard",
"wandb_project": "atari.benchmark",
"watch": True,
"save_buffer_name": None,
}
config_fqf_r = {
"task": "PongNoFrameskip-v4",
"algo_name": "RainbowFQF",
"seed": 3128,
"scale_obs": 0,
"eps_test": 0.005,
"eps_train": 1.0,
"eps_train_final": 0.05,
"buffer_size": 100000,
"lr": 5e-5,
"fraction_lr": 2.5e-9,
"gamma": 0.99,
"num_fractions": 32,
"num_cosines": 64,
"ent_coef": 10.0,
"hidden_sizes": [512],
"n_step": 3,
"target_update_freq": 500,
"epoch": 100,
"step_per_epoch": 100000,
"step_per_collect": 10,
"update_per_step": 0.1,
"batch_size": 32,
"training_num": 1,
"test_num": 1,
"logdir": "log",
"no_dueling": False,
"no_noisy": False,
"no_priority": False,
"noisy_std": 0.1,
"alpha": 0.5,
"beta": 0.4,
"beta_final": 1.0,
"beta_anneal_step": 5000000,
"no_weight_norm": False,
"render": 0.0,
"device": "cuda" if torch.cuda.is_available() else "cpu",
"frames_stack": 4,
"resume_path": None,
"resume_id": None,
"logger": "tensorboard",
"wandb_project": "atari.benchmark",
"watch": False,
"save_buffer_name": None,
"per": False,
}
def test_c51(config : dict) -> None:
# _, _, test_envs,_ = make_atari_watch_env(
# config["task"],
# config["seed"],
# config["training_num"],
# config["test_num"],
# scale=config["scale_obs"],
# frame_stack=config["frames_stack"],
# )
env_wrap = gym.make(config["task"],render_mode = 'rgb_array')
env_wrap.action_space.seed(config["seed"])
env_deep = wrap_deepmind(env_wrap)
rec_env = DummyVectorEnv(
[
lambda: gym.wrappers.RecordVideo(
env_deep,
video_folder='video-app/'
)
]
)
state_shape = env_deep.observation_space.shape or env_deep.observation_space.n
action_shape = env_deep.action_space.shape or env_deep.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", state_shape)
print("Actions shape:", action_shape)
# seed
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
# rec_env.seed(config["seed"])
# test_envs.seed(config["seed"])
print("seed is ",config["seed"])
net = C51(*state_shape, action_shape, config["num_atoms"], config["device"])
optim = torch.optim.Adam(net.parameters(), lr=config["lr"])
# define policy
policy = C51Policy(
model=net,
optim=optim,
discount_factor=config["gamma"],
action_space=env_deep.action_space,
num_atoms=config["num_atoms"],
v_min=config["v_min"],
v_max=config["v_max"],
estimation_step=config["n_step"],
target_update_freq=config["target_update_freq"],
).to(config["device"])
# load a previous policy
if config["resume_path"]:
policy.load_state_dict(torch.load(config["resume_path"], map_location=config["device"]))
print("Loaded agent from:", config["resume_path"])
collector = Collector(policy, rec_env, exploration_noise=True)
# result = collector.collect(n_episode=config["test_num"], render=config["render"])
result = collector.collect(n_episode=config["test_num"])
# Collector(policy, rec_env, exploration_noise=True).collect(n_episode=config["test_num"])
rec_env.close()
result.pprint_asdict()
return result
def test_FQF(config : dict) -> None:
# _, _, test_envs,_ = make_atari_watch_env(
# config["task"],
# config["seed"],
# config["training_num"],
# config["test_num"],
# scale=config["scale_obs"],
# frame_stack=config["frames_stack"],
# )
env_wrap = gym.make(config["task"],render_mode = 'rgb_array')
env_wrap.action_space.seed(config["seed"])
env_deep = wrap_deepmind(env_wrap)
rec_env = DummyVectorEnv(
[
lambda: gym.wrappers.RecordVideo(
env_deep,
video_folder='video-app/'
)
]
)
state_shape = env_deep.observation_space.shape or env_deep.observation_space.n
action_shape = env_deep.action_space.shape or env_deep.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", state_shape)
print("Actions shape:", action_shape)
# seed
print(config["seed"])
# np.random.seed(config["seed"])
# torch.manual_seed(config["seed"])
# rec_env.seed(config["seed"])
feature_net = DQN(*state_shape, action_shape, config["device"], features_only=True)
# Create FullQuantileFunction net
net = FullQuantileFunction(
feature_net,
action_shape,
config["hidden_sizes"],
config["num_cosines"],
).to(config["device"])
# Create Adam optimizer
optim = torch.optim.Adam(net.parameters(), lr=config["lr"])
# Create FractionProposalNetwork
fraction_net = FractionProposalNetwork(config["num_fractions"], net.input_dim)
# Create RMSprop optimizer for fraction_net
fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=config["fraction_lr"])
# Define policy using FQFPolicy
policy: FQFPolicy = FQFPolicy(
model=net,
optim=optim,
fraction_model=fraction_net,
fraction_optim=fraction_optim,
action_space=env_deep.action_space,
discount_factor=config["gamma"],
num_fractions=config["num_fractions"],
ent_coef=config["ent_coef"],
estimation_step=config["n_step"],
target_update_freq=config["target_update_freq"],
).to(config["device"])
# load a previous policy
if config["resume_path"]:
policy.load_state_dict(torch.load(config["resume_path"], map_location=config["device"]))
print("Loaded agent from:", config["resume_path"])
collector = Collector(policy, rec_env, exploration_noise=True)
# result = collector.collect(n_episode=config["test_num"], render=config["render"])
result = collector.collect(n_episode=config["test_num"])
# Collector(policy, rec_env, exploration_noise=True).collect(n_episode=config["test_num"])
rec_env.close()
result.pprint_asdict()
return result
def test_fqf_rainbow(config: dict) -> None:
# _, _, test_envs,_ = make_atari_watch_env(
# config['task'],
# config['seed'],
# config['training_num'],
# config['test_num'],
# scale=config['scale_obs'],
# frame_stack=config['frames_stack'],
# )
env_wrap = gym.make(config["task"],render_mode = 'rgb_array')
env_wrap.action_space.seed(config["seed"])
env_deep = wrap_deepmind(env_wrap)
rec_env = DummyVectorEnv(
[
lambda: gym.wrappers.RecordVideo(
env_deep,
video_folder='video-app/'
)
]
)
config['state_shape'] = env_deep.observation_space.shape or env_deep.observation_space.n
config['action_shape'] = env_deep.action_space.shape or env_deep.action_space.n
# print(env_deep.action_space)
# print(test_envs.action_space)
# should be N_FRAMES x H x W
# print("Observations shape:", config['state_shape'])
# print("Actions shape:", config['action_shape'])
# seed
print(config["seed"])
# np.random.seed(config['seed'])
# torch.manual_seed(config['seed'])
# test_envs.seed(config['seed'])
# rec_env.seed(config['seed'])
# define model
feature_net = DQN(*config['state_shape'], config['action_shape'], config['device'], features_only=True)
preprocess_net_output_dim = feature_net.output_dim # Ensure this is correctly set
# print(preprocess_net_output_dim)
net = FullQuantileFunctionRainbow(
preprocess_net=feature_net,
action_shape=config['action_shape'],
hidden_sizes=config['hidden_sizes'],
num_cosines=config['num_cosines'],
preprocess_net_output_dim=preprocess_net_output_dim,
device=config['device'],
noisy_std=config['noisy_std'],
is_noisy=not config['no_noisy'], # Set to True to use noisy layers
is_dueling=not config['no_dueling'], # Set to True to use dueling layers
).to(config['device'])
# print(net)
optim = torch.optim.Adam(net.parameters(), lr=config['lr'])
fraction_net = FractionProposalNetwork(config['num_fractions'], net.input_dim)
fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=config['fraction_lr'])
# define policy
policy: FQF_RainbowPolicy = FQF_RainbowPolicy(
model=net,
optim=optim,
fraction_model=fraction_net,
fraction_optim=fraction_optim,
action_space=env_deep.action_space,
discount_factor=config['gamma'],
num_fractions=config['num_fractions'],
ent_coef=config['ent_coef'],
estimation_step=config['n_step'],
target_update_freq=config['target_update_freq'],
is_noisy=not config['no_noisy']
).to(config['device'])
# load a previous policy
if config['resume_path']:
policy.load_state_dict(torch.load(config['resume_path'], map_location=config['device']))
print("Loaded agent from:", config['resume_path'])
# policy.eval()
test_collector = Collector(policy, rec_env, exploration_noise=True)
result = test_collector.collect(n_episode=config["test_num"])
#replay
# Collector(policy, rec_env, exploration_noise=True).collect(n_episode=1)
rec_env.close()
result.pprint_asdict()
return result
# Define the function to display choices and mean scores
def display_choice(algo, game,slider):
# Dictionary to store mean scores for each algorithm and game
match algo:
case "C51":
config_c51["seed"] = slider
match game:
case "Freeway":
config_c51["resume_path"] = "models/c51_freeway.pth"
config_c51["task"] = "FreewayNoFrameskip-v4"
mean_scores = test_c51(config_c51)
case "Pong" :
config_c51["resume_path"] = "models/c51_pong.pth"
config_c51["task"] = "PongNoFrameskip-v4"
mean_scores = test_c51(config_c51)
case "FQF":
config_fqf["seed"] = slider
match game:
case "Freeway":
config_fqf["resume_path"] = "models/fqf_freeway.pth"
config_fqf["task"] = "FreewayNoFrameskip-v4"
mean_scores = test_FQF(config_fqf)
case "Pong" :
config_fqf["resume_path"] = "models/fqf_pong.pth"
config_fqf["task"] = "PongNoFrameskip-v4"
mean_scores = test_FQF(config_fqf)
case "FQF-Rainbow":
config_fqf_r["seed"] = slider
match game:
case "Freeway":
config_fqf_r["resume_path"] = "models/fqf-rainbow_freeway.pth"
config_fqf_r["task"] = "FreewayNoFrameskip-v4"
mean_scores = test_fqf_rainbow(config_fqf_r)
case "Pong" :
config_fqf_r["resume_path"] = "models/fqf-rainbow_pong.pth"
config_fqf_r["task"] = "PongNoFrameskip-v4"
mean_scores = test_fqf_rainbow(config_fqf_r)
# Calculate or fetch the mean score for the selected combination
mean_score = mean_scores.returns_stat.mean
# Return the selected options and the mean score
# return f"Your {algo} agent finished {game} with a \nMean Score of ##{mean_score}"
return [mean_score,"video-app/rl-video-episode-0.mp4"]
# Define the choices for the radio buttons
algos = ["C51", "FQF", "FQF-Rainbow"]
# games = ["Pong", "Space Invaders","Freeway","MsPacman"]
games = ["Freeway","Pong"]
# Create a Gradio Interface
demo = gr.Interface(
fn=display_choice, # Function to call when an option is selected
inputs=[gr.Radio(algos,label="Algorithm"), gr.Radio(games, label="Game"),gr.Slider(maximum=100,label="Seed")], # Radio buttons with the defined choices
outputs=[gr.Textbox(label="Score"),gr.Video(autoplay=True,height=480,width=480,label="Replay")],
title="Distributional RL Algorithms Benchmark",
description="Select the DRL agent and the game of your choice",
theme="soft",
examples=[["FQF","Pong",31],
["C51","Freeway",31],
["FQF-Rainbow","Freeway",31]
]
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch(share=False)