Spaces:
Running
Running
Commit
·
2236e6f
1
Parent(s):
4f5cf39
fix
Browse files
vms/ui/project/tabs/train_tab.py
CHANGED
@@ -580,7 +580,9 @@ class TrainTab(BaseTab):
|
|
580 |
def handle_training_start(
|
581 |
self, preset, model_type, model_version, training_type,
|
582 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
583 |
-
save_iterations, repo_id,
|
|
|
|
|
584 |
):
|
585 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
586 |
# Safely reset log parser if it exists
|
@@ -594,14 +596,14 @@ class TrainTab(BaseTab):
|
|
594 |
# Check for latest checkpoint
|
595 |
checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
|
596 |
has_checkpoints = len(checkpoints) > 0
|
597 |
-
resume_from =
|
598 |
|
599 |
-
if checkpoints:
|
600 |
# Find the latest checkpoint
|
601 |
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
602 |
resume_from = str(latest_checkpoint)
|
603 |
logger.info(f"Found checkpoint at {resume_from}, will resume training")
|
604 |
-
|
605 |
# Convert model_type display name to internal name
|
606 |
model_internal_type = MODEL_TYPES.get(model_type)
|
607 |
|
|
|
580 |
def handle_training_start(
|
581 |
self, preset, model_type, model_version, training_type,
|
582 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
583 |
+
save_iterations, repo_id,
|
584 |
+
progress=gr.Progress(),
|
585 |
+
resume_from_checkpoint=None,
|
586 |
):
|
587 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
588 |
# Safely reset log parser if it exists
|
|
|
596 |
# Check for latest checkpoint
|
597 |
checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
|
598 |
has_checkpoints = len(checkpoints) > 0
|
599 |
+
resume_from = resume_from_checkpoint # Use the passed parameter
|
600 |
|
601 |
+
if resume_from == "latest" and checkpoints:
|
602 |
# Find the latest checkpoint
|
603 |
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
604 |
resume_from = str(latest_checkpoint)
|
605 |
logger.info(f"Found checkpoint at {resume_from}, will resume training")
|
606 |
+
|
607 |
# Convert model_type display name to internal name
|
608 |
model_internal_type = MODEL_TYPES.get(model_type)
|
609 |
|