from functools import partial import os from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP, FullStateDictConfig, StateDictType) from torch.distributed.fsdp.wrap import (lambda_auto_wrap_policy, transformer_auto_wrap_policy) from wenet.LLM.decoder import DecoderOnly from wenet.branchformer.encoder_layer import BranchformerEncoderLayer from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer from wenet.paraformer.layers import AliParaformerEncoderLayer, SanmDecoderLayer from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer from wenet.transformer.encoder_layer import (ConformerEncoderLayer, TransformerEncoderLayer) from wenet.transformer.decoder_layer import DecoderLayer from wenet.utils.checkpoint import save_state_dict_and_infos from wenet.utils.init_model import WENET_DECODER_CLASSES, WENET_ENCODER_CLASSES WENET_ENCODER_LAYERS_CLASSES = { 'transformer_encoder_layer': TransformerEncoderLayer, 'conformer_encoder_layer': ConformerEncoderLayer, 'paraformer_encoder_layer': AliParaformerEncoderLayer, 'squeezeformer_encoder_layer': SqueezeformerEncoderLayer, 'ebranchformer_encoder_layer': EBranchformerEncoderLayer, 'efficient_conformer_encoder_layer': StrideConformerEncoderLayer, 'branchformer_encoder_layer': BranchformerEncoderLayer, } WENET_DECODER_LAYERS_CLASSES = { 'transformer_decoder_layer': DecoderLayer, 'paraformer_decoder_layer': SanmDecoderLayer, # TODO(Mddct): # 1 wrap transducer's predictor and joint # 2 wrap paraformer's cif and ignore lstm } def wenet_fsdp_wrap_policy(mode): # different wrap methods # please refer: https://openmmlab.medium.com/its-2023-is-pytorch-s-fsdp-the-best-choice-for-training-large-models-fe8d2848832f # noqa assert mode in ['no_shard', 'model', 'zero2', 'zero3'] if mode == 'no_shard': return None else: # TODO(Mddct): Support user customization # see more wrap methods: # https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/fsdp_utils.py#L13 # noqa if mode == 'model': enc_dec_wrap_policy = partial( lambda_auto_wrap_policy, lambda_fn=lambda module: isinstance( module, tuple(WENET_ENCODER_CLASSES.values()) + tuple( WENET_DECODER_CLASSES.values()))) return enc_dec_wrap_policy else: to_wrap_class = set() to_wrap_class.update(set(WENET_ENCODER_LAYERS_CLASSES.values())) to_wrap_class.update(set(WENET_DECODER_LAYERS_CLASSES.values())) layers_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls=to_wrap_class) return layers_wrap_policy fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) def fsdp_save_model(model, save_model_path, info_dict): # TODO(Mddct); When the model is large, saving a model will take a long time. # We only need to keep the sharding in an asynchronous manner, but it is # good now. This feature will be supported when llm is supported in the future. rank = int(os.environ.get('RANK', 0)) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, fullstate_save_policy): state_dict = model.state_dict() if rank == 0: save_state_dict_and_infos(state_dict, save_model_path, info_dict) def check_gradient_checkpoint(model): ckpt_laye_types = [] if hasattr(model, 'encoder') and hasattr(model.encoder, 'gradient_checkpointing'): if model.encoder.gradient_checkpointing: model.encoder.gradient_checkpointing = False ckpt_laye_types += list(WENET_ENCODER_LAYERS_CLASSES.values()) if hasattr(model, 'decoder') and hasattr(model.decoder, 'gradient_checkpointing'): if model.decoder.gradient_checkpointing: model.decoder.gradient_checkpointing = False ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values()) if isinstance(model.decoder, DecoderOnly): ckpt_laye_types += [DecoderOnly] return tuple(ckpt_laye_types) def apply_fsdp_checkpointing(model, ckpt_layer_types: tuple): # NOTE(Mddct): torch.utils.checkpoint is currently incompatible with # wenet's model mode. Using this writing method, Please refer to # https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/policies/activation_checkpointing_functions.py#L21 # noqa if len(ckpt_layer_types) == 0: return from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, CheckpointImpl, apply_activation_checkpointing, ) non_reentrant_wrapper = partial( checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) apply_activation_checkpointing( model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=lambda submodule: isinstance(submodule, ckpt_layer_types))