Spaces:
Running
Running
Commit
·
f3d03c6
1
Parent(s):
4d3d0e8
fixes
Browse files- vms/services/importer.py +1 -2
- vms/tabs/train_tab.py +94 -58
vms/services/importer.py
CHANGED
@@ -10,8 +10,7 @@ from pytubefix import YouTube
|
|
10 |
import logging
|
11 |
|
12 |
from ..config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
|
13 |
-
from ..utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption
|
14 |
-
from ..webdataset import webdataset_handler
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
|
|
10 |
import logging
|
11 |
|
12 |
from ..config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
|
13 |
+
from ..utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption, webdataset_handler
|
|
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
vms/tabs/train_tab.py
CHANGED
@@ -4,12 +4,12 @@ Train tab for Video Model Studio UI
|
|
4 |
|
5 |
import gradio as gr
|
6 |
import logging
|
|
|
7 |
from typing import Dict, Any, List, Optional, Tuple
|
8 |
from pathlib import Path
|
9 |
|
10 |
from .base_tab import BaseTab
|
11 |
-
from ..config import TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS
|
12 |
-
from ..utils import TrainingLogParser
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
@@ -156,7 +156,7 @@ class TrainTab(BaseTab):
|
|
156 |
# Model type change event
|
157 |
def update_model_info(model, training_type):
|
158 |
params = self.get_default_params(MODEL_TYPES[model], TRAINING_TYPES[training_type])
|
159 |
-
info = self.get_model_info(
|
160 |
show_lora_params = training_type == list(TRAINING_TYPES.keys())[0] # Show if LoRA Finetune
|
161 |
|
162 |
return {
|
@@ -313,6 +313,21 @@ class TrainTab(BaseTab):
|
|
313 |
self.components["pause_resume_btn"]
|
314 |
]
|
315 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
|
317 |
def handle_training_start(self, preset, model_type, training_type, *args):
|
318 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
@@ -360,86 +375,103 @@ class TrainTab(BaseTab):
|
|
360 |
except Exception as e:
|
361 |
logger.exception("Error starting training")
|
362 |
return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
|
363 |
-
|
364 |
|
365 |
def get_model_info(self, model_type: str, training_type: str) -> str:
|
366 |
"""Get information about the selected model type and training method"""
|
367 |
-
|
368 |
-
|
369 |
-
if model_type == "hunyuan_video":
|
370 |
base_info = """### HunyuanVideo
|
371 |
- Required VRAM: ~48GB minimum
|
372 |
- Recommended batch size: 1-2
|
373 |
- Typical training time: 2-4 hours
|
374 |
- Default resolution: 49x512x768"""
|
375 |
|
376 |
-
if training_type == "
|
377 |
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
378 |
else:
|
379 |
-
return base_info + "\n- Required VRAM: ~
|
380 |
|
381 |
-
elif model_type == "
|
382 |
-
base_info = """###
|
383 |
-
- Recommended batch size: 1-
|
384 |
- Typical training time: 1-3 hours
|
385 |
- Default resolution: 49x512x768"""
|
386 |
|
387 |
-
if training_type == "
|
388 |
-
return base_info + "\n- Required VRAM: ~
|
389 |
-
else:
|
390 |
-
return base_info + "\n- **Full finetune not supported in this UI**" + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
391 |
else:
|
392 |
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
|
393 |
|
394 |
-
elif model_type == "
|
395 |
base_info = """### Wan-2.1-T2V
|
396 |
- Recommended batch size: 1-2
|
397 |
- Typical training time: 1-3 hours
|
398 |
- Default resolution: 49x512x768"""
|
399 |
|
400 |
-
if training_type == "
|
401 |
return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
|
402 |
-
else:
|
403 |
-
return base_info + "\n- **Full finetune not supported in this UI**" + "\n- Default LoRA rank: 128 (~600 MB)"
|
404 |
else:
|
405 |
return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
- Recommended batch size: 1-4
|
410 |
-
- Typical training time: 1-3 hours
|
411 |
-
- Default resolution: 49x512x768"""
|
412 |
-
|
413 |
-
if training_type == "lora":
|
414 |
-
return base_
|
415 |
|
416 |
-
def get_default_params(self, model_type: str) -> Dict[str, Any]:
|
417 |
"""Get default training parameters for model type"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
418 |
if model_type == "hunyuan_video":
|
419 |
return {
|
420 |
"num_epochs": 70,
|
421 |
"batch_size": 1,
|
422 |
"learning_rate": 2e-5,
|
423 |
"save_iterations": 500,
|
424 |
-
"
|
425 |
-
"
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
"
|
|
|
|
|
|
|
|
|
|
|
430 |
}
|
431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
return {
|
433 |
"num_epochs": 70,
|
434 |
"batch_size": 1,
|
435 |
"learning_rate": 3e-5,
|
436 |
"save_iterations": 500,
|
437 |
-
"
|
438 |
-
"
|
439 |
-
"caption_dropout_p": 0.05,
|
440 |
-
"gradient_accumulation_steps": 4,
|
441 |
-
"rank": 128,
|
442 |
-
"lora_alpha": 128
|
443 |
}
|
444 |
|
445 |
def update_training_params(self, preset_name: str) -> Tuple:
|
@@ -454,6 +486,12 @@ class TrainTab(BaseTab):
|
|
454 |
key for key, value in MODEL_TYPES.items()
|
455 |
if value == preset["model_type"]
|
456 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
|
458 |
# Get preset description for display
|
459 |
description = preset.get("description", "")
|
@@ -467,24 +505,29 @@ class TrainTab(BaseTab):
|
|
467 |
|
468 |
info_text = f"{description}{bucket_info}"
|
469 |
|
470 |
-
#
|
|
|
|
|
471 |
# Use preset defaults but preserve user-modified values if they exist
|
472 |
-
lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset
|
473 |
-
lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset
|
474 |
-
num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset
|
475 |
-
batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset
|
476 |
-
learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset
|
477 |
-
save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset
|
478 |
|
|
|
479 |
return (
|
480 |
model_display_name,
|
|
|
481 |
lora_rank_val,
|
482 |
lora_alpha_val,
|
483 |
num_epochs_val,
|
484 |
batch_size_val,
|
485 |
learning_rate_val,
|
486 |
save_iterations_val,
|
487 |
-
info_text
|
|
|
488 |
)
|
489 |
|
490 |
def update_training_ui(self, training_state: Dict[str, Any]):
|
@@ -498,13 +541,6 @@ class TrainTab(BaseTab):
|
|
498 |
f"Status: {training_state['status']}",
|
499 |
f"Progress: {training_state['progress']}",
|
500 |
f"Step: {training_state['current_step']}/{training_state['total_steps']}",
|
501 |
-
|
502 |
-
# Epoch information
|
503 |
-
# there is an issue with how epoch is reported because we display:
|
504 |
-
# Progress: 96.9%, Step: 872/900, Epoch: 12/50
|
505 |
-
# we should probably just show the steps
|
506 |
-
#f"Epoch: {training_state['current_epoch']}/{training_state['total_epochs']}",
|
507 |
-
|
508 |
f"Time elapsed: {training_state['elapsed']}",
|
509 |
f"Estimated remaining: {training_state['remaining']}",
|
510 |
"",
|
|
|
4 |
|
5 |
import gradio as gr
|
6 |
import logging
|
7 |
+
import os
|
8 |
from typing import Dict, Any, List, Optional, Tuple
|
9 |
from pathlib import Path
|
10 |
|
11 |
from .base_tab import BaseTab
|
12 |
+
from ..config import TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS, TRAINING_TYPES
|
|
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
|
|
156 |
# Model type change event
|
157 |
def update_model_info(model, training_type):
|
158 |
params = self.get_default_params(MODEL_TYPES[model], TRAINING_TYPES[training_type])
|
159 |
+
info = self.get_model_info(model, training_type)
|
160 |
show_lora_params = training_type == list(TRAINING_TYPES.keys())[0] # Show if LoRA Finetune
|
161 |
|
162 |
return {
|
|
|
313 |
self.components["pause_resume_btn"]
|
314 |
]
|
315 |
)
|
316 |
+
|
317 |
+
# Add an event handler for delete_checkpoints_btn
|
318 |
+
self.components["delete_checkpoints_btn"].click(
|
319 |
+
fn=lambda: self.app.trainer.delete_all_checkpoints(),
|
320 |
+
outputs=[self.components["status_box"]]
|
321 |
+
).then(
|
322 |
+
fn=self.get_latest_status_message_logs_and_button_labels,
|
323 |
+
outputs=[
|
324 |
+
self.components["status_box"],
|
325 |
+
self.components["log_box"],
|
326 |
+
self.components["start_btn"],
|
327 |
+
self.components["stop_btn"],
|
328 |
+
self.components["delete_checkpoints_btn"]
|
329 |
+
]
|
330 |
+
)
|
331 |
|
332 |
def handle_training_start(self, preset, model_type, training_type, *args):
|
333 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
|
|
375 |
except Exception as e:
|
376 |
logger.exception("Error starting training")
|
377 |
return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
|
|
|
378 |
|
379 |
def get_model_info(self, model_type: str, training_type: str) -> str:
|
380 |
"""Get information about the selected model type and training method"""
|
381 |
+
if model_type == "HunyuanVideo (LoRA)":
|
|
|
|
|
382 |
base_info = """### HunyuanVideo
|
383 |
- Required VRAM: ~48GB minimum
|
384 |
- Recommended batch size: 1-2
|
385 |
- Typical training time: 2-4 hours
|
386 |
- Default resolution: 49x512x768"""
|
387 |
|
388 |
+
if training_type == "LoRA Finetune":
|
389 |
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
390 |
else:
|
391 |
+
return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
|
392 |
|
393 |
+
elif model_type == "LTX-Video (LoRA)":
|
394 |
+
base_info = """### LTX-Video
|
395 |
+
- Recommended batch size: 1-4
|
396 |
- Typical training time: 1-3 hours
|
397 |
- Default resolution: 49x512x768"""
|
398 |
|
399 |
+
if training_type == "LoRA Finetune":
|
400 |
+
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
|
|
|
|
401 |
else:
|
402 |
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
|
403 |
|
404 |
+
elif model_type == "Wan-2.1-T2V (LoRA)":
|
405 |
base_info = """### Wan-2.1-T2V
|
406 |
- Recommended batch size: 1-2
|
407 |
- Typical training time: 1-3 hours
|
408 |
- Default resolution: 49x512x768"""
|
409 |
|
410 |
+
if training_type == "LoRA Finetune":
|
411 |
return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
|
|
|
|
|
412 |
else:
|
413 |
return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
|
414 |
+
|
415 |
+
# Default fallback
|
416 |
+
return f"### {model_type}\nPlease check documentation for VRAM requirements and recommended settings."
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
|
418 |
+
def get_default_params(self, model_type: str, training_type: str) -> Dict[str, Any]:
|
419 |
"""Get default training parameters for model type"""
|
420 |
+
# Find preset that matches model type and training type
|
421 |
+
matching_presets = [
|
422 |
+
preset for preset_name, preset in TRAINING_PRESETS.items()
|
423 |
+
if preset["model_type"] == model_type and preset["training_type"] == training_type
|
424 |
+
]
|
425 |
+
|
426 |
+
if matching_presets:
|
427 |
+
# Use the first matching preset
|
428 |
+
preset = matching_presets[0]
|
429 |
+
return {
|
430 |
+
"num_epochs": preset.get("num_epochs", 70),
|
431 |
+
"batch_size": preset.get("batch_size", 1),
|
432 |
+
"learning_rate": preset.get("learning_rate", 3e-5),
|
433 |
+
"save_iterations": preset.get("save_iterations", 500),
|
434 |
+
"lora_rank": preset.get("lora_rank", "128"),
|
435 |
+
"lora_alpha": preset.get("lora_alpha", "128")
|
436 |
+
}
|
437 |
+
|
438 |
+
# Default fallbacks
|
439 |
if model_type == "hunyuan_video":
|
440 |
return {
|
441 |
"num_epochs": 70,
|
442 |
"batch_size": 1,
|
443 |
"learning_rate": 2e-5,
|
444 |
"save_iterations": 500,
|
445 |
+
"lora_rank": "128",
|
446 |
+
"lora_alpha": "128"
|
447 |
+
}
|
448 |
+
elif model_type == "ltx_video":
|
449 |
+
return {
|
450 |
+
"num_epochs": 70,
|
451 |
+
"batch_size": 1,
|
452 |
+
"learning_rate": 3e-5,
|
453 |
+
"save_iterations": 500,
|
454 |
+
"lora_rank": "128",
|
455 |
+
"lora_alpha": "128"
|
456 |
}
|
457 |
+
elif model_type == "wan":
|
458 |
+
return {
|
459 |
+
"num_epochs": 70,
|
460 |
+
"batch_size": 1,
|
461 |
+
"learning_rate": 5e-5,
|
462 |
+
"save_iterations": 500,
|
463 |
+
"lora_rank": "32",
|
464 |
+
"lora_alpha": "32"
|
465 |
+
}
|
466 |
+
else:
|
467 |
+
# Generic defaults
|
468 |
return {
|
469 |
"num_epochs": 70,
|
470 |
"batch_size": 1,
|
471 |
"learning_rate": 3e-5,
|
472 |
"save_iterations": 500,
|
473 |
+
"lora_rank": "128",
|
474 |
+
"lora_alpha": "128"
|
|
|
|
|
|
|
|
|
475 |
}
|
476 |
|
477 |
def update_training_params(self, preset_name: str) -> Tuple:
|
|
|
486 |
key for key, value in MODEL_TYPES.items()
|
487 |
if value == preset["model_type"]
|
488 |
)
|
489 |
+
|
490 |
+
# Find the display name that maps to our training type
|
491 |
+
training_display_name = next(
|
492 |
+
key for key, value in TRAINING_TYPES.items()
|
493 |
+
if value == preset["training_type"]
|
494 |
+
)
|
495 |
|
496 |
# Get preset description for display
|
497 |
description = preset.get("description", "")
|
|
|
505 |
|
506 |
info_text = f"{description}{bucket_info}"
|
507 |
|
508 |
+
# Check if LoRA params should be visible
|
509 |
+
show_lora_params = preset["training_type"] == "lora"
|
510 |
+
|
511 |
# Use preset defaults but preserve user-modified values if they exist
|
512 |
+
lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset.get("lora_rank", "128")
|
513 |
+
lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset.get("lora_alpha", "128")
|
514 |
+
num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset.get("num_epochs", 70)
|
515 |
+
batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset.get("batch_size", 1)
|
516 |
+
learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset.get("learning_rate", 3e-5)
|
517 |
+
save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset.get("save_iterations", 500)
|
518 |
|
519 |
+
# Return values in the same order as the output components
|
520 |
return (
|
521 |
model_display_name,
|
522 |
+
training_display_name,
|
523 |
lora_rank_val,
|
524 |
lora_alpha_val,
|
525 |
num_epochs_val,
|
526 |
batch_size_val,
|
527 |
learning_rate_val,
|
528 |
save_iterations_val,
|
529 |
+
info_text,
|
530 |
+
gr.Row(visible=show_lora_params)
|
531 |
)
|
532 |
|
533 |
def update_training_ui(self, training_state: Dict[str, Any]):
|
|
|
541 |
f"Status: {training_state['status']}",
|
542 |
f"Progress: {training_state['progress']}",
|
543 |
f"Step: {training_state['current_step']}/{training_state['total_steps']}",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
544 |
f"Time elapsed: {training_state['elapsed']}",
|
545 |
f"Estimated remaining: {training_state['remaining']}",
|
546 |
"",
|