# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) # 2023 Tsinghua Univ. (authors: Xingchen Song) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from contextlib import nullcontext import copy from typing import List, Optional import json import logging import os import torch import yaml import torch.optim as optim import torch.distributed as dist from tensorboardX import SummaryWriter from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP, CPUOffload, MixedPrecision, sharded_grad_scaler, ShardingStrategy) try: import deepspeed from deepspeed.runtime.zero.stage_1_and_2 import ( estimate_zero2_model_states_mem_needs_all_live) from deepspeed.runtime.zero.stage3 import ( estimate_zero3_model_states_mem_needs_all_live) from deepspeed.utils.zero_to_fp32 import ( convert_zero_checkpoint_to_fp32_state_dict) except ImportError: pass from wenet.utils.checkpoint import save_checkpoint from wenet.utils.common import (StepTimer, get_nested_attribute, lrs_to_str, tensor_to_scalar) from wenet.utils.fsdp_utils import (check_gradient_checkpoint, fsdp_save_model, apply_fsdp_checkpointing, wenet_fsdp_wrap_policy) from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing from wenet.utils.ctc_utils import get_blank_id from wenet.utils.common import TORCH_NPU_AVAILABLE from wenet.utils.init_dataset import init_dataset def add_model_args(parser): parser.add_argument('--config', required=True, help='config file') parser.add_argument('--model_dir', required=True, help='save model dir') parser.add_argument('--checkpoint', help='checkpoint model') parser.add_argument('--tensorboard_dir', default='tensorboard', help='tensorboard log dir') parser.add_argument('--override_config', action='append', default=[], help="override yaml config") parser.add_argument("--enc_init", default=None, type=str, help="Pre-trained model to initialize encoder") parser.add_argument( '--enc_init_mods', default="encoder.", type=lambda s: [str(mod) for mod in s.split(",") if s != ""], help="List of encoder modules \ to initialize ,separated by a comma") parser.add_argument( '--freeze_modules', default="", type=lambda s: [str(mod) for mod in s.split(",") if s != ""], help='free module names', ) return parser def add_trace_args(parser): parser.add_argument('--jit', action='store_true', default=False, help='if use jit to trace model while training stage') parser.add_argument('--print_model', action='store_true', default=False, help='print model') return parser def add_dataset_args(parser): parser.add_argument('--data_type', default='raw', # choices=['raw', 'shard'], help='train and cv data type') parser.add_argument('--train_data', required=True, help='train data file') parser.add_argument('--cv_data', required=True, help='cv data file') parser.add_argument('--num_workers', default=0, type=int, help='num of subprocess workers for reading') parser.add_argument('--pin_memory', action='store_true', default=False, help='Use pinned memory buffers used for reading') parser.add_argument('--prefetch', default=100, type=int, help='prefetch number') return parser def add_lora_args(parser): '''Configure parameters for LoRA fine-tuning. Set use_lora and only_optimize_lora to true to enable LoRA functionality. LoRA will be injected to model through (lora_modules, lora_attn_attr, lora_list). LoRA weights will be merged after calling model.eval() (or model.train(mode=False)). LoRA weights need to be loaded after fine-tuning with DeepSpeed. ''' parser.add_argument("--use_lora", default=False, type=bool, help="whether use the lora finetune.") parser.add_argument("--only_optimize_lora", default=False, type=bool, help="freeze all other paramters and only optimize \ LoRA-related prameters.") parser.add_argument( '--lora_modules', default="encoder.encoders", type=lambda s: [str(mod) for mod in s.split(",") if s != ""], help='modules names needs inject lora', ) parser.add_argument( "--lora_attn_attr", default="self_attn,src_attn", type=lambda s: [str(mod) for mod in s.split(",") if s != ""], help="lora_attn_attr.") parser.add_argument( "--lora_list", default="linear_out,linear_q,linear_k,linear_v", type=lambda s: [str(mod) for mod in s.split(",") if s != ""], help="lora module list.") parser.add_argument("--lora_rank", default=8, type=int, help="lora rank num.") parser.add_argument("--lora_alpha", default=8, type=int, help="lora scale param, scale=lora_alpha/lora_rank.") parser.add_argument("--lora_dropout", default=0, type=float, help="lora dropout param.") parser.add_argument("--lora_ckpt_path", default=None, type=str, help="lora checkpoint path.") parser.add_argument("--lora_reinit", default=False, type=bool, help="whether use the lora init, default is zero init.") parser.add_argument('--lora_init_yaml', default="wenet/finetune/lora/config.yaml", type=str, help='Path to the configuration YAML file') return parser def add_ddp_args(parser): parser.add_argument('--ddp.dist_backend', dest='dist_backend', default='nccl', choices=['nccl', 'gloo', "hccl"], help='distributed backend') parser.add_argument('--use_amp', action='store_true', default=False, help='Use automatic mixed precision training') parser.add_argument('--fp16_grad_sync', action='store_true', default=False, help='Use fp16 gradient sync for ddp') return parser def add_deepspeed_args(parser): parser.add_argument('--timeout', default=30, type=int, help='timeout (in seconds) of wenet_join. ' + '30s for aishell & 300s for wenetspeech') parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher') parser.add_argument('--deepspeed.save_states', dest='save_states', default='model_only', choices=['model_only', 'model+optimizer'], help='save model/optimizer states') # DeepSpeed automaticly add '--deepspeed' and '--deepspeed_config' to parser try: parser = deepspeed.add_config_arguments(parser) except Exception as e: print(e) return parser def add_fsdp_args(parser): parser.add_argument( '--dtype', default='fp32', choices=['fp32', 'fp16', 'bf16'], help='when amp is used, dtype is automatically set to fp16.\ this arg has no effect when deepspeed is enabled.') parser.add_argument( '--fsdp_cpu_offload', default=False, type=bool, help='whether to offload parameters to CPU', ) parser.add_argument( '--fsdp_sync_module_states', type=bool, default=True, help='\ each FSDP module will broadcast module parameters and buffers from \ rank 0 to ensure that they are replicated across ranks', ) parser.add_argument( '--fsdp_sharding_strategy', default='zero2', # TODO(Mddct): pipeline and model parallel (3-D parallelism) choices=['no_shard', 'model', 'zero2', 'zero3'], help='Sharding strategy for FSDP. Choose from the following options:\n' ' - "no_shard": Equivalent to DistributedDataParallel (DDP).\n' ' - "model": WENET_ENC_DEC strategy, equivalent to DeepSpeed zero1.\n' ' - "zero2": SHARD_GRAD_OP strategy, equivalent to DeepSpeed zero2.\n' ' - "zero3": FULL_SHARD strategy, equivalent to DeepSpeed zero3.\n' 'For more information, refer to the FSDP API documentation.') return parser def init_distributed(args): world_size = int(os.environ.get('WORLD_SIZE', 1)) local_rank = int(os.environ.get('LOCAL_RANK', 0)) rank = int(os.environ.get('RANK', 0)) logging.info('training on multiple gpus, this gpu {}'.format(local_rank) + ', rank {}, world_size {}'.format(rank, world_size)) if args.train_engine in ["torch_ddp", "torch_fsdp"]: if "cuda" in args.device: torch.cuda.set_device(local_rank) elif "npu" in args.device and TORCH_NPU_AVAILABLE: torch.npu.set_device(local_rank) else: logging.error("not supported device: {}".format(args.device)) dist.init_process_group(args.dist_backend) elif args.train_engine == "deepspeed": deepspeed.init_distributed(dist_backend=args.dist_backend) else: logging.error("not supported engine: {}".format(args.train_engine)) return world_size, local_rank, rank def check_modify_and_save_config(args, configs, symbol_table): if args.train_engine in ["torch_ddp", "torch_fsdp"]: if args.use_amp: configs["dtype"] = "fp16" args.dtype = 'fp16' else: configs["dtype"] = args.dtype elif args.train_engine == "deepspeed": # NOTE(xcsong): DeepSpeed does not support uneven data. When using custom # dataset, we need to manually ensure that the data is evenly distributed # across all processe. we impl `train_utils.py::wenet_join` for this func # ref: https://github.com/microsoft/DeepSpeed/issues/2223 # # NOTE(xsong): We also need to keep: # 1. `train_micro_batch_size_per_gpu == 1` # 2. `accum_grad (in train_confomrer.yaml) # == gradient_accumulation_steps (in ds_config.json)` # 3. `grad_clip (in train_confomrer.yaml) # == gradient_clipping (in ds_config.json)` # The reason for such consistence checking lies in that deepspeed's native # dataloader uses PyTorch's torch.utils.data.DistributedSampler which does # not support IterableDataset, IterableDataset is extremly useful in large # scale training because it lets you stream the data without having to # download the complete dataset. # ref: https://github.com/microsoft/DeepSpeed/issues/1371 # https://github.com/microsoft/DeepSpeed/issues/285 # To make deepspeed training compatible with IterableDataset, we have to # use custom dataloader instead of deepspeed's native loader and thus we # should configure batchsize in train_confomrer.yaml instead of # ds_config.json. On the contrary, gradient accumulation / clipping should be # configured in ds_config.json since they will be handled by ds automatically. # ref: https://github.com/microsoft/DeepSpeed/issues/62 with open(args.deepspeed_config, 'r') as fin: ds_configs = json.load(fin) if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]: configs["dtype"] = "fp16" elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]: configs["dtype"] = "bf16" else: configs["dtype"] = "fp32" assert ds_configs["train_micro_batch_size_per_gpu"] == 1 assert ds_configs["gradient_accumulation_steps"] == configs[ 'accum_grad'] assert ds_configs["gradient_clipping"] == configs['grad_clip'] assert ds_configs["steps_per_print"] == configs['log_interval'] if args.use_lora: configs['lora_conf'] = {} configs['lora_conf']['lora_modules'] = args.lora_modules configs['lora_conf']['lora_attn_attr'] = args.lora_attn_attr configs['lora_conf']['lora_list'] = args.lora_list configs['lora_conf']['lora_rank'] = args.lora_rank configs['lora_conf']['lora_alpha'] = args.lora_alpha configs['lora_conf']['lora_dropout'] = args.lora_dropout if configs["model"] == 'asr_model': if 'input_dim' not in configs: if 'fbank_conf' in configs['dataset_conf']: input_dim = configs['dataset_conf']['fbank_conf'][ 'num_mel_bins'] elif 'log_mel_spectrogram_conf' in configs['dataset_conf']: input_dim = configs['dataset_conf'][ 'log_mel_spectrogram_conf']['num_mel_bins'] else: input_dim = configs['dataset_conf']['mfcc_conf'][ 'num_mel_bins'] else: input_dim = configs['input_dim'] configs['input_dim'] = input_dim configs, _ = get_blank_id(configs, symbol_table) configs['output_dim'] = configs['vocab_size'] configs['train_engine'] = args.train_engine configs['use_amp'] = args.use_amp configs['model_dir'] = args.model_dir configs['save_states'] = args.save_states # Save configs to model_dir/train.yaml for inference and export if int(os.environ.get('RANK', 0)) == 0: saved_config_path = os.path.join(args.model_dir, 'train.yaml') with open(saved_config_path, 'w') as fout: data = yaml.dump(configs) fout.write(data) if configs["model_conf"].get("apply_non_blank_embedding", False): logging.warn('Had better load a well trained model' 'if apply_non_blank_embedding is true !!!') return configs def init_dataset_and_dataloader(args, configs, tokenizer, seed=777): generator = torch.Generator() generator.manual_seed(seed) # if save_interval in configs, steps mode else epoch mode if "save_interval" in configs: configs['dataset_conf']['cycle'] = configs.get('max_epoch', 100) conf = configs['dataset_conf'] dataset_type = configs.get('dataset', 'asr') configs['vocab_size'] = tokenizer.vocab_size() train_dataset = init_dataset(dataset_type, args.data_type, args.train_data, tokenizer, conf, True, split='train') tag = configs["init_infos"].get("tag", "init") train_dataset.set_epoch(configs["init_infos"].get('epoch', 0) + int("epoch_" in tag) - 1) cv_conf = copy.deepcopy(conf) cv_conf['split_num'] = 1 cv_dataset = init_dataset(dataset_type, args.data_type, args.cv_data, tokenizer, cv_conf, partition=False, split='cv') # NOTE(xcsong): Why we prefer persistent_workers=True ? # https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110 train_data_loader = DataLoader(train_dataset, batch_size=None, pin_memory=args.pin_memory, num_workers=args.num_workers, persistent_workers=True, generator=generator, prefetch_factor=args.prefetch) cv_data_loader = DataLoader(cv_dataset, batch_size=None, pin_memory=args.pin_memory, num_workers=args.num_workers, persistent_workers=True, generator=generator, prefetch_factor=args.prefetch) return train_dataset, cv_dataset, train_data_loader, cv_data_loader def wrap_cuda_model(args, model, configs=None): local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) world_size = int(os.environ.get('WORLD_SIZE', 1)) if hasattr(model, 'encoder'): grad_ckpt = getattr(model.encoder, 'gradient_checkpointing', False) else: grad_ckpt = False if args.train_engine == "torch_ddp": # native pytorch ddp device = torch.device(args.device) model.to(device) # model = torch.nn.parallel.DistributedDataParallel( # model, find_unused_parameters=not grad_ckpt) model = torch.nn.parallel.DistributedDataParallel( model, find_unused_parameters=True) elif args.train_engine == "deepspeed": # deepspeed # NOTE(xcsong): look in detail how the memory estimator API works: # https://deepspeed.readthedocs.io/en/latest/memory.html#discussion if int(os.environ.get('RANK', 0)) == 0: logging.info("Estimating model states memory needs (zero2)...") estimate_zero2_model_states_mem_needs_all_live( model, num_gpus_per_node=local_world_size, num_nodes=world_size // local_world_size) logging.info("Estimating model states memory needs (zero3)...") estimate_zero3_model_states_mem_needs_all_live( model, num_gpus_per_node=local_world_size, num_nodes=world_size // local_world_size) device = torch.device(args.device) # Init device later pass # Init DeepSpeed later elif args.train_engine == 'torch_fsdp': assert configs is not None mixed_precision_dtype = { 'fp32': torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16, }[configs['dtype']] sharding_strategy = { 'model': ShardingStrategy.SHARD_GRAD_OP, 'zero2': ShardingStrategy.SHARD_GRAD_OP, 'zero3': ShardingStrategy.FULL_SHARD, 'no_shard': ShardingStrategy.NO_SHARD, }[args.fsdp_sharding_strategy] wrap_policy = wenet_fsdp_wrap_policy(mode=args.fsdp_sharding_strategy) layer_types = check_gradient_checkpoint(model) if "cuda" in args.device: device_id = torch.cuda.current_device() elif "npu" in args.device and TORCH_NPU_AVAILABLE: device_id = torch.npu.current_device() else: logging.error("not supported device: {}".format(args.device)) model = FSDP( model, auto_wrap_policy=wrap_policy, cpu_offload=CPUOffload(offload_params=True) if args.fsdp_cpu_offload is True else None, mixed_precision=MixedPrecision( param_dtype=mixed_precision_dtype, reduce_dtype=mixed_precision_dtype, buffer_dtype=mixed_precision_dtype, ), sharding_strategy=sharding_strategy, limit_all_gathers=True, use_orig_params=True, sync_module_states=args.fsdp_sync_module_states, # init_distributed is called (torch.cuda.set_device), # we should set device_id, see FSDP api device_id=device_id) apply_fsdp_checkpointing(model, layer_types) device = torch.device(args.device) else: logging.error("not supported engine: {}".format(args.train_engine)) if args.train_engine in ["torch_fsdp", "torch_ddp"]: if args.fp16_grad_sync: from torch.distributed.algorithms.ddp_comm_hooks import ( default as comm_hooks, ) model.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook) return model, device def init_optimizer_and_scheduler(args, configs, model): groups = [] lr = configs['optim_conf'].get('lr') if isinstance(lr, List): assert configs['scheduler'] == 'warmuplr' modules_m = configs['optim_conf']['modules'] assert isinstance(modules_m, List) assert len(modules_m) + 1 == len(lr) special_param_ids = set() rest_params = [] for (i, m_str) in enumerate(modules_m): sub_module = get_nested_attribute(model, m_str) subs_params = [] for _, sub_params in sub_module.named_parameters(): subs_params.append(sub_params) special_param_ids.add(id(sub_params)) groups.append({'params': subs_params, 'lr': lr[i]}) # other model's parameters for _, param in model.named_parameters(): if id(param) not in special_param_ids: rest_params.append(param) groups.append({'params': rest_params, 'lr': lr[-1]}) params = groups if len(groups) > 0 else model.parameters() optim_conf = copy.deepcopy(configs['optim_conf']) if 'modules' in optim_conf: del optim_conf['modules'] if isinstance(lr, List): optim_conf['lr'] = lr[-1] if configs['optim'] == 'adam': optimizer = optim.Adam(params, **optim_conf) elif configs['optim'] == 'adamw': optimizer = optim.AdamW(params, **optim_conf) else: raise ValueError("unknown optimizer: " + configs['optim']) scheduler_type = None if configs['scheduler'] == 'warmuplr': scheduler_type = WarmupLR scheduler = WarmupLR(optimizer, **configs['scheduler_conf']) elif configs['scheduler'] == 'NoamHoldAnnealing': scheduler_type = NoamHoldAnnealing scheduler = NoamHoldAnnealing(optimizer, **configs['scheduler_conf']) else: raise ValueError("unknown scheduler: " + configs['scheduler']) # NOTE(xcsong): Custom optimizer might yield poor performance when # zero-offload is enabled, if you do want to offload optimizer to CPU, # please set optimizer in ds_config.json, see: # (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters) if args.train_engine == "deepspeed": with open(args.deepspeed_config, 'r') as fin: ds_configs = json.load(fin) if "optimizer" in ds_configs: # NOTE(xcsong): Disable custom optimizer if it is set in ds_config, # extremely useful when enable cpu_offload, DeepspeedCpuAdam # could be 4~5x faster than torch native adam optimizer = None if "scheduler" in ds_configs: scheduler = None else: def scheduler(opt): return scheduler_type(opt, **configs['scheduler_conf']) model, optimizer, _, scheduler = deepspeed.initialize( args=args, model=model, optimizer=optimizer, lr_scheduler=scheduler, model_parameters=model.parameters()) step = configs.get("init_infos", {}).get("step", -1) scheduler.set_step(step) return model, optimizer, scheduler def trace_and_print_model(args, model): # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine # the code to satisfy the script export requirements if int(os.environ.get('RANK', 0)) == 0: if args.jit: script_model = torch.jit.script(model) script_model.save(os.path.join(args.model_dir, 'init.zip')) if args.print_model: print(model) num_params = sum(p.numel() for p in model.parameters()) print('the number of model params: {:,d}'.format(num_params)) def init_summarywriter(args): writer = None if int(os.environ.get('RANK', 0)) == 0: os.makedirs(args.model_dir, exist_ok=True) exp_id = os.path.basename(args.model_dir) writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id)) return writer def init_scaler(args): scaler = None if args.use_amp: if "cuda" in args.device: scaler = torch.cuda.amp.GradScaler() elif "npu" in args.device and TORCH_NPU_AVAILABLE: scaler = torch.npu.amp.GradScaler() else: logging.error("not supported device: {}".format(args.device)) elif args.train_engine == 'torch_fsdp': # why bf16 don't need scaler: # https://discuss.pytorch.org/t/why-bf16-do-not-need-loss-scaling/176596 if args.dtype in ['fp16']: scaler = sharded_grad_scaler.ShardedGradScaler(enabled=True) return scaler def save_model(model, info_dict): rank = int(os.environ.get('RANK', 0)) tag = info_dict["tag"] model_dir = info_dict["model_dir"] save_model_path = os.path.join(model_dir, '{}.pt'.format(tag)) # save ckpt if info_dict["train_engine"] == "deepspeed": # NOTE(xcsong): All ranks should call this API, but only rank 0 # save the general model params. see: # https://github.com/microsoft/DeepSpeed/issues/2993 with torch.no_grad(): model.save_checkpoint(save_dir=model_dir, tag=tag, client_state=info_dict) if info_dict["save_states"] == "model_only" and rank == 0: convert_zero_checkpoint_to_fp32_state_dict(model_dir, save_model_path, tag=tag) os.system("rm -rf {}/{}".format(model_dir, tag)) elif info_dict['train_engine'] == "torch_fsdp": fsdp_save_model(model, save_model_path, info_dict) elif rank == 0: # NOTE(xcsong): For torch_ddp, only rank-0 should call this. save_checkpoint(model, save_model_path, info_dict) # save yaml if rank == 0: with open("{}/{}.yaml".format(model_dir, tag), 'w') as fout: data = yaml.dump(info_dict) fout.write(data) def wenet_join(group_join, info_dict): world_size = int(os.environ.get('WORLD_SIZE', 1)) local_rank = int(os.environ.get('LOCAL_RANK', 0)) rank = int(os.environ.get('RANK', 0)) train_engine = info_dict.get('train_engine', "torch_ddp") if info_dict["batch_idx"] == 0 or train_engine == "torch_ddp": # NOTE(xcsong): skip first batch because its processing time includes # dataloader initialization time, which may exceed 30 seconds return False try: # NOTE(xcsong): Why we need a new group? # Because Deepspeed has its own group where all the relevant communication # operations are executed. If we add a communication operation that is not # managed by Deepspeed in this group, it's highly likely to cause # communication chaos, resulting in hard-to-troubleshoot hangs. dist.monitored_barrier(group=group_join, timeout=group_join.options._timeout) except RuntimeError as e: logging.info("Detected uneven workload distribution: {}\n".format(e) + "Break current worker to manually join all workers, " + "world_size {}, current rank {}, current local_rank {}\n". format(world_size, rank, local_rank)) return True return False def batch_forward(model, batch, scaler, info_dict, device): train_engine = info_dict.get('train_engine', "torch_ddp") accum_grad = info_dict.get('accum_grad', 1) dtype = info_dict.get("dtype", "fp32") if dtype == "fp16": dtype = torch.float16 elif dtype == "bf16": dtype = torch.bfloat16 else: # fp32 dtype = None # autocast context # The more details about amp can be found in # https://pytorch.org/docs/stable/notes/amp_examples.html amp_autocast = torch.cuda.amp.autocast if "npu" in device.__str__() and TORCH_NPU_AVAILABLE: amp_autocast = torch.npu.amp.autocast autocast = { "deepspeed": amp_autocast(enabled=dtype is not None, dtype=dtype, cache_enabled=False), "torch_ddp": amp_autocast(enabled=scaler is not None), "torch_fsdp": amp_autocast(enabled=True, dtype=dtype) if dtype is not None else nullcontext() }[train_engine] with autocast: loss_dict = model(batch, device) info_dict['loss_dict'] = loss_dict return info_dict def batch_backward(model, scaler, info_dict): train_engine = info_dict.get("train_engine", "torch_ddp") accum_grad = info_dict.get('accum_grad', 1) use_amp = info_dict.get('use_amp', False) if use_amp: assert scaler is not None loss = info_dict['loss_dict']['loss'] if train_engine == "deepspeed": # NOTE(xcsong): `model.backward(loss)` is equivalent to # `scale_loss_wrt_accum_grad + loss.backward()` # ref: https://www.deepspeed.ai/tutorials/megatron/#using-the-training-api scaled_loss = model.backward(loss) else: assert train_engine in ["torch_ddp", "torch_fsdp"] scaled_loss = loss / accum_grad if scaler is not None: # fp16 (amp and fsdp) scaler.scale(scaled_loss).backward() else: # float32 (ddp and fsdp) # bf16 (fsdp) scaled_loss.backward() info_dict['loss_dict']['loss'] = scaled_loss for loss_name, loss_value in info_dict['loss_dict'].items(): if loss_value is not None: info_dict['loss_dict'][loss_name] = tensor_to_scalar(loss_value) return info_dict def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): rank = int(os.environ.get('RANK', 0)) train_engine = info_dict.get("train_engine", "torch_ddp") accum_grad = info_dict.get('accum_grad', 1) use_amp = info_dict.get('use_amp', False) clip = info_dict.get('grad_clip', 50.0) batch_idx = info_dict["batch_idx"] if use_amp: assert scaler is not None grad_norm = 0.0 if train_engine == "deepspeed": # NOTE(xcsong): The step() function in DeepSpeed engine updates the # model parameters as well as the learning rate. # Zeroing the gradients is handled automatically by # DeepSpeed after the weights have been updated using a mini-batch. # DeepSpeed also performs gradient averaging automatically at the # gradient accumulation boundaries and addresses clip_grad_norm internally. # `ds_model.step() = clip_grad_norm_() + optimizer.step() # + optimizer.zero_grad() + scheduler.step()` # ref: https://www.deepspeed.ai/tutorials/megatron/#using-the-training-api info_dict["is_gradient_accumulation_boundary"] = \ model.is_gradient_accumulation_boundary() model.step() grad_norm = model.get_global_grad_norm() if grad_norm is None: grad_norm = 0.0 elif (batch_idx + 1) % accum_grad == 0: # Use mixed precision training # fp16 (ddp fsdp) if scaler is not None: scaler.unscale_(optimizer) if train_engine == "torch_ddp": grad_norm = clip_grad_norm_(model.parameters(), clip) else: # fsdp grad_norm = model.clip_grad_norm_(clip) # Must invoke scaler.update() if unscale_() is used in # the iteration to avoid the following error: # RuntimeError: unscale_() has already been called # on this optimizer since the last update(). # We don't check grad here since that if the gradient # has inf/nan values, scaler.step will skip # optimizer.step(). scaler.step(optimizer) scaler.update() else: if train_engine == "torch_ddp": grad_norm = clip_grad_norm_(model.parameters(), clip) else: grad_norm = model.clip_grad_norm_(clip) if torch.isfinite(grad_norm): optimizer.step() optimizer.zero_grad() scheduler.step() info_dict["lrs"] = [group['lr'] for group in optimizer.param_groups] info_dict["grad_norm"] = tensor_to_scalar(grad_norm) return info_dict def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None): tag = info_dict["tag"] step = info_dict["step"] batch_idx = info_dict["batch_idx"] loss_dict = info_dict['loss_dict'] epoch = info_dict.get('epoch', 0) train_engine = info_dict.get("train_engine", "torch_ddp") accum_grad = info_dict.get('accum_grad', 1) if tag != "CV" else 1 log_interval = info_dict.get('log_interval', 10) lrs = info_dict.get("lrs", [0.0]) is_gradient_accumulation_boundary = info_dict.get( "is_gradient_accumulation_boundary", False) rank = int(os.environ.get('RANK', 0)) # TRAIN Tensorboard if tag == "TRAIN" and rank == 0 and writer is not None: if (train_engine == "deepspeed" and is_gradient_accumulation_boundary ) or (train_engine in ["torch_ddp", "torch_fsdp"] and (batch_idx + 1) % accum_grad == 0): writer.add_scalar('train/train_loss', tensor_to_scalar(loss_dict['loss']) * accum_grad, step) writer.add_scalar('train/grad_norm', info_dict['grad_norm'], step) for name, value in loss_dict.items(): if name != 'loss' and value is not None: writer.add_scalar('train/{}'.format(name), tensor_to_scalar(value), step) # lr for i, lr in enumerate(lrs): writer.add_scalar('train/lr_{}'.format(i), lr, step) # CV Tensorboard elif "step_" in tag and rank == 0 and writer is not None: for name, value in loss_dict.items(): writer.add_scalar('cv/{}'.format(name), tensor_to_scalar(value), step) logging.info( 'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format( epoch, step + 1, lrs_to_str(lrs), tensor_to_scalar(loss_dict["loss"]), rank, tensor_to_scalar(loss_dict["acc"]))) return # TRAIN & CV, Shell log (stdout) if (batch_idx + 1) % log_interval == 0: log_str = '{} | '.format(tag) if timer is not None: timer_step = step if info_dict.get("cv_step", None) is not None: timer_step = info_dict['cv_step'] steps_per_second = timer.steps_per_second(timer_step) log_str += 'steps/sec {:.3f}| '.format(steps_per_second) log_str += 'Batch {}/{} loss {:.6f} '.format( epoch, batch_idx + 1 if 'save_interval' not in info_dict else (step + 1) * accum_grad, tensor_to_scalar(loss_dict['loss']) * accum_grad) for name, value in loss_dict.items(): if name != 'loss' and value is not None: log_str += '{} {:.6f} '.format(name, tensor_to_scalar(value)) if tag == "TRAIN": log_str += 'lr {} grad_norm {:.6f} rank {}'.format( lrs_to_str(lrs), info_dict['grad_norm'], rank) logging.debug(log_str) def log_per_epoch(writer, info_dict): epoch = info_dict["epoch"] loss_dict = info_dict["loss_dict"] lrs = info_dict['lrs'] rank = int(os.environ.get('RANK', 0)) step = info_dict["step"] logging.info( 'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format( epoch, step, lrs_to_str(lrs), tensor_to_scalar(loss_dict["loss"]), rank, tensor_to_scalar(loss_dict["acc"]))) if int(os.environ.get('RANK', 0)) == 0: for i, lr in enumerate(info_dict["lrs"]): writer.add_scalar('epoch/lr_{}'.format(i), lr, epoch) for name, value in loss_dict.items(): writer.add_scalar('epoch/{}'.format(name), tensor_to_scalar(value), epoch) def freeze_modules(model, args): for name, param in model.named_parameters(): for module_name in args.freeze_modules: if module_name in name: param.requires_grad = False logging.debug("{} module is freezed".format(name)) def reinit_lora(model, args, configs, tokenizer, seed=777): from tqdm import tqdm from wenet.finetune.lora.utils import estimate_gradient, reinit_lora_modules from wenet.finetune.lora.layers import LoRALayer from types import SimpleNamespace logging.info("reinit lora modules.") with open(args.lora_init_yaml, 'r') as file: lora_config = yaml.safe_load(file) generator = torch.Generator() generator.manual_seed(seed) dataset_conf = copy.deepcopy(configs['dataset_conf']) dataset_conf['batch_conf']['batch_size'] = lora_config['init_batch_size'] dataset_type = configs.get('dataset', 'asr') dataset = init_dataset(dataset_type, args.data_type, args.train_data, tokenizer, dataset_conf, True) dataloader = DataLoader(dataset, batch_size=None, pin_memory=args.pin_memory, num_workers=args.num_workers, persistent_workers=True, generator=generator, prefetch_factor=args.prefetch) additional_kwargs = {} if lora_config["init_config"]["mode"] == "gradient": named_grads = estimate_gradient(model, dataloader, lora_config['init_iters']) additional_kwargs["named_grads"] = named_grads lora_config = SimpleNamespace(**lora_config["init_config"]) for name, module in tqdm( model.named_modules(), desc="Reinitializing Lora", total=len(list(model.named_modules())), ): if isinstance(module, LoRALayer): reinit_lora_modules(name, module, lora_config, **additional_kwargs) # lora_init_model needs to be saved, w0 = w0 - A0 * B0 save_checkpoint(model, os.path.join(args.model_dir, "lora_init.pt"), infos={"tag": "lora_init", **configs})