Spaces:
Running
Running
Commit
·
c8589f9
1
Parent(s):
adc5756
various fixes regarding session recovery
Browse files- finetrainers/dataset.py +1 -1
- vms/services/trainer.py +134 -132
- vms/tabs/train_tab.py +55 -9
- vms/ui/video_trainer_ui.py +30 -5
finetrainers/dataset.py
CHANGED
@@ -32,6 +32,7 @@ from .constants import ( # noqa
|
|
32 |
PRECOMPUTED_LATENTS_DIR_NAME,
|
33 |
)
|
34 |
|
|
|
35 |
|
36 |
# Decord is causing us some issues!
|
37 |
# Let's try to increase file descriptor limits to avoid this error:
|
@@ -49,7 +50,6 @@ try:
|
|
49 |
except Exception as e:
|
50 |
logger.warning(f"Could not check or update file descriptor limits: {e}")
|
51 |
|
52 |
-
logger = get_logger(__name__)
|
53 |
|
54 |
|
55 |
# TODO(aryan): This needs a refactor with separation of concerns.
|
|
|
32 |
PRECOMPUTED_LATENTS_DIR_NAME,
|
33 |
)
|
34 |
|
35 |
+
logger = get_logger(__name__)
|
36 |
|
37 |
# Decord is causing us some issues!
|
38 |
# Let's try to increase file descriptor limits to avoid this error:
|
|
|
50 |
except Exception as e:
|
51 |
logger.warning(f"Could not check or update file descriptor limits: {e}")
|
52 |
|
|
|
53 |
|
54 |
|
55 |
# TODO(aryan): This needs a refactor with separation of concerns.
|
vms/services/trainer.py
CHANGED
@@ -637,149 +637,151 @@ class TrainingService:
|
|
637 |
return False
|
638 |
|
639 |
def recover_interrupted_training(self) -> Dict[str, Any]:
|
640 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
641 |
|
642 |
-
|
643 |
-
Dict with recovery status and UI updates
|
644 |
-
"""
|
645 |
-
status = self.get_status()
|
646 |
-
ui_updates = {}
|
647 |
|
648 |
-
#
|
649 |
-
|
650 |
-
has_checkpoints = len(checkpoints) > 0
|
651 |
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
(has_checkpoints and not self.is_training_running()):
|
656 |
-
|
657 |
-
logger.info("Detected interrupted training session or existing checkpoints, attempting to recover...")
|
658 |
-
|
659 |
-
# Get the latest checkpoint
|
660 |
-
last_session = self.load_session()
|
661 |
-
|
662 |
-
if not last_session:
|
663 |
-
logger.warning("No session data found for recovery, but will check for checkpoints")
|
664 |
-
# Try to create a default session based on UI state if we have checkpoints
|
665 |
-
if has_checkpoints:
|
666 |
-
ui_state = self.load_ui_state()
|
667 |
-
# Create a default session using UI state values
|
668 |
-
last_session = {
|
669 |
-
"params": {
|
670 |
-
"model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
|
671 |
-
"lora_rank": ui_state.get("lora_rank", "128"),
|
672 |
-
"lora_alpha": ui_state.get("lora_alpha", "128"),
|
673 |
-
"num_epochs": ui_state.get("num_epochs", 70),
|
674 |
-
"batch_size": ui_state.get("batch_size", 1),
|
675 |
-
"learning_rate": ui_state.get("learning_rate", 3e-5),
|
676 |
-
"save_iterations": ui_state.get("save_iterations", 500),
|
677 |
-
"preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
678 |
-
"repo_id": "" # Default empty repo ID
|
679 |
-
}
|
680 |
-
}
|
681 |
-
logger.info("Created default session from UI state for recovery")
|
682 |
-
else:
|
683 |
-
# Set buttons for no active training
|
684 |
-
ui_updates = {
|
685 |
-
"start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
|
686 |
-
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
687 |
-
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
688 |
-
}
|
689 |
-
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
690 |
-
|
691 |
-
# Find the latest checkpoint if we have checkpoints
|
692 |
-
latest_checkpoint = None
|
693 |
-
checkpoint_step = 0
|
694 |
-
|
695 |
if has_checkpoints:
|
696 |
-
|
697 |
-
|
698 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
699 |
else:
|
700 |
-
logger.warning("No checkpoints found for recovery")
|
701 |
# Set buttons for no active training
|
702 |
ui_updates = {
|
703 |
"start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
|
704 |
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
|
|
705 |
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
706 |
}
|
707 |
-
return {"status": "
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
"
|
722 |
-
"
|
723 |
-
"
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
}
|
782 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
783 |
else:
|
784 |
# Set up UI for manual recovery
|
785 |
ui_updates.update({
|
|
|
637 |
return False
|
638 |
|
639 |
def recover_interrupted_training(self) -> Dict[str, Any]:
|
640 |
+
"""Attempt to recover interrupted training
|
641 |
+
|
642 |
+
Returns:
|
643 |
+
Dict with recovery status and UI updates
|
644 |
+
"""
|
645 |
+
status = self.get_status()
|
646 |
+
ui_updates = {}
|
647 |
+
|
648 |
+
# Check for any checkpoints, even if status doesn't indicate training
|
649 |
+
checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
|
650 |
+
has_checkpoints = len(checkpoints) > 0
|
651 |
+
|
652 |
+
# If status indicates training but process isn't running, or if we have checkpoints
|
653 |
+
# and no active training process, try to recover
|
654 |
+
if (status.get('status') in ['training', 'paused'] and not self.is_training_running()) or \
|
655 |
+
(has_checkpoints and not self.is_training_running()):
|
656 |
|
657 |
+
logger.info("Detected interrupted training session or existing checkpoints, attempting to recover...")
|
|
|
|
|
|
|
|
|
658 |
|
659 |
+
# Get the latest checkpoint
|
660 |
+
last_session = self.load_session()
|
|
|
661 |
|
662 |
+
if not last_session:
|
663 |
+
logger.warning("No session data found for recovery, but will check for checkpoints")
|
664 |
+
# Try to create a default session based on UI state if we have checkpoints
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
665 |
if has_checkpoints:
|
666 |
+
ui_state = self.load_ui_state()
|
667 |
+
# Create a default session using UI state values
|
668 |
+
last_session = {
|
669 |
+
"params": {
|
670 |
+
"model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
|
671 |
+
"lora_rank": ui_state.get("lora_rank", "128"),
|
672 |
+
"lora_alpha": ui_state.get("lora_alpha", "128"),
|
673 |
+
"num_epochs": ui_state.get("num_epochs", 70),
|
674 |
+
"batch_size": ui_state.get("batch_size", 1),
|
675 |
+
"learning_rate": ui_state.get("learning_rate", 3e-5),
|
676 |
+
"save_iterations": ui_state.get("save_iterations", 500),
|
677 |
+
"preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
678 |
+
"repo_id": "" # Default empty repo ID
|
679 |
+
}
|
680 |
+
}
|
681 |
+
logger.info("Created default session from UI state for recovery")
|
682 |
else:
|
|
|
683 |
# Set buttons for no active training
|
684 |
ui_updates = {
|
685 |
"start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
|
686 |
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
687 |
+
"delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
|
688 |
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
689 |
}
|
690 |
+
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
691 |
+
|
692 |
+
# Find the latest checkpoint if we have checkpoints
|
693 |
+
latest_checkpoint = None
|
694 |
+
checkpoint_step = 0
|
695 |
+
|
696 |
+
if has_checkpoints:
|
697 |
+
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
698 |
+
checkpoint_step = int(latest_checkpoint.name.split("-")[1])
|
699 |
+
logger.info(f"Found checkpoint at step {checkpoint_step}")
|
700 |
+
else:
|
701 |
+
logger.warning("No checkpoints found for recovery")
|
702 |
+
# Set buttons for no active training
|
703 |
+
ui_updates = {
|
704 |
+
"start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
|
705 |
+
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
706 |
+
"delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
|
707 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
708 |
+
}
|
709 |
+
return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
|
710 |
+
|
711 |
+
# Extract parameters from the saved session (not current UI state)
|
712 |
+
# This ensures we use the original training parameters
|
713 |
+
params = last_session.get('params', {})
|
714 |
+
|
715 |
+
# Map internal model type back to display name for UI
|
716 |
+
# This is the key fix for the "ltx_video" vs "LTX-Video (LoRA)" mismatch
|
717 |
+
model_type_internal = params.get('model_type')
|
718 |
+
model_type_display = model_type_internal
|
719 |
+
|
720 |
+
# Find the display name that maps to our internal model type
|
721 |
+
for display_name, internal_name in MODEL_TYPES.items():
|
722 |
+
if internal_name == model_type_internal:
|
723 |
+
model_type_display = display_name
|
724 |
+
logger.info(f"Mapped internal model type '{model_type_internal}' to display name '{model_type_display}'")
|
725 |
+
break
|
726 |
+
|
727 |
+
# Add UI updates to restore the training parameters in the UI
|
728 |
+
# This shows the user what values are being used for the resumed training
|
729 |
+
ui_updates.update({
|
730 |
+
"model_type": model_type_display, # Use the display name for the UI dropdown
|
731 |
+
"lora_rank": params.get('lora_rank', "128"),
|
732 |
+
"lora_alpha": params.get('lora_alpha', "128"),
|
733 |
+
"num_epochs": params.get('num_epochs', 70),
|
734 |
+
"batch_size": params.get('batch_size', 1),
|
735 |
+
"learning_rate": params.get('learning_rate', 3e-5),
|
736 |
+
"save_iterations": params.get('save_iterations', 500),
|
737 |
+
"training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
|
738 |
+
})
|
739 |
+
|
740 |
+
# Check if we should auto-recover (immediate restart)
|
741 |
+
auto_recover = True # Always auto-recover on startup
|
742 |
+
|
743 |
+
if auto_recover:
|
744 |
+
# Rest of the auto-recovery code remains unchanged
|
745 |
+
try:
|
746 |
+
# Use the internal model_type for the actual training
|
747 |
+
# But keep model_type_display for the UI
|
748 |
+
result = self.start_training(
|
749 |
+
model_type=model_type_internal,
|
750 |
+
lora_rank=params.get('lora_rank', "128"),
|
751 |
+
lora_alpha=params.get('lora_alpha', "128"),
|
752 |
+
num_epochs=params.get('num_epochs', 70),
|
753 |
+
batch_size=params.get('batch_size', 1),
|
754 |
+
learning_rate=params.get('learning_rate', 3e-5),
|
755 |
+
save_iterations=params.get('save_iterations', 500),
|
756 |
+
repo_id=params.get('repo_id', ''),
|
757 |
+
preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
|
758 |
+
resume_from_checkpoint=str(latest_checkpoint)
|
759 |
+
)
|
760 |
+
|
761 |
+
# Set buttons for active training
|
762 |
+
ui_updates.update({
|
763 |
+
"start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training"},
|
764 |
+
"stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"},
|
765 |
+
"delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
|
766 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
767 |
+
})
|
768 |
+
|
769 |
+
return {
|
770 |
+
"status": "recovered",
|
771 |
+
"message": f"Training resumed from checkpoint {checkpoint_step}",
|
772 |
+
"result": result,
|
773 |
+
"ui_updates": ui_updates
|
774 |
+
}
|
775 |
+
except Exception as e:
|
776 |
+
logger.error(f"Failed to auto-resume training: {str(e)}")
|
777 |
+
# Set buttons for manual recovery
|
778 |
+
ui_updates.update({
|
779 |
+
"start_btn": {"interactive": True, "variant": "primary", "value": "Continue Training"},
|
780 |
+
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
781 |
+
"delete_checkpoints_btn": {"interactive": True, "variant": "stop", "value": "Delete All Checkpoints"},
|
782 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
783 |
+
})
|
784 |
+
return {"status": "error", "message": f"Failed to auto-resume: {str(e)}", "ui_updates": ui_updates}
|
785 |
else:
|
786 |
# Set up UI for manual recovery
|
787 |
ui_updates.update({
|
vms/tabs/train_tab.py
CHANGED
@@ -8,7 +8,7 @@ from typing import Dict, Any, List, Optional, Tuple
|
|
8 |
from pathlib import Path
|
9 |
|
10 |
from .base_tab import BaseTab
|
11 |
-
from ..config import TRAINING_PRESETS, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS
|
12 |
from ..utils import TrainingLogParser
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
@@ -279,7 +279,7 @@ class TrainTab(BaseTab):
|
|
279 |
)
|
280 |
|
281 |
def handle_training_start(self, preset, model_type, *args):
|
282 |
-
"""Handle training start with proper log parser reset"""
|
283 |
# Safely reset log parser if it exists
|
284 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
285 |
self.app.log_parser.reset()
|
@@ -288,12 +288,35 @@ class TrainTab(BaseTab):
|
|
288 |
from ..utils import TrainingLogParser
|
289 |
self.app.log_parser = TrainingLogParser()
|
290 |
|
291 |
-
#
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
def get_model_info(self, model_type: str) -> str:
|
299 |
"""Get information about the selected model type"""
|
@@ -455,6 +478,23 @@ class TrainTab(BaseTab):
|
|
455 |
state = self.app.trainer.get_status()
|
456 |
logs = self.app.trainer.get_logs()
|
457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
# Ensure log parser is initialized
|
459 |
if not hasattr(self.app, 'log_parser') or self.app.log_parser is None:
|
460 |
from ..utils import TrainingLogParser
|
@@ -462,7 +502,7 @@ class TrainTab(BaseTab):
|
|
462 |
logger.info("Initialized missing log parser")
|
463 |
|
464 |
# Parse new log lines
|
465 |
-
if logs:
|
466 |
last_state = None
|
467 |
for line in logs.splitlines():
|
468 |
try:
|
@@ -480,6 +520,12 @@ class TrainTab(BaseTab):
|
|
480 |
# Parse status for training state
|
481 |
if "completed" in state["message"].lower():
|
482 |
state["status"] = "completed"
|
|
|
|
|
|
|
|
|
|
|
|
|
483 |
|
484 |
return (state["status"], state["message"], logs)
|
485 |
|
|
|
8 |
from pathlib import Path
|
9 |
|
10 |
from .base_tab import BaseTab
|
11 |
+
from ..config import TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS
|
12 |
from ..utils import TrainingLogParser
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
|
|
279 |
)
|
280 |
|
281 |
def handle_training_start(self, preset, model_type, *args):
|
282 |
+
"""Handle training start with proper log parser reset and checkpoint detection"""
|
283 |
# Safely reset log parser if it exists
|
284 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
285 |
self.app.log_parser.reset()
|
|
|
288 |
from ..utils import TrainingLogParser
|
289 |
self.app.log_parser = TrainingLogParser()
|
290 |
|
291 |
+
# Check for latest checkpoint
|
292 |
+
checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
|
293 |
+
resume_from = None
|
294 |
+
|
295 |
+
if checkpoints:
|
296 |
+
# Find the latest checkpoint
|
297 |
+
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
298 |
+
resume_from = str(latest_checkpoint)
|
299 |
+
logger.info(f"Found checkpoint at {resume_from}, will resume training")
|
300 |
+
|
301 |
+
# Convert model_type display name to internal name
|
302 |
+
model_internal_type = MODEL_TYPES.get(model_type)
|
303 |
+
|
304 |
+
if not model_internal_type:
|
305 |
+
logger.error(f"Invalid model type: {model_type}")
|
306 |
+
return f"Error: Invalid model type '{model_type}'", "Model type not recognized"
|
307 |
+
|
308 |
+
# Start training (it will automatically use the checkpoint if provided)
|
309 |
+
try:
|
310 |
+
return self.app.trainer.start_training(
|
311 |
+
model_internal_type, # Use internal model type
|
312 |
+
*args,
|
313 |
+
preset_name=preset,
|
314 |
+
resume_from_checkpoint=resume_from
|
315 |
+
)
|
316 |
+
except Exception as e:
|
317 |
+
logger.exception("Error starting training")
|
318 |
+
return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
|
319 |
+
|
320 |
|
321 |
def get_model_info(self, model_type: str) -> str:
|
322 |
"""Get information about the selected model type"""
|
|
|
478 |
state = self.app.trainer.get_status()
|
479 |
logs = self.app.trainer.get_logs()
|
480 |
|
481 |
+
# Check if training process died unexpectedly
|
482 |
+
training_died = False
|
483 |
+
|
484 |
+
if state["status"] == "training" and not self.app.trainer.is_training_running():
|
485 |
+
state["status"] = "error"
|
486 |
+
state["message"] = "Training process terminated unexpectedly."
|
487 |
+
training_died = True
|
488 |
+
|
489 |
+
# Look for error in logs
|
490 |
+
error_lines = []
|
491 |
+
for line in logs.splitlines():
|
492 |
+
if "Error:" in line or "Exception:" in line or "Traceback" in line:
|
493 |
+
error_lines.append(line)
|
494 |
+
|
495 |
+
if error_lines:
|
496 |
+
state["message"] += f"\n\nPossible error: {error_lines[-1]}"
|
497 |
+
|
498 |
# Ensure log parser is initialized
|
499 |
if not hasattr(self.app, 'log_parser') or self.app.log_parser is None:
|
500 |
from ..utils import TrainingLogParser
|
|
|
502 |
logger.info("Initialized missing log parser")
|
503 |
|
504 |
# Parse new log lines
|
505 |
+
if logs and not training_died:
|
506 |
last_state = None
|
507 |
for line in logs.splitlines():
|
508 |
try:
|
|
|
520 |
# Parse status for training state
|
521 |
if "completed" in state["message"].lower():
|
522 |
state["status"] = "completed"
|
523 |
+
elif "error" in state["message"].lower():
|
524 |
+
state["status"] = "error"
|
525 |
+
elif "failed" in state["message"].lower():
|
526 |
+
state["status"] = "error"
|
527 |
+
elif "stopped" in state["message"].lower():
|
528 |
+
state["status"] = "stopped"
|
529 |
|
530 |
return (state["status"], state["message"], logs)
|
531 |
|
vms/ui/video_trainer_ui.py
CHANGED
@@ -7,7 +7,7 @@ from typing import Any, Optional, Dict, List, Union, Tuple
|
|
7 |
|
8 |
from ..services import TrainingService, CaptioningService, SplittingService, ImportService
|
9 |
from ..config import (
|
10 |
-
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
|
11 |
TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
|
12 |
MODEL_TYPES, SMALL_TRAINING_BUCKETS
|
13 |
)
|
@@ -160,7 +160,24 @@ class VideoTrainerUI:
|
|
160 |
|
161 |
# If we recovered training parameters from the original session
|
162 |
ui_state = {}
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
"batch_size", "learning_rate", "save_iterations", "training_preset"]:
|
165 |
if param in recovery_ui:
|
166 |
ui_state[param] = recovery_ui[param]
|
@@ -175,8 +192,16 @@ class VideoTrainerUI:
|
|
175 |
# Load values (potentially with recovery updates applied)
|
176 |
ui_state = self.load_ui_values()
|
177 |
|
178 |
-
|
179 |
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
lora_rank_val = ui_state.get("lora_rank", "128")
|
181 |
lora_alpha_val = ui_state.get("lora_alpha", "128")
|
182 |
num_epochs_val = int(ui_state.get("num_epochs", 70))
|
@@ -190,9 +215,9 @@ class VideoTrainerUI:
|
|
190 |
training_dataset,
|
191 |
start_btn,
|
192 |
stop_btn,
|
193 |
-
delete_checkpoints_btn,
|
194 |
training_preset,
|
195 |
-
model_type_val,
|
196 |
lora_rank_val,
|
197 |
lora_alpha_val,
|
198 |
num_epochs_val,
|
|
|
7 |
|
8 |
from ..services import TrainingService, CaptioningService, SplittingService, ImportService
|
9 |
from ..config import (
|
10 |
+
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
|
11 |
TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
|
12 |
MODEL_TYPES, SMALL_TRAINING_BUCKETS
|
13 |
)
|
|
|
160 |
|
161 |
# If we recovered training parameters from the original session
|
162 |
ui_state = {}
|
163 |
+
|
164 |
+
# Handle model_type specifically - could be internal or display name
|
165 |
+
if "model_type" in recovery_ui:
|
166 |
+
model_type_value = recovery_ui["model_type"]
|
167 |
+
|
168 |
+
# If it's an internal name, convert to display name
|
169 |
+
if model_type_value not in MODEL_TYPES:
|
170 |
+
# Find the display name for this internal model type
|
171 |
+
for display_name, internal_name in MODEL_TYPES.items():
|
172 |
+
if internal_name == model_type_value:
|
173 |
+
model_type_value = display_name
|
174 |
+
logger.info(f"Converted internal model type '{recovery_ui['model_type']}' to display name '{model_type_value}'")
|
175 |
+
break
|
176 |
+
|
177 |
+
ui_state["model_type"] = model_type_value
|
178 |
+
|
179 |
+
# Copy other parameters
|
180 |
+
for param in ["lora_rank", "lora_alpha", "num_epochs",
|
181 |
"batch_size", "learning_rate", "save_iterations", "training_preset"]:
|
182 |
if param in recovery_ui:
|
183 |
ui_state[param] = recovery_ui[param]
|
|
|
192 |
# Load values (potentially with recovery updates applied)
|
193 |
ui_state = self.load_ui_values()
|
194 |
|
195 |
+
# Ensure model_type is a display name, not internal name
|
196 |
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
|
197 |
+
if model_type_val not in MODEL_TYPES:
|
198 |
+
# Convert from internal to display name
|
199 |
+
for display_name, internal_name in MODEL_TYPES.items():
|
200 |
+
if internal_name == model_type_val:
|
201 |
+
model_type_val = display_name
|
202 |
+
break
|
203 |
+
|
204 |
+
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
|
205 |
lora_rank_val = ui_state.get("lora_rank", "128")
|
206 |
lora_alpha_val = ui_state.get("lora_alpha", "128")
|
207 |
num_epochs_val = int(ui_state.get("num_epochs", 70))
|
|
|
215 |
training_dataset,
|
216 |
start_btn,
|
217 |
stop_btn,
|
218 |
+
delete_checkpoints_btn,
|
219 |
training_preset,
|
220 |
+
model_type_val,
|
221 |
lora_rank_val,
|
222 |
lora_alpha_val,
|
223 |
num_epochs_val,
|