gemma-7b / examples /example_sft_qlora.py
satpalsr's picture
Upload folder using huggingface_hub
dc3b54a verified
from dataclasses import dataclass, field
from typing import Optional
import torch
from transformers import AutoTokenizer, HfArgumentParser, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer
@dataclass
class ScriptArguments:
"""
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
"""
per_device_train_batch_size: Optional[int] = field(default=4)
per_device_eval_batch_size: Optional[int] = field(default=1)
gradient_accumulation_steps: Optional[int] = field(default=4)
learning_rate: Optional[float] = field(default=2e-4)
max_grad_norm: Optional[float] = field(default=0.3)
weight_decay: Optional[int] = field(default=0.001)
lora_alpha: Optional[int] = field(default=16)
lora_dropout: Optional[float] = field(default=0.1)
lora_r: Optional[int] = field(default=8)
max_seq_length: Optional[int] = field(default=2048)
model_name: Optional[str] = field(
default=None,
metadata={
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
}
)
dataset_name: Optional[str] = field(
default="stingning/ultrachat",
metadata={"help": "The preference dataset to use."},
)
fp16: Optional[bool] = field(
default=False,
metadata={"help": "Enables fp16 training."},
)
bf16: Optional[bool] = field(
default=False,
metadata={"help": "Enables bf16 training."},
)
packing: Optional[bool] = field(
default=True,
metadata={"help": "Use packing dataset creating."},
)
gradient_checkpointing: Optional[bool] = field(
default=True,
metadata={"help": "Enables gradient checkpointing."},
)
use_flash_attention_2: Optional[bool] = field(
default=False,
metadata={"help": "Enables Flash Attention 2."},
)
optim: Optional[str] = field(
default="paged_adamw_32bit",
metadata={"help": "The optimizer to use."},
)
lr_scheduler_type: str = field(
default="constant",
metadata={"help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"},
)
max_steps: int = field(default=1000, metadata={"help": "How many optimizer update steps to take"})
warmup_ratio: float = field(default=0.03, metadata={"help": "Fraction of steps to do a warmup for"})
save_steps: int = field(default=10, metadata={"help": "Save checkpoint every X updates steps."})
logging_steps: int = field(default=10, metadata={"help": "Log every X updates steps."})
output_dir: str = field(
default="./results",
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
def formatting_func(example):
text = f"### USER: {example['data'][0]}\n### ASSISTANT: {example['data'][1]}"
return text
# Load the GG model - this is the local one, update it to the one on the Hub
model_id = "google/gemma-7b"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4"
)
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
torch_dtype=torch.float32,
attn_implementation="sdpa" if not script_args.use_flash_attention_2 else "flash_attention_2"
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id
lora_config = LoraConfig(
r=script_args.lora_r,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM",
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout
)
train_dataset = load_dataset(script_args.dataset_name, split="train[:5%]")
# TODO: make that configurable
YOUR_HF_USERNAME = xxx
output_dir = f"{YOUR_HF_USERNAME}/gemma-qlora-ultrachat"
training_arguments = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=script_args.per_device_train_batch_size,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
optim=script_args.optim,
save_steps=script_args.save_steps,
logging_steps=script_args.logging_steps,
learning_rate=script_args.learning_rate,
max_grad_norm=script_args.max_grad_norm,
max_steps=script_args.max_steps,
warmup_ratio=script_args.warmup_ratio,
lr_scheduler_type=script_args.lr_scheduler_type,
gradient_checkpointing=script_args.gradient_checkpointing,
fp16=script_args.fp16,
bf16=script_args.bf16,
)
trainer = SFTTrainer(
model=model,
args=training_arguments,
train_dataset=train_dataset,
peft_config=lora_config,
packing=script_args.packing,
dataset_text_field="id",
tokenizer=tokenizer,
max_seq_length=script_args.max_seq_length,
formatting_func=formatting_func,
)
trainer.train()