jbilcke-hf HF Staff commited on
Commit
5a793ee
·
1 Parent(s): aeb51a1

small fix (not tested yet)

Browse files
vms/ui/project/tabs/manage_tab.py CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
6
  import logging
7
  import shutil
8
  from pathlib import Path
9
- from typing import Dict, Any, List, Optional
10
  from gradio_modal import Modal
11
 
12
  from vms.utils import BaseTab, validate_model_repo
@@ -51,6 +51,17 @@ class ManageTab(BaseTab):
51
  """Update the download button text"""
52
  return gr.update(value=self.get_download_button_text())
53
 
 
 
 
 
 
 
 
 
 
 
 
54
  def download_and_update_button(self):
55
  """Handle download and return updated button with current text"""
56
  # Get the safetensors path for download
 
6
  import logging
7
  import shutil
8
  from pathlib import Path
9
+ from typing import Dict, Any, List, Optional, Tuple
10
  from gradio_modal import Modal
11
 
12
  from vms.utils import BaseTab, validate_model_repo
 
51
  """Update the download button text"""
52
  return gr.update(value=self.get_download_button_text())
53
 
54
+ def update_checkpoint_button_text(self) -> gr.update:
55
+ """Update the checkpoint button text"""
56
+ return gr.update(value=self.get_checkpoint_button_text())
57
+
58
+ def update_both_download_buttons(self) -> Tuple[gr.update, gr.update]:
59
+ """Update both download button texts"""
60
+ return (
61
+ gr.update(value=self.get_download_button_text()),
62
+ gr.update(value=self.get_checkpoint_button_text())
63
+ )
64
+
65
  def download_and_update_button(self):
66
  """Handle download and return updated button with current text"""
67
  # Get the safetensors path for download
vms/ui/project/tabs/train_tab.py CHANGED
@@ -341,9 +341,12 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
341
  ## ⚗️ Train your model on your dataset
342
  - **🚀 Start new training**: Begins training from scratch (clears previous checkpoints)
343
  - **🛸 Start from latest checkpoint**: Continues training from the most recent checkpoint
344
- - **🔄 Start over using latest LoRA weights**: Start fresh training but use existing LoRA weights as initialization
345
  """)
346
-
 
 
 
 
347
  with gr.Row():
348
  # Check for existing checkpoints to determine button text
349
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
@@ -485,11 +488,18 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
485
  self.app.training.append_log("Cleared previous checkpoints for new training session")
486
 
487
  # Start training normally
488
- return self.handle_training_start(
489
  model_type, model_version, training_type,
490
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
491
  save_iterations, repo_id, progress
492
  )
 
 
 
 
 
 
 
493
 
494
  def handle_resume_training(
495
  self, model_type, model_version, training_type,
@@ -501,17 +511,27 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
501
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
502
 
503
  if not checkpoints:
504
- return "No checkpoints found to resume from", "Please start a new training session instead"
 
 
 
505
 
506
  self.app.training.append_log(f"Resuming training from latest checkpoint")
507
 
508
  # Start training with the checkpoint
509
- return self.handle_training_start(
510
  model_type, model_version, training_type,
511
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
512
  save_iterations, repo_id, progress,
513
  resume_from_checkpoint="latest"
514
  )
 
 
 
 
 
 
 
515
 
516
  def handle_start_from_lora_training(
517
  self, model_type, model_version, training_type,
@@ -522,22 +542,26 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
522
  # Find the latest LoRA weights
523
  lora_weights_path = self.app.output_path / "lora_weights"
524
 
 
 
 
 
525
  if not lora_weights_path.exists():
526
- return "No LoRA weights found", "Please train a model first or start a new training session"
527
 
528
  # Find the latest LoRA checkpoint directory
529
  lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
530
  key=lambda x: int(x.name), reverse=True)
531
 
532
  if not lora_dirs:
533
- return "No LoRA weight directories found", "Please train a model first or start a new training session"
534
 
535
  latest_lora_dir = lora_dirs[0]
536
 
537
  # Verify the LoRA weights file exists
538
  lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
539
  if not lora_weights_file.exists():
540
- return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory"
541
 
542
  # Clear checkpoints to start fresh (but keep LoRA weights)
543
  for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
@@ -552,11 +576,17 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
552
  self.app.training.append_log(f"Starting training from LoRA weights: {latest_lora_dir}")
553
 
554
  # Start training with the LoRA weights
555
- return self.handle_training_start(
556
  model_type, model_version, training_type,
557
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
558
  save_iterations, repo_id, progress,
559
  )
 
 
 
 
 
 
560
 
561
  def connect_events(self) -> None:
562
  """Connect event handlers to UI components"""
@@ -739,7 +769,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
739
  ],
740
  outputs=[
741
  self.components["status_box"],
742
- self.components["log_box"]
 
 
743
  ]
744
  )
745
 
@@ -759,7 +791,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
759
  ],
760
  outputs=[
761
  self.components["status_box"],
762
- self.components["log_box"]
 
 
763
  ]
764
  )
765
 
@@ -779,7 +813,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
779
  ],
780
  outputs=[
781
  self.components["status_box"],
782
- self.components["log_box"]
 
 
783
  ]
784
  )
785
 
@@ -795,7 +831,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
795
  self.components["current_task_box"],
796
  self.components["start_btn"],
797
  self.components["stop_btn"],
798
- third_btn
 
 
799
  ]
800
  )
801
 
@@ -807,7 +845,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
807
  self.components["current_task_box"],
808
  self.components["start_btn"],
809
  self.components["stop_btn"],
810
- third_btn
 
 
811
  ]
812
  )
813
 
@@ -1200,7 +1240,12 @@ Full finetune mode trains all parameters of the model, requiring more VRAM but p
1200
  variant="stop"
1201
  )
1202
 
1203
- return start_btn, resume_btn, stop_btn, delete_checkpoints_btn
 
 
 
 
 
1204
 
1205
  def update_training_ui(self, training_state: Dict[str, Any]):
1206
  """Update UI components based on training state"""
 
341
  ## ⚗️ Train your model on your dataset
342
  - **🚀 Start new training**: Begins training from scratch (clears previous checkpoints)
343
  - **🛸 Start from latest checkpoint**: Continues training from the most recent checkpoint
 
344
  """)
345
+
346
+ #Finetrainers doesn't support recovery of a training session using a LoRA,
347
+ #so this feature doesn't work, I've disabled the line/documentation:
348
+ #- **🔄 Start over using latest LoRA weights**: Start fresh training but use existing LoRA weights as initialization
349
+
350
  with gr.Row():
351
  # Check for existing checkpoints to determine button text
352
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
 
488
  self.app.training.append_log("Cleared previous checkpoints for new training session")
489
 
490
  # Start training normally
491
+ status, logs = self.handle_training_start(
492
  model_type, model_version, training_type,
493
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
494
  save_iterations, repo_id, progress
495
  )
496
+
497
+ # Update download button texts
498
+ manage_tab = self.app.tabs["manage_tab"]
499
+ download_btn_text = gr.update(value=manage_tab.get_download_button_text())
500
+ checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
501
+
502
+ return status, logs, download_btn_text, checkpoint_btn_text
503
 
504
  def handle_resume_training(
505
  self, model_type, model_version, training_type,
 
511
  checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
512
 
513
  if not checkpoints:
514
+ manage_tab = self.app.tabs["manage_tab"]
515
+ download_btn_text = gr.update(value=manage_tab.get_download_button_text())
516
+ checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
517
+ return "No checkpoints found to resume from", "Please start a new training session instead", download_btn_text, checkpoint_btn_text
518
 
519
  self.app.training.append_log(f"Resuming training from latest checkpoint")
520
 
521
  # Start training with the checkpoint
522
+ status, logs = self.handle_training_start(
523
  model_type, model_version, training_type,
524
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
525
  save_iterations, repo_id, progress,
526
  resume_from_checkpoint="latest"
527
  )
528
+
529
+ # Update download button texts
530
+ manage_tab = self.app.tabs["manage_tab"]
531
+ download_btn_text = gr.update(value=manage_tab.get_download_button_text())
532
+ checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
533
+
534
+ return status, logs, download_btn_text, checkpoint_btn_text
535
 
536
  def handle_start_from_lora_training(
537
  self, model_type, model_version, training_type,
 
542
  # Find the latest LoRA weights
543
  lora_weights_path = self.app.output_path / "lora_weights"
544
 
545
+ manage_tab = self.app.tabs["manage_tab"]
546
+ download_btn_text = gr.update(value=manage_tab.get_download_button_text())
547
+ checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
548
+
549
  if not lora_weights_path.exists():
550
+ return "No LoRA weights found", "Please train a model first or start a new training session", download_btn_text, checkpoint_btn_text
551
 
552
  # Find the latest LoRA checkpoint directory
553
  lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
554
  key=lambda x: int(x.name), reverse=True)
555
 
556
  if not lora_dirs:
557
+ return "No LoRA weight directories found", "Please train a model first or start a new training session", download_btn_text, checkpoint_btn_text
558
 
559
  latest_lora_dir = lora_dirs[0]
560
 
561
  # Verify the LoRA weights file exists
562
  lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
563
  if not lora_weights_file.exists():
564
+ return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory", download_btn_text, checkpoint_btn_text
565
 
566
  # Clear checkpoints to start fresh (but keep LoRA weights)
567
  for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
 
576
  self.app.training.append_log(f"Starting training from LoRA weights: {latest_lora_dir}")
577
 
578
  # Start training with the LoRA weights
579
+ status, logs = self.handle_training_start(
580
  model_type, model_version, training_type,
581
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
582
  save_iterations, repo_id, progress,
583
  )
584
+
585
+ # Update download button texts
586
+ download_btn_text = gr.update(value=manage_tab.get_download_button_text())
587
+ checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
588
+
589
+ return status, logs, download_btn_text, checkpoint_btn_text
590
 
591
  def connect_events(self) -> None:
592
  """Connect event handlers to UI components"""
 
769
  ],
770
  outputs=[
771
  self.components["status_box"],
772
+ self.components["log_box"],
773
+ self.app.tabs["manage_tab"].components["download_model_btn"],
774
+ self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
775
  ]
776
  )
777
 
 
791
  ],
792
  outputs=[
793
  self.components["status_box"],
794
+ self.components["log_box"],
795
+ self.app.tabs["manage_tab"].components["download_model_btn"],
796
+ self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
797
  ]
798
  )
799
 
 
813
  ],
814
  outputs=[
815
  self.components["status_box"],
816
+ self.components["log_box"],
817
+ self.app.tabs["manage_tab"].components["download_model_btn"],
818
+ self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
819
  ]
820
  )
821
 
 
831
  self.components["current_task_box"],
832
  self.components["start_btn"],
833
  self.components["stop_btn"],
834
+ third_btn,
835
+ self.app.tabs["manage_tab"].components["download_model_btn"],
836
+ self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
837
  ]
838
  )
839
 
 
845
  self.components["current_task_box"],
846
  self.components["start_btn"],
847
  self.components["stop_btn"],
848
+ third_btn,
849
+ self.app.tabs["manage_tab"].components["download_model_btn"],
850
+ self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
851
  ]
852
  )
853
 
 
1240
  variant="stop"
1241
  )
1242
 
1243
+ # Update download button texts
1244
+ manage_tab = self.app.tabs["manage_tab"]
1245
+ download_btn_text = gr.update(value=manage_tab.get_download_button_text())
1246
+ checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
1247
+
1248
+ return start_btn, resume_btn, stop_btn, delete_checkpoints_btn, download_btn_text, checkpoint_btn_text
1249
 
1250
  def update_training_ui(self, training_state: Dict[str, Any]):
1251
  """Update UI components based on training state"""