g-ronimo commited on
Commit
e1f30e6
·
verified ·
1 Parent(s): aa1b636

Upload qlora.py

Browse files
Files changed (1) hide show
  1. code/qlora.py +183 -0
code/qlora.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os, wandb, uuid, json
2
+ import bitsandbytes as bnb
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, BitsAndBytesConfig, TrainerCallback
4
+ from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
5
+ from accelerate import Accelerator
6
+ from accelerate.utils import set_seed
7
+ from datasets import load_dataset, DatasetDict, Dataset,load_from_disk
8
+ from functools import partial
9
+
10
+ set_seed(42)
11
+
12
+ accelerator = Accelerator()
13
+ run_id = str(uuid.uuid4())
14
+ modelpath="microsoft/phi-2"
15
+ dataset_name="teknium/OpenHermes-2.5"
16
+ lr=0.00002
17
+ bs=10 # batch size
18
+ bs_eval=16 # batch size for evals
19
+ ga_steps=4 # gradient acc. steps
20
+ epochs=1
21
+ max_length=1024
22
+ output_dir=f"out_{run_id}"
23
+
24
+ # Load model
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ modelpath,
27
+ device_map={"": accelerator.process_index},
28
+ quantization_config=BitsAndBytesConfig(
29
+ load_in_4bit=True,
30
+ bnb_4bit_compute_dtype=torch.bfloat16,
31
+ bnb_4bit_quant_type="nf4",
32
+ ),
33
+ torch_dtype=torch.bfloat16,
34
+ # does not work yet
35
+ # attn_implementation="flash_attention_2",
36
+ )
37
+
38
+ # Load Tokenizer
39
+ tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast=False) # fast tokenizer sometimes ignores the added tokens
40
+
41
+ # Add tokens <|im_start|> and <|im_end|>, latter is special eos token,
42
+ tokenizer.add_tokens(["<|im_start|>", "<PAD>"])
43
+ tokenizer.pad_token = "<PAD>"
44
+ tokenizer.add_special_tokens(dict(eos_token="<|im_end|>"))
45
+ model.config.eos_token_id = tokenizer.eos_token_id
46
+
47
+ # Add adapters to model
48
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
49
+
50
+ lora_config = LoraConfig(
51
+ r=32,
52
+ lora_alpha=32,
53
+ target_modules = [ "q_proj", "k_proj", "v_proj", "dense" ],
54
+ modules_to_save = ["lm_head", "embed_tokens"],
55
+ lora_dropout=0.1,
56
+ bias="none",
57
+ task_type="CAUSAL_LM",
58
+ )
59
+ model = get_peft_model(model, lora_config)
60
+
61
+ model.config.use_cache = False
62
+
63
+ # Print stats
64
+ if accelerator.is_main_process:
65
+ model.print_trainable_parameters()
66
+
67
+ # Load dataset
68
+ with accelerator.main_process_first():
69
+ dataset = load_dataset(dataset_name)
70
+ dataset = dataset["train"].train_test_split(test_size=0.1)
71
+
72
+ # Format (chatML) and tokenize dataset
73
+ templates= {
74
+ "system": "<|im_start|>system\n{msg}<|im_end|>",
75
+ "human": "<|im_start|>user\n{msg}<|im_end|>",
76
+ "gpt": "<|im_start|>assistant\n{msg}<|im_end|>",
77
+ }
78
+ IGNORE_INDEX=-100
79
+
80
+ def tokenize(input, max_length):
81
+ input_ids, attention_mask, labels = [], [], []
82
+
83
+ for i,msg in enumerate(input["conversations"]):
84
+ msg_role=msg["from"]
85
+ msg_content=msg["value"]
86
+ isHuman=msg_role=="human"
87
+ if not msg_role in templates: return # this will break it
88
+ msg_chatml=templates[msg_role].format(msg=msg_content)
89
+ msg_tokenized=tokenizer(msg_chatml, truncation=False, add_special_tokens=False)
90
+
91
+ input_ids+=msg_tokenized["input_ids"]
92
+ attention_mask+=msg_tokenized["attention_mask"]
93
+ labels+=[IGNORE_INDEX]*len(msg_tokenized["input_ids"]) if isHuman else msg_tokenized["input_ids"]
94
+
95
+ return {
96
+ "input_ids": input_ids[:max_length],
97
+ "attention_mask": attention_mask[:max_length],
98
+ "labels": labels[:max_length],
99
+ }
100
+
101
+ dataset_tokenized = dataset.map(
102
+ partial(tokenize, max_length=max_length),
103
+ batched=False,
104
+ # num_proc=os.cpu_count()//accelerator.num_processes, # multithreaded
105
+ num_proc=os.cpu_count(), # multithreaded
106
+ remove_columns=dataset["train"].column_names # don't need this anymore, we have tokens from here on
107
+ )
108
+
109
+ # collate function - to transform list of dictionaries [ {input_ids: [123, ..]}, {.. ] to single batch dictionary { input_ids: [..], labels: [..], attention_mask: [..] }
110
+ def collate(elements):
111
+ tokens=[e["input_ids"] for e in elements]
112
+ tokens_maxlen=max([len(t) for t in tokens])
113
+
114
+ for i,sample in enumerate(elements):
115
+ input_ids=sample["input_ids"]
116
+ labels=sample["labels"]
117
+ attention_mask=sample["attention_mask"]
118
+
119
+ pad_len=tokens_maxlen-len(input_ids)
120
+
121
+ input_ids.extend( pad_len * [tokenizer.pad_token_id] )
122
+ labels.extend( pad_len * [IGNORE_INDEX] )
123
+ attention_mask.extend( pad_len * [0] )
124
+
125
+ batch={
126
+ "input_ids": torch.tensor( [e["input_ids"] for e in elements] ),
127
+ "labels": torch.tensor( [e["labels"] for e in elements] ),
128
+ "attention_mask": torch.tensor( [e["attention_mask"] for e in elements] ),
129
+ }
130
+
131
+ return batch
132
+
133
+ steps_per_epoch=len(dataset_tokenized["train"])//(accelerator.num_processes*bs*ga_steps)
134
+
135
+ args = TrainingArguments(
136
+ output_dir=output_dir,
137
+ per_device_train_batch_size=bs,
138
+ per_device_eval_batch_size=bs_eval,
139
+ evaluation_strategy="steps",
140
+ logging_steps=1,
141
+ eval_steps=steps_per_epoch//3, # 2 evals per epoch
142
+ save_steps=steps_per_epoch//3, # save once per epoch
143
+ gradient_accumulation_steps=ga_steps,
144
+ num_train_epochs=epochs,
145
+ lr_scheduler_type="constant",
146
+ optim="paged_adamw_32bit", # val_loss will go nan with paged_adamw_8bit
147
+ learning_rate=lr,
148
+ group_by_length=False,
149
+ bf16=True,
150
+ ddp_find_unused_parameters=False,
151
+ )
152
+
153
+ trainer = Trainer(
154
+ model=model,
155
+ tokenizer=tokenizer,
156
+ args=args,
157
+ data_collator=collate,
158
+ train_dataset=dataset_tokenized["train"],
159
+ eval_dataset=dataset_tokenized["test"],
160
+ )
161
+
162
+ if accelerator.is_main_process:
163
+ run = wandb.init(
164
+ project="phi2-teknium1",
165
+ name=modelpath+"_"+dataset_name+f"_bs-{bs}_LR-{lr}_GPUs-{accelerator.num_processes}_maxlen-{max_length}_{run_id}",
166
+ config={
167
+ "model_name": modelpath,
168
+ "run_id": run_id,
169
+ "dataset": dataset_name,
170
+ "output_dir": output_dir,
171
+ "lr": lr,
172
+ "max_length": max_length,
173
+ "train_batch_size": bs,
174
+ "validation_batch_size": bs,
175
+ "ga_steps": ga_steps,
176
+ "lora_config": lora_config,
177
+ "training_args": args,
178
+ "GPUs": accelerator.num_processes,
179
+ }
180
+ )
181
+ run.log_code()
182
+
183
+ trainer.train()