Spaces:
Sleeping
Sleeping
from easydict import EasyDict | |
import pytest | |
from ding.utils import Scheduler | |
from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config | |
class TestSchedulerModule(): | |
test_merged_scheduler_config = dict( | |
schedule_flag=False, | |
schedule_mode='reduce', | |
factor=0.05, | |
change_range=[-1, 1], | |
threshold=1e-4, | |
optimize_mode='min', | |
patience=1, | |
cooldown=0, | |
) | |
test_merged_scheduler_config = EasyDict(test_merged_scheduler_config) | |
test_policy_config = EasyDict(league_demo_ppo_config.policy) | |
test_policy_config_param = test_policy_config.learn.entropy_weight | |
def test_init_factor(self): | |
self.test_merged_scheduler_config.factor = 'hello_test' | |
with pytest.raises(AssertionError) as excinfo: | |
test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
assert 'float/int' in str(excinfo.value) | |
self.test_merged_scheduler_config.factor = 0 | |
with pytest.raises(AssertionError) as excinfo: | |
test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
assert 'greater than 0' in str(excinfo.value) | |
# recover the correct value for later test function | |
self.test_merged_scheduler_config.factor = 0.05 | |
def test_init_change_range(self): | |
self.test_merged_scheduler_config.change_range = 0 | |
with pytest.raises(AssertionError) as excinfo: | |
test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
assert 'list' in str(excinfo.value) | |
self.test_merged_scheduler_config.change_range = [0, 'hello_test'] | |
with pytest.raises(AssertionError) as excinfo: | |
test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
assert 'float' in str(excinfo.value) | |
self.test_merged_scheduler_config.change_range = [0, -1] | |
with pytest.raises(AssertionError) as excinfo: | |
test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
assert 'smaller' in str(excinfo.value) | |
# recover the correct value for later test function | |
self.test_merged_scheduler_config.change_range = [-1, 1] | |
def test_init_patience(self): | |
self.test_merged_scheduler_config.patience = "hello_test" | |
with pytest.raises(AssertionError) as excinfo: | |
test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
assert 'integer' in str(excinfo.value) | |
self.test_merged_scheduler_config.patience = -1 | |
with pytest.raises(AssertionError) as excinfo: | |
test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
assert 'greater' in str(excinfo.value) | |
# recover the correct value for later test function | |
self.test_merged_scheduler_config.patience = 1 | |
def test_is_better(self): | |
test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
assert test_scheduler.is_better(-1) is True | |
test_scheduler.last_metrics = 1 | |
assert test_scheduler.is_better(0.5) is True | |
def test_in_cooldown(self): | |
self.test_merged_scheduler_config.cooldown_counter = 0 | |
test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
assert test_scheduler.in_cooldown is False | |
def test_step(self): | |
self.test_merged_scheduler_config.cooldown = 1 | |
test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
assert test_scheduler.cooldown_counter == 1 | |
test_scheduler.last_metrics = 1.0 | |
old_param = self.test_policy_config.learn.entropy_weight | |
# good epoch with maximum cooldown lenth is 1 | |
self.test_policy_config_param = test_scheduler.step(0.9, self.test_policy_config_param) | |
assert self.test_policy_config_param == old_param | |
assert test_scheduler.cooldown_counter == 0 | |
assert test_scheduler.last_metrics == 0.9 | |
assert test_scheduler.bad_epochs_num == 0 | |
# first bad epoch in cooldown period | |
self.test_policy_config_param = test_scheduler.step(0.899999, self.test_policy_config_param) | |
assert self.test_policy_config_param == old_param | |
assert test_scheduler.cooldown_counter == 0 | |
assert test_scheduler.last_metrics == 0.899999 | |
assert test_scheduler.bad_epochs_num == 1 | |
# first bad epoch after cooldown | |
self.test_policy_config_param = test_scheduler.step(0.899998, self.test_policy_config_param) | |
assert self.test_policy_config_param == old_param - self.test_merged_scheduler_config.factor | |
assert test_scheduler.cooldown_counter == 1 | |
assert test_scheduler.last_metrics == 0.899998 | |
assert test_scheduler.bad_epochs_num == 0 | |