jbilcke-hf HF staff commited on
Commit
91fb4ef
·
0 Parent(s):

initial commit log 🪵🦫

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +8 -0
  2. Dockerfile +44 -0
  3. README.md +96 -0
  4. accelerate_configs/compiled_1.yaml +22 -0
  5. accelerate_configs/deepspeed.yaml +23 -0
  6. accelerate_configs/uncompiled_1.yaml +17 -0
  7. accelerate_configs/uncompiled_2.yaml +17 -0
  8. accelerate_configs/uncompiled_8.yaml +17 -0
  9. app.py +1270 -0
  10. captioning_service.py +534 -0
  11. config.py +303 -0
  12. finetrainers/__init__.py +2 -0
  13. finetrainers/args.py +1191 -0
  14. finetrainers/constants.py +80 -0
  15. finetrainers/dataset.py +467 -0
  16. finetrainers/hooks/__init__.py +1 -0
  17. finetrainers/hooks/hooks.py +176 -0
  18. finetrainers/hooks/layerwise_upcasting.py +140 -0
  19. finetrainers/models/__init__.py +33 -0
  20. finetrainers/models/cogvideox/__init__.py +2 -0
  21. finetrainers/models/cogvideox/full_finetune.py +32 -0
  22. finetrainers/models/cogvideox/lora.py +334 -0
  23. finetrainers/models/cogvideox/utils.py +51 -0
  24. finetrainers/models/hunyuan_video/__init__.py +2 -0
  25. finetrainers/models/hunyuan_video/full_finetune.py +30 -0
  26. finetrainers/models/hunyuan_video/lora.py +368 -0
  27. finetrainers/models/ltx_video/__init__.py +2 -0
  28. finetrainers/models/ltx_video/full_finetune.py +30 -0
  29. finetrainers/models/ltx_video/lora.py +331 -0
  30. finetrainers/patches.py +50 -0
  31. finetrainers/state.py +24 -0
  32. finetrainers/trainer.py +1207 -0
  33. finetrainers/utils/__init__.py +13 -0
  34. finetrainers/utils/checkpointing.py +64 -0
  35. finetrainers/utils/data_utils.py +35 -0
  36. finetrainers/utils/diffusion_utils.py +145 -0
  37. finetrainers/utils/file_utils.py +44 -0
  38. finetrainers/utils/hub_utils.py +84 -0
  39. finetrainers/utils/memory_utils.py +58 -0
  40. finetrainers/utils/model_utils.py +25 -0
  41. finetrainers/utils/optimizer_utils.py +178 -0
  42. finetrainers/utils/torch_utils.py +35 -0
  43. finetrainers_utils.py +126 -0
  44. image_preprocessing.py +116 -0
  45. import_service.py +245 -0
  46. requirements.txt +43 -0
  47. requirements_without_flash_attention.txt +42 -0
  48. run.sh +5 -0
  49. setup.sh +7 -0
  50. setup_no_captions.sh +12 -0
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ .venv
3
+ .data
4
+ __pycache__
5
+ *.mp3
6
+ *.mp4
7
+ *.zip
8
+ training_service.log
Dockerfile ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
2
+
3
+ # Prevent interactive prompts during build
4
+ ARG DEBIAN_FRONTEND=noninteractive
5
+
6
+ # Set environment variables
7
+ ENV PYTHONUNBUFFERED=1
8
+ ENV PYTHONDONTWRITEBYTECODE=1
9
+ ENV DEBIAN_FRONTEND=noninteractive
10
+
11
+ # Install system dependencies
12
+ RUN apt-get update && apt-get install -y \
13
+ python3.10 \
14
+ python3-pip \
15
+ python3-dev \
16
+ git \
17
+ ffmpeg \
18
+ libsm6 \
19
+ libxext6 \
20
+ libgl1-mesa-glx \
21
+ libglib2.0-0 \
22
+ && apt-get clean \
23
+ && rm -rf /var/lib/apt/lists/*
24
+
25
+ # Create app directory
26
+ WORKDIR /app
27
+
28
+ # Install Python dependencies first for better caching
29
+ COPY requirements.txt .
30
+ RUN pip3 install --no-cache-dir -r requirements.txt
31
+
32
+ # actually we found a way to put flash attention inside the requirements.txt
33
+ # so we are good, we don't need this anymore:
34
+ # RUN pip3 install --no-cache-dir -r requirements_without_flash_attention.txt
35
+ # RUN pip3 install wheel setuptools flash-attn --no-build-isolation --no-cache-dir
36
+
37
+ # Copy application files
38
+ COPY . .
39
+
40
+ # Expose Gradio port
41
+ EXPOSE 7860
42
+
43
+ # Run the application
44
+ CMD ["python3", "app.py"]
README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Video Model Studio
3
+ emoji: 🎥
4
+ colorFrom: gray
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.15.0
8
+ app_file: app.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ short_description: All-in-one tool for AI video training
12
+ ---
13
+
14
+ # 🎥 Video Model Studio (VMS)
15
+
16
+ ## Presentation
17
+
18
+ VMS is an all-in-one tool to train LoRA models for various open-source AI video models:
19
+
20
+ - Data collection from various sources
21
+ - Splitting videos into short single camera shots
22
+ - Automatic captioning
23
+ - Training HunyuanVideo or LTX-Video
24
+
25
+ ## Similar projects
26
+
27
+ I wasn't aware of it when I started this project,
28
+ but there is also this: https://github.com/alisson-anjos/diffusion-pipe-ui
29
+
30
+ ## Installation
31
+
32
+ VMS is built on top of Finetrainers and Gradio, and designed to run as a Hugging Face Space (but you can deploy it elsewhere if you want to).
33
+
34
+ ### Full installation at Hugging Face
35
+
36
+ Easy peasy: create a Space (make sure to use the `Gradio` type/template), and push the repo. No Docker needed!
37
+
38
+ ### Dev mode on Hugging Face
39
+
40
+ Enable dev mode in the space, then open VSCode in local or remote and run:
41
+
42
+ ```
43
+ pip install -r requirements.txt
44
+ ```
45
+
46
+ As this is not automatic, then click on "Restart" in the space dev mode UI widget.
47
+
48
+ ### Full installation somewhere else
49
+
50
+ I haven't tested it, but you can try to provided Dockerfile
51
+
52
+ ### Full installation in local
53
+
54
+ the full installation requires:
55
+ - Linux
56
+ - CUDA 12
57
+ - Python 3.10
58
+
59
+ This is because of flash attention, which is defined in the `requirements.txt` using an URL to download a prebuilt wheel (python bindings for a native library)
60
+
61
+ ```bash
62
+ ./setup.sh
63
+ ```
64
+
65
+ ### Degraded installation in local
66
+
67
+ If you cannot meet the requirements, you can:
68
+
69
+ - solution 1: fix requirements.txt to use another prebuilt wheel
70
+ - solution 2: manually build/install flash attention
71
+ - solution 3: don't use clip captioning
72
+
73
+ Here is how to do solution 3:
74
+ ```bash
75
+ ./setup_no_captions.sh
76
+ ```
77
+
78
+ ## Run
79
+
80
+ ### Running the Gradio app
81
+
82
+ Note: please make sure you properly define the environment variables for `STORAGE_PATH` (eg. `/data/`) and `HF_HOME` (eg. `/data/huggingface/`)
83
+
84
+ ```bash
85
+ python app.py
86
+ ```
87
+
88
+ ### Running locally
89
+
90
+ See above remarks about the environment variable.
91
+
92
+ By default `run.sh` will store stuff in `.data/` (located inside the current working directory):
93
+
94
+ ```bash
95
+ ./run.sh
96
+ ```
accelerate_configs/compiled_1.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: 'NO'
4
+ downcast_bf16: 'no'
5
+ dynamo_config:
6
+ dynamo_backend: INDUCTOR
7
+ dynamo_mode: max-autotune
8
+ dynamo_use_dynamic: true
9
+ dynamo_use_fullgraph: false
10
+ enable_cpu_affinity: false
11
+ gpu_ids: '3'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: bf16
15
+ num_machines: 1
16
+ num_processes: 1
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
accelerate_configs/deepspeed.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ gradient_accumulation_steps: 1
5
+ gradient_clipping: 1.0
6
+ offload_optimizer_device: cpu
7
+ offload_param_device: cpu
8
+ zero3_init_flag: false
9
+ zero_stage: 2
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ enable_cpu_affinity: false
13
+ machine_rank: 0
14
+ main_training_function: main
15
+ mixed_precision: bf16
16
+ num_machines: 1
17
+ num_processes: 2
18
+ rdzv_backend: static
19
+ same_network: true
20
+ tpu_env: []
21
+ tpu_use_cluster: false
22
+ tpu_use_sudo: false
23
+ use_cpu: false
accelerate_configs/uncompiled_1.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: 'NO'
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: '3'
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 1
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
accelerate_configs/uncompiled_2.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: 0,1
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 2
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
accelerate_configs/uncompiled_8.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 8
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
app.py ADDED
@@ -0,0 +1,1270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ import subprocess
3
+
4
+ #import sys
5
+ #print("python = ", sys.version)
6
+
7
+ # can be "Linux", "Darwin"
8
+ if platform.system() == "Linux":
9
+ # for some reason it says "pip not found"
10
+ # and also "pip3 not found"
11
+ # subprocess.run(
12
+ # "pip install flash-attn --no-build-isolation",
13
+ #
14
+ # # hmm... this should be False, since we are in a CUDA environment, no?
15
+ # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
16
+ #
17
+ # shell=True,
18
+ # )
19
+ pass
20
+
21
+ import gradio as gr
22
+ from pathlib import Path
23
+ import logging
24
+ import mimetypes
25
+ import shutil
26
+ import os
27
+ import traceback
28
+ import asyncio
29
+ import tempfile
30
+ import zipfile
31
+ from typing import Any, Optional, Dict, List, Union, Tuple
32
+ from typing import AsyncGenerator
33
+ from training_service import TrainingService
34
+ from captioning_service import CaptioningService
35
+ from splitting_service import SplittingService
36
+ from import_service import ImportService
37
+ from config import (
38
+ STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
39
+ TRAINING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
40
+ DEFAULT_PROMPT_PREFIX, HF_API_TOKEN, ASK_USER_TO_DUPLICATE_SPACE, MODEL_TYPES, TRAINING_BUCKETS
41
+ )
42
+ from utils import make_archive, count_media_files, format_media_title, is_image_file, is_video_file, validate_model_repo, format_time
43
+ from finetrainers_utils import copy_files_to_training_dir, prepare_finetrainers_dataset
44
+ from training_log_parser import TrainingLogParser
45
+
46
+ logger = logging.getLogger(__name__)
47
+ logger.setLevel(logging.INFO)
48
+
49
+ httpx_logger = logging.getLogger('httpx')
50
+ httpx_logger.setLevel(logging.WARN)
51
+
52
+
53
+ class VideoTrainerUI:
54
+ def __init__(self):
55
+ self.trainer = TrainingService()
56
+ self.splitter = SplittingService()
57
+ self.importer = ImportService()
58
+ self.captioner = CaptioningService()
59
+ self._should_stop_captioning = False
60
+ self.log_parser = TrainingLogParser()
61
+
62
+ def update_training_ui(self, training_state: Dict[str, Any]):
63
+ """Update UI components based on training state"""
64
+ updates = {}
65
+
66
+ # Update status box with high-level information
67
+ status_text = []
68
+ if training_state["status"] != "idle":
69
+ status_text.extend([
70
+ f"Status: {training_state['status']}",
71
+ f"Progress: {training_state['progress']}",
72
+ f"Step: {training_state['current_step']}/{training_state['total_steps']}",
73
+
74
+ # Epoch information
75
+ # there is an issue with how epoch is reported because we display:
76
+ # Progress: 96.9%, Step: 872/900, Epoch: 12/50
77
+ # we should probably just show the steps
78
+ #f"Epoch: {training_state['current_epoch']}/{training_state['total_epochs']}",
79
+
80
+ f"Time elapsed: {training_state['elapsed']}",
81
+ f"Estimated remaining: {training_state['remaining']}",
82
+ "",
83
+ f"Current loss: {training_state['step_loss']}",
84
+ f"Learning rate: {training_state['learning_rate']}",
85
+ f"Gradient norm: {training_state['grad_norm']}",
86
+ f"Memory usage: {training_state['memory']}"
87
+ ])
88
+
89
+ if training_state["error_message"]:
90
+ status_text.append(f"\nError: {training_state['error_message']}")
91
+
92
+ updates["status_box"] = "\n".join(status_text)
93
+
94
+ # Update button states
95
+ updates["start_btn"] = gr.Button(
96
+ "Start training",
97
+ interactive=(training_state["status"] in ["idle", "completed", "error", "stopped"]),
98
+ variant="primary" if training_state["status"] == "idle" else "secondary"
99
+ )
100
+
101
+ updates["stop_btn"] = gr.Button(
102
+ "Stop training",
103
+ interactive=(training_state["status"] in ["training", "initializing"]),
104
+ variant="stop"
105
+ )
106
+
107
+ return updates
108
+
109
+ def stop_all_and_clear(self) -> Dict[str, str]:
110
+ """Stop all running processes and clear data
111
+
112
+ Returns:
113
+ Dict with status messages for different components
114
+ """
115
+ status_messages = {}
116
+
117
+ try:
118
+ # Stop training if running
119
+ if self.trainer.is_training_running():
120
+ training_result = self.trainer.stop_training()
121
+ status_messages["training"] = training_result["status"]
122
+
123
+ # Stop captioning if running
124
+ if self.captioner:
125
+ self.captioner.stop_captioning()
126
+ #self.captioner.close()
127
+ #self.captioner = None
128
+ status_messages["captioning"] = "Captioning stopped"
129
+
130
+ # Stop scene detection if running
131
+ if self.splitter.is_processing():
132
+ self.splitter.processing = False
133
+ status_messages["splitting"] = "Scene detection stopped"
134
+
135
+ # Clear all data directories
136
+ for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH,
137
+ MODEL_PATH, OUTPUT_PATH]:
138
+ if path.exists():
139
+ try:
140
+ shutil.rmtree(path)
141
+ path.mkdir(parents=True, exist_ok=True)
142
+ except Exception as e:
143
+ status_messages[f"clear_{path.name}"] = f"Error clearing {path.name}: {str(e)}"
144
+ else:
145
+ status_messages[f"clear_{path.name}"] = f"Cleared {path.name}"
146
+
147
+ # Reset any persistent state
148
+ self._should_stop_captioning = True
149
+ self.splitter.processing = False
150
+
151
+ return {
152
+ "status": "All processes stopped and data cleared",
153
+ "details": status_messages
154
+ }
155
+
156
+ except Exception as e:
157
+ return {
158
+ "status": f"Error during cleanup: {str(e)}",
159
+ "details": status_messages
160
+ }
161
+
162
+ def update_titles(self) -> Tuple[Any]:
163
+ """Update all dynamic titles with current counts
164
+
165
+ Returns:
166
+ Dict of Gradio updates
167
+ """
168
+ # Count files for splitting
169
+ split_videos, _, split_size = count_media_files(VIDEOS_TO_SPLIT_PATH)
170
+ split_title = format_media_title(
171
+ "split", split_videos, 0, split_size
172
+ )
173
+
174
+ # Count files for captioning
175
+ caption_videos, caption_images, caption_size = count_media_files(STAGING_PATH)
176
+ caption_title = format_media_title(
177
+ "caption", caption_videos, caption_images, caption_size
178
+ )
179
+
180
+ # Count files for training
181
+ train_videos, train_images, train_size = count_media_files(TRAINING_VIDEOS_PATH)
182
+ train_title = format_media_title(
183
+ "train", train_videos, train_images, train_size
184
+ )
185
+
186
+ return (
187
+ gr.Markdown(value=split_title),
188
+ gr.Markdown(value=caption_title),
189
+ gr.Markdown(value=f"{train_title} available for training")
190
+ )
191
+
192
+ def copy_files_to_training_dir(self, prompt_prefix: str):
193
+ """Run auto-captioning process"""
194
+
195
+ # Initialize captioner if not already done
196
+ self._should_stop_captioning = False
197
+
198
+ try:
199
+ copy_files_to_training_dir(prompt_prefix)
200
+
201
+ except Exception as e:
202
+ traceback.print_exc()
203
+ raise gr.Error(f"Error copying assets to training dir: {str(e)}")
204
+
205
+ async def start_caption_generation(self, captioning_bot_instructions: str, prompt_prefix: str) -> AsyncGenerator[gr.update, None]:
206
+ """Run auto-captioning process"""
207
+ try:
208
+ # Initialize captioner if not already done
209
+ self._should_stop_captioning = False
210
+
211
+ async for rows in self.captioner.start_caption_generation(captioning_bot_instructions, prompt_prefix):
212
+ # Yield UI update
213
+ yield gr.update(
214
+ value=rows,
215
+ headers=["name", "status"]
216
+ )
217
+
218
+ # Final update after completion
219
+ yield gr.update(
220
+ value=self.list_training_files_to_caption(),
221
+ headers=["name", "status"]
222
+ )
223
+
224
+ except Exception as e:
225
+ yield gr.update(
226
+ value=[[str(e), "error"]],
227
+ headers=["name", "status"]
228
+ )
229
+
230
+ def list_training_files_to_caption(self) -> List[List[str]]:
231
+ """List all clips and images - both pending and captioned"""
232
+ files = []
233
+ already_listed: Dict[str, bool] = {}
234
+
235
+ # Check files in STAGING_PATH
236
+ for file in STAGING_PATH.glob("*.*"):
237
+ if is_video_file(file) or is_image_file(file):
238
+ txt_file = file.with_suffix('.txt')
239
+ status = "captioned" if txt_file.exists() else "no caption"
240
+ file_type = "video" if is_video_file(file) else "image"
241
+ files.append([file.name, f"{status} ({file_type})", str(file)])
242
+ already_listed[str(file.name)] = True
243
+
244
+ # Check files in TRAINING_VIDEOS_PATH
245
+ for file in TRAINING_VIDEOS_PATH.glob("*.*"):
246
+ if not str(file.name) in already_listed:
247
+ if is_video_file(file) or is_image_file(file):
248
+ txt_file = file.with_suffix('.txt')
249
+ if txt_file.exists():
250
+ file_type = "video" if is_video_file(file) else "image"
251
+ files.append([file.name, f"captioned ({file_type})", str(file)])
252
+
253
+ # Sort by filename
254
+ files.sort(key=lambda x: x[0])
255
+
256
+ # Only return name and status columns for display
257
+ return [[file[0], file[1]] for file in files]
258
+
259
+ def update_training_buttons(self, training_state: Dict[str, Any]) -> Dict:
260
+ """Update training control buttons based on state"""
261
+ is_training = training_state["status"] in ["training", "initializing"]
262
+ is_paused = training_state["status"] == "paused"
263
+ is_completed = training_state["status"] in ["completed", "error", "stopped"]
264
+
265
+ return {
266
+ start_btn: gr.Button(
267
+ interactive=not is_training and not is_paused,
268
+ variant="primary" if not is_training else "secondary",
269
+ ),
270
+ stop_btn: gr.Button(
271
+ interactive=is_training or is_paused,
272
+ variant="stop",
273
+ ),
274
+ pause_resume_btn: gr.Button(
275
+ value="Resume Training" if is_paused else "Pause Training",
276
+ interactive=(is_training or is_paused) and not is_completed,
277
+ variant="secondary",
278
+ )
279
+ }
280
+
281
+ def handle_training_complete(self):
282
+ """Handle training completion"""
283
+ # Reset button states
284
+ return self.update_training_buttons({
285
+ "status": "completed",
286
+ "progress": "100%",
287
+ "current_step": 0,
288
+ "total_steps": 0
289
+ })
290
+
291
+ def handle_pause_resume(self):
292
+ status = self.trainer.get_status()
293
+ if status["state"] == "paused":
294
+ result = self.trainer.resume_training()
295
+ new_state = {"status": "training"}
296
+ else:
297
+ result = self.trainer.pause_training()
298
+ new_state = {"status": "paused"}
299
+ return (
300
+ *result,
301
+ *self.update_training_buttons(new_state).values()
302
+ )
303
+
304
+
305
+ def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str]]:
306
+ """Handle selection of both video clips and images"""
307
+ try:
308
+ if not evt:
309
+ return [
310
+ gr.Image(
311
+ interactive=False,
312
+ visible=False
313
+ ),
314
+ gr.Video(
315
+ interactive=False,
316
+ visible=False
317
+ ),
318
+ gr.Textbox(
319
+ visible=False
320
+ ),
321
+ "No file selected"
322
+ ]
323
+
324
+ file_name = evt.value
325
+ if not file_name:
326
+ return [
327
+ gr.Image(
328
+ interactive=False,
329
+ visible=False
330
+ ),
331
+ gr.Video(
332
+ interactive=False,
333
+ visible=False
334
+ ),
335
+ gr.Textbox(
336
+ visible=False
337
+ ),
338
+ "No file selected"
339
+ ]
340
+
341
+ # Check both possible locations for the file
342
+ possible_paths = [
343
+ STAGING_PATH / file_name,
344
+
345
+ # note: we use to look into this dir for already-captioned clips,
346
+ # but we don't do this anymore
347
+ #TRAINING_VIDEOS_PATH / file_name
348
+ ]
349
+
350
+ # Find the first existing file path
351
+ file_path = None
352
+ for path in possible_paths:
353
+ if path.exists():
354
+ file_path = path
355
+ break
356
+
357
+ if not file_path:
358
+ return [
359
+ gr.Image(
360
+ interactive=False,
361
+ visible=False
362
+ ),
363
+ gr.Video(
364
+ interactive=False,
365
+ visible=False
366
+ ),
367
+ gr.Textbox(
368
+ visible=False
369
+ ),
370
+ f"File not found: {file_name}"
371
+ ]
372
+
373
+ txt_path = file_path.with_suffix('.txt')
374
+ caption = txt_path.read_text() if txt_path.exists() else ""
375
+
376
+ # Handle video files
377
+ if is_video_file(file_path):
378
+ return [
379
+ gr.Image(
380
+ interactive=False,
381
+ visible=False
382
+ ),
383
+ gr.Video(
384
+ label="Video Preview",
385
+ interactive=False,
386
+ visible=True,
387
+ value=str(file_path)
388
+ ),
389
+ gr.Textbox(
390
+ label="Caption",
391
+ lines=6,
392
+ interactive=True,
393
+ visible=True,
394
+ value=str(caption)
395
+ ),
396
+ None
397
+ ]
398
+ # Handle image files
399
+ elif is_image_file(file_path):
400
+ return [
401
+ gr.Image(
402
+ label="Image Preview",
403
+ interactive=False,
404
+ visible=True,
405
+ value=str(file_path)
406
+ ),
407
+ gr.Video(
408
+ interactive=False,
409
+ visible=False
410
+ ),
411
+ gr.Textbox(
412
+ label="Caption",
413
+ lines=6,
414
+ interactive=True,
415
+ visible=True,
416
+ value=str(caption)
417
+ ),
418
+ None
419
+ ]
420
+ else:
421
+ return [
422
+ gr.Image(
423
+ interactive=False,
424
+ visible=False
425
+ ),
426
+ gr.Video(
427
+ interactive=False,
428
+ visible=False
429
+ ),
430
+ gr.Textbox(
431
+ interactive=False,
432
+ visible=False
433
+ ),
434
+ f"Unsupported file type: {file_path.suffix}"
435
+ ]
436
+ except Exception as e:
437
+ logger.error(f"Error handling selection: {str(e)}")
438
+ return [
439
+ gr.Image(
440
+ interactive=False,
441
+ visible=False
442
+ ),
443
+ gr.Video(
444
+ interactive=False,
445
+ visible=False
446
+ ),
447
+ gr.Textbox(
448
+ interactive=False,
449
+ visible=False
450
+ ),
451
+ f"Error handling selection: {str(e)}"
452
+ ]
453
+
454
+ def save_caption_changes(self, preview_caption: str, preview_image: str, preview_video: str, prompt_prefix: str):
455
+ """Save changes to caption"""
456
+ try:
457
+ # Add prefix if not already present
458
+ if prompt_prefix and not preview_caption.startswith(prompt_prefix):
459
+ full_caption = f"{prompt_prefix}{preview_caption}"
460
+ else:
461
+ full_caption = preview_caption
462
+
463
+ path = Path(preview_video if preview_video else preview_image)
464
+ if path.suffix == '.txt':
465
+ self.trainer.update_file_caption(path.with_suffix(''), full_caption)
466
+ else:
467
+ self.trainer.update_file_caption(path, full_caption)
468
+ return gr.update(value="Caption saved successfully!")
469
+ except Exception as e:
470
+ return gr.update(value=f"Error saving caption: {str(e)}")
471
+
472
+ def get_model_info(self, model_type: str) -> str:
473
+ """Get information about the selected model type"""
474
+ if model_type == "hunyuan_video":
475
+ return """### HunyuanVideo (LoRA)
476
+ - Best for learning complex video generation patterns
477
+ - Required VRAM: ~47GB minimum
478
+ - Recommended batch size: 1-2
479
+ - Typical training time: 2-4 hours
480
+ - Default resolution: 49x512x768
481
+ - Default LoRA rank: 128"""
482
+
483
+ elif model_type == "ltx_video":
484
+ return """### LTX-Video (LoRA)
485
+ - Lightweight video model
486
+ - Required VRAM: ~18GB minimum
487
+ - Recommended batch size: 1-4
488
+ - Typical training time: 1-3 hours
489
+ - Default resolution: 49x512x768
490
+ - Default LoRA rank: 128"""
491
+
492
+ return ""
493
+
494
+ def get_default_params(self, model_type: str) -> Dict[str, Any]:
495
+ """Get default training parameters for model type"""
496
+ if model_type == "hunyuan_video":
497
+ return {
498
+ "num_epochs": 70,
499
+ "batch_size": 1,
500
+ "learning_rate": 2e-5,
501
+ "save_iterations": 500,
502
+ "video_resolution_buckets": TRAINING_BUCKETS,
503
+ "video_reshape_mode": "center",
504
+ "caption_dropout_p": 0.05,
505
+ "gradient_accumulation_steps": 1,
506
+ "rank": 128,
507
+ "lora_alpha": 128
508
+ }
509
+ else: # ltx_video
510
+ return {
511
+ "num_epochs": 70,
512
+ "batch_size": 1,
513
+ "learning_rate": 3e-5,
514
+ "save_iterations": 500,
515
+ "video_resolution_buckets": TRAINING_BUCKETS,
516
+ "video_reshape_mode": "center",
517
+ "caption_dropout_p": 0.05,
518
+ "gradient_accumulation_steps": 4,
519
+ "rank": 128,
520
+ "lora_alpha": 128
521
+ }
522
+
523
+ def preview_file(self, selected_text: str) -> Dict:
524
+ """Generate preview based on selected file
525
+
526
+ Args:
527
+ selected_text: Text of the selected item containing filename
528
+
529
+ Returns:
530
+ Dict with preview content for each preview component
531
+ """
532
+ if not selected_text or "Caption:" in selected_text:
533
+ return {
534
+ "video": None,
535
+ "image": None,
536
+ "text": None
537
+ }
538
+
539
+ # Extract filename from the preview text (remove size info)
540
+ filename = selected_text.split(" (")[0].strip()
541
+ file_path = TRAINING_VIDEOS_PATH / filename
542
+
543
+ if not file_path.exists():
544
+ return {
545
+ "video": None,
546
+ "image": None,
547
+ "text": f"File not found: {filename}"
548
+ }
549
+
550
+ # Detect file type
551
+ mime_type, _ = mimetypes.guess_type(str(file_path))
552
+ if not mime_type:
553
+ return {
554
+ "video": None,
555
+ "image": None,
556
+ "text": f"Unknown file type: {filename}"
557
+ }
558
+
559
+ # Return appropriate preview
560
+ if mime_type.startswith('video/'):
561
+ return {
562
+ "video": str(file_path),
563
+ "image": None,
564
+ "text": None
565
+ }
566
+ elif mime_type.startswith('image/'):
567
+ return {
568
+ "video": None,
569
+ "image": str(file_path),
570
+ "text": None
571
+ }
572
+ elif mime_type.startswith('text/'):
573
+ try:
574
+ text_content = file_path.read_text()
575
+ return {
576
+ "video": None,
577
+ "image": None,
578
+ "text": text_content
579
+ }
580
+ except Exception as e:
581
+ return {
582
+ "video": None,
583
+ "image": None,
584
+ "text": f"Error reading file: {str(e)}"
585
+ }
586
+ else:
587
+ return {
588
+ "video": None,
589
+ "image": None,
590
+ "text": f"Unsupported file type: {mime_type}"
591
+ }
592
+
593
+ def list_unprocessed_videos(self) -> gr.Dataframe:
594
+ """Update list of unprocessed videos"""
595
+ videos = self.splitter.list_unprocessed_videos()
596
+ # videos is already in [[name, status]] format from splitting_service
597
+ return gr.Dataframe(
598
+ headers=["name", "status"],
599
+ value=videos,
600
+ interactive=False
601
+ )
602
+
603
+ async def start_scene_detection(self, enable_splitting: bool) -> str:
604
+ """Start background scene detection process
605
+
606
+ Args:
607
+ enable_splitting: Whether to split videos into scenes
608
+ """
609
+ if self.splitter.is_processing():
610
+ return "Scene detection already running"
611
+
612
+ try:
613
+ await self.splitter.start_processing(enable_splitting)
614
+ return "Scene detection completed"
615
+ except Exception as e:
616
+ return f"Error during scene detection: {str(e)}"
617
+
618
+
619
+ def refresh_training_status_and_logs(self):
620
+ """Refresh all dynamic lists and training state"""
621
+ status = self.trainer.get_status()
622
+ logs = self.trainer.get_logs()
623
+
624
+ status_update = status["message"]
625
+
626
+ # Parse new log lines
627
+ if logs:
628
+ last_state = None
629
+ for line in logs.splitlines():
630
+ state_update = self.log_parser.parse_line(line)
631
+ if state_update:
632
+ last_state = state_update
633
+
634
+ if last_state:
635
+ ui_updates = self.update_training_ui(last_state)
636
+ status_update = ui_updates.get("status_box", status["message"])
637
+
638
+ return (status_update, logs)
639
+
640
+ def refresh_training_status(self):
641
+ """Refresh training status and update UI"""
642
+ status, logs = self.refresh_training_status_and_logs()
643
+
644
+ # Parse status for training state
645
+ is_completed = "completed" in status.lower() or "100.0%" in status
646
+ current_state = {
647
+ "status": "completed" if is_completed else "training",
648
+ "message": status
649
+ }
650
+
651
+ if is_completed:
652
+ button_updates = self.handle_training_complete()
653
+ return (
654
+ status,
655
+ logs,
656
+ *button_updates.values()
657
+ )
658
+
659
+ # Update based on current training state
660
+ button_updates = self.update_training_buttons(current_state)
661
+ return (
662
+ status,
663
+ logs,
664
+ *button_updates.values()
665
+ )
666
+
667
+ def refresh_dataset(self):
668
+ """Refresh all dynamic lists and training state"""
669
+ video_list = self.splitter.list_unprocessed_videos()
670
+ training_dataset = self.list_training_files_to_caption()
671
+
672
+ return (
673
+ video_list,
674
+ training_dataset
675
+ )
676
+
677
+ def create_ui(self):
678
+ """Create Gradio interface"""
679
+
680
+ with gr.Blocks(title="🎥 Video Model Studio") as app:
681
+ gr.Markdown("# 🎥 Video Model Studio")
682
+
683
+ with gr.Tabs() as tabs:
684
+ with gr.TabItem("1️⃣ Import", id="import_tab"):
685
+
686
+ with gr.Row():
687
+ gr.Markdown("## Optional: automated data cleaning")
688
+
689
+ with gr.Row():
690
+ enable_automatic_video_split = gr.Checkbox(
691
+ label="Automatically split videos into smaller clips",
692
+ info="Note: a clip is a single camera shot, usually a few seconds",
693
+ value=True,
694
+ visible=False
695
+ )
696
+ enable_automatic_content_captioning = gr.Checkbox(
697
+ label="Automatically caption photos and videos",
698
+ info="Note: this uses LlaVA and takes some extra time to load and process",
699
+ value=False,
700
+ visible=False,
701
+ )
702
+
703
+ with gr.Row():
704
+ with gr.Column(scale=3):
705
+ with gr.Row():
706
+ with gr.Column():
707
+ gr.Markdown("## Import video files")
708
+ gr.Markdown("You can upload either:")
709
+ gr.Markdown("- A single MP4 video file")
710
+ gr.Markdown("- A ZIP archive containing multiple videos and optional caption files")
711
+ gr.Markdown("For ZIP files: Create a folder containing videos (name is not important) and optional caption files with the same name (eg. `some_video.txt` for `some_video.mp4`)")
712
+
713
+ with gr.Row():
714
+ files = gr.Files(
715
+ label="Upload Images, Videos or ZIP",
716
+ #file_count="multiple",
717
+ file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip"],
718
+ type="filepath"
719
+ )
720
+
721
+ with gr.Column(scale=3):
722
+ with gr.Row():
723
+ with gr.Column():
724
+ gr.Markdown("## Import a YouTube video")
725
+ gr.Markdown("You can also use a YouTube video as reference, by pasting its URL here:")
726
+
727
+ with gr.Row():
728
+ youtube_url = gr.Textbox(
729
+ label="Import YouTube Video",
730
+ placeholder="https://www.youtube.com/watch?v=..."
731
+ )
732
+ with gr.Row():
733
+ youtube_download_btn = gr.Button("Download YouTube Video", variant="secondary")
734
+ with gr.Row():
735
+ import_status = gr.Textbox(label="Status", interactive=False)
736
+
737
+
738
+ with gr.TabItem("2️⃣ Split", id="split_tab"):
739
+ with gr.Row():
740
+ split_title = gr.Markdown("## Splitting of 0 videos (0 bytes)")
741
+
742
+ with gr.Row():
743
+ with gr.Column():
744
+ detect_btn = gr.Button("Split videos into single-camera shots", variant="primary")
745
+ detect_status = gr.Textbox(label="Status", interactive=False)
746
+
747
+ with gr.Column():
748
+
749
+ video_list = gr.Dataframe(
750
+ headers=["name", "status"],
751
+ label="Videos to split",
752
+ interactive=False,
753
+ wrap=True,
754
+ #selection_mode="cell" # Enable cell selection
755
+ )
756
+
757
+
758
+ with gr.TabItem("3️⃣ Caption"):
759
+ with gr.Row():
760
+ caption_title = gr.Markdown("## Captioning of 0 files (0 bytes)")
761
+
762
+ with gr.Row():
763
+
764
+ with gr.Column():
765
+ with gr.Row():
766
+ custom_prompt_prefix = gr.Textbox(
767
+ scale=3,
768
+ label='Prefix to add to ALL captions (eg. "In the style of TOK, ")',
769
+ placeholder="In the style of TOK, ",
770
+ lines=2,
771
+ value=DEFAULT_PROMPT_PREFIX
772
+ )
773
+ captioning_bot_instructions = gr.Textbox(
774
+ scale=6,
775
+ label="System instructions for the automatic captioning model",
776
+ placeholder="Please generate a full description of...",
777
+ lines=5,
778
+ value=DEFAULT_CAPTIONING_BOT_INSTRUCTIONS
779
+ )
780
+ with gr.Row():
781
+ run_autocaption_btn = gr.Button(
782
+ "Automatically fill missing captions",
783
+ variant="primary" # Makes it green by default
784
+ )
785
+ copy_files_to_training_dir_btn = gr.Button(
786
+ "Copy assets to training directory",
787
+ variant="primary" # Makes it green by default
788
+ )
789
+ stop_autocaption_btn = gr.Button(
790
+ "Stop Captioning",
791
+ variant="stop", # Red when enabled
792
+ interactive=False # Disabled by default
793
+ )
794
+
795
+ with gr.Row():
796
+ with gr.Column():
797
+ training_dataset = gr.Dataframe(
798
+ headers=["name", "status"],
799
+ interactive=False,
800
+ wrap=True,
801
+ value=self.list_training_files_to_caption(),
802
+ row_count=10, # Optional: set a reasonable row count
803
+ #selection_mode="cell"
804
+ )
805
+
806
+ with gr.Column():
807
+ preview_video = gr.Video(
808
+ label="Video Preview",
809
+ interactive=False,
810
+ visible=False
811
+ )
812
+ preview_image = gr.Image(
813
+ label="Image Preview",
814
+ interactive=False,
815
+ visible=False
816
+ )
817
+ preview_caption = gr.Textbox(
818
+ label="Caption",
819
+ lines=6,
820
+ interactive=True
821
+ )
822
+ save_caption_btn = gr.Button("Save Caption")
823
+ preview_status = gr.Textbox(
824
+ label="Status",
825
+ interactive=False,
826
+ visible=True
827
+ )
828
+
829
+ with gr.TabItem("4️⃣ Train"):
830
+ with gr.Row():
831
+ with gr.Column():
832
+
833
+ with gr.Row():
834
+ train_title = gr.Markdown("## 0 files available for training (0 bytes)")
835
+
836
+ with gr.Row():
837
+ with gr.Column():
838
+ model_type = gr.Dropdown(
839
+ choices=list(MODEL_TYPES.keys()),
840
+ label="Model Type",
841
+ value=list(MODEL_TYPES.keys())[0]
842
+ )
843
+ model_info = gr.Markdown(
844
+ value=self.get_model_info(list(MODEL_TYPES.keys())[0])
845
+ )
846
+
847
+ with gr.Row():
848
+ lora_rank = gr.Dropdown(
849
+ label="LoRA Rank",
850
+ choices=["16", "32", "64", "128", "256"],
851
+ value="128",
852
+ type="value"
853
+ )
854
+ lora_alpha = gr.Dropdown(
855
+ label="LoRA Alpha",
856
+ choices=["16", "32", "64", "128", "256"],
857
+ value="128",
858
+ type="value"
859
+ )
860
+ with gr.Row():
861
+ num_epochs = gr.Number(
862
+ label="Number of Epochs",
863
+ value=70,
864
+ minimum=1,
865
+ precision=0
866
+ )
867
+ batch_size = gr.Number(
868
+ label="Batch Size",
869
+ value=1,
870
+ minimum=1,
871
+ precision=0
872
+ )
873
+ with gr.Row():
874
+ learning_rate = gr.Number(
875
+ label="Learning Rate",
876
+ value=2e-5,
877
+ minimum=1e-7
878
+ )
879
+ save_iterations = gr.Number(
880
+ label="Save checkpoint every N iterations",
881
+ value=500,
882
+ minimum=50,
883
+ precision=0,
884
+ info="Model will be saved periodically after these many steps"
885
+ )
886
+
887
+ with gr.Column():
888
+ with gr.Row():
889
+ start_btn = gr.Button(
890
+ "Start Training",
891
+ variant="primary",
892
+ interactive=not ASK_USER_TO_DUPLICATE_SPACE
893
+ )
894
+ pause_resume_btn = gr.Button(
895
+ "Resume Training",
896
+ variant="secondary",
897
+ interactive=False
898
+ )
899
+ stop_btn = gr.Button(
900
+ "Stop Training",
901
+ variant="stop",
902
+ interactive=False
903
+ )
904
+
905
+ with gr.Row():
906
+ with gr.Column():
907
+ status_box = gr.Textbox(
908
+ label="Training Status",
909
+ interactive=False,
910
+ lines=4
911
+ )
912
+ log_box = gr.TextArea(
913
+ label="Training Logs",
914
+ interactive=False,
915
+ lines=10,
916
+ max_lines=40,
917
+ autoscroll=True
918
+ )
919
+
920
+ with gr.TabItem("5️⃣ Manage"):
921
+
922
+ with gr.Column():
923
+ with gr.Row():
924
+ with gr.Column():
925
+ gr.Markdown("## Publishing")
926
+ gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)")
927
+
928
+ with gr.Row():
929
+
930
+ with gr.Column():
931
+ repo_id = gr.Textbox(
932
+ label="HuggingFace Model Repository",
933
+ placeholder="username/model-name",
934
+ info="The repository will be created if it doesn't exist"
935
+ )
936
+ gr.Checkbox(label="Check this to make your model public (ie. visible and downloadable by anyone)", info="You model is private by default"),
937
+ global_stop_btn = gr.Button(
938
+ "Push my model",
939
+ #variant="stop"
940
+ )
941
+
942
+
943
+ with gr.Row():
944
+ with gr.Column():
945
+ with gr.Row():
946
+ with gr.Column():
947
+ gr.Markdown("## Storage management")
948
+ with gr.Row():
949
+ download_dataset_btn = gr.DownloadButton(
950
+ "Download dataset",
951
+ variant="secondary",
952
+ size="lg"
953
+ )
954
+ download_model_btn = gr.DownloadButton(
955
+ "Download model",
956
+ variant="secondary",
957
+ size="lg"
958
+ )
959
+
960
+
961
+ with gr.Row():
962
+ global_stop_btn = gr.Button(
963
+ "Stop everything and delete my data",
964
+ variant="stop"
965
+ )
966
+ global_status = gr.Textbox(
967
+ label="Global Status",
968
+ interactive=False,
969
+ visible=False
970
+ )
971
+
972
+
973
+
974
+ # Event handlers
975
+ def update_model_info(model):
976
+ params = self.get_default_params(MODEL_TYPES[model])
977
+ info = self.get_model_info(MODEL_TYPES[model])
978
+ return {
979
+ model_info: info,
980
+ num_epochs: params["num_epochs"],
981
+ batch_size: params["batch_size"],
982
+ learning_rate: params["learning_rate"],
983
+ save_iterations: params["save_iterations"]
984
+ }
985
+
986
+ def validate_repo(repo_id: str) -> dict:
987
+ validation = validate_model_repo(repo_id)
988
+ if validation["error"]:
989
+ return gr.update(value=repo_id, error=validation["error"])
990
+ return gr.update(value=repo_id, error=None)
991
+
992
+ # Connect events
993
+ model_type.change(
994
+ fn=update_model_info,
995
+ inputs=[model_type],
996
+ outputs=[model_info, num_epochs, batch_size, learning_rate, save_iterations]
997
+ )
998
+
999
+ async def on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix):
1000
+ videos = self.list_unprocessed_videos()
1001
+ # If scene detection isn't already running and there are videos to process,
1002
+ # and auto-splitting is enabled, start the detection
1003
+ if videos and not self.splitter.is_processing() and enable_splitting:
1004
+ await self.start_scene_detection(enable_splitting)
1005
+ msg = "Starting automatic scene detection..."
1006
+ else:
1007
+ # Just copy files without splitting if auto-split disabled
1008
+ for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
1009
+ await self.splitter.process_video(video_file, enable_splitting=False)
1010
+ msg = "Copying videos without splitting..."
1011
+
1012
+ copy_files_to_training_dir(prompt_prefix)
1013
+
1014
+ # Start auto-captioning if enabled
1015
+ if enable_automatic_content_captioning:
1016
+ await self.start_caption_generation(
1017
+ DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
1018
+ prompt_prefix
1019
+ )
1020
+
1021
+ return {
1022
+ tabs: gr.Tabs(selected="split_tab"),
1023
+ video_list: videos,
1024
+ detect_status: msg
1025
+ }
1026
+
1027
+
1028
+ async def update_titles_after_import(enable_splitting, enable_automatic_content_captioning, prompt_prefix):
1029
+ """Handle post-import updates including titles"""
1030
+ import_result = await on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix)
1031
+ titles = self.update_titles()
1032
+ return (*import_result, *titles)
1033
+
1034
+ files.upload(
1035
+ fn=lambda x: self.importer.process_uploaded_files(x),
1036
+ inputs=[files],
1037
+ outputs=[import_status]
1038
+ ).success(
1039
+ fn=update_titles_after_import,
1040
+ inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix],
1041
+ outputs=[
1042
+ tabs, video_list, detect_status,
1043
+ split_title, caption_title, train_title
1044
+ ]
1045
+ )
1046
+
1047
+ youtube_download_btn.click(
1048
+ fn=self.importer.download_youtube_video,
1049
+ inputs=[youtube_url],
1050
+ outputs=[import_status]
1051
+ ).success(
1052
+ fn=on_import_success,
1053
+ inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix],
1054
+ outputs=[tabs, video_list, detect_status]
1055
+ )
1056
+
1057
+ # Scene detection events
1058
+ detect_btn.click(
1059
+ fn=self.start_scene_detection,
1060
+ inputs=[enable_automatic_video_split],
1061
+ outputs=[detect_status]
1062
+ )
1063
+
1064
+
1065
+ # Update button states based on captioning status
1066
+ def update_button_states(is_running):
1067
+ return {
1068
+ run_autocaption_btn: gr.Button(
1069
+ interactive=not is_running,
1070
+ variant="secondary" if is_running else "primary",
1071
+ ),
1072
+ stop_autocaption_btn: gr.Button(
1073
+ interactive=is_running,
1074
+ variant="secondary",
1075
+ ),
1076
+ }
1077
+
1078
+ run_autocaption_btn.click(
1079
+ fn=self.start_caption_generation,
1080
+ inputs=[captioning_bot_instructions, custom_prompt_prefix],
1081
+ outputs=[training_dataset],
1082
+ ).then(
1083
+ fn=lambda: update_button_states(True),
1084
+ outputs=[run_autocaption_btn, stop_autocaption_btn]
1085
+ )
1086
+
1087
+ copy_files_to_training_dir_btn.click(
1088
+ fn=self.copy_files_to_training_dir,
1089
+ inputs=[custom_prompt_prefix]
1090
+ )
1091
+
1092
+ stop_autocaption_btn.click(
1093
+ fn=lambda: (self.captioner.stop_captioning() if self.captioner else None, update_button_states(False)),
1094
+ outputs=[run_autocaption_btn, stop_autocaption_btn]
1095
+ )
1096
+
1097
+ training_dataset.select(
1098
+ fn=self.handle_training_dataset_select,
1099
+ outputs=[preview_image, preview_video, preview_caption, preview_status]
1100
+ )
1101
+
1102
+ save_caption_btn.click(
1103
+ fn=self.save_caption_changes,
1104
+ inputs=[preview_caption, preview_image, preview_video, custom_prompt_prefix],
1105
+ outputs=[preview_status]
1106
+ ).success(
1107
+ fn=self.list_training_files_to_caption,
1108
+ outputs=[training_dataset]
1109
+ )
1110
+
1111
+ # Training control events
1112
+ start_btn.click(
1113
+ fn=lambda model_type, *args: (
1114
+ self.log_parser.reset(),
1115
+ self.trainer.start_training(
1116
+ MODEL_TYPES[model_type],
1117
+ *args
1118
+ )
1119
+ ),
1120
+ inputs=[
1121
+ model_type,
1122
+ lora_rank,
1123
+ lora_alpha,
1124
+ num_epochs,
1125
+ batch_size,
1126
+ learning_rate,
1127
+ save_iterations,
1128
+ repo_id
1129
+ ],
1130
+ outputs=[status_box, log_box]
1131
+ ).success(
1132
+ fn=lambda: self.update_training_buttons({
1133
+ "status": "training"
1134
+ }),
1135
+ outputs=[start_btn, stop_btn, pause_resume_btn]
1136
+ )
1137
+
1138
+
1139
+ pause_resume_btn.click(
1140
+ fn=self.handle_pause_resume,
1141
+ outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
1142
+ )
1143
+
1144
+ stop_btn.click(
1145
+ fn=self.trainer.stop_training,
1146
+ outputs=[status_box, log_box]
1147
+ ).success(
1148
+ fn=self.handle_training_complete,
1149
+ outputs=[start_btn, stop_btn, pause_resume_btn]
1150
+ )
1151
+
1152
+ def handle_global_stop():
1153
+ result = self.stop_all_and_clear()
1154
+ # Update all relevant UI components
1155
+ status = result["status"]
1156
+ details = "\n".join(f"{k}: {v}" for k, v in result["details"].items())
1157
+ full_status = f"{status}\n\nDetails:\n{details}"
1158
+
1159
+ # Get fresh lists after cleanup
1160
+ videos = self.splitter.list_unprocessed_videos()
1161
+ clips = self.list_training_files_to_caption()
1162
+
1163
+ return {
1164
+ global_status: gr.update(value=full_status, visible=True),
1165
+ video_list: videos,
1166
+ training_dataset: clips,
1167
+ status_box: "Training stopped and data cleared",
1168
+ log_box: "",
1169
+ detect_status: "Scene detection stopped",
1170
+ import_status: "All data cleared",
1171
+ preview_status: "Captioning stopped"
1172
+ }
1173
+
1174
+ download_dataset_btn.click(
1175
+ fn=self.trainer.create_training_dataset_zip,
1176
+ outputs=[download_dataset_btn]
1177
+ )
1178
+
1179
+ download_model_btn.click(
1180
+ fn=self.trainer.get_model_output_safetensors,
1181
+ outputs=[download_model_btn]
1182
+ )
1183
+
1184
+ global_stop_btn.click(
1185
+ fn=handle_global_stop,
1186
+ outputs=[
1187
+ global_status,
1188
+ video_list,
1189
+ training_dataset,
1190
+ status_box,
1191
+ log_box,
1192
+ detect_status,
1193
+ import_status,
1194
+ preview_status
1195
+ ]
1196
+ )
1197
+
1198
+ # Auto-refresh timers
1199
+ app.load(
1200
+ fn=lambda: (
1201
+ self.refresh_dataset()
1202
+ ),
1203
+ outputs=[
1204
+ video_list, training_dataset
1205
+ ]
1206
+ )
1207
+
1208
+ timer = gr.Timer(value=1)
1209
+ timer.tick(
1210
+ fn=lambda: (
1211
+ self.refresh_training_status_and_logs()
1212
+ ),
1213
+ outputs=[
1214
+ status_box,
1215
+ log_box
1216
+ ]
1217
+ )
1218
+
1219
+ timer = gr.Timer(value=5)
1220
+ timer.tick(
1221
+ fn=lambda: (
1222
+ self.refresh_dataset()
1223
+ ),
1224
+ outputs=[
1225
+ video_list, training_dataset
1226
+ ]
1227
+ )
1228
+
1229
+ timer = gr.Timer(value=5)
1230
+ timer.tick(
1231
+ fn=lambda: self.update_titles(),
1232
+ outputs=[
1233
+ split_title, caption_title, train_title
1234
+ ]
1235
+ )
1236
+
1237
+ return app
1238
+
1239
+ def create_app():
1240
+ if ASK_USER_TO_DUPLICATE_SPACE:
1241
+ with gr.Blocks() as app:
1242
+ gr.Markdown("""# Finetrainers UI
1243
+
1244
+ This Hugging Face space needs to be duplicated to your own billing account to work.
1245
+
1246
+ Click the 'Duplicate Space' button at the top of the page to create your own copy.
1247
+
1248
+ It is recommended to use a Nvidia L40S and a persistent storage space.
1249
+ To avoid overpaying for your space, you can configure the auto-sleep settings to fit your personal budget.""")
1250
+ return app
1251
+
1252
+ ui = VideoTrainerUI()
1253
+ return ui.create_ui()
1254
+
1255
+ if __name__ == "__main__":
1256
+ app = create_app()
1257
+
1258
+ allowed_paths = [
1259
+ str(STORAGE_PATH), # Base storage
1260
+ str(VIDEOS_TO_SPLIT_PATH),
1261
+ str(STAGING_PATH),
1262
+ str(TRAINING_PATH),
1263
+ str(TRAINING_VIDEOS_PATH),
1264
+ str(MODEL_PATH),
1265
+ str(OUTPUT_PATH)
1266
+ ]
1267
+ app.queue(default_concurrency_limit=1).launch(
1268
+ server_name="0.0.0.0",
1269
+ allowed_paths=allowed_paths
1270
+ )
captioning_service.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import shutil
4
+ import gradio as gr
5
+ from llava.model.builder import load_pretrained_model
6
+ from llava.mm_utils import tokenizer_image_token
7
+ import numpy as np
8
+ from decord import VideoReader, cpu
9
+ from pathlib import Path
10
+ from typing import Any, Tuple, Dict, Optional, AsyncGenerator, List
11
+ import asyncio
12
+ from dataclasses import dataclass
13
+ from datetime import datetime
14
+ import cv2
15
+ from config import TRAINING_VIDEOS_PATH, STAGING_PATH, PRELOAD_CAPTIONING_MODEL, CAPTIONING_MODEL, USE_MOCK_CAPTIONING_MODEL, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, VIDEOS_TO_SPLIT_PATH, DEFAULT_PROMPT_PREFIX
16
+ from utils import extract_scene_info, is_image_file, is_video_file
17
+ from finetrainers_utils import copy_files_to_training_dir, prepare_finetrainers_dataset
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ @dataclass
22
+ class CaptioningProgress:
23
+ video_name: str
24
+ total_frames: int
25
+ processed_frames: int
26
+ status: str
27
+ started_at: datetime
28
+ completed_at: Optional[datetime] = None
29
+ error: Optional[str] = None
30
+
31
+ class CaptioningService:
32
+ _instance = None
33
+ _model = None
34
+ _tokenizer = None
35
+ _image_processor = None
36
+ _model_loading = None
37
+ _loop = None
38
+
39
+ def __new__(cls, model_name=CAPTIONING_MODEL):
40
+ if cls._instance is not None:
41
+ return cls._instance
42
+
43
+ instance = super().__new__(cls)
44
+ if PRELOAD_CAPTIONING_MODEL:
45
+ cls._instance = instance
46
+ try:
47
+ cls._loop = asyncio.get_running_loop()
48
+ except RuntimeError:
49
+ cls._loop = asyncio.new_event_loop()
50
+ asyncio.set_event_loop(cls._loop)
51
+
52
+ if not USE_MOCK_CAPTIONING_MODEL and cls._model_loading is None:
53
+ cls._model_loading = cls._loop.create_task(cls._background_load_model(model_name))
54
+ return instance
55
+
56
+ def __init__(self, model_name=CAPTIONING_MODEL):
57
+ if hasattr(self, 'model_name'): # Already initialized
58
+ return
59
+
60
+ self.model_name = model_name
61
+ self.tokenizer = None
62
+ self.model = None
63
+ self.image_processor = None
64
+ self.active_tasks: Dict[str, CaptioningProgress] = {}
65
+ self._should_stop = False
66
+ self._model_loaded = False
67
+
68
+ @classmethod
69
+ async def _background_load_model(cls, model_name):
70
+ """Background task to load the model"""
71
+ try:
72
+ logger.info("Starting background model loading...")
73
+ if not cls._loop:
74
+ cls._loop = asyncio.get_running_loop()
75
+
76
+ def load_model():
77
+ try:
78
+ tokenizer, model, image_processor, _ = load_pretrained_model(
79
+ model_name, None, "llava_qwen",
80
+ torch_dtype="bfloat16", device_map="auto"
81
+ )
82
+ model.eval()
83
+ return tokenizer, model, image_processor
84
+ except Exception as e:
85
+ logger.error(f"Error in load_model: {str(e)}")
86
+ raise
87
+
88
+ result = await cls._loop.run_in_executor(None, load_model)
89
+
90
+ cls._tokenizer, cls._model, cls._image_processor = result
91
+ logger.info("Background model loading completed successfully!")
92
+
93
+ except Exception as e:
94
+ logger.error(f"Background model loading failed: {str(e)}")
95
+ cls._model_loading = None
96
+ raise
97
+
98
+ async def ensure_model_loaded(self):
99
+ """Ensure model is loaded before processing"""
100
+ if USE_MOCK_CAPTIONING_MODEL:
101
+ logger.info("Using mock model, skipping model loading")
102
+ self.__class__._model_loading = None
103
+ self._model_loaded = True
104
+ return
105
+
106
+ if not self._model_loaded:
107
+ try:
108
+ if PRELOAD_CAPTIONING_MODEL and self.__class__._model_loading:
109
+ logger.info("Waiting for background model loading to complete...")
110
+ if self.__class__._loop and self.__class__._loop != asyncio.get_running_loop():
111
+ logger.warning("Different event loop detected, creating new loading task")
112
+ self.__class__._model_loading = None
113
+ await self._load_model_sync()
114
+ else:
115
+ await self.__class__._model_loading
116
+ self.model = self.__class__._model
117
+ self.tokenizer = self.__class__._tokenizer
118
+ self.image_processor = self.__class__._image_processor
119
+ else:
120
+ await self._load_model_sync()
121
+
122
+ self._model_loaded = True
123
+ logger.info("Model loading completed!")
124
+ except Exception as e:
125
+ logger.error(f"Error loading model: {str(e)}")
126
+ raise
127
+
128
+ async def _load_model_sync(self):
129
+ """Synchronously load the model"""
130
+ logger.info("Loading model synchronously...")
131
+ current_loop = asyncio.get_running_loop()
132
+
133
+ def load_model():
134
+ return load_pretrained_model(
135
+ self.model_name, None, "llava_qwen",
136
+ torch_dtype="bfloat16", device_map="auto"
137
+ )
138
+
139
+ self.tokenizer, self.model, self.image_processor, _ = await current_loop.run_in_executor(
140
+ None, load_model
141
+ )
142
+ self.model.eval()
143
+
144
+ def _load_video(self, video_path: Path, max_frames_num: int = 64, fps: int = 1, force_sample: bool = True) -> tuple[np.ndarray, str, float]:
145
+ """Load and preprocess video frames"""
146
+
147
+ video_path_str = str(video_path) if hasattr(video_path, '__fspath__') else video_path
148
+
149
+ logger.debug(f"Loading video: {video_path_str}")
150
+
151
+ if max_frames_num == 0:
152
+ return np.zeros((1, 336, 336, 3)), "", 0
153
+
154
+ vr = VideoReader(video_path_str, ctx=cpu(0), num_threads=1)
155
+ total_frame_num = len(vr)
156
+ video_time = total_frame_num / vr.get_avg_fps()
157
+
158
+ # Calculate frame indices
159
+ fps = round(vr.get_avg_fps()/fps)
160
+ frame_idx = [i for i in range(0, len(vr), fps)]
161
+ frame_time = [i/fps for i in frame_idx]
162
+
163
+ if len(frame_idx) > max_frames_num or force_sample:
164
+ sample_fps = max_frames_num
165
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
166
+ frame_idx = uniform_sampled_frames.tolist()
167
+ frame_time = [i/vr.get_avg_fps() for i in frame_idx]
168
+
169
+ frame_time_str = ",".join([f"{i:.2f}s" for i in frame_time])
170
+
171
+ try:
172
+ frames = vr.get_batch(frame_idx).asnumpy()
173
+ logger.debug(f"Loaded {len(frames)} frames with shape {frames.shape}")
174
+ return frames, frame_time_str, video_time
175
+ except Exception as e:
176
+ logger.error(f"Error loading video frames: {str(e)}")
177
+ raise
178
+
179
+ async def process_video(self, video_path: Path, prompt: str, prompt_prefix: str = "") -> AsyncGenerator[tuple[CaptioningProgress, Optional[str]], None]:
180
+ try:
181
+ video_name = video_path.name
182
+ logger.info(f"Starting processing of video: {video_name}")
183
+
184
+ # Load video metadata
185
+ logger.debug(f"Loading video metadata for {video_name}")
186
+ loop = asyncio.get_event_loop()
187
+ vr = await loop.run_in_executor(None, lambda: VideoReader(str(video_path), ctx=cpu(0)))
188
+ total_frames = len(vr)
189
+
190
+ progress = CaptioningProgress(
191
+ video_name=video_name,
192
+ total_frames=total_frames,
193
+ processed_frames=0,
194
+ status="initializing",
195
+ started_at=datetime.now()
196
+ )
197
+ self.active_tasks[video_name] = progress
198
+ yield progress, None
199
+
200
+ # Get parent caption if this is a clip
201
+ parent_caption = ""
202
+ if "___" in video_path.stem:
203
+ parent_name, _ = extract_scene_info(video_path.stem)
204
+ #print(f"parent_name is {parent_name}")
205
+ parent_txt_path = VIDEOS_TO_SPLIT_PATH / f"{parent_name}.txt"
206
+ if parent_txt_path.exists():
207
+ logger.debug(f"Found parent caption file: {parent_txt_path}")
208
+ parent_caption = parent_txt_path.read_text().strip()
209
+
210
+ # Ensure model is loaded before processing
211
+ await self.ensure_model_loaded()
212
+
213
+ if USE_MOCK_CAPTIONING_MODEL:
214
+
215
+ # Even in mock mode, we'll generate a caption that shows we processed parent info
216
+ clip_caption = f"This is a test caption for {video_name}"
217
+
218
+ # Combine clip caption with parent caption
219
+ if parent_caption and not full_caption.endswith(parent_caption):
220
+ #print(f"we have parent_caption, so we define the full_caption as {clip_caption}\n{parent_caption}")
221
+
222
+ full_caption = f"{clip_caption}\n{parent_caption}"
223
+ else:
224
+ #print(f"we don't have a parent_caption, so we define the full_caption as {clip_caption}")
225
+
226
+ full_caption = clip_caption
227
+
228
+ if prompt_prefix and not full_caption.startswith(prompt_prefix):
229
+ full_caption = f"{prompt_prefix}{full_caption}"
230
+
231
+ # Write the caption file
232
+ txt_path = video_path.with_suffix('.txt')
233
+ txt_path.write_text(full_caption)
234
+
235
+ logger.debug(f"Mock mode: Saved caption to {txt_path}")
236
+
237
+ progress.status = "completed"
238
+ progress.processed_frames = total_frames
239
+ progress.completed_at = datetime.now()
240
+ yield progress, full_caption
241
+
242
+ else:
243
+ # Process frames in batches
244
+ max_frames_num = 64
245
+ frames, frame_times_str, video_time = await loop.run_in_executor(
246
+ None,
247
+ lambda: self._load_video(video_path, max_frames_num)
248
+ )
249
+
250
+ # Process all frames at once using the image processor
251
+ processed_frames = await loop.run_in_executor(
252
+ None,
253
+ lambda: self.image_processor.preprocess(
254
+ frames,
255
+ return_tensors="pt"
256
+ )["pixel_values"]
257
+ )
258
+
259
+ # Update progress
260
+ progress.processed_frames = len(frames)
261
+ progress.status = "generating caption"
262
+ yield progress, None
263
+
264
+ # Move processed frames to GPU
265
+ video_tensor = processed_frames.to('cuda').bfloat16()
266
+
267
+ time_instruction = (f"The video lasts for {video_time:.2f} seconds, and {len(frames)} "
268
+ f"frames are uniformly sampled from it. These frames are located at {frame_times_str}.")
269
+ full_prompt = f"<image>{time_instruction}\n{prompt}"
270
+
271
+ input_ids = await loop.run_in_executor(
272
+ None,
273
+ lambda: tokenizer_image_token(full_prompt, self.tokenizer, return_tensors="pt").unsqueeze(0).to('cuda')
274
+ )
275
+
276
+ # Generate caption
277
+ with torch.no_grad():
278
+ output = await loop.run_in_executor(
279
+ None,
280
+ lambda: self.model.generate(
281
+ input_ids,
282
+ images=[video_tensor],
283
+ modalities=["video"],
284
+ do_sample=False,
285
+ temperature=0,
286
+ max_new_tokens=4096,
287
+ )
288
+ )
289
+
290
+ clip_caption = await loop.run_in_executor(
291
+ None,
292
+ lambda: self.tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
293
+ )
294
+
295
+ # Combine clip caption with parent caption
296
+ if parent_caption:
297
+ print(f"we have parent_caption, so we define the full_caption as {clip_caption}\n{parent_caption}")
298
+
299
+ full_caption = f"{clip_caption}\n{parent_caption}"
300
+ else:
301
+ print(f"we don't have a parent_caption, so we define the full_caption as {clip_caption}")
302
+
303
+ full_caption = clip_caption
304
+
305
+ if prompt_prefix:
306
+ full_caption = f"{prompt_prefix}{full_caption}"
307
+
308
+
309
+ # Write the caption file
310
+ txt_path = video_path.with_suffix('.txt')
311
+ txt_path.write_text(full_caption)
312
+
313
+ progress.status = "completed"
314
+ progress.completed_at = datetime.now()
315
+ gr.Info(f"Successfully generated caption for {video_name}")
316
+ yield progress, full_caption
317
+
318
+ except Exception as e:
319
+ progress.status = "error"
320
+ progress.error = str(e)
321
+ progress.completed_at = datetime.now()
322
+ yield progress, None
323
+ raise gr.Error(f"Error processing video: {str(e)}")
324
+
325
+ async def process_image(self, image_path: Path, prompt: str, prompt_prefix: str = "") -> AsyncGenerator[tuple[CaptioningProgress, Optional[str]], None]:
326
+ """Process a single image for captioning"""
327
+ try:
328
+ image_name = image_path.name
329
+ logger.info(f"Starting processing of image: {image_name}")
330
+
331
+ progress = CaptioningProgress(
332
+ video_name=image_name, # Reusing video_name field for images
333
+ total_frames=1,
334
+ processed_frames=0,
335
+ status="initializing",
336
+ started_at=datetime.now()
337
+ )
338
+ self.active_tasks[image_name] = progress
339
+ yield progress, None
340
+
341
+ # Ensure model is loaded
342
+ await self.ensure_model_loaded()
343
+
344
+ if USE_MOCK_CAPTIONING_MODEL:
345
+ progress.status = "completed"
346
+ progress.processed_frames = 1
347
+ progress.completed_at = datetime.now()
348
+ print("yielding fake")
349
+ yield progress, "This is a test image caption"
350
+ return
351
+
352
+ # Read and process image
353
+ loop = asyncio.get_event_loop()
354
+ image = await loop.run_in_executor(
355
+ None,
356
+ lambda: cv2.imread(str(image_path))
357
+ )
358
+ if image is None:
359
+ raise ValueError(f"Could not read image: {str(image_path)}")
360
+
361
+ # Convert BGR to RGB
362
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
363
+
364
+ # Process image
365
+ processed_image = await loop.run_in_executor(
366
+ None,
367
+ lambda: self.image_processor.preprocess(
368
+ image,
369
+ return_tensors="pt"
370
+ )["pixel_values"]
371
+ )
372
+
373
+ progress.processed_frames = 1
374
+ progress.status = "generating caption"
375
+ yield progress, None
376
+
377
+ # Move to GPU and generate caption
378
+ image_tensor = processed_image.to('cuda').bfloat16()
379
+ full_prompt = f"<image>{prompt}"
380
+
381
+ input_ids = await loop.run_in_executor(
382
+ None,
383
+ lambda: tokenizer_image_token(full_prompt, self.tokenizer, return_tensors="pt").unsqueeze(0).to('cuda')
384
+ )
385
+
386
+ with torch.no_grad():
387
+ output = await loop.run_in_executor(
388
+ None,
389
+ lambda: self.model.generate(
390
+ input_ids,
391
+ images=[image_tensor],
392
+ modalities=["image"],
393
+ do_sample=False,
394
+ temperature=0,
395
+ max_new_tokens=4096,
396
+ )
397
+ )
398
+
399
+ caption = await loop.run_in_executor(
400
+ None,
401
+ lambda: self.tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
402
+ )
403
+
404
+ progress.status = "completed"
405
+ progress.completed_at = datetime.now()
406
+ gr.Info(f"Successfully generated caption for {image_name}")
407
+ yield progress, caption
408
+
409
+ except Exception as e:
410
+ progress.status = "error"
411
+ progress.error = str(e)
412
+ progress.completed_at = datetime.now()
413
+ yield progress, None
414
+ raise gr.Error(f"Error processing image: {str(e)}")
415
+
416
+
417
+ async def start_caption_generation(self, custom_prompt: str, prompt_prefix: str) -> AsyncGenerator[List[List[str]], None]:
418
+ """Iterates over clips to auto-generate captions asynchronously."""
419
+ try:
420
+ logger.info("Starting auto-caption generation")
421
+
422
+ # Use provided prompt or default
423
+ default_prompt = DEFAULT_CAPTIONING_BOT_INSTRUCTIONS
424
+ prompt = custom_prompt.strip() or default_prompt
425
+ logger.debug(f"Using prompt: {prompt}")
426
+
427
+ # Find files needing captions
428
+ video_files = list(STAGING_PATH.glob("*.mp4"))
429
+ image_files = [f for f in STAGING_PATH.glob("*") if is_image_file(f)]
430
+ all_files = video_files + image_files
431
+
432
+ # Filter for files missing captions or with empty caption files
433
+ files_to_process = []
434
+ for file_path in all_files:
435
+ caption_path = file_path.with_suffix('.txt')
436
+ needs_caption = (
437
+ not caption_path.exists() or
438
+ caption_path.stat().st_size == 0 or
439
+ caption_path.read_text().strip() == ""
440
+ )
441
+ if needs_caption:
442
+ files_to_process.append(file_path)
443
+
444
+ logger.info(f"Found {len(files_to_process)} files needing captions")
445
+
446
+ if not files_to_process:
447
+ logger.info("No files need captioning")
448
+ yield []
449
+ return
450
+
451
+ self._should_stop = False
452
+ self.active_tasks.clear()
453
+ status_update: Dict[str, Dict[str, Any]] = {}
454
+
455
+ for file_path in all_files:
456
+ if self._should_stop:
457
+ break
458
+
459
+ try:
460
+ print(f"we are in file_path {str(file_path)}")
461
+ # Choose appropriate processing method based on file type
462
+ if is_video_file(file_path):
463
+ process_gen = self.process_video(file_path, prompt, prompt_prefix)
464
+ else:
465
+ process_gen = self.process_image(file_path, prompt, prompt_prefix)
466
+ print("got process_gen = ", process_gen)
467
+ async for progress, caption in process_gen:
468
+ print(f"process_gen contains this caption = {caption}")
469
+ if caption and prompt_prefix and not caption.startswith(prompt_prefix):
470
+ caption = f"{prompt_prefix}{caption}"
471
+
472
+ # Save caption
473
+ if caption:
474
+ txt_path = file_path.with_suffix('.txt')
475
+ txt_path.write_text(caption)
476
+
477
+ logger.debug(f"Progress update: {progress.status}")
478
+
479
+ # Store progress info
480
+ status_update[file_path.name] = {
481
+ "status": progress.status,
482
+ "frames": progress.processed_frames,
483
+ "total": progress.total_frames
484
+ }
485
+
486
+ # Convert to list format for Gradio DataFrame
487
+ rows = []
488
+ for file_name, info in status_update.items():
489
+ status = info["status"]
490
+ if status == "processing":
491
+ percent = (info["frames"] / info["total"]) * 100
492
+ status = f"Analyzing... {percent:.1f}% ({info['frames']}/{info['total']} frames)"
493
+ elif status == "generating caption":
494
+ status = "Generating caption..."
495
+ elif status == "error":
496
+ status = f"Error: {progress.error}"
497
+ elif status == "completed":
498
+ status = "Completed"
499
+
500
+ rows.append([file_name, status])
501
+
502
+ yield rows
503
+ await asyncio.sleep(0.1)
504
+
505
+
506
+ except Exception as e:
507
+ logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
508
+ rows = [[str(file_path.name), f"Error: {str(e)}"]]
509
+ yield rows
510
+ continue
511
+
512
+ logger.info("Auto-caption generation completed, cyping assets to the training dir..")
513
+
514
+ copy_files_to_training_dir(prompt_prefix)
515
+ except Exception as e:
516
+ logger.error(f"Error in start_caption_generation: {str(e)}")
517
+ yield [[str(e), "error"]]
518
+ raise
519
+
520
+ def stop_captioning(self):
521
+ """Stop all ongoing captioning tasks"""
522
+ logger.info("Stopping all captioning tasks")
523
+ self._should_stop = True
524
+
525
+ def close(self):
526
+ """Clean up resources"""
527
+ logger.info("Cleaning up captioning service resources")
528
+ if hasattr(self, 'model'):
529
+ del self.model
530
+ if hasattr(self, 'tokenizer'):
531
+ del self.tokenizer
532
+ if hasattr(self, 'image_processor'):
533
+ del self.image_processor
534
+ torch.cuda.empty_cache()
config.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict, Any, Optional, List, Tuple
4
+ from pathlib import Path
5
+ from utils import parse_bool_env
6
+
7
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
8
+ ASK_USER_TO_DUPLICATE_SPACE = parse_bool_env(os.getenv("ASK_USER_TO_DUPLICATE_SPACE"))
9
+
10
+ # Base storage path
11
+ STORAGE_PATH = Path(os.environ.get('STORAGE_PATH', '.data'))
12
+
13
+ # Subdirectories for different data types
14
+ VIDEOS_TO_SPLIT_PATH = STORAGE_PATH / "videos_to_split" # Raw uploaded/downloaded files
15
+ STAGING_PATH = STORAGE_PATH / "staging" # This is where files that are captioned or need captioning are waiting
16
+ TRAINING_PATH = STORAGE_PATH / "training" # Folder containing the final training dataset
17
+ TRAINING_VIDEOS_PATH = TRAINING_PATH / "videos" # Captioned clips ready for training
18
+ MODEL_PATH = STORAGE_PATH / "model" # Model checkpoints and files
19
+ OUTPUT_PATH = STORAGE_PATH / "output" # Training outputs and logs
20
+
21
+ # On the production server we can afford to preload the big model
22
+ PRELOAD_CAPTIONING_MODEL = parse_bool_env(os.environ.get('PRELOAD_CAPTIONING_MODEL'))
23
+
24
+ CAPTIONING_MODEL = "lmms-lab/LLaVA-Video-7B-Qwen2"
25
+
26
+ DEFAULT_PROMPT_PREFIX = "In the style of TOK, "
27
+
28
+ # This is only use to debug things in local
29
+ USE_MOCK_CAPTIONING_MODEL = parse_bool_env(os.environ.get('USE_MOCK_CAPTIONING_MODEL'))
30
+
31
+ DEFAULT_CAPTIONING_BOT_INSTRUCTIONS = "Please write a full description of the following video: camera (close-up shot, medium-shot..), genre (music video, horror movie scene, video game footage, go pro footage, japanese anime, noir film, science-fiction, action movie, documentary..), characters (physical appearance, look, skin, facial features, haircut, clothing), scene (action, positions, movements), location (indoor, outdoor, place, building, country..), time and lighting (natural, golden hour, night time, LED lights, kelvin temperature etc), weather and climate (dusty, rainy, fog, haze, snowing..), era/settings"
32
+
33
+ # Create directories
34
+ STORAGE_PATH.mkdir(parents=True, exist_ok=True)
35
+ VIDEOS_TO_SPLIT_PATH.mkdir(parents=True, exist_ok=True)
36
+ STAGING_PATH.mkdir(parents=True, exist_ok=True)
37
+ TRAINING_PATH.mkdir(parents=True, exist_ok=True)
38
+ TRAINING_VIDEOS_PATH.mkdir(parents=True, exist_ok=True)
39
+ MODEL_PATH.mkdir(parents=True, exist_ok=True)
40
+ OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
41
+
42
+ # Image normalization settings
43
+ NORMALIZE_IMAGES_TO = os.environ.get('NORMALIZE_IMAGES_TO', 'png').lower()
44
+ if NORMALIZE_IMAGES_TO not in ['png', 'jpg']:
45
+ raise ValueError("NORMALIZE_IMAGES_TO must be either 'png' or 'jpg'")
46
+ JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97'))
47
+
48
+ MODEL_TYPES = {
49
+ "HunyuanVideo (LoRA)": "hunyuan_video",
50
+ "LTX-Video (LoRA)": "ltx_video"
51
+ }
52
+
53
+
54
+ # it is best to use resolutions that are powers of 8
55
+ # The resolution should be divisible by 32
56
+ # so we cannot use 1080, 540 etc as they are not divisible by 32
57
+ TRAINING_WIDTH = 768 # 32 * 24
58
+ TRAINING_HEIGHT = 512 # 32 * 16
59
+
60
+ # 1920 = 32 * 60 (divided by 2: 960 = 32 * 30)
61
+ # 1920 = 32 * 60 (divided by 2: 960 = 32 * 30)
62
+ # 1056 = 32 * 33 (divided by 2: 544 = 17 * 32)
63
+ # 1024 = 32 * 32 (divided by 2: 512 = 16 * 32)
64
+ # it is important that the resolution buckets properly cover the training dataset,
65
+ # or else that we exclude from the dataset videos that are out of this range
66
+ # right now, finetrainers will crash if that happens, so the workaround is to have more buckets in here
67
+
68
+ TRAINING_BUCKETS = [
69
+ (8 * 2 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 16 + 1
70
+ (8 * 4 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 32 + 1
71
+ (8 * 6 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 48 + 1
72
+ (8 * 8 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 64 + 1
73
+ (8 * 10 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 80 + 1
74
+ (8 * 12 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 96 + 1
75
+ (8 * 14 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 112 + 1
76
+ (8 * 16 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 128 + 1
77
+ (8 * 18 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 144 + 1
78
+ (8 * 20 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 160 + 1
79
+ (8 * 22 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 176 + 1
80
+ (8 * 24 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 192 + 1
81
+ (8 * 28 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 224 + 1
82
+ (8 * 32 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 256 + 1
83
+ ]
84
+
85
+ @dataclass
86
+ class TrainingConfig:
87
+ """Configuration class for finetrainers training"""
88
+
89
+ # Required arguments must come first
90
+ model_name: str
91
+ pretrained_model_name_or_path: str
92
+ data_root: str
93
+ output_dir: str
94
+
95
+ # Optional arguments follow
96
+ revision: Optional[str] = None
97
+ variant: Optional[str] = None
98
+ cache_dir: Optional[str] = None
99
+
100
+ # Dataset arguments
101
+
102
+ # note: video_column and caption_column serve a dual purpose,
103
+ # when using the CSV mode they have to be CSV column names,
104
+ # otherwise they have to be filename (relative to the data_root dir path)
105
+ video_column: str = "videos.txt"
106
+ caption_column: str = "prompts.txt"
107
+
108
+ id_token: Optional[str] = None
109
+ video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda: TRAINING_BUCKETS)
110
+ video_reshape_mode: str = "center"
111
+ caption_dropout_p: float = 0.05
112
+ caption_dropout_technique: str = "empty"
113
+ precompute_conditions: bool = False
114
+
115
+ # Diffusion arguments
116
+ flow_resolution_shifting: bool = False
117
+ flow_weighting_scheme: str = "none"
118
+ flow_logit_mean: float = 0.0
119
+ flow_logit_std: float = 1.0
120
+ flow_mode_scale: float = 1.29
121
+
122
+ # Training arguments
123
+ training_type: str = "lora"
124
+ seed: int = 42
125
+ mixed_precision: str = "bf16"
126
+ batch_size: int = 1
127
+ train_epochs: int = 70
128
+ lora_rank: int = 128
129
+ lora_alpha: int = 128
130
+ target_modules: List[str] = field(default_factory=lambda: ["to_q", "to_k", "to_v", "to_out.0"])
131
+ gradient_accumulation_steps: int = 1
132
+ gradient_checkpointing: bool = True
133
+ checkpointing_steps: int = 500
134
+ checkpointing_limit: Optional[int] = 2
135
+ resume_from_checkpoint: Optional[str] = None
136
+ enable_slicing: bool = True
137
+ enable_tiling: bool = True
138
+
139
+ # Optimizer arguments
140
+ optimizer: str = "adamw"
141
+ lr: float = 3e-5
142
+ scale_lr: bool = False
143
+ lr_scheduler: str = "constant_with_warmup"
144
+ lr_warmup_steps: int = 100
145
+ lr_num_cycles: int = 1
146
+ lr_power: float = 1.0
147
+ beta1: float = 0.9
148
+ beta2: float = 0.95
149
+ weight_decay: float = 1e-4
150
+ epsilon: float = 1e-8
151
+ max_grad_norm: float = 1.0
152
+
153
+ # Miscellaneous arguments
154
+ tracker_name: str = "finetrainers"
155
+ report_to: str = "wandb"
156
+ nccl_timeout: int = 1800
157
+
158
+ @classmethod
159
+ def hunyuan_video_lora(cls, data_path: str, output_path: str) -> 'TrainingConfig':
160
+ """Configuration for Hunyuan video-to-video LoRA training"""
161
+ return cls(
162
+ model_name="hunyuan_video",
163
+ pretrained_model_name_or_path="hunyuanvideo-community/HunyuanVideo",
164
+ data_root=data_path,
165
+ output_dir=output_path,
166
+ batch_size=1,
167
+ train_epochs=70,
168
+ lr=2e-5,
169
+ gradient_checkpointing=True,
170
+ id_token="afkx",
171
+ gradient_accumulation_steps=1,
172
+ lora_rank=128,
173
+ lora_alpha=128,
174
+ video_resolution_buckets=TRAINING_BUCKETS,
175
+ caption_dropout_p=0.05,
176
+ flow_weighting_scheme="none" # Hunyuan specific
177
+ )
178
+
179
+ @classmethod
180
+ def ltx_video_lora(cls, data_path: str, output_path: str) -> 'TrainingConfig':
181
+ """Configuration for LTX-Video LoRA training"""
182
+ return cls(
183
+ model_name="ltx_video",
184
+ pretrained_model_name_or_path="Lightricks/LTX-Video",
185
+ data_root=data_path,
186
+ output_dir=output_path,
187
+ batch_size=1,
188
+ train_epochs=70,
189
+ lr=3e-5,
190
+ gradient_checkpointing=True,
191
+ id_token="BW_STYLE",
192
+ gradient_accumulation_steps=4,
193
+ lora_rank=128,
194
+ lora_alpha=128,
195
+ video_resolution_buckets=TRAINING_BUCKETS,
196
+ caption_dropout_p=0.05,
197
+ flow_weighting_scheme="logit_normal" # LTX specific
198
+ )
199
+
200
+ def to_args_list(self) -> List[str]:
201
+ """Convert config to command line arguments list"""
202
+ args = []
203
+
204
+ # Model arguments
205
+
206
+ # Add model_name (required argument)
207
+ args.extend(["--model_name", self.model_name])
208
+
209
+ args.extend(["--pretrained_model_name_or_path", self.pretrained_model_name_or_path])
210
+ if self.revision:
211
+ args.extend(["--revision", self.revision])
212
+ if self.variant:
213
+ args.extend(["--variant", self.variant])
214
+ if self.cache_dir:
215
+ args.extend(["--cache_dir", self.cache_dir])
216
+
217
+ # Dataset arguments
218
+ args.extend(["--data_root", self.data_root])
219
+ args.extend(["--video_column", self.video_column])
220
+ args.extend(["--caption_column", self.caption_column])
221
+ if self.id_token:
222
+ args.extend(["--id_token", self.id_token])
223
+
224
+ # Add video resolution buckets
225
+ if self.video_resolution_buckets:
226
+ bucket_strs = [f"{f}x{h}x{w}" for f, h, w in self.video_resolution_buckets]
227
+ args.extend(["--video_resolution_buckets"] + bucket_strs)
228
+
229
+ if self.video_reshape_mode:
230
+ args.extend(["--video_reshape_mode", self.video_reshape_mode])
231
+
232
+ args.extend(["--caption_dropout_p", str(self.caption_dropout_p)])
233
+ args.extend(["--caption_dropout_technique", self.caption_dropout_technique])
234
+ if self.precompute_conditions:
235
+ args.append("--precompute_conditions")
236
+
237
+ # Diffusion arguments
238
+ if self.flow_resolution_shifting:
239
+ args.append("--flow_resolution_shifting")
240
+ args.extend(["--flow_weighting_scheme", self.flow_weighting_scheme])
241
+ args.extend(["--flow_logit_mean", str(self.flow_logit_mean)])
242
+ args.extend(["--flow_logit_std", str(self.flow_logit_std)])
243
+ args.extend(["--flow_mode_scale", str(self.flow_mode_scale)])
244
+
245
+ # Training arguments
246
+ args.extend(["--training_type", self.training_type])
247
+ args.extend(["--seed", str(self.seed)])
248
+
249
+ # we don't use this, because mixed precision is handled by accelerate launch, not by the training script itself.
250
+ #args.extend(["--mixed_precision", self.mixed_precision])
251
+
252
+ args.extend(["--batch_size", str(self.batch_size)])
253
+ args.extend(["--train_epochs", str(self.train_epochs)])
254
+ args.extend(["--rank", str(self.lora_rank)])
255
+ args.extend(["--lora_alpha", str(self.lora_alpha)])
256
+ args.extend(["--target_modules"] + self.target_modules)
257
+ args.extend(["--gradient_accumulation_steps", str(self.gradient_accumulation_steps)])
258
+ if self.gradient_checkpointing:
259
+ args.append("--gradient_checkpointing")
260
+ args.extend(["--checkpointing_steps", str(self.checkpointing_steps)])
261
+ if self.checkpointing_limit:
262
+ args.extend(["--checkpointing_limit", str(self.checkpointing_limit)])
263
+ if self.resume_from_checkpoint:
264
+ args.extend(["--resume_from_checkpoint", self.resume_from_checkpoint])
265
+ if self.enable_slicing:
266
+ args.append("--enable_slicing")
267
+ if self.enable_tiling:
268
+ args.append("--enable_tiling")
269
+
270
+ # Optimizer arguments
271
+ args.extend(["--optimizer", self.optimizer])
272
+ args.extend(["--lr", str(self.lr)])
273
+ if self.scale_lr:
274
+ args.append("--scale_lr")
275
+ args.extend(["--lr_scheduler", self.lr_scheduler])
276
+ args.extend(["--lr_warmup_steps", str(self.lr_warmup_steps)])
277
+ args.extend(["--lr_num_cycles", str(self.lr_num_cycles)])
278
+ args.extend(["--lr_power", str(self.lr_power)])
279
+ args.extend(["--beta1", str(self.beta1)])
280
+ args.extend(["--beta2", str(self.beta2)])
281
+ args.extend(["--weight_decay", str(self.weight_decay)])
282
+ args.extend(["--epsilon", str(self.epsilon)])
283
+ args.extend(["--max_grad_norm", str(self.max_grad_norm)])
284
+
285
+ # Miscellaneous arguments
286
+ args.extend(["--tracker_name", self.tracker_name])
287
+ args.extend(["--output_dir", self.output_dir])
288
+ args.extend(["--report_to", self.report_to])
289
+ args.extend(["--nccl_timeout", str(self.nccl_timeout)])
290
+
291
+ # normally this is disabled by default, but there was a bug in finetrainers
292
+ # so I had to fix it in trainer.py to make sure we check for push_to-hub
293
+ #args.append("--push_to_hub")
294
+ #args.extend(["--hub_token", str(False)])
295
+ #args.extend(["--hub_model_id", str(False)])
296
+
297
+ # If you are using LLM-captioned videos, it is common to see many unwanted starting phrases like
298
+ # "In this video, ...", "This video features ...", etc.
299
+ # To remove a simple subset of these phrases, you can specify
300
+ # --remove_common_llm_caption_prefixes when starting training.
301
+ args.append("--remove_common_llm_caption_prefixes")
302
+
303
+ return args
finetrainers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .args import Args, parse_arguments
2
+ from .trainer import Trainer
finetrainers/args.py ADDED
@@ -0,0 +1,1191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+
7
+ from .constants import DEFAULT_IMAGE_RESOLUTION_BUCKETS, DEFAULT_VIDEO_RESOLUTION_BUCKETS
8
+ from .models import SUPPORTED_MODEL_CONFIGS
9
+
10
+
11
+ class Args:
12
+ r"""
13
+ The arguments for the finetrainers training script.
14
+
15
+ For helpful information about arguments, run `python train.py --help`.
16
+
17
+ TODO(aryan): add `python train.py --recommend_configs --model_name <model_name>` to recommend
18
+ good training configs for a model after extensive testing.
19
+ TODO(aryan): add `python train.py --memory_requirements --model_name <model_name>` to show
20
+ memory requirements per model, per training type with sensible training settings.
21
+
22
+ MODEL ARGUMENTS
23
+ ---------------
24
+ model_name (`str`):
25
+ Name of model to train. To get a list of models, run `python train.py --list_models`.
26
+ pretrained_model_name_or_path (`str`):
27
+ Path to pretrained model or model identifier from https://huggingface.co/models. The model should be
28
+ loadable based on specified `model_name`.
29
+ revision (`str`, defaults to `None`):
30
+ If provided, the model will be loaded from a specific branch of the model repository.
31
+ variant (`str`, defaults to `None`):
32
+ Variant of model weights to use. Some models provide weight variants, such as `fp16`, to reduce disk
33
+ storage requirements.
34
+ cache_dir (`str`, defaults to `None`):
35
+ The directory where the downloaded models and datasets will be stored, or loaded from.
36
+ text_encoder_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
37
+ Data type for the text encoder when generating text embeddings.
38
+ text_encoder_2_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
39
+ Data type for the text encoder 2 when generating text embeddings.
40
+ text_encoder_3_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
41
+ Data type for the text encoder 3 when generating text embeddings.
42
+ transformer_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
43
+ Data type for the transformer model.
44
+ vae_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
45
+ Data type for the VAE model.
46
+ layerwise_upcasting_modules (`List[str]`, defaults to `[]`):
47
+ Modules that should have fp8 storage weights but higher precision computation. Choose between ['transformer'].
48
+ layerwise_upcasting_storage_dtype (`torch.dtype`, defaults to `float8_e4m3fn`):
49
+ Data type for the layerwise upcasting storage. Choose between ['float8_e4m3fn', 'float8_e5m2'].
50
+ layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"]`):
51
+ Modules to skip for layerwise upcasting. Layers such as normalization and modulation, when casted to fp8 precision
52
+ naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers
53
+ by default, and recommend adding more layers to the default list based on the model architecture.
54
+
55
+ DATASET ARGUMENTS
56
+ -----------------
57
+ data_root (`str`):
58
+ A folder containing the training data.
59
+ dataset_file (`str`, defaults to `None`):
60
+ Path to a CSV/JSON/JSONL file containing metadata for training. This should be provided if you're not using
61
+ a directory dataset format containing a simple `prompts.txt` and `videos.txt`/`images.txt` for example.
62
+ video_column (`str`):
63
+ The column of the dataset containing videos. Or, the name of the file in `data_root` folder containing the
64
+ line-separated path to video data.
65
+ caption_column (`str`):
66
+ The column of the dataset containing the instance prompt for each video. Or, the name of the file in
67
+ `data_root` folder containing the line-separated instance prompts.
68
+ id_token (`str`, defaults to `None`):
69
+ Identifier token appended to the start of each prompt if provided. This is useful for LoRA-type training.
70
+ image_resolution_buckets (`List[Tuple[int, int]]`, defaults to `None`):
71
+ Resolution buckets for images. This should be a list of integer tuples, where each tuple represents the
72
+ resolution (height, width) of the image. All images will be resized to the nearest bucket resolution.
73
+ video_resolution_buckets (`List[Tuple[int, int, int]]`, defaults to `None`):
74
+ Resolution buckets for videos. This should be a list of integer tuples, where each tuple represents the
75
+ resolution (num_frames, height, width) of the video. All videos will be resized to the nearest bucket
76
+ resolution.
77
+ video_reshape_mode (`str`, defaults to `None`):
78
+ All input videos are reshaped to this mode. Choose between ['center', 'random', 'none'].
79
+ TODO(aryan): We don't support this.
80
+ caption_dropout_p (`float`, defaults to `0.00`):
81
+ Probability of dropout for the caption tokens. This is useful to improve the unconditional generation
82
+ quality of the model.
83
+ caption_dropout_technique (`str`, defaults to `empty`):
84
+ Technique to use for caption dropout. Choose between ['empty', 'zero']. Some models apply caption dropout
85
+ by setting the prompt condition to an empty string, while others zero-out the text embedding tensors.
86
+ precompute_conditions (`bool`, defaults to `False`):
87
+ Whether or not to precompute the conditionings for the model. This is useful for faster training, and
88
+ reduces the memory requirements.
89
+ remove_common_llm_caption_prefixes (`bool`, defaults to `False`):
90
+ Whether or not to remove common LLM caption prefixes. This is useful for improving the quality of the
91
+ generated text.
92
+
93
+ DATALOADER_ARGUMENTS
94
+ --------------------
95
+ See https://pytorch.org/docs/stable/data.html for more information.
96
+
97
+ dataloader_num_workers (`int`, defaults to `0`):
98
+ Number of subprocesses to use for data loading. `0` means that the data will be loaded in a blocking manner
99
+ on the main process.
100
+ pin_memory (`bool`, defaults to `False`):
101
+ Whether or not to use the pinned memory setting in PyTorch dataloader. This is useful for faster data loading.
102
+
103
+ DIFFUSION ARGUMENTS
104
+ -------------------
105
+ flow_resolution_shifting (`bool`, defaults to `False`):
106
+ Resolution-dependent shifting of timestep schedules.
107
+ [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206).
108
+ TODO(aryan): We don't support this yet.
109
+ flow_base_seq_len (`int`, defaults to `256`):
110
+ Base number of tokens for images/video when applying resolution-dependent shifting.
111
+ flow_max_seq_len (`int`, defaults to `4096`):
112
+ Maximum number of tokens for images/video when applying resolution-dependent shifting.
113
+ flow_base_shift (`float`, defaults to `0.5`):
114
+ Base shift for timestep schedules when applying resolution-dependent shifting.
115
+ flow_max_shift (`float`, defaults to `1.15`):
116
+ Maximum shift for timestep schedules when applying resolution-dependent shifting.
117
+ flow_shift (`float`, defaults to `1.0`):
118
+ Instead of training with uniform/logit-normal sigmas, shift them as (shift * sigma) / (1 + (shift - 1) * sigma).
119
+ Setting it higher is helpful when trying to train models for high-resolution generation or to produce better
120
+ samples in lower number of inference steps.
121
+ flow_weighting_scheme (`str`, defaults to `none`):
122
+ We default to the "none" weighting scheme for uniform sampling and uniform loss.
123
+ Choose between ['sigma_sqrt', 'logit_normal', 'mode', 'cosmap', 'none'].
124
+ flow_logit_mean (`float`, defaults to `0.0`):
125
+ Mean to use when using the `'logit_normal'` weighting scheme.
126
+ flow_logit_std (`float`, defaults to `1.0`):
127
+ Standard deviation to use when using the `'logit_normal'` weighting scheme.
128
+ flow_mode_scale (`float`, defaults to `1.29`):
129
+ Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.
130
+
131
+ TRAINING ARGUMENTS
132
+ ------------------
133
+ training_type (`str`, defaults to `None`):
134
+ Type of training to perform. Choose between ['lora'].
135
+ seed (`int`, defaults to `42`):
136
+ A seed for reproducible training.
137
+ batch_size (`int`, defaults to `1`):
138
+ Per-device batch size.
139
+ train_epochs (`int`, defaults to `1`):
140
+ Number of training epochs.
141
+ train_steps (`int`, defaults to `None`):
142
+ Total number of training steps to perform. If provided, overrides `train_epochs`.
143
+ rank (`int`, defaults to `128`):
144
+ The rank for LoRA matrices.
145
+ lora_alpha (`float`, defaults to `64`):
146
+ The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.
147
+ target_modules (`List[str]`, defaults to `["to_k", "to_q", "to_v", "to_out.0"]`):
148
+ The target modules for LoRA. Make sure to modify this based on the model.
149
+ gradient_accumulation_steps (`int`, defaults to `1`):
150
+ Number of gradients steps to accumulate before performing an optimizer step.
151
+ gradient_checkpointing (`bool`, defaults to `False`):
152
+ Whether or not to use gradient/activation checkpointing to save memory at the expense of slower
153
+ backward pass.
154
+ checkpointing_steps (`int`, defaults to `500`):
155
+ Save a checkpoint of the training state every X training steps. These checkpoints can be used both
156
+ as final checkpoints in case they are better than the last checkpoint, and are also suitable for
157
+ resuming training using `resume_from_checkpoint`.
158
+ checkpointing_limit (`int`, defaults to `None`):
159
+ Max number of checkpoints to store.
160
+ resume_from_checkpoint (`str`, defaults to `None`):
161
+ Whether training should be resumed from a previous checkpoint. Use a path saved by `checkpointing_steps`,
162
+ or `"latest"` to automatically select the last available checkpoint.
163
+
164
+ OPTIMIZER ARGUMENTS
165
+ -------------------
166
+ optimizer (`str`, defaults to `adamw`):
167
+ The optimizer type to use. Choose between ['adam', 'adamw'].
168
+ use_8bit_bnb (`bool`, defaults to `False`):
169
+ Whether to use 8bit variant of the `optimizer` using `bitsandbytes`.
170
+ lr (`float`, defaults to `1e-4`):
171
+ Initial learning rate (after the potential warmup period) to use.
172
+ scale_lr (`bool`, defaults to `False`):
173
+ Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.
174
+ lr_scheduler (`str`, defaults to `cosine_with_restarts`):
175
+ The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial',
176
+ 'constant', 'constant_with_warmup'].
177
+ lr_warmup_steps (`int`, defaults to `500`):
178
+ Number of steps for the warmup in the lr scheduler.
179
+ lr_num_cycles (`int`, defaults to `1`):
180
+ Number of hard resets of the lr in cosine_with_restarts scheduler.
181
+ lr_power (`float`, defaults to `1.0`):
182
+ Power factor of the polynomial scheduler.
183
+ beta1 (`float`, defaults to `0.9`):
184
+ beta2 (`float`, defaults to `0.95`):
185
+ beta3 (`float`, defaults to `0.999`):
186
+ weight_decay (`float`, defaults to `0.0001`):
187
+ Penalty for large weights in the model.
188
+ epsilon (`float`, defaults to `1e-8`):
189
+ Small value to avoid division by zero in the optimizer.
190
+ max_grad_norm (`float`, defaults to `1.0`):
191
+ Maximum gradient norm to clip the gradients.
192
+
193
+ VALIDATION ARGUMENTS
194
+ --------------------
195
+ validation_prompts (`List[str]`, defaults to `None`):
196
+ List of prompts to use for validation. If not provided, a random prompt will be selected from the training
197
+ dataset.
198
+ validation_images (`List[str]`, defaults to `None`):
199
+ List of image paths to use for validation.
200
+ validation_videos (`List[str]`, defaults to `None`):
201
+ List of video paths to use for validation.
202
+ validation_heights (`List[int]`, defaults to `None`):
203
+ List of heights for the validation videos.
204
+ validation_widths (`List[int]`, defaults to `None`):
205
+ List of widths for the validation videos.
206
+ validation_num_frames (`List[int]`, defaults to `None`):
207
+ List of number of frames for the validation videos.
208
+ num_validation_videos_per_prompt (`int`, defaults to `1`):
209
+ Number of videos to use for validation per prompt.
210
+ validation_every_n_epochs (`int`, defaults to `None`):
211
+ Perform validation every `n` training epochs.
212
+ validation_every_n_steps (`int`, defaults to `None`):
213
+ Perform validation every `n` training steps.
214
+ enable_model_cpu_offload (`bool`, defaults to `False`):
215
+ Whether or not to offload different modeling components to CPU during validation.
216
+ validation_frame_rate (`int`, defaults to `25`):
217
+ Frame rate to use for the validation videos. This value is defaulted to 25, as used in LTX Video pipeline.
218
+
219
+ MISCELLANEOUS ARGUMENTS
220
+ -----------------------
221
+ tracker_name (`str`, defaults to `finetrainers`):
222
+ Name of the tracker/project to use for logging training metrics.
223
+ push_to_hub (`bool`, defaults to `False`):
224
+ Whether or not to push the model to the Hugging Face Hub.
225
+ hub_token (`str`, defaults to `None`):
226
+ The API token to use for pushing the model to the Hugging Face Hub.
227
+ hub_model_id (`str`, defaults to `None`):
228
+ The model identifier to use for pushing the model to the Hugging Face Hub.
229
+ output_dir (`str`, defaults to `None`):
230
+ The directory where the model checkpoints and logs will be stored.
231
+ logging_dir (`str`, defaults to `logs`):
232
+ The directory where the logs will be stored.
233
+ allow_tf32 (`bool`, defaults to `False`):
234
+ Whether or not to allow the use of TF32 matmul on compatible hardware.
235
+ nccl_timeout (`int`, defaults to `1800`):
236
+ Timeout for the NCCL communication.
237
+ report_to (`str`, defaults to `wandb`):
238
+ The name of the logger to use for logging training metrics. Choose between ['wandb'].
239
+ """
240
+
241
+ # Model arguments
242
+ model_name: str = None
243
+ pretrained_model_name_or_path: str = None
244
+ revision: Optional[str] = None
245
+ variant: Optional[str] = None
246
+ cache_dir: Optional[str] = None
247
+ text_encoder_dtype: torch.dtype = torch.bfloat16
248
+ text_encoder_2_dtype: torch.dtype = torch.bfloat16
249
+ text_encoder_3_dtype: torch.dtype = torch.bfloat16
250
+ transformer_dtype: torch.dtype = torch.bfloat16
251
+ vae_dtype: torch.dtype = torch.bfloat16
252
+ layerwise_upcasting_modules: List[str] = []
253
+ layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn
254
+ layerwise_upcasting_skip_modules_pattern: List[str] = [
255
+ "patch_embed",
256
+ "pos_embed",
257
+ "x_embedder",
258
+ "context_embedder",
259
+ "time_embed",
260
+ "^proj_in$",
261
+ "^proj_out$",
262
+ "norm",
263
+ ]
264
+
265
+ # Dataset arguments
266
+ data_root: str = None
267
+ dataset_file: Optional[str] = None
268
+ video_column: str = None
269
+ caption_column: str = None
270
+ id_token: Optional[str] = None
271
+ image_resolution_buckets: List[Tuple[int, int]] = None
272
+ video_resolution_buckets: List[Tuple[int, int, int]] = None
273
+ video_reshape_mode: Optional[str] = None
274
+ caption_dropout_p: float = 0.00
275
+ caption_dropout_technique: str = "empty"
276
+ precompute_conditions: bool = False
277
+ remove_common_llm_caption_prefixes: bool = False
278
+
279
+ # Dataloader arguments
280
+ dataloader_num_workers: int = 0
281
+ pin_memory: bool = False
282
+
283
+ # Diffusion arguments
284
+ flow_resolution_shifting: bool = False
285
+ flow_base_seq_len: int = 256
286
+ flow_max_seq_len: int = 4096
287
+ flow_base_shift: float = 0.5
288
+ flow_max_shift: float = 1.15
289
+ flow_shift: float = 1.0
290
+ flow_weighting_scheme: str = "none"
291
+ flow_logit_mean: float = 0.0
292
+ flow_logit_std: float = 1.0
293
+ flow_mode_scale: float = 1.29
294
+
295
+ # Training arguments
296
+ training_type: str = None
297
+ seed: int = 42
298
+ batch_size: int = 1
299
+ train_epochs: int = 1
300
+ train_steps: int = None
301
+ rank: int = 128
302
+ lora_alpha: float = 64
303
+ target_modules: List[str] = ["to_k", "to_q", "to_v", "to_out.0"]
304
+ gradient_accumulation_steps: int = 1
305
+ gradient_checkpointing: bool = False
306
+ checkpointing_steps: int = 500
307
+ checkpointing_limit: Optional[int] = None
308
+ resume_from_checkpoint: Optional[str] = None
309
+ enable_slicing: bool = False
310
+ enable_tiling: bool = False
311
+
312
+ # Optimizer arguments
313
+ optimizer: str = "adamw"
314
+ use_8bit_bnb: bool = False
315
+ lr: float = 1e-4
316
+ scale_lr: bool = False
317
+ lr_scheduler: str = "cosine_with_restarts"
318
+ lr_warmup_steps: int = 0
319
+ lr_num_cycles: int = 1
320
+ lr_power: float = 1.0
321
+ beta1: float = 0.9
322
+ beta2: float = 0.95
323
+ beta3: float = 0.999
324
+ weight_decay: float = 0.0001
325
+ epsilon: float = 1e-8
326
+ max_grad_norm: float = 1.0
327
+
328
+ # Validation arguments
329
+ validation_prompts: List[str] = None
330
+ validation_images: List[str] = None
331
+ validation_videos: List[str] = None
332
+ validation_heights: List[int] = None
333
+ validation_widths: List[int] = None
334
+ validation_num_frames: List[int] = None
335
+ num_validation_videos_per_prompt: int = 1
336
+ validation_every_n_epochs: Optional[int] = None
337
+ validation_every_n_steps: Optional[int] = None
338
+ enable_model_cpu_offload: bool = False
339
+ validation_frame_rate: int = 25
340
+
341
+ # Miscellaneous arguments
342
+ tracker_name: str = "finetrainers"
343
+ push_to_hub: bool = False
344
+ hub_token: Optional[str] = None
345
+ hub_model_id: Optional[str] = None
346
+ output_dir: str = None
347
+ logging_dir: Optional[str] = "logs"
348
+ allow_tf32: bool = False
349
+ nccl_timeout: int = 1800 # 30 minutes
350
+ report_to: str = "wandb"
351
+
352
+ def to_dict(self) -> Dict[str, Any]:
353
+ return {
354
+ "model_arguments": {
355
+ "model_name": self.model_name,
356
+ "pretrained_model_name_or_path": self.pretrained_model_name_or_path,
357
+ "revision": self.revision,
358
+ "variant": self.variant,
359
+ "cache_dir": self.cache_dir,
360
+ "text_encoder_dtype": self.text_encoder_dtype,
361
+ "text_encoder_2_dtype": self.text_encoder_2_dtype,
362
+ "text_encoder_3_dtype": self.text_encoder_3_dtype,
363
+ "transformer_dtype": self.transformer_dtype,
364
+ "vae_dtype": self.vae_dtype,
365
+ "layerwise_upcasting_modules": self.layerwise_upcasting_modules,
366
+ "layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype,
367
+ "layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern,
368
+ },
369
+ "dataset_arguments": {
370
+ "data_root": self.data_root,
371
+ "dataset_file": self.dataset_file,
372
+ "video_column": self.video_column,
373
+ "caption_column": self.caption_column,
374
+ "id_token": self.id_token,
375
+ "image_resolution_buckets": self.image_resolution_buckets,
376
+ "video_resolution_buckets": self.video_resolution_buckets,
377
+ "video_reshape_mode": self.video_reshape_mode,
378
+ "caption_dropout_p": self.caption_dropout_p,
379
+ "caption_dropout_technique": self.caption_dropout_technique,
380
+ "precompute_conditions": self.precompute_conditions,
381
+ "remove_common_llm_caption_prefixes": self.remove_common_llm_caption_prefixes,
382
+ },
383
+ "dataloader_arguments": {
384
+ "dataloader_num_workers": self.dataloader_num_workers,
385
+ "pin_memory": self.pin_memory,
386
+ },
387
+ "diffusion_arguments": {
388
+ "flow_resolution_shifting": self.flow_resolution_shifting,
389
+ "flow_base_seq_len": self.flow_base_seq_len,
390
+ "flow_max_seq_len": self.flow_max_seq_len,
391
+ "flow_base_shift": self.flow_base_shift,
392
+ "flow_max_shift": self.flow_max_shift,
393
+ "flow_shift": self.flow_shift,
394
+ "flow_weighting_scheme": self.flow_weighting_scheme,
395
+ "flow_logit_mean": self.flow_logit_mean,
396
+ "flow_logit_std": self.flow_logit_std,
397
+ "flow_mode_scale": self.flow_mode_scale,
398
+ },
399
+ "training_arguments": {
400
+ "training_type": self.training_type,
401
+ "seed": self.seed,
402
+ "batch_size": self.batch_size,
403
+ "train_epochs": self.train_epochs,
404
+ "train_steps": self.train_steps,
405
+ "rank": self.rank,
406
+ "lora_alpha": self.lora_alpha,
407
+ "target_modules": self.target_modules,
408
+ "gradient_accumulation_steps": self.gradient_accumulation_steps,
409
+ "gradient_checkpointing": self.gradient_checkpointing,
410
+ "checkpointing_steps": self.checkpointing_steps,
411
+ "checkpointing_limit": self.checkpointing_limit,
412
+ "resume_from_checkpoint": self.resume_from_checkpoint,
413
+ "enable_slicing": self.enable_slicing,
414
+ "enable_tiling": self.enable_tiling,
415
+ },
416
+ "optimizer_arguments": {
417
+ "optimizer": self.optimizer,
418
+ "use_8bit_bnb": self.use_8bit_bnb,
419
+ "lr": self.lr,
420
+ "scale_lr": self.scale_lr,
421
+ "lr_scheduler": self.lr_scheduler,
422
+ "lr_warmup_steps": self.lr_warmup_steps,
423
+ "lr_num_cycles": self.lr_num_cycles,
424
+ "lr_power": self.lr_power,
425
+ "beta1": self.beta1,
426
+ "beta2": self.beta2,
427
+ "beta3": self.beta3,
428
+ "weight_decay": self.weight_decay,
429
+ "epsilon": self.epsilon,
430
+ "max_grad_norm": self.max_grad_norm,
431
+ },
432
+ "validation_arguments": {
433
+ "validation_prompts": self.validation_prompts,
434
+ "validation_images": self.validation_images,
435
+ "validation_videos": self.validation_videos,
436
+ "num_validation_videos_per_prompt": self.num_validation_videos_per_prompt,
437
+ "validation_every_n_epochs": self.validation_every_n_epochs,
438
+ "validation_every_n_steps": self.validation_every_n_steps,
439
+ "enable_model_cpu_offload": self.enable_model_cpu_offload,
440
+ "validation_frame_rate": self.validation_frame_rate,
441
+ },
442
+ "miscellaneous_arguments": {
443
+ "tracker_name": self.tracker_name,
444
+ "push_to_hub": self.push_to_hub,
445
+ "hub_token": self.hub_token,
446
+ "hub_model_id": self.hub_model_id,
447
+ "output_dir": self.output_dir,
448
+ "logging_dir": self.logging_dir,
449
+ "allow_tf32": self.allow_tf32,
450
+ "nccl_timeout": self.nccl_timeout,
451
+ "report_to": self.report_to,
452
+ },
453
+ }
454
+
455
+
456
+ # TODO(aryan): handle more informative messages
457
+ _IS_ARGUMENTS_REQUIRED = "--list_models" not in sys.argv
458
+
459
+
460
+ def parse_arguments() -> Args:
461
+ parser = argparse.ArgumentParser()
462
+
463
+ if _IS_ARGUMENTS_REQUIRED:
464
+ _add_model_arguments(parser)
465
+ _add_dataset_arguments(parser)
466
+ _add_dataloader_arguments(parser)
467
+ _add_diffusion_arguments(parser)
468
+ _add_training_arguments(parser)
469
+ _add_optimizer_arguments(parser)
470
+ _add_validation_arguments(parser)
471
+ _add_miscellaneous_arguments(parser)
472
+
473
+ args = parser.parse_args()
474
+ return _map_to_args_type(args)
475
+ else:
476
+ _add_helper_arguments(parser)
477
+
478
+ args = parser.parse_args()
479
+ _display_helper_messages(args)
480
+ sys.exit(0)
481
+
482
+
483
+ def validate_args(args: Args):
484
+ _validated_model_args(args)
485
+ _validate_training_args(args)
486
+ _validate_validation_args(args)
487
+
488
+
489
+ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
490
+ parser.add_argument(
491
+ "--model_name",
492
+ type=str,
493
+ required=True,
494
+ choices=list(SUPPORTED_MODEL_CONFIGS.keys()),
495
+ help="Name of model to train.",
496
+ )
497
+ parser.add_argument(
498
+ "--pretrained_model_name_or_path",
499
+ type=str,
500
+ required=True,
501
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
502
+ )
503
+ parser.add_argument(
504
+ "--revision",
505
+ type=str,
506
+ default=None,
507
+ required=False,
508
+ help="Revision of pretrained model identifier from huggingface.co/models.",
509
+ )
510
+ parser.add_argument(
511
+ "--variant",
512
+ type=str,
513
+ default=None,
514
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
515
+ )
516
+ parser.add_argument(
517
+ "--cache_dir",
518
+ type=str,
519
+ default=None,
520
+ help="The directory where the downloaded models and datasets will be stored.",
521
+ )
522
+ parser.add_argument("--text_encoder_dtype", type=str, default="bf16", help="Data type for the text encoder.")
523
+ parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16", help="Data type for the text encoder 2.")
524
+ parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.")
525
+ parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.")
526
+ parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.")
527
+ parser.add_argument(
528
+ "--layerwise_upcasting_modules",
529
+ type=str,
530
+ default=[],
531
+ nargs="+",
532
+ choices=["transformer"],
533
+ help="Modules that should have fp8 storage weights but higher precision computation.",
534
+ )
535
+ parser.add_argument(
536
+ "--layerwise_upcasting_storage_dtype",
537
+ type=str,
538
+ default="float8_e4m3fn",
539
+ choices=["float8_e4m3fn", "float8_e5m2"],
540
+ help="Data type for the layerwise upcasting storage.",
541
+ )
542
+ parser.add_argument(
543
+ "--layerwise_upcasting_skip_modules_pattern",
544
+ type=str,
545
+ default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"],
546
+ nargs="+",
547
+ help="Modules to skip for layerwise upcasting.",
548
+ )
549
+
550
+
551
+ def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
552
+ def parse_resolution_bucket(resolution_bucket: str) -> Tuple[int, ...]:
553
+ return tuple(map(int, resolution_bucket.split("x")))
554
+
555
+ def parse_image_resolution_bucket(resolution_bucket: str) -> Tuple[int, int]:
556
+ resolution_bucket = parse_resolution_bucket(resolution_bucket)
557
+ assert (
558
+ len(resolution_bucket) == 2
559
+ ), f"Expected 2D resolution bucket, got {len(resolution_bucket)}D resolution bucket"
560
+ return resolution_bucket
561
+
562
+ def parse_video_resolution_bucket(resolution_bucket: str) -> Tuple[int, int, int]:
563
+ resolution_bucket = parse_resolution_bucket(resolution_bucket)
564
+ assert (
565
+ len(resolution_bucket) == 3
566
+ ), f"Expected 3D resolution bucket, got {len(resolution_bucket)}D resolution bucket"
567
+ return resolution_bucket
568
+
569
+ parser.add_argument(
570
+ "--data_root",
571
+ type=str,
572
+ required=True,
573
+ help=("A folder containing the training data."),
574
+ )
575
+ parser.add_argument(
576
+ "--dataset_file",
577
+ type=str,
578
+ default=None,
579
+ help=("Path to a CSV file if loading prompts/video paths using this format."),
580
+ )
581
+ parser.add_argument(
582
+ "--video_column",
583
+ type=str,
584
+ default="video",
585
+ help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.",
586
+ )
587
+ parser.add_argument(
588
+ "--caption_column",
589
+ type=str,
590
+ default="text",
591
+ help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.",
592
+ )
593
+ parser.add_argument(
594
+ "--id_token",
595
+ type=str,
596
+ default=None,
597
+ help="Identifier token appended to the start of each prompt if provided.",
598
+ )
599
+ parser.add_argument(
600
+ "--image_resolution_buckets",
601
+ type=parse_image_resolution_bucket,
602
+ default=None,
603
+ nargs="+",
604
+ help="Resolution buckets for images.",
605
+ )
606
+ parser.add_argument(
607
+ "--video_resolution_buckets",
608
+ type=parse_video_resolution_bucket,
609
+ default=None,
610
+ nargs="+",
611
+ help="Resolution buckets for videos.",
612
+ )
613
+ parser.add_argument(
614
+ "--video_reshape_mode",
615
+ type=str,
616
+ default=None,
617
+ help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
618
+ )
619
+ parser.add_argument(
620
+ "--caption_dropout_p",
621
+ type=float,
622
+ default=0.00,
623
+ help="Probability of dropout for the caption tokens.",
624
+ )
625
+ parser.add_argument(
626
+ "--caption_dropout_technique",
627
+ type=str,
628
+ default="empty",
629
+ choices=["empty", "zero"],
630
+ help="Technique to use for caption dropout.",
631
+ )
632
+ parser.add_argument(
633
+ "--precompute_conditions",
634
+ action="store_true",
635
+ help="Whether or not to precompute the conditionings for the model.",
636
+ )
637
+ parser.add_argument(
638
+ "--remove_common_llm_caption_prefixes",
639
+ action="store_true",
640
+ help="Whether or not to remove common LLM caption prefixes.",
641
+ )
642
+
643
+
644
+ def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None:
645
+ parser.add_argument(
646
+ "--dataloader_num_workers",
647
+ type=int,
648
+ default=0,
649
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
650
+ )
651
+ parser.add_argument(
652
+ "--pin_memory",
653
+ action="store_true",
654
+ help="Whether or not to use the pinned memory setting in pytorch dataloader.",
655
+ )
656
+
657
+
658
+ def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None:
659
+ parser.add_argument(
660
+ "--flow_resolution_shifting",
661
+ action="store_true",
662
+ help="Resolution-dependent shifting of timestep schedules.",
663
+ )
664
+ parser.add_argument(
665
+ "--flow_base_seq_len",
666
+ type=int,
667
+ default=256,
668
+ help="Base image/video sequence length for the diffusion model.",
669
+ )
670
+ parser.add_argument(
671
+ "--flow_max_seq_len",
672
+ type=int,
673
+ default=4096,
674
+ help="Maximum image/video sequence length for the diffusion model.",
675
+ )
676
+ parser.add_argument(
677
+ "--flow_base_shift",
678
+ type=float,
679
+ default=0.5,
680
+ help="Base shift as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206)",
681
+ )
682
+ parser.add_argument(
683
+ "--flow_max_shift",
684
+ type=float,
685
+ default=1.15,
686
+ help="Maximum shift as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206)",
687
+ )
688
+ parser.add_argument(
689
+ "--flow_shift",
690
+ type=float,
691
+ default=1.0,
692
+ help="Shift value to use for the flow matching timestep schedule.",
693
+ )
694
+ parser.add_argument(
695
+ "--flow_weighting_scheme",
696
+ type=str,
697
+ default="none",
698
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
699
+ help='We default to the "none" weighting scheme for uniform sampling and uniform loss',
700
+ )
701
+ parser.add_argument(
702
+ "--flow_logit_mean",
703
+ type=float,
704
+ default=0.0,
705
+ help="Mean to use when using the `'logit_normal'` weighting scheme.",
706
+ )
707
+ parser.add_argument(
708
+ "--flow_logit_std",
709
+ type=float,
710
+ default=1.0,
711
+ help="Standard deviation to use when using the `'logit_normal'` weighting scheme.",
712
+ )
713
+ parser.add_argument(
714
+ "--flow_mode_scale",
715
+ type=float,
716
+ default=1.29,
717
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
718
+ )
719
+
720
+
721
+ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
722
+ # TODO: support full finetuning and other kinds
723
+ parser.add_argument(
724
+ "--training_type",
725
+ type=str,
726
+ choices=["lora", "full-finetune"],
727
+ required=True,
728
+ help="Type of training to perform. Choose between ['lora', 'full-finetune']",
729
+ )
730
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
731
+ parser.add_argument(
732
+ "--batch_size",
733
+ type=int,
734
+ default=1,
735
+ help="Batch size (per device) for the training dataloader.",
736
+ )
737
+ parser.add_argument("--train_epochs", type=int, default=1, help="Number of training epochs.")
738
+ parser.add_argument(
739
+ "--train_steps",
740
+ type=int,
741
+ default=None,
742
+ help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
743
+ )
744
+ parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.")
745
+ parser.add_argument(
746
+ "--lora_alpha",
747
+ type=int,
748
+ default=64,
749
+ help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
750
+ )
751
+ parser.add_argument(
752
+ "--target_modules",
753
+ type=str,
754
+ default=["to_k", "to_q", "to_v", "to_out.0"],
755
+ nargs="+",
756
+ help="The target modules for LoRA.",
757
+ )
758
+ parser.add_argument(
759
+ "--gradient_accumulation_steps",
760
+ type=int,
761
+ default=1,
762
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
763
+ )
764
+ parser.add_argument(
765
+ "--gradient_checkpointing",
766
+ action="store_true",
767
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
768
+ )
769
+ parser.add_argument(
770
+ "--checkpointing_steps",
771
+ type=int,
772
+ default=500,
773
+ help=(
774
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
775
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
776
+ " training using `--resume_from_checkpoint`."
777
+ ),
778
+ )
779
+ parser.add_argument(
780
+ "--checkpointing_limit",
781
+ type=int,
782
+ default=None,
783
+ help=("Max number of checkpoints to store."),
784
+ )
785
+ parser.add_argument(
786
+ "--resume_from_checkpoint",
787
+ type=str,
788
+ default=None,
789
+ help=(
790
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
791
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
792
+ ),
793
+ )
794
+ parser.add_argument(
795
+ "--enable_slicing",
796
+ action="store_true",
797
+ help="Whether or not to use VAE slicing for saving memory.",
798
+ )
799
+ parser.add_argument(
800
+ "--enable_tiling",
801
+ action="store_true",
802
+ help="Whether or not to use VAE tiling for saving memory.",
803
+ )
804
+
805
+
806
+ def _add_optimizer_arguments(parser: argparse.ArgumentParser) -> None:
807
+ parser.add_argument(
808
+ "--lr",
809
+ type=float,
810
+ default=1e-4,
811
+ help="Initial learning rate (after the potential warmup period) to use.",
812
+ )
813
+ parser.add_argument(
814
+ "--scale_lr",
815
+ action="store_true",
816
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
817
+ )
818
+ parser.add_argument(
819
+ "--lr_scheduler",
820
+ type=str,
821
+ default="constant",
822
+ help=(
823
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
824
+ ' "constant", "constant_with_warmup"]'
825
+ ),
826
+ )
827
+ parser.add_argument(
828
+ "--lr_warmup_steps",
829
+ type=int,
830
+ default=500,
831
+ help="Number of steps for the warmup in the lr scheduler.",
832
+ )
833
+ parser.add_argument(
834
+ "--lr_num_cycles",
835
+ type=int,
836
+ default=1,
837
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
838
+ )
839
+ parser.add_argument(
840
+ "--lr_power",
841
+ type=float,
842
+ default=1.0,
843
+ help="Power factor of the polynomial scheduler.",
844
+ )
845
+ parser.add_argument(
846
+ "--optimizer",
847
+ type=lambda s: s.lower(),
848
+ default="adam",
849
+ choices=["adam", "adamw"],
850
+ help=("The optimizer type to use."),
851
+ )
852
+ parser.add_argument(
853
+ "--use_8bit_bnb",
854
+ action="store_true",
855
+ help=("Whether to use 8bit variant of the `--optimizer` using `bitsandbytes`."),
856
+ )
857
+ parser.add_argument(
858
+ "--beta1",
859
+ type=float,
860
+ default=0.9,
861
+ help="The beta1 parameter for the Adam and Prodigy optimizers.",
862
+ )
863
+ parser.add_argument(
864
+ "--beta2",
865
+ type=float,
866
+ default=0.95,
867
+ help="The beta2 parameter for the Adam and Prodigy optimizers.",
868
+ )
869
+ parser.add_argument(
870
+ "--beta3",
871
+ type=float,
872
+ default=None,
873
+ help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
874
+ )
875
+ parser.add_argument(
876
+ "--weight_decay",
877
+ type=float,
878
+ default=1e-04,
879
+ help="Weight decay to use for optimizer.",
880
+ )
881
+ parser.add_argument(
882
+ "--epsilon",
883
+ type=float,
884
+ default=1e-8,
885
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
886
+ )
887
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
888
+
889
+
890
+ def _add_validation_arguments(parser: argparse.ArgumentParser) -> None:
891
+ parser.add_argument(
892
+ "--validation_prompts",
893
+ type=str,
894
+ default=None,
895
+ help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
896
+ )
897
+ parser.add_argument(
898
+ "--validation_images",
899
+ type=str,
900
+ default=None,
901
+ help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
902
+ )
903
+ parser.add_argument(
904
+ "--validation_videos",
905
+ type=str,
906
+ default=None,
907
+ help="One or more video path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
908
+ )
909
+ parser.add_argument(
910
+ "--validation_separator",
911
+ type=str,
912
+ default=":::",
913
+ help="String that separates multiple validation prompts",
914
+ )
915
+ parser.add_argument(
916
+ "--num_validation_videos",
917
+ type=int,
918
+ default=1,
919
+ help="Number of videos that should be generated during validation per `validation_prompt`.",
920
+ )
921
+ parser.add_argument(
922
+ "--validation_epochs",
923
+ type=int,
924
+ default=None,
925
+ help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.",
926
+ )
927
+ parser.add_argument(
928
+ "--validation_steps",
929
+ type=int,
930
+ default=None,
931
+ help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
932
+ )
933
+ parser.add_argument(
934
+ "--validation_frame_rate",
935
+ type=int,
936
+ default=25,
937
+ help="Frame rate to use for the validation videos.",
938
+ )
939
+ parser.add_argument(
940
+ "--enable_model_cpu_offload",
941
+ action="store_true",
942
+ help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
943
+ )
944
+
945
+
946
+ def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
947
+ parser.add_argument("--tracker_name", type=str, default="finetrainers", help="Project tracker name")
948
+ parser.add_argument(
949
+ "--push_to_hub",
950
+ action="store_true",
951
+ help="Whether or not to push the model to the Hub.",
952
+ )
953
+ parser.add_argument(
954
+ "--hub_token",
955
+ type=str,
956
+ default=None,
957
+ help="The token to use to push to the Model Hub.",
958
+ )
959
+ parser.add_argument(
960
+ "--hub_model_id",
961
+ type=str,
962
+ default=None,
963
+ help="The name of the repository to keep in sync with the local `output_dir`.",
964
+ )
965
+ parser.add_argument(
966
+ "--output_dir",
967
+ type=str,
968
+ default="finetrainers-training",
969
+ help="The output directory where the model predictions and checkpoints will be written.",
970
+ )
971
+ parser.add_argument(
972
+ "--logging_dir",
973
+ type=str,
974
+ default="logs",
975
+ help="Directory where logs are stored.",
976
+ )
977
+ parser.add_argument(
978
+ "--allow_tf32",
979
+ action="store_true",
980
+ help=(
981
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
982
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
983
+ ),
984
+ )
985
+ parser.add_argument(
986
+ "--nccl_timeout",
987
+ type=int,
988
+ default=600,
989
+ help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.",
990
+ )
991
+ parser.add_argument(
992
+ "--report_to",
993
+ type=str,
994
+ default="none",
995
+ choices=["none", "wandb"],
996
+ help="The integration to report the results and logs to.",
997
+ )
998
+
999
+
1000
+ def _add_helper_arguments(parser: argparse.ArgumentParser) -> None:
1001
+ parser.add_argument(
1002
+ "--list_models",
1003
+ action="store_true",
1004
+ help="List all the supported models.",
1005
+ )
1006
+
1007
+
1008
+ _DTYPE_MAP = {
1009
+ "bf16": torch.bfloat16,
1010
+ "fp16": torch.float16,
1011
+ "fp32": torch.float32,
1012
+ "float8_e4m3fn": torch.float8_e4m3fn,
1013
+ "float8_e5m2": torch.float8_e5m2,
1014
+ }
1015
+
1016
+
1017
+ def _map_to_args_type(args: Dict[str, Any]) -> Args:
1018
+ result_args = Args()
1019
+
1020
+ # Model arguments
1021
+ result_args.model_name = args.model_name
1022
+ result_args.pretrained_model_name_or_path = args.pretrained_model_name_or_path
1023
+ result_args.revision = args.revision
1024
+ result_args.variant = args.variant
1025
+ result_args.cache_dir = args.cache_dir
1026
+ result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype]
1027
+ result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype]
1028
+ result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
1029
+ result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype]
1030
+ result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype]
1031
+ result_args.layerwise_upcasting_modules = args.layerwise_upcasting_modules
1032
+ result_args.layerwise_upcasting_storage_dtype = _DTYPE_MAP[args.layerwise_upcasting_storage_dtype]
1033
+ result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern
1034
+
1035
+ # Dataset arguments
1036
+ if args.data_root is None and args.dataset_file is None:
1037
+ raise ValueError("At least one of `data_root` or `dataset_file` should be provided.")
1038
+
1039
+ result_args.data_root = args.data_root
1040
+ result_args.dataset_file = args.dataset_file
1041
+ result_args.video_column = args.video_column
1042
+ result_args.caption_column = args.caption_column
1043
+ result_args.id_token = args.id_token
1044
+ result_args.image_resolution_buckets = args.image_resolution_buckets or DEFAULT_IMAGE_RESOLUTION_BUCKETS
1045
+ result_args.video_resolution_buckets = args.video_resolution_buckets or DEFAULT_VIDEO_RESOLUTION_BUCKETS
1046
+ result_args.video_reshape_mode = args.video_reshape_mode
1047
+ result_args.caption_dropout_p = args.caption_dropout_p
1048
+ result_args.caption_dropout_technique = args.caption_dropout_technique
1049
+ result_args.precompute_conditions = args.precompute_conditions
1050
+ result_args.remove_common_llm_caption_prefixes = args.remove_common_llm_caption_prefixes
1051
+
1052
+ # Dataloader arguments
1053
+ result_args.dataloader_num_workers = args.dataloader_num_workers
1054
+ result_args.pin_memory = args.pin_memory
1055
+
1056
+ # Diffusion arguments
1057
+ result_args.flow_resolution_shifting = args.flow_resolution_shifting
1058
+ result_args.flow_base_seq_len = args.flow_base_seq_len
1059
+ result_args.flow_max_seq_len = args.flow_max_seq_len
1060
+ result_args.flow_base_shift = args.flow_base_shift
1061
+ result_args.flow_max_shift = args.flow_max_shift
1062
+ result_args.flow_shift = args.flow_shift
1063
+ result_args.flow_weighting_scheme = args.flow_weighting_scheme
1064
+ result_args.flow_logit_mean = args.flow_logit_mean
1065
+ result_args.flow_logit_std = args.flow_logit_std
1066
+ result_args.flow_mode_scale = args.flow_mode_scale
1067
+
1068
+ # Training arguments
1069
+ result_args.training_type = args.training_type
1070
+ result_args.seed = args.seed
1071
+ result_args.batch_size = args.batch_size
1072
+ result_args.train_epochs = args.train_epochs
1073
+ result_args.train_steps = args.train_steps
1074
+ result_args.rank = args.rank
1075
+ result_args.lora_alpha = args.lora_alpha
1076
+ result_args.target_modules = args.target_modules
1077
+ result_args.gradient_accumulation_steps = args.gradient_accumulation_steps
1078
+ result_args.gradient_checkpointing = args.gradient_checkpointing
1079
+ result_args.checkpointing_steps = args.checkpointing_steps
1080
+ result_args.checkpointing_limit = args.checkpointing_limit
1081
+ result_args.resume_from_checkpoint = args.resume_from_checkpoint
1082
+ result_args.enable_slicing = args.enable_slicing
1083
+ result_args.enable_tiling = args.enable_tiling
1084
+
1085
+ # Optimizer arguments
1086
+ result_args.optimizer = args.optimizer or "adamw"
1087
+ result_args.use_8bit_bnb = args.use_8bit_bnb
1088
+ result_args.lr = args.lr or 1e-4
1089
+ result_args.scale_lr = args.scale_lr
1090
+ result_args.lr_scheduler = args.lr_scheduler
1091
+ result_args.lr_warmup_steps = args.lr_warmup_steps
1092
+ result_args.lr_num_cycles = args.lr_num_cycles
1093
+ result_args.lr_power = args.lr_power
1094
+ result_args.beta1 = args.beta1
1095
+ result_args.beta2 = args.beta2
1096
+ result_args.beta3 = args.beta3
1097
+ result_args.weight_decay = args.weight_decay
1098
+ result_args.epsilon = args.epsilon
1099
+ result_args.max_grad_norm = args.max_grad_norm
1100
+
1101
+ # Validation arguments
1102
+ validation_prompts = args.validation_prompts.split(args.validation_separator) if args.validation_prompts else []
1103
+ validation_images = args.validation_images.split(args.validation_separator) if args.validation_images else None
1104
+ validation_videos = args.validation_videos.split(args.validation_separator) if args.validation_videos else None
1105
+ stripped_validation_prompts = []
1106
+ validation_heights = []
1107
+ validation_widths = []
1108
+ validation_num_frames = []
1109
+ for prompt in validation_prompts:
1110
+ prompt: str
1111
+ prompt = prompt.strip()
1112
+ actual_prompt, separator, resolution = prompt.rpartition("@@@")
1113
+ stripped_validation_prompts.append(actual_prompt)
1114
+ num_frames, height, width = None, None, None
1115
+ if len(resolution) > 0:
1116
+ num_frames, height, width = map(int, resolution.split("x"))
1117
+ validation_num_frames.append(num_frames)
1118
+ validation_heights.append(height)
1119
+ validation_widths.append(width)
1120
+
1121
+ if validation_images is None:
1122
+ validation_images = [None] * len(validation_prompts)
1123
+ if validation_videos is None:
1124
+ validation_videos = [None] * len(validation_prompts)
1125
+
1126
+ result_args.validation_prompts = stripped_validation_prompts
1127
+ result_args.validation_heights = validation_heights
1128
+ result_args.validation_widths = validation_widths
1129
+ result_args.validation_num_frames = validation_num_frames
1130
+ result_args.validation_images = validation_images
1131
+ result_args.validation_videos = validation_videos
1132
+
1133
+ result_args.num_validation_videos_per_prompt = args.num_validation_videos
1134
+ result_args.validation_every_n_epochs = args.validation_epochs
1135
+ result_args.validation_every_n_steps = args.validation_steps
1136
+ result_args.enable_model_cpu_offload = args.enable_model_cpu_offload
1137
+ result_args.validation_frame_rate = args.validation_frame_rate
1138
+
1139
+ # Miscellaneous arguments
1140
+ result_args.tracker_name = args.tracker_name
1141
+ result_args.push_to_hub = args.push_to_hub
1142
+ result_args.hub_token = args.hub_token
1143
+ result_args.hub_model_id = args.hub_model_id
1144
+ result_args.output_dir = args.output_dir
1145
+ result_args.logging_dir = args.logging_dir
1146
+ result_args.allow_tf32 = args.allow_tf32
1147
+ result_args.nccl_timeout = args.nccl_timeout
1148
+ result_args.report_to = args.report_to
1149
+
1150
+ return result_args
1151
+
1152
+
1153
+ def _validated_model_args(args: Args):
1154
+ if args.training_type == "full-finetune":
1155
+ assert (
1156
+ "transformer" not in args.layerwise_upcasting_modules
1157
+ ), "Layerwise upcasting is not supported for full-finetune training"
1158
+
1159
+
1160
+ def _validate_training_args(args: Args):
1161
+ if args.training_type == "lora":
1162
+ assert args.rank is not None, "Rank is required for LoRA training"
1163
+ assert args.lora_alpha is not None, "LoRA alpha is required for LoRA training"
1164
+ assert (
1165
+ args.target_modules is not None and len(args.target_modules) > 0
1166
+ ), "Target modules are required for LoRA training"
1167
+
1168
+
1169
+ def _validate_validation_args(args: Args):
1170
+ assert args.validation_prompts is not None, "Validation prompts are required for validation"
1171
+ if args.validation_images is not None:
1172
+ assert len(args.validation_images) == len(
1173
+ args.validation_prompts
1174
+ ), "Validation images and prompts should be of same length"
1175
+ if args.validation_videos is not None:
1176
+ assert len(args.validation_videos) == len(
1177
+ args.validation_prompts
1178
+ ), "Validation videos and prompts should be of same length"
1179
+ assert len(args.validation_prompts) == len(
1180
+ args.validation_heights
1181
+ ), "Validation prompts and heights should be of same length"
1182
+ assert len(args.validation_prompts) == len(
1183
+ args.validation_widths
1184
+ ), "Validation prompts and widths should be of same length"
1185
+
1186
+
1187
+ def _display_helper_messages(args: argparse.Namespace):
1188
+ if args.list_models:
1189
+ print("Supported models:")
1190
+ for index, model_name in enumerate(SUPPORTED_MODEL_CONFIGS.keys()):
1191
+ print(f" {index + 1}. {model_name}")
finetrainers/constants.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ DEFAULT_HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
5
+ DEFAULT_WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
6
+ DEFAULT_FRAME_BUCKETS = [49]
7
+
8
+ DEFAULT_IMAGE_RESOLUTION_BUCKETS = []
9
+ for height in DEFAULT_HEIGHT_BUCKETS:
10
+ for width in DEFAULT_WIDTH_BUCKETS:
11
+ DEFAULT_IMAGE_RESOLUTION_BUCKETS.append((height, width))
12
+
13
+ DEFAULT_VIDEO_RESOLUTION_BUCKETS = []
14
+ for frames in DEFAULT_FRAME_BUCKETS:
15
+ for height in DEFAULT_HEIGHT_BUCKETS:
16
+ for width in DEFAULT_WIDTH_BUCKETS:
17
+ DEFAULT_VIDEO_RESOLUTION_BUCKETS.append((frames, height, width))
18
+
19
+
20
+ FINETRAINERS_LOG_LEVEL = os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO")
21
+
22
+ PRECOMPUTED_DIR_NAME = "precomputed"
23
+ PRECOMPUTED_CONDITIONS_DIR_NAME = "conditions"
24
+ PRECOMPUTED_LATENTS_DIR_NAME = "latents"
25
+
26
+ MODEL_DESCRIPTION = r"""
27
+ \# {model_id} {training_type} finetune
28
+
29
+ <Gallery />
30
+
31
+ \#\# Model Description
32
+
33
+ This model is a {training_type} of the `{model_id}` model.
34
+
35
+ This model was trained using the `fine-video-trainers` library - a repository containing memory-optimized scripts for training video models with [Diffusers](https://github.com/huggingface/diffusers).
36
+
37
+ \#\# Download model
38
+
39
+ [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
40
+
41
+ \#\# Usage
42
+
43
+ Requires [🧨 Diffusers](https://github.com/huggingface/diffusers) installed.
44
+
45
+ ```python
46
+ {model_example}
47
+ ```
48
+
49
+ For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
50
+
51
+ \#\# License
52
+
53
+ Please adhere to the license of the base model.
54
+ """.strip()
55
+
56
+ _COMMON_BEGINNING_PHRASES = (
57
+ "This video",
58
+ "The video",
59
+ "This clip",
60
+ "The clip",
61
+ "The animation",
62
+ "This image",
63
+ "The image",
64
+ "This picture",
65
+ "The picture",
66
+ )
67
+ _COMMON_CONTINUATION_WORDS = ("shows", "depicts", "features", "captures", "highlights", "introduces", "presents")
68
+
69
+ COMMON_LLM_START_PHRASES = (
70
+ "In the video,",
71
+ "In this video,",
72
+ "In this video clip,",
73
+ "In the clip,",
74
+ "Caption:",
75
+ *(
76
+ f"{beginning} {continuation}"
77
+ for beginning in _COMMON_BEGINNING_PHRASES
78
+ for continuation in _COMMON_CONTINUATION_WORDS
79
+ ),
80
+ )
finetrainers/dataset.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ import torchvision.transforms as TT
11
+ import torchvision.transforms.functional as TTF
12
+ from accelerate.logging import get_logger
13
+ from torch.utils.data import Dataset, Sampler
14
+ from torchvision import transforms
15
+ from torchvision.transforms import InterpolationMode
16
+ from torchvision.transforms.functional import resize
17
+
18
+
19
+ # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
20
+ # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
21
+ import decord # isort:skip
22
+
23
+ decord.bridge.set_bridge("torch")
24
+
25
+ from .constants import ( # noqa
26
+ COMMON_LLM_START_PHRASES,
27
+ PRECOMPUTED_CONDITIONS_DIR_NAME,
28
+ PRECOMPUTED_DIR_NAME,
29
+ PRECOMPUTED_LATENTS_DIR_NAME,
30
+ )
31
+
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ # TODO(aryan): This needs a refactor with separation of concerns.
37
+ # Images should be handled separately. Videos should be handled separately.
38
+ # Loading should be handled separately.
39
+ # Preprocessing (aspect ratio, resizing) should be handled separately.
40
+ # URL loading should be handled.
41
+ # Parquet format should be handled.
42
+ # Loading from ZIP should be handled.
43
+ class ImageOrVideoDataset(Dataset):
44
+ def __init__(
45
+ self,
46
+ data_root: str,
47
+ caption_column: str,
48
+ video_column: str,
49
+ resolution_buckets: List[Tuple[int, int, int]],
50
+ dataset_file: Optional[str] = None,
51
+ id_token: Optional[str] = None,
52
+ remove_llm_prefixes: bool = False,
53
+ ) -> None:
54
+ super().__init__()
55
+
56
+ self.data_root = Path(data_root)
57
+ self.dataset_file = dataset_file
58
+ self.caption_column = caption_column
59
+ self.video_column = video_column
60
+ self.id_token = f"{id_token.strip()} " if id_token else ""
61
+ self.resolution_buckets = resolution_buckets
62
+
63
+ # Four methods of loading data are supported.
64
+ # - Using a CSV: caption_column and video_column must be some column in the CSV. One could
65
+ # make use of other columns too, such as a motion score or aesthetic score, by modifying the
66
+ # logic in CSV processing.
67
+ # - Using two files containing line-separate captions and relative paths to videos.
68
+ # - Using a JSON file containing a list of dictionaries, where each dictionary has a `caption_column` and `video_column` key.
69
+ # - Using a JSONL file containing a list of line-separated dictionaries, where each dictionary has a `caption_column` and `video_column` key.
70
+ # For a more detailed explanation about preparing dataset format, checkout the README.
71
+ if dataset_file is None:
72
+ (
73
+ self.prompts,
74
+ self.video_paths,
75
+ ) = self._load_dataset_from_local_path()
76
+ elif dataset_file.endswith(".csv"):
77
+ (
78
+ self.prompts,
79
+ self.video_paths,
80
+ ) = self._load_dataset_from_csv()
81
+ elif dataset_file.endswith(".json"):
82
+ (
83
+ self.prompts,
84
+ self.video_paths,
85
+ ) = self._load_dataset_from_json()
86
+ elif dataset_file.endswith(".jsonl"):
87
+ (
88
+ self.prompts,
89
+ self.video_paths,
90
+ ) = self._load_dataset_from_jsonl()
91
+ else:
92
+ raise ValueError(
93
+ "Expected `--dataset_file` to be a path to a CSV file or a directory containing line-separated text prompts and video paths."
94
+ )
95
+
96
+ if len(self.video_paths) != len(self.prompts):
97
+ raise ValueError(
98
+ f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
99
+ )
100
+
101
+ # Clean LLM start phrases
102
+ if remove_llm_prefixes:
103
+ for i in range(len(self.prompts)):
104
+ self.prompts[i] = self.prompts[i].strip()
105
+ for phrase in COMMON_LLM_START_PHRASES:
106
+ if self.prompts[i].startswith(phrase):
107
+ self.prompts[i] = self.prompts[i].removeprefix(phrase).strip()
108
+
109
+ self.video_transforms = transforms.Compose(
110
+ [
111
+ transforms.Lambda(self.scale_transform),
112
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
113
+ ]
114
+ )
115
+
116
+ @staticmethod
117
+ def scale_transform(x):
118
+ return x / 255.0
119
+
120
+ def __len__(self) -> int:
121
+ return len(self.video_paths)
122
+
123
+ def __getitem__(self, index: int) -> Dict[str, Any]:
124
+ if isinstance(index, list):
125
+ # Here, index is actually a list of data objects that we need to return.
126
+ # The BucketSampler should ideally return indices. But, in the sampler, we'd like
127
+ # to have information about num_frames, height and width. Since this is not stored
128
+ # as metadata, we need to read the video to get this information. You could read this
129
+ # information without loading the full video in memory, but we do it anyway. In order
130
+ # to not load the video twice (once to get the metadata, and once to return the loaded video
131
+ # based on sampled indices), we cache it in the BucketSampler. When the sampler is
132
+ # to yield, we yield the cache data instead of indices. So, this special check ensures
133
+ # that data is not loaded a second time. PRs are welcome for improvements.
134
+ return index
135
+
136
+ prompt = self.id_token + self.prompts[index]
137
+
138
+ video_path: Path = self.video_paths[index]
139
+ if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]:
140
+ video = self._preprocess_image(video_path)
141
+ else:
142
+ video = self._preprocess_video(video_path)
143
+
144
+ return {
145
+ "prompt": prompt,
146
+ "video": video,
147
+ "video_metadata": {
148
+ "num_frames": video.shape[0],
149
+ "height": video.shape[2],
150
+ "width": video.shape[3],
151
+ },
152
+ }
153
+
154
+ def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]:
155
+ if not self.data_root.exists():
156
+ raise ValueError("Root folder for videos does not exist")
157
+
158
+ prompt_path = self.data_root.joinpath(self.caption_column)
159
+ video_path = self.data_root.joinpath(self.video_column)
160
+
161
+ if not prompt_path.exists() or not prompt_path.is_file():
162
+ raise ValueError(
163
+ "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts."
164
+ )
165
+ if not video_path.exists() or not video_path.is_file():
166
+ raise ValueError(
167
+ "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory."
168
+ )
169
+
170
+ with open(prompt_path, "r", encoding="utf-8") as file:
171
+ prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
172
+ with open(video_path, "r", encoding="utf-8") as file:
173
+ video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0]
174
+
175
+ if any(not path.is_file() for path in video_paths):
176
+ raise ValueError(
177
+ f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
178
+ )
179
+
180
+ return prompts, video_paths
181
+
182
+ def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]:
183
+ df = pd.read_csv(self.dataset_file)
184
+ prompts = df[self.caption_column].tolist()
185
+ video_paths = df[self.video_column].tolist()
186
+ video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths]
187
+
188
+ if any(not path.is_file() for path in video_paths):
189
+ raise ValueError(
190
+ f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
191
+ )
192
+
193
+ return prompts, video_paths
194
+
195
+ def _load_dataset_from_json(self) -> Tuple[List[str], List[str]]:
196
+ with open(self.dataset_file, "r", encoding="utf-8") as file:
197
+ data = json.load(file)
198
+
199
+ prompts = [entry[self.caption_column] for entry in data]
200
+ video_paths = [self.data_root.joinpath(entry[self.video_column].strip()) for entry in data]
201
+
202
+ if any(not path.is_file() for path in video_paths):
203
+ raise ValueError(
204
+ f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
205
+ )
206
+
207
+ return prompts, video_paths
208
+
209
+ def _load_dataset_from_jsonl(self) -> Tuple[List[str], List[str]]:
210
+ with open(self.dataset_file, "r", encoding="utf-8") as file:
211
+ data = [json.loads(line) for line in file]
212
+
213
+ prompts = [entry[self.caption_column] for entry in data]
214
+ video_paths = [self.data_root.joinpath(entry[self.video_column].strip()) for entry in data]
215
+
216
+ if any(not path.is_file() for path in video_paths):
217
+ raise ValueError(
218
+ f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
219
+ )
220
+
221
+ return prompts, video_paths
222
+
223
+ def _preprocess_image(self, path: Path) -> torch.Tensor:
224
+ # TODO(aryan): Support alpha channel in future by whitening background
225
+ image = TTF.Image.open(path.as_posix()).convert("RGB")
226
+ image = TTF.to_tensor(image)
227
+ image = image * 2.0 - 1.0
228
+ image = image.unsqueeze(0).contiguous() # [C, H, W] -> [1, C, H, W] (1-frame video)
229
+ return image
230
+
231
+ def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
232
+ r"""
233
+ Loads a single video, or latent and prompt embedding, based on initialization parameters.
234
+
235
+ Returns a [F, C, H, W] video tensor.
236
+ """
237
+ video_reader = decord.VideoReader(uri=path.as_posix())
238
+ video_num_frames = len(video_reader)
239
+
240
+ indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames))
241
+ frames = video_reader.get_batch(indices)
242
+ frames = frames[: self.max_num_frames].float()
243
+ frames = frames.permute(0, 3, 1, 2).contiguous()
244
+ frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0)
245
+ return frames
246
+
247
+
248
+ class ImageOrVideoDatasetWithResizing(ImageOrVideoDataset):
249
+ def __init__(self, *args, **kwargs) -> None:
250
+ super().__init__(*args, **kwargs)
251
+
252
+ self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0]
253
+
254
+ def _preprocess_image(self, path: Path) -> torch.Tensor:
255
+ # TODO(aryan): Support alpha channel in future by whitening background
256
+ image = TTF.Image.open(path.as_posix()).convert("RGB")
257
+ image = TTF.to_tensor(image)
258
+
259
+ nearest_res = self._find_nearest_resolution(image.shape[1], image.shape[2])
260
+ image = resize(image, nearest_res)
261
+
262
+ image = image * 2.0 - 1.0
263
+ image = image.unsqueeze(0).contiguous()
264
+ return image
265
+
266
+ def _preprocess_video(self, path: Path) -> torch.Tensor:
267
+ video_reader = decord.VideoReader(uri=path.as_posix())
268
+ video_num_frames = len(video_reader)
269
+ print(f"ImageOrVideoDatasetWithResizing: self.resolution_buckets = ", self.resolution_buckets)
270
+ print(f"ImageOrVideoDatasetWithResizing: self.max_num_frames = ", self.max_num_frames)
271
+ print(f"ImageOrVideoDatasetWithResizing: video_num_frames = ", video_num_frames)
272
+
273
+ video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
274
+
275
+ if not video_buckets:
276
+ _, h, w = self.resolution_buckets[0]
277
+ video_buckets = [(1, h, w)]
278
+
279
+ nearest_frame_bucket = min(
280
+ video_buckets,
281
+ key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)),
282
+ default=video_buckets[0],
283
+ )[0]
284
+
285
+ frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
286
+
287
+ frames = video_reader.get_batch(frame_indices)
288
+ frames = frames[:nearest_frame_bucket].float()
289
+ frames = frames.permute(0, 3, 1, 2).contiguous()
290
+
291
+ nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
292
+ frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0)
293
+ frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
294
+
295
+ return frames
296
+
297
+ def _find_nearest_resolution(self, height, width):
298
+ nearest_res = min(self.resolution_buckets, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
299
+ return nearest_res[1], nearest_res[2]
300
+
301
+
302
+ class ImageOrVideoDatasetWithResizeAndRectangleCrop(ImageOrVideoDataset):
303
+ def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None:
304
+ super().__init__(*args, **kwargs)
305
+
306
+ self.video_reshape_mode = video_reshape_mode
307
+ self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0]
308
+
309
+ def _resize_for_rectangle_crop(self, arr, image_size):
310
+ reshape_mode = self.video_reshape_mode
311
+ if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
312
+ arr = resize(
313
+ arr,
314
+ size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
315
+ interpolation=InterpolationMode.BICUBIC,
316
+ )
317
+ else:
318
+ arr = resize(
319
+ arr,
320
+ size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
321
+ interpolation=InterpolationMode.BICUBIC,
322
+ )
323
+
324
+ h, w = arr.shape[2], arr.shape[3]
325
+ arr = arr.squeeze(0)
326
+
327
+ delta_h = h - image_size[0]
328
+ delta_w = w - image_size[1]
329
+
330
+ if reshape_mode == "random" or reshape_mode == "none":
331
+ top = np.random.randint(0, delta_h + 1)
332
+ left = np.random.randint(0, delta_w + 1)
333
+ elif reshape_mode == "center":
334
+ top, left = delta_h // 2, delta_w // 2
335
+ else:
336
+ raise NotImplementedError
337
+ arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
338
+ return arr
339
+
340
+ def _preprocess_video(self, path: Path) -> torch.Tensor:
341
+ video_reader = decord.VideoReader(uri=path.as_posix())
342
+ video_num_frames = len(video_reader)
343
+ print(f"ImageOrVideoDatasetWithResizeAndRectangleCrop: self.resolution_buckets = ", self.resolution_buckets)
344
+ print(f"ImageOrVideoDatasetWithResizeAndRectangleCrop: self.max_num_frames = ", self.max_num_frames)
345
+ print(f"ImageOrVideoDatasetWithResizeAndRectangleCrop: video_num_frames = ", video_num_frames)
346
+
347
+ video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
348
+
349
+ if not video_buckets:
350
+ _, h, w = self.resolution_buckets[0]
351
+ video_buckets = [(1, h, w)]
352
+
353
+ nearest_frame_bucket = min(
354
+ video_buckets,
355
+ key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)),
356
+ default=video_buckets[0],
357
+ )[0]
358
+
359
+ frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
360
+
361
+ frames = video_reader.get_batch(frame_indices)
362
+ frames = frames[:nearest_frame_bucket].float()
363
+ frames = frames.permute(0, 3, 1, 2).contiguous()
364
+
365
+ nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
366
+ frames_resized = self._resize_for_rectangle_crop(frames, nearest_res)
367
+ frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
368
+ return frames
369
+
370
+ def _find_nearest_resolution(self, height, width):
371
+ nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
372
+ return nearest_res[1], nearest_res[2]
373
+
374
+
375
+ class PrecomputedDataset(Dataset):
376
+ def __init__(self, data_root: str, model_name: str = None, cleaned_model_id: str = None) -> None:
377
+ super().__init__()
378
+
379
+ self.data_root = Path(data_root)
380
+
381
+ if model_name and cleaned_model_id:
382
+ precomputation_dir = self.data_root / f"{model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
383
+ self.latents_path = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
384
+ self.conditions_path = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
385
+ else:
386
+ self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME
387
+ self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME
388
+
389
+ self.latent_conditions = sorted(os.listdir(self.latents_path))
390
+ self.text_conditions = sorted(os.listdir(self.conditions_path))
391
+
392
+ assert len(self.latent_conditions) == len(self.text_conditions), "Number of captions and videos do not match"
393
+
394
+ def __len__(self) -> int:
395
+ return len(self.latent_conditions)
396
+
397
+ def __getitem__(self, index: int) -> Dict[str, Any]:
398
+ conditions = {}
399
+ latent_path = self.latents_path / self.latent_conditions[index]
400
+ condition_path = self.conditions_path / self.text_conditions[index]
401
+ conditions["latent_conditions"] = torch.load(latent_path, map_location="cpu", weights_only=True)
402
+ conditions["text_conditions"] = torch.load(condition_path, map_location="cpu", weights_only=True)
403
+ return conditions
404
+
405
+
406
+ class BucketSampler(Sampler):
407
+ r"""
408
+ PyTorch Sampler that groups 3D data by height, width and frames.
409
+
410
+ Args:
411
+ data_source (`ImageOrVideoDataset`):
412
+ A PyTorch dataset object that is an instance of `ImageOrVideoDataset`.
413
+ batch_size (`int`, defaults to `8`):
414
+ The batch size to use for training.
415
+ shuffle (`bool`, defaults to `True`):
416
+ Whether or not to shuffle the data in each batch before dispatching to dataloader.
417
+ drop_last (`bool`, defaults to `False`):
418
+ Whether or not to drop incomplete buckets of data after completely iterating over all data
419
+ in the dataset. If set to True, only batches that have `batch_size` number of entries will
420
+ be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
421
+ and batches that do not have `batch_size` number of entries will also be yielded.
422
+ """
423
+
424
+ def __init__(
425
+ self, data_source: ImageOrVideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
426
+ ) -> None:
427
+ self.data_source = data_source
428
+ self.batch_size = batch_size
429
+ self.shuffle = shuffle
430
+ self.drop_last = drop_last
431
+
432
+ self.buckets = {resolution: [] for resolution in data_source.resolution_buckets}
433
+
434
+ self._raised_warning_for_drop_last = False
435
+
436
+ def __len__(self):
437
+ if self.drop_last and not self._raised_warning_for_drop_last:
438
+ self._raised_warning_for_drop_last = True
439
+ logger.warning(
440
+ "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
441
+ )
442
+ return (len(self.data_source) + self.batch_size - 1) // self.batch_size
443
+
444
+ def __iter__(self):
445
+ for index, data in enumerate(self.data_source):
446
+ video_metadata = data["video_metadata"]
447
+ f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
448
+
449
+ self.buckets[(f, h, w)].append(data)
450
+ if len(self.buckets[(f, h, w)]) == self.batch_size:
451
+ if self.shuffle:
452
+ random.shuffle(self.buckets[(f, h, w)])
453
+ yield self.buckets[(f, h, w)]
454
+ del self.buckets[(f, h, w)]
455
+ self.buckets[(f, h, w)] = []
456
+
457
+ if self.drop_last:
458
+ return
459
+
460
+ for fhw, bucket in list(self.buckets.items()):
461
+ if len(bucket) == 0:
462
+ continue
463
+ if self.shuffle:
464
+ random.shuffle(bucket)
465
+ yield bucket
466
+ del self.buckets[fhw]
467
+ self.buckets[fhw] = []
finetrainers/hooks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .layerwise_upcasting import apply_layerwise_upcasting
finetrainers/hooks/hooks.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ from typing import Any, Dict, Optional, Tuple
17
+
18
+ import torch
19
+ from accelerate.logging import get_logger
20
+
21
+ from ..constants import FINETRAINERS_LOG_LEVEL
22
+
23
+
24
+ logger = get_logger("finetrainers") # pylint: disable=invalid-name
25
+ logger.setLevel(FINETRAINERS_LOG_LEVEL)
26
+
27
+
28
+ class ModelHook:
29
+ r"""
30
+ A hook that contains callbacks to be executed just before and after the forward method of a model.
31
+ """
32
+
33
+ _is_stateful = False
34
+
35
+ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
36
+ r"""
37
+ Hook that is executed when a model is initialized.
38
+ Args:
39
+ module (`torch.nn.Module`):
40
+ The module attached to this hook.
41
+ """
42
+ return module
43
+
44
+ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
45
+ r"""
46
+ Hook that is executed when a model is deinitalized.
47
+ Args:
48
+ module (`torch.nn.Module`):
49
+ The module attached to this hook.
50
+ """
51
+ module.forward = module._old_forward
52
+ del module._old_forward
53
+ return module
54
+
55
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
56
+ r"""
57
+ Hook that is executed just before the forward method of the model.
58
+ Args:
59
+ module (`torch.nn.Module`):
60
+ The module whose forward pass will be executed just after this event.
61
+ args (`Tuple[Any]`):
62
+ The positional arguments passed to the module.
63
+ kwargs (`Dict[Str, Any]`):
64
+ The keyword arguments passed to the module.
65
+ Returns:
66
+ `Tuple[Tuple[Any], Dict[Str, Any]]`:
67
+ A tuple with the treated `args` and `kwargs`.
68
+ """
69
+ return args, kwargs
70
+
71
+ def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
72
+ r"""
73
+ Hook that is executed just after the forward method of the model.
74
+ Args:
75
+ module (`torch.nn.Module`):
76
+ The module whose forward pass been executed just before this event.
77
+ output (`Any`):
78
+ The output of the module.
79
+ Returns:
80
+ `Any`: The processed `output`.
81
+ """
82
+ return output
83
+
84
+ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
85
+ r"""
86
+ Hook that is executed when the hook is detached from a module.
87
+ Args:
88
+ module (`torch.nn.Module`):
89
+ The module detached from this hook.
90
+ """
91
+ return module
92
+
93
+ def reset_state(self, module: torch.nn.Module):
94
+ if self._is_stateful:
95
+ raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
96
+ return module
97
+
98
+
99
+ class HookRegistry:
100
+ def __init__(self, module_ref: torch.nn.Module) -> None:
101
+ super().__init__()
102
+
103
+ self.hooks: Dict[str, ModelHook] = {}
104
+
105
+ self._module_ref = module_ref
106
+ self._hook_order = []
107
+
108
+ def register_hook(self, hook: ModelHook, name: str) -> None:
109
+ if name in self.hooks.keys():
110
+ logger.warning(f"Hook with name {name} already exists, replacing it.")
111
+
112
+ if hasattr(self._module_ref, "_old_forward"):
113
+ old_forward = self._module_ref._old_forward
114
+ else:
115
+ old_forward = self._module_ref.forward
116
+ self._module_ref._old_forward = self._module_ref.forward
117
+
118
+ self._module_ref = hook.initialize_hook(self._module_ref)
119
+
120
+ if hasattr(hook, "new_forward"):
121
+ rewritten_forward = hook.new_forward
122
+
123
+ def new_forward(module, *args, **kwargs):
124
+ args, kwargs = hook.pre_forward(module, *args, **kwargs)
125
+ output = rewritten_forward(module, *args, **kwargs)
126
+ return hook.post_forward(module, output)
127
+ else:
128
+
129
+ def new_forward(module, *args, **kwargs):
130
+ args, kwargs = hook.pre_forward(module, *args, **kwargs)
131
+ output = old_forward(*args, **kwargs)
132
+ return hook.post_forward(module, output)
133
+
134
+ self._module_ref.forward = functools.update_wrapper(
135
+ functools.partial(new_forward, self._module_ref), old_forward
136
+ )
137
+
138
+ self.hooks[name] = hook
139
+ self._hook_order.append(name)
140
+
141
+ def get_hook(self, name: str) -> Optional[ModelHook]:
142
+ if name not in self.hooks.keys():
143
+ return None
144
+ return self.hooks[name]
145
+
146
+ def remove_hook(self, name: str) -> None:
147
+ if name not in self.hooks.keys():
148
+ raise ValueError(f"Hook with name {name} not found.")
149
+ self.hooks[name].deinitalize_hook(self._module_ref)
150
+ del self.hooks[name]
151
+ self._hook_order.remove(name)
152
+
153
+ def reset_stateful_hooks(self, recurse: bool = True) -> None:
154
+ for hook_name in self._hook_order:
155
+ hook = self.hooks[hook_name]
156
+ if hook._is_stateful:
157
+ hook.reset_state(self._module_ref)
158
+
159
+ if recurse:
160
+ for module in self._module_ref.modules():
161
+ if hasattr(module, "_diffusers_hook"):
162
+ module._diffusers_hook.reset_stateful_hooks(recurse=False)
163
+
164
+ @classmethod
165
+ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
166
+ if not hasattr(module, "_diffusers_hook"):
167
+ module._diffusers_hook = cls(module)
168
+ return module._diffusers_hook
169
+
170
+ def __repr__(self) -> str:
171
+ hook_repr = ""
172
+ for i, hook_name in enumerate(self._hook_order):
173
+ hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
174
+ if i < len(self._hook_order) - 1:
175
+ hook_repr += "\n"
176
+ return f"HookRegistry(\n{hook_repr}\n)"
finetrainers/hooks/layerwise_upcasting.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import Optional, Tuple, Type
17
+
18
+ import torch
19
+ from accelerate.logging import get_logger
20
+
21
+ from ..constants import FINETRAINERS_LOG_LEVEL
22
+ from .hooks import HookRegistry, ModelHook
23
+
24
+
25
+ logger = get_logger("finetrainers") # pylint: disable=invalid-name
26
+ logger.setLevel(FINETRAINERS_LOG_LEVEL)
27
+
28
+
29
+ # fmt: off
30
+ _SUPPORTED_PYTORCH_LAYERS = (
31
+ torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
32
+ torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
33
+ torch.nn.Linear,
34
+ )
35
+
36
+ _DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm")
37
+ # fmt: on
38
+
39
+
40
+ class LayerwiseUpcastingHook(ModelHook):
41
+ r"""
42
+ A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
43
+ for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
44
+ footprint.
45
+ """
46
+
47
+ _is_stateful = False
48
+
49
+ def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
50
+ self.storage_dtype = storage_dtype
51
+ self.compute_dtype = compute_dtype
52
+ self.non_blocking = non_blocking
53
+
54
+ def initialize_hook(self, module: torch.nn.Module):
55
+ module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
56
+ return module
57
+
58
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
59
+ module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
60
+ return args, kwargs
61
+
62
+ def post_forward(self, module: torch.nn.Module, output):
63
+ module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
64
+ return output
65
+
66
+
67
+ def apply_layerwise_upcasting(
68
+ module: torch.nn.Module,
69
+ storage_dtype: torch.dtype,
70
+ compute_dtype: torch.dtype,
71
+ skip_modules_pattern: Optional[Tuple[str]] = _DEFAULT_SKIP_MODULES_PATTERN,
72
+ skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = None,
73
+ non_blocking: bool = False,
74
+ _prefix: str = "",
75
+ ) -> None:
76
+ r"""
77
+ Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
78
+ nn.Module using diffusers layers or pytorch primitives.
79
+ Args:
80
+ module (`torch.nn.Module`):
81
+ The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
82
+ precision dtype for storage.
83
+ storage_dtype (`torch.dtype`):
84
+ The dtype to cast the module to before/after the forward pass for storage.
85
+ compute_dtype (`torch.dtype`):
86
+ The dtype to cast the module to during the forward pass for computation.
87
+ skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`):
88
+ A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
89
+ skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `None`):
90
+ A list of module classes to skip during the layerwise upcasting process.
91
+ non_blocking (`bool`, defaults to `False`):
92
+ If `True`, the weight casting operations are non-blocking.
93
+ """
94
+ if skip_modules_classes is None and skip_modules_pattern is None:
95
+ apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
96
+ return
97
+
98
+ should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
99
+ skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
100
+ )
101
+ if should_skip:
102
+ logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"')
103
+ return
104
+
105
+ if isinstance(module, _SUPPORTED_PYTORCH_LAYERS):
106
+ logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"')
107
+ apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
108
+ return
109
+
110
+ for name, submodule in module.named_children():
111
+ layer_name = f"{_prefix}.{name}" if _prefix else name
112
+ apply_layerwise_upcasting(
113
+ submodule,
114
+ storage_dtype,
115
+ compute_dtype,
116
+ skip_modules_pattern,
117
+ skip_modules_classes,
118
+ non_blocking,
119
+ _prefix=layer_name,
120
+ )
121
+
122
+
123
+ def apply_layerwise_upcasting_hook(
124
+ module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
125
+ ) -> None:
126
+ r"""
127
+ Applies a `LayerwiseUpcastingHook` to a given module.
128
+ Args:
129
+ module (`torch.nn.Module`):
130
+ The module to attach the hook to.
131
+ storage_dtype (`torch.dtype`):
132
+ The dtype to cast the module to before the forward pass.
133
+ compute_dtype (`torch.dtype`):
134
+ The dtype to cast the module to during the forward pass.
135
+ non_blocking (`bool`):
136
+ If `True`, the weight casting operations are non-blocking.
137
+ """
138
+ registry = HookRegistry.check_if_exists_or_initialize(module)
139
+ hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking)
140
+ registry.register_hook(hook, "layerwise_upcasting")
finetrainers/models/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ from .cogvideox import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG, COGVIDEOX_T2V_LORA_CONFIG
4
+ from .hunyuan_video import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG, HUNYUAN_VIDEO_T2V_LORA_CONFIG
5
+ from .ltx_video import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG, LTX_VIDEO_T2V_LORA_CONFIG
6
+
7
+
8
+ SUPPORTED_MODEL_CONFIGS = {
9
+ "hunyuan_video": {
10
+ "lora": HUNYUAN_VIDEO_T2V_LORA_CONFIG,
11
+ "full-finetune": HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG,
12
+ },
13
+ "ltx_video": {
14
+ "lora": LTX_VIDEO_T2V_LORA_CONFIG,
15
+ "full-finetune": LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG,
16
+ },
17
+ "cogvideox": {
18
+ "lora": COGVIDEOX_T2V_LORA_CONFIG,
19
+ "full-finetune": COGVIDEOX_T2V_FULL_FINETUNE_CONFIG,
20
+ },
21
+ }
22
+
23
+
24
+ def get_config_from_model_name(model_name: str, training_type: str) -> Dict[str, Any]:
25
+ if model_name not in SUPPORTED_MODEL_CONFIGS:
26
+ raise ValueError(
27
+ f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}"
28
+ )
29
+ if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]:
30
+ raise ValueError(
31
+ f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}"
32
+ )
33
+ return SUPPORTED_MODEL_CONFIGS[model_name][training_type]
finetrainers/models/cogvideox/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .full_finetune import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG
2
+ from .lora import COGVIDEOX_T2V_LORA_CONFIG
finetrainers/models/cogvideox/full_finetune.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import CogVideoXPipeline
2
+
3
+ from .lora import (
4
+ calculate_noisy_latents,
5
+ collate_fn_t2v,
6
+ forward_pass,
7
+ initialize_pipeline,
8
+ load_condition_models,
9
+ load_diffusion_models,
10
+ load_latent_models,
11
+ post_latent_preparation,
12
+ prepare_conditions,
13
+ prepare_latents,
14
+ validation,
15
+ )
16
+
17
+
18
+ # TODO(aryan): refactor into model specs for better re-use
19
+ COGVIDEOX_T2V_FULL_FINETUNE_CONFIG = {
20
+ "pipeline_cls": CogVideoXPipeline,
21
+ "load_condition_models": load_condition_models,
22
+ "load_latent_models": load_latent_models,
23
+ "load_diffusion_models": load_diffusion_models,
24
+ "initialize_pipeline": initialize_pipeline,
25
+ "prepare_conditions": prepare_conditions,
26
+ "prepare_latents": prepare_latents,
27
+ "post_latent_preparation": post_latent_preparation,
28
+ "collate_fn": collate_fn_t2v,
29
+ "calculate_noisy_latents": calculate_noisy_latents,
30
+ "forward_pass": forward_pass,
31
+ "validation": validation,
32
+ }
finetrainers/models/cogvideox/lora.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ import torch
4
+ from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
5
+ from PIL import Image
6
+ from transformers import T5EncoderModel, T5Tokenizer
7
+
8
+ from .utils import prepare_rotary_positional_embeddings
9
+
10
+
11
+ def load_condition_models(
12
+ model_id: str = "THUDM/CogVideoX-5b",
13
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
14
+ revision: Optional[str] = None,
15
+ cache_dir: Optional[str] = None,
16
+ **kwargs,
17
+ ):
18
+ tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
19
+ text_encoder = T5EncoderModel.from_pretrained(
20
+ model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
21
+ )
22
+ return {"tokenizer": tokenizer, "text_encoder": text_encoder}
23
+
24
+
25
+ def load_latent_models(
26
+ model_id: str = "THUDM/CogVideoX-5b",
27
+ vae_dtype: torch.dtype = torch.bfloat16,
28
+ revision: Optional[str] = None,
29
+ cache_dir: Optional[str] = None,
30
+ **kwargs,
31
+ ):
32
+ vae = AutoencoderKLCogVideoX.from_pretrained(
33
+ model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
34
+ )
35
+ return {"vae": vae}
36
+
37
+
38
+ def load_diffusion_models(
39
+ model_id: str = "THUDM/CogVideoX-5b",
40
+ transformer_dtype: torch.dtype = torch.bfloat16,
41
+ revision: Optional[str] = None,
42
+ cache_dir: Optional[str] = None,
43
+ **kwargs,
44
+ ):
45
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
46
+ model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
47
+ )
48
+ scheduler = CogVideoXDDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
49
+ return {"transformer": transformer, "scheduler": scheduler}
50
+
51
+
52
+ def initialize_pipeline(
53
+ model_id: str = "THUDM/CogVideoX-5b",
54
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
55
+ transformer_dtype: torch.dtype = torch.bfloat16,
56
+ vae_dtype: torch.dtype = torch.bfloat16,
57
+ tokenizer: Optional[T5Tokenizer] = None,
58
+ text_encoder: Optional[T5EncoderModel] = None,
59
+ transformer: Optional[CogVideoXTransformer3DModel] = None,
60
+ vae: Optional[AutoencoderKLCogVideoX] = None,
61
+ scheduler: Optional[CogVideoXDDIMScheduler] = None,
62
+ device: Optional[torch.device] = None,
63
+ revision: Optional[str] = None,
64
+ cache_dir: Optional[str] = None,
65
+ enable_slicing: bool = False,
66
+ enable_tiling: bool = False,
67
+ enable_model_cpu_offload: bool = False,
68
+ is_training: bool = False,
69
+ **kwargs,
70
+ ) -> CogVideoXPipeline:
71
+ component_name_pairs = [
72
+ ("tokenizer", tokenizer),
73
+ ("text_encoder", text_encoder),
74
+ ("transformer", transformer),
75
+ ("vae", vae),
76
+ ("scheduler", scheduler),
77
+ ]
78
+ components = {}
79
+ for name, component in component_name_pairs:
80
+ if component is not None:
81
+ components[name] = component
82
+
83
+ pipe = CogVideoXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
84
+ pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
85
+ pipe.vae = pipe.vae.to(dtype=vae_dtype)
86
+
87
+ # The transformer should already be in the correct dtype when training, so we don't need to cast it here.
88
+ # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
89
+ # DDP optimizer step.
90
+ if not is_training:
91
+ pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
92
+
93
+ if enable_slicing:
94
+ pipe.vae.enable_slicing()
95
+ if enable_tiling:
96
+ pipe.vae.enable_tiling()
97
+
98
+ if enable_model_cpu_offload:
99
+ pipe.enable_model_cpu_offload(device=device)
100
+ else:
101
+ pipe.to(device=device)
102
+
103
+ return pipe
104
+
105
+
106
+ def prepare_conditions(
107
+ tokenizer,
108
+ text_encoder,
109
+ prompt: Union[str, List[str]],
110
+ device: Optional[torch.device] = None,
111
+ dtype: Optional[torch.dtype] = None,
112
+ max_sequence_length: int = 226, # TODO: this should be configurable
113
+ **kwargs,
114
+ ):
115
+ device = device or text_encoder.device
116
+ dtype = dtype or text_encoder.dtype
117
+ return _get_t5_prompt_embeds(
118
+ tokenizer=tokenizer,
119
+ text_encoder=text_encoder,
120
+ prompt=prompt,
121
+ max_sequence_length=max_sequence_length,
122
+ device=device,
123
+ dtype=dtype,
124
+ )
125
+
126
+
127
+ def prepare_latents(
128
+ vae: AutoencoderKLCogVideoX,
129
+ image_or_video: torch.Tensor,
130
+ device: Optional[torch.device] = None,
131
+ dtype: Optional[torch.dtype] = None,
132
+ generator: Optional[torch.Generator] = None,
133
+ precompute: bool = False,
134
+ **kwargs,
135
+ ) -> torch.Tensor:
136
+ device = device or vae.device
137
+ dtype = dtype or vae.dtype
138
+
139
+ if image_or_video.ndim == 4:
140
+ image_or_video = image_or_video.unsqueeze(2)
141
+ assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
142
+
143
+ image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
144
+ image_or_video = image_or_video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
145
+ if not precompute:
146
+ latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
147
+ if not vae.config.invert_scale_latents:
148
+ latents = latents * vae.config.scaling_factor
149
+ # For training Cog 1.5, we don't need to handle the scaling factor here.
150
+ # The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents
151
+ # is probably only needed for image-to-video training.
152
+ # TODO(aryan): investigate this
153
+ # else:
154
+ # latents = 1 / vae.config.scaling_factor * latents
155
+ latents = latents.to(dtype=dtype)
156
+ return {"latents": latents}
157
+ else:
158
+ # handle vae scaling in the `train()` method directly.
159
+ if vae.use_slicing and image_or_video.shape[0] > 1:
160
+ encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)]
161
+ h = torch.cat(encoded_slices)
162
+ else:
163
+ h = vae._encode(image_or_video)
164
+ return {"latents": h}
165
+
166
+
167
+ def post_latent_preparation(
168
+ vae_config: Dict[str, Any], latents: torch.Tensor, patch_size_t: Optional[int] = None, **kwargs
169
+ ) -> torch.Tensor:
170
+ if not vae_config.invert_scale_latents:
171
+ latents = latents * vae_config.scaling_factor
172
+ # For training Cog 1.5, we don't need to handle the scaling factor here.
173
+ # The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents
174
+ # is probably only needed for image-to-video training.
175
+ # TODO(aryan): investigate this
176
+ # else:
177
+ # latents = 1 / vae_config.scaling_factor * latents
178
+ latents = _pad_frames(latents, patch_size_t)
179
+ latents = latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
180
+ return {"latents": latents}
181
+
182
+
183
+ def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
184
+ return {
185
+ "prompts": [x["prompt"] for x in batch[0]],
186
+ "videos": torch.stack([x["video"] for x in batch[0]]),
187
+ }
188
+
189
+
190
+ def calculate_noisy_latents(
191
+ scheduler: CogVideoXDDIMScheduler,
192
+ noise: torch.Tensor,
193
+ latents: torch.Tensor,
194
+ timesteps: torch.LongTensor,
195
+ ) -> torch.Tensor:
196
+ noisy_latents = scheduler.add_noise(latents, noise, timesteps)
197
+ return noisy_latents
198
+
199
+
200
+ def forward_pass(
201
+ transformer: CogVideoXTransformer3DModel,
202
+ scheduler: CogVideoXDDIMScheduler,
203
+ prompt_embeds: torch.Tensor,
204
+ latents: torch.Tensor,
205
+ noisy_latents: torch.Tensor,
206
+ timesteps: torch.LongTensor,
207
+ ofs_emb: Optional[torch.Tensor] = None,
208
+ **kwargs,
209
+ ) -> torch.Tensor:
210
+ # Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself.
211
+ VAE_SPATIAL_SCALE_FACTOR = 8
212
+ transformer_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
213
+ batch_size, num_frames, num_channels, height, width = noisy_latents.shape
214
+ rope_base_height = transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR
215
+ rope_base_width = transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR
216
+
217
+ image_rotary_emb = (
218
+ prepare_rotary_positional_embeddings(
219
+ height=height * VAE_SPATIAL_SCALE_FACTOR,
220
+ width=width * VAE_SPATIAL_SCALE_FACTOR,
221
+ num_frames=num_frames,
222
+ vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR,
223
+ patch_size=transformer_config.patch_size,
224
+ patch_size_t=transformer_config.patch_size_t if hasattr(transformer_config, "patch_size_t") else None,
225
+ attention_head_dim=transformer_config.attention_head_dim,
226
+ device=transformer.device,
227
+ base_height=rope_base_height,
228
+ base_width=rope_base_width,
229
+ )
230
+ if transformer_config.use_rotary_positional_embeddings
231
+ else None
232
+ )
233
+ ofs_emb = None if transformer_config.ofs_embed_dim is None else latents.new_full((batch_size,), fill_value=2.0)
234
+
235
+ velocity = transformer(
236
+ hidden_states=noisy_latents,
237
+ timestep=timesteps,
238
+ encoder_hidden_states=prompt_embeds,
239
+ ofs=ofs_emb,
240
+ image_rotary_emb=image_rotary_emb,
241
+ return_dict=False,
242
+ )[0]
243
+ # For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same
244
+ # code paths as scheduler.get_velocity(), which can be confusing to understand.
245
+ denoised_latents = scheduler.get_velocity(velocity, noisy_latents, timesteps)
246
+
247
+ return {"latents": denoised_latents}
248
+
249
+
250
+ def validation(
251
+ pipeline: CogVideoXPipeline,
252
+ prompt: str,
253
+ image: Optional[Image.Image] = None,
254
+ video: Optional[List[Image.Image]] = None,
255
+ height: Optional[int] = None,
256
+ width: Optional[int] = None,
257
+ num_frames: Optional[int] = None,
258
+ num_videos_per_prompt: int = 1,
259
+ generator: Optional[torch.Generator] = None,
260
+ **kwargs,
261
+ ):
262
+ generation_kwargs = {
263
+ "prompt": prompt,
264
+ "height": height,
265
+ "width": width,
266
+ "num_frames": num_frames,
267
+ "num_videos_per_prompt": num_videos_per_prompt,
268
+ "generator": generator,
269
+ "return_dict": True,
270
+ "output_type": "pil",
271
+ }
272
+ generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
273
+ output = pipeline(**generation_kwargs).frames[0]
274
+ return [("video", output)]
275
+
276
+
277
+ def _get_t5_prompt_embeds(
278
+ tokenizer: T5Tokenizer,
279
+ text_encoder: T5EncoderModel,
280
+ prompt: Union[str, List[str]] = None,
281
+ max_sequence_length: int = 226,
282
+ device: Optional[torch.device] = None,
283
+ dtype: Optional[torch.dtype] = None,
284
+ ):
285
+ prompt = [prompt] if isinstance(prompt, str) else prompt
286
+
287
+ text_inputs = tokenizer(
288
+ prompt,
289
+ padding="max_length",
290
+ max_length=max_sequence_length,
291
+ truncation=True,
292
+ add_special_tokens=True,
293
+ return_tensors="pt",
294
+ )
295
+ text_input_ids = text_inputs.input_ids
296
+
297
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
298
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
299
+
300
+ return {"prompt_embeds": prompt_embeds}
301
+
302
+
303
+ def _pad_frames(latents: torch.Tensor, patch_size_t: int):
304
+ if patch_size_t is None or patch_size_t == 1:
305
+ return latents
306
+
307
+ # `latents` should be of the following format: [B, C, F, H, W].
308
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
309
+ latent_num_frames = latents.shape[2]
310
+ additional_frames = patch_size_t - latent_num_frames % patch_size_t
311
+
312
+ if additional_frames > 0:
313
+ last_frame = latents[:, :, -1:, :, :]
314
+ padding_frames = last_frame.repeat(1, 1, additional_frames, 1, 1)
315
+ latents = torch.cat([latents, padding_frames], dim=2)
316
+
317
+ return latents
318
+
319
+
320
+ # TODO(aryan): refactor into model specs for better re-use
321
+ COGVIDEOX_T2V_LORA_CONFIG = {
322
+ "pipeline_cls": CogVideoXPipeline,
323
+ "load_condition_models": load_condition_models,
324
+ "load_latent_models": load_latent_models,
325
+ "load_diffusion_models": load_diffusion_models,
326
+ "initialize_pipeline": initialize_pipeline,
327
+ "prepare_conditions": prepare_conditions,
328
+ "prepare_latents": prepare_latents,
329
+ "post_latent_preparation": post_latent_preparation,
330
+ "collate_fn": collate_fn_t2v,
331
+ "calculate_noisy_latents": calculate_noisy_latents,
332
+ "forward_pass": forward_pass,
333
+ "validation": validation,
334
+ }
finetrainers/models/cogvideox/utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
5
+ from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
6
+
7
+
8
+ def prepare_rotary_positional_embeddings(
9
+ height: int,
10
+ width: int,
11
+ num_frames: int,
12
+ vae_scale_factor_spatial: int = 8,
13
+ patch_size: int = 2,
14
+ patch_size_t: int = None,
15
+ attention_head_dim: int = 64,
16
+ device: Optional[torch.device] = None,
17
+ base_height: int = 480,
18
+ base_width: int = 720,
19
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
20
+ grid_height = height // (vae_scale_factor_spatial * patch_size)
21
+ grid_width = width // (vae_scale_factor_spatial * patch_size)
22
+ base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
23
+ base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
24
+
25
+ if patch_size_t is None:
26
+ # CogVideoX 1.0
27
+ grid_crops_coords = get_resize_crop_region_for_grid(
28
+ (grid_height, grid_width), base_size_width, base_size_height
29
+ )
30
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
31
+ embed_dim=attention_head_dim,
32
+ crops_coords=grid_crops_coords,
33
+ grid_size=(grid_height, grid_width),
34
+ temporal_size=num_frames,
35
+ )
36
+ else:
37
+ # CogVideoX 1.5
38
+ base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t
39
+
40
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
41
+ embed_dim=attention_head_dim,
42
+ crops_coords=None,
43
+ grid_size=(grid_height, grid_width),
44
+ temporal_size=base_num_frames,
45
+ grid_type="slice",
46
+ max_size=(base_size_height, base_size_width),
47
+ )
48
+
49
+ freqs_cos = freqs_cos.to(device=device)
50
+ freqs_sin = freqs_sin.to(device=device)
51
+ return freqs_cos, freqs_sin
finetrainers/models/hunyuan_video/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .full_finetune import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG
2
+ from .lora import HUNYUAN_VIDEO_T2V_LORA_CONFIG
finetrainers/models/hunyuan_video/full_finetune.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import HunyuanVideoPipeline
2
+
3
+ from .lora import (
4
+ collate_fn_t2v,
5
+ forward_pass,
6
+ initialize_pipeline,
7
+ load_condition_models,
8
+ load_diffusion_models,
9
+ load_latent_models,
10
+ post_latent_preparation,
11
+ prepare_conditions,
12
+ prepare_latents,
13
+ validation,
14
+ )
15
+
16
+
17
+ # TODO(aryan): refactor into model specs for better re-use
18
+ HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG = {
19
+ "pipeline_cls": HunyuanVideoPipeline,
20
+ "load_condition_models": load_condition_models,
21
+ "load_latent_models": load_latent_models,
22
+ "load_diffusion_models": load_diffusion_models,
23
+ "initialize_pipeline": initialize_pipeline,
24
+ "prepare_conditions": prepare_conditions,
25
+ "prepare_latents": prepare_latents,
26
+ "post_latent_preparation": post_latent_preparation,
27
+ "collate_fn": collate_fn_t2v,
28
+ "forward_pass": forward_pass,
29
+ "validation": validation,
30
+ }
finetrainers/models/hunyuan_video/lora.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from accelerate.logging import get_logger
6
+ from diffusers import (
7
+ AutoencoderKLHunyuanVideo,
8
+ FlowMatchEulerDiscreteScheduler,
9
+ HunyuanVideoPipeline,
10
+ HunyuanVideoTransformer3DModel,
11
+ )
12
+ from PIL import Image
13
+ from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizer
14
+
15
+
16
+ logger = get_logger("finetrainers") # pylint: disable=invalid-name
17
+
18
+
19
+ def load_condition_models(
20
+ model_id: str = "hunyuanvideo-community/HunyuanVideo",
21
+ text_encoder_dtype: torch.dtype = torch.float16,
22
+ text_encoder_2_dtype: torch.dtype = torch.float16,
23
+ revision: Optional[str] = None,
24
+ cache_dir: Optional[str] = None,
25
+ **kwargs,
26
+ ) -> Dict[str, nn.Module]:
27
+ tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
28
+ text_encoder = LlamaModel.from_pretrained(
29
+ model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
30
+ )
31
+ tokenizer_2 = CLIPTokenizer.from_pretrained(
32
+ model_id, subfolder="tokenizer_2", revision=revision, cache_dir=cache_dir
33
+ )
34
+ text_encoder_2 = CLIPTextModel.from_pretrained(
35
+ model_id, subfolder="text_encoder_2", torch_dtype=text_encoder_2_dtype, revision=revision, cache_dir=cache_dir
36
+ )
37
+ return {
38
+ "tokenizer": tokenizer,
39
+ "text_encoder": text_encoder,
40
+ "tokenizer_2": tokenizer_2,
41
+ "text_encoder_2": text_encoder_2,
42
+ }
43
+
44
+
45
+ def load_latent_models(
46
+ model_id: str = "hunyuanvideo-community/HunyuanVideo",
47
+ vae_dtype: torch.dtype = torch.float16,
48
+ revision: Optional[str] = None,
49
+ cache_dir: Optional[str] = None,
50
+ **kwargs,
51
+ ) -> Dict[str, nn.Module]:
52
+ vae = AutoencoderKLHunyuanVideo.from_pretrained(
53
+ model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
54
+ )
55
+ return {"vae": vae}
56
+
57
+
58
+ def load_diffusion_models(
59
+ model_id: str = "hunyuanvideo-community/HunyuanVideo",
60
+ transformer_dtype: torch.dtype = torch.bfloat16,
61
+ shift: float = 1.0,
62
+ revision: Optional[str] = None,
63
+ cache_dir: Optional[str] = None,
64
+ **kwargs,
65
+ ) -> Dict[str, Union[nn.Module, FlowMatchEulerDiscreteScheduler]]:
66
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
67
+ model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
68
+ )
69
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
70
+ return {"transformer": transformer, "scheduler": scheduler}
71
+
72
+
73
+ def initialize_pipeline(
74
+ model_id: str = "hunyuanvideo-community/HunyuanVideo",
75
+ text_encoder_dtype: torch.dtype = torch.float16,
76
+ text_encoder_2_dtype: torch.dtype = torch.float16,
77
+ transformer_dtype: torch.dtype = torch.bfloat16,
78
+ vae_dtype: torch.dtype = torch.float16,
79
+ tokenizer: Optional[LlamaTokenizer] = None,
80
+ text_encoder: Optional[LlamaModel] = None,
81
+ tokenizer_2: Optional[CLIPTokenizer] = None,
82
+ text_encoder_2: Optional[CLIPTextModel] = None,
83
+ transformer: Optional[HunyuanVideoTransformer3DModel] = None,
84
+ vae: Optional[AutoencoderKLHunyuanVideo] = None,
85
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
86
+ device: Optional[torch.device] = None,
87
+ revision: Optional[str] = None,
88
+ cache_dir: Optional[str] = None,
89
+ enable_slicing: bool = False,
90
+ enable_tiling: bool = False,
91
+ enable_model_cpu_offload: bool = False,
92
+ is_training: bool = False,
93
+ **kwargs,
94
+ ) -> HunyuanVideoPipeline:
95
+ component_name_pairs = [
96
+ ("tokenizer", tokenizer),
97
+ ("text_encoder", text_encoder),
98
+ ("tokenizer_2", tokenizer_2),
99
+ ("text_encoder_2", text_encoder_2),
100
+ ("transformer", transformer),
101
+ ("vae", vae),
102
+ ("scheduler", scheduler),
103
+ ]
104
+ components = {}
105
+ for name, component in component_name_pairs:
106
+ if component is not None:
107
+ components[name] = component
108
+
109
+ pipe = HunyuanVideoPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
110
+ pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
111
+ pipe.text_encoder_2 = pipe.text_encoder_2.to(dtype=text_encoder_2_dtype)
112
+ pipe.vae = pipe.vae.to(dtype=vae_dtype)
113
+
114
+ # The transformer should already be in the correct dtype when training, so we don't need to cast it here.
115
+ # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
116
+ # DDP optimizer step.
117
+ if not is_training:
118
+ pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
119
+
120
+ if enable_slicing:
121
+ pipe.vae.enable_slicing()
122
+ if enable_tiling:
123
+ pipe.vae.enable_tiling()
124
+
125
+ if enable_model_cpu_offload:
126
+ pipe.enable_model_cpu_offload(device=device)
127
+ else:
128
+ pipe.to(device=device)
129
+
130
+ return pipe
131
+
132
+
133
+ def prepare_conditions(
134
+ tokenizer: LlamaTokenizer,
135
+ text_encoder: LlamaModel,
136
+ tokenizer_2: CLIPTokenizer,
137
+ text_encoder_2: CLIPTextModel,
138
+ prompt: Union[str, List[str]],
139
+ guidance: float = 1.0,
140
+ device: Optional[torch.device] = None,
141
+ dtype: Optional[torch.dtype] = None,
142
+ max_sequence_length: int = 256,
143
+ # TODO(aryan): make configurable
144
+ prompt_template: Dict[str, Any] = {
145
+ "template": (
146
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
147
+ "1. The main content and theme of the video."
148
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
149
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
150
+ "4. background environment, light, style and atmosphere."
151
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
152
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
153
+ ),
154
+ "crop_start": 95,
155
+ },
156
+ **kwargs,
157
+ ) -> torch.Tensor:
158
+ device = device or text_encoder.device
159
+ dtype = dtype or text_encoder.dtype
160
+
161
+ if isinstance(prompt, str):
162
+ prompt = [prompt]
163
+
164
+ conditions = {}
165
+ conditions.update(
166
+ _get_llama_prompt_embeds(tokenizer, text_encoder, prompt, prompt_template, device, dtype, max_sequence_length)
167
+ )
168
+ conditions.update(_get_clip_prompt_embeds(tokenizer_2, text_encoder_2, prompt, device, dtype))
169
+
170
+ guidance = torch.tensor([guidance], device=device, dtype=dtype) * 1000.0
171
+ conditions["guidance"] = guidance
172
+
173
+ return conditions
174
+
175
+
176
+ def prepare_latents(
177
+ vae: AutoencoderKLHunyuanVideo,
178
+ image_or_video: torch.Tensor,
179
+ device: Optional[torch.device] = None,
180
+ dtype: Optional[torch.dtype] = None,
181
+ generator: Optional[torch.Generator] = None,
182
+ precompute: bool = False,
183
+ **kwargs,
184
+ ) -> torch.Tensor:
185
+ device = device or vae.device
186
+ dtype = dtype or vae.dtype
187
+
188
+ if image_or_video.ndim == 4:
189
+ image_or_video = image_or_video.unsqueeze(2)
190
+ assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
191
+
192
+ image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
193
+ image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W]
194
+ if not precompute:
195
+ latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
196
+ latents = latents * vae.config.scaling_factor
197
+ latents = latents.to(dtype=dtype)
198
+ return {"latents": latents}
199
+ else:
200
+ if vae.use_slicing and image_or_video.shape[0] > 1:
201
+ encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)]
202
+ h = torch.cat(encoded_slices)
203
+ else:
204
+ h = vae._encode(image_or_video)
205
+ return {"latents": h}
206
+
207
+
208
+ def post_latent_preparation(
209
+ vae_config: Dict[str, Any],
210
+ latents: torch.Tensor,
211
+ **kwargs,
212
+ ) -> torch.Tensor:
213
+ latents = latents * vae_config.scaling_factor
214
+ return {"latents": latents}
215
+
216
+
217
+ def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
218
+ return {
219
+ "prompts": [x["prompt"] for x in batch[0]],
220
+ "videos": torch.stack([x["video"] for x in batch[0]]),
221
+ }
222
+
223
+
224
+ def forward_pass(
225
+ transformer: HunyuanVideoTransformer3DModel,
226
+ prompt_embeds: torch.Tensor,
227
+ pooled_prompt_embeds: torch.Tensor,
228
+ prompt_attention_mask: torch.Tensor,
229
+ guidance: torch.Tensor,
230
+ latents: torch.Tensor,
231
+ noisy_latents: torch.Tensor,
232
+ timesteps: torch.LongTensor,
233
+ **kwargs,
234
+ ) -> torch.Tensor:
235
+ denoised_latents = transformer(
236
+ hidden_states=noisy_latents,
237
+ timestep=timesteps,
238
+ encoder_hidden_states=prompt_embeds,
239
+ pooled_projections=pooled_prompt_embeds,
240
+ encoder_attention_mask=prompt_attention_mask,
241
+ guidance=guidance,
242
+ return_dict=False,
243
+ )[0]
244
+
245
+ return {"latents": denoised_latents}
246
+
247
+
248
+ def validation(
249
+ pipeline: HunyuanVideoPipeline,
250
+ prompt: str,
251
+ image: Optional[Image.Image] = None,
252
+ video: Optional[List[Image.Image]] = None,
253
+ height: Optional[int] = None,
254
+ width: Optional[int] = None,
255
+ num_frames: Optional[int] = None,
256
+ num_videos_per_prompt: int = 1,
257
+ generator: Optional[torch.Generator] = None,
258
+ **kwargs,
259
+ ):
260
+ generation_kwargs = {
261
+ "prompt": prompt,
262
+ "height": height,
263
+ "width": width,
264
+ "num_frames": num_frames,
265
+ "num_inference_steps": 30,
266
+ "num_videos_per_prompt": num_videos_per_prompt,
267
+ "generator": generator,
268
+ "return_dict": True,
269
+ "output_type": "pil",
270
+ }
271
+ generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
272
+ output = pipeline(**generation_kwargs).frames[0]
273
+ return [("video", output)]
274
+
275
+
276
+ def _get_llama_prompt_embeds(
277
+ tokenizer: LlamaTokenizer,
278
+ text_encoder: LlamaModel,
279
+ prompt: List[str],
280
+ prompt_template: Dict[str, Any],
281
+ device: torch.device,
282
+ dtype: torch.dtype,
283
+ max_sequence_length: int = 256,
284
+ num_hidden_layers_to_skip: int = 2,
285
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
286
+ batch_size = len(prompt)
287
+ prompt = [prompt_template["template"].format(p) for p in prompt]
288
+
289
+ crop_start = prompt_template.get("crop_start", None)
290
+ if crop_start is None:
291
+ prompt_template_input = tokenizer(
292
+ prompt_template["template"],
293
+ padding="max_length",
294
+ return_tensors="pt",
295
+ return_length=False,
296
+ return_overflowing_tokens=False,
297
+ return_attention_mask=False,
298
+ )
299
+ crop_start = prompt_template_input["input_ids"].shape[-1]
300
+ # Remove <|eot_id|> token and placeholder {}
301
+ crop_start -= 2
302
+
303
+ max_sequence_length += crop_start
304
+ text_inputs = tokenizer(
305
+ prompt,
306
+ max_length=max_sequence_length,
307
+ padding="max_length",
308
+ truncation=True,
309
+ return_tensors="pt",
310
+ return_length=False,
311
+ return_overflowing_tokens=False,
312
+ return_attention_mask=True,
313
+ )
314
+ text_input_ids = text_inputs.input_ids.to(device=device)
315
+ prompt_attention_mask = text_inputs.attention_mask.to(device=device)
316
+
317
+ prompt_embeds = text_encoder(
318
+ input_ids=text_input_ids,
319
+ attention_mask=prompt_attention_mask,
320
+ output_hidden_states=True,
321
+ ).hidden_states[-(num_hidden_layers_to_skip + 1)]
322
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
323
+
324
+ if crop_start is not None and crop_start > 0:
325
+ prompt_embeds = prompt_embeds[:, crop_start:]
326
+ prompt_attention_mask = prompt_attention_mask[:, crop_start:]
327
+
328
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
329
+
330
+ return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask}
331
+
332
+
333
+ def _get_clip_prompt_embeds(
334
+ tokenizer_2: CLIPTokenizer,
335
+ text_encoder_2: CLIPTextModel,
336
+ prompt: Union[str, List[str]],
337
+ device: torch.device,
338
+ dtype: torch.dtype,
339
+ max_sequence_length: int = 77,
340
+ ) -> torch.Tensor:
341
+ text_inputs = tokenizer_2(
342
+ prompt,
343
+ padding="max_length",
344
+ max_length=max_sequence_length,
345
+ truncation=True,
346
+ return_tensors="pt",
347
+ )
348
+
349
+ prompt_embeds = text_encoder_2(text_inputs.input_ids.to(device), output_hidden_states=False).pooler_output
350
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
351
+
352
+ return {"pooled_prompt_embeds": prompt_embeds}
353
+
354
+
355
+ # TODO(aryan): refactor into model specs for better re-use
356
+ HUNYUAN_VIDEO_T2V_LORA_CONFIG = {
357
+ "pipeline_cls": HunyuanVideoPipeline,
358
+ "load_condition_models": load_condition_models,
359
+ "load_latent_models": load_latent_models,
360
+ "load_diffusion_models": load_diffusion_models,
361
+ "initialize_pipeline": initialize_pipeline,
362
+ "prepare_conditions": prepare_conditions,
363
+ "prepare_latents": prepare_latents,
364
+ "post_latent_preparation": post_latent_preparation,
365
+ "collate_fn": collate_fn_t2v,
366
+ "forward_pass": forward_pass,
367
+ "validation": validation,
368
+ }
finetrainers/models/ltx_video/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .full_finetune import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG
2
+ from .lora import LTX_VIDEO_T2V_LORA_CONFIG
finetrainers/models/ltx_video/full_finetune.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import LTXPipeline
2
+
3
+ from .lora import (
4
+ collate_fn_t2v,
5
+ forward_pass,
6
+ initialize_pipeline,
7
+ load_condition_models,
8
+ load_diffusion_models,
9
+ load_latent_models,
10
+ post_latent_preparation,
11
+ prepare_conditions,
12
+ prepare_latents,
13
+ validation,
14
+ )
15
+
16
+
17
+ # TODO(aryan): refactor into model specs for better re-use
18
+ LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG = {
19
+ "pipeline_cls": LTXPipeline,
20
+ "load_condition_models": load_condition_models,
21
+ "load_latent_models": load_latent_models,
22
+ "load_diffusion_models": load_diffusion_models,
23
+ "initialize_pipeline": initialize_pipeline,
24
+ "prepare_conditions": prepare_conditions,
25
+ "prepare_latents": prepare_latents,
26
+ "post_latent_preparation": post_latent_preparation,
27
+ "collate_fn": collate_fn_t2v,
28
+ "forward_pass": forward_pass,
29
+ "validation": validation,
30
+ }
finetrainers/models/ltx_video/lora.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from accelerate.logging import get_logger
6
+ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
7
+ from PIL import Image
8
+ from transformers import T5EncoderModel, T5Tokenizer
9
+
10
+
11
+ logger = get_logger("finetrainers") # pylint: disable=invalid-name
12
+
13
+
14
+ def load_condition_models(
15
+ model_id: str = "Lightricks/LTX-Video",
16
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
17
+ revision: Optional[str] = None,
18
+ cache_dir: Optional[str] = None,
19
+ **kwargs,
20
+ ) -> Dict[str, nn.Module]:
21
+ tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
22
+ text_encoder = T5EncoderModel.from_pretrained(
23
+ model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
24
+ )
25
+ return {"tokenizer": tokenizer, "text_encoder": text_encoder}
26
+
27
+
28
+ def load_latent_models(
29
+ model_id: str = "Lightricks/LTX-Video",
30
+ vae_dtype: torch.dtype = torch.bfloat16,
31
+ revision: Optional[str] = None,
32
+ cache_dir: Optional[str] = None,
33
+ **kwargs,
34
+ ) -> Dict[str, nn.Module]:
35
+ vae = AutoencoderKLLTXVideo.from_pretrained(
36
+ model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
37
+ )
38
+ return {"vae": vae}
39
+
40
+
41
+ def load_diffusion_models(
42
+ model_id: str = "Lightricks/LTX-Video",
43
+ transformer_dtype: torch.dtype = torch.bfloat16,
44
+ revision: Optional[str] = None,
45
+ cache_dir: Optional[str] = None,
46
+ **kwargs,
47
+ ) -> Dict[str, nn.Module]:
48
+ transformer = LTXVideoTransformer3DModel.from_pretrained(
49
+ model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
50
+ )
51
+ scheduler = FlowMatchEulerDiscreteScheduler()
52
+ return {"transformer": transformer, "scheduler": scheduler}
53
+
54
+
55
+ def initialize_pipeline(
56
+ model_id: str = "Lightricks/LTX-Video",
57
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
58
+ transformer_dtype: torch.dtype = torch.bfloat16,
59
+ vae_dtype: torch.dtype = torch.bfloat16,
60
+ tokenizer: Optional[T5Tokenizer] = None,
61
+ text_encoder: Optional[T5EncoderModel] = None,
62
+ transformer: Optional[LTXVideoTransformer3DModel] = None,
63
+ vae: Optional[AutoencoderKLLTXVideo] = None,
64
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
65
+ device: Optional[torch.device] = None,
66
+ revision: Optional[str] = None,
67
+ cache_dir: Optional[str] = None,
68
+ enable_slicing: bool = False,
69
+ enable_tiling: bool = False,
70
+ enable_model_cpu_offload: bool = False,
71
+ is_training: bool = False,
72
+ **kwargs,
73
+ ) -> LTXPipeline:
74
+ component_name_pairs = [
75
+ ("tokenizer", tokenizer),
76
+ ("text_encoder", text_encoder),
77
+ ("transformer", transformer),
78
+ ("vae", vae),
79
+ ("scheduler", scheduler),
80
+ ]
81
+ components = {}
82
+ for name, component in component_name_pairs:
83
+ if component is not None:
84
+ components[name] = component
85
+
86
+ pipe = LTXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
87
+ pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
88
+ pipe.vae = pipe.vae.to(dtype=vae_dtype)
89
+ # The transformer should already be in the correct dtype when training, so we don't need to cast it here.
90
+ # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
91
+ # DDP optimizer step.
92
+ if not is_training:
93
+ pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
94
+
95
+ if enable_slicing:
96
+ pipe.vae.enable_slicing()
97
+ if enable_tiling:
98
+ pipe.vae.enable_tiling()
99
+
100
+ if enable_model_cpu_offload:
101
+ pipe.enable_model_cpu_offload(device=device)
102
+ else:
103
+ pipe.to(device=device)
104
+
105
+ return pipe
106
+
107
+
108
+ def prepare_conditions(
109
+ tokenizer: T5Tokenizer,
110
+ text_encoder: T5EncoderModel,
111
+ prompt: Union[str, List[str]],
112
+ device: Optional[torch.device] = None,
113
+ dtype: Optional[torch.dtype] = None,
114
+ max_sequence_length: int = 128,
115
+ **kwargs,
116
+ ) -> torch.Tensor:
117
+ device = device or text_encoder.device
118
+ dtype = dtype or text_encoder.dtype
119
+
120
+ if isinstance(prompt, str):
121
+ prompt = [prompt]
122
+
123
+ return _encode_prompt_t5(tokenizer, text_encoder, prompt, device, dtype, max_sequence_length)
124
+
125
+
126
+ def prepare_latents(
127
+ vae: AutoencoderKLLTXVideo,
128
+ image_or_video: torch.Tensor,
129
+ patch_size: int = 1,
130
+ patch_size_t: int = 1,
131
+ device: Optional[torch.device] = None,
132
+ dtype: Optional[torch.dtype] = None,
133
+ generator: Optional[torch.Generator] = None,
134
+ precompute: bool = False,
135
+ ) -> torch.Tensor:
136
+ device = device or vae.device
137
+
138
+ if image_or_video.ndim == 4:
139
+ image_or_video = image_or_video.unsqueeze(2)
140
+ assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
141
+
142
+ image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
143
+ image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W]
144
+ if not precompute:
145
+ latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
146
+ latents = latents.to(dtype=dtype)
147
+ _, _, num_frames, height, width = latents.shape
148
+ latents = _normalize_latents(latents, vae.latents_mean, vae.latents_std)
149
+ latents = _pack_latents(latents, patch_size, patch_size_t)
150
+ return {"latents": latents, "num_frames": num_frames, "height": height, "width": width}
151
+ else:
152
+ if vae.use_slicing and image_or_video.shape[0] > 1:
153
+ encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)]
154
+ h = torch.cat(encoded_slices)
155
+ else:
156
+ h = vae._encode(image_or_video)
157
+ _, _, num_frames, height, width = h.shape
158
+
159
+ # TODO(aryan): This is very stupid that we might possibly be storing the latents_mean and latents_std in every file
160
+ # if precomputation is enabled. We should probably have a single file where re-usable properties like this are stored
161
+ # so as to reduce the disk memory requirements of the precomputed files.
162
+ return {
163
+ "latents": h,
164
+ "num_frames": num_frames,
165
+ "height": height,
166
+ "width": width,
167
+ "latents_mean": vae.latents_mean,
168
+ "latents_std": vae.latents_std,
169
+ }
170
+
171
+
172
+ def post_latent_preparation(
173
+ latents: torch.Tensor,
174
+ latents_mean: torch.Tensor,
175
+ latents_std: torch.Tensor,
176
+ num_frames: int,
177
+ height: int,
178
+ width: int,
179
+ patch_size: int = 1,
180
+ patch_size_t: int = 1,
181
+ **kwargs,
182
+ ) -> torch.Tensor:
183
+ latents = _normalize_latents(latents, latents_mean, latents_std)
184
+ latents = _pack_latents(latents, patch_size, patch_size_t)
185
+ return {"latents": latents, "num_frames": num_frames, "height": height, "width": width}
186
+
187
+
188
+ def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
189
+ return {
190
+ "prompts": [x["prompt"] for x in batch[0]],
191
+ "videos": torch.stack([x["video"] for x in batch[0]]),
192
+ }
193
+
194
+
195
+ def forward_pass(
196
+ transformer: LTXVideoTransformer3DModel,
197
+ prompt_embeds: torch.Tensor,
198
+ prompt_attention_mask: torch.Tensor,
199
+ latents: torch.Tensor,
200
+ noisy_latents: torch.Tensor,
201
+ timesteps: torch.LongTensor,
202
+ num_frames: int,
203
+ height: int,
204
+ width: int,
205
+ **kwargs,
206
+ ) -> torch.Tensor:
207
+ # TODO(aryan): make configurable
208
+ frame_rate = 25
209
+ latent_frame_rate = frame_rate / 8
210
+ spatial_compression_ratio = 32
211
+ rope_interpolation_scale = [1 / latent_frame_rate, spatial_compression_ratio, spatial_compression_ratio]
212
+
213
+ denoised_latents = transformer(
214
+ hidden_states=noisy_latents,
215
+ encoder_hidden_states=prompt_embeds,
216
+ timestep=timesteps,
217
+ encoder_attention_mask=prompt_attention_mask,
218
+ num_frames=num_frames,
219
+ height=height,
220
+ width=width,
221
+ rope_interpolation_scale=rope_interpolation_scale,
222
+ return_dict=False,
223
+ )[0]
224
+
225
+ return {"latents": denoised_latents}
226
+
227
+
228
+ def validation(
229
+ pipeline: LTXPipeline,
230
+ prompt: str,
231
+ image: Optional[Image.Image] = None,
232
+ video: Optional[List[Image.Image]] = None,
233
+ height: Optional[int] = None,
234
+ width: Optional[int] = None,
235
+ num_frames: Optional[int] = None,
236
+ frame_rate: int = 24,
237
+ num_videos_per_prompt: int = 1,
238
+ generator: Optional[torch.Generator] = None,
239
+ **kwargs,
240
+ ):
241
+ generation_kwargs = {
242
+ "prompt": prompt,
243
+ "height": height,
244
+ "width": width,
245
+ "num_frames": num_frames,
246
+ "frame_rate": frame_rate,
247
+ "num_videos_per_prompt": num_videos_per_prompt,
248
+ "generator": generator,
249
+ "return_dict": True,
250
+ "output_type": "pil",
251
+ }
252
+ generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
253
+ video = pipeline(**generation_kwargs).frames[0]
254
+ return [("video", video)]
255
+
256
+
257
+ def _encode_prompt_t5(
258
+ tokenizer: T5Tokenizer,
259
+ text_encoder: T5EncoderModel,
260
+ prompt: List[str],
261
+ device: torch.device,
262
+ dtype: torch.dtype,
263
+ max_sequence_length,
264
+ ) -> torch.Tensor:
265
+ batch_size = len(prompt)
266
+
267
+ text_inputs = tokenizer(
268
+ prompt,
269
+ padding="max_length",
270
+ max_length=max_sequence_length,
271
+ truncation=True,
272
+ add_special_tokens=True,
273
+ return_tensors="pt",
274
+ )
275
+ text_input_ids = text_inputs.input_ids
276
+ prompt_attention_mask = text_inputs.attention_mask
277
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
278
+
279
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
280
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
281
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
282
+
283
+ return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask}
284
+
285
+
286
+ def _normalize_latents(
287
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
288
+ ) -> torch.Tensor:
289
+ # Normalize latents across the channel dimension [B, C, F, H, W]
290
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
291
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
292
+ latents = (latents - latents_mean) * scaling_factor / latents_std
293
+ return latents
294
+
295
+
296
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
297
+ # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
298
+ # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
299
+ # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
300
+ # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
301
+ batch_size, num_channels, num_frames, height, width = latents.shape
302
+ post_patch_num_frames = num_frames // patch_size_t
303
+ post_patch_height = height // patch_size
304
+ post_patch_width = width // patch_size
305
+ latents = latents.reshape(
306
+ batch_size,
307
+ -1,
308
+ post_patch_num_frames,
309
+ patch_size_t,
310
+ post_patch_height,
311
+ patch_size,
312
+ post_patch_width,
313
+ patch_size,
314
+ )
315
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
316
+ return latents
317
+
318
+
319
+ LTX_VIDEO_T2V_LORA_CONFIG = {
320
+ "pipeline_cls": LTXPipeline,
321
+ "load_condition_models": load_condition_models,
322
+ "load_latent_models": load_latent_models,
323
+ "load_diffusion_models": load_diffusion_models,
324
+ "initialize_pipeline": initialize_pipeline,
325
+ "prepare_conditions": prepare_conditions,
326
+ "prepare_latents": prepare_latents,
327
+ "post_latent_preparation": post_latent_preparation,
328
+ "collate_fn": collate_fn_t2v,
329
+ "forward_pass": forward_pass,
330
+ "validation": validation,
331
+ }
finetrainers/patches.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+ from accelerate.logging import get_logger
5
+ from peft.tuners.tuners_utils import BaseTunerLayer
6
+
7
+ from .constants import FINETRAINERS_LOG_LEVEL
8
+
9
+
10
+ logger = get_logger("finetrainers") # pylint: disable=invalid-name
11
+ logger.setLevel(FINETRAINERS_LOG_LEVEL)
12
+
13
+
14
+ def perform_peft_patches() -> None:
15
+ _perform_patch_move_adapter_to_device_of_base_layer()
16
+
17
+
18
+ def _perform_patch_move_adapter_to_device_of_base_layer() -> None:
19
+ # We don't patch the method for torch.float32 and torch.bfloat16 because it is okay to train with them. If the model weights
20
+ # are in torch.float16, torch.float8_e4m3fn or torch.float8_e5m2, we need to patch this method to avoid conversion of
21
+ # LoRA weights from higher precision dtype.
22
+ BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer(
23
+ BaseTunerLayer._move_adapter_to_device_of_base_layer
24
+ )
25
+
26
+
27
+ def _patched_move_adapter_to_device_of_base_layer(func) -> None:
28
+ @functools.wraps(func)
29
+ def wrapper(self, *args, **kwargs):
30
+ with DisableTensorToDtype():
31
+ return func(self, *args, **kwargs)
32
+
33
+ return wrapper
34
+
35
+
36
+ class DisableTensorToDtype:
37
+ def __enter__(self):
38
+ self.original_to = torch.Tensor.to
39
+
40
+ def modified_to(tensor, *args, **kwargs):
41
+ # remove dtype from args if present
42
+ args = [arg if not isinstance(arg, torch.dtype) else None for arg in args]
43
+ if "dtype" in kwargs:
44
+ kwargs.pop("dtype")
45
+ return self.original_to(tensor, *args, **kwargs)
46
+
47
+ torch.Tensor.to = modified_to
48
+
49
+ def __exit__(self, exc_type, exc_val, exc_tb):
50
+ torch.Tensor.to = self.original_to
finetrainers/state.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from accelerate import Accelerator
3
+
4
+
5
+ class State:
6
+ # Training state
7
+ seed: int = None
8
+ model_name: str = None
9
+ accelerator: Accelerator = None
10
+ weight_dtype: torch.dtype = None
11
+ train_epochs: int = None
12
+ train_steps: int = None
13
+ overwrote_max_train_steps: bool = False
14
+ num_trainable_parameters: int = 0
15
+ learning_rate: float = None
16
+ train_batch_size: int = None
17
+ generator: torch.Generator = None
18
+ num_update_steps_per_epoch: int = None
19
+
20
+ # Hub state
21
+ repo_id: str = None
22
+
23
+ # Artifacts state
24
+ output_dir: str = None
finetrainers/trainer.py ADDED
@@ -0,0 +1,1207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import math
4
+ import os
5
+ import random
6
+ from datetime import datetime, timedelta
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List
9
+
10
+ import diffusers
11
+ import torch
12
+ import torch.backends
13
+ import transformers
14
+ import wandb
15
+ from accelerate import Accelerator, DistributedType
16
+ from accelerate.logging import get_logger
17
+ from accelerate.utils import (
18
+ DistributedDataParallelKwargs,
19
+ InitProcessGroupKwargs,
20
+ ProjectConfiguration,
21
+ gather_object,
22
+ set_seed,
23
+ )
24
+ from diffusers import DiffusionPipeline
25
+ from diffusers.configuration_utils import FrozenDict
26
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
27
+ from diffusers.optimization import get_scheduler
28
+ from diffusers.training_utils import cast_training_params
29
+ from diffusers.utils import export_to_video, load_image, load_video
30
+ from huggingface_hub import create_repo, upload_folder
31
+ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
32
+ from tqdm import tqdm
33
+
34
+ from .args import Args, validate_args
35
+ from .constants import (
36
+ FINETRAINERS_LOG_LEVEL,
37
+ PRECOMPUTED_CONDITIONS_DIR_NAME,
38
+ PRECOMPUTED_DIR_NAME,
39
+ PRECOMPUTED_LATENTS_DIR_NAME,
40
+ )
41
+ from .dataset import BucketSampler, ImageOrVideoDatasetWithResizing, PrecomputedDataset
42
+ from .hooks import apply_layerwise_upcasting
43
+ from .models import get_config_from_model_name
44
+ from .patches import perform_peft_patches
45
+ from .state import State
46
+ from .utils.checkpointing import get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from
47
+ from .utils.data_utils import should_perform_precomputation
48
+ from .utils.diffusion_utils import (
49
+ get_scheduler_alphas,
50
+ get_scheduler_sigmas,
51
+ prepare_loss_weights,
52
+ prepare_sigmas,
53
+ prepare_target,
54
+ )
55
+ from .utils.file_utils import string_to_filename
56
+ from .utils.hub_utils import save_model_card
57
+ from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous
58
+ from .utils.model_utils import resolve_vae_cls_from_ckpt_path
59
+ from .utils.optimizer_utils import get_optimizer
60
+ from .utils.torch_utils import align_device_and_dtype, expand_tensor_dims, unwrap_model
61
+
62
+
63
+ logger = get_logger("finetrainers")
64
+ logger.setLevel(FINETRAINERS_LOG_LEVEL)
65
+
66
+
67
+ class Trainer:
68
+ def __init__(self, args: Args) -> None:
69
+ validate_args(args)
70
+
71
+ self.args = args
72
+ self.args.seed = self.args.seed or datetime.now().year
73
+ self.state = State()
74
+
75
+ # Tokenizers
76
+ self.tokenizer = None
77
+ self.tokenizer_2 = None
78
+ self.tokenizer_3 = None
79
+
80
+ # Text encoders
81
+ self.text_encoder = None
82
+ self.text_encoder_2 = None
83
+ self.text_encoder_3 = None
84
+
85
+ # Denoisers
86
+ self.transformer = None
87
+ self.unet = None
88
+
89
+ # Autoencoders
90
+ self.vae = None
91
+
92
+ # Scheduler
93
+ self.scheduler = None
94
+
95
+ self.transformer_config = None
96
+ self.vae_config = None
97
+
98
+ self._init_distributed()
99
+ self._init_logging()
100
+ self._init_directories_and_repositories()
101
+ self._init_config_options()
102
+
103
+ # Peform any patches needed for training
104
+ if len(self.args.layerwise_upcasting_modules) > 0:
105
+ perform_peft_patches()
106
+ # TODO(aryan): handle text encoders
107
+ # if any(["text_encoder" in component_name for component_name in self.args.layerwise_upcasting_modules]):
108
+ # perform_text_encoder_patches()
109
+
110
+ self.state.model_name = self.args.model_name
111
+ self.model_config = get_config_from_model_name(self.args.model_name, self.args.training_type)
112
+
113
+ def prepare_dataset(self) -> None:
114
+ # TODO(aryan): Make a background process for fetching
115
+ logger.info("Initializing dataset and dataloader")
116
+
117
+ self.dataset = ImageOrVideoDatasetWithResizing(
118
+ data_root=self.args.data_root,
119
+ caption_column=self.args.caption_column,
120
+ video_column=self.args.video_column,
121
+ resolution_buckets=self.args.video_resolution_buckets,
122
+ dataset_file=self.args.dataset_file,
123
+ id_token=self.args.id_token,
124
+ remove_llm_prefixes=self.args.remove_common_llm_caption_prefixes,
125
+ )
126
+ self.dataloader = torch.utils.data.DataLoader(
127
+ self.dataset,
128
+ batch_size=1,
129
+ sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True),
130
+ collate_fn=self.model_config.get("collate_fn"),
131
+ num_workers=self.args.dataloader_num_workers,
132
+ pin_memory=self.args.pin_memory,
133
+ )
134
+
135
+ def prepare_models(self) -> None:
136
+ logger.info("Initializing models")
137
+
138
+ load_components_kwargs = self._get_load_components_kwargs()
139
+ condition_components, latent_components, diffusion_components = {}, {}, {}
140
+ if not self.args.precompute_conditions:
141
+ # To download the model files first on the main process (if not already present)
142
+ # and then load the cached files afterward from the other processes.
143
+ with self.state.accelerator.main_process_first():
144
+ condition_components = self.model_config["load_condition_models"](**load_components_kwargs)
145
+ latent_components = self.model_config["load_latent_models"](**load_components_kwargs)
146
+ diffusion_components = self.model_config["load_diffusion_models"](**load_components_kwargs)
147
+
148
+ components = {}
149
+ components.update(condition_components)
150
+ components.update(latent_components)
151
+ components.update(diffusion_components)
152
+ self._set_components(components)
153
+
154
+ if self.vae is not None:
155
+ if self.args.enable_slicing:
156
+ self.vae.enable_slicing()
157
+ if self.args.enable_tiling:
158
+ self.vae.enable_tiling()
159
+
160
+ def prepare_precomputations(self) -> None:
161
+ if not self.args.precompute_conditions:
162
+ return
163
+
164
+ logger.info("Initializing precomputations")
165
+
166
+ if self.args.batch_size != 1:
167
+ raise ValueError("Precomputation is only supported with batch size 1. This will be supported in future.")
168
+
169
+ def collate_fn(batch):
170
+ latent_conditions = [x["latent_conditions"] for x in batch]
171
+ text_conditions = [x["text_conditions"] for x in batch]
172
+ batched_latent_conditions = {}
173
+ batched_text_conditions = {}
174
+ for key in list(latent_conditions[0].keys()):
175
+ if torch.is_tensor(latent_conditions[0][key]):
176
+ batched_latent_conditions[key] = torch.cat([x[key] for x in latent_conditions], dim=0)
177
+ else:
178
+ # TODO(aryan): implement batch sampler for precomputed latents
179
+ batched_latent_conditions[key] = [x[key] for x in latent_conditions][0]
180
+ for key in list(text_conditions[0].keys()):
181
+ if torch.is_tensor(text_conditions[0][key]):
182
+ batched_text_conditions[key] = torch.cat([x[key] for x in text_conditions], dim=0)
183
+ else:
184
+ # TODO(aryan): implement batch sampler for precomputed latents
185
+ batched_text_conditions[key] = [x[key] for x in text_conditions][0]
186
+ return {"latent_conditions": batched_latent_conditions, "text_conditions": batched_text_conditions}
187
+
188
+ cleaned_model_id = string_to_filename(self.args.pretrained_model_name_or_path)
189
+ precomputation_dir = (
190
+ Path(self.args.data_root) / f"{self.args.model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
191
+ )
192
+ should_precompute = should_perform_precomputation(precomputation_dir)
193
+ if not should_precompute:
194
+ logger.info("Precomputed conditions and latents found. Loading precomputed data.")
195
+ self.dataloader = torch.utils.data.DataLoader(
196
+ PrecomputedDataset(
197
+ data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id
198
+ ),
199
+ batch_size=self.args.batch_size,
200
+ shuffle=True,
201
+ collate_fn=collate_fn,
202
+ num_workers=self.args.dataloader_num_workers,
203
+ pin_memory=self.args.pin_memory,
204
+ )
205
+ return
206
+
207
+ logger.info("Precomputed conditions and latents not found. Running precomputation.")
208
+
209
+ # At this point, no models are loaded, so we need to load and precompute conditions and latents
210
+ with self.state.accelerator.main_process_first():
211
+ condition_components = self.model_config["load_condition_models"](**self._get_load_components_kwargs())
212
+ self._set_components(condition_components)
213
+ self._move_components_to_device()
214
+ self._disable_grad_for_components([self.text_encoder, self.text_encoder_2, self.text_encoder_3])
215
+
216
+ if self.args.caption_dropout_p > 0 and self.args.caption_dropout_technique == "empty":
217
+ logger.warning(
218
+ "Caption dropout is not supported with precomputation yet. This will be supported in the future."
219
+ )
220
+
221
+ conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
222
+ latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
223
+ conditions_dir.mkdir(parents=True, exist_ok=True)
224
+ latents_dir.mkdir(parents=True, exist_ok=True)
225
+
226
+ accelerator = self.state.accelerator
227
+
228
+ # Precompute conditions
229
+ progress_bar = tqdm(
230
+ range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes),
231
+ desc="Precomputing conditions",
232
+ disable=not accelerator.is_local_main_process,
233
+ )
234
+ index = 0
235
+ for i, data in enumerate(self.dataset):
236
+ if i % accelerator.num_processes != accelerator.process_index:
237
+ continue
238
+
239
+ logger.debug(
240
+ f"Precomputing conditions for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}"
241
+ )
242
+
243
+ text_conditions = self.model_config["prepare_conditions"](
244
+ tokenizer=self.tokenizer,
245
+ tokenizer_2=self.tokenizer_2,
246
+ tokenizer_3=self.tokenizer_3,
247
+ text_encoder=self.text_encoder,
248
+ text_encoder_2=self.text_encoder_2,
249
+ text_encoder_3=self.text_encoder_3,
250
+ prompt=data["prompt"],
251
+ device=accelerator.device,
252
+ dtype=self.args.transformer_dtype,
253
+ )
254
+ filename = conditions_dir / f"conditions-{accelerator.process_index}-{index}.pt"
255
+ torch.save(text_conditions, filename.as_posix())
256
+ index += 1
257
+ progress_bar.update(1)
258
+ self._delete_components()
259
+
260
+ memory_statistics = get_memory_statistics()
261
+ logger.info(f"Memory after precomputing conditions: {json.dumps(memory_statistics, indent=4)}")
262
+ torch.cuda.reset_peak_memory_stats(accelerator.device)
263
+
264
+ # Precompute latents
265
+ with self.state.accelerator.main_process_first():
266
+ latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs())
267
+ self._set_components(latent_components)
268
+ self._move_components_to_device()
269
+ self._disable_grad_for_components([self.vae])
270
+
271
+ if self.vae is not None:
272
+ if self.args.enable_slicing:
273
+ self.vae.enable_slicing()
274
+ if self.args.enable_tiling:
275
+ self.vae.enable_tiling()
276
+
277
+ progress_bar = tqdm(
278
+ range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes),
279
+ desc="Precomputing latents",
280
+ disable=not accelerator.is_local_main_process,
281
+ )
282
+ index = 0
283
+ for i, data in enumerate(self.dataset):
284
+ if i % accelerator.num_processes != accelerator.process_index:
285
+ continue
286
+
287
+ logger.debug(
288
+ f"Precomputing latents for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}"
289
+ )
290
+
291
+ latent_conditions = self.model_config["prepare_latents"](
292
+ vae=self.vae,
293
+ image_or_video=data["video"].unsqueeze(0),
294
+ device=accelerator.device,
295
+ dtype=self.args.transformer_dtype,
296
+ generator=self.state.generator,
297
+ precompute=True,
298
+ )
299
+ filename = latents_dir / f"latents-{accelerator.process_index}-{index}.pt"
300
+ torch.save(latent_conditions, filename.as_posix())
301
+ index += 1
302
+ progress_bar.update(1)
303
+ self._delete_components()
304
+
305
+ accelerator.wait_for_everyone()
306
+ logger.info("Precomputation complete")
307
+
308
+ memory_statistics = get_memory_statistics()
309
+ logger.info(f"Memory after precomputing latents: {json.dumps(memory_statistics, indent=4)}")
310
+ torch.cuda.reset_peak_memory_stats(accelerator.device)
311
+
312
+ # Update dataloader to use precomputed conditions and latents
313
+ self.dataloader = torch.utils.data.DataLoader(
314
+ PrecomputedDataset(
315
+ data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id
316
+ ),
317
+ batch_size=self.args.batch_size,
318
+ shuffle=True,
319
+ collate_fn=collate_fn,
320
+ num_workers=self.args.dataloader_num_workers,
321
+ pin_memory=self.args.pin_memory,
322
+ )
323
+
324
+ def prepare_trainable_parameters(self) -> None:
325
+ logger.info("Initializing trainable parameters")
326
+
327
+ with self.state.accelerator.main_process_first():
328
+ diffusion_components = self.model_config["load_diffusion_models"](**self._get_load_components_kwargs())
329
+ self._set_components(diffusion_components)
330
+
331
+ components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.vae]
332
+ self._disable_grad_for_components(components)
333
+
334
+ if self.args.training_type == "full-finetune":
335
+ logger.info("Finetuning transformer with no additional parameters")
336
+ self._enable_grad_for_components([self.transformer])
337
+ else:
338
+ logger.info("Finetuning transformer with PEFT parameters")
339
+ self._disable_grad_for_components([self.transformer])
340
+
341
+ # Layerwise upcasting must be applied before adding the LoRA adapter.
342
+ # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on
343
+ # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly.
344
+ if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules:
345
+ apply_layerwise_upcasting(
346
+ self.transformer,
347
+ storage_dtype=self.args.layerwise_upcasting_storage_dtype,
348
+ compute_dtype=self.args.transformer_dtype,
349
+ skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern,
350
+ non_blocking=True,
351
+ )
352
+
353
+ self._move_components_to_device()
354
+
355
+ if self.args.gradient_checkpointing:
356
+ self.transformer.enable_gradient_checkpointing()
357
+
358
+ if self.args.training_type == "lora":
359
+ transformer_lora_config = LoraConfig(
360
+ r=self.args.rank,
361
+ lora_alpha=self.args.lora_alpha,
362
+ init_lora_weights=True,
363
+ target_modules=self.args.target_modules,
364
+ )
365
+ self.transformer.add_adapter(transformer_lora_config)
366
+ else:
367
+ transformer_lora_config = None
368
+
369
+ # TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32
370
+ # even if layerwise upcasting. Would be nice to have a test as well
371
+
372
+ self.register_saving_loading_hooks(transformer_lora_config)
373
+
374
+ def register_saving_loading_hooks(self, transformer_lora_config):
375
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
376
+ def save_model_hook(models, weights, output_dir):
377
+ if self.state.accelerator.is_main_process:
378
+ transformer_lora_layers_to_save = None
379
+
380
+ for model in models:
381
+ if isinstance(
382
+ unwrap_model(self.state.accelerator, model),
383
+ type(unwrap_model(self.state.accelerator, self.transformer)),
384
+ ):
385
+ model = unwrap_model(self.state.accelerator, model)
386
+ if self.args.training_type == "lora":
387
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
388
+ else:
389
+ raise ValueError(f"Unexpected save model: {model.__class__}")
390
+
391
+ # make sure to pop weight so that corresponding model is not saved again
392
+ if weights:
393
+ weights.pop()
394
+
395
+ if self.args.training_type == "lora":
396
+ self.model_config["pipeline_cls"].save_lora_weights(
397
+ output_dir,
398
+ transformer_lora_layers=transformer_lora_layers_to_save,
399
+ )
400
+ else:
401
+ model.save_pretrained(os.path.join(output_dir, "transformer"))
402
+
403
+ # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need
404
+ # to able to load all diffusion components from a specific checkpoint folder during validation, we need to
405
+ # ensure the scheduler config is serialized as well.
406
+ self.scheduler.save_pretrained(os.path.join(output_dir, "scheduler"))
407
+
408
+ def load_model_hook(models, input_dir):
409
+ if not self.state.accelerator.distributed_type == DistributedType.DEEPSPEED:
410
+ while len(models) > 0:
411
+ model = models.pop()
412
+ if isinstance(
413
+ unwrap_model(self.state.accelerator, model),
414
+ type(unwrap_model(self.state.accelerator, self.transformer)),
415
+ ):
416
+ transformer_ = unwrap_model(self.state.accelerator, model)
417
+ else:
418
+ raise ValueError(
419
+ f"Unexpected save model: {unwrap_model(self.state.accelerator, model).__class__}"
420
+ )
421
+ else:
422
+ transformer_cls_ = unwrap_model(self.state.accelerator, self.transformer).__class__
423
+
424
+ if self.args.training_type == "lora":
425
+ transformer_ = transformer_cls_.from_pretrained(
426
+ self.args.pretrained_model_name_or_path, subfolder="transformer"
427
+ )
428
+ transformer_.add_adapter(transformer_lora_config)
429
+ lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(input_dir)
430
+ transformer_state_dict = {
431
+ f'{k.replace("transformer.", "")}': v
432
+ for k, v in lora_state_dict.items()
433
+ if k.startswith("transformer.")
434
+ }
435
+ incompatible_keys = set_peft_model_state_dict(
436
+ transformer_, transformer_state_dict, adapter_name="default"
437
+ )
438
+ if incompatible_keys is not None:
439
+ # check only for unexpected keys
440
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
441
+ if unexpected_keys:
442
+ logger.warning(
443
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
444
+ f" {unexpected_keys}. "
445
+ )
446
+ else:
447
+ transformer_ = transformer_cls_.from_pretrained(os.path.join(input_dir, "transformer"))
448
+
449
+ self.state.accelerator.register_save_state_pre_hook(save_model_hook)
450
+ self.state.accelerator.register_load_state_pre_hook(load_model_hook)
451
+
452
+ def prepare_optimizer(self) -> None:
453
+ logger.info("Initializing optimizer and lr scheduler")
454
+
455
+ self.state.train_epochs = self.args.train_epochs
456
+ self.state.train_steps = self.args.train_steps
457
+
458
+ # Make sure the trainable params are in float32
459
+ if self.args.training_type == "lora":
460
+ cast_training_params([self.transformer], dtype=torch.float32)
461
+
462
+ self.state.learning_rate = self.args.lr
463
+ if self.args.scale_lr:
464
+ self.state.learning_rate = (
465
+ self.state.learning_rate
466
+ * self.args.gradient_accumulation_steps
467
+ * self.args.batch_size
468
+ * self.state.accelerator.num_processes
469
+ )
470
+
471
+ transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, self.transformer.parameters()))
472
+ transformer_parameters_with_lr = {
473
+ "params": transformer_trainable_parameters,
474
+ "lr": self.state.learning_rate,
475
+ }
476
+ params_to_optimize = [transformer_parameters_with_lr]
477
+ self.state.num_trainable_parameters = sum(p.numel() for p in transformer_trainable_parameters)
478
+
479
+ use_deepspeed_opt = (
480
+ self.state.accelerator.state.deepspeed_plugin is not None
481
+ and "optimizer" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config
482
+ )
483
+ optimizer = get_optimizer(
484
+ params_to_optimize=params_to_optimize,
485
+ optimizer_name=self.args.optimizer,
486
+ learning_rate=self.state.learning_rate,
487
+ beta1=self.args.beta1,
488
+ beta2=self.args.beta2,
489
+ beta3=self.args.beta3,
490
+ epsilon=self.args.epsilon,
491
+ weight_decay=self.args.weight_decay,
492
+ use_8bit=self.args.use_8bit_bnb,
493
+ use_deepspeed=use_deepspeed_opt,
494
+ )
495
+
496
+ num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps)
497
+ if self.state.train_steps is None:
498
+ self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
499
+ self.state.overwrote_max_train_steps = True
500
+
501
+ use_deepspeed_lr_scheduler = (
502
+ self.state.accelerator.state.deepspeed_plugin is not None
503
+ and "scheduler" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config
504
+ )
505
+ total_training_steps = self.state.train_steps * self.state.accelerator.num_processes
506
+ num_warmup_steps = self.args.lr_warmup_steps * self.state.accelerator.num_processes
507
+
508
+ if use_deepspeed_lr_scheduler:
509
+ from accelerate.utils import DummyScheduler
510
+
511
+ lr_scheduler = DummyScheduler(
512
+ name=self.args.lr_scheduler,
513
+ optimizer=optimizer,
514
+ total_num_steps=total_training_steps,
515
+ num_warmup_steps=num_warmup_steps,
516
+ )
517
+ else:
518
+ lr_scheduler = get_scheduler(
519
+ name=self.args.lr_scheduler,
520
+ optimizer=optimizer,
521
+ num_warmup_steps=num_warmup_steps,
522
+ num_training_steps=total_training_steps,
523
+ num_cycles=self.args.lr_num_cycles,
524
+ power=self.args.lr_power,
525
+ )
526
+
527
+ self.optimizer = optimizer
528
+ self.lr_scheduler = lr_scheduler
529
+
530
+ def prepare_for_training(self) -> None:
531
+ self.transformer, self.optimizer, self.dataloader, self.lr_scheduler = self.state.accelerator.prepare(
532
+ self.transformer, self.optimizer, self.dataloader, self.lr_scheduler
533
+ )
534
+
535
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
536
+ num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps)
537
+ if self.state.overwrote_max_train_steps:
538
+ self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
539
+ # Afterwards we recalculate our number of training epochs
540
+ self.state.train_epochs = math.ceil(self.state.train_steps / num_update_steps_per_epoch)
541
+ self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
542
+
543
+ def prepare_trackers(self) -> None:
544
+ logger.info("Initializing trackers")
545
+
546
+ tracker_name = self.args.tracker_name or "finetrainers-experiment"
547
+ self.state.accelerator.init_trackers(tracker_name, config=self._get_training_info())
548
+
549
+ def train(self) -> None:
550
+ logger.info("Starting training")
551
+
552
+ memory_statistics = get_memory_statistics()
553
+ logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
554
+
555
+ if self.vae_config is None:
556
+ # If we've precomputed conditions and latents already, and are now re-using it, we will never load
557
+ # the VAE so self.vae_config will not be set. So, we need to load it here.
558
+ vae_cls = resolve_vae_cls_from_ckpt_path(
559
+ self.args.pretrained_model_name_or_path, revision=self.args.revision, cache_dir=self.args.cache_dir
560
+ )
561
+ vae_config = vae_cls.load_config(
562
+ self.args.pretrained_model_name_or_path,
563
+ subfolder="vae",
564
+ revision=self.args.revision,
565
+ cache_dir=self.args.cache_dir,
566
+ )
567
+ self.vae_config = FrozenDict(**vae_config)
568
+
569
+ # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need
570
+ # to able to load all diffusion components from a specific checkpoint folder during validation, we need to
571
+ # ensure the scheduler config is serialized as well.
572
+ if self.args.training_type == "full-finetune":
573
+ self.scheduler.save_pretrained(os.path.join(self.args.output_dir, "scheduler"))
574
+
575
+ self.state.train_batch_size = (
576
+ self.args.batch_size * self.state.accelerator.num_processes * self.args.gradient_accumulation_steps
577
+ )
578
+ info = {
579
+ "trainable parameters": self.state.num_trainable_parameters,
580
+ "total samples": len(self.dataset),
581
+ "train epochs": self.state.train_epochs,
582
+ "train steps": self.state.train_steps,
583
+ "batches per device": self.args.batch_size,
584
+ "total batches observed per epoch": len(self.dataloader),
585
+ "train batch size": self.state.train_batch_size,
586
+ "gradient accumulation steps": self.args.gradient_accumulation_steps,
587
+ }
588
+ logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
589
+
590
+ global_step = 0
591
+ first_epoch = 0
592
+ initial_global_step = 0
593
+
594
+ # Potentially load in the weights and states from a previous save
595
+ (
596
+ resume_from_checkpoint_path,
597
+ initial_global_step,
598
+ global_step,
599
+ first_epoch,
600
+ ) = get_latest_ckpt_path_to_resume_from(
601
+ resume_from_checkpoint=self.args.resume_from_checkpoint,
602
+ num_update_steps_per_epoch=self.state.num_update_steps_per_epoch,
603
+ output_dir=self.args.output_dir,
604
+ )
605
+ if resume_from_checkpoint_path:
606
+ self.state.accelerator.load_state(resume_from_checkpoint_path)
607
+
608
+ progress_bar = tqdm(
609
+ range(0, self.state.train_steps),
610
+ initial=initial_global_step,
611
+ desc="Training steps",
612
+ disable=not self.state.accelerator.is_local_main_process,
613
+ )
614
+
615
+ accelerator = self.state.accelerator
616
+ generator = torch.Generator(device=accelerator.device)
617
+ if self.args.seed is not None:
618
+ generator = generator.manual_seed(self.args.seed)
619
+ self.state.generator = generator
620
+
621
+ scheduler_sigmas = get_scheduler_sigmas(self.scheduler)
622
+ scheduler_sigmas = (
623
+ scheduler_sigmas.to(device=accelerator.device, dtype=torch.float32)
624
+ if scheduler_sigmas is not None
625
+ else None
626
+ )
627
+ scheduler_alphas = get_scheduler_alphas(self.scheduler)
628
+ scheduler_alphas = (
629
+ scheduler_alphas.to(device=accelerator.device, dtype=torch.float32)
630
+ if scheduler_alphas is not None
631
+ else None
632
+ )
633
+
634
+ for epoch in range(first_epoch, self.state.train_epochs):
635
+ logger.debug(f"Starting epoch ({epoch + 1}/{self.state.train_epochs})")
636
+
637
+ self.transformer.train()
638
+ models_to_accumulate = [self.transformer]
639
+ epoch_loss = 0.0
640
+ num_loss_updates = 0
641
+
642
+ for step, batch in enumerate(self.dataloader):
643
+ logger.debug(f"Starting step {step + 1}")
644
+ logs = {}
645
+
646
+ with accelerator.accumulate(models_to_accumulate):
647
+ if not self.args.precompute_conditions:
648
+ videos = batch["videos"]
649
+ prompts = batch["prompts"]
650
+ batch_size = len(prompts)
651
+
652
+ if self.args.caption_dropout_technique == "empty":
653
+ if random.random() < self.args.caption_dropout_p:
654
+ prompts = [""] * batch_size
655
+
656
+ latent_conditions = self.model_config["prepare_latents"](
657
+ vae=self.vae,
658
+ image_or_video=videos,
659
+ patch_size=self.transformer_config.patch_size,
660
+ patch_size_t=self.transformer_config.patch_size_t,
661
+ device=accelerator.device,
662
+ dtype=self.args.transformer_dtype,
663
+ generator=self.state.generator,
664
+ )
665
+ text_conditions = self.model_config["prepare_conditions"](
666
+ tokenizer=self.tokenizer,
667
+ text_encoder=self.text_encoder,
668
+ tokenizer_2=self.tokenizer_2,
669
+ text_encoder_2=self.text_encoder_2,
670
+ prompt=prompts,
671
+ device=accelerator.device,
672
+ dtype=self.args.transformer_dtype,
673
+ )
674
+ else:
675
+ latent_conditions = batch["latent_conditions"]
676
+ text_conditions = batch["text_conditions"]
677
+ latent_conditions["latents"] = DiagonalGaussianDistribution(
678
+ latent_conditions["latents"]
679
+ ).sample(self.state.generator)
680
+
681
+ # This method should only be called for precomputed latents.
682
+ # TODO(aryan): rename this in separate PR
683
+ latent_conditions = self.model_config["post_latent_preparation"](
684
+ vae_config=self.vae_config,
685
+ patch_size=self.transformer_config.patch_size,
686
+ patch_size_t=self.transformer_config.patch_size_t,
687
+ **latent_conditions,
688
+ )
689
+ align_device_and_dtype(latent_conditions, accelerator.device, self.args.transformer_dtype)
690
+ align_device_and_dtype(text_conditions, accelerator.device, self.args.transformer_dtype)
691
+ batch_size = latent_conditions["latents"].shape[0]
692
+
693
+ latent_conditions = make_contiguous(latent_conditions)
694
+ text_conditions = make_contiguous(text_conditions)
695
+
696
+ if self.args.caption_dropout_technique == "zero":
697
+ if random.random() < self.args.caption_dropout_p:
698
+ text_conditions["prompt_embeds"].fill_(0)
699
+ text_conditions["prompt_attention_mask"].fill_(False)
700
+
701
+ # TODO(aryan): refactor later
702
+ if "pooled_prompt_embeds" in text_conditions:
703
+ text_conditions["pooled_prompt_embeds"].fill_(0)
704
+
705
+ sigmas = prepare_sigmas(
706
+ scheduler=self.scheduler,
707
+ sigmas=scheduler_sigmas,
708
+ batch_size=batch_size,
709
+ num_train_timesteps=self.scheduler.config.num_train_timesteps,
710
+ flow_weighting_scheme=self.args.flow_weighting_scheme,
711
+ flow_logit_mean=self.args.flow_logit_mean,
712
+ flow_logit_std=self.args.flow_logit_std,
713
+ flow_mode_scale=self.args.flow_mode_scale,
714
+ device=accelerator.device,
715
+ generator=self.state.generator,
716
+ )
717
+ timesteps = (sigmas * 1000.0).long()
718
+
719
+ noise = torch.randn(
720
+ latent_conditions["latents"].shape,
721
+ generator=self.state.generator,
722
+ device=accelerator.device,
723
+ dtype=self.args.transformer_dtype,
724
+ )
725
+ sigmas = expand_tensor_dims(sigmas, ndim=noise.ndim)
726
+
727
+ # TODO(aryan): We probably don't need calculate_noisy_latents because we can determine the type of
728
+ # scheduler and calculate the noisy latents accordingly. Look into this later.
729
+ if "calculate_noisy_latents" in self.model_config.keys():
730
+ noisy_latents = self.model_config["calculate_noisy_latents"](
731
+ scheduler=self.scheduler,
732
+ noise=noise,
733
+ latents=latent_conditions["latents"],
734
+ timesteps=timesteps,
735
+ )
736
+ else:
737
+ # Default to flow-matching noise addition
738
+ noisy_latents = (1.0 - sigmas) * latent_conditions["latents"] + sigmas * noise
739
+ noisy_latents = noisy_latents.to(latent_conditions["latents"].dtype)
740
+
741
+ latent_conditions.update({"noisy_latents": noisy_latents})
742
+
743
+ weights = prepare_loss_weights(
744
+ scheduler=self.scheduler,
745
+ alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None,
746
+ sigmas=sigmas,
747
+ flow_weighting_scheme=self.args.flow_weighting_scheme,
748
+ )
749
+ weights = expand_tensor_dims(weights, noise.ndim)
750
+
751
+ pred = self.model_config["forward_pass"](
752
+ transformer=self.transformer,
753
+ scheduler=self.scheduler,
754
+ timesteps=timesteps,
755
+ **latent_conditions,
756
+ **text_conditions,
757
+ )
758
+ target = prepare_target(
759
+ scheduler=self.scheduler, noise=noise, latents=latent_conditions["latents"]
760
+ )
761
+
762
+ loss = weights.float() * (pred["latents"].float() - target.float()).pow(2)
763
+ # Average loss across all but batch dimension
764
+ loss = loss.mean(list(range(1, loss.ndim)))
765
+ # Average loss across batch dimension
766
+ loss = loss.mean()
767
+ accelerator.backward(loss)
768
+
769
+ if accelerator.sync_gradients:
770
+ if accelerator.distributed_type == DistributedType.DEEPSPEED:
771
+ grad_norm = self.transformer.get_global_grad_norm()
772
+ # In some cases the grad norm may not return a float
773
+ if torch.is_tensor(grad_norm):
774
+ grad_norm = grad_norm.item()
775
+ else:
776
+ grad_norm = accelerator.clip_grad_norm_(
777
+ self.transformer.parameters(), self.args.max_grad_norm
778
+ )
779
+ if torch.is_tensor(grad_norm):
780
+ grad_norm = grad_norm.item()
781
+
782
+ logs["grad_norm"] = grad_norm
783
+
784
+ self.optimizer.step()
785
+ self.lr_scheduler.step()
786
+ self.optimizer.zero_grad()
787
+
788
+ # Checks if the accelerator has performed an optimization step behind the scenes
789
+ if accelerator.sync_gradients:
790
+ progress_bar.update(1)
791
+ global_step += 1
792
+
793
+ # Checkpointing
794
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
795
+ if global_step % self.args.checkpointing_steps == 0:
796
+ save_path = get_intermediate_ckpt_path(
797
+ checkpointing_limit=self.args.checkpointing_limit,
798
+ step=global_step,
799
+ output_dir=self.args.output_dir,
800
+ )
801
+ accelerator.save_state(save_path)
802
+
803
+ # Maybe run validation
804
+ should_run_validation = (
805
+ self.args.validation_every_n_steps is not None
806
+ and global_step % self.args.validation_every_n_steps == 0
807
+ )
808
+ if should_run_validation:
809
+ self.validate(global_step)
810
+
811
+ loss_item = loss.detach().item()
812
+ epoch_loss += loss_item
813
+ num_loss_updates += 1
814
+ logs["step_loss"] = loss_item
815
+ logs["lr"] = self.lr_scheduler.get_last_lr()[0]
816
+ progress_bar.set_postfix(logs)
817
+ accelerator.log(logs, step=global_step)
818
+
819
+ if global_step >= self.state.train_steps:
820
+ break
821
+
822
+ if num_loss_updates > 0:
823
+ epoch_loss /= num_loss_updates
824
+ accelerator.log({"epoch_loss": epoch_loss}, step=global_step)
825
+ memory_statistics = get_memory_statistics()
826
+ logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")
827
+
828
+ # Maybe run validation
829
+ should_run_validation = (
830
+ self.args.validation_every_n_epochs is not None
831
+ and (epoch + 1) % self.args.validation_every_n_epochs == 0
832
+ )
833
+ if should_run_validation:
834
+ self.validate(global_step)
835
+
836
+ accelerator.wait_for_everyone()
837
+ if accelerator.is_main_process:
838
+ transformer = unwrap_model(accelerator, self.transformer)
839
+
840
+ if self.args.training_type == "lora":
841
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
842
+
843
+ self.model_config["pipeline_cls"].save_lora_weights(
844
+ save_directory=self.args.output_dir,
845
+ transformer_lora_layers=transformer_lora_layers,
846
+ )
847
+ else:
848
+ transformer.save_pretrained(os.path.join(self.args.output_dir, "transformer"))
849
+ accelerator.wait_for_everyone()
850
+ self.validate(step=global_step, final_validation=True)
851
+
852
+ if accelerator.is_main_process:
853
+ if self.args.push_to_hub:
854
+ upload_folder(
855
+ repo_id=self.state.repo_id, folder_path=self.args.output_dir, ignore_patterns=["checkpoint-*"]
856
+ )
857
+
858
+ self._delete_components()
859
+ memory_statistics = get_memory_statistics()
860
+ logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
861
+
862
+ accelerator.end_training()
863
+
864
+ def validate(self, step: int, final_validation: bool = False) -> None:
865
+ logger.info("Starting validation")
866
+
867
+ accelerator = self.state.accelerator
868
+ num_validation_samples = len(self.args.validation_prompts)
869
+
870
+ if num_validation_samples == 0:
871
+ logger.warning("No validation samples found. Skipping validation.")
872
+ if accelerator.is_main_process:
873
+ if self.args.push_to_hub:
874
+ save_model_card(
875
+ args=self.args,
876
+ repo_id=self.state.repo_id,
877
+ videos=None,
878
+ validation_prompts=None,
879
+ )
880
+ return
881
+
882
+ self.transformer.eval()
883
+
884
+ memory_statistics = get_memory_statistics()
885
+ logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
886
+
887
+ pipeline = self._get_and_prepare_pipeline_for_validation(final_validation=final_validation)
888
+
889
+ all_processes_artifacts = []
890
+ prompts_to_filenames = {}
891
+ for i in range(num_validation_samples):
892
+ # Skip current validation on all processes but one
893
+ if i % accelerator.num_processes != accelerator.process_index:
894
+ continue
895
+
896
+ prompt = self.args.validation_prompts[i]
897
+ image = self.args.validation_images[i]
898
+ video = self.args.validation_videos[i]
899
+ height = self.args.validation_heights[i]
900
+ width = self.args.validation_widths[i]
901
+ num_frames = self.args.validation_num_frames[i]
902
+ frame_rate = self.args.validation_frame_rate
903
+ if image is not None:
904
+ image = load_image(image)
905
+ if video is not None:
906
+ video = load_video(video)
907
+
908
+ logger.debug(
909
+ f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
910
+ main_process_only=False,
911
+ )
912
+ validation_artifacts = self.model_config["validation"](
913
+ pipeline=pipeline,
914
+ prompt=prompt,
915
+ image=image,
916
+ video=video,
917
+ height=height,
918
+ width=width,
919
+ num_frames=num_frames,
920
+ frame_rate=frame_rate,
921
+ num_videos_per_prompt=self.args.num_validation_videos_per_prompt,
922
+ generator=torch.Generator(device=accelerator.device).manual_seed(
923
+ self.args.seed if self.args.seed is not None else 0
924
+ ),
925
+ # todo support passing `fps` for supported pipelines.
926
+ )
927
+
928
+ prompt_filename = string_to_filename(prompt)[:25]
929
+ artifacts = {
930
+ "image": {"type": "image", "value": image},
931
+ "video": {"type": "video", "value": video},
932
+ }
933
+ for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
934
+ if artifact_value:
935
+ artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
936
+ logger.debug(
937
+ f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
938
+ main_process_only=False,
939
+ )
940
+
941
+ for index, (key, value) in enumerate(list(artifacts.items())):
942
+ artifact_type = value["type"]
943
+ artifact_value = value["value"]
944
+ if artifact_type not in ["image", "video"] or artifact_value is None:
945
+ continue
946
+
947
+ extension = "png" if artifact_type == "image" else "mp4"
948
+ filename = "validation-" if not final_validation else "final-"
949
+ filename += f"{step}-{accelerator.process_index}-{index}-{prompt_filename}.{extension}"
950
+ if accelerator.is_main_process and extension == "mp4":
951
+ prompts_to_filenames[prompt] = filename
952
+ filename = os.path.join(self.args.output_dir, filename)
953
+
954
+ if artifact_type == "image" and artifact_value:
955
+ logger.debug(f"Saving image to {filename}")
956
+ artifact_value.save(filename)
957
+ artifact_value = wandb.Image(filename)
958
+ elif artifact_type == "video" and artifact_value:
959
+ logger.debug(f"Saving video to {filename}")
960
+ # TODO: this should be configurable here as well as in validation runs where we call the pipeline that has `fps`.
961
+ export_to_video(artifact_value, filename, fps=frame_rate)
962
+ artifact_value = wandb.Video(filename, caption=prompt)
963
+
964
+ all_processes_artifacts.append(artifact_value)
965
+
966
+ all_artifacts = gather_object(all_processes_artifacts)
967
+
968
+ if accelerator.is_main_process:
969
+ tracker_key = "final" if final_validation else "validation"
970
+ for tracker in accelerator.trackers:
971
+ if tracker.name == "wandb":
972
+ artifact_log_dict = {}
973
+
974
+ image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
975
+ if len(image_artifacts) > 0:
976
+ artifact_log_dict["images"] = image_artifacts
977
+ video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
978
+ if len(video_artifacts) > 0:
979
+ artifact_log_dict["videos"] = video_artifacts
980
+ tracker.log({tracker_key: artifact_log_dict}, step=step)
981
+
982
+ if self.args.push_to_hub and final_validation:
983
+ video_filenames = list(prompts_to_filenames.values())
984
+ prompts = list(prompts_to_filenames.keys())
985
+ save_model_card(
986
+ args=self.args,
987
+ repo_id=self.state.repo_id,
988
+ videos=video_filenames,
989
+ validation_prompts=prompts,
990
+ )
991
+
992
+ # Remove all hooks that might have been added during pipeline initialization to the models
993
+ pipeline.remove_all_hooks()
994
+ del pipeline
995
+
996
+ accelerator.wait_for_everyone()
997
+
998
+ free_memory()
999
+ memory_statistics = get_memory_statistics()
1000
+ logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
1001
+ torch.cuda.reset_peak_memory_stats(accelerator.device)
1002
+
1003
+ if not final_validation:
1004
+ self.transformer.train()
1005
+
1006
+ def evaluate(self) -> None:
1007
+ raise NotImplementedError("Evaluation has not been implemented yet.")
1008
+
1009
+ def _init_distributed(self) -> None:
1010
+ logging_dir = Path(self.args.output_dir, self.args.logging_dir)
1011
+ project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
1012
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
1013
+ init_process_group_kwargs = InitProcessGroupKwargs(
1014
+ backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
1015
+ )
1016
+ report_to = None if self.args.report_to.lower() == "none" else self.args.report_to
1017
+
1018
+ accelerator = Accelerator(
1019
+ project_config=project_config,
1020
+ gradient_accumulation_steps=self.args.gradient_accumulation_steps,
1021
+ log_with=report_to,
1022
+ kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
1023
+ )
1024
+
1025
+ # Disable AMP for MPS.
1026
+ if torch.backends.mps.is_available():
1027
+ accelerator.native_amp = False
1028
+
1029
+ self.state.accelerator = accelerator
1030
+
1031
+ if self.args.seed is not None:
1032
+ self.state.seed = self.args.seed
1033
+ set_seed(self.args.seed)
1034
+
1035
+ def _init_logging(self) -> None:
1036
+ logging.basicConfig(
1037
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
1038
+ datefmt="%m/%d/%Y %H:%M:%S",
1039
+ level=FINETRAINERS_LOG_LEVEL,
1040
+ )
1041
+ if self.state.accelerator.is_local_main_process:
1042
+ transformers.utils.logging.set_verbosity_warning()
1043
+ diffusers.utils.logging.set_verbosity_info()
1044
+ else:
1045
+ transformers.utils.logging.set_verbosity_error()
1046
+ diffusers.utils.logging.set_verbosity_error()
1047
+
1048
+ logger.info("Initialized FineTrainers")
1049
+ logger.info(self.state.accelerator.state, main_process_only=False)
1050
+
1051
+ def _init_directories_and_repositories(self) -> None:
1052
+ if self.state.accelerator.is_main_process:
1053
+ self.args.output_dir = Path(self.args.output_dir)
1054
+ self.args.output_dir.mkdir(parents=True, exist_ok=True)
1055
+ self.state.output_dir = Path(self.args.output_dir)
1056
+
1057
+ if self.args.push_to_hub:
1058
+ repo_id = self.args.hub_model_id or Path(self.args.output_dir).name
1059
+ self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id
1060
+
1061
+ def _init_config_options(self) -> None:
1062
+ # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1063
+ if self.args.allow_tf32 and torch.cuda.is_available():
1064
+ torch.backends.cuda.matmul.allow_tf32 = True
1065
+
1066
+ def _move_components_to_device(self):
1067
+ if self.text_encoder is not None:
1068
+ self.text_encoder = self.text_encoder.to(self.state.accelerator.device)
1069
+ if self.text_encoder_2 is not None:
1070
+ self.text_encoder_2 = self.text_encoder_2.to(self.state.accelerator.device)
1071
+ if self.text_encoder_3 is not None:
1072
+ self.text_encoder_3 = self.text_encoder_3.to(self.state.accelerator.device)
1073
+ if self.transformer is not None:
1074
+ self.transformer = self.transformer.to(self.state.accelerator.device)
1075
+ if self.unet is not None:
1076
+ self.unet = self.unet.to(self.state.accelerator.device)
1077
+ if self.vae is not None:
1078
+ self.vae = self.vae.to(self.state.accelerator.device)
1079
+
1080
+ def _get_load_components_kwargs(self) -> Dict[str, Any]:
1081
+ load_component_kwargs = {
1082
+ "text_encoder_dtype": self.args.text_encoder_dtype,
1083
+ "text_encoder_2_dtype": self.args.text_encoder_2_dtype,
1084
+ "text_encoder_3_dtype": self.args.text_encoder_3_dtype,
1085
+ "transformer_dtype": self.args.transformer_dtype,
1086
+ "vae_dtype": self.args.vae_dtype,
1087
+ "shift": self.args.flow_shift,
1088
+ "revision": self.args.revision,
1089
+ "cache_dir": self.args.cache_dir,
1090
+ }
1091
+ if self.args.pretrained_model_name_or_path is not None:
1092
+ load_component_kwargs["model_id"] = self.args.pretrained_model_name_or_path
1093
+ return load_component_kwargs
1094
+
1095
+ def _set_components(self, components: Dict[str, Any]) -> None:
1096
+ # Set models
1097
+ self.tokenizer = components.get("tokenizer", self.tokenizer)
1098
+ self.tokenizer_2 = components.get("tokenizer_2", self.tokenizer_2)
1099
+ self.tokenizer_3 = components.get("tokenizer_3", self.tokenizer_3)
1100
+ self.text_encoder = components.get("text_encoder", self.text_encoder)
1101
+ self.text_encoder_2 = components.get("text_encoder_2", self.text_encoder_2)
1102
+ self.text_encoder_3 = components.get("text_encoder_3", self.text_encoder_3)
1103
+ self.transformer = components.get("transformer", self.transformer)
1104
+ self.unet = components.get("unet", self.unet)
1105
+ self.vae = components.get("vae", self.vae)
1106
+ self.scheduler = components.get("scheduler", self.scheduler)
1107
+
1108
+ # Set configs
1109
+ self.transformer_config = self.transformer.config if self.transformer is not None else self.transformer_config
1110
+ self.vae_config = self.vae.config if self.vae is not None else self.vae_config
1111
+
1112
+ def _delete_components(self) -> None:
1113
+ self.tokenizer = None
1114
+ self.tokenizer_2 = None
1115
+ self.tokenizer_3 = None
1116
+ self.text_encoder = None
1117
+ self.text_encoder_2 = None
1118
+ self.text_encoder_3 = None
1119
+ self.transformer = None
1120
+ self.unet = None
1121
+ self.vae = None
1122
+ self.scheduler = None
1123
+ free_memory()
1124
+ torch.cuda.synchronize(self.state.accelerator.device)
1125
+
1126
+ def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline:
1127
+ accelerator = self.state.accelerator
1128
+ if not final_validation:
1129
+ pipeline = self.model_config["initialize_pipeline"](
1130
+ model_id=self.args.pretrained_model_name_or_path,
1131
+ tokenizer=self.tokenizer,
1132
+ text_encoder=self.text_encoder,
1133
+ tokenizer_2=self.tokenizer_2,
1134
+ text_encoder_2=self.text_encoder_2,
1135
+ transformer=unwrap_model(accelerator, self.transformer),
1136
+ vae=self.vae,
1137
+ device=accelerator.device,
1138
+ revision=self.args.revision,
1139
+ cache_dir=self.args.cache_dir,
1140
+ enable_slicing=self.args.enable_slicing,
1141
+ enable_tiling=self.args.enable_tiling,
1142
+ enable_model_cpu_offload=self.args.enable_model_cpu_offload,
1143
+ is_training=True,
1144
+ )
1145
+ else:
1146
+ self._delete_components()
1147
+
1148
+ # Load the transformer weights from the final checkpoint if performing full-finetune
1149
+ transformer = None
1150
+ if self.args.training_type == "full-finetune":
1151
+ transformer = self.model_config["load_diffusion_models"](model_id=self.args.output_dir)["transformer"]
1152
+
1153
+ pipeline = self.model_config["initialize_pipeline"](
1154
+ model_id=self.args.pretrained_model_name_or_path,
1155
+ transformer=transformer,
1156
+ device=accelerator.device,
1157
+ revision=self.args.revision,
1158
+ cache_dir=self.args.cache_dir,
1159
+ enable_slicing=self.args.enable_slicing,
1160
+ enable_tiling=self.args.enable_tiling,
1161
+ enable_model_cpu_offload=self.args.enable_model_cpu_offload,
1162
+ is_training=False,
1163
+ )
1164
+
1165
+ # Load the LoRA weights if performing LoRA finetuning
1166
+ if self.args.training_type == "lora":
1167
+ pipeline.load_lora_weights(self.args.output_dir)
1168
+
1169
+ return pipeline
1170
+
1171
+ def _disable_grad_for_components(self, components: List[torch.nn.Module]):
1172
+ for component in components:
1173
+ if component is not None:
1174
+ component.requires_grad_(False)
1175
+
1176
+ def _enable_grad_for_components(self, components: List[torch.nn.Module]):
1177
+ for component in components:
1178
+ if component is not None:
1179
+ component.requires_grad_(True)
1180
+
1181
+ def _get_training_info(self) -> dict:
1182
+ args = self.args.to_dict()
1183
+
1184
+ training_args = args.get("training_arguments", {})
1185
+ training_type = training_args.get("training_type", "")
1186
+
1187
+ # LoRA/non-LoRA stuff.
1188
+ if training_type == "full-finetune":
1189
+ filtered_training_args = {
1190
+ k: v for k, v in training_args.items() if k not in {"rank", "lora_alpha", "target_modules"}
1191
+ }
1192
+ else:
1193
+ filtered_training_args = training_args
1194
+
1195
+ # Diffusion/flow stuff.
1196
+ diffusion_args = args.get("diffusion_arguments", {})
1197
+ scheduler_name = self.scheduler.__class__.__name__
1198
+ if scheduler_name != "FlowMatchEulerDiscreteScheduler":
1199
+ filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k}
1200
+ else:
1201
+ filtered_diffusion_args = diffusion_args
1202
+
1203
+ # Rest of the stuff.
1204
+ updated_training_info = args.copy()
1205
+ updated_training_info["training_arguments"] = filtered_training_args
1206
+ updated_training_info["diffusion_arguments"] = filtered_diffusion_args
1207
+ return updated_training_info
finetrainers/utils/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .diffusion_utils import (
2
+ default_flow_shift,
3
+ get_scheduler_alphas,
4
+ get_scheduler_sigmas,
5
+ prepare_loss_weights,
6
+ prepare_sigmas,
7
+ prepare_target,
8
+ resolution_dependent_timestep_flow_shift,
9
+ )
10
+ from .file_utils import delete_files, find_files
11
+ from .memory_utils import bytes_to_gigabytes, free_memory, get_memory_statistics, make_contiguous
12
+ from .optimizer_utils import get_optimizer, gradient_norm, max_gradient
13
+ from .torch_utils import unwrap_model
finetrainers/utils/checkpointing.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple
3
+
4
+ from accelerate.logging import get_logger
5
+
6
+ from ..constants import FINETRAINERS_LOG_LEVEL
7
+ from ..utils.file_utils import delete_files, find_files
8
+
9
+
10
+ logger = get_logger("finetrainers")
11
+ logger.setLevel(FINETRAINERS_LOG_LEVEL)
12
+
13
+
14
+ def get_latest_ckpt_path_to_resume_from(
15
+ resume_from_checkpoint: str, num_update_steps_per_epoch: int, output_dir: str
16
+ ) -> Tuple[str, int, int, int]:
17
+ if not resume_from_checkpoint:
18
+ initial_global_step = 0
19
+ global_step = 0
20
+ first_epoch = 0
21
+ resume_from_checkpoint_path = None
22
+ else:
23
+ if resume_from_checkpoint != "latest":
24
+ path = os.path.basename(resume_from_checkpoint)
25
+ else:
26
+ # Get the most recent checkpoint
27
+ dirs = os.listdir(output_dir)
28
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
29
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
30
+ path = dirs[-1] if len(dirs) > 0 else None
31
+
32
+ if path is None:
33
+ logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
34
+ resume_from_checkpoint = None
35
+ initial_global_step = 0
36
+ global_step = 0
37
+ first_epoch = 0
38
+ resume_from_checkpoint_path = None
39
+ else:
40
+ logger.info(f"Resuming from checkpoint {path}")
41
+ resume_from_checkpoint_path = os.path.join(output_dir, path)
42
+ global_step = int(path.split("-")[1])
43
+
44
+ initial_global_step = global_step
45
+ first_epoch = global_step // num_update_steps_per_epoch
46
+
47
+ return resume_from_checkpoint_path, initial_global_step, global_step, first_epoch
48
+
49
+
50
+ def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str) -> str:
51
+ # before saving state, check if this save would set us over the `checkpointing_limit`
52
+ if checkpointing_limit is not None:
53
+ checkpoints = find_files(output_dir, prefix="checkpoint")
54
+
55
+ # before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints
56
+ if len(checkpoints) >= checkpointing_limit:
57
+ num_to_remove = len(checkpoints) - checkpointing_limit + 1
58
+ checkpoints_to_remove = [os.path.join(output_dir, x) for x in checkpoints[0:num_to_remove]]
59
+ delete_files(checkpoints_to_remove)
60
+
61
+ logger.info(f"Checkpointing at step {step}")
62
+ save_path = os.path.join(output_dir, f"checkpoint-{step}")
63
+ logger.info(f"Saving state to {save_path}")
64
+ return save_path
finetrainers/utils/data_utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Union
3
+
4
+ from accelerate.logging import get_logger
5
+
6
+ from ..constants import PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME
7
+
8
+
9
+ logger = get_logger("finetrainers")
10
+
11
+
12
+ def should_perform_precomputation(precomputation_dir: Union[str, Path]) -> bool:
13
+ if isinstance(precomputation_dir, str):
14
+ precomputation_dir = Path(precomputation_dir)
15
+ conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
16
+ latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
17
+ if conditions_dir.exists() and latents_dir.exists():
18
+ num_files_conditions = len(list(conditions_dir.glob("*.pt")))
19
+ num_files_latents = len(list(latents_dir.glob("*.pt")))
20
+ if num_files_conditions != num_files_latents:
21
+ logger.warning(
22
+ f"Number of precomputed conditions ({num_files_conditions}) does not match number of precomputed latents ({num_files_latents})."
23
+ f"Cleaning up precomputed directories and re-running precomputation."
24
+ )
25
+ # clean up precomputed directories
26
+ for file in conditions_dir.glob("*.pt"):
27
+ file.unlink()
28
+ for file in latents_dir.glob("*.pt"):
29
+ file.unlink()
30
+ return True
31
+ if num_files_conditions > 0:
32
+ logger.info(f"Found {num_files_conditions} precomputed conditions and latents.")
33
+ return False
34
+ logger.info("Precomputed data not found. Running precomputation.")
35
+ return True
finetrainers/utils/diffusion_utils.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler
6
+ from diffusers.training_utils import compute_loss_weighting_for_sd3
7
+
8
+
9
+ # Default values copied from https://github.com/huggingface/diffusers/blob/8957324363d8b239d82db4909fbf8c0875683e3d/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L47
10
+ def resolution_dependent_timestep_flow_shift(
11
+ latents: torch.Tensor,
12
+ sigmas: torch.Tensor,
13
+ base_image_seq_len: int = 256,
14
+ max_image_seq_len: int = 4096,
15
+ base_shift: float = 0.5,
16
+ max_shift: float = 1.15,
17
+ ) -> torch.Tensor:
18
+ image_or_video_sequence_length = 0
19
+ if latents.ndim == 4:
20
+ image_or_video_sequence_length = latents.shape[2] * latents.shape[3]
21
+ elif latents.ndim == 5:
22
+ image_or_video_sequence_length = latents.shape[2] * latents.shape[3] * latents.shape[4]
23
+ else:
24
+ raise ValueError(f"Expected 4D or 5D tensor, got {latents.ndim}D tensor")
25
+
26
+ m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
27
+ b = base_shift - m * base_image_seq_len
28
+ mu = m * image_or_video_sequence_length + b
29
+ sigmas = default_flow_shift(latents, sigmas, shift=mu)
30
+ return sigmas
31
+
32
+
33
+ def default_flow_shift(sigmas: torch.Tensor, shift: float = 1.0) -> torch.Tensor:
34
+ sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
35
+ return sigmas
36
+
37
+
38
+ def compute_density_for_timestep_sampling(
39
+ weighting_scheme: str,
40
+ batch_size: int,
41
+ logit_mean: float = None,
42
+ logit_std: float = None,
43
+ mode_scale: float = None,
44
+ device: torch.device = torch.device("cpu"),
45
+ generator: Optional[torch.Generator] = None,
46
+ ) -> torch.Tensor:
47
+ r"""
48
+ Compute the density for sampling the timesteps when doing SD3 training.
49
+
50
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
51
+
52
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
53
+ """
54
+ if weighting_scheme == "logit_normal":
55
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
56
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
57
+ u = torch.nn.functional.sigmoid(u)
58
+ elif weighting_scheme == "mode":
59
+ u = torch.rand(size=(batch_size,), device=device, generator=generator)
60
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
61
+ else:
62
+ u = torch.rand(size=(batch_size,), device=device, generator=generator)
63
+ return u
64
+
65
+
66
+ def get_scheduler_alphas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor:
67
+ if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
68
+ return None
69
+ elif isinstance(scheduler, CogVideoXDDIMScheduler):
70
+ return scheduler.alphas_cumprod.clone()
71
+ else:
72
+ raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
73
+
74
+
75
+ def get_scheduler_sigmas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor:
76
+ if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
77
+ return scheduler.sigmas.clone()
78
+ elif isinstance(scheduler, CogVideoXDDIMScheduler):
79
+ return scheduler.timesteps.clone().float() / float(scheduler.config.num_train_timesteps)
80
+ else:
81
+ raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
82
+
83
+
84
+ def prepare_sigmas(
85
+ scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
86
+ sigmas: torch.Tensor,
87
+ batch_size: int,
88
+ num_train_timesteps: int,
89
+ flow_weighting_scheme: str = "none",
90
+ flow_logit_mean: float = 0.0,
91
+ flow_logit_std: float = 1.0,
92
+ flow_mode_scale: float = 1.29,
93
+ device: torch.device = torch.device("cpu"),
94
+ generator: Optional[torch.Generator] = None,
95
+ ) -> torch.Tensor:
96
+ if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
97
+ weights = compute_density_for_timestep_sampling(
98
+ weighting_scheme=flow_weighting_scheme,
99
+ batch_size=batch_size,
100
+ logit_mean=flow_logit_mean,
101
+ logit_std=flow_logit_std,
102
+ mode_scale=flow_mode_scale,
103
+ device=device,
104
+ generator=generator,
105
+ )
106
+ indices = (weights * num_train_timesteps).long()
107
+ elif isinstance(scheduler, CogVideoXDDIMScheduler):
108
+ # TODO(aryan): Currently, only uniform sampling is supported. Add more sampling schemes.
109
+ weights = torch.rand(size=(batch_size,), device=device, generator=generator)
110
+ indices = (weights * num_train_timesteps).long()
111
+ else:
112
+ raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
113
+
114
+ return sigmas[indices]
115
+
116
+
117
+ def prepare_loss_weights(
118
+ scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
119
+ alphas: Optional[torch.Tensor] = None,
120
+ sigmas: Optional[torch.Tensor] = None,
121
+ flow_weighting_scheme: str = "none",
122
+ ) -> torch.Tensor:
123
+ if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
124
+ return compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme=flow_weighting_scheme)
125
+ elif isinstance(scheduler, CogVideoXDDIMScheduler):
126
+ # SNR is computed as (alphas / (1 - alphas)), but for some reason CogVideoX uses 1 / (1 - alphas).
127
+ # TODO(aryan): Experiment if using alphas / (1 - alphas) gives better results.
128
+ return 1 / (1 - alphas)
129
+ else:
130
+ raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
131
+
132
+
133
+ def prepare_target(
134
+ scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
135
+ noise: torch.Tensor,
136
+ latents: torch.Tensor,
137
+ ) -> torch.Tensor:
138
+ if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
139
+ target = noise - latents
140
+ elif isinstance(scheduler, CogVideoXDDIMScheduler):
141
+ target = latents
142
+ else:
143
+ raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
144
+
145
+ return target
finetrainers/utils/file_utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import shutil
4
+ from pathlib import Path
5
+ from typing import List, Union
6
+
7
+
8
+ logger = logging.getLogger("finetrainers")
9
+ logger.setLevel(os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO"))
10
+
11
+
12
+ def find_files(dir: Union[str, Path], prefix: str = "checkpoint") -> List[str]:
13
+ if not isinstance(dir, Path):
14
+ dir = Path(dir)
15
+ if not dir.exists():
16
+ return []
17
+ checkpoints = os.listdir(dir.as_posix())
18
+ checkpoints = [c for c in checkpoints if c.startswith(prefix)]
19
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
20
+ return checkpoints
21
+
22
+
23
+ def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None:
24
+ if not isinstance(dirs, list):
25
+ dirs = [dirs]
26
+ dirs = [Path(d) if isinstance(d, str) else d for d in dirs]
27
+ logger.info(f"Deleting files: {dirs}")
28
+ for dir in dirs:
29
+ if not dir.exists():
30
+ continue
31
+ shutil.rmtree(dir, ignore_errors=True)
32
+
33
+
34
+ def string_to_filename(s: str) -> str:
35
+ return (
36
+ s.replace(" ", "-")
37
+ .replace("/", "-")
38
+ .replace(":", "-")
39
+ .replace(".", "-")
40
+ .replace(",", "-")
41
+ .replace(";", "-")
42
+ .replace("!", "-")
43
+ .replace("?", "-")
44
+ )
finetrainers/utils/hub_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import wandb
6
+ from diffusers.utils import export_to_video
7
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
8
+ from PIL import Image
9
+
10
+
11
+ def save_model_card(
12
+ args,
13
+ repo_id: str,
14
+ videos: Union[List[str], Union[List[Image.Image], List[np.ndarray]]],
15
+ validation_prompts: List[str],
16
+ fps: int = 30,
17
+ ) -> None:
18
+ widget_dict = []
19
+ output_dir = str(args.output_dir)
20
+ if videos is not None and len(videos) > 0:
21
+ for i, (video, validation_prompt) in enumerate(zip(videos, validation_prompts)):
22
+ if not isinstance(video, str):
23
+ export_to_video(video, os.path.join(output_dir, f"final_video_{i}.mp4"), fps=fps)
24
+ widget_dict.append(
25
+ {
26
+ "text": validation_prompt if validation_prompt else " ",
27
+ "output": {"url": video if isinstance(video, str) else f"final_video_{i}.mp4"},
28
+ }
29
+ )
30
+
31
+ training_type = "Full" if args.training_type == "full-finetune" else "LoRA"
32
+ model_description = f"""
33
+ # {training_type} Finetune
34
+
35
+ <Gallery />
36
+
37
+ ## Model description
38
+
39
+ This is a {training_type.lower()} finetune of model: `{args.pretrained_model_name_or_path}`.
40
+
41
+ The model was trained using [`finetrainers`](https://github.com/a-r-r-o-w/finetrainers).
42
+
43
+ `id_token` used: {args.id_token} (if it's not `None`, it should be used in the prompts.)
44
+
45
+ ## Download model
46
+
47
+ [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
48
+
49
+ ## Usage
50
+
51
+ Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
52
+
53
+ ```py
54
+ TODO
55
+ ```
56
+
57
+ For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
58
+ """
59
+ if wandb.run and wandb.run.url:
60
+ model_description += f"""
61
+ Find out the wandb run URL and training configurations [here]({wandb.run.url}).
62
+ """
63
+
64
+ model_card = load_or_create_model_card(
65
+ repo_id_or_path=repo_id,
66
+ from_training=True,
67
+ base_model=args.pretrained_model_name_or_path,
68
+ model_description=model_description,
69
+ widget=widget_dict,
70
+ )
71
+ tags = [
72
+ "text-to-video",
73
+ "diffusers-training",
74
+ "diffusers",
75
+ "finetrainers",
76
+ "template:sd-lora",
77
+ ]
78
+ if training_type == "Full":
79
+ tags.append("full-finetune")
80
+ else:
81
+ tags.append("lora")
82
+
83
+ model_card = populate_model_card(model_card, tags=tags)
84
+ model_card.save(os.path.join(args.output_dir, "README.md"))
finetrainers/utils/memory_utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from typing import Any, Dict, Union
3
+
4
+ import torch
5
+ from accelerate.logging import get_logger
6
+
7
+
8
+ logger = get_logger("finetrainers")
9
+
10
+
11
+ def get_memory_statistics(precision: int = 3) -> Dict[str, Any]:
12
+ memory_allocated = None
13
+ memory_reserved = None
14
+ max_memory_allocated = None
15
+ max_memory_reserved = None
16
+
17
+ if torch.cuda.is_available():
18
+ device = torch.cuda.current_device()
19
+ memory_allocated = torch.cuda.memory_allocated(device)
20
+ memory_reserved = torch.cuda.memory_reserved(device)
21
+ max_memory_allocated = torch.cuda.max_memory_allocated(device)
22
+ max_memory_reserved = torch.cuda.max_memory_reserved(device)
23
+
24
+ elif torch.backends.mps.is_available():
25
+ memory_allocated = torch.mps.current_allocated_memory()
26
+
27
+ else:
28
+ logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.")
29
+
30
+ return {
31
+ "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
32
+ "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
33
+ "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
34
+ "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
35
+ }
36
+
37
+
38
+ def bytes_to_gigabytes(x: int) -> float:
39
+ if x is not None:
40
+ return x / 1024**3
41
+
42
+
43
+ def free_memory() -> None:
44
+ if torch.cuda.is_available():
45
+ gc.collect()
46
+ torch.cuda.empty_cache()
47
+ torch.cuda.ipc_collect()
48
+
49
+ # TODO(aryan): handle non-cuda devices
50
+
51
+
52
+ def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
53
+ if isinstance(x, torch.Tensor):
54
+ return x.contiguous()
55
+ elif isinstance(x, dict):
56
+ return {k: make_contiguous(v) for k, v in x.items()}
57
+ else:
58
+ return x
finetrainers/utils/model_utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import json
3
+ import os
4
+
5
+ from huggingface_hub import hf_hub_download
6
+
7
+
8
+ def resolve_vae_cls_from_ckpt_path(ckpt_path, **kwargs):
9
+ ckpt_path = str(ckpt_path)
10
+ if os.path.exists(str(ckpt_path)) and os.path.isdir(ckpt_path):
11
+ index_path = os.path.join(ckpt_path, "model_index.json")
12
+ else:
13
+ revision = kwargs.get("revision", None)
14
+ cache_dir = kwargs.get("cache_dir", None)
15
+ index_path = hf_hub_download(
16
+ repo_id=ckpt_path, filename="model_index.json", revision=revision, cache_dir=cache_dir
17
+ )
18
+
19
+ with open(index_path, "r") as f:
20
+ model_index_dict = json.load(f)
21
+ assert "vae" in model_index_dict, "No VAE found in the modelx index dict."
22
+
23
+ vae_cls_config = model_index_dict["vae"]
24
+ library = importlib.import_module(vae_cls_config[0])
25
+ return getattr(library, vae_cls_config[1])
finetrainers/utils/optimizer_utils.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+
3
+ import torch
4
+ from accelerate.logging import get_logger
5
+
6
+
7
+ logger = get_logger("finetrainers")
8
+
9
+
10
+ def get_optimizer(
11
+ params_to_optimize,
12
+ optimizer_name: str = "adam",
13
+ learning_rate: float = 1e-3,
14
+ beta1: float = 0.9,
15
+ beta2: float = 0.95,
16
+ beta3: float = 0.98,
17
+ epsilon: float = 1e-8,
18
+ weight_decay: float = 1e-4,
19
+ prodigy_decouple: bool = False,
20
+ prodigy_use_bias_correction: bool = False,
21
+ prodigy_safeguard_warmup: bool = False,
22
+ use_8bit: bool = False,
23
+ use_4bit: bool = False,
24
+ use_torchao: bool = False,
25
+ use_deepspeed: bool = False,
26
+ use_cpu_offload_optimizer: bool = False,
27
+ offload_gradients: bool = False,
28
+ ) -> torch.optim.Optimizer:
29
+ optimizer_name = optimizer_name.lower()
30
+
31
+ # Use DeepSpeed optimzer
32
+ if use_deepspeed:
33
+ from accelerate.utils import DummyOptim
34
+
35
+ return DummyOptim(
36
+ params_to_optimize,
37
+ lr=learning_rate,
38
+ betas=(beta1, beta2),
39
+ eps=epsilon,
40
+ weight_decay=weight_decay,
41
+ )
42
+
43
+ # TODO: consider moving the validation logic to `args.py` when we have torchao.
44
+ if use_8bit and use_4bit:
45
+ raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
46
+
47
+ if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
48
+ try:
49
+ import torchao # noqa
50
+
51
+ except ImportError:
52
+ raise ImportError(
53
+ "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
54
+ )
55
+
56
+ if not use_torchao and use_4bit:
57
+ raise ValueError("4-bit Optimizers are only supported with torchao.")
58
+
59
+ # Optimizer creation
60
+ supported_optimizers = ["adam", "adamw", "prodigy", "came"]
61
+ if optimizer_name not in supported_optimizers:
62
+ logger.warning(
63
+ f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
64
+ )
65
+ optimizer_name = "adamw"
66
+
67
+ if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
68
+ raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
69
+
70
+ if use_8bit:
71
+ try:
72
+ import bitsandbytes as bnb
73
+ except ImportError:
74
+ raise ImportError(
75
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
76
+ )
77
+
78
+ if optimizer_name == "adamw":
79
+ if use_torchao:
80
+ from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
81
+
82
+ optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
83
+ else:
84
+ optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
85
+
86
+ init_kwargs = {
87
+ "betas": (beta1, beta2),
88
+ "eps": epsilon,
89
+ "weight_decay": weight_decay,
90
+ }
91
+
92
+ elif optimizer_name == "adam":
93
+ if use_torchao:
94
+ from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
95
+
96
+ optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
97
+ else:
98
+ optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
99
+
100
+ init_kwargs = {
101
+ "betas": (beta1, beta2),
102
+ "eps": epsilon,
103
+ "weight_decay": weight_decay,
104
+ }
105
+
106
+ elif optimizer_name == "prodigy":
107
+ try:
108
+ import prodigyopt
109
+ except ImportError:
110
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
111
+
112
+ optimizer_class = prodigyopt.Prodigy
113
+
114
+ if learning_rate <= 0.1:
115
+ logger.warning(
116
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
117
+ )
118
+
119
+ init_kwargs = {
120
+ "lr": learning_rate,
121
+ "betas": (beta1, beta2),
122
+ "beta3": beta3,
123
+ "eps": epsilon,
124
+ "weight_decay": weight_decay,
125
+ "decouple": prodigy_decouple,
126
+ "use_bias_correction": prodigy_use_bias_correction,
127
+ "safeguard_warmup": prodigy_safeguard_warmup,
128
+ }
129
+
130
+ elif optimizer_name == "came":
131
+ try:
132
+ import came_pytorch
133
+ except ImportError:
134
+ raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
135
+
136
+ optimizer_class = came_pytorch.CAME
137
+
138
+ init_kwargs = {
139
+ "lr": learning_rate,
140
+ "eps": (1e-30, 1e-16),
141
+ "betas": (beta1, beta2, beta3),
142
+ "weight_decay": weight_decay,
143
+ }
144
+
145
+ if use_cpu_offload_optimizer:
146
+ from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
147
+
148
+ if "fused" in inspect.signature(optimizer_class.__init__).parameters:
149
+ init_kwargs.update({"fused": True})
150
+
151
+ optimizer = CPUOffloadOptimizer(
152
+ params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
153
+ )
154
+ else:
155
+ optimizer = optimizer_class(params_to_optimize, **init_kwargs)
156
+
157
+ return optimizer
158
+
159
+
160
+ def gradient_norm(parameters):
161
+ norm = 0
162
+ for param in parameters:
163
+ if param.grad is None:
164
+ continue
165
+ local_norm = param.grad.detach().data.norm(2)
166
+ norm += local_norm.item() ** 2
167
+ norm = norm**0.5
168
+ return norm
169
+
170
+
171
+ def max_gradient(parameters):
172
+ max_grad_value = float("-inf")
173
+ for param in parameters:
174
+ if param.grad is None:
175
+ continue
176
+ local_max_grad = param.grad.detach().data.abs().max()
177
+ max_grad_value = max(max_grad_value, local_max_grad.item())
178
+ return max_grad_value
finetrainers/utils/torch_utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Union
2
+
3
+ import torch
4
+ from accelerate import Accelerator
5
+ from diffusers.utils.torch_utils import is_compiled_module
6
+
7
+
8
+ def unwrap_model(accelerator: Accelerator, model):
9
+ model = accelerator.unwrap_model(model)
10
+ model = model._orig_mod if is_compiled_module(model) else model
11
+ return model
12
+
13
+
14
+ def align_device_and_dtype(
15
+ x: Union[torch.Tensor, Dict[str, torch.Tensor]],
16
+ device: Optional[torch.device] = None,
17
+ dtype: Optional[torch.dtype] = None,
18
+ ):
19
+ if isinstance(x, torch.Tensor):
20
+ if device is not None:
21
+ x = x.to(device)
22
+ if dtype is not None:
23
+ x = x.to(dtype)
24
+ elif isinstance(x, dict):
25
+ if device is not None:
26
+ x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
27
+ if dtype is not None:
28
+ x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
29
+ return x
30
+
31
+
32
+ def expand_tensor_dims(tensor, ndim):
33
+ while len(tensor.shape) < ndim:
34
+ tensor = tensor.unsqueeze(-1)
35
+ return tensor
finetrainers_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ import logging
4
+ import shutil
5
+ from typing import Any, Optional, Dict, List, Union, Tuple
6
+ from config import STORAGE_PATH, TRAINING_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES
7
+ from utils import extract_scene_info, make_archive, is_image_file, is_video_file
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ def prepare_finetrainers_dataset() -> Tuple[Path, Path]:
12
+ """make sure we have a Finetrainers-compatible dataset structure
13
+
14
+ Checks that we have:
15
+ training/
16
+ ├── prompt.txt # All captions, one per line
17
+ ├── videos.txt # All video paths, one per line
18
+ └── videos/ # Directory containing all mp4 files
19
+ ├── 00000.mp4
20
+ ├── 00001.mp4
21
+ └── ...
22
+ Returns:
23
+ Tuple of (videos_file_path, prompts_file_path)
24
+ """
25
+
26
+ # Verifies the videos subdirectory
27
+ TRAINING_VIDEOS_PATH.mkdir(exist_ok=True)
28
+
29
+ # Clear existing training lists
30
+ for f in TRAINING_PATH.glob("*"):
31
+ if f.is_file():
32
+ if f.name in ["videos.txt", "prompts.txt"]:
33
+ f.unlink()
34
+
35
+ videos_file = TRAINING_PATH / "videos.txt"
36
+ prompts_file = TRAINING_PATH / "prompts.txt" # Note: Changed from prompt.txt to prompts.txt to match our config
37
+
38
+ media_files = []
39
+ captions = []
40
+ # Process all video files from the videos subdirectory
41
+ for idx, file in enumerate(sorted(TRAINING_VIDEOS_PATH.glob("*.mp4"))):
42
+ caption_file = file.with_suffix('.txt')
43
+ if caption_file.exists():
44
+ # Normalize caption to single line
45
+ caption = caption_file.read_text().strip()
46
+ caption = ' '.join(caption.split())
47
+
48
+ # Use relative path from training root
49
+ relative_path = f"videos/{file.name}"
50
+ media_files.append(relative_path)
51
+ captions.append(caption)
52
+
53
+ # Clean up the caption file since it's now in prompts.txt
54
+ # EDIT well you know what, let's keep it, otherwise running the function
55
+ # twice might cause some errors
56
+ # caption_file.unlink()
57
+
58
+ # Write files if we have content
59
+ if media_files and captions:
60
+ videos_file.write_text('\n'.join(media_files))
61
+ prompts_file.write_text('\n'.join(captions))
62
+
63
+ else:
64
+ raise ValueError("No valid video/caption pairs found in training directory")
65
+ # Verify file contents
66
+ with open(videos_file) as vf:
67
+ video_lines = [l.strip() for l in vf.readlines() if l.strip()]
68
+ with open(prompts_file) as pf:
69
+ prompt_lines = [l.strip() for l in pf.readlines() if l.strip()]
70
+
71
+ if len(video_lines) != len(prompt_lines):
72
+ raise ValueError(f"Mismatch in generated files: {len(video_lines)} videos vs {len(prompt_lines)} prompts")
73
+
74
+ return videos_file, prompts_file
75
+
76
+ def copy_files_to_training_dir(prompt_prefix: str) -> int:
77
+ """Just copy files over, with no destruction"""
78
+
79
+ gr.Info("Copying assets to the training dataset..")
80
+
81
+ # Find files needing captions
82
+ video_files = list(STAGING_PATH.glob("*.mp4"))
83
+ image_files = [f for f in STAGING_PATH.glob("*") if is_image_file(f)]
84
+ all_files = video_files + image_files
85
+
86
+ nb_copied_pairs = 0
87
+
88
+ for file_path in all_files:
89
+
90
+ caption = ""
91
+ file_caption_path = file_path.with_suffix('.txt')
92
+ if file_caption_path.exists():
93
+ logger.debug(f"Found caption file: {file_caption_path}")
94
+ caption = file_caption_path.read_text()
95
+
96
+ # Get parent caption if this is a clip
97
+ parent_caption = ""
98
+ if "___" in file_path.stem:
99
+ parent_name, _ = extract_scene_info(file_path.stem)
100
+ #print(f"parent_name is {parent_name}")
101
+ parent_caption_path = STAGING_PATH / f"{parent_name}.txt"
102
+ if parent_caption_path.exists():
103
+ logger.debug(f"Found parent caption file: {parent_caption_path}")
104
+ parent_caption = parent_caption_path.read_text().strip()
105
+
106
+ target_file_path = TRAINING_VIDEOS_PATH / file_path.name
107
+
108
+ target_caption_path = target_file_path.with_suffix('.txt')
109
+
110
+ if parent_caption and not caption.endswith(parent_caption):
111
+ caption = f"{caption}\n{parent_caption}"
112
+
113
+ if prompt_prefix and not caption.startswith(prompt_prefix):
114
+ caption = f"{prompt_prefix}{caption}"
115
+
116
+ # make sure we only copy over VALID pairs
117
+ if caption:
118
+ target_caption_path.write_text(caption)
119
+ shutil.copy2(file_path, target_file_path)
120
+ nb_copied_pairs += 1
121
+
122
+ prepare_finetrainers_dataset()
123
+
124
+ gr.Info(f"Successfully generated the training dataset ({nb_copied_pairs} pairs)")
125
+
126
+ return nb_copied_pairs
image_preprocessing.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from pathlib import Path
4
+ from PIL import Image
5
+ import pillow_avif
6
+ import logging
7
+ from config import NORMALIZE_IMAGES_TO, JPEG_QUALITY
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ def normalize_image(input_path: Path, output_path: Path) -> bool:
12
+ """Convert image to normalized format (PNG or JPEG) and optionally remove black bars
13
+
14
+ Args:
15
+ input_path: Source image path
16
+ output_path: Target path
17
+
18
+ Returns:
19
+ bool: True if successful, False otherwise
20
+ """
21
+ try:
22
+ # Open image with PIL
23
+ with Image.open(input_path) as img:
24
+ # Convert to RGB if needed
25
+ if img.mode in ('RGBA', 'LA'):
26
+ background = Image.new('RGB', img.size, (255, 255, 255))
27
+ if img.mode == 'RGBA':
28
+ background.paste(img, mask=img.split()[3])
29
+ else:
30
+ background.paste(img, mask=img.split()[1])
31
+ img = background
32
+ elif img.mode != 'RGB':
33
+ img = img.convert('RGB')
34
+
35
+ # Convert to numpy for black bar detection
36
+ img_np = np.array(img)
37
+
38
+ # Detect black bars
39
+ top, bottom, left, right = detect_black_bars(img_np)
40
+
41
+ # Crop if black bars detected
42
+ if any([top > 0, bottom < img_np.shape[0] - 1,
43
+ left > 0, right < img_np.shape[1] - 1]):
44
+ img = img.crop((left, top, right, bottom))
45
+
46
+ # Save as configured format
47
+ if NORMALIZE_IMAGES_TO == 'png':
48
+ img.save(output_path, 'PNG', optimize=True)
49
+ else: # jpg
50
+ img.save(output_path, 'JPEG', quality=JPEG_QUALITY, optimize=True)
51
+ return True
52
+
53
+ except Exception as e:
54
+ logger.error(f"Error converting image {input_path}: {str(e)}")
55
+ return False
56
+
57
+ def detect_black_bars(img: np.ndarray) -> tuple[int, int, int, int]:
58
+ """Detect black bars in image
59
+
60
+ Args:
61
+ img: numpy array of image (HxWxC)
62
+
63
+ Returns:
64
+ Tuple of (top, bottom, left, right) crop coordinates
65
+ """
66
+ # Convert to grayscale if needed
67
+ if len(img.shape) == 3:
68
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
69
+ else:
70
+ gray = img
71
+
72
+ # Threshold to detect black regions
73
+ threshold = 20
74
+ black_mask = gray < threshold
75
+
76
+ # Find black bars by analyzing row/column means
77
+ row_means = np.mean(black_mask, axis=1)
78
+ col_means = np.mean(black_mask, axis=0)
79
+
80
+ # Detect edges where black bars end (95% threshold)
81
+ black_threshold = 0.95
82
+
83
+ # Find top and bottom crops
84
+ top = 0
85
+ bottom = img.shape[0]
86
+
87
+ for i, mean in enumerate(row_means):
88
+ if mean > black_threshold:
89
+ top = i + 1
90
+ else:
91
+ break
92
+
93
+ for i, mean in enumerate(reversed(row_means)):
94
+ if mean > black_threshold:
95
+ bottom = img.shape[0] - i - 1
96
+ else:
97
+ break
98
+
99
+ # Find left and right crops
100
+ left = 0
101
+ right = img.shape[1]
102
+
103
+ for i, mean in enumerate(col_means):
104
+ if mean > black_threshold:
105
+ left = i + 1
106
+ else:
107
+ break
108
+
109
+ for i, mean in enumerate(reversed(col_means)):
110
+ if mean > black_threshold:
111
+ right = img.shape[1] - i - 1
112
+ else:
113
+ break
114
+
115
+ return top, bottom, left, right
116
+
import_service.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import zipfile
4
+ import tempfile
5
+ import gradio as gr
6
+ from pathlib import Path
7
+ from typing import List, Dict, Optional, Tuple
8
+ from pytubefix import YouTube
9
+ import logging
10
+ from utils import is_image_file, is_video_file, add_prefix_to_caption
11
+ from image_preprocessing import normalize_image
12
+
13
+ from config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, TRAINING_PATH, DEFAULT_PROMPT_PREFIX
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class ImportService:
18
+ def process_uploaded_files(self, file_paths: List[str]) -> str:
19
+ """Process uploaded file (ZIP, MP4, or image)
20
+
21
+ Args:
22
+ file_paths: File paths to the ploaded files from Gradio
23
+
24
+ Returns:
25
+ Status message string
26
+ """
27
+ for file_path in file_paths:
28
+ file_path = Path(file_path)
29
+ try:
30
+ original_name = file_path.name
31
+ print("original_name = ", original_name)
32
+
33
+ # Determine file type from name
34
+ file_ext = file_path.suffix.lower()
35
+
36
+ if file_ext == '.zip':
37
+ return self.process_zip_file(file_path)
38
+ elif file_ext == '.mp4' or file_ext == '.webm':
39
+ return self.process_mp4_file(file_path, original_name)
40
+ elif is_image_file(file_path):
41
+ return self.process_image_file(file_path, original_name)
42
+ else:
43
+ raise gr.Error(f"Unsupported file type: {file_ext}")
44
+
45
+ except Exception as e:
46
+ raise gr.Error(f"Error processing file: {str(e)}")
47
+
48
+ def process_image_file(self, file_path: Path, original_name: str) -> str:
49
+ """Process a single image file
50
+
51
+ Args:
52
+ file_path: Path to the image
53
+ original_name: Original filename
54
+
55
+ Returns:
56
+ Status message string
57
+ """
58
+ try:
59
+ # Create a unique filename with configured extension
60
+ stem = Path(original_name).stem
61
+ target_path = STAGING_PATH / f"{stem}.{NORMALIZE_IMAGES_TO}"
62
+
63
+ # If file already exists, add number suffix
64
+ counter = 1
65
+ while target_path.exists():
66
+ target_path = STAGING_PATH / f"{stem}___{counter}.{NORMALIZE_IMAGES_TO}"
67
+ counter += 1
68
+
69
+ # Convert to normalized format and remove black bars
70
+ success = normalize_image(file_path, target_path)
71
+
72
+ if not success:
73
+ raise gr.Error(f"Failed to process image: {original_name}")
74
+
75
+ # Handle caption
76
+ src_caption_path = file_path.with_suffix('.txt')
77
+ if src_caption_path.exists():
78
+ caption = src_caption_path.read_text()
79
+ caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
80
+ target_path.with_suffix('.txt').write_text(caption)
81
+
82
+ logger.info(f"Successfully stored image: {target_path.name}")
83
+ gr.Info(f"Successfully stored image: {target_path.name}")
84
+ return f"Successfully stored image: {target_path.name}"
85
+
86
+ except Exception as e:
87
+ raise gr.Error(f"Error processing image file: {str(e)}")
88
+
89
+ def process_zip_file(self, file_path: Path) -> str:
90
+ """Process uploaded ZIP file containing media files
91
+
92
+ Args:
93
+ file_path: Path to the uploaded ZIP file
94
+
95
+ Returns:
96
+ Status message string
97
+ """
98
+ try:
99
+ video_count = 0
100
+ image_count = 0
101
+
102
+ # Create temporary directory
103
+ with tempfile.TemporaryDirectory() as temp_dir:
104
+ # Extract ZIP
105
+ extract_dir = Path(temp_dir) / "extracted"
106
+ extract_dir.mkdir()
107
+ with zipfile.ZipFile(file_path, 'r') as zip_ref:
108
+ zip_ref.extractall(extract_dir)
109
+
110
+ # Process each file
111
+ for root, _, files in os.walk(extract_dir):
112
+ for file in files:
113
+ if file.startswith('._'): # Skip Mac metadata
114
+ continue
115
+
116
+ file_path = Path(root) / file
117
+
118
+ try:
119
+ if is_video_file(file_path):
120
+ # Copy video to videos_to_split
121
+ target_path = VIDEOS_TO_SPLIT_PATH / file_path.name
122
+ counter = 1
123
+ while target_path.exists():
124
+ target_path = VIDEOS_TO_SPLIT_PATH / f"{file_path.stem}___{counter}{file_path.suffix}"
125
+ counter += 1
126
+ shutil.copy2(file_path, target_path)
127
+ video_count += 1
128
+
129
+ elif is_image_file(file_path):
130
+ # Convert image and save to staging
131
+ target_path = STAGING_PATH / f"{file_path.stem}.{NORMALIZE_IMAGES_TO}"
132
+ counter = 1
133
+ while target_path.exists():
134
+ target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}"
135
+ counter += 1
136
+ if normalize_image(file_path, target_path):
137
+ image_count += 1
138
+
139
+ # Copy associated caption file if it exists
140
+ txt_path = file_path.with_suffix('.txt')
141
+ if txt_path.exists():
142
+ if is_video_file(file_path):
143
+ shutil.copy2(txt_path, target_path.with_suffix('.txt'))
144
+ elif is_image_file(file_path):
145
+ shutil.copy2(txt_path, target_path.with_suffix('.txt'))
146
+
147
+ except Exception as e:
148
+ logger.error(f"Error processing {file_path.name}: {str(e)}")
149
+ continue
150
+
151
+ # Generate status message
152
+ parts = []
153
+ if video_count > 0:
154
+ parts.append(f"{video_count} videos")
155
+ if image_count > 0:
156
+ parts.append(f"{image_count} images")
157
+
158
+ if not parts:
159
+ return "No supported media files found in ZIP"
160
+
161
+ status = f"Successfully stored {' and '.join(parts)}"
162
+ gr.Info(status)
163
+ return status
164
+
165
+ except Exception as e:
166
+ raise gr.Error(f"Error processing ZIP: {str(e)}")
167
+
168
+ def process_mp4_file(self, file_path: Path, original_name: str) -> str:
169
+ """Process a single video file
170
+
171
+ Args:
172
+ file_path: Path to the file
173
+ original_name: Original filename
174
+
175
+ Returns:
176
+ Status message string
177
+ """
178
+ try:
179
+ # Create a unique filename
180
+ target_path = VIDEOS_TO_SPLIT_PATH / original_name
181
+
182
+ # If file already exists, add number suffix
183
+ counter = 1
184
+ while target_path.exists():
185
+ stem = Path(original_name).stem
186
+ target_path = VIDEOS_TO_SPLIT_PATH / f"{stem}___{counter}.mp4"
187
+ counter += 1
188
+
189
+ # Copy the file to the target location
190
+ shutil.copy2(file_path, target_path)
191
+
192
+ gr.Info(f"Successfully stored video: {target_path.name}")
193
+ return f"Successfully stored video: {target_path.name}"
194
+
195
+ except Exception as e:
196
+ raise gr.Error(f"Error processing video file: {str(e)}")
197
+
198
+ def download_youtube_video(self, url: str, progress=None) -> Dict:
199
+ """Download a video from YouTube
200
+
201
+ Args:
202
+ url: YouTube video URL
203
+ progress: Optional Gradio progress indicator
204
+
205
+ Returns:
206
+ Dict with status message and error (if any)
207
+ """
208
+ try:
209
+ # Extract video ID and create YouTube object
210
+ yt = YouTube(url, on_progress_callback=lambda stream, chunk, bytes_remaining:
211
+ progress((1 - bytes_remaining / stream.filesize), desc="Downloading...")
212
+ if progress else None)
213
+
214
+ video_id = yt.video_id
215
+ output_path = VIDEOS_TO_SPLIT_PATH / f"{video_id}.mp4"
216
+
217
+ # Download highest quality progressive MP4
218
+ if progress:
219
+ print("Getting video streams...")
220
+ progress(0, desc="Getting video streams...")
221
+ video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
222
+
223
+ if not video:
224
+ print("Could not find a compatible video format")
225
+ gr.Error("Could not find a compatible video format")
226
+ return "Could not find a compatible video format"
227
+
228
+ # Download the video
229
+ if progress:
230
+ print("Starting YouTube video download...")
231
+ progress(0, desc="Starting download...")
232
+
233
+ video.download(output_path=str(VIDEOS_TO_SPLIT_PATH), filename=f"{video_id}.mp4")
234
+
235
+ # Update UI
236
+ if progress:
237
+ print("YouTube video download complete!")
238
+ gr.Info("YouTube video download complete!")
239
+ progress(1, desc="Download complete!")
240
+ return f"Successfully downloaded video: {yt.title}"
241
+
242
+ except Exception as e:
243
+ print(e)
244
+ gr.Error(f"Error downloading video: {str(e)}")
245
+ return f"Error downloading video: {str(e)}"
requirements.txt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy>=1.26.4
2
+
3
+ # to quote a-r-r-o-w/finetrainers:
4
+ # It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.
5
+ torch==2.5.1
6
+ torchvision==0.20.1
7
+ torchao==0.6.1
8
+
9
+ huggingface_hub
10
+ hf_transfer>=0.1.8
11
+ diffusers>=0.30.3
12
+ transformers>=4.45.2
13
+
14
+ accelerate
15
+ bitsandbytes
16
+ peft>=0.12.0
17
+ eva-decord==0.6.1
18
+ wandb
19
+ pandas
20
+ sentencepiece>=0.2.0
21
+ imageio-ffmpeg>=0.5.1
22
+
23
+ flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
24
+
25
+ # for youtube video download
26
+ pytube
27
+ pytubefix
28
+
29
+ # for scene splitting
30
+ scenedetect[opencv]
31
+
32
+ # for llava video / captionning
33
+ pillow
34
+ pillow-avif-plugin
35
+ polars
36
+ einops
37
+ open_clip_torch
38
+ av==14.1.0
39
+ git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
40
+
41
+ # for our frontend
42
+ gradio==5.15.0
43
+ gradio_toggle
requirements_without_flash_attention.txt ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy>=1.26.4
2
+
3
+ # to quote a-r-r-o-w/finetrainers:
4
+ # It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.
5
+ torch==2.5.1
6
+ torchvision==0.20.1
7
+ torchao==0.6.1
8
+
9
+
10
+ huggingface_hub
11
+ hf_transfer>=0.1.8
12
+ diffusers>=0.30.3
13
+ transformers>=4.45.2
14
+
15
+ accelerate
16
+ bitsandbytes
17
+ peft>=0.12.0
18
+ eva-decord==0.6.1
19
+ wandb
20
+ pandas
21
+ sentencepiece>=0.2.0
22
+ imageio-ffmpeg>=0.5.1
23
+
24
+ # for youtube video download
25
+ pytube
26
+ pytubefix
27
+
28
+ # for scene splitting
29
+ scenedetect[opencv]
30
+
31
+ # for llava video / captionning
32
+ pillow
33
+ pillow-avif-plugin
34
+ polars
35
+ einops
36
+ open_clip_torch
37
+ av==14.1.0
38
+ git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
39
+
40
+ # for our frontend
41
+ gradio==5.15.0
42
+ gradio_toggle
run.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ source .venv/bin/activate
4
+
5
+ USE_MOCK_CAPTIONING_MODEL=True python app.py
setup.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ python -m venv .venv
4
+
5
+ source .venv/bin/activate
6
+
7
+ python -m pip install -r requirements.txt
setup_no_captions.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ python -m venv .venv
4
+
5
+ source .venv/bin/activate
6
+
7
+ python -m pip install -r requirements_without_flash_attention.txt
8
+
9
+ # if you require flash attention, please install it manually for your operating system
10
+
11
+ # you can try this:
12
+ # python -m pip install wheel setuptools flash-attn --no-build-isolation --no-cache-dir