# Copyright 2025 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import glob import os import re import subprocess import traceback import torch from torch.nn.parallel import DistributedDataParallel import torch.distributed as dist @contextlib.contextmanager def dist_load(path): if not dist.is_initialized() or dist.get_world_size() == 1 or os.path.realpath(path).startswith('/dev/shm'): yield path else: from tts.utils.commons.hparams import hparams from tts.utils.commons.trainer import LOCAL_RANK tmpdir = '/dev/shm' assert len(os.path.basename(path)) > 0 shm_ckpt_path = f'{tmpdir}/{hparams["exp_name"]}/{os.path.basename(path)}' if LOCAL_RANK == 0: subprocess.check_call( f'mkdir -p {os.path.dirname(shm_ckpt_path)}; ' f'cp -Lr {path} {shm_ckpt_path}', shell=True) dist.barrier() yield shm_ckpt_path dist.barrier() if LOCAL_RANK == 0: subprocess.check_call(f'rm -rf {shm_ckpt_path}', shell=True) def torch_load_dist(path, map_location='cpu'): with dist_load(path) as tmp_path: checkpoint = torch.load(tmp_path, map_location=map_location) return checkpoint def get_last_checkpoint(work_dir, steps=None): checkpoint = None last_ckpt_path = None ckpt_paths = get_all_ckpts(work_dir, steps) if len(ckpt_paths) > 0: last_ckpt_path = ckpt_paths[0] checkpoint = torch_load_dist(last_ckpt_path, map_location='cpu') return checkpoint, last_ckpt_path def get_all_ckpts(work_dir, steps=None): if steps is None or steps == 0: ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' else: ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' return sorted(glob.glob(ckpt_path_pattern), key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True, silent=False, load_opt=False, opts=None, steps=None, checkpoint=None, ckpt_path='', delete_unmatch=True): if checkpoint is None: if os.path.isfile(ckpt_base_dir): base_dir = os.path.dirname(ckpt_base_dir) ckpt_path = ckpt_base_dir checkpoint = torch_load_dist(ckpt_base_dir, map_location='cpu') else: base_dir = ckpt_base_dir if load_opt: checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps) else: ckpt_path = f'{ckpt_base_dir}/model_only_last.ckpt' if os.path.exists(ckpt_path): checkpoint = torch_load_dist(ckpt_path, map_location='cpu') else: checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps) if checkpoint is not None: state_dict_all = { k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items()} if not isinstance(cur_model, list): cur_models = [cur_model] model_names = [model_name] else: cur_models = cur_model model_names = model_name for model_name, cur_model in zip(model_names, cur_models): if isinstance(cur_model, DistributedDataParallel): cur_model = cur_model.module device = next(cur_model.parameters()).device if '.' not in model_name: state_dict = state_dict_all[model_name] else: base_model_name = model_name.split('.')[0] rest_model_name = model_name[len(base_model_name) + 1:] state_dict = { k[len(rest_model_name) + 1:]: v for k, v in state_dict_all[base_model_name].items() if k.startswith(f'{rest_model_name}.')} state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()} if not strict and delete_unmatch: try: cur_model.load_state_dict(state_dict, strict=True) if not silent: print(f"| loaded '{model_name}' from '{ckpt_path}' with strict=True.") except: cur_model_state_dict = cur_model.state_dict() cur_model_state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in cur_model_state_dict.items()} unmatched_keys = [] for key, param in state_dict.items(): if key in cur_model_state_dict: new_param = cur_model_state_dict[key] if new_param.shape != param.shape: unmatched_keys.append(key) print("| Unmatched keys: ", key, "cur model: ", new_param.shape, "ckpt model: ", param.shape) for key in unmatched_keys: del state_dict[key] load_results = cur_model.load_state_dict(state_dict, strict=strict) cur_model.to(device) if not silent: print(f"| loaded '{model_name}' from '{ckpt_path}'.") missing_keys, unexpected_keys = load_results.missing_keys, load_results.unexpected_keys print(f"| Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}") if load_opt: optimizer_states = checkpoint['optimizer_states'] assert len(opts) == len(optimizer_states) for optimizer, opt_state in zip(opts, optimizer_states): opt_state = {k.replace('_orig_mod.', ''): v for k, v in opt_state.items()} if optimizer is None: return try: optimizer.load_state_dict(opt_state) for i, state in enumerate(optimizer.state.values()): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) except ValueError: print(f"| WARMING: optimizer {optimizer} parameters not match !!!") return checkpoint.get('global_step', 0) else: e_msg = f"| ckpt not found in {base_dir}." if force: assert False, e_msg else: print(e_msg) def load_with_size_mismatch(model, state_dict, prefix=""): current_model_dict = model.state_dict() cm_keys = current_model_dict.keys() mismatch_keys = {k.replace(prefix, "") for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() != current_model_dict[k.replace(prefix, "")].size()} new_state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() == current_model_dict[k.replace(prefix, "")].size()} missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) print(f"| mismatch keys: ", mismatch_keys) if len(missing_keys) > 0: print(f"| missing_keys in dit: {missing_keys}") if len(unexpected_keys) > 0: print(f"| unexpected_keys in dit: {unexpected_keys}")