Spaces:
Runtime error
Runtime error
zetavg
commited on
update
Browse files- llama_lora/globals.py +6 -1
- llama_lora/lib/finetune.py +223 -0
- llama_lora/models.py +32 -2
- llama_lora/ui/finetune_ui.py +6 -1
- llama_lora/ui/inference_ui.py +1 -0
llama_lora/globals.py
CHANGED
@@ -3,6 +3,8 @@ import subprocess
|
|
3 |
|
4 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
5 |
|
|
|
|
|
6 |
|
7 |
class Global:
|
8 |
version = None
|
@@ -15,11 +17,14 @@ class Global:
|
|
15 |
loaded_base_model: Any = None
|
16 |
|
17 |
# Functions
|
18 |
-
train_fn: Any =
|
19 |
|
20 |
# Training Control
|
21 |
should_stop_training = False
|
22 |
|
|
|
|
|
|
|
23 |
# UI related
|
24 |
ui_title: str = "LLaMA-LoRA"
|
25 |
ui_emoji: str = "π¦ποΈ"
|
|
|
3 |
|
4 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
5 |
|
6 |
+
from .lib.finetune import train
|
7 |
+
|
8 |
|
9 |
class Global:
|
10 |
version = None
|
|
|
17 |
loaded_base_model: Any = None
|
18 |
|
19 |
# Functions
|
20 |
+
train_fn: Any = train
|
21 |
|
22 |
# Training Control
|
23 |
should_stop_training = False
|
24 |
|
25 |
+
# Model related
|
26 |
+
model_has_been_used = False
|
27 |
+
|
28 |
# UI related
|
29 |
ui_title: str = "LLaMA-LoRA"
|
30 |
ui_emoji: str = "π¦ποΈ"
|
llama_lora/lib/finetune.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from typing import Any, List
|
4 |
+
|
5 |
+
import fire
|
6 |
+
import torch
|
7 |
+
import transformers
|
8 |
+
from datasets import Dataset, load_dataset
|
9 |
+
|
10 |
+
|
11 |
+
from peft import (
|
12 |
+
LoraConfig,
|
13 |
+
get_peft_model,
|
14 |
+
get_peft_model_state_dict,
|
15 |
+
prepare_model_for_int8_training,
|
16 |
+
set_peft_model_state_dict,
|
17 |
+
)
|
18 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
19 |
+
|
20 |
+
|
21 |
+
def train(
|
22 |
+
# model/data params
|
23 |
+
base_model: Any,
|
24 |
+
tokenizer: Any,
|
25 |
+
output_dir: str,
|
26 |
+
train_dataset_data: List[Any],
|
27 |
+
# training hyperparams
|
28 |
+
micro_batch_size: int = 4,
|
29 |
+
gradient_accumulation_steps: int = 32,
|
30 |
+
num_epochs: int = 3,
|
31 |
+
learning_rate: float = 3e-4,
|
32 |
+
cutoff_len: int = 256,
|
33 |
+
val_set_size: int = 2000,
|
34 |
+
# lora hyperparams
|
35 |
+
lora_r: int = 8,
|
36 |
+
lora_alpha: int = 16,
|
37 |
+
lora_dropout: float = 0.05,
|
38 |
+
lora_target_modules: List[str] = [
|
39 |
+
"q_proj",
|
40 |
+
"v_proj",
|
41 |
+
],
|
42 |
+
# llm hyperparams
|
43 |
+
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
44 |
+
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
45 |
+
# either training checkpoint or final adapter
|
46 |
+
resume_from_checkpoint: str = None,
|
47 |
+
# logging
|
48 |
+
callbacks: List[Any] = []
|
49 |
+
):
|
50 |
+
device_map = "auto"
|
51 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
52 |
+
ddp = world_size != 1
|
53 |
+
if ddp:
|
54 |
+
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
55 |
+
|
56 |
+
model = base_model
|
57 |
+
if isinstance(model, str):
|
58 |
+
model = LlamaForCausalLM.from_pretrained(
|
59 |
+
base_model,
|
60 |
+
load_in_8bit=True,
|
61 |
+
torch_dtype=torch.float16,
|
62 |
+
device_map=device_map,
|
63 |
+
)
|
64 |
+
|
65 |
+
if isinstance(tokenizer, str):
|
66 |
+
tokenizer = LlamaTokenizer.from_pretrained(tokenizer)
|
67 |
+
|
68 |
+
tokenizer.pad_token_id = (
|
69 |
+
0 # unk. we want this to be different from the eos token
|
70 |
+
)
|
71 |
+
tokenizer.padding_side = "left" # Allow batched inference
|
72 |
+
|
73 |
+
def tokenize(prompt, add_eos_token=True):
|
74 |
+
# there's probably a way to do this with the tokenizer settings
|
75 |
+
# but again, gotta move fast
|
76 |
+
result = tokenizer(
|
77 |
+
prompt,
|
78 |
+
truncation=True,
|
79 |
+
max_length=cutoff_len,
|
80 |
+
padding=False,
|
81 |
+
return_tensors=None,
|
82 |
+
)
|
83 |
+
if (
|
84 |
+
result["input_ids"][-1] != tokenizer.eos_token_id
|
85 |
+
and len(result["input_ids"]) < cutoff_len
|
86 |
+
and add_eos_token
|
87 |
+
):
|
88 |
+
result["input_ids"].append(tokenizer.eos_token_id)
|
89 |
+
result["attention_mask"].append(1)
|
90 |
+
|
91 |
+
result["labels"] = result["input_ids"].copy()
|
92 |
+
|
93 |
+
return result
|
94 |
+
|
95 |
+
def generate_and_tokenize_prompt(data_point):
|
96 |
+
full_prompt = data_point["prompt"] + data_point["completion"]
|
97 |
+
tokenized_full_prompt = tokenize(full_prompt)
|
98 |
+
if not train_on_inputs:
|
99 |
+
user_prompt = data_point["prompt"]
|
100 |
+
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
|
101 |
+
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
102 |
+
|
103 |
+
tokenized_full_prompt["labels"] = [
|
104 |
+
-100
|
105 |
+
] * user_prompt_len + tokenized_full_prompt["labels"][
|
106 |
+
user_prompt_len:
|
107 |
+
] # could be sped up, probably
|
108 |
+
return tokenized_full_prompt
|
109 |
+
|
110 |
+
# will fail anyway.
|
111 |
+
try:
|
112 |
+
model = prepare_model_for_int8_training(model)
|
113 |
+
except Exception as e:
|
114 |
+
print(
|
115 |
+
f"Got error while running prepare_model_for_int8_training(model), maybe the model has already be prepared. Original error: {e}.")
|
116 |
+
|
117 |
+
# model = prepare_model_for_int8_training(model)
|
118 |
+
|
119 |
+
config = LoraConfig(
|
120 |
+
r=lora_r,
|
121 |
+
lora_alpha=lora_alpha,
|
122 |
+
target_modules=lora_target_modules,
|
123 |
+
lora_dropout=lora_dropout,
|
124 |
+
bias="none",
|
125 |
+
task_type="CAUSAL_LM",
|
126 |
+
)
|
127 |
+
model = get_peft_model(model, config)
|
128 |
+
|
129 |
+
# If train_dataset_data is a list, convert it to datasets.Dataset
|
130 |
+
if isinstance(train_dataset_data, list):
|
131 |
+
train_dataset_data = Dataset.from_list(train_dataset_data)
|
132 |
+
|
133 |
+
if resume_from_checkpoint:
|
134 |
+
# Check the available weights and load them
|
135 |
+
checkpoint_name = os.path.join(
|
136 |
+
resume_from_checkpoint, "pytorch_model.bin"
|
137 |
+
) # Full checkpoint
|
138 |
+
if not os.path.exists(checkpoint_name):
|
139 |
+
checkpoint_name = os.path.join(
|
140 |
+
resume_from_checkpoint, "adapter_model.bin"
|
141 |
+
) # only LoRA model - LoRA config above has to fit
|
142 |
+
resume_from_checkpoint = (
|
143 |
+
False # So the trainer won't try loading its state
|
144 |
+
)
|
145 |
+
# The two files above have a different name depending on how they were saved, but are actually the same.
|
146 |
+
if os.path.exists(checkpoint_name):
|
147 |
+
print(f"Restarting from {checkpoint_name}")
|
148 |
+
adapters_weights = torch.load(checkpoint_name)
|
149 |
+
model = set_peft_model_state_dict(model, adapters_weights)
|
150 |
+
else:
|
151 |
+
print(f"Checkpoint {checkpoint_name} not found")
|
152 |
+
|
153 |
+
# Be more transparent about the % of trainable params.
|
154 |
+
model.print_trainable_parameters()
|
155 |
+
|
156 |
+
if val_set_size > 0:
|
157 |
+
train_val = train_dataset_data.train_test_split(
|
158 |
+
test_size=val_set_size, shuffle=True, seed=42
|
159 |
+
)
|
160 |
+
train_data = (
|
161 |
+
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
162 |
+
)
|
163 |
+
val_data = (
|
164 |
+
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
train_data = train_dataset_data.shuffle().map(generate_and_tokenize_prompt)
|
168 |
+
val_data = None
|
169 |
+
|
170 |
+
if not ddp and torch.cuda.device_count() > 1:
|
171 |
+
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
|
172 |
+
model.is_parallelizable = True
|
173 |
+
model.model_parallel = True
|
174 |
+
|
175 |
+
trainer = transformers.Trainer(
|
176 |
+
model=model,
|
177 |
+
train_dataset=train_data,
|
178 |
+
eval_dataset=val_data,
|
179 |
+
args=transformers.TrainingArguments(
|
180 |
+
per_device_train_batch_size=micro_batch_size,
|
181 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
182 |
+
warmup_steps=100,
|
183 |
+
num_train_epochs=num_epochs,
|
184 |
+
learning_rate=learning_rate,
|
185 |
+
fp16=True,
|
186 |
+
logging_steps=10,
|
187 |
+
optim="adamw_torch",
|
188 |
+
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
189 |
+
save_strategy="steps",
|
190 |
+
eval_steps=200 if val_set_size > 0 else None,
|
191 |
+
save_steps=200,
|
192 |
+
output_dir=output_dir,
|
193 |
+
save_total_limit=3,
|
194 |
+
load_best_model_at_end=True if val_set_size > 0 else False,
|
195 |
+
ddp_find_unused_parameters=False if ddp else None,
|
196 |
+
group_by_length=group_by_length,
|
197 |
+
# report_to="wandb" if use_wandb else None,
|
198 |
+
# run_name=wandb_run_name if use_wandb else None,
|
199 |
+
),
|
200 |
+
data_collator=transformers.DataCollatorForSeq2Seq(
|
201 |
+
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
202 |
+
),
|
203 |
+
callbacks=callbacks,
|
204 |
+
)
|
205 |
+
model.config.use_cache = False
|
206 |
+
|
207 |
+
old_state_dict = model.state_dict
|
208 |
+
model.state_dict = (
|
209 |
+
lambda self, *_, **__: get_peft_model_state_dict(
|
210 |
+
self, old_state_dict()
|
211 |
+
)
|
212 |
+
).__get__(model, type(model))
|
213 |
+
|
214 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
215 |
+
model = torch.compile(model)
|
216 |
+
|
217 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
218 |
+
|
219 |
+
model.save_pretrained(output_dir)
|
220 |
+
|
221 |
+
print(
|
222 |
+
"\n If there's a warning about missing keys above, please disregard :)"
|
223 |
+
)
|
llama_lora/models.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
import sys
|
|
|
3 |
|
4 |
import torch
|
5 |
import transformers
|
@@ -31,11 +32,14 @@ def get_base_model():
|
|
31 |
|
32 |
|
33 |
def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
|
|
|
|
|
34 |
if device == "cuda":
|
35 |
return PeftModel.from_pretrained(
|
36 |
get_base_model(),
|
37 |
lora_weights,
|
38 |
torch_dtype=torch.float16,
|
|
|
39 |
)
|
40 |
elif device == "mps":
|
41 |
return PeftModel.from_pretrained(
|
@@ -58,16 +62,21 @@ def get_tokenizer():
|
|
58 |
|
59 |
|
60 |
def load_base_model():
|
|
|
|
|
|
|
61 |
if Global.loaded_tokenizer is None:
|
62 |
Global.loaded_tokenizer = LlamaTokenizer.from_pretrained(
|
63 |
-
Global.base_model
|
|
|
64 |
if Global.loaded_base_model is None:
|
65 |
if device == "cuda":
|
66 |
Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
|
67 |
Global.base_model,
|
68 |
load_in_8bit=Global.load_8bit,
|
69 |
torch_dtype=torch.float16,
|
70 |
-
device_map="auto",
|
|
|
71 |
)
|
72 |
elif device == "mps":
|
73 |
Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
|
@@ -79,3 +88,24 @@ def load_base_model():
|
|
79 |
model = LlamaForCausalLM.from_pretrained(
|
80 |
base_model, device_map={"": device}, low_cpu_mem_usage=True
|
81 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
import gc
|
4 |
|
5 |
import torch
|
6 |
import transformers
|
|
|
32 |
|
33 |
|
34 |
def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
|
35 |
+
Global.model_has_been_used = True
|
36 |
+
|
37 |
if device == "cuda":
|
38 |
return PeftModel.from_pretrained(
|
39 |
get_base_model(),
|
40 |
lora_weights,
|
41 |
torch_dtype=torch.float16,
|
42 |
+
device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
|
43 |
)
|
44 |
elif device == "mps":
|
45 |
return PeftModel.from_pretrained(
|
|
|
62 |
|
63 |
|
64 |
def load_base_model():
|
65 |
+
if Global.ui_dev_mode:
|
66 |
+
return
|
67 |
+
|
68 |
if Global.loaded_tokenizer is None:
|
69 |
Global.loaded_tokenizer = LlamaTokenizer.from_pretrained(
|
70 |
+
Global.base_model
|
71 |
+
)
|
72 |
if Global.loaded_base_model is None:
|
73 |
if device == "cuda":
|
74 |
Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
|
75 |
Global.base_model,
|
76 |
load_in_8bit=Global.load_8bit,
|
77 |
torch_dtype=torch.float16,
|
78 |
+
# device_map="auto",
|
79 |
+
device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
|
80 |
)
|
81 |
elif device == "mps":
|
82 |
Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
|
|
|
88 |
model = LlamaForCausalLM.from_pretrained(
|
89 |
base_model, device_map={"": device}, low_cpu_mem_usage=True
|
90 |
)
|
91 |
+
|
92 |
+
|
93 |
+
def unload_models():
|
94 |
+
del Global.loaded_base_model
|
95 |
+
Global.loaded_base_model = None
|
96 |
+
|
97 |
+
del Global.loaded_tokenizer
|
98 |
+
Global.loaded_tokenizer = None
|
99 |
+
|
100 |
+
gc.collect()
|
101 |
+
|
102 |
+
# if not shared.args.cpu: # will not be running on CPUs anyway
|
103 |
+
with torch.no_grad():
|
104 |
+
torch.cuda.empty_cache()
|
105 |
+
|
106 |
+
Global.model_has_been_used = False
|
107 |
+
|
108 |
+
|
109 |
+
def unload_models_if_already_used():
|
110 |
+
if Global.model_has_been_used:
|
111 |
+
unload_models()
|
llama_lora/ui/finetune_ui.py
CHANGED
@@ -9,7 +9,7 @@ from random_word import RandomWords
|
|
9 |
from transformers import TrainerCallback
|
10 |
|
11 |
from ..globals import Global
|
12 |
-
from ..models import get_base_model, get_tokenizer
|
13 |
from ..utils.data import (
|
14 |
get_available_template_names,
|
15 |
get_available_dataset_names,
|
@@ -353,6 +353,11 @@ Train data (first 10):
|
|
353 |
|
354 |
training_callbacks = [UiTrainerCallback]
|
355 |
|
|
|
|
|
|
|
|
|
|
|
356 |
Global.should_stop_training = False
|
357 |
|
358 |
return Global.train_fn(
|
|
|
9 |
from transformers import TrainerCallback
|
10 |
|
11 |
from ..globals import Global
|
12 |
+
from ..models import get_base_model, get_tokenizer, unload_models_if_already_used
|
13 |
from ..utils.data import (
|
14 |
get_available_template_names,
|
15 |
get_available_dataset_names,
|
|
|
353 |
|
354 |
training_callbacks = [UiTrainerCallback]
|
355 |
|
356 |
+
# If model has been used in inference, we need to unload it first.
|
357 |
+
# Otherwise, we'll get a 'Function MmBackward0 returned an invalid
|
358 |
+
# gradient at index 1 - expected device meta but got cuda:0' error.
|
359 |
+
unload_models_if_already_used()
|
360 |
+
|
361 |
Global.should_stop_training = False
|
362 |
|
363 |
return Global.train_fn(
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -26,6 +26,7 @@ def inference(
|
|
26 |
repetition_penalty=1.2,
|
27 |
max_new_tokens=128,
|
28 |
stream_output=False,
|
|
|
29 |
**kwargs,
|
30 |
):
|
31 |
variables = [variable_0, variable_1, variable_2, variable_3,
|
|
|
26 |
repetition_penalty=1.2,
|
27 |
max_new_tokens=128,
|
28 |
stream_output=False,
|
29 |
+
progress=gr.Progress(track_tqdm=True),
|
30 |
**kwargs,
|
31 |
):
|
32 |
variables = [variable_0, variable_1, variable_2, variable_3,
|