zhangtao-whu's picture
Upload folder using huggingface_hub
476ac07 verified
# Copyright (c) OpenMMLab. All rights reserved.
import math
from collections import OrderedDict
from contextlib import nullcontext
import torch
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from mmengine.model import BaseModel
from mmengine.runner import load_checkpoint
from peft import get_peft_model, prepare_model_for_kbit_training
from torch import nn
from transformers import AutoConfig, PreTrainedModel, PreTrainedTokenizer
from transformers.integrations import is_deepspeed_zero3_enabled
from xtuner.parallel.sequence import (get_sequence_parallel_group,
get_sequence_parallel_world_size,
reduce_sequence_parallel_loss,
split_for_sequence_parallel)
from xtuner.registry import BUILDER
from .modules import dispatch_modules
from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
from .utils import (LoadWoInit, find_all_linear_names,
get_peft_model_state_dict, make_inputs_require_grad,
traverse_dict)
def smart_tokenizer_and_embedding_resize(
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
):
"""Resize embedding."""
if is_deepspeed_zero3_enabled():
import deepspeed
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings(
) is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(
params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
assert isinstance(model.get_output_embeddings(), nn.Linear)
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
num_new_tokens = len(tokenizer) - current_embedding_size
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
print_log(
f'Resized token embeddings from {current_embedding_size} to '
f'{len(tokenizer)}.', 'current')
class SupervisedFinetune(BaseModel):
def __init__(self,
llm,
lora=None,
peft_model=None,
use_activation_checkpointing=True,
use_varlen_attn=False,
tokenizer=None,
max_position_embeddings=None):
super().__init__()
with LoadWoInit():
if isinstance(llm, dict):
llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
self.llm = self._build_from_cfg_or_module(llm)
if tokenizer is not None:
if isinstance(tokenizer, dict):
tokenizer = BUILDER.build(tokenizer)
smart_tokenizer_and_embedding_resize(tokenizer, self.llm)
self.llm.config.use_cache = False
dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn)
if use_activation_checkpointing:
# For backward compatibility
if hasattr(self.llm, 'enable_input_require_grads'):
self.llm.enable_input_require_grads()
else:
self.llm.get_input_embeddings().register_forward_hook(
make_inputs_require_grad)
# enable gradient checkpointing for memory efficiency
self.gradient_checkpointing_enable()
if isinstance(lora, dict) or isinstance(lora, Config) or isinstance(
lora, ConfigDict):
self.lora = BUILDER.build(lora)
else:
self.lora = lora
self.peft_model = peft_model
self.use_lora = lora is not None
if self.use_lora:
self._prepare_for_lora(peft_model, use_activation_checkpointing)
self._is_init = True
# Determines whether to calculate attention based on the
# seq_len dimension (use_varlen_attn = False) or the actual length of
# the sequence.
self.use_varlen_attn = use_varlen_attn
def gradient_checkpointing_enable(self):
self.activation_checkpointing_enable()
def activation_checkpointing_enable(self):
self.llm.gradient_checkpointing_enable()
def gradient_checkpointing_disable(self):
self.activation_checkpointing_disable()
def activation_checkpointing_disable(self):
self.llm.gradient_checkpointing_disable()
def _prepare_for_lora(self,
peft_model=None,
use_activation_checkpointing=True):
self.llm = prepare_model_for_kbit_training(
self.llm, use_activation_checkpointing)
if self.lora.target_modules is None:
modules = find_all_linear_names(self.llm)
self.lora.target_modules = modules
self.llm = get_peft_model(self.llm, self.lora)
if peft_model is not None:
_ = load_checkpoint(self, peft_model)
def init_weights(self):
pass
@staticmethod
def _prepare_for_long_context_training(cfg, llm_cfg,
max_position_embeddings):
if not hasattr(llm_cfg, 'rope_scaling'):
print_log('Current model does not support RoPE scaling.',
'current')
return
current_max_length = getattr(llm_cfg, 'max_position_embeddings', None)
if current_max_length and max_position_embeddings > current_max_length:
print_log(
f'Enlarge max model length from {current_max_length} '
f'to {max_position_embeddings}.', 'current')
scaling_factor = float(
math.ceil(max_position_embeddings / current_max_length))
else:
print_log(
'The input `max_position_embeddings` is smaller than '
'origin max length. Consider increase input length.',
'current')
scaling_factor = 1.0
cfg.rope_scaling = {'type': 'linear', 'factor': scaling_factor}
return cfg
@staticmethod
def _prepare_for_flash_attn(cfg, llm_cfg):
cls_name = type(llm_cfg).__name__
SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
'Starcoder2Config', 'Starcoder2Config',
'Phi3Config')
SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
'MistralConfig', 'MixtralConfig', 'Qwen2Config',
'Qwen2MoeConfig', 'Starcoder2Config',
'Starcoder2Config', 'Phi3Config',
'DeepseekV2Config')
torch_dtype = torch.bfloat16 if (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
else torch.float16
if getattr(cfg, 'attn_implementation', None) is not None:
# Flash Attention 2.0 only supports torch.float16 and
# torch.bfloat16 dtypes
if cfg.attn_implementation == 'flash_attention_2':
cfg.torch_dtype = torch_dtype
elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
cfg.torch_dtype = torch_dtype
cfg.attn_implementation = 'flash_attention_2'
elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
cfg.attn_implementation = 'sdpa'
return cfg
@staticmethod
def _prepare_for_qlora_zero3(cfg):
if (not is_deepspeed_zero3_enabled()) or (not hasattr(
cfg, 'quantization_config')):
return cfg
torch_dtype = torch.bfloat16 if (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
else torch.float16
cfg.torch_dtype = torch_dtype
quantization_config = cfg.quantization_config
quantization_config.bnb_4bit_compute_dtype = torch_dtype
quantization_config.bnb_4bit_quant_storage = torch_dtype
return cfg
def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
cfg = self._prepare_for_qlora_zero3(cfg)
pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
llm_cfg = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True)
cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
if max_position_embeddings is not None:
cfg = self._prepare_for_long_context_training(
cfg, llm_cfg, max_position_embeddings)
return cfg
def _build_from_cfg_or_module(self, cfg_or_mod):
if isinstance(cfg_or_mod, nn.Module):
return cfg_or_mod
elif isinstance(cfg_or_mod, dict):
traverse_dict(cfg_or_mod)
return BUILDER.build(cfg_or_mod)
else:
raise NotImplementedError
def forward(self, data, data_samples=None, mode='loss'):
if mode == 'loss':
return self.compute_loss(data, data_samples)
elif mode == 'predict':
return self.predict(data, data_samples)
elif mode == 'tensor':
return self._forward(data, data_samples)
else:
raise NotImplementedError
def _forward(self, data, data_samples=None):
outputs = self.llm(**data)
return outputs
def predict(self, data, data_samples=None):
outputs = self.llm(**data)
logits_dict = [{'logits': logits} for logits in outputs.logits]
return logits_dict
@staticmethod
def _split_for_sequence_parallel(data):
# attention mask should not be split
ARGS_NEED_TO_SPLIT = ('input_ids', 'labels', 'position_ids')
sp_group = get_sequence_parallel_group()
for key in ARGS_NEED_TO_SPLIT:
val = data.get(key, None)
if val is not None:
# `dim` is 1 as the shape of tensor is (bs, seq_len, ...)
data[key] = split_for_sequence_parallel(
val, dim=1, sp_group=sp_group)
return data
def _compute_sequence_parallel_loss(self, data):
data = self._split_for_sequence_parallel(data)
outputs = self.llm(**data)
labels = data['labels']
num_tokens = (labels != -100).sum()
sp_group = get_sequence_parallel_group()
loss = reduce_sequence_parallel_loss(outputs.loss, num_tokens,
sp_group)
return {'loss': loss}
def compute_loss(self, data, data_samples=None):
if get_sequence_parallel_world_size() > 1:
return self._compute_sequence_parallel_loss(data)
else:
outputs = self.llm(**data)
loss_dict = {'loss': outputs.loss}
return loss_dict
def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
if not self.use_lora:
return state_dict
to_return = get_peft_model_state_dict(self.llm, state_dict=state_dict)
return OrderedDict(to_return)
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.llm, name)
def to_hf(self,
cfg,
save_dir,
fp32=False,
save_pretrained_kwargs={},
**kwargs):
self.llm.config.use_cache = True
if not fp32:
print_log('Convert LLM to float16', 'current')
self.llm.half()
if self.use_lora:
print_log(f'Saving adapter to {save_dir}', 'current')
else:
print_log(f'Saving LLM tokenizer to {save_dir}', 'current')
tokenizer = BUILDER.build(cfg.tokenizer)
tokenizer.save_pretrained(save_dir)
print_log(f'Saving LLM to {save_dir}', 'current')
self.llm.save_pretrained(save_dir, **save_pretrained_kwargs)
self.llm.config.use_cache = False