jbilcke-hf HF Staff commited on
Commit
2236e6f
·
1 Parent(s): 4f5cf39
Files changed (1) hide show
  1. vms/ui/project/tabs/train_tab.py +6 -4
vms/ui/project/tabs/train_tab.py CHANGED
@@ -580,7 +580,9 @@ class TrainTab(BaseTab):
580
  def handle_training_start(
581
  self, preset, model_type, model_version, training_type,
582
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
583
- save_iterations, repo_id, progress=gr.Progress()
 
 
584
  ):
585
  """Handle training start with proper log parser reset and checkpoint detection"""
586
  # Safely reset log parser if it exists
@@ -594,14 +596,14 @@ class TrainTab(BaseTab):
594
  # Check for latest checkpoint
595
  checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
596
  has_checkpoints = len(checkpoints) > 0
597
- resume_from = None
598
 
599
- if checkpoints:
600
  # Find the latest checkpoint
601
  latest_checkpoint = max(checkpoints, key=os.path.getmtime)
602
  resume_from = str(latest_checkpoint)
603
  logger.info(f"Found checkpoint at {resume_from}, will resume training")
604
-
605
  # Convert model_type display name to internal name
606
  model_internal_type = MODEL_TYPES.get(model_type)
607
 
 
580
  def handle_training_start(
581
  self, preset, model_type, model_version, training_type,
582
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
583
+ save_iterations, repo_id,
584
+ progress=gr.Progress(),
585
+ resume_from_checkpoint=None,
586
  ):
587
  """Handle training start with proper log parser reset and checkpoint detection"""
588
  # Safely reset log parser if it exists
 
596
  # Check for latest checkpoint
597
  checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
598
  has_checkpoints = len(checkpoints) > 0
599
+ resume_from = resume_from_checkpoint # Use the passed parameter
600
 
601
+ if resume_from == "latest" and checkpoints:
602
  # Find the latest checkpoint
603
  latest_checkpoint = max(checkpoints, key=os.path.getmtime)
604
  resume_from = str(latest_checkpoint)
605
  logger.info(f"Found checkpoint at {resume_from}, will resume training")
606
+
607
  # Convert model_type display name to internal name
608
  model_internal_type = MODEL_TYPES.get(model_type)
609