liyy201912's picture
Upload folder using huggingface_hub
cc0dd3c
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from mmengine.runner import CheckpointLoader, load_state_dict
def load_checkpoint(model,
filename,
map_location='cpu',
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = CheckpointLoader.load_checkpoint(filename, map_location)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
# get state_dict from checkpoint
if 'state_dict' in checkpoint:
state_dict_tmp = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict_tmp = checkpoint['model']
else:
state_dict_tmp = checkpoint
state_dict = OrderedDict()
# strip prefix of state_dict
for k, v in state_dict_tmp.items():
if k.startswith('module.backbone.'):
state_dict[k[16:]] = v
elif k.startswith('module.'):
state_dict[k[7:]] = v
elif k.startswith('backbone.'):
state_dict[k[9:]] = v
else:
state_dict[k] = v
# load state_dict
load_state_dict(model, state_dict, strict, logger)
return checkpoint
def get_state_dict(filename, map_location='cpu'):
"""Get state_dict from a file or URI.
Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``.
map_location (str): Same as :func:`torch.load`.
Returns:
OrderedDict: The state_dict.
"""
checkpoint = CheckpointLoader.load_checkpoint(filename, map_location)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
# get state_dict from checkpoint
if 'state_dict' in checkpoint:
state_dict_tmp = checkpoint['state_dict']
else:
state_dict_tmp = checkpoint
state_dict = OrderedDict()
# strip prefix of state_dict
for k, v in state_dict_tmp.items():
if k.startswith('module.backbone.'):
state_dict[k[16:]] = v
elif k.startswith('module.'):
state_dict[k[7:]] = v
elif k.startswith('backbone.'):
state_dict[k[9:]] = v
else:
state_dict[k] = v
return state_dict