File size: 9,955 Bytes
de4ade4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# helper functions from https://github.com/CarperAI/trlx/blob/main/trlx/utils/modeling.py
# which is MIT licensed

import functools
from typing import Any, Iterable, List, Optional

import torch
from transformers import PreTrainedModel
from transformers.models.opt.modeling_opt import OPTDecoder


# helper functions
def rhasattr(obj: Any, attr: str) -> bool:
    """A chain-able attribute version of hasattr.

    For example, to check if
    `obj` has the attribute `foo.bar.baz`, you can use:
        `rhasattr(obj, "foo.bar.baz")`
    Reference: https://stackoverflow.com/a/67303315
    """
    _nested_attrs = attr.split('.')
    _curr_obj = obj
    for _a in _nested_attrs[:-1]:
        if hasattr(_curr_obj, _a):
            _curr_obj = getattr(_curr_obj, _a)
        else:
            return False
    return hasattr(_curr_obj, _nested_attrs[-1])


def rgetattr(obj: Any, attr: str, *args: List[Any]) -> Any:
    """A chain-able attribute version of getattr.

    For example, to get the attribute `foo.bar.baz` from `obj`, you can use:
        `rgetattr(obj, "foo.bar.baz")`
    Reference: https://stackoverflow.com/a/31174427
    """

    def _getattr(obj: Any, attr: str):
        return getattr(obj, attr, *args)

    return functools.reduce(_getattr, [obj] + attr.split('.'))


def findattr(obj: Any, attrs: Iterable[str]) -> Optional[Any]:
    for attr in attrs:
        if rhasattr(obj, attr):
            return rgetattr(obj, attr)
    return None


def hf_get_causal_base_model(model: PreTrainedModel) -> Any:
    """Returns the causal decoder backbone of the specified HuggingFace model.

    Newer HF models have a `self.get_decoder()` method. Older models do not.

    NOTE: Different model configurations have different causal decoder attribute
    names.
        - transformer: (GPT2LMHeadModel, GPTJConfig)
        - model.decoder: (OPTConfig, BloomConfig)
        - gpt_neox: (GPTNeoXConfig)
    """
    if hasattr(model, 'get_decoder'):
        return model.get_decoder()

    decoder_attrs = ('transformer', 'model.decoder', 'gpt_neox')
    causal_base_model = findattr(model, decoder_attrs)
    if causal_base_model is None:
        raise ValueError(
            f'Unable to FSDP-wrap model {model}. Please open a github issue to add support.'
        )
    return causal_base_model


def hf_get_hidden_layers(model: PreTrainedModel) -> Any:
    """Returns the hidden layers of the specified model.

    NOTE: Different model configurations have different hidden layer attribute names.
        - transformer.h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM)
        - model.decoder.layers: (OPTForCausalLM)
        - gpt_neox.layers: (GPTNeoXForCausalLM)
        - model.layers: (LlaMaForCausalLM)
        - transformer.blocks: (MPTForCausalLM)
    """
    hidden_layers_attrs = (
        'transformer.h',  # BLOOM, GPT2, GPTJ
        'model.decoder.layers',  # OPT
        'gpt_neox.layers',  # GPTNeoX
        'block',  # T5, BART, Pegasus (from encoder)
        'layers',  # ProphetNet, Marian (from encoder)
        'model.layers',  # LLaMa
        'transformer.blocks',  # MPT
    )
    layers = findattr(model, hidden_layers_attrs)
    if layers is None:
        raise ValueError(
            f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}'
        )
    return layers


def hf_get_init_device(init_device: Optional[str]) -> Optional[str]:
    """Returns the appropriate device to initialize models."""
    from composer.utils import dist
    if init_device == 'mixed':
        if dist.get_local_rank() == 0:
            return 'cpu'
        return 'meta'
    return init_device


# /end helper functions


def prepare_hf_model_for_fsdp(model: PreTrainedModel,
                              init_device: Optional[str]) -> None:
    """FSDP wrap a HuggingFace model.

    Call specific functions
    """
    if model.config.is_encoder_decoder:
        prepare_hf_enc_dec_model_for_fsdp(model, init_device)
    else:
        # many common decoder-only model do not set the flag
        # model.config.is_decoder, so we can't trust it
        prepare_hf_causal_lm_model_for_fsdp(model, init_device)


def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel,
                                        init_device: Optional[str]) -> None:
    """FSDP wrap a HuggingFace decoder.

    Wrap any model for FSDP which follows one of the 3 existing conventions from
    HuggingFace for decoder-only LLMs.
    """
    causal_base_model = hf_get_causal_base_model(model)

    # OPT has an extra layer of wrapping, so special case here
    if isinstance(causal_base_model, OPTDecoder):
        model.model._fsdp_wrap = False
    model_block = hf_get_hidden_layers(model)
    lm_head = model.get_output_embeddings()
    # some models (OPT) implement .get_input_embeddings for the causal subclass
    # but all of them implement it for the base model
    tied_embeddings = causal_base_model.get_input_embeddings()
    modules = {
        'base_model': causal_base_model,
        'model_block': model_block,
        'lm_head': lm_head,
        'tied_embeddings': tied_embeddings
    }

    for mod_name, module in modules.items():
        if module is None:
            raise ValueError(
                f'Unable to FSDP-wrap this model! `{mod_name}` does not ' +
                'follow common layer/weight naming conventions.')
    block_type = type(model_block[0])
    if init_device == 'mixed':
        # For FSDP with models with different device initializations, `mixed`, which
        # initializes the model on rank 0 on `cpu` and on all other ranks on `meta,``
        # we need to tag all child modules that are torch.nn.Modules with `_fsdp_wrap`.
        for child in model.children():
            if isinstance(child, type(causal_base_model)):
                continue
            if isinstance(child, torch.nn.Module):
                child._fsdp_wrap = True

        for child in causal_base_model.children():
            if isinstance(child, torch.nn.ModuleList):
                continue
            if isinstance(child, torch.nn.Module):
                child._fsdp_wrap = True

        if model.config.tie_word_embeddings and not model.config.model_type == 'mpt':
            raise ValueError(
                'The passed in HuggingFaceModel has tied word embeddings ' +
                'and the passed in initialization device is `mixed.` ' +
                'In order to support this initialization scheme, we would need to break '
                +
                'the weight tying. As a result, either use a different initialization scheme '
                + 'or in the model config set `tie_word_embeddings=False.`')
    else:
        # When using the HF LM models,
        # the weights of the self.lm_head and self.transformer.wte are tied.
        # This tying occurs inside the `self.post_init()` function.
        # This is a hurdle for FSDP because they need to be in the same FSDP block
        # These lines ensures that both modules stay together in the top-most block when
        # the model has this tying enabled (almost all do; this property defaults to True)
        if model.config.tie_word_embeddings:
            causal_base_model._fsdp_wrap = False
            tied_embeddings._fsdp_wrap = False
            lm_head._fsdp_wrap = False

    # FSDP Wrap and Activation Checkpoint every model block
    model.fsdp_wrap_fn = lambda module: isinstance(module, block_type)
    model.activation_checkpointing_fn = lambda module: isinstance(
        module, block_type)


def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel,
                                      init_device: Optional[str]) -> None:
    """Wrap an encoder/decoder HF model.

    This works for T5, BART, Pegasus, PegasusX, but not all enc/dec (ProphetNet)
    You have model.shared, model.encoder, model.decoder and model.lm_head, where
    model.shared are the embeddings which are tied to model.lm_head, and
    model.shared == model.encoder.embed_tokens and model.shared ==
    model.decoder.embed_tokens
    """
    tied_embeddings = model.get_input_embeddings()
    encoder = model.get_encoder()
    decoder = model.get_decoder()
    lm_head = model.get_output_embeddings()
    # some encoder/decoders have different layers for encoder vs decoder
    encoder_block = hf_get_hidden_layers(encoder)
    decoder_block = hf_get_hidden_layers(decoder)

    modules = {
        'encoder': encoder,
        'decoder': decoder,
        'encoder_block': encoder_block,
        'decoder_block': decoder_block,
        'lm_head': lm_head,
        'tied_embeddings': tied_embeddings
    }

    for mod_name, module in modules.items():
        if module is None:
            raise ValueError(
                f'Unable to FSDP-wrap this model! `{mod_name}` does not ' +
                'follow common layer/weight naming conventions.')
    decoder_block_type = type(decoder_block[0])
    encoder_block_type = type(encoder_block[0])

    if model.config.tie_word_embeddings:
        # it is possible to train an enc/dec without tied embeddings, hence the check
        tied_embeddings._fsdp_wrap = False
        encoder._fsdp_wrap = False
        decoder._fsdp_wrap = False
        lm_head._fsdp_wrap = False

    # FSDP Wrap and Activation Checkpoint every decoder block
    model.fsdp_wrap_fn = lambda module: isinstance(module, decoder_block_type)
    model.activation_checkpointing_fn = lambda module: isinstance(
        module, decoder_block_type)

    if encoder_block_type == decoder_block_type:
        return

    # need to wrap encoder blocks separately for ProhpetNet and Marian
    model.fsdp_wrap_fn = lambda module: isinstance(module, encoder_block_type)
    model.activation_checkpointing_fn = lambda module: isinstance(
        module, encoder_block_type)