Spaces:
Sleeping
Sleeping
import os | |
import time | |
import pytest | |
import torch | |
from easydict import EasyDict | |
from typing import Any | |
from functools import partial | |
from ding.worker import BaseLearner | |
from ding.worker.learner import LearnerHook, add_learner_hook, create_learner | |
class FakeLearner(BaseLearner): | |
def random_data(): | |
return { | |
'obs': torch.randn(2), | |
'replay_buffer_idx': 0, | |
'replay_unique_id': 0, | |
} | |
def get_data(self, batch_size): | |
return [self.random_data for _ in range(batch_size)] | |
class FakePolicy: | |
def __init__(self): | |
self._model = torch.nn.Identity() | |
def forward(self, x): | |
return { | |
'total_loss': torch.randn(1).squeeze(), | |
'cur_lr': 0.1, | |
'priority': [1., 2., 3.], | |
'[histogram]h_example': [1.2, 2.3, 3.4], | |
'[scalars]s_example': { | |
'a': 5., | |
'b': 4. | |
}, | |
} | |
def data_preprocess(self, x): | |
return x | |
def state_dict(self): | |
return {'model': self._model} | |
def load_state_dict(self, state_dict): | |
pass | |
def info(self): | |
return 'FakePolicy' | |
def monitor_vars(self): | |
return ['total_loss', 'cur_lr'] | |
def get_attribute(self, name): | |
if name == 'cuda': | |
return False | |
elif name == 'device': | |
return 'cpu' | |
elif name == 'batch_size': | |
return 2 | |
elif name == 'on_policy': | |
return False | |
else: | |
raise KeyError | |
def reset(self): | |
pass | |
class TestBaseLearner: | |
def _get_cfg(self, path): | |
cfg = BaseLearner.default_config() | |
cfg.import_names = [] | |
cfg.learner_type = 'fake' | |
cfg.train_iterations = 10 | |
cfg.hook.load_ckpt_before_run = path | |
cfg.hook.log_show_after_iter = 5 | |
# Another way to build hook: Complete config | |
cfg.hook.save_ckpt_after_iter = dict( | |
name='save_ckpt_after_iter', type='save_ckpt', priority=40, position='after_iter', ext_args={'freq': 5} | |
) | |
return cfg | |
def test_naive(self): | |
os.popen('rm -rf iteration_5.pth.tar*') | |
time.sleep(1.0) | |
with pytest.raises(KeyError): | |
create_learner(EasyDict({'type': 'placeholder', 'import_names': []})) | |
path = os.path.join(os.path.dirname(__file__), './iteration_5.pth.tar') | |
torch.save({'model': {}, 'last_iter': 5}, path) | |
time.sleep(0.5) | |
cfg = self._get_cfg(path) | |
learner = FakeLearner(cfg, exp_name='exp_test') | |
learner.policy = FakePolicy() | |
learner.setup_dataloader() | |
learner.start() | |
time.sleep(2) | |
assert learner.last_iter.val == 10 + 5 | |
# test hook | |
dir_name = '{}/ckpt'.format(learner.exp_name) | |
for n in [5, 10, 15]: | |
assert os.path.exists(dir_name + '/iteration_{}.pth.tar'.format(n)) | |
for n in [0, 4, 7, 12]: | |
assert not os.path.exists(dir_name + '/iteration_{}.pth.tar'.format(n)) | |
learner.debug('iter [5, 10, 15] exists; iter [0, 4, 7, 12] does not exist.') | |
learner.save_checkpoint('best') | |
info = learner.learn_info | |
for info_name in ['learner_step', 'priority_info', 'learner_done']: | |
assert info_name in info | |
class FakeHook(LearnerHook): | |
def __call__(self, engine: Any) -> Any: | |
pass | |
original_hook_num = len(learner._hooks['after_run']) | |
add_learner_hook(learner._hooks, FakeHook(name='fake_hook', priority=30, position='after_run')) | |
assert len(learner._hooks['after_run']) == original_hook_num + 1 | |
os.popen('rm -rf iteration_5.pth.tar*') | |
os.popen('rm -rf ' + dir_name) | |
os.popen('rm -rf learner') | |
os.popen('rm -rf log') | |
learner.close() | |