|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
def load_state_dict(model, state_dict):
|
|
"""Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict.
|
|
|
|
DataParallel prefixes state_dict keys with 'module.' when saving.
|
|
If the model is not a DataParallel model but the state_dict is, then prefixes are removed.
|
|
If the model is a DataParallel model but the state_dict is not, then prefixes are added.
|
|
"""
|
|
state_dict = state_dict.get('model', state_dict)
|
|
|
|
|
|
do_prefix = isinstance(
|
|
model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
|
|
state = {}
|
|
for k, v in state_dict.items():
|
|
if k.startswith('module.') and not do_prefix:
|
|
k = k[7:]
|
|
|
|
if not k.startswith('module.') and do_prefix:
|
|
k = 'module.' + k
|
|
|
|
state[k] = v
|
|
|
|
model.load_state_dict(state)
|
|
print("Loaded successfully")
|
|
return model
|
|
|
|
|
|
def load_wts(model, checkpoint_path):
|
|
ckpt = torch.load(checkpoint_path, map_location='cpu')
|
|
return load_state_dict(model, ckpt)
|
|
|
|
|
|
def load_state_dict_from_url(model, url, **kwargs):
|
|
state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs)
|
|
return load_state_dict(model, state_dict)
|
|
|
|
|
|
def load_state_from_resource(model, resource: str):
|
|
"""Loads weights to the model from a given resource. A resource can be of following types:
|
|
1. URL. Prefixed with "url::"
|
|
e.g. url::http(s)://url.resource.com/ckpt.pt
|
|
|
|
2. Local path. Prefixed with "local::"
|
|
e.g. local::/path/to/ckpt.pt
|
|
|
|
|
|
Args:
|
|
model (torch.nn.Module): Model
|
|
resource (str): resource string
|
|
|
|
Returns:
|
|
torch.nn.Module: Model with loaded weights
|
|
"""
|
|
print(f"Using pretrained resource {resource}")
|
|
|
|
if resource.startswith('url::'):
|
|
url = resource.split('url::')[1]
|
|
return load_state_dict_from_url(model, url, progress=True)
|
|
|
|
elif resource.startswith('local::'):
|
|
path = resource.split('local::')[1]
|
|
return load_wts(model, path)
|
|
|
|
else:
|
|
raise ValueError("Invalid resource type, only url:: and local:: are supported")
|
|
|