|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
import logging |
|
import pathlib |
|
import typing |
|
|
|
from deepspeed import zero |
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus |
|
from peft import LoraConfig, get_peft_model |
|
import transformers |
|
from transformers import Trainer |
|
|
|
from fastchat.train.train import ( |
|
DataArguments, |
|
ModelArguments, |
|
TrainingArguments, |
|
make_supervised_data_module, |
|
) |
|
|
|
from fastchat.train.llama_flash_attn_monkey_patch import ( |
|
replace_llama_attn_with_flash_attn, |
|
) |
|
|
|
replace_llama_attn_with_flash_attn() |
|
|
|
|
|
@dataclass |
|
class LoraArguments: |
|
lora_r: int = 8 |
|
lora_alpha: int = 16 |
|
lora_dropout: float = 0.05 |
|
lora_target_modules: typing.List[str] = field( |
|
default_factory=lambda: ["q_proj", "v_proj"] |
|
) |
|
lora_weight_path: str = "" |
|
bias: str = "none" |
|
|
|
|
|
def maybe_zero_3(param): |
|
if hasattr(param, "ds_id"): |
|
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE |
|
with zero.GatheredParameters([param]): |
|
param = param.data.cpu().clone().detach() |
|
return param |
|
|
|
|
|
|
|
def get_peft_state_maybe_zero_3(state_dict, bias): |
|
if bias == "none": |
|
to_return = { |
|
k: state_dict[k].cpu().clone().detach() for k in state_dict if "lora_" in k |
|
} |
|
elif bias == "all": |
|
to_return = { |
|
k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k |
|
} |
|
elif bias == "lora_only": |
|
to_return = {} |
|
for k in state_dict: |
|
if "lora_" in k: |
|
to_return[k] = state_dict[k] |
|
bias_name = k.split("lora_")[0] + "bias" |
|
if bias_name in state_dict: |
|
to_return[bias_name] = state_dict[bias_name] |
|
else: |
|
raise NotImplementedError |
|
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} |
|
return to_return |
|
|
|
|
|
def train(): |
|
parser = transformers.HfArgumentParser( |
|
(ModelArguments, DataArguments, TrainingArguments, LoraArguments) |
|
) |
|
( |
|
model_args, |
|
data_args, |
|
training_args, |
|
lora_args, |
|
) = parser.parse_args_into_dataclasses() |
|
|
|
model = transformers.AutoModelForCausalLM.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
) |
|
lora_config = LoraConfig( |
|
r=lora_args.lora_r, |
|
lora_alpha=lora_args.lora_alpha, |
|
target_modules=lora_args.lora_target_modules, |
|
lora_dropout=lora_args.lora_dropout, |
|
bias=lora_args.bias, |
|
task_type="CAUSAL_LM", |
|
) |
|
model = get_peft_model(model, lora_config) |
|
if training_args.deepspeed is not None and training_args.local_rank == 0: |
|
model.print_trainable_parameters() |
|
|
|
if training_args.gradient_checkpointing: |
|
logging.warning( |
|
"gradient checkpointing with lora makes requires_grad " |
|
"incorrect and needs a monkey patch in Trainer or the " |
|
"wrapped model's forward. ref: " |
|
"https://github.com/lm-sys/FastChat/pull/138#issuecomment-1509172198" |
|
) |
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=training_args.cache_dir, |
|
model_max_length=training_args.model_max_length, |
|
padding_side="right", |
|
use_fast=False, |
|
) |
|
tokenizer.pad_token = tokenizer.unk_token |
|
|
|
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) |
|
trainer = Trainer( |
|
model=model, tokenizer=tokenizer, args=training_args, **data_module |
|
) |
|
|
|
model.config.use_cache = False |
|
|
|
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
|
trainer.train(resume_from_checkpoint=True) |
|
else: |
|
trainer.train() |
|
trainer.save_state() |
|
|
|
|
|
state_dict = get_peft_state_maybe_zero_3(model.state_dict(), lora_args.bias) |
|
if training_args.local_rank == 0: |
|
model.save_pretrained(training_args.output_dir, state_dict=state_dict) |
|
|
|
|
|
if __name__ == "__main__": |
|
train() |
|
|