Spaces:
Sleeping
Sleeping
import pytest | |
from easydict import EasyDict | |
from ding.framework import OnlineRLContext | |
from ding.framework.middleware.ckpt_handler import CkptSaver | |
import torch.nn as nn | |
import torch.optim as optim | |
import os | |
import shutil | |
from unittest.mock import Mock, patch | |
from ding.framework import task | |
from ding.policy.base_policy import Policy | |
class TheModelClass(nn.Module): | |
def state_dict(self): | |
return 'fake_state_dict' | |
class MockPolicy(Mock): | |
def __init__(self, model, **kwargs) -> None: | |
super(MockPolicy, self).__init__(model) | |
self.learn_mode = model | |
def eval_mode(self): | |
return EasyDict({"state_dict": lambda: {}}) | |
def test_ckpt_saver(): | |
exp_name = 'test_ckpt_saver_exp' | |
ctx = OnlineRLContext() | |
train_freq = 100 | |
model = TheModelClass() | |
if not os.path.exists(exp_name): | |
os.makedirs(exp_name) | |
prefix = '{}/ckpt'.format(exp_name) | |
with patch("ding.policy.Policy", MockPolicy), task.start(): | |
policy = MockPolicy(model) | |
def mock_save_file(path, data, fs_type=None, use_lock=False): | |
assert path == "{}/eval.pth.tar".format(prefix) | |
with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file): | |
ctx.train_iter = 1 | |
ctx.eval_value = 9.4 | |
ckpt_saver = CkptSaver(policy, exp_name, train_freq) | |
ckpt_saver(ctx) | |
def mock_save_file(path, data, fs_type=None, use_lock=False): | |
assert path == "{}/iteration_{}.pth.tar".format(prefix, ctx.train_iter) | |
with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file): | |
ctx.train_iter = 100 | |
ctx.eval_value = 1 | |
ckpt_saver(ctx) | |
def mock_save_file(path, data, fs_type=None, use_lock=False): | |
assert path == "{}/final.pth.tar".format(prefix) | |
with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file): | |
task.finish = True | |
ckpt_saver(ctx) | |
shutil.rmtree(exp_name) | |