jbilcke-hf HF Staff commited on
Commit
54a2a4e
·
1 Parent(s): 9545589

working on training job failure recovery

Browse files
Files changed (3) hide show
  1. app.py +124 -4
  2. vms/training_log_parser.py +33 -34
  3. vms/training_service.py +188 -4
app.py CHANGED
@@ -59,7 +59,43 @@ class VideoTrainerUI:
59
  self.captioner = CaptioningService()
60
  self._should_stop_captioning = False
61
  self.log_parser = TrainingLogParser()
62
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def update_captioning_buttons_start(self):
64
  """Return individual button values instead of a dictionary"""
65
  return (
@@ -1120,12 +1156,55 @@ class VideoTrainerUI:
1120
  return gr.update(value=repo_id, error=None)
1121
 
1122
  # Connect events
 
 
1123
  model_type.change(
 
 
 
 
1124
  fn=update_model_info,
1125
  inputs=[model_type],
1126
  outputs=[model_info, num_epochs, batch_size, learning_rate, save_iterations]
1127
  )
1128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1129
  async def on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix):
1130
  videos = self.list_unprocessed_videos()
1131
  # If scene detection isn't already running and there are videos to process,
@@ -1243,8 +1322,13 @@ class VideoTrainerUI:
1243
  fn=self.list_training_files_to_caption,
1244
  outputs=[training_dataset]
1245
  )
1246
-
 
1247
  training_preset.change(
 
 
 
 
1248
  fn=self.update_training_params,
1249
  inputs=[training_preset],
1250
  outputs=[
@@ -1337,13 +1421,49 @@ class VideoTrainerUI:
1337
  ]
1338
  )
1339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1340
  # Auto-refresh timers
1341
  app.load(
1342
  fn=lambda: (
1343
- self.refresh_dataset()
 
 
 
1344
  ),
1345
  outputs=[
1346
- video_list, training_dataset
 
 
 
 
1347
  ]
1348
  )
1349
 
 
59
  self.captioner = CaptioningService()
60
  self._should_stop_captioning = False
61
  self.log_parser = TrainingLogParser()
62
+
63
+ # Try to recover any interrupted training sessions
64
+ recovery_result = self.trainer.recover_interrupted_training()
65
+
66
+ self.recovery_status = recovery_result.get("status", "unknown")
67
+ self.ui_updates = recovery_result.get("ui_updates", {})
68
+
69
+ if recovery_result["status"] == "recovered":
70
+ logger.info(f"Training recovery: {recovery_result['message']}")
71
+ # No need to do anything else - the training is already running
72
+ elif recovery_result["status"] == "running":
73
+ logger.info("Training process is already running")
74
+ # No need to do anything - the process is still alive
75
+ elif recovery_result["status"] in ["error", "idle"]:
76
+ logger.warning(f"Training status: {recovery_result['message']}")
77
+ # UI will be in ready-to-start mode
78
+
79
+
80
+ def update_ui_state(self, **kwargs):
81
+ """Update UI state with new values"""
82
+ current_state = self.trainer.load_ui_state()
83
+ current_state.update(kwargs)
84
+ self.trainer.save_ui_state(current_state)
85
+ return current_state
86
+
87
+ def load_ui_values(self):
88
+ """Load UI state values for initializing form fields"""
89
+ ui_state = self.trainer.load_ui_state()
90
+
91
+ # Convert types as needed since JSON stores everything as strings
92
+ ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70))
93
+ ui_state["batch_size"] = int(ui_state.get("batch_size", 1))
94
+ ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5))
95
+ ui_state["save_iterations"] = int(ui_state.get("save_iterations", 500))
96
+
97
+ return ui_state
98
+
99
  def update_captioning_buttons_start(self):
100
  """Return individual button values instead of a dictionary"""
101
  return (
 
1156
  return gr.update(value=repo_id, error=None)
1157
 
1158
  # Connect events
1159
+
1160
+ # Save state when model type changes
1161
  model_type.change(
1162
+ fn=lambda v: self.update_ui_state(model_type=v),
1163
+ inputs=[model_type],
1164
+ outputs=[] # No UI update needed
1165
+ ).then(
1166
  fn=update_model_info,
1167
  inputs=[model_type],
1168
  outputs=[model_info, num_epochs, batch_size, learning_rate, save_iterations]
1169
  )
1170
 
1171
+ # the following change listeners are used for UI persistence
1172
+ lora_rank.change(
1173
+ fn=lambda v: self.update_ui_state(lora_rank=v),
1174
+ inputs=[lora_rank],
1175
+ outputs=[]
1176
+ )
1177
+
1178
+ lora_alpha.change(
1179
+ fn=lambda v: self.update_ui_state(lora_alpha=v),
1180
+ inputs=[lora_alpha],
1181
+ outputs=[]
1182
+ )
1183
+
1184
+ num_epochs.change(
1185
+ fn=lambda v: self.update_ui_state(num_epochs=v),
1186
+ inputs=[num_epochs],
1187
+ outputs=[]
1188
+ )
1189
+
1190
+ batch_size.change(
1191
+ fn=lambda v: self.update_ui_state(batch_size=v),
1192
+ inputs=[batch_size],
1193
+ outputs=[]
1194
+ )
1195
+
1196
+ learning_rate.change(
1197
+ fn=lambda v: self.update_ui_state(learning_rate=v),
1198
+ inputs=[learning_rate],
1199
+ outputs=[]
1200
+ )
1201
+
1202
+ save_iterations.change(
1203
+ fn=lambda v: self.update_ui_state(save_iterations=v),
1204
+ inputs=[save_iterations],
1205
+ outputs=[]
1206
+ )
1207
+
1208
  async def on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix):
1209
  videos = self.list_unprocessed_videos()
1210
  # If scene detection isn't already running and there are videos to process,
 
1322
  fn=self.list_training_files_to_caption,
1323
  outputs=[training_dataset]
1324
  )
1325
+
1326
+ # Save state when training preset changes
1327
  training_preset.change(
1328
+ fn=lambda v: self.update_ui_state(training_preset=v),
1329
+ inputs=[training_preset],
1330
+ outputs=[] # No UI update needed
1331
+ ).then(
1332
  fn=self.update_training_params,
1333
  inputs=[training_preset],
1334
  outputs=[
 
1421
  ]
1422
  )
1423
 
1424
+ # Add this new method to get initial button states:
1425
+ def get_initial_button_states(self):
1426
+ """Get the initial states for training buttons based on recovery status"""
1427
+ recovery_result = self.trainer.recover_interrupted_training()
1428
+ ui_updates = recovery_result.get("ui_updates", {})
1429
+
1430
+ # Return button states in the correct order
1431
+ return (
1432
+ gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})),
1433
+ gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})),
1434
+ gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
1435
+ )
1436
+
1437
+ def initialize_ui_from_state(self):
1438
+ """Initialize UI components from saved state"""
1439
+ ui_state = self.load_ui_values()
1440
+
1441
+ # Return values in order matching the outputs in app.load
1442
+ return (
1443
+ ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
1444
+ ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
1445
+ ui_state.get("lora_rank", "128"),
1446
+ ui_state.get("lora_alpha", "128"),
1447
+ ui_state.get("num_epochs", 70),
1448
+ ui_state.get("batch_size", 1),
1449
+ ui_state.get("learning_rate", 3e-5),
1450
+ ui_state.get("save_iterations", 500)
1451
+ )
1452
+
1453
  # Auto-refresh timers
1454
  app.load(
1455
  fn=lambda: (
1456
+ self.refresh_dataset(),
1457
+ *self.get_initial_button_states(),
1458
+ # Load saved UI state values
1459
+ *self.initialize_ui_from_state()
1460
  ),
1461
  outputs=[
1462
+ video_list, training_dataset,
1463
+ start_btn, stop_btn, pause_resume_btn,
1464
+ # Add outputs for UI fields
1465
+ training_preset, model_type, lora_rank, lora_alpha,
1466
+ num_epochs, batch_size, learning_rate, save_iterations
1467
  ]
1468
  )
1469
 
vms/training_log_parser.py CHANGED
@@ -34,7 +34,14 @@ class TrainingState:
34
 
35
  def to_dict(self) -> Dict[str, Any]:
36
  """Convert state to dictionary for UI updates"""
37
- elapsed = str(datetime.now() - self.start_time) if self.start_time else "0:00:00"
 
 
 
 
 
 
 
38
  remaining = str(self.estimated_remaining) if self.estimated_remaining else "calculating..."
39
 
40
  return {
@@ -74,10 +81,11 @@ class TrainingLogParser:
74
  if ("Started training" in line) or ("Starting training" in line):
75
  self.state.status = "training"
76
 
 
77
  if "Training steps:" in line:
78
  # Set status to training if we see this
79
  self.state.status = "training"
80
- #print("setting status to 'training'")
81
  if not self.state.start_time:
82
  self.state.start_time = datetime.now()
83
 
@@ -97,36 +105,23 @@ class TrainingLogParser:
97
  if match:
98
  setattr(self.state, attr, float(match.group(1)))
99
 
100
- # Calculate time estimates based on total elapsed time
101
- now = datetime.now()
102
- if self.state.start_time and self.state.current_step > 0:
103
- # Calculate elapsed time and average time per step
104
- elapsed_seconds = (now - self.state.start_time).total_seconds()
105
- avg_time_per_step = elapsed_seconds / self.state.current_step
106
-
107
- # Calculate remaining time
108
- remaining_steps = self.state.total_steps - self.state.current_step
109
- estimated_remaining_seconds = avg_time_per_step * remaining_steps
110
-
111
- # Format as days, hours, minutes, seconds
112
- days = int(estimated_remaining_seconds // (24 * 3600))
113
- hours = int((estimated_remaining_seconds % (24 * 3600)) // 3600)
114
- minutes = int((estimated_remaining_seconds % 3600) // 60)
115
- seconds = int(estimated_remaining_seconds % 60)
116
-
117
- # Create formatted timedelta
118
- if days > 0:
119
- formatted_time = f"{days}d {hours}h {minutes}m {seconds}s"
120
- elif hours > 0:
121
- formatted_time = f"{hours}h {minutes}m {seconds}s"
122
- elif minutes > 0:
123
- formatted_time = f"{minutes}m {seconds}s"
124
- else:
125
- formatted_time = f"{seconds}s"
126
-
127
- self.state.estimated_remaining = formatted_time
128
- self.state.last_step_time = now
129
-
130
  logger.info(f"Updated training state: step={self.state.current_step}/{self.state.total_steps}, loss={self.state.step_loss}")
131
  return self.state.to_dict()
132
 
@@ -162,12 +157,16 @@ class TrainingLogParser:
162
 
163
  # Completion states
164
  if "Training completed successfully" in line:
165
- self.state.status = "completed"
 
 
166
  logger.info("Training completed")
167
  return self.state.to_dict()
168
 
169
  if any(x in line for x in ["Training process stopped", "Training stopped"]):
170
- self.state.status = "stopped"
 
 
171
  logger.info("Training stopped")
172
  return self.state.to_dict()
173
 
 
34
 
35
  def to_dict(self) -> Dict[str, Any]:
36
  """Convert state to dictionary for UI updates"""
37
+ # Calculate elapsed time only if training is active and we have a start time
38
+ if self.start_time and self.status in ["training", "initializing"]:
39
+ elapsed = str(datetime.now() - self.start_time)
40
+ else:
41
+ # Use the last known elapsed time or show 0
42
+ elapsed = "0:00:00" if not self.last_step_time else str(self.last_step_time - self.start_time if self.start_time else "0:00:00")
43
+
44
+ # Use precomputed remaining time from logs if available
45
  remaining = str(self.estimated_remaining) if self.estimated_remaining else "calculating..."
46
 
47
  return {
 
81
  if ("Started training" in line) or ("Starting training" in line):
82
  self.state.status = "training"
83
 
84
+ # Check for "Training steps:" which contains the progress information
85
  if "Training steps:" in line:
86
  # Set status to training if we see this
87
  self.state.status = "training"
88
+
89
  if not self.state.start_time:
90
  self.state.start_time = datetime.now()
91
 
 
105
  if match:
106
  setattr(self.state, attr, float(match.group(1)))
107
 
108
+ # Extract time remaining directly from the log
109
+ # Format: [MM:SS<M:SS:SS, SS.SSs/it]
110
+ time_remaining_match = re.search(r"<(\d+:\d+:\d+)", line)
111
+ if time_remaining_match:
112
+ remaining_str = time_remaining_match.group(1)
113
+ # Store the string directly - no need to parse it
114
+ self.state.estimated_remaining = remaining_str
115
+
116
+ # If no direct time estimate, look for hour:min format
117
+ if not time_remaining_match:
118
+ hour_min_match = re.search(r"<(\d+h\s*\d+m)", line)
119
+ if hour_min_match:
120
+ self.state.estimated_remaining = hour_min_match.group(1)
121
+
122
+ # Update last processing time
123
+ self.state.last_step_time = datetime.now()
124
+
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  logger.info(f"Updated training state: step={self.state.current_step}/{self.state.total_steps}, loss={self.state.step_loss}")
126
  return self.state.to_dict()
127
 
 
157
 
158
  # Completion states
159
  if "Training completed successfully" in line:
160
+ self.status = "completed"
161
+ # Store final elapsed time
162
+ self.last_step_time = datetime.now()
163
  logger.info("Training completed")
164
  return self.state.to_dict()
165
 
166
  if any(x in line for x in ["Training process stopped", "Training stopped"]):
167
+ self.status = "stopped"
168
+ # Store final elapsed time
169
+ self.last_step_time = datetime.now()
170
  logger.info("Training stopped")
171
  return self.state.to_dict()
172
 
vms/training_service.py CHANGED
@@ -38,7 +38,7 @@ class TrainingService:
38
  self.setup_logging()
39
 
40
  logger.info("Training service initialized")
41
-
42
  def setup_logging(self):
43
  """Set up logging with proper handler management"""
44
  global logger
@@ -96,16 +96,58 @@ class TrainingService:
96
  if self.file_handler:
97
  self.file_handler.close()
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def save_session(self, params: Dict) -> None:
100
  """Save training session parameters"""
101
  session_data = {
102
  "timestamp": datetime.now().isoformat(),
103
  "params": params,
104
- "status": self.get_status()
 
 
105
  }
106
  with open(self.session_file, 'w') as f:
107
  json.dump(session_data, f, indent=2)
108
-
109
  def load_session(self) -> Optional[Dict]:
110
  """Load saved training session"""
111
  if self.session_file.exists():
@@ -225,6 +267,7 @@ class TrainingService:
225
  save_iterations: int,
226
  repo_id: str,
227
  preset_name: str,
 
228
  ) -> Tuple[str, str]:
229
  """Start training with finetrainers"""
230
 
@@ -295,6 +338,11 @@ class TrainingService:
295
  config.lr = float(learning_rate)
296
  config.checkpointing_steps = int(save_iterations)
297
 
 
 
 
 
 
298
  # Common settings for both models
299
  config.mixed_precision = "bf16"
300
  config.seed = 42
@@ -477,10 +525,146 @@ class TrainingService:
477
  try:
478
  with open(self.pid_file, 'r') as f:
479
  pid = int(f.read().strip())
480
- return psutil.pid_exists(pid)
 
 
 
 
 
 
 
 
 
 
481
  except:
482
  return False
483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  def clear_training_data(self) -> str:
485
  """Clear all training data"""
486
  if self.is_training_running():
 
38
  self.setup_logging()
39
 
40
  logger.info("Training service initialized")
41
+
42
  def setup_logging(self):
43
  """Set up logging with proper handler management"""
44
  global logger
 
96
  if self.file_handler:
97
  self.file_handler.close()
98
 
99
+
100
+ def save_ui_state(self, values: Dict[str, Any]) -> None:
101
+ """Save current UI state to file"""
102
+ ui_state_file = OUTPUT_PATH / "ui_state.json"
103
+ try:
104
+ with open(ui_state_file, 'w') as f:
105
+ json.dump(values, f, indent=2)
106
+ logger.debug(f"UI state saved: {values}")
107
+ except Exception as e:
108
+ logger.error(f"Error saving UI state: {str(e)}")
109
+
110
+ def load_ui_state(self) -> Dict[str, Any]:
111
+ """Load saved UI state"""
112
+ ui_state_file = OUTPUT_PATH / "ui_state.json"
113
+ default_state = {
114
+ "model_type": list(MODEL_TYPES.keys())[0],
115
+ "lora_rank": "128",
116
+ "lora_alpha": "128",
117
+ "num_epochs": 70,
118
+ "batch_size": 1,
119
+ "learning_rate": 3e-5,
120
+ "save_iterations": 500,
121
+ "training_preset": list(TRAINING_PRESETS.keys())[0]
122
+ }
123
+
124
+ if not ui_state_file.exists():
125
+ return default_state
126
+
127
+ try:
128
+ with open(ui_state_file, 'r') as f:
129
+ saved_state = json.load(f)
130
+ # Make sure we have all keys (in case structure changed)
131
+ merged_state = default_state.copy()
132
+ merged_state.update(saved_state)
133
+ return merged_state
134
+ except Exception as e:
135
+ logger.error(f"Error loading UI state: {str(e)}")
136
+ return default_state
137
+
138
+ # Modify save_session to also store the UI state at training start
139
  def save_session(self, params: Dict) -> None:
140
  """Save training session parameters"""
141
  session_data = {
142
  "timestamp": datetime.now().isoformat(),
143
  "params": params,
144
+ "status": self.get_status(),
145
+ # Add UI state at the time training started
146
+ "initial_ui_state": self.load_ui_state()
147
  }
148
  with open(self.session_file, 'w') as f:
149
  json.dump(session_data, f, indent=2)
150
+
151
  def load_session(self) -> Optional[Dict]:
152
  """Load saved training session"""
153
  if self.session_file.exists():
 
267
  save_iterations: int,
268
  repo_id: str,
269
  preset_name: str,
270
+ resume_from_checkpoint: Optional[str] = None,
271
  ) -> Tuple[str, str]:
272
  """Start training with finetrainers"""
273
 
 
338
  config.lr = float(learning_rate)
339
  config.checkpointing_steps = int(save_iterations)
340
 
341
+ # Update with resume_from_checkpoint if provided
342
+ if resume_from_checkpoint:
343
+ config.resume_from_checkpoint = resume_from_checkpoint
344
+ self.append_log(f"Resuming from checkpoint: {resume_from_checkpoint}")
345
+
346
  # Common settings for both models
347
  config.mixed_precision = "bf16"
348
  config.seed = 42
 
525
  try:
526
  with open(self.pid_file, 'r') as f:
527
  pid = int(f.read().strip())
528
+
529
+ # Check if process exists AND is a Python process running train.py
530
+ if psutil.pid_exists(pid):
531
+ try:
532
+ process = psutil.Process(pid)
533
+ cmdline = process.cmdline()
534
+ # Check if it's a Python process running train.py
535
+ return any('train.py' in cmd for cmd in cmdline)
536
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
537
+ return False
538
+ return False
539
  except:
540
  return False
541
 
542
+ def recover_interrupted_training(self) -> Dict[str, Any]:
543
+ """Attempt to recover interrupted training
544
+
545
+ Returns:
546
+ Dict with recovery status and UI updates
547
+ """
548
+ status = self.get_status()
549
+ ui_updates = {}
550
+
551
+ # If status indicates training but process isn't running, try to recover
552
+ if status.get('status') == 'training' and not self.is_training_running():
553
+ logger.info("Detected interrupted training session, attempting to recover...")
554
+
555
+ # Get the latest checkpoint
556
+ last_session = self.load_session()
557
+ if not last_session:
558
+ logger.warning("No session data found for recovery")
559
+ # Set buttons for no active training
560
+ ui_updates = {
561
+ "start_btn": {"interactive": True, "variant": "primary"},
562
+ "stop_btn": {"interactive": False, "variant": "secondary"},
563
+ "pause_resume_btn": {"interactive": False, "variant": "secondary"}
564
+ }
565
+ return {"status": "error", "message": "No session data found", "ui_updates": ui_updates}
566
+
567
+ # Find the latest checkpoint
568
+ checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
569
+ if not checkpoints:
570
+ logger.warning("No checkpoints found for recovery")
571
+ # Set buttons for no active training
572
+ ui_updates = {
573
+ "start_btn": {"interactive": True, "variant": "primary"},
574
+ "stop_btn": {"interactive": False, "variant": "secondary"},
575
+ "pause_resume_btn": {"interactive": False, "variant": "secondary"}
576
+ }
577
+ return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
578
+
579
+ latest_checkpoint = max(checkpoints, key=os.path.getmtime)
580
+ checkpoint_step = int(latest_checkpoint.name.split("-")[1])
581
+
582
+ logger.info(f"Found checkpoint at step {checkpoint_step}, attempting to resume")
583
+
584
+ # Extract parameters from the saved session (not current UI state)
585
+ # This ensures we use the original training parameters
586
+ params = last_session.get('params', {})
587
+ initial_ui_state = last_session.get('initial_ui_state', {})
588
+
589
+ # Add UI updates to restore the training parameters in the UI
590
+ # This shows the user what values are being used for the resumed training
591
+ ui_updates.update({
592
+ "model_type": gr.update(value=params.get('model_type', list(MODEL_TYPES.keys())[0])),
593
+ "lora_rank": gr.update(value=params.get('lora_rank', "128")),
594
+ "lora_alpha": gr.update(value=params.get('lora_alpha', "128")),
595
+ "num_epochs": gr.update(value=params.get('num_epochs', 70)),
596
+ "batch_size": gr.update(value=params.get('batch_size', 1)),
597
+ "learning_rate": gr.update(value=params.get('learning_rate', 3e-5)),
598
+ "save_iterations": gr.update(value=params.get('save_iterations', 500)),
599
+ "training_preset": gr.update(value=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]))
600
+ })
601
+
602
+ # Attempt to resume training using the ORIGINAL parameters
603
+ try:
604
+ # Extract required parameters from the session
605
+ model_type = params.get('model_type')
606
+ lora_rank = params.get('lora_rank')
607
+ lora_alpha = params.get('lora_alpha')
608
+ num_epochs = params.get('num_epochs')
609
+ batch_size = params.get('batch_size')
610
+ learning_rate = params.get('learning_rate')
611
+ save_iterations = params.get('save_iterations')
612
+ repo_id = params.get('repo_id')
613
+ preset_name = params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
614
+
615
+ # Attempt to resume training
616
+ result = self.start_training(
617
+ model_type=model_type,
618
+ lora_rank=lora_rank,
619
+ lora_alpha=lora_alpha,
620
+ num_epochs=num_epochs,
621
+ batch_size=batch_size,
622
+ learning_rate=learning_rate,
623
+ save_iterations=save_iterations,
624
+ repo_id=repo_id,
625
+ preset_name=preset_name,
626
+ resume_from_checkpoint=str(latest_checkpoint)
627
+ )
628
+
629
+ # Set buttons for active training
630
+ ui_updates.update({
631
+ "start_btn": {"interactive": False, "variant": "secondary"},
632
+ "stop_btn": {"interactive": True, "variant": "stop"},
633
+ "pause_resume_btn": {"interactive": True, "variant": "secondary"}
634
+ })
635
+
636
+ return {
637
+ "status": "recovered",
638
+ "message": f"Training resumed from checkpoint {checkpoint_step}",
639
+ "result": result,
640
+ "ui_updates": ui_updates
641
+ }
642
+ except Exception as e:
643
+ logger.error(f"Failed to resume training: {str(e)}")
644
+ # Set buttons for no active training
645
+ ui_updates.update({
646
+ "start_btn": {"interactive": True, "variant": "primary"},
647
+ "stop_btn": {"interactive": False, "variant": "secondary"},
648
+ "pause_resume_btn": {"interactive": False, "variant": "secondary"}
649
+ })
650
+ return {"status": "error", "message": f"Failed to resume: {str(e)}", "ui_updates": ui_updates}
651
+ elif self.is_training_running():
652
+ # Process is still running, set buttons accordingly
653
+ ui_updates = {
654
+ "start_btn": {"interactive": False, "variant": "secondary"},
655
+ "stop_btn": {"interactive": True, "variant": "stop"},
656
+ "pause_resume_btn": {"interactive": True, "variant": "secondary"}
657
+ }
658
+ return {"status": "running", "message": "Training process is running", "ui_updates": ui_updates}
659
+ else:
660
+ # No training process, set buttons to default state
661
+ ui_updates = {
662
+ "start_btn": {"interactive": True, "variant": "primary"},
663
+ "stop_btn": {"interactive": False, "variant": "secondary"},
664
+ "pause_resume_btn": {"interactive": False, "variant": "secondary"}
665
+ }
666
+ return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
667
+
668
  def clear_training_data(self) -> str:
669
  """Clear all training data"""
670
  if self.is_training_running():