zetavg commited on
Commit
bcc3066
1 Parent(s): f69a138

finetune: add load_in_8bit and fp16 options

Browse files
llama_lora/lib/finetune.py CHANGED
@@ -28,6 +28,8 @@ def train(
28
  tokenizer: Any,
29
  output_dir: str,
30
  train_data: List[Any],
 
 
31
  # training hyperparams
32
  micro_batch_size: int = 4,
33
  gradient_accumulation_steps: int = 32,
@@ -79,6 +81,8 @@ def train(
79
  'lora_target_modules': lora_target_modules,
80
  'train_on_inputs': train_on_inputs,
81
  'group_by_length': group_by_length,
 
 
82
  'save_steps': save_steps,
83
  'save_total_limit': save_total_limit,
84
  'logging_steps': logging_steps,
@@ -140,7 +144,7 @@ def train(
140
  model_name = model
141
  model = AutoModelForCausalLM.from_pretrained(
142
  base_model,
143
- load_in_8bit=True,
144
  torch_dtype=torch.float16,
145
  llm_int8_skip_modules=lora_modules_to_save,
146
  device_map=device_map,
@@ -289,7 +293,7 @@ def train(
289
  warmup_steps=100,
290
  num_train_epochs=num_train_epochs,
291
  learning_rate=learning_rate,
292
- # fp16=True,
293
  logging_steps=logging_steps,
294
  optim="adamw_torch",
295
  evaluation_strategy="steps" if val_set_size > 0 else "no",
 
28
  tokenizer: Any,
29
  output_dir: str,
30
  train_data: List[Any],
31
+ load_in_8bit=True,
32
+ fp16=True,
33
  # training hyperparams
34
  micro_batch_size: int = 4,
35
  gradient_accumulation_steps: int = 32,
 
81
  'lora_target_modules': lora_target_modules,
82
  'train_on_inputs': train_on_inputs,
83
  'group_by_length': group_by_length,
84
+ 'load_in_8bit': load_in_8bit,
85
+ 'fp16': fp16,
86
  'save_steps': save_steps,
87
  'save_total_limit': save_total_limit,
88
  'logging_steps': logging_steps,
 
144
  model_name = model
145
  model = AutoModelForCausalLM.from_pretrained(
146
  base_model,
147
+ load_in_8bit=load_in_8bit,
148
  torch_dtype=torch.float16,
149
  llm_int8_skip_modules=lora_modules_to_save,
150
  device_map=device_map,
 
293
  warmup_steps=100,
294
  num_train_epochs=num_train_epochs,
295
  learning_rate=learning_rate,
296
+ fp16=fp16,
297
  logging_steps=logging_steps,
298
  optim="adamw_torch",
299
  evaluation_strategy="steps" if val_set_size > 0 else "no",
llama_lora/ui/finetune_ui.py CHANGED
@@ -297,6 +297,8 @@ def do_train(
297
  lora_dropout,
298
  lora_target_modules,
299
  lora_modules_to_save,
 
 
300
  save_steps,
301
  save_total_limit,
302
  logging_steps,
@@ -389,6 +391,8 @@ Train options: {json.dumps({
389
  'lora_dropout': lora_dropout,
390
  'lora_target_modules': lora_target_modules,
391
  'lora_modules_to_save': lora_modules_to_save,
 
 
392
  'model_name': model_name,
393
  'continue_from_model': continue_from_model,
394
  'continue_from_checkpoint': continue_from_checkpoint,
@@ -526,6 +530,8 @@ Train data (first 10):
526
  lora_target_modules=lora_target_modules,
527
  lora_modules_to_save=lora_modules_to_save,
528
  train_on_inputs=train_on_inputs,
 
 
529
  group_by_length=False,
530
  resume_from_checkpoint=resume_from_checkpoint,
531
  save_steps=save_steps,
@@ -589,6 +595,8 @@ def handle_load_params_from_model(
589
  lora_dropout,
590
  lora_target_modules,
591
  lora_modules_to_save,
 
 
592
  save_steps,
593
  save_total_limit,
594
  logging_steps,
@@ -650,6 +658,10 @@ def handle_load_params_from_model(
650
  for element in value:
651
  if element not in lora_modules_to_save_choices:
652
  lora_modules_to_save_choices.append(element)
 
 
 
 
653
  elif key == "save_steps":
654
  save_steps = value
655
  elif key == "save_total_limit":
@@ -691,6 +703,8 @@ def handle_load_params_from_model(
691
  choices=lora_target_module_choices),
692
  gr.CheckboxGroup.update(
693
  value=lora_modules_to_save, choices=lora_modules_to_save_choices),
 
 
694
  save_steps,
695
  save_total_limit,
696
  logging_steps,
@@ -934,6 +948,11 @@ def finetune_ui():
934
  )
935
  )
936
 
 
 
 
 
 
937
  with gr.Column():
938
  lora_r = gr.Slider(
939
  minimum=1, maximum=16, step=1, value=8,
@@ -1102,6 +1121,8 @@ def finetune_ui():
1102
  lora_dropout,
1103
  lora_target_modules,
1104
  lora_modules_to_save,
 
 
1105
  save_steps,
1106
  save_total_limit,
1107
  logging_steps,
 
297
  lora_dropout,
298
  lora_target_modules,
299
  lora_modules_to_save,
300
+ load_in_8bit,
301
+ fp16,
302
  save_steps,
303
  save_total_limit,
304
  logging_steps,
 
391
  'lora_dropout': lora_dropout,
392
  'lora_target_modules': lora_target_modules,
393
  'lora_modules_to_save': lora_modules_to_save,
394
+ 'load_in_8bit': load_in_8bit,
395
+ 'fp16': fp16,
396
  'model_name': model_name,
397
  'continue_from_model': continue_from_model,
398
  'continue_from_checkpoint': continue_from_checkpoint,
 
530
  lora_target_modules=lora_target_modules,
531
  lora_modules_to_save=lora_modules_to_save,
532
  train_on_inputs=train_on_inputs,
533
+ load_in_8bit=load_in_8bit,
534
+ fp16=fp16,
535
  group_by_length=False,
536
  resume_from_checkpoint=resume_from_checkpoint,
537
  save_steps=save_steps,
 
595
  lora_dropout,
596
  lora_target_modules,
597
  lora_modules_to_save,
598
+ load_in_8bit,
599
+ fp16,
600
  save_steps,
601
  save_total_limit,
602
  logging_steps,
 
658
  for element in value:
659
  if element not in lora_modules_to_save_choices:
660
  lora_modules_to_save_choices.append(element)
661
+ elif key == "load_in_8bit":
662
+ load_in_8bit = value
663
+ elif key == "fp16":
664
+ fp16 = value
665
  elif key == "save_steps":
666
  save_steps = value
667
  elif key == "save_total_limit":
 
703
  choices=lora_target_module_choices),
704
  gr.CheckboxGroup.update(
705
  value=lora_modules_to_save, choices=lora_modules_to_save_choices),
706
+ load_in_8bit,
707
+ fp16,
708
  save_steps,
709
  save_total_limit,
710
  logging_steps,
 
948
  )
949
  )
950
 
951
+ with gr.Accordion("Advanced Options", open=False, elem_id="finetune_advance_options_accordion"):
952
+ with gr.Row():
953
+ load_in_8bit = gr.Checkbox(label="8bit", value=True)
954
+ fp16 = gr.Checkbox(label="FP16", value=True)
955
+
956
  with gr.Column():
957
  lora_r = gr.Slider(
958
  minimum=1, maximum=16, step=1, value=8,
 
1121
  lora_dropout,
1122
  lora_target_modules,
1123
  lora_modules_to_save,
1124
+ load_in_8bit,
1125
+ fp16,
1126
  save_steps,
1127
  save_total_limit,
1128
  logging_steps,