File size: 10,766 Bytes
de4ade4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`."""

import logging
import os
from typing import Mapping, Union

# required for loading a python model into composer
import transformers
from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy,
                                  InContextLearningLMAccuracy,
                                  InContextLearningLMExpectedCalibrationError,
                                  InContextLearningMCExpectedCalibrationError,
                                  InContextLearningMultipleChoiceAccuracy,
                                  InContextLearningQAAccuracy,
                                  LanguageCrossEntropy, LanguagePerplexity)
from composer.utils import dist
from omegaconf import DictConfig
from torch import nn
from transformers import (AutoConfig, AutoModelForCausalLM,
                          PreTrainedTokenizerBase)

from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.layers.llama_attention_monkeypatch import \
    get_llama_attention_patch_fn
from llmfoundry.models.utils import init_empty_weights

try:
    from peft.peft_model import PeftModel
    model_types = PeftModel, transformers.PreTrainedModel

except ImportError:
    model_types = transformers.PreTrainedModel

__all__ = ['ComposerHFCausalLM']

log = logging.getLogger(__name__)


class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
    """Configures a :class:`.HuggingFaceModel` around a Causal LM.

    Args:
        om_model_config (DictConfig | PeftModel | transformers.PreTrainedModel): either an omegaconf dictionary used to configure the model, or an instantiated model object from the peft or transformers library.
        if DictConfig, the following keys are required:
            cfg.pretrained_model_name_or_path (str): The name of or local path to
                the HF Causal LM (e.g., `gpt2` to instantiate a GPT2LMHeadModel).
            cfg.config_overrides (dict, optional): An optional dictionary of keyword
                arguments that override the default configuration associated with
                cfg.pretrained_model_name_or_path.
            cfg.pretrained (bool): Whether to instantiate the model with pre-trained
                weights coming from cfg.pretrained_model_name_or_path. If ``True``,
                cfg.config_overrides must be compatible with the pre-trained weights.
            cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to
                initialize the model on. Currently, `meta` is only supported when
                cfg.pretrained is ``False``. Default: ``'cpu'``.
        tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
    """

    def __init__(self, om_model_config: Union[DictConfig,
                                              transformers.PreTrainedModel,
                                              nn.Module],
                 tokenizer: PreTrainedTokenizerBase):
        # set up training and eval metrics
        train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()]
        eval_metrics = [
            LanguageCrossEntropy(),
            LanguagePerplexity(),
            InContextLearningLMAccuracy(),
            InContextLearningMultipleChoiceAccuracy(),
            InContextLearningQAAccuracy(),
            InContextLearningCodeEvalAccuracy(),
            InContextLearningLMExpectedCalibrationError(),
            InContextLearningMCExpectedCalibrationError()
        ]

        # if we are passed a DictConfig, we need to instantiate the model
        if isinstance(om_model_config, DictConfig):
            if not om_model_config.get('trust_remote_code',
                                       True) and om_model_config.get(
                                           'pretrained_model_name_or_path',
                                           None).startswith('mosaicml/mpt'):
                raise ValueError(
                    'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, '
                    +
                    'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.'
                )

            if not om_model_config.get('use_train_metrics', True):
                train_metrics = []

            # load the model config
            trust_remote_code = om_model_config.get('trust_remote_code', True)
            use_auth_token = om_model_config.get('use_auth_token', False)
            config = AutoConfig.from_pretrained(
                om_model_config.pretrained_model_name_or_path,
                trust_remote_code=trust_remote_code,
                use_auth_token=use_auth_token,
            )

            # set config overrides
            for k, v in om_model_config.get('config_overrides', {}).items():
                if not hasattr(config, k):
                    raise ValueError(
                        f'config does not have attribute "{k}" to override ({k}: {v}).'
                    )

                attr = getattr(config, k)
                # attempt to disallow typos in nested configs
                if isinstance(attr, Mapping):
                    extra_keys = [
                        _k for _k in v.keys() if _k not in attr.keys()
                    ]
                    if extra_keys:
                        raise ValueError(
                            f'Config dict override got unknown keys. ' +
                            f'Extra keys: {extra_keys}. ' +
                            f'Expected (a subset of) keys: {list(attr.keys())}.'
                        )
                    getattr(config, k).update(v)
                # necessary case to allow for rope_scaling to be overriden in llama config
                elif attr is None and isinstance(v, Mapping):
                    setattr(config, k, {})
                    getattr(config, k).update(v)
                else:
                    setattr(config, k, v)

            load_in_8bit = om_model_config.get('load_in_8bit', False)

            # below we set up the device to initialize the model on
            init_device = om_model_config.get('init_device', 'cpu')

            # Get the device we want to initialize, and use the
            # reolved version to initialize the HF model
            resolved_init_device = hf_get_init_device(init_device)

            # We need to have all non-zero local ranks be not-pretrained
            # Rank 0 will still be pretrained, and distribute the weights appropriately
            if dist.get_local_rank() != 0 and init_device == 'mixed':
                om_model_config.pretrained = False

            # initialize the model on the correct device
            if resolved_init_device == 'cpu':
                if om_model_config.pretrained:
                    model = AutoModelForCausalLM.from_pretrained(
                        om_model_config.pretrained_model_name_or_path,
                        trust_remote_code=trust_remote_code,
                        use_auth_token=use_auth_token,
                        load_in_8bit=load_in_8bit,
                        config=config)
                else:
                    model = AutoModelForCausalLM.from_config(
                        config,
                        trust_remote_code=trust_remote_code,
                    )
            elif resolved_init_device == 'meta':
                if om_model_config.pretrained:
                    raise ValueError(
                        'Setting cfg.pretrained=True is not supported when init_device="meta".'
                    )
                with init_empty_weights(include_buffers=False):
                    model = AutoModelForCausalLM.from_config(
                        config,
                        trust_remote_code=trust_remote_code,
                    )
            else:
                raise ValueError(
                    f'init_device="{init_device}" must be either "cpu" or "meta".'
                )

            signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed'
            if dist.get_local_rank() == 0:
                with open(signal_file_path, 'wb') as f:
                    f.write(b'local_rank0_completed_download')

            # Avoid the collective call until the local rank zero has finished trying to download the checkpoint
            # so that we don't timeout for large downloads. This syncs all processes on the node
            with dist.local_rank_zero_download_and_wait(signal_file_path):
                # Then, wait to ensure every node has finished downloading the checkpoint
                dist.barrier()

            if dist.get_local_rank() == 0:
                os.remove(signal_file_path)

            z_loss = om_model_config.get('z_loss', 0.0)

            attention_patch_type = om_model_config.get('attention_patch_type',
                                                       None)
            if attention_patch_type is not None:
                if model.config.model_type != 'llama':
                    raise ValueError(
                        f'attention_patch_type is only supported for llama models, but got {model.config.model_type}'
                    )

                log.debug(
                    f'Patching llama attention with {attention_patch_type} attention'
                )
                from transformers.models.llama.modeling_llama import \
                    LlamaAttention
                LlamaAttention.forward = get_llama_attention_patch_fn(
                    attention_patch_type)
                model.config.use_cache = False

        # elif the model is either a PeftModel or a PreTrainedModel
        elif isinstance(om_model_config, model_types):
            model = om_model_config
            init_device = 'cpu'
            z_loss = 0.0

        # else, unsupported type
        else:
            raise ValueError(
                f'om_model_config must be either a DictConfig, PeftModel, or PreTrainedModel, but got {type(om_model_config)}'
            )

        composer_model = super().__init__(model=model,
                                          shift_labels=True,
                                          tokenizer=tokenizer,
                                          metrics=train_metrics,
                                          eval_metrics=eval_metrics,
                                          z_loss=z_loss,
                                          init_device=init_device)

        return composer_model