Spaces:
Sleeping
Sleeping
import os | |
import copy | |
import gym | |
import numpy as np | |
import torch | |
from tensorboardX import SummaryWriter | |
from easydict import EasyDict | |
from ding.config import compile_config | |
from ding.worker import BaseLearner, BattleInteractionSerialEvaluator, NaiveReplayBuffer | |
from ding.envs import BaseEnvManager, DingEnvWrapper | |
from ding.policy import PPOPolicy | |
from ding.model import VAC | |
from ding.utils import set_pkg_seed, Scheduler, deep_merge_dicts | |
from dizoo.league_demo.game_env import GameEnv | |
from dizoo.league_demo.demo_league import DemoLeague | |
from dizoo.league_demo.league_demo_collector import LeagueDemoCollector | |
from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config | |
class EvalPolicy1: | |
def __init__(self, optimal_policy: list) -> None: | |
assert len(optimal_policy) == 2 | |
self.optimal_policy = optimal_policy | |
def forward(self, data: dict) -> dict: | |
return { | |
env_id: { | |
'action': torch.from_numpy(np.random.choice([0, 1], p=self.optimal_policy, size=(1, ))) | |
} | |
for env_id in data.keys() | |
} | |
def reset(self, data_id: list = []) -> None: | |
pass | |
class EvalPolicy2: | |
def forward(self, data: dict) -> dict: | |
return { | |
env_id: { | |
'action': torch.from_numpy(np.random.choice([0, 1], p=[0.5, 0.5], size=(1, ))) | |
} | |
for env_id in data.keys() | |
} | |
def reset(self, data_id: list = []) -> None: | |
pass | |
def main(cfg, seed=0, max_train_iter=int(1e8), max_env_step=int(1e8)): | |
cfg = compile_config( | |
cfg, | |
BaseEnvManager, | |
PPOPolicy, | |
BaseLearner, | |
LeagueDemoCollector, | |
BattleInteractionSerialEvaluator, | |
NaiveReplayBuffer, | |
save_cfg=True | |
) | |
env_type = cfg.env.env_type | |
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num | |
evaluator_env1 = BaseEnvManager( | |
env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager | |
) | |
evaluator_env2 = BaseEnvManager( | |
env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager | |
) | |
evaluator_env3 = BaseEnvManager( | |
env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager | |
) | |
evaluator_env1.seed(seed, dynamic_seed=False) | |
evaluator_env2.seed(seed, dynamic_seed=False) | |
evaluator_env3.seed(seed, dynamic_seed=False) | |
set_pkg_seed(seed, use_cuda=cfg.policy.cuda) | |
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) | |
league = DemoLeague(cfg.policy.other.league) | |
eval_policy1 = EvalPolicy1(evaluator_env1._env_ref.optimal_policy) | |
eval_policy2 = EvalPolicy2() | |
policies = {} | |
learners = {} | |
collectors = {} | |
for player_id in league.active_players_ids: | |
# default set the same arch model(different init weight) | |
model = VAC(**cfg.policy.model) | |
policy = PPOPolicy(cfg.policy, model=model) | |
policies[player_id] = policy | |
collector_env = BaseEnvManager( | |
env_fn=[lambda: GameEnv(env_type) for _ in range(collector_env_num)], cfg=cfg.env.manager | |
) | |
collector_env.seed(seed) | |
learners[player_id] = BaseLearner( | |
cfg.policy.learn.learner, | |
policy.learn_mode, | |
tb_logger=tb_logger, | |
exp_name=cfg.exp_name, | |
instance_name=player_id + '_learner' | |
) | |
collectors[player_id] = LeagueDemoCollector( | |
cfg.policy.collect.collector, | |
collector_env, | |
tb_logger=tb_logger, | |
exp_name=cfg.exp_name, | |
instance_name=player_id + '_collector', | |
) | |
model = VAC(**cfg.policy.model) | |
policy = PPOPolicy(cfg.policy, model=model) | |
policies['historical'] = policy | |
# use initial policy as another eval_policy | |
eval_policy3 = PPOPolicy(cfg.policy, model=copy.deepcopy(model)).collect_mode | |
main_key = [k for k in learners.keys() if k.startswith('main_player')][0] | |
main_player = league.get_player_by_id(main_key) | |
main_learner = learners[main_key] | |
main_collector = collectors[main_key] | |
# collect_mode ppo use multinomial sample for selecting action | |
evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator) | |
evaluator1_cfg.stop_value = cfg.env.stop_value[0] | |
evaluator1 = BattleInteractionSerialEvaluator( | |
evaluator1_cfg, | |
evaluator_env1, [policies[main_key].collect_mode, eval_policy1], | |
tb_logger, | |
exp_name=cfg.exp_name, | |
instance_name='fixed_evaluator' | |
) | |
evaluator2_cfg = copy.deepcopy(cfg.policy.eval.evaluator) | |
evaluator2_cfg.stop_value = cfg.env.stop_value[1] | |
evaluator2 = BattleInteractionSerialEvaluator( | |
evaluator2_cfg, | |
evaluator_env2, [policies[main_key].collect_mode, eval_policy2], | |
tb_logger, | |
exp_name=cfg.exp_name, | |
instance_name='uniform_evaluator' | |
) | |
evaluator3_cfg = copy.deepcopy(cfg.policy.eval.evaluator) | |
evaluator3_cfg.stop_value = 99999999 # stop_value of evaluator3 is a placeholder | |
evaluator3 = BattleInteractionSerialEvaluator( | |
evaluator3_cfg, | |
evaluator_env3, [policies[main_key].collect_mode, eval_policy3], | |
tb_logger, | |
exp_name=cfg.exp_name, | |
instance_name='init_evaluator' | |
) | |
def load_checkpoint_fn(player_id: str, ckpt_path: str): | |
state_dict = torch.load(ckpt_path) | |
policies[player_id].learn_mode.load_state_dict(state_dict) | |
torch.save(policies['historical'].learn_mode.state_dict(), league.reset_checkpoint_path) | |
league.load_checkpoint = load_checkpoint_fn | |
# snapshot the initial player as the first historial player | |
for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts): | |
torch.save(policies[player_id].collect_mode.state_dict(), player_ckpt_path) | |
league.judge_snapshot(player_id, force=True) | |
init_main_player_rating = league.metric_env.create_rating(mu=0) | |
count = 0 | |
while True: | |
if evaluator1.should_eval(main_learner.train_iter): | |
stop_flag1, episode_info = evaluator1.eval( | |
main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep | |
) | |
win_loss_result = [e['result'] for e in episode_info[0]] | |
# set fixed NE policy trueskill(exposure) equal 10 | |
main_player.rating = league.metric_env.rate_1vsC( | |
main_player.rating, league.metric_env.create_rating(mu=10, sigma=1e-8), win_loss_result | |
) | |
if evaluator2.should_eval(main_learner.train_iter): | |
stop_flag2, episode_info = evaluator2.eval( | |
main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep | |
) | |
win_loss_result = [e['result'] for e in episode_info[0]] | |
# set random(uniform) policy trueskill(exposure) equal 0 | |
main_player.rating = league.metric_env.rate_1vsC( | |
main_player.rating, league.metric_env.create_rating(mu=0, sigma=1e-8), win_loss_result | |
) | |
if evaluator3.should_eval(main_learner.train_iter): | |
_, episode_info = evaluator3.eval( | |
main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep | |
) | |
win_loss_result = [e['result'] for e in episode_info[0]] | |
# use init main player as another evaluator metric | |
main_player.rating, init_main_player_rating = league.metric_env.rate_1vs1( | |
main_player.rating, init_main_player_rating, win_loss_result | |
) | |
tb_logger.add_scalar( | |
'league/init_main_player_trueskill', init_main_player_rating.exposure, main_collector.envstep | |
) | |
if stop_flag1 and stop_flag2: | |
break | |
for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts): | |
tb_logger.add_scalar( | |
'league/{}_trueskill'.format(player_id), | |
league.get_player_by_id(player_id).rating.exposure, main_collector.envstep | |
) | |
collector, learner = collectors[player_id], learners[player_id] | |
job = league.get_job_info(player_id) | |
opponent_player_id = job['player_id'][1] | |
# print('job player: {}'.format(job['player_id'])) | |
if 'historical' in opponent_player_id: | |
opponent_policy = policies['historical'].collect_mode | |
opponent_path = job['checkpoint_path'][1] | |
opponent_policy.load_state_dict(torch.load(opponent_path, map_location='cpu')) | |
else: | |
opponent_policy = policies[opponent_player_id].collect_mode | |
collector.reset_policy([policies[player_id].collect_mode, opponent_policy]) | |
train_data, episode_info = collector.collect(train_iter=learner.train_iter) | |
train_data, episode_info = train_data[0], episode_info[0] # only use launch player data for training | |
for d in train_data: | |
d['adv'] = d['reward'] | |
for i in range(cfg.policy.learn.update_per_collect): | |
learner.train(train_data, collector.envstep) | |
torch.save(learner.policy.state_dict(), player_ckpt_path) | |
player_info = learner.learn_info | |
player_info['player_id'] = player_id | |
league.update_active_player(player_info) | |
league.judge_snapshot(player_id) | |
# set eval_flag=True to enable trueskill update | |
job_finish_info = { | |
'eval_flag': True, | |
'launch_player': job['launch_player'], | |
'player_id': job['player_id'], | |
'result': [e['result'] for e in episode_info], | |
} | |
league.finish_job(job_finish_info) | |
if main_collector.envstep >= max_env_step or main_learner.train_iter >= max_train_iter: | |
break | |
if count % 100 == 0: | |
print(repr(league.payoff)) | |
count += 1 | |
if __name__ == "__main__": | |
main(league_demo_ppo_config) | |