gomoku / DI-engine /ding /league /base_league.py
zjowowen's picture
init space
3dfe8fb
from typing import Union, Dict
import uuid
import copy
import os
import os.path as osp
from abc import abstractmethod
from easydict import EasyDict
from tabulate import tabulate
from ding.league.player import ActivePlayer, HistoricalPlayer, create_player
from ding.league.shared_payoff import create_payoff
from ding.utils import import_module, read_file, save_file, LockContext, LockContextType, LEAGUE_REGISTRY, \
deep_merge_dicts
from .metric import LeagueMetricEnv
class BaseLeague:
"""
Overview:
League, proposed by Google Deepmind AlphaStar. Can manage multiple players in one league.
Interface:
get_job_info, judge_snapshot, update_active_player, finish_job, save_checkpoint
.. note::
In ``__init__`` method, league would also initialized players as well(in ``_init_players`` method).
"""
@classmethod
def default_config(cls: type) -> EasyDict:
cfg = EasyDict(copy.deepcopy(cls.config))
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg
config = dict(
league_type='base',
import_names=["ding.league.base_league"],
# ---player----
# "player_category" is just a name. Depends on the env.
# For example, in StarCraft, this can be ['zerg', 'terran', 'protoss'].
player_category=['default'],
# Support different types of active players for solo and battle league.
# For solo league, supports ['solo_active_player'].
# For battle league, supports ['battle_active_player', 'main_player', 'main_exploiter', 'league_exploiter'].
# active_players=dict(),
# "use_pretrain" means whether to use pretrain model to initialize active player.
use_pretrain=False,
# "use_pretrain_init_historical" means whether to use pretrain model to initialize historical player.
# "pretrain_checkpoint_path" is the pretrain checkpoint path used in "use_pretrain" and
# "use_pretrain_init_historical". If both are False, "pretrain_checkpoint_path" can be omitted as well.
# Otherwise, "pretrain_checkpoint_path" should list paths of all player categories.
use_pretrain_init_historical=False,
pretrain_checkpoint_path=dict(default='default_cate_pretrain.pth', ),
# ---payoff---
payoff=dict(
# Supports ['battle']
type='battle',
decay=0.99,
min_win_rate_games=8,
),
metric=dict(
mu=0,
sigma=25 / 3,
beta=25 / 3 / 2,
tau=0.0,
draw_probability=0.02,
),
)
def __init__(self, cfg: EasyDict) -> None:
"""
Overview:
Initialization method.
Arguments:
- cfg (:obj:`EasyDict`): League config.
"""
self.cfg = deep_merge_dicts(self.default_config(), cfg)
self.path_policy = cfg.path_policy
if not osp.exists(self.path_policy):
os.mkdir(self.path_policy)
self.league_uid = str(uuid.uuid1())
# TODO dict players
self.active_players = []
self.historical_players = []
self.player_path = "./league"
self.payoff = create_payoff(self.cfg.payoff)
metric_cfg = self.cfg.metric
self.metric_env = LeagueMetricEnv(metric_cfg.mu, metric_cfg.sigma, metric_cfg.tau, metric_cfg.draw_probability)
self._active_players_lock = LockContext(type_=LockContextType.THREAD_LOCK)
self._init_players()
def _init_players(self) -> None:
"""
Overview:
Initialize players (active & historical) in the league.
"""
# Add different types of active players for each player category, according to ``cfg.active_players``.
for cate in self.cfg.player_category: # Player's category (Depends on the env)
for k, n in self.cfg.active_players.items(): # Active player's type
for i in range(n): # This type's active player number
name = '{}_{}_{}'.format(k, cate, i)
ckpt_path = osp.join(self.path_policy, '{}_ckpt.pth'.format(name))
player = create_player(
self.cfg, k, self.cfg[k], cate, self.payoff, ckpt_path, name, 0, self.metric_env.create_rating()
)
if self.cfg.use_pretrain:
self.save_checkpoint(self.cfg.pretrain_checkpoint_path[cate], ckpt_path)
self.active_players.append(player)
self.payoff.add_player(player)
# Add pretrain player as the initial HistoricalPlayer for each player category.
if self.cfg.use_pretrain_init_historical:
for cate in self.cfg.player_category:
main_player_name = [k for k in self.cfg.keys() if 'main_player' in k]
assert len(main_player_name) == 1, main_player_name
main_player_name = main_player_name[0]
name = '{}_{}_0_pretrain_historical'.format(main_player_name, cate)
parent_name = '{}_{}_0'.format(main_player_name, cate)
hp = HistoricalPlayer(
self.cfg.get(main_player_name),
cate,
self.payoff,
self.cfg.pretrain_checkpoint_path[cate],
name,
0,
self.metric_env.create_rating(),
parent_id=parent_name
)
self.historical_players.append(hp)
self.payoff.add_player(hp)
# Save active players' ``player_id``` & ``player_ckpt```.
self.active_players_ids = [p.player_id for p in self.active_players]
self.active_players_ckpts = [p.checkpoint_path for p in self.active_players]
# Validate active players are unique by ``player_id``.
assert len(self.active_players_ids) == len(set(self.active_players_ids))
def get_job_info(self, player_id: str = None, eval_flag: bool = False) -> dict:
"""
Overview:
Get info dict of the job which is to be launched to an active player.
Arguments:
- player_id (:obj:`str`): The active player's id.
- eval_flag (:obj:`bool`): Whether this is an evaluation job.
Returns:
- job_info (:obj:`dict`): Job info.
ReturnsKeys:
- necessary: ``launch_player`` (the active player)
"""
if player_id is None:
player_id = self.active_players_ids[0]
with self._active_players_lock:
idx = self.active_players_ids.index(player_id)
player = self.active_players[idx]
job_info = self._get_job_info(player, eval_flag)
assert 'launch_player' in job_info.keys() and job_info['launch_player'] == player.player_id
return job_info
@abstractmethod
def _get_job_info(self, player: ActivePlayer, eval_flag: bool = False) -> dict:
"""
Overview:
Real `get_job` method. Called by ``_launch_job``.
Arguments:
- player (:obj:`ActivePlayer`): The active player to be launched a job.
- eval_flag (:obj:`bool`): Whether this is an evaluation job.
Returns:
- job_info (:obj:`dict`): Job info. Should include keys ['lauch_player'].
"""
raise NotImplementedError
def judge_snapshot(self, player_id: str, force: bool = False) -> bool:
"""
Overview:
Judge whether a player is trained enough for snapshot. If yes, call player's ``snapshot``, create a
historical player(prepare the checkpoint and add it to the shared payoff), then mutate it, and return True.
Otherwise, return False.
Arguments:
- player_id (:obj:`ActivePlayer`): The active player's id.
Returns:
- snapshot_or_not (:obj:`dict`): Whether the active player is snapshotted.
"""
with self._active_players_lock:
idx = self.active_players_ids.index(player_id)
player = self.active_players[idx]
if force or player.is_trained_enough():
# Snapshot
hp = player.snapshot(self.metric_env)
self.save_checkpoint(player.checkpoint_path, hp.checkpoint_path)
self.historical_players.append(hp)
self.payoff.add_player(hp)
# Mutate
self._mutate_player(player)
return True
else:
return False
@abstractmethod
def _mutate_player(self, player: ActivePlayer) -> None:
"""
Overview:
Players have the probability to mutate, e.g. Reset network parameters.
Called by ``self.judge_snapshot``.
Arguments:
- player (:obj:`ActivePlayer`): The active player that may mutate.
"""
raise NotImplementedError
def update_active_player(self, player_info: dict) -> None:
"""
Overview:
Update an active player's info.
Arguments:
- player_info (:obj:`dict`): Info dict of the player which is to be updated.
ArgumentsKeys:
- necessary: `player_id`, `train_iteration`
"""
try:
idx = self.active_players_ids.index(player_info['player_id'])
player = self.active_players[idx]
return self._update_player(player, player_info)
except ValueError as e:
print(e)
@abstractmethod
def _update_player(self, player: ActivePlayer, player_info: dict) -> None:
"""
Overview:
Update an active player. Called by ``self.update_active_player``.
Arguments:
- player (:obj:`ActivePlayer`): The active player that will be updated.
- player_info (:obj:`dict`): Info dict of the active player which is to be updated.
"""
raise NotImplementedError
def finish_job(self, job_info: dict) -> None:
"""
Overview:
Finish current job. Update shared payoff to record the game results.
Arguments:
- job_info (:obj:`dict`): A dict containing job result information.
"""
# TODO(nyz) more fine-grained job info
self.payoff.update(job_info)
if 'eval_flag' in job_info and job_info['eval_flag']:
home_id, away_id = job_info['player_id']
home_player, away_player = self.get_player_by_id(home_id), self.get_player_by_id(away_id)
job_info_result = job_info['result']
if isinstance(job_info_result[0], list):
job_info_result = sum(job_info_result, [])
home_player.rating, away_player.rating = self.metric_env.rate_1vs1(
home_player.rating, away_player.rating, result=job_info_result
)
def get_player_by_id(self, player_id: str) -> 'Player': # noqa
if 'historical' in player_id:
return [p for p in self.historical_players if p.player_id == player_id][0]
else:
return [p for p in self.active_players if p.player_id == player_id][0]
@staticmethod
def save_checkpoint(src_checkpoint, dst_checkpoint) -> None:
'''
Overview:
Copy a checkpoint from path ``src_checkpoint`` to path ``dst_checkpoint``.
Arguments:
- src_checkpoint (:obj:`str`): Source checkpoint's path, e.g. s3://alphastar_fake_data/ckpt.pth
- dst_checkpoint (:obj:`str`): Destination checkpoint's path, e.g. s3://alphastar_fake_data/ckpt.pth
'''
checkpoint = read_file(src_checkpoint)
save_file(dst_checkpoint, checkpoint)
def player_rank(self, string: bool = False) -> Union[str, Dict[str, float]]:
rank = {}
for p in self.active_players + self.historical_players:
name = p.player_id
rank[name] = p.rating.exposure
if string:
headers = ["Player ID", "Rank (TrueSkill)"]
data = []
for k, v in rank.items():
data.append([k, "{:.2f}".format(v)])
s = "\n" + tabulate(data, headers=headers, tablefmt='pipe')
return s
else:
return rank
def create_league(cfg: EasyDict, *args) -> BaseLeague:
"""
Overview:
Given the key (league_type), create a new league instance if in league_mapping's values,
or raise an KeyError. In other words, a derived league must first register then call ``create_league``
to get the instance object.
Arguments:
- cfg (:obj:`EasyDict`): league config, necessary keys: [league.import_module, league.learner_type]
Returns:
- league (:obj:`BaseLeague`): the created new league, should be an instance of one of \
league_mapping's values
"""
import_module(cfg.get('import_names', []))
return LEAGUE_REGISTRY.build(cfg.league_type, cfg=cfg, *args)