Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
import os | |
import re | |
import yaml | |
import torch | |
from collections import OrderedDict | |
import datetime | |
def load_checkpoint(model: torch.nn.Module, path: str) -> dict: | |
rank = int(os.environ.get('RANK', 0)) | |
logging.info('[Rank {}] Checkpoint: loading from checkpoint {}'.format( | |
rank, path)) | |
checkpoint = torch.load(path, map_location='cpu') | |
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, | |
strict=False) | |
if rank == 0: | |
for key in missing_keys: | |
logging.info("missing tensor: {}".format(key)) | |
for key in unexpected_keys: | |
logging.info("unexpected tensor: {}".format(key)) | |
info_path = re.sub('.pt$', '.yaml', path) | |
configs = {} | |
if os.path.exists(info_path): | |
with open(info_path, 'r') as fin: | |
configs = yaml.load(fin, Loader=yaml.FullLoader) | |
if configs is None: | |
configs = {} | |
return configs | |
def save_state_dict_and_infos(state_dict, path: str, infos=None): | |
rank = int(os.environ.get('RANK', 0)) | |
logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format( | |
rank, path)) | |
torch.save(state_dict, path) | |
info_path = re.sub('.pt$', '.yaml', path) | |
if infos is None: | |
infos = {} | |
infos['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S') | |
with open(info_path, 'w') as fout: | |
data = yaml.dump(infos) | |
fout.write(data) | |
def save_checkpoint(model: torch.nn.Module, path: str, infos=None): | |
''' | |
Args: | |
infos (dict or None): any info you want to save. | |
''' | |
if isinstance(model, torch.nn.DataParallel): | |
state_dict = model.module.state_dict() | |
elif isinstance(model, torch.nn.parallel.DistributedDataParallel): | |
state_dict = model.module.state_dict() | |
else: | |
state_dict = model.state_dict() | |
save_state_dict_and_infos(state_dict, path, infos) | |
def filter_modules(model_state_dict, modules): | |
rank = int(os.environ.get('RANK', 0)) | |
new_mods = [] | |
incorrect_mods = [] | |
mods_model = model_state_dict.keys() | |
for mod in modules: | |
if any(key.startswith(mod) for key in mods_model): | |
new_mods += [mod] | |
else: | |
incorrect_mods += [mod] | |
if incorrect_mods and rank == 0: | |
logging.warning( | |
"module(s) %s don't match or (partially match) " | |
"available modules in model.", | |
incorrect_mods, | |
) | |
logging.warning("for information, the existing modules in model are:") | |
logging.warning("%s", mods_model) | |
return new_mods | |
def load_trained_modules(model: torch.nn.Module, args: None): | |
# Load encoder modules with pre-trained model(s). | |
enc_model_path = args.enc_init | |
enc_modules = args.enc_init_mods | |
main_state_dict = model.state_dict() | |
logging.warning("model(s) found for pre-initialization") | |
if os.path.isfile(enc_model_path): | |
logging.info('Checkpoint: loading from checkpoint %s for CPU' % | |
enc_model_path) | |
model_state_dict = torch.load(enc_model_path, map_location='cpu') | |
modules = filter_modules(model_state_dict, enc_modules) | |
partial_state_dict = OrderedDict() | |
for key, value in model_state_dict.items(): | |
if any(key.startswith(m) for m in modules): | |
partial_state_dict[key] = value | |
main_state_dict.update(partial_state_dict) | |
else: | |
logging.warning("model was not found : %s", enc_model_path) | |
model.load_state_dict(main_state_dict) | |
configs = {} | |
return configs | |