Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,486 Bytes
568e264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from functools import partial
import os
from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP,
FullStateDictConfig, StateDictType)
from torch.distributed.fsdp.wrap import (lambda_auto_wrap_policy,
transformer_auto_wrap_policy)
from wenet.LLM.decoder import DecoderOnly
from wenet.branchformer.encoder_layer import BranchformerEncoderLayer
from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer
from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer
from wenet.paraformer.layers import AliParaformerEncoderLayer, SanmDecoderLayer
from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer
from wenet.transformer.encoder_layer import (ConformerEncoderLayer,
TransformerEncoderLayer)
from wenet.transformer.decoder_layer import DecoderLayer
from wenet.utils.checkpoint import save_state_dict_and_infos
from wenet.utils.init_model import WENET_DECODER_CLASSES, WENET_ENCODER_CLASSES
WENET_ENCODER_LAYERS_CLASSES = {
'transformer_encoder_layer': TransformerEncoderLayer,
'conformer_encoder_layer': ConformerEncoderLayer,
'paraformer_encoder_layer': AliParaformerEncoderLayer,
'squeezeformer_encoder_layer': SqueezeformerEncoderLayer,
'ebranchformer_encoder_layer': EBranchformerEncoderLayer,
'efficient_conformer_encoder_layer': StrideConformerEncoderLayer,
'branchformer_encoder_layer': BranchformerEncoderLayer,
}
WENET_DECODER_LAYERS_CLASSES = {
'transformer_decoder_layer': DecoderLayer,
'paraformer_decoder_layer': SanmDecoderLayer,
# TODO(Mddct):
# 1 wrap transducer's predictor and joint
# 2 wrap paraformer's cif and ignore lstm
}
def wenet_fsdp_wrap_policy(mode):
# different wrap methods
# please refer: https://openmmlab.medium.com/its-2023-is-pytorch-s-fsdp-the-best-choice-for-training-large-models-fe8d2848832f # noqa
assert mode in ['no_shard', 'model', 'zero2', 'zero3']
if mode == 'no_shard':
return None
else:
# TODO(Mddct): Support user customization
# see more wrap methods:
# https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/fsdp_utils.py#L13 # noqa
if mode == 'model':
enc_dec_wrap_policy = partial(
lambda_auto_wrap_policy,
lambda_fn=lambda module: isinstance(
module,
tuple(WENET_ENCODER_CLASSES.values()) + tuple(
WENET_DECODER_CLASSES.values())))
return enc_dec_wrap_policy
else:
to_wrap_class = set()
to_wrap_class.update(set(WENET_ENCODER_LAYERS_CLASSES.values()))
to_wrap_class.update(set(WENET_DECODER_LAYERS_CLASSES.values()))
layers_wrap_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls=to_wrap_class)
return layers_wrap_policy
fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True,
rank0_only=True)
def fsdp_save_model(model, save_model_path, info_dict):
# TODO(Mddct); When the model is large, saving a model will take a long time.
# We only need to keep the sharding in an asynchronous manner, but it is
# good now. This feature will be supported when llm is supported in the future.
rank = int(os.environ.get('RANK', 0))
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT,
fullstate_save_policy):
state_dict = model.state_dict()
if rank == 0:
save_state_dict_and_infos(state_dict, save_model_path, info_dict)
def check_gradient_checkpoint(model):
ckpt_laye_types = []
if hasattr(model, 'encoder') and hasattr(model.encoder,
'gradient_checkpointing'):
if model.encoder.gradient_checkpointing:
model.encoder.gradient_checkpointing = False
ckpt_laye_types += list(WENET_ENCODER_LAYERS_CLASSES.values())
if hasattr(model, 'decoder') and hasattr(model.decoder,
'gradient_checkpointing'):
if model.decoder.gradient_checkpointing:
model.decoder.gradient_checkpointing = False
ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values())
if isinstance(model.decoder, DecoderOnly):
ckpt_laye_types += [DecoderOnly]
return tuple(ckpt_laye_types)
def apply_fsdp_checkpointing(model, ckpt_layer_types: tuple):
# NOTE(Mddct): torch.utils.checkpoint is currently incompatible with
# wenet's model mode. Using this writing method, Please refer to
# https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/policies/activation_checkpointing_functions.py#L21 # noqa
if len(ckpt_layer_types) == 0:
return
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=lambda submodule: isinstance(submodule, ckpt_layer_types))
|