ChatVID / model /fastchat /train /train_lora.py
Yiqin's picture
init
6ef31de
raw
history blame
4.96 kB
# Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG>
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
import logging
import pathlib
import typing
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from peft import LoraConfig, get_peft_model
import transformers
from transformers import Trainer
from fastchat.train.train import (
DataArguments,
ModelArguments,
TrainingArguments,
make_supervised_data_module,
)
from fastchat.train.llama_flash_attn_monkey_patch import (
replace_llama_attn_with_flash_attn,
)
replace_llama_attn_with_flash_attn()
@dataclass
class LoraArguments:
lora_r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_target_modules: typing.List[str] = field(
default_factory=lambda: ["q_proj", "v_proj"]
)
lora_weight_path: str = ""
bias: str = "none"
def maybe_zero_3(param):
if hasattr(param, "ds_id"):
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
with zero.GatheredParameters([param]):
param = param.data.cpu().clone().detach()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(state_dict, bias):
if bias == "none":
to_return = {
k: state_dict[k].cpu().clone().detach() for k in state_dict if "lora_" in k
}
elif bias == "all":
to_return = {
k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k
}
elif bias == "lora_only":
to_return = {}
for k in state_dict:
if "lora_" in k:
to_return[k] = state_dict[k]
bias_name = k.split("lora_")[0] + "bias"
if bias_name in state_dict:
to_return[bias_name] = state_dict[bias_name]
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
return to_return
def train():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
)
(
model_args,
data_args,
training_args,
lora_args,
) = parser.parse_args_into_dataclasses()
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout,
bias=lora_args.bias,
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
if training_args.deepspeed is not None and training_args.local_rank == 0:
model.print_trainable_parameters()
if training_args.gradient_checkpointing:
logging.warning(
"gradient checkpointing with lora makes requires_grad "
"incorrect and needs a monkey patch in Trainer or the "
"wrapped model's forward. ref: "
"https://github.com/lm-sys/FastChat/pull/138#issuecomment-1509172198"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args, **data_module
)
model.config.use_cache = False
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
# Save states. Weights might be a placeholder in zero3 and need a gather
state_dict = get_peft_state_maybe_zero_3(model.state_dict(), lora_args.bias)
if training_args.local_rank == 0:
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
if __name__ == "__main__":
train()