Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from collections import OrderedDict | |
from mmengine.runner import CheckpointLoader, load_state_dict | |
def load_checkpoint(model, | |
filename, | |
map_location='cpu', | |
strict=False, | |
logger=None): | |
"""Load checkpoint from a file or URI. | |
Args: | |
model (Module): Module to load checkpoint. | |
filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |
``open-mmlab://xxx``. | |
map_location (str): Same as :func:`torch.load`. | |
strict (bool): Whether to allow different params for the model and | |
checkpoint. | |
logger (:mod:`logging.Logger` or None): The logger for error message. | |
Returns: | |
dict or OrderedDict: The loaded checkpoint. | |
""" | |
checkpoint = CheckpointLoader.load_checkpoint(filename, map_location) | |
# OrderedDict is a subclass of dict | |
if not isinstance(checkpoint, dict): | |
raise RuntimeError( | |
f'No state_dict found in checkpoint file {filename}') | |
# get state_dict from checkpoint | |
if 'state_dict' in checkpoint: | |
state_dict_tmp = checkpoint['state_dict'] | |
elif 'model' in checkpoint: | |
state_dict_tmp = checkpoint['model'] | |
else: | |
state_dict_tmp = checkpoint | |
state_dict = OrderedDict() | |
# strip prefix of state_dict | |
for k, v in state_dict_tmp.items(): | |
if k.startswith('module.backbone.'): | |
state_dict[k[16:]] = v | |
elif k.startswith('module.'): | |
state_dict[k[7:]] = v | |
elif k.startswith('backbone.'): | |
state_dict[k[9:]] = v | |
else: | |
state_dict[k] = v | |
# load state_dict | |
load_state_dict(model, state_dict, strict, logger) | |
return checkpoint | |
def get_state_dict(filename, map_location='cpu'): | |
"""Get state_dict from a file or URI. | |
Args: | |
filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |
``open-mmlab://xxx``. | |
map_location (str): Same as :func:`torch.load`. | |
Returns: | |
OrderedDict: The state_dict. | |
""" | |
checkpoint = CheckpointLoader.load_checkpoint(filename, map_location) | |
# OrderedDict is a subclass of dict | |
if not isinstance(checkpoint, dict): | |
raise RuntimeError( | |
f'No state_dict found in checkpoint file {filename}') | |
# get state_dict from checkpoint | |
if 'state_dict' in checkpoint: | |
state_dict_tmp = checkpoint['state_dict'] | |
else: | |
state_dict_tmp = checkpoint | |
state_dict = OrderedDict() | |
# strip prefix of state_dict | |
for k, v in state_dict_tmp.items(): | |
if k.startswith('module.backbone.'): | |
state_dict[k[16:]] = v | |
elif k.startswith('module.'): | |
state_dict[k[7:]] = v | |
elif k.startswith('backbone.'): | |
state_dict[k[9:]] = v | |
else: | |
state_dict[k] = v | |
return state_dict | |