jbilcke-hf HF Staff commited on
Commit
c8589f9
·
1 Parent(s): adc5756

various fixes regarding session recovery

Browse files
finetrainers/dataset.py CHANGED
@@ -32,6 +32,7 @@ from .constants import ( # noqa
32
  PRECOMPUTED_LATENTS_DIR_NAME,
33
  )
34
 
 
35
 
36
  # Decord is causing us some issues!
37
  # Let's try to increase file descriptor limits to avoid this error:
@@ -49,7 +50,6 @@ try:
49
  except Exception as e:
50
  logger.warning(f"Could not check or update file descriptor limits: {e}")
51
 
52
- logger = get_logger(__name__)
53
 
54
 
55
  # TODO(aryan): This needs a refactor with separation of concerns.
 
32
  PRECOMPUTED_LATENTS_DIR_NAME,
33
  )
34
 
35
+ logger = get_logger(__name__)
36
 
37
  # Decord is causing us some issues!
38
  # Let's try to increase file descriptor limits to avoid this error:
 
50
  except Exception as e:
51
  logger.warning(f"Could not check or update file descriptor limits: {e}")
52
 
 
53
 
54
 
55
  # TODO(aryan): This needs a refactor with separation of concerns.
vms/services/trainer.py CHANGED
@@ -637,149 +637,151 @@ class TrainingService:
637
  return False
638
 
639
  def recover_interrupted_training(self) -> Dict[str, Any]:
640
- """Attempt to recover interrupted training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
 
642
- Returns:
643
- Dict with recovery status and UI updates
644
- """
645
- status = self.get_status()
646
- ui_updates = {}
647
 
648
- # Check for any checkpoints, even if status doesn't indicate training
649
- checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
650
- has_checkpoints = len(checkpoints) > 0
651
 
652
- # If status indicates training but process isn't running, or if we have checkpoints
653
- # and no active training process, try to recover
654
- if (status.get('status') in ['training', 'paused'] and not self.is_training_running()) or \
655
- (has_checkpoints and not self.is_training_running()):
656
-
657
- logger.info("Detected interrupted training session or existing checkpoints, attempting to recover...")
658
-
659
- # Get the latest checkpoint
660
- last_session = self.load_session()
661
-
662
- if not last_session:
663
- logger.warning("No session data found for recovery, but will check for checkpoints")
664
- # Try to create a default session based on UI state if we have checkpoints
665
- if has_checkpoints:
666
- ui_state = self.load_ui_state()
667
- # Create a default session using UI state values
668
- last_session = {
669
- "params": {
670
- "model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
671
- "lora_rank": ui_state.get("lora_rank", "128"),
672
- "lora_alpha": ui_state.get("lora_alpha", "128"),
673
- "num_epochs": ui_state.get("num_epochs", 70),
674
- "batch_size": ui_state.get("batch_size", 1),
675
- "learning_rate": ui_state.get("learning_rate", 3e-5),
676
- "save_iterations": ui_state.get("save_iterations", 500),
677
- "preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
678
- "repo_id": "" # Default empty repo ID
679
- }
680
- }
681
- logger.info("Created default session from UI state for recovery")
682
- else:
683
- # Set buttons for no active training
684
- ui_updates = {
685
- "start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
686
- "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
687
- "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
688
- }
689
- return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
690
-
691
- # Find the latest checkpoint if we have checkpoints
692
- latest_checkpoint = None
693
- checkpoint_step = 0
694
-
695
  if has_checkpoints:
696
- latest_checkpoint = max(checkpoints, key=os.path.getmtime)
697
- checkpoint_step = int(latest_checkpoint.name.split("-")[1])
698
- logger.info(f"Found checkpoint at step {checkpoint_step}")
 
 
 
 
 
 
 
 
 
 
 
 
 
699
  else:
700
- logger.warning("No checkpoints found for recovery")
701
  # Set buttons for no active training
702
  ui_updates = {
703
  "start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
704
  "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
 
705
  "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
706
  }
707
- return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
708
-
709
- # Extract parameters from the saved session (not current UI state)
710
- # This ensures we use the original training parameters
711
- params = last_session.get('params', {})
712
-
713
- # Add UI updates to restore the training parameters in the UI
714
- # This shows the user what values are being used for the resumed training
715
- ui_updates.update({
716
- "model_type": params.get('model_type', list(MODEL_TYPES.keys())[0]),
717
- "lora_rank": params.get('lora_rank', "128"),
718
- "lora_alpha": params.get('lora_alpha', "128"),
719
- "num_epochs": params.get('num_epochs', 70),
720
- "batch_size": params.get('batch_size', 1),
721
- "learning_rate": params.get('learning_rate', 3e-5),
722
- "save_iterations": params.get('save_iterations', 500),
723
- "training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
724
- })
725
-
726
- # Check if we should auto-recover (immediate restart)
727
- auto_recover = True # Always auto-recover on startup
728
-
729
- if auto_recover:
730
- # Attempt to resume training using the ORIGINAL parameters
731
- try:
732
- # Extract required parameters from the session
733
- model_type = params.get('model_type')
734
- lora_rank = params.get('lora_rank')
735
- lora_alpha = params.get('lora_alpha')
736
- num_epochs = params.get('num_epochs')
737
- batch_size = params.get('batch_size')
738
- learning_rate = params.get('learning_rate')
739
- save_iterations = params.get('save_iterations')
740
- repo_id = params.get('repo_id', '')
741
- preset_name = params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
742
-
743
- # Log the recovery attempt
744
- self.append_log(f"Auto-recovering training from checkpoint {checkpoint_step}")
745
- gr.Info(f"Automatically resuming training from checkpoint {checkpoint_step}")
746
-
747
- # Attempt to resume training
748
- result = self.start_training(
749
- model_type=model_type,
750
- lora_rank=lora_rank,
751
- lora_alpha=lora_alpha,
752
- num_epochs=num_epochs,
753
- batch_size=batch_size,
754
- learning_rate=learning_rate,
755
- save_iterations=save_iterations,
756
- repo_id=repo_id,
757
- preset_name=preset_name,
758
- resume_from_checkpoint=str(latest_checkpoint)
759
- )
760
-
761
- # Set buttons for active training
762
- ui_updates.update({
763
- "start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training"},
764
- "stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"},
765
- "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
766
- })
767
-
768
- return {
769
- "status": "recovered",
770
- "message": f"Training resumed from checkpoint {checkpoint_step}",
771
- "result": result,
772
- "ui_updates": ui_updates
773
- }
774
- except Exception as e:
775
- logger.error(f"Failed to auto-resume training: {str(e)}")
776
- # Set buttons for manual recovery
777
- ui_updates.update({
778
- "start_btn": {"interactive": True, "variant": "primary", "value": "Continue Training"},
779
- "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
780
- "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
781
- })
782
- return {"status": "error", "message": f"Failed to auto-resume: {str(e)}", "ui_updates": ui_updates}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
783
  else:
784
  # Set up UI for manual recovery
785
  ui_updates.update({
 
637
  return False
638
 
639
  def recover_interrupted_training(self) -> Dict[str, Any]:
640
+ """Attempt to recover interrupted training
641
+
642
+ Returns:
643
+ Dict with recovery status and UI updates
644
+ """
645
+ status = self.get_status()
646
+ ui_updates = {}
647
+
648
+ # Check for any checkpoints, even if status doesn't indicate training
649
+ checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
650
+ has_checkpoints = len(checkpoints) > 0
651
+
652
+ # If status indicates training but process isn't running, or if we have checkpoints
653
+ # and no active training process, try to recover
654
+ if (status.get('status') in ['training', 'paused'] and not self.is_training_running()) or \
655
+ (has_checkpoints and not self.is_training_running()):
656
 
657
+ logger.info("Detected interrupted training session or existing checkpoints, attempting to recover...")
 
 
 
 
658
 
659
+ # Get the latest checkpoint
660
+ last_session = self.load_session()
 
661
 
662
+ if not last_session:
663
+ logger.warning("No session data found for recovery, but will check for checkpoints")
664
+ # Try to create a default session based on UI state if we have checkpoints
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
  if has_checkpoints:
666
+ ui_state = self.load_ui_state()
667
+ # Create a default session using UI state values
668
+ last_session = {
669
+ "params": {
670
+ "model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
671
+ "lora_rank": ui_state.get("lora_rank", "128"),
672
+ "lora_alpha": ui_state.get("lora_alpha", "128"),
673
+ "num_epochs": ui_state.get("num_epochs", 70),
674
+ "batch_size": ui_state.get("batch_size", 1),
675
+ "learning_rate": ui_state.get("learning_rate", 3e-5),
676
+ "save_iterations": ui_state.get("save_iterations", 500),
677
+ "preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
678
+ "repo_id": "" # Default empty repo ID
679
+ }
680
+ }
681
+ logger.info("Created default session from UI state for recovery")
682
  else:
 
683
  # Set buttons for no active training
684
  ui_updates = {
685
  "start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
686
  "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
687
+ "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
688
  "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
689
  }
690
+ return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
691
+
692
+ # Find the latest checkpoint if we have checkpoints
693
+ latest_checkpoint = None
694
+ checkpoint_step = 0
695
+
696
+ if has_checkpoints:
697
+ latest_checkpoint = max(checkpoints, key=os.path.getmtime)
698
+ checkpoint_step = int(latest_checkpoint.name.split("-")[1])
699
+ logger.info(f"Found checkpoint at step {checkpoint_step}")
700
+ else:
701
+ logger.warning("No checkpoints found for recovery")
702
+ # Set buttons for no active training
703
+ ui_updates = {
704
+ "start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
705
+ "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
706
+ "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
707
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
708
+ }
709
+ return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
710
+
711
+ # Extract parameters from the saved session (not current UI state)
712
+ # This ensures we use the original training parameters
713
+ params = last_session.get('params', {})
714
+
715
+ # Map internal model type back to display name for UI
716
+ # This is the key fix for the "ltx_video" vs "LTX-Video (LoRA)" mismatch
717
+ model_type_internal = params.get('model_type')
718
+ model_type_display = model_type_internal
719
+
720
+ # Find the display name that maps to our internal model type
721
+ for display_name, internal_name in MODEL_TYPES.items():
722
+ if internal_name == model_type_internal:
723
+ model_type_display = display_name
724
+ logger.info(f"Mapped internal model type '{model_type_internal}' to display name '{model_type_display}'")
725
+ break
726
+
727
+ # Add UI updates to restore the training parameters in the UI
728
+ # This shows the user what values are being used for the resumed training
729
+ ui_updates.update({
730
+ "model_type": model_type_display, # Use the display name for the UI dropdown
731
+ "lora_rank": params.get('lora_rank', "128"),
732
+ "lora_alpha": params.get('lora_alpha', "128"),
733
+ "num_epochs": params.get('num_epochs', 70),
734
+ "batch_size": params.get('batch_size', 1),
735
+ "learning_rate": params.get('learning_rate', 3e-5),
736
+ "save_iterations": params.get('save_iterations', 500),
737
+ "training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
738
+ })
739
+
740
+ # Check if we should auto-recover (immediate restart)
741
+ auto_recover = True # Always auto-recover on startup
742
+
743
+ if auto_recover:
744
+ # Rest of the auto-recovery code remains unchanged
745
+ try:
746
+ # Use the internal model_type for the actual training
747
+ # But keep model_type_display for the UI
748
+ result = self.start_training(
749
+ model_type=model_type_internal,
750
+ lora_rank=params.get('lora_rank', "128"),
751
+ lora_alpha=params.get('lora_alpha', "128"),
752
+ num_epochs=params.get('num_epochs', 70),
753
+ batch_size=params.get('batch_size', 1),
754
+ learning_rate=params.get('learning_rate', 3e-5),
755
+ save_iterations=params.get('save_iterations', 500),
756
+ repo_id=params.get('repo_id', ''),
757
+ preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
758
+ resume_from_checkpoint=str(latest_checkpoint)
759
+ )
760
+
761
+ # Set buttons for active training
762
+ ui_updates.update({
763
+ "start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training"},
764
+ "stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"},
765
+ "delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
766
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
767
+ })
768
+
769
+ return {
770
+ "status": "recovered",
771
+ "message": f"Training resumed from checkpoint {checkpoint_step}",
772
+ "result": result,
773
+ "ui_updates": ui_updates
774
+ }
775
+ except Exception as e:
776
+ logger.error(f"Failed to auto-resume training: {str(e)}")
777
+ # Set buttons for manual recovery
778
+ ui_updates.update({
779
+ "start_btn": {"interactive": True, "variant": "primary", "value": "Continue Training"},
780
+ "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
781
+ "delete_checkpoints_btn": {"interactive": True, "variant": "stop", "value": "Delete All Checkpoints"},
782
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
783
+ })
784
+ return {"status": "error", "message": f"Failed to auto-resume: {str(e)}", "ui_updates": ui_updates}
785
  else:
786
  # Set up UI for manual recovery
787
  ui_updates.update({
vms/tabs/train_tab.py CHANGED
@@ -8,7 +8,7 @@ from typing import Dict, Any, List, Optional, Tuple
8
  from pathlib import Path
9
 
10
  from .base_tab import BaseTab
11
- from ..config import TRAINING_PRESETS, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS
12
  from ..utils import TrainingLogParser
13
 
14
  logger = logging.getLogger(__name__)
@@ -279,7 +279,7 @@ class TrainTab(BaseTab):
279
  )
280
 
281
  def handle_training_start(self, preset, model_type, *args):
282
- """Handle training start with proper log parser reset"""
283
  # Safely reset log parser if it exists
284
  if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
285
  self.app.log_parser.reset()
@@ -288,12 +288,35 @@ class TrainTab(BaseTab):
288
  from ..utils import TrainingLogParser
289
  self.app.log_parser = TrainingLogParser()
290
 
291
- # Start training
292
- return self.app.trainer.start_training(
293
- MODEL_TYPES[model_type],
294
- *args,
295
- preset_name=preset
296
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
  def get_model_info(self, model_type: str) -> str:
299
  """Get information about the selected model type"""
@@ -455,6 +478,23 @@ class TrainTab(BaseTab):
455
  state = self.app.trainer.get_status()
456
  logs = self.app.trainer.get_logs()
457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  # Ensure log parser is initialized
459
  if not hasattr(self.app, 'log_parser') or self.app.log_parser is None:
460
  from ..utils import TrainingLogParser
@@ -462,7 +502,7 @@ class TrainTab(BaseTab):
462
  logger.info("Initialized missing log parser")
463
 
464
  # Parse new log lines
465
- if logs:
466
  last_state = None
467
  for line in logs.splitlines():
468
  try:
@@ -480,6 +520,12 @@ class TrainTab(BaseTab):
480
  # Parse status for training state
481
  if "completed" in state["message"].lower():
482
  state["status"] = "completed"
 
 
 
 
 
 
483
 
484
  return (state["status"], state["message"], logs)
485
 
 
8
  from pathlib import Path
9
 
10
  from .base_tab import BaseTab
11
+ from ..config import TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS
12
  from ..utils import TrainingLogParser
13
 
14
  logger = logging.getLogger(__name__)
 
279
  )
280
 
281
  def handle_training_start(self, preset, model_type, *args):
282
+ """Handle training start with proper log parser reset and checkpoint detection"""
283
  # Safely reset log parser if it exists
284
  if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
285
  self.app.log_parser.reset()
 
288
  from ..utils import TrainingLogParser
289
  self.app.log_parser = TrainingLogParser()
290
 
291
+ # Check for latest checkpoint
292
+ checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
293
+ resume_from = None
294
+
295
+ if checkpoints:
296
+ # Find the latest checkpoint
297
+ latest_checkpoint = max(checkpoints, key=os.path.getmtime)
298
+ resume_from = str(latest_checkpoint)
299
+ logger.info(f"Found checkpoint at {resume_from}, will resume training")
300
+
301
+ # Convert model_type display name to internal name
302
+ model_internal_type = MODEL_TYPES.get(model_type)
303
+
304
+ if not model_internal_type:
305
+ logger.error(f"Invalid model type: {model_type}")
306
+ return f"Error: Invalid model type '{model_type}'", "Model type not recognized"
307
+
308
+ # Start training (it will automatically use the checkpoint if provided)
309
+ try:
310
+ return self.app.trainer.start_training(
311
+ model_internal_type, # Use internal model type
312
+ *args,
313
+ preset_name=preset,
314
+ resume_from_checkpoint=resume_from
315
+ )
316
+ except Exception as e:
317
+ logger.exception("Error starting training")
318
+ return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
319
+
320
 
321
  def get_model_info(self, model_type: str) -> str:
322
  """Get information about the selected model type"""
 
478
  state = self.app.trainer.get_status()
479
  logs = self.app.trainer.get_logs()
480
 
481
+ # Check if training process died unexpectedly
482
+ training_died = False
483
+
484
+ if state["status"] == "training" and not self.app.trainer.is_training_running():
485
+ state["status"] = "error"
486
+ state["message"] = "Training process terminated unexpectedly."
487
+ training_died = True
488
+
489
+ # Look for error in logs
490
+ error_lines = []
491
+ for line in logs.splitlines():
492
+ if "Error:" in line or "Exception:" in line or "Traceback" in line:
493
+ error_lines.append(line)
494
+
495
+ if error_lines:
496
+ state["message"] += f"\n\nPossible error: {error_lines[-1]}"
497
+
498
  # Ensure log parser is initialized
499
  if not hasattr(self.app, 'log_parser') or self.app.log_parser is None:
500
  from ..utils import TrainingLogParser
 
502
  logger.info("Initialized missing log parser")
503
 
504
  # Parse new log lines
505
+ if logs and not training_died:
506
  last_state = None
507
  for line in logs.splitlines():
508
  try:
 
520
  # Parse status for training state
521
  if "completed" in state["message"].lower():
522
  state["status"] = "completed"
523
+ elif "error" in state["message"].lower():
524
+ state["status"] = "error"
525
+ elif "failed" in state["message"].lower():
526
+ state["status"] = "error"
527
+ elif "stopped" in state["message"].lower():
528
+ state["status"] = "stopped"
529
 
530
  return (state["status"], state["message"], logs)
531
 
vms/ui/video_trainer_ui.py CHANGED
@@ -7,7 +7,7 @@ from typing import Any, Optional, Dict, List, Union, Tuple
7
 
8
  from ..services import TrainingService, CaptioningService, SplittingService, ImportService
9
  from ..config import (
10
- STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
11
  TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
12
  MODEL_TYPES, SMALL_TRAINING_BUCKETS
13
  )
@@ -160,7 +160,24 @@ class VideoTrainerUI:
160
 
161
  # If we recovered training parameters from the original session
162
  ui_state = {}
163
- for param in ["model_type", "lora_rank", "lora_alpha", "num_epochs",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  "batch_size", "learning_rate", "save_iterations", "training_preset"]:
165
  if param in recovery_ui:
166
  ui_state[param] = recovery_ui[param]
@@ -175,8 +192,16 @@ class VideoTrainerUI:
175
  # Load values (potentially with recovery updates applied)
176
  ui_state = self.load_ui_values()
177
 
178
- training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
179
  model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
 
 
 
 
 
 
 
 
180
  lora_rank_val = ui_state.get("lora_rank", "128")
181
  lora_alpha_val = ui_state.get("lora_alpha", "128")
182
  num_epochs_val = int(ui_state.get("num_epochs", 70))
@@ -190,9 +215,9 @@ class VideoTrainerUI:
190
  training_dataset,
191
  start_btn,
192
  stop_btn,
193
- delete_checkpoints_btn, # Replaces pause_resume_btn
194
  training_preset,
195
- model_type_val,
196
  lora_rank_val,
197
  lora_alpha_val,
198
  num_epochs_val,
 
7
 
8
  from ..services import TrainingService, CaptioningService, SplittingService, ImportService
9
  from ..config import (
10
+ STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
11
  TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
12
  MODEL_TYPES, SMALL_TRAINING_BUCKETS
13
  )
 
160
 
161
  # If we recovered training parameters from the original session
162
  ui_state = {}
163
+
164
+ # Handle model_type specifically - could be internal or display name
165
+ if "model_type" in recovery_ui:
166
+ model_type_value = recovery_ui["model_type"]
167
+
168
+ # If it's an internal name, convert to display name
169
+ if model_type_value not in MODEL_TYPES:
170
+ # Find the display name for this internal model type
171
+ for display_name, internal_name in MODEL_TYPES.items():
172
+ if internal_name == model_type_value:
173
+ model_type_value = display_name
174
+ logger.info(f"Converted internal model type '{recovery_ui['model_type']}' to display name '{model_type_value}'")
175
+ break
176
+
177
+ ui_state["model_type"] = model_type_value
178
+
179
+ # Copy other parameters
180
+ for param in ["lora_rank", "lora_alpha", "num_epochs",
181
  "batch_size", "learning_rate", "save_iterations", "training_preset"]:
182
  if param in recovery_ui:
183
  ui_state[param] = recovery_ui[param]
 
192
  # Load values (potentially with recovery updates applied)
193
  ui_state = self.load_ui_values()
194
 
195
+ # Ensure model_type is a display name, not internal name
196
  model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
197
+ if model_type_val not in MODEL_TYPES:
198
+ # Convert from internal to display name
199
+ for display_name, internal_name in MODEL_TYPES.items():
200
+ if internal_name == model_type_val:
201
+ model_type_val = display_name
202
+ break
203
+
204
+ training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
205
  lora_rank_val = ui_state.get("lora_rank", "128")
206
  lora_alpha_val = ui_state.get("lora_alpha", "128")
207
  num_epochs_val = int(ui_state.get("num_epochs", 70))
 
215
  training_dataset,
216
  start_btn,
217
  stop_btn,
218
+ delete_checkpoints_btn,
219
  training_preset,
220
+ model_type_val,
221
  lora_rank_val,
222
  lora_alpha_val,
223
  num_epochs_val,