Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
•
bcc3066
1
Parent(s):
f69a138
finetune: add load_in_8bit and fp16 options
Browse files- llama_lora/lib/finetune.py +6 -2
- llama_lora/ui/finetune_ui.py +21 -0
llama_lora/lib/finetune.py
CHANGED
@@ -28,6 +28,8 @@ def train(
|
|
28 |
tokenizer: Any,
|
29 |
output_dir: str,
|
30 |
train_data: List[Any],
|
|
|
|
|
31 |
# training hyperparams
|
32 |
micro_batch_size: int = 4,
|
33 |
gradient_accumulation_steps: int = 32,
|
@@ -79,6 +81,8 @@ def train(
|
|
79 |
'lora_target_modules': lora_target_modules,
|
80 |
'train_on_inputs': train_on_inputs,
|
81 |
'group_by_length': group_by_length,
|
|
|
|
|
82 |
'save_steps': save_steps,
|
83 |
'save_total_limit': save_total_limit,
|
84 |
'logging_steps': logging_steps,
|
@@ -140,7 +144,7 @@ def train(
|
|
140 |
model_name = model
|
141 |
model = AutoModelForCausalLM.from_pretrained(
|
142 |
base_model,
|
143 |
-
load_in_8bit=
|
144 |
torch_dtype=torch.float16,
|
145 |
llm_int8_skip_modules=lora_modules_to_save,
|
146 |
device_map=device_map,
|
@@ -289,7 +293,7 @@ def train(
|
|
289 |
warmup_steps=100,
|
290 |
num_train_epochs=num_train_epochs,
|
291 |
learning_rate=learning_rate,
|
292 |
-
|
293 |
logging_steps=logging_steps,
|
294 |
optim="adamw_torch",
|
295 |
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
|
|
28 |
tokenizer: Any,
|
29 |
output_dir: str,
|
30 |
train_data: List[Any],
|
31 |
+
load_in_8bit=True,
|
32 |
+
fp16=True,
|
33 |
# training hyperparams
|
34 |
micro_batch_size: int = 4,
|
35 |
gradient_accumulation_steps: int = 32,
|
|
|
81 |
'lora_target_modules': lora_target_modules,
|
82 |
'train_on_inputs': train_on_inputs,
|
83 |
'group_by_length': group_by_length,
|
84 |
+
'load_in_8bit': load_in_8bit,
|
85 |
+
'fp16': fp16,
|
86 |
'save_steps': save_steps,
|
87 |
'save_total_limit': save_total_limit,
|
88 |
'logging_steps': logging_steps,
|
|
|
144 |
model_name = model
|
145 |
model = AutoModelForCausalLM.from_pretrained(
|
146 |
base_model,
|
147 |
+
load_in_8bit=load_in_8bit,
|
148 |
torch_dtype=torch.float16,
|
149 |
llm_int8_skip_modules=lora_modules_to_save,
|
150 |
device_map=device_map,
|
|
|
293 |
warmup_steps=100,
|
294 |
num_train_epochs=num_train_epochs,
|
295 |
learning_rate=learning_rate,
|
296 |
+
fp16=fp16,
|
297 |
logging_steps=logging_steps,
|
298 |
optim="adamw_torch",
|
299 |
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
llama_lora/ui/finetune_ui.py
CHANGED
@@ -297,6 +297,8 @@ def do_train(
|
|
297 |
lora_dropout,
|
298 |
lora_target_modules,
|
299 |
lora_modules_to_save,
|
|
|
|
|
300 |
save_steps,
|
301 |
save_total_limit,
|
302 |
logging_steps,
|
@@ -389,6 +391,8 @@ Train options: {json.dumps({
|
|
389 |
'lora_dropout': lora_dropout,
|
390 |
'lora_target_modules': lora_target_modules,
|
391 |
'lora_modules_to_save': lora_modules_to_save,
|
|
|
|
|
392 |
'model_name': model_name,
|
393 |
'continue_from_model': continue_from_model,
|
394 |
'continue_from_checkpoint': continue_from_checkpoint,
|
@@ -526,6 +530,8 @@ Train data (first 10):
|
|
526 |
lora_target_modules=lora_target_modules,
|
527 |
lora_modules_to_save=lora_modules_to_save,
|
528 |
train_on_inputs=train_on_inputs,
|
|
|
|
|
529 |
group_by_length=False,
|
530 |
resume_from_checkpoint=resume_from_checkpoint,
|
531 |
save_steps=save_steps,
|
@@ -589,6 +595,8 @@ def handle_load_params_from_model(
|
|
589 |
lora_dropout,
|
590 |
lora_target_modules,
|
591 |
lora_modules_to_save,
|
|
|
|
|
592 |
save_steps,
|
593 |
save_total_limit,
|
594 |
logging_steps,
|
@@ -650,6 +658,10 @@ def handle_load_params_from_model(
|
|
650 |
for element in value:
|
651 |
if element not in lora_modules_to_save_choices:
|
652 |
lora_modules_to_save_choices.append(element)
|
|
|
|
|
|
|
|
|
653 |
elif key == "save_steps":
|
654 |
save_steps = value
|
655 |
elif key == "save_total_limit":
|
@@ -691,6 +703,8 @@ def handle_load_params_from_model(
|
|
691 |
choices=lora_target_module_choices),
|
692 |
gr.CheckboxGroup.update(
|
693 |
value=lora_modules_to_save, choices=lora_modules_to_save_choices),
|
|
|
|
|
694 |
save_steps,
|
695 |
save_total_limit,
|
696 |
logging_steps,
|
@@ -934,6 +948,11 @@ def finetune_ui():
|
|
934 |
)
|
935 |
)
|
936 |
|
|
|
|
|
|
|
|
|
|
|
937 |
with gr.Column():
|
938 |
lora_r = gr.Slider(
|
939 |
minimum=1, maximum=16, step=1, value=8,
|
@@ -1102,6 +1121,8 @@ def finetune_ui():
|
|
1102 |
lora_dropout,
|
1103 |
lora_target_modules,
|
1104 |
lora_modules_to_save,
|
|
|
|
|
1105 |
save_steps,
|
1106 |
save_total_limit,
|
1107 |
logging_steps,
|
|
|
297 |
lora_dropout,
|
298 |
lora_target_modules,
|
299 |
lora_modules_to_save,
|
300 |
+
load_in_8bit,
|
301 |
+
fp16,
|
302 |
save_steps,
|
303 |
save_total_limit,
|
304 |
logging_steps,
|
|
|
391 |
'lora_dropout': lora_dropout,
|
392 |
'lora_target_modules': lora_target_modules,
|
393 |
'lora_modules_to_save': lora_modules_to_save,
|
394 |
+
'load_in_8bit': load_in_8bit,
|
395 |
+
'fp16': fp16,
|
396 |
'model_name': model_name,
|
397 |
'continue_from_model': continue_from_model,
|
398 |
'continue_from_checkpoint': continue_from_checkpoint,
|
|
|
530 |
lora_target_modules=lora_target_modules,
|
531 |
lora_modules_to_save=lora_modules_to_save,
|
532 |
train_on_inputs=train_on_inputs,
|
533 |
+
load_in_8bit=load_in_8bit,
|
534 |
+
fp16=fp16,
|
535 |
group_by_length=False,
|
536 |
resume_from_checkpoint=resume_from_checkpoint,
|
537 |
save_steps=save_steps,
|
|
|
595 |
lora_dropout,
|
596 |
lora_target_modules,
|
597 |
lora_modules_to_save,
|
598 |
+
load_in_8bit,
|
599 |
+
fp16,
|
600 |
save_steps,
|
601 |
save_total_limit,
|
602 |
logging_steps,
|
|
|
658 |
for element in value:
|
659 |
if element not in lora_modules_to_save_choices:
|
660 |
lora_modules_to_save_choices.append(element)
|
661 |
+
elif key == "load_in_8bit":
|
662 |
+
load_in_8bit = value
|
663 |
+
elif key == "fp16":
|
664 |
+
fp16 = value
|
665 |
elif key == "save_steps":
|
666 |
save_steps = value
|
667 |
elif key == "save_total_limit":
|
|
|
703 |
choices=lora_target_module_choices),
|
704 |
gr.CheckboxGroup.update(
|
705 |
value=lora_modules_to_save, choices=lora_modules_to_save_choices),
|
706 |
+
load_in_8bit,
|
707 |
+
fp16,
|
708 |
save_steps,
|
709 |
save_total_limit,
|
710 |
logging_steps,
|
|
|
948 |
)
|
949 |
)
|
950 |
|
951 |
+
with gr.Accordion("Advanced Options", open=False, elem_id="finetune_advance_options_accordion"):
|
952 |
+
with gr.Row():
|
953 |
+
load_in_8bit = gr.Checkbox(label="8bit", value=True)
|
954 |
+
fp16 = gr.Checkbox(label="FP16", value=True)
|
955 |
+
|
956 |
with gr.Column():
|
957 |
lora_r = gr.Slider(
|
958 |
minimum=1, maximum=16, step=1, value=8,
|
|
|
1121 |
lora_dropout,
|
1122 |
lora_target_modules,
|
1123 |
lora_modules_to_save,
|
1124 |
+
load_in_8bit,
|
1125 |
+
fp16,
|
1126 |
save_steps,
|
1127 |
save_total_limit,
|
1128 |
logging_steps,
|