Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import pytest | |
from easydict import EasyDict | |
from ding.league.player import Player, HistoricalPlayer, ActivePlayer, create_player | |
from ding.league.shared_payoff import create_payoff | |
from ding.league.starcraft_player import MainPlayer, MainExploiter, LeagueExploiter | |
from ding.league.tests.league_test_default_config import league_test_config | |
from ding.league.metric import LeagueMetricEnv | |
ONE_PHASE_STEP = 2000 | |
env = LeagueMetricEnv() | |
def setup_payoff(): | |
cfg = EasyDict({'type': 'battle', 'decay': 0.99}) | |
return create_payoff(cfg) | |
def setup_league(setup_payoff): | |
players = [] | |
for category in ['zerg', 'terran', 'protoss']: | |
# main_player | |
main_player_name = '{}_{}'.format('MainPlayer', category) | |
players.append( | |
create_player( | |
league_test_config.league, 'main_player', league_test_config.league.main_player, category, setup_payoff, | |
'ckpt_{}.pth'.format(main_player_name), main_player_name, 0, env.create_rating() | |
) | |
) | |
# main_exloiter | |
main_exploiter_name = '{}_{}'.format('MainExploiter', category) | |
players.append( | |
create_player( | |
league_test_config.league, 'main_exploiter', league_test_config.league.main_exploiter, category, | |
setup_payoff, 'ckpt_{}.pth'.format(main_exploiter_name), main_exploiter_name, 0, env.create_rating() | |
) | |
) | |
# league_exploiter | |
league_exploiter_name = '{}_{}'.format('LeagueExploiter', category) | |
for i in range(2): | |
players.append( | |
create_player( | |
league_test_config.league, | |
'league_exploiter', | |
league_test_config.league.league_exploiter, | |
category, | |
setup_payoff, | |
'ckpt_{}.pth'.format(league_exploiter_name), | |
league_exploiter_name, | |
0, | |
env.create_rating(), | |
) | |
) | |
# historical player: sl player is used as initial HistoricalPlayer | |
sl_hp_name = '{}_{}_sl'.format('MainPlayer', category) | |
players.append( | |
create_player( | |
league_test_config.league, | |
'historical_player', | |
EasyDict(), | |
category, | |
setup_payoff, | |
'ckpt_sl_{}'.format(sl_hp_name), | |
sl_hp_name, | |
0, | |
env.create_rating(), | |
parent_id=main_player_name, | |
) | |
) | |
for p in players: | |
setup_payoff.add_player(p) | |
return players | |
class TestMainPlayer: | |
def test_get_job(self, setup_league, setup_payoff): | |
N = 10 | |
# no indicated p | |
# test get_job | |
for p in setup_league: | |
if isinstance(p, MainPlayer): | |
for i in range(N): | |
job_dict = p.get_job() | |
assert isinstance(job_dict, dict) | |
opponent = job_dict['opponent'] | |
assert isinstance(opponent, Player) | |
assert opponent in setup_league | |
# payoff = setup_league[np.random.randint(0, len(setup_league))].payoff # random select reference | |
hp_list = [] | |
for p in setup_league: | |
if isinstance(p, ActivePlayer): | |
p.total_agent_step = 2 * ONE_PHASE_STEP | |
hp = p.snapshot(env) | |
hp_list.append(hp) | |
setup_payoff.add_player(hp) | |
setup_league += hp_list # 12+3 + 12 | |
# test get_job with branch prob | |
pfsp, sp, veri = False, False, False | |
for p in setup_league: | |
if isinstance(p, MainPlayer): | |
while True: | |
job_dict = p.get_job() | |
opponent = job_dict['opponent'] | |
if isinstance(opponent, HistoricalPlayer) and 'MainPlayer' in opponent.parent_id: | |
veri = True | |
elif isinstance(opponent, HistoricalPlayer): | |
pfsp = True | |
elif isinstance(opponent, MainPlayer): | |
sp = True | |
else: | |
raise Exception("Main Player selects a wrong opponent {}", type(opponent)) | |
if veri and pfsp and sp: | |
break | |
def test_snapshot(self, setup_league, setup_payoff): | |
N = 10 | |
for p in setup_league: | |
for i in range(N): | |
if isinstance(p, ActivePlayer): | |
hp = p.snapshot(env) | |
assert isinstance(hp, HistoricalPlayer) | |
assert id(hp.payoff) == id(p.payoff) | |
assert hp.parent_id == p.player_id | |
def test_is_trained_enough(self, setup_league, setup_payoff): | |
for p in setup_league: | |
if isinstance(p, ActivePlayer): | |
assert not p.is_trained_enough() | |
assert p._last_enough_step == 0 | |
# step_passed < ONE_PHASE_STEP | |
p.total_agent_step = ONE_PHASE_STEP * 0.99 | |
assert not p.is_trained_enough() | |
assert p._last_enough_step == 0 | |
# ONE_PHASE_STEP < step_passed < 2*ONE_PHASE_STEP, but low win rate | |
p.total_agent_step = ONE_PHASE_STEP + 1 | |
assert not p.is_trained_enough() | |
assert p._last_enough_step == 0 | |
# prepare HistoricalPlayer | |
# payoff = setup_league[np.random.randint(0, len(setup_league))].payoff # random select reference | |
hp_list = [] | |
for p in setup_league: | |
if isinstance(p, MainPlayer): | |
hp = p.snapshot(env) | |
setup_payoff.add_player(hp) | |
hp_list.append(hp) | |
setup_league += hp_list | |
# update 10 wins against all historical players, should be trained enough | |
N = 10 | |
assert isinstance(setup_league[0], MainPlayer) | |
for n in range(N): | |
for hp in [p for p in setup_league if isinstance(p, HistoricalPlayer)]: | |
match_info = { | |
'player_id': [setup_league[0].player_id, hp.player_id], | |
'result': [['wins']], | |
} | |
result = setup_payoff.update(match_info) | |
assert result | |
assert setup_league[0]._total_agent_step > ONE_PHASE_STEP | |
assert setup_league[0]._last_enough_step == 0 | |
assert setup_league[0]._last_enough_step != setup_league[0]._total_agent_step | |
assert setup_league[0].is_trained_enough() | |
assert setup_league[0]._last_enough_step == setup_league[0]._total_agent_step | |
# update 10 draws against all historical players, should be not trained enough; | |
# then update ``total_agent_step`` to 2*ONE_PHASE_STEP, should be trained enough | |
assert isinstance(setup_league[5], MainPlayer) | |
for n in range(N): | |
for hp in hp_list: | |
match_info = { | |
'player_id': [setup_league[5].player_id, hp.player_id], | |
'result': [['draws']], | |
} | |
result = setup_payoff.update(match_info) | |
assert result | |
assert setup_league[5]._total_agent_step > ONE_PHASE_STEP | |
assert not setup_league[5].is_trained_enough() | |
setup_league[5].total_agent_step = 2 * ONE_PHASE_STEP | |
assert setup_league[5].is_trained_enough() | |
def test_mutate(self, setup_league, setup_payoff): | |
# main players do not mutate | |
assert isinstance(setup_league[0], MainPlayer) | |
for _ in range(10): | |
assert setup_league[0].mutate({}) is None | |
def test_sp_historical(self, setup_league, setup_payoff): | |
N = 10 | |
main1 = setup_league[0] # 'zerg' | |
main2 = setup_league[5] # 'terran' | |
assert isinstance(main1, MainPlayer) | |
assert isinstance(main2, MainPlayer) | |
for n in range(N): | |
match_info = { | |
'player_id': [main1.player_id, main2.player_id], | |
'result': [['wins']], | |
} | |
result = setup_payoff.update(match_info) | |
assert result | |
for _ in range(200): | |
opponent = main2._sp_branch() | |
condition1 = opponent.category == 'terran' or opponent.category == 'protoss' | |
# condition2 means: zerg_main_opponent is too strong, so that must choose a historical weaker one | |
condition2 = opponent.category == 'zerg' and isinstance( | |
opponent, HistoricalPlayer | |
) and opponent.parent_id == main1.player_id | |
assert condition1 or condition2, (condition1, condition2) | |
class TestMainExploiter: | |
def test_get_job(self, setup_league, random_job_result, setup_payoff): | |
assert isinstance(setup_league[1], MainExploiter) | |
job_dict = setup_league[1].get_job() | |
opponent = job_dict['opponent'] | |
assert isinstance(opponent, MainPlayer) | |
N = 10 | |
# payoff = setup_league[np.random.randint(0, len(setup_league))].payoff # random select reference | |
for n in range(N): | |
for p in setup_league: | |
if isinstance(p, MainPlayer): | |
match_info = { | |
'player_id': [setup_league[1].player_id, p.player_id], | |
'result': [['losses']], | |
} | |
assert setup_payoff.update(match_info) | |
job_dict = setup_league[1].get_job() | |
opponent = job_dict['opponent'] | |
# as long as main player, both active and historical are ok | |
assert (isinstance(opponent, HistoricalPlayer) | |
and 'MainPlayer' in opponent.parent_id) or isinstance(opponent, MainPlayer) | |
hp_list = [] | |
for i in range(3): | |
for p in setup_league: | |
if isinstance(p, MainPlayer): | |
p.total_agent_step = (i + 1) * 2 * ONE_PHASE_STEP | |
hp = p.snapshot(env) | |
setup_payoff.add_player(hp) | |
hp_list.append(hp) | |
setup_league += hp_list | |
no_main_player_league = [p for p in setup_league if not isinstance(p, MainPlayer)] | |
for i in range(10000): | |
home = np.random.choice(no_main_player_league) | |
away = np.random.choice(no_main_player_league) | |
result = random_job_result() | |
match_info = { | |
'player_id': [home.player_id, away.player_id], | |
'result': [[result]], | |
} | |
assert setup_payoff.update(match_info) | |
for i in range(10): | |
job_dict = setup_league[1].get_job() | |
opponent = job_dict['opponent'] | |
# as long as main player, both active and historical are ok | |
assert (isinstance(opponent, HistoricalPlayer) | |
and 'MainPlayer' in opponent.parent_id) or isinstance(opponent, MainPlayer) | |
def test_is_trained_enough(self, setup_league): | |
# only a few differences from `is_trained_enough` of MainPlayer | |
pass | |
def test_mutate(self, setup_league): | |
assert isinstance(setup_league[1], MainExploiter) | |
info = {'reset_checkpoint_path': 'pretrain_checkpoint.pth'} | |
for _ in range(10): | |
assert setup_league[1].mutate(info) == info['reset_checkpoint_path'] | |
class TestLeagueExploiter: | |
def test_get_job(self, setup_league): | |
assert isinstance(setup_league[2], LeagueExploiter) | |
job_dict = setup_league[2].get_job() | |
opponent = job_dict['opponent'] | |
assert isinstance(opponent, HistoricalPlayer) | |
assert isinstance(setup_league[3], LeagueExploiter) | |
job_dict = setup_league[3].get_job() | |
opponent = job_dict['opponent'] | |
assert isinstance(opponent, HistoricalPlayer) | |
def test_is_trained_enough(self, setup_league): | |
# this function is the same as `is_trained_enough` of MainPlayer | |
pass | |
def test_mutate(self, setup_league): | |
assert isinstance(setup_league[2], LeagueExploiter) | |
info = {'reset_checkpoint_path': 'pretrain_checkpoint.pth'} | |
results = [] | |
for _ in range(1000): | |
results.append(setup_league[2].mutate(info)) | |
freq = len([t for t in results if t]) * 1.0 / len(results) | |
assert 0.2 <= freq <= 0.3 # approximate | |
if __name__ == '__main__': | |
pytest.main(["-sv", os.path.basename(__file__)]) | |