Spaces:
Sleeping
Sleeping
import os | |
import time | |
import pytest | |
import torch | |
import torch.nn as nn | |
import uuid | |
from ding.torch_utils.checkpoint_helper import auto_checkpoint, build_checkpoint_helper, CountVar | |
from ding.utils import read_file, save_file | |
class DstModel(nn.Module): | |
def __init__(self): | |
super(DstModel, self).__init__() | |
self.fc1 = nn.Linear(3, 3) | |
self.fc2 = nn.Linear(3, 8) | |
self.fc_dst = nn.Linear(3, 6) | |
class SrcModel(nn.Module): | |
def __init__(self): | |
super(SrcModel, self).__init__() | |
self.fc1 = nn.Linear(3, 3) | |
self.fc2 = nn.Linear(3, 8) | |
self.fc_src = nn.Linear(3, 7) | |
class HasStateDict(object): | |
def __init__(self, name): | |
self._name = name | |
self._state_dict = name + str(uuid.uuid4()) | |
def state_dict(self): | |
old = self._state_dict | |
self._state_dict = self._name + str(uuid.uuid4()) | |
return old | |
def load_state_dict(self, state_dict): | |
self._state_dict = state_dict | |
class TestCkptHelper: | |
def test_load_model(self): | |
path = 'model.pt' | |
os.popen('rm -rf ' + path) | |
time.sleep(1) | |
dst_model = DstModel() | |
src_model = SrcModel() | |
ckpt_state_dict = {'model': src_model.state_dict()} | |
torch.save(ckpt_state_dict, path) | |
ckpt_helper = build_checkpoint_helper({}) | |
with pytest.raises(RuntimeError): | |
ckpt_helper.load(path, dst_model, strict=True) | |
ckpt_helper.load(path, dst_model, strict=False) | |
assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() < 1e-6 | |
assert torch.abs(dst_model.fc1.bias - src_model.fc1.bias).max() < 1e-6 | |
dst_model = DstModel() | |
src_model = SrcModel() | |
assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() > 1e-6 | |
src_optimizer = HasStateDict('src_optimizer') | |
dst_optimizer = HasStateDict('dst_optimizer') | |
src_last_epoch = CountVar(11) | |
dst_last_epoch = CountVar(5) | |
src_last_iter = CountVar(110) | |
dst_last_iter = CountVar(50) | |
src_dataset = HasStateDict('src_dataset') | |
dst_dataset = HasStateDict('dst_dataset') | |
src_collector_info = HasStateDict('src_collect_info') | |
dst_collector_info = HasStateDict('dst_collect_info') | |
ckpt_helper.save( | |
path, | |
src_model, | |
optimizer=src_optimizer, | |
dataset=src_dataset, | |
collector_info=src_collector_info, | |
last_iter=src_last_iter, | |
last_epoch=src_last_epoch, | |
prefix_op='remove', | |
prefix="f" | |
) | |
ckpt_helper.load( | |
path, | |
dst_model, | |
dataset=dst_dataset, | |
optimizer=dst_optimizer, | |
last_iter=dst_last_iter, | |
last_epoch=dst_last_epoch, | |
collector_info=dst_collector_info, | |
strict=False, | |
state_dict_mask=['fc1'], | |
prefix_op='add', | |
prefix="f" | |
) | |
assert dst_dataset.state_dict().startswith('src') | |
assert dst_optimizer.state_dict().startswith('src') | |
assert dst_collector_info.state_dict().startswith('src') | |
assert dst_last_iter.val == 110 | |
for k, v in dst_model.named_parameters(): | |
assert k.startswith('fc') | |
print('==dst', dst_model.fc2.weight) | |
print('==src', src_model.fc2.weight) | |
assert torch.abs(dst_model.fc2.weight - src_model.fc2.weight).max() < 1e-6 | |
assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() > 1e-6 | |
checkpoint = read_file(path) | |
checkpoint.pop('dataset') | |
checkpoint.pop('optimizer') | |
checkpoint.pop('last_iter') | |
save_file(path, checkpoint) | |
ckpt_helper.load( | |
path, | |
dst_model, | |
dataset=dst_dataset, | |
optimizer=dst_optimizer, | |
last_iter=dst_last_iter, | |
last_epoch=dst_last_epoch, | |
collector_info=dst_collector_info, | |
strict=True, | |
state_dict_mask=['fc1'], | |
prefix_op='add', | |
prefix="f" | |
) | |
with pytest.raises(NotImplementedError): | |
ckpt_helper.load( | |
path, | |
dst_model, | |
strict=False, | |
lr_schduler='lr_scheduler', | |
last_iter=dst_last_iter, | |
) | |
with pytest.raises(KeyError): | |
ckpt_helper.save(path, src_model, prefix_op='key_error', prefix="f") | |
ckpt_helper.load(path, dst_model, strict=False, prefix_op='key_error', prefix="f") | |
os.popen('rm -rf ' + path + '*') | |
def test_count_var(): | |
var = CountVar(0) | |
var.add(5) | |
assert var.val == 5 | |
var.update(3) | |
assert var.val == 3 | |
def test_auto_checkpoint(): | |
class AutoCkptCls: | |
def __init__(self): | |
pass | |
def start(self): | |
for i in range(10): | |
if i < 5: | |
time.sleep(0.2) | |
else: | |
raise Exception("There is an exception") | |
break | |
def save_checkpoint(self, ckpt_path): | |
print('Checkpoint is saved successfully in {}!'.format(ckpt_path)) | |
auto_ckpt = AutoCkptCls() | |
auto_ckpt.start() | |
if __name__ == '__main__': | |
test = TestCkptHelper() | |
test.test_load_model() | |