Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os | |
import types | |
import torch | |
import transformers | |
from mmengine.config.lazy import LazyObject | |
from mmengine.utils import digit_version | |
from transformers.utils.import_utils import is_flash_attn_2_available | |
TRANSFORMERS_VERSION = digit_version(transformers.__version__) | |
IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.38') | |
# Transformers requires torch version >= 2.1.1 when using Torch SDPA. | |
# Refer to https://github.com/huggingface/transformers/blob/caa5c65db1f4db617cdac2ad667ba62edf94dd98/src/transformers/modeling_utils.py#L1611 # noqa: E501 | |
SUPPORT_FLASH1 = digit_version(torch.__version__) >= digit_version('2.1.1') | |
SUPPORT_FLASH2 = is_flash_attn_2_available() | |
SUPPORT_FLASH = SUPPORT_FLASH1 or SUPPORT_FLASH2 | |
USE_TRITON_KERNEL = bool(os.getenv('USE_TRITON_KERNEL', default=0)) | |
SUPPORT_TRITON = False | |
try: | |
import triton # pre-check # noqa: F401 | |
import triton.language as tl # pre-check # noqa: F401 | |
SUPPORT_TRITON = True | |
except ImportError: | |
if USE_TRITON_KERNEL: | |
raise RuntimeError( | |
'USE_TRITON_KERNEL is set to 1, but triton has not been installed.' | |
' Run `pip install triton==2.1.0` to install triton.') | |
NO_ATTN_WEIGHTS_MSG = ( | |
'Due to the implementation of the PyTorch version of flash attention, ' | |
'even when the `output_attentions` flag is set to True, it is not ' | |
'possible to return the `attn_weights`.') | |
LOWEST_TRANSFORMERS_VERSION = dict( | |
InternLM2ForCausalLM=digit_version('4.36'), | |
InternLMForCausalLM=digit_version('4.36'), | |
LlamaForCausalLM=digit_version('4.36'), | |
Phi3ForCausalLM=digit_version('4.39'), | |
MistralForCausalLM=digit_version('4.36'), | |
# Training mixtral with lower version may lead to nccl timeout | |
# Refer to https://github.com/microsoft/DeepSpeed/issues/5066 | |
MixtralForCausalLM=digit_version('4.40'), | |
CohereForCausalLM=digit_version('4.40'), | |
Qwen2ForCausalLM=digit_version('4.39'), | |
Qwen2MoeForCausalLM=digit_version('4.40'), | |
DeepseekV2ForCausalLM=digit_version('4.40'), | |
) | |
ATTN_DISPATCH_MAPPING = dict( | |
InternLM2FlashAttention2=LazyObject( | |
'xtuner.model.modules.dispatch.internlm2', 'internlm2_attn_forward'), | |
InternLMAttention=LazyObject('xtuner.model.modules.dispatch.internlm', | |
'internlm_attn_forward'), | |
LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama', | |
'llama_attn_forward'), | |
Phi3FlashAttention2=LazyObject('xtuner.model.modules.dispatch.phi3', | |
'phi3_attn_forward'), | |
MistralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral', | |
'mistral_attn_forward'), | |
MixtralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral', | |
'mistral_attn_forward'), | |
CohereFlashAttention2=LazyObject('xtuner.model.modules.dispatch.cohere', | |
'cohere_attn_forward'), | |
Qwen2FlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2', | |
'qwen2_attn_forward'), | |
Qwen2MoeFlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2', | |
'qwen2_attn_forward'), | |
DeepseekV2FlashAttention2=LazyObject( | |
'xtuner.model.modules.dispatch.deepseek_v2', 'deepseek_attn_forward'), | |
) | |
ATTN_LEGACY_DISPATCH_MAPPING = dict( | |
LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama', | |
'llama_attn_forward_legacy'), ) | |
VARLEN_ATTN_DISPATCH_MAPPING = dict( | |
InternLM2FlashAttention2=LazyObject( | |
'xtuner.model.modules.dispatch.internlm2', | |
'internlm2_varlen_attn_forward'), | |
InternLMAttention=LazyObject('xtuner.model.modules.dispatch.internlm', | |
'internlm_varlen_attn_forward'), | |
LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama', | |
'llama_varlen_attn_forward'), | |
Phi3FlashAttention2=LazyObject('xtuner.model.modules.dispatch.phi3', | |
'phi3_varlen_attn_forward'), | |
MistralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral', | |
'mistral_varlen_attn_forward'), | |
MixtralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral', | |
'mistral_varlen_attn_forward'), | |
CohereFlashAttention2=None, | |
Qwen2FlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2', | |
'qwen2_varlen_attn_forward'), | |
Qwen2MoeFlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2', | |
'qwen2_varlen_attn_forward'), | |
DeepseekV2FlashAttention2=LazyObject( | |
'xtuner.model.modules.dispatch.deepseek_v2', | |
'deepseek_varlen_attn_forward'), | |
) | |
VARLEN_ATTN_LEGACY_DISPATCH_MAPPING = dict( | |
LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama', | |
'llama_varlen_attn_forward_legacy'), ) | |
RMS_DISPATCH_MAPPING = dict( | |
InternLM2RMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels', | |
'rms_norm_forward'), | |
InternLMRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels', | |
'rms_norm_forward'), | |
LlamaRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels', | |
'rms_norm_forward'), | |
Phi3RMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels', | |
'rms_norm_forward'), | |
MistralRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels', | |
'rms_norm_forward'), | |
MixtralRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels', | |
'rms_norm_forward'), | |
CohereLayerNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels', | |
'layer_norm_forward'), | |
Qwen2RMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels', | |
'rms_norm_forward'), | |
Qwen2MoeRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels', | |
'rms_norm_forward'), | |
) | |
ROTE_DISPATCH_MAPPING = dict( | |
InternLM2RotaryEmbedding=LazyObject( | |
'xtuner.model.modules.dispatch.internlm2', 'InternLM2RotaryEmbedding'), | |
InternLMRotaryEmbedding=LazyObject( | |
'xtuner.model.modules.dispatch.internlm', 'InternLMRotaryEmbedding'), | |
MistralRotaryEmbedding=LazyObject('xtuner.model.modules.dispatch.mistral', | |
'MistralRotaryEmbedding'), | |
MixtralRotaryEmbedding=LazyObject('xtuner.model.modules.dispatch.mistral', | |
'MistralRotaryEmbedding'), | |
) | |
def log_once(func): | |
logged = False | |
def wrapper(*args, **kwargs): | |
nonlocal logged | |
if not logged: | |
logged = True | |
func(*args, **kwargs) | |
return | |
return wrapper | |
def dispatch_attn_forward(model): | |
if not SUPPORT_FLASH2: | |
return | |
from mmengine import print_log | |
print_log = log_once(print_log) | |
attn_forward = None | |
for module in model.modules(): | |
name = type(module).__name__ | |
if (IS_LOW_VERSION_TRANSFORMERS | |
and name in ATTN_LEGACY_DISPATCH_MAPPING): | |
if attn_forward is None: | |
attn_forward = ATTN_LEGACY_DISPATCH_MAPPING[name] | |
attn_forward = attn_forward.build() | |
print_log(f'Dispatch {name} legacy forward. {NO_ATTN_WEIGHTS_MSG}', | |
'current') | |
module.forward = types.MethodType(attn_forward, module) | |
elif name in ATTN_DISPATCH_MAPPING: | |
if attn_forward is None: | |
attn_forward = ATTN_DISPATCH_MAPPING[name] | |
attn_forward = attn_forward.build() | |
print_log(f'Dispatch {name} forward. {NO_ATTN_WEIGHTS_MSG}', | |
'current') | |
module.forward = types.MethodType(attn_forward, module) | |
def dispatch_varlen_attn_forward(model): | |
if not SUPPORT_FLASH2: | |
return | |
from mmengine import print_log | |
print_log = log_once(print_log) | |
varlen_attn_forward = None | |
for module in model.modules(): | |
name = type(module).__name__ | |
if (IS_LOW_VERSION_TRANSFORMERS | |
and name in VARLEN_ATTN_LEGACY_DISPATCH_MAPPING): | |
if varlen_attn_forward is None: | |
varlen_attn_forward = VARLEN_ATTN_LEGACY_DISPATCH_MAPPING[name] | |
varlen_attn_forward = varlen_attn_forward.build() | |
print_log( | |
f'Dispatch legacy {name} varlen forward. ' | |
f'{NO_ATTN_WEIGHTS_MSG}', 'current') | |
module.forward = types.MethodType(varlen_attn_forward, module) | |
elif name in VARLEN_ATTN_DISPATCH_MAPPING: | |
if varlen_attn_forward is None: | |
varlen_attn_forward = VARLEN_ATTN_DISPATCH_MAPPING[name] | |
varlen_attn_forward = varlen_attn_forward.build() | |
print_log(f'Dispatch {name} varlen forward. {NO_ATTN_WEIGHTS_MSG}', | |
'current') | |
module.forward = types.MethodType(varlen_attn_forward, module) | |
def dispatch_rmsnorm_forward(model): | |
if (not SUPPORT_TRITON) or (not USE_TRITON_KERNEL): | |
return | |
from mmengine import print_log | |
print_log = log_once(print_log) | |
rms_forward = None | |
for module in model.modules(): | |
name = type(module).__name__ | |
if name in RMS_DISPATCH_MAPPING: | |
if rms_forward is None: | |
rms_forward = RMS_DISPATCH_MAPPING[name] | |
rms_forward = rms_forward.build() | |
print_log(f'Dispatch {name} forward.', 'current') | |
module.forward = types.MethodType(rms_forward, module) | |
def replace_rote(model): | |
from mmengine import print_log | |
print_log = log_once(print_log) | |
assert hasattr(model.config, 'rope_theta'), \ | |
'`rope_theta` should be in the model config.' | |
rope_theta = model.config.rope_theta | |
def traverse(module): | |
for name, child in module.named_children(): | |
cls_name = type(child).__name__ | |
if cls_name in ROTE_DISPATCH_MAPPING: | |
rote = ROTE_DISPATCH_MAPPING[cls_name] | |
rote = rote.build() | |
print_log(f'replace {cls_name}', 'current') | |
dim_model = child.inv_freq.shape[0] * 2 | |
child_new = rote(dim_model, child.max_seq_len_cached, | |
rope_theta).to( | |
device=child.inv_freq.device, | |
dtype=child.inv_freq.dtype) | |
setattr(module, name, child_new) | |
else: | |
traverse(child) | |
traverse(model) | |
def dispatch_modules(model, use_varlen_attn=False): | |
def check(model_name): | |
if 'ForCausalLM' not in model_name and model_name.endswith('Model'): | |
# a walkaround for reward model | |
model_name = model_name[:-5] + 'ForCausalLM' | |
msg = '{} requires transformers version at least {}, but got {}' | |
assert TRANSFORMERS_VERSION >= LOWEST_TRANSFORMERS_VERSION[ | |
model_name], msg.format(model_name, | |
LOWEST_TRANSFORMERS_VERSION[model_name], | |
TRANSFORMERS_VERSION) | |
check(type(model).__name__) | |
if use_varlen_attn: | |
dispatch_varlen_attn_forward(model) | |
else: | |
dispatch_attn_forward(model) | |
dispatch_rmsnorm_forward(model) | |
replace_rote(model) | |
__all__ = ['dispatch_modules'] | |