jbilcke-hf HF staff commited on
Commit
29d6f3c
·
1 Parent(s): 12bcca7

various fixes

Browse files
Files changed (1) hide show
  1. vms/services/trainer.py +111 -38
vms/services/trainer.py CHANGED
@@ -304,30 +304,38 @@ class TrainingService:
304
 
305
  try:
306
  # Basic validation
307
- if not config.data_root or not Path(config.data_root).exists():
308
- return f"Invalid data root path: {config.data_root}"
309
-
310
  if not config.output_dir:
311
  return "Output directory not specified"
312
 
313
- # Check for required files
314
- videos_file = Path(config.data_root) / "videos.txt"
315
- prompts_file = Path(config.data_root) / "prompts.txt"
 
316
 
317
- if not videos_file.exists():
318
- return f"Missing videos list file: {videos_file}"
319
- if not prompts_file.exists():
320
- return f"Missing prompts list file: {prompts_file}"
321
 
322
- # Validate file counts match
323
- video_lines = [l.strip() for l in open(videos_file) if l.strip()]
324
- prompt_lines = [l.strip() for l in open(prompts_file) if l.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
- if not video_lines:
327
  return "No training files found"
328
- if len(video_lines) != len(prompt_lines):
329
- return f"Mismatch between video count ({len(video_lines)}) and prompt count ({len(prompt_lines)})"
330
-
331
  # Model-specific validation
332
  if model_type == "hunyuan_video":
333
  if config.batch_size > 2:
@@ -341,13 +349,13 @@ class TrainingService:
341
  if config.batch_size > 4:
342
  return "Wan model recommended batch size is 1-4"
343
 
344
- logger.info(f"Config validation passed with {len(video_lines)} training files")
345
  return None
346
 
347
  except Exception as e:
348
  logger.error(f"Error during config validation: {str(e)}")
349
  return f"Configuration validation failed: {str(e)}"
350
-
351
  def start_training(
352
  self,
353
  model_type: str,
@@ -427,6 +435,36 @@ class TrainingService:
427
  flow_weighting_scheme = preset.get("flow_weighting_scheme", "none")
428
  preset_training_type = preset.get("training_type", "lora")
429
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  # Get config for selected model type with preset buckets
431
  if model_type == "hunyuan_video":
432
  if training_type == "lora":
@@ -477,6 +515,9 @@ class TrainingService:
477
  config.training_type = training_type
478
  config.flow_weighting_scheme = flow_weighting_scheme
479
 
 
 
 
480
  # Update LoRA parameters if using LoRA training type
481
  if training_type == "lora":
482
  config.lora_rank = int(lora_rank)
@@ -501,26 +542,58 @@ class TrainingService:
501
  logger.error(error_msg)
502
  return "Error: Invalid configuration", error_msg
503
 
504
- # Configure accelerate parameters
505
- accelerate_args = [
506
- "accelerate", "launch",
507
- "--mixed_precision=bf16",
508
- "--num_processes=1",
509
- "--num_machines=1",
510
- "--dynamo_backend=no"
511
- ]
512
-
513
- accelerate_args.append(str(train_script))
514
-
515
- # Convert config to command line arguments
516
  config_args = config.to_args_list()
517
-
518
  logger.debug("Generated args list: %s", config_args)
519
-
520
- # Log the full command for debugging
521
- command_str = ' '.join(accelerate_args + config_args)
522
- self.append_log(f"Command: {command_str}")
523
- logger.info(f"Executing command: {command_str}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
525
  # Set environment variables
526
  env = os.environ.copy()
@@ -532,7 +605,7 @@ class TrainingService:
532
 
533
  # Start the training process
534
  process = subprocess.Popen(
535
- accelerate_args + config_args,
536
  stdout=subprocess.PIPE,
537
  stderr=subprocess.PIPE,
538
  start_new_session=True,
 
304
 
305
  try:
306
  # Basic validation
 
 
 
307
  if not config.output_dir:
308
  return "Output directory not specified"
309
 
310
+ # For the dataset_config validation, we now expect it to be a JSON file
311
+ dataset_config_path = Path(config.data_root)
312
+ if not dataset_config_path.exists():
313
+ return f"Dataset config file does not exist: {dataset_config_path}"
314
 
315
+ # Check the JSON file is valid
316
+ try:
317
+ with open(dataset_config_path, 'r') as f:
318
+ dataset_json = json.load(f)
319
 
320
+ # Basic validation of the JSON structure
321
+ if "datasets" not in dataset_json or not isinstance(dataset_json["datasets"], list) or len(dataset_json["datasets"]) == 0:
322
+ return "Invalid dataset config JSON: missing or empty 'datasets' array"
323
+
324
+ except json.JSONDecodeError:
325
+ return f"Invalid JSON in dataset config file: {dataset_config_path}"
326
+ except Exception as e:
327
+ return f"Error reading dataset config file: {str(e)}"
328
+
329
+ # Check training videos directory exists
330
+ if not TRAINING_VIDEOS_PATH.exists():
331
+ return f"Training videos directory does not exist: {TRAINING_VIDEOS_PATH}"
332
+
333
+ # Validate file counts
334
+ video_count = len(list(TRAINING_VIDEOS_PATH.glob('*.mp4')))
335
 
336
+ if video_count == 0:
337
  return "No training files found"
338
+
 
 
339
  # Model-specific validation
340
  if model_type == "hunyuan_video":
341
  if config.batch_size > 2:
 
349
  if config.batch_size > 4:
350
  return "Wan model recommended batch size is 1-4"
351
 
352
+ logger.info(f"Config validation passed with {video_count} training files")
353
  return None
354
 
355
  except Exception as e:
356
  logger.error(f"Error during config validation: {str(e)}")
357
  return f"Configuration validation failed: {str(e)}"
358
+
359
  def start_training(
360
  self,
361
  model_type: str,
 
435
  flow_weighting_scheme = preset.get("flow_weighting_scheme", "none")
436
  preset_training_type = preset.get("training_type", "lora")
437
 
438
+ # Create a proper dataset configuration JSON file
439
+ dataset_config_file = OUTPUT_PATH / "dataset_config.json"
440
+
441
+ # Determine appropriate ID token based on model type
442
+ id_token = None
443
+ if model_type == "hunyuan_video":
444
+ id_token = "afkx"
445
+ elif model_type == "ltx_video":
446
+ id_token = "BW_STYLE"
447
+ # Wan doesn't use an ID token by default, so leave it as None
448
+
449
+ dataset_config = {
450
+ "datasets": [
451
+ {
452
+ "data_root": str(TRAINING_PATH),
453
+ "dataset_type": "video",
454
+ "id_token": id_token,
455
+ "video_resolution_buckets": [[f, h, w] for f, h, w in training_buckets],
456
+ "reshape_mode": "bicubic",
457
+ "remove_common_llm_caption_prefixes": True
458
+ }
459
+ ]
460
+ }
461
+
462
+ # Write the dataset config to file
463
+ with open(dataset_config_file, 'w') as f:
464
+ json.dump(dataset_config, f, indent=2)
465
+
466
+ logger.info(f"Created dataset configuration file at {dataset_config_file}")
467
+
468
  # Get config for selected model type with preset buckets
469
  if model_type == "hunyuan_video":
470
  if training_type == "lora":
 
515
  config.training_type = training_type
516
  config.flow_weighting_scheme = flow_weighting_scheme
517
 
518
+ # CRITICAL FIX: Update the dataset_config to point to the JSON file, not the directory
519
+ config.data_root = str(dataset_config_file)
520
+
521
  # Update LoRA parameters if using LoRA training type
522
  if training_type == "lora":
523
  config.lora_rank = int(lora_rank)
 
542
  logger.error(error_msg)
543
  return "Error: Invalid configuration", error_msg
544
 
545
+ # Convert config to command line arguments for all launcher types
 
 
 
 
 
 
 
 
 
 
 
546
  config_args = config.to_args_list()
 
547
  logger.debug("Generated args list: %s", config_args)
548
+
549
+ # Use different launch commands based on model type
550
+ # For Wan models, use torchrun instead of accelerate launch
551
+ if model_type == "wan":
552
+ # Configure torchrun parameters
553
+ torchrun_args = [
554
+ "torchrun",
555
+ "--standalone",
556
+ "--nproc_per_node=1",
557
+ "--nnodes=1",
558
+ "--rdzv_backend=c10d",
559
+ "--rdzv_endpoint=localhost:0",
560
+ str(train_script)
561
+ ]
562
+
563
+ # Additional args needed for torchrun
564
+ config_args.extend([
565
+ "--parallel_backend", "ptd",
566
+ "--pp_degree", "1",
567
+ "--dp_degree", "1",
568
+ "--dp_shards", "1",
569
+ "--cp_degree", "1",
570
+ "--tp_degree", "1"
571
+ ])
572
+
573
+ # Log the full command for debugging
574
+ command_str = ' '.join(torchrun_args + config_args)
575
+ self.append_log(f"Command: {command_str}")
576
+ logger.info(f"Executing command: {command_str}")
577
+
578
+ launch_args = torchrun_args
579
+ else:
580
+ # For other models, use accelerate launch as before
581
+ # Configure accelerate parameters
582
+ accelerate_args = [
583
+ "accelerate", "launch",
584
+ "--mixed_precision=bf16",
585
+ "--num_processes=1",
586
+ "--num_machines=1",
587
+ "--dynamo_backend=no",
588
+ str(train_script)
589
+ ]
590
+
591
+ # Log the full command for debugging
592
+ command_str = ' '.join(accelerate_args + config_args)
593
+ self.append_log(f"Command: {command_str}")
594
+ logger.info(f"Executing command: {command_str}")
595
+
596
+ launch_args = accelerate_args
597
 
598
  # Set environment variables
599
  env = os.environ.copy()
 
605
 
606
  # Start the training process
607
  process = subprocess.Popen(
608
+ launch_args + config_args,
609
  stdout=subprocess.PIPE,
610
  stderr=subprocess.PIPE,
611
  start_new_session=True,