Spaces:
Sleeping
Sleeping
from zoo.board_games.connect4.config.connect4_muzero_bot_mode_config import main_config, create_config | |
from lzero.entry import eval_muzero | |
import numpy as np | |
if __name__ == '__main__': | |
""" | |
Entry point for the evaluation of the MuZero model on the Connect4 environment. | |
Variables: | |
- model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the | |
pretrained model. An absolute path is recommended. In LightZero, the path is usually something like | |
``exp_name/ckpt/ckpt_best.pth.tar``. | |
- returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed. | |
- returns_seeds (:obj:`List[float]`): List to store the returns for each seed. | |
- seeds (:obj:`List[int]`): List of seeds for the environment. | |
- num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed. | |
- total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of | |
seeds and the number of episodes per seed. | |
""" | |
# model_path = './ckpt/ckpt_best.pth.tar' | |
model_path = None | |
seeds = [0] | |
num_episodes_each_seed = 1 | |
# If True, you can play with the agent. | |
# main_config.env.agent_vs_human = True | |
main_config.env.agent_vs_human = False | |
# main_config.env.render_mode = 'image_realtime_mode' | |
main_config.env.render_mode = 'image_savefile_mode' | |
main_config.env.replay_path = './video' | |
main_config.env.prob_random_action_in_bot = 0. | |
main_config.env.bot_action_type = 'rule' | |
create_config.env_manager.type = 'base' | |
main_config.env.evaluator_env_num = 1 | |
main_config.env.n_evaluator_episode = 1 | |
total_test_episodes = num_episodes_each_seed * len(seeds) | |
returns_mean_seeds = [] | |
returns_seeds = [] | |
for seed in seeds: | |
returns_mean, returns = eval_muzero( | |
[main_config, create_config], | |
seed=seed, | |
num_episodes_each_seed=num_episodes_each_seed, | |
print_seed_details=True, | |
model_path=model_path | |
) | |
returns_mean_seeds.append(returns_mean) | |
returns_seeds.append(returns) | |
returns_mean_seeds = np.array(returns_mean_seeds) | |
returns_seeds = np.array(returns_seeds) | |
print("=" * 20) | |
print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).") | |
print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.") | |
print("Across all seeds, the mean reward is:", returns_mean_seeds.mean()) | |
print( | |
f'win rate: {len(np.where(returns_seeds == 1.)[0]) / total_test_episodes}, draw rate: {len(np.where(returns_seeds == 0.)[0]) / total_test_episodes}, lose rate: {len(np.where(returns_seeds == -1.)[0]) / total_test_episodes}' | |
) | |
print("=" * 20) | |