jbilcke-hf HF Staff commited on
Commit
a3e57a3
·
1 Parent(s): cb66746

debugging checkpoint restoration

Browse files
vms/ui/app_ui.py CHANGED
@@ -214,8 +214,9 @@ class AppUI:
214
  outputs=[
215
  self.project_tabs["caption_tab"].components["training_dataset"],
216
  self.project_tabs["train_tab"].components["start_btn"],
 
217
  self.project_tabs["train_tab"].components["stop_btn"],
218
- self.project_tabs["train_tab"].components["pause_resume_btn"],
219
  self.project_tabs["train_tab"].components["training_preset"],
220
  self.project_tabs["train_tab"].components["model_type"],
221
  self.project_tabs["train_tab"].components["model_version"],
@@ -240,7 +241,7 @@ class AppUI:
240
  # Status update timer for text components (every 1 second)
241
  status_timer = gr.Timer(value=1)
242
  status_timer.tick(
243
- fn=self.project_tabs["train_tab"].get_status_updates, # Use a new function that returns appropriate updates
244
  outputs=[
245
  self.project_tabs["train_tab"].components["status_box"],
246
  self.project_tabs["train_tab"].components["log_box"],
@@ -252,20 +253,23 @@ class AppUI:
252
  button_timer = gr.Timer(value=1)
253
  button_outputs = [
254
  self.project_tabs["train_tab"].components["start_btn"],
255
- self.project_tabs["train_tab"].components["stop_btn"]
 
 
256
  ]
257
 
 
 
 
 
 
 
258
  # Add delete_checkpoints_btn or pause_resume_btn as the third button
259
  if "delete_checkpoints_btn" in self.project_tabs["train_tab"].components:
260
  button_outputs.append(self.project_tabs["train_tab"].components["delete_checkpoints_btn"])
261
  elif "pause_resume_btn" in self.project_tabs["train_tab"].components:
262
  button_outputs.append(self.project_tabs["train_tab"].components["pause_resume_btn"])
263
 
264
- button_timer.tick(
265
- fn=self.project_tabs["train_tab"].get_button_updates, # Use a new function for button-specific updates
266
- outputs=button_outputs
267
- )
268
-
269
  # Dataset refresh timer (every 5 seconds)
270
  dataset_timer = gr.Timer(value=5)
271
  dataset_timer.tick(
@@ -293,9 +297,10 @@ class AppUI:
293
  # Get button states based on recovery status
294
  button_states = self.get_initial_button_states()
295
  start_btn = button_states[0]
296
- stop_btn = button_states[1]
297
- delete_checkpoints_btn = button_states[2] # This replaces pause_resume_btn in the response tuple
298
-
 
299
  # Get UI form values - possibly from the recovery
300
  if self.recovery_status in ["recovered", "ready_to_recover", "running"] and "ui_updates" in self.state["recovery_result"]:
301
  recovery_ui = self.state["recovery_result"]["ui_updates"]
@@ -467,6 +472,7 @@ class AppUI:
467
  return (
468
  training_dataset,
469
  start_btn,
 
470
  stop_btn,
471
  delete_checkpoints_btn,
472
  training_preset,
@@ -543,7 +549,8 @@ class AppUI:
543
  ui_updates = recovery_result.get("ui_updates", {})
544
 
545
  # Check for checkpoints to determine start button text
546
- has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
 
547
 
548
  # Default button states if recovery didn't provide any
549
  if not ui_updates or not ui_updates.get("start_btn"):
@@ -551,27 +558,32 @@ class AppUI:
551
 
552
  if is_training:
553
  # Active training detected
554
- start_btn_props = {"interactive": False, "variant": "secondary", "value": "Continue Training" if has_checkpoints else "Start Training"}
 
555
  stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
556
  delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
557
  else:
558
  # No active training
559
- start_btn_props = {"interactive": True, "variant": "primary", "value": "Continue Training" if has_checkpoints else "Start Training"}
 
560
  stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
561
  delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
562
  else:
563
- # Use button states from recovery
564
- start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "Start Training"})
 
 
565
  stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
566
  delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
567
 
568
  # Return button states in the correct order
569
  return (
570
  gr.Button(**start_btn_props),
 
571
  gr.Button(**stop_btn_props),
572
  gr.Button(**delete_btn_props)
573
  )
574
-
575
  def update_titles(self) -> Tuple[Any]:
576
  """Update all dynamic titles with current counts
577
 
 
214
  outputs=[
215
  self.project_tabs["caption_tab"].components["training_dataset"],
216
  self.project_tabs["train_tab"].components["start_btn"],
217
+ self.project_tabs["train_tab"].components["resume_btn"],
218
  self.project_tabs["train_tab"].components["stop_btn"],
219
+ self.project_tabs["train_tab"].components["delete_checkpoints_btn"],
220
  self.project_tabs["train_tab"].components["training_preset"],
221
  self.project_tabs["train_tab"].components["model_type"],
222
  self.project_tabs["train_tab"].components["model_version"],
 
241
  # Status update timer for text components (every 1 second)
242
  status_timer = gr.Timer(value=1)
243
  status_timer.tick(
244
+ fn=self.project_tabs["train_tab"].get_status_updates,
245
  outputs=[
246
  self.project_tabs["train_tab"].components["status_box"],
247
  self.project_tabs["train_tab"].components["log_box"],
 
253
  button_timer = gr.Timer(value=1)
254
  button_outputs = [
255
  self.project_tabs["train_tab"].components["start_btn"],
256
+ self.project_tabs["train_tab"].components["resume_btn"],
257
+ self.project_tabs["train_tab"].components["stop_btn"],
258
+ self.project_tabs["train_tab"].components["delete_checkpoints_btn"]
259
  ]
260
 
261
+ button_timer.tick(
262
+ fn=self.project_tabs["train_tab"].get_button_updates,
263
+ outputs=button_outputs
264
+ )
265
+
266
+
267
  # Add delete_checkpoints_btn or pause_resume_btn as the third button
268
  if "delete_checkpoints_btn" in self.project_tabs["train_tab"].components:
269
  button_outputs.append(self.project_tabs["train_tab"].components["delete_checkpoints_btn"])
270
  elif "pause_resume_btn" in self.project_tabs["train_tab"].components:
271
  button_outputs.append(self.project_tabs["train_tab"].components["pause_resume_btn"])
272
 
 
 
 
 
 
273
  # Dataset refresh timer (every 5 seconds)
274
  dataset_timer = gr.Timer(value=5)
275
  dataset_timer.tick(
 
297
  # Get button states based on recovery status
298
  button_states = self.get_initial_button_states()
299
  start_btn = button_states[0]
300
+ resume_btn = button_states[1]
301
+ stop_btn = button_states[2]
302
+ delete_checkpoints_btn = button_states[3]
303
+
304
  # Get UI form values - possibly from the recovery
305
  if self.recovery_status in ["recovered", "ready_to_recover", "running"] and "ui_updates" in self.state["recovery_result"]:
306
  recovery_ui = self.state["recovery_result"]["ui_updates"]
 
472
  return (
473
  training_dataset,
474
  start_btn,
475
+ resume_btn,
476
  stop_btn,
477
  delete_checkpoints_btn,
478
  training_preset,
 
549
  ui_updates = recovery_result.get("ui_updates", {})
550
 
551
  # Check for checkpoints to determine start button text
552
+ checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
553
+ has_checkpoints = len(checkpoints) > 0
554
 
555
  # Default button states if recovery didn't provide any
556
  if not ui_updates or not ui_updates.get("start_btn"):
 
558
 
559
  if is_training:
560
  # Active training detected
561
+ start_btn_props = {"interactive": False, "variant": "secondary", "value": "Start new training"}
562
+ resume_btn_props = {"interactive": False, "variant": "secondary", "value": "Start from latest checkpoint"}
563
  stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
564
  delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
565
  else:
566
  # No active training
567
+ start_btn_props = {"interactive": True, "variant": "primary", "value": "Start new training"}
568
+ resume_btn_props = {"interactive": has_checkpoints, "variant": "primary", "value": "Start from latest checkpoint"}
569
  stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
570
  delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
571
  else:
572
+ # Use button states from recovery, adding the new resume button
573
+ start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "Start new training"})
574
+ resume_btn_props = {"interactive": has_checkpoints and not self.training.is_training_running(),
575
+ "variant": "primary", "value": "Start from latest checkpoint"}
576
  stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
577
  delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
578
 
579
  # Return button states in the correct order
580
  return (
581
  gr.Button(**start_btn_props),
582
+ gr.Button(**resume_btn_props), # Add the new resume button
583
  gr.Button(**stop_btn_props),
584
  gr.Button(**delete_btn_props)
585
  )
586
+
587
  def update_titles(self) -> Tuple[Any]:
588
  """Update all dynamic titles with current counts
589
 
vms/ui/project/services/previewing.py CHANGED
@@ -36,7 +36,9 @@ class PreviewingService:
36
  return str(lora_path)
37
 
38
  # If not found in the expected location, try to find in checkpoints
39
- checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
 
 
40
  if not checkpoints:
41
  return None
42
 
 
36
  return str(lora_path)
37
 
38
  # If not found in the expected location, try to find in checkpoints
39
+ checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
40
+ has_checkpoints = len(checkpoints) > 0
41
+
42
  if not checkpoints:
43
  return None
44
 
vms/ui/project/services/training.py CHANGED
@@ -1042,7 +1042,7 @@ class TrainingService:
1042
  ui_updates = {}
1043
 
1044
  # Check for any checkpoints, even if status doesn't indicate training
1045
- checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
1046
  has_checkpoints = len(checkpoints) > 0
1047
 
1048
  # If status indicates training but process isn't running, or if we have checkpoints
@@ -1078,6 +1078,7 @@ class TrainingService:
1078
  }
1079
  logger.info("Created default session from UI state for recovery")
1080
  else:
 
1081
  # Set buttons for no active training
1082
  ui_updates = {
1083
  "start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
@@ -1092,8 +1093,9 @@ class TrainingService:
1092
  checkpoint_step = 0
1093
 
1094
  if has_checkpoints:
1095
- latest_checkpoint = max(checkpoints, key=os.path.getmtime)
1096
- checkpoint_step = int(latest_checkpoint.name.split("-")[1])
 
1097
  logger.info(f"Found checkpoint at step {checkpoint_step}")
1098
  else:
1099
  logger.warning("No checkpoints found for recovery")
@@ -1226,7 +1228,7 @@ class TrainingService:
1226
 
1227
  try:
1228
  # Find all checkpoint directories
1229
- checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
1230
 
1231
  if not checkpoints:
1232
  return "No checkpoints found to delete."
 
1042
  ui_updates = {}
1043
 
1044
  # Check for any checkpoints, even if status doesn't indicate training
1045
+ checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
1046
  has_checkpoints = len(checkpoints) > 0
1047
 
1048
  # If status indicates training but process isn't running, or if we have checkpoints
 
1078
  }
1079
  logger.info("Created default session from UI state for recovery")
1080
  else:
1081
+ logger.warning(f"No checkpoints found for recovery")
1082
  # Set buttons for no active training
1083
  ui_updates = {
1084
  "start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
 
1093
  checkpoint_step = 0
1094
 
1095
  if has_checkpoints:
1096
+ # Find the latest checkpoint by step number
1097
+ latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
1098
+ checkpoint_step = int(latest_checkpoint.name.split("_")[-1])
1099
  logger.info(f"Found checkpoint at step {checkpoint_step}")
1100
  else:
1101
  logger.warning("No checkpoints found for recovery")
 
1228
 
1229
  try:
1230
  # Find all checkpoint directories
1231
+ checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
1232
 
1233
  if not checkpoints:
1234
  return "No checkpoints found to delete."
vms/ui/project/tabs/manage_tab.py CHANGED
@@ -65,11 +65,43 @@ class ManageTab(BaseTab):
65
 
66
  with gr.Row():
67
  with gr.Column():
68
- gr.Markdown("## Delete your model")
69
- gr.Markdown("If something went wrong, you can trigger a full reset (model shutdown + data destruction).")
70
  gr.Markdown("Make sure you have made a backup first.")
71
  gr.Markdown("If you are deleting because of a bug, remember you can use the Developer Mode on HF to inspect the working directory (in /data or .data)")
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  with gr.Row():
74
  self.components["global_stop_btn"] = gr.Button(
75
  "Stop everything and delete my data",
@@ -103,6 +135,24 @@ class ManageTab(BaseTab):
103
  outputs=[self.components["download_model_btn"]]
104
  )
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  # Global stop button
107
  self.components["global_stop_btn"].click(
108
  fn=self.handle_global_stop,
@@ -151,6 +201,91 @@ class ManageTab(BaseTab):
151
  return f"Successfully uploaded model to {repo_id}"
152
  else:
153
  return f"Failed to upload model to {repo_id}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  def handle_global_stop(self):
156
  """Handle the global stop button click"""
 
65
 
66
  with gr.Row():
67
  with gr.Column():
68
+ gr.Markdown("## Delete your data")
 
69
  gr.Markdown("Make sure you have made a backup first.")
70
  gr.Markdown("If you are deleting because of a bug, remember you can use the Developer Mode on HF to inspect the working directory (in /data or .data)")
71
 
72
+ with gr.Row():
73
+ with gr.Column():
74
+ gr.Markdown("### Delete specific data")
75
+ gr.Markdown("You can selectively delete either the dataset and/or the last model data.")
76
+
77
+ with gr.Row():
78
+ with gr.Column(scale=1):
79
+ self.components["delete_dataset_btn"] = gr.Button(
80
+ "Delete dataset (images, video, captions)",
81
+ variant="secondary"
82
+ )
83
+ self.components["delete_dataset_status"] = gr.Textbox(
84
+ label="Delete Dataset Status",
85
+ interactive=False,
86
+ visible=False
87
+ )
88
+
89
+ with gr.Column(scale=1):
90
+ self.components["delete_model_btn"] = gr.Button(
91
+ "Delete model (checkpoints, weights, config)",
92
+ variant="secondary"
93
+ )
94
+ self.components["delete_model_status"] = gr.Textbox(
95
+ label="Delete Model Status",
96
+ interactive=False,
97
+ visible=False
98
+ )
99
+
100
+ with gr.Row():
101
+ with gr.Column():
102
+ gr.Markdown("### Delete everything")
103
+ gr.Markdown("This will delete both the dataset (all images, videos and captions) AND the latest model (weights, checkpoints, settings). So use with care!")
104
+
105
  with gr.Row():
106
  self.components["global_stop_btn"] = gr.Button(
107
  "Stop everything and delete my data",
 
135
  outputs=[self.components["download_model_btn"]]
136
  )
137
 
138
+ # New delete dataset button
139
+ self.components["delete_dataset_btn"].click(
140
+ fn=self.delete_dataset,
141
+ outputs=[
142
+ self.components["delete_dataset_status"],
143
+ self.app.tabs["caption_tab"].components["training_dataset"]
144
+ ]
145
+ )
146
+
147
+ # New delete model button
148
+ self.components["delete_model_btn"].click(
149
+ fn=self.delete_model,
150
+ outputs=[
151
+ self.components["delete_model_status"],
152
+ self.app.tabs["train_tab"].components["status_box"]
153
+ ]
154
+ )
155
+
156
  # Global stop button
157
  self.components["global_stop_btn"].click(
158
  fn=self.handle_global_stop,
 
201
  return f"Successfully uploaded model to {repo_id}"
202
  else:
203
  return f"Failed to upload model to {repo_id}"
204
+
205
+ def delete_dataset(self):
206
+ """Delete dataset files (images, videos, captions)"""
207
+ status_messages = {}
208
+
209
+ try:
210
+ # Stop captioning if running
211
+ if self.app.captioning:
212
+ self.app.captioning.stop_captioning()
213
+ status_messages["captioning"] = "Captioning stopped"
214
+
215
+ # Stop scene detection if running
216
+ if self.app.splitting.is_processing():
217
+ self.app.splitting.processing = False
218
+ status_messages["splitting"] = "Scene detection stopped"
219
+
220
+ # Clear dataset directories
221
+ for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH]:
222
+ if path.exists():
223
+ try:
224
+ shutil.rmtree(path)
225
+ path.mkdir(parents=True, exist_ok=True)
226
+ except Exception as e:
227
+ status_messages[f"clear_{path.name}"] = f"Error clearing {path.name}: {str(e)}"
228
+ else:
229
+ status_messages[f"clear_{path.name}"] = f"Cleared {path.name}"
230
+
231
+ # Reset any relevant persistent state
232
+ self.app.tabs["caption_tab"]._should_stop_captioning = True
233
+ self.app.splitting.processing = False
234
+
235
+ # Format response
236
+ details = "\n".join(f"{k}: {v}" for k, v in status_messages.items())
237
+ message = f"Dataset deleted successfully\n\nDetails:\n{details}"
238
+
239
+ # Get fresh lists after cleanup
240
+ clips = self.app.tabs["caption_tab"].list_training_files_to_caption()
241
+
242
+ return gr.update(value=message, visible=True), clips
243
+
244
+ except Exception as e:
245
+ error_message = f"Error deleting dataset: {str(e)}\n\nDetails:\n{status_messages}"
246
+ return gr.update(value=error_message, visible=True), self.app.tabs["caption_tab"].list_training_files_to_caption()
247
+
248
+ def delete_model(self):
249
+ """Delete model files (checkpoints, weights, configuration)"""
250
+ status_messages = {}
251
+
252
+ try:
253
+ # Stop training if running
254
+ if self.app.training.is_training_running():
255
+ training_result = self.app.training.stop_training()
256
+ status_messages["training"] = training_result["status"]
257
+
258
+ # Clear model output directory
259
+ if OUTPUT_PATH.exists():
260
+ try:
261
+ shutil.rmtree(OUTPUT_PATH)
262
+ OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
263
+ except Exception as e:
264
+ status_messages[f"clear_{OUTPUT_PATH.name}"] = f"Error clearing {OUTPUT_PATH.name}: {str(e)}"
265
+ else:
266
+ status_messages[f"clear_{OUTPUT_PATH.name}"] = f"Cleared {OUTPUT_PATH.name}"
267
+
268
+ # Properly close logging before clearing log file
269
+ if self.app.training.file_handler:
270
+ self.app.training.file_handler.close()
271
+ logger.removeHandler(self.app.training.file_handler)
272
+ self.app.training.file_handler = None
273
+
274
+ if LOG_FILE_PATH.exists():
275
+ LOG_FILE_PATH.unlink()
276
+
277
+ # Reset training UI state
278
+ self.app.training.setup_logging()
279
+
280
+ # Format response
281
+ details = "\n".join(f"{k}: {v}" for k, v in status_messages.items())
282
+ message = f"Model deleted successfully\n\nDetails:\n{details}"
283
+
284
+ return gr.update(value=message, visible=True), "Model files have been deleted"
285
+
286
+ except Exception as e:
287
+ error_message = f"Error deleting model: {str(e)}\n\nDetails:\n{status_messages}"
288
+ return gr.update(value=error_message, visible=True), f"Error deleting model: {str(e)}"
289
 
290
  def handle_global_stop(self):
291
  """Handle the global stop button click"""
vms/ui/project/tabs/preview_tab.py CHANGED
@@ -219,7 +219,8 @@ class PreviewTab(BaseTab):
219
  return True
220
 
221
  # If not found in the expected location, try to find in checkpoints
222
- checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
 
223
  if not checkpoints:
224
  return False
225
 
 
219
  return True
220
 
221
  # If not found in the expected location, try to find in checkpoints
222
+ checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
223
+ has_checkpoints = len(checkpoints) > 0
224
  if not checkpoints:
225
  return False
226
 
vms/ui/project/tabs/train_tab.py CHANGED
@@ -6,6 +6,7 @@ import gradio as gr
6
  import logging
7
  import os
8
  import json
 
9
  from typing import Dict, Any, List, Optional, Tuple
10
  from pathlib import Path
11
 
@@ -177,39 +178,58 @@ class TrainTab(BaseTab):
177
  precision=0,
178
  info="Number of warmup steps (typically 20-40% of total training steps). This helps reducing the impact of early training examples as well as giving time to optimizers to compute accurate statistics."
179
  )
180
- with gr.Column():
181
- with gr.Row():
182
- # Check for existing checkpoints to determine button text
183
- has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
184
- start_text = "Continue Training" if has_checkpoints else "Start Training"
185
-
186
- self.components["start_btn"] = gr.Button(
187
- start_text,
188
- variant="primary",
189
- interactive=not ASK_USER_TO_DUPLICATE_SPACE
190
- )
191
-
192
- # Just use stop and pause buttons for now to ensure compatibility
193
- self.components["stop_btn"] = gr.Button(
194
- "Stop at Last Checkpoint",
195
- variant="primary",
196
- interactive=False
197
- )
198
-
199
- self.components["pause_resume_btn"] = gr.Button(
200
- "Resume Training",
201
- variant="secondary",
202
- interactive=False,
203
- visible=False
204
- )
205
 
206
- # Add delete checkpoints button
207
- self.components["delete_checkpoints_btn"] = gr.Button(
208
- "Delete All Checkpoints",
209
- variant="stop",
210
- interactive=True
211
- )
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  with gr.Row():
214
  with gr.Column():
215
  self.components["status_box"] = gr.Textbox(
@@ -226,12 +246,12 @@ class TrainTab(BaseTab):
226
  elem_id="current_task_display"
227
  )
228
 
229
- with gr.Accordion("See training logs"):
230
  self.components["log_box"] = gr.TextArea(
231
- label="Finetrainers output (see HF Space logs for more details)",
232
  interactive=False,
233
- lines=40,
234
- max_lines=200,
235
  autoscroll=True
236
  )
237
 
@@ -268,6 +288,55 @@ class TrainTab(BaseTab):
268
  self.app.update_ui_state(model_type=model_type, model_version=model_version)
269
  return None
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  def connect_events(self) -> None:
272
  """Connect event handlers to UI components"""
273
  # Model type change event - Update model version dropdown choices
@@ -396,11 +465,11 @@ class TrainTab(BaseTab):
396
 
397
  # Training control events
398
  self.components["start_btn"].click(
399
- fn=self.handle_training_start,
400
  inputs=[
401
  self.components["training_preset"],
402
  self.components["model_type"],
403
- self.components["model_version"], # Add model_version to the inputs
404
  self.components["training_type"],
405
  self.components["lora_rank"],
406
  self.components["lora_alpha"],
@@ -416,6 +485,28 @@ class TrainTab(BaseTab):
416
  ]
417
  )
418
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  # Use simplified event handlers for pause/resume and stop
420
  third_btn = self.components["delete_checkpoints_btn"] if "delete_checkpoints_btn" in self.components else self.components["pause_resume_btn"]
421
 
@@ -500,7 +591,8 @@ class TrainTab(BaseTab):
500
  self.app.log_parser = TrainingLogParser()
501
 
502
  # Check for latest checkpoint
503
- checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
 
504
  resume_from = None
505
 
506
  if checkpoints:
@@ -863,43 +955,40 @@ class TrainTab(BaseTab):
863
  status, _, _ = self.get_latest_status_message_and_logs()
864
 
865
  # Add checkpoints detection
866
- has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
 
867
 
868
  is_training = status in ["training", "initializing"]
869
  is_completed = status in ["completed", "error", "stopped"]
870
 
871
- start_text = "Continue Training" if has_checkpoints else "Start Training"
872
-
873
  # Create button updates
874
  start_btn = gr.Button(
875
- value=start_text,
876
  interactive=not is_training,
877
  variant="primary" if not is_training else "secondary"
878
  )
879
 
 
 
 
 
 
 
880
  stop_btn = gr.Button(
881
  value="Stop at Last Checkpoint",
882
  interactive=is_training,
883
  variant="primary" if is_training else "secondary"
884
  )
885
 
886
- # Add delete_checkpoints_btn or pause_resume_btn
887
- if "delete_checkpoints_btn" in self.components:
888
- third_btn = gr.Button(
889
- "Delete All Checkpoints",
890
- interactive=has_checkpoints and not is_training,
891
- variant="stop"
892
- )
893
- else:
894
- third_btn = gr.Button(
895
- "Resume Training" if status == "paused" else "Pause Training",
896
- interactive=(is_training or status == "paused") and not is_completed,
897
- variant="secondary",
898
- visible=False
899
- )
900
-
901
- return start_btn, stop_btn, third_btn
902
 
 
 
903
  def update_training_ui(self, training_state: Dict[str, Any]):
904
  """Update UI components based on training state"""
905
  updates = {}
 
6
  import logging
7
  import os
8
  import json
9
+ import shutil
10
  from typing import Dict, Any, List, Optional, Tuple
11
  from pathlib import Path
12
 
 
178
  precision=0,
179
  info="Number of warmup steps (typically 20-40% of total training steps). This helps reducing the impact of early training examples as well as giving time to optimizers to compute accurate statistics."
180
  )
181
+
182
+ with gr.Row():
183
+ with gr.Column():
184
+ # Add description of the training buttons
185
+ self.components["training_buttons_info"] = gr.Markdown("""
186
+ ## Training Options
187
+ - **Start new training**: Begins training from scratch (clears previous checkpoints)
188
+ - **Start from latest checkpoint**: Continues training from the most recent checkpoint
189
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ with gr.Row():
192
+ # Check for existing checkpoints to determine button text
193
+ checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
194
+ has_checkpoints = len(checkpoints) > 0
 
 
195
 
196
+ # Rename "Start Training" to "Start new training"
197
+ self.components["start_btn"] = gr.Button(
198
+ "Start new training",
199
+ variant="primary",
200
+ interactive=not ASK_USER_TO_DUPLICATE_SPACE
201
+ )
202
+
203
+ # Add new button for continuing from checkpoint
204
+ self.components["resume_btn"] = gr.Button(
205
+ "Start from latest checkpoint",
206
+ variant="primary",
207
+ interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
208
+ )
209
+
210
+ with gr.Row():
211
+ # Just use stop and pause buttons for now to ensure compatibility
212
+ self.components["stop_btn"] = gr.Button(
213
+ "Stop at Last Checkpoint",
214
+ variant="primary",
215
+ interactive=False
216
+ )
217
+
218
+ self.components["pause_resume_btn"] = gr.Button(
219
+ "Resume Training",
220
+ variant="secondary",
221
+ interactive=False,
222
+ visible=False
223
+ )
224
+
225
+ # Add delete checkpoints button
226
+ self.components["delete_checkpoints_btn"] = gr.Button(
227
+ "Delete All Checkpoints",
228
+ variant="stop",
229
+ interactive=has_checkpoints
230
+ )
231
+
232
+ with gr.Column():
233
  with gr.Row():
234
  with gr.Column():
235
  self.components["status_box"] = gr.Textbox(
 
246
  elem_id="current_task_display"
247
  )
248
 
249
+ with gr.Accordion("Finetrainers output (or see app logs for more details)"):
250
  self.components["log_box"] = gr.TextArea(
251
+ #label="",
252
  interactive=False,
253
+ lines=60,
254
+ max_lines=600,
255
  autoscroll=True
256
  )
257
 
 
288
  self.app.update_ui_state(model_type=model_type, model_version=model_version)
289
  return None
290
 
291
+ def handle_new_training_start(
292
+ self, preset, model_type, model_version, training_type,
293
+ lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
294
+ save_iterations, repo_id, progress=gr.Progress()
295
+ ):
296
+ """Handle new training start with checkpoint cleanup"""
297
+ # Clear output directory to start fresh
298
+
299
+ # Delete all checkpoint directories
300
+ for checkpoint in OUTPUT_PATH.glob("finetrainers_step_*"):
301
+ if checkpoint.is_dir():
302
+ shutil.rmtree(checkpoint)
303
+
304
+ # Also delete session.json which contains previous training info
305
+ session_file = OUTPUT_PATH / "session.json"
306
+ if session_file.exists():
307
+ session_file.unlink()
308
+
309
+ self.append_log("Cleared previous checkpoints for new training session")
310
+
311
+ # Start training normally
312
+ return self.handle_training_start(
313
+ preset, model_type, model_version, training_type,
314
+ lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
315
+ save_iterations, repo_id, progress
316
+ )
317
+
318
+ def handle_resume_training(
319
+ self, preset, model_type, model_version, training_type,
320
+ lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
321
+ save_iterations, repo_id, progress=gr.Progress()
322
+ ):
323
+ """Handle resuming training from the latest checkpoint"""
324
+ # Find the latest checkpoint
325
+ checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
326
+
327
+ if not checkpoints:
328
+ return "No checkpoints found to resume from", "Please start a new training session instead"
329
+
330
+ self.append_log(f"Resuming training from latest checkpoint")
331
+
332
+ # Start training with the checkpoint
333
+ return self.handle_training_start(
334
+ preset, model_type, model_version, training_type,
335
+ lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
336
+ save_iterations, repo_id, progress,
337
+ resume_from_checkpoint="latest"
338
+ )
339
+
340
  def connect_events(self) -> None:
341
  """Connect event handlers to UI components"""
342
  # Model type change event - Update model version dropdown choices
 
465
 
466
  # Training control events
467
  self.components["start_btn"].click(
468
+ fn=self.handle_new_training_start,
469
  inputs=[
470
  self.components["training_preset"],
471
  self.components["model_type"],
472
+ self.components["model_version"],
473
  self.components["training_type"],
474
  self.components["lora_rank"],
475
  self.components["lora_alpha"],
 
485
  ]
486
  )
487
 
488
+ self.components["resume_btn"].click(
489
+ fn=self.handle_resume_training,
490
+ inputs=[
491
+ self.components["training_preset"],
492
+ self.components["model_type"],
493
+ self.components["model_version"],
494
+ self.components["training_type"],
495
+ self.components["lora_rank"],
496
+ self.components["lora_alpha"],
497
+ self.components["train_steps"],
498
+ self.components["batch_size"],
499
+ self.components["learning_rate"],
500
+ self.components["save_iterations"],
501
+ self.app.tabs["manage_tab"].components["repo_id"]
502
+ ],
503
+ outputs=[
504
+ self.components["status_box"],
505
+ self.components["log_box"]
506
+ ]
507
+ )
508
+
509
+
510
  # Use simplified event handlers for pause/resume and stop
511
  third_btn = self.components["delete_checkpoints_btn"] if "delete_checkpoints_btn" in self.components else self.components["pause_resume_btn"]
512
 
 
591
  self.app.log_parser = TrainingLogParser()
592
 
593
  # Check for latest checkpoint
594
+ checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
595
+ has_checkpoints = len(checkpoints) > 0
596
  resume_from = None
597
 
598
  if checkpoints:
 
955
  status, _, _ = self.get_latest_status_message_and_logs()
956
 
957
  # Add checkpoints detection
958
+ checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
959
+ has_checkpoints = len(checkpoints) > 0
960
 
961
  is_training = status in ["training", "initializing"]
962
  is_completed = status in ["completed", "error", "stopped"]
963
 
 
 
964
  # Create button updates
965
  start_btn = gr.Button(
966
+ value="Start new training",
967
  interactive=not is_training,
968
  variant="primary" if not is_training else "secondary"
969
  )
970
 
971
+ resume_btn = gr.Button(
972
+ value="Start from latest checkpoint",
973
+ interactive=has_checkpoints and not is_training,
974
+ variant="primary" if not is_training else "secondary"
975
+ )
976
+
977
  stop_btn = gr.Button(
978
  value="Stop at Last Checkpoint",
979
  interactive=is_training,
980
  variant="primary" if is_training else "secondary"
981
  )
982
 
983
+ # Add delete_checkpoints_btn
984
+ delete_checkpoints_btn = gr.Button(
985
+ "Delete All Checkpoints",
986
+ interactive=has_checkpoints and not is_training,
987
+ variant="stop"
988
+ )
 
 
 
 
 
 
 
 
 
 
989
 
990
+ return start_btn, resume_btn, stop_btn, delete_checkpoints_btn
991
+
992
  def update_training_ui(self, training_state: Dict[str, Any]):
993
  """Update UI components based on training state"""
994
  updates = {}