zetavg commited on
Commit
5fcf47b
1 Parent(s): 0fd8c50

ui for continue fine-tuning from existing model

Browse files
llama_lora/ui/finetune_ui.py CHANGED
@@ -17,7 +17,8 @@ from ..models import (
17
  from ..utils.data import (
18
  get_available_template_names,
19
  get_available_dataset_names,
20
- get_dataset_content
 
21
  )
22
  from ..utils.prompter import Prompter
23
 
@@ -49,13 +50,16 @@ def reload_selections(current_template, current_dataset):
49
  current_dataset = current_dataset or next(
50
  iter(available_dataset_names), None)
51
 
 
 
52
  return (
53
  gr.Dropdown.update(
54
  choices=available_template_names_with_none,
55
  value=current_template),
56
  gr.Dropdown.update(
57
  choices=available_dataset_names,
58
- value=current_dataset)
 
59
  )
60
 
61
 
@@ -228,7 +232,7 @@ def refresh_dataset_items_count(
228
  info_message = "This dataset contains " + info_message
229
  update_message = gr.Markdown.update(info_message, visible=True)
230
 
231
- return gr.Markdown.update(preview_info_message), update_message, update_message
232
  except Exception as e:
233
  update_message = gr.Markdown.update(
234
  f"<span class=\"finetune_dataset_error_message\">Error: {e}.</span>", visible=True)
@@ -236,13 +240,14 @@ def refresh_dataset_items_count(
236
  trace = traceback.format_exc()
237
  traces = [s.strip() for s in re.split("\n * File ", trace)]
238
  templates_path = os.path.join(Global.data_dir, "templates")
239
- traces_to_show = [s for s in traces if os.path.join(Global.data_dir, "templates") in s]
 
240
  traces_to_show = [re.sub(" *\n *", ": ", s) for s in traces_to_show]
241
  if len(traces_to_show) > 0:
242
  update_message = gr.Markdown.update(
243
  f"<span class=\"finetune_dataset_error_message\">Error: {e} ({','.join(traces_to_show)}).</span>", visible=True)
244
 
245
- return gr.Markdown.update("Set the dataset in the \"Prepare\" tab, then preview it here."), update_message, update_message
246
 
247
 
248
  def parse_plain_text_input(
@@ -281,7 +286,7 @@ def do_train(
281
  dataset_plain_text_data_separator,
282
  # Training Options
283
  max_seq_length,
284
- evaluate_data_percentage,
285
  micro_batch_size,
286
  gradient_accumulation_steps,
287
  epochs,
@@ -291,10 +296,12 @@ def do_train(
291
  lora_alpha,
292
  lora_dropout,
293
  lora_target_modules,
294
- model_name,
295
  save_steps,
296
  save_total_limit,
297
  logging_steps,
 
 
 
298
  progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
299
  ):
300
  try:
@@ -327,7 +334,6 @@ def do_train(
327
  train_data = prompter.get_train_data_from_dataset(data)
328
 
329
  data_count = len(train_data)
330
- evaluate_data_count = math.ceil(data_count * evaluate_data_percentage)
331
 
332
  def get_progress_text(epoch, epochs, last_loss):
333
  progress_detail = f"Epoch {math.ceil(epoch)}/{epochs}"
@@ -448,7 +454,7 @@ Train data (first 10):
448
  # 'epochs': epochs,
449
  # 'learning_rate': learning_rate,
450
 
451
- # 'evaluate_data_percentage': evaluate_data_percentage,
452
 
453
  # 'lora_r': lora_r,
454
  # 'lora_alpha': lora_alpha,
@@ -517,6 +523,127 @@ def do_abort_training():
517
  Global.should_stop_training = True
518
 
519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  def finetune_ui():
521
  things_that_might_timeout = []
522
 
@@ -640,33 +767,6 @@ def finetune_ui():
640
  ]
641
  dataset_preview_inputs = dataset_inputs + \
642
  [finetune_dataset_preview_count]
643
- for i in dataset_preview_inputs:
644
- things_that_might_timeout.append(
645
- i.change(
646
- fn=refresh_preview,
647
- inputs=dataset_preview_inputs,
648
- outputs=[
649
- finetune_dataset_preview,
650
- finetune_dataset_preview_info_message,
651
- dataset_from_text_message,
652
- dataset_from_data_dir_message
653
- ]
654
- ).then(
655
- fn=refresh_dataset_items_count,
656
- inputs=dataset_preview_inputs,
657
- outputs=[
658
- finetune_dataset_preview_info_message,
659
- dataset_from_text_message,
660
- dataset_from_data_dir_message
661
- ]
662
- ))
663
-
664
- things_that_might_timeout.append(reload_selections_button.click(
665
- reload_selections,
666
- inputs=[template, dataset_from_data_dir],
667
- outputs=[template, dataset_from_data_dir],
668
- )
669
- )
670
 
671
  with gr.Row():
672
  max_seq_length = gr.Slider(
@@ -719,12 +819,43 @@ def finetune_ui():
719
  info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
720
  )
721
 
722
- evaluate_data_percentage = gr.Slider(
723
- minimum=0, maximum=0.5, step=0.001, value=0,
724
- label="Evaluation Data Percentage",
725
- info="The percentage of data to be used for evaluation. This percentage of data will not be used for training and will be used to assess the performance of the model during the process."
726
  )
727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
  with gr.Column():
729
  lora_r = gr.Slider(
730
  minimum=1, maximum=16, step=1, value=8,
@@ -793,6 +924,59 @@ def finetune_ui():
793
  elem_id="finetune_confirm_stop_btn"
794
  )
795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
796
  train_output = gr.Text(
797
  "Training results will be shown here.",
798
  label="Train Output",
@@ -800,22 +984,10 @@ def finetune_ui():
800
 
801
  train_progress = train_btn.click(
802
  fn=do_train,
803
- inputs=(dataset_inputs + [
804
- max_seq_length,
805
- evaluate_data_percentage,
806
- micro_batch_size,
807
- gradient_accumulation_steps,
808
- epochs,
809
- learning_rate,
810
- train_on_inputs,
811
- lora_r,
812
- lora_alpha,
813
- lora_dropout,
814
- lora_target_modules,
815
  model_name,
816
- save_steps,
817
- save_total_limit,
818
- logging_steps,
819
  ]),
820
  outputs=train_output
821
  )
 
17
  from ..utils.data import (
18
  get_available_template_names,
19
  get_available_dataset_names,
20
+ get_dataset_content,
21
+ get_available_lora_model_names
22
  )
23
  from ..utils.prompter import Prompter
24
 
 
50
  current_dataset = current_dataset or next(
51
  iter(available_dataset_names), None)
52
 
53
+ available_lora_models = ["-"] + get_available_lora_model_names()
54
+
55
  return (
56
  gr.Dropdown.update(
57
  choices=available_template_names_with_none,
58
  value=current_template),
59
  gr.Dropdown.update(
60
  choices=available_dataset_names,
61
+ value=current_dataset),
62
+ gr.Dropdown.update(choices=available_lora_models)
63
  )
64
 
65
 
 
232
  info_message = "This dataset contains " + info_message
233
  update_message = gr.Markdown.update(info_message, visible=True)
234
 
235
+ return gr.Markdown.update(preview_info_message), update_message, update_message, gr.Slider.update(maximum=math.floor(data_count / 2))
236
  except Exception as e:
237
  update_message = gr.Markdown.update(
238
  f"<span class=\"finetune_dataset_error_message\">Error: {e}.</span>", visible=True)
 
240
  trace = traceback.format_exc()
241
  traces = [s.strip() for s in re.split("\n * File ", trace)]
242
  templates_path = os.path.join(Global.data_dir, "templates")
243
+ traces_to_show = [s for s in traces if os.path.join(
244
+ Global.data_dir, "templates") in s]
245
  traces_to_show = [re.sub(" *\n *", ": ", s) for s in traces_to_show]
246
  if len(traces_to_show) > 0:
247
  update_message = gr.Markdown.update(
248
  f"<span class=\"finetune_dataset_error_message\">Error: {e} ({','.join(traces_to_show)}).</span>", visible=True)
249
 
250
+ return gr.Markdown.update("Set the dataset in the \"Prepare\" tab, then preview it here."), update_message, update_message, gr.Slider.update(maximum=1)
251
 
252
 
253
  def parse_plain_text_input(
 
286
  dataset_plain_text_data_separator,
287
  # Training Options
288
  max_seq_length,
289
+ evaluate_data_count,
290
  micro_batch_size,
291
  gradient_accumulation_steps,
292
  epochs,
 
296
  lora_alpha,
297
  lora_dropout,
298
  lora_target_modules,
 
299
  save_steps,
300
  save_total_limit,
301
  logging_steps,
302
+ model_name,
303
+ continue_from_model,
304
+ continue_from_checkpoint,
305
  progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
306
  ):
307
  try:
 
334
  train_data = prompter.get_train_data_from_dataset(data)
335
 
336
  data_count = len(train_data)
 
337
 
338
  def get_progress_text(epoch, epochs, last_loss):
339
  progress_detail = f"Epoch {math.ceil(epoch)}/{epochs}"
 
454
  # 'epochs': epochs,
455
  # 'learning_rate': learning_rate,
456
 
457
+ # 'evaluate_data_count': evaluate_data_count,
458
 
459
  # 'lora_r': lora_r,
460
  # 'lora_alpha': lora_alpha,
 
523
  Global.should_stop_training = True
524
 
525
 
526
+ def handle_continue_from_model_change(model_name):
527
+ try:
528
+ lora_models_directory_path = os.path.join(
529
+ Global.data_dir, "lora_models")
530
+ lora_model_directory_path = os.path.join(
531
+ lora_models_directory_path, model_name)
532
+ all_files = os.listdir(lora_model_directory_path)
533
+ checkpoints = [
534
+ file for file in all_files if file.startswith("checkpoint-")]
535
+ checkpoints = ["-"] + checkpoints
536
+ can_load_params = "finetune_params.json" in all_files or "finetune_args.json" in all_files
537
+ return gr.Dropdown.update(choices=checkpoints, value="-"), gr.Button.update(visible=can_load_params), gr.Markdown.update(value="", visible=False)
538
+ except Exception:
539
+ pass
540
+ return gr.Dropdown.update(choices=["-"], value="-"), gr.Button.update(visible=False), gr.Markdown.update(value="", visible=False)
541
+
542
+
543
+ def handle_load_params_from_model(
544
+ model_name,
545
+ max_seq_length,
546
+ evaluate_data_count,
547
+ micro_batch_size,
548
+ gradient_accumulation_steps,
549
+ epochs,
550
+ learning_rate,
551
+ train_on_inputs,
552
+ lora_r,
553
+ lora_alpha,
554
+ lora_dropout,
555
+ lora_target_modules,
556
+ save_steps,
557
+ save_total_limit,
558
+ logging_steps,
559
+ ):
560
+ error_message = ""
561
+ notice_message = ""
562
+ unknown_keys = []
563
+ try:
564
+ lora_models_directory_path = os.path.join(
565
+ Global.data_dir, "lora_models")
566
+ lora_model_directory_path = os.path.join(
567
+ lora_models_directory_path, model_name)
568
+
569
+ data = {}
570
+ possible_files = ["finetune_params.json", "finetune_args.json"]
571
+ for file in possible_files:
572
+ try:
573
+ with open(os.path.join(lora_model_directory_path, file), "r") as f:
574
+ data = json.load(f)
575
+ except FileNotFoundError:
576
+ pass
577
+
578
+ for key, value in data.items():
579
+ if key == "max_seq_length":
580
+ max_seq_length = value
581
+ if key == "cutoff_len":
582
+ cutoff_len = value
583
+ elif key == "evaluate_data_count":
584
+ evaluate_data_count = value
585
+ elif key == "micro_batch_size":
586
+ micro_batch_size = value
587
+ elif key == "gradient_accumulation_steps":
588
+ gradient_accumulation_steps = value
589
+ elif key == "epochs":
590
+ epochs = value
591
+ elif key == "num_train_epochs":
592
+ epochs = value
593
+ elif key == "learning_rate":
594
+ learning_rate = value
595
+ elif key == "train_on_inputs":
596
+ train_on_inputs = value
597
+ elif key == "lora_r":
598
+ lora_r = value
599
+ elif key == "lora_alpha":
600
+ lora_alpha = value
601
+ elif key == "lora_dropout":
602
+ lora_dropout = value
603
+ elif key == "lora_target_modules":
604
+ lora_target_modules = value
605
+ elif key == "save_steps":
606
+ save_steps = value
607
+ elif key == "save_total_limit":
608
+ save_total_limit = value
609
+ elif key == "logging_steps":
610
+ logging_steps = value
611
+ elif key == "group_by_length":
612
+ pass
613
+ else:
614
+ unknown_keys.append(key)
615
+ except Exception as e:
616
+ error_message = str(e)
617
+
618
+ if len(unknown_keys) > 0:
619
+ notice_message = f"Note: cannot restore unknown arg: {', '.join([f'`{x}`' for x in unknown_keys])}"
620
+
621
+ message = ". ".join([x for x in [error_message, notice_message] if x])
622
+
623
+ has_message = False
624
+ if message:
625
+ message += "."
626
+ has_message = True
627
+
628
+ return (
629
+ gr.Markdown.update(value=message, visible=has_message),
630
+ max_seq_length,
631
+ evaluate_data_count,
632
+ micro_batch_size,
633
+ gradient_accumulation_steps,
634
+ epochs,
635
+ learning_rate,
636
+ train_on_inputs,
637
+ lora_r,
638
+ lora_alpha,
639
+ lora_dropout,
640
+ lora_target_modules,
641
+ save_steps,
642
+ save_total_limit,
643
+ logging_steps,
644
+ )
645
+
646
+
647
  def finetune_ui():
648
  things_that_might_timeout = []
649
 
 
767
  ]
768
  dataset_preview_inputs = dataset_inputs + \
769
  [finetune_dataset_preview_count]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
770
 
771
  with gr.Row():
772
  max_seq_length = gr.Slider(
 
819
  info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
820
  )
821
 
822
+ evaluate_data_count = gr.Slider(
823
+ minimum=0, maximum=1, step=1, value=0,
824
+ label="Evaluation Data Count",
825
+ info="The number of data to be used for evaluation. This amount of data will not be used for training and will be used to assess the performance of the model during the process."
826
  )
827
 
828
+ with gr.Box(elem_id="finetune_continue_from_model_box"):
829
+ with gr.Row():
830
+ continue_from_model = gr.Dropdown(
831
+ value="-",
832
+ label="Continue from Model",
833
+ choices=["-"],
834
+ elem_id="finetune_continue_from_model"
835
+ )
836
+ continue_from_checkpoint = gr.Dropdown(
837
+ value="-", label="Checkpoint", choices=["-"])
838
+ with gr.Column():
839
+ load_params_from_model_btn = gr.Button(
840
+ "Load training parameters from selected model", visible=False)
841
+ load_params_from_model_btn.style(
842
+ full_width=False,
843
+ size="sm")
844
+ load_params_from_model_message = gr.Markdown(
845
+ "", visible=False)
846
+
847
+ things_that_might_timeout.append(
848
+ continue_from_model.change(
849
+ fn=handle_continue_from_model_change,
850
+ inputs=[continue_from_model],
851
+ outputs=[
852
+ continue_from_checkpoint,
853
+ load_params_from_model_btn,
854
+ load_params_from_model_message
855
+ ]
856
+ )
857
+ )
858
+
859
  with gr.Column():
860
  lora_r = gr.Slider(
861
  minimum=1, maximum=16, step=1, value=8,
 
924
  elem_id="finetune_confirm_stop_btn"
925
  )
926
 
927
+ things_that_might_timeout.append(reload_selections_button.click(
928
+ reload_selections,
929
+ inputs=[template, dataset_from_data_dir],
930
+ outputs=[template, dataset_from_data_dir, continue_from_model],
931
+ ))
932
+
933
+ for i in dataset_preview_inputs:
934
+ things_that_might_timeout.append(
935
+ i.change(
936
+ fn=refresh_preview,
937
+ inputs=dataset_preview_inputs,
938
+ outputs=[
939
+ finetune_dataset_preview,
940
+ finetune_dataset_preview_info_message,
941
+ dataset_from_text_message,
942
+ dataset_from_data_dir_message
943
+ ]
944
+ ).then(
945
+ fn=refresh_dataset_items_count,
946
+ inputs=dataset_preview_inputs,
947
+ outputs=[
948
+ finetune_dataset_preview_info_message,
949
+ dataset_from_text_message,
950
+ dataset_from_data_dir_message,
951
+ evaluate_data_count,
952
+ ]
953
+ ))
954
+
955
+ finetune_args = [
956
+ max_seq_length,
957
+ evaluate_data_count,
958
+ micro_batch_size,
959
+ gradient_accumulation_steps,
960
+ epochs,
961
+ learning_rate,
962
+ train_on_inputs,
963
+ lora_r,
964
+ lora_alpha,
965
+ lora_dropout,
966
+ lora_target_modules,
967
+ save_steps,
968
+ save_total_limit,
969
+ logging_steps,
970
+ ]
971
+
972
+ things_that_might_timeout.append(
973
+ load_params_from_model_btn.click(
974
+ fn=handle_load_params_from_model,
975
+ inputs=[continue_from_model] + finetune_args,
976
+ outputs=[load_params_from_model_message] + finetune_args
977
+ )
978
+ )
979
+
980
  train_output = gr.Text(
981
  "Training results will be shown here.",
982
  label="Train Output",
 
984
 
985
  train_progress = train_btn.click(
986
  fn=do_train,
987
+ inputs=(dataset_inputs + finetune_args + [
 
 
 
 
 
 
 
 
 
 
 
988
  model_name,
989
+ continue_from_model,
990
+ continue_from_checkpoint,
 
991
  ]),
992
  outputs=train_output
993
  )
llama_lora/ui/main_page.py CHANGED
@@ -515,6 +515,24 @@ def main_page_custom_css():
515
  margin: -32px -16px;
516
  }
517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  .finetune_dataset_error_message {
519
  color: var(--error-text-color) !important;
520
  }
 
515
  margin: -32px -16px;
516
  }
517
 
518
+ #finetune_continue_from_model_box {
519
+ /* padding: 0; */
520
+ }
521
+ #finetune_continue_from_model_box .block {
522
+ border: 0;
523
+ box-shadow: none;
524
+ padding: 0;
525
+ }
526
+ #finetune_continue_from_model_box > * {
527
+ /* gap: 0; */
528
+ }
529
+ #finetune_continue_from_model_box button {
530
+ margin-top: 16px;
531
+ }
532
+ #finetune_continue_from_model {
533
+ flex-grow: 2;
534
+ }
535
+
536
  .finetune_dataset_error_message {
537
  color: var(--error-text-color) !important;
538
  }
llama_lora/utils/data.py CHANGED
@@ -30,19 +30,22 @@ def copy_sample_data_if_not_exists(source, destination):
30
  def get_available_template_names():
31
  templates_directory_path = os.path.join(Global.data_dir, "templates")
32
  all_files = os.listdir(templates_directory_path)
33
- return [filename.rstrip(".json") for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.py")]
 
34
 
35
 
36
  def get_available_dataset_names():
37
  datasets_directory_path = os.path.join(Global.data_dir, "datasets")
38
  all_files = os.listdir(datasets_directory_path)
39
- return [filename for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.jsonl")]
 
40
 
41
 
42
  def get_available_lora_model_names():
43
- datasets_directory_path = os.path.join(Global.data_dir, "lora_models")
44
- all_items = os.listdir(datasets_directory_path)
45
- return [item for item in all_items if os.path.isdir(os.path.join(datasets_directory_path, item))]
 
46
 
47
 
48
  def get_path_of_available_lora_model(name):
 
30
  def get_available_template_names():
31
  templates_directory_path = os.path.join(Global.data_dir, "templates")
32
  all_files = os.listdir(templates_directory_path)
33
+ names = [filename.rstrip(".json") for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.py")]
34
+ return sorted(names)
35
 
36
 
37
  def get_available_dataset_names():
38
  datasets_directory_path = os.path.join(Global.data_dir, "datasets")
39
  all_files = os.listdir(datasets_directory_path)
40
+ names = [filename for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.jsonl")]
41
+ return sorted(names)
42
 
43
 
44
  def get_available_lora_model_names():
45
+ lora_models_directory_path = os.path.join(Global.data_dir, "lora_models")
46
+ all_items = os.listdir(lora_models_directory_path)
47
+ names = [item for item in all_items if os.path.isdir(os.path.join(lora_models_directory_path, item))]
48
+ return sorted(names)
49
 
50
 
51
  def get_path_of_available_lora_model(name):