Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
from easydict import EasyDict | |
from ding.league import BaseLeague, ActivePlayer | |
class DemoLeague(BaseLeague): | |
def __init__(self, cfg): | |
super(DemoLeague, self).__init__(cfg) | |
self.reset_checkpoint_path = os.path.join(self.path_policy, 'reset_ckpt.pth') | |
# override | |
def _get_job_info(self, player: ActivePlayer, eval_flag: bool = False) -> dict: | |
assert isinstance(player, ActivePlayer), player.__class__ | |
player_job_info = EasyDict(player.get_job(eval_flag)) | |
return { | |
'agent_num': 2, | |
'launch_player': player.player_id, | |
'player_id': [player.player_id, player_job_info.opponent.player_id], | |
'checkpoint_path': [player.checkpoint_path, player_job_info.opponent.checkpoint_path], | |
'player_active_flag': [isinstance(p, ActivePlayer) for p in [player, player_job_info.opponent]], | |
} | |
# override | |
def _mutate_player(self, player: ActivePlayer): | |
for p in self.active_players: | |
result = p.mutate({'reset_checkpoint_path': self.reset_checkpoint_path}) | |
if result is not None: | |
p.rating = self.metric_env.create_rating() | |
self.load_checkpoint(p.player_id, result) # load_checkpoint is set by the caller of league | |
self.save_checkpoint(result, p.checkpoint_path) | |
# override | |
def _update_player(self, player: ActivePlayer, player_info: dict) -> None: | |
assert isinstance(player, ActivePlayer) | |
if 'learner_step' in player_info: | |
player.total_agent_step = player_info['learner_step'] | |
# override | |
def save_checkpoint(src_checkpoint_path: str, dst_checkpoint_path: str) -> None: | |
shutil.copy(src_checkpoint_path, dst_checkpoint_path) | |