zetavg commited on
Commit
49ce4b9
Β·
unverified Β·
1 Parent(s): 72ff821

work on finetune ui

Browse files
llama_lora/globals.py CHANGED
@@ -14,6 +14,9 @@ class Global:
14
  loaded_tokenizer: Any = None
15
  loaded_base_model: Any = None
16
 
 
 
 
17
  # UI related
18
  ui_title: str = "LLaMA-LoRA"
19
  ui_emoji: str = "πŸ¦™πŸŽ›οΈ"
 
14
  loaded_tokenizer: Any = None
15
  loaded_base_model: Any = None
16
 
17
+ # Functions
18
+ train_fn: Any = None
19
+
20
  # UI related
21
  ui_title: str = "LLaMA-LoRA"
22
  ui_emoji: str = "πŸ¦™πŸŽ›οΈ"
llama_lora/ui/finetune_ui.py CHANGED
@@ -1,8 +1,12 @@
 
1
  import json
2
  import time
 
3
  import gradio as gr
4
  from random_word import RandomWords
5
 
 
 
6
  from ..utils.data import (
7
  get_available_template_names,
8
  get_available_dataset_names,
@@ -10,15 +14,20 @@ from ..utils.data import (
10
  )
11
  from ..utils.prompter import Prompter
12
 
13
- r = RandomWords()
14
-
15
 
16
  def random_hyphenated_word():
 
17
  word1 = r.get_random_word()
18
  word2 = r.get_random_word()
19
  return word1 + '-' + word2
20
 
21
 
 
 
 
 
 
 
22
  def reload_selections(current_template, current_dataset):
23
  available_template_names = get_available_template_names()
24
  available_template_names_with_none = available_template_names + ["None"]
@@ -226,6 +235,127 @@ def parse_plain_text_input(
226
  return result
227
 
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  def finetune_ui():
230
  with gr.Blocks() as finetune_ui_blocks:
231
  with gr.Column(elem_id="finetune_ui_content"):
@@ -356,75 +486,233 @@ def finetune_ui():
356
  outputs=[template, dataset_from_data_dir],
357
  )
358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  finetune_ui_blocks.load(_js="""
360
  function finetune_ui_blocks_js() {
361
  // Auto load options
362
  setTimeout(function () {
363
- document.getElementById("finetune_reload_selections_button").click();
364
  }, 100);
365
 
366
-
367
  // Add tooltips
368
  setTimeout(function () {
369
- tippy("#finetune_reload_selections_button", {
370
  placement: 'bottom-end',
371
  delay: [500, 0],
372
  animation: 'scale-subtle',
373
  content: 'Press to reload options.',
374
  });
375
 
376
- tippy("#finetune_template", {
377
  placement: 'bottom-start',
378
  delay: [500, 0],
379
  animation: 'scale-subtle',
380
- content: 'Select a template for your prompt. <br />To see how the selected template work, select the "Preview" tab and then check "Show actual prompt". <br />Templates are loaded from the "templates" folder of your data directory.',
 
381
  allowHTML: true,
382
  });
383
 
384
- tippy("#finetune_load_dataset_from", {
385
  placement: 'bottom-start',
386
  delay: [500, 0],
387
  animation: 'scale-subtle',
388
- content: '<strong>Text Input</strong>: Paste the dataset directly in the UI.<br/><strong>Data Dir</strong>: Select a dataset in the data directory.',
 
389
  allowHTML: true,
390
  });
391
 
392
- tippy("#finetune_dataset_preview_show_actual_prompt", {
393
  placement: 'bottom-start',
394
  delay: [500, 0],
395
  animation: 'scale-subtle',
396
- content: 'Check to show the prompt that will be feed to the language model.',
 
397
  });
398
 
399
- tippy("#dataset_plain_text_input_variables_separator", {
400
  placement: 'bottom',
401
  delay: [500, 0],
402
  animation: 'scale-subtle',
403
- content: 'Define a separator to separate input variables. Use "\\\\n" for new lines.',
 
404
  });
405
 
406
- tippy("#dataset_plain_text_input_and_output_separator", {
407
  placement: 'bottom',
408
  delay: [500, 0],
409
  animation: 'scale-subtle',
410
- content: 'Define a separator to separate the input (prompt) and the output (completion). Use "\\\\n" for new lines.',
 
411
  });
412
 
413
- tippy("#dataset_plain_text_data_separator", {
414
  placement: 'bottom',
415
  delay: [500, 0],
416
  animation: 'scale-subtle',
417
- content: 'Define a separator to separate different rows of the train data. Use "\\\\n" for new lines.',
 
418
  });
419
 
420
- tippy("#finetune_dataset_text_load_sample_button", {
421
  placement: 'bottom-start',
422
  delay: [500, 0],
423
  animation: 'scale-subtle',
424
- content: 'Press to load a sample dataset of the current selected format into the textbox.',
 
425
  });
426
 
 
 
 
 
 
 
 
427
  }, 100);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  }
429
  """)
430
 
 
1
+ import os
2
  import json
3
  import time
4
+ from datetime import datetime
5
  import gradio as gr
6
  from random_word import RandomWords
7
 
8
+ from ..globals import Global
9
+ from ..models import get_base_model, get_tokenizer
10
  from ..utils.data import (
11
  get_available_template_names,
12
  get_available_dataset_names,
 
14
  )
15
  from ..utils.prompter import Prompter
16
 
 
 
17
 
18
  def random_hyphenated_word():
19
+ r = RandomWords()
20
  word1 = r.get_random_word()
21
  word2 = r.get_random_word()
22
  return word1 + '-' + word2
23
 
24
 
25
+ def random_name():
26
+ current_datetime = datetime.now()
27
+ formatted_datetime = current_datetime.strftime("%Y-%m-%d-%H-%M-%S")
28
+ return f"{random_hyphenated_word()}-{formatted_datetime}"
29
+
30
+
31
  def reload_selections(current_template, current_dataset):
32
  available_template_names = get_available_template_names()
33
  available_template_names_with_none = available_template_names + ["None"]
 
235
  return result
236
 
237
 
238
+ def do_train(
239
+ # Dataset
240
+ template,
241
+ load_dataset_from,
242
+ dataset_from_data_dir,
243
+ dataset_text,
244
+ dataset_text_format,
245
+ dataset_plain_text_input_variables_separator,
246
+ dataset_plain_text_input_and_output_separator,
247
+ dataset_plain_text_data_separator,
248
+ # Training Options
249
+ max_seq_length,
250
+ micro_batch_size,
251
+ gradient_accumulation_steps,
252
+ epochs,
253
+ learning_rate,
254
+ lora_r,
255
+ lora_alpha,
256
+ lora_dropout,
257
+ model_name,
258
+ progress=gr.Progress(track_tqdm=True),
259
+ ):
260
+ try:
261
+ prompter = Prompter(template)
262
+ variable_names = prompter.get_variable_names()
263
+
264
+ if load_dataset_from == "Text Input":
265
+ if dataset_text_format == "JSON":
266
+ data = json.loads(dataset_text)
267
+ data = process_json_dataset(data)
268
+
269
+ elif dataset_text_format == "JSON Lines":
270
+ lines = dataset_text.split('\n')
271
+ data = []
272
+ for i, line in enumerate(lines):
273
+ line_number = i + 1
274
+ try:
275
+ data.append(json.loads(line))
276
+ except Exception as e:
277
+ raise ValueError(
278
+ f"Error parsing JSON on line {line_number}: {e}")
279
+
280
+ data = process_json_dataset(data)
281
+
282
+ else: # Plain Text
283
+ data = parse_plain_text_input(
284
+ dataset_text,
285
+ (
286
+ dataset_plain_text_input_variables_separator or
287
+ default_dataset_plain_text_input_variables_separator
288
+ ).replace("\\n", "\n"),
289
+ (
290
+ dataset_plain_text_input_and_output_separator or
291
+ default_dataset_plain_text_input_and_output_separator
292
+ ).replace("\\n", "\n"),
293
+ (
294
+ dataset_plain_text_data_separator or
295
+ default_dataset_plain_text_data_separator
296
+ ).replace("\\n", "\n"),
297
+ variable_names
298
+ )
299
+
300
+ else: # Load dataset from data directory
301
+ data = get_dataset_content(dataset_from_data_dir)
302
+ data = process_json_dataset(data)
303
+
304
+ data_count = len(data)
305
+
306
+ train_data = [
307
+ {
308
+ 'prompt': prompter.generate_prompt(d['variables']),
309
+ 'completion': d['output']}
310
+ for d in data]
311
+
312
+ if Global.ui_dev_mode:
313
+ message = f"""Currently in UI dev mode, not doing the actual training.
314
+
315
+ Train options: {json.dumps({
316
+ 'max_seq_length': max_seq_length,
317
+ 'micro_batch_size': micro_batch_size,
318
+ 'gradient_accumulation_steps': gradient_accumulation_steps,
319
+ 'epochs': epochs,
320
+ 'learning_rate': learning_rate,
321
+ 'lora_r': lora_r,
322
+ 'lora_alpha': lora_alpha,
323
+ 'lora_dropout': lora_dropout,
324
+ 'model_name': model_name,
325
+ }, indent=2)}
326
+
327
+ Train data (first 10):
328
+ {json.dumps(train_data[:10], indent=2)}
329
+ """
330
+ print(message)
331
+ time.sleep(2)
332
+ return message
333
+
334
+ return Global.train_fn(
335
+ get_base_model(), # base_model
336
+ get_tokenizer(), # tokenizer
337
+ os.path.join(Global.data_dir, "lora_models",
338
+ model_name), # output_dir
339
+ train_data,
340
+ # 128, # batch_size (is not used, use gradient_accumulation_steps instead)
341
+ micro_batch_size, # micro_batch_size
342
+ gradient_accumulation_steps,
343
+ epochs, # num_epochs
344
+ learning_rate, # learning_rate
345
+ max_seq_length, # cutoff_len
346
+ 0, # val_set_size
347
+ lora_r, # lora_r
348
+ lora_alpha, # lora_alpha
349
+ lora_dropout, # lora_dropout
350
+ ["q_proj", "v_proj"], # lora_target_modules
351
+ True, # train_on_inputs
352
+ False, # group_by_length
353
+ None, # resume_from_checkpoint
354
+ )
355
+ except Exception as e:
356
+ raise gr.Error(e)
357
+
358
+
359
  def finetune_ui():
360
  with gr.Blocks() as finetune_ui_blocks:
361
  with gr.Column(elem_id="finetune_ui_content"):
 
486
  outputs=[template, dataset_from_data_dir],
487
  )
488
 
489
+ max_seq_length = gr.Slider(
490
+ minimum=1, maximum=4096, value=512,
491
+ label="Max Sequence Length",
492
+ info="The maximum length of each sample text sequence. Sequences longer than this will be truncated."
493
+ )
494
+
495
+ with gr.Row():
496
+ with gr.Column():
497
+ micro_batch_size = gr.Slider(
498
+ minimum=1, maximum=100, value=1,
499
+ label="Micro Batch Size",
500
+ info="The number of examples in each mini-batch for gradient computation. A smaller micro_batch_size reduces memory usage but may increase training time."
501
+ )
502
+
503
+ gradient_accumulation_steps = gr.Slider(
504
+ minimum=1, maximum=10, value=1,
505
+ label="Gradient Accumulation Steps",
506
+ info="The number of steps to accumulate gradients before updating model parameters. This can be used to simulate a larger effective batch size without increasing memory usage."
507
+ )
508
+
509
+ epochs = gr.Slider(
510
+ minimum=1, maximum=100, value=1,
511
+ label="Epochs",
512
+ info="The number of times to iterate over the entire training dataset. A larger number of epochs may improve model performance but also increase the risk of overfitting.")
513
+
514
+ learning_rate = gr.Slider(
515
+ minimum=0.00001, maximum=0.01, value=3e-4,
516
+ label="Learning Rate",
517
+ info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
518
+ )
519
+
520
+ with gr.Column():
521
+ lora_r = gr.Slider(
522
+ minimum=1, maximum=16, value=8,
523
+ label="LoRA R",
524
+ info="The rank parameter for LoRA, which controls the dimensionality of the rank decomposition matrices. A larger lora_r increases the expressiveness and flexibility of LoRA but also increases the number of trainable parameters and memory usage."
525
+ )
526
+
527
+ lora_alpha = gr.Slider(
528
+ minimum=1, maximum=128, value=16,
529
+ label="LoRA Alpha",
530
+ info="The scaling parameter for LoRA, which controls how much LoRA affects the original pre-trained model weights. A larger lora_alpha amplifies the impact of LoRA but may also distort or override the pre-trained knowledge."
531
+ )
532
+
533
+ lora_dropout = gr.Slider(
534
+ minimum=0, maximum=1, value=0.01,
535
+ label="LoRA Dropout",
536
+ info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
537
+ )
538
+
539
+ with gr.Column():
540
+ model_name = gr.Textbox(
541
+ lines=1, label="LoRA Model Name", value=random_name(),
542
+ elem_id="finetune_model_name",
543
+ )
544
+
545
+ with gr.Row():
546
+ train_btn = gr.Button(
547
+ "Train", variant="primary", label="Train",
548
+ elem_id="finetune_start_btn"
549
+ )
550
+
551
+ abort_button = gr.Button(
552
+ "Abort", label="Abort",
553
+ elem_id="finetune_stop_btn"
554
+ )
555
+ confirm_abort_button = gr.Button(
556
+ "Confirm Abort", label="Confirm Abort", variant="stop",
557
+ elem_id="finetune_confirm_stop_btn"
558
+ )
559
+
560
+ training_status = gr.Text(
561
+ "Training status will be shown here.",
562
+ label="Training Status/Results",
563
+ elem_id="finetune_training_status")
564
+
565
+ train_progress = train_btn.click(
566
+ fn=do_train,
567
+ inputs=(dataset_inputs + [
568
+ max_seq_length,
569
+ micro_batch_size,
570
+ gradient_accumulation_steps,
571
+ epochs,
572
+ learning_rate,
573
+ lora_r,
574
+ lora_alpha,
575
+ lora_dropout,
576
+ model_name
577
+ ]),
578
+ outputs=training_status
579
+ )
580
+
581
+ # controlled by JS, shows the confirm_abort_button
582
+ abort_button.click(None, None, None, None)
583
+ confirm_abort_button.click(None, None, None, cancels=[train_progress])
584
+
585
  finetune_ui_blocks.load(_js="""
586
  function finetune_ui_blocks_js() {
587
  // Auto load options
588
  setTimeout(function () {
589
+ document.getElementById('finetune_reload_selections_button').click();
590
  }, 100);
591
 
 
592
  // Add tooltips
593
  setTimeout(function () {
594
+ tippy('#finetune_reload_selections_button', {
595
  placement: 'bottom-end',
596
  delay: [500, 0],
597
  animation: 'scale-subtle',
598
  content: 'Press to reload options.',
599
  });
600
 
601
+ tippy('#finetune_template', {
602
  placement: 'bottom-start',
603
  delay: [500, 0],
604
  animation: 'scale-subtle',
605
+ content:
606
+ 'Select a template for your prompt. <br />To see how the selected template work, select the "Preview" tab and then check "Show actual prompt". <br />Templates are loaded from the "templates" folder of your data directory.',
607
  allowHTML: true,
608
  });
609
 
610
+ tippy('#finetune_load_dataset_from', {
611
  placement: 'bottom-start',
612
  delay: [500, 0],
613
  animation: 'scale-subtle',
614
+ content:
615
+ '<strong>Text Input</strong>: Paste the dataset directly in the UI.<br/><strong>Data Dir</strong>: Select a dataset in the data directory.',
616
  allowHTML: true,
617
  });
618
 
619
+ tippy('#finetune_dataset_preview_show_actual_prompt', {
620
  placement: 'bottom-start',
621
  delay: [500, 0],
622
  animation: 'scale-subtle',
623
+ content:
624
+ 'Check to show the prompt that will be feed to the language model.',
625
  });
626
 
627
+ tippy('#dataset_plain_text_input_variables_separator', {
628
  placement: 'bottom',
629
  delay: [500, 0],
630
  animation: 'scale-subtle',
631
+ content:
632
+ 'Define a separator to separate input variables. Use "\\\\n" for new lines.',
633
  });
634
 
635
+ tippy('#dataset_plain_text_input_and_output_separator', {
636
  placement: 'bottom',
637
  delay: [500, 0],
638
  animation: 'scale-subtle',
639
+ content:
640
+ 'Define a separator to separate the input (prompt) and the output (completion). Use "\\\\n" for new lines.',
641
  });
642
 
643
+ tippy('#dataset_plain_text_data_separator', {
644
  placement: 'bottom',
645
  delay: [500, 0],
646
  animation: 'scale-subtle',
647
+ content:
648
+ 'Define a separator to separate different rows of the train data. Use "\\\\n" for new lines.',
649
  });
650
 
651
+ tippy('#finetune_dataset_text_load_sample_button', {
652
  placement: 'bottom-start',
653
  delay: [500, 0],
654
  animation: 'scale-subtle',
655
+ content:
656
+ 'Press to load a sample dataset of the current selected format into the textbox.',
657
  });
658
 
659
+ tippy('#finetune_model_name', {
660
+ placement: 'bottom',
661
+ delay: [500, 0],
662
+ animation: 'scale-subtle',
663
+ content:
664
+ 'The name of the new LoRA model. Must be unique.',
665
+ });
666
  }, 100);
667
+
668
+ // Show/hide start and stop button base on the state.
669
+ setTimeout(function () {
670
+ // Make the '#finetune_training_status > .wrap' element appear
671
+ if (!document.querySelector('#finetune_training_status > .wrap')) {
672
+ document.getElementById('finetune_confirm_stop_btn').click();
673
+ }
674
+
675
+ setTimeout(function () {
676
+ let resetStopButtonTimer;
677
+ document
678
+ .getElementById('finetune_stop_btn')
679
+ .addEventListener('click', function () {
680
+ if (resetStopButtonTimer) clearTimeout(resetStopButtonTimer);
681
+ resetStopButtonTimer = setTimeout(function () {
682
+ document.getElementById('finetune_stop_btn').style.display = 'block';
683
+ document.getElementById('finetune_confirm_stop_btn').style.display =
684
+ 'none';
685
+ }, 5000);
686
+ document.getElementById('finetune_stop_btn').style.display = 'none';
687
+ document.getElementById('finetune_confirm_stop_btn').style.display =
688
+ 'block';
689
+ });
690
+ const output_wrap_element = document.querySelector(
691
+ '#finetune_training_status > .wrap'
692
+ );
693
+ function handle_output_wrap_element_class_change() {
694
+ if (Array.from(output_wrap_element.classList).includes('hide')) {
695
+ if (resetStopButtonTimer) clearTimeout(resetStopButtonTimer);
696
+ document.getElementById('finetune_start_btn').style.display = 'block';
697
+ document.getElementById('finetune_stop_btn').style.display = 'none';
698
+ document.getElementById('finetune_confirm_stop_btn').style.display =
699
+ 'none';
700
+ } else {
701
+ document.getElementById('finetune_start_btn').style.display = 'none';
702
+ document.getElementById('finetune_stop_btn').style.display = 'block';
703
+ document.getElementById('finetune_confirm_stop_btn').style.display =
704
+ 'none';
705
+ }
706
+ }
707
+ new MutationObserver(function (mutationsList, observer) {
708
+ handle_output_wrap_element_class_change();
709
+ }).observe(output_wrap_element, {
710
+ attributes: true,
711
+ attributeFilter: ['class'],
712
+ });
713
+ handle_output_wrap_element_class_change();
714
+ }, 500);
715
+ }, 0);
716
  }
717
  """)
718
 
llama_lora/utils/data.py CHANGED
@@ -11,6 +11,7 @@ def init_data_dir():
11
  parent_directory_path = os.path.dirname(current_file_path)
12
  project_dir_path = os.path.abspath(
13
  os.path.join(parent_directory_path, "..", ".."))
 
14
  copy_sample_data_if_not_exists(os.path.join(project_dir_path, "templates"),
15
  os.path.join(Global.data_dir, "templates"))
16
  copy_sample_data_if_not_exists(os.path.join(project_dir_path, "datasets"),
 
11
  parent_directory_path = os.path.dirname(current_file_path)
12
  project_dir_path = os.path.abspath(
13
  os.path.join(parent_directory_path, "..", ".."))
14
+ os.makedirs(os.path.join(Global.data_dir, "lora_models"), exist_ok=True)
15
  copy_sample_data_if_not_exists(os.path.join(project_dir_path, "templates"),
16
  os.path.join(Global.data_dir, "templates"))
17
  copy_sample_data_if_not_exists(os.path.join(project_dir_path, "datasets"),