File size: 5,486 Bytes
568e264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from functools import partial
import os
from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP,
                                    FullStateDictConfig, StateDictType)

from torch.distributed.fsdp.wrap import (lambda_auto_wrap_policy,
                                         transformer_auto_wrap_policy)
from wenet.LLM.decoder import DecoderOnly
from wenet.branchformer.encoder_layer import BranchformerEncoderLayer
from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer
from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer
from wenet.paraformer.layers import AliParaformerEncoderLayer, SanmDecoderLayer
from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer
from wenet.transformer.encoder_layer import (ConformerEncoderLayer,
                                             TransformerEncoderLayer)
from wenet.transformer.decoder_layer import DecoderLayer
from wenet.utils.checkpoint import save_state_dict_and_infos
from wenet.utils.init_model import WENET_DECODER_CLASSES, WENET_ENCODER_CLASSES

WENET_ENCODER_LAYERS_CLASSES = {
    'transformer_encoder_layer': TransformerEncoderLayer,
    'conformer_encoder_layer': ConformerEncoderLayer,
    'paraformer_encoder_layer': AliParaformerEncoderLayer,
    'squeezeformer_encoder_layer': SqueezeformerEncoderLayer,
    'ebranchformer_encoder_layer': EBranchformerEncoderLayer,
    'efficient_conformer_encoder_layer': StrideConformerEncoderLayer,
    'branchformer_encoder_layer': BranchformerEncoderLayer,
}

WENET_DECODER_LAYERS_CLASSES = {
    'transformer_decoder_layer': DecoderLayer,
    'paraformer_decoder_layer': SanmDecoderLayer,
    # TODO(Mddct):
    #     1 wrap transducer's predictor and joint
    #     2 wrap paraformer's cif and ignore lstm
}


def wenet_fsdp_wrap_policy(mode):
    # different wrap methods
    # please refer: https://openmmlab.medium.com/its-2023-is-pytorch-s-fsdp-the-best-choice-for-training-large-models-fe8d2848832f # noqa
    assert mode in ['no_shard', 'model', 'zero2', 'zero3']
    if mode == 'no_shard':
        return None
    else:
        # TODO(Mddct):  Support user customization
        # see more wrap methods:
        # https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/fsdp_utils.py#L13 # noqa
        if mode == 'model':
            enc_dec_wrap_policy = partial(
                lambda_auto_wrap_policy,
                lambda_fn=lambda module: isinstance(
                    module,
                    tuple(WENET_ENCODER_CLASSES.values()) + tuple(
                        WENET_DECODER_CLASSES.values())))
            return enc_dec_wrap_policy
        else:
            to_wrap_class = set()
            to_wrap_class.update(set(WENET_ENCODER_LAYERS_CLASSES.values()))
            to_wrap_class.update(set(WENET_DECODER_LAYERS_CLASSES.values()))
            layers_wrap_policy = partial(transformer_auto_wrap_policy,
                                         transformer_layer_cls=to_wrap_class)
            return layers_wrap_policy


fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True,
                                            rank0_only=True)


def fsdp_save_model(model, save_model_path, info_dict):
    # TODO(Mddct); When the model is large, saving a model will take a long time.
    # We only need to keep the sharding in an asynchronous manner, but it is
    # good now. This feature will be supported when llm is supported in the future.

    rank = int(os.environ.get('RANK', 0))
    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT,
                              fullstate_save_policy):
        state_dict = model.state_dict()
        if rank == 0:
            save_state_dict_and_infos(state_dict, save_model_path, info_dict)


def check_gradient_checkpoint(model):
    ckpt_laye_types = []
    if hasattr(model, 'encoder') and hasattr(model.encoder,
                                             'gradient_checkpointing'):
        if model.encoder.gradient_checkpointing:
            model.encoder.gradient_checkpointing = False
            ckpt_laye_types += list(WENET_ENCODER_LAYERS_CLASSES.values())
    if hasattr(model, 'decoder') and hasattr(model.decoder,
                                             'gradient_checkpointing'):
        if model.decoder.gradient_checkpointing:
            model.decoder.gradient_checkpointing = False
            ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values())
            if isinstance(model.decoder, DecoderOnly):
                ckpt_laye_types += [DecoderOnly]
    return tuple(ckpt_laye_types)


def apply_fsdp_checkpointing(model, ckpt_layer_types: tuple):
    # NOTE(Mddct):  torch.utils.checkpoint is currently incompatible with
    # wenet's model mode. Using this writing method, Please refer to
    # https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/policies/activation_checkpointing_functions.py#L21 # noqa
    if len(ckpt_layer_types) == 0:
        return
    from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
        checkpoint_wrapper,
        CheckpointImpl,
        apply_activation_checkpointing,
    )
    non_reentrant_wrapper = partial(
        checkpoint_wrapper,
        checkpoint_impl=CheckpointImpl.NO_REENTRANT,
    )
    apply_activation_checkpointing(
        model,
        checkpoint_wrapper_fn=non_reentrant_wrapper,
        check_fn=lambda submodule: isinstance(submodule, ckpt_layer_types))