Spaces:
Running
Running
fixed some bugs with finetrainers CLI params
Browse files- vms/config.py +3 -0
- vms/services/trainer.py +25 -22
- vms/tabs/train_tab.py +13 -5
- vms/ui/video_trainer_ui.py +1 -1
vms/config.py
CHANGED
@@ -485,6 +485,9 @@ class TrainingConfig:
|
|
485 |
if self.precompute_conditions:
|
486 |
args.append("--precompute_conditions")
|
487 |
|
|
|
|
|
|
|
488 |
# Diffusion arguments
|
489 |
if self.flow_resolution_shifting:
|
490 |
args.append("--flow_resolution_shifting")
|
|
|
485 |
if self.precompute_conditions:
|
486 |
args.append("--precompute_conditions")
|
487 |
|
488 |
+
if hasattr(self, 'precomputation_items') and self.precomputation_items:
|
489 |
+
args.extend(["--precomputation_items", str(self.precomputation_items)])
|
490 |
+
|
491 |
# Diffusion arguments
|
492 |
if self.flow_resolution_shifting:
|
493 |
args.append("--flow_resolution_shifting")
|
vms/services/trainer.py
CHANGED
@@ -52,7 +52,10 @@ from ..utils import (
|
|
52 |
logger = logging.getLogger(__name__)
|
53 |
|
54 |
class TrainingService:
|
55 |
-
def __init__(self):
|
|
|
|
|
|
|
56 |
# State and log files
|
57 |
self.session_file = OUTPUT_PATH / "session.json"
|
58 |
self.status_file = OUTPUT_PATH / "status.json"
|
@@ -565,8 +568,8 @@ class TrainingService:
|
|
565 |
logger.info(f"{log_prefix} training with model_type={model_type}, training_type={training_type}")
|
566 |
|
567 |
# Update progress if available
|
568 |
-
if progress:
|
569 |
-
|
570 |
|
571 |
try:
|
572 |
# Get absolute paths - FIXED to look in project root instead of within vms directory
|
@@ -598,8 +601,8 @@ class TrainingService:
|
|
598 |
logger.info("Training data path: %s", TRAINING_PATH)
|
599 |
|
600 |
# Update progress
|
601 |
-
if progress:
|
602 |
-
|
603 |
|
604 |
videos_file, prompts_file = prepare_finetrainers_dataset()
|
605 |
if videos_file is None or prompts_file is None:
|
@@ -616,8 +619,8 @@ class TrainingService:
|
|
616 |
return error_msg, "No training data available"
|
617 |
|
618 |
# Update progress
|
619 |
-
if progress:
|
620 |
-
|
621 |
|
622 |
# Get preset configuration
|
623 |
preset = TRAINING_PRESETS[preset_name]
|
@@ -627,13 +630,14 @@ class TrainingService:
|
|
627 |
|
628 |
# Get the custom prompt prefix from the tabs
|
629 |
custom_prompt_prefix = None
|
630 |
-
if hasattr(self
|
631 |
-
if hasattr(self.app
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
|
|
637 |
|
638 |
# Create a proper dataset configuration JSON file
|
639 |
dataset_config_file = OUTPUT_PATH / "dataset_config.json"
|
@@ -725,10 +729,7 @@ class TrainingService:
|
|
725 |
config.flow_weighting_scheme = flow_weighting_scheme
|
726 |
|
727 |
config.lr_warmup_steps = int(lr_warmup_steps)
|
728 |
-
|
729 |
-
"--precomputation_items", str(precomputation_items)
|
730 |
-
])
|
731 |
-
|
732 |
# Update the NUM_GPUS variable and CUDA_VISIBLE_DEVICES
|
733 |
num_gpus = min(num_gpus, get_available_gpu_count())
|
734 |
if num_gpus <= 0:
|
@@ -757,6 +758,8 @@ class TrainingService:
|
|
757 |
config.enable_tiling = True
|
758 |
config.caption_dropout_p = DEFAULT_CAPTION_DROPOUT_P
|
759 |
|
|
|
|
|
760 |
validation_error = self.validate_training_config(config, model_type)
|
761 |
if validation_error:
|
762 |
error_msg = f"Configuration validation failed: {validation_error}"
|
@@ -843,8 +846,8 @@ class TrainingService:
|
|
843 |
env["FINETRAINERS_LOG_LEVEL"] = "DEBUG" # Added for better debugging
|
844 |
env["CUDA_VISIBLE_DEVICES"] = visible_devices
|
845 |
|
846 |
-
if progress:
|
847 |
-
|
848 |
|
849 |
# Start the training process
|
850 |
process = subprocess.Popen(
|
@@ -901,8 +904,8 @@ class TrainingService:
|
|
901 |
logger.info(success_msg)
|
902 |
|
903 |
# Final progress update - now we'll track it through the log monitor
|
904 |
-
if progress:
|
905 |
-
|
906 |
|
907 |
return success_msg, self.get_logs()
|
908 |
|
|
|
52 |
logger = logging.getLogger(__name__)
|
53 |
|
54 |
class TrainingService:
|
55 |
+
def __init__(self, app=None):
|
56 |
+
# Store reference to app
|
57 |
+
self.app = app
|
58 |
+
|
59 |
# State and log files
|
60 |
self.session_file = OUTPUT_PATH / "session.json"
|
61 |
self.status_file = OUTPUT_PATH / "status.json"
|
|
|
568 |
logger.info(f"{log_prefix} training with model_type={model_type}, training_type={training_type}")
|
569 |
|
570 |
# Update progress if available
|
571 |
+
#if progress:
|
572 |
+
# progress(0.15, desc="Setting up training configuration")
|
573 |
|
574 |
try:
|
575 |
# Get absolute paths - FIXED to look in project root instead of within vms directory
|
|
|
601 |
logger.info("Training data path: %s", TRAINING_PATH)
|
602 |
|
603 |
# Update progress
|
604 |
+
#if progress:
|
605 |
+
# progress(0.2, desc="Preparing training dataset")
|
606 |
|
607 |
videos_file, prompts_file = prepare_finetrainers_dataset()
|
608 |
if videos_file is None or prompts_file is None:
|
|
|
619 |
return error_msg, "No training data available"
|
620 |
|
621 |
# Update progress
|
622 |
+
#if progress:
|
623 |
+
# progress(0.25, desc="Creating dataset configuration")
|
624 |
|
625 |
# Get preset configuration
|
626 |
preset = TRAINING_PRESETS[preset_name]
|
|
|
630 |
|
631 |
# Get the custom prompt prefix from the tabs
|
632 |
custom_prompt_prefix = None
|
633 |
+
if hasattr(self, 'app') and self.app is not None:
|
634 |
+
if hasattr(self.app, 'tabs') and 'caption_tab' in self.app.tabs:
|
635 |
+
if hasattr(self.app.tabs['caption_tab'], 'components') and 'custom_prompt_prefix' in self.app.tabs['caption_tab'].components:
|
636 |
+
# Get the value and clean it
|
637 |
+
prefix = self.app.tabs['caption_tab'].components['custom_prompt_prefix'].value
|
638 |
+
if prefix:
|
639 |
+
# Clean the prefix - remove trailing comma, space or comma+space
|
640 |
+
custom_prompt_prefix = prefix.rstrip(', ')
|
641 |
|
642 |
# Create a proper dataset configuration JSON file
|
643 |
dataset_config_file = OUTPUT_PATH / "dataset_config.json"
|
|
|
729 |
config.flow_weighting_scheme = flow_weighting_scheme
|
730 |
|
731 |
config.lr_warmup_steps = int(lr_warmup_steps)
|
732 |
+
|
|
|
|
|
|
|
733 |
# Update the NUM_GPUS variable and CUDA_VISIBLE_DEVICES
|
734 |
num_gpus = min(num_gpus, get_available_gpu_count())
|
735 |
if num_gpus <= 0:
|
|
|
758 |
config.enable_tiling = True
|
759 |
config.caption_dropout_p = DEFAULT_CAPTION_DROPOUT_P
|
760 |
|
761 |
+
config.precomputation_items = precomputation_items
|
762 |
+
|
763 |
validation_error = self.validate_training_config(config, model_type)
|
764 |
if validation_error:
|
765 |
error_msg = f"Configuration validation failed: {validation_error}"
|
|
|
846 |
env["FINETRAINERS_LOG_LEVEL"] = "DEBUG" # Added for better debugging
|
847 |
env["CUDA_VISIBLE_DEVICES"] = visible_devices
|
848 |
|
849 |
+
#if progress:
|
850 |
+
# progress(0.9, desc="Launching training process")
|
851 |
|
852 |
# Start the training process
|
853 |
process = subprocess.Popen(
|
|
|
904 |
logger.info(success_msg)
|
905 |
|
906 |
# Final progress update - now we'll track it through the log monitor
|
907 |
+
#if progress:
|
908 |
+
# progress(1.0, desc="Training started successfully")
|
909 |
|
910 |
return success_msg, self.get_logs()
|
911 |
|
vms/tabs/train_tab.py
CHANGED
@@ -384,7 +384,9 @@ class TrainTab(BaseTab):
|
|
384 |
outputs=[self.components["status_box"]]
|
385 |
)
|
386 |
|
387 |
-
def handle_training_start(
|
|
|
|
|
388 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
389 |
# Safely reset log parser if it exists
|
390 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
@@ -395,7 +397,7 @@ class TrainTab(BaseTab):
|
|
395 |
self.app.log_parser = TrainingLogParser()
|
396 |
|
397 |
# Initialize progress
|
398 |
-
progress(0, desc="Initializing training")
|
399 |
|
400 |
# Check for latest checkpoint
|
401 |
checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
|
@@ -406,9 +408,10 @@ class TrainTab(BaseTab):
|
|
406 |
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
407 |
resume_from = str(latest_checkpoint)
|
408 |
logger.info(f"Found checkpoint at {resume_from}, will resume training")
|
409 |
-
progress(0.05, desc=f"Resuming from checkpoint {Path(resume_from).name}")
|
410 |
else:
|
411 |
-
progress(0.05, desc="Starting new training run")
|
|
|
412 |
|
413 |
# Convert model_type display name to internal name
|
414 |
model_internal_type = MODEL_TYPES.get(model_type)
|
@@ -424,8 +427,13 @@ class TrainTab(BaseTab):
|
|
424 |
logger.error(f"Invalid training type: {training_type}")
|
425 |
return f"Error: Invalid training type '{training_type}'", "Training type not recognized"
|
426 |
|
|
|
|
|
|
|
|
|
|
|
427 |
# Progress update
|
428 |
-
progress(0.1, desc="Preparing dataset")
|
429 |
|
430 |
# Start training (it will automatically use the checkpoint if provided)
|
431 |
try:
|
|
|
384 |
outputs=[self.components["status_box"]]
|
385 |
)
|
386 |
|
387 |
+
def handle_training_start(
|
388 |
+
self, preset, model_type, training_type, lora_rank, lora_alpha, train_steps, batch_size, learning_rate, save_iterations, repo_id, progress=gr.Progress()
|
389 |
+
):
|
390 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
391 |
# Safely reset log parser if it exists
|
392 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
|
|
397 |
self.app.log_parser = TrainingLogParser()
|
398 |
|
399 |
# Initialize progress
|
400 |
+
#progress(0, desc="Initializing training")
|
401 |
|
402 |
# Check for latest checkpoint
|
403 |
checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
|
|
|
408 |
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
409 |
resume_from = str(latest_checkpoint)
|
410 |
logger.info(f"Found checkpoint at {resume_from}, will resume training")
|
411 |
+
#progress(0.05, desc=f"Resuming from checkpoint {Path(resume_from).name}")
|
412 |
else:
|
413 |
+
#progress(0.05, desc="Starting new training run")
|
414 |
+
pass
|
415 |
|
416 |
# Convert model_type display name to internal name
|
417 |
model_internal_type = MODEL_TYPES.get(model_type)
|
|
|
427 |
logger.error(f"Invalid training type: {training_type}")
|
428 |
return f"Error: Invalid training type '{training_type}'", "Training type not recognized"
|
429 |
|
430 |
+
# Get other parameters from UI form
|
431 |
+
num_gpus = int(self.components["num_gpus"].value)
|
432 |
+
precomputation_items = int(self.components["precomputation_items"].value)
|
433 |
+
lr_warmup_steps = int(self.components["lr_warmup_steps"].value)
|
434 |
+
|
435 |
# Progress update
|
436 |
+
#progress(0.1, desc="Preparing dataset")
|
437 |
|
438 |
# Start training (it will automatically use the checkpoint if provided)
|
439 |
try:
|
vms/ui/video_trainer_ui.py
CHANGED
@@ -40,7 +40,7 @@ class VideoTrainerUI:
|
|
40 |
def __init__(self):
|
41 |
"""Initialize services and tabs"""
|
42 |
# Initialize core services
|
43 |
-
self.trainer = TrainingService()
|
44 |
self.splitter = SplittingService()
|
45 |
self.importer = ImportService()
|
46 |
self.captioner = CaptioningService()
|
|
|
40 |
def __init__(self):
|
41 |
"""Initialize services and tabs"""
|
42 |
# Initialize core services
|
43 |
+
self.trainer = TrainingService(self)
|
44 |
self.splitter = SplittingService()
|
45 |
self.importer = ImportService()
|
46 |
self.captioner = CaptioningService()
|