Upload example_sft_qlora.py with huggingface_hub
Browse files- example_sft_qlora.py +146 -0
example_sft_qlora.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from transformers import AutoTokenizer, HfArgumentParser, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
|
7 |
+
from datasets import load_dataset
|
8 |
+
from peft import LoraConfig
|
9 |
+
from trl import SFTTrainer
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class ScriptArguments:
|
13 |
+
"""
|
14 |
+
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
|
15 |
+
"""
|
16 |
+
per_device_train_batch_size: Optional[int] = field(default=4)
|
17 |
+
per_device_eval_batch_size: Optional[int] = field(default=1)
|
18 |
+
gradient_accumulation_steps: Optional[int] = field(default=4)
|
19 |
+
learning_rate: Optional[float] = field(default=2e-4)
|
20 |
+
max_grad_norm: Optional[float] = field(default=0.3)
|
21 |
+
weight_decay: Optional[int] = field(default=0.001)
|
22 |
+
lora_alpha: Optional[int] = field(default=16)
|
23 |
+
lora_dropout: Optional[float] = field(default=0.1)
|
24 |
+
lora_r: Optional[int] = field(default=8)
|
25 |
+
max_seq_length: Optional[int] = field(default=2048)
|
26 |
+
model_name: Optional[str] = field(
|
27 |
+
default=None,
|
28 |
+
metadata={
|
29 |
+
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
|
30 |
+
}
|
31 |
+
)
|
32 |
+
dataset_name: Optional[str] = field(
|
33 |
+
default="stingning/ultrachat",
|
34 |
+
metadata={"help": "The preference dataset to use."},
|
35 |
+
)
|
36 |
+
fp16: Optional[bool] = field(
|
37 |
+
default=False,
|
38 |
+
metadata={"help": "Enables fp16 training."},
|
39 |
+
)
|
40 |
+
bf16: Optional[bool] = field(
|
41 |
+
default=False,
|
42 |
+
metadata={"help": "Enables bf16 training."},
|
43 |
+
)
|
44 |
+
packing: Optional[bool] = field(
|
45 |
+
default=True,
|
46 |
+
metadata={"help": "Use packing dataset creating."},
|
47 |
+
)
|
48 |
+
gradient_checkpointing: Optional[bool] = field(
|
49 |
+
default=True,
|
50 |
+
metadata={"help": "Enables gradient checkpointing."},
|
51 |
+
)
|
52 |
+
use_flash_attention_2: Optional[bool] = field(
|
53 |
+
default=False,
|
54 |
+
metadata={"help": "Enables Flash Attention 2."},
|
55 |
+
)
|
56 |
+
optim: Optional[str] = field(
|
57 |
+
default="paged_adamw_32bit",
|
58 |
+
metadata={"help": "The optimizer to use."},
|
59 |
+
)
|
60 |
+
lr_scheduler_type: str = field(
|
61 |
+
default="constant",
|
62 |
+
metadata={"help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"},
|
63 |
+
)
|
64 |
+
max_steps: int = field(default=1000, metadata={"help": "How many optimizer update steps to take"})
|
65 |
+
warmup_ratio: float = field(default=0.03, metadata={"help": "Fraction of steps to do a warmup for"})
|
66 |
+
save_steps: int = field(default=10, metadata={"help": "Save checkpoint every X updates steps."})
|
67 |
+
logging_steps: int = field(default=10, metadata={"help": "Log every X updates steps."})
|
68 |
+
output_dir: str = field(
|
69 |
+
default="./results",
|
70 |
+
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
|
71 |
+
)
|
72 |
+
|
73 |
+
parser = HfArgumentParser(ScriptArguments)
|
74 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
75 |
+
|
76 |
+
|
77 |
+
def formatting_func(example):
|
78 |
+
text = f"### USER: {example['data'][0]}\n### ASSISTANT: {example['data'][1]}"
|
79 |
+
return text
|
80 |
+
|
81 |
+
# Load the GG model - this is the local one, update it to the one on the Hub
|
82 |
+
model_id = "google/gemma-7b"
|
83 |
+
|
84 |
+
quantization_config = BitsAndBytesConfig(
|
85 |
+
load_in_4bit=True,
|
86 |
+
bnb_4bit_compute_dtype=torch.float16,
|
87 |
+
bnb_4bit_quant_type="nf4"
|
88 |
+
)
|
89 |
+
|
90 |
+
# Load model
|
91 |
+
model = AutoModelForCausalLM.from_pretrained(
|
92 |
+
model_id,
|
93 |
+
quantization_config=quantization_config,
|
94 |
+
torch_dtype=torch.float32,
|
95 |
+
attn_implementation="sdpa" if not script_args.use_flash_attention_2 else "flash_attention_2"
|
96 |
+
)
|
97 |
+
|
98 |
+
# Load tokenizer
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
100 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
101 |
+
|
102 |
+
lora_config = LoraConfig(
|
103 |
+
r=script_args.lora_r,
|
104 |
+
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
|
105 |
+
bias="none",
|
106 |
+
task_type="CAUSAL_LM",
|
107 |
+
lora_alpha=script_args.lora_alpha,
|
108 |
+
lora_dropout=script_args.lora_dropout
|
109 |
+
)
|
110 |
+
|
111 |
+
train_dataset = load_dataset(script_args.dataset_name, split="train[:5%]")
|
112 |
+
|
113 |
+
# TODO: make that configurable
|
114 |
+
YOUR_HF_USERNAME = xxx
|
115 |
+
output_dir = f"{YOUR_HF_USERNAME}/gemma-qlora-ultrachat"
|
116 |
+
|
117 |
+
training_arguments = TrainingArguments(
|
118 |
+
output_dir=output_dir,
|
119 |
+
per_device_train_batch_size=script_args.per_device_train_batch_size,
|
120 |
+
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
121 |
+
optim=script_args.optim,
|
122 |
+
save_steps=script_args.save_steps,
|
123 |
+
logging_steps=script_args.logging_steps,
|
124 |
+
learning_rate=script_args.learning_rate,
|
125 |
+
max_grad_norm=script_args.max_grad_norm,
|
126 |
+
max_steps=script_args.max_steps,
|
127 |
+
warmup_ratio=script_args.warmup_ratio,
|
128 |
+
lr_scheduler_type=script_args.lr_scheduler_type,
|
129 |
+
gradient_checkpointing=script_args.gradient_checkpointing,
|
130 |
+
fp16=script_args.fp16,
|
131 |
+
bf16=script_args.bf16,
|
132 |
+
)
|
133 |
+
|
134 |
+
trainer = SFTTrainer(
|
135 |
+
model=model,
|
136 |
+
args=training_arguments,
|
137 |
+
train_dataset=train_dataset,
|
138 |
+
peft_config=lora_config,
|
139 |
+
packing=script_args.packing,
|
140 |
+
dataset_text_field="id",
|
141 |
+
tokenizer=tokenizer,
|
142 |
+
max_seq_length=script_args.max_seq_length,
|
143 |
+
formatting_func=formatting_func,
|
144 |
+
)
|
145 |
+
|
146 |
+
trainer.train()
|