import gradio as gr import datetime import os import pprint import sys import numpy as np import torch from examples.atari.atari_network import C51 from examples.atari.atari_wrapper import wrap_deepmind from tianshou.data import Collector from examples.atari.tianshou.policy import C51Policy import gymnasium as gym from examples.atari.tianshou.env.venvs import DummyVectorEnv from examples.atari.tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction, FullQuantileFunctionRainbow from examples.atari.atari_network import DQN from examples.atari.tianshou.policy import FQFPolicy,FQF_RainbowPolicy def seed(self, seed): np.random.seed(seed) # Define configuration parameters config_c51 = { "task": "PongNoFrameskip-v4", "seed": 3128, "scale_obs": 0, "eps_test": 0.005, "eps_train": 1.0, "eps_train_final": 0.05, "buffer_size": 100000, "lr": 0.0001, "gamma": 0.99, "num_atoms": 51, "v_min": -10.0, "v_max": 10.0, "n_step": 3, "target_update_freq": 500, "epoch": 100, "step_per_epoch": 100000, "step_per_collect": 10, "update_per_step": 0.1, "batch_size": 32, "training_num": 1, "test_num": 1, "logdir": "log", "render": 0.0, "device": "cuda" if torch.cuda.is_available() else "cpu", "frames_stack": 4, "resume_path": "examples/atari/c51_pong.pth", "resume_id": "", "logger": "tensorboard", "wandb_project": "atari.benchmark", "watch": True, "save_buffer_name": None } config_fqf = { "task": "SpaceInvadersNoFrameskip-v4", "seed": 3128, "scale_obs": 0, "eps_test": 0.005, "eps_train": 1.0, "eps_train_final": 0.05, "buffer_size": 100000, "lr": 5e-5, "fraction_lr": 2.5e-9, "gamma": 0.99, "num_fractions": 32, "num_cosines": 64, "ent_coef": 10.0, "hidden_sizes": [512], "n_step": 3, "target_update_freq": 500, "epoch": 100, "step_per_epoch": 100000, "step_per_collect": 10, "update_per_step": 0.1, "batch_size": 32, "training_num": 1, "test_num": 1, "logdir": "log", "render": 0.0, "device": "cuda" if torch.cuda.is_available() else "cpu", "frames_stack": 4, "resume_path": "fqf_pong.pth", "resume_id": None, "logger": "tensorboard", "wandb_project": "atari.benchmark", "watch": True, "save_buffer_name": None, } config_fqf_r = { "task": "PongNoFrameskip-v4", "algo_name": "RainbowFQF", "seed": 3128, "scale_obs": 0, "eps_test": 0.005, "eps_train": 1.0, "eps_train_final": 0.05, "buffer_size": 100000, "lr": 5e-5, "fraction_lr": 2.5e-9, "gamma": 0.99, "num_fractions": 32, "num_cosines": 64, "ent_coef": 10.0, "hidden_sizes": [512], "n_step": 3, "target_update_freq": 500, "epoch": 100, "step_per_epoch": 100000, "step_per_collect": 10, "update_per_step": 0.1, "batch_size": 32, "training_num": 1, "test_num": 1, "logdir": "log", "no_dueling": False, "no_noisy": False, "no_priority": False, "noisy_std": 0.1, "alpha": 0.5, "beta": 0.4, "beta_final": 1.0, "beta_anneal_step": 5000000, "no_weight_norm": False, "render": 0.0, "device": "cuda" if torch.cuda.is_available() else "cpu", "frames_stack": 4, "resume_path": None, "resume_id": None, "logger": "tensorboard", "wandb_project": "atari.benchmark", "watch": False, "save_buffer_name": None, "per": False, } def test_c51(config : dict) -> None: # _, _, test_envs,_ = make_atari_watch_env( # config["task"], # config["seed"], # config["training_num"], # config["test_num"], # scale=config["scale_obs"], # frame_stack=config["frames_stack"], # ) env_wrap = gym.make(config["task"],render_mode = 'rgb_array') env_wrap.action_space.seed(config["seed"]) env_deep = wrap_deepmind(env_wrap) rec_env = DummyVectorEnv( [ lambda: gym.wrappers.RecordVideo( env_deep, video_folder='video-app/' ) ] ) state_shape = env_deep.observation_space.shape or env_deep.observation_space.n action_shape = env_deep.action_space.shape or env_deep.action_space.n # should be N_FRAMES x H x W print("Observations shape:", state_shape) print("Actions shape:", action_shape) # seed np.random.seed(config["seed"]) torch.manual_seed(config["seed"]) # rec_env.seed(config["seed"]) # test_envs.seed(config["seed"]) print("seed is ",config["seed"]) net = C51(*state_shape, action_shape, config["num_atoms"], config["device"]) optim = torch.optim.Adam(net.parameters(), lr=config["lr"]) # define policy policy = C51Policy( model=net, optim=optim, discount_factor=config["gamma"], action_space=env_deep.action_space, num_atoms=config["num_atoms"], v_min=config["v_min"], v_max=config["v_max"], estimation_step=config["n_step"], target_update_freq=config["target_update_freq"], ).to(config["device"]) # load a previous policy if config["resume_path"]: policy.load_state_dict(torch.load(config["resume_path"], map_location=config["device"])) print("Loaded agent from:", config["resume_path"]) collector = Collector(policy, rec_env, exploration_noise=True) # result = collector.collect(n_episode=config["test_num"], render=config["render"]) result = collector.collect(n_episode=config["test_num"]) # Collector(policy, rec_env, exploration_noise=True).collect(n_episode=config["test_num"]) rec_env.close() result.pprint_asdict() return result def test_FQF(config : dict) -> None: # _, _, test_envs,_ = make_atari_watch_env( # config["task"], # config["seed"], # config["training_num"], # config["test_num"], # scale=config["scale_obs"], # frame_stack=config["frames_stack"], # ) env_wrap = gym.make(config["task"],render_mode = 'rgb_array') env_wrap.action_space.seed(config["seed"]) env_deep = wrap_deepmind(env_wrap) rec_env = DummyVectorEnv( [ lambda: gym.wrappers.RecordVideo( env_deep, video_folder='video-app/' ) ] ) state_shape = env_deep.observation_space.shape or env_deep.observation_space.n action_shape = env_deep.action_space.shape or env_deep.action_space.n # should be N_FRAMES x H x W print("Observations shape:", state_shape) print("Actions shape:", action_shape) # seed print(config["seed"]) # np.random.seed(config["seed"]) # torch.manual_seed(config["seed"]) # rec_env.seed(config["seed"]) feature_net = DQN(*state_shape, action_shape, config["device"], features_only=True) # Create FullQuantileFunction net net = FullQuantileFunction( feature_net, action_shape, config["hidden_sizes"], config["num_cosines"], ).to(config["device"]) # Create Adam optimizer optim = torch.optim.Adam(net.parameters(), lr=config["lr"]) # Create FractionProposalNetwork fraction_net = FractionProposalNetwork(config["num_fractions"], net.input_dim) # Create RMSprop optimizer for fraction_net fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=config["fraction_lr"]) # Define policy using FQFPolicy policy: FQFPolicy = FQFPolicy( model=net, optim=optim, fraction_model=fraction_net, fraction_optim=fraction_optim, action_space=env_deep.action_space, discount_factor=config["gamma"], num_fractions=config["num_fractions"], ent_coef=config["ent_coef"], estimation_step=config["n_step"], target_update_freq=config["target_update_freq"], ).to(config["device"]) # load a previous policy if config["resume_path"]: policy.load_state_dict(torch.load(config["resume_path"], map_location=config["device"])) print("Loaded agent from:", config["resume_path"]) collector = Collector(policy, rec_env, exploration_noise=True) # result = collector.collect(n_episode=config["test_num"], render=config["render"]) result = collector.collect(n_episode=config["test_num"]) # Collector(policy, rec_env, exploration_noise=True).collect(n_episode=config["test_num"]) rec_env.close() result.pprint_asdict() return result def test_fqf_rainbow(config: dict) -> None: # _, _, test_envs,_ = make_atari_watch_env( # config['task'], # config['seed'], # config['training_num'], # config['test_num'], # scale=config['scale_obs'], # frame_stack=config['frames_stack'], # ) env_wrap = gym.make(config["task"],render_mode = 'rgb_array') env_wrap.action_space.seed(config["seed"]) env_deep = wrap_deepmind(env_wrap) rec_env = DummyVectorEnv( [ lambda: gym.wrappers.RecordVideo( env_deep, video_folder='video-app/' ) ] ) config['state_shape'] = env_deep.observation_space.shape or env_deep.observation_space.n config['action_shape'] = env_deep.action_space.shape or env_deep.action_space.n # print(env_deep.action_space) # print(test_envs.action_space) # should be N_FRAMES x H x W # print("Observations shape:", config['state_shape']) # print("Actions shape:", config['action_shape']) # seed print(config["seed"]) # np.random.seed(config['seed']) # torch.manual_seed(config['seed']) # test_envs.seed(config['seed']) # rec_env.seed(config['seed']) # define model feature_net = DQN(*config['state_shape'], config['action_shape'], config['device'], features_only=True) preprocess_net_output_dim = feature_net.output_dim # Ensure this is correctly set # print(preprocess_net_output_dim) net = FullQuantileFunctionRainbow( preprocess_net=feature_net, action_shape=config['action_shape'], hidden_sizes=config['hidden_sizes'], num_cosines=config['num_cosines'], preprocess_net_output_dim=preprocess_net_output_dim, device=config['device'], noisy_std=config['noisy_std'], is_noisy=not config['no_noisy'], # Set to True to use noisy layers is_dueling=not config['no_dueling'], # Set to True to use dueling layers ).to(config['device']) # print(net) optim = torch.optim.Adam(net.parameters(), lr=config['lr']) fraction_net = FractionProposalNetwork(config['num_fractions'], net.input_dim) fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=config['fraction_lr']) # define policy policy: FQF_RainbowPolicy = FQF_RainbowPolicy( model=net, optim=optim, fraction_model=fraction_net, fraction_optim=fraction_optim, action_space=env_deep.action_space, discount_factor=config['gamma'], num_fractions=config['num_fractions'], ent_coef=config['ent_coef'], estimation_step=config['n_step'], target_update_freq=config['target_update_freq'], is_noisy=not config['no_noisy'] ).to(config['device']) # load a previous policy if config['resume_path']: policy.load_state_dict(torch.load(config['resume_path'], map_location=config['device'])) print("Loaded agent from:", config['resume_path']) # policy.eval() test_collector = Collector(policy, rec_env, exploration_noise=True) result = test_collector.collect(n_episode=config["test_num"]) #replay # Collector(policy, rec_env, exploration_noise=True).collect(n_episode=1) rec_env.close() result.pprint_asdict() return result # Define the function to display choices and mean scores def display_choice(algo, game,slider): # Dictionary to store mean scores for each algorithm and game match algo: case "C51": config_c51["seed"] = slider match game: case "Freeway": config_c51["resume_path"] = "models/c51_freeway.pth" config_c51["task"] = "FreewayNoFrameskip-v4" mean_scores = test_c51(config_c51) case "Pong" : config_c51["resume_path"] = "models/c51_pong.pth" config_c51["task"] = "PongNoFrameskip-v4" mean_scores = test_c51(config_c51) case "FQF": config_fqf["seed"] = slider match game: case "Freeway": config_fqf["resume_path"] = "models/fqf_freeway.pth" config_fqf["task"] = "FreewayNoFrameskip-v4" mean_scores = test_FQF(config_fqf) case "Pong" : config_fqf["resume_path"] = "models/fqf_pong.pth" config_fqf["task"] = "PongNoFrameskip-v4" mean_scores = test_FQF(config_fqf) case "FQF-Rainbow": config_fqf_r["seed"] = slider match game: case "Freeway": config_fqf_r["resume_path"] = "models/fqf-rainbow_freeway.pth" config_fqf_r["task"] = "FreewayNoFrameskip-v4" mean_scores = test_fqf_rainbow(config_fqf_r) case "Pong" : config_fqf_r["resume_path"] = "models/fqf-rainbow_pong.pth" config_fqf_r["task"] = "PongNoFrameskip-v4" mean_scores = test_fqf_rainbow(config_fqf_r) # Calculate or fetch the mean score for the selected combination mean_score = mean_scores.returns_stat.mean # Return the selected options and the mean score # return f"Your {algo} agent finished {game} with a \nMean Score of ##{mean_score}" return [mean_score,"video-app/rl-video-episode-0.mp4"] # Define the choices for the radio buttons algos = ["C51", "FQF", "FQF-Rainbow"] # games = ["Pong", "Space Invaders","Freeway","MsPacman"] games = ["Freeway","Pong"] # Create a Gradio Interface demo = gr.Interface( fn=display_choice, # Function to call when an option is selected inputs=[gr.Radio(algos,label="Algorithm"), gr.Radio(games, label="Game"),gr.Slider(maximum=100,label="Seed")], # Radio buttons with the defined choices outputs=[gr.Textbox(label="Score"),gr.Video(autoplay=True,height=480,width=480,label="Replay")], title="Distributional RL Algorithms Benchmark", description="Select the DRL agent and the game of your choice", theme="soft", examples=[["FQF","Pong",31], ["C51","Freeway",31], ["FQF-Rainbow","Freeway",31] ] ) # Launch the Gradio app if __name__ == "__main__": demo.launch(share=False)