File size: 11,803 Bytes
476ac07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
# Copyright (c) OpenMMLab. All rights reserved.
import os
import types

import torch
import transformers
from mmengine.config.lazy import LazyObject
from mmengine.utils import digit_version
from transformers.utils.import_utils import is_flash_attn_2_available

TRANSFORMERS_VERSION = digit_version(transformers.__version__)
IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.38')
# Transformers requires torch version >= 2.1.1 when using Torch SDPA.
# Refer to https://github.com/huggingface/transformers/blob/caa5c65db1f4db617cdac2ad667ba62edf94dd98/src/transformers/modeling_utils.py#L1611  # noqa: E501
SUPPORT_FLASH1 = digit_version(torch.__version__) >= digit_version('2.1.1')
SUPPORT_FLASH2 = is_flash_attn_2_available()
SUPPORT_FLASH = SUPPORT_FLASH1 or SUPPORT_FLASH2

USE_TRITON_KERNEL = bool(os.getenv('USE_TRITON_KERNEL', default=0))
SUPPORT_TRITON = False
try:
    import triton  # pre-check # noqa: F401
    import triton.language as tl  # pre-check # noqa: F401
    SUPPORT_TRITON = True
except ImportError:
    if USE_TRITON_KERNEL:
        raise RuntimeError(
            'USE_TRITON_KERNEL is set to 1, but triton has not been installed.'
            ' Run `pip install triton==2.1.0` to install triton.')

NO_ATTN_WEIGHTS_MSG = (
    'Due to the implementation of the PyTorch version of flash attention, '
    'even when the `output_attentions` flag is set to True, it is not '
    'possible to return the `attn_weights`.')

LOWEST_TRANSFORMERS_VERSION = dict(
    InternLM2ForCausalLM=digit_version('4.36'),
    InternLMForCausalLM=digit_version('4.36'),
    LlamaForCausalLM=digit_version('4.36'),
    Phi3ForCausalLM=digit_version('4.39'),
    MistralForCausalLM=digit_version('4.36'),
    # Training mixtral with lower version may lead to nccl timeout
    # Refer to https://github.com/microsoft/DeepSpeed/issues/5066
    MixtralForCausalLM=digit_version('4.40'),
    CohereForCausalLM=digit_version('4.40'),
    Qwen2ForCausalLM=digit_version('4.39'),
    Qwen2MoeForCausalLM=digit_version('4.40'),
    DeepseekV2ForCausalLM=digit_version('4.40'),
)

ATTN_DISPATCH_MAPPING = dict(
    InternLM2FlashAttention2=LazyObject(
        'xtuner.model.modules.dispatch.internlm2', 'internlm2_attn_forward'),
    InternLMAttention=LazyObject('xtuner.model.modules.dispatch.internlm',
                                 'internlm_attn_forward'),
    LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama',
                                    'llama_attn_forward'),
    Phi3FlashAttention2=LazyObject('xtuner.model.modules.dispatch.phi3',
                                   'phi3_attn_forward'),
    MistralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral',
                                      'mistral_attn_forward'),
    MixtralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral',
                                      'mistral_attn_forward'),
    CohereFlashAttention2=LazyObject('xtuner.model.modules.dispatch.cohere',
                                     'cohere_attn_forward'),
    Qwen2FlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2',
                                    'qwen2_attn_forward'),
    Qwen2MoeFlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2',
                                       'qwen2_attn_forward'),
    DeepseekV2FlashAttention2=LazyObject(
        'xtuner.model.modules.dispatch.deepseek_v2', 'deepseek_attn_forward'),
)

ATTN_LEGACY_DISPATCH_MAPPING = dict(
    LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama',
                                    'llama_attn_forward_legacy'), )

VARLEN_ATTN_DISPATCH_MAPPING = dict(
    InternLM2FlashAttention2=LazyObject(
        'xtuner.model.modules.dispatch.internlm2',
        'internlm2_varlen_attn_forward'),
    InternLMAttention=LazyObject('xtuner.model.modules.dispatch.internlm',
                                 'internlm_varlen_attn_forward'),
    LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama',
                                    'llama_varlen_attn_forward'),
    Phi3FlashAttention2=LazyObject('xtuner.model.modules.dispatch.phi3',
                                   'phi3_varlen_attn_forward'),
    MistralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral',
                                      'mistral_varlen_attn_forward'),
    MixtralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral',
                                      'mistral_varlen_attn_forward'),
    CohereFlashAttention2=None,
    Qwen2FlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2',
                                    'qwen2_varlen_attn_forward'),
    Qwen2MoeFlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2',
                                       'qwen2_varlen_attn_forward'),
    DeepseekV2FlashAttention2=LazyObject(
        'xtuner.model.modules.dispatch.deepseek_v2',
        'deepseek_varlen_attn_forward'),
)

VARLEN_ATTN_LEGACY_DISPATCH_MAPPING = dict(
    LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama',
                                    'llama_varlen_attn_forward_legacy'), )

RMS_DISPATCH_MAPPING = dict(
    InternLM2RMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
                                'rms_norm_forward'),
    InternLMRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
                               'rms_norm_forward'),
    LlamaRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
                            'rms_norm_forward'),
    Phi3RMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
                           'rms_norm_forward'),
    MistralRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
                              'rms_norm_forward'),
    MixtralRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
                              'rms_norm_forward'),
    CohereLayerNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
                               'layer_norm_forward'),
    Qwen2RMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
                            'rms_norm_forward'),
    Qwen2MoeRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
                               'rms_norm_forward'),
)

ROTE_DISPATCH_MAPPING = dict(
    InternLM2RotaryEmbedding=LazyObject(
        'xtuner.model.modules.dispatch.internlm2', 'InternLM2RotaryEmbedding'),
    InternLMRotaryEmbedding=LazyObject(
        'xtuner.model.modules.dispatch.internlm', 'InternLMRotaryEmbedding'),
    MistralRotaryEmbedding=LazyObject('xtuner.model.modules.dispatch.mistral',
                                      'MistralRotaryEmbedding'),
    MixtralRotaryEmbedding=LazyObject('xtuner.model.modules.dispatch.mistral',
                                      'MistralRotaryEmbedding'),
)


def log_once(func):
    logged = False

    def wrapper(*args, **kwargs):
        nonlocal logged
        if not logged:
            logged = True
            func(*args, **kwargs)
        return

    return wrapper


def dispatch_attn_forward(model):

    if not SUPPORT_FLASH2:
        return

    from mmengine import print_log
    print_log = log_once(print_log)

    attn_forward = None
    for module in model.modules():
        name = type(module).__name__
        if (IS_LOW_VERSION_TRANSFORMERS
                and name in ATTN_LEGACY_DISPATCH_MAPPING):
            if attn_forward is None:
                attn_forward = ATTN_LEGACY_DISPATCH_MAPPING[name]
                attn_forward = attn_forward.build()
            print_log(f'Dispatch {name} legacy forward. {NO_ATTN_WEIGHTS_MSG}',
                      'current')
            module.forward = types.MethodType(attn_forward, module)
        elif name in ATTN_DISPATCH_MAPPING:
            if attn_forward is None:
                attn_forward = ATTN_DISPATCH_MAPPING[name]
                attn_forward = attn_forward.build()
            print_log(f'Dispatch {name} forward. {NO_ATTN_WEIGHTS_MSG}',
                      'current')
            module.forward = types.MethodType(attn_forward, module)


def dispatch_varlen_attn_forward(model):

    if not SUPPORT_FLASH2:
        return

    from mmengine import print_log
    print_log = log_once(print_log)

    varlen_attn_forward = None
    for module in model.modules():
        name = type(module).__name__
        if (IS_LOW_VERSION_TRANSFORMERS
                and name in VARLEN_ATTN_LEGACY_DISPATCH_MAPPING):
            if varlen_attn_forward is None:
                varlen_attn_forward = VARLEN_ATTN_LEGACY_DISPATCH_MAPPING[name]
                varlen_attn_forward = varlen_attn_forward.build()
            print_log(
                f'Dispatch legacy {name} varlen forward. '
                f'{NO_ATTN_WEIGHTS_MSG}', 'current')
            module.forward = types.MethodType(varlen_attn_forward, module)
        elif name in VARLEN_ATTN_DISPATCH_MAPPING:
            if varlen_attn_forward is None:
                varlen_attn_forward = VARLEN_ATTN_DISPATCH_MAPPING[name]
                varlen_attn_forward = varlen_attn_forward.build()
            print_log(f'Dispatch {name} varlen forward. {NO_ATTN_WEIGHTS_MSG}',
                      'current')
            module.forward = types.MethodType(varlen_attn_forward, module)


def dispatch_rmsnorm_forward(model):

    if (not SUPPORT_TRITON) or (not USE_TRITON_KERNEL):
        return

    from mmengine import print_log
    print_log = log_once(print_log)

    rms_forward = None
    for module in model.modules():
        name = type(module).__name__
        if name in RMS_DISPATCH_MAPPING:
            if rms_forward is None:
                rms_forward = RMS_DISPATCH_MAPPING[name]
                rms_forward = rms_forward.build()
            print_log(f'Dispatch {name} forward.', 'current')
            module.forward = types.MethodType(rms_forward, module)


def replace_rote(model):

    from mmengine import print_log
    print_log = log_once(print_log)

    assert hasattr(model.config, 'rope_theta'), \
        '`rope_theta` should be in the model config.'
    rope_theta = model.config.rope_theta

    def traverse(module):
        for name, child in module.named_children():
            cls_name = type(child).__name__
            if cls_name in ROTE_DISPATCH_MAPPING:
                rote = ROTE_DISPATCH_MAPPING[cls_name]
                rote = rote.build()
                print_log(f'replace {cls_name}', 'current')
                dim_model = child.inv_freq.shape[0] * 2
                child_new = rote(dim_model, child.max_seq_len_cached,
                                 rope_theta).to(
                                     device=child.inv_freq.device,
                                     dtype=child.inv_freq.dtype)
                setattr(module, name, child_new)
            else:
                traverse(child)

    traverse(model)


def dispatch_modules(model, use_varlen_attn=False):

    def check(model_name):
        if 'ForCausalLM' not in model_name and model_name.endswith('Model'):
            # a walkaround for reward model
            model_name = model_name[:-5] + 'ForCausalLM'
        msg = '{} requires transformers version at least {}, but got {}'
        assert TRANSFORMERS_VERSION >= LOWEST_TRANSFORMERS_VERSION[
            model_name], msg.format(model_name,
                                    LOWEST_TRANSFORMERS_VERSION[model_name],
                                    TRANSFORMERS_VERSION)

    check(type(model).__name__)
    if use_varlen_attn:
        dispatch_varlen_attn_forward(model)
    else:
        dispatch_attn_forward(model)
    dispatch_rmsnorm_forward(model)
    replace_rote(model)


__all__ = ['dispatch_modules']