jbilcke-hf HF Staff commited on
Commit
f3d03c6
·
1 Parent(s): 4d3d0e8
Files changed (2) hide show
  1. vms/services/importer.py +1 -2
  2. vms/tabs/train_tab.py +94 -58
vms/services/importer.py CHANGED
@@ -10,8 +10,7 @@ from pytubefix import YouTube
10
  import logging
11
 
12
  from ..config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
13
- from ..utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption
14
- from ..webdataset import webdataset_handler
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
10
  import logging
11
 
12
  from ..config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
13
+ from ..utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption, webdataset_handler
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
vms/tabs/train_tab.py CHANGED
@@ -4,12 +4,12 @@ Train tab for Video Model Studio UI
4
 
5
  import gradio as gr
6
  import logging
 
7
  from typing import Dict, Any, List, Optional, Tuple
8
  from pathlib import Path
9
 
10
  from .base_tab import BaseTab
11
- from ..config import TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS
12
- from ..utils import TrainingLogParser
13
 
14
  logger = logging.getLogger(__name__)
15
 
@@ -156,7 +156,7 @@ class TrainTab(BaseTab):
156
  # Model type change event
157
  def update_model_info(model, training_type):
158
  params = self.get_default_params(MODEL_TYPES[model], TRAINING_TYPES[training_type])
159
- info = self.get_model_info(MODEL_TYPES[model], TRAINING_TYPES[training_type])
160
  show_lora_params = training_type == list(TRAINING_TYPES.keys())[0] # Show if LoRA Finetune
161
 
162
  return {
@@ -313,6 +313,21 @@ class TrainTab(BaseTab):
313
  self.components["pause_resume_btn"]
314
  ]
315
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
  def handle_training_start(self, preset, model_type, training_type, *args):
318
  """Handle training start with proper log parser reset and checkpoint detection"""
@@ -360,86 +375,103 @@ class TrainTab(BaseTab):
360
  except Exception as e:
361
  logger.exception("Error starting training")
362
  return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
363
-
364
 
365
  def get_model_info(self, model_type: str, training_type: str) -> str:
366
  """Get information about the selected model type and training method"""
367
- training_method = "LoRA finetune" if training_type == "lora" else "Full finetune"
368
-
369
- if model_type == "hunyuan_video":
370
  base_info = """### HunyuanVideo
371
  - Required VRAM: ~48GB minimum
372
  - Recommended batch size: 1-2
373
  - Typical training time: 2-4 hours
374
  - Default resolution: 49x512x768"""
375
 
376
- if training_type == "lora":
377
  return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
378
  else:
379
- return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
380
 
381
- elif model_type == "wan":
382
- base_info = """### Wan-2.1-T2V
383
- - Recommended batch size: 1-2
384
  - Typical training time: 1-3 hours
385
  - Default resolution: 49x512x768"""
386
 
387
- if training_type == "lora":
388
- return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
389
- else:
390
- return base_info + "\n- **Full finetune not supported in this UI**" + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
391
  else:
392
  return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
393
 
394
- elif model_type == "wan":
395
  base_info = """### Wan-2.1-T2V
396
  - Recommended batch size: 1-2
397
  - Typical training time: 1-3 hours
398
  - Default resolution: 49x512x768"""
399
 
400
- if training_type == "lora":
401
  return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
402
- else:
403
- return base_info + "\n- **Full finetune not supported in this UI**" + "\n- Default LoRA rank: 128 (~600 MB)"
404
  else:
405
  return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
406
-
407
- elif model_type == "ltx_video":
408
- base_info = """### LTX-Video
409
- - Recommended batch size: 1-4
410
- - Typical training time: 1-3 hours
411
- - Default resolution: 49x512x768"""
412
-
413
- if training_type == "lora":
414
- return base_
415
 
416
- def get_default_params(self, model_type: str) -> Dict[str, Any]:
417
  """Get default training parameters for model type"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  if model_type == "hunyuan_video":
419
  return {
420
  "num_epochs": 70,
421
  "batch_size": 1,
422
  "learning_rate": 2e-5,
423
  "save_iterations": 500,
424
- "video_resolution_buckets": SMALL_TRAINING_BUCKETS,
425
- "video_reshape_mode": "center",
426
- "caption_dropout_p": 0.05,
427
- "gradient_accumulation_steps": 1,
428
- "rank": 128,
429
- "lora_alpha": 128
 
 
 
 
 
430
  }
431
- else: # ltx_video
 
 
 
 
 
 
 
 
 
 
432
  return {
433
  "num_epochs": 70,
434
  "batch_size": 1,
435
  "learning_rate": 3e-5,
436
  "save_iterations": 500,
437
- "video_resolution_buckets": SMALL_TRAINING_BUCKETS,
438
- "video_reshape_mode": "center",
439
- "caption_dropout_p": 0.05,
440
- "gradient_accumulation_steps": 4,
441
- "rank": 128,
442
- "lora_alpha": 128
443
  }
444
 
445
  def update_training_params(self, preset_name: str) -> Tuple:
@@ -454,6 +486,12 @@ class TrainTab(BaseTab):
454
  key for key, value in MODEL_TYPES.items()
455
  if value == preset["model_type"]
456
  )
 
 
 
 
 
 
457
 
458
  # Get preset description for display
459
  description = preset.get("description", "")
@@ -467,24 +505,29 @@ class TrainTab(BaseTab):
467
 
468
  info_text = f"{description}{bucket_info}"
469
 
470
- # Return values in the same order as the output components
 
 
471
  # Use preset defaults but preserve user-modified values if they exist
472
- lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset["lora_rank"]
473
- lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset["lora_alpha"]
474
- num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset["num_epochs"]
475
- batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset["batch_size"]
476
- learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset["learning_rate"]
477
- save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset["save_iterations"]
478
 
 
479
  return (
480
  model_display_name,
 
481
  lora_rank_val,
482
  lora_alpha_val,
483
  num_epochs_val,
484
  batch_size_val,
485
  learning_rate_val,
486
  save_iterations_val,
487
- info_text
 
488
  )
489
 
490
  def update_training_ui(self, training_state: Dict[str, Any]):
@@ -498,13 +541,6 @@ class TrainTab(BaseTab):
498
  f"Status: {training_state['status']}",
499
  f"Progress: {training_state['progress']}",
500
  f"Step: {training_state['current_step']}/{training_state['total_steps']}",
501
-
502
- # Epoch information
503
- # there is an issue with how epoch is reported because we display:
504
- # Progress: 96.9%, Step: 872/900, Epoch: 12/50
505
- # we should probably just show the steps
506
- #f"Epoch: {training_state['current_epoch']}/{training_state['total_epochs']}",
507
-
508
  f"Time elapsed: {training_state['elapsed']}",
509
  f"Estimated remaining: {training_state['remaining']}",
510
  "",
 
4
 
5
  import gradio as gr
6
  import logging
7
+ import os
8
  from typing import Dict, Any, List, Optional, Tuple
9
  from pathlib import Path
10
 
11
  from .base_tab import BaseTab
12
+ from ..config import TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS, TRAINING_TYPES
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
 
156
  # Model type change event
157
  def update_model_info(model, training_type):
158
  params = self.get_default_params(MODEL_TYPES[model], TRAINING_TYPES[training_type])
159
+ info = self.get_model_info(model, training_type)
160
  show_lora_params = training_type == list(TRAINING_TYPES.keys())[0] # Show if LoRA Finetune
161
 
162
  return {
 
313
  self.components["pause_resume_btn"]
314
  ]
315
  )
316
+
317
+ # Add an event handler for delete_checkpoints_btn
318
+ self.components["delete_checkpoints_btn"].click(
319
+ fn=lambda: self.app.trainer.delete_all_checkpoints(),
320
+ outputs=[self.components["status_box"]]
321
+ ).then(
322
+ fn=self.get_latest_status_message_logs_and_button_labels,
323
+ outputs=[
324
+ self.components["status_box"],
325
+ self.components["log_box"],
326
+ self.components["start_btn"],
327
+ self.components["stop_btn"],
328
+ self.components["delete_checkpoints_btn"]
329
+ ]
330
+ )
331
 
332
  def handle_training_start(self, preset, model_type, training_type, *args):
333
  """Handle training start with proper log parser reset and checkpoint detection"""
 
375
  except Exception as e:
376
  logger.exception("Error starting training")
377
  return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
 
378
 
379
  def get_model_info(self, model_type: str, training_type: str) -> str:
380
  """Get information about the selected model type and training method"""
381
+ if model_type == "HunyuanVideo (LoRA)":
 
 
382
  base_info = """### HunyuanVideo
383
  - Required VRAM: ~48GB minimum
384
  - Recommended batch size: 1-2
385
  - Typical training time: 2-4 hours
386
  - Default resolution: 49x512x768"""
387
 
388
+ if training_type == "LoRA Finetune":
389
  return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
390
  else:
391
+ return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
392
 
393
+ elif model_type == "LTX-Video (LoRA)":
394
+ base_info = """### LTX-Video
395
+ - Recommended batch size: 1-4
396
  - Typical training time: 1-3 hours
397
  - Default resolution: 49x512x768"""
398
 
399
+ if training_type == "LoRA Finetune":
400
+ return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
 
 
401
  else:
402
  return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
403
 
404
+ elif model_type == "Wan-2.1-T2V (LoRA)":
405
  base_info = """### Wan-2.1-T2V
406
  - Recommended batch size: 1-2
407
  - Typical training time: 1-3 hours
408
  - Default resolution: 49x512x768"""
409
 
410
+ if training_type == "LoRA Finetune":
411
  return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
 
 
412
  else:
413
  return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
414
+
415
+ # Default fallback
416
+ return f"### {model_type}\nPlease check documentation for VRAM requirements and recommended settings."
 
 
 
 
 
 
417
 
418
+ def get_default_params(self, model_type: str, training_type: str) -> Dict[str, Any]:
419
  """Get default training parameters for model type"""
420
+ # Find preset that matches model type and training type
421
+ matching_presets = [
422
+ preset for preset_name, preset in TRAINING_PRESETS.items()
423
+ if preset["model_type"] == model_type and preset["training_type"] == training_type
424
+ ]
425
+
426
+ if matching_presets:
427
+ # Use the first matching preset
428
+ preset = matching_presets[0]
429
+ return {
430
+ "num_epochs": preset.get("num_epochs", 70),
431
+ "batch_size": preset.get("batch_size", 1),
432
+ "learning_rate": preset.get("learning_rate", 3e-5),
433
+ "save_iterations": preset.get("save_iterations", 500),
434
+ "lora_rank": preset.get("lora_rank", "128"),
435
+ "lora_alpha": preset.get("lora_alpha", "128")
436
+ }
437
+
438
+ # Default fallbacks
439
  if model_type == "hunyuan_video":
440
  return {
441
  "num_epochs": 70,
442
  "batch_size": 1,
443
  "learning_rate": 2e-5,
444
  "save_iterations": 500,
445
+ "lora_rank": "128",
446
+ "lora_alpha": "128"
447
+ }
448
+ elif model_type == "ltx_video":
449
+ return {
450
+ "num_epochs": 70,
451
+ "batch_size": 1,
452
+ "learning_rate": 3e-5,
453
+ "save_iterations": 500,
454
+ "lora_rank": "128",
455
+ "lora_alpha": "128"
456
  }
457
+ elif model_type == "wan":
458
+ return {
459
+ "num_epochs": 70,
460
+ "batch_size": 1,
461
+ "learning_rate": 5e-5,
462
+ "save_iterations": 500,
463
+ "lora_rank": "32",
464
+ "lora_alpha": "32"
465
+ }
466
+ else:
467
+ # Generic defaults
468
  return {
469
  "num_epochs": 70,
470
  "batch_size": 1,
471
  "learning_rate": 3e-5,
472
  "save_iterations": 500,
473
+ "lora_rank": "128",
474
+ "lora_alpha": "128"
 
 
 
 
475
  }
476
 
477
  def update_training_params(self, preset_name: str) -> Tuple:
 
486
  key for key, value in MODEL_TYPES.items()
487
  if value == preset["model_type"]
488
  )
489
+
490
+ # Find the display name that maps to our training type
491
+ training_display_name = next(
492
+ key for key, value in TRAINING_TYPES.items()
493
+ if value == preset["training_type"]
494
+ )
495
 
496
  # Get preset description for display
497
  description = preset.get("description", "")
 
505
 
506
  info_text = f"{description}{bucket_info}"
507
 
508
+ # Check if LoRA params should be visible
509
+ show_lora_params = preset["training_type"] == "lora"
510
+
511
  # Use preset defaults but preserve user-modified values if they exist
512
+ lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset.get("lora_rank", "128")
513
+ lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset.get("lora_alpha", "128")
514
+ num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset.get("num_epochs", 70)
515
+ batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset.get("batch_size", 1)
516
+ learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset.get("learning_rate", 3e-5)
517
+ save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset.get("save_iterations", 500)
518
 
519
+ # Return values in the same order as the output components
520
  return (
521
  model_display_name,
522
+ training_display_name,
523
  lora_rank_val,
524
  lora_alpha_val,
525
  num_epochs_val,
526
  batch_size_val,
527
  learning_rate_val,
528
  save_iterations_val,
529
+ info_text,
530
+ gr.Row(visible=show_lora_params)
531
  )
532
 
533
  def update_training_ui(self, training_state: Dict[str, Any]):
 
541
  f"Status: {training_state['status']}",
542
  f"Progress: {training_state['progress']}",
543
  f"Step: {training_state['current_step']}/{training_state['total_steps']}",
 
 
 
 
 
 
 
544
  f"Time elapsed: {training_state['elapsed']}",
545
  f"Estimated remaining: {training_state['remaining']}",
546
  "",