jbilcke commited on
Commit
1042322
·
1 Parent(s): 57737a0

fixed some bugs with finetrainers CLI params

Browse files
vms/config.py CHANGED
@@ -485,6 +485,9 @@ class TrainingConfig:
485
  if self.precompute_conditions:
486
  args.append("--precompute_conditions")
487
 
 
 
 
488
  # Diffusion arguments
489
  if self.flow_resolution_shifting:
490
  args.append("--flow_resolution_shifting")
 
485
  if self.precompute_conditions:
486
  args.append("--precompute_conditions")
487
 
488
+ if hasattr(self, 'precomputation_items') and self.precomputation_items:
489
+ args.extend(["--precomputation_items", str(self.precomputation_items)])
490
+
491
  # Diffusion arguments
492
  if self.flow_resolution_shifting:
493
  args.append("--flow_resolution_shifting")
vms/services/trainer.py CHANGED
@@ -52,7 +52,10 @@ from ..utils import (
52
  logger = logging.getLogger(__name__)
53
 
54
  class TrainingService:
55
- def __init__(self):
 
 
 
56
  # State and log files
57
  self.session_file = OUTPUT_PATH / "session.json"
58
  self.status_file = OUTPUT_PATH / "status.json"
@@ -565,8 +568,8 @@ class TrainingService:
565
  logger.info(f"{log_prefix} training with model_type={model_type}, training_type={training_type}")
566
 
567
  # Update progress if available
568
- if progress:
569
- progress(0.15, desc="Setting up training configuration")
570
 
571
  try:
572
  # Get absolute paths - FIXED to look in project root instead of within vms directory
@@ -598,8 +601,8 @@ class TrainingService:
598
  logger.info("Training data path: %s", TRAINING_PATH)
599
 
600
  # Update progress
601
- if progress:
602
- progress(0.2, desc="Preparing training dataset")
603
 
604
  videos_file, prompts_file = prepare_finetrainers_dataset()
605
  if videos_file is None or prompts_file is None:
@@ -616,8 +619,8 @@ class TrainingService:
616
  return error_msg, "No training data available"
617
 
618
  # Update progress
619
- if progress:
620
- progress(0.25, desc="Creating dataset configuration")
621
 
622
  # Get preset configuration
623
  preset = TRAINING_PRESETS[preset_name]
@@ -627,13 +630,14 @@ class TrainingService:
627
 
628
  # Get the custom prompt prefix from the tabs
629
  custom_prompt_prefix = None
630
- if hasattr(self.app, 'tabs') and 'caption_tab' in self.app.tabs:
631
- if hasattr(self.app.tabs['caption_tab'], 'components') and 'custom_prompt_prefix' in self.app.tabs['caption_tab'].components:
632
- # Get the value and clean it
633
- prefix = self.app.tabs['caption_tab'].components['custom_prompt_prefix'].value
634
- if prefix:
635
- # Clean the prefix - remove trailing comma, space or comma+space
636
- custom_prompt_prefix = prefix.rstrip(', ')
 
637
 
638
  # Create a proper dataset configuration JSON file
639
  dataset_config_file = OUTPUT_PATH / "dataset_config.json"
@@ -725,10 +729,7 @@ class TrainingService:
725
  config.flow_weighting_scheme = flow_weighting_scheme
726
 
727
  config.lr_warmup_steps = int(lr_warmup_steps)
728
- config_args.extend([
729
- "--precomputation_items", str(precomputation_items)
730
- ])
731
-
732
  # Update the NUM_GPUS variable and CUDA_VISIBLE_DEVICES
733
  num_gpus = min(num_gpus, get_available_gpu_count())
734
  if num_gpus <= 0:
@@ -757,6 +758,8 @@ class TrainingService:
757
  config.enable_tiling = True
758
  config.caption_dropout_p = DEFAULT_CAPTION_DROPOUT_P
759
 
 
 
760
  validation_error = self.validate_training_config(config, model_type)
761
  if validation_error:
762
  error_msg = f"Configuration validation failed: {validation_error}"
@@ -843,8 +846,8 @@ class TrainingService:
843
  env["FINETRAINERS_LOG_LEVEL"] = "DEBUG" # Added for better debugging
844
  env["CUDA_VISIBLE_DEVICES"] = visible_devices
845
 
846
- if progress:
847
- progress(0.9, desc="Launching training process")
848
 
849
  # Start the training process
850
  process = subprocess.Popen(
@@ -901,8 +904,8 @@ class TrainingService:
901
  logger.info(success_msg)
902
 
903
  # Final progress update - now we'll track it through the log monitor
904
- if progress:
905
- progress(1.0, desc="Training started successfully")
906
 
907
  return success_msg, self.get_logs()
908
 
 
52
  logger = logging.getLogger(__name__)
53
 
54
  class TrainingService:
55
+ def __init__(self, app=None):
56
+ # Store reference to app
57
+ self.app = app
58
+
59
  # State and log files
60
  self.session_file = OUTPUT_PATH / "session.json"
61
  self.status_file = OUTPUT_PATH / "status.json"
 
568
  logger.info(f"{log_prefix} training with model_type={model_type}, training_type={training_type}")
569
 
570
  # Update progress if available
571
+ #if progress:
572
+ # progress(0.15, desc="Setting up training configuration")
573
 
574
  try:
575
  # Get absolute paths - FIXED to look in project root instead of within vms directory
 
601
  logger.info("Training data path: %s", TRAINING_PATH)
602
 
603
  # Update progress
604
+ #if progress:
605
+ # progress(0.2, desc="Preparing training dataset")
606
 
607
  videos_file, prompts_file = prepare_finetrainers_dataset()
608
  if videos_file is None or prompts_file is None:
 
619
  return error_msg, "No training data available"
620
 
621
  # Update progress
622
+ #if progress:
623
+ # progress(0.25, desc="Creating dataset configuration")
624
 
625
  # Get preset configuration
626
  preset = TRAINING_PRESETS[preset_name]
 
630
 
631
  # Get the custom prompt prefix from the tabs
632
  custom_prompt_prefix = None
633
+ if hasattr(self, 'app') and self.app is not None:
634
+ if hasattr(self.app, 'tabs') and 'caption_tab' in self.app.tabs:
635
+ if hasattr(self.app.tabs['caption_tab'], 'components') and 'custom_prompt_prefix' in self.app.tabs['caption_tab'].components:
636
+ # Get the value and clean it
637
+ prefix = self.app.tabs['caption_tab'].components['custom_prompt_prefix'].value
638
+ if prefix:
639
+ # Clean the prefix - remove trailing comma, space or comma+space
640
+ custom_prompt_prefix = prefix.rstrip(', ')
641
 
642
  # Create a proper dataset configuration JSON file
643
  dataset_config_file = OUTPUT_PATH / "dataset_config.json"
 
729
  config.flow_weighting_scheme = flow_weighting_scheme
730
 
731
  config.lr_warmup_steps = int(lr_warmup_steps)
732
+
 
 
 
733
  # Update the NUM_GPUS variable and CUDA_VISIBLE_DEVICES
734
  num_gpus = min(num_gpus, get_available_gpu_count())
735
  if num_gpus <= 0:
 
758
  config.enable_tiling = True
759
  config.caption_dropout_p = DEFAULT_CAPTION_DROPOUT_P
760
 
761
+ config.precomputation_items = precomputation_items
762
+
763
  validation_error = self.validate_training_config(config, model_type)
764
  if validation_error:
765
  error_msg = f"Configuration validation failed: {validation_error}"
 
846
  env["FINETRAINERS_LOG_LEVEL"] = "DEBUG" # Added for better debugging
847
  env["CUDA_VISIBLE_DEVICES"] = visible_devices
848
 
849
+ #if progress:
850
+ # progress(0.9, desc="Launching training process")
851
 
852
  # Start the training process
853
  process = subprocess.Popen(
 
904
  logger.info(success_msg)
905
 
906
  # Final progress update - now we'll track it through the log monitor
907
+ #if progress:
908
+ # progress(1.0, desc="Training started successfully")
909
 
910
  return success_msg, self.get_logs()
911
 
vms/tabs/train_tab.py CHANGED
@@ -384,7 +384,9 @@ class TrainTab(BaseTab):
384
  outputs=[self.components["status_box"]]
385
  )
386
 
387
- def handle_training_start(self, preset, model_type, training_type, *args, progress=gr.Progress()):
 
 
388
  """Handle training start with proper log parser reset and checkpoint detection"""
389
  # Safely reset log parser if it exists
390
  if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
@@ -395,7 +397,7 @@ class TrainTab(BaseTab):
395
  self.app.log_parser = TrainingLogParser()
396
 
397
  # Initialize progress
398
- progress(0, desc="Initializing training")
399
 
400
  # Check for latest checkpoint
401
  checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
@@ -406,9 +408,10 @@ class TrainTab(BaseTab):
406
  latest_checkpoint = max(checkpoints, key=os.path.getmtime)
407
  resume_from = str(latest_checkpoint)
408
  logger.info(f"Found checkpoint at {resume_from}, will resume training")
409
- progress(0.05, desc=f"Resuming from checkpoint {Path(resume_from).name}")
410
  else:
411
- progress(0.05, desc="Starting new training run")
 
412
 
413
  # Convert model_type display name to internal name
414
  model_internal_type = MODEL_TYPES.get(model_type)
@@ -424,8 +427,13 @@ class TrainTab(BaseTab):
424
  logger.error(f"Invalid training type: {training_type}")
425
  return f"Error: Invalid training type '{training_type}'", "Training type not recognized"
426
 
 
 
 
 
 
427
  # Progress update
428
- progress(0.1, desc="Preparing dataset")
429
 
430
  # Start training (it will automatically use the checkpoint if provided)
431
  try:
 
384
  outputs=[self.components["status_box"]]
385
  )
386
 
387
+ def handle_training_start(
388
+ self, preset, model_type, training_type, lora_rank, lora_alpha, train_steps, batch_size, learning_rate, save_iterations, repo_id, progress=gr.Progress()
389
+ ):
390
  """Handle training start with proper log parser reset and checkpoint detection"""
391
  # Safely reset log parser if it exists
392
  if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
 
397
  self.app.log_parser = TrainingLogParser()
398
 
399
  # Initialize progress
400
+ #progress(0, desc="Initializing training")
401
 
402
  # Check for latest checkpoint
403
  checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
 
408
  latest_checkpoint = max(checkpoints, key=os.path.getmtime)
409
  resume_from = str(latest_checkpoint)
410
  logger.info(f"Found checkpoint at {resume_from}, will resume training")
411
+ #progress(0.05, desc=f"Resuming from checkpoint {Path(resume_from).name}")
412
  else:
413
+ #progress(0.05, desc="Starting new training run")
414
+ pass
415
 
416
  # Convert model_type display name to internal name
417
  model_internal_type = MODEL_TYPES.get(model_type)
 
427
  logger.error(f"Invalid training type: {training_type}")
428
  return f"Error: Invalid training type '{training_type}'", "Training type not recognized"
429
 
430
+ # Get other parameters from UI form
431
+ num_gpus = int(self.components["num_gpus"].value)
432
+ precomputation_items = int(self.components["precomputation_items"].value)
433
+ lr_warmup_steps = int(self.components["lr_warmup_steps"].value)
434
+
435
  # Progress update
436
+ #progress(0.1, desc="Preparing dataset")
437
 
438
  # Start training (it will automatically use the checkpoint if provided)
439
  try:
vms/ui/video_trainer_ui.py CHANGED
@@ -40,7 +40,7 @@ class VideoTrainerUI:
40
  def __init__(self):
41
  """Initialize services and tabs"""
42
  # Initialize core services
43
- self.trainer = TrainingService()
44
  self.splitter = SplittingService()
45
  self.importer = ImportService()
46
  self.captioner = CaptioningService()
 
40
  def __init__(self):
41
  """Initialize services and tabs"""
42
  # Initialize core services
43
+ self.trainer = TrainingService(self)
44
  self.splitter = SplittingService()
45
  self.importer = ImportService()
46
  self.captioner = CaptioningService()