jbilcke-hf HF Staff commited on
Commit
adc5756
·
1 Parent(s): 892fa67

working on a fix

Browse files
vms/services/trainer.py CHANGED
@@ -353,7 +353,7 @@ class TrainingService:
353
  resume_from_checkpoint: Optional[str] = None,
354
  ) -> Tuple[str, str]:
355
  """Start training with finetrainers"""
356
-
357
  self.clear_logs()
358
 
359
  if not model_type:
@@ -365,22 +365,31 @@ class TrainingService:
365
  is_resuming = resume_from_checkpoint is not None
366
  log_prefix = "Resuming" if is_resuming else "Initializing"
367
  logger.info(f"{log_prefix} training with model_type={model_type}")
368
- self.append_log(f"{log_prefix} training with model_type={model_type}")
369
-
370
- if is_resuming:
371
- self.append_log(f"Resuming from checkpoint: {resume_from_checkpoint}")
372
 
373
  try:
374
- # Get absolute paths
375
- current_dir = Path(__file__).parent.absolute()
376
- train_script = current_dir.parent / "train.py"
377
-
378
 
379
  if not train_script.exists():
380
- error_msg = f"Training script not found at {train_script}"
381
- logger.error(error_msg)
382
- return error_msg, "Training script not found"
 
 
 
 
 
 
 
 
 
383
 
 
 
 
 
 
384
  # Log paths for debugging
385
  logger.info("Current working directory: %s", current_dir)
386
  logger.info("Training script path: %s", train_script)
 
353
  resume_from_checkpoint: Optional[str] = None,
354
  ) -> Tuple[str, str]:
355
  """Start training with finetrainers"""
356
+
357
  self.clear_logs()
358
 
359
  if not model_type:
 
365
  is_resuming = resume_from_checkpoint is not None
366
  log_prefix = "Resuming" if is_resuming else "Initializing"
367
  logger.info(f"{log_prefix} training with model_type={model_type}")
 
 
 
 
368
 
369
  try:
370
+ # Get absolute paths - FIXED to look in project root instead of within vms directory
371
+ current_dir = Path(__file__).parent.parent.parent.absolute() # Go up to project root
372
+ train_script = current_dir / "train.py"
 
373
 
374
  if not train_script.exists():
375
+ # Try alternative locations
376
+ alt_locations = [
377
+ current_dir.parent / "train.py", # One level up from project root
378
+ Path("/home/user/app/train.py"), # Absolute path
379
+ Path("train.py") # Current working directory
380
+ ]
381
+
382
+ for alt_path in alt_locations:
383
+ if alt_path.exists():
384
+ train_script = alt_path
385
+ logger.info(f"Found train.py at alternative location: {train_script}")
386
+ break
387
 
388
+ if not train_script.exists():
389
+ error_msg = f"Training script not found at {train_script} or any alternative locations"
390
+ logger.error(error_msg)
391
+ return error_msg, "Training script not found"
392
+
393
  # Log paths for debugging
394
  logger.info("Current working directory: %s", current_dir)
395
  logger.info("Training script path: %s", train_script)
vms/tabs/train_tab.py CHANGED
@@ -91,20 +91,35 @@ class TrainTab(BaseTab):
91
 
92
  with gr.Column():
93
  with gr.Row():
 
 
 
 
94
  self.components["start_btn"] = gr.Button(
95
- "Start Training",
96
  variant="primary",
97
  interactive=not ASK_USER_TO_DUPLICATE_SPACE
98
  )
 
 
 
 
 
 
 
 
99
  self.components["pause_resume_btn"] = gr.Button(
100
  "Resume Training",
101
  variant="secondary",
102
- interactive=False
 
103
  )
104
- self.components["stop_btn"] = gr.Button(
105
- "Stop Training",
 
 
106
  variant="stop",
107
- interactive=False
108
  )
109
 
110
  with gr.Row():
@@ -468,31 +483,56 @@ class TrainTab(BaseTab):
468
 
469
  return (state["status"], state["message"], logs)
470
 
471
- def get_latest_status_message_logs_and_button_labels(self) -> Tuple[str, str, Any, Any, Any]:
 
472
  status, message, logs = self.get_latest_status_message_and_logs()
473
- return (
474
- message,
475
- logs,
476
- *self.update_training_buttons(status).values()
477
- )
478
-
479
- def update_training_buttons(self, status: str) -> Dict:
 
 
 
480
  """Update training control buttons based on state"""
 
 
 
481
  is_training = status in ["training", "initializing"]
482
- is_paused = status == "paused"
483
  is_completed = status in ["completed", "error", "stopped"]
484
- return {
 
 
 
 
485
  "start_btn": gr.Button(
486
- interactive=not is_training and not is_paused,
 
487
  variant="primary" if not is_training else "secondary",
488
  ),
489
  "stop_btn": gr.Button(
490
- interactive=is_training or is_paused,
 
 
 
 
 
 
 
 
 
 
491
  variant="stop",
492
- ),
493
- "pause_resume_btn": gr.Button(
494
- value="Resume Training" if is_paused else "Pause Training",
495
- interactive=(is_training or is_paused) and not is_completed,
 
 
496
  variant="secondary",
 
497
  )
498
- }
 
 
91
 
92
  with gr.Column():
93
  with gr.Row():
94
+ # Check for existing checkpoints to determine button text
95
+ has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
96
+ start_text = "Continue Training" if has_checkpoints else "Start Training"
97
+
98
  self.components["start_btn"] = gr.Button(
99
+ start_text,
100
  variant="primary",
101
  interactive=not ASK_USER_TO_DUPLICATE_SPACE
102
  )
103
+
104
+ # Just use stop and pause buttons for now to ensure compatibility
105
+ self.components["stop_btn"] = gr.Button(
106
+ "Stop at Last Checkpoint",
107
+ variant="primary",
108
+ interactive=False
109
+ )
110
+
111
  self.components["pause_resume_btn"] = gr.Button(
112
  "Resume Training",
113
  variant="secondary",
114
+ interactive=False,
115
+ visible=False
116
  )
117
+
118
+ # Add delete checkpoints button - THIS IS THE KEY FIX
119
+ self.components["delete_checkpoints_btn"] = gr.Button(
120
+ "Delete All Checkpoints",
121
  variant="stop",
122
+ interactive=True
123
  )
124
 
125
  with gr.Row():
 
483
 
484
  return (state["status"], state["message"], logs)
485
 
486
+ def get_latest_status_message_logs_and_button_labels(self) -> Tuple:
487
+ """Get latest status message, logs and button states"""
488
  status, message, logs = self.get_latest_status_message_and_logs()
489
+
490
+ # Add checkpoints detection
491
+ has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
492
+
493
+ button_updates = self.update_training_buttons(status, has_checkpoints).values()
494
+
495
+ # Return in order expected by timer
496
+ return (message, logs, *button_updates)
497
+
498
+ def update_training_buttons(self, status: str, has_checkpoints: bool = None) -> Dict:
499
  """Update training control buttons based on state"""
500
+ if has_checkpoints is None:
501
+ has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
502
+
503
  is_training = status in ["training", "initializing"]
 
504
  is_completed = status in ["completed", "error", "stopped"]
505
+
506
+ start_text = "Continue Training" if has_checkpoints else "Start Training"
507
+
508
+ # Only include buttons that we know exist in components
509
+ result = {
510
  "start_btn": gr.Button(
511
+ value=start_text,
512
+ interactive=not is_training,
513
  variant="primary" if not is_training else "secondary",
514
  ),
515
  "stop_btn": gr.Button(
516
+ value="Stop at Last Checkpoint",
517
+ interactive=is_training,
518
+ variant="primary" if is_training else "secondary",
519
+ )
520
+ }
521
+
522
+ # Add delete_checkpoints_btn only if it exists in components
523
+ if "delete_checkpoints_btn" in self.components:
524
+ result["delete_checkpoints_btn"] = gr.Button(
525
+ value="Delete All Checkpoints",
526
+ interactive=has_checkpoints and not is_training,
527
  variant="stop",
528
+ )
529
+ else:
530
+ # Add pause_resume_btn as fallback
531
+ result["pause_resume_btn"] = gr.Button(
532
+ value="Resume Training" if status == "paused" else "Pause Training",
533
+ interactive=(is_training or status == "paused") and not is_completed,
534
  variant="secondary",
535
+ visible=False
536
  )
537
+
538
+ return result
vms/ui/video_trainer_ui.py CHANGED
@@ -100,15 +100,25 @@ class VideoTrainerUI:
100
  """Add auto-refresh timers to the UI"""
101
  # Status update timer (every 1 second)
102
  status_timer = gr.Timer(value=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  status_timer.tick(
104
  fn=self.tabs["train_tab"].get_latest_status_message_logs_and_button_labels,
105
- outputs=[
106
- self.tabs["train_tab"].components["status_box"],
107
- self.tabs["train_tab"].components["log_box"],
108
- self.tabs["train_tab"].components["start_btn"],
109
- self.tabs["train_tab"].components["stop_btn"],
110
- self.tabs["train_tab"].components["delete_checkpoints_btn"] # Replace pause_resume_btn
111
- ]
112
  )
113
 
114
  # Dataset refresh timer (every 5 seconds)
 
100
  """Add auto-refresh timers to the UI"""
101
  # Status update timer (every 1 second)
102
  status_timer = gr.Timer(value=1)
103
+
104
+ # Use a safer approach - check if the component exists before using it
105
+ outputs = [
106
+ self.tabs["train_tab"].components["status_box"],
107
+ self.tabs["train_tab"].components["log_box"],
108
+ self.tabs["train_tab"].components["start_btn"],
109
+ self.tabs["train_tab"].components["stop_btn"]
110
+ ]
111
+
112
+ # Add delete_checkpoints_btn only if it exists
113
+ if "delete_checkpoints_btn" in self.tabs["train_tab"].components:
114
+ outputs.append(self.tabs["train_tab"].components["delete_checkpoints_btn"])
115
+ else:
116
+ # Add pause_resume_btn as fallback
117
+ outputs.append(self.tabs["train_tab"].components["pause_resume_btn"])
118
+
119
  status_timer.tick(
120
  fn=self.tabs["train_tab"].get_latest_status_message_logs_and_button_labels,
121
+ outputs=outputs
 
 
 
 
 
 
122
  )
123
 
124
  # Dataset refresh timer (every 5 seconds)